Skip to content

Commit

Permalink
Merge pull request #1104 from centerofci/mathesar-992-attnum-column-s…
Browse files Browse the repository at this point in the history
…election-

Replace column_index usage in column selection operations with column_attnum
  • Loading branch information
silentninja committed Mar 3, 2022
2 parents 3079ecd + d089764 commit eb4b5e4
Show file tree
Hide file tree
Showing 13 changed files with 112 additions and 58 deletions.
22 changes: 19 additions & 3 deletions db/columns/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from db.columns.defaults import TYPE, PRIMARY_KEY, NULLABLE, DEFAULT_COLUMNS
from db.columns.operations.select import (
get_column_default, get_column_default_dict, get_column_index_from_name
get_column_attnum_from_name, get_column_default, get_column_default_dict, get_column_index_from_name,
)
from db.tables.operations.select import get_oid_from_table
from db.types.operations.cast import get_full_cast_map
Expand Down Expand Up @@ -138,12 +138,28 @@ def column_index(self):
self.engine
)

@property
def column_attnum(self):
"""
Get the attnum of this column in its table, if it is
attached to a table that is associated with the column's engine.
"""
engine_exists = self.engine is not None
table_exists = self.table_ is not None
engine_has_table = inspect(self.engine).has_table(self.table_.name, schema=self.table_.schema)
if engine_exists and table_exists and engine_has_table:
return get_column_attnum_from_name(
self.table_oid,
self.name,
self.engine
)

@property
def column_default_dict(self):
if self.table_ is None:
return
default_dict = get_column_default_dict(
self.table_oid, self.column_index, self.engine
self.table_oid, self.column_attnum, self.engine
)
if default_dict:
return {
Expand All @@ -154,7 +170,7 @@ def column_default_dict(self):
@property
def default_value(self):
if self.table_ is not None:
return get_column_default(self.table_oid, self.column_index, self.engine)
return get_column_default(self.table_oid, self.column_attnum, self.engine)

@property
def plain_type(self):
Expand Down
5 changes: 3 additions & 2 deletions db/columns/operations/alter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from db.columns.defaults import NAME, NULLABLE
from db.columns.exceptions import InvalidDefaultError, InvalidTypeError, InvalidTypeOptionError
from db.columns.operations.select import get_column_default, get_column_index_from_name
from db.columns.operations.select import get_column_attnum_from_name, get_column_default, get_column_index_from_name
from db.columns.utils import get_mathesar_column_with_engine, get_type_options
from db.tables.operations.select import get_oid_from_table, reflect_table_from_oid
from db.types.base import get_db_type_name
Expand Down Expand Up @@ -73,8 +73,9 @@ def alter_column_type(
table = reflect_table_from_oid(table_oid, engine, connection)
column = table.columns[column_name]
column_index = get_column_index_from_name(table_oid, column_name, engine, connection)
column_attnum = get_column_attnum_from_name(table_oid, column_name, engine, connection)

default = get_column_default(table_oid, column_index, engine, connection)
default = get_column_default(table_oid, column_attnum, engine, connection)
if default is not None:
default_text = column.server_default.arg.text
set_column_default(table, column_index, engine, connection, None)
Expand Down
6 changes: 3 additions & 3 deletions db/columns/operations/create.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from db.columns.defaults import DEFAULT, NAME, NULLABLE, TYPE
from db.columns.exceptions import InvalidDefaultError, InvalidTypeError, InvalidTypeOptionError
from db.columns.operations.alter import set_column_default, change_column_nullable
from db.columns.operations.select import get_column_default, get_column_index_from_name
from db.columns.operations.select import get_column_attnum_from_name, get_column_default, get_column_index_from_name
from db.columns.utils import get_mathesar_column_with_engine
from db.constraints.operations.create import copy_constraint
from db.constraints.operations.select import get_column_constraints
Expand Down Expand Up @@ -94,6 +94,7 @@ def compile_copy_column(element, compiler, **_):

def _duplicate_column_data(table_oid, from_column, to_column, engine):
table = reflect_table_from_oid(table_oid, engine)
from_column_attnum = get_column_attnum_from_name(table_oid, table.c[from_column].name, engine)
copy = CopyColumn(
table.schema,
table.name,
Expand All @@ -102,8 +103,7 @@ def _duplicate_column_data(table_oid, from_column, to_column, engine):
)
with engine.begin() as conn:
conn.execute(copy)

from_default = get_column_default(table_oid, from_column, engine)
from_default = get_column_default(table_oid, from_column_attnum, engine)
if from_default is not None:
with engine.begin() as conn:
set_column_default(table, to_column, engine, conn, from_default)
Expand Down
37 changes: 27 additions & 10 deletions db/columns/operations/select.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import warnings

from pglast import Node, parse_sql
from sqlalchemy import Table, MetaData, and_, select, text, func, cast
from sqlalchemy import Table, MetaData, and_, asc, select, text, func, cast

from db.columns.exceptions import DynamicDefaultWarning
from db.tables.operations.select import reflect_table_from_oid
Expand All @@ -17,7 +17,7 @@
DYNAMIC_NODE_TAGS = {"SQLValueFunction", "FuncCall"}


def get_columns_attnum_from_names(table_oid, column_names, engine, connection_to_use=None):
def _get_columns_attnum_from_names(table_oid, column_names, engine, connection_to_use=None):
with warnings.catch_warnings():
warnings.filterwarnings("ignore", message="Did not recognize type")
pg_attribute = Table("pg_attribute", MetaData(), autoload_with=engine)
Expand All @@ -26,15 +26,31 @@ def get_columns_attnum_from_names(table_oid, column_names, engine, connection_to
pg_attribute.c.attrelid == table_oid,
pg_attribute.c.attname.in_(column_names)
)
)
return execute_statement(engine, sel, connection_to_use).fetchall()
).order_by(asc(pg_attribute.c.attnum))
return sel


