Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for model contracts in v1.5 #148

Merged
merged 14 commits into from
Apr 26, 2023
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@

import duckdb

from .credentials import DuckDBCredentials
from .plugins import Plugin
from .utils import SourceConfig
from ..credentials import DuckDBCredentials
from ..plugins import Plugin
from ..utils import SourceConfig
from dbt.contracts.connection import AdapterResponse
from dbt.exceptions import DbtRuntimeError

Expand All @@ -28,36 +28,12 @@ def _ensure_event_loop():
asyncio.set_event_loop(loop)


class DuckDBCursorWrapper:
def __init__(self, cursor):
self._cursor = cursor

# forward along all non-execute() methods/attribute look-ups
def __getattr__(self, name):
return getattr(self._cursor, name)

def execute(self, sql, bindings=None):
try:
if bindings is None:
return self._cursor.execute(sql)
else:
return self._cursor.execute(sql, bindings)
except RuntimeError as e:
raise DbtRuntimeError(str(e))


class DuckDBConnectionWrapper:
def __init__(self, cursor):
self._cursor = DuckDBCursorWrapper(cursor)

def close(self):
self._cursor.close()

def cursor(self):
return self._cursor


class Environment(abc.ABC):
"""An Environment is an abstraction to describe *where* the code you execute in your dbt-duckdb project
actually runs. This could be the local Python process that runs dbt (which is the default),
a remote server (like a Buena Vista instance), or even a Jupyter notebook kernel.
"""

@abc.abstractmethod
def handle(self):
pass
Expand All @@ -66,13 +42,13 @@ def handle(self):
def submit_python_job(self, handle, parsed_model: dict, compiled_code: str) -> AdapterResponse:
pass

def get_binding_char(self) -> str:
return "?"

@abc.abstractmethod
def load_source(self, plugin_name: str, source_config: SourceConfig) -> str:
pass

def get_binding_char(self) -> str:
return "?"

