diff --git a/README.md b/README.md index cbf7446f3..8d34b88d5 100644 --- a/README.md +++ b/README.md @@ -203,7 +203,9 @@ case specific CLI arguments or editing the saved YAML configuration file. For example, the following command creates a YAML file for the validation of the `new_york_citibike` table: `data-validation validate column -sc my_bq_conn -tc my_bq_conn -tbls bigquery-public-data.new_york_citibike.citibike_trips -c -citibike.yaml` +citibike.yaml`. + +The vaildation config file is saved to the GCS path specified by the `PSO_DV_CONFIG_HOME` env variable if that has been set; otherwise, it is saved to wherever the tool is run. Here is the generated YAML file named `citibike.yaml`: @@ -233,12 +235,14 @@ Once the file is updated and saved, the following command runs the new validation: ``` -data-validation run-config -c citibike.yaml +data-validation configs run -c citibike.yaml ``` View the complete YAML file for a GroupedColumn validation on the [examples](docs/examples.md#) page. +You can view a list of all saved validation YAML files using `data-validation configs list`, and print a YAML config using `data-validation configs get -c citibike.yaml`. + ### Aggregated Fields Aggregate fields contain the SQL fields that you want to produce an aggregate diff --git a/data_validation/__main__.py b/data_validation/__main__.py index ffe933457..389632c6e 100644 --- a/data_validation/__main__.py +++ b/data_validation/__main__.py @@ -16,7 +16,6 @@ import logging import json -from yaml import dump, load, Dumper, Loader from data_validation import ( cli_tools, @@ -28,6 +27,9 @@ from data_validation.config_manager import ConfigManager from data_validation.data_validation import DataValidation +from yaml import dump +import sys + def _get_arg_config_file(args): """Return String yaml config file path.""" @@ -39,10 +41,8 @@ def _get_arg_config_file(args): def _get_yaml_config_from_file(config_file_path): """Return Dict of yaml validation data.""" - with open(config_file_path, "r") as yaml_file: - yaml_configs = load(yaml_file.read(), Loader=Loader) - - return yaml_configs + yaml_config = cli_tools.get_validation(config_file_path) + return yaml_config def get_aggregate_config(args, config_manager): @@ -336,12 +336,9 @@ def store_yaml_config_file(args, config_managers): Args: config_managers (list[ConfigManager]): List of config manager instances. """ - config_file_path = _get_arg_config_file(args) yaml_configs = convert_config_to_yaml(args, config_managers) - yaml_config_str = dump(yaml_configs, Dumper=Dumper) - - with open(config_file_path, "w") as yaml_file: - yaml_file.write(yaml_config_str) + config_file_path = _get_arg_config_file(args) + cli_tools.store_validation(config_file_path, yaml_configs) def run(args): @@ -355,7 +352,7 @@ def run(args): def run_connections(args): - """ Run commands related to connection management.""" + """Run commands related to connection management.""" if args.connect_cmd == "list": cli_tools.list_connections() elif args.connect_cmd == "add": @@ -367,8 +364,29 @@ def run_connections(args): raise ValueError(f"Connections Argument '{args.connect_cmd}' is not supported") +def run_config(args): + """Run commands related to validation config YAMLs (legacy - superceded by run_validation_configs).""" + config_managers = build_config_managers_from_yaml(args) + run_validations(args, config_managers) + + +def run_validation_configs(args): + """Run commands related to validation config YAMLs.""" + if args.validation_config_cmd == "run": + config_managers = build_config_managers_from_yaml(args) + run_validations(args, config_managers) + elif args.validation_config_cmd == "list": + cli_tools.list_validations() + elif args.validation_config_cmd == "get": + # Get and print yaml file config. + yaml = cli_tools.get_validation(_get_arg_config_file(args)) + dump(yaml, sys.stdout) + else: + raise ValueError(f"Configs argument '{args.validate_cmd}' is not supported") + + def validate(args): - """ Run commands related to data validation.""" + """Run commands related to data validation.""" if args.validate_cmd == "column" or args.validate_cmd == "schema": run(args) else: @@ -384,8 +402,9 @@ def main(): elif args.command == "connections": run_connections(args) elif args.command == "run-config": - config_managers = build_config_managers_from_yaml(args) - run_validations(args, config_managers) + run_config(args) + elif args.command == "configs": + run_validation_configs(args) elif args.command == "find-tables": print(find_tables_using_string_matching(args)) elif args.command == "query": diff --git a/data_validation/cli_tools.py b/data_validation/cli_tools.py index 9cfce0ff2..c8bf4e42e 100644 --- a/data_validation/cli_tools.py +++ b/data_validation/cli_tools.py @@ -11,8 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - - """The Data Validation CLI tool is intended to help to build and execute data validation runs with ease. @@ -53,7 +51,6 @@ from data_validation import consts from data_validation import state_manager - CONNECTION_SOURCE_FIELDS = { "BigQuery": [ ["project_id", "GCP Project to use for BigQuery"], @@ -133,7 +130,7 @@ def get_parsed_args(): - """ Return ArgParser with configured CLI arguments.""" + """Return ArgParser with configured CLI arguments.""" parser = configure_arg_parser() return parser.parse_args() @@ -149,6 +146,7 @@ def configure_arg_parser(): subparsers = parser.add_subparsers(dest="command") _configure_validate_parser(subparsers) _configure_run_config_parser(subparsers) + _configure_validation_config_parser(subparsers) _configure_connection_parser(subparsers) _configure_find_tables(subparsers) _configure_raw_query(subparsers) @@ -209,10 +207,12 @@ def _configure_raw_query(subparsers): def _configure_run_config_parser(subparsers): - """ Configure arguments to run a data validation YAML config.""" + """Configure arguments to run a data validation YAML config using the legacy run-config command.""" run_config_parser = subparsers.add_parser( - "run-config", help="Run validations stored in a YAML config file" + "run-config", + help="Run validations stored in a YAML config file. Note: the 'configs run' command is now the recommended approach", ) + run_config_parser.add_argument( "--config-file", "-c", @@ -220,8 +220,36 @@ def _configure_run_config_parser(subparsers): ) +def _configure_validation_config_parser(subparsers): + """Configure arguments to run a data validation YAML config.""" + validation_config_parser = subparsers.add_parser( + "configs", help="Run validations stored in a YAML config file" + ) + configs_subparsers = validation_config_parser.add_subparsers( + dest="validation_config_cmd" + ) + _ = configs_subparsers.add_parser("list", help="List your validation configs") + run_parser = configs_subparsers.add_parser( + "run", help="Run your validation configs" + ) + run_parser.add_argument( + "--config-file", + "-c", + help="YAML Config File Path to be used for building or running validations.", + ) + + get_parser = configs_subparsers.add_parser( + "get", help="Get and print a validation config" + ) + get_parser.add_argument( + "--config-file", + "-c", + help="YAML Config File Path to be used for building or running validations.", + ) + + def _configure_run_parser(subparsers): - """ Configure arguments to run a data validation.""" + """Configure arguments to run a data validation.""" # subparsers = parser.add_subparsers(dest="command") @@ -327,7 +355,7 @@ def _configure_run_parser(subparsers): def _configure_connection_parser(subparsers): - """ Configure the Parser for Connection Management. """ + """Configure the Parser for Connection Management.""" connection_parser = subparsers.add_parser( "connections", help="Manage & Store connections to your Databases" ) @@ -487,7 +515,7 @@ def _add_common_arguments(parser): def get_connection_config_from_args(args): - """ Return dict with connection config supplied.""" + """Return dict with connection config supplied.""" config = {consts.SOURCE_TYPE: args.connect_type} if args.connect_type == "Raw": @@ -525,7 +553,6 @@ def threshold_float(x): # os.makedirs(dir_path) # return dir_path - # def _get_connection_file(connection_name): # dir_path = _get_data_validation_directory() # file_name = f"{connection_name}.connection.json" @@ -538,7 +565,7 @@ def _generate_random_name(conn): def store_connection(connection_name, conn): - """ Store the connection config under the given name.""" + """Store the connection config under the given name.""" mgr = state_manager.StateManager() mgr.create_connection(connection_name, conn) @@ -549,7 +576,6 @@ def store_connection(connection_name, conn): # with open(file_path, "w") as file: # file.write(json.dumps(conn)) - # def get_connections(): # """ Return dict with connection name and path key pairs.""" # connections = {} @@ -567,7 +593,7 @@ def store_connection(connection_name, conn): def list_connections(): - """ List all saved connections.""" + """List all saved connections.""" mgr = state_manager.StateManager() connections = mgr.list_connections() @@ -576,7 +602,7 @@ def list_connections(): def get_connection(connection_name): - """ Return dict connection details for a specific connection.""" + """Return dict connection details for a specific connection.""" mgr = state_manager.StateManager() return mgr.get_connection_config(connection_name) @@ -586,8 +612,30 @@ def get_connection(connection_name): # return json.loads(conn_str) +def store_validation(validation_file_name, yaml_config): + """Store the validation YAML config under the given name.""" + mgr = state_manager.StateManager() + mgr.create_validation_yaml(validation_file_name, yaml_config) + + +def get_validation(validation_name): + """Return validation YAML for a specific connection.""" + mgr = state_manager.StateManager() + return mgr.get_validation_config(validation_name) + + +def list_validations(): + """List all saved validation YAMLs.""" + mgr = state_manager.StateManager() + validations = mgr.list_validations() + + print("Validation YAMLs found:") + for validation_name in validations: + print(f"{validation_name}.yaml") + + def get_labels(arg_labels): - """ Return list of tuples representing key-value label pairs. """ + """Return list of tuples representing key-value label pairs.""" labels = [] if arg_labels: pairs = arg_labels.split(",") @@ -672,7 +720,7 @@ def get_arg_list(arg_value, default_value=None): def get_tables_list(arg_tables, default_value=None, is_filesystem=False): - """ Returns dictionary of tables. Backwards compatible for JSON input. + """Returns dictionary of tables. Backwards compatible for JSON input. arg_table (str): tables_list argument specified default_value (Any): A default value to supply when arg_value is empty. @@ -728,7 +776,7 @@ def get_tables_list(arg_tables, default_value=None, is_filesystem=False): def split_table(table_ref, schema_required=True): - """ Returns schema and table name given list of input values. + """Returns schema and table name given list of input values. table_ref (List): Table reference i.e ['my.schema.my_table'] schema_required (boolean): Indicates whether schema is required. A source diff --git a/data_validation/state_manager.py b/data_validation/state_manager.py index e68706d55..d5a799467 100644 --- a/data_validation/state_manager.py +++ b/data_validation/state_manager.py @@ -22,6 +22,7 @@ import os from google.cloud import storage from typing import Dict, List +from yaml import dump, load, Dumper, Loader from data_validation import client_info from data_validation import consts @@ -93,13 +94,59 @@ def _get_connections_directory(self) -> str: def _get_connection_path(self, name: str) -> str: """Returns the full path to a connection. - Args: - name: The name of the connection. - """ + Args: + name: The name of the connection. + """ return os.path.join( self._get_connections_directory(), f"{name}.connection.json" ) + def create_validation_yaml(self, name: str, yaml_config: Dict[str, str]): + """Create a validation file and store the given config as YAML. + + Args: + name (String): The name of the validation. + yaml_config (Dict): A dictionary with the validation details. + """ + validation_path = self._get_validation_path(name) + yaml_config_str = dump(yaml_config, Dumper=Dumper) + self._write_file(validation_path, yaml_config_str) + + def get_validation_config(self, name: str) -> Dict[str, str]: + """Get a validation configuration from the expected file. + + Args: + name: The name of the validation. + Returns: + A dict of the validation values from the file. + """ + validation_path = self._get_validation_path(name) + validation_bytes = self._read_file(validation_path) + return load(validation_bytes, Loader=Loader) + + def list_validations(self): + file_names = self._list_directory(self._get_validations_directory()) + return [ + file_name.split(".")[0] + for file_name in file_names + if file_name.endswith(".yaml") + ] + + def _get_validations_directory(self): + """Returns the validations directory path.""" + if self.file_system == FileSystem.LOCAL: + # Validation configs should be written to tool root dir, not consts.DEFAULT_ENV_DIRECTORY as connections are + return "./" + return os.path.join(self.file_system_root_path, "validations/") + + def _get_validation_path(self, name: str) -> str: + """Returns the full path to a validation. + + Args: + name: The name of the validation. + """ + return os.path.join(self._get_validations_directory(), f"{name}") + def _read_file(self, file_path: str) -> str: if self.file_system == FileSystem.GCS: return self._read_gcs_file(file_path) @@ -113,6 +160,8 @@ def _write_file(self, file_path: str, data: str): with open(file_path, "w") as file: file.write(data) + print("Success! Config output written to {}".format(file_path)) + def _list_directory(self, directory_path: str) -> List[str]: if self.file_system == FileSystem.GCS: return self._list_gcs_directory(directory_path) @@ -154,7 +203,7 @@ def _read_gcs_file(self, file_path: str) -> str: gcs_file_path = self._get_gcs_file_path(file_path) blob = self.gcs_bucket.get_blob(gcs_file_path) - return blob.download_as_string() + return blob.download_as_bytes() def _write_gcs_file(self, file_path: str, data: str): gcs_file_path = self._get_gcs_file_path(file_path) diff --git a/docs/connections.md b/docs/connections.md index 249fca59f..2da380d22 100644 --- a/docs/connections.md +++ b/docs/connections.md @@ -8,7 +8,7 @@ a directory specified by the env variable `PSO_DV_CONFIG_HOME`. ## GCS Connection Management (recommended) The connections can also be stored in GCS using `PSO_DV_CONFIG_HOME`. -To do so simply add the GCS path to the environment. +To do so simply add the GCS path to the environment. Note that if this path is set, query validation configs will also be saved here. eg. `export PSO_DV_CONFIG_HOME=gs://my-bucket/my/connections/path/` diff --git a/tests/system/data_sources/test_bigquery.py b/tests/system/data_sources/test_bigquery.py index 5b7d71ea6..ebba9593e 100644 --- a/tests/system/data_sources/test_bigquery.py +++ b/tests/system/data_sources/test_bigquery.py @@ -14,10 +14,9 @@ import os -from data_validation import cli_tools, consts, data_validation +from data_validation import cli_tools, consts, data_validation, state_manager from data_validation import __main__ as main - PROJECT_ID = os.environ["PROJECT_ID"] os.environ[consts.ENV_DIRECTORY_VAR] = f"gs://{PROJECT_ID}/integration_tests/" BQ_CONN = {"source_type": "BigQuery", "project_id": PROJECT_ID} @@ -167,7 +166,9 @@ "--config-file", CLI_CONFIG_FILE, ] +EXPECTED_NUM_YAML_LINES = 35 # Expected number of lines for validation config geenrated by CLI_STORE_COLUMN_ARGS CLI_RUN_CONFIG_ARGS = ["run-config", "--config-file", CLI_CONFIG_FILE] +CLI_CONFIGS_RUN_ARGS = ["configs", "run", "--config-file", CLI_CONFIG_FILE] CLI_FIND_TABLES_ARGS = [ "find-tables", @@ -237,7 +238,48 @@ def test_numeric_types(): ) -def test_cli_store_yaml_then_run(): +def test_cli_store_yaml_then_run_gcs(): + """Test storing and retrieving validation YAML when GCS env var is set.""" + # Store BQ Connection + _store_bq_conn() + + # Build validation and store to file + parser = cli_tools.configure_arg_parser() + mock_args = parser.parse_args(CLI_STORE_COLUMN_ARGS) + main.run(mock_args) + + # Look for YAML file in GCS env directory, since that has been set + yaml_file_path = os.path.join( + os.environ[consts.ENV_DIRECTORY_VAR], "validations/", CLI_CONFIG_FILE + ) + + # The number of lines is not significant, except that it represents + # the exact file expected to be created. Any change to this value + # is likely to be a breaking change and must be assessed. + mgr = state_manager.StateManager() + validation_bytes = mgr._read_file(yaml_file_path) + yaml_file_str = validation_bytes.decode("utf-8") + assert len(yaml_file_str.splitlines()) == EXPECTED_NUM_YAML_LINES + + # Run generated config using 'run-config' command + run_config_args = parser.parse_args(CLI_RUN_CONFIG_ARGS) + config_managers = main.build_config_managers_from_yaml(run_config_args) + main.run_validations(run_config_args, config_managers) + + # Run generated config using 'configs run' command + run_config_args = parser.parse_args(CLI_CONFIGS_RUN_ARGS) + config_managers = main.build_config_managers_from_yaml(run_config_args) + main.run_validations(run_config_args, config_managers) + + # _remove_bq_conn() + + +def test_cli_store_yaml_then_run_local(): + """Test storing and retrieving validation YAML when GCS env var not set.""" + # Unset GCS env var so that YAML is saved locally + gcs_path = os.environ[consts.ENV_DIRECTORY_VAR] + os.environ[consts.ENV_DIRECTORY_VAR] = "" + # Store BQ Connection _store_bq_conn() @@ -251,16 +293,24 @@ def test_cli_store_yaml_then_run(): # The number of lines is not significant, except that it represents # the exact file expected to be created. Any change to this value # is likely to be a breaking change and must be assessed. - assert len(yaml_file.readlines()) == 35 + assert len(yaml_file.readlines()) == EXPECTED_NUM_YAML_LINES - # Run generated config + # Run generated config using 'run-config' command run_config_args = parser.parse_args(CLI_RUN_CONFIG_ARGS) config_managers = main.build_config_managers_from_yaml(run_config_args) main.run_validations(run_config_args, config_managers) + # Run generated config using 'configs run' command + run_config_args = parser.parse_args(CLI_CONFIGS_RUN_ARGS) + config_managers = main.build_config_managers_from_yaml(run_config_args) + main.run_validations(run_config_args, config_managers) + os.remove(yaml_file_path) # _remove_bq_conn() + # Re-set GCS env var + os.environ[consts.ENV_DIRECTORY_VAR] = gcs_path + def test_cli_find_tables(): _store_bq_conn() diff --git a/tests/unit/test_cli_tools.py b/tests/unit/test_cli_tools.py index 7ac30ebce..564324c64 100644 --- a/tests/unit/test_cli_tools.py +++ b/tests/unit/test_cli_tools.py @@ -18,7 +18,6 @@ from data_validation import cli_tools - TEST_CONN = '{"source_type":"Example"}' CLI_ARGS = { "command": "validate", @@ -59,6 +58,35 @@ "example-project", ] +TEST_VALIDATION_CONFIG = { + "source": "example", + "target": "example", + "result_handler": {}, + "validations": [ + { + "type": "Column", + "table_name": "citibike_trips", + "schema_name": "bigquery-public-data.new_york_citibike", + "target_schema_name": "bigquery-public-data.new_york_citibike", + "target_table_name": "citibike_trips", + "labels": [], + "threshold": 0.0, + "format": "table", + "filters": [], + "aggregates": [ + { + "source_column": None, + "target_column": None, + "field_alias": "count", + "type": "count", + } + ], + } + ], +} + +WRITE_SUCCESS_STRING = "Success! Config output written to" + CLI_FIND_TABLES_ARGS = [ "find-tables", "--source-conn", @@ -121,6 +149,8 @@ def test_create_and_list_connections(capsys, fs): conn = cli_tools.get_connection_config_from_args(args) cli_tools.store_connection(args.connection_name, conn) + captured = capsys.readouterr() + assert WRITE_SUCCESS_STRING in captured.out # List Connection cli_tools.list_connections() @@ -129,6 +159,36 @@ def test_create_and_list_connections(capsys, fs): assert captured.out == "Connection Name: test\n" +def test_configure_arg_parser_list_and_run_validation_configs(): + """Test configuring arg parse in different ways.""" + parser = cli_tools.configure_arg_parser() + + args = parser.parse_args(["configs", "list"]) + assert args.command == "configs" + assert args.validation_config_cmd == "list" + + args = parser.parse_args(["configs", "run"]) + assert args.command == "configs" + assert args.validation_config_cmd == "run" + + +def test_create_and_list_and_get_validations(capsys, fs): + # Create validation config file + cli_tools.store_validation("example_validation.yaml", TEST_VALIDATION_CONFIG) + captured = capsys.readouterr() + assert WRITE_SUCCESS_STRING in captured.out + + # List validation configs + cli_tools.list_validations() + captured = capsys.readouterr() + + assert captured.out == "Validation YAMLs found:\nexample_validation.yaml\n" + + # Retrive the stored vaildation config + yaml_config = cli_tools.get_validation("example_validation.yaml") + assert yaml_config == TEST_VALIDATION_CONFIG + + def test_find_tables_config(): parser = cli_tools.configure_arg_parser() args = parser.parse_args(CLI_FIND_TABLES_ARGS) @@ -167,7 +227,7 @@ def test_get_labels(test_input, expected): ], ) def test_get_labels_err(test_input): - """Ensure that Value Error is raised when incorrect label argument is provided. """ + """Ensure that Value Error is raised when incorrect label argument is provided.""" with pytest.raises(ValueError): cli_tools.get_labels(test_input) @@ -329,7 +389,7 @@ def test_get_result_handler(test_input, expected): ], ) def test_get_filters(test_input, expected): - """ Test get filters from file function. """ + """Test get filters from file function.""" res = cli_tools.get_filters(test_input) assert res == expected @@ -338,7 +398,7 @@ def test_get_filters(test_input, expected): "test_input", [("source:"), ("invalid:filter:count")], ) def test_get_filters_err(test_input): - """ Test get filters function returns error. """ + """Test get filters function returns error.""" with pytest.raises(ValueError): cli_tools.get_filters(test_input) diff --git a/tests/unit/test_state_manager.py b/tests/unit/test_state_manager.py index 5bce7e57e..cdee265e0 100644 --- a/tests/unit/test_state_manager.py +++ b/tests/unit/test_state_manager.py @@ -14,12 +14,38 @@ from data_validation import state_manager - TEST_CONN_NAME = "example" TEST_CONN = { "source_type": "BigQuery", "project_id": "my-project", } +TEST_VALIDATION_NAME = "citibike.yaml" +TEST_VALIDATION_CONFIG = { + "source": "example", + "target": "example", + "result_handler": {}, + "validations": [ + { + "type": "Column", + "table_name": "citibike_trips", + "schema_name": "bigquery-public-data.new_york_citibike", + "target_schema_name": "bigquery-public-data.new_york_citibike", + "target_table_name": "citibike_trips", + "labels": [], + "threshold": 0.0, + "format": "table", + "filters": [], + "aggregates": [ + { + "source_column": None, + "target_column": None, + "field_alias": "count", + "type": "count", + } + ], + } + ], +} def test_create_and_get_connection_config(capsys, fs): @@ -50,3 +76,19 @@ def test_create_unknown_filepath(capsys, fs): file_path = manager._get_connection_path(TEST_CONN_NAME) expected_file_path = files_directory + f"{TEST_CONN_NAME}.connection.json" assert file_path == expected_file_path + + +def test_create_and_get_validation_config(capsys, fs): + manager = state_manager.StateManager() + manager.create_validation_yaml(TEST_VALIDATION_NAME, TEST_VALIDATION_CONFIG) + + config = manager.get_validation_config(TEST_VALIDATION_NAME) + assert config == TEST_VALIDATION_CONFIG + + +def test_create_and_list_validation(capsys, fs): + manager = state_manager.StateManager() + manager.create_validation_yaml(TEST_VALIDATION_NAME, TEST_VALIDATION_CONFIG) + + validations = manager.list_validations() + assert validations == [TEST_VALIDATION_NAME.split(".")[0]]