Skip to content

Commit

Permalink
Merge pull request #1132 from centerofci/mathesar-993-attnum-column-a…
Browse files Browse the repository at this point in the history
…lter

Replace column index usage in db column alter operations with column attnum.
  • Loading branch information
dmos62 committed Mar 8, 2022
2 parents 1a831c1 + 79a663d commit 67e615a
Show file tree
Hide file tree
Showing 9 changed files with 59 additions and 47 deletions.
45 changes: 24 additions & 21 deletions db/columns/operations/alter.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,31 +16,27 @@
from db.utils import execute_statement


def alter_column(engine, table_oid, column_index, column_data):
def alter_column(engine, table_oid, column_attnum, column_data):
TYPE_KEY = 'plain_type'
TYPE_OPTIONS_KEY = 'type_options'
NULLABLE_KEY = NULLABLE
DEFAULT_DICT = 'column_default_dict'
DEFAULT_KEY = 'value'
NAME_KEY = NAME

table = reflect_table_from_oid(table_oid, engine)
column_index = int(column_index)

with engine.begin() as conn:
if TYPE_KEY in column_data:
retype_column(
table, column_index, engine, conn,
table_oid, column_attnum, engine, conn,
new_type=column_data[TYPE_KEY],
type_options=column_data.get(TYPE_OPTIONS_KEY, {})
)
elif TYPE_OPTIONS_KEY in column_data:
retype_column(
table, column_index, engine, conn,
table_oid, column_attnum, engine, conn,
type_options=column_data[TYPE_OPTIONS_KEY]
)
column_name = table.columns[column_index].name
column_attnum = get_column_attnum_from_name(table_oid, column_name, engine)

if NULLABLE_KEY in column_data:
nullable = column_data[NULLABLE_KEY]
change_column_nullable(table_oid, column_attnum, engine, conn, nullable)
Expand All @@ -53,18 +49,19 @@ def alter_column(engine, table_oid, column_index, column_data):
# Name always needs to be the last item altered
# since previous operations need the name to work
name = column_data[NAME_KEY]
rename_column(table, column_index, engine, conn, name)

rename_column(table_oid, column_attnum, engine, conn, name)
column_name = get_column_name_from_attnum(table_oid, column_attnum, engine)
return get_mathesar_column_with_engine(
reflect_table_from_oid(table_oid, engine).columns[column_index],
reflect_table_from_oid(table_oid, engine).columns[column_name],
engine
)


