Skip to content

Commit

Permalink
fix: validate column sum/min/max issue for decimals with precision be…
Browse files Browse the repository at this point in the history
…yond int64/float64 (#918)

* fix: Column validation casts decimal aggs to string if precision > 18

* tests: Add regression test for issue-900
  • Loading branch information
nj1973 committed Jul 28, 2023
1 parent f9db68f commit 5a8d691
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 22 deletions.
20 changes: 6 additions & 14 deletions data_validation/combiner.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,12 @@ def _calculate_difference(field_differences, datatype, validation, is_value_comp
target_value = field_differences["differences_target_value"]

# Does not calculate difference between agg values for row hash due to int64 overflow
if is_value_comparison:
difference = pct_difference = ibis.null()
if is_value_comparison or isinstance(datatype, ibis.expr.datatypes.String):
# String data types i.e "None" can be returned for NULL timestamp/datetime aggs
if is_value_comparison:
difference = pct_difference = ibis.null()
else:
difference = pct_difference = ibis.null().cast("float64")
validation_status = (
ibis.case()
.when(
Expand All @@ -130,18 +134,6 @@ def _calculate_difference(field_differences, datatype, validation, is_value_comp
.else_(consts.VALIDATION_STATUS_FAIL)
.end()
)
# String data types i.e "None" can be returned for NULL timestamp/datetime aggs
elif isinstance(datatype, ibis.expr.datatypes.String):
difference = pct_difference = ibis.null().cast("float64")
validation_status = (
ibis.case()
.when(
target_value.isnull() & source_value.isnull(),
consts.VALIDATION_STATUS_SUCCESS,
)
.else_(consts.VALIDATION_STATUS_FAIL)
.end()
)
else:
difference = (target_value - source_value).cast("float64")

Expand Down
41 changes: 33 additions & 8 deletions data_validation/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,7 +592,7 @@ def build_and_append_pre_agg_calc_config(

def append_pre_agg_calc_field(
self, source_column, target_column, agg_type, column_type, column_position
):
) -> dict:
"""Append calculated field for length(string) or epoch_seconds(timestamp) for preprocessing before column validation aggregation."""
depth, cast_type = 0, None
if column_type == "string":
Expand Down Expand Up @@ -624,10 +624,6 @@ def append_pre_agg_calc_field(
calc_func = "cast"
cast_type = "int64"

elif column_type == "decimal" or column_type == "!decimal":
calc_func = "cast"
cast_type = "float64"

else:
raise ValueError(f"Unsupported column type: {column_type}")

Expand All @@ -652,6 +648,27 @@ def build_config_column_aggregates(
self, agg_type, arg_value, supported_types, cast_to_bigint=False
):
"""Return list of aggregate objects of given agg_type."""

def decimal_too_big_for_64bit(
source_column_ibis_type, target_column_ibis_type
) -> bool:
return bool(
(
isinstance(source_column_ibis_type, dt.Decimal)
and (
source_column_ibis_type.precision is None
or source_column_ibis_type.precision > 18
)
)
and (
isinstance(target_column_ibis_type, dt.Decimal)
and (
target_column_ibis_type.precision is None
or target_column_ibis_type.precision > 18
)
)
)

aggregate_configs = []
source_table = self.get_source_ibis_calculated_table()
target_table = self.get_target_ibis_calculated_table()
Expand All @@ -667,8 +684,13 @@ def build_config_column_aggregates(
allowlist_columns = arg_value or casefold_source_columns
for column_position, column in enumerate(casefold_source_columns):
# Get column type and remove precision/scale attributes
column_type_str = str(source_table[casefold_source_columns[column]].type())
column_type = column_type_str.split("(")[0]
source_column_ibis_type = source_table[
casefold_source_columns[column]
].type()
target_column_ibis_type = target_table[
casefold_target_columns[column]
].type()
column_type = str(source_column_ibis_type).split("(")[0]

if column not in allowlist_columns:
continue
Expand Down Expand Up @@ -699,7 +721,6 @@ def build_config_column_aggregates(
"bit_xor",
) # For timestamps: do not convert to epoch seconds for min/max
)
or (column_type == "decimal" or column_type == "!decimal")
):
aggregate_config = self.append_pre_agg_calc_field(
casefold_source_columns[column],
Expand All @@ -718,6 +739,10 @@ def build_config_column_aggregates(
),
consts.CONFIG_TYPE: agg_type,
}
if decimal_too_big_for_64bit(
source_column_ibis_type, target_column_ibis_type
):
aggregate_config[consts.CONFIG_CAST] = "string"

aggregate_configs.append(aggregate_config)

Expand Down
37 changes: 37 additions & 0 deletions tests/system/data_sources/test_oracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,3 +277,40 @@ def test_custom_query_validation_core_types():
df = validator.execute()
# With filter on failures the data frame should be empty
assert len(df) == 0


@mock.patch(
"data_validation.state_manager.StateManager.get_connection_config",
new=mock_get_connection_config,
)
def test_custom_query_invalid_long_decimal():
"""Oracle to BigQuery of comparisons of decimals that exceed precision of 18 (int64 & float64).
We used to have an issue where we would see false success because long numbers would lose precision
and look the same even if they differed slightly.
See: https://github.com/GoogleCloudPlatform/professional-services-data-validator/issues/900
This is the regression test.
"""
parser = cli_tools.configure_arg_parser()
# Notice the two numeric values balow have a different final digit, we expect a failure status.
args = parser.parse_args(
[
"validate",
"custom-query",
"column",
"-sc=mock-conn",
"-tc=bq-conn",
"--source-query=select to_number(1234567890123456789012345) as dec_25 from dual",
"--target-query=select cast('1234567890123456789012340' as numeric) as dec_25",
"--filter-status=fail",
"--min=dec_25",
"--max=dec_25",
"--sum=dec_25",
]
)
config_managers = main.build_config_managers_from_args(args)
assert len(config_managers) == 1
config_manager = config_managers[0]
validator = data_validation.DataValidation(config_manager.config, verbose=False)
df = validator.execute()
# With filter on failures the data frame should be populated
assert len(df) > 0

0 comments on commit 5a8d691

Please sign in to comment.