Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Support for custom query #390

Merged
merged 12 commits into from
Mar 23, 2022
13 changes: 12 additions & 1 deletion data_validation/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,15 @@ def build_config_from_args(args, config_manager):

# TODO(GH#18): Add query filter config logic


if config_manager.validation_type == consts.CUSTOM_QUERY:
config_manager.append_aggregates(get_aggregate_config(args, config_manager))
if args.source_query_file is not None:
query_file = cli_tools.get_arg_list(args.source_query_file)
config_manager.append_source_query_file(query_file)
if args.target_query_file is not None:
query_file = cli_tools.get_arg_list(args.target_query_file)
config_manager.append_target_query_file(query_file)
return config_manager


Expand All @@ -165,6 +174,8 @@ def build_config_managers_from_args(args):
config_type = consts.COLUMN_VALIDATION
elif validate_cmd == "Row":
config_type = consts.ROW_VALIDATION
elif validate_cmd == "Custom-query":
config_type = consts.CUSTOM_QUERY
else:
raise ValueError(f"Unknown Validation Type: {validate_cmd}")
else:
Expand Down Expand Up @@ -429,7 +440,7 @@ def run_validation_configs(args):

def validate(args):
""" Run commands related to data validation."""
if args.validate_cmd in ["column", "row", "schema"]:
if args.validate_cmd in ["column", "row", "schema","custom-query"]:
run(args)
else:
raise ValueError(f"Validation Argument '{args.validate_cmd}' is not supported")
Expand Down
84 changes: 83 additions & 1 deletion data_validation/cli_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,6 @@ def configure_arg_parser():
_configure_raw_query(subparsers)
_configure_run_parser(subparsers)
_configure_beta_parser(subparsers)

return parser


Expand Down Expand Up @@ -385,6 +384,10 @@ def _configure_validate_parser(subparsers):
)
_configure_schema_parser(schema_parser)

custom_query_parser = validate_subparsers.add_parser(
"custom-query", help="Run a custom query validation"
)
_configure_custom_query_parser(custom_query_parser)

def _configure_row_parser(row_parser):
"""Configure arguments to run row level validations."""
Expand Down Expand Up @@ -530,6 +533,85 @@ def _configure_schema_parser(schema_parser):
"""Configure arguments to run column level validations."""
_add_common_arguments(schema_parser)

def _configure_custom_query_parser(custom_query_parser):
"""Configure arguments to run custom-query validations."""
_add_common_arguments(custom_query_parser)
custom_query_parser.add_argument(
"--source-query-file",
"-sqf",
help="File containing the source sql commands",
)
custom_query_parser.add_argument(
"--target-query-file",
"-tqf",
help="File containing the source sql commands",
Robby29 marked this conversation as resolved.
Show resolved Hide resolved
)
custom_query_parser.add_argument(
"--count",
"-count",
help="Comma separated list of columns for count 'col_a,col_b' or * for all columns",
)
custom_query_parser.add_argument(
"--sum",
"-sum",
help="Comma separated list of columns for sum 'col_a,col_b' or * for all columns",
)
custom_query_parser.add_argument(
"--avg",
"-avg",
help="Comma separated list of columns for avg 'col_a,col_b' or * for all columns",
)
custom_query_parser.add_argument(
"--min",
"-min",
help="Comma separated list of columns for min 'col_a,col_b' or * for all columns",
)
custom_query_parser.add_argument(
"--max",
"-max",
help="Comma separated list of columns for max 'col_a,col_b' or * for all columns",
)
custom_query_parser.add_argument(
"--bit_xor",
"-bit_xor",
help="Comma separated list of columns for hashing a concatenate 'col_a,col_b' or * for all columns",
)
custom_query_parser.add_argument(
"--hash",
"-hash",
help="Comma separated list of columns for hashing a concatenate 'col_a,col_b' or * for all columns",
)
custom_query_parser.add_argument(
"--filters",
"-filters",
help="Filters in the format source_filter:target_filter",
)
custom_query_parser.add_argument(
"--labels", "-l", help="Key value pair labels for validation run"
)
custom_query_parser.add_argument(
"--threshold",
"-th",
type=threshold_float,
help="Float max threshold for percent difference",
)
custom_query_parser.add_argument(
"--use-random-row",
"-rr",
action="store_true",
help="Finds a set of random rows of the first primary key supplied.",
)
custom_query_parser.add_argument(
"--random-row-batch-size",
"-rbs",
help="Row batch size used for random row filters (default 10,000).",
)
custom_query_parser.add_argument(
"--primary-keys",
"-pk",
help="Comma separated list of primary key columns 'col_a,col_b'",
)


