Skip to content

Commit

Permalink
feat: Allow user to specify a format for stdout, files reformatted wi…
Browse files Browse the repository at this point in the history
…th black (#242)
  • Loading branch information
Yogesh Tewari committed Jul 30, 2021
1 parent 5707b3d commit f730ac9
Show file tree
Hide file tree
Showing 24 changed files with 193 additions and 112 deletions.
4 changes: 2 additions & 2 deletions data_validation/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def build_config_managers_from_args(args):
result_handler_config=result_handler_config,
filter_config=filter_config,
verbose=args.verbose,
format=args.format
format=args.format,
)
configs.append(build_config_from_args(args, config_manager))

Expand Down Expand Up @@ -326,7 +326,7 @@ def run(args):


def run_connections(args):
""" Run commands related to connection management."""
"""Run commands related to connection management."""
if args.connect_cmd == "list":
cli_tools.list_connections()
elif args.connect_cmd == "add":
Expand Down
28 changes: 15 additions & 13 deletions data_validation/cli_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@


def get_parsed_args():
""" Return ArgParser with configured CLI arguments."""
"""Return ArgParser with configured CLI arguments."""
parser = configure_arg_parser()
return parser.parse_args()

Expand Down Expand Up @@ -194,7 +194,7 @@ def _configure_raw_query(subparsers):


def _configure_run_config_parser(subparsers):
""" Configure arguments to run a data validation YAML config."""
"""Configure arguments to run a data validation YAML config."""
run_config_parser = subparsers.add_parser(
"run-config", help="Run validations stored in a YAML config file"
)
Expand All @@ -206,7 +206,7 @@ def _configure_run_config_parser(subparsers):


def _configure_run_parser(subparsers):
""" Configure arguments to run a data validation."""
"""Configure arguments to run a data validation."""

# subparsers = parser.add_subparsers(dest="command")

