diff --git a/data_validation/partition_builder.py b/data_validation/partition_builder.py index b6f431de..31d6ab19 100644 --- a/data_validation/partition_builder.py +++ b/data_validation/partition_builder.py @@ -32,7 +32,6 @@ def __init__(self, config_managers: List[ConfigManager], args: Namespace) -> Non self.table_count = len(config_managers) self.args = args self.config_dir = self._get_arg_config_dir() - self.primary_keys = self._get_primary_keys() def _get_arg_config_dir(self) -> str: """Return String yaml config folder path.""" @@ -41,10 +40,6 @@ def _get_arg_config_dir(self) -> str: return self.args.config_dir - def _get_primary_keys(self) -> str: - """Return the Primary Keys""" - return cli_tools.get_arg_list(self.args.primary_keys) - def _get_yaml_from_config(self, config_manager: ConfigManager) -> Dict: """Return dict objects formatted for yaml validations. @@ -106,8 +101,13 @@ def _get_partition_key_filters(self) -> List[List[List[str]]]: for config_manager in self.config_managers: # For each pair of tables validation_builder = ValidationBuilder(config_manager) + source_pks, target_pks = [], [] + for pk in config_manager.primary_keys: + source_pks.append(pk["source_column"]) + target_pks.append(pk["target_column"]) + source_partition_row_builder = PartitionRowBuilder( - self.primary_keys, + source_pks, config_manager.source_client, config_manager.source_schema, config_manager.source_table, @@ -115,7 +115,7 @@ def _get_partition_key_filters(self) -> List[List[List[str]]]: ) source_table = source_partition_row_builder.query target_partition_row_builder = PartitionRowBuilder( - self.primary_keys, + target_pks, config_manager.target_client, config_manager.target_schema, config_manager.target_table, @@ -150,9 +150,9 @@ def _get_partition_key_filters(self) -> List[List[List[str]]]: # First we number each row in the source table. Using row_number instead of ntile since it is # available on all platforms (Teradata does not support NTILE). For our purposes, it is likely # more efficient - window1 = ibis.window(order_by=self.primary_keys) + window1 = ibis.window(order_by=source_pks) row_number = (ibis.row_number().over(window1) + 1).name(consts.DVT_POS_COL) - dvt_keys = self.primary_keys.copy() + dvt_keys = source_pks.copy() dvt_keys.append(row_number) rownum_table = source_table.select(dvt_keys) # Rownum table is just the primary key columns in the source table along with @@ -182,7 +182,7 @@ def _get_partition_key_filters(self) -> List[List[List[str]]]: ) > 0 ) - first_keys_table = rownum_table[cond].order_by(self.primary_keys) + first_keys_table = rownum_table[cond].order_by(source_pks) # Up until this point, we have built the table expression, have not executed the query yet. # The query is now executed to find the first element of each partition @@ -219,13 +219,13 @@ def _get_partition_key_filters(self) -> List[List[List[str]]]: filter_source_clause = less_than_value( source_table, - self.primary_keys, - first_elements[1, : len(self.primary_keys)], + source_pks, + first_elements[1, : len(source_pks)], ) filter_target_clause = less_than_value( target_table, - self.primary_keys, - first_elements[1, : len(self.primary_keys)], + target_pks, + first_elements[1, : len(target_pks)], ) source_where_list.append( self._extract_where( @@ -243,21 +243,21 @@ def _get_partition_key_filters(self) -> List[List[List[str]]]: for i in range(1, first_elements.shape[0] - 1): filter_source_clause = geq_value( source_table, - self.primary_keys, - first_elements[i, : len(self.primary_keys)], + source_pks, + first_elements[i, : len(source_pks)], ) & less_than_value( source_table, - self.primary_keys, - first_elements[i + 1, : len(self.primary_keys)], + source_pks, + first_elements[i + 1, : len(source_pks)], ) filter_target_clause = geq_value( target_table, - self.primary_keys, - first_elements[i, : len(self.primary_keys)], + target_pks, + first_elements[i, : len(target_pks)], ) & less_than_value( target_table, - self.primary_keys, - first_elements[i + 1, : len(self.primary_keys)], + target_pks, + first_elements[i + 1, : len(target_pks)], ) source_where_list.append( self._extract_where( @@ -273,13 +273,13 @@ def _get_partition_key_filters(self) -> List[List[List[str]]]: ) filter_source_clause = geq_value( source_table, - self.primary_keys, - first_elements[len(first_elements) - 1, : len(self.primary_keys)], + source_pks, + first_elements[len(first_elements) - 1, : len(source_pks)], ) filter_target_clause = geq_value( target_table, - self.primary_keys, - first_elements[len(first_elements) - 1, : len(self.primary_keys)], + target_pks, + first_elements[len(first_elements) - 1, : len(target_pks)], ) source_where_list.append( self._extract_where( diff --git a/tests/unit/test_partition_builder.py b/tests/unit/test_partition_builder.py index 9236a7f0..4a582c60 100644 --- a/tests/unit/test_partition_builder.py +++ b/tests/unit/test_partition_builder.py @@ -381,13 +381,6 @@ def test_class_object_creation(module_under_test): args = parser.parse_args(CLI_ARGS_SINGLE_KEY) builder = module_under_test.PartitionBuilder(config_managers, args) assert builder.table_count == len(config_managers) - assert builder.primary_keys == ["id"] - - # multiple primary keys are present - args = parser.parse_args(CLI_ARGS_MULTIPLE_KEYS) - builder = module_under_test.PartitionBuilder(config_managers, args) - assert builder.table_count == len(config_managers) - assert builder.primary_keys == ["region_id", "station_id"] def test_add_partition_filters_to_config(module_under_test):