Skip to content

Commit

Permalink
Merge pull request #1047 from centerofci/mathesar-896-records-api-col…
Browse files Browse the repository at this point in the history
…umn-id-to-name

Mathesar 896 Change records api parameters to use column id instead of name
  • Loading branch information
seancolsen committed Mar 8, 2022
2 parents 10fa641 + 15e237a commit 8631e1c
Show file tree
Hide file tree
Showing 26 changed files with 413 additions and 294 deletions.
3 changes: 1 addition & 2 deletions mathesar/api/db/viewsets/columns.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,7 @@ class ColumnViewSet(viewsets.ModelViewSet):
pagination_class = DefaultLimitOffsetPagination

def get_queryset(self):
table = get_table_or_404(pk=self.kwargs['table_pk'])
return table.get_dj_columns_queryset()
return Column.objects.filter(table=self.kwargs['table_pk'])

def create(self, request, table_pk=None):
table = get_table_or_404(table_pk)
Expand Down
85 changes: 47 additions & 38 deletions mathesar/api/db/viewsets/records.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,16 @@
from psycopg2.errors import NotNullViolation

from rest_framework import status, viewsets
from rest_framework.exceptions import NotFound
from rest_framework.response import Response
from rest_framework.renderers import BrowsableAPIRenderer
from sqlalchemy.exc import IntegrityError
from rest_framework.response import Response
from sqlalchemy_filters.exceptions import BadSortFormat, SortFieldNotFound

from db.functions.exceptions import BadDBFunctionFormat, UnknownDBFunctionID, ReferencedColumnsDontExist
from mathesar.functions.operations.convert import rewrite_db_function_spec_column_ids_to_names
from db.records.exceptions import BadGroupFormat, GroupFieldNotFound, InvalidGroupType

import mathesar.api.exceptions.database_exceptions.exceptions as database_api_exceptions
from db.functions.exceptions import BadDBFunctionFormat, ReferencedColumnsDontExist, UnknownDBFunctionID
from db.records.exceptions import BadGroupFormat, GroupFieldNotFound, InvalidGroupType
from mathesar.api.pagination import TableLimitOffsetGroupPagination
from mathesar.api.serializers.records import RecordListParameterSerializer, RecordSerializer
from mathesar.api.utils import get_table_or_404
from mathesar.functions.operations.convert import rewrite_db_function_spec_column_ids_to_names
from mathesar.models import Table
from mathesar.utils.json import MathesarJSONRenderer

Expand All @@ -37,23 +33,33 @@ def list(self, request, table_pk=None):

serializer = RecordListParameterSerializer(data=request.GET)
serializer.is_valid(raise_exception=True)
table = get_table_or_404(table_pk)

filter_unprocessed = serializer.validated_data['filter']
order_by = serializer.validated_data['order_by']
grouping = serializer.validated_data['grouping']
filter_processed = None
column_ids_to_names = table.get_column_name_id_bidirectional_map().inverse
if filter_unprocessed:
table = get_table_or_404(table_pk)
filter_processed = rewrite_db_function_spec_column_ids_to_names(
column_ids_to_names=column_ids_to_names,
spec=filter_unprocessed,
)
# Replace column id value used in the `field` property with column name
name_converted_group_by = None
if grouping:
group_by_columns_names = [column_ids_to_names[column_id] for column_id in grouping['columns']]
name_converted_group_by = {**grouping, 'columns': group_by_columns_names}
name_converted_order_by = [{**column, 'field': column_ids_to_names[column['field']]} for column in order_by]

try:
if filter_unprocessed:
table = get_table_or_404(table_pk)
column_ids_to_names = table.get_dj_column_id_to_name_mapping()
filter_processed = rewrite_db_function_spec_column_ids_to_names(
column_ids_to_names=column_ids_to_names,
spec=filter_unprocessed,
)

records = paginator.paginate_queryset(
self.get_queryset(), request, table_pk,
filter=filter_processed,
order_by=serializer.validated_data['order_by'],
grouping=serializer.validated_data['grouping'],
self.get_queryset(), request, table,
filters=filter_processed,
order_by=name_converted_order_by,
grouping=name_converted_group_by,
duplicate_only=serializer.validated_data['duplicate_only'],
)
except (BadDBFunctionFormat, UnknownDBFunctionID, ReferencedColumnsDontExist) as e:
Expand All @@ -74,43 +80,46 @@ def list(self, request, table_pk=None):
field='grouping',
status_code=status.HTTP_400_BAD_REQUEST
)

