Skip to content

Commit

Permalink
fix: Prevent Oracle blob throwing exceptions during column validation (
Browse files Browse the repository at this point in the history
…#1005)

* fix: Add byte_length for count on Oracle binary columns. Also extended to be used for min/max/sum on any binary column

* fix: Allow binary columns to be used for aggregation validation if requested explicitly
  • Loading branch information
nj1973 committed Sep 28, 2023
1 parent 0bd48a2 commit 8df1cfa
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 16 deletions.
53 changes: 38 additions & 15 deletions data_validation/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -593,11 +593,14 @@ def build_and_append_pre_agg_calc_config(
def append_pre_agg_calc_field(
self, source_column, target_column, agg_type, column_type, column_position
) -> dict:
"""Append calculated field for length(string) or epoch_seconds(timestamp) for preprocessing before column validation aggregation."""
"""Append calculated field for length(string | binary) or epoch_seconds(timestamp) for preprocessing before column validation aggregation."""
depth, cast_type = 0, None
if column_type in ["string", "!string"]:
calc_func = "length"

elif column_type in ["binary", "!binary"]:
calc_func = "byte_length"

elif column_type in ["timestamp", "!timestamp", "date", "!date"]:
if (
self.source_client.name == "bigquery"
Expand Down Expand Up @@ -669,6 +672,37 @@ def decimal_too_big_for_64bit(
)
)

def require_pre_agg_calc_field(
column_type: str, agg_type: str, cast_to_bigint: bool
) -> bool:
if column_type in ["string", "!string"]:
return True
elif column_type in ["binary", "!binary"]:
if agg_type == "count":
# Oracle BLOB is invalid for use with SQL COUNT function.
return bool(
self.source_client.name == "oracle"
or self.target_client.name == "oracle"
)
else:
# Convert to length for any min/max/sum on binary columns.
return True
elif cast_to_bigint and column_type in ["int32", "!int32"]:
return True
elif column_type in [
"timestamp",
"!timestamp",
"date",
"!date",
] and agg_type in (
"sum",
"avg",
"bit_xor",
):
# For timestamps: do not convert to epoch seconds for min/max
return True
return False

aggregate_configs = []
source_table = self.get_source_ibis_calculated_table()
target_table = self.get_target_ibis_calculated_table()
Expand All @@ -687,6 +721,8 @@ def decimal_too_big_for_64bit(
"!timestamp",
"date",
"!date",
"binary",
"!binary",
]

allowlist_columns = arg_value or casefold_source_columns
Expand Down Expand Up @@ -715,27 +751,14 @@ def decimal_too_big_for_64bit(
casefold_target_columns[column]
].type()

if (
column_type in ["string", "!string"]
or (cast_to_bigint and column_type in ["int32", "!int32"])
or (
column_type in ["timestamp", "!timestamp", "date", "!date"]
and agg_type
in (
"sum",
"avg",
"bit_xor",
) # For timestamps: do not convert to epoch seconds for min/max
)
):
if require_pre_agg_calc_field(column_type, agg_type, cast_to_bigint):
aggregate_config = self.append_pre_agg_calc_field(
casefold_source_columns[column],
casefold_target_columns[column],
agg_type,
column_type,
column_position,
)

else:
aggregate_config = {
consts.CONFIG_SOURCE_COLUMN: casefold_source_columns[column],
Expand Down
8 changes: 8 additions & 0 deletions data_validation/query_builder/query_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,14 @@ def length(config, fields):
fields,
)

@staticmethod
def byte_length(config, fields):
return CalculatedField(
ibis.expr.types.BinaryValue.byte_length,
config,
fields,
)

@staticmethod
def rstrip(config, fields):
return CalculatedField(
Expand Down
34 changes: 33 additions & 1 deletion third_party/ibis/ibis_addon/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
Value,
ExtractEpochSeconds,
)
from ibis.expr.types import NumericValue, TemporalValue
from ibis.expr.types import BinaryValue, NumericValue, TemporalValue

import third_party.ibis.ibis_mysql.compiler
import third_party.ibis.ibis_postgres.client
Expand All @@ -76,6 +76,12 @@
SnowflakeExprTranslator = None


class BinaryLength(Value):
arg = rlz.one_of([rlz.value(dt.Binary)])
output_dtype = dt.int32
output_shape = rlz.shape_like("arg")


class ToChar(Value):
arg = rlz.one_of(
[
Expand All @@ -94,6 +100,10 @@ class RawSQL(Comparison):
pass


def compile_binary_length(binary_value):
return BinaryLength(binary_value).to_expr()


def compile_to_char(numeric_value, fmt):
return ToChar(numeric_value, fmt=fmt).to_expr()

Expand Down Expand Up @@ -292,6 +302,16 @@ def sa_format_to_char(translator, op):
return sa.func.to_char(arg, fmt)


def sa_format_binary_length(translator, op):
arg = translator.translate(op.arg)
return sa.func.length(arg)


def sa_format_binary_length_mssql(translator, op):
arg = translator.translate(op.arg)
return sa.func.datalength(arg)


def sa_cast_postgres(t, op):
# Add cast from numeric to string
arg = op.arg
Expand Down Expand Up @@ -392,12 +412,15 @@ def _bigquery_field_to_ibis_dtype(field):
return ibis_type


BinaryValue.byte_length = compile_binary_length

NumericValue.to_char = compile_to_char
TemporalValue.to_char = compile_to_char

BigQueryExprTranslator._registry[HashBytes] = format_hashbytes_bigquery
BigQueryExprTranslator._registry[RawSQL] = format_raw_sql
BigQueryExprTranslator._registry[Strftime] = strftime_bigquery
BigQueryExprTranslator._registry[BinaryLength] = sa_format_binary_length

AlchemyExprTranslator._registry[RawSQL] = format_raw_sql
AlchemyExprTranslator._registry[HashBytes] = format_hashbytes_alchemy
Expand All @@ -408,40 +431,49 @@ def _bigquery_field_to_ibis_dtype(field):
ImpalaExprTranslator._registry[HashBytes] = format_hashbytes_hive
ImpalaExprTranslator._registry[RandomScalar] = fixed_arity("RAND", 0)
ImpalaExprTranslator._registry[Strftime] = strftime_impala
ImpalaExprTranslator._registry[BinaryLength] = sa_format_binary_length

OracleExprTranslator._registry[RawSQL] = sa_format_raw_sql
OracleExprTranslator._registry[HashBytes] = sa_format_hashbytes_oracle
OracleExprTranslator._registry[ToChar] = sa_format_to_char
OracleExprTranslator._registry[BinaryLength] = sa_format_binary_length

PostgreSQLExprTranslator._registry[HashBytes] = sa_format_hashbytes_postgres
PostgreSQLExprTranslator._registry[RawSQL] = sa_format_raw_sql
PostgreSQLExprTranslator._registry[ToChar] = sa_format_to_char
PostgreSQLExprTranslator._registry[Cast] = sa_cast_postgres
PostgreSQLExprTranslator._registry[BinaryLength] = sa_format_binary_length

MsSqlExprTranslator._registry[HashBytes] = sa_format_hashbytes_mssql
MsSqlExprTranslator._registry[RawSQL] = sa_format_raw_sql
MsSqlExprTranslator._registry[IfNull] = sa_fixed_arity(sa.func.isnull, 2)
MsSqlExprTranslator._registry[StringJoin] = _sa_string_join
MsSqlExprTranslator._registry[RandomScalar] = sa_format_new_id
MsSqlExprTranslator._registry[Strftime] = strftime_mssql
MsSqlExprTranslator._registry[BinaryLength] = sa_format_binary_length_mssql

MySQLExprTranslator._registry[RawSQL] = sa_format_raw_sql
MySQLExprTranslator._registry[HashBytes] = sa_format_hashbytes_mysql
MySQLExprTranslator._registry[Strftime] = strftime_mysql
MySQLExprTranslator._registry[BinaryLength] = sa_format_binary_length

RedShiftExprTranslator._registry[HashBytes] = sa_format_hashbytes_redshift
RedShiftExprTranslator._registry[RawSQL] = sa_format_raw_sql
RedShiftExprTranslator._registry[BinaryLength] = sa_format_binary_length

Db2ExprTranslator._registry[HashBytes] = sa_format_hashbytes_db2
Db2ExprTranslator._registry[RawSQL] = sa_format_raw_sql
Db2ExprTranslator._registry[BinaryLength] = sa_format_binary_length

if TeradataExprTranslator:
TeradataExprTranslator._registry[RawSQL] = format_raw_sql
TeradataExprTranslator._registry[HashBytes] = format_hashbytes_teradata
TeradataExprTranslator._registry[BinaryLength] = sa_format_binary_length

if SnowflakeExprTranslator:
SnowflakeExprTranslator._registry[HashBytes] = sa_format_hashbytes_snowflake
SnowflakeExprTranslator._registry[RawSQL] = sa_format_raw_sql
SnowflakeExprTranslator._registry[IfNull] = sa_fixed_arity(sa.func.ifnull, 2)
SnowflakeExprTranslator._registry[ExtractEpochSeconds] = sa_epoch_time_snowflake
SnowflakeExprTranslator._registry[RandomScalar] = sa_format_random
SnowflakeExprTranslator._registry[BinaryLength] = sa_format_binary_length

0 comments on commit 8df1cfa

Please sign in to comment.