Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Prevent Oracle blob throwing exceptions during column validation #1005

Merged
merged 4 commits into from
Sep 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 38 additions & 15 deletions data_validation/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -593,11 +593,14 @@ def build_and_append_pre_agg_calc_config(
def append_pre_agg_calc_field(
self, source_column, target_column, agg_type, column_type, column_position
) -> dict:
"""Append calculated field for length(string) or epoch_seconds(timestamp) for preprocessing before column validation aggregation."""
"""Append calculated field for length(string | binary) or epoch_seconds(timestamp) for preprocessing before column validation aggregation."""
depth, cast_type = 0, None
if column_type in ["string", "!string"]:
calc_func = "length"

elif column_type in ["binary", "!binary"]:
calc_func = "byte_length"

elif column_type in ["timestamp", "!timestamp", "date", "!date"]:
if (
self.source_client.name == "bigquery"
Expand Down Expand Up @@ -669,6 +672,37 @@ def decimal_too_big_for_64bit(
)
)

def require_pre_agg_calc_field(
column_type: str, agg_type: str, cast_to_bigint: bool
) -> bool:
if column_type in ["string", "!string"]:
return True
elif column_type in ["binary", "!binary"]:
if agg_type == "count":
# Oracle BLOB is invalid for use with SQL COUNT function.
return bool(
self.source_client.name == "oracle"
or self.target_client.name == "oracle"
)
else:
# Convert to length for any min/max/sum on binary columns.
return True
elif cast_to_bigint and column_type in ["int32", "!int32"]:
return True
elif column_type in [
"timestamp",
"!timestamp",
"date",
"!date",
] and agg_type in (
"sum",
"avg",
"bit_xor",
):
# For timestamps: do not convert to epoch seconds for min/max
return True
return False

aggregate_configs = []
source_table = self.get_source_ibis_calculated_table()
target_table = self.get_target_ibis_calculated_table()
Expand All @@ -687,6 +721,8 @@ def decimal_too_big_for_64bit(
"!timestamp",
"date",
"!date",
"binary",
"!binary",
]

allowlist_columns = arg_value or casefold_source_columns
Expand Down Expand Up @@ -715,27 +751,14 @@ def decimal_too_big_for_64bit(
casefold_target_columns[column]
].type()

if (
column_type in ["string", "!string"]
or (cast_to_bigint and column_type in ["int32", "!int32"])
or (
column_type in ["timestamp", "!timestamp", "date", "!date"]
and agg_type
in (
"sum",
"avg",
"bit_xor",
) # For timestamps: do not convert to epoch seconds for min/max
)
):
if require_pre_agg_calc_field(column_type, agg_type, cast_to_bigint):
aggregate_config = self.append_pre_agg_calc_field(
casefold_source_columns[column],
casefold_target_columns[column],
agg_type,
column_type,
column_position,
)

else:
aggregate_config = {
consts.CONFIG_SOURCE_COLUMN: casefold_source_columns[column],
Expand Down
8 changes: 8 additions & 0 deletions data_validation/query_builder/query_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,14 @@ def length(config, fields):
fields,
)

@staticmethod
def byte_length(config, fields):
return CalculatedField(
ibis.expr.types.BinaryValue.byte_length,
config,
fields,
)

@staticmethod
def rstrip(config, fields):
return CalculatedField(
Expand Down
34 changes: 33 additions & 1 deletion third_party/ibis/ibis_addon/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
Value,
ExtractEpochSeconds,
)
from ibis.expr.types import NumericValue, TemporalValue
from ibis.expr.types import BinaryValue, NumericValue, TemporalValue

import third_party.ibis.ibis_mysql.compiler
import third_party.ibis.ibis_postgres.client
Expand All @@ -76,6 +76,12 @@
SnowflakeExprTranslator = None


class BinaryLength(Value):
arg = rlz.one_of([rlz.value(dt.Binary)])
output_dtype = dt.int32
output_shape = rlz.shape_like("arg")


class ToChar(Value):
arg = rlz.one_of(
[
Expand All @@ -94,6 +100,10 @@ class RawSQL(Comparison):
pass


def compile_binary_length(binary_value):
return BinaryLength(binary_value).to_expr()


def compile_to_char(numeric_value, fmt):
return ToChar(numeric_value, fmt=fmt).to_expr()

Expand Down Expand Up @@ -292,6 +302,16 @@ def sa_format_to_char(translator, op):
return sa.func.to_char(arg, fmt)


def sa_format_binary_length(translator, op):
arg = translator.translate(op.arg)
return sa.func.length(arg)


def sa_format_binary_length_mssql(translator, op):
arg = translator.translate(op.arg)
return sa.func.datalength(arg)


def sa_cast_postgres(t, op):
# Add cast from numeric to string
arg = op.arg
Expand Down Expand Up @@ -392,12 +412,15 @@ def _bigquery_field_to_ibis_dtype(field):
return ibis_type


BinaryValue.byte_length = compile_binary_length

NumericValue.to_char = compile_to_char
TemporalValue.to_char = compile_to_char

BigQueryExprTranslator._registry[HashBytes] = format_hashbytes_bigquery
BigQueryExprTranslator._registry[RawSQL] = format_raw_sql
BigQueryExprTranslator._registry[Strftime] = strftime_bigquery
BigQueryExprTranslator._registry[BinaryLength] = sa_format_binary_length

AlchemyExprTranslator._registry[RawSQL] = format_raw_sql
AlchemyExprTranslator._registry[HashBytes] = format_hashbytes_alchemy
Expand All @@ -408,40 +431,49 @@ def _bigquery_field_to_ibis_dtype(field):
ImpalaExprTranslator._registry[HashBytes] = format_hashbytes_hive
ImpalaExprTranslator._registry[RandomScalar] = fixed_arity("RAND", 0)
ImpalaExprTranslator._registry[Strftime] = strftime_impala
ImpalaExprTranslator._registry[BinaryLength] = sa_format_binary_length

OracleExprTranslator._registry[RawSQL] = sa_format_raw_sql
OracleExprTranslator._registry[HashBytes] = sa_format_hashbytes_oracle
OracleExprTranslator._registry[ToChar] = sa_format_to_char
OracleExprTranslator._registry[BinaryLength] = sa_format_binary_length

PostgreSQLExprTranslator._registry[HashBytes] = sa_format_hashbytes_postgres
PostgreSQLExprTranslator._registry[RawSQL] = sa_format_raw_sql
PostgreSQLExprTranslator._registry[ToChar] = sa_format_to_char
PostgreSQLExprTranslator._registry[Cast] = sa_cast_postgres
PostgreSQLExprTranslator._registry[BinaryLength] = sa_format_binary_length

MsSqlExprTranslator._registry[HashBytes] = sa_format_hashbytes_mssql
MsSqlExprTranslator._registry[RawSQL] = sa_format_raw_sql
MsSqlExprTranslator._registry[IfNull] = sa_fixed_arity(sa.func.isnull, 2)
MsSqlExprTranslator._registry[StringJoin] = _sa_string_join
MsSqlExprTranslator._registry[RandomScalar] = sa_format_new_id
MsSqlExprTranslator._registry[Strftime] = strftime_mssql
MsSqlExprTranslator._registry[BinaryLength] = sa_format_binary_length_mssql

MySQLExprTranslator._registry[RawSQL] = sa_format_raw_sql
MySQLExprTranslator._registry[HashBytes] = sa_format_hashbytes_mysql
MySQLExprTranslator._registry[Strftime] = strftime_mysql
MySQLExprTranslator._registry[BinaryLength] = sa_format_binary_length

RedShiftExprTranslator._registry[HashBytes] = sa_format_hashbytes_redshift
RedShiftExprTranslator._registry[RawSQL] = sa_format_raw_sql
RedShiftExprTranslator._registry[BinaryLength] = sa_format_binary_length

Db2ExprTranslator._registry[HashBytes] = sa_format_hashbytes_db2
Db2ExprTranslator._registry[RawSQL] = sa_format_raw_sql
Db2ExprTranslator._registry[BinaryLength] = sa_format_binary_length

if TeradataExprTranslator:
TeradataExprTranslator._registry[RawSQL] = format_raw_sql
TeradataExprTranslator._registry[HashBytes] = format_hashbytes_teradata
TeradataExprTranslator._registry[BinaryLength] = sa_format_binary_length

if SnowflakeExprTranslator:
SnowflakeExprTranslator._registry[HashBytes] = sa_format_hashbytes_snowflake
SnowflakeExprTranslator._registry[RawSQL] = sa_format_raw_sql
SnowflakeExprTranslator._registry[IfNull] = sa_fixed_arity(sa.func.ifnull, 2)
SnowflakeExprTranslator._registry[ExtractEpochSeconds] = sa_epoch_time_snowflake
SnowflakeExprTranslator._registry[RandomScalar] = sa_format_random
SnowflakeExprTranslator._registry[BinaryLength] = sa_format_binary_length