Skip to content

Commit

Permalink
fix: Schema validation to make case insensitive column name comparisi…
Browse files Browse the repository at this point in the history
…on (#500)

* fix: Schema validation to make case insensitive column name comparision

* fix: Schema validation to make case insensitive column name comparision

* fix: Schema validation to make case insensitive column name comparision

Co-authored-by: Latika Wadhwa <[email protected]>
  • Loading branch information
latika-wadhwa and Latika Wadhwa committed Jun 9, 2022
1 parent 179d1c5 commit ee8c542
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 10 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ data-validation (--verbose or -v) validate row

#### Schema Validations

Below is the syntax for schema validations. These can be used to compare column
Below is the syntax for schema validations. These can be used to compare case insensitive column names and
types between source and target.

```
Expand Down
24 changes: 17 additions & 7 deletions data_validation/schema_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,12 +97,22 @@ def execute(self):
def schema_validation_matching(source_fields, target_fields):
"""Compare schemas between two dictionary objects"""
results = []
# Apply the casefold() function to lowercase the keys of source and target
source_fields_casefold = {
source_field_name.casefold(): source_field_type
for source_field_name, source_field_type in source_fields.items()
}
target_fields_casefold = {
target_field_name.casefold(): target_field_type
for target_field_name, target_field_type in target_fields.items()
}

# Go through each source and check if target exists and matches
for source_field_name, source_field_type in source_fields.items():
for source_field_name, source_field_type in source_fields_casefold.items():
# target field exists
if source_field_name in target_fields:
if source_field_name in target_fields_casefold:
# target data type matches
if source_field_type == target_fields[source_field_name]:
if source_field_type == target_fields_casefold[source_field_name]:
results.append(
[
source_field_name,
Expand All @@ -111,7 +121,7 @@ def schema_validation_matching(source_fields, target_fields):
"1",
consts.VALIDATION_STATUS_SUCCESS,
"Source_type:{} Target_type:{}".format(
source_field_type, target_fields[source_field_name]
source_field_type, target_fields_casefold[source_field_name]
),
]
)
Expand All @@ -125,7 +135,7 @@ def schema_validation_matching(source_fields, target_fields):
"1",
consts.VALIDATION_STATUS_FAIL,
"Data type mismatch between source and target. Source_type:{} Target_type:{}".format(
source_field_type, target_fields[source_field_name]
source_field_type, target_fields_casefold[source_field_name]
),
]
)
Expand All @@ -143,8 +153,8 @@ def schema_validation_matching(source_fields, target_fields):
)

# source field doesn't exist
for target_field_name, target_field_type in target_fields.items():
if target_field_name not in source_fields:
for target_field_name, target_field_type in target_fields_casefold.items():
if target_field_name not in source_fields_casefold:
results.append(
[
"N/A",
Expand Down
3 changes: 1 addition & 2 deletions tests/unit/test_schema_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def test_import(module_under_test):


def test_schema_validation_matching(module_under_test):
source_fields = {"field1": "string", "field2": "datetime", "field3": "string"}
source_fields = {"FIELD1": "string", "fiEld2": "datetime", "field3": "string"}
target_fields = {"field1": "string", "field2": "timestamp", "field_3": "string"}

expected_results = [
Expand Down Expand Up @@ -202,7 +202,6 @@ def test_execute(module_under_test, fs):
failures = result_df[
result_df["validation_status"].str.contains(consts.VALIDATION_STATUS_FAIL)
]

assert len(result_df) == len(source_data[0]) + 1
assert result_df["source_agg_value"].astype(float).sum() == 7
assert result_df["target_agg_value"].astype(float).sum() == 7
Expand Down

0 comments on commit ee8c542

Please sign in to comment.