def get_columns_attnum_from_names(table_oid, column_names, engine, connection_to_use=None):
"""
Returns the respective list of attnum of the column names passed.
The order is based on the column order in the table and not by the order of the column names argument.
"""
statement = _get_columns_attnum_from_names(table_oid, column_names, engine, connection_to_use=None)
attnums_tuple = execute_statement(engine, statement, connection_to_use).fetchall()
attnums = [attnum_tuple[0] for attnum_tuple in attnums_tuple]
return attnums


def get_column_attnum_from_name(table_oid, column_name, engine, connection_to_use=None):
statement = _get_columns_attnum_from_names(table_oid, [column_name], engine, connection_to_use=None)
return execute_statement(engine, statement, connection_to_use).scalar()


def get_column_index_from_name(table_oid, column_name, engine, connection_to_use=None):
with warnings.catch_warnings():
warnings.filterwarnings("ignore", message="Did not recognize type")
pg_attribute = Table("pg_attribute", MetaData(), autoload_with=engine)
result = get_columns_attnum_from_names(table_oid, [column_name], engine, connection_to_use)[0][0]
result = get_column_attnum_from_name(table_oid, column_name, engine, connection_to_use)

# Account for dropped columns that don't appear in the SQLAlchemy tables
sel = (
Expand Down Expand Up @@ -77,13 +93,14 @@ def get_column_name_from_attnum(table_oid, attnum, engine, connection_to_use=Non
pg_attribute.c.attnum == attnum
)
)
result = execute_statement(engine, sel, connection_to_use).fetchone()[0]
result = execute_statement(engine, sel, connection_to_use).scalar()
return result


def get_column_default_dict(table_oid, column_index, engine, connection_to_use=None):
def get_column_default_dict(table_oid, attnum, engine, connection_to_use=None):
table = reflect_table_from_oid(table_oid, engine, connection_to_use)
column = table.columns[column_index]
column_name = get_column_name_from_attnum(table_oid, attnum, engine, connection_to_use)
column = table.columns[column_name]
if column.server_default is None:
return

Expand All @@ -108,9 +125,9 @@ def get_column_default_dict(table_oid, column_index, engine, connection_to_use=N
return {"value": default_value, "is_dynamic": is_dynamic}


def get_column_default(table_oid, column_index, engine, connection_to_use=None):
def get_column_default(table_oid, attnum, engine, connection_to_use=None):
default_dict = get_column_default_dict(
table_oid, column_index, engine, connection_to_use=connection_to_use
table_oid, attnum, engine, connection_to_use=connection_to_use
)
if default_dict is not None:
return default_dict['value']
Expand Down
21 changes: 11 additions & 10 deletions db/tests/columns/operations/test_alter.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from db import constants
from db.columns.operations import alter as alter_operations
from db.columns.operations.alter import alter_column, batch_update_columns, change_column_nullable, rename_column, retype_column, set_column_default
from db.columns.operations.select import get_column_default, get_column_index_from_name
from db.columns.operations.select import get_column_attnum_from_name, get_column_default, get_column_index_from_name
from db.columns.utils import get_mathesar_column_with_engine
from db.tables.operations.create import create_mathesar_table
from db.tables.operations.select import get_oid_from_table, reflect_table
Expand Down Expand Up @@ -357,11 +357,12 @@ def test_column_default_create(engine_with_schema, col_type):
Column(column_name, col_type)
)
table.create()

