Skip to content

Commit

Permalink
fix: spanner hash function to return string instead of bytes (#1062)
Browse files Browse the repository at this point in the history
  • Loading branch information
nehanene15 committed Nov 30, 2023
1 parent dd62baa commit 722dff9
Show file tree
Hide file tree
Showing 2 changed files with 150 additions and 123 deletions.
271 changes: 148 additions & 123 deletions tests/system/data_sources/test_spanner.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,65 +21,45 @@

from data_validation import cli_tools, consts, data_validation
from data_validation import __main__ as main
from data_validation import state_manager
from third_party.ibis.ibis_cloud_spanner.tests import conftest

from tests.system.data_sources.test_bigquery import BQ_CONN


PROJECT_ID = os.environ["PROJECT_ID"]
SPANNER_CONN_NAME = "spanner-integration-test"
CLI_FIND_TABLES_ARGS = [
"find-tables",
"--source-conn",
SPANNER_CONN_NAME,
"mock-conn",
"--target-conn",
SPANNER_CONN_NAME,
"mock-conn",
]

SPANNER_CONN = {
"source_type": "Spanner",
"project_id": PROJECT_ID,
"instance_id": "span1",
"database_id": "pso_data_validator",
}

# Copy text fixtures from Spanner Ibis client tests, because it's a bit verbose
# to create a Spanner instance and load data to it. Relevant test fixtures can
# be copied here after clients are upstreamed to Ibis.
spanner_client = conftest.spanner_client
instance_id = conftest.instance_id
database_id = conftest.database_id


@pytest.fixture
def spanner_connection_config(instance_id, database_id):
return {
"source_type": "Spanner",
"project_id": os.environ["PROJECT_ID"],
"instance_id": instance_id,
"database_id": database_id,
}
def mock_get_connection_config(*args):
if args[1] in ("spanner-conn", "mock-conn"):
return SPANNER_CONN
elif args[1] == "bq-conn":
return BQ_CONN


@pytest.fixture
def spanner_connection_args(instance_id, database_id):
return [
"connections",
"add",
"--connection-name",
SPANNER_CONN_NAME,
"Spanner",
"--project-id",
os.environ["PROJECT_ID"],
"--instance-id",
instance_id,
"--database-id",
database_id,
]


@pytest.fixture
def count_config(spanner_connection_config, database_id):
def count_config():
return {
# Connection Name
consts.CONFIG_SOURCE_CONN: spanner_connection_config,
consts.CONFIG_TARGET_CONN: spanner_connection_config,
consts.CONFIG_SOURCE_CONN: SPANNER_CONN,
consts.CONFIG_TARGET_CONN: SPANNER_CONN,
# Validation Type
consts.CONFIG_TYPE: "Column",
# Configuration Required Depending on Validator Type
consts.CONFIG_SCHEMA_NAME: database_id,
consts.CONFIG_SCHEMA_NAME: "pso_data_validator",
consts.CONFIG_TABLE_NAME: "functional_alltypes",
consts.CONFIG_GROUPED_COLUMNS: [],
consts.CONFIG_AGGREGATES: [
Expand Down Expand Up @@ -120,15 +100,15 @@ def count_config(spanner_connection_config, database_id):


@pytest.fixture
def grouped_config(spanner_connection_config, database_id):
def grouped_config():
return {
# Connection Name
consts.CONFIG_SOURCE_CONN: spanner_connection_config,
consts.CONFIG_TARGET_CONN: spanner_connection_config,
consts.CONFIG_SOURCE_CONN: SPANNER_CONN,
consts.CONFIG_TARGET_CONN: SPANNER_CONN,
# Validation Type
consts.CONFIG_TYPE: "GroupedColumn",
# Configuration Required Depending on Validator Type
consts.CONFIG_SCHEMA_NAME: database_id,
consts.CONFIG_SCHEMA_NAME: "pso_data_validator",
consts.CONFIG_TABLE_NAME: "functional_alltypes",
consts.CONFIG_AGGREGATES: [
{
Expand Down Expand Up @@ -208,130 +188,179 @@ def test_grouped_count_validator(grouped_config):
assert row["source_agg_value"] == row["target_agg_value"]


def test_cli_find_tables(spanner_connection_args, database_id):
_store_spanner_conn(spanner_connection_args)

@mock.patch(
"data_validation.state_manager.StateManager.get_connection_config",
new=mock_get_connection_config,
)
def test_cli_find_tables():
parser = cli_tools.configure_arg_parser()
args = parser.parse_args(CLI_FIND_TABLES_ARGS)
tables_json = main.find_tables_using_string_matching(args)
tables = json.loads(tables_json)
assert isinstance(tables_json, str)
assert {
"schema_name": database_id,
"schema_name": "pso_data_validator",
"table_name": "array_table",
"target_schema_name": database_id,
"target_schema_name": "pso_data_validator",
"target_table_name": "array_table",
} in tables
assert {
"schema_name": database_id,
"schema_name": "pso_data_validator",
"table_name": "functional_alltypes",
"target_schema_name": database_id,
"target_schema_name": "pso_data_validator",
"target_table_name": "functional_alltypes",
} in tables
assert {
"schema_name": database_id,
"schema_name": "pso_data_validator",
"table_name": "students_pointer",
"target_schema_name": database_id,
"target_schema_name": "pso_data_validator",
"target_table_name": "students_pointer",
} in tables

_remove_spanner_conn()


def _store_spanner_conn(spanner_connection_args):
@mock.patch(
"data_validation.state_manager.StateManager.get_connection_config",
new=mock_get_connection_config,
)
def test_schema_validation_core_types():
parser = cli_tools.configure_arg_parser()
mock_args = parser.parse_args(spanner_connection_args)
main.run_connections(mock_args)


def _remove_spanner_conn():
mgr = state_manager.StateManager()
file_path = mgr._get_connection_path(SPANNER_CONN_NAME)
os.remove(file_path)
args = parser.parse_args(
[
"validate",
"schema",
"-sc=mock-conn",
"-tc=mock-conn",
"-tbls=pso_data_validator.dvt_core_types",
"--filter-status=fail",
]
)
config_managers = main.build_config_managers_from_args(args)
assert len(config_managers) == 1
config_manager = config_managers[0]
validator = data_validation.DataValidation(config_manager.config, verbose=False)
df = validator.execute()
# With filter on failures the data frame should be empty
assert len(df) == 0


def test_schema_validation_core_types(spanner_connection_config, database_id):
@mock.patch(
"data_validation.state_manager.StateManager.get_connection_config",
new=mock_get_connection_config,
)
def test_column_validation_core_types():
parser = cli_tools.configure_arg_parser()
args = parser.parse_args(
[
"validate",
"schema",
"column",
"-sc=mock-conn",
"-tc=mock-conn",
f"-tbls={database_id}.dvt_core_types",
"-tbls=pso_data_validator.dvt_core_types",
"--filter-status=fail",
"--grouped-columns=col_varchar_30",
"--sum=*",
"--min=*",
"--max=*",
]
)
with mock.patch(
"data_validation.state_manager.StateManager.get_connection_config",
return_value=spanner_connection_config,
):
config_managers = main.build_config_managers_from_args(args)
assert len(config_managers) == 1
config_manager = config_managers[0]
validator = data_validation.DataValidation(config_manager.config, verbose=False)
df = validator.execute()
# With filter on failures the data frame should be empty
assert len(df) == 0


def test_column_validation_core_types(spanner_connection_config, database_id):
config_managers = main.build_config_managers_from_args(args)
assert len(config_managers) == 1
config_manager = config_managers[0]
validator = data_validation.DataValidation(config_manager.config, verbose=False)
df = validator.execute()
# With filter on failures the data frame should be empty
assert len(df) == 0


@mock.patch(
"data_validation.state_manager.StateManager.get_connection_config",
new=mock_get_connection_config,
)
def test_column_validation_core_types_to_bigquery():
parser = cli_tools.configure_arg_parser()
args = parser.parse_args(
[
"validate",
"column",
"-sc=mock-conn",
"-tc=mock-conn",
f"-tbls={database_id}.dvt_core_types",
"-tc=bq-conn",
"-tbls=pso_data_validator.dvt_core_types",
"--filter-status=fail",
"--grouped-columns=col_varchar_30",
"--sum=*",
"--min=*",
"--max=*",
]
)
with mock.patch(
"data_validation.state_manager.StateManager.get_connection_config",
return_value=spanner_connection_config,
):
config_managers = main.build_config_managers_from_args(args)
assert len(config_managers) == 1
config_manager = config_managers[0]
validator = data_validation.DataValidation(config_manager.config, verbose=False)
df = validator.execute()
# With filter on failures the data frame should be empty
assert len(df) == 0


def test_row_validation_core_types(spanner_connection_config, database_id):
config_managers = main.build_config_managers_from_args(args)
assert len(config_managers) == 1
config_manager = config_managers[0]
validator = data_validation.DataValidation(config_manager.config, verbose=False)
df = validator.execute()
# With filter on failures the data frame should be empty
assert len(df) == 0


@mock.patch(
"data_validation.state_manager.StateManager.get_connection_config",
new=mock_get_connection_config,
)
def test_row_validation_core_types():
parser = cli_tools.configure_arg_parser()
args = parser.parse_args(
[
"validate",
"row",
"-sc=mock-conn",
"-tc=mock-conn",
f"-tbls={database_id}.dvt_core_types",
"-tbls=pso_data_validator.dvt_core_types",
"--primary-keys=id",
"--filter-status=fail",
"--hash=*",
]
)
with mock.patch(
"data_validation.state_manager.StateManager.get_connection_config",
return_value=spanner_connection_config,
):
config_managers = main.build_config_managers_from_args(args)
assert len(config_managers) == 1
config_manager = config_managers[0]
validator = data_validation.DataValidation(config_manager.config, verbose=False)
df = validator.execute()
# With filter on failures the data frame should be empty
assert len(df) == 0


def test_custom_query_validation_core_types(spanner_connection_config, database_id):
config_managers = main.build_config_managers_from_args(args)
assert len(config_managers) == 1
config_manager = config_managers[0]
validator = data_validation.DataValidation(config_manager.config, verbose=False)
df = validator.execute()
# With filter on failures the data frame should be empty
assert len(df) == 0


@mock.patch(
"data_validation.state_manager.StateManager.get_connection_config",
new=mock_get_connection_config,
)
def test_row_validation_core_types_to_bigquery():
# TODO Change --hash to include col_date and col_datetime when issue-1061 is complete.
parser = cli_tools.configure_arg_parser()
args = parser.parse_args(
[
"validate",
"row",
"-sc=mock-conn",
"-tc=bq-conn",
"-tbls=pso_data_validator.dvt_core_types",
"--primary-keys=id",
"--filter-status=fail",
"--hash=col_int8,col_int16,col_int32,col_int64,col_dec_20,col_dec_38,col_dec_10_2,col_float32,col_float64,col_varchar_30,col_string,col_tstz",
]
)
config_managers = main.build_config_managers_from_args(args)
assert len(config_managers) == 1
config_manager = config_managers[0]
validator = data_validation.DataValidation(config_manager.config, verbose=False)
df = validator.execute()
# With filter on failures the data frame should be empty
assert len(df) == 0


@mock.patch(
"data_validation.state_manager.StateManager.get_connection_config",
new=mock_get_connection_config,
)
def test_custom_query_validation_core_types():
"""Spanner to Spanner dvt_core_types custom-query validation"""
parser = cli_tools.configure_arg_parser()
args = parser.parse_args(
Expand All @@ -347,14 +376,10 @@ def test_custom_query_validation_core_types(spanner_connection_config, database_
"--count=*",
]
)
with mock.patch(
"data_validation.state_manager.StateManager.get_connection_config",
return_value=spanner_connection_config,
):
config_managers = main.build_config_managers_from_args(args)
assert len(config_managers) == 1
config_manager = config_managers[0]
validator = data_validation.DataValidation(config_manager.config, verbose=False)
df = validator.execute()
# With filter on failures the data frame should be empty
assert len(df) == 0
config_managers = main.build_config_managers_from_args(args)
assert len(config_managers) == 1
config_manager = config_managers[0]
validator = data_validation.DataValidation(config_manager.config, verbose=False)
df = validator.execute()
# With filter on failures the data frame should be empty
assert len(df) == 0
2 changes: 2 additions & 0 deletions third_party/ibis/ibis_addon/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,6 +474,8 @@ def _bigquery_field_to_ibis_dtype(field):
Db2ExprTranslator._registry[BinaryLength] = sa_format_binary_length

SpannerExprTranslator._registry[RawSQL] = format_raw_sql
SpannerExprTranslator._registry[HashBytes] = format_hashbytes_bigquery
SpannerExprTranslator._registry[BinaryLength] = sa_format_binary_length

if TeradataExprTranslator:
TeradataExprTranslator._registry[RawSQL] = format_raw_sql
Expand Down

0 comments on commit 722dff9

Please sign in to comment.