def _add_common_arguments(parser):
parser.add_argument("--source-conn", "-sc", help="Source connection name")
Expand Down
3 changes: 3 additions & 0 deletions data_validation/combiner.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def generate_report(
"Expected source and target to have same schema, got "
f"source: {source_names} target: {target_names}"
)

differences_pivot = _calculate_differences(
source, target, join_on_fields, run_metadata.validations, is_value_comparison
)
Expand Down Expand Up @@ -185,6 +186,7 @@ def _calculate_differences(
)
]
)

differences_pivot = functools.reduce(
lambda pivot1, pivot2: pivot1.union(pivot2), differences_pivots
)
Expand Down Expand Up @@ -224,6 +226,7 @@ def _pivot_result(result, join_on_fields, validations, result_type):
+ join_on_fields
)
)

pivot = functools.reduce(lambda pivot1, pivot2: pivot1.union(pivot2), pivots)
return pivot

Expand Down
22 changes: 22 additions & 0 deletions data_validation/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,28 @@ def append_query_groups(self, grouped_column_configs):
self.query_groups + grouped_column_configs
)

@property
def source_query_file(self):
""" Return SQL Query File from Config """
return self._config.get(consts.CONFIG_SOURCE_QUERY_FILE, [])

def append_source_query_file(self, query_file_configs):
"""Append grouped configs to existing config."""
self._config[consts.CONFIG_SOURCE_QUERY_FILE] = (
self.source_query_file + query_file_configs
)

@property
def target_query_file(self):
""" Return SQL Query File from Config """
return self._config.get(consts.CONFIG_TARGET_QUERY_FILE, [])

def append_target_query_file(self, query_file_configs):
"""Append grouped configs to existing config."""
self._config[consts.CONFIG_TARGET_QUERY_FILE] = (
self.target_query_file + query_file_configs
)

@property
def primary_keys(self):
""" Return Primary keys from Config """
Expand Down
4 changes: 4 additions & 0 deletions data_validation/consts.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@
CONFIG_FILTER_SOURCE = "source"
CONFIG_FILTER_TARGET = "target"
CONFIG_MAX_RECURSIVE_QUERY_SIZE = "max_recursive_query_size"
CONFIG_SOURCE_QUERY_FILE = "source_query_file"
CONFIG_TARGET_QUERY_FILE = "target_query_file"

CONFIG_FILTER_SOURCE_COLUMN = "source_column"
CONFIG_FILTER_SOURCE_VALUE = "source_value"
Expand All @@ -71,12 +73,14 @@
GROUPED_COLUMN_VALIDATION = "GroupedColumn"
ROW_VALIDATION = "Row"
SCHEMA_VALIDATION = "Schema"
CUSTOM_QUERY = "Custom-query"

CONFIG_TYPES = [
COLUMN_VALIDATION,
GROUPED_COLUMN_VALIDATION,
ROW_VALIDATION,
SCHEMA_VALIDATION,
CUSTOM_QUERY
]

# State Manager Fields
Expand Down
12 changes: 11 additions & 1 deletion data_validation/data_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,8 @@ def _execute_validation(self, validation_builder, process_in_memory=True):
pd_schema = self._get_pandas_schema(
source_df, target_df, join_on_fields, verbose=self.verbose
)
source_df = self._adjust_schema(source_df)
target_df = self._adjust_schema(target_df)
pandas_client = ibis.backends.pandas.connect(
{combiner.DEFAULT_SOURCE: source_df, combiner.DEFAULT_TARGET: target_df}
)
Expand Down Expand Up @@ -338,7 +340,7 @@ def _execute_validation(self, validation_builder, process_in_memory=True):
)

return result_df

