From 3d78ee578b4222a9bdb19c7091445f48f413b9a0 Mon Sep 17 00:00:00 2001 From: Mike Hilton Date: Wed, 23 Feb 2022 13:51:41 -0800 Subject: [PATCH] feat: first class support for row level hashing (#345) * adding scaffolding for calc field builder in config manager * exposing cast via calculated fields. Don't know if we necessarily need this just adding for consistency * diff check * config file generating as expected * expanding cli for row level validations * splitting out comparison fields from aggregates * row comparisons operational (sort of) * re-enabling aggregate validations * cohabitation of validation types! * figuring out why unit tests are borked * continuing field split * stash before merge * testing diff * tests passing * removing extra print statements * tests and lint * adding fail tests * first round of requested changes * change requests round two. * refactor CLI and lint * swapping out farm fingerprint for sha256 as default * changes per CR * fixing text result tests * adding docs * hash example * linting * think I found the broken test * fixed tests * setting default for depth length * relaxing system test --- README.md | 51 ++++ data_validation/__main__.py | 92 +++++-- data_validation/cli_tools.py | 122 ++++++--- data_validation/combiner.py | 97 ++++---- data_validation/config_manager.py | 132 ++++++++-- data_validation/consts.py | 2 + data_validation/data_validation.py | 20 +- .../query_builder/query_builder.py | 109 ++++++-- data_validation/result_handlers/text.py | 11 + data_validation/validation_builder.py | 68 ++++- docs/examples.md | 7 +- tests/system/data_sources/test_bigquery.py | 7 +- tests/unit/result_handlers/test_text.py | 25 +- tests/unit/test_config_manager.py | 35 +-- tests/unit/test_data_validation.py | 232 +++++++++++++++--- 15 files changed, 798 insertions(+), 212 deletions(-) diff --git a/README.md b/README.md index 8d34b88d5..2d9842f94 100644 --- a/README.md +++ b/README.md @@ -134,6 +134,51 @@ sum , min, etc.) is provided, the default aggregation will run. The [Examples](docs/examples.md) page provides many examples of how a tool can used to run powerful validations without writing any queries. +#### Row Validations + +Below is the command syntax for row validations. In order to run row level +validations you need to pass a `--primary-key` flag which defines what field(s) +the validation will be compared along, as well as a `--comparison-fields` flag +which specifies the values (e.g. columns) whose raw values will be compared +based on the primary key join. Additionally you can use +[Calculated Fields](#calculated-fields) to compare derived values such as string +counts and hashes of multiple columns. + +``` +data-validation (--verbose or -v) validate row + --source-conn or -sc SOURCE_CONN + Source connection details + See: *Data Source Configurations* section for each data source + --target-conn or -tc TARGET_CONN + Target connection details + See: *Connections* section for each data source + --tables-list or -tbls SOURCE_SCHEMA.SOURCE_TABLE=TARGET_SCHEMA.TARGET_TABLE + Comma separated list of tables in the form schema.table=target_schema.target_table + Target schema name and table name are optional. + i.e 'bigquery-public-data.new_york_citibike.citibike_trips' + [--primary-keys or -pk PRIMARY_KEYS] + Comma separated list of columns to use as primary keys + [--comparison-fields or -fields comparison-fields] + Comma separated list of columns to compare. Can either be a physical column or an alias + See: *Calculated Fields* section for details + [--hash COLUMNS] Comma separated list of columns to perform a hash operation on or * for all columns + [--bq-result-handler or -bqrh PROJECT_ID.DATASET.TABLE] + BigQuery destination for validation results. Defaults to stdout. + See: *Validation Reports* section + [--service-account or -sa PATH_TO_SA_KEY] + Service account to use for BigQuery result handler output. + [--filters SOURCE_FILTER:TARGET_FILTER] + Colon spearated string values of source and target filters. + If target filter is not provided, the source filter will run on source and target tables. + See: *Filters* section + [--config-file or -c CONFIG_FILE] + YAML Config File Path to be used for storing validations. + [--labels or -l KEY1=VALUE1,KEY2=VALUE2] + Comma-separated key value pair labels for the run. + [--format or -fmt] Format for stdout output. Supported formats are (text, csv, json, table). + Defaults to table. +``` + #### Schema Validations Below is the syntax for schema validations. These can be used to compare column @@ -289,6 +334,12 @@ Grouped Columns contain the fields you want your aggregations to be broken out by, e.g. `SELECT last_updated::DATE, COUNT(*) FROM my.table` will produce a resultset that breaks down the count of rows per calendar date. +### Comparison Fields + +For row validations you need to specify the specific columns that you want to +compare. These values will be compared via a JOIN on their corresponding primary +key and will be evaluated for an exact match. + ### Calculated Fields Sometimes direct comparisons are not feasible between databases due to diff --git a/data_validation/__main__.py b/data_validation/__main__.py index 389632c6e..35c5758ae 100644 --- a/data_validation/__main__.py +++ b/data_validation/__main__.py @@ -13,9 +13,9 @@ # limitations under the License. import os - -import logging import json +import sys +from yaml import Dumper, dump from data_validation import ( cli_tools, @@ -27,8 +27,9 @@ from data_validation.config_manager import ConfigManager from data_validation.data_validation import DataValidation -from yaml import dump -import sys + +# by default yaml dumps lists as pointers. This disables that feature +Dumper.ignore_aliases = lambda *args: True def _get_arg_config_file(args): @@ -78,30 +79,73 @@ def get_aggregate_config(args, config_manager): aggregate_configs += config_manager.build_config_column_aggregates( "max", col_args, consts.NUMERIC_DATA_TYPES ) + if args.bit_xor: + col_args = None if args.bit_xor == "*" else cli_tools.get_arg_list(args.bit_xor) + aggregate_configs += config_manager.build_config_column_aggregates( + "bit_xor", col_args, consts.NUMERIC_DATA_TYPES + ) return aggregate_configs +def get_calculated_config(args, config_manager): + """Return list of formatted calculated objects. + + Args: + config_manager(ConfigManager): Validation config manager instance. + """ + calculated_configs = [] + fields = [] + if args.hash: + fields = config_manager._build_dependent_aliases("hash") + if len(fields) > 0: + max_depth = max([x["depth"] for x in fields]) + else: + max_depth = 0 + for field in fields: + calculated_configs.append( + config_manager.build_config_calculated_fields( + field["reference"], + field["calc_type"], + field["name"], + field["depth"], + None, + ) + ) + if args.hash: + config_manager.append_comparison_fields( + config_manager.build_config_comparison_fields( + ["hash__all"], depth=max_depth + ) + ) + return calculated_configs + + def build_config_from_args(args, config_manager): """Return config manager object ready to execute. Args: config_manager (ConfigManager): Validation config manager instance. """ - config_manager.append_aggregates(get_aggregate_config(args, config_manager)) - if args.primary_keys and not args.grouped_columns: - if not args.grouped_columns and not config_manager.use_random_rows(): - logging.warning( - "No Grouped columns or Random Rows specified, ignoring primary keys." + config_manager.append_calculated_fields(get_calculated_config(args, config_manager)) + if config_manager.validation_type == consts.COLUMN_VALIDATION: + config_manager.append_aggregates(get_aggregate_config(args, config_manager)) + if args.grouped_columns is not None: + grouped_columns = cli_tools.get_arg_list(args.grouped_columns) + config_manager.append_query_groups( + config_manager.build_config_grouped_columns(grouped_columns) ) - if args.grouped_columns: - grouped_columns = cli_tools.get_arg_list(args.grouped_columns) - config_manager.append_query_groups( - config_manager.build_config_grouped_columns(grouped_columns) - ) - if args.primary_keys: - primary_keys = cli_tools.get_arg_list(args.primary_keys, default_value=[]) + elif config_manager.validation_type == consts.ROW_VALIDATION: + if args.comparison_fields is not None: + comparison_fields = cli_tools.get_arg_list( + args.comparison_fields, default_value=[] + ) + config_manager.append_comparison_fields( + config_manager.build_config_comparison_fields(comparison_fields) + ) + if args.primary_keys is not None: + primary_keys = cli_tools.get_arg_list(args.primary_keys) config_manager.append_primary_keys( - config_manager.build_config_grouped_columns(primary_keys) + config_manager.build_config_comparison_fields(primary_keys) ) # TODO(GH#18): Add query filter config logic @@ -118,11 +162,9 @@ def build_config_managers_from_args(args): if validate_cmd == "Schema": config_type = consts.SCHEMA_VALIDATION elif validate_cmd == "Column": - # TODO: We need to discuss how GroupedColumn and Row are differentiated. - if args.grouped_columns: - config_type = consts.GROUPED_COLUMN_VALIDATION - else: - config_type = consts.COLUMN_VALIDATION + config_type = consts.COLUMN_VALIDATION + elif validate_cmd == "Row": + config_type = consts.ROW_VALIDATION else: raise ValueError(f"Unknown Validation Type: {validate_cmd}") else: @@ -140,7 +182,7 @@ def build_config_managers_from_args(args): # Schema validation will not accept filters, labels, or threshold as flags filter_config, labels, threshold = [], [], 0.0 - if config_type != consts.SCHEMA_VALIDATION: + if config_type != consts.COLUMN_VALIDATION: if args.filters: filter_config = cli_tools.get_filters(args.filters) if args.threshold: @@ -386,8 +428,8 @@ def run_validation_configs(args): def validate(args): - """Run commands related to data validation.""" - if args.validate_cmd == "column" or args.validate_cmd == "schema": + """ Run commands related to data validation.""" + if args.validate_cmd in ["column", "row", "schema"]: run(args) else: raise ValueError(f"Validation Argument '{args.validate_cmd}' is not supported") diff --git a/data_validation/cli_tools.py b/data_validation/cli_tools.py index c8bf4e42e..b50ba3e9a 100644 --- a/data_validation/cli_tools.py +++ b/data_validation/cli_tools.py @@ -61,7 +61,6 @@ ["port", "Teradata port to connect on"], ["user_name", "User used to connect"], ["password", "Password for supplied user"], - ["logmech", "Log on mechanism"], ], "Oracle": [ ["host", "Desired Oracle host"], @@ -269,41 +268,6 @@ def _configure_run_parser(subparsers): "-tbls", help="Comma separated tables list in the form 'schema.table=target_schema.target_table'", ) - run_parser.add_argument( - "--count", - "-count", - help="Comma separated list of columns for count 'col_a,col_b' or * for all columns", - ) - run_parser.add_argument( - "--sum", - "-sum", - help="Comma separated list of columns for sum 'col_a,col_b' or * for all columns", - ) - run_parser.add_argument( - "--avg", - "-avg", - help="Comma separated list of columns for avg 'col_a,col_b' or * for all columns", - ) - run_parser.add_argument( - "--min", - "-min", - help="Comma separated list of columns for min 'col_a,col_b' or * for all columns", - ) - run_parser.add_argument( - "--max", - "-max", - help="Comma separated list of columns for max 'col_a,col_b' or * for all columns", - ) - run_parser.add_argument( - "--grouped-columns", - "-gc", - help="Comma separated list of columns to use in GroupBy 'col_a,col_b'", - ) - run_parser.add_argument( - "--primary-keys", - "-pk", - help="Comma separated list of primary key columns 'col_a,col_b'", - ) run_parser.add_argument( "--result-handler-config", "-rc", help="Result handler config details" ) @@ -318,6 +282,11 @@ def _configure_run_parser(subparsers): run_parser.add_argument( "--labels", "-l", help="Key value pair labels for validation run", ) + run_parser.add_argument( + "--hash", + "-hash", + help="Comma separated list of columns for hash 'col_a,col_b' or * for all columns", + ) run_parser.add_argument( "--service-account", "-sa", @@ -408,12 +377,70 @@ def _configure_validate_parser(subparsers): ) _configure_column_parser(column_parser) + row_parser = validate_subparsers.add_parser("row", help="Run a row validation") + _configure_row_parser(row_parser) + schema_parser = validate_subparsers.add_parser( "schema", help="Run a schema validation" ) _configure_schema_parser(schema_parser) +def _configure_row_parser(row_parser): + """Configure arguments to run row level validations.""" + _add_common_arguments(row_parser) + row_parser.add_argument( + "--hash", + "-hash", + help="Comma separated list of columns for hash 'col_a,col_b' or * for all columns", + ) + row_parser.add_argument( + "--comparison-fields", + "-comp-fields", + help="Individual columns to compare. If comparing a calculated field use the column alias.", + ) + row_parser.add_argument( + "--calculated-fields", + "-calc-fields", + help="list of calculated fields to generate.", + ) + row_parser.add_argument( + "--primary-keys", + "-pk", + help="Comma separated list of primary key columns 'col_a,col_b'", + ) + row_parser.add_argument( + "--labels", "-l", help="Key value pair labels for validation run" + ) + row_parser.add_argument( + "--threshold", + "-th", + type=threshold_float, + help="Float max threshold for percent difference", + ) + row_parser.add_argument( + "--grouped-columns", + "-gc", + help="Comma separated list of columns to use in GroupBy 'col_a,col_b'", + ) + row_parser.add_argument( + "--filters", + "-filters", + help="Filters in the format source_filter:target_filter", + ) + row_parser.add_argument( + "--use-random-row", + "-rr", + action="store_true", + help="Finds a set of random rows of the first primary key supplied.", + ) + row_parser.add_argument( + "--random-row-batch-size", + "-rbs", + help="Row batch size used for random row filters (default 10,000).", + ) + + def _configure_column_parser(column_parser): """Configure arguments to run column level validations.""" _add_common_arguments(column_parser) @@ -442,6 +469,26 @@ def _configure_column_parser(column_parser): "-max", help="Comma separated list of columns for max 'col_a,col_b' or * for all columns", ) + column_parser.add_argument( + "--hash", + "-hash", + help="Comma separated list of columns for hashing a concatenate 'col_a,col_b' or * for all columns", + ) + column_parser.add_argument( + "--bit_xor", + "-bit_xor", + help="Comma separated list of columns for hashing a concatenate 'col_a,col_b' or * for all columns", + ) + column_parser.add_argument( + "--comparison-fields", + "-comp-fields", + help="list of fields to perform exact comparisons to. Use column aliases if this is calculated.", + ) + column_parser.add_argument( + "--calculated-fields", + "-calc-fields", + help="list of calculated fields to generate.", + ) column_parser.add_argument( "--grouped-columns", "-gc", @@ -713,6 +760,9 @@ def get_arg_list(arg_value, default_value=None): return default_value try: + if isinstance(arg_value, list): + arg_value = str(arg_value) + # arg_value = "hash_all" arg_list = json.loads(arg_value) except json.decoder.JSONDecodeError: arg_list = arg_value.split(",") diff --git a/data_validation/combiner.py b/data_validation/combiner.py index 3aeac5a94..bd446621d 100644 --- a/data_validation/combiner.py +++ b/data_validation/combiner.py @@ -52,7 +52,7 @@ def generate_report( join_on_fields (Sequence[str]): A collection of column names to use to join source and target. These are the columns that both the source and target queries - grouped by. + are grouped by. is_value_comparison (boolean): Boolean representing if source and target agg values should be compared with 'equals to' rather than a 'difference' comparison. @@ -72,7 +72,6 @@ def generate_report( "Expected source and target to have same schema, got " f"source: {source_names} target: {target_names}" ) - differences_pivot = _calculate_differences( source, target, join_on_fields, run_metadata.validations, is_value_comparison ) @@ -99,11 +98,11 @@ def _calculate_difference(field_differences, datatype, validation, is_value_comp pct_threshold = ibis.literal(validation.threshold) if isinstance(datatype, ibis.expr.datatypes.Timestamp): - source_value = field_differences["differences_source_agg_value"].epoch_seconds() - target_value = field_differences["differences_target_agg_value"].epoch_seconds() + source_value = field_differences["differences_source_value"].epoch_seconds() + target_value = field_differences["differences_target_value"].epoch_seconds() else: - source_value = field_differences["differences_source_agg_value"] - target_value = field_differences["differences_target_agg_value"] + source_value = field_differences["differences_source_value"] + target_value = field_differences["differences_target_value"] # Does not calculate difference between agg values for row hash due to int64 overflow if is_value_comparison: @@ -154,8 +153,6 @@ def _calculate_differences( difference calculation would fail if done after that step. """ schema = source.schema() - all_fields = frozenset(schema.names) - validation_fields = all_fields - frozenset(join_on_fields) if join_on_fields: # Use an inner join because a row must be present in source and target @@ -167,30 +164,27 @@ def _calculate_differences( differences_joined = source.cross_join(target) differences_pivots = [] - for field, field_type in schema.items(): - if field not in validation_fields: + if field not in validations: continue - - validation = validations[field] - - field_differences = differences_joined.projection( - [ - source[field].name("differences_source_agg_value"), - target[field].name("differences_target_agg_value"), - ] - + [source[join_field] for join_field in join_on_fields] - ) - differences_pivots.append( - field_differences[ - (ibis.literal(field).name("validation_name"),) - + join_on_fields - + _calculate_difference( - field_differences, field_type, validation, is_value_comparison - ) - ] - ) - + else: + validation = validations[field] + field_differences = differences_joined.projection( + [ + source[field].name("differences_source_value"), + target[field].name("differences_target_value"), + ] + + [source[join_field] for join_field in join_on_fields] + ) + differences_pivots.append( + field_differences[ + (ibis.literal(field).name("validation_name"),) + + join_on_fields + + _calculate_difference( + field_differences, field_type, validation, is_value_comparison + ) + ] + ) differences_pivot = functools.reduce( lambda pivot1, pivot2: pivot1.union(pivot2), differences_pivots ) @@ -203,26 +197,33 @@ def _pivot_result(result, join_on_fields, validations, result_type): pivots = [] for field in validation_fields: - validation = validations[field] - pivots.append( - result.projection( - ( - ibis.literal(field).name("validation_name"), - ibis.literal(validation.validation_type).name("validation_type"), - ibis.literal(validation.aggregation_type).name("aggregation_type"), - ibis.literal(validation.get_table_name(result_type)).name( - "table_name" - ), - # Cast to string to ensure types match, even when column - # name is NULL (such as for count aggregations). - ibis.literal(validation.get_column_name(result_type)) - .cast("string") - .name("column_name"), - result[field].cast("string").name("agg_value"), + if field not in validations: + continue + else: + validation = validations[field] + pivots.append( + result.projection( + ( + ibis.literal(field).name("validation_name"), + ibis.literal(validation.validation_type).name( + "validation_type" + ), + ibis.literal(validation.aggregation_type).name( + "aggregation_type" + ), + ibis.literal(validation.get_table_name(result_type)).name( + "table_name" + ), + # Cast to string to ensure types match, even when column + # name is NULL (such as for count aggregations). + ibis.literal(validation.get_column_name(result_type)) + .cast("string") + .name("column_name"), + result[field].cast("string").name("agg_value"), + ) + + join_on_fields ) - + join_on_fields ) - ) pivot = functools.reduce(lambda pivot1, pivot2: pivot1.union(pivot2), pivots) return pivot diff --git a/data_validation/config_manager.py b/data_validation/config_manager.py index d63ecbfe9..9969bd6e5 100644 --- a/data_validation/config_manager.py +++ b/data_validation/config_manager.py @@ -53,8 +53,6 @@ def __init__(self, config, source_client=None, target_client=None, verbose=False self.target_client = target_client or clients.get_data_client( self.get_target_connection() ) - if not self.process_in_memory(): - self.target_client = self.source_client self.verbose = verbose if self.validation_type not in consts.CONFIG_TYPES: @@ -104,12 +102,7 @@ def random_row_batch_size(self): ) def process_in_memory(self): - if ( - self.validation_type == "Row" - and self.get_source_connection() == self.get_target_connection() - ): - return False - + """Return whether to process in memory or on a remote platform.""" return True @property @@ -128,7 +121,7 @@ def append_aggregates(self, aggregate_configs): @property def calculated_fields(self): - return self._config.get(consts.CONFIG_CALCULATED_FIELDS) + return self._config.get(consts.CONFIG_CALCULATED_FIELDS, []) def append_calculated_fields(self, calculated_configs): self._config[consts.CONFIG_CALCULATED_FIELDS] = ( @@ -157,6 +150,17 @@ def append_primary_keys(self, primary_key_configs): self.primary_keys + primary_key_configs ) + @property + def comparison_fields(self): + """ Return fields from Config """ + return self._config.get(consts.CONFIG_COMPARISON_FIELDS, []) + + def append_comparison_fields(self, field_configs): + """Append field configs to existing config.""" + self._config[consts.CONFIG_COMPARISON_FIELDS] = ( + self.comparison_fields + field_configs + ) + @property def filters(self): """Return Filters from Config """ @@ -232,12 +236,13 @@ def get_source_ibis_table(self): ) return self._source_ibis_table - def get_source_ibis_calculated_table(self): - """Return mutated IbisTable from source""" + def get_source_ibis_calculated_table(self, depth=None): + """Return mutated IbisTable from source + n: Int the depth of subquery requested""" table = self.get_source_ibis_table() vb = ValidationBuilder(self) calculated_table = table.mutate( - vb.source_builder.compile_calculated_fields(table) + vb.source_builder.compile_calculated_fields(table, n=depth) ) return calculated_table @@ -250,12 +255,13 @@ def get_target_ibis_table(self): ) return self._target_ibis_table - def get_target_ibis_calculated_table(self): - """Return mutated IbisTable from target""" + def get_target_ibis_calculated_table(self, depth=None): + """Return mutated IbisTable from target + n: Int the depth of subquery requested""" table = self.get_target_ibis_table() vb = ValidationBuilder(self) calculated_table = table.mutate( - vb.target_builder.compile_calculated_fields(table) + vb.target_builder.compile_calculated_fields(table, n=depth) ) return calculated_table @@ -347,6 +353,19 @@ def build_config_manager( verbose=verbose, ) + def build_config_comparison_fields(self, fields, depth=None): + """Return list of field config objects.""" + field_configs = [] + for field in fields: + column_config = { + consts.CONFIG_SOURCE_COLUMN: field.casefold(), + consts.CONFIG_TARGET_COLUMN: field.casefold(), + consts.CONFIG_FIELD_ALIAS: field, + consts.CONFIG_CAST: None, + } + field_configs.append(column_config) + return field_configs + def build_config_grouped_columns(self, grouped_columns): """Return list of grouped column config objects.""" grouped_column_configs = [] @@ -423,3 +442,86 @@ def build_config_column_aggregates(self, agg_type, arg_value, supported_types): aggregate_configs.append(aggregate_config) return aggregate_configs + + def build_config_calculated_fields( + self, reference, calc_type, alias, depth, supported_types, arg_value=None + ): + """Returns list of calculated fields""" + source_table = self.get_source_ibis_calculated_table(depth=depth) + target_table = self.get_target_ibis_calculated_table(depth=depth) + + casefold_source_columns = {x.casefold(): str(x) for x in source_table.columns} + casefold_target_columns = {x.casefold(): str(x) for x in target_table.columns} + + allowlist_columns = arg_value or casefold_source_columns + for column in casefold_source_columns: + column_type_str = str(source_table[casefold_source_columns[column]].type()) + column_type = column_type_str.split("(")[0] + if column not in allowlist_columns: + continue + elif column not in casefold_target_columns: + logging.info( + f"Skipping Calc {calc_type}: {source_table.op().name}.{column} {column_type}" + ) + continue + elif supported_types and column_type not in supported_types: + if self.verbose: + msg = f"Skipping Calc {calc_type}: {source_table.op().name}.{column} {column_type}" + print(msg) + continue + + calculated_config = { + consts.CONFIG_CALCULATED_SOURCE_COLUMNS: reference, + consts.CONFIG_CALCULATED_TARGET_COLUMNS: reference, + consts.CONFIG_FIELD_ALIAS: alias, + consts.CONFIG_TYPE: calc_type, + consts.CONFIG_DEPTH: depth, + } + return calculated_config + + def _build_dependent_aliases(self, calc_type): + """This is a utility function for determining the required depth of all fields""" + order_of_operations = [] + source_table = self.get_source_ibis_calculated_table() + casefold_source_columns = {x.casefold(): str(x) for x in source_table.columns} + if calc_type == "hash": + order_of_operations = [ + "cast", + "ifnull", + "rstrip", + "upper", + "concat", + "hash", + ] + column_aliases = {} + col_names = [] + for i, calc in enumerate(order_of_operations): + if i == 0: + previous_level = [x for x in casefold_source_columns.values()] + else: + previous_level = [k for k, v in column_aliases.items() if v == i - 1] + if calc in ["concat", "hash"]: + col = {} + col["reference"] = previous_level + col["name"] = f"{calc}__all" + col["calc_type"] = calc + col["depth"] = i + name = col["name"] + # need to capture all aliases at the previous level. probably name concat__all + column_aliases[name] = i + col_names.append(col) + else: + for ( + column + ) in ( + previous_level + ): # this needs to be the previous manifest of columns + col = {} + col["reference"] = [column] + col["name"] = f"{calc}__" + column + col["calc_type"] = calc + col["depth"] = i + name = col["name"] + column_aliases[name] = i + col_names.append(col) + return col_names diff --git a/data_validation/consts.py b/data_validation/consts.py index 6766992e8..8e0d7c3db 100644 --- a/data_validation/consts.py +++ b/data_validation/consts.py @@ -26,6 +26,7 @@ CONFIG_TARGET_SCHEMA_NAME = "target_schema_name" CONFIG_TARGET_TABLE_NAME = "target_table_name" CONFIG_LABELS = "labels" +CONFIG_COMPARISON_FIELDS = "comparison_fields" CONFIG_FIELD_ALIAS = "field_alias" CONFIG_AGGREGATES = "aggregates" CONFIG_CALCULATED_FIELDS = "calculated_fields" @@ -39,6 +40,7 @@ CONFIG_TARGET_COLUMN = "target_column" CONFIG_THRESHOLD = "threshold" CONFIG_CAST = "cast" +CONFIG_DEPTH = "depth" CONFIG_FORMAT = "format" CONFIG_LIMIT = "limit" CONFIG_FILTERS = "filters" diff --git a/data_validation/data_validation.py b/data_validation/data_validation.py index 2f4edbc32..95f650667 100644 --- a/data_validation/data_validation.py +++ b/data_validation/data_validation.py @@ -176,7 +176,6 @@ def execute_recursive_validation(self, validation_builder, grouped_fields): """ process_in_memory = self.config_manager.process_in_memory() past_results = [] - if len(grouped_fields) > 0: validation_builder.add_query_group(grouped_fields[0]) result_df = self._execute_validation( @@ -214,13 +213,17 @@ def execute_recursive_validation(self, validation_builder, grouped_fields): recursive_validation_builder, grouped_fields[1:] ) ) - elif self.config_manager.primary_keys: - validation_builder.add_config_query_groups(self.config_manager.primary_keys) + elif self.config_manager.primary_keys and len(grouped_fields) == 0: past_results.append( self._execute_validation( validation_builder, process_in_memory=process_in_memory ) ) + + # elif self.config_manager.primary_keys: + # validation_builder.add_config_query_groups(self.config_manager.primary_keys) + # validation_builder.add_config_query_groups(grouped_fields) + else: warnings.warn( "WARNING: No Primary Keys Suppplied in Row Validation", UserWarning @@ -283,10 +286,16 @@ def _execute_validation(self, validation_builder, process_in_memory=True): source_query = validation_builder.get_source_query() target_query = validation_builder.get_target_query() - join_on_fields = validation_builder.get_group_aliases() + join_on_fields = ( + set(validation_builder.get_primary_keys()) + if self.config_manager.validation_type == consts.ROW_VALIDATION + else set(validation_builder.get_group_aliases()) + ) # If row validation from YAML, compare source and target agg values - is_value_comparison = self.config_manager.validation_type == "Row" + is_value_comparison = ( + self.config_manager.validation_type == consts.ROW_VALIDATION + ) if process_in_memory: source_df = self.config_manager.source_client.execute(source_query) @@ -294,7 +303,6 @@ def _execute_validation(self, validation_builder, process_in_memory=True): pd_schema = self._get_pandas_schema( source_df, target_df, join_on_fields, verbose=self.verbose ) - pandas_client = ibis.backends.pandas.connect( {combiner.DEFAULT_SOURCE: source_df, combiner.DEFAULT_TARGET: target_df} ) diff --git a/data_validation/query_builder/query_builder.py b/data_validation/query_builder/query_builder.py index f0da255fb..94ef61e95 100644 --- a/data_validation/query_builder/query_builder.py +++ b/data_validation/query_builder/query_builder.py @@ -16,7 +16,7 @@ from ibis.expr.types import StringScalar from third_party.ibis.ibis_addon import operations -from data_validation import clients +from data_validation import clients, consts class AggregateField(object): @@ -192,6 +192,28 @@ def compile(self, ibis_table): return self.expr(self.left, self.right) +class ComparisonField(object): + def __init__(self, field_name, alias=None, cast=None): + """A representation of a comparison field used to build a query. + + Args: + field_name (String): A field to act on in the table + alias (String): An alias to use for the group + cast (String): A cast on the column if required + """ + self.field_name = field_name + self.alias = alias + self.cast = cast + + def compile(self, ibis_table): + # Fields are supplied on compile or on build + comparison_field = ibis_table[self.field_name] + alias = self.alias or self.field_name + comparison_field = comparison_field.name(alias) + + return comparison_field + + class GroupedField(object): def __init__(self, field_name, alias=None, cast=None): """A representation of a group by field used to build a query. @@ -272,14 +294,24 @@ def concat(config, fields): @staticmethod def hash(config, fields): if config.get("default_hash_function") is None: + how = "sha256" + return CalculatedField( + ibis.expr.api.StringValue.hashbytes, config, fields, how=how, + ) + else: how = "farm_fingerprint" - return CalculatedField(ibis.expr.api.ValueExpr.hash, config, fields, how=how,) + return CalculatedField( + ibis.expr.api.ValueExpr.hash, config, fields, how=how, + ) @staticmethod def ifnull(config, fields): - if config.get("default_null_string") is None: - config["default_string"] = ibis.literal("DEFAULT_REPLACEMENT_STRING") - fields = [config["default_string"], fields[0]] + config["default_string"] = ( + ibis.literal("DEFAULT_REPLACEMENT_STRING") + if config.get("default_null_string") is None + else config.get("default_null_string") + ) + fields = [fields[0], config["default_string"]] return CalculatedField(ibis.expr.api.ValueExpr.fillna, config, fields,) @staticmethod @@ -294,6 +326,14 @@ def rstrip(config, fields): def upper(config, fields): return CalculatedField(ibis.expr.api.StringValue.upper, config, fields,) + @staticmethod + def cast(config, fields): + if config.get("default_cast") is None: + target_type = "string" + return CalculatedField( + ibis.expr.api.ValueExpr.cast, config, fields, target_type=target_type, + ) + @staticmethod def custom(expr): """ Returns a CalculatedField instance built for any custom SQL using a supported operator. @@ -315,7 +355,6 @@ def _compile_fields(self, ibis_table, fields): compiled_fields.append(ibis_table[field].cast(self.cast)) else: compiled_fields.append(ibis_table[field]) - return compiled_fields def compile(self, ibis_table): @@ -329,7 +368,13 @@ def compile(self, ibis_table): class QueryBuilder(object): def __init__( - self, aggregate_fields, calculated_fields, filters, grouped_fields, limit=None + self, + aggregate_fields, + calculated_fields, + filters, + grouped_fields, + comparison_fields, + limit=None, ): """ Build a QueryBuilder object which can be used to build queries easily @@ -344,6 +389,7 @@ def __init__( self.calculated_fields = calculated_fields self.filters = filters self.grouped_fields = grouped_fields + self.comparison_fields = comparison_fields self.limit = limit @staticmethod @@ -352,12 +398,14 @@ def build_count_validator(limit=None): aggregate_fields = [] filters = [] grouped_fields = [] + comparison_fields = [] calculated_fields = [] return QueryBuilder( aggregate_fields, filters=filters, grouped_fields=grouped_fields, + comparison_fields=comparison_fields, calculated_fields=calculated_fields, ) @@ -372,15 +420,23 @@ def compile_filter_fields(self, table): def compile_group_fields(self, table): return [field.compile(table) for field in self.grouped_fields] - def compile_calculated_fields(self, table, n=None): - if n is not None: - return [ - field.compile(table) - for field in self.calculated_fields - if field.config["depth"] == n - ] - else: - return [field.compile(table) for field in self.calculated_fields] + def compile_comparison_fields(self, table): + return [field.compile(table) for field in self.comparison_fields] + + def compile_calculated_fields(self, table, n=0): + return [ + field.compile(table) + for field in self.calculated_fields + if field.config[consts.CONFIG_DEPTH] == n + ] + # if n is not None: + # return [ + # field.compile(table) + # for field in self.calculated_fields + # if field.config[consts.CONFIG_DEPTH] == n + # ] + # else: + # return [field.compile(table) for field in self.calculated_fields] def compile(self, data_client, schema_name, table_name): """Return an Ibis query object @@ -396,7 +452,8 @@ def compile(self, data_client, schema_name, table_name): calc_table = table if self.calculated_fields: depth_limit = max( - field.config.get("depth", 0) for field in self.calculated_fields + field.config.get(consts.CONFIG_DEPTH, 0) + for field in self.calculated_fields ) for n in range(0, (depth_limit + 1)): calc_table = calc_table.mutate( @@ -406,15 +463,18 @@ def compile(self, data_client, schema_name, table_name): filtered_table = ( calc_table.filter(compiled_filters) if compiled_filters else calc_table ) - compiled_groups = self.compile_group_fields(filtered_table) grouped_table = ( filtered_table.groupby(compiled_groups) if compiled_groups else filtered_table ) - - query = grouped_table.aggregate(self.compile_aggregate_fields(filtered_table)) + if self.aggregate_fields: + query = grouped_table.aggregate( + self.compile_aggregate_fields(filtered_table) + ) + else: + query = grouped_table if self.limit: query = query.limit(self.limit) @@ -430,6 +490,15 @@ def add_aggregate_field(self, aggregate_field): """ self.aggregate_fields.append(aggregate_field) + def add_comparison_field(self, comparison_field): + """Add an ComparisonField instance to the query which + will be used when compiling your query (ie. SUM(a)) + + Args: + comparison_field (ComparisonField): An ComparisonField instance + """ + self.comparison_fields.append(comparison_field) + def add_grouped_field(self, grouped_field): """Add a GroupedField instance to the query which represents adding a column to group by in the diff --git a/data_validation/result_handlers/text.py b/data_validation/result_handlers/text.py index 73735ffb1..9e3a2bf5e 100644 --- a/data_validation/result_handlers/text.py +++ b/data_validation/result_handlers/text.py @@ -23,6 +23,7 @@ """ from data_validation import consts +import pandas as pd class TextResultHandler(object): @@ -35,6 +36,16 @@ def print_formatted_(self, result_df): Utility for printing formatted results :param result_df """ + # the text transformer chokes on bytestring results (e.g. SHA256) this + # dataframe slice is to remove the source and target values + mask = result_df["validation_type"] == consts.ROW_VALIDATION + row_result_df = result_df[mask] + other_result_df = result_df[~mask] + row_result_df["source_agg_value"] = None + row_result_df["target_agg_value"] = None + frames = [row_result_df, other_result_df] + result_df = pd.concat(frames) + if self.format == "text": print(result_df.to_string(index=False)) elif self.format == "csv": diff --git a/data_validation/validation_builder.py b/data_validation/validation_builder.py index 7477525e1..2af152e40 100644 --- a/data_validation/validation_builder.py +++ b/data_validation/validation_builder.py @@ -45,13 +45,17 @@ def __init__(self, config_manager): self.source_builder = self.get_query_builder(self.validation_type) self.target_builder = self.get_query_builder(self.validation_type) + self.primary_keys = {} self.group_aliases = {} self.calculated_aliases = {} + self.comparison_fields = {} self.add_config_aggregates() self.add_config_query_groups() self.add_config_calculated_fields() + self.add_comparison_fields() self.add_config_filters() + self.add_primary_keys() self.add_query_limit() def clone(self): @@ -61,6 +65,7 @@ def clone(self): cloned_builder.target_builder = deepcopy(self.target_builder) cloned_builder.group_aliases = deepcopy(self.group_aliases) cloned_builder.calculated_aliases = deepcopy(self.calculated_aliases) + cloned_builder.comparison_fields = deepcopy(self.comparison_fields) cloned_builder._metadata = deepcopy(self._metadata) return cloned_builder @@ -90,12 +95,21 @@ def get_metadata(self): def get_group_aliases(self): """ Return List of String Aliases """ - return self.group_aliases.keys() + return list(self.group_aliases.keys()) + + def get_primary_keys(self): + """ Return List of String Aliases """ + # do we need this? + return list(self.primary_keys.keys()) def get_calculated_aliases(self): """ Return List of String Aliases """ return self.calculated_aliases.keys() + def get_comparison_fields(self): + """ Return List of String Aliases """ + return self.comparison_fields.keys() + def get_grouped_alias_source_column(self, alias): return self.group_aliases[alias][consts.CONFIG_SOURCE_COLUMN] @@ -115,6 +129,16 @@ def add_config_calculated_fields(self): for calc_field in calc_fields: self.add_calc(calc_field) + def add_primary_keys(self, primary_keys=None): + primary_keys = primary_keys or self.config_manager.primary_keys + for field in primary_keys: + self.add_primary_key(field) + + def add_comparison_fields(self, comparison_fields=None): + comparison_fields = comparison_fields or self.config_manager.comparison_fields + for field in comparison_fields: + self.add_comparison_field(field) + def add_config_query_groups(self, query_groups=None): """ Add Grouped Columns to Query """ grouped_fields = query_groups or self.config_manager.query_groups @@ -192,6 +216,21 @@ def add_query_group(self, grouped_field): self.target_builder.add_grouped_field(target_field) self.group_aliases[alias] = grouped_field + def add_primary_key(self, primary_key): + """ Add ComparionField to Queries + + Args: + primary_key (Dict): An object with source, target, and cast info + """ + source_field_name = primary_key[consts.CONFIG_SOURCE_COLUMN] + target_field_name = primary_key[consts.CONFIG_TARGET_COLUMN] + # grab calc field metadata + alias = primary_key[consts.CONFIG_FIELD_ALIAS] + # check if valid calc field and return correct object + self.source_builder.add_comparison_field(source_field_name) + self.target_builder.add_comparison_field(target_field_name) + self.primary_keys[alias] = primary_key + def add_filter(self, filter_field): """Add FilterField to Queries @@ -228,6 +267,31 @@ def add_filter(self, filter_field): self.source_builder.add_filter_field(source_filter) self.target_builder.add_filter_field(target_filter) + def add_comparison_field(self, comparison_field): + """ Add ComparionField to Queries + + Args: + comparison_field (Dict): An object with source, target, and cast info + """ + source_field_name = comparison_field[consts.CONFIG_SOURCE_COLUMN] + target_field_name = comparison_field[consts.CONFIG_TARGET_COLUMN] + # grab calc field metadata + alias = comparison_field[consts.CONFIG_FIELD_ALIAS] + # check if valid calc field and return correct object + self.source_builder.add_comparison_field(source_field_name) + self.target_builder.add_comparison_field(target_field_name) + self._metadata[alias] = metadata.ValidationMetadata( + aggregation_type=None, + validation_type=self.validation_type, + source_table_schema=self.config_manager.source_schema, + source_table_name=self.config_manager.source_table, + target_table_schema=self.config_manager.target_schema, + target_table_name=self.config_manager.target_table, + source_column_name=source_field_name, + target_column_name=target_field_name, + threshold=self.config_manager.threshold, + ) + def add_calc(self, calc_field): """ Add CalculatedField to Queries @@ -265,6 +329,7 @@ def get_source_query(self): } query = self.source_builder.compile(**source_config) if self.verbose: + print(source_config) print("-- ** Source Query ** --") print(query.compile()) @@ -279,6 +344,7 @@ def get_target_query(self): } query = self.target_builder.compile(**target_config) if self.verbose: + print(target_config) print("-- ** Target Query ** --") print(query.compile()) diff --git a/docs/examples.md b/docs/examples.md index b3019caff..ca634e278 100644 --- a/docs/examples.md +++ b/docs/examples.md @@ -41,6 +41,11 @@ data-validation validate column -sc my_bq_conn -tc my_bq_conn -tbls bigquery-pub data-validation validate column -sc my_bq_conn -tc my_bq_conn -tbls bigquery-public-data.new_york_citibike.citibike_trips --count bikeid,gender ```` +#### Run a checksum validation for all rows +````shell script +data-validation validate row -sc my_bq_conn -tc my_bq_conn -tbls bigquery-public-data.new_york_citibike.citibike_trips --primary-keys station_id --hash '*' +```` + #### Store results in a BigQuery table ````shell script data-validation validate column -sc my_bq_conn -tc my_bq_conn -tbls bigquery-public-data.new_york_citibike.citibike_trips --count tripduration,start_station_name -bqrh $YOUR_PROJECT_ID.pso_data_validator.results @@ -178,4 +183,4 @@ validations: target_table_name: citibike_stations threshold: 0.0 type: Column - ``` \ No newline at end of file + ``` diff --git a/tests/system/data_sources/test_bigquery.py b/tests/system/data_sources/test_bigquery.py index ebba9593e..68a37bc7a 100644 --- a/tests/system/data_sources/test_bigquery.py +++ b/tests/system/data_sources/test_bigquery.py @@ -146,9 +146,8 @@ os.environ["PROJECT_ID"], ] CLI_STORE_COLUMN_ARGS = [ - "run", - "--type", - "Column", + "validate", + "column", "--source-conn", BQ_CONN_NAME, "--target-conn", @@ -166,7 +165,7 @@ "--config-file", CLI_CONFIG_FILE, ] -EXPECTED_NUM_YAML_LINES = 35 # Expected number of lines for validation config geenrated by CLI_STORE_COLUMN_ARGS +EXPECTED_NUM_YAML_LINES = 33 # Expected number of lines for validation config geenrated by CLI_STORE_COLUMN_ARGS CLI_RUN_CONFIG_ARGS = ["run-config", "--config-file", CLI_CONFIG_FILE] CLI_CONFIGS_RUN_ARGS = ["configs", "run", "--config-file", CLI_CONFIG_FILE] diff --git a/tests/unit/result_handlers/test_text.py b/tests/unit/result_handlers/test_text.py index 4f0fddf33..bec0305c1 100644 --- a/tests/unit/result_handlers/test_text.py +++ b/tests/unit/result_handlers/test_text.py @@ -17,9 +17,27 @@ from pandas import DataFrame SAMPLE_CONFIG = {} -SAMPLE_RESULT_DATA = [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]] -SAMPLE_RESULT_COLUMNS = ["A", "B", "C", "D"] -SAMPLE_RESULT_COLUMNS_FILTER_LIST = ["B", "D"] +SAMPLE_RESULT_DATA = [ + [0, 1, 2, 3, "Column", "source", "target"], + [4, 5, 6, 7, "Column", "source", "target"], + [8, 9, 10, 11, "Column", "source", "target"], +] +SAMPLE_RESULT_COLUMNS = [ + "A", + "B", + "C", + "D", + "validation_type", + "source_agg_value", + "target_agg_value", +] +SAMPLE_RESULT_COLUMNS_FILTER_LIST = [ + "B", + "D", + "validation_type", + "source_agg_value", + "target_agg_value", +] @pytest.fixture @@ -70,6 +88,7 @@ def test_columns_to_print(module_under_test, capsys): grid_text = "││A│C││0│0│2││1│4│6││2│8│10│" printed_text = capsys.readouterr().out + print(printed_text) printed_text = ( printed_text.replace("\n", "") .replace("'", "") diff --git a/tests/unit/test_config_manager.py b/tests/unit/test_config_manager.py index 0b985d9ab..86c2f6460 100644 --- a/tests/unit/test_config_manager.py +++ b/tests/unit/test_config_manager.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy import pytest from data_validation import consts @@ -159,7 +158,9 @@ def test_get_threshold_property(module_under_test): def test_process_in_memory(module_under_test): - """Test process in memory for normal validations.""" + """Test process in memory for normal validations. + TODO: emceehilton Re-enable opposite test once option is available + """ config_manager = module_under_test.ConfigManager( SAMPLE_CONFIG, MockIbisClient(), MockIbisClient(), verbose=False ) @@ -167,21 +168,21 @@ def test_process_in_memory(module_under_test): assert config_manager.process_in_memory() is True -def test_do_not_process_in_memory(module_under_test): - """Test process in memory for normal validations.""" - config_manager = module_under_test.ConfigManager( - copy.deepcopy(SAMPLE_CONFIG), MockIbisClient(), MockIbisClient(), verbose=False - ) - config_manager._config[consts.CONFIG_TYPE] = consts.ROW_VALIDATION - config_manager._config[consts.CONFIG_PRIMARY_KEYS] = [ - { - consts.CONFIG_FIELD_ALIAS: "id", - consts.CONFIG_SOURCE_COLUMN: "id", - consts.CONFIG_TARGET_COLUMN: "id", - consts.CONFIG_CAST: None, - }, - ] - assert config_manager.process_in_memory() is False +# def test_do_not_process_in_memory(module_under_test): +# """Test process in memory for normal validations.""" +# config_manager = module_under_test.ConfigManager( +# copy.deepcopy(SAMPLE_CONFIG), MockIbisClient(), MockIbisClient(), verbose=False +# ) +# config_manager._config[consts.CONFIG_TYPE] = consts.ROW_VALIDATION +# config_manager._config[consts.CONFIG_PRIMARY_KEYS] = [ +# { +# consts.CONFIG_FIELD_ALIAS: "id", +# consts.CONFIG_SOURCE_COLUMN: "id", +# consts.CONFIG_TARGET_COLUMN: "id", +# consts.CONFIG_CAST: None, +# }, +# ] +# assert config_manager.process_in_memory() is True def test_get_table_info(module_under_test): diff --git a/tests/unit/test_data_validation.py b/tests/unit/test_data_validation.py index 04f616299..cca6fbf69 100644 --- a/tests/unit/test_data_validation.py +++ b/tests/unit/test_data_validation.py @@ -102,12 +102,12 @@ } # Grouped Column Row confg -SAMPLE_GC_ROW_CONFIG = { +SAMPLE_GC_CONFIG = { # BigQuery Specific Connection Config "source_conn": SOURCE_CONN_CONFIG, "target_conn": TARGET_CONN_CONFIG, # Validation Type - consts.CONFIG_TYPE: consts.ROW_VALIDATION, + consts.CONFIG_TYPE: consts.COLUMN_VALIDATION, consts.CONFIG_MAX_RECURSIVE_QUERY_SIZE: 50, # Configuration Required Depending on Validator Type "schema_name": None, @@ -142,12 +142,59 @@ consts.CONFIG_FORMAT: "table", } -SAMPLE_GC_ROW_CALC_CONFIG = { +# Grouped Column Row confg +SAMPLE_MULTI_GC_CONFIG = { # BigQuery Specific Connection Config "source_conn": SOURCE_CONN_CONFIG, "target_conn": TARGET_CONN_CONFIG, # Validation Type - consts.CONFIG_TYPE: "Column", + consts.CONFIG_TYPE: consts.COLUMN_VALIDATION, + consts.CONFIG_MAX_RECURSIVE_QUERY_SIZE: 50, + # Configuration Required Depending on Validator Type + "schema_name": None, + "table_name": "my_table", + "target_schema_name": None, + "target_table_name": "my_table", + consts.CONFIG_GROUPED_COLUMNS: [ + { + consts.CONFIG_FIELD_ALIAS: "date_value", + consts.CONFIG_SOURCE_COLUMN: "date_value", + consts.CONFIG_TARGET_COLUMN: "date_value", + consts.CONFIG_CAST: "date", + }, + { + consts.CONFIG_FIELD_ALIAS: "id", + consts.CONFIG_SOURCE_COLUMN: "id", + consts.CONFIG_TARGET_COLUMN: "id", + consts.CONFIG_CAST: None, + }, + ], + consts.CONFIG_PRIMARY_KEYS: [ + { + consts.CONFIG_FIELD_ALIAS: "id", + consts.CONFIG_SOURCE_COLUMN: "id", + consts.CONFIG_TARGET_COLUMN: "id", + consts.CONFIG_CAST: None, + } + ], + consts.CONFIG_AGGREGATES: [ + { + "source_column": "text_value", + "target_column": "text_value", + "field_alias": "count_text_value", + "type": "count", + }, + ], + consts.CONFIG_RESULT_HANDLER: None, + consts.CONFIG_FORMAT: "table", +} + +SAMPLE_GC_CALC_CONFIG = { + # BigQuery Specific Connection Config + "source_conn": SOURCE_CONN_CONFIG, + "target_conn": TARGET_CONN_CONFIG, + # Validation Type + consts.CONFIG_TYPE: consts.COLUMN_VALIDATION, consts.CONFIG_MAX_RECURSIVE_QUERY_SIZE: 50, # Configuration Required Depending on Validator Type "schema_name": None, @@ -207,12 +254,12 @@ }, ], consts.CONFIG_AGGREGATES: [ - # { - # "source_column": "text_value", - # "target_column": "text_value", - # "field_alias": "count_text_value", - # "type": "count", - # }, + { + "source_column": "text_value", + "target_column": "text_value", + "field_alias": "count_text_value", + "type": "count", + }, { "source_column": "length_text_constant", "target_column": "length_text_constant", @@ -237,10 +284,83 @@ consts.CONFIG_FORMAT: "table", } +# Row confg +SAMPLE_ROW_CONFIG = { + # BigQuery Specific Connection Config + "source_conn": SOURCE_CONN_CONFIG, + "target_conn": TARGET_CONN_CONFIG, + # Validation Type + consts.CONFIG_TYPE: consts.ROW_VALIDATION, + # Configuration Required Depending on Validator Type + "schema_name": None, + "table_name": "my_table", + "target_schema_name": None, + "target_table_name": "my_table", + consts.CONFIG_PRIMARY_KEYS: [ + { + consts.CONFIG_FIELD_ALIAS: "id", + consts.CONFIG_SOURCE_COLUMN: "id", + consts.CONFIG_TARGET_COLUMN: "id", + consts.CONFIG_CAST: None, + }, + ], + consts.CONFIG_COMPARISON_FIELDS: [ + { + consts.CONFIG_FIELD_ALIAS: "int_value", + consts.CONFIG_SOURCE_COLUMN: "int_value", + consts.CONFIG_TARGET_COLUMN: "int_value", + consts.CONFIG_CAST: None, + }, + { + consts.CONFIG_FIELD_ALIAS: "text_value", + consts.CONFIG_SOURCE_COLUMN: "text_value", + consts.CONFIG_TARGET_COLUMN: "text_value", + consts.CONFIG_CAST: None, + }, + ], + consts.CONFIG_RESULT_HANDLER: None, + consts.CONFIG_FORMAT: "table", +} + +# Row confg +SAMPLE_JSON_ROW_CONFIG = { + # BigQuery Specific Connection Config + "source_conn": SOURCE_CONN_CONFIG, + "target_conn": TARGET_CONN_CONFIG, + # Validation Type + consts.CONFIG_TYPE: consts.ROW_VALIDATION, + # Configuration Required Depending on Validator Type + "schema_name": None, + "table_name": "my_table", + "target_schema_name": None, + "target_table_name": "my_table", + consts.CONFIG_PRIMARY_KEYS: [ + { + consts.CONFIG_FIELD_ALIAS: "pkey", + consts.CONFIG_SOURCE_COLUMN: "pkey", + consts.CONFIG_TARGET_COLUMN: "pkey", + consts.CONFIG_CAST: None, + }, + ], + consts.CONFIG_COMPARISON_FIELDS: [ + { + consts.CONFIG_FIELD_ALIAS: "col_b", + consts.CONFIG_SOURCE_COLUMN: "col_b", + consts.CONFIG_TARGET_COLUMN: "col_b", + consts.CONFIG_CAST: None, + }, + ], + consts.CONFIG_RESULT_HANDLER: None, + consts.CONFIG_FORMAT: "table", +} -JSON_DATA = """[{"col_a":0,"col_b":"a"},{"col_a":1,"col_b":"b"}]""" +JSON_DATA = """[{"col_a":1,"col_b":"a"},{"col_a":1,"col_b":"b"}]""" JSON_COLA_ZERO_DATA = """[{"col_a":null,"col_b":"a"}]""" JSON_BAD_DATA = """[{"col_a":0,"col_b":"a"},{"col_a":1,"col_b":"b"},{"col_a":2,"col_b":"c"},{"col_a":3,"col_b":"d"},{"col_a":4,"col_b":"e"}]""" +JSON_PK_DATA = ( + """[{"pkey":1, "col_a":1,"col_b":"a"},{"pkey":2, "col_a":1,"col_b":"b"}]""" +) +JSON_PK_BAD_DATA = """[{"pkey":1, "col_a":0,"col_b":"b"},{"pkey":2, "col_a":1,"col_b":"c"},{"pkey":3, "col_a":2,"col_b":"d"},{"pkey":4, "col_a":3,"col_b":"e"},{"pkey":5, "col_a":4,"col_b":"f"}]""" STRING_CONSTANT = "constant" @@ -405,7 +525,6 @@ def test_status_fail_validation(module_under_test, fs): client = module_under_test.DataValidation(SAMPLE_CONFIG) result_df = client.execute() - col_a_result_df = result_df[result_df.validation_name == "count_col_a"] col_a_pct_threshold = col_a_result_df.pct_threshold.values[0] col_a_status = col_a_result_df.status.values[0] @@ -420,7 +539,6 @@ def test_threshold_equals_diff(module_under_test, fs): client = module_under_test.DataValidation(SAMPLE_THRESHOLD_CONFIG) result_df = client.execute() - col_a_result_df = result_df[result_df.validation_name == "count_col_a"] col_a_pct_diff = col_a_result_df.pct_difference.values[0] col_a_pct_threshold = col_a_result_df.pct_threshold.values[0] @@ -431,14 +549,14 @@ def test_threshold_equals_diff(module_under_test, fs): assert col_a_status == "success" -def test_row_level_validation_perfect_match(module_under_test, fs): +def test_grouped_column_level_validation_perfect_match(module_under_test, fs): data = _generate_fake_data(second_range=0) json_data = _get_fake_json_data(data) _create_table_file(SOURCE_TABLE_FILE_PATH, json_data) _create_table_file(TARGET_TABLE_FILE_PATH, json_data) - client = module_under_test.DataValidation(SAMPLE_GC_ROW_CONFIG) + client = module_under_test.DataValidation(SAMPLE_GC_CONFIG) result_df = client.execute() expected_date_result = '{"date_value": "%s"}' % str(datetime.now().date()) @@ -455,7 +573,7 @@ def test_calc_field_validation_calc_match(module_under_test, fs): _create_table_file(SOURCE_TABLE_FILE_PATH, json_data) _create_table_file(TARGET_TABLE_FILE_PATH, json_data) - client = module_under_test.DataValidation(SAMPLE_GC_ROW_CALC_CONFIG) + client = module_under_test.DataValidation(SAMPLE_GC_CALC_CONFIG) result_df = client.execute() calc_val_df = result_df[result_df["validation_name"] == "sum_length"] calc_val_df2 = result_df[result_df["validation_name"] == "sum_concat_length"] @@ -470,34 +588,26 @@ def test_calc_field_validation_calc_match(module_under_test, fs): assert calc_val_df3["source_agg_value"].sum() == str(num_rows * 2) -def test_row_level_validation_non_matching(module_under_test, fs): +def test_grouped_column_level_validation_non_matching(module_under_test, fs): data = _generate_fake_data(rows=10, second_range=0) trg_data = _generate_fake_data(initial_id=11, rows=1, second_range=0) - source_json_data = _get_fake_json_data(data) target_json_data = _get_fake_json_data(data + trg_data) _create_table_file(SOURCE_TABLE_FILE_PATH, source_json_data) _create_table_file(TARGET_TABLE_FILE_PATH, target_json_data) - - client = module_under_test.DataValidation(SAMPLE_GC_ROW_CONFIG, verbose=True) + client = module_under_test.DataValidation(SAMPLE_GC_CONFIG) result_df = client.execute() validation_df = result_df[result_df["validation_name"] == "count_text_value"] + # TODO: this value is 0 because a COUNT() on no rows returns Null + assert result_df["difference"].sum() == 1 - # TODO: this value is 0 because a COUNT() on no rows returns Null. - # When calc fields is released, we could COALESCE(COUNT(), 0) to avoid this - assert result_df["difference"].sum() == 0 - - expected_date_result = '{"date_value": "%s", "id": "11"}' % str( - datetime.now().date() - ) - grouped_column = validation_df[validation_df["source_table_name"].isnull()][ - "group_by_columns" - ].max() + expected_date_result = '{"date_value": "%s"}' % str(datetime.now().date()) + grouped_column = validation_df["group_by_columns"].max() assert expected_date_result == grouped_column -def test_row_level_validation_smart_count(module_under_test, fs): +def test_grouped_column_level_validation_smart_count(module_under_test, fs): data = _generate_fake_data(rows=100, second_range=0) source_json_data = _get_fake_json_data(data) @@ -506,7 +616,7 @@ def test_row_level_validation_smart_count(module_under_test, fs): _create_table_file(SOURCE_TABLE_FILE_PATH, source_json_data) _create_table_file(TARGET_TABLE_FILE_PATH, target_json_data) - client = module_under_test.DataValidation(SAMPLE_GC_ROW_CONFIG) + client = module_under_test.DataValidation(SAMPLE_GC_CONFIG) result_df = client.execute() expected_date_result = '{"date_value": "%s"}' % str(datetime.now().date()) @@ -517,7 +627,7 @@ def test_row_level_validation_smart_count(module_under_test, fs): assert smart_count_df["target_agg_value"].astype(int).sum() == 200 -def test_row_level_validation_multiple_aggregations(module_under_test, fs): +def test_grouped_column_level_validation_multiple_aggregations(module_under_test): data = _generate_fake_data(rows=10, second_range=0) trg_data = _generate_fake_data(initial_id=11, rows=1, second_range=0) @@ -527,11 +637,61 @@ def test_row_level_validation_multiple_aggregations(module_under_test, fs): _create_table_file(SOURCE_TABLE_FILE_PATH, source_json_data) _create_table_file(TARGET_TABLE_FILE_PATH, target_json_data) - client = module_under_test.DataValidation(SAMPLE_GC_ROW_CONFIG, verbose=True) + client = module_under_test.DataValidation(SAMPLE_MULTI_GC_CONFIG) result_df = client.execute() - validation_df = result_df[result_df["validation_name"] == "count_text_value"] - + validation_df = result_df # [result_df["validation_name"] == "count_text_value"] # Expect 11 rows, one for each PK value assert len(validation_df) == 11 assert validation_df["source_agg_value"].astype(float).sum() == 10 assert validation_df["target_agg_value"].astype(float).sum() == 11 + + +def test_row_level_validation(module_under_test, fs): + data = _generate_fake_data(rows=100, second_range=0) + + source_json_data = _get_fake_json_data(data) + target_json_data = _get_fake_json_data(data) + + _create_table_file(SOURCE_TABLE_FILE_PATH, source_json_data) + _create_table_file(TARGET_TABLE_FILE_PATH, target_json_data) + + client = module_under_test.DataValidation(SAMPLE_ROW_CONFIG) + result_df = client.execute() + + str_comparison_df = result_df[result_df["validation_name"] == "text_value"] + int_comparison_df = result_df[result_df["validation_name"] == "int_value"] + + assert len(result_df) == 200 + assert len(str_comparison_df) == 100 + assert len(int_comparison_df) == 100 + + +def test_fail_row_level_validation(module_under_test, fs): + _create_table_file(SOURCE_TABLE_FILE_PATH, JSON_PK_DATA) + _create_table_file(TARGET_TABLE_FILE_PATH, JSON_PK_BAD_DATA) + + client = module_under_test.DataValidation(SAMPLE_JSON_ROW_CONFIG) + result_df = client.execute() + + # based on shared keys + fail_df = result_df[result_df["status"] == "fail"] + assert len(fail_df) == 5 + + +def test_bad_join_row_level_validation(module_under_test, fs): + data = _generate_fake_data(rows=100, second_range=0) + target_data = _generate_fake_data(initial_id=100, rows=1, second_range=0) + + source_json_data = _get_fake_json_data(data) + target_json_data = _get_fake_json_data(target_data) + + _create_table_file(SOURCE_TABLE_FILE_PATH, source_json_data) + _create_table_file(TARGET_TABLE_FILE_PATH, target_json_data) + + client = module_under_test.DataValidation(SAMPLE_ROW_CONFIG) + result_df = client.execute() + + comparison_df = result_df[result_df["status"] == "fail"] + # 2 validations * (100 source + 1 target) + assert len(result_df) == 202 + assert len(comparison_df) == 202