Skip to content

Commit

Permalink
feat(generate-table-partitions): Works on all 7 platforms - BigQuery,…
Browse files Browse the repository at this point in the history
… Hive, MySQL, Oracle, Postgres, SQL Server and Teradata. (#922)

* Changes to test suites to test generate-partitions.
Teradata does not work.

* Updated generate-partitions to use simpler rownumber method which also works on teradata.
Tested with postgres and updated postgres tests. Other test will likely need to be modified.
Requires the ibis upgrade, hence merging with develop before finishing testing.

* Fixing some changes that were due to the use of git stash.

* Update to test suites to add test for generate partitions

* Lint updates

* Updated failing tests

* tests: Updated failing SQL Server tests

---------

Co-authored-by: nj1973 <[email protected]>
  • Loading branch information
sundar-mudupalli-work and nj1973 committed Aug 1, 2023
1 parent aed1505 commit aa84d7a
Show file tree
Hide file tree
Showing 10 changed files with 358 additions and 57 deletions.
9 changes: 3 additions & 6 deletions data_validation/consts.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,9 +168,6 @@
"pct_threshold",
]

# Constants for named columns used in generate partitions
# these cannot conflict with primary key column names
DVT_NTILE_COL = "dvt_ntile"
DVT_PART_NO = "dvt_part_no"
DVT_FIRST_PRE = "dvt_first_" # prefix for first_element_column names
DVT_LAST_PRE = "dvt_last_" # prefix for last_element_column names
# Constants for the named column used in generate partitions
# this cannot conflict with primary key column names
DVT_POS_COL = "dvt_pos_num"
97 changes: 54 additions & 43 deletions data_validation/partition_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import os
import ibis
import pandas
import logging
from typing import List, Dict
from argparse import Namespace
Expand Down Expand Up @@ -156,6 +157,13 @@ def _get_partition_key_filters(self) -> List[List[str]]:
source_count = source_partition_row_builder.get_count()
target_count = target_partition_row_builder.get_count()

# For some reason Teradata connector returns a dataframe with the count element,
# while the other connectors return a numpy.int64 value
if isinstance(source_count, pandas.DataFrame):
source_count = source_count.values[0][0]
if isinstance(target_count, pandas.DataFrame):
target_count = target_count.values[0][0]

if abs(source_count - target_count) > source_count * 0.1:
logging.warning(
"Source and Target table row counts vary by more than 10%,"
Expand All @@ -169,71 +177,74 @@ def _get_partition_key_filters(self) -> List[List[str]]:
else source_count
)

# First we use the ntile aggregate function and divide assign a partition
# number to each row in the source table
# First we number each row in the source table. Using row_number instead of ntile since it is
# available on all platforms (Teradata does not support NTILE). For our purposes, it is likely
# more efficient
window1 = ibis.window(order_by=self.primary_keys)
nt = (
source_table[self.primary_keys[0]]
.ntile(buckets=number_of_part)
.over(window1)
.name(consts.DVT_NTILE_COL)
)
dvt_nt = self.primary_keys.copy()
dvt_nt.append(nt)
partitioned_table = source_table.select(dvt_nt)
# Partitioned table is just the primary key columns in the source table along with
# an additional column with the partition number associated with each row.