serializer = RecordSerializer(records, many=True)
serializer = RecordSerializer(
records,
many=True,
context=self.get_serializer_context(table)
)
return paginator.get_paginated_response(serializer.data)

def retrieve(self, request, pk=None, table_pk=None):
table = get_table_or_404(table_pk)
record = table.get_record(pk)
if not record:
raise NotFound
serializer = RecordSerializer(record)
serializer = RecordSerializer(record, context=self.get_serializer_context(table))
return Response(serializer.data)

def create(self, request, table_pk=None):
table = get_table_or_404(table_pk)
# We only support adding a single record through the API.
assert isinstance((request.data), dict)
try:
record = table.create_record_or_records(request.data)
except IntegrityError as e:
if type(e.orig) == NotNullViolation:
raise database_api_exceptions.NotNullViolationAPIException(
e,
status_code=status.HTTP_400_BAD_REQUEST,
table=table
)
else:
raise database_api_exceptions.MathesarAPIException(e, status_code=status.HTTP_400_BAD_REQUEST)
serializer = RecordSerializer(record)
serializer = RecordSerializer(data=request.data, context=self.get_serializer_context(table))
serializer.is_valid(raise_exception=True)
serializer.save()
return Response(serializer.data, status=status.HTTP_201_CREATED)

def partial_update(self, request, pk=None, table_pk=None):
table = get_table_or_404(table_pk)
record = table.update_record(pk, request.data)
serializer = RecordSerializer(record)
serializer = RecordSerializer(
{'id': pk},
data=request.data,
context=self.get_serializer_context(table),
partial=True
)
serializer.is_valid(raise_exception=True)
serializer.save()
return Response(serializer.data)

def destroy(self, request, pk=None, table_pk=None):
table = get_table_or_404(table_pk)
table.delete_record(pk)
return Response(status=status.HTTP_204_NO_CONTENT)

def get_serializer_context(self, table):
columns_map = table.get_column_name_id_bidirectional_map()
context = {'columns_map': columns_map, 'table': table}
return context
44 changes: 25 additions & 19 deletions mathesar/api/pagination.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,14 @@ class DefaultLimitOffsetPagination(LimitOffsetPagination):
max_limit = 500

def get_paginated_response(self, data):
return Response(OrderedDict([
('count', self.count),
('results', data)
]))
return Response(
OrderedDict(
[
('count', self.count),
('results', data)
]
)
)


