From ae07fa429e7c94991c570231b46a54423b2e56e5 Mon Sep 17 00:00:00 2001 From: Prayas Purusottam Date: Fri, 21 Apr 2023 06:19:04 +0530 Subject: [PATCH] fix: Mysql fix to support row hash validations, random row validation, and filter (#812) --- README.md | 4 +- .../query_builder/random_row_builder.py | 2 + tests/system/data_sources/test_mysql.py | 6 - .../ibis_addon/base_sqlalchemy/alchemy.py | 175 +++++++++++++++++- third_party/ibis/ibis_addon/operations.py | 17 ++ 5 files changed, 195 insertions(+), 9 deletions(-) diff --git a/README.md b/README.md index 0987b817a..7ae543393 100644 --- a/README.md +++ b/README.md @@ -14,7 +14,7 @@ perform this task. DVT supports the following validations: * Column validation (count, sum, avg, min, max, group by) -* Row validation (BQ, Hive, Teradata, Oracle, SQL Server, Postgres only) +* Row validation (BQ, Hive, Teradata, Oracle, SQL Server, Postgres, Mysql only) * Schema validation * Custom Query validation * Ad hoc SQL exploration @@ -136,7 +136,7 @@ The [Examples](https://github.com/GoogleCloudPlatform/professional-services-data #### Row Validations -(Note: Row hash validation is currently supported for BigQuery, Teradata, Impala/Hive, Oracle, SQL Server, Postgres, Db2 and Alloy DB. Struct and array data types are not currently supported. +(Note: Row hash validation is currently supported for BigQuery, Teradata, Impala/Hive, Oracle, SQL Server, Postgres, Mysql, Db2 and Alloy DB. Struct and array data types are not currently supported. In addition, please note that SHA256 is not a supported function on Teradata systems. If you wish to perform this comparison on Teradata you will need to [deploy a UDF to perform the conversion](https://github.com/akuroda/teradata-udf-sha2/blob/master/src/sha256.c).) diff --git a/data_validation/query_builder/random_row_builder.py b/data_validation/query_builder/random_row_builder.py index 869b62d57..6e8809e5d 100644 --- a/data_validation/query_builder/random_row_builder.py +++ b/data_validation/query_builder/random_row_builder.py @@ -29,6 +29,7 @@ from ibis.backends.impala.client import ImpalaClient from ibis.backends.pandas.client import PandasClient from ibis.backends.postgres.client import PostgreSQLClient +from ibis.backends.mysql.client import MySQLClient from ibis.expr.signature import Argument as Arg from data_validation import clients from data_validation.query_builder.query_builder import QueryBuilder @@ -53,6 +54,7 @@ PostgreSQLClient: "RANDOM()", clients.MSSQLClient: "NEWID()", clients.DB2Client: "RAND()", + MySQLClient: "RAND()", } diff --git a/tests/system/data_sources/test_mysql.py b/tests/system/data_sources/test_mysql.py index 995663aa9..f39f0cbe2 100644 --- a/tests/system/data_sources/test_mysql.py +++ b/tests/system/data_sources/test_mysql.py @@ -95,9 +95,6 @@ def test_schema_validation(): def test_mysql_row(): """Test row validation on MySQL""" - # This test is disabled. - # When issue-776 is resolved we can remove these comments and the return statement below. - return try: config_row_valid = { consts.CONFIG_SOURCE_CONN: CONN, @@ -487,9 +484,6 @@ def test_column_validation_core_types(mock_conn): return_value=CONN, ) def test_row_validation_core_types(mock_conn): - # This test is disabled. - # When issue-776 is resolved we can remove these comments and the return statement below. - return parser = cli_tools.configure_arg_parser() args = parser.parse_args( [ diff --git a/third_party/ibis/ibis_addon/base_sqlalchemy/alchemy.py b/third_party/ibis/ibis_addon/base_sqlalchemy/alchemy.py index 665bb2263..ff5e40686 100644 --- a/third_party/ibis/ibis_addon/base_sqlalchemy/alchemy.py +++ b/third_party/ibis/ibis_addon/base_sqlalchemy/alchemy.py @@ -21,6 +21,72 @@ import ibis.expr.schema as sch import ibis.expr.operations as ops +import re +import ibis.expr.datatypes as dt +from ibis.backends.base_sqlalchemy.alchemy import AlchemyClient + +from typing import Iterable + +_type_codes = { + 0: "DECIMAL", + 1: "TINY", + 2: "SHORT", + 3: "LONG", + 4: "FLOAT", + 5: "DOUBLE", + 6: "NULL", + 7: "TIMESTAMP", + 8: "LONGLONG", + 9: "INT24", + 10: "DATE", + 11: "TIME", + 12: "DATETIME", + 13: "YEAR", + 15: "VARCHAR", + 16: "BIT", + 245: "JSON", + 246: "NEWDECIMAL", + 247: "ENUM", + 248: "SET", + 249: "TINY_BLOB", + 250: "MEDIUM_BLOB", + 251: "LONG_BLOB", + 252: "BLOB", + 253: "VAR_STRING", + 254: "STRING", + 255: "GEOMETRY", +} + +_type_mapping = { + "DECIMAL": dt.Decimal, + "TINY": dt.Int8, + "SHORT": dt.Int16, + "LONG": dt.Int32, + "FLOAT": dt.Float32, + "DOUBLE": dt.Float64, + "NULL": dt.Null, + "TIMESTAMP": lambda nullable: dt.Timestamp(timezone="UTC", nullable=nullable), + "LONGLONG": dt.Int64, + "INT24": dt.Int32, + "DATE": dt.Date, + "TIME": dt.Time, + "DATETIME": dt.Timestamp, + "YEAR": dt.Int8, + "VARCHAR": dt.String, + "JSON": dt.JSON, + "NEWDECIMAL": dt.Decimal, + "ENUM": dt.String, + "SET": lambda nullable: dt.Set(dt.string, nullable=nullable), + "TINY_BLOB": dt.Binary, + "MEDIUM_BLOB": dt.Binary, + "LONG_BLOB": dt.Binary, + "BLOB": dt.Binary, + "VAR_STRING": dt.String, + "STRING": dt.String, + "GEOMETRY": dt.Geometry, +} +MY_CHARSET_BIN = 63 + def _schema_to_sqlalchemy_columns(schema: sch.Schema): return [sa.column(n, _to_sqla_type(t)) for n, t in schema.items()] @@ -75,5 +141,112 @@ def _format_table_new(self, expr): ctx.set_table(expr, result) return result +class _FieldFlags: + """Flags used to disambiguate field types. + Gaps in the flag numbers are because we do not map in flags that are + of no use in determining the field's type, such as whether the field + is a primary key or not. + """ + + UNSIGNED = 1 << 5 + SET = 1 << 11 + NUM = 1 << 15 + + __slots__ = ("value",) + + def __init__(self, value: int) -> None: + self.value = value + + @property + def is_unsigned(self) -> bool: + return (self.UNSIGNED & self.value) != 0 + + @property + def is_set(self) -> bool: + return (self.SET & self.value) != 0 + + @property + def is_num(self) -> bool: + return (self.NUM & self.value) != 0 + + +def _decimal_length_to_precision(*, length: int, scale: int, is_unsigned: bool) -> int: + return length - (scale > 0) - (not (is_unsigned or not length)) + + +def _type_from_cursor_info(descr, field) -> dt.DataType: + """Construct an ibis type from MySQL field descr and field result metadata. + This method is complex because the MySQL protocol is complex. + Types are not encoded in a self contained way, meaning you need + multiple pieces of information coming from the result set metadata to + determine the most precise type for a field. Even then, the decoding is + not high fidelity in some cases: UUIDs for example are decoded as + strings, because the protocol does not appear to preserve the logical + type, only the physical type. + """ + from pymysql.connections import TEXT_TYPES + + _, type_code, _, _, field_length, scale, _ = descr + flags = _FieldFlags(field.flags) + typename = _type_codes.get(type_code) + if typename is None: + raise NotImplementedError(f"MySQL type code {type_code:d} is not supported") + + if typename in ("DECIMAL", "NEWDECIMAL"): + precision = _decimal_length_to_precision( + length=field_length, + scale=scale, + is_unsigned=flags.is_unsigned, + ) + typ = partial(_type_mapping[typename], precision=precision, scale=scale) + elif typename == "BIT": + if field_length <= 8: + typ = dt.int8 + elif field_length <= 16: + typ = dt.int16 + elif field_length <= 32: + typ = dt.int32 + elif field_length <= 64: + typ = dt.int64 + else: + raise AssertionError('invalid field length for BIT type') + elif flags.is_set: + # sets are limited to strings + typ = dt.Set(dt.string) + elif flags.is_unsigned and flags.is_num: + typ = getattr(dt, f"U{typ.__name__}") + elif type_code in TEXT_TYPES: + # binary text + if field.charsetnr == MY_CHARSET_BIN: + typ = dt.Binary + else: + typ = dt.String + else: + typ = _type_mapping[typename] + + # projection columns are always nullable + return typ(nullable=True) + + +def _metadata(self, query: str): + if ( + re.search(r"^\s*SELECT\s", query, flags=re.MULTILINE | re.IGNORECASE) + is not None + ): + query = f"({query})" + + with self.begin() as con: + result = con.execute(f"SELECT * FROM {query} _ LIMIT 0") + cursor = result.cursor + yield from ( + (field.name, _type_from_cursor_info(descr, field)) + for descr, field in zip(cursor.description, cursor._result.fields) + ) + +def _get_schema_using_query(self, query: str) -> sch.Schema: + """Return an ibis Schema from a backend-specific SQL string.""" + return sch.Schema.from_tuples(self._metadata(query)) -_AlchemyTableSet._format_table = _format_table_new +AlchemyClient._get_schema_using_query = _get_schema_using_query +AlchemyClient._metadata = _metadata +_AlchemyTableSet._format_table = _format_table_new \ No newline at end of file diff --git a/third_party/ibis/ibis_addon/operations.py b/third_party/ibis/ibis_addon/operations.py index 768673e71..2dd4cce54 100644 --- a/third_party/ibis/ibis_addon/operations.py +++ b/third_party/ibis/ibis_addon/operations.py @@ -29,6 +29,7 @@ import ibis.expr.api import ibis.expr.datatypes as dt import ibis.expr.rules as rlz +import ibis.expr.operations as ops from data_validation.clients import _raise_missing_client_error from ibis_bigquery.compiler import ( @@ -36,6 +37,7 @@ BigQueryExprTranslator, STRFTIME_FORMAT_FUNCTIONS as BQ_STRFTIME_FORMAT_FUNCTIONS ) +from ibis.backends.base_sqlalchemy.alchemy import fixed_arity from ibis.expr.operations import Arg, Cast, Comparison, Reduction, Strftime, ValueOp from ibis.expr.types import BinaryValue, IntegerColumn, StringValue, NumericValue, TemporalValue from ibis.backends.impala.compiler import ImpalaExprTranslator @@ -47,6 +49,7 @@ from third_party.ibis.ibis_teradata.compiler import TeradataExprTranslator from third_party.ibis.ibis_mssql.compiler import MSSQLExprTranslator from ibis.backends.postgres.compiler import PostgreSQLExprTranslator +from ibis.backends.mysql.compiler import MySQLExprTranslator # avoid errors if Db2 is not installed and not needed try: @@ -220,6 +223,12 @@ def sa_format_hashbytes_oracle(translator, expr): hash_func = sa.func.standard_hash(compiled_arg, sa.sql.literal_column("'SHA256'")) return sa.func.lower(hash_func) +def sa_format_hashbytes_mysql(translator, expr): + arg, how = expr.op().args + compiled_arg = translator.translate(arg) + hash_func = sa.func.sha2(compiled_arg, sa.sql.literal_column("'256'")) + return hash_func + def sa_format_hashbytes_db2(translator, expr): arg, how = expr.op().args compiled_arg = translator.translate(arg) @@ -240,6 +249,10 @@ def sa_format_to_char(translator, expr): compiled_fmt = translator.translate(fmt) return sa.func.to_char(compiled_arg, compiled_fmt) +def sa_format_to_stringjoin(translator, expr): + sep, elements = expr.op().args + return sa.func.concat_ws(translator.translate(sep), *map(translator.translate, elements)) + def sa_cast_postgres(t, expr): arg, typ = expr.op().args @@ -306,6 +319,10 @@ def sa_cast_postgres(t, expr): PostgreSQLExprTranslator._registry[RawSQL] = sa_format_raw_sql PostgreSQLExprTranslator._registry[ToChar] = sa_format_to_char PostgreSQLExprTranslator._registry[Cast] = sa_cast_postgres +MySQLExprTranslator._registry[RawSQL] = sa_format_raw_sql +MySQLExprTranslator._registry[HashBytes] = sa_format_hashbytes_mysql +MySQLExprTranslator._registry[ops.IfNull] = fixed_arity(sa.func.ifnull, 2) +MySQLExprTranslator._registry[ops.StringJoin] = sa_format_to_stringjoin if DB2ExprTranslator: #check if Db2 driver is loaded DB2ExprTranslator._registry[HashBytes] = sa_format_hashbytes_db2