Skip to content

Commit

Permalink
fix: issue 265 add cloud spanner functionality (#394)
Browse files Browse the repository at this point in the history
* fix: issue 265 add cloud spanner functionality

*Add/override raw_sql() for cloud spanner
*Add SpannerCursor class to implement ibis client interface
*This is because current spanner client returns a dataframe which is not
compatible with the context manager that raw_query funtionality is
expecting

* fix: update test for raw_sql()
  • Loading branch information
ngdav committed Mar 30, 2022
1 parent 88b6620 commit 783cdf8
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 8 deletions.
80 changes: 73 additions & 7 deletions third_party/ibis/ibis_cloud_spanner/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,48 @@ def execute(self):
return dataframe_output


class SpannerCursor:
"""Spanner cursor.
This allows the Spanner client to reuse machinery in
:file:`ibis/client.py`.
"""

def __init__(self, results):
self.results = results

def fetchall(self):
"""Fetch all rows."""
result = self.results
return [tuple(row) for row in result]

@property
def columns(self):
"""Return the columns of the result set."""
result = self.results
return [field.name for field in result.fields]

@property
def description(self):
"""Get the fields of the result set's schema."""
return self.results.metadata.row_type

def __enter__(self):
# For compatibility when constructed from Query.execute()
"""No-op for compatibility.
See Also
--------
ibis.client.Query.execute
"""
return self

def __exit__(self, exc_type, exc_value, traceback):
"""No-op for compatibility.
See Also
--------
ibis.client.Query.execute
"""


class CloudSpannerDatabase(Database):
"""A Cloud spanner dataset."""

Expand Down Expand Up @@ -240,17 +282,16 @@ def __init__(self, instance_id, database_id, project_id=None, credentials=None):
self.spanner_client = spanner.Client(project=project_id)
self.instance = self.spanner_client.instance(instance_id)
self.database_name = self.instance.database(database_id)
(
self.data_instance,
self.dataset,
) = parse_instance_and_dataset(instance_id, database_id)
(self.data_instance, self.dataset,) = parse_instance_and_dataset(
instance_id, database_id
)
self.client = cs.Client()

def _parse_instance_and_dataset(self, dataset):
if not dataset and not self.dataset:
raise ValueError("Unable to determine Cloud Spanner dataset.")
instance, dataset = parse_instance_and_dataset(
self.data_instance,(dataset or self.dataset)
self.data_instance, (dataset or self.dataset)
)

return instance, dataset
Expand Down Expand Up @@ -366,6 +407,33 @@ def _execute(self, stmt, results=True, query_parameters=None):
data_qry = pandas_df.to_pandas(snapshot, stmt, query_parameters)
return data_qry

def raw_sql(self, query: str, results=False, params=None):
query_parameters = [
cloud_spanner_param(param, value) for param, value in (params or {}).items()
]
spanner_client = spanner.Client()
instance_id = self.instance_id
instance = spanner_client.instance(instance_id)
database_id = self.dataset_id
database_1 = instance.database(database_id)
with database_1.snapshot() as snapshot:
if query_parameters:
param = {}
param_type = {}
for i in query_parameters:
param.update(i["params"])
param_type.update(i["param_types"])

results = snapshot.execute_sql(
query, params=param, param_types=param_type
)

else:
results = snapshot.execute_sql(query)

sp = SpannerCursor(results)
return sp

def database(self, name=None):
if name is None and self.dataset is None:
raise ValueError(
Expand All @@ -385,5 +453,3 @@ def dataset(self, database):

def exists_database(self, name):
return self.instance.database(name).exists()


2 changes: 1 addition & 1 deletion third_party/ibis/ibis_cloud_spanner/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ def test_scalar_param_date(alltypes, df, date_value):


def test_raw_sql(client):
assert (client.raw_sql("SELECT 1")).iloc[0][0] == 1
assert (client.raw_sql("SELECT 1")).fetchall()[0][0] == 1


def test_scalar_param_scope(alltypes):
Expand Down

0 comments on commit 783cdf8

Please sign in to comment.