diff --git a/data_validation/data_validation.py b/data_validation/data_validation.py index 723e11778..013d15535 100644 --- a/data_validation/data_validation.py +++ b/data_validation/data_validation.py @@ -107,6 +107,7 @@ def _add_random_row_filter(self): # Filter for only first primary key (multi-pk filter not supported) primary_key_info = self.config_manager.primary_keys[0] + query = RandomRowBuilder( [primary_key_info[consts.CONFIG_SOURCE_COLUMN]], self.config_manager.random_row_batch_size(), @@ -114,9 +115,12 @@ def _add_random_row_filter(self): self.config_manager.source_client, self.config_manager.source_schema, self.config_manager.source_table, + self.validation_builder.source_builder, ) random_rows = self.config_manager.source_client.execute(query) + if len(random_rows) == 0: + return filter_field = { consts.CONFIG_TYPE: consts.FILTER_TYPE_ISIN, consts.CONFIG_FILTER_SOURCE_COLUMN: primary_key_info[ diff --git a/data_validation/query_builder/random_row_builder.py b/data_validation/query_builder/random_row_builder.py index c77d3c847..01f12c266 100644 --- a/data_validation/query_builder/random_row_builder.py +++ b/data_validation/query_builder/random_row_builder.py @@ -14,20 +14,20 @@ import random import logging +from typing import List +from io import StringIO import ibis import ibis.expr.operations as ops import ibis.expr.types as tz import ibis.expr.rules as rlz import ibis.backends.base_sqlalchemy.compiler as sql_compiler +import ibis.backends.pandas.execution.util as pandas_util from ibis_bigquery import BigQueryClient from ibis.backends.impala.client import ImpalaClient from ibis.backends.pandas.client import PandasClient -import ibis.backends.pandas.execution.util as pandas_util - from ibis.expr.signature import Argument as Arg -from typing import List from data_validation import clients -from io import StringIO +from data_validation.query_builder.query_builder import QueryBuilder try: from third_party.ibis.ibis_teradata.client import TeradataClient @@ -90,7 +90,11 @@ def __init__(self, primary_keys: List[str], batch_size: int): self.batch_size = batch_size def compile( - self, data_client: ibis.client, schema_name: str, table_name: str + self, + data_client: ibis.client, + schema_name: str, + table_name: str, + query_builder: QueryBuilder, ) -> ibis.Expr: """Return an Ibis query object @@ -100,7 +104,9 @@ def compile( table_name (String): The name of the table to query. """ table = clients.get_ibis_table(data_client, schema_name, table_name) - randomly_sorted_table = self.maybe_add_random_sort(data_client, table) + compiled_filters = query_builder.compile_filter_fields(table) + filtered_table = table.filter(compiled_filters) if compiled_filters else table + randomly_sorted_table = self.maybe_add_random_sort(data_client, filtered_table) query = randomly_sorted_table.limit(self.batch_size)[self.primary_keys] return query diff --git a/tests/system/data_sources/test_bigquery.py b/tests/system/data_sources/test_bigquery.py index 2d1d2c53c..1a7a08431 100644 --- a/tests/system/data_sources/test_bigquery.py +++ b/tests/system/data_sources/test_bigquery.py @@ -17,6 +17,8 @@ from data_validation import __main__ as main from data_validation import cli_tools, clients, consts, data_validation, state_manager from data_validation.query_builder import random_row_builder +from data_validation.query_builder.query_builder import QueryBuilder + PROJECT_ID = os.environ["PROJECT_ID"] os.environ[consts.ENV_DIRECTORY_VAR] = f"gs://{PROJECT_ID}/integration_tests/" @@ -570,7 +572,10 @@ def test_random_row_query_builder(): bq_client = clients.get_data_client(BQ_CONN) row_query_builder = random_row_builder.RandomRowBuilder(["station_id"], 10) query = row_query_builder.compile( - bq_client, "bigquery-public-data.new_york_citibike", "citibike_stations" + bq_client, + "bigquery-public-data.new_york_citibike", + "citibike_stations", + QueryBuilder([], [], [], [], [], None), ) random_rows = bq_client.execute(query) diff --git a/tests/unit/query_builder/test_random_row_builder.py b/tests/unit/query_builder/test_random_row_builder.py index 738cf1fe0..806304eed 100644 --- a/tests/unit/query_builder/test_random_row_builder.py +++ b/tests/unit/query_builder/test_random_row_builder.py @@ -15,6 +15,7 @@ import pytest from data_validation import clients +from data_validation.query_builder.query_builder import QueryBuilder TABLE_FILE_PATH = "table_data.json" @@ -68,7 +69,9 @@ def test_compile(module_under_test, fs): primary_keys = ["col_a"] builder = module_under_test.RandomRowBuilder(primary_keys, 10) - query = builder.compile(client, None, CONN_CONFIG["table_name"]) + query = builder.compile( + client, None, CONN_CONFIG["table_name"], QueryBuilder([], [], [], [], [], None) + ) df = client.execute(query) assert list(df.columns) == primary_keys