diff --git a/sqlalchemy_file/helpers.py b/sqlalchemy_file/helpers.py index 39ef3a0..021ca0d 100644 --- a/sqlalchemy_file/helpers.py +++ b/sqlalchemy_file/helpers.py @@ -4,7 +4,7 @@ import re from builtins import RuntimeError from tempfile import SpooledTemporaryFile -from typing import Any, Dict, Union +from typing import Any, Dict, List, Union INMEMORY_FILESIZE = 1024 * 1024 LOCAL_STORAGE_DRIVER_NAME = "Local Storage" @@ -74,3 +74,8 @@ def convert_size(size: Union[str, int]) -> int: si_map = {"k": 1000, "K": 1000, "M": 1000**2, "Ki": 1024, "Mi": 1024**2} return int(value) * si_map[si] return size + + +def flatmap(lists: List[List[Any]]) -> List[Any]: + """Flattens a list of lists into a single list.""" + return [value for _list in lists for value in _list] diff --git a/sqlalchemy_file/types.py b/sqlalchemy_file/types.py index 9c97000..69fbb7e 100644 --- a/sqlalchemy_file/types.py +++ b/sqlalchemy_file/types.py @@ -5,6 +5,7 @@ from sqlalchemy.orm import ColumnProperty, Mapper, Session, SessionTransaction from sqlalchemy.orm.attributes import get_history from sqlalchemy_file.file import File +from sqlalchemy_file.helpers import flatmap from sqlalchemy_file.mutable_list import MutableList from sqlalchemy_file.processors import Processor, ThumbnailGenerator from sqlalchemy_file.storage import StorageManager @@ -190,10 +191,10 @@ def extract_files_from_history(cls, data: Union[Tuple[()], List[Any]]) -> List[s paths = [] for item in data: if isinstance(item, list): - paths.extend([f["path"] for f in item]) + paths.extend([f["files"] for f in item]) elif isinstance(item, File): - paths.append(item["path"]) - return paths + paths.append(item["files"]) + return flatmap(paths) @classmethod def _mapper_configured(cls, mapper: Mapper, class_: Any) -> None: # type: ignore[type-arg] @@ -242,10 +243,12 @@ def _after_delete(cls, mapper: Mapper, _: Connection, obj: Any) -> None: # type if value is not None: cls.add_old_files_to_session( inspect(obj).session, - [ - f["path"] - for f in (value if isinstance(value, list) else [value]) - ], + flatmap( + [ + f["files"] + for f in (value if isinstance(value, list) else [value]) + ] + ), ) @classmethod @@ -280,7 +283,9 @@ def _before_update(cls, mapper: Mapper, _: Connection, obj: Any) -> None: # typ ) if isinstance(value, MutableList): _removed = getattr(value, "_removed", ()) - cls.add_old_files_to_session(session, [f["path"] for f in _removed]) + cls.add_old_files_to_session( + session, flatmap([f["files"] for f in _removed]) + ) @classmethod def _before_insert(cls, mapper: Mapper, _: Connection, obj: Any) -> None: # type: ignore[type-arg] diff --git a/tests/test_processor.py b/tests/test_processor.py index a1b81a3..568ac93 100644 --- a/tests/test_processor.py +++ b/tests/test_processor.py @@ -2,6 +2,8 @@ import tempfile import pytest +from libcloud.storage.types import ObjectDoesNotExistError +from PIL import Image from sqlalchemy import Column, Integer, String, select from sqlalchemy.orm import Session, declarative_base from sqlalchemy_file.storage import StorageManager @@ -51,8 +53,6 @@ def setup_method(self, method) -> None: def test_create_image_with_thumbnail(self, fake_image) -> None: with Session(engine) as session: - from PIL import Image - session.add(Book(title="Pointless Meetings", cover=fake_image)) session.flush() book = session.execute( @@ -66,6 +66,38 @@ def test_create_image_with_thumbnail(self, fake_image) -> None: assert book.cover["thumbnail"]["width"] == thumbnail.width assert book.cover["thumbnail"]["height"] == thumbnail.height + def test_update_image_with_thumbnail(self, fake_image) -> None: + with Session(engine) as session: + session.add(Book(title="Pointless Meetings", cover=fake_image)) + session.commit() + book = session.execute( + select(Book).where(Book.title == "Pointless Meetings") + ).scalar_one() + old_file_id = book.cover.path + old_thumbnail_file_id = book.cover.thumbnail["path"] + book.cover = fake_image + session.commit() + with pytest.raises(ObjectDoesNotExistError): + assert StorageManager.get_file(old_file_id) + with pytest.raises(ObjectDoesNotExistError): + assert StorageManager.get_file(old_thumbnail_file_id) + + def test_delete_image_with_thumbnail(self, fake_image) -> None: + with Session(engine) as session: + session.add(Book(title="Pointless Meetings", cover=fake_image)) + session.commit() + book = session.execute( + select(Book).where(Book.title == "Pointless Meetings") + ).scalar_one() + old_file_id = book.cover.path + old_thumbnail_file_id = book.cover.thumbnail["path"] + session.delete(book) + session.commit() + with pytest.raises(ObjectDoesNotExistError): + assert StorageManager.get_file(old_file_id) + with pytest.raises(ObjectDoesNotExistError): + assert StorageManager.get_file(old_thumbnail_file_id) + def teardown_method(self, method): for obj in StorageManager.get().list_objects(): obj.delete()