Skip to content

Commit

Permalink
feat: Allow user to specify a format for stdout (#242)
Browse files Browse the repository at this point in the history
  • Loading branch information
Yogesh Tewari committed Jul 30, 2021
1 parent eb0f21a commit 28f983f
Show file tree
Hide file tree
Showing 7 changed files with 59 additions and 10 deletions.
2 changes: 2 additions & 0 deletions data_validation/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ def build_config_managers_from_args(args):
result_handler_config=result_handler_config,
filter_config=filter_config,
verbose=args.verbose,
format=args.format
)
configs.append(build_config_from_args(args, config_manager))

Expand Down Expand Up @@ -281,6 +282,7 @@ def run_validation(config_manager, verbose=False):
"""
validator = DataValidation(
config_manager.config,
format=config_manager.format,
validation_builder=None,
result_handler=None,
verbose=verbose,
Expand Down
6 changes: 6 additions & 0 deletions data_validation/cli_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,12 @@ def _configure_run_parser(subparsers):
"-filters",
help="Filters in the format source_filter:target_filter",
)
run_parser.add_argument(
"--format",
"-format",
default="table",
help="Set the format for printing command output",
)


def _configure_connection_parser(subparsers):
Expand Down
7 changes: 7 additions & 0 deletions data_validation/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,11 @@ def threshold(self):
"""Return threshold from Config """
return self._config.get(consts.CONFIG_THRESHOLD, 0.0)

@property
def format(self):
"""Return threshold from Config """
return self._config.get(consts.CONFIG_FORMAT, "table")

def get_source_ibis_table(self):
"""Return IbisTable from source."""
if not hasattr(self, "_source_ibis_table"):
Expand Down Expand Up @@ -269,6 +274,7 @@ def build_config_manager(
table_obj,
labels,
threshold,
format,
result_handler_config=None,
filter_config=None,
verbose=False,
Expand All @@ -289,6 +295,7 @@ def build_config_manager(
consts.CONFIG_THRESHOLD: threshold,
consts.CONFIG_RESULT_HANDLER: result_handler_config,
consts.CONFIG_FILTERS: filter_config,
consts.CONFIG_FORMAT: format,
}

# Only FileSystem connections do not require schemas
Expand Down
1 change: 1 addition & 0 deletions data_validation/consts.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
CONFIG_SOURCE_COLUMN = "source_column"
CONFIG_TARGET_COLUMN = "target_column"
CONFIG_THRESHOLD = "threshold"
CONFIG_FORMAT = "format"
CONFIG_CAST = "cast"
CONFIG_LIMIT = "limit"
CONFIG_FILTERS = "filters"
Expand Down
5 changes: 4 additions & 1 deletion data_validation/data_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ class DataValidation(object):
def __init__(
self,
config,
format,
validation_builder=None,
schema_validator=None,
result_handler=None,
Expand All @@ -58,6 +59,8 @@ def __init__(
# Data Client Management
self.config = config

self.format = format

self.source_client = clients.get_data_client(
self.config[consts.CONFIG_SOURCE_CONN]
)
Expand Down Expand Up @@ -102,7 +105,7 @@ def execute(self):
)

# Call Result Handler to Manage Results
return self.result_handler.execute(self.config, result_df)
return self.result_handler.execute(self.config, self.format, result_df)

def query_too_large(self, rows_df, grouped_fields):
""" Return bool to dictate if another level of recursion
Expand Down
25 changes: 23 additions & 2 deletions data_validation/result_handlers/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,29 @@
"""


class TextResultHandler(object):
def execute(self, config, result_df):
def print_formatted_(format, result_df):
"""
Utility for printing formatted results
:param result_df
:param format
"""
if format == "text":
print(result_df.to_string(index=False))
elif format == "csv":
print(result_df.to_csv(index=False))
elif format == "json":
print(result_df.to_json(orient="index"))
elif format == "table":
print(result_df.to_markdown(tablefmt="fancy_grid"))
else:
error_msg = f"format [{format}] not supported, results printed in default(table) mode. " \
f"Supported formats are [text, csv, json, table]"
print(result_df.to_markdown(tablefmt="fancy_grid"))
raise ValueError(error_msg)


class TextResultHandler(object):
def execute(self, config, format, result_df):
print_formatted_(format, result_df)

return result_df
23 changes: 16 additions & 7 deletions tests/system/data_sources/test_bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,11 @@

import os

import pytest

from data_validation import cli_tools, consts, data_validation
from data_validation import __main__ as main


BQ_CONN = {"source_type": "BigQuery", "project_id": os.environ["PROJECT_ID"]}
CONFIG_COUNT_VALID = {
# BigQuery Specific Connection Name
Expand Down Expand Up @@ -178,7 +179,7 @@


def test_count_validator():
validator = data_validation.DataValidation(CONFIG_COUNT_VALID, verbose=True)
validator = data_validation.DataValidation(CONFIG_COUNT_VALID, format="text", verbose=True)
df = validator.execute()

count_value = df[df["validation_name"] == "count"]["source_agg_value"].values[0]
Expand All @@ -201,13 +202,13 @@ def test_count_validator():
assert float(max_birth_year_value) > 0
assert float(min_birth_year_value) > 0
assert (
df["source_agg_value"].astype(float).sum()
== df["target_agg_value"].astype(float).sum()
df["source_agg_value"].astype(float).sum()
== df["target_agg_value"].astype(float).sum()
)


def test_grouped_count_validator():
validator = data_validation.DataValidation(CONFIG_GROUPED_COUNT_VALID, verbose=True)
validator = data_validation.DataValidation(CONFIG_GROUPED_COUNT_VALID, format="csv", verbose=True)
df = validator.execute()
rows = list(df[df["validation_name"] == "count"].iterrows())

Expand All @@ -223,7 +224,7 @@ def test_grouped_count_validator():


def test_numeric_types():
validator = data_validation.DataValidation(CONFIG_NUMERIC_AGG_VALID, verbose=True)
validator = data_validation.DataValidation(CONFIG_NUMERIC_AGG_VALID, format="json", verbose=True)
df = validator.execute()

for validation in df.to_dict(orient="records"):
Expand All @@ -246,7 +247,7 @@ def test_cli_store_yaml_then_run():
# The number of lines is not significant, except that it represents
# the exact file expected to be created. Any change to this value
# is likely to be a breaking change and must be assessed.
assert len(yaml_file.readlines()) == 32
assert len(yaml_file.readlines()) == 33

# Run generated config
run_config_args = parser.parse_args(CLI_RUN_CONFIG_ARGS)
Expand Down Expand Up @@ -278,3 +279,11 @@ def _store_bq_conn():
def _remove_bq_conn():
file_path = cli_tools._get_connection_file(BQ_CONN_NAME)
os.remove(file_path)


def test_unsupported_result_format():
with pytest.raises(ValueError) as exp:
validator = data_validation.DataValidation(CONFIG_GROUPED_COUNT_VALID, format="foobar", verbose=True)
df = validator.execute()
rows = list(df[df["validation_name"] == "count"].iterrows())
assert len(rows) > 1

0 comments on commit 28f983f

Please sign in to comment.