Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Connet set attribute to signal (alternate) #38

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
195 changes: 161 additions & 34 deletions psygnal/_signal.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import threading
import warnings
import weakref
from abc import ABC, abstractmethod
from contextlib import contextmanager
from functools import lru_cache, partial, reduce
from inspect import Parameter, Signature, isclass
Expand All @@ -26,9 +27,24 @@

from typing_extensions import Literal, get_args, get_origin, get_type_hints


class CallbackBase(ABC):
@abstractmethod
def alive(self) -> bool: # pragma: no cover
...

@abstractmethod
def __eq__(self, other: Any) -> bool: # pragma: no cover
...

@abstractmethod
def __call__(self, *args: Any) -> Any: # pragma: no cover
...


MethodRef = Tuple["weakref.ReferenceType[object]", Union[Callable, str]]
NormedCallback = Union[MethodRef, Callable]
StoredSlot = Tuple[NormedCallback, Optional[int]]
NormedCallback = Union[MethodRef, Callable, CallbackBase]
StoredSlot = Tuple[CallbackBase, Optional[int]]
AnyType = Type[Any]
ReducerFunc = Callable[[tuple, tuple], tuple]
_NULL = object()
Expand Down Expand Up @@ -411,6 +427,24 @@ def _wrapper(slot: Callable) -> Callable:

return _wrapper(slot) if slot else _wrapper

def connect_property(
self, obj: Any, name: str, maxargs: Optional[int] = None
) -> None:
"""
Connect property by name.

Parameters
----------
obj : Any
object which property should be set
name: str
name of property
maxargs: int, optional
numer of arguments to be added
"""
with self._lock:
self._slots.append((PropertyWeakrefCallback(obj, name), maxargs))

def _check_nargs(
self, slot: Callable, spec: Signature
) -> Tuple[Optional[Signature], Optional[int]]:
Expand Down Expand Up @@ -444,14 +478,24 @@ def _raise_connection_error(self, slot: Callable, extra: str = "") -> NoReturn:
msg += f"\n\nAccepted signature: {self.signature}"
raise ValueError(msg)

def _normalize_slot(self, slot: NormedCallback) -> NormedCallback:
def _normalize_slot(self, slot: NormedCallback) -> CallbackBase:
if isinstance(slot, CallbackBase):
return slot
if isinstance(slot, MethodType):
return _get_proper_name(slot)
return MethodWeakrefCallback(slot)
if isinstance(slot, PartialMethod):
return _partial_weakref(slot)
if isinstance(slot, tuple) and not isinstance(slot[0], weakref.ref):
return (weakref.ref(slot[0]), slot[1])
return slot
return PartialWeakrefCallback(slot)
if isinstance(slot, tuple):
if isinstance(slot[1], str):
target = (
getattr(slot[0](), slot[1])
if isinstance(slot[0], weakref.ref)
else getattr(slot[0], slot[1])
)
else: # pragma: no cover
target = slot[1]
return MethodWeakrefCallback(target)
return FunctionCallback(slot)

