diff --git a/.gitignore b/.gitignore index 5bc01cac8..a2b202091 100644 --- a/.gitignore +++ b/.gitignore @@ -76,3 +76,7 @@ terraform.rc # Custom *.yaml + +# Test temp files +source_table_data.json +target_table_data.json diff --git a/data_validation/__main__.py b/data_validation/__main__.py index 0a35e7664..aa13cf2ea 100644 --- a/data_validation/__main__.py +++ b/data_validation/__main__.py @@ -96,7 +96,16 @@ def get_calculated_config(args, config_manager): calculated_configs = [] fields = [] if args.hash: - fields = config_manager._build_dependent_aliases("hash") + col_list = None if args.hash == "*" else cli_tools.get_arg_list(args.hash) + fields = config_manager._build_dependent_aliases("hash", col_list) + aliases = [field["name"] for field in fields] + + # Add to list of necessary columns for selective hashing in order to drop + # excess columns with invalid data types (i.e structs) when generating source/target DFs + if col_list: + config_manager.append_dependent_aliases(col_list) + config_manager.append_dependent_aliases(aliases) + if len(fields) > 0: max_depth = max([x["depth"] for x in fields]) else: @@ -142,11 +151,13 @@ def build_config_from_args(args, config_manager): config_manager.append_comparison_fields( config_manager.build_config_comparison_fields(comparison_fields) ) + config_manager.append_dependent_aliases(comparison_fields) if args.primary_keys is not None: primary_keys = cli_tools.get_arg_list(args.primary_keys) config_manager.append_primary_keys( config_manager.build_config_comparison_fields(primary_keys) ) + config_manager.append_dependent_aliases(primary_keys) # TODO(GH#18): Add query filter config logic diff --git a/data_validation/config_manager.py b/data_validation/config_manager.py index 939217df7..dccebe0b6 100644 --- a/data_validation/config_manager.py +++ b/data_validation/config_manager.py @@ -128,6 +128,17 @@ def append_calculated_fields(self, calculated_configs): self.calculated_fields + calculated_configs ) + @property + def dependent_aliases(self): + """ Return all columns that are needed in final dataframe for row validations. """ + return self._config.get(consts.CONFIG_DEPENDENT_ALIASES, []) + + def append_dependent_aliases(self, dependent_aliases): + """ Appends columns that are needed in final dataframe for row validations. """ + self._config[consts.CONFIG_DEPENDENT_ALIASES] = ( + self.dependent_aliases + dependent_aliases + ) + @property def query_groups(self): """ Return Query Groups from Config """ @@ -507,11 +518,16 @@ def build_config_calculated_fields( } return calculated_config - def _build_dependent_aliases(self, calc_type): + def _build_dependent_aliases(self, calc_type, col_list=None): """This is a utility function for determining the required depth of all fields""" order_of_operations = [] - source_table = self.get_source_ibis_calculated_table() - casefold_source_columns = {x.casefold(): str(x) for x in source_table.columns} + if col_list is None: + source_table = self.get_source_ibis_calculated_table() + casefold_source_columns = { + x.casefold(): str(x) for x in source_table.columns + } + else: + casefold_source_columns = {x.casefold(): str(x) for x in col_list} if calc_type == "hash": order_of_operations = [ "cast", diff --git a/data_validation/consts.py b/data_validation/consts.py index 5743e511b..ad77edc46 100644 --- a/data_validation/consts.py +++ b/data_validation/consts.py @@ -30,6 +30,7 @@ CONFIG_FIELD_ALIAS = "field_alias" CONFIG_AGGREGATES = "aggregates" CONFIG_CALCULATED_FIELDS = "calculated_fields" +CONFIG_DEPENDENT_ALIASES = "dependent_aliases" CONFIG_GROUPED_COLUMNS = "grouped_columns" CONFIG_CALCULATED_SOURCE_COLUMNS = "source_calculated_columns" CONFIG_CALCULATED_TARGET_COLUMNS = "target_calculated_columns" diff --git a/data_validation/data_validation.py b/data_validation/data_validation.py index 95f650667..b5b3e1c4e 100644 --- a/data_validation/data_validation.py +++ b/data_validation/data_validation.py @@ -300,9 +300,27 @@ def _execute_validation(self, validation_builder, process_in_memory=True): if process_in_memory: source_df = self.config_manager.source_client.execute(source_query) target_df = self.config_manager.target_client.execute(target_query) + + # Drop excess fields for row validation to avoid pandas errors for unsupported column data types (i.e structs) + if ( + self.config_manager.validation_type == consts.ROW_VALIDATION + and self.config_manager.dependent_aliases + ): + source_df.drop( + source_df.columns.difference(self.config_manager.dependent_aliases), + axis=1, + inplace=True, + ) + target_df.drop( + target_df.columns.difference(self.config_manager.dependent_aliases), + axis=1, + inplace=True, + ) + pd_schema = self._get_pandas_schema( source_df, target_df, join_on_fields, verbose=self.verbose ) + pandas_client = ibis.backends.pandas.connect( {combiner.DEFAULT_SOURCE: source_df, combiner.DEFAULT_TARGET: target_df} ) diff --git a/tests/unit/test_config_manager.py b/tests/unit/test_config_manager.py index 86c2f6460..214125962 100644 --- a/tests/unit/test_config_manager.py +++ b/tests/unit/test_config_manager.py @@ -43,6 +43,22 @@ ], } +SAMPLE_ROW_CONFIG = { + # BigQuery Specific Connection Config + consts.CONFIG_SOURCE_CONN: {"type": "DNE connection"}, + consts.CONFIG_TARGET_CONN: {"type": "DNE connection"}, + # Validation Type + consts.CONFIG_TYPE: "Row", + # Configuration Required Depending on Validator Type + consts.CONFIG_SCHEMA_NAME: "bigquery-public-data.new_york_citibike", + consts.CONFIG_TABLE_NAME: "citibike_trips", + consts.CONFIG_GROUPED_COLUMNS: [], + consts.CONFIG_THRESHOLD: 0.0, + consts.CONFIG_PRIMARY_KEYS: "id", + consts.CONFIG_CALCULATED_FIELDS: ["name", "station_id"], + consts.CONFIG_DEPENDENT_ALIASES: ["id", "name", "station_id"], +} + AGGREGATE_CONFIG_A = { consts.CONFIG_SOURCE_COLUMN: "a", consts.CONFIG_TARGET_COLUMN: "a", @@ -277,3 +293,18 @@ def test_get_result_handler(module_under_test): handler = config_manager.get_result_handler() assert handler._table_id == "dataset.table_name" + + +def test_dependent_aliases(module_under_test): + config_manager = module_under_test.ConfigManager( + SAMPLE_ROW_CONFIG, MockIbisClient(), MockIbisClient(), verbose=False + ) + config_manager.append_dependent_aliases(["location", "bike"]) + + assert config_manager.dependent_aliases == [ + "id", + "name", + "station_id", + "location", + "bike", + ]