Skip to content

Commit

Permalink
fix: Protect column and row validation calculated column names from O…
Browse files Browse the repository at this point in the history
…racle 30 character identifier limit (#749)

* fix: Protect row validation calculated column names from Oracle 30 character identifier limit

* fix: Protect column validation calculated column names from Oracle 30 character identifier limit

* fix: Protect column validation calculated column names from Oracle 30 character identifier limit
  • Loading branch information
nj1973 committed Mar 14, 2023
1 parent a7889bf commit 89413c1
Show file tree
Hide file tree
Showing 3 changed files with 122 additions and 25 deletions.
23 changes: 23 additions & 0 deletions data_validation/clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,14 @@ def get_pandas_client(table_name, file_path, file_type):
return pandas_client


def is_oracle_client(client):
try:
return isinstance(client, OracleClient)
except TypeError:
# When no Oracle client has been installed OracleClient is not a class
return False


def get_ibis_table(client, schema_name, table_name, database_name=None):
"""Return Ibis Table for Supplied Client.
Expand Down Expand Up @@ -269,6 +277,21 @@ def get_data_client(connection_config):
return data_client


def get_max_column_length(client):
"""Return the max column length supported by client.
client (IbisClient): Client to use for tables
"""
if is_oracle_client(client):
# We can't reliably know which Version class client.version is stored in
# because it is out of our control. Therefore using string identification
# of Oracle <= 12.1 to avoid exceptions of this nature:
# TypeError: '<' not supported between instances of 'Version' and 'Version'
if str(client.version)[:2] in ["10", "11"] or str(client.version)[:4] == "12.1":
return 30
return 128


CLIENT_LOOKUP = {
"BigQuery": get_bigquery_client,
"Impala": impala_connect,
Expand Down
81 changes: 56 additions & 25 deletions data_validation/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def __init__(self, config, source_client=None, target_client=None, verbose=False
self.verbose = verbose
if self.validation_type not in consts.CONFIG_TYPES:
raise ValueError(f"Unknown Configuration Type: {self.validation_type}")
self._comparison_max_col_length = None

@property
def config(self):
Expand Down Expand Up @@ -525,23 +526,41 @@ def build_config_count_aggregate(self):

return aggregate_config

def _prefix_calc_col_name(
self, column_name: str, prefix: str, column_number: int
) -> str:
"""Prefix a column name but protect final string from overflowing SQL engine identifier length limit."""
new_name = f"{prefix}__{column_name}"
if len(new_name) > self._get_comparison_max_col_length():
# Use an abstract name for the calculated column to avoid composing invalid SQL.
new_name = f"{prefix}__dvt_calc_col_{column_number}"
return new_name

def build_and_append_pre_agg_calc_config(
self, source_column, target_column, calc_func, cast_type=None, depth=0
self,
source_column,
target_column,
calc_func,
column_position,
cast_type=None,
depth=0,
):
"""Create calculated field config used as a pre-aggregation step. Appends to calulated fields if does not already exist and returns created config."""
calculated_config = {
consts.CONFIG_CALCULATED_SOURCE_COLUMNS: [source_column],
consts.CONFIG_CALCULATED_TARGET_COLUMNS: [target_column],
consts.CONFIG_FIELD_ALIAS: f"{calc_func}__{source_column}",
consts.CONFIG_FIELD_ALIAS: self._prefix_calc_col_name(
source_column, calc_func, column_position
),
consts.CONFIG_TYPE: calc_func,
consts.CONFIG_DEPTH: depth,
}

if calc_func == "cast" and cast_type is not None:
calculated_config[consts.CONFIG_DEFAULT_CAST] = cast_type
calculated_config[
consts.CONFIG_FIELD_ALIAS
] = f"{calc_func}_{cast_type}__{source_column}"
calculated_config[consts.CONFIG_FIELD_ALIAS] = self._prefix_calc_col_name(
source_column, f"{calc_func}_{cast_type}", column_position
)

existing_calc_fields = [
config[consts.CONFIG_FIELD_ALIAS] for config in self.calculated_fields
Expand All @@ -552,7 +571,7 @@ def build_and_append_pre_agg_calc_config(
return calculated_config

def append_pre_agg_calc_field(
self, source_column, target_column, agg_type, column_type
self, source_column, target_column, agg_type, column_type, column_position
):
"""Append calculated field for length(string) or epoch_seconds(timestamp) for preprocessing before column validation aggregation."""
depth, cast_type = 0, None
Expand All @@ -567,7 +586,12 @@ def append_pre_agg_calc_field(
calc_func = "cast"
cast_type = "timestamp"
pre_calculated_config = self.build_and_append_pre_agg_calc_config(
source_column, target_column, calc_func, cast_type, depth
source_column,
target_column,
calc_func,
column_position,
cast_type,
depth,
)
source_column = target_column = pre_calculated_config[
consts.CONFIG_FIELD_ALIAS
Expand All @@ -584,13 +608,17 @@ def append_pre_agg_calc_field(
raise ValueError(f"Unsupported column type: {column_type}")

calculated_config = self.build_and_append_pre_agg_calc_config(
source_column, target_column, calc_func, cast_type, depth
source_column, target_column, calc_func, column_position, cast_type, depth
)

aggregate_config = {
consts.CONFIG_SOURCE_COLUMN: f"{calculated_config[consts.CONFIG_FIELD_ALIAS]}",
consts.CONFIG_TARGET_COLUMN: f"{calculated_config[consts.CONFIG_FIELD_ALIAS]}",
consts.CONFIG_FIELD_ALIAS: f"{agg_type}__{calculated_config[consts.CONFIG_FIELD_ALIAS]}",
consts.CONFIG_FIELD_ALIAS: self._prefix_calc_col_name(
calculated_config[consts.CONFIG_FIELD_ALIAS],
f"{agg_type}",
column_position,
),
consts.CONFIG_TYPE: agg_type,
}

Expand All @@ -613,7 +641,7 @@ def build_config_column_aggregates(
supported_types.append("string")

allowlist_columns = arg_value or casefold_source_columns
for column in casefold_source_columns:
for column_position, column in enumerate(casefold_source_columns):
# Get column type and remove precision/scale attributes
column_type_str = str(source_table[casefold_source_columns[column]].type())
column_type = column_type_str.split("(")[0]
Expand Down Expand Up @@ -650,13 +678,16 @@ def build_config_column_aggregates(
casefold_target_columns[column],
agg_type,
column_type,
column_position,
)

else:
aggregate_config = {
consts.CONFIG_SOURCE_COLUMN: casefold_source_columns[column],
consts.CONFIG_TARGET_COLUMN: casefold_target_columns[column],
consts.CONFIG_FIELD_ALIAS: f"{agg_type}__{column}",
consts.CONFIG_FIELD_ALIAS: self._prefix_calc_col_name(
column, f"{agg_type}", column_position
),
consts.CONFIG_TYPE: agg_type,
}

Expand Down Expand Up @@ -711,19 +742,22 @@ def build_config_calculated_fields(

return calculated_config

def _get_comparison_max_col_length(self) -> int:
if not self._comparison_max_col_length:
self._comparison_max_col_length = min(
[
clients.get_max_column_length(self.source_client),
clients.get_max_column_length(self.target_client),
]
)
return self._comparison_max_col_length

def _strftime_format(
self, column_type: Union[dt.Date, dt.Timestamp], client
) -> str:
def is_oracle_client(client):
# When no Oracle client installed clients.OracleClient is not a class
try:
return isinstance(client, clients.OracleClient)
except TypeError:
return False

if isinstance(column_type, dt.Timestamp):
return "%Y-%m-%d %H:%M:%S"
if is_oracle_client(client):
if clients.is_oracle_client(client):
# Oracle DATE is a DateTime
return "%Y-%m-%d %H:%M:%S"
return "%Y-%m-%d"
Expand Down Expand Up @@ -819,14 +853,11 @@ def build_dependent_aliases(self, calc_type, col_list=None):
column_aliases[name] = i
col_names.append(col)
else:
for (
column
) in (
previous_level
): # this needs to be the previous manifest of columns
# This needs to be the previous manifest of columns
for j, column in enumerate(previous_level):
col = {}
col["reference"] = [column]
col["name"] = f"{calc}__" + column
col["name"] = self._prefix_calc_col_name(column, calc, j)
col["calc_type"] = calc
col["depth"] = i

Expand Down
43 changes: 43 additions & 0 deletions tests/unit/test_config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,3 +380,46 @@ def test_custom_query_get_query_from_inline(module_under_test):
"Expected arg with sql query, got empty arg or arg "
"with white spaces. input query: ' '"
)


def test__get_comparison_max_col_length(module_under_test):
config_manager = module_under_test.ConfigManager(
SAMPLE_CONFIG, MockIbisClient(), MockIbisClient(), verbose=False
)
max_identifier_length = config_manager._get_comparison_max_col_length()
assert isinstance(max_identifier_length, int)
short_itentifier = "id"
too_long_itentifier = "a_long_column_name".ljust(max_identifier_length + 1, "_")
nearly_too_long_itentifier = "another_long_column_name".ljust(
max_identifier_length - 1, "_"
)
assert len(short_itentifier) < max_identifier_length
assert len(too_long_itentifier) > max_identifier_length
assert len(nearly_too_long_itentifier) < max_identifier_length
new_identifier = config_manager._prefix_calc_col_name(
short_itentifier, "prefix", 900
)
assert (
len(short_itentifier) <= max_identifier_length
), f"Column name is too long: {new_identifier}"
assert (
"900" not in new_identifier
), f"Column name should NOT contain ID 900: {new_identifier}"
new_identifier = config_manager._prefix_calc_col_name(
too_long_itentifier, "prefix", 901
)
assert (
len(new_identifier) <= max_identifier_length
), f"Column name is too long: {new_identifier}"
assert (
"901" in new_identifier
), f"Column name should contain ID 901: {new_identifier}"
new_identifier = config_manager._prefix_calc_col_name(
nearly_too_long_itentifier, "prefix", 902
)
assert (
len(new_identifier) <= max_identifier_length
), f"Column name is too long: {new_identifier}"
assert (
"902" in new_identifier
), f"Column name should contain ID 902: {new_identifier}"

0 comments on commit 89413c1

Please sign in to comment.