table_oid = get_oid_from_table(table_name, schema, engine)
column_attnum = get_column_attnum_from_name(table_oid, column_name, engine)
with engine.begin() as conn:
set_column_default(table, 0, engine, conn, set_default)
table_oid = get_oid_from_table(table_name, schema, engine)
default = get_column_default(table_oid, 0, engine)

default = get_column_default(table_oid, column_attnum, engine)
created_default = get_default(engine, table)

assert default == expt_default
Expand All @@ -380,11 +381,11 @@ def test_column_default_update(engine_with_schema, col_type):
Column(column_name, col_type, server_default=start_default)
)
table.create()

table_oid = get_oid_from_table(table_name, schema, engine)
column_attnum = get_column_attnum_from_name(table_oid, column_name, engine)
with engine.begin() as conn:
set_column_default(table, 0, engine, conn, set_default)
table_oid = get_oid_from_table(table_name, schema, engine)
default = get_column_default(table_oid, 0, engine)
default = get_column_default(table_oid, column_attnum, engine)
created_default = get_default(engine, table)

assert default != start_default
Expand All @@ -404,11 +405,11 @@ def test_column_default_delete(engine_with_schema, col_type):
Column(column_name, col_type, server_default=set_default)
)
table.create()

table_oid = get_oid_from_table(table_name, schema, engine)
column_attnum = get_column_attnum_from_name(table_oid, column_name, engine)
with engine.begin() as conn:
set_column_default(table, 0, engine, conn, None)
table_oid = get_oid_from_table(table_name, schema, engine)
default = get_column_default(table_oid, 0, engine)
default = get_column_default(table_oid, column_attnum, engine)
created_default = get_default(engine, table)

assert default is None
Expand Down
6 changes: 3 additions & 3 deletions db/tests/columns/operations/test_create.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from sqlalchemy import Integer, Column, Table, MetaData, Numeric, UniqueConstraint

from db.columns.operations.create import create_column, duplicate_column
from db.columns.operations.select import get_column_default, get_column_index_from_name
from db.columns.operations.select import get_column_attnum_from_name, get_column_default, get_column_index_from_name
from db.tables.operations.select import get_oid_from_table, reflect_table_from_oid
from db.constraints.operations.select import get_column_constraints
from db.tests.columns.utils import create_test_table
Expand Down Expand Up @@ -341,8 +341,8 @@ def test_duplicate_column_default(engine_with_schema, copy_data, copy_constraint
table_oid, 0, engine, new_col_name, copy_data, copy_constraints
)

col_index = get_column_index_from_name(table_oid, new_col_name, engine)
default = get_column_default(table_oid, col_index, engine)
column_attnum = get_column_attnum_from_name(table_oid, new_col_name, engine)
default = get_column_default(table_oid, column_attnum, engine)
if copy_data:
assert default == expt_default
else:
Expand Down
43 changes: 33 additions & 10 deletions db/tests/columns/operations/test_select.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,19 @@
from alembic.operations import Operations
import pytest
from sqlalchemy import (
String, Integer, Column, Table, MetaData, DateTime, func, text, DefaultClause
String, Integer, Column, Table, MetaData, DateTime, func, text, DefaultClause,
)

from db.columns.exceptions import DynamicDefaultWarning
from db.columns.operations.select import (
get_column_default, get_column_index_from_name, _is_default_expr_dynamic
get_column_attnum_from_name, get_column_default, get_column_index_from_name, _is_default_expr_dynamic,
get_column_name_from_attnum, get_columns_attnum_from_names,
)
from db.tables.operations.select import get_oid_from_table
from db.tests.columns.utils import column_test_dict, get_default


def test_get_column_index_from_name(engine_with_schema):
def test_get_attnum_from_name(engine_with_schema):
engine, schema = engine_with_schema
table_name = "table_with_columns"
zero_name = "colzero"
Expand All @@ -27,8 +28,28 @@ def test_get_column_index_from_name(engine_with_schema):
)
table.create()
table_oid = get_oid_from_table(table_name, schema, engine)
assert get_column_index_from_name(table_oid, zero_name, engine) == 0
assert get_column_index_from_name(table_oid, one_name, engine) == 1
column_zero_attnum = get_column_attnum_from_name(table_oid, zero_name, engine)
column_one_attnum = get_column_attnum_from_name(table_oid, one_name, engine)
assert get_column_name_from_attnum(table_oid, column_zero_attnum, engine) == zero_name
assert get_column_name_from_attnum(table_oid, column_one_attnum, engine) == one_name


