Skip to content

Commit

Permalink
fix: BigQuery now sets nullable correctly
Browse files Browse the repository at this point in the history
  • Loading branch information
nj1973 committed Aug 7, 2023
1 parent dc9fb56 commit 773da9c
Showing 1 changed file with 96 additions and 21 deletions.
117 changes: 96 additions & 21 deletions third_party/ibis/ibis_addon/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,27 +22,39 @@
Ibis as an override, though it would not apply for Pandas and other
non-textual languages.
"""
import google.cloud.bigquery as bq
import ibis
import ibis.expr.datatypes as dt
import ibis.expr.operations as ops
import ibis.expr.rules as rlz
import sqlalchemy as sa
from ibis.backends.base.sql.alchemy.registry import \
fixed_arity as sa_fixed_arity
from ibis.backends.base.sql.alchemy.registry import fixed_arity as sa_fixed_arity
from ibis.backends.base.sql.alchemy.translator import AlchemyExprTranslator
from ibis.backends.base.sql.compiler.translator import ExprTranslator
from ibis.backends.base.sql.registry import fixed_arity
from ibis.backends.bigquery.client import _DTYPE_TO_IBIS_TYPE
from ibis.backends.bigquery.client import (
_DTYPE_TO_IBIS_TYPE as _BQ_DTYPE_TO_IBIS_TYPE,
_LEGACY_TO_STANDARD as _BQ_LEGACY_TO_STANDARD,
)
from ibis.backends.bigquery.compiler import BigQueryExprTranslator
from ibis.backends.bigquery.registry import \
STRFTIME_FORMAT_FUNCTIONS as BQ_STRFTIME_FORMAT_FUNCTIONS
from ibis.backends.bigquery.registry import (
STRFTIME_FORMAT_FUNCTIONS as BQ_STRFTIME_FORMAT_FUNCTIONS,
)
from ibis.backends.impala.compiler import ImpalaExprTranslator
from ibis.backends.mssql.compiler import MsSqlExprTranslator
from ibis.backends.mysql.compiler import MySQLExprTranslator
from ibis.backends.postgres.compiler import PostgreSQLExprTranslator
from ibis.expr.operations import (Cast, Comparison, HashBytes, IfNull,
RandomScalar, Strftime, StringJoin,
Value, ExtractEpochSeconds)
from ibis.expr.operations import (
Cast,
Comparison,
HashBytes,
IfNull,
RandomScalar,
Strftime,
StringJoin,
Value,
ExtractEpochSeconds,
)
from ibis.expr.types import NumericValue, TemporalValue

import third_party.ibis.ibis_mysql.compiler
Expand All @@ -63,11 +75,21 @@
except Exception:
SnowflakeExprTranslator = None


class ToChar(Value):
arg = rlz.one_of([rlz.value(dt.Decimal), rlz.value(dt.float64), rlz.value(dt.Date), rlz.value(dt.Time), rlz.value(dt.Timestamp)])
arg = rlz.one_of(
[
rlz.value(dt.Decimal),
rlz.value(dt.float64),
rlz.value(dt.Date),
rlz.value(dt.Time),
rlz.value(dt.Timestamp),
]
)
fmt = rlz.string
output_type = rlz.shape_like("arg")


class RawSQL(Comparison):
pass

Expand Down Expand Up @@ -126,6 +148,7 @@ def strftime_bigquery(translator, op):
strftime_format_func_name, fmt_string, arg_formatted
)


def strftime_mysql(translator, op):
arg = op.arg
format_string = op.format_str
Expand All @@ -136,6 +159,7 @@ def strftime_mysql(translator, op):
fmt_string = "%Y-%m-%d %H:%i:%S"
return sa.func.date_format(arg_formatted, fmt_string)


def strftime_mssql(translator, op):
"""Use MS SQL CONVERT() in place of STRFTIME().
Expand All @@ -144,17 +168,17 @@ def strftime_mssql(translator, op):
to string in order to complete row data comparison."""
arg, pattern = map(translator.translate, op.args)
supported_convert_styles = {
"%Y-%m-%d": 23, # ISO8601
"%Y-%m-%d %H:%M:%S": 20, # ODBC canonical
"%Y-%m-%d %H:%M:%S.%f": 21, # ODBC canonical (with milliseconds)
"%Y-%m-%d": 23, # ISO8601
"%Y-%m-%d %H:%M:%S": 20, # ODBC canonical
"%Y-%m-%d %H:%M:%S.%f": 21, # ODBC canonical (with milliseconds)
}
try:
convert_style = supported_convert_styles[pattern.value]
except KeyError:
raise NotImplementedError(
f'strftime format {pattern.value} not supported for SQL Server.'
f"strftime format {pattern.value} not supported for SQL Server."
)
result = sa.func.convert(sa.text('VARCHAR(32)'), arg, convert_style)
result = sa.func.convert(sa.text("VARCHAR(32)"), arg, convert_style)
return result


Expand All @@ -179,6 +203,7 @@ def format_hashbytes_hive(translator, op):
else:
raise ValueError(f"unexpected value for 'how': {op.how}")


def format_hashbytes_alchemy(translator, op):
arg = translator.translate(op.arg)
if op.how == "sha256":
Expand All @@ -188,10 +213,12 @@ def format_hashbytes_alchemy(translator, op):
else:
raise ValueError(f"unexpected value for 'how': {op.how}")


def format_hashbytes_base(translator, op):
arg = translator.translate(op.arg)
return f"sha2({arg}, 256)"


def compile_raw_sql(table, sql):
op = RawSQL(table[table.columns[0]].cast(dt.string), ibis.literal(sql))
return op.to_expr()
Expand All @@ -206,48 +233,59 @@ def sa_format_raw_sql(translator, op):
rand_col, raw_sql = op.args
return sa.text(raw_sql.args[0])


def sa_format_hashbytes_mssql(translator, op):
arg = translator.translate(op.arg)
cast_arg = sa.func.convert(sa.sql.literal_column("VARCHAR(MAX)"), arg)
hash_func = sa.func.hashbytes(sa.sql.literal_column("'SHA2_256'"), cast_arg)
hash_to_string = sa.func.convert(sa.sql.literal_column('CHAR(64)'), hash_func, sa.sql.literal_column('2'))
hash_to_string = sa.func.convert(
sa.sql.literal_column("CHAR(64)"), hash_func, sa.sql.literal_column("2")
)
return sa.func.lower(hash_to_string)


def sa_format_hashbytes_oracle(translator, op):
arg = translator.translate(op.arg)
convert = sa.func.convert(arg, sa.sql.literal_column("'UTF8'"))
hash_func = sa.func.standard_hash(convert, sa.sql.literal_column("'SHA256'"))
return sa.func.lower(hash_func)


def sa_format_hashbytes_mysql(translator, op):
arg = translator.translate(op.arg)
hash_func = sa.func.sha2(arg, sa.sql.literal_column("'256'"))
return hash_func


def sa_format_hashbytes_db2(translator, op):
compiled_arg = translator.translate(op.arg)
hashfunc = sa.func.hash(compiled_arg,sa.sql.literal_column("2"))
hashfunc = sa.func.hash(compiled_arg, sa.sql.literal_column("2"))
hex = sa.func.hex(hashfunc)
return sa.func.lower(hex)


def sa_format_hashbytes_redshift(translator, op):
arg = translator.translate(op.arg)
return sa.sql.literal_column(f"sha2({arg}, 256)")


def sa_format_hashbytes_postgres(translator, op):
arg = translator.translate(op.arg)
convert = sa.func.convert_to(arg, sa.sql.literal_column("'UTF8'"))
hash_func = sa.func.sha256(convert)
return sa.func.encode(hash_func, sa.sql.literal_column("'hex'"))


def sa_format_hashbytes_snowflake(translator, op):
arg = translator.translate(op.arg)
return sa.func.sha2(arg)


def sa_epoch_time_snowflake(translator, op):
arg = translator.translate(op.arg)
return sa.func.date_part(sa.sql.literal_column("epoch_seconds"), arg)


def sa_format_to_char(translator, op):
arg = translator.translate(op.arg)
fmt = translator.translate(op.fmt)
Expand All @@ -263,7 +301,12 @@ def sa_cast_postgres(t, op):
sa_arg = t.translate(arg)

# Specialize going from numeric(p,s>0) to string
if arg_dtype.is_decimal() and arg_dtype.scale and arg_dtype.scale > 0 and typ.is_string():
if (
arg_dtype.is_decimal()
and arg_dtype.scale
and arg_dtype.scale > 0
and typ.is_string()
):
# When casting a number to string PostgreSQL includes the full scale, e.g.:
# SELECT CAST(CAST(100 AS DECIMAL(5,2)) AS VARCHAR(10));
# 100.00
Expand All @@ -273,15 +316,17 @@ def sa_cast_postgres(t, op):
# Would have liked to use trim_scale but this is only available in PostgreSQL 13+
# return (sa.cast(sa.func.trim_scale(arg), typ))
precision = arg_dtype.precision or 38
fmt = "FM" + ("9" * (precision - arg_dtype.scale)) + "." + ("9" * arg_dtype.scale)
fmt = (
"FM" + ("9" * (precision - arg_dtype.scale)) + "." + ("9" * arg_dtype.scale)
)
return sa.func.rtrim(sa.func.to_char(sa_arg, fmt), ".")

# specialize going from an integer type to a timestamp
if arg_dtype.is_integer() and typ.is_timestamp():
return t.integer_to_timestamp(sa_arg, tz=typ.timezone)

if arg_dtype.is_binary() and typ.is_string():
return sa.func.encode(sa_arg, 'escape')
return sa.func.encode(sa_arg, "escape")

if typ.is_binary():
# decode yields a column of memoryview which is annoying to deal with
Expand All @@ -293,19 +338,49 @@ def sa_cast_postgres(t, op):

return sa.cast(sa_arg, t.get_sqla_type(typ))


def _sa_string_join(t, op):
return sa.func.concat(*map(t.translate, op.arg))


def sa_format_new_id(t, op):
return sa.func.NEWID()


_BQ_DTYPE_TO_IBIS_TYPE["TIMESTAMP"] = dt.Timestamp(timezone="UTC")


@dt.dtype.register(bq.schema.SchemaField)
def _bigquery_field_to_ibis_dtype(field):
"""Convert BigQuery `field` to an ibis type.
Taken from ibis.backends.bigquery.client.py for issue:
https://github.com/GoogleCloudPlatform/professional-services-data-validator/issues/926
The nullable problem appears to be fixed in the latest Ibis but we cannot upgrade due to Redshift.
"""
typ = field.field_type
if typ == "RECORD":
fields = field.fields
assert fields, "RECORD fields are empty"
names = [el.name for el in fields]
ibis_types = list(map(dt.dtype, fields))
ibis_type = dt.Struct(dict(zip(names, ibis_types)))
else:
ibis_type = _BQ_LEGACY_TO_STANDARD.get(typ, typ)
if ibis_type in _BQ_DTYPE_TO_IBIS_TYPE:
ibis_type = _BQ_DTYPE_TO_IBIS_TYPE[ibis_type](nullable=field.is_nullable)
else:
ibis_type = ibis_type
if field.mode == "REPEATED":
ibis_type = dt.Array(ibis_type)
return ibis_type


NumericValue.to_char = compile_to_char
TemporalValue.to_char = compile_to_char

BigQueryExprTranslator._registry[HashBytes] = format_hashbytes_bigquery
BigQueryExprTranslator._registry[RawSQL] = format_raw_sql
BigQueryExprTranslator._registry[Strftime] = strftime_bigquery
_DTYPE_TO_IBIS_TYPE["TIMESTAMP"] = dt.Timestamp(timezone="UTC")

AlchemyExprTranslator._registry[RawSQL] = format_raw_sql
AlchemyExprTranslator._registry[HashBytes] = format_hashbytes_alchemy
Expand All @@ -328,7 +403,7 @@ def sa_format_new_id(t, op):

MsSqlExprTranslator._registry[HashBytes] = sa_format_hashbytes_mssql
MsSqlExprTranslator._registry[RawSQL] = sa_format_raw_sql
MsSqlExprTranslator._registry[IfNull] = sa_fixed_arity(sa.func.isnull,2)
MsSqlExprTranslator._registry[IfNull] = sa_fixed_arity(sa.func.isnull, 2)
MsSqlExprTranslator._registry[StringJoin] = _sa_string_join
MsSqlExprTranslator._registry[RandomScalar] = sa_format_new_id
MsSqlExprTranslator._registry[Strftime] = strftime_mssql
Expand Down

0 comments on commit 773da9c

Please sign in to comment.