Skip to content

Commit

Permalink
Initial framework for schema validation (#228)
Browse files Browse the repository at this point in the history
* Initial framework for schema validation

* Made fixes according to comments from Dylan

* lint fix

* nit remove compare to None

* fixing client tests

* cleaning schema validation code and centralizing metadata

* use fake fs

* cleaning and improving tests

* improving tests

* linting

* lint

* lint style

* lint style

* clean tests

Co-authored-by: Dylan Hercher <[email protected]>
  • Loading branch information
dhaval-d and dhercher committed Apr 27, 2021
1 parent c63f68e commit e0b2be0
Show file tree
Hide file tree
Showing 11 changed files with 489 additions and 87 deletions.
23 changes: 13 additions & 10 deletions data_validation/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,12 +80,15 @@ def build_config_from_args(args, config_manager):
config_manager (ConfigManager): Validation config manager instance.
"""
config_manager.append_aggregates(get_aggregate_config(args, config_manager))
if config_manager.validation_type in ["GroupedColumn", "Row"]:
if config_manager.validation_type in [
consts.GROUPED_COLUMN_VALIDATION,
consts.ROW_VALIDATION,
]:
grouped_columns = cli_tools.get_json_arg(args.grouped_columns)
config_manager.append_query_groups(
config_manager.build_config_grouped_columns(grouped_columns)
)
if config_manager.validation_type in ["Row"]:
if config_manager.validation_type in [consts.ROW_VALIDATION]:
primary_keys = cli_tools.get_json_arg(args.primary_keys, default_value=[])
config_manager.append_primary_keys(
config_manager.build_config_grouped_columns(primary_keys)
Expand Down Expand Up @@ -114,8 +117,8 @@ def build_config_managers_from_args(args):
if args.filters:
filter_config = cli_tools.get_json_arg(args.filters)

source_client = DataValidation.get_data_client(source_conn)
target_client = DataValidation.get_data_client(target_conn)
source_client = clients.get_data_client(source_conn)
target_client = clients.get_data_client(target_conn)

threshold = args.threshold if args.threshold else 0.0
tables_list = cli_tools.get_json_arg(args.tables_list, default_value=[])
Expand Down Expand Up @@ -148,8 +151,8 @@ def build_config_managers_from_yaml(args):
source_conn = cli_tools.get_connection(yaml_configs[consts.YAML_SOURCE])
target_conn = cli_tools.get_connection(yaml_configs[consts.YAML_TARGET])

source_client = DataValidation.get_data_client(source_conn)
target_client = DataValidation.get_data_client(target_conn)
source_client = clients.get_data_client(source_conn)
target_client = clients.get_data_client(target_conn)

for config in yaml_configs[consts.YAML_VALIDATIONS]:
config[consts.CONFIG_SOURCE_CONN] = source_conn
Expand Down Expand Up @@ -215,8 +218,8 @@ def find_tables_using_string_matching(args):
source_conn = cli_tools.get_connection(args.source_conn)
target_conn = cli_tools.get_connection(args.target_conn)

source_client = DataValidation.get_data_client(source_conn)
target_client = DataValidation.get_data_client(target_conn)
source_client = clients.get_data_client(source_conn)
target_client = clients.get_data_client(target_conn)

allowed_schemas = cli_tools.get_json_arg(args.allowed_schemas)
source_table_map = get_table_map(source_client, allowed_schemas=allowed_schemas)
Expand All @@ -229,7 +232,7 @@ def find_tables_using_string_matching(args):
def run_raw_query_against_connection(args):
"""Return results of raw query for adhoc usage."""
conn = cli_tools.get_connection(args.conn)
client = DataValidation.get_data_client(conn)
client = clients.get_data_client(conn)

with client.raw_sql(args.query, results=True) as cur:
return cur.fetchall()
Expand Down Expand Up @@ -313,7 +316,7 @@ def run_connections(args):
elif args.connect_cmd == "add":
conn = cli_tools.get_connection_config_from_args(args)
# Test getting a client to validate connection details
_ = DataValidation.get_data_client(conn)
_ = clients.get_data_client(conn)
cli_tools.store_connection(args.connection_name, conn)
else:
raise ValueError(f"Connections Argument '{args.connect_cmd}' is not supported")
Expand Down
4 changes: 3 additions & 1 deletion data_validation/cli_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,9 @@ def _configure_run_parser(subparsers):
)

run_parser.add_argument(
"--type", "-t", help="Type of Data Validation (Column, GroupedColumn, Row)"
"--type",
"-t",
help="Type of Data Validation (Column, GroupedColumn, Row, Schema)",
)
run_parser.add_argument("--source-conn", "-sc", help="Source connection name")
run_parser.add_argument("--target-conn", "-tc", help="Target connection name")
Expand Down
37 changes: 36 additions & 1 deletion data_validation/clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@

import pandas
import warnings
import copy

import google.oauth2.service_account

from google.cloud import bigquery
from ibis.backends.bigquery.client import BigQueryClient
Expand All @@ -27,7 +30,7 @@
from third_party.ibis.ibis_cloud_spanner.api import connect as spanner_connect
from third_party.ibis.ibis_impala.api import impala_connect
from data_validation import client_info

from data_validation import consts, exceptions

# Our customized Ibis Datatype logic add support for new types
third_party.ibis.ibis_addon.datatypes
Expand Down Expand Up @@ -174,6 +177,38 @@ def get_all_tables(client, allowed_schemas=None):
return table_objs


def get_data_client(connection_config):
""" Return DataClient client from given configuration """
connection_config = copy.deepcopy(connection_config)
source_type = connection_config.pop(consts.SOURCE_TYPE)

# The BigQueryClient expects a credentials object, not a string.
if consts.GOOGLE_SERVICE_ACCOUNT_KEY_PATH in connection_config:
key_path = connection_config.pop(consts.GOOGLE_SERVICE_ACCOUNT_KEY_PATH)
if key_path:
connection_config[
"credentials"
] = google.oauth2.service_account.Credentials.from_service_account_file(
key_path
)

if source_type not in CLIENT_LOOKUP:
msg = 'ConfigurationError: Source type "{source_type}" is not supported'.format(
source_type=source_type
)
raise Exception(msg)

try:
data_client = CLIENT_LOOKUP[source_type](**connection_config)
except Exception as e:
msg = 'Connection Type "{source_type}" could not connect: {error}'.format(
source_type=source_type, error=str(e)
)
raise exceptions.DataClientConnectionFailure(msg)

return data_client


CLIENT_LOOKUP = {
"BigQuery": get_bigquery_client,
"Impala": impala_connect,
Expand Down
16 changes: 16 additions & 0 deletions data_validation/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,22 @@ def target_table(self):
consts.CONFIG_TARGET_TABLE_NAME, self._config[consts.CONFIG_TABLE_NAME]
)

@property
def full_target_table(self):
"""Return string value of fully qualified target table."""
if self.target_schema:
return self.target_schema + "." + self.target_table
else:
return self.target_table

@property
def full_source_table(self):
"""Return string value of target table."""
if self.source_schema:
return self.source_schema + "." + self.source_table
else:
return self.source_table

@property
def labels(self):
"""Return labels."""
Expand Down
6 changes: 6 additions & 0 deletions data_validation/consts.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,12 @@
FILTER_TYPE_CUSTOM = "custom"
FILTER_TYPE_EQUALS = "equals"

# Validation Types
COLUMN_VALIDATION = "Column"
GROUPED_COLUMN_VALIDATION = "GroupedColumn"
ROW_VALIDATION = "Row"
SCHEMA_VALIDATION = "Schema"

# Yaml File Config Fields
ENV_DIRECTORY_VAR = "PSO_DV_CONFIG_HOME"
DEFAULT_ENV_DIRECTORY = "~/.config/google-pso-data-validator/"
Expand Down
79 changes: 30 additions & 49 deletions data_validation/data_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
import datetime
import json
import logging
import warnings

import google.oauth2.service_account
import ibis.backends.pandas
import pandas
import numpy

from data_validation import consts, combiner, exceptions, metadata, clients
from data_validation import consts, combiner, metadata, clients
from data_validation.config_manager import ConfigManager
from data_validation.validation_builder import ValidationBuilder
from data_validation.schema_validation import SchemaValidation

""" The DataValidation class is where the code becomes source/target aware
Expand All @@ -38,49 +37,65 @@

class DataValidation(object):
def __init__(
self, config, validation_builder=None, result_handler=None, verbose=False
self,
config,
validation_builder=None,
schema_validator=None,
result_handler=None,
verbose=False,
):
"""Initialize a DataValidation client
Args:
config (dict): The validation config used for the comparison
validation_builder (ValidationBuilder): Optional instance of a ValidationBuilder
result_handler (ResultHandler): Optional instance of as ResultHandler client
verbose (bool): If verbose, the Data Validation client will print the queries run
config (dict): The validation config used for the comparison.
validation_builder (ValidationBuilder): Optional instance of a ValidationBuilder.
schema_validator (SchemaValidation): Optional instance of a SchemaValidation.
result_handler (ResultHandler): Optional instance of as ResultHandler client.
verbose (bool): If verbose, the Data Validation client will print the queries run.
"""
self.verbose = verbose

# Data Client Management
self.config = config

self.source_client = DataValidation.get_data_client(
self.source_client = clients.get_data_client(
self.config[consts.CONFIG_SOURCE_CONN]
)
self.target_client = DataValidation.get_data_client(
self.target_client = clients.get_data_client(
self.config[consts.CONFIG_TARGET_CONN]
)

self.config_manager = ConfigManager(
config, self.source_client, self.target_client, verbose=self.verbose
)

self.run_metadata = metadata.RunMetadata()
self.run_metadata.labels = self.config_manager.labels

# Initialize Validation Builder if None was supplied
self.validation_builder = validation_builder or ValidationBuilder(
self.config_manager
)

self.schema_validator = schema_validator or SchemaValidation(
self.config_manager, run_metadata=self.run_metadata, verbose=self.verbose
)

# Initialize the default Result Handler if None was supplied
self.result_handler = result_handler or self.config_manager.get_result_handler()

# TODO(dhercher) we planned on shifting this to use an Execution Handler.
# Leaving to to swast on the design of how this should look.
def execute(self):
""" Execute Queries and Store Results """
if self.config_manager.validation_type == "Row":
if self.config_manager.validation_type == consts.ROW_VALIDATION:
grouped_fields = self.validation_builder.pop_grouped_fields()
result_df = self.execute_recursive_validation(
self.validation_builder, grouped_fields
)
elif self.config_manager.validation_type == consts.SCHEMA_VALIDATION:
""" Perform only schema validation """
result_df = self.schema_validator.execute()
else:
result_df = self._execute_validation(
self.validation_builder, process_in_memory=True
Expand Down Expand Up @@ -233,10 +248,7 @@ def _get_pandas_schema(self, source_df, target_df, join_on_fields, verbose=False

def _execute_validation(self, validation_builder, process_in_memory=True):
""" Execute Against a Supplied Validation Builder """
run_metadata = metadata.RunMetadata()
run_metadata.end_time = datetime.datetime.now(datetime.timezone.utc)
run_metadata.validations = validation_builder.get_metadata()
run_metadata.labels = self.config_manager.labels
self.run_metadata.validations = validation_builder.get_metadata()

source_query = validation_builder.get_source_query()
target_query = validation_builder.get_target_query()
Expand All @@ -257,7 +269,7 @@ def _execute_validation(self, validation_builder, process_in_memory=True):
try:
result_df = combiner.generate_report(
pandas_client,
run_metadata,
self.run_metadata,
pandas_client.table(combiner.DEFAULT_SOURCE, schema=pd_schema),
pandas_client.table(combiner.DEFAULT_TARGET, schema=pd_schema),
join_on_fields=join_on_fields,
Expand All @@ -275,47 +287,16 @@ def _execute_validation(self, validation_builder, process_in_memory=True):
else:
result_df = combiner.generate_report(
self.source_client,
run_metadata,
self.run_metadata,
source_query,
target_query,
join_on_fields=join_on_fields,
verbose=self.verbose,
)

self.run_metadata.end_time = datetime.datetime.now(datetime.timezone.utc)
return result_df

@staticmethod
def get_data_client(connection_config):
""" Return DataClient client from given configuration """
connection_config = copy.deepcopy(connection_config)
source_type = connection_config.pop(consts.SOURCE_TYPE)

# The BigQueryClient expects a credentials object, not a string.
if consts.GOOGLE_SERVICE_ACCOUNT_KEY_PATH in connection_config:
key_path = connection_config.pop(consts.GOOGLE_SERVICE_ACCOUNT_KEY_PATH)
if key_path:
connection_config[
"credentials"
] = google.oauth2.service_account.Credentials.from_service_account_file(
key_path
)

if source_type not in clients.CLIENT_LOOKUP:
msg = 'ConfigurationError: Source type "{source_type}" is not supported'.format(
source_type=source_type
)
raise Exception(msg)

try:
data_client = clients.CLIENT_LOOKUP[source_type](**connection_config)
except Exception as e:
msg = 'Connection Type "{source_type}" could not connect: {error}'.format(
source_type=source_type, error=str(e)
)
raise exceptions.DataClientConnectionFailure(msg)

return data_client

def combine_data(self, source_df, target_df, join_on_fields):
""" TODO: Return List of Dictionaries """
# Clean Data to Standardize
Expand Down
Loading

0 comments on commit e0b2be0

Please sign in to comment.