Skip to content

Commit

Permalink
fix: Add support for numeric and precision with length and precision …
Browse files Browse the repository at this point in the history
…in Postgres Custom Query (#723)

* fix: change numeric to decimal and add supporting defs

* fix: change dt.decimal to dt.Decimal

* fix: change dt.Decimal to dt.float64

* fix: add self

* chore: add parsy to pypi installations

* fix: replace := opertor with python 3.7 compatible assignment

* chore: fix linting in setup.py

* chore: remove commented imports

* fix: handle fixed length specification in the _type_mapping keys
  • Loading branch information
sharangagarwal committed Feb 17, 2023
1 parent 26bb8e9 commit 742b77e
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 57 deletions.
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
"jellyfish==0.8.2",
"tabulate==0.8.9",
"Flask==2.0.2",
"parsy==2.0",
]

extras_require = {
Expand Down
163 changes: 106 additions & 57 deletions third_party/ibis/ibis_postgres/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# from ibis.backends.postgres.client import PostgreSQLClient
from __future__ import annotations
import parsy
import re
import ast
import toolz
import ibis.expr.datatypes as dt
import ibis.expr.schema as sch
from ibis import util
from ibis.backends.postgres.client import PostgreSQLClient


def _get_schema_using_query(self, query: str) -> sch.Schema:
raw_name = util.guid()
name = self.con.dialect.identifier_preparer.quote_identifier(raw_name)
Expand All @@ -40,72 +45,116 @@ def _get_schema_using_query(self, query: str) -> sch.Schema:
tuples = [(col, self._get_type(typestr)) for col, typestr in type_info]
return sch.Schema.from_tuples(tuples)

def _get_type(self,typestr: str) -> dt.DataType:
_type_mapping = {
# "boolean": dt.bool,

_BRACKETS = "[]"
_STRING_REGEX = (
"""('[^\n'\\\\]*(?:\\\\.[^\n'\\\\]*)*'|"[^\n"\\\\"]*(?:\\\\.[^\n"\\\\]*)*")"""
)


def spaceless(parser):
return SPACES.then(parser).skip(SPACES)


def spaceless_string(*strings: str):
return spaceless(
parsy.alt(*(parsy.string(s, transform=str.lower) for s in strings))
)


SPACES = parsy.regex(r"\s*", re.MULTILINE)
RAW_NUMBER = parsy.decimal_digit.at_least(1).concat()
SINGLE_DIGIT = parsy.decimal_digit
PRECISION = SCALE = NUMBER = RAW_NUMBER.map(int)

LPAREN = spaceless_string("(")
RPAREN = spaceless_string(")")

LBRACKET = spaceless_string("[")
RBRACKET = spaceless_string("]")

LANGLE = spaceless_string("<")
RANGLE = spaceless_string(">")

COMMA = spaceless_string(",")
COLON = spaceless_string(":")
SEMICOLON = spaceless_string(";")

RAW_STRING = parsy.regex(_STRING_REGEX).map(ast.literal_eval)
FIELD = parsy.regex("[a-zA-Z_][a-zA-Z_0-9]*")


def _parse_numeric(
text: str, ddp: tuple[int | None, int | None] = (None, None)
) -> dt.DataType:
decimal = spaceless_string("decimal", "numeric").then(
parsy.seq(LPAREN.then(PRECISION.skip(COMMA)), SCALE.skip(RPAREN))
.optional(ddp)
.combine(dt.Decimal)
)

brackets = spaceless(LBRACKET).then(spaceless(RBRACKET))

pg_array = parsy.seq(decimal, brackets.at_least(1).map(len)).combine(
lambda value_type, n: toolz.nth(n, toolz.iterate(dt.Array, value_type))
)

ty = pg_array | decimal
return ty.parse(text)


def _get_type(self, typestr: str) -> dt.DataType:
is_array = typestr.endswith(_BRACKETS)
# typ = _type_mapping.get(typestr.replace(_BRACKETS, ""))
# handle bracket length
typestr_wob = typestr.replace(_BRACKETS, "")
if "(" in typestr_wob:
typestr_wo_length = (
typestr_wob[: typestr_wob.index("(")]
+ typestr_wob[typestr_wob.index(")") + 1 :]
)
else:
typestr_wo_length = typestr_wob
typ = _type_mapping.get(typestr_wo_length)
if typ is not None:
return dt.Array(typ) if is_array else typ
return _parse_numeric(typestr)


_type_mapping = {
"bigint": dt.int64,
"boolean": dt.boolean,
"boolean[]": dt.Array(dt.boolean),
"bytea": dt.binary,
"bytea[]": dt.Array(dt.binary),
"character varying": dt.string,
"character": dt.string,
"character(1)": dt.string,
"character(1)[]": dt.Array(dt.string),
"bigint": dt.int64,
"bigint[]": dt.Array(dt.int64),
"smallint": dt.int16,
"smallint[]": dt.Array(dt.int16),
"date": dt.date,
"double precision": dt.float64,
"geography": dt.geography,
"geometry": dt.geometry,
"inet": dt.inet,
"integer": dt.int32,
"integer[]": dt.Array(dt.int32),
"text": dt.string,
"text[]": dt.Array(dt.string),
"interval": dt.interval,
"json": dt.json,
"json[]": dt.Array(dt.json),
"jsonb": dt.json,
"line": dt.linestring,
"macaddr": dt.macaddr,
"macaddr8": dt.macaddr,
"numeric": dt.float64,
"point": dt.point,
"point[]": dt.Array(dt.point),
"polygon": dt.polygon,
"polygon[]": dt.Array(dt.polygon),
"line": dt.linestring,
"line[]": dt.Array(dt.linestring),
"real": dt.float32,
"real[]": dt.Array(dt.float32),
"double precision": dt.float64,
"double precision[]": dt.Array(dt.float64),
"macaddr8": dt.macaddr,
"macaddr8[]": dt.Array(dt.macaddr),
"macaddr": dt.macaddr,
"macaddr[]": dt.Array(dt.macaddr),
"inet": dt.inet,
"inet[]": dt.Array(dt.inet),
"character": dt.string,
"character[]": dt.Array(dt.string),
"character varying": dt.string,
"character varying[]": dt.Array(dt.string),
"date": dt.date,
"date[]": dt.Array(dt.date),
"smallint": dt.int16,
"text": dt.string,
# NB: this isn't correct because we're losing the "with time zone"
# information (ibis doesn't have time type that is time-zone aware), but we
# try to do _something_ here instead of failing
"time with time zone": dt.time,
"time without time zone": dt.time,
"time without time zone[]": dt.Array(dt.time),
"timestamp without time zone": dt.timestamp,
"timestamp without time zone[]": dt.Array(dt.timestamp),
"timestamp with time zone": dt.Timestamp("UTC"),
"timestamp with time zone[]": dt.Array(dt.Timestamp("UTC")),
"interval": dt.interval,
"interval[]": dt.Array(dt.interval),
# NB: this isn"t correct, but we try not to fail
"time with time zone": "time",
"numeric": dt.float64,
"numeric[]": dt.Array(dt.float64),
"timestamp without time zone": dt.timestamp,
"uuid": dt.uuid,
"uuid[]": dt.Array(dt.uuid),
"jsonb": dt.jsonb,
"jsonb[]": dt.Array(dt.jsonb),
"geometry": dt.geometry,
"geometry[]": dt.Array(dt.geometry),
"geography": dt.geography,
"geography[]": dt.Array(dt.geography),
}
try:
return _type_mapping[typestr]
except KeyError:
return
}

PostgreSQLClient._get_schema_using_query = _get_schema_using_query
PostgreSQLClient._get_type = _get_type
PostgreSQLClient._get_type = _get_type

0 comments on commit 742b77e

Please sign in to comment.