diff --git a/README.md b/README.md index f718e16e5..25e25616f 100644 --- a/README.md +++ b/README.md @@ -193,7 +193,7 @@ data-validation (--verbose or -v) validate row #### Schema Validations -Below is the syntax for schema validations. These can be used to compare column +Below is the syntax for schema validations. These can be used to compare case insensitive column names and types between source and target. ``` diff --git a/data_validation/schema_validation.py b/data_validation/schema_validation.py index 1d8e62a7a..1782cbfd8 100644 --- a/data_validation/schema_validation.py +++ b/data_validation/schema_validation.py @@ -97,12 +97,22 @@ def execute(self): def schema_validation_matching(source_fields, target_fields): """Compare schemas between two dictionary objects""" results = [] + # Apply the casefold() function to lowercase the keys of source and target + source_fields_casefold = { + source_field_name.casefold(): source_field_type + for source_field_name, source_field_type in source_fields.items() + } + target_fields_casefold = { + target_field_name.casefold(): target_field_type + for target_field_name, target_field_type in target_fields.items() + } + # Go through each source and check if target exists and matches - for source_field_name, source_field_type in source_fields.items(): + for source_field_name, source_field_type in source_fields_casefold.items(): # target field exists - if source_field_name in target_fields: + if source_field_name in target_fields_casefold: # target data type matches - if source_field_type == target_fields[source_field_name]: + if source_field_type == target_fields_casefold[source_field_name]: results.append( [ source_field_name, @@ -111,7 +121,7 @@ def schema_validation_matching(source_fields, target_fields): "1", consts.VALIDATION_STATUS_SUCCESS, "Source_type:{} Target_type:{}".format( - source_field_type, target_fields[source_field_name] + source_field_type, target_fields_casefold[source_field_name] ), ] ) @@ -125,7 +135,7 @@ def schema_validation_matching(source_fields, target_fields): "1", consts.VALIDATION_STATUS_FAIL, "Data type mismatch between source and target. Source_type:{} Target_type:{}".format( - source_field_type, target_fields[source_field_name] + source_field_type, target_fields_casefold[source_field_name] ), ] ) @@ -143,8 +153,8 @@ def schema_validation_matching(source_fields, target_fields): ) # source field doesn't exist - for target_field_name, target_field_type in target_fields.items(): - if target_field_name not in source_fields: + for target_field_name, target_field_type in target_fields_casefold.items(): + if target_field_name not in source_fields_casefold: results.append( [ "N/A", diff --git a/tests/unit/test_schema_validation.py b/tests/unit/test_schema_validation.py index 2de3a6e0d..57230e6d6 100644 --- a/tests/unit/test_schema_validation.py +++ b/tests/unit/test_schema_validation.py @@ -143,7 +143,7 @@ def test_import(module_under_test): def test_schema_validation_matching(module_under_test): - source_fields = {"field1": "string", "field2": "datetime", "field3": "string"} + source_fields = {"FIELD1": "string", "fiEld2": "datetime", "field3": "string"} target_fields = {"field1": "string", "field2": "timestamp", "field_3": "string"} expected_results = [ @@ -202,7 +202,6 @@ def test_execute(module_under_test, fs): failures = result_df[ result_df["validation_status"].str.contains(consts.VALIDATION_STATUS_FAIL) ] - assert len(result_df) == len(source_data[0]) + 1 assert result_df["source_agg_value"].astype(float).sum() == 7 assert result_df["target_agg_value"].astype(float).sum() == 7