def alter_column_type(
table, column_name, engine, connection, target_type_str,
table_oid, column_name, engine, connection, target_type_str,
type_options={}, friendly_names=True,
):
table = reflect_table_from_oid(table_oid, engine, connection)
_preparer = engine.dialect.identifier_preparer
supported_types = get_supported_alter_column_types(
engine, friendly_names=friendly_names
Expand Down Expand Up @@ -104,9 +101,11 @@ def alter_column_type(


def retype_column(
table, column_index, engine, connection, new_type=None, type_options={},
table_oid, column_attnum, engine, connection, new_type=None, type_options={},
):
column = table.columns[column_index]
table = reflect_table_from_oid(table_oid, engine, connection)
column_name = get_column_name_from_attnum(table_oid, column_attnum, engine)
column = table.columns[column_name]
column_db_type = get_db_type_name(column.type, engine)
new_type = new_type if new_type is not None else column_db_type
column_type_options = get_type_options(column)
Expand All @@ -119,8 +118,8 @@ def retype_column(

try:
alter_column_type(
table,
table.columns[column_index].name,
table_oid,
column_name,
engine,
connection,
new_type,
Expand Down Expand Up @@ -163,8 +162,10 @@ def set_column_default(table_oid, column_attnum, engine, connection, default):
raise e


def rename_column(table, column_index, engine, connection, new_name):
column = table.columns[column_index]
def rename_column(table_oid, column_attnum, engine, connection, new_name):
table = reflect_table_from_oid(table_oid, engine, connection)
column_name = get_column_name_from_attnum(table_oid, column_attnum, engine)
column = table.columns[column_name]
ctx = MigrationContext.configure(connection)
op = Operations(ctx)
op.alter_column(table.name, column.name, new_column_name=new_name, schema=table.schema)
Expand All @@ -190,14 +191,16 @@ def _validate_columns_for_batch_update(table, column_data):
raise ValueError(f'Key "{key}" found in columns. Keys allowed are: {allowed_key_list}')


def _batch_update_column_types(table, column_data_list, connection, engine):
def _batch_update_column_types(table_oid, column_data_list, connection, engine):
table = reflect_table_from_oid(table_oid, engine, connection)
for index, column_data in enumerate(column_data_list):
if 'plain_type' in column_data:
new_type = column_data['plain_type']
type_options = column_data.get('type_options', {})
if type_options is None:
type_options = {}
retype_column(table, index, engine, connection, new_type, type_options)
column_attnum = get_column_attnum_from_name(table_oid, table.columns[index].name, engine, connection)
retype_column(table_oid, column_attnum, engine, connection, new_type, type_options)


def _batch_alter_table_columns(table, column_data_list, connection):
Expand All @@ -219,5 +222,5 @@ def batch_update_columns(table_oid, engine, column_data_list):
table = reflect_table_from_oid(table_oid, engine)
_validate_columns_for_batch_update(table, column_data_list)
with engine.begin() as conn:
_batch_update_column_types(table, column_data_list, conn, engine)
_batch_update_column_types(table_oid, column_data_list, conn, engine)
_batch_alter_table_columns(table, column_data_list, conn)
5 changes: 3 additions & 2 deletions db/columns/operations/infer_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from db.columns.exceptions import DagCycleError
from db.columns.operations.alter import alter_column_type
from db.tables.operations.select import reflect_table
from db.tables.operations.select import get_oid_from_table, reflect_table
from db.types.operations.cast import get_supported_alter_column_types
from db.types import base

Expand Down Expand Up @@ -62,10 +62,11 @@ def infer_column_type(schema, table_name, column_name, engine, depth=0, type_inf
column_type_str = reverse_type_map.get(column_type)

logger.debug(f"column_type_str: {column_type_str}")
table_oid = get_oid_from_table(table_name, schema, engine)
for type_str in type_inference_dag.get(column_type_str, []):
try:
with engine.begin() as conn:
alter_column_type(table, column_name, engine, conn, type_str)
alter_column_type(table_oid, column_name, engine, conn, type_str)
logger.info(f"Column {column_name} altered to type {type_str}")
column_type = infer_column_type(
schema,
Expand Down
32 changes: 20 additions & 12 deletions db/tests/columns/operations/test_alter.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from db import constants
from db.columns.operations import alter as alter_operations
from db.columns.operations.alter import alter_column, batch_update_columns, change_column_nullable, rename_column, retype_column, set_column_default
from db.columns.operations.select import get_column_attnum_from_name, get_column_default, get_column_index_from_name
from db.columns.operations.select import get_column_attnum_from_name, get_column_default
from db.columns.utils import get_mathesar_column_with_engine
from db.tables.operations.create import create_mathesar_table
from db.tables.operations.select import get_oid_from_table, reflect_table
Expand All @@ -31,9 +31,9 @@ def _rename_column_and_assert(table, old_col_name, new_col_name, engine):
Renames the colum of a table and assert the change went through
"""
table_oid = get_oid_from_table(table.name, table.schema, engine)
column_index = get_column_index_from_name(table_oid, old_col_name, engine)
column_attnum = get_column_attnum_from_name(table_oid, old_col_name, engine)
with engine.begin() as conn:
rename_column(table, column_index, engine, conn, new_col_name)
rename_column(table_oid, column_attnum, engine, conn, new_col_name)
table = reflect_table(table.name, table.schema, engine)
assert new_col_name in table.columns
assert old_col_name not in table.columns
Expand Down Expand Up @@ -86,15 +86,16 @@ def test_alter_column_chooses_wisely(column_dict, func_name, engine_with_schema)
table_name = "table_with_columns"
engine, schema = engine_with_schema
metadata = MetaData(bind=engine, schema=schema)
table = Table(table_name, metadata, Column('col', String))
column_name = 'col'
table = Table(table_name, metadata, Column(column_name, String))
table.create()
table_oid = get_oid_from_table(table.name, table.schema, engine)

target_column_attnum = get_column_attnum_from_name(table_oid, column_name, engine)
with patch.object(alter_operations, func_name) as mock_alterer:
alter_column(
engine,
table_oid,
0,
target_column_attnum,
column_dict
)
mock_alterer.assert_called_once()
Expand Down Expand Up @@ -179,11 +180,13 @@ def test_retype_column_correct_column(engine_with_schema):
Column(nontarget_column_name, String),
)
table.create()
table_oid = get_oid_from_table(table.name, table.schema, engine)
target_column_attnum = get_column_attnum_from_name(table_oid, target_column_name, engine)
with engine.begin() as conn:
with patch.object(alter_operations, "alter_column_type") as mock_retyper:
retype_column(table, 0, engine, conn, target_type)
retype_column(table_oid, target_column_attnum, engine, conn, target_type)
mock_retyper.assert_called_with(
table,
table_oid,
target_column_name,
engine,
conn,
Expand All @@ -207,11 +210,14 @@ def test_retype_column_adds_options(engine_with_schema, target_type):
)
table.create()
type_options = {"precision": 5}
table_oid = get_oid_from_table(table.name, table.schema, engine)
target_column_attnum = get_column_attnum_from_name(table_oid, target_column_name, engine)

with engine.begin() as conn:
with patch.object(alter_operations, "alter_column_type") as mock_retyper:
retype_column(table, 0, engine, conn, target_type, type_options)
retype_column(table_oid, target_column_attnum, engine, conn, target_type, type_options)
mock_retyper.assert_called_with(
table,
table_oid,
target_column_name,
engine,
conn,
Expand All @@ -233,13 +239,15 @@ def test_retype_column_options_only(engine_with_schema):
)
table.create()
type_options = {"length": 5}
table_oid = get_oid_from_table(table.name, table.schema, engine)
target_column_attnum = get_column_attnum_from_name(table_oid, target_column_name, engine)
with engine.begin() as conn:
with patch.object(alter_operations, "alter_column_type") as mock_retyper:
retype_column(
table, 0, engine, conn, new_type=None, type_options=type_options
table_oid, target_column_attnum, engine, conn, new_type=None, type_options=type_options
)
mock_retyper.assert_called_with(
table,
table_oid,
target_column_name,
engine,
conn,
Expand Down
5 changes: 3 additions & 2 deletions db/tests/types/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from sqlalchemy import MetaData, Table
from sqlalchemy.schema import CreateSchema, DropSchema
from db.engine import _add_custom_types_to_engine
from db.tables.operations.select import get_oid_from_table
from db.types import base, install
from db.columns.operations.alter import alter_column_type

Expand Down Expand Up @@ -51,7 +52,7 @@ def uris_table_obj(engine_with_uris, uris_table_name):
uri_column_name = "uri"
uri_type_id = "uri"
alter_column_type(
table,
get_oid_from_table(table.name, schema, engine),
uri_column_name,
engine,
conn,
Expand All @@ -70,7 +71,7 @@ def roster_table_obj(engine_with_roster, roster_table_name):
email_column_name = "Teacher Email"
email_type_id = "email"
alter_column_type(
table,
get_oid_from_table(table.name, schema, engine),
email_column_name,
engine,
conn,
Expand Down
10 changes: 5 additions & 5 deletions db/tests/types/operations/test_cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -856,7 +856,7 @@ def test_alter_column_type_alters_column_type(
input_table.create()
with engine.begin() as conn:
alter_column_type(
input_table,
get_oid_from_table(TABLE_NAME, schema, engine),
COLUMN_NAME,
engine,
conn,
Expand Down Expand Up @@ -935,7 +935,7 @@ def test_alter_column_type_casts_column_data_args(
with engine.begin() as conn:
conn.execute(ins)
alter_column_type(
input_table,
get_oid_from_table(TABLE_NAME, schema, engine),
COLUMN_NAME,
engine,
conn,
Expand Down Expand Up @@ -1002,7 +1002,7 @@ def test_alter_column_casts_data_gen(
with engine.begin() as conn:
conn.execute(ins)
alter_column_type(
input_table,
get_oid_from_table(TABLE_NAME, schema, engine),
COLUMN_NAME,
engine,
conn,
Expand Down Expand Up @@ -1064,7 +1064,7 @@ def test_alter_column_type_raises_on_bad_column_data(
conn.execute(ins)
with pytest.raises(Exception):
alter_column_type(
input_table,
get_oid_from_table(TABLE_NAME, schema, engine),
COLUMN_NAME,
engine,
conn,
Expand Down Expand Up @@ -1092,7 +1092,7 @@ def test_alter_column_type_raises_on_bad_parameters(
conn.execute(ins)
with pytest.raises(DataError) as e:
alter_column_type(
input_table,
get_oid_from_table(TABLE_NAME, schema, engine),
COLUMN_NAME,
engine,
conn,
Expand Down
2 changes: 1 addition & 1 deletion mathesar/api/db/viewsets/columns.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def partial_update(self, request, pk=None, table_pk=None):
with warnings.catch_warnings():
warnings.filterwarnings("error", category=DynamicDefaultWarning)
try:
table.alter_column(column_instance._sa_column.column_index, serializer.validated_data)
table.alter_column(column_instance._sa_column.column_attnum, serializer.validated_data)
except ProgrammingError as e:
if type(e.orig) == UndefinedFunction:
raise database_api_exceptions.UndefinedFunctionAPIException(
Expand Down
4 changes: 2 additions & 2 deletions mathesar/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,11 +253,11 @@ def add_column(self, column_data):
column_data,
)

def alter_column(self, column_index, column_data):
def alter_column(self, column_attnum, column_data):
return alter_column(
self.schema._sa_engine,
self.oid,
column_index,
column_attnum,
column_data,
)

Expand Down
2 changes: 1 addition & 1 deletion mathesar/tests/api/test_column_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,7 +432,7 @@ def test_column_invalid_display_options_type_on_reflection(column_test_table_wit
column_index = 2
column = columns[column_index]
with engine.begin() as conn:
alter_column_type(table._sa_table, column.name, engine, conn, 'boolean')
alter_column_type(table.oid, column.name, engine, conn, 'boolean')
column_id = column.id
response = client.get(
f"/api/db/v0/tables/{table.id}/columns/{column_id}/",
Expand Down
1 change: 0 additions & 1 deletion mathesar/tests/api/test_table_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1146,7 +1146,6 @@ def test_table_patch_columns_multiple_type_change(create_data_types_table, clien
}
response = client.patch(f'/api/db/v0/tables/{table.id}/', body)
response_json = response.json()

assert response.status_code == 200
_check_columns(response_json['columns'], column_data)

Expand Down

0 comments on commit 67e615a

Please sign in to comment.