Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Logic to add allow-list to support datatype matching with a provided list in case of mismatched datatypes between source and target #643

Merged
merged 7 commits into from
Jan 16, 2023
5 changes: 5 additions & 0 deletions data_validation/cli_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,6 +485,11 @@ def _configure_schema_parser(schema_parser):
"-ec",
help="Comma separated list of columns 'col_a,col_b' to be excluded from the schema validation",
)
schema_parser.add_argument(
"--allow-list",
"-al",
help="Comma separated list of datatype mappings due to incompatible datatypes in source and columns. e.g: decimal(12,2):decimal(38,9),string[non-nullable]:string",
kanhaPrayas marked this conversation as resolved.
Show resolved Hide resolved
)


def _configure_custom_query_parser(custom_query_parser):
Expand Down
127 changes: 124 additions & 3 deletions data_validation/schema_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import datetime
import pandas
import logging

from data_validation import metadata, consts, clients

Expand Down Expand Up @@ -52,7 +53,7 @@ def execute(self):
target_fields[field_name] = data_type

results = schema_validation_matching(
source_fields, target_fields, self.config_manager.exclusion_columns
source_fields, target_fields, self.config_manager.exclusion_columns, allow_list
)
df = pandas.DataFrame(
results,
Expand Down Expand Up @@ -101,8 +102,7 @@ def execute(self):

return df


def schema_validation_matching(source_fields, target_fields, exclusion_fields):
def schema_validation_matching(source_fields, target_fields, exclusion_fields, allow_list):
"""Compare schemas between two dictionary objects"""
results = []
# Apply the casefold() function to lowercase the keys of source and target
Expand All @@ -120,6 +120,9 @@ def schema_validation_matching(source_fields, target_fields, exclusion_fields):
source_fields_casefold.pop(field, None)
target_fields_casefold.pop(field, None)

#allow list map in case of incompatible data types in source and target
allow_list_map = parse_allow_list(allow_list)

# Go through each source and check if target exists and matches
for source_field_name, source_field_type in source_fields_casefold.items():
# target field exists
Expand All @@ -136,6 +139,31 @@ def schema_validation_matching(source_fields, target_fields, exclusion_fields):
consts.VALIDATION_STATUS_SUCCESS,
]
)
elif source_field_type in allow_list_map:
target_field_type = allow_list_map[source_field_type]
name_mismatch, higher_precision, lower_precision = parse_n_validate_datatypes(source_field_type, target_field_type)
if name_mismatch or lower_precision:
results.append(
[
source_field_name,
source_field_name,
str(source_field_type),
str(target_field_type),
consts.VALIDATION_STATUS_FAIL,
]
)
else:
if higher_precision:
logging.warning("Source and target data type has precision mismatch: %s - %s", source_field_type, target_field_type)
results.append(
[
source_field_name,
source_field_name,
str(source_field_type),
str(target_field_type),
consts.VALIDATION_STATUS_SUCCESS,
]
)
# target data type mismatch
else:
results.append(
Expand Down Expand Up @@ -172,3 +200,96 @@ def schema_validation_matching(source_fields, target_fields, exclusion_fields):
]
)
return results

def is_number(val):
kanhaPrayas marked this conversation as resolved.
Show resolved Hide resolved
try:
num = int(val)
except ValueError as e:
return False
return True

def parse_allow_list(st):
output = {}
stack = []
key = None
for i in range(len(st)):
if st[i] == ":":
key = "".join(stack)
output[key] = None
stack = []
continue
if st[i] == "," and not is_number(st[i+1]):
value = "".join(stack)
output[key] = value
stack = []
i+=1
continue
stack.append(st[i])
value = "".join(stack)
output[key] = value
stack = []
return output

def get_datatype_name(st):
chars = []
for i in range(len(st)):
if ord(st[i].lower()) >= 97 and ord(st[i].lower()) <= 122:
chars.append(st[i].lower())
out = "".join(chars)
if num == '':
return -1
return out

#typea data types: int8,int16
def get_typea_numeric_sustr(st):
nums = []
for i in range(len(st)):
if is_number(st[i]):
nums.append(st[i])
num = "".join(nums)
if num == '':
return -1
return int(num)

#typeb data types: Decimal(10,2)
def get_typeb_numeric_sustr(st):
nums = []
first_half = st.split(",")[0]
second_half = st.split(",")[1]
first_half_num = get_typea_numeric_sustr(first_half)
second_half_num = get_typea_numeric_sustr(second_half)
return first_half_num, second_half_num

def validate_typeb_vals(source, target):
if source[0] > target[0] or source[1] > target[1]:
return False, True
elif source[0] == target[0] and source[1] == target[1]:
return False, False
return True, False

'''
@returns
bool:source and target datatype names matched or not
bool:target has higher precision value
bool:target has lower precision value
'''
def parse_n_validate_datatypes(source, target):
if get_datatype_name(source) != get_datatype_name(target):
return True, None, None
#Check for type of precisions supplied e.g: int8,Decimal(10,2),int
if "(" in source:
typeb_source = get_typeb_numeric_sustr(source)
typeb_target = get_typeb_numeric_sustr(target)
higher_precision, lower_precision = validate_typeb_vals(typeb_source, typeb_target):
return False, higher_precision, lower_precision
source_num = get_typea_numeric_sustr(source)
target_num = get_typea_numeric_sustr(target)
if source_num == target_num:
return False, False, False
elif source_num > target_num:
return False, False, True
return False, True, False