Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace column index usage in db column alter operations with column attnum. #1132

Merged
merged 11 commits into from
Mar 8, 2022
Merged
45 changes: 24 additions & 21 deletions db/columns/operations/alter.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,31 +16,27 @@
from db.utils import execute_statement


def alter_column(engine, table_oid, column_index, column_data):
def alter_column(engine, table_oid, column_attnum, column_data):
TYPE_KEY = 'plain_type'
TYPE_OPTIONS_KEY = 'type_options'
NULLABLE_KEY = NULLABLE
DEFAULT_DICT = 'column_default_dict'
DEFAULT_KEY = 'value'
NAME_KEY = NAME

table = reflect_table_from_oid(table_oid, engine)
column_index = int(column_index)

with engine.begin() as conn:
if TYPE_KEY in column_data:
retype_column(
table, column_index, engine, conn,
table_oid, column_attnum, engine, conn,
new_type=column_data[TYPE_KEY],
type_options=column_data.get(TYPE_OPTIONS_KEY, {})
)
elif TYPE_OPTIONS_KEY in column_data:
retype_column(
table, column_index, engine, conn,
table_oid, column_attnum, engine, conn,
type_options=column_data[TYPE_OPTIONS_KEY]
)
column_name = table.columns[column_index].name
column_attnum = get_column_attnum_from_name(table_oid, column_name, engine)

if NULLABLE_KEY in column_data:
nullable = column_data[NULLABLE_KEY]
change_column_nullable(table_oid, column_attnum, engine, conn, nullable)
Expand All @@ -53,18 +49,19 @@ def alter_column(engine, table_oid, column_index, column_data):
# Name always needs to be the last item altered
# since previous operations need the name to work
name = column_data[NAME_KEY]
rename_column(table, column_index, engine, conn, name)

rename_column(table_oid, column_attnum, engine, conn, name)
column_name = get_column_name_from_attnum(table_oid, column_attnum, engine)
return get_mathesar_column_with_engine(
reflect_table_from_oid(table_oid, engine).columns[column_index],
reflect_table_from_oid(table_oid, engine).columns[column_name],
engine
)


