-
Notifications
You must be signed in to change notification settings - Fork 11
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
refactor: split out dataclass utils (#176)
* refactor: split out dataclass utils * feat: struct frozen * test: add tests * test: fix coverage * fix: fix compiled * fix: fix py37 * fix: fix again py38
- Loading branch information
1 parent
a87ad44
commit 538b280
Showing
4 changed files
with
257 additions
and
71 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,188 @@ | ||
from __future__ import annotations | ||
|
||
import contextlib | ||
import dataclasses | ||
import sys | ||
import types | ||
from typing import TYPE_CHECKING, Any, Iterator, List, cast, overload | ||
|
||
from typing_extensions import Protocol | ||
|
||
if TYPE_CHECKING: | ||
import attrs | ||
import msgspec | ||
from pydantic import BaseModel | ||
from typing_extensions import TypeGuard | ||
|
||
GenericAlias = getattr(types, "GenericAlias", type(List[int])) # safe for < py 3.9 | ||
|
||
|
||
class _DataclassParams(Protocol): | ||
init: bool | ||
repr: bool | ||
eq: bool | ||
order: bool | ||
unsafe_hash: bool | ||
frozen: bool | ||
|
||
|
||
class AttrsType: | ||
__attrs_attrs__: tuple[attrs.Attribute, ...] | ||
|
||
|
||
_DATACLASS_PARAMS = "__dataclass_params__" | ||
with contextlib.suppress(ImportError): | ||
from dataclasses import _DATACLASS_PARAMS # type: ignore | ||
_DATACLASS_FIELDS = "__dataclass_fields__" | ||
with contextlib.suppress(ImportError): | ||
from dataclasses import _DATACLASS_FIELDS # type: ignore | ||
|
||
|
||
class DataClassType: | ||
__dataclass_params__: _DataclassParams | ||
__dataclass_fields__: dict[str, dataclasses.Field] | ||
|
||
|
||
@overload | ||
def is_dataclass(obj: type) -> TypeGuard[type[DataClassType]]: | ||
... | ||
|
||
|
||
@overload | ||
def is_dataclass(obj: object) -> TypeGuard[DataClassType]: | ||
... | ||
|
||
|
||
def is_dataclass(obj: object) -> TypeGuard[DataClassType]: | ||
"""Return True if the object is a dataclass.""" | ||
cls = ( | ||
obj | ||
if isinstance(obj, type) and not isinstance(obj, GenericAlias) | ||
else type(obj) | ||
) | ||
return hasattr(cls, _DATACLASS_FIELDS) | ||
|
||
|
||
@overload | ||
def is_attrs_class(obj: type) -> TypeGuard[type[AttrsType]]: | ||
... | ||
|
||
|
||
@overload | ||
def is_attrs_class(obj: object) -> TypeGuard[AttrsType]: | ||
... | ||
|
||
|
||
def is_attrs_class(obj: object) -> TypeGuard[type[AttrsType]]: | ||
"""Return True if the class is an attrs class.""" | ||
attr = sys.modules.get("attr", None) | ||
cls = obj if isinstance(obj, type) else type(obj) | ||
return attr.has(cls) if attr is not None else False # type: ignore [no-any-return] | ||
|
||
|
||
@overload | ||
def is_pydantic_model(obj: type) -> TypeGuard[type[BaseModel]]: | ||
... | ||
|
||
|
||
@overload | ||
def is_pydantic_model(obj: object) -> TypeGuard[BaseModel]: | ||
... | ||
|
||
|
||
def is_pydantic_model(obj: object) -> TypeGuard[BaseModel]: | ||
"""Return True if the class is a pydantic BaseModel.""" | ||
pydantic = sys.modules.get("pydantic", None) | ||
cls = obj if isinstance(obj, type) else type(obj) | ||
return pydantic is not None and issubclass(cls, pydantic.BaseModel) | ||
|
||
|
||
@overload | ||
def is_msgspec_struct(obj: type) -> TypeGuard[type[msgspec.Struct]]: | ||
... | ||
|
||
|
||
@overload | ||
def is_msgspec_struct(obj: object) -> TypeGuard[msgspec.Struct]: | ||
... | ||
|
||
|
||
def is_msgspec_struct(obj: object) -> TypeGuard[msgspec.Struct]: | ||
"""Return True if the class is a `msgspec.Struct`.""" | ||
msgspec = sys.modules.get("msgspec", None) | ||
cls = obj if isinstance(obj, type) else type(obj) | ||
return msgspec is not None and issubclass(cls, msgspec.Struct) | ||
|
||
|
||
def is_frozen(obj: Any) -> bool: | ||
"""Return True if the object is frozen.""" | ||
# sourcery skip: reintroduce-else | ||
cls = obj if isinstance(obj, type) else type(obj) | ||
|
||
params = cast("_DataclassParams | None", getattr(cls, _DATACLASS_PARAMS, None)) | ||
if params is not None: | ||
return params.frozen | ||
|
||
# pydantic | ||
cfg = getattr(cls, "__config__", None) | ||
if cfg is not None and getattr(cfg, "allow_mutation", None) is False: | ||
return True | ||
|
||
# attrs | ||
if getattr(cls.__setattr__, "__name__", None) == "_frozen_setattrs": | ||
return True | ||
|
||
cfg = getattr(cls, "__struct_config__", None) | ||
if cfg is not None: # pragma: no cover | ||
# this will be covered in msgspec > 0.13.1 | ||
return bool(getattr(cfg, "frozen", False)) | ||
|
||
return False | ||
|
||
|
||
def iter_fields( | ||
cls: type, exclude_frozen: bool = True | ||
) -> Iterator[tuple[str, type | None]]: | ||
"""Iterate over all fields in the class, including inherited fields. | ||
This function recognizes dataclasses, attrs classes, msgspec Structs, and pydantic | ||
models. | ||
Parameters | ||
---------- | ||
cls : type | ||
The class to iterate over. | ||
exclude_frozen : bool, optional | ||
If True, frozen fields will be excluded. By default True. | ||
Yields | ||
------ | ||
tuple[str, type | None] | ||
The name and type of each field. | ||
""" | ||
# generally opting for speed here over public API | ||
|
||
dclass_fields = getattr(cls, "__dataclass_fields__", None) | ||
if dclass_fields is not None: | ||
for d_field in dclass_fields.values(): | ||
if d_field._field_type is dataclasses._FIELD: # type: ignore [attr-defined] | ||
yield d_field.name, d_field.type | ||
return | ||
|
||
if is_pydantic_model(cls): | ||
for p_field in cls.__fields__.values(): | ||
if p_field.field_info.allow_mutation or not exclude_frozen: | ||
yield p_field.name, p_field.outer_type_ | ||
return | ||
|
||
attrs_fields = getattr(cls, "__attrs_attrs__", None) | ||
if attrs_fields is not None: | ||
for a_field in attrs_fields: | ||
yield a_field.name, a_field.type | ||
return | ||
|
||
if is_msgspec_struct(cls): | ||
for m_field in cls.__struct_fields__: | ||
type_ = cls.__annotations__.get(m_field, None) | ||
yield m_field, type_ | ||
return |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
from dataclasses import dataclass | ||
|
||
import pytest | ||
from attr import define | ||
from pydantic import BaseModel | ||
|
||
from psygnal import _dataclass_utils | ||
|
||
try: | ||
from msgspec import Struct | ||
except ImportError: | ||
Struct = None | ||
|
||
VARIANTS = ["dataclass", "attrs_class", "pydantic_model"] | ||
if Struct is not None: | ||
VARIANTS.append("msgspec_struct") | ||
|
||
|
||
@pytest.mark.parametrize("frozen", [True, False], ids=["frozen", ""]) | ||
@pytest.mark.parametrize("type_", VARIANTS) | ||
def test_dataclass_utils(type_: str, frozen: bool) -> None: | ||
if type_ == "attrs_class": | ||
|
||
@define(frozen=frozen) # type: ignore | ||
class Foo: | ||
x: int | ||
y: str = "foo" | ||
|
||
elif type_ == "dataclass": | ||
|
||
@dataclass(frozen=frozen) # type: ignore | ||
class Foo: # type: ignore [no-redef] | ||
x: int | ||
y: str = "foo" | ||
|
||
elif type_ == "msgspec_struct": | ||
|
||
class Foo(Struct, frozen=frozen): # type: ignore [no-redef] | ||
x: int | ||
y: str = "foo" | ||
|
||
elif type_ == "pydantic_model": | ||
|
||
class Foo(BaseModel): # type: ignore [no-redef] | ||
x: int | ||
y: str = "foo" | ||
|
||
class Config: | ||
allow_mutation = not frozen | ||
|
||
for name in VARIANTS: | ||
is_type = getattr(_dataclass_utils, f"is_{name}") | ||
assert is_type(Foo) is (name == type_) | ||
assert is_type(Foo(x=1)) is (name == type_) | ||
|
||
assert list(_dataclass_utils.iter_fields(Foo)) == [("x", int), ("y", str)] | ||
|
||
if type_ == "msgspec_struct" and frozen: | ||
# not supported until next release of msgspec | ||
return | ||
|
||
assert _dataclass_utils.is_frozen(Foo) == frozen |