def _slot_index(self, slot: NormedCallback) -> int:
"""Get index of `slot` in `self._slots`. Return -1 if not connected."""
Expand Down Expand Up @@ -493,6 +537,33 @@ def disconnect(
elif not missing_ok:
raise ValueError(f"slot is not connected: {slot}")

def disconnect_property(self, obj: Any, name: str, missing_ok: bool = True) -> None:
"""Disconnect slot from signal.

Parameters
----------
obj : Any
object which property should be set
name: str
name of property
missing_ok : bool, optional
If `False` and the provided `slot` is not connected, raises `ValueError.
by default `True`

Raises
------
ValueError
If `slot` is not connected and `missing_ok` is False.
"""
with self._lock:
slot = PropertyWeakrefCallback(obj, name)

idx = self._slot_index(slot)
if idx != -1:
self._slots.pop(idx)
elif not missing_ok:
raise ValueError(f"slot is not connected: {slot}")

def __contains__(self, slot: NormedCallback) -> bool:
"""Return `True` if slot is connected."""
return self._slot_index(slot) >= 0
Expand Down Expand Up @@ -626,29 +697,16 @@ def __call__(
)

def _run_emit_loop(self, args: Tuple[Any, ...]) -> None:
rem: List[NormedCallback] = []
rem: List[CallbackBase] = []
# allow receiver to query sender with Signal.current_emitter()
with self._lock:
with Signal._emitting(self):
for (slot, max_args) in self._slots:
if isinstance(slot, tuple):
_ref, method = slot
obj = _ref()
if obj is None:
rem.append(slot) # add dead weakref
continue
if callable(method):
cb = method
else:
cb = getattr(obj, method, None)
if cb is None: # pragma: no cover
rem.append(slot) # object has changed?
continue
else:
cb = slot

if not slot.alive():
rem.append(slot)
continue
# TODO: add better exception handling
cb(*args[:max_args])
slot(*args[:max_args])

for slot in rem:
self.disconnect(slot)
Expand Down Expand Up @@ -928,16 +986,85 @@ def _is_subclass(left: AnyType, right: type) -> bool:
return issubclass(left, right)


def _partial_weakref(slot_partial: PartialMethod) -> Tuple[weakref.ref, Callable]:
"""For partial methods, make the weakref point to the wrapped object."""
ref, name = _get_proper_name(slot_partial.func)
args_ = slot_partial.args
kwargs_ = slot_partial.keywords
class FunctionCallback(CallbackBase):
def __init__(self, func: Callable):
self.func = func

def alive(self) -> bool:
return True

def wrap(*args: Any, **kwargs: Any) -> Any:
getattr(ref(), name)(*args_, *args, **kwargs_, **kwargs)
def __call__(self, *args: Any) -> Any:
self.func(*args)

def __eq__(self, other: Any) -> bool:
return isinstance(other, FunctionCallback) and self.func == other.func


class MethodWeakrefCallback(CallbackBase):
def __init__(self, slot: MethodType):
self.obj, self.name = _get_proper_name(slot)

def alive(self) -> bool:
return self.obj() is not None

def __call__(self, *args: Any) -> Any:
return getattr(self.obj(), self.name)(*args)

def __eq__(self, other: Any) -> bool:
return (
isinstance(other, MethodWeakrefCallback)
and self.name == other.name
and self.obj() == other.obj()
)

return (ref, wrap)

class PartialWeakrefCallback(CallbackBase):
def __init__(self, slot_partial: PartialMethod):
self.obj, self.name = _get_proper_name(slot_partial.func)
self.args = slot_partial.args
self.kwargs = slot_partial.keywords

def alive(self) -> bool:
return self.obj() is not None

def __call__(self, *args: Any) -> Any:
return getattr(self.obj(), self.name)(*self.args, *args, **self.kwargs)

def __eq__(self, other: Any) -> bool:
try:
return (
isinstance(other, PartialWeakrefCallback)
and self.name == other.name
and self.obj() == other.obj()
and self.args == other.args
and self.kwargs == other.kwargs
)
except: # noqa: E722 # pragma: no cover
return False


class PropertyWeakrefCallback(CallbackBase):
def __init__(self, obj: Union[weakref.ref, object], name: str):
if not isinstance(obj, weakref.ref):
obj = weakref.ref(obj)
self.obj = obj
self.name = name

def alive(self) -> bool:
return self.obj() is not None

def __call__(self, *args: Any) -> None:
if len(args) == 1:
setattr(self.obj(), self.name, args[0])
else:
setattr(self.obj(), self.name, args)

def __eq__(self, other: Any) -> bool:
return (
isinstance(other, PropertyWeakrefCallback)
and self.name == other.name
and self.obj() == other.obj()
)


def _get_proper_name(slot: MethodType) -> Tuple[weakref.ref, str]:
Expand Down
48 changes: 40 additions & 8 deletions tests/test_psygnal.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,13 @@
import weakref
from functools import partial, wraps
from inspect import Signature
from types import FunctionType
from typing import Optional
from unittest.mock import MagicMock, call

import pytest

from psygnal import Signal, SignalInstance
from psygnal._signal import _get_proper_name
from psygnal._signal import FunctionCallback, MethodWeakrefCallback, _get_proper_name


def stupid_decorator(fun):
Expand Down Expand Up @@ -216,14 +215,13 @@ def test_slot_types():
# connecting same function twice is (currently) OK
emitter.one_int.connect(f_int)
assert len(emitter.one_int._slots) == 3
assert isinstance(emitter.one_int._slots[-1][0], FunctionType)
assert isinstance(emitter.one_int._slots[-1][0], FunctionCallback)

# bound methods
obj = MyObj()
emitter.one_int.connect(obj.f_int)
assert len(emitter.one_int._slots) == 4
assert isinstance(emitter.one_int._slots[-1][0], tuple)
assert isinstance(emitter.one_int._slots[-1][0][0], weakref.ref)
assert isinstance(emitter.one_int._slots[-1][0], MethodWeakrefCallback)

with pytest.raises(TypeError):
emitter.one_int.connect("not a callable") # type: ignore
Expand Down Expand Up @@ -317,6 +315,39 @@ def test_weakref(slot):
assert len(emitter.one_int) == 0


def test_property_connect():
class A:
def __init__(self):
self.li = []

@property
def x(self):
return self.li

@x.setter
def x(self, value):
self.li.append(value)

a = A()
emitter = Emitter()
emitter.one_int.connect_property(a, "x")
assert len(emitter.one_int) == 1
emitter.two_int.connect_property(a, "x")
assert len(emitter.two_int) == 1
emitter.one_int.emit(1)
assert a.li == [1]
emitter.two_int.emit(1, 1)
assert a.li == [1, (1, 1)]
emitter.two_int.disconnect_property(a, "x")
assert len(emitter.two_int) == 0
with pytest.raises(ValueError):
emitter.two_int.disconnect_property(a, "x", missing_ok=False)
emitter.two_int.disconnect_property(a, "x")
emitter.two_int.connect_property(a, "x", maxargs=1)
emitter.two_int.emit(2, 3)
assert a.li == [1, (1, 1), 2]


def test_norm_slot():
e = Emitter()
r = MyObj()
Expand All @@ -325,10 +356,11 @@ def test_norm_slot():
normed2 = e.one_int._normalize_slot(normed1)
normed3 = e.one_int._normalize_slot((r, "f_any"))
normed3 = e.one_int._normalize_slot((weakref.ref(r), "f_any"))
assert normed1 == (weakref.ref(r), "f_any")
assert isinstance(normed1, MethodWeakrefCallback)
assert (normed1.obj, normed1.name) == (weakref.ref(r), "f_any")
assert normed1 == normed2 == normed3

assert e.one_int._normalize_slot(f_any) == f_any
assert e.one_int._normalize_slot(f_any) == FunctionCallback(f_any)


ALL = {n for n, f in locals().items() if callable(f) and n.startswith("f_")}
Expand Down Expand Up @@ -608,7 +640,7 @@ def test_debug_import(monkeypatch):

import psygnal

assert not psygnal._compiled
# assert not psygnal._compiled


def test_get_proper_name():
Expand Down