# We are interested in only the primary key values at the begining of
# each partitition - the following window groups by partition number
window2 = ibis.window(
order_by=self.primary_keys, group_by=[consts.DVT_NTILE_COL]
)
first_pkys = [
partitioned_table[primary_key]
.first()
.over(window2)
.name(consts.DVT_FIRST_PRE + primary_key)
for primary_key in self.primary_keys
]
partition_no = (
partitioned_table[consts.DVT_NTILE_COL]
.first()
.over(window2)
.name(consts.DVT_PART_NO)
)
column_list = [partition_no] + first_pkys
partition_boundary = (
partitioned_table.select(column_list)
.sort_by([consts.DVT_PART_NO])
.distinct()
row_number = (ibis.row_number().over(window1) + 1).name(consts.DVT_POS_COL)
dvt_keys = self.primary_keys.copy()
dvt_keys.append(row_number)
rownum_table = source_table.select(dvt_keys)
# Rownum table is just the primary key columns in the source table along with
# an additional column with the row number associated with each row.

# This rather complicated expression below is a filter (where) clause condition that filters the row numbers
# that correspond to the first element of the partition. The number of a partition is
# ceiling(row number * # of partitions / total number of rows). The first element of the partition is where
# the remainder, i.e. row number * # of partitions % total number of rows is > 0 and <= number of partitions.
# The remainder function does not work well with Teradata, hence writing that out explicitly.
cond = (
(
rownum_table[consts.DVT_POS_COL] * number_of_part
- (
rownum_table[consts.DVT_POS_COL] * number_of_part / source_count
).floor()
* source_count
)
<= number_of_part
) & (
(
rownum_table[consts.DVT_POS_COL] * number_of_part
- (
rownum_table[consts.DVT_POS_COL] * number_of_part / source_count
).floor()
* source_count
)
> 0
)
first_keys_table = rownum_table[cond].order_by(self.primary_keys)

# Up until this point, we have built the table expression, have not executed the query yet.
# The query is now executed to find the first and last element of each partition
first_elements = partition_boundary.execute().to_numpy()
# The query is now executed to find the first element of each partition
first_elements = first_keys_table.execute().to_numpy()

# Once we have the first element of each partition, we can generate the where clause
# i.e. greater than or equal to first element and less than first element of next partition
# The first and the last partitions have special where clauses - less than first element of second
# partition and greater than or equal to the first element of the last partition respectively
filter_clause_list = []
filter_clause_list.append(
self._less_than_value(self.primary_keys, first_elements[1, 1:])
self._less_than_value(
self.primary_keys, first_elements[1, : len(self.primary_keys)]
)
)
for i in range(1, first_elements.shape[0] - 1):
filter_clause_list.append(
"("
+ self._geq_value(self.primary_keys, first_elements[i, 1:])
+ self._geq_value(
self.primary_keys, first_elements[i, : len(self.primary_keys)]
)
+ ") AND ("
+ self._less_than_value(
self.primary_keys, first_elements[i + 1, 1:]
self.primary_keys,
first_elements[i + 1, : len(self.primary_keys)],
)
+ ")"
)
filter_clause_list.append(
self._geq_value(
self.primary_keys, first_elements[len(first_elements) - 1, 1:]
self.primary_keys,
first_elements[len(first_elements) - 1, : len(self.primary_keys)],
)
)

Expand Down
2 changes: 1 addition & 1 deletion data_validation/query_builder/partition_row_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,4 +62,4 @@ def _compile_query(

def get_count(self) -> int:
"""Return a count of rows of primary keys - they should be all distinct"""
return self.query.select(self.primary_keys).count().execute()
return self.query[self.primary_keys].count().execute()
49 changes: 49 additions & 0 deletions tests/system/data_sources/test_bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from data_validation import __main__ as main
from data_validation import cli_tools, clients, consts, data_validation, state_manager
from data_validation.query_builder import random_row_builder
from data_validation.partition_builder import PartitionBuilder
from data_validation.query_builder.query_builder import QueryBuilder


Expand Down Expand Up @@ -1094,6 +1095,54 @@ def test_custom_query():
assert result_df.source_agg_value.equals(result_df.target_agg_value)


# Expected result from partitioning table on 3 keys
EXPECTED_PARTITION_FILTER = [
"course_id < 'ALG001' OR course_id = 'ALG001' AND (quarter_id < 3 OR quarter_id = 3 AND (student_id < 1234))",
"(course_id > 'ALG001' OR course_id = 'ALG001' AND (quarter_id > 3 OR quarter_id = 3 AND (student_id >= 1234)))"
+ " AND (course_id < 'GEO001' OR course_id = 'GEO001' AND (quarter_id < 2 OR quarter_id = 2 AND (student_id < 5678)))",
"(course_id > 'GEO001' OR course_id = 'GEO001' AND (quarter_id > 2 OR quarter_id = 2 AND (student_id >= 5678)))"
+ " AND (course_id < 'TRI001' OR course_id = 'TRI001' AND (quarter_id < 1 OR quarter_id = 1 AND (student_id < 9012)))",
"course_id > 'TRI001' OR course_id = 'TRI001' AND (quarter_id > 1 OR quarter_id = 1 AND (student_id >= 9012))",
]


@mock.patch(
"data_validation.state_manager.StateManager.get_connection_config",
return_value=BQ_CONN,
)
def test_bigquery_generate_table_partitions(mock_conn):
"""Test generate table partitions on BigQuery
The unit tests, specifically test_add_partition_filters_to_config and test_store_yaml_partitions_local
check that yaml configurations are created and saved in local storage. Partitions can only be created with
a database that can handle SQL with ntile, hence doing this as part of system testing.
What we are checking
1. the shape of the partition list is 1, number of partitions (only one table in the list)
2. value of the partition list matches what we expect.
"""
parser = cli_tools.configure_arg_parser()
args = parser.parse_args(
[
"generate-table-partitions",
"-sc=mock-conn",
"-tc=mock-conn",
"-tbls=pso_data_validator.test_generate_partitions=pso_data_validator.test_generate_partitions",
"-pk=course_id,quarter_id,student_id",
"-hash=*",
"-cdir=/home/users/yaml",
"-pn=4",
]
)
config_managers = main.build_config_managers_from_args(args, consts.ROW_VALIDATION)
partition_builder = PartitionBuilder(config_managers, args)
partition_filters = partition_builder._get_partition_key_filters()

assert len(partition_filters) == 1 # only one pair of tables
assert (
len(partition_filters[0]) == partition_builder.args.partition_num
) # assume no of table rows > partition_num
assert partition_filters[0] == EXPECTED_PARTITION_FILTER


@mock.patch(
"data_validation.state_manager.StateManager.get_connection_config",
return_value=BQ_CONN,
Expand Down
49 changes: 49 additions & 0 deletions tests/system/data_sources/test_hive.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from data_validation import __main__ as main
from data_validation import cli_tools, data_validation, consts
from data_validation.partition_builder import PartitionBuilder
from tests.system.data_sources.test_bigquery import BQ_CONN


Expand Down Expand Up @@ -98,6 +99,54 @@ def disabled_test_schema_validation_core_types():
assert len(df) == 0


# Expected result from partitioning table on 3 keys
EXPECTED_PARTITION_FILTER = [
"course_id < 'ALG001' OR course_id = 'ALG001' AND (quarter_id < 3 OR quarter_id = 3 AND (student_id < 1234))",
"(course_id > 'ALG001' OR course_id = 'ALG001' AND (quarter_id > 3 OR quarter_id = 3 AND (student_id >= 1234)))"
+ " AND (course_id < 'GEO001' OR course_id = 'GEO001' AND (quarter_id < 2 OR quarter_id = 2 AND (student_id < 5678)))",
"(course_id > 'GEO001' OR course_id = 'GEO001' AND (quarter_id > 2 OR quarter_id = 2 AND (student_id >= 5678)))"
+ " AND (course_id < 'TRI001' OR course_id = 'TRI001' AND (quarter_id < 1 OR quarter_id = 1 AND (student_id < 9012)))",
"course_id > 'TRI001' OR course_id = 'TRI001' AND (quarter_id > 1 OR quarter_id = 1 AND (student_id >= 9012))",
]


@mock.patch(
"data_validation.state_manager.StateManager.get_connection_config",
new=mock_get_connection_config,
)
def test_bigquery_generate_table_partitions():
"""Test generate table partitions on BigQuery
The unit tests, specifically test_add_partition_filters_to_config and test_store_yaml_partitions_local
check that yaml configurations are created and saved in local storage. Partitions can only be created with
a database that can handle SQL with ntile, hence doing this as part of system testing.
What we are checking
1. the shape of the partition list is 1, number of partitions (only one table in the list)
2. value of the partition list matches what we expect.
"""
parser = cli_tools.configure_arg_parser()
args = parser.parse_args(
[
"generate-table-partitions",
"-sc=hive-conn",
"-tc=hive-conn",
"-tbls=pso_data_validator.test_generate_partitions=pso_data_validator.test_generate_partitions",
"-pk=course_id,quarter_id,student_id",
"-hash=*",
"-cdir=/home/users/yaml",
"-pn=4",
]
)
config_managers = main.build_config_managers_from_args(args, consts.ROW_VALIDATION)
partition_builder = PartitionBuilder(config_managers, args)
partition_filters = partition_builder._get_partition_key_filters()

assert len(partition_filters) == 1 # only one pair of tables
assert (
len(partition_filters[0]) == partition_builder.args.partition_num
) # assume no of table rows > partition_num
assert partition_filters[0] == EXPECTED_PARTITION_FILTER


@mock.patch(
"data_validation.state_manager.StateManager.get_connection_config",
new=mock_get_connection_config,
Expand Down
50 changes: 49 additions & 1 deletion tests/system/data_sources/test_mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from data_validation import __main__ as main
from data_validation import cli_tools, data_validation, consts, exceptions

from data_validation.partition_builder import PartitionBuilder

MYSQL_HOST = os.getenv("MYSQL_HOST", "localhost")
MYSQL_USER = os.getenv("MYSQL_USER", "dvt")
Expand Down Expand Up @@ -426,6 +426,54 @@ def test_mysql_row():
pass


# Expected result from partitioning table on 3 keys
EXPECTED_PARTITION_FILTER = [
"course_id < 'ALG001' OR course_id = 'ALG001' AND (quarter_id < 3 OR quarter_id = 3 AND (student_id < 1234))",
"(course_id > 'ALG001' OR course_id = 'ALG001' AND (quarter_id > 3 OR quarter_id = 3 AND (student_id >= 1234)))"
+ " AND (course_id < 'GEO001' OR course_id = 'GEO001' AND (quarter_id < 2 OR quarter_id = 2 AND (student_id < 5678)))",
"(course_id > 'GEO001' OR course_id = 'GEO001' AND (quarter_id > 2 OR quarter_id = 2 AND (student_id >= 5678)))"
+ " AND (course_id < 'TRI001' OR course_id = 'TRI001' AND (quarter_id < 1 OR quarter_id = 1 AND (student_id < 9012)))",
"course_id > 'TRI001' OR course_id = 'TRI001' AND (quarter_id > 1 OR quarter_id = 1 AND (student_id >= 9012))",
]


@mock.patch(
"data_validation.state_manager.StateManager.get_connection_config",
return_value=CONN,
)
def test_mysql_generate_table_partitions(mock_conn):
"""Test generate table partitions on mysql
The unit tests, specifically test_add_partition_filters_to_config and test_store_yaml_partitions_local
check that yaml configurations are created and saved in local storage. Partitions can only be created with
a database that can handle SQL with ntile, hence doing this as part of system testing.
What we are checking
1. the shape of the partition list is 1, number of partitions (only one table in the list)
2. value of the partition list matches what we expect.
"""
parser = cli_tools.configure_arg_parser()
args = parser.parse_args(
[
"generate-table-partitions",
"-sc=mock-conn",
"-tc=mock-conn",
"-tbls=pso_data_validator.test_generate_partitions=pso_data_validator.test_generate_partitions",
"-pk=course_id,quarter_id,student_id",
"-hash=*",
"-cdir=/home/users/yaml",
"-pn=4",
]
)
config_managers = main.build_config_managers_from_args(args, consts.ROW_VALIDATION)
partition_builder = PartitionBuilder(config_managers, args)
partition_filters = partition_builder._get_partition_key_filters()

assert len(partition_filters) == 1 # only one pair of tables
assert (
len(partition_filters[0]) == partition_builder.args.partition_num
) # assume no of table rows > partition_num
assert partition_filters[0] == EXPECTED_PARTITION_FILTER


@mock.patch(
"data_validation.state_manager.StateManager.get_connection_config",
return_value=CONN,
Expand Down
Loading

0 comments on commit aa84d7a

Please sign in to comment.