diff --git a/setup.py b/setup.py index b2130848d..6b7d4cd63 100644 --- a/setup.py +++ b/setup.py @@ -51,6 +51,7 @@ "jellyfish==0.8.2", "tabulate==0.8.9", "Flask==2.0.2", + "parsy==2.0", ] extras_require = { diff --git a/third_party/ibis/ibis_postgres/client.py b/third_party/ibis/ibis_postgres/client.py index 3a0491ba7..dd562e905 100644 --- a/third_party/ibis/ibis_postgres/client.py +++ b/third_party/ibis/ibis_postgres/client.py @@ -12,12 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -# from ibis.backends.postgres.client import PostgreSQLClient +from __future__ import annotations +import parsy +import re +import ast +import toolz import ibis.expr.datatypes as dt import ibis.expr.schema as sch from ibis import util from ibis.backends.postgres.client import PostgreSQLClient + def _get_schema_using_query(self, query: str) -> sch.Schema: raw_name = util.guid() name = self.con.dialect.identifier_preparer.quote_identifier(raw_name) @@ -40,72 +45,116 @@ def _get_schema_using_query(self, query: str) -> sch.Schema: tuples = [(col, self._get_type(typestr)) for col, typestr in type_info] return sch.Schema.from_tuples(tuples) -def _get_type(self,typestr: str) -> dt.DataType: - _type_mapping = { - # "boolean": dt.bool, + +_BRACKETS = "[]" +_STRING_REGEX = ( + """('[^\n'\\\\]*(?:\\\\.[^\n'\\\\]*)*'|"[^\n"\\\\"]*(?:\\\\.[^\n"\\\\]*)*")""" +) + + +def spaceless(parser): + return SPACES.then(parser).skip(SPACES) + + +def spaceless_string(*strings: str): + return spaceless( + parsy.alt(*(parsy.string(s, transform=str.lower) for s in strings)) + ) + + +SPACES = parsy.regex(r"\s*", re.MULTILINE) +RAW_NUMBER = parsy.decimal_digit.at_least(1).concat() +SINGLE_DIGIT = parsy.decimal_digit +PRECISION = SCALE = NUMBER = RAW_NUMBER.map(int) + +LPAREN = spaceless_string("(") +RPAREN = spaceless_string(")") + +LBRACKET = spaceless_string("[") +RBRACKET = spaceless_string("]") + +LANGLE = spaceless_string("<") +RANGLE = spaceless_string(">") + +COMMA = spaceless_string(",") +COLON = spaceless_string(":") +SEMICOLON = spaceless_string(";") + +RAW_STRING = parsy.regex(_STRING_REGEX).map(ast.literal_eval) +FIELD = parsy.regex("[a-zA-Z_][a-zA-Z_0-9]*") + + +def _parse_numeric( + text: str, ddp: tuple[int | None, int | None] = (None, None) +) -> dt.DataType: + decimal = spaceless_string("decimal", "numeric").then( + parsy.seq(LPAREN.then(PRECISION.skip(COMMA)), SCALE.skip(RPAREN)) + .optional(ddp) + .combine(dt.Decimal) + ) + + brackets = spaceless(LBRACKET).then(spaceless(RBRACKET)) + + pg_array = parsy.seq(decimal, brackets.at_least(1).map(len)).combine( + lambda value_type, n: toolz.nth(n, toolz.iterate(dt.Array, value_type)) + ) + + ty = pg_array | decimal + return ty.parse(text) + + +def _get_type(self, typestr: str) -> dt.DataType: + is_array = typestr.endswith(_BRACKETS) + # typ = _type_mapping.get(typestr.replace(_BRACKETS, "")) + # handle bracket length + typestr_wob = typestr.replace(_BRACKETS, "") + if "(" in typestr_wob: + typestr_wo_length = ( + typestr_wob[: typestr_wob.index("(")] + + typestr_wob[typestr_wob.index(")") + 1 :] + ) + else: + typestr_wo_length = typestr_wob + typ = _type_mapping.get(typestr_wo_length) + if typ is not None: + return dt.Array(typ) if is_array else typ + return _parse_numeric(typestr) + + +_type_mapping = { + "bigint": dt.int64, "boolean": dt.boolean, - "boolean[]": dt.Array(dt.boolean), "bytea": dt.binary, - "bytea[]": dt.Array(dt.binary), + "character varying": dt.string, + "character": dt.string, "character(1)": dt.string, - "character(1)[]": dt.Array(dt.string), - "bigint": dt.int64, - "bigint[]": dt.Array(dt.int64), - "smallint": dt.int16, - "smallint[]": dt.Array(dt.int16), + "date": dt.date, + "double precision": dt.float64, + "geography": dt.geography, + "geometry": dt.geometry, + "inet": dt.inet, "integer": dt.int32, - "integer[]": dt.Array(dt.int32), - "text": dt.string, - "text[]": dt.Array(dt.string), + "interval": dt.interval, "json": dt.json, - "json[]": dt.Array(dt.json), + "jsonb": dt.json, + "line": dt.linestring, + "macaddr": dt.macaddr, + "macaddr8": dt.macaddr, + "numeric": dt.float64, "point": dt.point, - "point[]": dt.Array(dt.point), "polygon": dt.polygon, - "polygon[]": dt.Array(dt.polygon), - "line": dt.linestring, - "line[]": dt.Array(dt.linestring), "real": dt.float32, - "real[]": dt.Array(dt.float32), - "double precision": dt.float64, - "double precision[]": dt.Array(dt.float64), - "macaddr8": dt.macaddr, - "macaddr8[]": dt.Array(dt.macaddr), - "macaddr": dt.macaddr, - "macaddr[]": dt.Array(dt.macaddr), - "inet": dt.inet, - "inet[]": dt.Array(dt.inet), - "character": dt.string, - "character[]": dt.Array(dt.string), - "character varying": dt.string, - "character varying[]": dt.Array(dt.string), - "date": dt.date, - "date[]": dt.Array(dt.date), + "smallint": dt.int16, + "text": dt.string, + # NB: this isn't correct because we're losing the "with time zone" + # information (ibis doesn't have time type that is time-zone aware), but we + # try to do _something_ here instead of failing + "time with time zone": dt.time, "time without time zone": dt.time, - "time without time zone[]": dt.Array(dt.time), - "timestamp without time zone": dt.timestamp, - "timestamp without time zone[]": dt.Array(dt.timestamp), "timestamp with time zone": dt.Timestamp("UTC"), - "timestamp with time zone[]": dt.Array(dt.Timestamp("UTC")), - "interval": dt.interval, - "interval[]": dt.Array(dt.interval), - # NB: this isn"t correct, but we try not to fail - "time with time zone": "time", - "numeric": dt.float64, - "numeric[]": dt.Array(dt.float64), + "timestamp without time zone": dt.timestamp, "uuid": dt.uuid, - "uuid[]": dt.Array(dt.uuid), - "jsonb": dt.jsonb, - "jsonb[]": dt.Array(dt.jsonb), - "geometry": dt.geometry, - "geometry[]": dt.Array(dt.geometry), - "geography": dt.geography, - "geography[]": dt.Array(dt.geography), - } - try: - return _type_mapping[typestr] - except KeyError: - return +} PostgreSQLClient._get_schema_using_query = _get_schema_using_query -PostgreSQLClient._get_type = _get_type \ No newline at end of file +PostgreSQLClient._get_type = _get_type