Skip to content

Commit

Permalink
fix: Support PKs with different casing for generate-partitions (#1142)
Browse files Browse the repository at this point in the history
* fix: support different casing PKs for generate partitions

* remove invalid test

---------

Co-authored-by: sundar-mudupalli-work <[email protected]>
  • Loading branch information
nehanene15 and sundar-mudupalli-work committed May 24, 2024
1 parent 210c352 commit 021ce75
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 33 deletions.
52 changes: 26 additions & 26 deletions data_validation/partition_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ def __init__(self, config_managers: List[ConfigManager], args: Namespace) -> Non
self.table_count = len(config_managers)
self.args = args
self.config_dir = self._get_arg_config_dir()
self.primary_keys = self._get_primary_keys()

def _get_arg_config_dir(self) -> str:
"""Return String yaml config folder path."""
Expand All @@ -41,10 +40,6 @@ def _get_arg_config_dir(self) -> str:

return self.args.config_dir

def _get_primary_keys(self) -> str:
"""Return the Primary Keys"""
return cli_tools.get_arg_list(self.args.primary_keys)

def _get_yaml_from_config(self, config_manager: ConfigManager) -> Dict:
"""Return dict objects formatted for yaml validations.
Expand Down Expand Up @@ -106,16 +101,21 @@ def _get_partition_key_filters(self) -> List[List[List[str]]]:
for config_manager in self.config_managers: # For each pair of tables
validation_builder = ValidationBuilder(config_manager)

source_pks, target_pks = [], []
for pk in config_manager.primary_keys:
source_pks.append(pk["source_column"])
target_pks.append(pk["target_column"])

source_partition_row_builder = PartitionRowBuilder(
self.primary_keys,
source_pks,
config_manager.source_client,
config_manager.source_schema,
config_manager.source_table,
validation_builder.source_builder,
)
source_table = source_partition_row_builder.query
target_partition_row_builder = PartitionRowBuilder(
self.primary_keys,
target_pks,
config_manager.target_client,
config_manager.target_schema,
config_manager.target_table,
Expand Down Expand Up @@ -150,9 +150,9 @@ def _get_partition_key_filters(self) -> List[List[List[str]]]:
# First we number each row in the source table. Using row_number instead of ntile since it is
# available on all platforms (Teradata does not support NTILE). For our purposes, it is likely
# more efficient
window1 = ibis.window(order_by=self.primary_keys)
window1 = ibis.window(order_by=source_pks)
row_number = (ibis.row_number().over(window1) + 1).name(consts.DVT_POS_COL)
dvt_keys = self.primary_keys.copy()
dvt_keys = source_pks.copy()
dvt_keys.append(row_number)
rownum_table = source_table.select(dvt_keys)
# Rownum table is just the primary key columns in the source table along with
Expand Down Expand Up @@ -182,7 +182,7 @@ def _get_partition_key_filters(self) -> List[List[List[str]]]:
)
> 0
)
first_keys_table = rownum_table[cond].order_by(self.primary_keys)
first_keys_table = rownum_table[cond].order_by(source_pks)

# Up until this point, we have built the table expression, have not executed the query yet.
# The query is now executed to find the first element of each partition
Expand Down Expand Up @@ -219,13 +219,13 @@ def _get_partition_key_filters(self) -> List[List[List[str]]]:

filter_source_clause = less_than_value(
source_table,
self.primary_keys,
first_elements[1, : len(self.primary_keys)],
source_pks,
first_elements[1, : len(source_pks)],
)
filter_target_clause = less_than_value(
target_table,
self.primary_keys,
first_elements[1, : len(self.primary_keys)],
target_pks,
first_elements[1, : len(target_pks)],
)
source_where_list.append(
self._extract_where(
Expand All @@ -243,21 +243,21 @@ def _get_partition_key_filters(self) -> List[List[List[str]]]:
for i in range(1, first_elements.shape[0] - 1):
filter_source_clause = geq_value(
source_table,
self.primary_keys,
first_elements[i, : len(self.primary_keys)],
source_pks,
first_elements[i, : len(source_pks)],
) & less_than_value(
source_table,
self.primary_keys,
first_elements[i + 1, : len(self.primary_keys)],
source_pks,
first_elements[i + 1, : len(source_pks)],
)
filter_target_clause = geq_value(
target_table,
self.primary_keys,
first_elements[i, : len(self.primary_keys)],
target_pks,
first_elements[i, : len(target_pks)],
) & less_than_value(
target_table,
self.primary_keys,
first_elements[i + 1, : len(self.primary_keys)],
target_pks,
first_elements[i + 1, : len(target_pks)],
)
source_where_list.append(
self._extract_where(
Expand All @@ -273,13 +273,13 @@ def _get_partition_key_filters(self) -> List[List[List[str]]]:
)
filter_source_clause = geq_value(
source_table,
self.primary_keys,
first_elements[len(first_elements) - 1, : len(self.primary_keys)],
source_pks,
first_elements[len(first_elements) - 1, : len(source_pks)],
)
filter_target_clause = geq_value(
target_table,
self.primary_keys,
first_elements[len(first_elements) - 1, : len(self.primary_keys)],
target_pks,
first_elements[len(first_elements) - 1, : len(target_pks)],
)
source_where_list.append(
self._extract_where(
Expand Down
7 changes: 0 additions & 7 deletions tests/unit/test_partition_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,13 +381,6 @@ def test_class_object_creation(module_under_test):
args = parser.parse_args(CLI_ARGS_SINGLE_KEY)
builder = module_under_test.PartitionBuilder(config_managers, args)
assert builder.table_count == len(config_managers)
assert builder.primary_keys == ["id"]

# multiple primary keys are present
args = parser.parse_args(CLI_ARGS_MULTIPLE_KEYS)
builder = module_under_test.PartitionBuilder(config_managers, args)
assert builder.table_count == len(config_managers)
assert builder.primary_keys == ["region_id", "station_id"]


def test_add_partition_filters_to_config(module_under_test):
Expand Down

0 comments on commit 021ce75

Please sign in to comment.