Skip to content

Commit

Permalink
Fix issue when non runtime_protocol does not raise TypeError (#132)
Browse files Browse the repository at this point in the history
Backport of CPython PR 26067 (python/cpython#26067)
  • Loading branch information
AlexWaygood committed Apr 12, 2023
1 parent 25b0971 commit 7e998c2
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 28 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@
- Add `typing_extensions.Buffer`, a marker class for buffer types, as proposed
by PEP 688. Equivalent to `collections.abc.Buffer` in Python 3.12. Patch by
Jelle Zijlstra.
- Backport [CPython PR 26067](https://github.com/python/cpython/pull/26067)
(originally by Yurii Karabas), ensuring that `isinstance()` calls on
protocols raise `TypeError` when the protocol is not decorated with
`@runtime_checkable`. Patch by Alex Waygood.

# Release 4.5.0 (February 14, 2023)

Expand Down
31 changes: 25 additions & 6 deletions src/test_typing_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1421,6 +1421,22 @@ class E(C, BP): pass
self.assertNotIsInstance(D(), E)
self.assertNotIsInstance(E(), D)

@skipUnless(
hasattr(typing, "Protocol"),
"Test is only relevant if typing.Protocol exists"
)
def test_runtimecheckable_on_typing_dot_Protocol(self):
@runtime_checkable
class Foo(typing.Protocol):
x: int

class Bar:
def __init__(self):
self.x = 42

self.assertIsInstance(Bar(), Foo)
self.assertNotIsInstance(object(), Foo)

def test_no_instantiation(self):
class P(Protocol): pass
with self.assertRaises(TypeError):
Expand Down Expand Up @@ -1829,11 +1845,7 @@ def meth(self):
self.assertTrue(P._is_protocol)
self.assertTrue(PR._is_protocol)
self.assertTrue(PG._is_protocol)
if hasattr(typing, 'Protocol'):
self.assertFalse(P._is_runtime_protocol)
else:
with self.assertRaises(AttributeError):
self.assertFalse(P._is_runtime_protocol)
self.assertFalse(P._is_runtime_protocol)
self.assertTrue(PR._is_runtime_protocol)
self.assertTrue(PG[int]._is_protocol)
self.assertEqual(typing_extensions._get_protocol_attrs(P), {'meth'})
Expand Down Expand Up @@ -1929,6 +1941,13 @@ class CustomProtocol(TestCase, Protocol):
class CustomContextManager(typing.ContextManager, Protocol):
pass

def test_non_runtime_protocol_isinstance_check(self):
class P(Protocol):
x: int

with self.assertRaisesRegex(TypeError, "@runtime_checkable"):
isinstance(1, P)

def test_no_init_same_for_different_protocol_implementations(self):
class CustomProtocolWithoutInitA(Protocol):
pass
Expand Down Expand Up @@ -3314,7 +3333,7 @@ def test_typing_extensions_defers_when_possible(self):
'is_typeddict',
}
if sys.version_info < (3, 10):
exclude |= {'get_args', 'get_origin'}
exclude |= {'get_args', 'get_origin', 'Protocol', 'runtime_checkable'}
if sys.version_info < (3, 11):
exclude |= {'final', 'NamedTuple', 'Any'}
for item in typing_extensions.__all__:
Expand Down
69 changes: 47 additions & 22 deletions src/typing_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,21 +398,33 @@ def clear_overloads():
}


_EXCLUDED_ATTRS = {
"__abstractmethods__", "__annotations__", "__weakref__", "_is_protocol",
"_is_runtime_protocol", "__dict__", "__slots__", "__parameters__",
"__orig_bases__", "__module__", "_MutableMapping__marker", "__doc__",
"__subclasshook__", "__orig_class__", "__init__", "__new__",
}

if sys.version_info < (3, 8):
_EXCLUDED_ATTRS |= {
"_gorg", "__next_in_mro__", "__extra__", "__tree_hash__", "__args__",
"__origin__"
}

if sys.version_info >= (3, 9):
_EXCLUDED_ATTRS.add("__class_getitem__")

_EXCLUDED_ATTRS = frozenset(_EXCLUDED_ATTRS)


def _get_protocol_attrs(cls):
attrs = set()
for base in cls.__mro__[:-1]: # without object
if base.__name__ in ('Protocol', 'Generic'):
continue
annotations = getattr(base, '__annotations__', {})
for attr in list(base.__dict__.keys()) + list(annotations.keys()):
if (not attr.startswith('_abc_') and attr not in (
'__abstractmethods__', '__annotations__', '__weakref__',
'_is_protocol', '_is_runtime_protocol', '__dict__',
'__args__', '__slots__',
'__next_in_mro__', '__parameters__', '__origin__',
'__orig_bases__', '__extra__', '__tree_hash__',
'__doc__', '__subclasshook__', '__init__', '__new__',
'__module__', '_MutableMapping__marker', '_gorg')):
if (not attr.startswith('_abc_') and attr not in _EXCLUDED_ATTRS):
attrs.add(attr)
return attrs

Expand Down Expand Up @@ -468,11 +480,18 @@ def _caller(depth=2):
return None


# 3.8+
if hasattr(typing, 'Protocol'):
# A bug in runtime-checkable protocols was fixed in 3.10+,
# but we backport it to all versions
if sys.version_info >= (3, 10):
Protocol = typing.Protocol
# 3.7
runtime_checkable = typing.runtime_checkable
else:
def _allow_reckless_class_checks(depth=4):
"""Allow instance and class checks for special stdlib modules.
The abc and functools modules indiscriminately call isinstance() and
issubclass() on the whole MRO of a user class, which may contain protocols.
"""
return _caller(depth) in {'abc', 'functools', None}

def _no_init(self, *args, **kwargs):
if type(self)._is_protocol:
Expand All @@ -484,11 +503,19 @@ class _ProtocolMeta(abc.ABCMeta):
def __instancecheck__(cls, instance):
# We need this method for situations where attributes are
# assigned in __init__.
if ((not getattr(cls, '_is_protocol', False) or
is_protocol_cls = getattr(cls, "_is_protocol", False)
if (
is_protocol_cls and
not getattr(cls, '_is_runtime_protocol', False) and
not _allow_reckless_class_checks(depth=2)
):
raise TypeError("Instance and class checks can only be used with"
" @runtime_checkable protocols")
if ((not is_protocol_cls or
_is_callable_members_only(cls)) and
issubclass(instance.__class__, cls)):
return True
if cls._is_protocol:
if is_protocol_cls:
if all(hasattr(instance, attr) and
(not callable(getattr(cls, attr, None)) or
getattr(instance, attr) is not None)
Expand Down Expand Up @@ -530,6 +557,7 @@ def meth(self) -> T:
"""
__slots__ = ()
_is_protocol = True
_is_runtime_protocol = False

def __new__(cls, *args, **kwds):
if cls is Protocol:
Expand Down Expand Up @@ -581,12 +609,12 @@ def _proto_hook(other):
if not cls.__dict__.get('_is_protocol', None):
return NotImplemented
if not getattr(cls, '_is_runtime_protocol', False):
if _caller(depth=3) in {'abc', 'functools'}:
if _allow_reckless_class_checks():
return NotImplemented
raise TypeError("Instance and class checks can only be used with"
" @runtime protocols")
if not _is_callable_members_only(cls):
if _caller(depth=3) in {'abc', 'functools'}:
if _allow_reckless_class_checks():
return NotImplemented
raise TypeError("Protocols with non-method members"
" don't support issubclass()")
Expand Down Expand Up @@ -625,12 +653,6 @@ def _proto_hook(other):
f' protocols, got {repr(base)}')
cls.__init__ = _no_init


# 3.8+
if hasattr(typing, 'runtime_checkable'):
runtime_checkable = typing.runtime_checkable
# 3.7
else:
def runtime_checkable(cls):
"""Mark a protocol class as a runtime protocol, so that it
can be used with isinstance() and issubclass(). Raise TypeError
Expand All @@ -639,7 +661,10 @@ def runtime_checkable(cls):
This allows a simple-minded structural check very similar to the
one-offs in collections.abc such as Hashable.
"""
if not isinstance(cls, _ProtocolMeta) or not cls._is_protocol:
if not (
(isinstance(cls, _ProtocolMeta) or issubclass(cls, typing.Generic))
and getattr(cls, "_is_protocol", False)
):
raise TypeError('@runtime_checkable can be only applied to protocol classes,'
f' got {cls!r}')
cls._is_runtime_protocol = True
Expand Down

0 comments on commit 7e998c2

Please sign in to comment.