diff --git a/db/columns/operations/alter.py b/db/columns/operations/alter.py index c187af26d9..652ab1aa70 100644 --- a/db/columns/operations/alter.py +++ b/db/columns/operations/alter.py @@ -16,7 +16,7 @@ from db.utils import execute_statement -def alter_column(engine, table_oid, column_index, column_data): +def alter_column(engine, table_oid, column_attnum, column_data): TYPE_KEY = 'plain_type' TYPE_OPTIONS_KEY = 'type_options' NULLABLE_KEY = NULLABLE @@ -24,23 +24,19 @@ def alter_column(engine, table_oid, column_index, column_data): DEFAULT_KEY = 'value' NAME_KEY = NAME - table = reflect_table_from_oid(table_oid, engine) - column_index = int(column_index) - with engine.begin() as conn: if TYPE_KEY in column_data: retype_column( - table, column_index, engine, conn, + table_oid, column_attnum, engine, conn, new_type=column_data[TYPE_KEY], type_options=column_data.get(TYPE_OPTIONS_KEY, {}) ) elif TYPE_OPTIONS_KEY in column_data: retype_column( - table, column_index, engine, conn, + table_oid, column_attnum, engine, conn, type_options=column_data[TYPE_OPTIONS_KEY] ) - column_name = table.columns[column_index].name - column_attnum = get_column_attnum_from_name(table_oid, column_name, engine) + if NULLABLE_KEY in column_data: nullable = column_data[NULLABLE_KEY] change_column_nullable(table_oid, column_attnum, engine, conn, nullable) @@ -53,18 +49,19 @@ def alter_column(engine, table_oid, column_index, column_data): # Name always needs to be the last item altered # since previous operations need the name to work name = column_data[NAME_KEY] - rename_column(table, column_index, engine, conn, name) - + rename_column(table_oid, column_attnum, engine, conn, name) + column_name = get_column_name_from_attnum(table_oid, column_attnum, engine) return get_mathesar_column_with_engine( - reflect_table_from_oid(table_oid, engine).columns[column_index], + reflect_table_from_oid(table_oid, engine).columns[column_name], engine ) def alter_column_type( - table, column_name, engine, connection, target_type_str, + table_oid, column_name, engine, connection, target_type_str, type_options={}, friendly_names=True, ): + table = reflect_table_from_oid(table_oid, engine, connection) _preparer = engine.dialect.identifier_preparer supported_types = get_supported_alter_column_types( engine, friendly_names=friendly_names @@ -104,9 +101,11 @@ def alter_column_type( def retype_column( - table, column_index, engine, connection, new_type=None, type_options={}, + table_oid, column_attnum, engine, connection, new_type=None, type_options={}, ): - column = table.columns[column_index] + table = reflect_table_from_oid(table_oid, engine, connection) + column_name = get_column_name_from_attnum(table_oid, column_attnum, engine) + column = table.columns[column_name] column_db_type = get_db_type_name(column.type, engine) new_type = new_type if new_type is not None else column_db_type column_type_options = get_type_options(column) @@ -119,8 +118,8 @@ def retype_column( try: alter_column_type( - table, - table.columns[column_index].name, + table_oid, + column_name, engine, connection, new_type, @@ -163,8 +162,10 @@ def set_column_default(table_oid, column_attnum, engine, connection, default): raise e -def rename_column(table, column_index, engine, connection, new_name): - column = table.columns[column_index] +def rename_column(table_oid, column_attnum, engine, connection, new_name): + table = reflect_table_from_oid(table_oid, engine, connection) + column_name = get_column_name_from_attnum(table_oid, column_attnum, engine) + column = table.columns[column_name] ctx = MigrationContext.configure(connection) op = Operations(ctx) op.alter_column(table.name, column.name, new_column_name=new_name, schema=table.schema) @@ -190,14 +191,16 @@ def _validate_columns_for_batch_update(table, column_data): raise ValueError(f'Key "{key}" found in columns. Keys allowed are: {allowed_key_list}') -def _batch_update_column_types(table, column_data_list, connection, engine): +def _batch_update_column_types(table_oid, column_data_list, connection, engine): + table = reflect_table_from_oid(table_oid, engine, connection) for index, column_data in enumerate(column_data_list): if 'plain_type' in column_data: new_type = column_data['plain_type'] type_options = column_data.get('type_options', {}) if type_options is None: type_options = {} - retype_column(table, index, engine, connection, new_type, type_options) + column_attnum = get_column_attnum_from_name(table_oid, table.columns[index].name, engine, connection) + retype_column(table_oid, column_attnum, engine, connection, new_type, type_options) def _batch_alter_table_columns(table, column_data_list, connection): @@ -219,5 +222,5 @@ def batch_update_columns(table_oid, engine, column_data_list): table = reflect_table_from_oid(table_oid, engine) _validate_columns_for_batch_update(table, column_data_list) with engine.begin() as conn: - _batch_update_column_types(table, column_data_list, conn, engine) + _batch_update_column_types(table_oid, column_data_list, conn, engine) _batch_alter_table_columns(table, column_data_list, conn) diff --git a/db/columns/operations/infer_types.py b/db/columns/operations/infer_types.py index 10305f4039..93643fd717 100644 --- a/db/columns/operations/infer_types.py +++ b/db/columns/operations/infer_types.py @@ -5,7 +5,7 @@ from db.columns.exceptions import DagCycleError from db.columns.operations.alter import alter_column_type -from db.tables.operations.select import reflect_table +from db.tables.operations.select import get_oid_from_table, reflect_table from db.types.operations.cast import get_supported_alter_column_types from db.types import base @@ -62,10 +62,11 @@ def infer_column_type(schema, table_name, column_name, engine, depth=0, type_inf column_type_str = reverse_type_map.get(column_type) logger.debug(f"column_type_str: {column_type_str}") + table_oid = get_oid_from_table(table_name, schema, engine) for type_str in type_inference_dag.get(column_type_str, []): try: with engine.begin() as conn: - alter_column_type(table, column_name, engine, conn, type_str) + alter_column_type(table_oid, column_name, engine, conn, type_str) logger.info(f"Column {column_name} altered to type {type_str}") column_type = infer_column_type( schema, diff --git a/db/tests/columns/operations/test_alter.py b/db/tests/columns/operations/test_alter.py index 346df01d3c..7fb449c176 100644 --- a/db/tests/columns/operations/test_alter.py +++ b/db/tests/columns/operations/test_alter.py @@ -8,7 +8,7 @@ from db import constants from db.columns.operations import alter as alter_operations from db.columns.operations.alter import alter_column, batch_update_columns, change_column_nullable, rename_column, retype_column, set_column_default -from db.columns.operations.select import get_column_attnum_from_name, get_column_default, get_column_index_from_name +from db.columns.operations.select import get_column_attnum_from_name, get_column_default from db.columns.utils import get_mathesar_column_with_engine from db.tables.operations.create import create_mathesar_table from db.tables.operations.select import get_oid_from_table, reflect_table @@ -31,9 +31,9 @@ def _rename_column_and_assert(table, old_col_name, new_col_name, engine): Renames the colum of a table and assert the change went through """ table_oid = get_oid_from_table(table.name, table.schema, engine) - column_index = get_column_index_from_name(table_oid, old_col_name, engine) + column_attnum = get_column_attnum_from_name(table_oid, old_col_name, engine) with engine.begin() as conn: - rename_column(table, column_index, engine, conn, new_col_name) + rename_column(table_oid, column_attnum, engine, conn, new_col_name) table = reflect_table(table.name, table.schema, engine) assert new_col_name in table.columns assert old_col_name not in table.columns @@ -86,15 +86,16 @@ def test_alter_column_chooses_wisely(column_dict, func_name, engine_with_schema) table_name = "table_with_columns" engine, schema = engine_with_schema metadata = MetaData(bind=engine, schema=schema) - table = Table(table_name, metadata, Column('col', String)) + column_name = 'col' + table = Table(table_name, metadata, Column(column_name, String)) table.create() table_oid = get_oid_from_table(table.name, table.schema, engine) - + target_column_attnum = get_column_attnum_from_name(table_oid, column_name, engine) with patch.object(alter_operations, func_name) as mock_alterer: alter_column( engine, table_oid, - 0, + target_column_attnum, column_dict ) mock_alterer.assert_called_once() @@ -179,11 +180,13 @@ def test_retype_column_correct_column(engine_with_schema): Column(nontarget_column_name, String), ) table.create() + table_oid = get_oid_from_table(table.name, table.schema, engine) + target_column_attnum = get_column_attnum_from_name(table_oid, target_column_name, engine) with engine.begin() as conn: with patch.object(alter_operations, "alter_column_type") as mock_retyper: - retype_column(table, 0, engine, conn, target_type) + retype_column(table_oid, target_column_attnum, engine, conn, target_type) mock_retyper.assert_called_with( - table, + table_oid, target_column_name, engine, conn, @@ -207,11 +210,14 @@ def test_retype_column_adds_options(engine_with_schema, target_type): ) table.create() type_options = {"precision": 5} + table_oid = get_oid_from_table(table.name, table.schema, engine) + target_column_attnum = get_column_attnum_from_name(table_oid, target_column_name, engine) + with engine.begin() as conn: with patch.object(alter_operations, "alter_column_type") as mock_retyper: - retype_column(table, 0, engine, conn, target_type, type_options) + retype_column(table_oid, target_column_attnum, engine, conn, target_type, type_options) mock_retyper.assert_called_with( - table, + table_oid, target_column_name, engine, conn, @@ -233,13 +239,15 @@ def test_retype_column_options_only(engine_with_schema): ) table.create() type_options = {"length": 5} + table_oid = get_oid_from_table(table.name, table.schema, engine) + target_column_attnum = get_column_attnum_from_name(table_oid, target_column_name, engine) with engine.begin() as conn: with patch.object(alter_operations, "alter_column_type") as mock_retyper: retype_column( - table, 0, engine, conn, new_type=None, type_options=type_options + table_oid, target_column_attnum, engine, conn, new_type=None, type_options=type_options ) mock_retyper.assert_called_with( - table, + table_oid, target_column_name, engine, conn, diff --git a/db/tests/types/fixtures.py b/db/tests/types/fixtures.py index dd37fa2a40..0edea258d2 100644 --- a/db/tests/types/fixtures.py +++ b/db/tests/types/fixtures.py @@ -10,6 +10,7 @@ from sqlalchemy import MetaData, Table from sqlalchemy.schema import CreateSchema, DropSchema from db.engine import _add_custom_types_to_engine +from db.tables.operations.select import get_oid_from_table from db.types import base, install from db.columns.operations.alter import alter_column_type @@ -51,7 +52,7 @@ def uris_table_obj(engine_with_uris, uris_table_name): uri_column_name = "uri" uri_type_id = "uri" alter_column_type( - table, + get_oid_from_table(table.name, schema, engine), uri_column_name, engine, conn, @@ -70,7 +71,7 @@ def roster_table_obj(engine_with_roster, roster_table_name): email_column_name = "Teacher Email" email_type_id = "email" alter_column_type( - table, + get_oid_from_table(table.name, schema, engine), email_column_name, engine, conn, diff --git a/db/tests/types/operations/test_cast.py b/db/tests/types/operations/test_cast.py index be92f54011..227308da5b 100644 --- a/db/tests/types/operations/test_cast.py +++ b/db/tests/types/operations/test_cast.py @@ -856,7 +856,7 @@ def test_alter_column_type_alters_column_type( input_table.create() with engine.begin() as conn: alter_column_type( - input_table, + get_oid_from_table(TABLE_NAME, schema, engine), COLUMN_NAME, engine, conn, @@ -935,7 +935,7 @@ def test_alter_column_type_casts_column_data_args( with engine.begin() as conn: conn.execute(ins) alter_column_type( - input_table, + get_oid_from_table(TABLE_NAME, schema, engine), COLUMN_NAME, engine, conn, @@ -1002,7 +1002,7 @@ def test_alter_column_casts_data_gen( with engine.begin() as conn: conn.execute(ins) alter_column_type( - input_table, + get_oid_from_table(TABLE_NAME, schema, engine), COLUMN_NAME, engine, conn, @@ -1064,7 +1064,7 @@ def test_alter_column_type_raises_on_bad_column_data( conn.execute(ins) with pytest.raises(Exception): alter_column_type( - input_table, + get_oid_from_table(TABLE_NAME, schema, engine), COLUMN_NAME, engine, conn, @@ -1092,7 +1092,7 @@ def test_alter_column_type_raises_on_bad_parameters( conn.execute(ins) with pytest.raises(DataError) as e: alter_column_type( - input_table, + get_oid_from_table(TABLE_NAME, schema, engine), COLUMN_NAME, engine, conn, diff --git a/mathesar/api/db/viewsets/columns.py b/mathesar/api/db/viewsets/columns.py index b273329ed4..538dd09a16 100644 --- a/mathesar/api/db/viewsets/columns.py +++ b/mathesar/api/db/viewsets/columns.py @@ -99,7 +99,7 @@ def partial_update(self, request, pk=None, table_pk=None): with warnings.catch_warnings(): warnings.filterwarnings("error", category=DynamicDefaultWarning) try: - table.alter_column(column_instance._sa_column.column_index, serializer.validated_data) + table.alter_column(column_instance._sa_column.column_attnum, serializer.validated_data) except ProgrammingError as e: if type(e.orig) == UndefinedFunction: raise database_api_exceptions.UndefinedFunctionAPIException( diff --git a/mathesar/models.py b/mathesar/models.py index 16c3cfdfd3..97d07c1651 100644 --- a/mathesar/models.py +++ b/mathesar/models.py @@ -253,11 +253,11 @@ def add_column(self, column_data): column_data, ) - def alter_column(self, column_index, column_data): + def alter_column(self, column_attnum, column_data): return alter_column( self.schema._sa_engine, self.oid, - column_index, + column_attnum, column_data, ) diff --git a/mathesar/tests/api/test_column_api.py b/mathesar/tests/api/test_column_api.py index a903a403cb..18a5eb2520 100644 --- a/mathesar/tests/api/test_column_api.py +++ b/mathesar/tests/api/test_column_api.py @@ -432,7 +432,7 @@ def test_column_invalid_display_options_type_on_reflection(column_test_table_wit column_index = 2 column = columns[column_index] with engine.begin() as conn: - alter_column_type(table._sa_table, column.name, engine, conn, 'boolean') + alter_column_type(table.oid, column.name, engine, conn, 'boolean') column_id = column.id response = client.get( f"/api/db/v0/tables/{table.id}/columns/{column_id}/", diff --git a/mathesar/tests/api/test_table_api.py b/mathesar/tests/api/test_table_api.py index 0391662bbe..08af3f7879 100644 --- a/mathesar/tests/api/test_table_api.py +++ b/mathesar/tests/api/test_table_api.py @@ -1146,7 +1146,6 @@ def test_table_patch_columns_multiple_type_change(create_data_types_table, clien } response = client.patch(f'/api/db/v0/tables/{table.id}/', body) response_json = response.json() - assert response.status_code == 200 _check_columns(response_json['columns'], column_data)