diff --git a/data_validation/config_manager.py b/data_validation/config_manager.py index 259f59d31..09fe1c13a 100644 --- a/data_validation/config_manager.py +++ b/data_validation/config_manager.py @@ -593,11 +593,14 @@ def build_and_append_pre_agg_calc_config( def append_pre_agg_calc_field( self, source_column, target_column, agg_type, column_type, column_position ) -> dict: - """Append calculated field for length(string) or epoch_seconds(timestamp) for preprocessing before column validation aggregation.""" + """Append calculated field for length(string | binary) or epoch_seconds(timestamp) for preprocessing before column validation aggregation.""" depth, cast_type = 0, None if column_type in ["string", "!string"]: calc_func = "length" + elif column_type in ["binary", "!binary"]: + calc_func = "byte_length" + elif column_type in ["timestamp", "!timestamp", "date", "!date"]: if ( self.source_client.name == "bigquery" @@ -669,6 +672,37 @@ def decimal_too_big_for_64bit( ) ) + def require_pre_agg_calc_field( + column_type: str, agg_type: str, cast_to_bigint: bool + ) -> bool: + if column_type in ["string", "!string"]: + return True + elif column_type in ["binary", "!binary"]: + if agg_type == "count": + # Oracle BLOB is invalid for use with SQL COUNT function. + return bool( + self.source_client.name == "oracle" + or self.target_client.name == "oracle" + ) + else: + # Convert to length for any min/max/sum on binary columns. + return True + elif cast_to_bigint and column_type in ["int32", "!int32"]: + return True + elif column_type in [ + "timestamp", + "!timestamp", + "date", + "!date", + ] and agg_type in ( + "sum", + "avg", + "bit_xor", + ): + # For timestamps: do not convert to epoch seconds for min/max + return True + return False + aggregate_configs = [] source_table = self.get_source_ibis_calculated_table() target_table = self.get_target_ibis_calculated_table() @@ -687,6 +721,8 @@ def decimal_too_big_for_64bit( "!timestamp", "date", "!date", + "binary", + "!binary", ] allowlist_columns = arg_value or casefold_source_columns @@ -715,19 +751,7 @@ def decimal_too_big_for_64bit( casefold_target_columns[column] ].type() - if ( - column_type in ["string", "!string"] - or (cast_to_bigint and column_type in ["int32", "!int32"]) - or ( - column_type in ["timestamp", "!timestamp", "date", "!date"] - and agg_type - in ( - "sum", - "avg", - "bit_xor", - ) # For timestamps: do not convert to epoch seconds for min/max - ) - ): + if require_pre_agg_calc_field(column_type, agg_type, cast_to_bigint): aggregate_config = self.append_pre_agg_calc_field( casefold_source_columns[column], casefold_target_columns[column], @@ -735,7 +759,6 @@ def decimal_too_big_for_64bit( column_type, column_position, ) - else: aggregate_config = { consts.CONFIG_SOURCE_COLUMN: casefold_source_columns[column], diff --git a/data_validation/query_builder/query_builder.py b/data_validation/query_builder/query_builder.py index 604a94c0c..03c593a45 100644 --- a/data_validation/query_builder/query_builder.py +++ b/data_validation/query_builder/query_builder.py @@ -341,6 +341,14 @@ def length(config, fields): fields, ) + @staticmethod + def byte_length(config, fields): + return CalculatedField( + ibis.expr.types.BinaryValue.byte_length, + config, + fields, + ) + @staticmethod def rstrip(config, fields): return CalculatedField( diff --git a/third_party/ibis/ibis_addon/operations.py b/third_party/ibis/ibis_addon/operations.py index 90403a03a..0fe7a741a 100644 --- a/third_party/ibis/ibis_addon/operations.py +++ b/third_party/ibis/ibis_addon/operations.py @@ -55,7 +55,7 @@ Value, ExtractEpochSeconds, ) -from ibis.expr.types import NumericValue, TemporalValue +from ibis.expr.types import BinaryValue, NumericValue, TemporalValue import third_party.ibis.ibis_mysql.compiler import third_party.ibis.ibis_postgres.client @@ -76,6 +76,12 @@ SnowflakeExprTranslator = None +class BinaryLength(Value): + arg = rlz.one_of([rlz.value(dt.Binary)]) + output_dtype = dt.int32 + output_shape = rlz.shape_like("arg") + + class ToChar(Value): arg = rlz.one_of( [ @@ -94,6 +100,10 @@ class RawSQL(Comparison): pass +def compile_binary_length(binary_value): + return BinaryLength(binary_value).to_expr() + + def compile_to_char(numeric_value, fmt): return ToChar(numeric_value, fmt=fmt).to_expr() @@ -292,6 +302,16 @@ def sa_format_to_char(translator, op): return sa.func.to_char(arg, fmt) +def sa_format_binary_length(translator, op): + arg = translator.translate(op.arg) + return sa.func.length(arg) + + +def sa_format_binary_length_mssql(translator, op): + arg = translator.translate(op.arg) + return sa.func.datalength(arg) + + def sa_cast_postgres(t, op): # Add cast from numeric to string arg = op.arg @@ -392,12 +412,15 @@ def _bigquery_field_to_ibis_dtype(field): return ibis_type +BinaryValue.byte_length = compile_binary_length + NumericValue.to_char = compile_to_char TemporalValue.to_char = compile_to_char BigQueryExprTranslator._registry[HashBytes] = format_hashbytes_bigquery BigQueryExprTranslator._registry[RawSQL] = format_raw_sql BigQueryExprTranslator._registry[Strftime] = strftime_bigquery +BigQueryExprTranslator._registry[BinaryLength] = sa_format_binary_length AlchemyExprTranslator._registry[RawSQL] = format_raw_sql AlchemyExprTranslator._registry[HashBytes] = format_hashbytes_alchemy @@ -408,15 +431,18 @@ def _bigquery_field_to_ibis_dtype(field): ImpalaExprTranslator._registry[HashBytes] = format_hashbytes_hive ImpalaExprTranslator._registry[RandomScalar] = fixed_arity("RAND", 0) ImpalaExprTranslator._registry[Strftime] = strftime_impala +ImpalaExprTranslator._registry[BinaryLength] = sa_format_binary_length OracleExprTranslator._registry[RawSQL] = sa_format_raw_sql OracleExprTranslator._registry[HashBytes] = sa_format_hashbytes_oracle OracleExprTranslator._registry[ToChar] = sa_format_to_char +OracleExprTranslator._registry[BinaryLength] = sa_format_binary_length PostgreSQLExprTranslator._registry[HashBytes] = sa_format_hashbytes_postgres PostgreSQLExprTranslator._registry[RawSQL] = sa_format_raw_sql PostgreSQLExprTranslator._registry[ToChar] = sa_format_to_char PostgreSQLExprTranslator._registry[Cast] = sa_cast_postgres +PostgreSQLExprTranslator._registry[BinaryLength] = sa_format_binary_length MsSqlExprTranslator._registry[HashBytes] = sa_format_hashbytes_mssql MsSqlExprTranslator._registry[RawSQL] = sa_format_raw_sql @@ -424,20 +450,25 @@ def _bigquery_field_to_ibis_dtype(field): MsSqlExprTranslator._registry[StringJoin] = _sa_string_join MsSqlExprTranslator._registry[RandomScalar] = sa_format_new_id MsSqlExprTranslator._registry[Strftime] = strftime_mssql +MsSqlExprTranslator._registry[BinaryLength] = sa_format_binary_length_mssql MySQLExprTranslator._registry[RawSQL] = sa_format_raw_sql MySQLExprTranslator._registry[HashBytes] = sa_format_hashbytes_mysql MySQLExprTranslator._registry[Strftime] = strftime_mysql +MySQLExprTranslator._registry[BinaryLength] = sa_format_binary_length RedShiftExprTranslator._registry[HashBytes] = sa_format_hashbytes_redshift RedShiftExprTranslator._registry[RawSQL] = sa_format_raw_sql +RedShiftExprTranslator._registry[BinaryLength] = sa_format_binary_length Db2ExprTranslator._registry[HashBytes] = sa_format_hashbytes_db2 Db2ExprTranslator._registry[RawSQL] = sa_format_raw_sql +Db2ExprTranslator._registry[BinaryLength] = sa_format_binary_length if TeradataExprTranslator: TeradataExprTranslator._registry[RawSQL] = format_raw_sql TeradataExprTranslator._registry[HashBytes] = format_hashbytes_teradata + TeradataExprTranslator._registry[BinaryLength] = sa_format_binary_length if SnowflakeExprTranslator: SnowflakeExprTranslator._registry[HashBytes] = sa_format_hashbytes_snowflake @@ -445,3 +476,4 @@ def _bigquery_field_to_ibis_dtype(field): SnowflakeExprTranslator._registry[IfNull] = sa_fixed_arity(sa.func.ifnull, 2) SnowflakeExprTranslator._registry[ExtractEpochSeconds] = sa_epoch_time_snowflake SnowflakeExprTranslator._registry[RandomScalar] = sa_format_random + SnowflakeExprTranslator._registry[BinaryLength] = sa_format_binary_length