From 783cdf8810c29755b26e4894555b6dd03f4c9025 Mon Sep 17 00:00:00 2001 From: ngdav <52801477+ngdav@users.noreply.github.com> Date: Wed, 30 Mar 2022 13:12:11 -0400 Subject: [PATCH] fix: issue 265 add cloud spanner functionality (#394) * fix: issue 265 add cloud spanner functionality *Add/override raw_sql() for cloud spanner *Add SpannerCursor class to implement ibis client interface *This is because current spanner client returns a dataframe which is not compatible with the context manager that raw_query funtionality is expecting * fix: update test for raw_sql() --- third_party/ibis/ibis_cloud_spanner/client.py | 80 +++++++++++++++++-- .../ibis_cloud_spanner/tests/test_client.py | 2 +- 2 files changed, 74 insertions(+), 8 deletions(-) diff --git a/third_party/ibis/ibis_cloud_spanner/client.py b/third_party/ibis/ibis_cloud_spanner/client.py index 1f1692e69..dd9567eec 100644 --- a/third_party/ibis/ibis_cloud_spanner/client.py +++ b/third_party/ibis/ibis_cloud_spanner/client.py @@ -212,6 +212,48 @@ def execute(self): return dataframe_output +class SpannerCursor: + """Spanner cursor. + This allows the Spanner client to reuse machinery in + :file:`ibis/client.py`. + """ + + def __init__(self, results): + self.results = results + + def fetchall(self): + """Fetch all rows.""" + result = self.results + return [tuple(row) for row in result] + + @property + def columns(self): + """Return the columns of the result set.""" + result = self.results + return [field.name for field in result.fields] + + @property + def description(self): + """Get the fields of the result set's schema.""" + return self.results.metadata.row_type + + def __enter__(self): + # For compatibility when constructed from Query.execute() + """No-op for compatibility. + See Also + -------- + ibis.client.Query.execute + """ + return self + + def __exit__(self, exc_type, exc_value, traceback): + """No-op for compatibility. + See Also + -------- + ibis.client.Query.execute + """ + + class CloudSpannerDatabase(Database): """A Cloud spanner dataset.""" @@ -240,17 +282,16 @@ def __init__(self, instance_id, database_id, project_id=None, credentials=None): self.spanner_client = spanner.Client(project=project_id) self.instance = self.spanner_client.instance(instance_id) self.database_name = self.instance.database(database_id) - ( - self.data_instance, - self.dataset, - ) = parse_instance_and_dataset(instance_id, database_id) + (self.data_instance, self.dataset,) = parse_instance_and_dataset( + instance_id, database_id + ) self.client = cs.Client() def _parse_instance_and_dataset(self, dataset): if not dataset and not self.dataset: raise ValueError("Unable to determine Cloud Spanner dataset.") instance, dataset = parse_instance_and_dataset( - self.data_instance,(dataset or self.dataset) + self.data_instance, (dataset or self.dataset) ) return instance, dataset @@ -366,6 +407,33 @@ def _execute(self, stmt, results=True, query_parameters=None): data_qry = pandas_df.to_pandas(snapshot, stmt, query_parameters) return data_qry + def raw_sql(self, query: str, results=False, params=None): + query_parameters = [ + cloud_spanner_param(param, value) for param, value in (params or {}).items() + ] + spanner_client = spanner.Client() + instance_id = self.instance_id + instance = spanner_client.instance(instance_id) + database_id = self.dataset_id + database_1 = instance.database(database_id) + with database_1.snapshot() as snapshot: + if query_parameters: + param = {} + param_type = {} + for i in query_parameters: + param.update(i["params"]) + param_type.update(i["param_types"]) + + results = snapshot.execute_sql( + query, params=param, param_types=param_type + ) + + else: + results = snapshot.execute_sql(query) + + sp = SpannerCursor(results) + return sp + def database(self, name=None): if name is None and self.dataset is None: raise ValueError( @@ -385,5 +453,3 @@ def dataset(self, database): def exists_database(self, name): return self.instance.database(name).exists() - - diff --git a/third_party/ibis/ibis_cloud_spanner/tests/test_client.py b/third_party/ibis/ibis_cloud_spanner/tests/test_client.py index 1e2cd0ea3..37837fccd 100644 --- a/third_party/ibis/ibis_cloud_spanner/tests/test_client.py +++ b/third_party/ibis/ibis_cloud_spanner/tests/test_client.py @@ -263,7 +263,7 @@ def test_scalar_param_date(alltypes, df, date_value): def test_raw_sql(client): - assert (client.raw_sql("SELECT 1")).iloc[0][0] == 1 + assert (client.raw_sql("SELECT 1")).fetchall()[0][0] == 1 def test_scalar_param_scope(alltypes):