diff --git a/data_validation/clients.py b/data_validation/clients.py index 1a03bc26e..a89ec2782 100644 --- a/data_validation/clients.py +++ b/data_validation/clients.py @@ -138,6 +138,23 @@ def get_ibis_table(client, schema_name, table_name, database_name=None): return client.table(table_name, database=schema_name) +def get_ibis_table_schema(client, schema_name, table_name): + """Return Ibis Table Schema for Supplied Client. + + client (IbisClient): Client to use for table + schema_name (str): Schema name of table object + table_name (str): Table name of table object + database_name (str): Database name (generally default is used) + """ + if type(client) in [ + MySQLClient, + PostgreSQLClient + ]: + return client.schema(schema_name).table(table_name).schema() + else: + return client.get_schema(table_name, schema_name) + + def list_schemas(client): """Return a list of schemas in the DB.""" if type(client) in [ diff --git a/data_validation/schema_validation.py b/data_validation/schema_validation.py index 7596ab65a..ac5d668dd 100644 --- a/data_validation/schema_validation.py +++ b/data_validation/schema_validation.py @@ -15,7 +15,7 @@ import datetime import pandas -from data_validation import metadata, consts +from data_validation import metadata, consts, clients class SchemaValidation(object): @@ -33,11 +33,15 @@ def __init__(self, config_manager, run_metadata=None, verbose=False): def execute(self): """ Performs a validation between source and a target schema""" - ibis_source_schema = self.config_manager.source_client.get_schema( - self.config_manager.source_table, self.config_manager.source_schema + ibis_source_schema = clients.get_ibis_table_schema( + self.config_manager.source_client, + self.config_manager.source_schema, + self.config_manager.source_table ) - ibis_target_schema = self.config_manager.target_client.get_schema( - self.config_manager.target_table, self.config_manager.target_schema + ibis_target_schema = clients.get_ibis_table_schema( + self.config_manager.target_client, + self.config_manager.target_schema, + self.config_manager.target_table ) source_fields = {} diff --git a/tests/system/data_sources/test_mysql.py b/tests/system/data_sources/test_mysql.py index 81a27f4c8..ec93b2684 100644 --- a/tests/system/data_sources/test_mysql.py +++ b/tests/system/data_sources/test_mysql.py @@ -49,6 +49,15 @@ consts.CONFIG_FORMAT: "table", } +CONFIG_SCHEMA_VALID = { + consts.CONFIG_SOURCE_CONN: CONN, + consts.CONFIG_TARGET_CONN: CONN, + consts.CONFIG_TYPE: "Column", + consts.CONFIG_SCHEMA_NAME: "guestbook", + consts.CONFIG_TABLE_NAME: "entries", + consts.CONFIG_FORMAT: "table", +} + def test_mysql_count_invalid_host(): try: @@ -60,3 +69,17 @@ def test_mysql_count_invalid_host(): except exceptions.DataClientConnectionFailure: # Local Testing will not work for MySQL pass + + +def test_schema_validation(): + try: + data_validator = data_validation.DataValidation( + CONFIG_SCHEMA_VALID, verbose=False, + ) + df = data_validator.execute() + + for validation in df.to_dict(orient="records"): + assert validation["status"] == consts.VALIDATION_STATUS_SUCCESS + except exceptions.DataClientConnectionFailure: + # Local Testing will not work for MySQL + pass diff --git a/tests/system/data_sources/test_postgres.py b/tests/system/data_sources/test_postgres.py index b51a9c928..cd9f0b5b3 100644 --- a/tests/system/data_sources/test_postgres.py +++ b/tests/system/data_sources/test_postgres.py @@ -27,8 +27,18 @@ # Cloud SQL proxy listens to localhost POSTGRES_HOST = os.getenv("POSTGRES_HOST", "localhost") POSTGRES_PASSWORD = os.getenv("POSTGRES_PASSWORD") +POSTGRES_DATABASE = os.getenv("POSTGRES_DATABASE", "guestbook") PROJECT_ID = os.getenv("PROJECT_ID") +CONN = { + "source_type": "Postgres", + "host": POSTGRES_HOST, + "user": "postgres", + "password": POSTGRES_PASSWORD, + "port": 5432, + "database": POSTGRES_DATABASE, +} + @pytest.fixture def cloud_sql(request): @@ -54,19 +64,10 @@ def cloud_sql(request): def test_postgres_count(cloud_sql): """ Test count validation on Postgres instance """ - conn = { - "source_type": "Postgres", - "host": POSTGRES_HOST, - "user": "postgres", - "password": POSTGRES_PASSWORD, - "port": 5432, - "database": "guestbook", - } - config_count_valid = { # BigQuery Specific Connection Config - consts.CONFIG_SOURCE_CONN: conn, - consts.CONFIG_TARGET_CONN: conn, + consts.CONFIG_SOURCE_CONN: CONN, + consts.CONFIG_TARGET_CONN: CONN, # Validation Type consts.CONFIG_TYPE: "Column", # Configuration Required Depending on Validator Type @@ -86,3 +87,21 @@ def test_postgres_count(cloud_sql): data_validator = data_validation.DataValidation(config_count_valid, verbose=False,) df = data_validator.execute() assert df["source_agg_value"][0] == df["target_agg_value"][0] + + +def test_schema_validation(cloud_sql): + """ Test schema validation on Postgres instance """ + config_count_valid = { + consts.CONFIG_SOURCE_CONN: CONN, + consts.CONFIG_TARGET_CONN: CONN, + consts.CONFIG_TYPE: "Schema", + consts.CONFIG_SCHEMA_NAME: "public", + consts.CONFIG_TABLE_NAME: "entries", + consts.CONFIG_FORMAT: "table", + } + + data_validator = data_validation.DataValidation(config_count_valid, verbose=False,) + df = data_validator.execute() + + for validation in df.to_dict(orient="records"): + assert validation["status"] == consts.VALIDATION_STATUS_SUCCESS