def combine_data(self, source_df, target_df, join_on_fields):
""" TODO: Return List of Dictionaries """
# Clean Data to Standardize
Expand All @@ -357,3 +359,11 @@ def combine_data(self, source_df, target_df, join_on_fields):
rsuffix=consts.OUTPUT_SUFFIX,
)
return df

def _adjust_schema(self,data_frame):
Robby29 marked this conversation as resolved.
Show resolved Hide resolved
""" Fix schema differences introduced because of ibis sql function """
substitutions = {}
for column_name in data_frame.columns:
if 't0.' in column_name:
substitutions[column_name] = column_name[3:]
return data_frame.rename(columns=substitutions)
72 changes: 69 additions & 3 deletions data_validation/validation_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
# limitations under the License.

from copy import deepcopy

from sqlalchemy import column
from data_validation import consts, metadata
from data_validation.query_builder.query_builder import (
QueryBuilder,
Expand Down Expand Up @@ -73,7 +75,7 @@ def clone(self):
@staticmethod
def get_query_builder(validation_type):
""" Return Query Builder object given validation type """
if validation_type in ["Column", "GroupedColumn", "Row", "Schema"]:
if validation_type in ["Column", "GroupedColumn", "Row", "Schema", "Custom-query"]:
builder = QueryBuilder.build_count_validator()
else:
msg = "Validation Builder supplied unknown type: %s" % validation_type
Expand Down Expand Up @@ -327,7 +329,21 @@ def get_source_query(self):
"schema_name": self.config_manager.source_schema,
"table_name": self.config_manager.source_table,
}
query = self.source_builder.compile(**source_config)
if self.validation_type == consts.CUSTOM_QUERY:
source_input_query = self.get_query_from_file(self.config_manager.source_query_file[0])
source_aggregate_query = "SELECT "
for aggregate in self.config_manager.aggregates:
source_aggregate_query += self.get_aggregation_query(
aggregate.get("type"),
aggregate.get("target_column")
)
source_aggregate_query = self.get_wrapper_aggregation_query(
source_aggregate_query,
source_input_query
)
query = self.source_client.sql(source_aggregate_query)
else:
query = self.source_builder.compile(**source_config)
if self.verbose:
print(source_config)
print("-- ** Source Query ** --")
Expand All @@ -342,7 +358,22 @@ def get_target_query(self):
"schema_name": self.config_manager.target_schema,
"table_name": self.config_manager.target_table,
}
query = self.target_builder.compile(**target_config)
if self.validation_type == consts.CUSTOM_QUERY:
target_input_query = self.get_query_from_file(self.config_manager.target_query_file[0])
target_aggregate_query = "SELECT "
for aggregate in self.config_manager.aggregates:
target_aggregate_query += self.get_aggregation_query(
aggregate.get("type"),
aggregate.get("target_column")
)

target_aggregate_query = self.get_wrapper_aggregation_query(
target_aggregate_query,
target_input_query
)
query = self.target_client.sql(target_aggregate_query)
else:
query = self.target_builder.compile(**target_config)
if self.verbose:
print(target_config)
print("-- ** Target Query ** --")
Expand All @@ -358,3 +389,38 @@ def add_query_limit(self):
limit = self.config_manager.query_limit
self.source_builder.limit = limit
self.target_builder.limit = limit

def get_query_from_file(self,filename):
""" Return query from input file """
query = ""
try:
file = open(filename, "r")
query = file.read()
except IOError:
print("Cannot read query file: ",filename)

if not query or query.isspace():
raise ValueError(
"Expected file with sql query, got empty file or file with white spaces. "
f"input file: {filename}"
)
file.close()
return query

def get_aggregation_query(self, agg_type, column_name):
Robby29 marked this conversation as resolved.
Show resolved Hide resolved
""" Return aggregation query """
aggregation_query = ""
if column_name is None:
aggregation_query = agg_type + "(*) as " + \
agg_type + ","
else:
aggregation_query = agg_type + \
"(" + column_name + ") as " + \
agg_type + "__" + column_name + ","
return aggregation_query

def get_wrapper_aggregation_query(self, aggregate_query, base_query):
""" Return wrapper aggregation query """

return aggregate_query[:len(aggregate_query)-1] + \
" FROM (" + base_query + ") as base_query"