Skip to content

Commit

Permalink
fix: Remove DDL automatically issued by Ibis for Postgres connections (
Browse files Browse the repository at this point in the history
…#1067)

* fix: remove DDL automatically issued by Ibis for PG

* fix: lint

* fix: remove unsupported | operand for python 3.8
  • Loading branch information
nehanene15 committed Dec 6, 2023
1 parent f3cc565 commit c2b660b
Showing 1 changed file with 49 additions and 2 deletions.
51 changes: 49 additions & 2 deletions third_party/ibis/ibis_postgres/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,61 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Literal

import ibis.expr.datatypes as dt
import ibis.expr.schema as sch
from ibis import util

import sqlalchemy as sa
from ibis import util
from ibis.backends.postgres import Backend as PostgresBackend
from ibis.backends.postgres.datatypes import _BRACKETS, _parse_numeric, _type_mapping


def do_connect(
self,
host: str = "localhost",
user: str = None,
password: str = None,
port: int = 5432,
database: str = None,
schema: str = None,
url: str = None,
driver: Literal["psycopg2"] = "psycopg2",
) -> None:
# Override do_connect() method to remove DDL queries to CREATE/DROP FUNCTION
if driver != "psycopg2":
raise NotImplementedError("psycopg2 is currently the only supported driver")

alchemy_url = self._build_alchemy_url(
url=url,
host=host,
port=port,
user=user,
password=password,
database=database,
driver=f"postgresql+{driver}",
)
self.database_name = alchemy_url.database
connect_args = {}
if schema is not None:
connect_args["options"] = f"-csearch_path={schema}"

engine = sa.create_engine(
alchemy_url, connect_args=connect_args, poolclass=sa.pool.StaticPool
)

@sa.event.listens_for(engine, "connect")
def connect(dbapi_connection, connection_record):
with dbapi_connection.cursor() as cur:
cur.execute("SET TIMEZONE = UTC")

# Equivalent of super().do_connect() below
self.con = engine
self._inspector = sa.inspect(self.con)
self._schemas: dict[str, sch.Schema] = {}
self._temp_views: set[str] = set()


def _metadata(self, query: str) -> sch.Schema:
raw_name = util.guid()
name = self._quote(raw_name)
Expand Down Expand Up @@ -72,3 +118,4 @@ def list_schemas(self, like=None):

PostgresBackend._metadata = _metadata
PostgresBackend.list_databases = list_schemas
PostgresBackend.do_connect = do_connect

0 comments on commit c2b660b

Please sign in to comment.