diff --git a/data_validation/__main__.py b/data_validation/__main__.py index 870cfc94a..beaf0ecbf 100644 --- a/data_validation/__main__.py +++ b/data_validation/__main__.py @@ -146,7 +146,7 @@ def build_config_managers_from_args(args): result_handler_config=result_handler_config, filter_config=filter_config, verbose=args.verbose, - format=args.format, + format=args.format ) configs.append(build_config_from_args(args, config_manager)) @@ -326,7 +326,7 @@ def run(args): def run_connections(args): - """Run commands related to connection management.""" + """ Run commands related to connection management.""" if args.connect_cmd == "list": cli_tools.list_connections() elif args.connect_cmd == "add": diff --git a/data_validation/cli_tools.py b/data_validation/cli_tools.py index 090c69862..41e189e8f 100644 --- a/data_validation/cli_tools.py +++ b/data_validation/cli_tools.py @@ -128,7 +128,7 @@ def get_parsed_args(): - """Return ArgParser with configured CLI arguments.""" + """ Return ArgParser with configured CLI arguments.""" parser = configure_arg_parser() return parser.parse_args() @@ -194,7 +194,7 @@ def _configure_raw_query(subparsers): def _configure_run_config_parser(subparsers): - """Configure arguments to run a data validation YAML config.""" + """ Configure arguments to run a data validation YAML config.""" run_config_parser = subparsers.add_parser( "run-config", help="Run validations stored in a YAML config file" ) @@ -206,7 +206,7 @@ def _configure_run_config_parser(subparsers): def _configure_run_parser(subparsers): - """Configure arguments to run a data validation.""" + """ Configure arguments to run a data validation.""" # subparsers = parser.add_subparsers(dest="command") @@ -273,9 +273,7 @@ def _configure_run_parser(subparsers): help="Store the validation in the YAML Config File Path specified", ) run_parser.add_argument( - "--labels", - "-l", - help="Key value pair labels for validation run", + "--labels", "-l", help="Key value pair labels for validation run", ) run_parser.add_argument( "--service-account", @@ -302,7 +300,7 @@ def _configure_run_parser(subparsers): def _configure_connection_parser(subparsers): - """Configure the Parser for Connection Management.""" + """ Configure the Parser for Connection Management. """ connection_parser = subparsers.add_parser( "connections", help="Manage & Store connections to your Databases" ) @@ -337,7 +335,7 @@ def _configure_database_specific_parsers(parser): def get_connection_config_from_args(args): - """Return dict with connection config supplied.""" + """ Return dict with connection config supplied.""" config = {consts.SOURCE_TYPE: args.connect_type} if args.connect_type == "Raw": @@ -390,7 +388,7 @@ def _generate_random_name(conn): def store_connection(connection_name, conn): - """Store the connection config under the given name.""" + """ Store the connection config under the given name.""" connection_name = connection_name or _generate_random_name(conn) file_path = _get_connection_file(connection_name) @@ -399,7 +397,7 @@ def store_connection(connection_name, conn): def get_connections(): - """Return dict with connection name and path key pairs.""" + """ Return dict with connection name and path key pairs.""" connections = {} dir_path = _get_data_validation_directory() @@ -415,7 +413,7 @@ def get_connections(): def list_connections(): - """List all saved connections.""" + """ List all saved connections.""" connections = get_connections() for conn_name in connections: @@ -423,7 +421,7 @@ def list_connections(): def get_connection(connection_name): - """Return dict connection details for a specific connection.""" + """ Return dict connection details for a specific connection.""" file_path = _get_connection_file(connection_name) with open(file_path, "r") as file: conn_str = file.read() @@ -432,7 +430,7 @@ def get_connection(connection_name): def get_labels(arg_labels): - """Return list of tuples representing key-value label pairs.""" + """ Return list of tuples representing key-value label pairs. """ labels = [] if arg_labels: pairs = arg_labels.split(",") @@ -517,7 +515,7 @@ def get_arg_list(arg_value, default_value=None): def get_tables_list(arg_tables, default_value=None, is_filesystem=False): - """Returns dictionary of tables. Backwards compatible for JSON input. + """ Returns dictionary of tables. Backwards compatible for JSON input. arg_table (str): tables_list argument specified default_value (Any): A default value to supply when arg_value is empty. @@ -573,7 +571,7 @@ def get_tables_list(arg_tables, default_value=None, is_filesystem=False): def split_table(table_ref, schema_required=True): - """Returns schema and table name given list of input values. + """ Returns schema and table name given list of input values. table_ref (List): Table reference i.e ['my.schema.my_table'] scehma_required (boolean): Indicates whether schema is required. A source diff --git a/data_validation/clients.py b/data_validation/clients.py index b07261172..734488647 100644 --- a/data_validation/clients.py +++ b/data_validation/clients.py @@ -179,7 +179,7 @@ def get_all_tables(client, allowed_schemas=None): def get_data_client(connection_config): - """Return DataClient client from given configuration""" + """ Return DataClient client from given configuration """ connection_config = copy.deepcopy(connection_config) source_type = connection_config.pop(consts.SOURCE_TYPE) diff --git a/data_validation/combiner.py b/data_validation/combiner.py index 251511668..9055cd3e7 100644 --- a/data_validation/combiner.py +++ b/data_validation/combiner.py @@ -32,12 +32,7 @@ def generate_report( - client, - run_metadata, - source, - target, - join_on_fields=(), - verbose=False, + client, run_metadata, source, target, join_on_fields=(), verbose=False, ): """Combine results into a report. diff --git a/data_validation/config_manager.py b/data_validation/config_manager.py index 1f10d44c9..7adae8c5e 100644 --- a/data_validation/config_manager.py +++ b/data_validation/config_manager.py @@ -82,12 +82,12 @@ def process_in_memory(self): @property def max_recursive_query_size(self): - """Return Aggregates from Config""" + """Return Aggregates from Config """ return self._config.get(consts.CONFIG_MAX_RECURSIVE_QUERY_SIZE, 50000) @property def aggregates(self): - """Return Aggregates from Config""" + """Return Aggregates from Config """ return self._config.get(consts.CONFIG_AGGREGATES, []) def append_aggregates(self, aggregate_configs): @@ -105,7 +105,7 @@ def append_calculated_fields(self, calculated_configs): @property def query_groups(self): - """Return Query Groups from Config""" + """ Return Query Groups from Config """ return self._config.get(consts.CONFIG_GROUPED_COLUMNS, []) def append_query_groups(self, grouped_column_configs): @@ -116,7 +116,7 @@ def append_query_groups(self, grouped_column_configs): @property def primary_keys(self): - """Return Query Groups from Config""" + """ Return Query Groups from Config """ return self._config.get(consts.CONFIG_PRIMARY_KEYS, []) def append_primary_keys(self, primary_key_configs): @@ -127,7 +127,7 @@ def append_primary_keys(self, primary_key_configs): @property def filters(self): - """Return Filters from Config""" + """Return Filters from Config """ return self._config.get(consts.CONFIG_FILTERS, []) @property @@ -185,12 +185,12 @@ def query_limit(self): @property def threshold(self): - """Return threshold from Config""" + """Return threshold from Config """ return self._config.get(consts.CONFIG_THRESHOLD, 0.0) @property def format(self): - """Return threshold from Config""" + """Return threshold from Config """ return self._config.get(consts.CONFIG_FORMAT, "table") def get_source_ibis_table(self): @@ -253,10 +253,8 @@ def get_result_handler(self): consts.GOOGLE_SERVICE_ACCOUNT_KEY_PATH ) if key_path: - credentials = ( - google.oauth2.service_account.Credentials.from_service_account_file( - key_path - ) + credentials = google.oauth2.service_account.Credentials.from_service_account_file( + key_path ) else: credentials = None diff --git a/data_validation/data_validation.py b/data_validation/data_validation.py index ea6f5fd8b..38c7daddc 100644 --- a/data_validation/data_validation.py +++ b/data_validation/data_validation.py @@ -90,14 +90,14 @@ def __init__( # TODO(dhercher) we planned on shifting this to use an Execution Handler. # Leaving to to swast on the design of how this should look. def execute(self): - """Execute Queries and Store Results""" + """ Execute Queries and Store Results """ if self.config_manager.validation_type == consts.ROW_VALIDATION: grouped_fields = self.validation_builder.pop_grouped_fields() result_df = self.execute_recursive_validation( self.validation_builder, grouped_fields ) elif self.config_manager.validation_type == consts.SCHEMA_VALIDATION: - """Perform only schema validation""" + """ Perform only schema validation """ result_df = self.schema_validator.execute() else: result_df = self._execute_validation( @@ -108,15 +108,15 @@ def execute(self): return self.result_handler.execute(self.config, self.format, result_df) def query_too_large(self, rows_df, grouped_fields): - """Return bool to dictate if another level of recursion - would create a too large result set. - - Rules to define too large are: - - If any grouped fields remain, return False. - (assumes user added logical sized groups) - - Else, if next group size is larger - than the limit, return True. - - Finally return False if no covered case occured. + """ Return bool to dictate if another level of recursion + would create a too large result set. + + Rules to define too large are: + - If any grouped fields remain, return False. + (assumes user added logical sized groups) + - Else, if next group size is larger + than the limit, return True. + - Finally return False if no covered case occured. """ if len(grouped_fields) > 1: return False @@ -203,7 +203,7 @@ def execute_recursive_validation(self, validation_builder, grouped_fields): return pandas.concat(past_results) def _add_recursive_validation_filter(self, validation_builder, row): - """Return ValidationBuilder Configured for Next Recursive Search""" + """ Return ValidationBuilder Configured for Next Recursive Search """ group_by_columns = json.loads(row[consts.GROUP_BY_COLUMNS]) for alias, value in group_by_columns.items(): filter_field = { @@ -250,7 +250,7 @@ def _get_pandas_schema(self, source_df, target_df, join_on_fields, verbose=False return pd_schema def _execute_validation(self, validation_builder, process_in_memory=True): - """Execute Against a Supplied Validation Builder""" + """ Execute Against a Supplied Validation Builder """ self.run_metadata.validations = validation_builder.get_metadata() source_query = validation_builder.get_source_query() @@ -301,7 +301,7 @@ def _execute_validation(self, validation_builder, process_in_memory=True): return result_df def combine_data(self, source_df, target_df, join_on_fields): - """TODO: Return List of Dictionaries""" + """ TODO: Return List of Dictionaries """ # Clean Data to Standardize if join_on_fields: df = source_df.merge( diff --git a/data_validation/query_builder/query_builder.py b/data_validation/query_builder/query_builder.py index fb5cb1c87..bee93ee16 100644 --- a/data_validation/query_builder/query_builder.py +++ b/data_validation/query_builder/query_builder.py @@ -37,15 +37,11 @@ def __init__(self, ibis_expr, field_name=None, alias=None): def count(field_name=None, alias=None): if field_name: return AggregateField( - ibis.expr.types.ColumnExpr.count, - field_name=field_name, - alias=alias, + ibis.expr.types.ColumnExpr.count, field_name=field_name, alias=alias, ) else: return AggregateField( - ibis.expr.types.TableExpr.count, - field_name=field_name, - alias=alias, + ibis.expr.types.TableExpr.count, field_name=field_name, alias=alias, ) @staticmethod @@ -63,25 +59,19 @@ def avg(field_name=None, alias=None): @staticmethod def max(field_name=None, alias=None): return AggregateField( - ibis.expr.types.ColumnExpr.max, - field_name=field_name, - alias=alias, + ibis.expr.types.ColumnExpr.max, field_name=field_name, alias=alias, ) @staticmethod def sum(field_name=None, alias=None): return AggregateField( - ibis.expr.api.IntegerColumn.sum, - field_name=field_name, - alias=alias, + ibis.expr.api.IntegerColumn.sum, field_name=field_name, alias=alias, ) @staticmethod def bit_xor(field_name=None, alias=None): return AggregateField( - ibis.expr.api.IntegerColumn.bit_xor, - field_name=field_name, - alias=alias, + ibis.expr.api.IntegerColumn.bit_xor, field_name=field_name, alias=alias, ) def compile(self, ibis_table): @@ -206,7 +196,7 @@ def compile(self, ibis_table): class ColumnReference(object): def __init__(self, column_name): - """A representation of an calculated field to build a query. + """ A representation of an calculated field to build a query. Args: column_name (String): The column name used in a complex expr @@ -214,7 +204,7 @@ def __init__(self, column_name): self.column_name = column_name def compile(self, ibis_table): - """Return an ibis object referencing the column. + """ Return an ibis object referencing the column. Args: ibis_table (IbisTable): The table obj reference @@ -225,7 +215,7 @@ def compile(self, ibis_table): class CalculatedField(object): def __init__(self, ibis_expr, config, fields, cast=None, **kwargs): - """A representation of an calculated field to build a query. + """ A representation of an calculated field to build a query. Args: config dict: Configurations object explaining calc field details @@ -244,61 +234,37 @@ def concat(config, fields): fields = [config["default_concat_separator"], fields] cast = "string" return CalculatedField( - ibis.expr.api.StringValue.join, - config, - fields, - cast=cast, + ibis.expr.api.StringValue.join, config, fields, cast=cast, ) @staticmethod def hash(config, fields): if config.get("default_hash_function") is None: how = "farm_fingerprint" - return CalculatedField( - ibis.expr.api.ValueExpr.hash, - config, - fields, - how=how, - ) + return CalculatedField(ibis.expr.api.ValueExpr.hash, config, fields, how=how,) @staticmethod def ifnull(config, fields): if config.get("default_null_string") is None: config["default_string"] = ibis.literal("DEFAULT_REPLACEMENT_STRING") fields = [config["default_string"], fields[0]] - return CalculatedField( - ibis.expr.api.ValueExpr.fillna, - config, - fields, - ) + return CalculatedField(ibis.expr.api.ValueExpr.fillna, config, fields,) @staticmethod def length(config, fields): - return CalculatedField( - ibis.expr.api.StringValue.length, - config, - fields, - ) + return CalculatedField(ibis.expr.api.StringValue.length, config, fields,) @staticmethod def rstrip(config, fields): - return CalculatedField( - ibis.expr.api.StringValue.rstrip, - config, - fields, - ) + return CalculatedField(ibis.expr.api.StringValue.rstrip, config, fields,) @staticmethod def upper(config, fields): - return CalculatedField( - ibis.expr.api.StringValue.upper, - config, - fields, - ) + return CalculatedField(ibis.expr.api.StringValue.upper, config, fields,) @staticmethod def custom(expr): - """Returns a CalculatedField instance built for any custom SQL using a supported operator. + """ Returns a CalculatedField instance built for any custom SQL using a supported operator. Args: expr (Str): A custom SQL expression used to filter a query """ @@ -333,7 +299,7 @@ class QueryBuilder(object): def __init__( self, aggregate_fields, calculated_fields, filters, grouped_fields, limit=None ): - """Build a QueryBuilder object which can be used to build queries easily + """ Build a QueryBuilder object which can be used to build queries easily Args: aggregate_fields (list[AggregateField]): AggregateField instances with Ibis expressions @@ -350,7 +316,7 @@ def __init__( @staticmethod def build_count_validator(limit=None): - """Return a basic template builder for most validations""" + """ Return a basic template builder for most validations """ aggregate_fields = [] filters = [] grouped_fields = [] @@ -453,7 +419,7 @@ def add_filter_field(self, filter_obj): self.filters.append(filter_obj) def add_calculated_field(self, calculated_field): - """Add a CalculatedField instance to your query which + """ Add a CalculatedField instance to your query which will add the desired scalar function to your compiled query (ie. CONCAT(field_a, field_b)) Args: diff --git a/data_validation/result_handlers/text.py b/data_validation/result_handlers/text.py index da7e2dbf0..6e28d7e8a 100644 --- a/data_validation/result_handlers/text.py +++ b/data_validation/result_handlers/text.py @@ -38,10 +38,8 @@ def print_formatted_(format, result_df): elif format == "table": print(result_df.to_markdown(tablefmt="fancy_grid")) else: - error_msg = ( - f"format [{format}] not supported, results printed in default(table) mode. " - f"Supported formats are [text, csv, json, table]" - ) + error_msg = f"format [{format}] not supported, results printed in default(table) mode. " \ + f"Supported formats are [text, csv, json, table]" print(result_df.to_markdown(tablefmt="fancy_grid")) raise ValueError(error_msg) diff --git a/data_validation/schema_validation.py b/data_validation/schema_validation.py index 702555d09..cdb60087e 100644 --- a/data_validation/schema_validation.py +++ b/data_validation/schema_validation.py @@ -32,7 +32,7 @@ def __init__(self, config_manager, run_metadata=None, verbose=False): self.run_metadata = run_metadata or metadata.RunMetadata() def execute(self): - """Performs a validation between source and a target schema""" + """ Performs a validation between source and a target schema""" ibis_source_schema = self.config_manager.source_client.get_schema( self.config_manager.source_table, self.config_manager.source_schema ) @@ -86,7 +86,7 @@ def execute(self): def schema_validation_matching(source_fields, target_fields): - """Compare schemas between two dictionary objects""" + """ Compare schemas between two dictionary objects """ results = [] # Go through each source and check if target exists and matches for source_field_name, source_field_type in source_fields.items(): diff --git a/data_validation/validation_builder.py b/data_validation/validation_builder.py index e7232b58a..259647510 100644 --- a/data_validation/validation_builder.py +++ b/data_validation/validation_builder.py @@ -67,7 +67,7 @@ def clone(self): @staticmethod def get_query_builder(validation_type): - """Return Query Builder object given validation type""" + """ Return Query Builder object given validation type """ if validation_type in ["Column", "GroupedColumn", "Row", "Schema"]: builder = QueryBuilder.build_count_validator() else: @@ -89,11 +89,11 @@ def get_metadata(self): return self._metadata def get_group_aliases(self): - """Return List of String Aliases""" + """ Return List of String Aliases """ return self.group_aliases.keys() def get_calculated_aliases(self): - """Return List of String Aliases""" + """ Return List of String Aliases """ return self.calculated_aliases.keys() def get_grouped_alias_source_column(self, alias): @@ -103,26 +103,26 @@ def get_grouped_alias_target_column(self, alias): return self.group_aliases[alias][consts.CONFIG_TARGET_COLUMN] def add_config_aggregates(self): - """Add Aggregations to Query""" + """ Add Aggregations to Query """ aggregate_fields = self.config_manager.aggregates for aggregate_field in aggregate_fields: self.add_aggregate(aggregate_field) def add_config_calculated_fields(self): - """Add calculated fields to Query""" + """ Add calculated fields to Query """ calc_fields = self.config_manager.calculated_fields if calc_fields is not None: for calc_field in calc_fields: self.add_calc(calc_field) def add_config_query_groups(self, query_groups=None): - """Add Grouped Columns to Query""" + """ Add Grouped Columns to Query """ grouped_fields = query_groups or self.config_manager.query_groups for grouped_field in grouped_fields: self.add_query_group(grouped_field) def add_config_filters(self): - """Add Filters to Query""" + """ Add Filters to Query """ filter_fields = self.config_manager.filters for filter_field in filter_fields: self.add_filter(filter_field) @@ -162,7 +162,7 @@ def add_aggregate(self, aggregate_field): ) def pop_grouped_fields(self): - """Return grouped fields and reset configs.""" + """ Return grouped fields and reset configs.""" self.source_builder.grouped_fields = [] self.target_builder.grouped_fields = [] self.group_aliases = {} @@ -219,7 +219,7 @@ def add_filter(self, filter_field): self.target_builder.add_filter_field(target_filter) def add_calc(self, calc_field): - """Add CalculatedField to Queries + """ Add CalculatedField to Queries Args: calc_field (Dict): An object with source, target, and cast info @@ -247,7 +247,7 @@ def add_calc(self, calc_field): self.calculated_aliases[alias] = calc_field def get_source_query(self): - """Return query for source validation""" + """ Return query for source validation """ source_config = { "data_client": self.source_client, "schema_name": self.config_manager.source_schema, @@ -261,7 +261,7 @@ def get_source_query(self): return query def get_target_query(self): - """Return query for source validation""" + """ Return query for source validation """ target_config = { "data_client": self.target_client, "schema_name": self.config_manager.target_schema, diff --git a/tests/system/data_sources/deploy_cloudsql/cloudsql_resource_manager.py b/tests/system/data_sources/deploy_cloudsql/cloudsql_resource_manager.py index eef054e0c..8d3a03958 100644 --- a/tests/system/data_sources/deploy_cloudsql/cloudsql_resource_manager.py +++ b/tests/system/data_sources/deploy_cloudsql/cloudsql_resource_manager.py @@ -36,7 +36,7 @@ def __init__( enable_bin_logs=True, already_exists=False, ): - """Initialize a CloudSQLResourceManager""" + """Initialize a CloudSQLResourceManager """ if database_type not in DATABASE_TYPES: raise ValueError( f"Invalid database type. Must be of the form {str(DATABASE_TYPES)}" @@ -55,14 +55,14 @@ def __init__( self.db = {} def describe(self): - """Returns description of resource manager instance""" + """ Returns description of resource manager instance """ print( f"Creates a {self._database_type} instance in project {self._project_id} with " f"database_id: {self._database_id}, instance_id: {self._instance_id}." ) def setup(self): - """Creates Cloud SQL instance and database""" + """ Creates Cloud SQL instance and database """ with GCloudContext(self._project_id) as gcloud: if self._already_exists: json_describe = gcloud.Run( @@ -114,7 +114,7 @@ def setup(self): return self.db["PRIMARY_ADDRESS"] def add_data(self, gcs_data_path): - """Adds data to Cloud SQL database""" + """ Adds data to Cloud SQL database """ if self._already_exists: return with GCloudContext(self._project_id) as gcloud: @@ -129,7 +129,7 @@ def add_data(self, gcs_data_path): ) def teardown(self): - """Deletes Cloud SQL instance""" + """ Deletes Cloud SQL instance """ # If instance is deleted per integration test, instance_id will need a random # suffix appended since Cloud SQL cannot re-use the same instance name until # 1 week after deletion. @@ -137,7 +137,7 @@ def teardown(self): gcloud.Run("--quiet", "sql", "instances", "delete", self._instance_id) def _get_random_string(self, length=5): - """Returns random string + """ Returns random string Args: length (int): Desired length of random string""" return "".join(random.choice(string.ascii_lowercase) for i in range(length)) diff --git a/tests/system/data_sources/deploy_cloudsql/gcloud_context.py b/tests/system/data_sources/deploy_cloudsql/gcloud_context.py index be150daa6..5175a761f 100644 --- a/tests/system/data_sources/deploy_cloudsql/gcloud_context.py +++ b/tests/system/data_sources/deploy_cloudsql/gcloud_context.py @@ -33,7 +33,7 @@ def __exit__(self, exc_type, exc_value, exc_traceback): pass def Run(self, *args, **kwargs): - """Runs gcloud command and returns output""" + """ Runs gcloud command and returns output""" env = kwargs.pop("env", None) if not env: env = os.environ.copy() diff --git a/tests/system/data_sources/test_bigquery.py b/tests/system/data_sources/test_bigquery.py index 2d7b95901..21d9263d5 100644 --- a/tests/system/data_sources/test_bigquery.py +++ b/tests/system/data_sources/test_bigquery.py @@ -179,9 +179,7 @@ def test_count_validator(): - validator = data_validation.DataValidation( - CONFIG_COUNT_VALID, format="text", verbose=True - ) + validator = data_validation.DataValidation(CONFIG_COUNT_VALID, format="text", verbose=True) df = validator.execute() count_value = df[df["validation_name"] == "count"]["source_agg_value"].values[0] @@ -204,15 +202,12 @@ def test_count_validator(): assert float(max_birth_year_value) > 0 assert float(min_birth_year_value) > 0 assert ( - df["source_agg_value"].astype(float).sum() - == df["target_agg_value"].astype(float).sum() + df["source_agg_value"].astype(float).sum() == df["target_agg_value"].astype(float).sum() ) def test_grouped_count_validator(): - validator = data_validation.DataValidation( - CONFIG_GROUPED_COUNT_VALID, format="csv", verbose=True - ) + validator = data_validation.DataValidation(CONFIG_GROUPED_COUNT_VALID, format="csv", verbose=True) df = validator.execute() rows = list(df[df["validation_name"] == "count"].iterrows()) @@ -228,9 +223,7 @@ def test_grouped_count_validator(): def test_numeric_types(): - validator = data_validation.DataValidation( - CONFIG_NUMERIC_AGG_VALID, format="json", verbose=True - ) + validator = data_validation.DataValidation(CONFIG_NUMERIC_AGG_VALID, format="json", verbose=True) df = validator.execute() for validation in df.to_dict(orient="records"): @@ -289,9 +282,7 @@ def _remove_bq_conn(): def test_unsupported_result_format(): with pytest.raises(ValueError): - validator = data_validation.DataValidation( - CONFIG_GROUPED_COUNT_VALID, format="foobar", verbose=True - ) + validator = data_validation.DataValidation(CONFIG_GROUPED_COUNT_VALID, format="foobar", verbose=True) df = validator.execute() rows = list(df[df["validation_name"] == "count"].iterrows()) assert len(rows) > 1 diff --git a/tests/system/data_sources/test_mysql.py b/tests/system/data_sources/test_mysql.py index e6563c6a4..626f17cd1 100644 --- a/tests/system/data_sources/test_mysql.py +++ b/tests/system/data_sources/test_mysql.py @@ -52,8 +52,7 @@ def test_mysql_count_invalid_host(): try: data_validator = data_validation.DataValidation( - CONFIG_COUNT_VALID, - verbose=False, + CONFIG_COUNT_VALID, verbose=False, ) df = data_validator.execute() assert df["source_agg_value"][0] == df["target_agg_value"][0] diff --git a/tests/system/data_sources/test_postgres.py b/tests/system/data_sources/test_postgres.py index bccf5d1a0..5295bf43c 100644 --- a/tests/system/data_sources/test_postgres.py +++ b/tests/system/data_sources/test_postgres.py @@ -28,7 +28,7 @@ def test_postgres_count(): - """Test count validation on Postgres instance""" + """ Test count validation on Postgres instance """ postgres_instance = CloudSQLResourceManager( PROJECT_ID, "POSTGRES_12", @@ -76,9 +76,6 @@ def test_postgres_count(): ], } - data_validator = data_validation.DataValidation( - config_count_valid, - verbose=False, - ) + data_validator = data_validation.DataValidation(config_count_valid, verbose=False,) df = data_validator.execute() assert df["source_agg_value"][0] == df["target_agg_value"][0] diff --git a/tests/system/data_sources/test_spanner.py b/tests/system/data_sources/test_spanner.py index ee74deddb..24aafd080 100644 --- a/tests/system/data_sources/test_spanner.py +++ b/tests/system/data_sources/test_spanner.py @@ -179,13 +179,9 @@ def test_count_validator(count_config): assert float(count_value) > 0 assert float(count_string_value) > 0 assert float(avg_float_value) > 0 - assert ( - datetime.datetime.strptime( - max_timestamp_value, - "%Y-%m-%d %H:%M:%S%z", - ) - > datetime.datetime(1970, 1, 1, tzinfo=datetime.timezone.utc) - ) + assert datetime.datetime.strptime( + max_timestamp_value, "%Y-%m-%d %H:%M:%S%z", + ) > datetime.datetime(1970, 1, 1, tzinfo=datetime.timezone.utc) assert float(min_int_value) > 0 diff --git a/tests/system/data_sources/test_sql_server.py b/tests/system/data_sources/test_sql_server.py index fa58c2adc..6382cc65d 100644 --- a/tests/system/data_sources/test_sql_server.py +++ b/tests/system/data_sources/test_sql_server.py @@ -28,7 +28,7 @@ def test_sql_server_count(): - """Test count validation on SQL Server instance""" + """ Test count validation on SQL Server instance """ mssql_instance = CloudSQLResourceManager( PROJECT_ID, "SQLSERVER_2017_STANDARD", @@ -76,9 +76,6 @@ def test_sql_server_count(): ], } - data_validator = data_validation.DataValidation( - config_count_valid, - verbose=False, - ) + data_validator = data_validation.DataValidation(config_count_valid, verbose=False,) df = data_validator.execute() assert df["source_agg_value"][0] == df["target_agg_value"][0] diff --git a/tests/unit/result_handlers/test_text.py b/tests/unit/result_handlers/test_text.py index 540c20cbc..cd11b1b0d 100644 --- a/tests/unit/result_handlers/test_text.py +++ b/tests/unit/result_handlers/test_text.py @@ -32,12 +32,12 @@ def module_under_test(): def test_import(module_under_test): - """Test import cleanly""" + """Test import cleanly """ assert module_under_test is not None def test_basic_result_handler(module_under_test): - """Test basic handler executes""" + """Test basic handler executes """ format = "json" result_df = DataFrame(SAMPLE_RESULT_DATA) result_handler = module_under_test.TextResultHandler() diff --git a/tests/unit/test__main.py b/tests/unit/test__main.py index b386ca11c..4e1077807 100644 --- a/tests/unit/test__main.py +++ b/tests/unit/test__main.py @@ -58,8 +58,7 @@ @mock.patch( - "argparse.ArgumentParser.parse_args", - return_value=argparse.Namespace(**CLI_ARGS), + "argparse.ArgumentParser.parse_args", return_value=argparse.Namespace(**CLI_ARGS), ) def test_configure_arg_parser(mock_args): """Test arg parser values.""" diff --git a/tests/unit/test_cli_tools.py b/tests/unit/test_cli_tools.py index 5175ebf11..ff7321cb0 100644 --- a/tests/unit/test_cli_tools.py +++ b/tests/unit/test_cli_tools.py @@ -56,8 +56,7 @@ @mock.patch( - "argparse.ArgumentParser.parse_args", - return_value=argparse.Namespace(**CLI_ARGS), + "argparse.ArgumentParser.parse_args", return_value=argparse.Namespace(**CLI_ARGS), ) def test_get_parsed_args(mock_args): """Test arg parser values.""" @@ -139,14 +138,13 @@ def test_get_labels(test_input, expected): ], ) def test_get_labels_err(test_input): - """Ensure that Value Error is raised when incorrect label argument is provided.""" + """Ensure that Value Error is raised when incorrect label argument is provided. """ with pytest.raises(ValueError): cli_tools.get_labels(test_input) @pytest.mark.parametrize( - "test_input,expected", - [(0, 0.0), (50, 50.0), (100, 100.0)], + "test_input,expected", [(0, 0.0), (50, 50.0), (100, 100.0)], ) def test_threshold_float(test_input, expected): """Test threshold float function.""" @@ -155,8 +153,7 @@ def test_threshold_float(test_input, expected): @pytest.mark.parametrize( - "test_input", - [(-4), (float("nan")), (float("inf")), ("string")], + "test_input", [(-4), (float("nan")), (float("inf")), ("string")], ) def test_threshold_float_err(test_input): """Test that threshold float only accepts positive floats.""" @@ -303,17 +300,16 @@ def test_get_result_handler(test_input, expected): ], ) def test_get_filters(test_input, expected): - """Test get filters from file function.""" + """ Test get filters from file function. """ res = cli_tools.get_filters(test_input) assert res == expected @pytest.mark.parametrize( - "test_input", - [("source:"), ("invalid:filter:count")], + "test_input", [("source:"), ("invalid:filter:count")], ) def test_get_filters_err(test_input): - """Test get filters function returns error.""" + """ Test get filters function returns error. """ with pytest.raises(ValueError): cli_tools.get_filters(test_input) @@ -340,12 +336,9 @@ def test_split_table_no_schema(): @pytest.mark.parametrize( - "test_input", - [(["table"])], + "test_input", [(["table"])], ) -def test_split_table_err( - test_input, -): +def test_split_table_err(test_input,): """Test split table throws the right errors.""" with pytest.raises(ValueError): cli_tools.split_table(test_input) diff --git a/tests/unit/test_config_manager.py b/tests/unit/test_config_manager.py index 6e026dca3..35ba739df 100644 --- a/tests/unit/test_config_manager.py +++ b/tests/unit/test_config_manager.py @@ -87,7 +87,7 @@ def module_under_test(): def test_import(module_under_test): - """Test import cleanly""" + """Test import cleanly """ assert module_under_test is not None @@ -170,7 +170,7 @@ def test_process_in_memory(module_under_test): def test_get_table_info(module_under_test): - """Test basic handler executes""" + """Test basic handler executes """ config_manager = module_under_test.ConfigManager( SAMPLE_CONFIG, MockIbisClient(), MockIbisClient(), verbose=False ) diff --git a/tests/unit/test_data_validation.py b/tests/unit/test_data_validation.py index 88b553f2b..d2a79308f 100644 --- a/tests/unit/test_data_validation.py +++ b/tests/unit/test_data_validation.py @@ -258,7 +258,7 @@ def module_under_test(): def _create_table_file(table_path, data): - """Create JSON File""" + """ Create JSON File """ with open(table_path, "w") as f: f.write(data) @@ -312,7 +312,7 @@ def test_import(module_under_test): def test_data_validation_client(module_under_test, fs): - """Test getting a Data Validation Client""" + """ Test getting a Data Validation Client """ _create_table_file(SOURCE_TABLE_FILE_PATH, JSON_DATA) _create_table_file(TARGET_TABLE_FILE_PATH, JSON_DATA) @@ -322,7 +322,7 @@ def test_data_validation_client(module_under_test, fs): def test_get_pandas_schema(module_under_test): - """Test extracting pandas schema from dataframes for Ibis Pandas.""" + """ Test extracting pandas schema from dataframes for Ibis Pandas.""" pandas_schema = module_under_test.DataValidation._get_pandas_schema( SOURCE_DF, SOURCE_DF, JOIN_ON_FIELDS ) diff --git a/tests/unit/test_metadata.py b/tests/unit/test_metadata.py index afcafd42f..639df169f 100644 --- a/tests/unit/test_metadata.py +++ b/tests/unit/test_metadata.py @@ -43,15 +43,7 @@ def test_get_column_name( def test_get_column_name_with_unexpected_result_type(module_under_test): validation = module_under_test.ValidationMetadata( - "", - "", - "", - "", - "", - "", - "", - "", - "", + "", "", "", "", "", "", "", "", "", ) with pytest.raises(ValueError, match="Unexpected result_type"): validation.get_column_name("oops_i_goofed") diff --git a/tests/unit/test_schema_validation.py b/tests/unit/test_schema_validation.py index 46e55b6c0..da515798e 100644 --- a/tests/unit/test_schema_validation.py +++ b/tests/unit/test_schema_validation.py @@ -77,7 +77,7 @@ def module_under_test(): def _create_table_file(table_path, data): - """Create JSON File""" + """ Create JSON File """ with open(table_path, "w") as f: f.write(data)