def alter_column_type(
table, column_name, engine, connection, target_type_str,
table_oid, column_name, engine, connection, target_type_str,
type_options={}, friendly_names=True,
):
table = reflect_table_from_oid(table_oid, engine, connection)
_preparer = engine.dialect.identifier_preparer
supported_types = get_supported_alter_column_types(
engine, friendly_names=friendly_names
Expand Down Expand Up @@ -104,9 +101,11 @@ def alter_column_type(


def retype_column(
table, column_index, engine, connection, new_type=None, type_options={},
table_oid, column_attnum, engine, connection, new_type=None, type_options={},
):
column = table.columns[column_index]
table = reflect_table_from_oid(table_oid, engine, connection)
column_name = get_column_name_from_attnum(table_oid, column_attnum, engine)
column = table.columns[column_name]
column_db_type = get_db_type_name(column.type, engine)
new_type = new_type if new_type is not None else column_db_type
column_type_options = get_type_options(column)
Expand All @@ -119,8 +118,8 @@ def retype_column(

try:
alter_column_type(
table,
table.columns[column_index].name,
table_oid,
column_name,
engine,
connection,
new_type,
Expand Down Expand Up @@ -163,8 +162,10 @@ def set_column_default(table_oid, column_attnum, engine, connection, default):
raise e


def rename_column(table, column_index, engine, connection, new_name):
column = table.columns[column_index]
def rename_column(table_oid, column_attnum, engine, connection, new_name):
table = reflect_table_from_oid(table_oid, engine, connection)
column_name = get_column_name_from_attnum(table_oid, column_attnum, engine)
column = table.columns[column_name]
ctx = MigrationContext.configure(connection)
op = Operations(ctx)
op.alter_column(table.name, column.name, new_column_name=new_name, schema=table.schema)
Expand All @@ -190,14 +191,16 @@ def _validate_columns_for_batch_update(table, column_data):
raise ValueError(f'Key "{key}" found in columns. Keys allowed are: {allowed_key_list}')


def _batch_update_column_types(table, column_data_list, connection, engine):
def _batch_update_column_types(table_oid, column_data_list, connection, engine):
table = reflect_table_from_oid(table_oid, engine, connection)
for index, column_data in enumerate(column_data_list):
if 'plain_type' in column_data:
new_type = column_data['plain_type']
type_options = column_data.get('type_options', {})
if type_options is None:
type_options = {}
retype_column(table, index, engine, connection, new_type, type_options)
column_attnum = get_column_attnum_from_name(table_oid, table.columns[index].name, engine, connection)
retype_column(table_oid, column_attnum, engine, connection, new_type, type_options)


def _batch_alter_table_columns(table, column_data_list, connection):
Expand All @@ -219,5 +222,5 @@ def batch_update_columns(table_oid, engine, column_data_list):
table = reflect_table_from_oid(table_oid, engine)
_validate_columns_for_batch_update(table, column_data_list)
with engine.begin() as conn:
_batch_update_column_types(table, column_data_list, conn, engine)
_batch_update_column_types(table_oid, column_data_list, conn, engine)
_batch_alter_table_columns(table, column_data_list, conn)
5 changes: 3 additions & 2 deletions db/columns/operations/infer_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from db.columns.exceptions import DagCycleError
from db.columns.operations.alter import alter_column_type
from db.tables.operations.select import reflect_table
from db.tables.operations.select import get_oid_from_table, reflect_table
from db.types.operations.cast import get_supported_alter_column_types
from db.types import base

Expand Down Expand Up @@ -62,10 +62,11 @@ def infer_column_type(schema, table_name, column_name, engine, depth=0, type_inf
column_type_str = reverse_type_map.get(column_type)

logger.debug(f"column_type_str: {column_type_str}")
table_oid = get_oid_from_table(table_name, schema, engine)
for type_str in type_inference_dag.get(column_type_str, []):
try:
with engine.begin() as conn:
alter_column_type(table, column_name, engine, conn, type_str)
alter_column_type(table_oid, column_name, engine, conn, type_str)
logger.info(f"Column {column_name} altered to type {type_str}")
column_type = infer_column_type(
schema,
Expand Down
32 changes: 20 additions & 12 deletions db/tests/columns/operations/test_alter.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from db import constants
from db.columns.operations import alter as alter_operations
from db.columns.operations.alter import alter_column, batch_update_columns, change_column_nullable, rename_column, retype_column, set_column_default
from db.columns.operations.select import get_column_attnum_from_name, get_column_default, get_column_index_from_name
from db.columns.operations.select import get_column_attnum_from_name, get_column_default
from db.columns.utils import get_mathesar_column_with_engine
from db.tables.operations.create import create_mathesar_table
from db.tables.operations.select import get_oid_from_table, reflect_table
Expand All @@ -31,9 +31,9 @@ def _rename_column_and_assert(table, old_col_name, new_col_name, engine):
Renames the colum of a table and assert the change went through
"""
table_oid = get_oid_from_table(table.name, table.schema, engine)
column_index = get_column_index_from_name(table_oid, old_col_name, engine)
column_attnum = get_column_attnum_from_name(table_oid, old_col_name, engine)
with engine.begin() as conn:
rename_column(table, column_index, engine, conn, new_col_name)
rename_column(table_oid, column_attnum, engine, conn, new_col_name)
table = reflect_table(table.name, table.schema, engine)
assert new_col_name in table.columns
assert old_col_name not in table.columns
Expand Down Expand Up @@ -86,15 +86,16 @@ def test_alter_column_chooses_wisely(column_dict, func_name, engine_with_schema)
table_name = "table_with_columns"
engine, schema = engine_with_schema
metadata = MetaData(bind=engine, schema=schema)
table = Table(table_name, metadata, Column('col', String))
column_name = 'col'
table = Table(table_name, metadata, Column(column_name, String))
table.create()
table_oid = get_oid_from_table(table.name, table.schema, engine)

target_column_attnum = get_column_attnum_from_name(table_oid, column_name, engine)
with patch.object(alter_operations, func_name) as mock_alterer:
alter_column(
engine,
table_oid,
0,
target_column_attnum,
column_dict
)
mock_alterer.assert_called_once()
Expand Down Expand Up @@ -179,11 +180,13 @@ def test_retype_column_correct_column(engine_with_schema):
Column(nontarget_column_name, String),
)
table.create()
table_oid = get_oid_from_table(table.name, table.schema, engine)
target_column_attnum = get_column_attnum_from_name(table_oid, target_column_name, engine)
with engine.begin() as conn:
with patch.object(alter_operations, "alter_column_type") as mock_retyper:
retype_column(table, 0, engine, conn, target_type)
retype_column(table_oid, target_column_attnum, engine, conn, target_type)
mock_retyper.assert_called_with(
table,
table_oid,
target_column_name,
engine,
conn,
Expand All @@ -207,11 +210,14 @@ def test_retype_column_adds_options(engine_with_schema, target_type):
)
table.create()
type_options = {"precision": 5}
table_oid = get_oid_from_table(table.name, table.schema, engine)
target_column_attnum = get_column_attnum_from_name(table_oid, target_column_name, engine)

with engine.begin() as conn:
with patch.object(alter_operations, "alter_column_type") as mock_retyper:
retype_column(table, 0, engine, conn, target_type, type_options)
retype_column(table_oid, target_column_attnum, engine, conn, target_type, type_options)
mock_retyper.assert_called_with(
table,
table_oid,
target_column_name,
engine,
conn,
Expand All @@ -233,13 +239,15 @@ def test_retype_column_options_only(engine_with_schema):
)
table.create()
type_options = {"length": 5}
table_oid = get_oid_from_table(table.name, table.schema, engine)
target_column_attnum = get_column_attnum_from_name(table_oid, target_column_name, engine)
with engine.begin() as conn:
with patch.object(alter_operations, "alter_column_type") as mock_retyper:
retype_column(
table, 0, engine, conn, new_type=None, type_options=type_options
table_oid, target_column_attnum, engine, conn, new_type=None, type_options=type_options
)
mock_retyper.assert_called_with(
table,
table_oid,
target_column_name,
engine,
conn,
Expand Down
5 changes: 3 additions & 2 deletions db/tests/types/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from sqlalchemy import MetaData, Table
from sqlalchemy.schema import CreateSchema, DropSchema
from db.engine import _add_custom_types_to_engine
from db.tables.operations.select import get_oid_from_table
from db.types import base, install
from db.columns.operations.alter import alter_column_type

Expand Down Expand Up @@ -51,7 +52,7 @@ def uris_table_obj(engine_with_uris, uris_table_name):
uri_column_name = "uri"
uri_type_id = "uri"
alter_column_type(
table,
get_oid_from_table(table.name, schema, engine),
uri_column_name,
engine,
conn,
Expand All @@ -70,7 +71,7 @@ def roster_table_obj(engine_with_roster, roster_table_name):
email_column_name = "Teacher Email"
email_type_id = "email"
alter_column_type(
table,
get_oid_from_table(table.name, schema, engine),
email_column_name,
engine,
conn,
Expand Down
10 changes: 5 additions & 5 deletions db/tests/types/operations/test_cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -856,7 +856,7 @@ def test_alter_column_type_alters_column_type(
input_table.create()
with engine.begin() as conn:
alter_column_type(
input_table,
get_oid_from_table(TABLE_NAME, schema, engine),
COLUMN_NAME,
engine,
conn,
Expand Down Expand Up @@ -935,7 +935,7 @@ def test_alter_column_type_casts_column_data_args(
with engine.begin() as conn:
conn.execute(ins)
alter_column_type(
input_table,
get_oid_from_table(TABLE_NAME, schema, engine),
COLUMN_NAME,
engine,
conn,
Expand Down Expand Up @@ -1002,7 +1002,7 @@ def test_alter_column_casts_data_gen(
with engine.begin() as conn:
conn.execute(ins)
alter_column_type(
input_table,
get_oid_from_table(TABLE_NAME, schema, engine),
COLUMN_NAME,
engine,
conn,
Expand Down Expand Up @@ -1064,7 +1064,7 @@ def test_alter_column_type_raises_on_bad_column_data(
conn.execute(ins)
with pytest.raises(Exception):
alter_column_type(
input_table,
get_oid_from_table(TABLE_NAME, schema, engine),
COLUMN_NAME,
engine,
conn,
Expand Down Expand Up @@ -1092,7 +1092,7 @@ def test_alter_column_type_raises_on_bad_parameters(
conn.execute(ins)
with pytest.raises(DataError) as e:
alter_column_type(
input_table,
get_oid_from_table(TABLE_NAME, schema, engine),
COLUMN_NAME,
engine,
conn,
Expand Down
2 changes: 1 addition & 1 deletion mathesar/api/db/viewsets/columns.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def partial_update(self, request, pk=None, table_pk=None):
with warnings.catch_warnings():
warnings.filterwarnings("error", category=DynamicDefaultWarning)
try:
table.alter_column(column_instance._sa_column.column_index, serializer.validated_data)
table.alter_column(column_instance._sa_column.column_attnum, serializer.validated_data)
except ProgrammingError as e:
if type(e.orig) == UndefinedFunction:
raise database_api_exceptions.UndefinedFunctionAPIException(
Expand Down
4 changes: 2 additions & 2 deletions mathesar/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,11 +253,11 @@ def add_column(self, column_data):
column_data,
)

def alter_column(self, column_index, column_data):
def alter_column(self, column_attnum, column_data):
return alter_column(
self.schema._sa_engine,
self.oid,
column_index,
column_attnum,
column_data,
)

Expand Down
2 changes: 1 addition & 1 deletion mathesar/tests/api/test_column_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,7 +432,7 @@ def test_column_invalid_display_options_type_on_reflection(column_test_table_wit
column_index = 2
column = columns[column_index]
with engine.begin() as conn:
alter_column_type(table._sa_table, column.name, engine, conn, 'boolean')
alter_column_type(table.oid, column.name, engine, conn, 'boolean')
column_id = column.id
response = client.get(
f"/api/db/v0/tables/{table.id}/columns/{column_id}/",
Expand Down
1 change: 0 additions & 1 deletion mathesar/tests/api/test_table_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1146,7 +1146,6 @@ def test_table_patch_columns_multiple_type_change(create_data_types_table, clien
}
response = client.patch(f'/api/db/v0/tables/{table.id}/', body)
response_json = response.json()

assert response.status_code == 200
_check_columns(response_json['columns'], column_data)

Expand Down