@classmethod
def initialize_db(cls, creds: DuckDBCredentials):
config = creds.config_options or {}
Expand Down Expand Up @@ -159,71 +135,14 @@ def run_python_job(cls, con, load_df_function, identifier: str, compiled_code: s
os.unlink(mod_file.name)


class LocalEnvironment(Environment):
def __init__(self, credentials: DuckDBCredentials):
self.conn = self.initialize_db(credentials)
self._plugins = self.initialize_plugins(credentials)
self.creds = credentials

def handle(self):
# Extensions/settings need to be configured per cursor
cursor = self.initialize_cursor(self.creds, self.conn.cursor())
return DuckDBConnectionWrapper(cursor)

def submit_python_job(self, handle, parsed_model: dict, compiled_code: str) -> AdapterResponse:
con = handle.cursor()

def ldf(table_name):
return con.query(f"select * from {table_name}")

self.run_python_job(con, ldf, parsed_model["alias"], compiled_code)
return AdapterResponse(_message="OK")

def load_source(self, plugin_name: str, source_config: SourceConfig):
if plugin_name not in self._plugins:
raise Exception(
f"Plugin {plugin_name} not found; known plugins are: "
+ ",".join(self._plugins.keys())
)
plugin = self._plugins[plugin_name]
handle = self.handle()
cursor = handle.cursor()
save_mode = source_config.meta.get("save_mode", "overwrite")
if save_mode in ("ignore", "error_if_exists"):
schema, identifier = source_config.schema, source_config.identifier
q = f"""SELECT COUNT(1)
FROM information_schema.tables
WHERE table_schema = '{schema}'
AND table_name = '{identifier}'
"""
if cursor.execute(q).fetchone()[0]:
if save_mode == "error_if_exists":
raise Exception(f"Source {source_config.table_name()} already exists!")
else:
# Nothing to do (we ignore the existing table)
return
df = plugin.load(source_config)
assert df is not None
materialization = source_config.meta.get("materialization", "table")
cursor.execute(
f"CREATE OR REPLACE {materialization} {source_config.table_name()} AS SELECT * FROM df"
)
cursor.close()
handle.close()

def close(self):
if self.conn:
self.conn.close()
self.conn = None

def __del__(self):
self.close()


def create(creds: DuckDBCredentials) -> Environment:
"""Create an Environment based on the credentials passed in."""

if creds.remote:
from .buenavista import BVEnvironment

return BVEnvironment(creds)
else:
from .local import LocalEnvironment

return LocalEnvironment(creds)
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

import psycopg2

from . import credentials
from . import utils
from .environments import Environment
from . import Environment
from .. import credentials
from .. import utils
from dbt.contracts.connection import AdapterResponse


Expand Down
95 changes: 95 additions & 0 deletions dbt/adapters/duckdb/environments/local.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
from . import Environment
from .. import credentials
from .. import utils
from dbt.contracts.connection import AdapterResponse
from dbt.exceptions import DbtRuntimeError


class DuckDBCursorWrapper:
def __init__(self, cursor):
self._cursor = cursor

# forward along all non-execute() methods/attribute look-ups
def __getattr__(self, name):
return getattr(self._cursor, name)

def execute(self, sql, bindings=None):
try:
if bindings is None:
return self._cursor.execute(sql)
else:
return self._cursor.execute(sql, bindings)
except RuntimeError as e:
raise DbtRuntimeError(str(e))


class DuckDBConnectionWrapper:
def __init__(self, cursor):
self._cursor = DuckDBCursorWrapper(cursor)

def close(self):
self._cursor.close()

def cursor(self):
return self._cursor


class LocalEnvironment(Environment):
def __init__(self, credentials: credentials.DuckDBCredentials):
self.conn = self.initialize_db(credentials)
self._plugins = self.initialize_plugins(credentials)
self.creds = credentials

def handle(self):
# Extensions/settings need to be configured per cursor
cursor = self.initialize_cursor(self.creds, self.conn.cursor())
return DuckDBConnectionWrapper(cursor)

def submit_python_job(self, handle, parsed_model: dict, compiled_code: str) -> AdapterResponse:
con = handle.cursor()

def ldf(table_name):
return con.query(f"select * from {table_name}")

self.run_python_job(con, ldf, parsed_model["alias"], compiled_code)
return AdapterResponse(_message="OK")

def load_source(self, plugin_name: str, source_config: utils.SourceConfig):
if plugin_name not in self._plugins:
raise Exception(
f"Plugin {plugin_name} not found; known plugins are: "
+ ",".join(self._plugins.keys())
)
plugin = self._plugins[plugin_name]
handle = self.handle()
cursor = handle.cursor()
save_mode = source_config.meta.get("save_mode", "overwrite")
if save_mode in ("ignore", "error_if_exists"):
schema, identifier = source_config.schema, source_config.identifier
q = f"""SELECT COUNT(1)
FROM information_schema.tables
WHERE table_schema = '{schema}'
AND table_name = '{identifier}'
"""
if cursor.execute(q).fetchone()[0]:
if save_mode == "error_if_exists":
raise Exception(f"Source {source_config.table_name()} already exists!")
else:
# Nothing to do (we ignore the existing table)
return
df = plugin.load(source_config)
assert df is not None
materialization = source_config.meta.get("materialization", "table")
cursor.execute(
f"CREATE OR REPLACE {materialization} {source_config.table_name()} AS SELECT * FROM df"
)
cursor.close()
handle.close()

def close(self):
if self.conn:
self.conn.close()
self.conn = None

def __del__(self):
self.close()
35 changes: 35 additions & 0 deletions dbt/adapters/duckdb/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,15 @@

from dbt.adapters.base import BaseRelation
from dbt.adapters.base.column import Column
from dbt.adapters.base.impl import ConstraintSupport
from dbt.adapters.base.meta import available
from dbt.adapters.duckdb.connections import DuckDBConnectionManager
from dbt.adapters.duckdb.glue import create_or_update_table
from dbt.adapters.duckdb.relation import DuckDBRelation
from dbt.adapters.sql import SQLAdapter
from dbt.contracts.connection import AdapterResponse
from dbt.contracts.graph.nodes import ColumnLevelConstraint
from dbt.contracts.graph.nodes import ConstraintType
from dbt.exceptions import DbtInternalError
from dbt.exceptions import DbtRuntimeError

Expand All @@ -22,6 +25,14 @@ class DuckDBAdapter(SQLAdapter):
ConnectionManager = DuckDBConnectionManager
Relation = DuckDBRelation

CONSTRAINT_SUPPORT = {
ConstraintType.check: ConstraintSupport.ENFORCED,
ConstraintType.not_null: ConstraintSupport.ENFORCED,
ConstraintType.unique: ConstraintSupport.ENFORCED,
ConstraintType.primary_key: ConstraintSupport.ENFORCED,
ConstraintType.foreign_key: ConstraintSupport.ENFORCED,
}

@classmethod
def date_function(cls) -> str:
return "now()"
Expand Down Expand Up @@ -176,6 +187,30 @@ def get_rows_different_sql(
)
return sql

@available.parse(lambda *a, **k: [])
def get_column_schema_from_query(self, sql: str) -> List[Column]:
"""Get a list of the Columns with names and data types from the given sql."""

# Taking advantage of yet another amazing DuckDB SQL feature right here: the
# ability to DESCRIBE a query instead of a relation
describe_sql = f"DESCRIBE ({sql})"
Copy link
Collaborator

@jwills jwills Apr 26, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jtcohen6 after a truly circuitous journey, this ended up being simple and delightful to implement (and in an environment-independent way to boot!)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

whoa! this is very handy

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is pretty sweet

_, cursor = self.connections.add_select_query(describe_sql)
ret = []
for row in cursor.fetchall():
name, dtype = row[0], row[1]
ret.append(Column.create(name, dtype))
return ret

@classmethod
def render_column_constraint(cls, constraint: ColumnLevelConstraint) -> Optional[str]:
"""Render the given constraint as DDL text. Should be overriden by adapters which need custom constraint
rendering."""
if constraint.type == ConstraintType.foreign_key:
# DuckDB doesn't support 'foreign key' as an alias
return f"references {constraint.expression}"
Comment on lines +209 to +210
Copy link
Contributor Author

@jtcohen6 jtcohen6 Apr 20, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually references seems like the more standard syntax. And we're not really doing a good job of supporting FK constraints, anyway:

The majority of columnar db's don't enforce them, so it hasn't felt like a priority. It's neat that DuckDB actually does.

else:
return super().render_column_constraint(constraint)


# Change `table_a/b` to `table_aaaaa/bbbbb` to avoid duckdb binding issues when relation_a/b
# is called "table_a" or "table_b" in some of the dbt tests
Expand Down
28 changes: 28 additions & 0 deletions dbt/include/duckdb/macros/adapters.sql
Original file line number Diff line number Diff line change
Expand Up @@ -28,17 +28,41 @@
{{ return(run_query(sql)) }}
{% endmacro %}

{% macro get_column_names() %}
{# loop through user_provided_columns to get column names #}
{%- set user_provided_columns = model['columns'] -%}
(
{% for i in user_provided_columns %}
{% set col = user_provided_columns[i] %}
{{ col['name'] }} {{ "," if not loop.last }}
{% endfor %}
)
{% endmacro %}


{% macro duckdb__create_table_as(temporary, relation, compiled_code, language='sql') -%}
{%- if language == 'sql' -%}
{% set contract_config = config.get('contract') %}
{% if contract_config.enforced %}
{{ get_assert_columns_equivalent(compiled_code) }}
{% endif %}
{%- set sql_header = config.get('sql_header', none) -%}

{{ sql_header if sql_header is not none }}

create {% if temporary: -%}temporary{%- endif %} table
{{ relation.include(database=(not temporary and adapter.use_database()), schema=(not temporary)) }}
{% if contract_config.enforced and not temporary %}
{#-- DuckDB doesnt support constraints on temp tables --#}
{{ get_table_columns_and_constraints() }} ;
insert into {{ relation }} {{ get_column_names() }} (
{{ get_select_subquery(compiled_code) }}
);
{% else %}
as (
{{ compiled_code }}
);
{% endif %}
{%- elif language == 'python' -%}
{{ py_write_table(temporary=temporary, relation=relation, compiled_code=compiled_code) }}
{%- else -%}
Expand All @@ -62,6 +86,10 @@ def materialize(df, con):
{% endmacro %}

{% macro duckdb__create_view_as(relation, sql) -%}
{% set contract_config = config.get('contract') %}
{% if contract_config.enforced %}
{{ get_assert_columns_equivalent(sql) }}
{%- endif %}
{%- set sql_header = config.get('sql_header', none) -%}

{{ sql_header if sql_header is not none }}
Expand Down
Loading