Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: issue 265 add cloud spanner functionality #394

Merged
merged 2 commits into from
Mar 30, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 73 additions & 7 deletions third_party/ibis/ibis_cloud_spanner/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,48 @@ def execute(self):
return dataframe_output


class SpannerCursor:
"""Spanner cursor.
This allows the Spanner client to reuse machinery in
:file:`ibis/client.py`.
"""

def __init__(self, results):
self.results = results

def fetchall(self):
"""Fetch all rows."""
result = self.results
return [tuple(row) for row in result]

@property
def columns(self):
"""Return the columns of the result set."""
result = self.results
return [field.name for field in result.fields]

@property
def description(self):
"""Get the fields of the result set's schema."""
return self.results.metadata.row_type

def __enter__(self):
# For compatibility when constructed from Query.execute()
"""No-op for compatibility.
See Also
--------
ibis.client.Query.execute
"""
return self

def __exit__(self, exc_type, exc_value, traceback):
"""No-op for compatibility.
See Also
--------
ibis.client.Query.execute
"""


class CloudSpannerDatabase(Database):
"""A Cloud spanner dataset."""

Expand Down Expand Up @@ -240,17 +282,16 @@ def __init__(self, instance_id, database_id, project_id=None, credentials=None):
self.spanner_client = spanner.Client(project=project_id)
self.instance = self.spanner_client.instance(instance_id)
self.database_name = self.instance.database(database_id)
(
self.data_instance,
self.dataset,
) = parse_instance_and_dataset(instance_id, database_id)
(self.data_instance, self.dataset,) = parse_instance_and_dataset(
instance_id, database_id
)
self.client = cs.Client()

def _parse_instance_and_dataset(self, dataset):
if not dataset and not self.dataset:
raise ValueError("Unable to determine Cloud Spanner dataset.")
instance, dataset = parse_instance_and_dataset(
self.data_instance,(dataset or self.dataset)
self.data_instance, (dataset or self.dataset)
)

return instance, dataset
Expand Down Expand Up @@ -366,6 +407,33 @@ def _execute(self, stmt, results=True, query_parameters=None):
data_qry = pandas_df.to_pandas(snapshot, stmt, query_parameters)
return data_qry

def raw_sql(self, query: str, results=False, params=None):
query_parameters = [
cloud_spanner_param(param, value) for param, value in (params or {}).items()
]
spanner_client = spanner.Client()
instance_id = self.instance_id
instance = spanner_client.instance(instance_id)
database_id = self.dataset_id
database_1 = instance.database(database_id)
with database_1.snapshot() as snapshot:
if query_parameters:
param = {}
param_type = {}
for i in query_parameters:
param.update(i["params"])
param_type.update(i["param_types"])

results = snapshot.execute_sql(
query, params=param, param_types=param_type
)

else:
results = snapshot.execute_sql(query)

sp = SpannerCursor(results)
return sp

def database(self, name=None):
if name is None and self.dataset is None:
raise ValueError(
Expand All @@ -385,5 +453,3 @@ def dataset(self, database):

def exists_database(self, name):
return self.instance.database(name).exists()


2 changes: 1 addition & 1 deletion third_party/ibis/ibis_cloud_spanner/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ def test_scalar_param_date(alltypes, df, date_value):


def test_raw_sql(client):
assert (client.raw_sql("SELECT 1")).iloc[0][0] == 1
assert (client.raw_sql("SELECT 1")).fetchall()[0][0] == 1


def test_scalar_param_scope(alltypes):
Expand Down