def test_get_attnum_from_names(engine_with_schema):
engine, schema = engine_with_schema
table_name = "table_with_columns"
zero_name = "colzero"
one_name = "colone"
table = Table(
table_name,
MetaData(bind=engine, schema=schema),
Column(zero_name, Integer),
Column(one_name, String),
)
table.create()
table_oid = get_oid_from_table(table_name, schema, engine)
columns_attnum = get_columns_attnum_from_names(table_oid, [zero_name, one_name], engine)
assert get_column_name_from_attnum(table_oid, columns_attnum[0], engine) == zero_name
assert get_column_name_from_attnum(table_oid, columns_attnum[1], engine) == one_name


def test_get_column_index_from_name_after_delete(engine_with_schema):
Expand All @@ -50,8 +71,9 @@ def test_get_column_index_from_name_after_delete(engine_with_schema):
op.drop_column(table.name, one_name, schema=schema)

table_oid = get_oid_from_table(table_name, schema, engine)
assert get_column_index_from_name(table_oid, zero_name, engine) == 0
assert get_column_index_from_name(table_oid, two_name, engine) == 1
columns_attnum = get_columns_attnum_from_names(table_oid, [zero_name, two_name], engine)
assert get_column_name_from_attnum(table_oid, columns_attnum[0], engine) == zero_name
assert get_column_name_from_attnum(table_oid, columns_attnum[1], engine) == two_name


def test_get_column_index_from_name_after_delete_two_tables(engine_with_schema):
Expand Down Expand Up @@ -108,8 +130,8 @@ def test_get_column_default(engine_with_schema, filler, col_type):
)
table.create()
table_oid = get_oid_from_table(table_name, schema, engine)

default = get_column_default(table_oid, 0, engine)
column_attnum = get_column_attnum_from_name(table_oid, column_name, engine)
default = get_column_default(table_oid, column_attnum, engine)
created_default = get_default(engine, table)
assert default == expt_default
assert default == created_default
Expand All @@ -133,9 +155,10 @@ def test_get_column_generated_default(engine_with_schema, col):
)
table.create()
table_oid = get_oid_from_table(table_name, schema, engine)
column_attnum = get_column_attnum_from_name(table_oid, col.name, engine)
with warnings.catch_warnings(), pytest.raises(DynamicDefaultWarning):
warnings.filterwarnings("error", category=DynamicDefaultWarning)
get_column_default(table_oid, 0, engine)
get_column_default(table_oid, column_attnum, engine)


default_expression_test_list = [
Expand Down
5 changes: 3 additions & 2 deletions db/tests/types/operations/test_cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from sqlalchemy.exc import DataError

from db import types
from db.columns.operations.select import get_column_default
from db.columns.operations.select import get_column_attnum_from_name, get_column_default
from db.columns.operations.alter import alter_column_type
from db.tables.operations.select import get_oid_from_table
from db.tests.types import fixtures
Expand Down Expand Up @@ -1017,7 +1017,8 @@ def test_alter_column_casts_data_gen(
actual_value = res[0][0]
assert actual_value == out_val
table_oid = get_oid_from_table(TABLE_NAME, schema, engine)
actual_default = get_column_default(table_oid, 0, engine)
column_attnum = get_column_attnum_from_name(table_oid, COLUMN_NAME, engine)
actual_default = get_column_default(table_oid, column_attnum, engine)
# TODO This needs to be sorted out by fixing how server_default is set.
if all([
source_type != get_qualified_name(MathesarCustomType.MATHESAR_MONEY.value),
Expand Down
5 changes: 2 additions & 3 deletions mathesar/api/db/viewsets/columns.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from db.columns.exceptions import (
DynamicDefaultWarning, InvalidDefaultError, InvalidTypeOptionError, InvalidTypeError,
)
from db.columns.operations.select import get_columns_attnum_from_names
from db.columns.operations.select import get_column_attnum_from_name
from db.types.exceptions import InvalidTypeParameters
from mathesar.api.pagination import DefaultLimitOffsetPagination
from mathesar.api.serializers.columns import ColumnSerializer
Expand Down Expand Up @@ -93,8 +93,7 @@ def create(self, request, table_pk=None):
)
dj_column = Column(
table=table,
attnum=get_columns_attnum_from_names(table.oid, [column.name], table.schema._sa_engine)[0][
0],
attnum=get_column_attnum_from_name(table.oid, column.name, table.schema._sa_engine),
**serializer.validated_model_fields
)
dj_column.save()
Expand Down
Loading

0 comments on commit eb4b5e4

Please sign in to comment.