diff --git a/third_party/ibis/ibis_postgres/client.py b/third_party/ibis/ibis_postgres/client.py index 09fa67f92..d44e4ca60 100644 --- a/third_party/ibis/ibis_postgres/client.py +++ b/third_party/ibis/ibis_postgres/client.py @@ -12,15 +12,61 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Literal + import ibis.expr.datatypes as dt import ibis.expr.schema as sch -from ibis import util - import sqlalchemy as sa +from ibis import util from ibis.backends.postgres import Backend as PostgresBackend from ibis.backends.postgres.datatypes import _BRACKETS, _parse_numeric, _type_mapping +def do_connect( + self, + host: str = "localhost", + user: str = None, + password: str = None, + port: int = 5432, + database: str = None, + schema: str = None, + url: str = None, + driver: Literal["psycopg2"] = "psycopg2", +) -> None: + # Override do_connect() method to remove DDL queries to CREATE/DROP FUNCTION + if driver != "psycopg2": + raise NotImplementedError("psycopg2 is currently the only supported driver") + + alchemy_url = self._build_alchemy_url( + url=url, + host=host, + port=port, + user=user, + password=password, + database=database, + driver=f"postgresql+{driver}", + ) + self.database_name = alchemy_url.database + connect_args = {} + if schema is not None: + connect_args["options"] = f"-csearch_path={schema}" + + engine = sa.create_engine( + alchemy_url, connect_args=connect_args, poolclass=sa.pool.StaticPool + ) + + @sa.event.listens_for(engine, "connect") + def connect(dbapi_connection, connection_record): + with dbapi_connection.cursor() as cur: + cur.execute("SET TIMEZONE = UTC") + + # Equivalent of super().do_connect() below + self.con = engine + self._inspector = sa.inspect(self.con) + self._schemas: dict[str, sch.Schema] = {} + self._temp_views: set[str] = set() + + def _metadata(self, query: str) -> sch.Schema: raw_name = util.guid() name = self._quote(raw_name) @@ -72,3 +118,4 @@ def list_schemas(self, like=None): PostgresBackend._metadata = _metadata PostgresBackend.list_databases = list_schemas +PostgresBackend.do_connect = do_connect