From 409362510d7e405016e87e253e4127a04089fabd Mon Sep 17 00:00:00 2001 From: Fuxin Hao Date: Sat, 23 Apr 2022 20:03:27 +0800 Subject: [PATCH] fix: add get_ibis_table_schema (#410) (#411) --- data_validation/clients.py | 14 +++++++ data_validation/schema_validation.py | 14 ++++--- tests/system/data_sources/test_mysql.py | 24 ++++++++++++ tests/system/data_sources/test_postgres.py | 44 ++++++++++++++++------ 4 files changed, 80 insertions(+), 16 deletions(-) diff --git a/data_validation/clients.py b/data_validation/clients.py index 4eb9e6e34..91f90a6c3 100644 --- a/data_validation/clients.py +++ b/data_validation/clients.py @@ -140,6 +140,20 @@ def get_ibis_table(client, schema_name, table_name, database_name=None): return client.table(table_name, database=schema_name) +def get_ibis_table_schema(client, schema_name, table_name): + """Return Ibis Table Schema for Supplied Client. + + client (IbisClient): Client to use for table + schema_name (str): Schema name of table object + table_name (str): Table name of table object + database_name (str): Database name (generally default is used) + """ + if type(client) in [MySQLClient, PostgreSQLClient]: + return client.schema(schema_name).table(table_name).schema() + else: + return client.get_schema(table_name, schema_name) + + def list_schemas(client): """Return a list of schemas in the DB.""" if type(client) in [ diff --git a/data_validation/schema_validation.py b/data_validation/schema_validation.py index bbdbe4d28..bf6dc12ce 100644 --- a/data_validation/schema_validation.py +++ b/data_validation/schema_validation.py @@ -15,7 +15,7 @@ import datetime import pandas -from data_validation import metadata, consts +from data_validation import metadata, consts, clients class SchemaValidation(object): @@ -33,11 +33,15 @@ def __init__(self, config_manager, run_metadata=None, verbose=False): def execute(self): """Performs a validation between source and a target schema""" - ibis_source_schema = self.config_manager.source_client.get_schema( - self.config_manager.source_table, self.config_manager.source_schema + ibis_source_schema = clients.get_ibis_table_schema( + self.config_manager.source_client, + self.config_manager.source_schema, + self.config_manager.source_table, ) - ibis_target_schema = self.config_manager.target_client.get_schema( - self.config_manager.target_table, self.config_manager.target_schema + ibis_target_schema = clients.get_ibis_table_schema( + self.config_manager.target_client, + self.config_manager.target_schema, + self.config_manager.target_table, ) source_fields = {} diff --git a/tests/system/data_sources/test_mysql.py b/tests/system/data_sources/test_mysql.py index d72e4090a..1973e3f5d 100644 --- a/tests/system/data_sources/test_mysql.py +++ b/tests/system/data_sources/test_mysql.py @@ -49,6 +49,15 @@ consts.CONFIG_FORMAT: "table", } +CONFIG_SCHEMA_VALID = { + consts.CONFIG_SOURCE_CONN: CONN, + consts.CONFIG_TARGET_CONN: CONN, + consts.CONFIG_TYPE: "Column", + consts.CONFIG_SCHEMA_NAME: "guestbook", + consts.CONFIG_TABLE_NAME: "entries", + consts.CONFIG_FORMAT: "table", +} + def test_mysql_count_invalid_host(): try: @@ -61,3 +70,18 @@ def test_mysql_count_invalid_host(): except exceptions.DataClientConnectionFailure: # Local Testing will not work for MySQL pass + + +def test_schema_validation(): + try: + data_validator = data_validation.DataValidation( + CONFIG_SCHEMA_VALID, + verbose=False, + ) + df = data_validator.execute() + + for validation in df.to_dict(orient="records"): + assert validation["status"] == consts.VALIDATION_STATUS_SUCCESS + except exceptions.DataClientConnectionFailure: + # Local Testing will not work for MySQL + pass diff --git a/tests/system/data_sources/test_postgres.py b/tests/system/data_sources/test_postgres.py index c4e6a6d47..26c410349 100644 --- a/tests/system/data_sources/test_postgres.py +++ b/tests/system/data_sources/test_postgres.py @@ -27,8 +27,18 @@ # Cloud SQL proxy listens to localhost POSTGRES_HOST = os.getenv("POSTGRES_HOST", "localhost") POSTGRES_PASSWORD = os.getenv("POSTGRES_PASSWORD") +POSTGRES_DATABASE = os.getenv("POSTGRES_DATABASE", "guestbook") PROJECT_ID = os.getenv("PROJECT_ID") +CONN = { + "source_type": "Postgres", + "host": POSTGRES_HOST, + "user": "postgres", + "password": POSTGRES_PASSWORD, + "port": 5432, + "database": POSTGRES_DATABASE, +} + @pytest.fixture def cloud_sql(request): @@ -54,19 +64,10 @@ def cloud_sql(request): def test_postgres_count(cloud_sql): """Test count validation on Postgres instance""" - conn = { - "source_type": "Postgres", - "host": POSTGRES_HOST, - "user": "postgres", - "password": POSTGRES_PASSWORD, - "port": 5432, - "database": "guestbook", - } - config_count_valid = { # BigQuery Specific Connection Config - consts.CONFIG_SOURCE_CONN: conn, - consts.CONFIG_TARGET_CONN: conn, + consts.CONFIG_SOURCE_CONN: CONN, + consts.CONFIG_TARGET_CONN: CONN, # Validation Type consts.CONFIG_TYPE: "Column", # Configuration Required Depending on Validator Type @@ -103,3 +104,24 @@ def test_postgres_count(cloud_sql): assert df["source_agg_value"].equals(df["target_agg_value"]) assert sorted(list(df["source_agg_value"])) == ["28", "7", "7"] + + +def test_schema_validation(cloud_sql): + """Test schema validation on Postgres instance""" + config_count_valid = { + consts.CONFIG_SOURCE_CONN: CONN, + consts.CONFIG_TARGET_CONN: CONN, + consts.CONFIG_TYPE: "Schema", + consts.CONFIG_SCHEMA_NAME: "public", + consts.CONFIG_TABLE_NAME: "entries", + consts.CONFIG_FORMAT: "table", + } + + data_validator = data_validation.DataValidation( + config_count_valid, + verbose=False, + ) + df = data_validator.execute() + + for validation in df.to_dict(orient="records"): + assert validation["status"] == consts.VALIDATION_STATUS_SUCCESS