Skip to content

Commit

Permalink
refactor: split out dataclass utils (#176)
Browse files Browse the repository at this point in the history
* refactor: split out dataclass utils

* feat: struct frozen

* test: add tests

* test: fix coverage

* fix: fix compiled

* fix: fix py37

* fix: fix again py38
  • Loading branch information
tlambert03 committed Feb 21, 2023
1 parent a87ad44 commit 538b280
Show file tree
Hide file tree
Showing 4 changed files with 257 additions and 71 deletions.
5 changes: 3 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -101,10 +101,11 @@ dependencies = [
"msgspec ; python_version >= '3.8'",
]
include = [
"src/psygnal/_signal.py",
"src/psygnal/_group.py",
"src/psygnal/_dataclass_utils.py",
"src/psygnal/_evented_decorator.py",
"src/psygnal/_group_descriptor.py",
"src/psygnal/_group.py",
"src/psygnal/_signal.py",
"src/psygnal/_weak_callback.py",
]

Expand Down
188 changes: 188 additions & 0 deletions src/psygnal/_dataclass_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
from __future__ import annotations

import contextlib
import dataclasses
import sys
import types
from typing import TYPE_CHECKING, Any, Iterator, List, cast, overload

from typing_extensions import Protocol

if TYPE_CHECKING:
import attrs
import msgspec
from pydantic import BaseModel
from typing_extensions import TypeGuard

GenericAlias = getattr(types, "GenericAlias", type(List[int])) # safe for < py 3.9


class _DataclassParams(Protocol):
init: bool
repr: bool
eq: bool
order: bool
unsafe_hash: bool
frozen: bool


class AttrsType:
__attrs_attrs__: tuple[attrs.Attribute, ...]


_DATACLASS_PARAMS = "__dataclass_params__"
with contextlib.suppress(ImportError):
from dataclasses import _DATACLASS_PARAMS # type: ignore
_DATACLASS_FIELDS = "__dataclass_fields__"
with contextlib.suppress(ImportError):
from dataclasses import _DATACLASS_FIELDS # type: ignore


class DataClassType:
__dataclass_params__: _DataclassParams
__dataclass_fields__: dict[str, dataclasses.Field]


@overload
def is_dataclass(obj: type) -> TypeGuard[type[DataClassType]]:
...


@overload
def is_dataclass(obj: object) -> TypeGuard[DataClassType]:
...


def is_dataclass(obj: object) -> TypeGuard[DataClassType]:
"""Return True if the object is a dataclass."""
cls = (
obj
if isinstance(obj, type) and not isinstance(obj, GenericAlias)
else type(obj)
)
return hasattr(cls, _DATACLASS_FIELDS)


@overload
def is_attrs_class(obj: type) -> TypeGuard[type[AttrsType]]:
...


@overload
def is_attrs_class(obj: object) -> TypeGuard[AttrsType]:
...


def is_attrs_class(obj: object) -> TypeGuard[type[AttrsType]]:
"""Return True if the class is an attrs class."""
attr = sys.modules.get("attr", None)
cls = obj if isinstance(obj, type) else type(obj)
return attr.has(cls) if attr is not None else False # type: ignore [no-any-return]


@overload
def is_pydantic_model(obj: type) -> TypeGuard[type[BaseModel]]:
...


@overload
def is_pydantic_model(obj: object) -> TypeGuard[BaseModel]:
...


def is_pydantic_model(obj: object) -> TypeGuard[BaseModel]:
"""Return True if the class is a pydantic BaseModel."""
pydantic = sys.modules.get("pydantic", None)
cls = obj if isinstance(obj, type) else type(obj)
return pydantic is not None and issubclass(cls, pydantic.BaseModel)


@overload
def is_msgspec_struct(obj: type) -> TypeGuard[type[msgspec.Struct]]:
...


@overload
def is_msgspec_struct(obj: object) -> TypeGuard[msgspec.Struct]:
...


def is_msgspec_struct(obj: object) -> TypeGuard[msgspec.Struct]:
"""Return True if the class is a `msgspec.Struct`."""
msgspec = sys.modules.get("msgspec", None)
cls = obj if isinstance(obj, type) else type(obj)
return msgspec is not None and issubclass(cls, msgspec.Struct)


def is_frozen(obj: Any) -> bool:
"""Return True if the object is frozen."""
# sourcery skip: reintroduce-else
cls = obj if isinstance(obj, type) else type(obj)

params = cast("_DataclassParams | None", getattr(cls, _DATACLASS_PARAMS, None))
if params is not None:
return params.frozen

# pydantic
cfg = getattr(cls, "__config__", None)
if cfg is not None and getattr(cfg, "allow_mutation", None) is False:
return True

# attrs
if getattr(cls.__setattr__, "__name__", None) == "_frozen_setattrs":
return True

cfg = getattr(cls, "__struct_config__", None)
if cfg is not None: # pragma: no cover
# this will be covered in msgspec > 0.13.1
return bool(getattr(cfg, "frozen", False))

return False


def iter_fields(
cls: type, exclude_frozen: bool = True
) -> Iterator[tuple[str, type | None]]:
"""Iterate over all fields in the class, including inherited fields.
This function recognizes dataclasses, attrs classes, msgspec Structs, and pydantic
models.
Parameters
----------
cls : type
The class to iterate over.
exclude_frozen : bool, optional
If True, frozen fields will be excluded. By default True.
Yields
------
tuple[str, type | None]
The name and type of each field.
"""
# generally opting for speed here over public API

dclass_fields = getattr(cls, "__dataclass_fields__", None)
if dclass_fields is not None:
for d_field in dclass_fields.values():
if d_field._field_type is dataclasses._FIELD: # type: ignore [attr-defined]
yield d_field.name, d_field.type
return

if is_pydantic_model(cls):
for p_field in cls.__fields__.values():
if p_field.field_info.allow_mutation or not exclude_frozen:
yield p_field.name, p_field.outer_type_
return

attrs_fields = getattr(cls, "__attrs_attrs__", None)
if attrs_fields is not None:
for a_field in attrs_fields:
yield a_field.name, a_field.type
return

if is_msgspec_struct(cls):
for m_field in cls.__struct_fields__:
type_ = cls.__annotations__.get(m_field, None)
yield m_field, type_
return
73 changes: 4 additions & 69 deletions src/psygnal/_group_descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,35 +5,18 @@
import sys
import warnings
import weakref
from dataclasses import fields, is_dataclass
from functools import lru_cache
from typing import (
TYPE_CHECKING,
Any,
Callable,
Iterable,
Iterator,
Type,
TypeVar,
cast,
overload,
)
from typing import TYPE_CHECKING, Any, Callable, Iterable, Type, TypeVar, cast, overload

from ._dataclass_utils import iter_fields
from ._group import SignalGroup
from ._signal import Signal

if TYPE_CHECKING:
import msgspec
from pydantic import BaseModel
from typing_extensions import TypeGuard

from ._signal import SignalInstance


__all__ = ["is_evented", "get_evented_namespace", "SignalGroupDescriptor"]
_DATACLASS_PARAMS = "__dataclass_params__"
with contextlib.suppress(ImportError):
from dataclasses import _DATACLASS_PARAMS # type: ignore

T = TypeVar("T", bound=Type)
S = TypeVar("S")
Expand Down Expand Up @@ -113,55 +96,7 @@ def _check_field_equality(
return _check_field_equality(cls, name, before, after, _fail=True)


def is_attrs_class(cls: type) -> bool:
"""Return True if the class is an attrs class."""
attr = sys.modules.get("attr", None)
return attr.has(cls) if attr is not None else False # type: ignore [no-any-return]


def is_pydantic_model(cls: type) -> TypeGuard[BaseModel]:
"""Return True if the class is a pydantic BaseModel."""
pydantic = sys.modules.get("pydantic", None)
return pydantic is not None and issubclass(cls, pydantic.BaseModel)


def is_msgspec_struct(cls: type) -> TypeGuard[msgspec.Struct]:
"""Return True if the class is a `msgspec.Struct`."""
msgspec = sys.modules.get("msgspec", None)
return msgspec is not None and issubclass(cls, msgspec.Struct)


def iter_fields(cls: type) -> Iterator[tuple[str, type]]:
"""Iterate over all mutable fields in the class, including inherited fields.
This function recognizes dataclasses, attrs classes, msgspec Structs, and pydantic
models.
"""
if is_dataclass(cls):
if getattr(cls, _DATACLASS_PARAMS).frozen: # pragma: no cover
raise TypeError("Frozen dataclasses cannot be made evented.")

for d_field in fields(cls):
yield d_field.name, d_field.type

elif is_attrs_class(cls):
import attr

for a_field in attr.fields(cls):
yield a_field.name, cast("type", a_field.type)

elif is_pydantic_model(cls):
for p_field in cls.__fields__.values():
if p_field.field_info.allow_mutation:
yield p_field.name, p_field.outer_type_

elif is_msgspec_struct(cls):
for m_field in cls.__struct_fields__:
type_ = cls.__annotations__.get(m_field, None)
yield m_field, type_


def _pick_equality_operator(type_: type) -> EqOperator:
def _pick_equality_operator(type_: type | None) -> EqOperator:
"""Get the default equality operator for a given type."""
np = sys.modules.get("numpy", None)
if np is not None and hasattr(type_, "__array__"):
Expand All @@ -184,7 +119,7 @@ def _build_dataclass_signal_group(
eq_map[name] = _equality_operators[name]
else:
eq_map[name] = _pick_equality_operator(type_)
signals[name] = Signal(type_)
signals[name] = Signal(object if type_ is None else type_)

return type(f"{cls.__name__}SignalGroup", (SignalGroup,), signals)

Expand Down
62 changes: 62 additions & 0 deletions tests/test_dataclass_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
from dataclasses import dataclass

import pytest
from attr import define
from pydantic import BaseModel

from psygnal import _dataclass_utils

try:
from msgspec import Struct
except ImportError:
Struct = None

VARIANTS = ["dataclass", "attrs_class", "pydantic_model"]
if Struct is not None:
VARIANTS.append("msgspec_struct")


@pytest.mark.parametrize("frozen", [True, False], ids=["frozen", ""])
@pytest.mark.parametrize("type_", VARIANTS)
def test_dataclass_utils(type_: str, frozen: bool) -> None:
if type_ == "attrs_class":

@define(frozen=frozen) # type: ignore
class Foo:
x: int
y: str = "foo"

elif type_ == "dataclass":

@dataclass(frozen=frozen) # type: ignore
class Foo: # type: ignore [no-redef]
x: int
y: str = "foo"

elif type_ == "msgspec_struct":

class Foo(Struct, frozen=frozen): # type: ignore [no-redef]
x: int
y: str = "foo"

elif type_ == "pydantic_model":

class Foo(BaseModel): # type: ignore [no-redef]
x: int
y: str = "foo"

class Config:
allow_mutation = not frozen

for name in VARIANTS:
is_type = getattr(_dataclass_utils, f"is_{name}")
assert is_type(Foo) is (name == type_)
assert is_type(Foo(x=1)) is (name == type_)

assert list(_dataclass_utils.iter_fields(Foo)) == [("x", int), ("y", str)]

if type_ == "msgspec_struct" and frozen:
# not supported until next release of msgspec
return

assert _dataclass_utils.is_frozen(Foo) == frozen

0 comments on commit 538b280

Please sign in to comment.