diff --git a/data_validation/config_manager.py b/data_validation/config_manager.py index 010d18a96..f7b5ac11b 100644 --- a/data_validation/config_manager.py +++ b/data_validation/config_manager.py @@ -14,7 +14,7 @@ import copy import logging -from typing import Optional, Union +from typing import Optional, Union, TYPE_CHECKING import google.oauth2.service_account from ibis_bigquery.client import BigQueryClient @@ -26,6 +26,10 @@ from data_validation.validation_builder import ValidationBuilder +if TYPE_CHECKING: + import ibis.expr.types.TableExpr + + class ConfigManager(object): _config: dict = None _source_conn = None @@ -718,17 +722,19 @@ def _strftime_format( return "%Y-%m-%d %H:%M:%S" return "%Y-%m-%d" - def _apply_base_cast_overrides(self, column: str, col_config: dict) -> dict: + def _apply_base_cast_overrides( + self, + column: str, + col_config: dict, + source_table: "ibis.expr.types.TableExpr", + target_table: "ibis.expr.types.TableExpr", + ) -> dict: """Mutates col_config to contain any overrides. Also returns col_config for convenience.""" if col_config["calc_type"] != "cast": return col_config - source_table_schema = { - k: v for k, v in self.get_source_ibis_table().schema().items() - } - target_table_schema = { - k: v for k, v in self.get_target_ibis_table().schema().items() - } + source_table_schema = {k: v for k, v in source_table.schema().items()} + target_table_schema = {k: v for k, v in target_table.schema().items()} if isinstance( source_table_schema[column], (dt.Date, dt.Timestamp) @@ -762,9 +768,11 @@ def _apply_base_cast_overrides(self, column: str, col_config: dict) -> dict: def build_dependent_aliases(self, calc_type, col_list=None): """This is a utility function for determining the required depth of all fields""" + source_table = self.get_source_ibis_calculated_table() + target_table = self.get_target_ibis_calculated_table() + order_of_operations = [] if col_list is None: - source_table = self.get_source_ibis_calculated_table() casefold_source_columns = { x.casefold(): str(x) for x in source_table.columns } @@ -819,7 +827,9 @@ def build_dependent_aliases(self, calc_type, col_list=None): if i == 0: # If we are casting the base column (i == 0) then apply any # datatype specific overrides. - col = self._apply_base_cast_overrides(column, col) + col = self._apply_base_cast_overrides( + column, col, source_table, target_table + ) name = col["name"] column_aliases[name] = i