Skip to content

Commit

Permalink
fix: Mysql fix to support row hash validations, random row validation…
Browse files Browse the repository at this point in the history
…, and filter (#812)
  • Loading branch information
kanhaPrayas committed Apr 21, 2023
1 parent 056275b commit ae07fa4
Show file tree
Hide file tree
Showing 5 changed files with 195 additions and 9 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ perform this task.

DVT supports the following validations:
* Column validation (count, sum, avg, min, max, group by)
* Row validation (BQ, Hive, Teradata, Oracle, SQL Server, Postgres only)
* Row validation (BQ, Hive, Teradata, Oracle, SQL Server, Postgres, Mysql only)
* Schema validation
* Custom Query validation
* Ad hoc SQL exploration
Expand Down Expand Up @@ -136,7 +136,7 @@ The [Examples](https://github.com/GoogleCloudPlatform/professional-services-data

#### Row Validations

(Note: Row hash validation is currently supported for BigQuery, Teradata, Impala/Hive, Oracle, SQL Server, Postgres, Db2 and Alloy DB. Struct and array data types are not currently supported.
(Note: Row hash validation is currently supported for BigQuery, Teradata, Impala/Hive, Oracle, SQL Server, Postgres, Mysql, Db2 and Alloy DB. Struct and array data types are not currently supported.
In addition, please note that SHA256 is not a supported function on Teradata systems.
If you wish to perform this comparison on Teradata you will need to
[deploy a UDF to perform the conversion](https://github.com/akuroda/teradata-udf-sha2/blob/master/src/sha256.c).)
Expand Down
2 changes: 2 additions & 0 deletions data_validation/query_builder/random_row_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from ibis.backends.impala.client import ImpalaClient
from ibis.backends.pandas.client import PandasClient
from ibis.backends.postgres.client import PostgreSQLClient
from ibis.backends.mysql.client import MySQLClient
from ibis.expr.signature import Argument as Arg
from data_validation import clients
from data_validation.query_builder.query_builder import QueryBuilder
Expand All @@ -53,6 +54,7 @@
PostgreSQLClient: "RANDOM()",
clients.MSSQLClient: "NEWID()",
clients.DB2Client: "RAND()",
MySQLClient: "RAND()",
}


Expand Down
6 changes: 0 additions & 6 deletions tests/system/data_sources/test_mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,6 @@ def test_schema_validation():

def test_mysql_row():
"""Test row validation on MySQL"""
# This test is disabled.
# When issue-776 is resolved we can remove these comments and the return statement below.
return
try:
config_row_valid = {
consts.CONFIG_SOURCE_CONN: CONN,
Expand Down Expand Up @@ -487,9 +484,6 @@ def test_column_validation_core_types(mock_conn):
return_value=CONN,
)
def test_row_validation_core_types(mock_conn):
# This test is disabled.
# When issue-776 is resolved we can remove these comments and the return statement below.
return
parser = cli_tools.configure_arg_parser()
args = parser.parse_args(
[
Expand Down
175 changes: 174 additions & 1 deletion third_party/ibis/ibis_addon/base_sqlalchemy/alchemy.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,72 @@
import ibis.expr.schema as sch
import ibis.expr.operations as ops

import re
import ibis.expr.datatypes as dt
from ibis.backends.base_sqlalchemy.alchemy import AlchemyClient

from typing import Iterable

_type_codes = {
0: "DECIMAL",
1: "TINY",
2: "SHORT",
3: "LONG",
4: "FLOAT",
5: "DOUBLE",
6: "NULL",
7: "TIMESTAMP",
8: "LONGLONG",
9: "INT24",
10: "DATE",
11: "TIME",
12: "DATETIME",
13: "YEAR",
15: "VARCHAR",
16: "BIT",
245: "JSON",
246: "NEWDECIMAL",
247: "ENUM",
248: "SET",
249: "TINY_BLOB",
250: "MEDIUM_BLOB",
251: "LONG_BLOB",
252: "BLOB",
253: "VAR_STRING",
254: "STRING",
255: "GEOMETRY",
}

_type_mapping = {
"DECIMAL": dt.Decimal,
"TINY": dt.Int8,
"SHORT": dt.Int16,
"LONG": dt.Int32,
"FLOAT": dt.Float32,
"DOUBLE": dt.Float64,
"NULL": dt.Null,
"TIMESTAMP": lambda nullable: dt.Timestamp(timezone="UTC", nullable=nullable),
"LONGLONG": dt.Int64,
"INT24": dt.Int32,
"DATE": dt.Date,
"TIME": dt.Time,
"DATETIME": dt.Timestamp,
"YEAR": dt.Int8,
"VARCHAR": dt.String,
"JSON": dt.JSON,
"NEWDECIMAL": dt.Decimal,
"ENUM": dt.String,
"SET": lambda nullable: dt.Set(dt.string, nullable=nullable),
"TINY_BLOB": dt.Binary,
"MEDIUM_BLOB": dt.Binary,
"LONG_BLOB": dt.Binary,
"BLOB": dt.Binary,
"VAR_STRING": dt.String,
"STRING": dt.String,
"GEOMETRY": dt.Geometry,
}
MY_CHARSET_BIN = 63


def _schema_to_sqlalchemy_columns(schema: sch.Schema):
return [sa.column(n, _to_sqla_type(t)) for n, t in schema.items()]
Expand Down Expand Up @@ -75,5 +141,112 @@ def _format_table_new(self, expr):
ctx.set_table(expr, result)
return result

class _FieldFlags:
"""Flags used to disambiguate field types.
Gaps in the flag numbers are because we do not map in flags that are
of no use in determining the field's type, such as whether the field
is a primary key or not.
"""

UNSIGNED = 1 << 5
SET = 1 << 11
NUM = 1 << 15

__slots__ = ("value",)

def __init__(self, value: int) -> None:
self.value = value

@property
def is_unsigned(self) -> bool:
return (self.UNSIGNED & self.value) != 0

@property
def is_set(self) -> bool:
return (self.SET & self.value) != 0

@property
def is_num(self) -> bool:
return (self.NUM & self.value) != 0


def _decimal_length_to_precision(*, length: int, scale: int, is_unsigned: bool) -> int:
return length - (scale > 0) - (not (is_unsigned or not length))


def _type_from_cursor_info(descr, field) -> dt.DataType:
"""Construct an ibis type from MySQL field descr and field result metadata.
This method is complex because the MySQL protocol is complex.
Types are not encoded in a self contained way, meaning you need
multiple pieces of information coming from the result set metadata to
determine the most precise type for a field. Even then, the decoding is
not high fidelity in some cases: UUIDs for example are decoded as
strings, because the protocol does not appear to preserve the logical
type, only the physical type.
"""
from pymysql.connections import TEXT_TYPES

_, type_code, _, _, field_length, scale, _ = descr
flags = _FieldFlags(field.flags)
typename = _type_codes.get(type_code)
if typename is None:
raise NotImplementedError(f"MySQL type code {type_code:d} is not supported")

if typename in ("DECIMAL", "NEWDECIMAL"):
precision = _decimal_length_to_precision(
length=field_length,
scale=scale,
is_unsigned=flags.is_unsigned,
)
typ = partial(_type_mapping[typename], precision=precision, scale=scale)
elif typename == "BIT":
if field_length <= 8:
typ = dt.int8
elif field_length <= 16:
typ = dt.int16
elif field_length <= 32:
typ = dt.int32
elif field_length <= 64:
typ = dt.int64
else:
raise AssertionError('invalid field length for BIT type')
elif flags.is_set:
# sets are limited to strings
typ = dt.Set(dt.string)
elif flags.is_unsigned and flags.is_num:
typ = getattr(dt, f"U{typ.__name__}")
elif type_code in TEXT_TYPES:
# binary text
if field.charsetnr == MY_CHARSET_BIN:
typ = dt.Binary
else:
typ = dt.String
else:
typ = _type_mapping[typename]

# projection columns are always nullable
return typ(nullable=True)


def _metadata(self, query: str):
if (
re.search(r"^\s*SELECT\s", query, flags=re.MULTILINE | re.IGNORECASE)
is not None
):
query = f"({query})"

with self.begin() as con:
result = con.execute(f"SELECT * FROM {query} _ LIMIT 0")
cursor = result.cursor
yield from (
(field.name, _type_from_cursor_info(descr, field))
for descr, field in zip(cursor.description, cursor._result.fields)
)

def _get_schema_using_query(self, query: str) -> sch.Schema:
"""Return an ibis Schema from a backend-specific SQL string."""
return sch.Schema.from_tuples(self._metadata(query))

_AlchemyTableSet._format_table = _format_table_new
AlchemyClient._get_schema_using_query = _get_schema_using_query
AlchemyClient._metadata = _metadata
_AlchemyTableSet._format_table = _format_table_new
17 changes: 17 additions & 0 deletions third_party/ibis/ibis_addon/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,15 @@
import ibis.expr.api
import ibis.expr.datatypes as dt
import ibis.expr.rules as rlz
import ibis.expr.operations as ops

from data_validation.clients import _raise_missing_client_error
from ibis_bigquery.compiler import (
reduction as bq_reduction,
BigQueryExprTranslator,
STRFTIME_FORMAT_FUNCTIONS as BQ_STRFTIME_FORMAT_FUNCTIONS
)
from ibis.backends.base_sqlalchemy.alchemy import fixed_arity
from ibis.expr.operations import Arg, Cast, Comparison, Reduction, Strftime, ValueOp
from ibis.expr.types import BinaryValue, IntegerColumn, StringValue, NumericValue, TemporalValue
from ibis.backends.impala.compiler import ImpalaExprTranslator
Expand All @@ -47,6 +49,7 @@
from third_party.ibis.ibis_teradata.compiler import TeradataExprTranslator
from third_party.ibis.ibis_mssql.compiler import MSSQLExprTranslator
from ibis.backends.postgres.compiler import PostgreSQLExprTranslator
from ibis.backends.mysql.compiler import MySQLExprTranslator

# avoid errors if Db2 is not installed and not needed
try:
Expand Down Expand Up @@ -220,6 +223,12 @@ def sa_format_hashbytes_oracle(translator, expr):
hash_func = sa.func.standard_hash(compiled_arg, sa.sql.literal_column("'SHA256'"))
return sa.func.lower(hash_func)

def sa_format_hashbytes_mysql(translator, expr):
arg, how = expr.op().args
compiled_arg = translator.translate(arg)
hash_func = sa.func.sha2(compiled_arg, sa.sql.literal_column("'256'"))
return hash_func

def sa_format_hashbytes_db2(translator, expr):
arg, how = expr.op().args
compiled_arg = translator.translate(arg)
Expand All @@ -240,6 +249,10 @@ def sa_format_to_char(translator, expr):
compiled_fmt = translator.translate(fmt)
return sa.func.to_char(compiled_arg, compiled_fmt)

def sa_format_to_stringjoin(translator, expr):
sep, elements = expr.op().args
return sa.func.concat_ws(translator.translate(sep), *map(translator.translate, elements))


def sa_cast_postgres(t, expr):
arg, typ = expr.op().args
Expand Down Expand Up @@ -306,6 +319,10 @@ def sa_cast_postgres(t, expr):
PostgreSQLExprTranslator._registry[RawSQL] = sa_format_raw_sql
PostgreSQLExprTranslator._registry[ToChar] = sa_format_to_char
PostgreSQLExprTranslator._registry[Cast] = sa_cast_postgres
MySQLExprTranslator._registry[RawSQL] = sa_format_raw_sql
MySQLExprTranslator._registry[HashBytes] = sa_format_hashbytes_mysql
MySQLExprTranslator._registry[ops.IfNull] = fixed_arity(sa.func.ifnull, 2)
MySQLExprTranslator._registry[ops.StringJoin] = sa_format_to_stringjoin

if DB2ExprTranslator: #check if Db2 driver is loaded
DB2ExprTranslator._registry[HashBytes] = sa_format_hashbytes_db2

0 comments on commit ae07fa4

Please sign in to comment.