class ColumnLimitOffsetPagination(DefaultLimitOffsetPagination):
Expand All @@ -37,8 +41,8 @@ def paginate_queryset(
self,
queryset,
request,
table_id,
filter=None,
table,
filters=None,
order_by=[],
group_by=None,
duplicate_only=None,
Expand All @@ -48,14 +52,13 @@ def paginate_queryset(
self.limit = self.default_limit
self.offset = self.get_offset(request)
# TODO: Cache count value somewhere, since calculating it is expensive.
table = get_table_or_404(pk=table_id)
self.count = table.sa_num_records(filter=filter)
self.count = table.sa_num_records(filter=filters)
self.request = request

return table.get_records(
self.limit,
self.offset,
filter=filter,
filter=filters,
order_by=order_by,
group_by=group_by,
duplicate_only=duplicate_only,
Expand All @@ -64,29 +67,32 @@ def paginate_queryset(

class TableLimitOffsetGroupPagination(TableLimitOffsetPagination):
def get_paginated_response(self, data):
return Response(OrderedDict([
('count', self.count),
('grouping', self.grouping),
('results', data)
]))
return Response(
OrderedDict(
[
('count', self.count),
('grouping', self.grouping),
('results', data)
]
)
)

def paginate_queryset(
self,
queryset,
request,
table_id,
filter=None,
table,
filters=None,
order_by=[],
grouping={},
duplicate_only=None,
):
group_by = GroupBy(**grouping) if grouping else None

records = super().paginate_queryset(
queryset,
request,
table_id,
filter=filter,
table,
filters=filters,
order_by=order_by,
group_by=group_by,
duplicate_only=duplicate_only,
Expand Down
34 changes: 33 additions & 1 deletion mathesar/api/serializers/records.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
from psycopg2.errors import NotNullViolation
from rest_framework import serializers
from rest_framework import status
from sqlalchemy.exc import IntegrityError

import mathesar.api.exceptions.database_exceptions.exceptions as database_api_exceptions
from mathesar.api.exceptions.mixins import MathesarErrorMessageMixin


Expand All @@ -11,5 +15,33 @@ class RecordListParameterSerializer(MathesarErrorMessageMixin, serializers.Seria


class RecordSerializer(MathesarErrorMessageMixin, serializers.BaseSerializer):
def update(self, instance, validated_data):
table = self.context['table']
record = table.update_record(instance['id'], validated_data)
return record

def create(self, validated_data):
table = self.context['table']
try:
record = table.create_record_or_records(validated_data)
except IntegrityError as e:
if type(e.orig) == NotNullViolation:
raise database_api_exceptions.NotNullViolationAPIException(
e,
status_code=status.HTTP_400_BAD_REQUEST,
table=table
)
else:
raise database_api_exceptions.MathesarAPIException(e, status_code=status.HTTP_400_BAD_REQUEST)
return record

def to_representation(self, instance):
return instance._asdict() if not isinstance(instance, dict) else instance
records = instance._asdict() if not isinstance(instance, dict) else instance
columns_map = self.context['columns_map']
records = {columns_map[column_name]: column_value for column_name, column_value in records.items()}
return records

def to_internal_value(self, data):
columns_map = self.context['columns_map'].inverse
data = {columns_map[int(column_id)]: value for column_id, value in data.items()}
return data
26 changes: 7 additions & 19 deletions mathesar/models.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Any

from bidict import bidict
from django.contrib.auth.models import User
from django.core.cache import cache
from django.db import models
Expand Down Expand Up @@ -227,25 +228,6 @@ def sa_column_names(self):
def has_dependencies(self):
return True

def get_dj_columns_queryset(self):
sa_column_name = [column.name for column in self.sa_columns]
column_attnum_list = get_columns_attnum_from_names(self.oid, sa_column_name, self.schema._sa_engine)
return Column.objects.filter(table=self, attnum__in=column_attnum_list).order_by("attnum")

def get_dj_columns(self):
return tuple(self.get_dj_columns_queryset())

def get_dj_column_id_to_name_mapping(self):
dj_columns = self.get_dj_columns()
return dict(
(dj_column.id, dj_column.name)
for dj_column in dj_columns
)

def get_dj_column_name_to_id_mapping(self):
ids_to_names = self.get_dj_column_id_to_name_mapping()
return dict(map(reversed, ids_to_names.items()))

def add_column(self, column_data):
return create_column(
self.schema._sa_engine,
Expand Down Expand Up @@ -350,6 +332,12 @@ def add_constraint(self, constraint_type, columns, name=None):
constraint_oid = get_constraint_oid_by_name_and_table_oid(name, self.oid, engine)
return Constraint.current_objects.create(oid=constraint_oid, table=self)

def get_column_name_id_bidirectional_map(self):
# TODO: Prefetch column names to avoid N+1 queries
columns = Column.objects.filter(table_id=self.id)
columns_map = bidict({column.name: column.id for column in columns})
return columns_map


class Column(ReflectionManagerMixin, BaseModel):
table = models.ForeignKey('Table', on_delete=models.CASCADE, related_name='columns')
Expand Down
13 changes: 12 additions & 1 deletion mathesar/tests/api/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
from rest_framework.test import APIClient
from sqlalchemy import text

from db.columns.operations.select import get_column_attnum_from_name
from mathesar.database.base import create_mathesar_engine
from mathesar.imports.csv import create_table_from_csv
from mathesar.models import DataFile
from mathesar.models import Column, DataFile


TEST_SCHEMA = 'import_csv_schema'
Expand Down Expand Up @@ -55,3 +56,13 @@ def table_for_reflection(test_db_name):
yield schema_name, table_name, engine
with engine.begin() as conn:
conn.execute(text(f'DROP SCHEMA {schema_name} CASCADE;'))


@pytest.fixture
def create_column():
def _create_column(table, column_data):
column = table.add_column(column_data)
attnum = get_column_attnum_from_name(table.oid, [column.name], table.schema._sa_engine)
column = Column.current_objects.get_or_create(attnum=attnum, table=table)
return column[0]
return _create_column
Loading

0 comments on commit 8631e1c

Please sign in to comment.