diff --git a/data_validation/query_builder/query_builder.py b/data_validation/query_builder/query_builder.py index 5fde52a8c..e7a5a99fe 100644 --- a/data_validation/query_builder/query_builder.py +++ b/data_validation/query_builder/query_builder.py @@ -208,6 +208,8 @@ def compile(self, ibis_table): # Fields are supplied on compile or on build comparison_field = ibis_table[self.field_name] alias = self.alias or self.field_name + if self.cast: + comparison_field = comparison_field.cast(self.cast) comparison_field = comparison_field.name(alias) return comparison_field @@ -493,6 +495,8 @@ def compile(self, data_client, schema_name, table_name): calc_table = calc_table.mutate( self.compile_calculated_fields(calc_table, n) ) + if self.comparison_fields: + calc_table = calc_table.mutate(self.compile_comparison_fields(calc_table)) compiled_filters = self.compile_filter_fields(table) filtered_table = ( calc_table.filter(compiled_filters) if compiled_filters else calc_table diff --git a/data_validation/validation_builder.py b/data_validation/validation_builder.py index e4e5442e1..4db2384b7 100644 --- a/data_validation/validation_builder.py +++ b/data_validation/validation_builder.py @@ -19,6 +19,7 @@ from data_validation.query_builder.query_builder import ( AggregateField, CalculatedField, + ComparisonField, FilterField, GroupedField, QueryBuilder, @@ -227,7 +228,7 @@ def add_query_group(self, grouped_field): self.group_aliases[alias] = grouped_field def add_primary_key(self, primary_key): - """Add ComparionField to Queries + """Add ComparisonField to Queries Args: primary_key (Dict): An object with source, target, and cast info @@ -235,10 +236,17 @@ def add_primary_key(self, primary_key): source_field_name = primary_key[consts.CONFIG_SOURCE_COLUMN] target_field_name = primary_key[consts.CONFIG_TARGET_COLUMN] # grab calc field metadata - alias = primary_key[consts.CONFIG_FIELD_ALIAS] + alias = primary_key.get(consts.CONFIG_FIELD_ALIAS) + cast = primary_key.get(consts.CONFIG_CAST) # check if valid calc field and return correct object - self.source_builder.add_comparison_field(source_field_name) - self.target_builder.add_comparison_field(target_field_name) + source_field = ComparisonField( + field_name=source_field_name, alias=alias, cast=cast + ) + target_field = ComparisonField( + field_name=target_field_name, alias=alias, cast=cast + ) + self.source_builder.add_comparison_field(source_field) + self.target_builder.add_comparison_field(target_field) self.primary_keys[alias] = primary_key def add_filter(self, filter_field): @@ -287,9 +295,16 @@ def add_comparison_field(self, comparison_field): target_field_name = comparison_field[consts.CONFIG_TARGET_COLUMN] # grab calc field metadata alias = comparison_field[consts.CONFIG_FIELD_ALIAS] + cast = comparison_field.get(consts.CONFIG_CAST) + source_field = ComparisonField( + field_name=source_field_name, alias=alias, cast=cast + ) + target_field = ComparisonField( + field_name=target_field_name, alias=alias, cast=cast + ) # check if valid calc field and return correct object - self.source_builder.add_comparison_field(source_field_name) - self.target_builder.add_comparison_field(target_field_name) + self.source_builder.add_comparison_field(source_field) + self.target_builder.add_comparison_field(target_field) self._metadata[alias] = metadata.ValidationMetadata( aggregation_type=None, validation_type=self.validation_type,