Skip to content

Commit

Permalink
feat: Support cast to BIGINT before aggregation (#461)
Browse files Browse the repository at this point in the history
* feat: support for default_cast in YAML calc field

* feat: add support for casting to BIGINT before aggregations
  • Loading branch information
nehanene15 committed May 4, 2022
1 parent f057fe8 commit ca598a0
Show file tree
Hide file tree
Showing 6 changed files with 53 additions and 17 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,8 @@ data-validation (--verbose or -v) validate column
Service account to use for BigQuery result handler output.
[--wildcard-include-string-len or -wis]
If flag is present, include string columns in aggregation as len(string_col)
[--cast-to-bigint or -ctb]
If flag is present, cast all int32 columns to int64 before aggregation
[--filters SOURCE_FILTER:TARGET_FILTER]
Colon separated string values of source and target filters.
If target filter is not provided, the source filter will run on source and target tables.
Expand Down
14 changes: 8 additions & 6 deletions data_validation/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,35 +66,37 @@ def get_aggregate_config(args, config_manager):
if args.wildcard_include_string_len:
supported_data_types.append("string")

cast_to_bigint = True if args.cast_to_bigint else False

if args.count:
col_args = None if args.count == "*" else cli_tools.get_arg_list(args.count)
aggregate_configs += config_manager.build_config_column_aggregates(
"count", col_args, None
"count", col_args, None, cast_to_bigint=cast_to_bigint
)
if args.sum:
col_args = None if args.sum == "*" else cli_tools.get_arg_list(args.sum)
aggregate_configs += config_manager.build_config_column_aggregates(
"sum", col_args, supported_data_types
"sum", col_args, supported_data_types, cast_to_bigint=cast_to_bigint
)
if args.avg:
col_args = None if args.avg == "*" else cli_tools.get_arg_list(args.avg)
aggregate_configs += config_manager.build_config_column_aggregates(
"avg", col_args, supported_data_types
"avg", col_args, supported_data_types, cast_to_bigint=cast_to_bigint
)
if args.min:
col_args = None if args.min == "*" else cli_tools.get_arg_list(args.min)
aggregate_configs += config_manager.build_config_column_aggregates(
"min", col_args, supported_data_types
"min", col_args, supported_data_types, cast_to_bigint=cast_to_bigint
)
if args.max:
col_args = None if args.max == "*" else cli_tools.get_arg_list(args.max)
aggregate_configs += config_manager.build_config_column_aggregates(
"max", col_args, supported_data_types
"max", col_args, supported_data_types, cast_to_bigint=cast_to_bigint
)
if args.bit_xor:
col_args = None if args.bit_xor == "*" else cli_tools.get_arg_list(args.bit_xor)
aggregate_configs += config_manager.build_config_column_aggregates(
"bit_xor", col_args, supported_data_types
"bit_xor", col_args, supported_data_types, cast_to_bigint=cast_to_bigint
)
return aggregate_configs

Expand Down
18 changes: 18 additions & 0 deletions data_validation/cli_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,12 @@ def _configure_run_parser(subparsers):
action="store_true",
help="Include string fields for wildcard aggregations.",
)
run_parser.add_argument(
"--cast-to-bigint",
"-ctb",
action="store_true",
help="Cast any int32 fields to int64 for large aggregations.",
)


def _configure_connection_parser(subparsers):
Expand Down Expand Up @@ -527,6 +533,12 @@ def _configure_column_parser(column_parser):
action="store_true",
help="Include string fields for wildcard aggregations.",
)
column_parser.add_argument(
"--cast-to-bigint",
"-ctb",
action="store_true",
help="Cast any int32 fields to int64 for large aggregations.",
)


def _configure_schema_parser(schema_parser):
Expand Down Expand Up @@ -621,6 +633,12 @@ def _configure_custom_query_parser(custom_query_parser):
action="store_true",
help="Include string fields for wildcard aggregations.",
)
custom_query_parser.add_argument(
"--cast-to-bigint",
"-ctb",
action="store_true",
help="Cast any int32 fields to int64 for large aggregations.",
)


def _add_common_arguments(parser):
Expand Down
4 changes: 4 additions & 0 deletions data_validation/combiner.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,10 @@ def _calculate_difference(field_differences, datatype, validation, is_value_comp
difference = pct_difference = ibis.null()
validation_status = (
ibis.case()
.when(
target_value.isnull() & source_value.isnull(),
consts.VALIDATION_STATUS_SUCCESS,
)
.when(target_value == source_value, consts.VALIDATION_STATUS_SUCCESS)
.else_(consts.VALIDATION_STATUS_FAIL)
.end()
Expand Down
29 changes: 20 additions & 9 deletions data_validation/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,6 +459,8 @@ def append_pre_agg_calc_field(self, column, agg_type, column_type):
calc_func = "length"
elif column_type == "timestamp":
calc_func = "epoch_seconds"
elif column_type == "int32":
calc_func = "cast"
else:
raise ValueError(f"Unsupported column type: {column_type}")

Expand All @@ -470,6 +472,9 @@ def append_pre_agg_calc_field(self, column, agg_type, column_type):
consts.CONFIG_DEPTH: 0,
}

if column_type == "int32":
calculated_config["default_cast"] = "int64"

existing_calc_fields = [
x[consts.CONFIG_FIELD_ALIAS] for x in self.calculated_fields
]
Expand All @@ -484,7 +489,9 @@ def append_pre_agg_calc_field(self, column, agg_type, column_type):
}
return aggregate_config

def build_config_column_aggregates(self, agg_type, arg_value, supported_types):
def build_config_column_aggregates(
self, agg_type, arg_value, supported_types, cast_to_bigint=False
):
"""Return list of aggregate objects of given agg_type."""
aggregate_configs = []
source_table = self.get_source_ibis_calculated_table()
Expand Down Expand Up @@ -516,14 +523,18 @@ def build_config_column_aggregates(self, agg_type, arg_value, supported_types):
)
continue

if column_type == "string" or (
column_type == "timestamp"
and agg_type
in (
"sum",
"avg",
"bit_xor",
) # timestamps: do not convert to epoch seconds for min/max
if (
column_type == "string"
or (cast_to_bigint and column_type == "int32")
or (
column_type == "timestamp"
and agg_type
in (
"sum",
"avg",
"bit_xor",
) # timestamps: do not convert to epoch seconds for min/max
)
):
aggregate_config = self.append_pre_agg_calc_field(
column, agg_type, column_type
Expand Down
3 changes: 1 addition & 2 deletions data_validation/query_builder/query_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,8 +360,7 @@ def epoch_seconds(config, fields):

@staticmethod
def cast(config, fields):
if config.get("default_cast") is None:
target_type = "string"
target_type = config.get("default_cast", "string")
return CalculatedField(
ibis.expr.api.ValueExpr.cast,
config,
Expand Down

0 comments on commit ca598a0

Please sign in to comment.