Expand Down Expand Up @@ -273,7 +273,9 @@ def _configure_run_parser(subparsers):
help="Store the validation in the YAML Config File Path specified",
)
run_parser.add_argument(
"--labels", "-l", help="Key value pair labels for validation run",
"--labels",
"-l",
help="Key value pair labels for validation run",
)
run_parser.add_argument(
"--service-account",
Expand All @@ -300,7 +302,7 @@ def _configure_run_parser(subparsers):


def _configure_connection_parser(subparsers):
""" Configure the Parser for Connection Management. """
"""Configure the Parser for Connection Management."""
connection_parser = subparsers.add_parser(
"connections", help="Manage & Store connections to your Databases"
)
Expand Down Expand Up @@ -335,7 +337,7 @@ def _configure_database_specific_parsers(parser):


def get_connection_config_from_args(args):
""" Return dict with connection config supplied."""
"""Return dict with connection config supplied."""
config = {consts.SOURCE_TYPE: args.connect_type}

if args.connect_type == "Raw":
Expand Down Expand Up @@ -388,7 +390,7 @@ def _generate_random_name(conn):


def store_connection(connection_name, conn):
""" Store the connection config under the given name."""
"""Store the connection config under the given name."""
connection_name = connection_name or _generate_random_name(conn)
file_path = _get_connection_file(connection_name)

Expand All @@ -397,7 +399,7 @@ def store_connection(connection_name, conn):


def get_connections():
""" Return dict with connection name and path key pairs."""
"""Return dict with connection name and path key pairs."""
connections = {}

dir_path = _get_data_validation_directory()
Expand All @@ -413,15 +415,15 @@ def get_connections():


def list_connections():
""" List all saved connections."""
"""List all saved connections."""
connections = get_connections()

for conn_name in connections:
print(f"Connection Name: {conn_name}")


def get_connection(connection_name):
""" Return dict connection details for a specific connection."""
"""Return dict connection details for a specific connection."""
file_path = _get_connection_file(connection_name)
with open(file_path, "r") as file:
conn_str = file.read()
Expand All @@ -430,7 +432,7 @@ def get_connection(connection_name):


def get_labels(arg_labels):
""" Return list of tuples representing key-value label pairs. """
"""Return list of tuples representing key-value label pairs."""
labels = []
if arg_labels:
pairs = arg_labels.split(",")
Expand Down Expand Up @@ -515,7 +517,7 @@ def get_arg_list(arg_value, default_value=None):


def get_tables_list(arg_tables, default_value=None, is_filesystem=False):
""" Returns dictionary of tables. Backwards compatible for JSON input.
"""Returns dictionary of tables. Backwards compatible for JSON input.
arg_table (str): tables_list argument specified
default_value (Any): A default value to supply when arg_value is empty.
Expand Down Expand Up @@ -571,7 +573,7 @@ def get_tables_list(arg_tables, default_value=None, is_filesystem=False):


def split_table(table_ref, schema_required=True):
""" Returns schema and table name given list of input values.
"""Returns schema and table name given list of input values.
table_ref (List): Table reference i.e ['my.schema.my_table']
scehma_required (boolean): Indicates whether schema is required. A source
Expand Down
2 changes: 1 addition & 1 deletion data_validation/clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def get_all_tables(client, allowed_schemas=None):


def get_data_client(connection_config):
""" Return DataClient client from given configuration """
"""Return DataClient client from given configuration"""
connection_config = copy.deepcopy(connection_config)
source_type = connection_config.pop(consts.SOURCE_TYPE)

Expand Down
7 changes: 6 additions & 1 deletion data_validation/combiner.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,12 @@


def generate_report(
client, run_metadata, source, target, join_on_fields=(), verbose=False,
client,
run_metadata,
source,
target,
join_on_fields=(),
verbose=False,
):
"""Combine results into a report.
Expand Down
20 changes: 11 additions & 9 deletions data_validation/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,12 +82,12 @@ def process_in_memory(self):

@property
def max_recursive_query_size(self):
"""Return Aggregates from Config """
"""Return Aggregates from Config"""
return self._config.get(consts.CONFIG_MAX_RECURSIVE_QUERY_SIZE, 50000)

@property
def aggregates(self):
"""Return Aggregates from Config """
"""Return Aggregates from Config"""
return self._config.get(consts.CONFIG_AGGREGATES, [])

def append_aggregates(self, aggregate_configs):
Expand All @@ -105,7 +105,7 @@ def append_calculated_fields(self, calculated_configs):

@property
def query_groups(self):
""" Return Query Groups from Config """
"""Return Query Groups from Config"""
return self._config.get(consts.CONFIG_GROUPED_COLUMNS, [])

def append_query_groups(self, grouped_column_configs):
Expand All @@ -116,7 +116,7 @@ def append_query_groups(self, grouped_column_configs):

@property
def primary_keys(self):
""" Return Query Groups from Config """
"""Return Query Groups from Config"""
return self._config.get(consts.CONFIG_PRIMARY_KEYS, [])

def append_primary_keys(self, primary_key_configs):
Expand All @@ -127,7 +127,7 @@ def append_primary_keys(self, primary_key_configs):

@property
def filters(self):
"""Return Filters from Config """
"""Return Filters from Config"""
return self._config.get(consts.CONFIG_FILTERS, [])

@property
Expand Down Expand Up @@ -185,12 +185,12 @@ def query_limit(self):

@property
def threshold(self):
"""Return threshold from Config """
"""Return threshold from Config"""
return self._config.get(consts.CONFIG_THRESHOLD, 0.0)

@property
def format(self):
"""Return threshold from Config """
"""Return threshold from Config"""
return self._config.get(consts.CONFIG_FORMAT, "table")

def get_source_ibis_table(self):
Expand Down Expand Up @@ -253,8 +253,10 @@ def get_result_handler(self):
consts.GOOGLE_SERVICE_ACCOUNT_KEY_PATH
)
if key_path:
credentials = google.oauth2.service_account.Credentials.from_service_account_file(
key_path
credentials = (
google.oauth2.service_account.Credentials.from_service_account_file(
key_path
)
)
else:
credentials = None
Expand Down
28 changes: 14 additions & 14 deletions data_validation/data_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,14 +90,14 @@ def __init__(
# TODO(dhercher) we planned on shifting this to use an Execution Handler.
# Leaving to to swast on the design of how this should look.
def execute(self):
""" Execute Queries and Store Results """
"""Execute Queries and Store Results"""
if self.config_manager.validation_type == consts.ROW_VALIDATION:
grouped_fields = self.validation_builder.pop_grouped_fields()
result_df = self.execute_recursive_validation(
self.validation_builder, grouped_fields
)
elif self.config_manager.validation_type == consts.SCHEMA_VALIDATION:
""" Perform only schema validation """
"""Perform only schema validation"""
result_df = self.schema_validator.execute()
else:
result_df = self._execute_validation(
Expand All @@ -108,15 +108,15 @@ def execute(self):
return self.result_handler.execute(self.config, self.format, result_df)

def query_too_large(self, rows_df, grouped_fields):
""" Return bool to dictate if another level of recursion
would create a too large result set.
Rules to define too large are:
- If any grouped fields remain, return False.
(assumes user added logical sized groups)
- Else, if next group size is larger
than the limit, return True.
- Finally return False if no covered case occured.
"""Return bool to dictate if another level of recursion
would create a too large result set.
Rules to define too large are:
- If any grouped fields remain, return False.
(assumes user added logical sized groups)
- Else, if next group size is larger
than the limit, return True.
- Finally return False if no covered case occured.
"""
if len(grouped_fields) > 1:
return False
Expand Down Expand Up @@ -203,7 +203,7 @@ def execute_recursive_validation(self, validation_builder, grouped_fields):
return pandas.concat(past_results)

def _add_recursive_validation_filter(self, validation_builder, row):
""" Return ValidationBuilder Configured for Next Recursive Search """
"""Return ValidationBuilder Configured for Next Recursive Search"""
group_by_columns = json.loads(row[consts.GROUP_BY_COLUMNS])
for alias, value in group_by_columns.items():
filter_field = {
Expand Down Expand Up @@ -250,7 +250,7 @@ def _get_pandas_schema(self, source_df, target_df, join_on_fields, verbose=False
return pd_schema

def _execute_validation(self, validation_builder, process_in_memory=True):
""" Execute Against a Supplied Validation Builder """
"""Execute Against a Supplied Validation Builder"""
self.run_metadata.validations = validation_builder.get_metadata()

source_query = validation_builder.get_source_query()
Expand Down Expand Up @@ -301,7 +301,7 @@ def _execute_validation(self, validation_builder, process_in_memory=True):
return result_df

def combine_data(self, source_df, target_df, join_on_fields):
""" TODO: Return List of Dictionaries """
"""TODO: Return List of Dictionaries"""
# Clean Data to Standardize
if join_on_fields:
df = source_df.merge(
Expand Down
Loading

0 comments on commit f730ac9

Please sign in to comment.