diff --git a/third_party/ibis/ibis_addon/operations.py b/third_party/ibis/ibis_addon/operations.py index 963f7b090..e18c52781 100644 --- a/third_party/ibis/ibis_addon/operations.py +++ b/third_party/ibis/ibis_addon/operations.py @@ -22,27 +22,39 @@ Ibis as an override, though it would not apply for Pandas and other non-textual languages. """ +import google.cloud.bigquery as bq import ibis import ibis.expr.datatypes as dt import ibis.expr.operations as ops import ibis.expr.rules as rlz import sqlalchemy as sa -from ibis.backends.base.sql.alchemy.registry import \ - fixed_arity as sa_fixed_arity +from ibis.backends.base.sql.alchemy.registry import fixed_arity as sa_fixed_arity from ibis.backends.base.sql.alchemy.translator import AlchemyExprTranslator from ibis.backends.base.sql.compiler.translator import ExprTranslator from ibis.backends.base.sql.registry import fixed_arity -from ibis.backends.bigquery.client import _DTYPE_TO_IBIS_TYPE +from ibis.backends.bigquery.client import ( + _DTYPE_TO_IBIS_TYPE as _BQ_DTYPE_TO_IBIS_TYPE, + _LEGACY_TO_STANDARD as _BQ_LEGACY_TO_STANDARD, +) from ibis.backends.bigquery.compiler import BigQueryExprTranslator -from ibis.backends.bigquery.registry import \ - STRFTIME_FORMAT_FUNCTIONS as BQ_STRFTIME_FORMAT_FUNCTIONS +from ibis.backends.bigquery.registry import ( + STRFTIME_FORMAT_FUNCTIONS as BQ_STRFTIME_FORMAT_FUNCTIONS, +) from ibis.backends.impala.compiler import ImpalaExprTranslator from ibis.backends.mssql.compiler import MsSqlExprTranslator from ibis.backends.mysql.compiler import MySQLExprTranslator from ibis.backends.postgres.compiler import PostgreSQLExprTranslator -from ibis.expr.operations import (Cast, Comparison, HashBytes, IfNull, - RandomScalar, Strftime, StringJoin, - Value, ExtractEpochSeconds) +from ibis.expr.operations import ( + Cast, + Comparison, + HashBytes, + IfNull, + RandomScalar, + Strftime, + StringJoin, + Value, + ExtractEpochSeconds, +) from ibis.expr.types import NumericValue, TemporalValue import third_party.ibis.ibis_mysql.compiler @@ -63,11 +75,21 @@ except Exception: SnowflakeExprTranslator = None + class ToChar(Value): - arg = rlz.one_of([rlz.value(dt.Decimal), rlz.value(dt.float64), rlz.value(dt.Date), rlz.value(dt.Time), rlz.value(dt.Timestamp)]) + arg = rlz.one_of( + [ + rlz.value(dt.Decimal), + rlz.value(dt.float64), + rlz.value(dt.Date), + rlz.value(dt.Time), + rlz.value(dt.Timestamp), + ] + ) fmt = rlz.string output_type = rlz.shape_like("arg") + class RawSQL(Comparison): pass @@ -126,6 +148,7 @@ def strftime_bigquery(translator, op): strftime_format_func_name, fmt_string, arg_formatted ) + def strftime_mysql(translator, op): arg = op.arg format_string = op.format_str @@ -136,6 +159,7 @@ def strftime_mysql(translator, op): fmt_string = "%Y-%m-%d %H:%i:%S" return sa.func.date_format(arg_formatted, fmt_string) + def strftime_mssql(translator, op): """Use MS SQL CONVERT() in place of STRFTIME(). @@ -144,17 +168,17 @@ def strftime_mssql(translator, op): to string in order to complete row data comparison.""" arg, pattern = map(translator.translate, op.args) supported_convert_styles = { - "%Y-%m-%d": 23, # ISO8601 - "%Y-%m-%d %H:%M:%S": 20, # ODBC canonical - "%Y-%m-%d %H:%M:%S.%f": 21, # ODBC canonical (with milliseconds) + "%Y-%m-%d": 23, # ISO8601 + "%Y-%m-%d %H:%M:%S": 20, # ODBC canonical + "%Y-%m-%d %H:%M:%S.%f": 21, # ODBC canonical (with milliseconds) } try: convert_style = supported_convert_styles[pattern.value] except KeyError: raise NotImplementedError( - f'strftime format {pattern.value} not supported for SQL Server.' + f"strftime format {pattern.value} not supported for SQL Server." ) - result = sa.func.convert(sa.text('VARCHAR(32)'), arg, convert_style) + result = sa.func.convert(sa.text("VARCHAR(32)"), arg, convert_style) return result @@ -179,6 +203,7 @@ def format_hashbytes_hive(translator, op): else: raise ValueError(f"unexpected value for 'how': {op.how}") + def format_hashbytes_alchemy(translator, op): arg = translator.translate(op.arg) if op.how == "sha256": @@ -188,10 +213,12 @@ def format_hashbytes_alchemy(translator, op): else: raise ValueError(f"unexpected value for 'how': {op.how}") + def format_hashbytes_base(translator, op): arg = translator.translate(op.arg) return f"sha2({arg}, 256)" + def compile_raw_sql(table, sql): op = RawSQL(table[table.columns[0]].cast(dt.string), ibis.literal(sql)) return op.to_expr() @@ -206,48 +233,59 @@ def sa_format_raw_sql(translator, op): rand_col, raw_sql = op.args return sa.text(raw_sql.args[0]) + def sa_format_hashbytes_mssql(translator, op): arg = translator.translate(op.arg) cast_arg = sa.func.convert(sa.sql.literal_column("VARCHAR(MAX)"), arg) hash_func = sa.func.hashbytes(sa.sql.literal_column("'SHA2_256'"), cast_arg) - hash_to_string = sa.func.convert(sa.sql.literal_column('CHAR(64)'), hash_func, sa.sql.literal_column('2')) + hash_to_string = sa.func.convert( + sa.sql.literal_column("CHAR(64)"), hash_func, sa.sql.literal_column("2") + ) return sa.func.lower(hash_to_string) + def sa_format_hashbytes_oracle(translator, op): arg = translator.translate(op.arg) convert = sa.func.convert(arg, sa.sql.literal_column("'UTF8'")) hash_func = sa.func.standard_hash(convert, sa.sql.literal_column("'SHA256'")) return sa.func.lower(hash_func) + def sa_format_hashbytes_mysql(translator, op): arg = translator.translate(op.arg) hash_func = sa.func.sha2(arg, sa.sql.literal_column("'256'")) return hash_func + def sa_format_hashbytes_db2(translator, op): compiled_arg = translator.translate(op.arg) - hashfunc = sa.func.hash(compiled_arg,sa.sql.literal_column("2")) + hashfunc = sa.func.hash(compiled_arg, sa.sql.literal_column("2")) hex = sa.func.hex(hashfunc) return sa.func.lower(hex) + def sa_format_hashbytes_redshift(translator, op): arg = translator.translate(op.arg) return sa.sql.literal_column(f"sha2({arg}, 256)") + def sa_format_hashbytes_postgres(translator, op): arg = translator.translate(op.arg) convert = sa.func.convert_to(arg, sa.sql.literal_column("'UTF8'")) hash_func = sa.func.sha256(convert) return sa.func.encode(hash_func, sa.sql.literal_column("'hex'")) + def sa_format_hashbytes_snowflake(translator, op): arg = translator.translate(op.arg) return sa.func.sha2(arg) + def sa_epoch_time_snowflake(translator, op): arg = translator.translate(op.arg) return sa.func.date_part(sa.sql.literal_column("epoch_seconds"), arg) + def sa_format_to_char(translator, op): arg = translator.translate(op.arg) fmt = translator.translate(op.fmt) @@ -263,7 +301,12 @@ def sa_cast_postgres(t, op): sa_arg = t.translate(arg) # Specialize going from numeric(p,s>0) to string - if arg_dtype.is_decimal() and arg_dtype.scale and arg_dtype.scale > 0 and typ.is_string(): + if ( + arg_dtype.is_decimal() + and arg_dtype.scale + and arg_dtype.scale > 0 + and typ.is_string() + ): # When casting a number to string PostgreSQL includes the full scale, e.g.: # SELECT CAST(CAST(100 AS DECIMAL(5,2)) AS VARCHAR(10)); # 100.00 @@ -273,7 +316,9 @@ def sa_cast_postgres(t, op): # Would have liked to use trim_scale but this is only available in PostgreSQL 13+ # return (sa.cast(sa.func.trim_scale(arg), typ)) precision = arg_dtype.precision or 38 - fmt = "FM" + ("9" * (precision - arg_dtype.scale)) + "." + ("9" * arg_dtype.scale) + fmt = ( + "FM" + ("9" * (precision - arg_dtype.scale)) + "." + ("9" * arg_dtype.scale) + ) return sa.func.rtrim(sa.func.to_char(sa_arg, fmt), ".") # specialize going from an integer type to a timestamp @@ -281,7 +326,7 @@ def sa_cast_postgres(t, op): return t.integer_to_timestamp(sa_arg, tz=typ.timezone) if arg_dtype.is_binary() and typ.is_string(): - return sa.func.encode(sa_arg, 'escape') + return sa.func.encode(sa_arg, "escape") if typ.is_binary(): # decode yields a column of memoryview which is annoying to deal with @@ -293,19 +338,49 @@ def sa_cast_postgres(t, op): return sa.cast(sa_arg, t.get_sqla_type(typ)) + def _sa_string_join(t, op): return sa.func.concat(*map(t.translate, op.arg)) + def sa_format_new_id(t, op): return sa.func.NEWID() + +_BQ_DTYPE_TO_IBIS_TYPE["TIMESTAMP"] = dt.Timestamp(timezone="UTC") + + +@dt.dtype.register(bq.schema.SchemaField) +def _bigquery_field_to_ibis_dtype(field): + """Convert BigQuery `field` to an ibis type. + Taken from ibis.backends.bigquery.client.py for issue: + https://github.com/GoogleCloudPlatform/professional-services-data-validator/issues/926 + The nullable problem appears to be fixed in the latest Ibis but we cannot upgrade due to Redshift. + """ + typ = field.field_type + if typ == "RECORD": + fields = field.fields + assert fields, "RECORD fields are empty" + names = [el.name for el in fields] + ibis_types = list(map(dt.dtype, fields)) + ibis_type = dt.Struct(dict(zip(names, ibis_types))) + else: + ibis_type = _BQ_LEGACY_TO_STANDARD.get(typ, typ) + if ibis_type in _BQ_DTYPE_TO_IBIS_TYPE: + ibis_type = _BQ_DTYPE_TO_IBIS_TYPE[ibis_type](nullable=field.is_nullable) + else: + ibis_type = ibis_type + if field.mode == "REPEATED": + ibis_type = dt.Array(ibis_type) + return ibis_type + + NumericValue.to_char = compile_to_char TemporalValue.to_char = compile_to_char BigQueryExprTranslator._registry[HashBytes] = format_hashbytes_bigquery BigQueryExprTranslator._registry[RawSQL] = format_raw_sql BigQueryExprTranslator._registry[Strftime] = strftime_bigquery -_DTYPE_TO_IBIS_TYPE["TIMESTAMP"] = dt.Timestamp(timezone="UTC") AlchemyExprTranslator._registry[RawSQL] = format_raw_sql AlchemyExprTranslator._registry[HashBytes] = format_hashbytes_alchemy @@ -328,7 +403,7 @@ def sa_format_new_id(t, op): MsSqlExprTranslator._registry[HashBytes] = sa_format_hashbytes_mssql MsSqlExprTranslator._registry[RawSQL] = sa_format_raw_sql -MsSqlExprTranslator._registry[IfNull] = sa_fixed_arity(sa.func.isnull,2) +MsSqlExprTranslator._registry[IfNull] = sa_fixed_arity(sa.func.isnull, 2) MsSqlExprTranslator._registry[StringJoin] = _sa_string_join MsSqlExprTranslator._registry[RandomScalar] = sa_format_new_id MsSqlExprTranslator._registry[Strftime] = strftime_mssql