From f730ac92bd9038fdb7c7dc232a52bec645e3fdf2 Mon Sep 17 00:00:00 2001 From: Yogesh Tewari Date: Fri, 30 Jul 2021 15:32:41 -0400 Subject: [PATCH] feat: Allow user to specify a format for stdout, files reformatted with black (#242) --- data_validation/__main__.py | 4 +- data_validation/cli_tools.py | 28 ++++---- data_validation/clients.py | 2 +- data_validation/combiner.py | 7 +- data_validation/config_manager.py | 20 +++--- data_validation/data_validation.py | 28 ++++---- .../query_builder/query_builder.py | 70 ++++++++++++++----- data_validation/result_handlers/text.py | 6 +- data_validation/schema_validation.py | 4 +- data_validation/validation_builder.py | 22 +++--- .../cloudsql_resource_manager.py | 12 ++-- .../deploy_cloudsql/gcloud_context.py | 2 +- tests/system/data_sources/test_bigquery.py | 19 +++-- tests/system/data_sources/test_mysql.py | 3 +- tests/system/data_sources/test_postgres.py | 7 +- tests/system/data_sources/test_spanner.py | 10 ++- tests/system/data_sources/test_sql_server.py | 7 +- tests/unit/result_handlers/test_text.py | 4 +- tests/unit/test__main.py | 3 +- tests/unit/test_cli_tools.py | 25 ++++--- tests/unit/test_config_manager.py | 4 +- tests/unit/test_data_validation.py | 6 +- tests/unit/test_metadata.py | 10 ++- tests/unit/test_schema_validation.py | 2 +- 24 files changed, 193 insertions(+), 112 deletions(-) diff --git a/data_validation/__main__.py b/data_validation/__main__.py index beaf0ecbf..870cfc94a 100644 --- a/data_validation/__main__.py +++ b/data_validation/__main__.py @@ -146,7 +146,7 @@ def build_config_managers_from_args(args): result_handler_config=result_handler_config, filter_config=filter_config, verbose=args.verbose, - format=args.format + format=args.format, ) configs.append(build_config_from_args(args, config_manager)) @@ -326,7 +326,7 @@ def run(args): def run_connections(args): - """ Run commands related to connection management.""" + """Run commands related to connection management.""" if args.connect_cmd == "list": cli_tools.list_connections() elif args.connect_cmd == "add": diff --git a/data_validation/cli_tools.py b/data_validation/cli_tools.py index 41e189e8f..090c69862 100644 --- a/data_validation/cli_tools.py +++ b/data_validation/cli_tools.py @@ -128,7 +128,7 @@ def get_parsed_args(): - """ Return ArgParser with configured CLI arguments.""" + """Return ArgParser with configured CLI arguments.""" parser = configure_arg_parser() return parser.parse_args() @@ -194,7 +194,7 @@ def _configure_raw_query(subparsers): def _configure_run_config_parser(subparsers): - """ Configure arguments to run a data validation YAML config.""" + """Configure arguments to run a data validation YAML config.""" run_config_parser = subparsers.add_parser( "run-config", help="Run validations stored in a YAML config file" ) @@ -206,7 +206,7 @@ def _configure_run_config_parser(subparsers): def _configure_run_parser(subparsers): - """ Configure arguments to run a data validation.""" + """Configure arguments to run a data validation.""" # subparsers = parser.add_subparsers(dest="command") @@ -273,7 +273,9 @@ def _configure_run_parser(subparsers): help="Store the validation in the YAML Config File Path specified", ) run_parser.add_argument( - "--labels", "-l", help="Key value pair labels for validation run", + "--labels", + "-l", + help="Key value pair labels for validation run", ) run_parser.add_argument( "--service-account", @@ -300,7 +302,7 @@ def _configure_run_parser(subparsers): def _configure_connection_parser(subparsers): - """ Configure the Parser for Connection Management. """ + """Configure the Parser for Connection Management.""" connection_parser = subparsers.add_parser( "connections", help="Manage & Store connections to your Databases" ) @@ -335,7 +337,7 @@ def _configure_database_specific_parsers(parser): def get_connection_config_from_args(args): - """ Return dict with connection config supplied.""" + """Return dict with connection config supplied.""" config = {consts.SOURCE_TYPE: args.connect_type} if args.connect_type == "Raw": @@ -388,7 +390,7 @@ def _generate_random_name(conn): def store_connection(connection_name, conn): - """ Store the connection config under the given name.""" + """Store the connection config under the given name.""" connection_name = connection_name or _generate_random_name(conn) file_path = _get_connection_file(connection_name) @@ -397,7 +399,7 @@ def store_connection(connection_name, conn): def get_connections(): - """ Return dict with connection name and path key pairs.""" + """Return dict with connection name and path key pairs.""" connections = {} dir_path = _get_data_validation_directory() @@ -413,7 +415,7 @@ def get_connections(): def list_connections(): - """ List all saved connections.""" + """List all saved connections.""" connections = get_connections() for conn_name in connections: @@ -421,7 +423,7 @@ def list_connections(): def get_connection(connection_name): - """ Return dict connection details for a specific connection.""" + """Return dict connection details for a specific connection.""" file_path = _get_connection_file(connection_name) with open(file_path, "r") as file: conn_str = file.read() @@ -430,7 +432,7 @@ def get_connection(connection_name): def get_labels(arg_labels): - """ Return list of tuples representing key-value label pairs. """ + """Return list of tuples representing key-value label pairs.""" labels = [] if arg_labels: pairs = arg_labels.split(",") @@ -515,7 +517,7 @@ def get_arg_list(arg_value, default_value=None): def get_tables_list(arg_tables, default_value=None, is_filesystem=False): - """ Returns dictionary of tables. Backwards compatible for JSON input. + """Returns dictionary of tables. Backwards compatible for JSON input. arg_table (str): tables_list argument specified default_value (Any): A default value to supply when arg_value is empty. @@ -571,7 +573,7 @@ def get_tables_list(arg_tables, default_value=None, is_filesystem=False): def split_table(table_ref, schema_required=True): - """ Returns schema and table name given list of input values. + """Returns schema and table name given list of input values. table_ref (List): Table reference i.e ['my.schema.my_table'] scehma_required (boolean): Indicates whether schema is required. A source diff --git a/data_validation/clients.py b/data_validation/clients.py index 734488647..b07261172 100644 --- a/data_validation/clients.py +++ b/data_validation/clients.py @@ -179,7 +179,7 @@ def get_all_tables(client, allowed_schemas=None): def get_data_client(connection_config): - """ Return DataClient client from given configuration """ + """Return DataClient client from given configuration""" connection_config = copy.deepcopy(connection_config) source_type = connection_config.pop(consts.SOURCE_TYPE) diff --git a/data_validation/combiner.py b/data_validation/combiner.py index 9055cd3e7..251511668 100644 --- a/data_validation/combiner.py +++ b/data_validation/combiner.py @@ -32,7 +32,12 @@ def generate_report( - client, run_metadata, source, target, join_on_fields=(), verbose=False, + client, + run_metadata, + source, + target, + join_on_fields=(), + verbose=False, ): """Combine results into a report. diff --git a/data_validation/config_manager.py b/data_validation/config_manager.py index 7adae8c5e..1f10d44c9 100644 --- a/data_validation/config_manager.py +++ b/data_validation/config_manager.py @@ -82,12 +82,12 @@ def process_in_memory(self): @property def max_recursive_query_size(self): - """Return Aggregates from Config """ + """Return Aggregates from Config""" return self._config.get(consts.CONFIG_MAX_RECURSIVE_QUERY_SIZE, 50000) @property def aggregates(self): - """Return Aggregates from Config """ + """Return Aggregates from Config""" return self._config.get(consts.CONFIG_AGGREGATES, []) def append_aggregates(self, aggregate_configs): @@ -105,7 +105,7 @@ def append_calculated_fields(self, calculated_configs): @property def query_groups(self): - """ Return Query Groups from Config """ + """Return Query Groups from Config""" return self._config.get(consts.CONFIG_GROUPED_COLUMNS, []) def append_query_groups(self, grouped_column_configs): @@ -116,7 +116,7 @@ def append_query_groups(self, grouped_column_configs): @property def primary_keys(self): - """ Return Query Groups from Config """ + """Return Query Groups from Config""" return self._config.get(consts.CONFIG_PRIMARY_KEYS, []) def append_primary_keys(self, primary_key_configs): @@ -127,7 +127,7 @@ def append_primary_keys(self, primary_key_configs): @property def filters(self): - """Return Filters from Config """ + """Return Filters from Config""" return self._config.get(consts.CONFIG_FILTERS, []) @property @@ -185,12 +185,12 @@ def query_limit(self): @property def threshold(self): - """Return threshold from Config """ + """Return threshold from Config""" return self._config.get(consts.CONFIG_THRESHOLD, 0.0) @property def format(self): - """Return threshold from Config """ + """Return threshold from Config""" return self._config.get(consts.CONFIG_FORMAT, "table") def get_source_ibis_table(self): @@ -253,8 +253,10 @@ def get_result_handler(self): consts.GOOGLE_SERVICE_ACCOUNT_KEY_PATH ) if key_path: - credentials = google.oauth2.service_account.Credentials.from_service_account_file( - key_path + credentials = ( + google.oauth2.service_account.Credentials.from_service_account_file( + key_path + ) ) else: credentials = None diff --git a/data_validation/data_validation.py b/data_validation/data_validation.py index 38c7daddc..ea6f5fd8b 100644 --- a/data_validation/data_validation.py +++ b/data_validation/data_validation.py @@ -90,14 +90,14 @@ def __init__( # TODO(dhercher) we planned on shifting this to use an Execution Handler. # Leaving to to swast on the design of how this should look. def execute(self): - """ Execute Queries and Store Results """ + """Execute Queries and Store Results""" if self.config_manager.validation_type == consts.ROW_VALIDATION: grouped_fields = self.validation_builder.pop_grouped_fields() result_df = self.execute_recursive_validation( self.validation_builder, grouped_fields ) elif self.config_manager.validation_type == consts.SCHEMA_VALIDATION: - """ Perform only schema validation """ + """Perform only schema validation""" result_df = self.schema_validator.execute() else: result_df = self._execute_validation( @@ -108,15 +108,15 @@ def execute(self): return self.result_handler.execute(self.config, self.format, result_df) def query_too_large(self, rows_df, grouped_fields): - """ Return bool to dictate if another level of recursion - would create a too large result set. - - Rules to define too large are: - - If any grouped fields remain, return False. - (assumes user added logical sized groups) - - Else, if next group size is larger - than the limit, return True. - - Finally return False if no covered case occured. + """Return bool to dictate if another level of recursion + would create a too large result set. + + Rules to define too large are: + - If any grouped fields remain, return False. + (assumes user added logical sized groups) + - Else, if next group size is larger + than the limit, return True. + - Finally return False if no covered case occured. """ if len(grouped_fields) > 1: return False @@ -203,7 +203,7 @@ def execute_recursive_validation(self, validation_builder, grouped_fields): return pandas.concat(past_results) def _add_recursive_validation_filter(self, validation_builder, row): - """ Return ValidationBuilder Configured for Next Recursive Search """ + """Return ValidationBuilder Configured for Next Recursive Search""" group_by_columns = json.loads(row[consts.GROUP_BY_COLUMNS]) for alias, value in group_by_columns.items(): filter_field = { @@ -250,7 +250,7 @@ def _get_pandas_schema(self, source_df, target_df, join_on_fields, verbose=False return pd_schema def _execute_validation(self, validation_builder, process_in_memory=True): - """ Execute Against a Supplied Validation Builder """ + """Execute Against a Supplied Validation Builder""" self.run_metadata.validations = validation_builder.get_metadata() source_query = validation_builder.get_source_query() @@ -301,7 +301,7 @@ def _execute_validation(self, validation_builder, process_in_memory=True): return result_df def combine_data(self, source_df, target_df, join_on_fields): - """ TODO: Return List of Dictionaries """ + """TODO: Return List of Dictionaries""" # Clean Data to Standardize if join_on_fields: df = source_df.merge( diff --git a/data_validation/query_builder/query_builder.py b/data_validation/query_builder/query_builder.py index bee93ee16..fb5cb1c87 100644 --- a/data_validation/query_builder/query_builder.py +++ b/data_validation/query_builder/query_builder.py @@ -37,11 +37,15 @@ def __init__(self, ibis_expr, field_name=None, alias=None): def count(field_name=None, alias=None): if field_name: return AggregateField( - ibis.expr.types.ColumnExpr.count, field_name=field_name, alias=alias, + ibis.expr.types.ColumnExpr.count, + field_name=field_name, + alias=alias, ) else: return AggregateField( - ibis.expr.types.TableExpr.count, field_name=field_name, alias=alias, + ibis.expr.types.TableExpr.count, + field_name=field_name, + alias=alias, ) @staticmethod @@ -59,19 +63,25 @@ def avg(field_name=None, alias=None): @staticmethod def max(field_name=None, alias=None): return AggregateField( - ibis.expr.types.ColumnExpr.max, field_name=field_name, alias=alias, + ibis.expr.types.ColumnExpr.max, + field_name=field_name, + alias=alias, ) @staticmethod def sum(field_name=None, alias=None): return AggregateField( - ibis.expr.api.IntegerColumn.sum, field_name=field_name, alias=alias, + ibis.expr.api.IntegerColumn.sum, + field_name=field_name, + alias=alias, ) @staticmethod def bit_xor(field_name=None, alias=None): return AggregateField( - ibis.expr.api.IntegerColumn.bit_xor, field_name=field_name, alias=alias, + ibis.expr.api.IntegerColumn.bit_xor, + field_name=field_name, + alias=alias, ) def compile(self, ibis_table): @@ -196,7 +206,7 @@ def compile(self, ibis_table): class ColumnReference(object): def __init__(self, column_name): - """ A representation of an calculated field to build a query. + """A representation of an calculated field to build a query. Args: column_name (String): The column name used in a complex expr @@ -204,7 +214,7 @@ def __init__(self, column_name): self.column_name = column_name def compile(self, ibis_table): - """ Return an ibis object referencing the column. + """Return an ibis object referencing the column. Args: ibis_table (IbisTable): The table obj reference @@ -215,7 +225,7 @@ def compile(self, ibis_table): class CalculatedField(object): def __init__(self, ibis_expr, config, fields, cast=None, **kwargs): - """ A representation of an calculated field to build a query. + """A representation of an calculated field to build a query. Args: config dict: Configurations object explaining calc field details @@ -234,37 +244,61 @@ def concat(config, fields): fields = [config["default_concat_separator"], fields] cast = "string" return CalculatedField( - ibis.expr.api.StringValue.join, config, fields, cast=cast, + ibis.expr.api.StringValue.join, + config, + fields, + cast=cast, ) @staticmethod def hash(config, fields): if config.get("default_hash_function") is None: how = "farm_fingerprint" - return CalculatedField(ibis.expr.api.ValueExpr.hash, config, fields, how=how,) + return CalculatedField( + ibis.expr.api.ValueExpr.hash, + config, + fields, + how=how, + ) @staticmethod def ifnull(config, fields): if config.get("default_null_string") is None: config["default_string"] = ibis.literal("DEFAULT_REPLACEMENT_STRING") fields = [config["default_string"], fields[0]] - return CalculatedField(ibis.expr.api.ValueExpr.fillna, config, fields,) + return CalculatedField( + ibis.expr.api.ValueExpr.fillna, + config, + fields, + ) @staticmethod def length(config, fields): - return CalculatedField(ibis.expr.api.StringValue.length, config, fields,) + return CalculatedField( + ibis.expr.api.StringValue.length, + config, + fields, + ) @staticmethod def rstrip(config, fields): - return CalculatedField(ibis.expr.api.StringValue.rstrip, config, fields,) + return CalculatedField( + ibis.expr.api.StringValue.rstrip, + config, + fields, + ) @staticmethod def upper(config, fields): - return CalculatedField(ibis.expr.api.StringValue.upper, config, fields,) + return CalculatedField( + ibis.expr.api.StringValue.upper, + config, + fields, + ) @staticmethod def custom(expr): - """ Returns a CalculatedField instance built for any custom SQL using a supported operator. + """Returns a CalculatedField instance built for any custom SQL using a supported operator. Args: expr (Str): A custom SQL expression used to filter a query """ @@ -299,7 +333,7 @@ class QueryBuilder(object): def __init__( self, aggregate_fields, calculated_fields, filters, grouped_fields, limit=None ): - """ Build a QueryBuilder object which can be used to build queries easily + """Build a QueryBuilder object which can be used to build queries easily Args: aggregate_fields (list[AggregateField]): AggregateField instances with Ibis expressions @@ -316,7 +350,7 @@ def __init__( @staticmethod def build_count_validator(limit=None): - """ Return a basic template builder for most validations """ + """Return a basic template builder for most validations""" aggregate_fields = [] filters = [] grouped_fields = [] @@ -419,7 +453,7 @@ def add_filter_field(self, filter_obj): self.filters.append(filter_obj) def add_calculated_field(self, calculated_field): - """ Add a CalculatedField instance to your query which + """Add a CalculatedField instance to your query which will add the desired scalar function to your compiled query (ie. CONCAT(field_a, field_b)) Args: diff --git a/data_validation/result_handlers/text.py b/data_validation/result_handlers/text.py index 6e28d7e8a..da7e2dbf0 100644 --- a/data_validation/result_handlers/text.py +++ b/data_validation/result_handlers/text.py @@ -38,8 +38,10 @@ def print_formatted_(format, result_df): elif format == "table": print(result_df.to_markdown(tablefmt="fancy_grid")) else: - error_msg = f"format [{format}] not supported, results printed in default(table) mode. " \ - f"Supported formats are [text, csv, json, table]" + error_msg = ( + f"format [{format}] not supported, results printed in default(table) mode. " + f"Supported formats are [text, csv, json, table]" + ) print(result_df.to_markdown(tablefmt="fancy_grid")) raise ValueError(error_msg) diff --git a/data_validation/schema_validation.py b/data_validation/schema_validation.py index cdb60087e..702555d09 100644 --- a/data_validation/schema_validation.py +++ b/data_validation/schema_validation.py @@ -32,7 +32,7 @@ def __init__(self, config_manager, run_metadata=None, verbose=False): self.run_metadata = run_metadata or metadata.RunMetadata() def execute(self): - """ Performs a validation between source and a target schema""" + """Performs a validation between source and a target schema""" ibis_source_schema = self.config_manager.source_client.get_schema( self.config_manager.source_table, self.config_manager.source_schema ) @@ -86,7 +86,7 @@ def execute(self): def schema_validation_matching(source_fields, target_fields): - """ Compare schemas between two dictionary objects """ + """Compare schemas between two dictionary objects""" results = [] # Go through each source and check if target exists and matches for source_field_name, source_field_type in source_fields.items(): diff --git a/data_validation/validation_builder.py b/data_validation/validation_builder.py index 259647510..e7232b58a 100644 --- a/data_validation/validation_builder.py +++ b/data_validation/validation_builder.py @@ -67,7 +67,7 @@ def clone(self): @staticmethod def get_query_builder(validation_type): - """ Return Query Builder object given validation type """ + """Return Query Builder object given validation type""" if validation_type in ["Column", "GroupedColumn", "Row", "Schema"]: builder = QueryBuilder.build_count_validator() else: @@ -89,11 +89,11 @@ def get_metadata(self): return self._metadata def get_group_aliases(self): - """ Return List of String Aliases """ + """Return List of String Aliases""" return self.group_aliases.keys() def get_calculated_aliases(self): - """ Return List of String Aliases """ + """Return List of String Aliases""" return self.calculated_aliases.keys() def get_grouped_alias_source_column(self, alias): @@ -103,26 +103,26 @@ def get_grouped_alias_target_column(self, alias): return self.group_aliases[alias][consts.CONFIG_TARGET_COLUMN] def add_config_aggregates(self): - """ Add Aggregations to Query """ + """Add Aggregations to Query""" aggregate_fields = self.config_manager.aggregates for aggregate_field in aggregate_fields: self.add_aggregate(aggregate_field) def add_config_calculated_fields(self): - """ Add calculated fields to Query """ + """Add calculated fields to Query""" calc_fields = self.config_manager.calculated_fields if calc_fields is not None: for calc_field in calc_fields: self.add_calc(calc_field) def add_config_query_groups(self, query_groups=None): - """ Add Grouped Columns to Query """ + """Add Grouped Columns to Query""" grouped_fields = query_groups or self.config_manager.query_groups for grouped_field in grouped_fields: self.add_query_group(grouped_field) def add_config_filters(self): - """ Add Filters to Query """ + """Add Filters to Query""" filter_fields = self.config_manager.filters for filter_field in filter_fields: self.add_filter(filter_field) @@ -162,7 +162,7 @@ def add_aggregate(self, aggregate_field): ) def pop_grouped_fields(self): - """ Return grouped fields and reset configs.""" + """Return grouped fields and reset configs.""" self.source_builder.grouped_fields = [] self.target_builder.grouped_fields = [] self.group_aliases = {} @@ -219,7 +219,7 @@ def add_filter(self, filter_field): self.target_builder.add_filter_field(target_filter) def add_calc(self, calc_field): - """ Add CalculatedField to Queries + """Add CalculatedField to Queries Args: calc_field (Dict): An object with source, target, and cast info @@ -247,7 +247,7 @@ def add_calc(self, calc_field): self.calculated_aliases[alias] = calc_field def get_source_query(self): - """ Return query for source validation """ + """Return query for source validation""" source_config = { "data_client": self.source_client, "schema_name": self.config_manager.source_schema, @@ -261,7 +261,7 @@ def get_source_query(self): return query def get_target_query(self): - """ Return query for source validation """ + """Return query for source validation""" target_config = { "data_client": self.target_client, "schema_name": self.config_manager.target_schema, diff --git a/tests/system/data_sources/deploy_cloudsql/cloudsql_resource_manager.py b/tests/system/data_sources/deploy_cloudsql/cloudsql_resource_manager.py index 8d3a03958..eef054e0c 100644 --- a/tests/system/data_sources/deploy_cloudsql/cloudsql_resource_manager.py +++ b/tests/system/data_sources/deploy_cloudsql/cloudsql_resource_manager.py @@ -36,7 +36,7 @@ def __init__( enable_bin_logs=True, already_exists=False, ): - """Initialize a CloudSQLResourceManager """ + """Initialize a CloudSQLResourceManager""" if database_type not in DATABASE_TYPES: raise ValueError( f"Invalid database type. Must be of the form {str(DATABASE_TYPES)}" @@ -55,14 +55,14 @@ def __init__( self.db = {} def describe(self): - """ Returns description of resource manager instance """ + """Returns description of resource manager instance""" print( f"Creates a {self._database_type} instance in project {self._project_id} with " f"database_id: {self._database_id}, instance_id: {self._instance_id}." ) def setup(self): - """ Creates Cloud SQL instance and database """ + """Creates Cloud SQL instance and database""" with GCloudContext(self._project_id) as gcloud: if self._already_exists: json_describe = gcloud.Run( @@ -114,7 +114,7 @@ def setup(self): return self.db["PRIMARY_ADDRESS"] def add_data(self, gcs_data_path): - """ Adds data to Cloud SQL database """ + """Adds data to Cloud SQL database""" if self._already_exists: return with GCloudContext(self._project_id) as gcloud: @@ -129,7 +129,7 @@ def add_data(self, gcs_data_path): ) def teardown(self): - """ Deletes Cloud SQL instance """ + """Deletes Cloud SQL instance""" # If instance is deleted per integration test, instance_id will need a random # suffix appended since Cloud SQL cannot re-use the same instance name until # 1 week after deletion. @@ -137,7 +137,7 @@ def teardown(self): gcloud.Run("--quiet", "sql", "instances", "delete", self._instance_id) def _get_random_string(self, length=5): - """ Returns random string + """Returns random string Args: length (int): Desired length of random string""" return "".join(random.choice(string.ascii_lowercase) for i in range(length)) diff --git a/tests/system/data_sources/deploy_cloudsql/gcloud_context.py b/tests/system/data_sources/deploy_cloudsql/gcloud_context.py index 5175a761f..be150daa6 100644 --- a/tests/system/data_sources/deploy_cloudsql/gcloud_context.py +++ b/tests/system/data_sources/deploy_cloudsql/gcloud_context.py @@ -33,7 +33,7 @@ def __exit__(self, exc_type, exc_value, exc_traceback): pass def Run(self, *args, **kwargs): - """ Runs gcloud command and returns output""" + """Runs gcloud command and returns output""" env = kwargs.pop("env", None) if not env: env = os.environ.copy() diff --git a/tests/system/data_sources/test_bigquery.py b/tests/system/data_sources/test_bigquery.py index 21d9263d5..2d7b95901 100644 --- a/tests/system/data_sources/test_bigquery.py +++ b/tests/system/data_sources/test_bigquery.py @@ -179,7 +179,9 @@ def test_count_validator(): - validator = data_validation.DataValidation(CONFIG_COUNT_VALID, format="text", verbose=True) + validator = data_validation.DataValidation( + CONFIG_COUNT_VALID, format="text", verbose=True + ) df = validator.execute() count_value = df[df["validation_name"] == "count"]["source_agg_value"].values[0] @@ -202,12 +204,15 @@ def test_count_validator(): assert float(max_birth_year_value) > 0 assert float(min_birth_year_value) > 0 assert ( - df["source_agg_value"].astype(float).sum() == df["target_agg_value"].astype(float).sum() + df["source_agg_value"].astype(float).sum() + == df["target_agg_value"].astype(float).sum() ) def test_grouped_count_validator(): - validator = data_validation.DataValidation(CONFIG_GROUPED_COUNT_VALID, format="csv", verbose=True) + validator = data_validation.DataValidation( + CONFIG_GROUPED_COUNT_VALID, format="csv", verbose=True + ) df = validator.execute() rows = list(df[df["validation_name"] == "count"].iterrows()) @@ -223,7 +228,9 @@ def test_grouped_count_validator(): def test_numeric_types(): - validator = data_validation.DataValidation(CONFIG_NUMERIC_AGG_VALID, format="json", verbose=True) + validator = data_validation.DataValidation( + CONFIG_NUMERIC_AGG_VALID, format="json", verbose=True + ) df = validator.execute() for validation in df.to_dict(orient="records"): @@ -282,7 +289,9 @@ def _remove_bq_conn(): def test_unsupported_result_format(): with pytest.raises(ValueError): - validator = data_validation.DataValidation(CONFIG_GROUPED_COUNT_VALID, format="foobar", verbose=True) + validator = data_validation.DataValidation( + CONFIG_GROUPED_COUNT_VALID, format="foobar", verbose=True + ) df = validator.execute() rows = list(df[df["validation_name"] == "count"].iterrows()) assert len(rows) > 1 diff --git a/tests/system/data_sources/test_mysql.py b/tests/system/data_sources/test_mysql.py index 626f17cd1..e6563c6a4 100644 --- a/tests/system/data_sources/test_mysql.py +++ b/tests/system/data_sources/test_mysql.py @@ -52,7 +52,8 @@ def test_mysql_count_invalid_host(): try: data_validator = data_validation.DataValidation( - CONFIG_COUNT_VALID, verbose=False, + CONFIG_COUNT_VALID, + verbose=False, ) df = data_validator.execute() assert df["source_agg_value"][0] == df["target_agg_value"][0] diff --git a/tests/system/data_sources/test_postgres.py b/tests/system/data_sources/test_postgres.py index 5295bf43c..bccf5d1a0 100644 --- a/tests/system/data_sources/test_postgres.py +++ b/tests/system/data_sources/test_postgres.py @@ -28,7 +28,7 @@ def test_postgres_count(): - """ Test count validation on Postgres instance """ + """Test count validation on Postgres instance""" postgres_instance = CloudSQLResourceManager( PROJECT_ID, "POSTGRES_12", @@ -76,6 +76,9 @@ def test_postgres_count(): ], } - data_validator = data_validation.DataValidation(config_count_valid, verbose=False,) + data_validator = data_validation.DataValidation( + config_count_valid, + verbose=False, + ) df = data_validator.execute() assert df["source_agg_value"][0] == df["target_agg_value"][0] diff --git a/tests/system/data_sources/test_spanner.py b/tests/system/data_sources/test_spanner.py index 24aafd080..ee74deddb 100644 --- a/tests/system/data_sources/test_spanner.py +++ b/tests/system/data_sources/test_spanner.py @@ -179,9 +179,13 @@ def test_count_validator(count_config): assert float(count_value) > 0 assert float(count_string_value) > 0 assert float(avg_float_value) > 0 - assert datetime.datetime.strptime( - max_timestamp_value, "%Y-%m-%d %H:%M:%S%z", - ) > datetime.datetime(1970, 1, 1, tzinfo=datetime.timezone.utc) + assert ( + datetime.datetime.strptime( + max_timestamp_value, + "%Y-%m-%d %H:%M:%S%z", + ) + > datetime.datetime(1970, 1, 1, tzinfo=datetime.timezone.utc) + ) assert float(min_int_value) > 0 diff --git a/tests/system/data_sources/test_sql_server.py b/tests/system/data_sources/test_sql_server.py index 6382cc65d..fa58c2adc 100644 --- a/tests/system/data_sources/test_sql_server.py +++ b/tests/system/data_sources/test_sql_server.py @@ -28,7 +28,7 @@ def test_sql_server_count(): - """ Test count validation on SQL Server instance """ + """Test count validation on SQL Server instance""" mssql_instance = CloudSQLResourceManager( PROJECT_ID, "SQLSERVER_2017_STANDARD", @@ -76,6 +76,9 @@ def test_sql_server_count(): ], } - data_validator = data_validation.DataValidation(config_count_valid, verbose=False,) + data_validator = data_validation.DataValidation( + config_count_valid, + verbose=False, + ) df = data_validator.execute() assert df["source_agg_value"][0] == df["target_agg_value"][0] diff --git a/tests/unit/result_handlers/test_text.py b/tests/unit/result_handlers/test_text.py index cd11b1b0d..540c20cbc 100644 --- a/tests/unit/result_handlers/test_text.py +++ b/tests/unit/result_handlers/test_text.py @@ -32,12 +32,12 @@ def module_under_test(): def test_import(module_under_test): - """Test import cleanly """ + """Test import cleanly""" assert module_under_test is not None def test_basic_result_handler(module_under_test): - """Test basic handler executes """ + """Test basic handler executes""" format = "json" result_df = DataFrame(SAMPLE_RESULT_DATA) result_handler = module_under_test.TextResultHandler() diff --git a/tests/unit/test__main.py b/tests/unit/test__main.py index 4e1077807..b386ca11c 100644 --- a/tests/unit/test__main.py +++ b/tests/unit/test__main.py @@ -58,7 +58,8 @@ @mock.patch( - "argparse.ArgumentParser.parse_args", return_value=argparse.Namespace(**CLI_ARGS), + "argparse.ArgumentParser.parse_args", + return_value=argparse.Namespace(**CLI_ARGS), ) def test_configure_arg_parser(mock_args): """Test arg parser values.""" diff --git a/tests/unit/test_cli_tools.py b/tests/unit/test_cli_tools.py index ff7321cb0..5175ebf11 100644 --- a/tests/unit/test_cli_tools.py +++ b/tests/unit/test_cli_tools.py @@ -56,7 +56,8 @@ @mock.patch( - "argparse.ArgumentParser.parse_args", return_value=argparse.Namespace(**CLI_ARGS), + "argparse.ArgumentParser.parse_args", + return_value=argparse.Namespace(**CLI_ARGS), ) def test_get_parsed_args(mock_args): """Test arg parser values.""" @@ -138,13 +139,14 @@ def test_get_labels(test_input, expected): ], ) def test_get_labels_err(test_input): - """Ensure that Value Error is raised when incorrect label argument is provided. """ + """Ensure that Value Error is raised when incorrect label argument is provided.""" with pytest.raises(ValueError): cli_tools.get_labels(test_input) @pytest.mark.parametrize( - "test_input,expected", [(0, 0.0), (50, 50.0), (100, 100.0)], + "test_input,expected", + [(0, 0.0), (50, 50.0), (100, 100.0)], ) def test_threshold_float(test_input, expected): """Test threshold float function.""" @@ -153,7 +155,8 @@ def test_threshold_float(test_input, expected): @pytest.mark.parametrize( - "test_input", [(-4), (float("nan")), (float("inf")), ("string")], + "test_input", + [(-4), (float("nan")), (float("inf")), ("string")], ) def test_threshold_float_err(test_input): """Test that threshold float only accepts positive floats.""" @@ -300,16 +303,17 @@ def test_get_result_handler(test_input, expected): ], ) def test_get_filters(test_input, expected): - """ Test get filters from file function. """ + """Test get filters from file function.""" res = cli_tools.get_filters(test_input) assert res == expected @pytest.mark.parametrize( - "test_input", [("source:"), ("invalid:filter:count")], + "test_input", + [("source:"), ("invalid:filter:count")], ) def test_get_filters_err(test_input): - """ Test get filters function returns error. """ + """Test get filters function returns error.""" with pytest.raises(ValueError): cli_tools.get_filters(test_input) @@ -336,9 +340,12 @@ def test_split_table_no_schema(): @pytest.mark.parametrize( - "test_input", [(["table"])], + "test_input", + [(["table"])], ) -def test_split_table_err(test_input,): +def test_split_table_err( + test_input, +): """Test split table throws the right errors.""" with pytest.raises(ValueError): cli_tools.split_table(test_input) diff --git a/tests/unit/test_config_manager.py b/tests/unit/test_config_manager.py index 35ba739df..6e026dca3 100644 --- a/tests/unit/test_config_manager.py +++ b/tests/unit/test_config_manager.py @@ -87,7 +87,7 @@ def module_under_test(): def test_import(module_under_test): - """Test import cleanly """ + """Test import cleanly""" assert module_under_test is not None @@ -170,7 +170,7 @@ def test_process_in_memory(module_under_test): def test_get_table_info(module_under_test): - """Test basic handler executes """ + """Test basic handler executes""" config_manager = module_under_test.ConfigManager( SAMPLE_CONFIG, MockIbisClient(), MockIbisClient(), verbose=False ) diff --git a/tests/unit/test_data_validation.py b/tests/unit/test_data_validation.py index d2a79308f..88b553f2b 100644 --- a/tests/unit/test_data_validation.py +++ b/tests/unit/test_data_validation.py @@ -258,7 +258,7 @@ def module_under_test(): def _create_table_file(table_path, data): - """ Create JSON File """ + """Create JSON File""" with open(table_path, "w") as f: f.write(data) @@ -312,7 +312,7 @@ def test_import(module_under_test): def test_data_validation_client(module_under_test, fs): - """ Test getting a Data Validation Client """ + """Test getting a Data Validation Client""" _create_table_file(SOURCE_TABLE_FILE_PATH, JSON_DATA) _create_table_file(TARGET_TABLE_FILE_PATH, JSON_DATA) @@ -322,7 +322,7 @@ def test_data_validation_client(module_under_test, fs): def test_get_pandas_schema(module_under_test): - """ Test extracting pandas schema from dataframes for Ibis Pandas.""" + """Test extracting pandas schema from dataframes for Ibis Pandas.""" pandas_schema = module_under_test.DataValidation._get_pandas_schema( SOURCE_DF, SOURCE_DF, JOIN_ON_FIELDS ) diff --git a/tests/unit/test_metadata.py b/tests/unit/test_metadata.py index 639df169f..afcafd42f 100644 --- a/tests/unit/test_metadata.py +++ b/tests/unit/test_metadata.py @@ -43,7 +43,15 @@ def test_get_column_name( def test_get_column_name_with_unexpected_result_type(module_under_test): validation = module_under_test.ValidationMetadata( - "", "", "", "", "", "", "", "", "", + "", + "", + "", + "", + "", + "", + "", + "", + "", ) with pytest.raises(ValueError, match="Unexpected result_type"): validation.get_column_name("oops_i_goofed") diff --git a/tests/unit/test_schema_validation.py b/tests/unit/test_schema_validation.py index da515798e..46e55b6c0 100644 --- a/tests/unit/test_schema_validation.py +++ b/tests/unit/test_schema_validation.py @@ -77,7 +77,7 @@ def module_under_test(): def _create_table_file(table_path, data): - """ Create JSON File """ + """Create JSON File""" with open(table_path, "w") as f: f.write(data)