From f057fe8d690c78219f6341d210ba9719d4510fd6 Mon Sep 17 00:00:00 2001 From: Robby <40561036+Robby29@users.noreply.github.com> Date: Thu, 28 Apr 2022 13:59:14 +0530 Subject: [PATCH] feat!: Adds custom query row level hash validation feature. (#440) * added custom-query sub-option to validate command * add source and target query option in custom query * added min,max,sum aggregates with custom query * fixed hive t0 column name addition issue * added empty query file check * linting fixes * added unit tests * incorporated black linting changes * incorporated flake linter changes * Fixed result schema status to validation_status to avoid duplicate column names * Fixed linting on tests folder * BREAKING CHANGE: update BQ results schema column name 'status' to 'validation_status' * Added script to update Bigquery schema * Moved bq_utils to right folder * Updated bash script path and formatting * Added custom query row validation feature * Incorporated black and flake8 linting changes. * Added wildcard-include-string-len sub option * Fixed custom query column bug * Made changes as per review from @dhercher * new changes according to Neha's review requests * changed custom query type from list to string * made custom query type argument required=true * typo changes Co-authored-by: raniksingh Co-authored-by: Neha Nene --- README.md | 54 +++- data_validation/__main__.py | 6 + data_validation/cli_tools.py | 12 + data_validation/combiner.py | 11 +- data_validation/config_manager.py | 11 + data_validation/consts.py | 2 +- data_validation/data_validation.py | 9 + .../query_builder/custom_query_builder.py | 242 ++++++++++++++++++ .../query_builder/query_builder.py | 2 - .../query_builder/random_row_builder.py | 2 +- data_validation/validation_builder.py | 92 +++---- docs/examples.md | 13 +- .../test_custom_query_builder.py | 77 ++++++ tests/unit/test_validation_builder.py | 29 --- 14 files changed, 473 insertions(+), 89 deletions(-) create mode 100644 data_validation/query_builder/custom_query_builder.py create mode 100644 tests/unit/query_builder/test_custom_query_builder.py diff --git a/README.md b/README.md index 67a94971f..c25f48093 100644 --- a/README.md +++ b/README.md @@ -230,9 +230,9 @@ data-validation (--verbose or -v) validate schema Defaults to table. ``` -#### Custom Query Validations +### Custom Query Column Validations -Below is the command syntax for custom query validations. +Below is the command syntax for custom query column validations. ``` data-validation (--verbose or -v) validate custom-query @@ -246,7 +246,10 @@ data-validation (--verbose or -v) validate custom-query Comma separated list of tables in the form schema.table=target_schema.target_table Target schema name and table name are optional. i.e 'bigquery-public-data.new_york_citibike.citibike_trips' - --source-query-file SOURCE_QUERY_FILE, -sqf SOURCE_QUERY_FILE + --custom-query-type CUSTOM_QUERY_TYPE, -cqt CUSTOM_QUERY_TYPE + Type of custom query validation: ('row'|'column') + Enter 'column' for custom query column validation + --source-query-file SOURCE_QUERY_FILE, -sqf SOURCE_QUERY_FILE File containing the source sql commands --target-query-file TARGET_QUERY_FILE, -tqf TARGET_QUERY_FILE File containing the target sql commands @@ -273,6 +276,51 @@ The [Examples](docs/examples.md) page provides few examples of how this tool can used to run custom query validations. +### Custom Query Row Validations + +#### (Note: Row hash validation is currently only supported for BigQuery, Imapala/Hive and Teradata) + +Below is the command syntax for row validations. In order to run row level +validations you need to pass `--hash` flag with `*` value which means all the fields +of the custom query result will be concatenated and hashed. + +Below is the command syntax for custom query row validations. + +``` +data-validation (--verbose or -v) validate custom-query + --source-conn or -sc SOURCE_CONN + Source connection details + See: *Data Source Configurations* section for each data source + --target-conn or -tc TARGET_CONN + Target connection details + See: *Connections* section for each data source + --tables-list or -tbls SOURCE_SCHEMA.SOURCE_TABLE=TARGET_SCHEMA.TARGET_TABLE + Comma separated list of tables in the form schema.table=target_schema.target_table + Target schema name and table name are optional. + i.e 'bigquery-public-data.new_york_citibike.citibike_trips' + --custom-query-type CUSTOM_QUERY_TYPE, -cqt CUSTOM_QUERY_TYPE + Type of custom query validation: ('row'|'column') + Enter 'row' for custom query column validation + --source-query-file SOURCE_QUERY_FILE, -sqf SOURCE_QUERY_FILE + File containing the source sql commands + --target-query-file TARGET_QUERY_FILE, -tqf TARGET_QUERY_FILE + File containing the target sql commands + --hash '*' '*' to hash all columns. + [--bq-result-handler or -bqrh PROJECT_ID.DATASET.TABLE] + BigQuery destination for validation results. Defaults to stdout. + See: *Validation Reports* section + [--service-account or -sa PATH_TO_SA_KEY] + Service account to use for BigQuery result handler output. + [--labels or -l KEY1=VALUE1,KEY2=VALUE2] + Comma-separated key value pair labels for the run. + [--format or -fmt] Format for stdout output. Supported formats are (text, csv, json, table). + Defaults to table. +``` + +The [Examples](docs/examples.md) page provides few examples of how this tool can +used to run custom query row validations. + + ### Running Custom SQL Exploration There are many occasions where you need to explore a data source while running diff --git a/data_validation/__main__.py b/data_validation/__main__.py index 081523308..922b8bab1 100644 --- a/data_validation/__main__.py +++ b/data_validation/__main__.py @@ -177,6 +177,12 @@ def build_config_from_args(args, config_manager): if config_manager.validation_type == consts.CUSTOM_QUERY: config_manager.append_aggregates(get_aggregate_config(args, config_manager)) + if args.custom_query_type is not None: + config_manager.append_custom_query_type(args.custom_query_type) + else: + raise ValueError( + "Expected custom query type to be given, got empty string." + ) if args.source_query_file is not None: query_file = cli_tools.get_arg_list(args.source_query_file) config_manager.append_source_query_file(query_file) diff --git a/data_validation/cli_tools.py b/data_validation/cli_tools.py index 288a33c87..f9fd22e68 100644 --- a/data_validation/cli_tools.py +++ b/data_validation/cli_tools.py @@ -537,6 +537,12 @@ def _configure_schema_parser(schema_parser): def _configure_custom_query_parser(custom_query_parser): """Configure arguments to run custom-query validations.""" _add_common_arguments(custom_query_parser) + custom_query_parser.add_argument( + "--custom-query-type", + "-cqt", + required=True, + help="Which type of custom query (row/column)", + ) custom_query_parser.add_argument( "--source-query-file", "-sqf", @@ -609,6 +615,12 @@ def _configure_custom_query_parser(custom_query_parser): "-pk", help="Comma separated list of primary key columns 'col_a,col_b'", ) + custom_query_parser.add_argument( + "--wildcard-include-string-len", + "-wis", + action="store_true", + help="Include string fields for wildcard aggregations.", + ) def _add_common_arguments(parser): diff --git a/data_validation/combiner.py b/data_validation/combiner.py index bbc29ada5..b2d430409 100644 --- a/data_validation/combiner.py +++ b/data_validation/combiner.py @@ -75,9 +75,11 @@ def generate_report( differences_pivot = _calculate_differences( source, target, join_on_fields, run_metadata.validations, is_value_comparison ) + source_pivot = _pivot_result( source, join_on_fields, run_metadata.validations, consts.RESULT_TYPE_SOURCE ) + target_pivot = _pivot_result( target, join_on_fields, run_metadata.validations, consts.RESULT_TYPE_TARGET ) @@ -161,7 +163,6 @@ def _calculate_difference(field_differences, datatype, validation, is_value_comp .else_(consts.VALIDATION_STATUS_SUCCESS) .end() ) - return ( difference.name("difference"), pct_difference.name("pct_difference"), @@ -190,7 +191,6 @@ def _calculate_differences( # When no join_on_fields are present, we expect only one row per table. # This is validated in generate_report before this function is called. differences_joined = source.cross_join(target) - differences_pivots = [] for field, field_type in schema.items(): if field not in validations: @@ -213,7 +213,6 @@ def _calculate_differences( ) ] ) - differences_pivot = functools.reduce( lambda pivot1, pivot2: pivot1.union(pivot2), differences_pivots ) @@ -222,7 +221,11 @@ def _calculate_differences( def _pivot_result(result, join_on_fields, validations, result_type): all_fields = frozenset(result.schema().names) - validation_fields = all_fields - frozenset(join_on_fields) + validation_fields = ( + all_fields - frozenset(join_on_fields) + if "hash__all" not in join_on_fields + else all_fields + ) pivots = [] for field in validation_fields: diff --git a/data_validation/config_manager.py b/data_validation/config_manager.py index 5bcb6b95b..31387cc49 100644 --- a/data_validation/config_manager.py +++ b/data_validation/config_manager.py @@ -150,6 +150,17 @@ def append_query_groups(self, grouped_column_configs): self.query_groups + grouped_column_configs ) + @property + def custom_query_type(self): + """Return custom query type from config""" + return self._config.get(consts.CONFIG_CUSTOM_QUERY_TYPE, "") + + def append_custom_query_type(self, custom_query_type): + """Append custom query type config to existing config.""" + self._config[consts.CONFIG_CUSTOM_QUERY_TYPE] = ( + self.custom_query_type + custom_query_type + ) + @property def source_query_file(self): """Return SQL Query File from Config""" diff --git a/data_validation/consts.py b/data_validation/consts.py index 8b69ca882..142afe7ae 100644 --- a/data_validation/consts.py +++ b/data_validation/consts.py @@ -50,7 +50,7 @@ CONFIG_MAX_RECURSIVE_QUERY_SIZE = "max_recursive_query_size" CONFIG_SOURCE_QUERY_FILE = "source_query_file" CONFIG_TARGET_QUERY_FILE = "target_query_file" - +CONFIG_CUSTOM_QUERY_TYPE = "custom_query_type" CONFIG_FILTER_SOURCE_COLUMN = "source_column" CONFIG_FILTER_SOURCE_VALUE = "source_value" CONFIG_FILTER_TARGET_COLUMN = "target_column" diff --git a/data_validation/data_validation.py b/data_validation/data_validation.py index f28e734cb..562a802e2 100644 --- a/data_validation/data_validation.py +++ b/data_validation/data_validation.py @@ -291,10 +291,19 @@ def _execute_validation(self, validation_builder, process_in_memory=True): if self.config_manager.validation_type == consts.ROW_VALIDATION else set(validation_builder.get_group_aliases()) ) + if ( + self.config_manager.validation_type == consts.CUSTOM_QUERY + and self.config_manager.custom_query_type == "row" + ): + join_on_fields = set(["hash__all"]) # If row validation from YAML, compare source and target agg values is_value_comparison = ( self.config_manager.validation_type == consts.ROW_VALIDATION + or ( + self.config_manager.validation_type == consts.CUSTOM_QUERY + and self.config_manager.custom_query_type == "row" + ) ) if process_in_memory: diff --git a/data_validation/query_builder/custom_query_builder.py b/data_validation/query_builder/custom_query_builder.py new file mode 100644 index 000000000..3644c0875 --- /dev/null +++ b/data_validation/query_builder/custom_query_builder.py @@ -0,0 +1,242 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" The QueryBuilder for building custom query row|column validation.""" + + +class CustomQueryBuilder(object): + def __init__(self): + """Build a CustomQueryBuilder object which is ready to build a custom query row|column validation nested query.""" + + def get_aggregation_query(self, agg_type, column_name): + """Return aggregation query""" + aggregation_query = "" + if column_name is None: + aggregation_query = agg_type + "(*) as " + agg_type + "," + else: + aggregation_query = ( + agg_type + + "(" + + column_name + + ") as " + + agg_type + + "__" + + column_name + + "," + ) + return aggregation_query + + def get_wrapper_aggregation_query(self, aggregate_query, base_query): + """Return wrapper aggregation query""" + + return ( + aggregate_query[: len(aggregate_query) - 1] + + " FROM (" + + base_query + + ") as base_query" + ) + + def compile_custom_query(self, input_query, client_config): + """Returns the nested sql query calculated from the input query + by adding calculated fields. + Args: + input_query (InputQuery): User provided sql query + """ + base_tbl_expr = self.get_table_expression(input_query, client_config) + base_df = self.get_data_frame(base_tbl_expr, client_config) + base_df_columns = self.compile_df_fields(base_df) + calculated_columns = self.get_calculated_columns(base_df_columns) + cast_query = self.compile_cast_df_fields( + calculated_columns, input_query, base_df + ) + ifnull_query = self.compile_ifnull_df_fields( + calculated_columns, cast_query, client_config + ) + rstrip_query = self.compile_rstrip_df_fields(calculated_columns, ifnull_query) + upper_query = self.compile_upper_df_fields(calculated_columns, rstrip_query) + concat_query = self.compile_concat_df_fields( + calculated_columns, upper_query, client_config + ) + sha2_query = self.compile_sha2_df_fields(concat_query, client_config) + return sha2_query + + def get_table_expression(self, input_query, client_config): + """Returns the ibis table expression for the input query.""" + return client_config["data_client"].sql(input_query) + + def get_data_frame(self, base_tbl_expr, client_config): + """Returns the data frame for the table expression.""" + return client_config["data_client"].execute(base_tbl_expr) + + def compile_df_fields(self, data_frame): + """Returns the list of columns in the dataframe. + Args: + data_frame (DataFrame): Pandas Dataframe + """ + return list(data_frame.columns.values) + + def get_calculated_columns(self, df_columns): + """Returns the dictionary containing the calculated fields.""" + + calculated_columns = {} + calculated_columns["columns"] = df_columns + calculated_columns["cast"] = [] + for column in df_columns: + current_column = "cast__" + column + calculated_columns["cast"].append(current_column) + + calculated_columns["ifnull"] = [] + for column in calculated_columns["cast"]: + current_column = "ifnull__" + column + calculated_columns["ifnull"].append(current_column) + + calculated_columns["rstrip"] = [] + for column in calculated_columns["ifnull"]: + current_column = "rstrip__" + column + calculated_columns["rstrip"].append(current_column) + + calculated_columns["upper"] = [] + for column in calculated_columns["rstrip"]: + current_column = "upper__" + column + calculated_columns["upper"].append(current_column) + + return calculated_columns + + def compile_cast_df_fields(self, calculated_columns, input_query, data_frame): + """Returns the wrapper cast query for the input query.""" + + query = "SELECT " + for column in calculated_columns["cast"]: + df_column = column[len("cast__") :] + df_column_dtype = data_frame[df_column].dtype.name + if df_column_dtype != "object" and df_column_dtype != "string": + query = ( + query + "CAST(" + df_column + " AS string)" + " AS " + column + "," + ) + else: + query += df_column + " AS " + column + "," + + query = query[: len(query) - 1] + " FROM (" + input_query + ") AS base_query" + return query + + def compile_ifnull_df_fields(self, calculated_columns, cast_query, client_config): + """Returns the wrapper ifnull query for the input cast_query.""" + + client = client_config["data_client"]._source_type + if client == "Impala": + operation = "COALESCE" + elif client == "BigQuery": + operation = "IFNULL" + query = "SELECT " + for column in calculated_columns["ifnull"]: + query = ( + query + + operation + + "(" + + column[len("ifnull__") :] + + ",'DEFAULT_REPLACEMENT_STRING')" + + " AS " + + column + + "," + ) + query = query[: len(query) - 1] + " FROM (" + cast_query + ") AS cast_query" + return query + + def compile_rstrip_df_fields(self, calculated_columns, ifnull_query): + """Returns the wrapper rstrip query for the input ifnull_query.""" + + operation = "RTRIM" + query = "SELECT " + for column in calculated_columns["rstrip"]: + query = ( + query + + operation + + "(" + + column[len("rstrip__") :] + + ")" + + " AS " + + column + + "," + ) + query = query[: len(query) - 1] + " FROM (" + ifnull_query + ") AS ifnull_query" + return query + + def compile_upper_df_fields(self, calculated_columns, rstrip_query): + """Returns the wrapper upper query for the input rstrip_query.""" + + query = "SELECT " + for column in calculated_columns["upper"]: + query = ( + query + + "UPPER(" + + column[len("upper__") :] + + ")" + + " AS " + + column + + "," + ) + query = query[: len(query) - 1] + " FROM (" + rstrip_query + ") AS rstrip_query" + return query + + def compile_concat_df_fields(self, calculated_columns, upper_query, client_config): + """Returns the wrapper concat query for the input upper_query.""" + + client = client_config["data_client"]._source_type + if client == "Impala": + operation = "CONCAT_WS" + query = "SELECT " + operation + "(','," + for column in calculated_columns["upper"]: + query += column + "," + query = ( + query[: len(query) - 1] + + ") AS concat__all FROM(" + + upper_query + + ") AS upper_query" + ) + elif client == "BigQuery": + operation = "ARRAY_TO_STRING" + query = "SELECT " + operation + "([" + for column in calculated_columns["upper"]: + query += column + "," + query = ( + query[: len(query) - 1] + + "],',') AS concat__all FROM(" + + upper_query + + ") AS upper_query" + ) + return query + + def compile_sha2_df_fields(self, concat_query, client_config): + """Returns the wrapper sha2 query for the input concat_query.""" + + client = client_config["data_client"]._source_type + if client == "Impala": + operation = "SHA2" + query = ( + "SELECT " + + operation + + "(concat__all,256) AS hash__all FROM (" + + concat_query + + ") AS concat_query" + ) + elif client == "BigQuery": + operation = "TO_HEX" + query = ( + "SELECT " + + operation + + "(SHA256(concat__all)) AS hash__all FROM (" + + concat_query + + ") AS concat_query" + ) + return query diff --git a/data_validation/query_builder/query_builder.py b/data_validation/query_builder/query_builder.py index 8b34ad4c5..e83942a0c 100644 --- a/data_validation/query_builder/query_builder.py +++ b/data_validation/query_builder/query_builder.py @@ -538,7 +538,6 @@ def add_grouped_field(self, grouped_field): """Add a GroupedField instance to the query which represents adding a column to group by in the query being built. - Args: grouped_field (GroupedField): A GroupedField instance """ @@ -548,7 +547,6 @@ def add_filter_field(self, filter_obj): """Add a FilterField instance to your query which will add the desired filter to your compiled query (ie. WHERE query_filter=True) - Args: filter_obj (FilterField): A FilterField instance """ diff --git a/data_validation/query_builder/random_row_builder.py b/data_validation/query_builder/random_row_builder.py index f85385848..c94173e7c 100644 --- a/data_validation/query_builder/random_row_builder.py +++ b/data_validation/query_builder/random_row_builder.py @@ -73,7 +73,7 @@ def resolve_name(self): class RandomRowBuilder(object): def __init__(self, primary_keys: List[str], batch_size: int): - """Build a RandomRowBuilder objct which is ready to build a random row filter query. + """Build a RandomRowBuilder object which is ready to build a random row filter query. Args: primary_keys: A list of primary key field strings used to find random rows. diff --git a/data_validation/validation_builder.py b/data_validation/validation_builder.py index d7a3d3590..561c266fe 100644 --- a/data_validation/validation_builder.py +++ b/data_validation/validation_builder.py @@ -21,6 +21,7 @@ CalculatedField, FilterField, ) +from data_validation.query_builder.custom_query_builder import CustomQueryBuilder class ValidationBuilder(object): @@ -337,15 +338,30 @@ def get_source_query(self): source_input_query = self.get_query_from_file( self.config_manager.source_query_file[0] ) - source_aggregate_query = "SELECT " - for aggregate in self.config_manager.aggregates: - source_aggregate_query += self.get_aggregation_query( - aggregate.get("type"), aggregate.get("target_column") + if self.config_manager.custom_query_type == "row": + calculated_query = CustomQueryBuilder().compile_custom_query( + source_input_query, source_config + ) + query = self.source_client.sql(calculated_query) + elif self.config_manager.custom_query_type == "column": + source_aggregate_query = "SELECT " + for aggregate in self.config_manager.aggregates: + source_aggregate_query += ( + CustomQueryBuilder().get_aggregation_query( + aggregate.get("type"), aggregate.get("target_column") + ) + ) + source_aggregate_query = ( + CustomQueryBuilder().get_wrapper_aggregation_query( + source_aggregate_query, source_input_query + ) + ) + query = self.source_client.sql(source_aggregate_query) + else: + raise ValueError( + "Expected custom query type to be column or row, got an unacceptable value. " + f"Input custom query type: {self.config_manager.custom_query_type}" ) - source_aggregate_query = self.get_wrapper_aggregation_query( - source_aggregate_query, source_input_query - ) - query = self.source_client.sql(source_aggregate_query) else: query = self.source_builder.compile(**source_config) if self.verbose: @@ -366,16 +382,30 @@ def get_target_query(self): target_input_query = self.get_query_from_file( self.config_manager.target_query_file[0] ) - target_aggregate_query = "SELECT " - for aggregate in self.config_manager.aggregates: - target_aggregate_query += self.get_aggregation_query( - aggregate.get("type"), aggregate.get("target_column") + if self.config_manager.custom_query_type == "row": + calculated_query = CustomQueryBuilder().compile_custom_query( + target_input_query, target_config + ) + query = self.target_client.sql(calculated_query) + elif self.config_manager.custom_query_type == "column": + target_aggregate_query = "SELECT " + for aggregate in self.config_manager.aggregates: + target_aggregate_query += ( + CustomQueryBuilder().get_aggregation_query( + aggregate.get("type"), aggregate.get("target_column") + ) + ) + target_aggregate_query = ( + CustomQueryBuilder().get_wrapper_aggregation_query( + target_aggregate_query, target_input_query + ) + ) + query = self.target_client.sql(target_aggregate_query) + else: + raise ValueError( + "Expected custom query type to be column or row, got an unacceptable value. " + f"Input custom query type: {self.config_manager.custom_query_type}" ) - - target_aggregate_query = self.get_wrapper_aggregation_query( - target_aggregate_query, target_input_query - ) - query = self.target_client.sql(target_aggregate_query) else: query = self.target_builder.compile(**target_config) if self.verbose: @@ -410,31 +440,3 @@ def get_query_from_file(self, filename): ) file.close() return query - - def get_aggregation_query(self, agg_type, column_name): - """Return aggregation query""" - aggregation_query = "" - if column_name is None: - aggregation_query = agg_type + "(*) as " + agg_type + "," - else: - aggregation_query = ( - agg_type - + "(" - + column_name - + ") as " - + agg_type - + "__" - + column_name - + "," - ) - return aggregation_query - - def get_wrapper_aggregation_query(self, aggregate_query, base_query): - """Return wrapper aggregation query""" - - return ( - aggregate_query[: len(aggregate_query) - 1] - + " FROM (" - + base_query - + ") as base_query" - ) diff --git a/docs/examples.md b/docs/examples.md index 650703e67..20f24835a 100644 --- a/docs/examples.md +++ b/docs/examples.md @@ -185,19 +185,24 @@ validations: type: Column ``` -#### Run a custom query validation +#### Run a custom query column validation ````shell script -data-validation validate custom-query --source-query-file source_query.sql --target-query-file target_query.sql -sc my_bq_conn -tc my_bq_conn -tbls bigquery-public-data.new_york_citibike.citibike_stations +data-validation validate --custom-query-type column custom-query --source-query-file source_query.sql --target-query-file target_query.sql -sc my_bq_conn -tc my_bq_conn -tbls bigquery-public-data.new_york_citibike.citibike_stations ```` #### Run a custom query validation with sum aggregation ````shell script -data-validation validate custom-query --source-query-file source_query.sql --target-query-file target_query.sql -sc my_bq_conn -tc my_bq_conn -tbls bigquery-public-data.new_york_citibike.citibike_stations --sum num_bikes_available +data-validation validate custom-query --custom-query-type column --source-query-file source_query.sql --target-query-file target_query.sql -sc my_bq_conn -tc my_bq_conn -tbls bigquery-public-data.new_york_citibike.citibike_stations --sum num_bikes_available ```` #### Run a custom query validation with max aggregation ````shell script -data-validation validate custom-query --source-query-file source_query.sql --target-query-file target_query.sql -sc my_bq_conn -tc my_bq_conn -tbls bigquery-public-data.new_york_citibike.citibike_stations --max num_bikes_available +data-validation validate custom-query --custom-query-type column --source-query-file source_query.sql --target-query-file target_query.sql -sc my_bq_conn -tc my_bq_conn -tbls bigquery-public-data.new_york_citibike.citibike_stations --max num_bikes_available +```` + +#### Run a custom query row validation +````shell script +data-validation validate custom-query --custom-query-type row --source-query-file source_query.sql --target-query-file target_query.sql -sc my_bq_conn -tc my_bq_conn -tbls bigquery-public-data.new_york_citibike.citibike_stations --hash \'*\' ```` Please replace source_query.sql and target_query.sql with the correct files containing sql query for source and target database respectively. \ No newline at end of file diff --git a/tests/unit/query_builder/test_custom_query_builder.py b/tests/unit/query_builder/test_custom_query_builder.py new file mode 100644 index 000000000..11d6f8e72 --- /dev/null +++ b/tests/unit/query_builder/test_custom_query_builder.py @@ -0,0 +1,77 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from data_validation import consts + +INPUT_QUERY = "SELECT b.mascot, count(*) as count from dvt_testing.mascot b group by mascot order by 2 desc" +DATAFRAME_COLUMNS = ["mascot", "count"] +CALCULATED_COLUMNS = { + "columns": ["mascot", "count"], + "cast": ["cast__mascot", "cast__count"], + "ifnull": ["ifnull__cast__mascot", "ifnull__cast__count"], + "rstrip": ["rstrip__ifnull__cast__mascot", "rstrip__ifnull__cast__count"], + "upper": [ + "upper__rstrip__ifnull__cast__mascot", + "upper__rstrip__ifnull__cast__count", + ], +} +AGGREGATES_TEST = [ + { + consts.CONFIG_FIELD_ALIAS: "sum_starttime", + consts.CONFIG_SOURCE_COLUMN: "starttime", + consts.CONFIG_TARGET_COLUMN: "starttime", + consts.CONFIG_TYPE: "sum", + } +] +AGGREGATION_QUERY = "sum(starttime) as sum_starttime," +BASE_QUERY = "SELECT * FROM bigquery-public-data.usa_names.usa_1910_2013" + + +@pytest.fixture +def module_under_test(): + import data_validation.query_builder.custom_query_builder + + return data_validation.query_builder.custom_query_builder + + +def test_import(module_under_test): + assert module_under_test is not None + + +def test_get_calculated_columns(module_under_test): + calculated_columns = module_under_test.CustomQueryBuilder().get_calculated_columns( + DATAFRAME_COLUMNS + ) + assert calculated_columns == CALCULATED_COLUMNS + + +def test_custom_query_get_aggregation_query(module_under_test): + aggregation_query = module_under_test.CustomQueryBuilder().get_aggregation_query( + AGGREGATES_TEST[0]["type"], AGGREGATES_TEST[0]["source_column"] + ) + assert aggregation_query == "sum(starttime) as sum__starttime," + + +def test_custom_query_get_wrapper_aggregation_query(module_under_test): + wrapper_query = ( + module_under_test.CustomQueryBuilder().get_wrapper_aggregation_query( + AGGREGATION_QUERY, BASE_QUERY + ) + ) + assert ( + wrapper_query + == "sum(starttime) as sum_starttime FROM (SELECT * FROM bigquery-public-data.usa_names.usa_1910_2013) as base_query" + ) diff --git a/tests/unit/test_validation_builder.py b/tests/unit/test_validation_builder.py index a537fae42..e32bb510c 100644 --- a/tests/unit/test_validation_builder.py +++ b/tests/unit/test_validation_builder.py @@ -254,32 +254,3 @@ def test_custom_query_get_query_from_file(module_under_test): builder = module_under_test.ValidationBuilder(mock_config_manager) query = builder.get_query_from_file(builder.config_manager.source_query_file) assert query == "SELECT * FROM bigquery-public-data.usa_names.usa_1910_2013" - - -def test_custom_query_get_aggregation_query(module_under_test): - mock_config_manager = ConfigManager( - CUSTOM_QUERY_VALIDATION_CONFIG, - MockIbisClient(), - MockIbisClient(), - verbose=False, - ) - builder = module_under_test.ValidationBuilder(mock_config_manager) - aggregation_query = builder.get_aggregation_query( - AGGREGATES_TEST[0]["type"], AGGREGATES_TEST[0]["source_column"] - ) - assert aggregation_query == "sum(starttime) as sum__starttime," - - -def test_custom_query_get_wrapper_aggregation_query(module_under_test): - mock_config_manager = ConfigManager( - CUSTOM_QUERY_VALIDATION_CONFIG, - MockIbisClient(), - MockIbisClient(), - verbose=False, - ) - builder = module_under_test.ValidationBuilder(mock_config_manager) - wrapper_query = builder.get_wrapper_aggregation_query(AGGREGATION_QUERY, BASE_QUERY) - assert ( - wrapper_query - == "sum(starttime) as sum_starttime FROM (SELECT * FROM bigquery-public-data.usa_names.usa_1910_2013) as base_query" - )