From 100b3eabed5ca83245e10e40950e725155332dcd Mon Sep 17 00:00:00 2001 From: Neha Nene Date: Thu, 10 Feb 2022 13:55:44 -0600 Subject: [PATCH] fix: supporting non default schemas for mssql (#365) * fix: supporting non default schemas for mssql * fix:updated MSSQL client instantiation * fix: typo --- data_validation/clients.py | 24 +++++++++++---- .../deploy_cloudsql/mssql_schema.sql | 29 +++++++++++++++++++ tests/system/data_sources/test_sql_server.py | 4 +-- 3 files changed, 49 insertions(+), 8 deletions(-) create mode 100644 tests/system/data_sources/deploy_cloudsql/mssql_schema.sql diff --git a/data_validation/clients.py b/data_validation/clients.py index 6a1068fd3..1a03bc26e 100644 --- a/data_validation/clients.py +++ b/data_validation/clients.py @@ -70,9 +70,9 @@ def get_client_call(*args, **kwargs): OracleClient = _raise_missing_client_error("pip install cx_Oracle") try: - from third_party.ibis.ibis_mssql import connect as mssql_connect + from third_party.ibis.ibis_mssql.client import MSSQLClient except Exception: - mssql_connect = _raise_missing_client_error("pip install pyodbc") + MSSQLClient = _raise_missing_client_error("pip install pyodbc") try: from third_party.ibis.ibis_snowflake.client import ( @@ -126,7 +126,11 @@ def get_ibis_table(client, schema_name, table_name, database_name=None): table_name (str): Table name of table object database_name (str): Database name (generally default is used) """ - if type(client) in [OracleClient, PostgreSQLClient]: + if type(client) in [ + OracleClient, + PostgreSQLClient, + MSSQLClient, + ]: return client.table(table_name, database=database_name, schema=schema_name) elif type(client) in [PandasClient]: return client.table(table_name, schema=schema_name) @@ -136,7 +140,11 @@ def get_ibis_table(client, schema_name, table_name, database_name=None): def list_schemas(client): """Return a list of schemas in the DB.""" - if type(client) in [OracleClient, PostgreSQLClient]: + if type(client) in [ + OracleClient, + PostgreSQLClient, + MSSQLClient, + ]: return client.list_schemas() elif hasattr(client, "list_databases"): return client.list_databases() @@ -146,7 +154,11 @@ def list_schemas(client): def list_tables(client, schema_name): """Return a list of tables in the DB schema.""" - if type(client) in [OracleClient, PostgreSQLClient]: + if type(client) in [ + OracleClient, + PostgreSQLClient, + MSSQLClient, + ]: return client.list_tables(schema=schema_name) elif schema_name: return client.list_tables(database=schema_name) @@ -220,7 +232,7 @@ def get_data_client(connection_config): "Postgres": PostgreSQLClient, "Redshift": PostgreSQLClient, "Teradata": TeradataClient, - "MSSQL": mssql_connect, + "MSSQL": MSSQLClient, "Snowflake": snowflake_connect, "Spanner": spanner_connect, } diff --git a/tests/system/data_sources/deploy_cloudsql/mssql_schema.sql b/tests/system/data_sources/deploy_cloudsql/mssql_schema.sql new file mode 100644 index 000000000..f1739de0f --- /dev/null +++ b/tests/system/data_sources/deploy_cloudsql/mssql_schema.sql @@ -0,0 +1,29 @@ +-- Copyright 2020 Google LLC +-- +-- Licensed under the Apache License, Version 2.0 (the "License"); +-- you may not use this file except in compliance with the License. +-- You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. + +CREATE SCHEMA Sales; + +DROP TABLE IF EXISTS Sales.customers; +CREATE TABLE Sales.customers ( +customer VARCHAR(255), +content VARCHAR(255), +entryID INT NOT NULL IDENTITY PRIMARY KEY +); +INSERT INTO Sales.customers(customer, content) VALUES ('Madeline', 'Arrived'); +INSERT INTO Sales.customers(customer, content) values ('Annie', 'Me too!'); +INSERT INTO Sales.customers(customer, content) values ('Bob', 'More data coming'); +INSERT INTO Sales.customers(customer, content) values ('Joe', 'Me too!'); +INSERT INTO Sales.customers(customer, content) values ('John', 'Here!'); +INSERT INTO Sales.customers(customer, content) values ('Alex', 'Me too!'); +INSERT INTO Sales.customers(customer, content) values ('Zoe', 'Same!'); diff --git a/tests/system/data_sources/test_sql_server.py b/tests/system/data_sources/test_sql_server.py index da11768fe..757213651 100644 --- a/tests/system/data_sources/test_sql_server.py +++ b/tests/system/data_sources/test_sql_server.py @@ -25,7 +25,7 @@ # https://cloud.google.com/sql/docs/sqlserver/connect-admin-proxy # Cloud SQL Proxy listens on localhost -SQL_SERVER_HOST = os.getenv("SQL_SERVER_HOST", "localhost") +SQL_SERVER_HOST = os.getenv("SQL_SERVER_HOST", "127.0.0.1") SQL_SERVER_USER = os.getenv("SQL_SERVER_USER", "sqlserver") SQL_SERVER_PASSWORD = os.getenv("SQL_SERVER_PASSWORD") PROJECT_ID = os.getenv("PROJECT_ID") @@ -71,7 +71,7 @@ def test_sql_server_count(cloud_sql): # Validation Type consts.CONFIG_TYPE: "Column", # Configuration Required Depending on Validator Type - consts.CONFIG_SCHEMA_NAME: "guestbook", + consts.CONFIG_SCHEMA_NAME: "dbo", consts.CONFIG_TABLE_NAME: "entries", consts.CONFIG_AGGREGATES: [ {