From ca598a0a1ca80541f98b5108d3fd358081af2c0b Mon Sep 17 00:00:00 2001 From: Neha Nene Date: Wed, 4 May 2022 00:03:45 -0500 Subject: [PATCH] feat: Support cast to BIGINT before aggregation (#461) * feat: support for default_cast in YAML calc field * feat: add support for casting to BIGINT before aggregations --- README.md | 2 ++ data_validation/__main__.py | 14 +++++---- data_validation/cli_tools.py | 18 ++++++++++++ data_validation/combiner.py | 4 +++ data_validation/config_manager.py | 29 +++++++++++++------ .../query_builder/query_builder.py | 3 +- 6 files changed, 53 insertions(+), 17 deletions(-) diff --git a/README.md b/README.md index c25f48093..6061f955c 100644 --- a/README.md +++ b/README.md @@ -127,6 +127,8 @@ data-validation (--verbose or -v) validate column Service account to use for BigQuery result handler output. [--wildcard-include-string-len or -wis] If flag is present, include string columns in aggregation as len(string_col) + [--cast-to-bigint or -ctb] + If flag is present, cast all int32 columns to int64 before aggregation [--filters SOURCE_FILTER:TARGET_FILTER] Colon separated string values of source and target filters. If target filter is not provided, the source filter will run on source and target tables. diff --git a/data_validation/__main__.py b/data_validation/__main__.py index 922b8bab1..117e2cb8b 100644 --- a/data_validation/__main__.py +++ b/data_validation/__main__.py @@ -66,35 +66,37 @@ def get_aggregate_config(args, config_manager): if args.wildcard_include_string_len: supported_data_types.append("string") + cast_to_bigint = True if args.cast_to_bigint else False + if args.count: col_args = None if args.count == "*" else cli_tools.get_arg_list(args.count) aggregate_configs += config_manager.build_config_column_aggregates( - "count", col_args, None + "count", col_args, None, cast_to_bigint=cast_to_bigint ) if args.sum: col_args = None if args.sum == "*" else cli_tools.get_arg_list(args.sum) aggregate_configs += config_manager.build_config_column_aggregates( - "sum", col_args, supported_data_types + "sum", col_args, supported_data_types, cast_to_bigint=cast_to_bigint ) if args.avg: col_args = None if args.avg == "*" else cli_tools.get_arg_list(args.avg) aggregate_configs += config_manager.build_config_column_aggregates( - "avg", col_args, supported_data_types + "avg", col_args, supported_data_types, cast_to_bigint=cast_to_bigint ) if args.min: col_args = None if args.min == "*" else cli_tools.get_arg_list(args.min) aggregate_configs += config_manager.build_config_column_aggregates( - "min", col_args, supported_data_types + "min", col_args, supported_data_types, cast_to_bigint=cast_to_bigint ) if args.max: col_args = None if args.max == "*" else cli_tools.get_arg_list(args.max) aggregate_configs += config_manager.build_config_column_aggregates( - "max", col_args, supported_data_types + "max", col_args, supported_data_types, cast_to_bigint=cast_to_bigint ) if args.bit_xor: col_args = None if args.bit_xor == "*" else cli_tools.get_arg_list(args.bit_xor) aggregate_configs += config_manager.build_config_column_aggregates( - "bit_xor", col_args, supported_data_types + "bit_xor", col_args, supported_data_types, cast_to_bigint=cast_to_bigint ) return aggregate_configs diff --git a/data_validation/cli_tools.py b/data_validation/cli_tools.py index f9fd22e68..03103f8bf 100644 --- a/data_validation/cli_tools.py +++ b/data_validation/cli_tools.py @@ -328,6 +328,12 @@ def _configure_run_parser(subparsers): action="store_true", help="Include string fields for wildcard aggregations.", ) + run_parser.add_argument( + "--cast-to-bigint", + "-ctb", + action="store_true", + help="Cast any int32 fields to int64 for large aggregations.", + ) def _configure_connection_parser(subparsers): @@ -527,6 +533,12 @@ def _configure_column_parser(column_parser): action="store_true", help="Include string fields for wildcard aggregations.", ) + column_parser.add_argument( + "--cast-to-bigint", + "-ctb", + action="store_true", + help="Cast any int32 fields to int64 for large aggregations.", + ) def _configure_schema_parser(schema_parser): @@ -621,6 +633,12 @@ def _configure_custom_query_parser(custom_query_parser): action="store_true", help="Include string fields for wildcard aggregations.", ) + custom_query_parser.add_argument( + "--cast-to-bigint", + "-ctb", + action="store_true", + help="Cast any int32 fields to int64 for large aggregations.", + ) def _add_common_arguments(parser): diff --git a/data_validation/combiner.py b/data_validation/combiner.py index b2d430409..e4e5dd413 100644 --- a/data_validation/combiner.py +++ b/data_validation/combiner.py @@ -114,6 +114,10 @@ def _calculate_difference(field_differences, datatype, validation, is_value_comp difference = pct_difference = ibis.null() validation_status = ( ibis.case() + .when( + target_value.isnull() & source_value.isnull(), + consts.VALIDATION_STATUS_SUCCESS, + ) .when(target_value == source_value, consts.VALIDATION_STATUS_SUCCESS) .else_(consts.VALIDATION_STATUS_FAIL) .end() diff --git a/data_validation/config_manager.py b/data_validation/config_manager.py index 31387cc49..fc7b984c6 100644 --- a/data_validation/config_manager.py +++ b/data_validation/config_manager.py @@ -459,6 +459,8 @@ def append_pre_agg_calc_field(self, column, agg_type, column_type): calc_func = "length" elif column_type == "timestamp": calc_func = "epoch_seconds" + elif column_type == "int32": + calc_func = "cast" else: raise ValueError(f"Unsupported column type: {column_type}") @@ -470,6 +472,9 @@ def append_pre_agg_calc_field(self, column, agg_type, column_type): consts.CONFIG_DEPTH: 0, } + if column_type == "int32": + calculated_config["default_cast"] = "int64" + existing_calc_fields = [ x[consts.CONFIG_FIELD_ALIAS] for x in self.calculated_fields ] @@ -484,7 +489,9 @@ def append_pre_agg_calc_field(self, column, agg_type, column_type): } return aggregate_config - def build_config_column_aggregates(self, agg_type, arg_value, supported_types): + def build_config_column_aggregates( + self, agg_type, arg_value, supported_types, cast_to_bigint=False + ): """Return list of aggregate objects of given agg_type.""" aggregate_configs = [] source_table = self.get_source_ibis_calculated_table() @@ -516,14 +523,18 @@ def build_config_column_aggregates(self, agg_type, arg_value, supported_types): ) continue - if column_type == "string" or ( - column_type == "timestamp" - and agg_type - in ( - "sum", - "avg", - "bit_xor", - ) # timestamps: do not convert to epoch seconds for min/max + if ( + column_type == "string" + or (cast_to_bigint and column_type == "int32") + or ( + column_type == "timestamp" + and agg_type + in ( + "sum", + "avg", + "bit_xor", + ) # timestamps: do not convert to epoch seconds for min/max + ) ): aggregate_config = self.append_pre_agg_calc_field( column, agg_type, column_type diff --git a/data_validation/query_builder/query_builder.py b/data_validation/query_builder/query_builder.py index e83942a0c..69ad5ca2d 100644 --- a/data_validation/query_builder/query_builder.py +++ b/data_validation/query_builder/query_builder.py @@ -360,8 +360,7 @@ def epoch_seconds(config, fields): @staticmethod def cast(config, fields): - if config.get("default_cast") is None: - target_type = "string" + target_type = config.get("default_cast", "string") return CalculatedField( ibis.expr.api.ValueExpr.cast, config,