Skip to content

Commit

Permalink
Backport PEP-696 specialisation on Python >=3.11.1 (#397)
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexWaygood committed May 16, 2024
1 parent 23378be commit 074d053
Show file tree
Hide file tree
Showing 2 changed files with 165 additions and 0 deletions.
68 changes: 68 additions & 0 deletions src/test_typing_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6402,6 +6402,34 @@ def test_typevartuple(self):
class A(Generic[Unpack[Ts]]): ...
Alias = Optional[Unpack[Ts]]

@skipIf(
sys.version_info < (3, 11, 1),
"Not yet backported for older versions of Python"
)
def test_typevartuple_specialization(self):
T = TypeVar("T")
Ts = TypeVarTuple('Ts', default=Unpack[Tuple[str, int]])
self.assertEqual(Ts.__default__, Unpack[Tuple[str, int]])
class A(Generic[T, Unpack[Ts]]): ...
self.assertEqual(A[float].__args__, (float, str, int))
self.assertEqual(A[float, range].__args__, (float, range))
self.assertEqual(A[float, Unpack[tuple[int, ...]]].__args__, (float, Unpack[tuple[int, ...]]))

@skipIf(
sys.version_info < (3, 11, 1),
"Not yet backported for older versions of Python"
)
def test_typevar_and_typevartuple_specialization(self):
T = TypeVar("T")
U = TypeVar("U", default=float)
Ts = TypeVarTuple('Ts', default=Unpack[Tuple[str, int]])
self.assertEqual(Ts.__default__, Unpack[Tuple[str, int]])
class A(Generic[T, U, Unpack[Ts]]): ...
self.assertEqual(A[int].__args__, (int, float, str, int))
self.assertEqual(A[int, str].__args__, (int, str, str, int))
self.assertEqual(A[int, str, range].__args__, (int, str, range))
self.assertEqual(A[int, str, Unpack[tuple[int, ...]]].__args__, (int, str, Unpack[tuple[int, ...]]))

def test_no_default_after_typevar_tuple(self):
T = TypeVar("T", default=int)
Ts = TypeVarTuple("Ts")
Expand Down Expand Up @@ -6487,6 +6515,46 @@ def test_allow_default_after_non_default_in_alias(self):
a4 = Callable[[Unpack[Ts]], T]
self.assertEqual(a4.__args__, (Unpack[Ts], T))

@skipIf(
sys.version_info < (3, 11, 1),
"Not yet backported for older versions of Python"
)
def test_paramspec_specialization(self):
T = TypeVar("T")
P = ParamSpec('P', default=[str, int])
self.assertEqual(P.__default__, [str, int])
class A(Generic[T, P]): ...
self.assertEqual(A[float].__args__, (float, (str, int)))
self.assertEqual(A[float, [range]].__args__, (float, (range,)))

@skipIf(
sys.version_info < (3, 11, 1),
"Not yet backported for older versions of Python"
)
def test_typevar_and_paramspec_specialization(self):
T = TypeVar("T")
U = TypeVar("U", default=float)
P = ParamSpec('P', default=[str, int])
self.assertEqual(P.__default__, [str, int])
class A(Generic[T, U, P]): ...
self.assertEqual(A[float].__args__, (float, float, (str, int)))
self.assertEqual(A[float, int].__args__, (float, int, (str, int)))
self.assertEqual(A[float, int, [range]].__args__, (float, int, (range,)))

@skipIf(
sys.version_info < (3, 11, 1),
"Not yet backported for older versions of Python"
)
def test_paramspec_and_typevar_specialization(self):
T = TypeVar("T")
P = ParamSpec('P', default=[str, int])
U = TypeVar("U", default=float)
self.assertEqual(P.__default__, [str, int])
class A(Generic[T, P, U]): ...
self.assertEqual(A[float].__args__, (float, (str, int), float))
self.assertEqual(A[float, [range]].__args__, (float, (range,), float))
self.assertEqual(A[float, [range], int].__args__, (float, (range,), int))


class NoDefaultTests(BaseTestCase):
@skip_if_py313_beta_1
Expand Down
97 changes: 97 additions & 0 deletions src/typing_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1513,8 +1513,19 @@ def __new__(cls, name, *constraints, bound=None,
if infer_variance and (covariant or contravariant):
raise ValueError("Variance cannot be specified with infer_variance.")
typevar.__infer_variance__ = infer_variance

_set_default(typevar, default)
_set_module(typevar)

def _tvar_prepare_subst(alias, args):
if (
typevar.has_default()
and alias.__parameters__.index(typevar) == len(args)
):
args += (typevar.__default__,)
return args

typevar.__typing_prepare_subst__ = _tvar_prepare_subst
return typevar

def __init_subclass__(cls) -> None:
Expand Down Expand Up @@ -1613,6 +1624,24 @@ def __new__(cls, name, *, bound=None,

_set_default(paramspec, default)
_set_module(paramspec)

def _paramspec_prepare_subst(alias, args):
params = alias.__parameters__
i = params.index(paramspec)
if i == len(args) and paramspec.has_default():
args = [*args, paramspec.__default__]
if i >= len(args):
raise TypeError(f"Too few arguments for {alias}")
# Special case where Z[[int, str, bool]] == Z[int, str, bool] in PEP 612.
if len(params) == 1 and not typing._is_param_expr(args[0]):
assert i == 0
args = (args,)
# Convert lists to tuples to help other libraries cache the results.
elif isinstance(args[i], list):
args = (*args[:i], tuple(args[i]), *args[i + 1:])
return args

paramspec.__typing_prepare_subst__ = _paramspec_prepare_subst
return paramspec

def __init_subclass__(cls) -> None:
Expand Down Expand Up @@ -2311,6 +2340,17 @@ def __init__(self, getitem):
class _UnpackAlias(typing._GenericAlias, _root=True):
__class__ = typing.TypeVar

@property
def __typing_unpacked_tuple_args__(self):
assert self.__origin__ is Unpack
assert len(self.__args__) == 1
arg, = self.__args__
if isinstance(arg, (typing._GenericAlias, _types.GenericAlias)):
if arg.__origin__ is not tuple:
raise TypeError("Unpack[...] must be used with a tuple type")
return arg.__args__
return None

@_UnpackSpecialForm
def Unpack(self, parameters):
item = typing._type_check(parameters, f'{self._name} accepts only a single type.')
Expand Down Expand Up @@ -2340,6 +2380,16 @@ def _is_unpack(obj):

elif hasattr(typing, "TypeVarTuple"): # 3.11+

def _unpack_args(*args):
newargs = []
for arg in args:
subargs = getattr(arg, '__typing_unpacked_tuple_args__', None)
if subargs is not None and not (subargs and subargs[-1] is ...):
newargs.extend(subargs)
else:
newargs.append(arg)
return newargs

# Add default parameter - PEP 696
class TypeVarTuple(metaclass=_TypeVarLikeMeta):
"""Type variable tuple."""
Expand All @@ -2350,6 +2400,53 @@ def __new__(cls, name, *, default=NoDefault):
tvt = typing.TypeVarTuple(name)
_set_default(tvt, default)
_set_module(tvt)

def _typevartuple_prepare_subst(alias, args):
params = alias.__parameters__
typevartuple_index = params.index(tvt)
for param in params[typevartuple_index + 1:]:
if isinstance(param, TypeVarTuple):
raise TypeError(
f"More than one TypeVarTuple parameter in {alias}"
)

alen = len(args)
plen = len(params)
left = typevartuple_index
right = plen - typevartuple_index - 1
var_tuple_index = None
fillarg = None
for k, arg in enumerate(args):
if not isinstance(arg, type):
subargs = getattr(arg, '__typing_unpacked_tuple_args__', None)
if subargs and len(subargs) == 2 and subargs[-1] is ...:
if var_tuple_index is not None:
raise TypeError(
"More than one unpacked "
"arbitrary-length tuple argument"
)
var_tuple_index = k
fillarg = subargs[0]
if var_tuple_index is not None:
left = min(left, var_tuple_index)
right = min(right, alen - var_tuple_index - 1)
elif left + right > alen:
raise TypeError(f"Too few arguments for {alias};"
f" actual {alen}, expected at least {plen - 1}")
if left == alen - right and tvt.has_default():
replacement = _unpack_args(tvt.__default__)
else:
replacement = args[left: alen - right]

return (
*args[:left],
*([fillarg] * (typevartuple_index - left)),
replacement,
*([fillarg] * (plen - right - left - typevartuple_index - 1)),
*args[alen - right:],
)

tvt.__typing_prepare_subst__ = _typevartuple_prepare_subst
return tvt

def __init_subclass__(self, *args, **kwds):
Expand Down

0 comments on commit 074d053

Please sign in to comment.