Skip to content

Commit

Permalink
fix: random rows with filter option (#582)
Browse files Browse the repository at this point in the history
  • Loading branch information
kanhaPrayas committed Sep 13, 2022
1 parent 489654c commit da4faaf
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 8 deletions.
4 changes: 4 additions & 0 deletions data_validation/data_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,16 +107,20 @@ def _add_random_row_filter(self):

# Filter for only first primary key (multi-pk filter not supported)
primary_key_info = self.config_manager.primary_keys[0]

query = RandomRowBuilder(
[primary_key_info[consts.CONFIG_SOURCE_COLUMN]],
self.config_manager.random_row_batch_size(),
).compile(
self.config_manager.source_client,
self.config_manager.source_schema,
self.config_manager.source_table,
self.validation_builder.source_builder,
)

random_rows = self.config_manager.source_client.execute(query)
if len(random_rows) == 0:
return
filter_field = {
consts.CONFIG_TYPE: consts.FILTER_TYPE_ISIN,
consts.CONFIG_FILTER_SOURCE_COLUMN: primary_key_info[
Expand Down
18 changes: 12 additions & 6 deletions data_validation/query_builder/random_row_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,20 @@

import random
import logging
from typing import List
from io import StringIO
import ibis
import ibis.expr.operations as ops
import ibis.expr.types as tz
import ibis.expr.rules as rlz
import ibis.backends.base_sqlalchemy.compiler as sql_compiler
import ibis.backends.pandas.execution.util as pandas_util
from ibis_bigquery import BigQueryClient
from ibis.backends.impala.client import ImpalaClient
from ibis.backends.pandas.client import PandasClient
import ibis.backends.pandas.execution.util as pandas_util

from ibis.expr.signature import Argument as Arg
from typing import List
from data_validation import clients
from io import StringIO
from data_validation.query_builder.query_builder import QueryBuilder

try:
from third_party.ibis.ibis_teradata.client import TeradataClient
Expand Down Expand Up @@ -90,7 +90,11 @@ def __init__(self, primary_keys: List[str], batch_size: int):
self.batch_size = batch_size

def compile(
self, data_client: ibis.client, schema_name: str, table_name: str
self,
data_client: ibis.client,
schema_name: str,
table_name: str,
query_builder: QueryBuilder,
) -> ibis.Expr:
"""Return an Ibis query object
Expand All @@ -100,7 +104,9 @@ def compile(
table_name (String): The name of the table to query.
"""
table = clients.get_ibis_table(data_client, schema_name, table_name)
randomly_sorted_table = self.maybe_add_random_sort(data_client, table)
compiled_filters = query_builder.compile_filter_fields(table)
filtered_table = table.filter(compiled_filters) if compiled_filters else table
randomly_sorted_table = self.maybe_add_random_sort(data_client, filtered_table)
query = randomly_sorted_table.limit(self.batch_size)[self.primary_keys]

return query
Expand Down
7 changes: 6 additions & 1 deletion tests/system/data_sources/test_bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
from data_validation import __main__ as main
from data_validation import cli_tools, clients, consts, data_validation, state_manager
from data_validation.query_builder import random_row_builder
from data_validation.query_builder.query_builder import QueryBuilder


PROJECT_ID = os.environ["PROJECT_ID"]
os.environ[consts.ENV_DIRECTORY_VAR] = f"gs://{PROJECT_ID}/integration_tests/"
Expand Down Expand Up @@ -570,7 +572,10 @@ def test_random_row_query_builder():
bq_client = clients.get_data_client(BQ_CONN)
row_query_builder = random_row_builder.RandomRowBuilder(["station_id"], 10)
query = row_query_builder.compile(
bq_client, "bigquery-public-data.new_york_citibike", "citibike_stations"
bq_client,
"bigquery-public-data.new_york_citibike",
"citibike_stations",
QueryBuilder([], [], [], [], [], None),
)

random_rows = bq_client.execute(query)
Expand Down
5 changes: 4 additions & 1 deletion tests/unit/query_builder/test_random_row_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import pytest

from data_validation import clients
from data_validation.query_builder.query_builder import QueryBuilder


TABLE_FILE_PATH = "table_data.json"
Expand Down Expand Up @@ -68,7 +69,9 @@ def test_compile(module_under_test, fs):
primary_keys = ["col_a"]
builder = module_under_test.RandomRowBuilder(primary_keys, 10)

query = builder.compile(client, None, CONN_CONFIG["table_name"])
query = builder.compile(
client, None, CONN_CONFIG["table_name"], QueryBuilder([], [], [], [], [], None)
)
df = client.execute(query)

assert list(df.columns) == primary_keys
Expand Down

0 comments on commit da4faaf

Please sign in to comment.