diff --git a/tests/resources/bigquery_test_tables.sql b/tests/resources/bigquery_test_tables.sql index f92a0aad6..ef3665808 100644 --- a/tests/resources/bigquery_test_tables.sql +++ b/tests/resources/bigquery_test_tables.sql @@ -51,3 +51,11 @@ INSERT INTO `pso_data_validator`.`dvt_core_types` VALUES ,'Hello DVT','C ','Hello DVT' ,DATE '1970-01-03',DATETIME '1970-01-03 00:00:03' ,TIMESTAMP '1970-01-03 00:00:03-03:00'); + +DROP TABLE `pso_data_validator`.`dvt_null_not_null`; +CREATE TABLE `pso_data_validator`.`dvt_null_not_null` +( col_nn DATETIME NOT NULL +, col_nullable DATETIME +, col_src_nn_trg_n DATETIME +, col_src_n_trg_nn DATETIME NOT NULL +) OPTIONS (description='Nullable integration test table, BigQuery is assumed to be a DVT target (not source)'); diff --git a/tests/resources/hive_test_tables.sql b/tests/resources/hive_test_tables.sql index c6e4c4ab6..038ca9447 100644 --- a/tests/resources/hive_test_tables.sql +++ b/tests/resources/hive_test_tables.sql @@ -48,3 +48,12 @@ INSERT INTO pso_data_validator.dvt_core_types VALUES ,12345678901234567890,1234567890123456789012345,123.33,123456.3,12345678.3 ,'Hello DVT','C ','Hello DVT' ,'1970-01-03','1970-01-03 00:00:03','1970-01-03 03:00:03'); + + +DROP TABLE `pso_data_validator`.`dvt_null_not_null`; +CREATE TABLE `pso_data_validator`.`dvt_null_not_null` +( col_nn timestamp NOT NULL +, col_nullable timestamp +, col_src_nn_trg_n timestamp NOT NULL +, col_src_n_trg_nn timestamp +) COMMENT 'Nullable integration test table, Hive is assumed to be a DVT source (not target).'; diff --git a/tests/resources/mysql_test_tables.sql b/tests/resources/mysql_test_tables.sql index a56d38b52..44f17f0c5 100644 --- a/tests/resources/mysql_test_tables.sql +++ b/tests/resources/mysql_test_tables.sql @@ -89,3 +89,11 @@ INSERT INTO `pso_data_validator`.`dvt_core_types` VALUES ,12345678901234567890,1234567890123456789012345,123.33,123456.3,12345678.3 ,'Hello DVT','C ','Hello DVT' ,'1970-01-03','1970-01-03 00:00:03','1970-01-03 03:00:03'); + +DROP TABLE `pso_data_validator`.`dvt_null_not_null`; +CREATE TABLE `pso_data_validator`.`dvt_null_not_null` +( col_nn datetime(0) NOT NULL +, col_nullable datetime(0) +, col_src_nn_trg_n datetime(0) NOT NULL +, col_src_n_trg_nn datetime(0) +) COMMENT 'Nullable integration test table, MySQL is assumed to be a DVT source (not target).'; diff --git a/tests/resources/oracle_test_tables.sql b/tests/resources/oracle_test_tables.sql index 1a860b6e2..e1c8b8cef 100644 --- a/tests/resources/oracle_test_tables.sql +++ b/tests/resources/oracle_test_tables.sql @@ -67,3 +67,12 @@ INSERT INTO pso_data_validator.dvt_core_types VALUES ,DATE'1970-01-03',TIMESTAMP'1970-01-03 00:00:03' ,to_timestamp_tz('1970-01-03 00:00:03 -03:00','YYYY-MM-DD HH24:MI:SS TZH:TZM')); COMMIT; + +DROP TABLE pso_data_validator.dvt_null_not_null; +CREATE TABLE pso_data_validator.dvt_null_not_null +( col_nn TIMESTAMP(0) NOT NULL +, col_nullable TIMESTAMP(0) +, col_src_nn_trg_n TIMESTAMP(0) NOT NULL +, col_src_n_trg_nn TIMESTAMP(0) +); +COMMENT ON TABLE pso_data_validator.dvt_null_not_null IS 'Nullable integration test table, Oracle is assumed to be a DVT source (not target).'; diff --git a/tests/resources/postgresql_test_tables.sql b/tests/resources/postgresql_test_tables.sql index 7c3afe50b..7455ebc3e 100644 --- a/tests/resources/postgresql_test_tables.sql +++ b/tests/resources/postgresql_test_tables.sql @@ -89,4 +89,13 @@ INSERT INTO public.test_generate_partitions (course_id, quarter_id, student_id, ('TRI001', 2, 9012, 3.5), ('TRI001', 3, 1234, 2.7), ('TRI001', 3, 5678, 3.5), - ('TRI001', 3, 9012, 2.8); + ('TRI001', 3, 9012, 2.8); + +DROP TABLE pso_data_validator.dvt_null_not_null; +CREATE TABLE pso_data_validator.dvt_null_not_null +( col_nn TIMESTAMP(0) NOT NULL +, col_nullable TIMESTAMP(0) +, col_src_nn_trg_n TIMESTAMP(0) NOT NULL +, col_src_n_trg_nn TIMESTAMP(0) +); +COMMENT ON TABLE pso_data_validator.dvt_null_not_null IS 'Nullable integration test table, PostgreSQL is assumed to be a DVT source (not target).'; diff --git a/tests/resources/snowflake_test_tables.sql b/tests/resources/snowflake_test_tables.sql index f44adbe93..4ad57a5cb 100644 --- a/tests/resources/snowflake_test_tables.sql +++ b/tests/resources/snowflake_test_tables.sql @@ -83,4 +83,13 @@ INSERT INTO PSO_DATA_VALIDATOR.PUBLIC.TEST_GENERATE_PARTITIONS (COURSE_ID, QUART ('TRI001', 2, 9012, 3.5), ('TRI001', 3, 1234, 2.7), ('TRI001', 3, 5678, 3.5), - ('TRI001', 3, 9012, 2.8); \ No newline at end of file + ('TRI001', 3, 9012, 2.8); + +DROP TABLE PSO_DATA_VALIDATOR.PUBLIC.DVT_NULL_NOT_NULL; +CREATE TABLE PSO_DATA_VALIDATOR.PUBLIC.DVT_NULL_NOT_NULL +( col_nn TIMESTAMP(0) NOT NULL +, col_nullable TIMESTAMP(0) +, col_src_nn_trg_n TIMESTAMP(0) NOT NULL +, col_src_n_trg_nn TIMESTAMP(0) +); +COMMENT ON TABLE PSO_DATA_VALIDATOR.PUBLIC.DVT_NULL_NOT_NULL IS 'Nullable integration test table, Oracle is assumed to be a DVT source (not target).'; diff --git a/tests/resources/sqlserver_test_tables.sql b/tests/resources/sqlserver_test_tables.sql index 11978a8ee..435e76bcf 100644 --- a/tests/resources/sqlserver_test_tables.sql +++ b/tests/resources/sqlserver_test_tables.sql @@ -53,4 +53,12 @@ INSERT INTO pso_data_validator.dvt_core_types VALUES ,12345678901234567890,1234567890123456789012345,123.33,123456.3,12345678.3 ,'Hello DVT','C ','Hello DVT' ,'1970-01-03','1970-01-03 00:00:03' -,cast('1970-01-03 00:00:03 -03:00' as datetimeoffset(3))); \ No newline at end of file +,cast('1970-01-03 00:00:03 -03:00' as datetimeoffset(3))); + +DROP TABLE pso_data_validator.dvt_null_not_null; +CREATE TABLE pso_data_validator.dvt_null_not_null +( col_nn datetime2(0) NOT NULL +, col_nullable datetime2(0) +, col_src_nn_trg_n datetime2(0) NOT NULL +, col_src_n_trg_nn datetime2(0) +); diff --git a/tests/resources/teradata_test_tables.sql b/tests/resources/teradata_test_tables.sql index 296580adf..84b7cc1cb 100644 --- a/tests/resources/teradata_test_tables.sql +++ b/tests/resources/teradata_test_tables.sql @@ -31,6 +31,7 @@ CREATE TABLE udf.dvt_core_types , col_datetime TIMESTAMP(3) , col_tstz TIMESTAMP(3) WITH TIME ZONE ); +COMMENT ON TABLE udf.dvt_core_types AS 'Core data types integration test table'; INSERT INTO udf.dvt_core_types VALUES (1,1,1,1,1 @@ -50,3 +51,12 @@ INSERT INTO udf.dvt_core_types VALUES ,'Hello DVT','C ','Hello DVT' ,DATE'1970-01-03',TIMESTAMP'1970-01-03 00:00:03' ,CAST('1970-01-03 00:00:03.000-03:00' AS TIMESTAMP(3) WITH TIME ZONE)); + +DROP TABLE udf.dvt_null_not_null; +CREATE TABLE udf.dvt_null_not_null +( col_nn TIMESTAMP(0) NOT NULL +, col_nullable TIMESTAMP(0) +, col_src_nn_trg_n TIMESTAMP(0) NOT NULL +, col_src_n_trg_nn TIMESTAMP(0) +); +COMMENT ON TABLE udf.dvt_null_not_null AS 'Nullable integration test table, Teradata is assumed to be a DVT source (not target).'; diff --git a/tests/system/data_sources/common_functions.py b/tests/system/data_sources/common_functions.py new file mode 100644 index 000000000..7602567ab --- /dev/null +++ b/tests/system/data_sources/common_functions.py @@ -0,0 +1,35 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +def null_not_null_assertions(df): + """Standard assertions for null_not_null integration test. + These tests use BigQuery as a set target with a mismatch of not null/nullable settings. + All other engines are validated against BigQuery to check we get the correct status.""" + # Should be 4 columns in the Dataframe. + assert len(df) == 4 + match_columns = ["col_nn", "col_nullable"] + mismatch_columns = ["col_src_nn_trg_n", "col_src_n_trg_nn"] + for column_name, status in zip(df["source_column_name"], df["validation_status"]): + assert column_name in (match_columns + mismatch_columns) + if column_name in match_columns: + # These columns are the same for all engines and should succeed. + assert ( + status == "success" + ), f"Column: {column_name}, status: {status} != 'success'" + elif column_name in mismatch_columns: + # These columns are the different for source and target engines and should fail. + assert ( + status == "fail" + ), f"Column: {column_name}, status: {status} != 'fail'" diff --git a/tests/system/data_sources/test_hive.py b/tests/system/data_sources/test_hive.py index 050fe3638..603024ffc 100644 --- a/tests/system/data_sources/test_hive.py +++ b/tests/system/data_sources/test_hive.py @@ -18,6 +18,7 @@ from data_validation import __main__ as main from data_validation import cli_tools, data_validation, consts from data_validation.partition_builder import PartitionBuilder +from tests.system.data_sources.common_functions import null_not_null_assertions from tests.system.data_sources.test_bigquery import BQ_CONN @@ -161,6 +162,7 @@ def test_schema_validation_core_types_to_bigquery(): "-tc=bq-conn", "-tbls=pso_data_validator.dvt_core_types", "--filter-status=fail", + "--exclusion-columns=id", ( # All Hive integrals go to BigQuery INT64. "--allow-list=int8:int64,int16:int64,int32:int64," @@ -184,6 +186,34 @@ def test_schema_validation_core_types_to_bigquery(): assert len(df) == 0 +@mock.patch( + "data_validation.state_manager.StateManager.get_connection_config", + new=mock_get_connection_config, +) +def disabled_test_schema_validation_not_null_vs_nullable(): + """ + Disabled this test because we don't currently pull nullable from Hive. + https://github.com/GoogleCloudPlatform/professional-services-data-validator/issues/934 + Compares a source table with a BigQuery target and ensure we match/fail on nnot null/nullable correctly. + """ + parser = cli_tools.configure_arg_parser() + args = parser.parse_args( + [ + "validate", + "schema", + "-sc=hive-conn", + "-tc=bq-conn", + "-tbls=pso_data_validator.dvt_null_not_null=pso_data_validator.dvt_null_not_null", + ] + ) + config_managers = main.build_config_managers_from_args(args) + assert len(config_managers) == 1 + config_manager = config_managers[0] + validator = data_validation.DataValidation(config_manager.config, verbose=False) + df = validator.execute() + null_not_null_assertions(df) + + @mock.patch( "data_validation.state_manager.StateManager.get_connection_config", new=mock_get_connection_config, diff --git a/tests/system/data_sources/test_mysql.py b/tests/system/data_sources/test_mysql.py index 43ac96a0e..87256b870 100644 --- a/tests/system/data_sources/test_mysql.py +++ b/tests/system/data_sources/test_mysql.py @@ -18,6 +18,7 @@ from data_validation import __main__ as main from data_validation import cli_tools, data_validation, consts, exceptions from data_validation.partition_builder import PartitionBuilder +from tests.system.data_sources.common_functions import null_not_null_assertions from tests.system.data_sources.test_bigquery import BQ_CONN @@ -64,6 +65,13 @@ } +def mock_get_connection_config(*args): + if args[1] in ("mysql-conn", "mock-conn"): + return CONN + elif args[1] == "bq-conn": + return BQ_CONN + + def test_mysql_count_invalid_host(): try: data_validator = data_validation.DataValidation( @@ -428,13 +436,6 @@ def test_mysql_row(): pass -def mock_get_connection_config(*args): - if args[1] in ("mysql-conn", "mock-conn"): - return CONN - elif args[1] == "bq-conn": - return BQ_CONN - - # Expected result from partitioning table on 3 keys EXPECTED_PARTITION_FILTER = [ "course_id < 'ALG001' OR course_id = 'ALG001' AND (quarter_id < 3 OR quarter_id = 3 AND (student_id < 1234))", @@ -508,6 +509,30 @@ def test_schema_validation_core_types(): assert len(df) == 0 +@mock.patch( + "data_validation.state_manager.StateManager.get_connection_config", + new=mock_get_connection_config, +) +def test_schema_validation_not_null_vs_nullable(): + """Compares a source table with a BigQuery target and ensure we match/fail on nnot null/nullable correctly.""" + parser = cli_tools.configure_arg_parser() + args = parser.parse_args( + [ + "validate", + "schema", + "-sc=mysql-conn", + "-tc=bq-conn", + "-tbls=pso_data_validator.dvt_null_not_null=pso_data_validator.dvt_null_not_null", + ] + ) + config_managers = main.build_config_managers_from_args(args) + assert len(config_managers) == 1 + config_manager = config_managers[0] + validator = data_validation.DataValidation(config_manager.config, verbose=False) + df = validator.execute() + null_not_null_assertions(df) + + @mock.patch( "data_validation.state_manager.StateManager.get_connection_config", new=mock_get_connection_config, diff --git a/tests/system/data_sources/test_oracle.py b/tests/system/data_sources/test_oracle.py index 377c2fb2e..395eeda42 100644 --- a/tests/system/data_sources/test_oracle.py +++ b/tests/system/data_sources/test_oracle.py @@ -18,6 +18,7 @@ from data_validation import __main__ as main from data_validation import cli_tools, data_validation, consts from data_validation.partition_builder import PartitionBuilder +from tests.system.data_sources.common_functions import null_not_null_assertions from tests.system.data_sources.test_bigquery import BQ_CONN @@ -160,9 +161,10 @@ def test_schema_validation_core_types_to_bigquery(): "-tc=bq-conn", "-tbls=pso_data_validator.dvt_core_types", "--filter-status=fail", + "--exclusion-columns=id", ( # Integral Oracle NUMBERS go to BigQuery INT64. - "--allow-list=!decimal(8,0):int64,decimal(2,0):int64,decimal(4,0):int64,decimal(9,0):int64,decimal(18,0):int64," + "--allow-list=decimal(2,0):int64,decimal(4,0):int64,decimal(9,0):int64,decimal(18,0):int64," # Oracle NUMBERS that map to BigQuery NUMERIC. "decimal(20,0):decimal(38,9),decimal(10,2):decimal(38,9)," # Oracle NUMBERS that map to BigQuery BIGNUMERIC. @@ -181,6 +183,30 @@ def test_schema_validation_core_types_to_bigquery(): assert len(df) == 0 +@mock.patch( + "data_validation.state_manager.StateManager.get_connection_config", + new=mock_get_connection_config, +) +def test_schema_validation_not_null_vs_nullable(): + """Compares a source table with a BigQuery target and ensure we match/fail on nnot null/nullable correctly.""" + parser = cli_tools.configure_arg_parser() + args = parser.parse_args( + [ + "validate", + "schema", + "-sc=ora-conn", + "-tc=bq-conn", + "-tbls=pso_data_validator.dvt_null_not_null=pso_data_validator.dvt_null_not_null", + ] + ) + config_managers = main.build_config_managers_from_args(args) + assert len(config_managers) == 1 + config_manager = config_managers[0] + validator = data_validation.DataValidation(config_manager.config, verbose=False) + df = validator.execute() + null_not_null_assertions(df) + + @mock.patch( "data_validation.state_manager.StateManager.get_connection_config", new=mock_get_connection_config, diff --git a/tests/system/data_sources/test_postgres.py b/tests/system/data_sources/test_postgres.py index be40813d7..87705af78 100644 --- a/tests/system/data_sources/test_postgres.py +++ b/tests/system/data_sources/test_postgres.py @@ -23,6 +23,7 @@ from tests.system.data_sources.deploy_cloudsql.cloudsql_resource_manager import ( CloudSQLResourceManager, ) +from tests.system.data_sources.common_functions import null_not_null_assertions from tests.system.data_sources.test_bigquery import BQ_CONN @@ -560,10 +561,11 @@ def test_schema_validation_core_types_to_bigquery(): "-sc=pg-conn", "-tc=bq-conn", "-tbls=pso_data_validator.dvt_core_types", + "--exclusion-columns=id", "--filter-status=fail", ( # PostgreSQL integrals go to BigQuery INT64. - "--allow-list=int16:int64,int32:int64,!int32:int64," + "--allow-list=int16:int64,int32:int64," # Oracle NUMBERS that map to BigQuery NUMERIC. "decimal(20,0):decimal(38,9),decimal(10,2):decimal(38,9)," # Oracle NUMBERS that map to BigQuery BIGNUMERIC. @@ -582,6 +584,30 @@ def test_schema_validation_core_types_to_bigquery(): assert len(df) == 0 +@mock.patch( + "data_validation.state_manager.StateManager.get_connection_config", + new=mock_get_connection_config, +) +def test_schema_validation_not_null_vs_nullable(): + """Compares a source table with a BigQuery target and ensure we match/fail on nnot null/nullable correctly.""" + parser = cli_tools.configure_arg_parser() + args = parser.parse_args( + [ + "validate", + "schema", + "-sc=pg-conn", + "-tc=bq-conn", + "-tbls=pso_data_validator.dvt_null_not_null=pso_data_validator.dvt_null_not_null", + ] + ) + config_managers = main.build_config_managers_from_args(args) + assert len(config_managers) == 1 + config_manager = config_managers[0] + validator = data_validation.DataValidation(config_manager.config, verbose=False) + df = validator.execute() + null_not_null_assertions(df) + + @mock.patch( "data_validation.state_manager.StateManager.get_connection_config", new=mock_get_connection_config, diff --git a/tests/system/data_sources/test_snowflake.py b/tests/system/data_sources/test_snowflake.py index 27d3c2ad6..82bb26b1f 100644 --- a/tests/system/data_sources/test_snowflake.py +++ b/tests/system/data_sources/test_snowflake.py @@ -18,8 +18,10 @@ from data_validation import __main__ as main from data_validation import cli_tools, data_validation, consts from data_validation.partition_builder import PartitionBuilder +from tests.system.data_sources.common_functions import null_not_null_assertions from tests.system.data_sources.test_bigquery import BQ_CONN + SNOWFLAKE_ACCOUNT = os.getenv("SNOWFLAKE_ACCOUNT") SNOWFLAKE_USER = os.getenv("SNOWFLAKE_USER") SNOWFLAKE_PASSWORD = os.getenv("SNOWFLAKE_PASSWORD") @@ -158,9 +160,10 @@ def test_schema_validation_core_types_to_bigquery(): "-tc=bq-conn", "-tbls=PSO_DATA_VALIDATOR.PUBLIC.DVT_CORE_TYPES=pso_data_validator.dvt_core_types", "--filter-status=fail", + "--exclusion-columns=id", ( # Integral Snowflake NUMBERs to to BigQuery INT64. - "--allow-list=!decimal(38,0):int64,decimal(38,0):int64," + "--allow-list=decimal(38,0):int64," # Snowflake NUMBERS that map to BigQuery NUMERIC. "decimal(20,0):decimal(38,9),decimal(10,2):decimal(38,9)," # Snowflake NUMBERS that map to BigQuery BIGNUMERIC @@ -179,6 +182,30 @@ def test_schema_validation_core_types_to_bigquery(): assert len(df) == 0 +@mock.patch( + "data_validation.state_manager.StateManager.get_connection_config", + new=mock_get_connection_config, +) +def test_schema_validation_not_null_vs_nullable(): + """Compares a source table with a BigQuery target and ensure we match/fail on nnot null/nullable correctly.""" + parser = cli_tools.configure_arg_parser() + args = parser.parse_args( + [ + "validate", + "schema", + "-sc=snowflake-conn", + "-tc=bq-conn", + "-tbls=PUBLIC.DVT_NULL_NOT_NULL=pso_data_validator.dvt_null_not_null", + ] + ) + config_managers = main.build_config_managers_from_args(args) + assert len(config_managers) == 1 + config_manager = config_managers[0] + validator = data_validation.DataValidation(config_manager.config, verbose=False) + df = validator.execute() + null_not_null_assertions(df) + + @mock.patch( "data_validation.state_manager.StateManager.get_connection_config", new=mock_get_connection_config, diff --git a/tests/system/data_sources/test_sql_server.py b/tests/system/data_sources/test_sql_server.py index bedba7925..83e857710 100644 --- a/tests/system/data_sources/test_sql_server.py +++ b/tests/system/data_sources/test_sql_server.py @@ -23,6 +23,7 @@ from data_validation import __main__ as main from data_validation import cli_tools, data_validation, consts from data_validation.partition_builder import PartitionBuilder +from tests.system.data_sources.common_functions import null_not_null_assertions from tests.system.data_sources.test_bigquery import BQ_CONN @@ -287,6 +288,7 @@ def test_schema_validation_core_types_to_bigquery(): "-tc=bq-conn", "-tbls=pso_data_validator.dvt_core_types", "--filter-status=fail", + "--exclusion-columns=id", ( # All SQL Server integrals go to BigQuery INT64. "--allow-list=int8:int64,int16:int64,int32:int64," @@ -295,9 +297,7 @@ def test_schema_validation_core_types_to_bigquery(): # SQL Server decimals that map to BigQuery BIGNUMERIC. "decimal(38,0):decimal(76,38)," # BigQuery does not have a float32 type. - "float32:float64," - # Ignore ID column, we're not testing that one. - "!int32:int64" + "float32:float64" ), ] ) @@ -310,6 +310,30 @@ def test_schema_validation_core_types_to_bigquery(): assert len(df) == 0 +@mock.patch( + "data_validation.state_manager.StateManager.get_connection_config", + new=mock_get_connection_config, +) +def test_schema_validation_not_null_vs_nullable(): + """Compares a source table with a BigQuery target and ensure we match/fail on nnot null/nullable correctly.""" + parser = cli_tools.configure_arg_parser() + args = parser.parse_args( + [ + "validate", + "schema", + "-sc=sql-conn", + "-tc=bq-conn", + "-tbls=pso_data_validator.dvt_null_not_null=pso_data_validator.dvt_null_not_null", + ] + ) + config_managers = main.build_config_managers_from_args(args) + assert len(config_managers) == 1 + config_manager = config_managers[0] + validator = data_validation.DataValidation(config_manager.config, verbose=False) + df = validator.execute() + null_not_null_assertions(df) + + @mock.patch( "data_validation.state_manager.StateManager.get_connection_config", new=mock_get_connection_config, diff --git a/tests/system/data_sources/test_teradata.py b/tests/system/data_sources/test_teradata.py index 8d1949e10..eebc6d038 100644 --- a/tests/system/data_sources/test_teradata.py +++ b/tests/system/data_sources/test_teradata.py @@ -18,6 +18,7 @@ from data_validation import __main__ as main from data_validation import cli_tools, data_validation, consts from data_validation.partition_builder import PartitionBuilder +from tests.system.data_sources.common_functions import null_not_null_assertions from tests.system.data_sources.test_bigquery import BQ_CONN @@ -240,6 +241,7 @@ def test_schema_validation_core_types_to_bigquery(): "-tc=bq-conn", "-tbls=udf.dvt_core_types=pso_data_validator.dvt_core_types", "--filter-status=fail", + "--exclusion-columns=id", ( # Teradata integrals go to BigQuery INT64. "--allow-list=int8:int64,int16:int64,int32:int64," @@ -259,6 +261,30 @@ def test_schema_validation_core_types_to_bigquery(): assert len(df) == 0 +@mock.patch( + "data_validation.state_manager.StateManager.get_connection_config", + new=mock_get_connection_config, +) +def test_schema_validation_not_null_vs_nullable(): + """Compares a source table with a BigQuery target and ensure we match/fail on nnot null/nullable correctly.""" + parser = cli_tools.configure_arg_parser() + args = parser.parse_args( + [ + "validate", + "schema", + "-sc=td-conn", + "-tc=bq-conn", + "-tbls=udf.dvt_null_not_null=pso_data_validator.dvt_null_not_null", + ] + ) + config_managers = main.build_config_managers_from_args(args) + assert len(config_managers) == 1 + config_manager = config_managers[0] + validator = data_validation.DataValidation(config_manager.config, verbose=False) + df = validator.execute() + null_not_null_assertions(df) + + @mock.patch( "data_validation.state_manager.StateManager.get_connection_config", new=mock_get_connection_config, diff --git a/third_party/ibis/ibis_addon/operations.py b/third_party/ibis/ibis_addon/operations.py index 963f7b090..9e160593c 100644 --- a/third_party/ibis/ibis_addon/operations.py +++ b/third_party/ibis/ibis_addon/operations.py @@ -22,27 +22,39 @@ Ibis as an override, though it would not apply for Pandas and other non-textual languages. """ +import google.cloud.bigquery as bq import ibis import ibis.expr.datatypes as dt import ibis.expr.operations as ops import ibis.expr.rules as rlz import sqlalchemy as sa -from ibis.backends.base.sql.alchemy.registry import \ - fixed_arity as sa_fixed_arity +from ibis.backends.base.sql.alchemy.registry import fixed_arity as sa_fixed_arity from ibis.backends.base.sql.alchemy.translator import AlchemyExprTranslator from ibis.backends.base.sql.compiler.translator import ExprTranslator from ibis.backends.base.sql.registry import fixed_arity -from ibis.backends.bigquery.client import _DTYPE_TO_IBIS_TYPE +from ibis.backends.bigquery.client import ( + _DTYPE_TO_IBIS_TYPE as _BQ_DTYPE_TO_IBIS_TYPE, + _LEGACY_TO_STANDARD as _BQ_LEGACY_TO_STANDARD, +) from ibis.backends.bigquery.compiler import BigQueryExprTranslator -from ibis.backends.bigquery.registry import \ - STRFTIME_FORMAT_FUNCTIONS as BQ_STRFTIME_FORMAT_FUNCTIONS +from ibis.backends.bigquery.registry import ( + STRFTIME_FORMAT_FUNCTIONS as BQ_STRFTIME_FORMAT_FUNCTIONS, +) from ibis.backends.impala.compiler import ImpalaExprTranslator from ibis.backends.mssql.compiler import MsSqlExprTranslator from ibis.backends.mysql.compiler import MySQLExprTranslator from ibis.backends.postgres.compiler import PostgreSQLExprTranslator -from ibis.expr.operations import (Cast, Comparison, HashBytes, IfNull, - RandomScalar, Strftime, StringJoin, - Value, ExtractEpochSeconds) +from ibis.expr.operations import ( + Cast, + Comparison, + HashBytes, + IfNull, + RandomScalar, + Strftime, + StringJoin, + Value, + ExtractEpochSeconds, +) from ibis.expr.types import NumericValue, TemporalValue import third_party.ibis.ibis_mysql.compiler @@ -63,11 +75,21 @@ except Exception: SnowflakeExprTranslator = None + class ToChar(Value): - arg = rlz.one_of([rlz.value(dt.Decimal), rlz.value(dt.float64), rlz.value(dt.Date), rlz.value(dt.Time), rlz.value(dt.Timestamp)]) + arg = rlz.one_of( + [ + rlz.value(dt.Decimal), + rlz.value(dt.float64), + rlz.value(dt.Date), + rlz.value(dt.Time), + rlz.value(dt.Timestamp), + ] + ) fmt = rlz.string output_type = rlz.shape_like("arg") + class RawSQL(Comparison): pass @@ -126,6 +148,7 @@ def strftime_bigquery(translator, op): strftime_format_func_name, fmt_string, arg_formatted ) + def strftime_mysql(translator, op): arg = op.arg format_string = op.format_str @@ -136,6 +159,7 @@ def strftime_mysql(translator, op): fmt_string = "%Y-%m-%d %H:%i:%S" return sa.func.date_format(arg_formatted, fmt_string) + def strftime_mssql(translator, op): """Use MS SQL CONVERT() in place of STRFTIME(). @@ -144,17 +168,17 @@ def strftime_mssql(translator, op): to string in order to complete row data comparison.""" arg, pattern = map(translator.translate, op.args) supported_convert_styles = { - "%Y-%m-%d": 23, # ISO8601 - "%Y-%m-%d %H:%M:%S": 20, # ODBC canonical - "%Y-%m-%d %H:%M:%S.%f": 21, # ODBC canonical (with milliseconds) + "%Y-%m-%d": 23, # ISO8601 + "%Y-%m-%d %H:%M:%S": 20, # ODBC canonical + "%Y-%m-%d %H:%M:%S.%f": 21, # ODBC canonical (with milliseconds) } try: convert_style = supported_convert_styles[pattern.value] except KeyError: raise NotImplementedError( - f'strftime format {pattern.value} not supported for SQL Server.' + f"strftime format {pattern.value} not supported for SQL Server." ) - result = sa.func.convert(sa.text('VARCHAR(32)'), arg, convert_style) + result = sa.func.convert(sa.text("VARCHAR(32)"), arg, convert_style) return result @@ -179,6 +203,7 @@ def format_hashbytes_hive(translator, op): else: raise ValueError(f"unexpected value for 'how': {op.how}") + def format_hashbytes_alchemy(translator, op): arg = translator.translate(op.arg) if op.how == "sha256": @@ -188,10 +213,12 @@ def format_hashbytes_alchemy(translator, op): else: raise ValueError(f"unexpected value for 'how': {op.how}") + def format_hashbytes_base(translator, op): arg = translator.translate(op.arg) return f"sha2({arg}, 256)" + def compile_raw_sql(table, sql): op = RawSQL(table[table.columns[0]].cast(dt.string), ibis.literal(sql)) return op.to_expr() @@ -206,48 +233,59 @@ def sa_format_raw_sql(translator, op): rand_col, raw_sql = op.args return sa.text(raw_sql.args[0]) + def sa_format_hashbytes_mssql(translator, op): arg = translator.translate(op.arg) cast_arg = sa.func.convert(sa.sql.literal_column("VARCHAR(MAX)"), arg) hash_func = sa.func.hashbytes(sa.sql.literal_column("'SHA2_256'"), cast_arg) - hash_to_string = sa.func.convert(sa.sql.literal_column('CHAR(64)'), hash_func, sa.sql.literal_column('2')) + hash_to_string = sa.func.convert( + sa.sql.literal_column("CHAR(64)"), hash_func, sa.sql.literal_column("2") + ) return sa.func.lower(hash_to_string) + def sa_format_hashbytes_oracle(translator, op): arg = translator.translate(op.arg) convert = sa.func.convert(arg, sa.sql.literal_column("'UTF8'")) hash_func = sa.func.standard_hash(convert, sa.sql.literal_column("'SHA256'")) return sa.func.lower(hash_func) + def sa_format_hashbytes_mysql(translator, op): arg = translator.translate(op.arg) hash_func = sa.func.sha2(arg, sa.sql.literal_column("'256'")) return hash_func + def sa_format_hashbytes_db2(translator, op): compiled_arg = translator.translate(op.arg) - hashfunc = sa.func.hash(compiled_arg,sa.sql.literal_column("2")) + hashfunc = sa.func.hash(compiled_arg, sa.sql.literal_column("2")) hex = sa.func.hex(hashfunc) return sa.func.lower(hex) + def sa_format_hashbytes_redshift(translator, op): arg = translator.translate(op.arg) return sa.sql.literal_column(f"sha2({arg}, 256)") + def sa_format_hashbytes_postgres(translator, op): arg = translator.translate(op.arg) convert = sa.func.convert_to(arg, sa.sql.literal_column("'UTF8'")) hash_func = sa.func.sha256(convert) return sa.func.encode(hash_func, sa.sql.literal_column("'hex'")) + def sa_format_hashbytes_snowflake(translator, op): arg = translator.translate(op.arg) return sa.func.sha2(arg) + def sa_epoch_time_snowflake(translator, op): arg = translator.translate(op.arg) return sa.func.date_part(sa.sql.literal_column("epoch_seconds"), arg) + def sa_format_to_char(translator, op): arg = translator.translate(op.arg) fmt = translator.translate(op.fmt) @@ -263,7 +301,12 @@ def sa_cast_postgres(t, op): sa_arg = t.translate(arg) # Specialize going from numeric(p,s>0) to string - if arg_dtype.is_decimal() and arg_dtype.scale and arg_dtype.scale > 0 and typ.is_string(): + if ( + arg_dtype.is_decimal() + and arg_dtype.scale + and arg_dtype.scale > 0 + and typ.is_string() + ): # When casting a number to string PostgreSQL includes the full scale, e.g.: # SELECT CAST(CAST(100 AS DECIMAL(5,2)) AS VARCHAR(10)); # 100.00 @@ -273,7 +316,9 @@ def sa_cast_postgres(t, op): # Would have liked to use trim_scale but this is only available in PostgreSQL 13+ # return (sa.cast(sa.func.trim_scale(arg), typ)) precision = arg_dtype.precision or 38 - fmt = "FM" + ("9" * (precision - arg_dtype.scale)) + "." + ("9" * arg_dtype.scale) + fmt = ( + "FM" + ("9" * (precision - arg_dtype.scale)) + "." + ("9" * arg_dtype.scale) + ) return sa.func.rtrim(sa.func.to_char(sa_arg, fmt), ".") # specialize going from an integer type to a timestamp @@ -281,7 +326,7 @@ def sa_cast_postgres(t, op): return t.integer_to_timestamp(sa_arg, tz=typ.timezone) if arg_dtype.is_binary() and typ.is_string(): - return sa.func.encode(sa_arg, 'escape') + return sa.func.encode(sa_arg, "escape") if typ.is_binary(): # decode yields a column of memoryview which is annoying to deal with @@ -293,19 +338,48 @@ def sa_cast_postgres(t, op): return sa.cast(sa_arg, t.get_sqla_type(typ)) + def _sa_string_join(t, op): return sa.func.concat(*map(t.translate, op.arg)) + def sa_format_new_id(t, op): return sa.func.NEWID() + +_BQ_DTYPE_TO_IBIS_TYPE["TIMESTAMP"] = dt.Timestamp(timezone="UTC") + + +@dt.dtype.register(bq.schema.SchemaField) +def _bigquery_field_to_ibis_dtype(field): + """Convert BigQuery `field` to an ibis type. + Taken from ibis.backends.bigquery.client.py for issue: + https://github.com/GoogleCloudPlatform/professional-services-data-validator/issues/926 + """ + typ = field.field_type + if typ == "RECORD": + fields = field.fields + assert fields, "RECORD fields are empty" + names = [el.name for el in fields] + ibis_types = list(map(dt.dtype, fields)) + ibis_type = dt.Struct(dict(zip(names, ibis_types))) + else: + ibis_type = _BQ_LEGACY_TO_STANDARD.get(typ, typ) + if ibis_type in _BQ_DTYPE_TO_IBIS_TYPE: + ibis_type = _BQ_DTYPE_TO_IBIS_TYPE[ibis_type](nullable=field.is_nullable) + else: + ibis_type = ibis_type + if field.mode == "REPEATED": + ibis_type = dt.Array(ibis_type) + return ibis_type + + NumericValue.to_char = compile_to_char TemporalValue.to_char = compile_to_char BigQueryExprTranslator._registry[HashBytes] = format_hashbytes_bigquery BigQueryExprTranslator._registry[RawSQL] = format_raw_sql BigQueryExprTranslator._registry[Strftime] = strftime_bigquery -_DTYPE_TO_IBIS_TYPE["TIMESTAMP"] = dt.Timestamp(timezone="UTC") AlchemyExprTranslator._registry[RawSQL] = format_raw_sql AlchemyExprTranslator._registry[HashBytes] = format_hashbytes_alchemy @@ -328,7 +402,7 @@ def sa_format_new_id(t, op): MsSqlExprTranslator._registry[HashBytes] = sa_format_hashbytes_mssql MsSqlExprTranslator._registry[RawSQL] = sa_format_raw_sql -MsSqlExprTranslator._registry[IfNull] = sa_fixed_arity(sa.func.isnull,2) +MsSqlExprTranslator._registry[IfNull] = sa_fixed_arity(sa.func.isnull, 2) MsSqlExprTranslator._registry[StringJoin] = _sa_string_join MsSqlExprTranslator._registry[RandomScalar] = sa_format_new_id MsSqlExprTranslator._registry[Strftime] = strftime_mssql diff --git a/third_party/ibis/ibis_teradata/datatypes.py b/third_party/ibis/ibis_teradata/datatypes.py index 35289a74e..0a08c028b 100644 --- a/third_party/ibis/ibis_teradata/datatypes.py +++ b/third_party/ibis/ibis_teradata/datatypes.py @@ -49,6 +49,10 @@ class TeradataTypeTranslator(object): def __init__(self): pass + @classmethod + def _col_data_nullable(cls, col_data: dict) -> bool: + return bool(col_data.get("Nullable", "Y ").startswith("Y")) + @classmethod def to_ibis(cls, col_data): td_type = col_data["Type"].strip() @@ -62,7 +66,7 @@ def to_ibis(cls, col_data): @classmethod def to_ibis_from_other(cls, col_data, return_ibis_type=True): if return_ibis_type: - return dt.string + return dt.string(nullable=cls._col_data_nullable(col_data)) print('Unsupported Date Type Seen: "%s"' % col_data["Type"]) return "VARCHAR" @@ -70,7 +74,7 @@ def to_ibis_from_other(cls, col_data, return_ibis_type=True): @classmethod def to_ibis_from_CV(cls, col_data, return_ibis_type=True): if return_ibis_type: - return dt.string + return dt.string(nullable=cls._col_data_nullable(col_data)) return "VARCHAR" @@ -86,13 +90,15 @@ def to_ibis_from_N(cls, col_data, return_ibis_type=True): ) if return_ibis_type: # No precision or scale specified - if precision == -128 or scale ==-128: + if precision == -128 or scale == -128: return dt.Decimal() - return dt.Decimal(precision, scale) - - if precision == -128 or scale ==-128: + return dt.Decimal( + precision, scale, nullable=cls._col_data_nullable(col_data) + ) + + if precision == -128 or scale == -128: return "DECIMAL" - + return "DECIMAL(%d, %d)" % (precision, scale) @classmethod @@ -106,58 +112,62 @@ def to_ibis_from_D(cls, col_data, return_ibis_type=True): ) ) if return_ibis_type: - return dt.Decimal(precision, scale) + return dt.Decimal( + precision, scale, nullable=cls._col_data_nullable(col_data) + ) value_type = "DECIMAL(%d, %d)" % (precision, scale) return value_type @classmethod def to_ibis_from_F(cls, col_data, return_ibis_type=True): if return_ibis_type: - return dt.float64 + return dt.float64(nullable=cls._col_data_nullable(col_data)) return "FLOAT" @classmethod def to_ibis_from_I(cls, col_data, return_ibis_type=True): if return_ibis_type: - return dt.int32 + return dt.int32(nullable=cls._col_data_nullable(col_data)) return "INT" @classmethod def to_ibis_from_I1(cls, col_data, return_ibis_type=True): if return_ibis_type: - return dt.int8 + return dt.int8(nullable=cls._col_data_nullable(col_data)) return "INT" @classmethod def to_ibis_from_I2(cls, col_data, return_ibis_type=True): if return_ibis_type: - return dt.int16 + return dt.int16(nullable=cls._col_data_nullable(col_data)) return "INT" @classmethod def to_ibis_from_I8(cls, col_data, return_ibis_type=True): if return_ibis_type: - return dt.int64 + return dt.int64(nullable=cls._col_data_nullable(col_data)) return "INT" @classmethod def to_ibis_from_DA(cls, col_data, return_ibis_type=True): if return_ibis_type: - return dt.date + return dt.date(nullable=cls._col_data_nullable(col_data)) return "DATE" @classmethod def to_ibis_from_TS(cls, col_data, return_ibis_type=True): if return_ibis_type: - return dt.timestamp + return dt.timestamp(nullable=cls._col_data_nullable(col_data)) return "TIMESTAMP" @classmethod def to_ibis_from_SZ(cls, col_data, return_ibis_type=True): if return_ibis_type: - return dt.timestamp(timezone="UTC") + return dt.timestamp( + timezone="UTC", nullable=cls._col_data_nullable(col_data) + ) return "TIMESTAMP"