diff --git a/build/rocm/run_single_gpu.py b/build/rocm/run_single_gpu.py index 13dac5ebbb48..add7ee3d86b5 100755 --- a/build/rocm/run_single_gpu.py +++ b/build/rocm/run_single_gpu.py @@ -32,7 +32,7 @@ def extract_filename(path): def generate_final_report(shell=False, env_vars={}): env = os.environ env = {**env, **env_vars} - cmd = ["pytest_html_merger", "-i", '{}'.format(base_dir), "-o", '{}/final_compiled_report.html'.format(base_dir)] + cmd = ["pytest_html_merger", "-i", f'{base_dir}', "-o", f'{base_dir}/final_compiled_report.html'] result = subprocess.run(cmd, shell=shell, capture_output=True, @@ -90,7 +90,7 @@ def run_test(testmodule, gpu_tokens): "XLA_PYTHON_CLIENT_ALLOCATOR": "default", } testfile = extract_filename(testmodule) - cmd = ["python3", "-m", "pytest", '--html={}/{}_log.html'.format(base_dir, testfile), "--reruns", "3", "-x", testmodule] + cmd = ["python3", "-m", "pytest", f'--html={base_dir}/{testfile}_log.html', "--reruns", "3", "-x", testmodule] return_code, stderr, stdout = run_shell_command(cmd, env_vars=env_vars) with GPU_LOCK: gpu_tokens.append(target_gpu) diff --git a/docs/Custom_Operation_for_GPUs.py b/docs/Custom_Operation_for_GPUs.py index 09aec08be4d4..4c0b4b6f7b38 100644 --- a/docs/Custom_Operation_for_GPUs.py +++ b/docs/Custom_Operation_for_GPUs.py @@ -14,7 +14,6 @@ from functools import partial, reduce import math -from typing import Tuple import jax import jax.numpy as jnp @@ -325,9 +324,9 @@ def batcher(batched_args, batch_dims, *, eps): return RmsNormFwdClass.outer_primitive.bind(x, gamma, eps=eps), out_bdims @staticmethod - def infer_sharding_from_operands(eps : float, mesh : jax.sharding.Mesh, - arg_infos : Tuple[jax._src.api.ShapeDtypeStruct], - result_infos : Tuple[jax._src.core.ShapedArray]): + def infer_sharding_from_operands(eps: float, mesh : jax.sharding.Mesh, + arg_infos: tuple[jax._src.api.ShapeDtypeStruct, ...], + result_infos: tuple[jax._src.core.ShapedArray, ...]): del eps, result_infos # Not needed for this example. x_info, weight_info = arg_infos assert len(x_info.shape) == 3 @@ -340,9 +339,9 @@ def infer_sharding_from_operands(eps : float, mesh : jax.sharding.Mesh, return (output_sharding, invvar_sharding) @staticmethod - def partition(eps : float, mesh : jax.sharding.Mesh, - arg_infos : Tuple[jax._src.api.ShapeDtypeStruct], - result_infos : Tuple[jax._src.api.ShapeDtypeStruct]): + def partition(eps: float, mesh : jax.sharding.Mesh, + arg_infos: tuple[jax._src.api.ShapeDtypeStruct, ...], + result_infos: tuple[jax._src.api.ShapeDtypeStruct, ...]): del result_infos # Not needed for this example. x_info, weight_info = arg_infos assert len(x_info.shape) == 3 @@ -395,9 +394,9 @@ def batcher(batched_args, batch_dims, *, eps): return RmsNormBwdClass.outer_primitive.bind(x, gamma, eps=eps), out_bdims @staticmethod - def infer_sharding_from_operands(eps : float, mesh : jax.sharding.Mesh, - arg_infos : Tuple[jax._src.api.ShapeDtypeStruct], - result_infos : Tuple[jax._src.core.ShapedArray]): + def infer_sharding_from_operands(eps: float, mesh : jax.sharding.Mesh, + arg_infos: tuple[jax._src.api.ShapeDtypeStruct, ...], + result_infos: tuple[jax._src.core.ShapedArray, ...]): del eps, result_infos # Not needed for this example. g_info, invvar_info, x_info, weight_info = arg_infos assert len(g_info.shape) == 3 @@ -411,9 +410,9 @@ def infer_sharding_from_operands(eps : float, mesh : jax.sharding.Mesh, return (output_sharding, invvar_sharding, output_sharding, ) @staticmethod - def partition(eps : float, mesh : jax.sharding.Mesh, - arg_infos : Tuple[jax._src.api.ShapeDtypeStruct], - result_infos : Tuple[jax._src.api.ShapeDtypeStruct]): + def partition(eps: float, mesh : jax.sharding.Mesh, + arg_infos: tuple[jax._src.api.ShapeDtypeStruct, ...], + result_infos: tuple[jax._src.api.ShapeDtypeStruct, ...]): del result_infos # Not needed for this example. g_info, invvar_info, x_info, weight_info = arg_infos assert len(g_info.shape) == 3 diff --git a/docs/autodidax.ipynb b/docs/autodidax.ipynb index 11752f0b0b74..ed242ecc5710 100644 --- a/docs/autodidax.ipynb +++ b/docs/autodidax.ipynb @@ -167,15 +167,15 @@ "source": [ "from collections.abc import Sequence\n", "from contextlib import contextmanager\n", - "from typing import Optional, Any\n", + "from typing import Any\n", "\n", "class MainTrace(NamedTuple):\n", " level: int\n", " trace_type: type['Trace']\n", - " global_data: Optional[Any]\n", + " global_data: Any | None\n", "\n", "trace_stack: list[MainTrace] = []\n", - "dynamic_trace: Optional[MainTrace] = None # to be employed in Part 3\n", + "dynamic_trace: MainTrace | None = None # to be employed in Part 3\n", "\n", "@contextmanager\n", "def new_main(trace_type: type['Trace'], global_data=None):\n", @@ -912,7 +912,7 @@ "source": [ "from collections.abc import Hashable, Iterable, Iterator\n", "import itertools as it\n", - "from typing import Callable\n", + "from collections.abc import Callable\n", "\n", "class NodeType(NamedTuple):\n", " name: str\n", @@ -1651,7 +1651,7 @@ "source": [ "from functools import lru_cache\n", "\n", - "@lru_cache() # ShapedArrays are hashable\n", + "@lru_cache # ShapedArrays are hashable\n", "def make_jaxpr_v1(f, *avals_in):\n", " avals_in, in_tree = tree_flatten(avals_in)\n", " f, out_tree = flatten_fun(f, in_tree)\n", @@ -1803,7 +1803,7 @@ " finally:\n", " dynamic_trace = prev_dynamic_trace\n", "\n", - "@lru_cache()\n", + "@lru_cache\n", "def make_jaxpr(f: Callable, *avals_in: ShapedArray,\n", " ) -> tuple[Jaxpr, list[Any], PyTreeDef]:\n", " avals_in, in_tree = tree_flatten(avals_in)\n", @@ -1994,7 +1994,7 @@ " return execute(*args)\n", "impl_rules[xla_call_p] = xla_call_impl\n", "\n", - "@lru_cache()\n", + "@lru_cache\n", "def xla_callable(hashable_jaxpr: IDHashable,\n", " hashable_consts: tuple[IDHashable, ...]):\n", " jaxpr: Jaxpr = hashable_jaxpr.val\n", @@ -2227,7 +2227,7 @@ " return primals_out, tangents_out\n", "jvp_rules[xla_call_p] = xla_call_jvp_rule\n", "\n", - "@lru_cache()\n", + "@lru_cache\n", "def jvp_jaxpr(jaxpr: Jaxpr) -> tuple[Jaxpr, list[Any]]:\n", " def jvp_traceable(*primals_and_tangents):\n", " n = len(primals_and_tangents) // 2\n", @@ -2253,7 +2253,7 @@ " return outs, [0] * len(outs)\n", "vmap_rules[xla_call_p] = xla_call_vmap_rule\n", "\n", - "@lru_cache()\n", + "@lru_cache\n", "def vmap_jaxpr(jaxpr: Jaxpr, axis_size: int, bdims_in: tuple[BatchAxis, ...]\n", " ) -> tuple[Jaxpr, list[Any]]:\n", " vmap_traceable = vmap(jaxpr_as_fun(jaxpr), tuple(bdims_in))\n", @@ -2638,7 +2638,7 @@ "source": [ "class PartialVal(NamedTuple):\n", " aval: ShapedArray\n", - " const: Optional[Any]\n", + " const: Any | None\n", "\n", " @classmethod\n", " def known(cls, val: Any):\n", @@ -2727,7 +2727,7 @@ "source": [ "class PartialEvalTracer(Tracer):\n", " pval: PartialVal\n", - " recipe: Optional[JaxprRecipe]\n", + " recipe: JaxprRecipe | None\n", "\n", " def __init__(self, trace, pval, recipe):\n", " self._trace = trace\n", @@ -2974,7 +2974,7 @@ "partial_eval_rules[xla_call_p] = xla_call_partial_eval\n", "\n", "def partial_eval_jaxpr(jaxpr: Jaxpr, in_unknowns: list[bool],\n", - " instantiate: Optional[list[bool]] = None,\n", + " instantiate: list[bool] | None = None,\n", " ) -> tuple[Jaxpr, Jaxpr, list[bool], int]:\n", " env: dict[Var, bool] = {}\n", " residuals: set[Var] = set()\n", @@ -3271,7 +3271,7 @@ " return [next(outs) if undef else None for undef in undef_primals]\n", "transpose_rules[xla_call_p] = xla_call_transpose_rule\n", "\n", - "@lru_cache()\n", + "@lru_cache\n", "def transpose_jaxpr(jaxpr: Jaxpr, undef_primals: tuple[bool, ...]\n", " ) -> tuple[Jaxpr, list[Any]]:\n", " avals_in, avals_out = typecheck_jaxpr(jaxpr)\n", diff --git a/docs/autodidax.md b/docs/autodidax.md index 07e997a3de96..0551b9905db3 100644 --- a/docs/autodidax.md +++ b/docs/autodidax.md @@ -148,15 +148,15 @@ more descriptive. ```{code-cell} from collections.abc import Sequence from contextlib import contextmanager -from typing import Optional, Any +from typing import Any class MainTrace(NamedTuple): level: int trace_type: type['Trace'] - global_data: Optional[Any] + global_data: Any | None trace_stack: list[MainTrace] = [] -dynamic_trace: Optional[MainTrace] = None # to be employed in Part 3 +dynamic_trace: MainTrace | None = None # to be employed in Part 3 @contextmanager def new_main(trace_type: type['Trace'], global_data=None): @@ -705,7 +705,7 @@ class Store: from collections.abc import Hashable, Iterable, Iterator import itertools as it -from typing import Callable +from collections.abc import Callable class NodeType(NamedTuple): name: str @@ -1295,7 +1295,7 @@ transformation and a pretty-printer: ```{code-cell} from functools import lru_cache -@lru_cache() # ShapedArrays are hashable +@lru_cache # ShapedArrays are hashable def make_jaxpr_v1(f, *avals_in): avals_in, in_tree = tree_flatten(avals_in) f, out_tree = flatten_fun(f, in_tree) @@ -1415,7 +1415,7 @@ def new_dynamic(main: MainTrace): finally: dynamic_trace = prev_dynamic_trace -@lru_cache() +@lru_cache def make_jaxpr(f: Callable, *avals_in: ShapedArray, ) -> tuple[Jaxpr, list[Any], PyTreeDef]: avals_in, in_tree = tree_flatten(avals_in) @@ -1564,7 +1564,7 @@ def xla_call_impl(*args, jaxpr: Jaxpr, num_consts: int): return execute(*args) impl_rules[xla_call_p] = xla_call_impl -@lru_cache() +@lru_cache def xla_callable(hashable_jaxpr: IDHashable, hashable_consts: tuple[IDHashable, ...]): jaxpr: Jaxpr = hashable_jaxpr.val @@ -1734,7 +1734,7 @@ def xla_call_jvp_rule(primals, tangents, *, jaxpr, num_consts): return primals_out, tangents_out jvp_rules[xla_call_p] = xla_call_jvp_rule -@lru_cache() +@lru_cache def jvp_jaxpr(jaxpr: Jaxpr) -> tuple[Jaxpr, list[Any]]: def jvp_traceable(*primals_and_tangents): n = len(primals_and_tangents) // 2 @@ -1755,7 +1755,7 @@ def xla_call_vmap_rule(axis_size, vals_in, dims_in, *, jaxpr, num_consts): return outs, [0] * len(outs) vmap_rules[xla_call_p] = xla_call_vmap_rule -@lru_cache() +@lru_cache def vmap_jaxpr(jaxpr: Jaxpr, axis_size: int, bdims_in: tuple[BatchAxis, ...] ) -> tuple[Jaxpr, list[Any]]: vmap_traceable = vmap(jaxpr_as_fun(jaxpr), tuple(bdims_in)) @@ -2065,7 +2065,7 @@ be either known or unknown: ```{code-cell} class PartialVal(NamedTuple): aval: ShapedArray - const: Optional[Any] + const: Any | None @classmethod def known(cls, val: Any): @@ -2129,7 +2129,7 @@ JaxprRecipe = Union[LambdaBindingRecipe, ConstRecipe, JaxprEqnRecipe] ```{code-cell} class PartialEvalTracer(Tracer): pval: PartialVal - recipe: Optional[JaxprRecipe] + recipe: JaxprRecipe | None def __init__(self, trace, pval, recipe): self._trace = trace @@ -2329,7 +2329,7 @@ def xla_call_partial_eval(trace, tracers, *, jaxpr, num_consts): partial_eval_rules[xla_call_p] = xla_call_partial_eval def partial_eval_jaxpr(jaxpr: Jaxpr, in_unknowns: list[bool], - instantiate: Optional[list[bool]] = None, + instantiate: list[bool] | None = None, ) -> tuple[Jaxpr, Jaxpr, list[bool], int]: env: dict[Var, bool] = {} residuals: set[Var] = set() @@ -2586,7 +2586,7 @@ def xla_call_transpose_rule(cts, *invals, jaxpr, num_consts): return [next(outs) if undef else None for undef in undef_primals] transpose_rules[xla_call_p] = xla_call_transpose_rule -@lru_cache() +@lru_cache def transpose_jaxpr(jaxpr: Jaxpr, undef_primals: tuple[bool, ...] ) -> tuple[Jaxpr, list[Any]]: avals_in, avals_out = typecheck_jaxpr(jaxpr) diff --git a/docs/autodidax.py b/docs/autodidax.py index 8a4f83fec50b..b09534381c69 100644 --- a/docs/autodidax.py +++ b/docs/autodidax.py @@ -138,15 +138,15 @@ def bind1(prim, *args, **params): # + from collections.abc import Sequence from contextlib import contextmanager -from typing import Optional, Any +from typing import Any class MainTrace(NamedTuple): level: int trace_type: type['Trace'] - global_data: Optional[Any] + global_data: Any | None trace_stack: list[MainTrace] = [] -dynamic_trace: Optional[MainTrace] = None # to be employed in Part 3 +dynamic_trace: MainTrace | None = None # to be employed in Part 3 @contextmanager def new_main(trace_type: type['Trace'], global_data=None): @@ -697,7 +697,7 @@ def __call__(self): # + tags=["hide-input"] from collections.abc import Hashable, Iterable, Iterator import itertools as it -from typing import Callable +from collections.abc import Callable class NodeType(NamedTuple): name: str @@ -1297,7 +1297,7 @@ def broadcast_abstract_eval(x: ShapedArray, *, shape: Sequence[int], # + from functools import lru_cache -@lru_cache() # ShapedArrays are hashable +@lru_cache # ShapedArrays are hashable def make_jaxpr_v1(f, *avals_in): avals_in, in_tree = tree_flatten(avals_in) f, out_tree = flatten_fun(f, in_tree) @@ -1412,7 +1412,7 @@ def new_dynamic(main: MainTrace): finally: dynamic_trace = prev_dynamic_trace -@lru_cache() +@lru_cache def make_jaxpr(f: Callable, *avals_in: ShapedArray, ) -> tuple[Jaxpr, list[Any], PyTreeDef]: avals_in, in_tree = tree_flatten(avals_in) @@ -1556,7 +1556,7 @@ def xla_call_impl(*args, jaxpr: Jaxpr, num_consts: int): return execute(*args) impl_rules[xla_call_p] = xla_call_impl -@lru_cache() +@lru_cache def xla_callable(hashable_jaxpr: IDHashable, hashable_consts: tuple[IDHashable, ...]): jaxpr: Jaxpr = hashable_jaxpr.val @@ -1728,7 +1728,7 @@ def xla_call_jvp_rule(primals, tangents, *, jaxpr, num_consts): return primals_out, tangents_out jvp_rules[xla_call_p] = xla_call_jvp_rule -@lru_cache() +@lru_cache def jvp_jaxpr(jaxpr: Jaxpr) -> tuple[Jaxpr, list[Any]]: def jvp_traceable(*primals_and_tangents): n = len(primals_and_tangents) // 2 @@ -1749,7 +1749,7 @@ def xla_call_vmap_rule(axis_size, vals_in, dims_in, *, jaxpr, num_consts): return outs, [0] * len(outs) vmap_rules[xla_call_p] = xla_call_vmap_rule -@lru_cache() +@lru_cache def vmap_jaxpr(jaxpr: Jaxpr, axis_size: int, bdims_in: tuple[BatchAxis, ...] ) -> tuple[Jaxpr, list[Any]]: vmap_traceable = vmap(jaxpr_as_fun(jaxpr), tuple(bdims_in)) @@ -2057,7 +2057,7 @@ def vspace(aval: ShapedArray) -> ShapedArray: class PartialVal(NamedTuple): aval: ShapedArray - const: Optional[Any] + const: Any | None @classmethod def known(cls, val: Any): @@ -2121,7 +2121,7 @@ class JaxprEqnRecipe(NamedTuple): class PartialEvalTracer(Tracer): pval: PartialVal - recipe: Optional[JaxprRecipe] + recipe: JaxprRecipe | None def __init__(self, trace, pval, recipe): self._trace = trace @@ -2322,7 +2322,7 @@ def xla_call_partial_eval(trace, tracers, *, jaxpr, num_consts): partial_eval_rules[xla_call_p] = xla_call_partial_eval def partial_eval_jaxpr(jaxpr: Jaxpr, in_unknowns: list[bool], - instantiate: Optional[list[bool]] = None, + instantiate: list[bool] | None = None, ) -> tuple[Jaxpr, Jaxpr, list[bool], int]: env: dict[Var, bool] = {} residuals: set[Var] = set() @@ -2585,7 +2585,7 @@ def xla_call_transpose_rule(cts, *invals, jaxpr, num_consts): return [next(outs) if undef else None for undef in undef_primals] transpose_rules[xla_call_p] = xla_call_transpose_rule -@lru_cache() +@lru_cache def transpose_jaxpr(jaxpr: Jaxpr, undef_primals: tuple[bool, ...] ) -> tuple[Jaxpr, list[Any]]: avals_in, avals_out = typecheck_jaxpr(jaxpr) diff --git a/jax/_src/ad_checkpoint.py b/jax/_src/ad_checkpoint.py index d7a86687cacb..4b52292f0e2d 100644 --- a/jax/_src/ad_checkpoint.py +++ b/jax/_src/ad_checkpoint.py @@ -14,11 +14,11 @@ from __future__ import annotations -from collections.abc import Sequence +from collections.abc import Callable, Sequence import functools from functools import partial import logging -from typing import Any, Callable +from typing import Any import types import numpy as np diff --git a/jax/_src/ad_util.py b/jax/_src/ad_util.py index 257fcc7c527c..90ae6c1413ec 100644 --- a/jax/_src/ad_util.py +++ b/jax/_src/ad_util.py @@ -13,8 +13,9 @@ # limitations under the License. from __future__ import annotations +from collections.abc import Callable import types -from typing import Any, Callable, TypeVar +from typing import Any, TypeVar from jax._src import core from jax._src import traceback_util diff --git a/jax/_src/api.py b/jax/_src/api.py index 54ad6ff74a72..4a42693c2e8f 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -23,12 +23,12 @@ from __future__ import annotations import collections -from collections.abc import Generator, Hashable, Iterable, Sequence +from collections.abc import Callable, Generator, Hashable, Iterable, Sequence from functools import partial, lru_cache import inspect import math import typing -from typing import (Any, Callable, Literal, NamedTuple, TypeVar, overload, +from typing import (Any, Literal, NamedTuple, TypeVar, overload, cast) import weakref diff --git a/jax/_src/api_util.py b/jax/_src/api_util.py index 67a8ee8bdfac..16a29e699bbc 100644 --- a/jax/_src/api_util.py +++ b/jax/_src/api_util.py @@ -14,11 +14,11 @@ from __future__ import annotations -from collections.abc import Iterable, Sequence +from collections.abc import Callable, Iterable, Sequence import inspect import operator from functools import partial, lru_cache -from typing import Any, Callable, Type +from typing import Any import numpy as np @@ -713,6 +713,6 @@ def __hash__(self): def __eq__(self, other): return self.val is other.val -def register_class_with_attrs(t: Type) -> None: +def register_class_with_attrs(t: type) -> None: _class_with_attrs.add(t) -_class_with_attrs: set[Type] = set() +_class_with_attrs: set[type] = set() diff --git a/jax/_src/array.py b/jax/_src/array.py index b8730c1b0820..6e3f0a76f512 100644 --- a/jax/_src/array.py +++ b/jax/_src/array.py @@ -15,12 +15,12 @@ from __future__ import annotations from collections import defaultdict -from collections.abc import Sequence +from collections.abc import Callable, Sequence import enum import functools import math import operator as op -from typing import Any, Callable, TYPE_CHECKING, cast +from typing import Any, TYPE_CHECKING, cast from jax._src import abstract_arrays from jax._src import api diff --git a/jax/_src/blocked_sampler.py b/jax/_src/blocked_sampler.py index 66d0f53abc28..16da61d75b3f 100644 --- a/jax/_src/blocked_sampler.py +++ b/jax/_src/blocked_sampler.py @@ -12,7 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Protocol, Sequence +from collections.abc import Sequence +from typing import Any, Protocol import jax from jax._src import random from jax._src.typing import Array, ArrayLike diff --git a/jax/_src/callback.py b/jax/_src/callback.py index 8e9ab0b62593..054804379043 100644 --- a/jax/_src/callback.py +++ b/jax/_src/callback.py @@ -14,11 +14,11 @@ """Module for JAX callbacks.""" from __future__ import annotations -from collections.abc import Sequence +from collections.abc import Callable, Sequence import dataclasses import functools import logging -from typing import Any, Callable +from typing import Any import jax from jax._src import core diff --git a/jax/_src/checkify.py b/jax/_src/checkify.py index 3727c8364a12..9bbaa1296c93 100644 --- a/jax/_src/checkify.py +++ b/jax/_src/checkify.py @@ -13,11 +13,11 @@ # limitations under the License. from __future__ import annotations -from collections.abc import Sequence +from collections.abc import Callable, Sequence import dataclasses import functools import itertools as it -from typing import Callable, TypeVar, Any, Union +from typing import TypeVar, Any, Union import numpy as np diff --git a/jax/_src/cloud_tpu_init.py b/jax/_src/cloud_tpu_init.py index 19f0be1ff31c..73e61de68008 100644 --- a/jax/_src/cloud_tpu_init.py +++ b/jax/_src/cloud_tpu_init.py @@ -16,7 +16,6 @@ from jax import version from jax._src import config from jax._src import hardware_utils -from typing import Optional running_in_cloud_tpu_vm: bool = False @@ -35,7 +34,7 @@ def maybe_import_libtpu(): return libtpu -def get_tpu_library_path() -> Optional[str]: +def get_tpu_library_path() -> str | None: path_from_env = os.getenv("TPU_LIBRARY_PATH") if path_from_env is not None and os.path.isfile(path_from_env): return path_from_env diff --git a/jax/_src/compiler.py b/jax/_src/compiler.py index 1d08b8296033..438f1f9e5183 100644 --- a/jax/_src/compiler.py +++ b/jax/_src/compiler.py @@ -21,7 +21,7 @@ import os import tempfile import time -from typing import Any, Optional +from typing import Any import warnings from jax._src import compilation_cache @@ -393,7 +393,7 @@ def _share_fdo_profiles( backend: xc.Client, global_client: lib.xla_extension.DistributedRuntimeClient, min_process_id -) -> Optional[bytes]: +) -> bytes | None: sym_name = computation.operation.attributes['sym_name'] module_name = ir.StringAttr(sym_name).value fdo_profile = compile_options.executable_build_options.fdo_profile diff --git a/jax/_src/config.py b/jax/_src/config.py index a0d575b37d6d..d9e5ee0137d0 100644 --- a/jax/_src/config.py +++ b/jax/_src/config.py @@ -14,7 +14,7 @@ from __future__ import annotations -from collections.abc import Hashable, Iterator, Sequence +from collections.abc import Callable, Hashable, Iterator, Sequence import contextlib import functools import itertools @@ -22,9 +22,7 @@ import os import sys import threading -from typing import ( - Any, Callable, Generic, NamedTuple, NoReturn, Protocol, TypeVar, cast, -) +from typing import Any, Generic, NamedTuple, NoReturn, Protocol, TypeVar, cast from jax._src import lib from jax._src.lib import jax_jit diff --git a/jax/_src/core.py b/jax/_src/core.py index b0f4c97e6c3d..ecb801afed8d 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -14,8 +14,8 @@ from __future__ import annotations from collections import Counter, defaultdict, deque, namedtuple -from collections.abc import (Collection, Generator, Hashable, Iterable, - Iterator, Set, Sequence, MutableSet, +from collections.abc import (Callable, Collection, Generator, Hashable, + Iterable, Iterator, Set, Sequence, MutableSet, MutableMapping) from contextlib import contextmanager, ExitStack from dataclasses import dataclass @@ -28,7 +28,7 @@ import operator import threading import types -from typing import (Any, Callable, ClassVar, Generic, NamedTuple, TypeVar, +from typing import (Any, ClassVar, Generic, NamedTuple, TypeVar, cast, overload, Union) import warnings from weakref import ref diff --git a/jax/_src/cudnn/fused_attention_stablehlo.py b/jax/_src/cudnn/fused_attention_stablehlo.py index a9009da74620..e5b1f0084d00 100644 --- a/jax/_src/cudnn/fused_attention_stablehlo.py +++ b/jax/_src/cudnn/fused_attention_stablehlo.py @@ -15,7 +15,6 @@ from enum import Enum from functools import partial, reduce import operator -from typing import Optional import json import jax @@ -927,10 +926,10 @@ def _dot_product_attention(query: Array, def dot_product_attention(query: Array, key: Array, value: Array, - bias: Optional[Array] = None, - mask: Optional[Array] = None, - q_seqlen: Optional[Array] = None, - kv_seqlen: Optional[Array] = None, + bias: Array | None = None, + mask: Array | None = None, + q_seqlen: Array | None = None, + kv_seqlen: Array | None = None, *, scale: float = 1.0, mask_type: MaskType = MaskType.NO_MASK, diff --git a/jax/_src/custom_batching.py b/jax/_src/custom_batching.py index 22592b854d14..4d41849b75d3 100644 --- a/jax/_src/custom_batching.py +++ b/jax/_src/custom_batching.py @@ -14,9 +14,9 @@ from __future__ import annotations +from collections.abc import Callable import functools import operator -from typing import Callable from jax import lax from jax._src import api diff --git a/jax/_src/custom_derivatives.py b/jax/_src/custom_derivatives.py index 142ced1944eb..46d9fab00455 100644 --- a/jax/_src/custom_derivatives.py +++ b/jax/_src/custom_derivatives.py @@ -14,11 +14,11 @@ from __future__ import annotations -from collections.abc import Sequence +from collections.abc import Callable, Sequence import dataclasses from functools import update_wrapper, reduce, partial import inspect -from typing import Any, Callable, Generic, TypeVar +from typing import Any, Generic, TypeVar from jax._src import config from jax._src import core diff --git a/jax/_src/custom_transpose.py b/jax/_src/custom_transpose.py index da132e085086..a4de1b8cc46c 100644 --- a/jax/_src/custom_transpose.py +++ b/jax/_src/custom_transpose.py @@ -14,8 +14,9 @@ from __future__ import annotations +from collections.abc import Callable import functools -from typing import Any, Callable +from typing import Any from jax._src import ad_util from jax._src import api_util diff --git a/jax/_src/debugging.py b/jax/_src/debugging.py index 265dcd0de843..7d8b3a914b6d 100644 --- a/jax/_src/debugging.py +++ b/jax/_src/debugging.py @@ -16,12 +16,12 @@ from __future__ import annotations import importlib.util -from collections.abc import Sequence +from collections.abc import Callable, Sequence import functools import logging import string import sys -from typing import Any, Callable, Union +from typing import Any, Union import weakref import numpy as np diff --git a/jax/_src/dispatch.py b/jax/_src/dispatch.py index 82aedb247f7c..9ae1f7a6c2a3 100644 --- a/jax/_src/dispatch.py +++ b/jax/_src/dispatch.py @@ -16,13 +16,13 @@ from __future__ import annotations import atexit -from collections.abc import Iterator, Sequence +from collections.abc import Callable, Iterator, Sequence import contextlib import dataclasses from functools import partial import itertools import time -from typing import Any, Callable, NamedTuple +from typing import Any, NamedTuple import logging import threading diff --git a/jax/_src/export/_export.py b/jax/_src/export/_export.py index efd240aa9679..a228eaa8b285 100644 --- a/jax/_src/export/_export.py +++ b/jax/_src/export/_export.py @@ -17,13 +17,13 @@ from __future__ import annotations -from collections.abc import Sequence +from collections.abc import Callable, Sequence import copy import dataclasses import functools import itertools import re -from typing import Any, Callable, Union +from typing import Any, Union import warnings from absl import logging diff --git a/jax/_src/export/serialization.py b/jax/_src/export/serialization.py index 5dd44f0d798e..a47b095e4450 100644 --- a/jax/_src/export/serialization.py +++ b/jax/_src/export/serialization.py @@ -16,9 +16,9 @@ from __future__ import annotations -from typing import Callable, TypeVar -from collections.abc import Sequence +from collections.abc import Callable, Sequence from functools import partial +from typing import TypeVar try: import flatbuffers diff --git a/jax/_src/export/serialization_generated.py b/jax/_src/export/serialization_generated.py index e8bd6b4a29db..a872d03a9fdd 100644 --- a/jax/_src/export/serialization_generated.py +++ b/jax/_src/export/serialization_generated.py @@ -21,7 +21,7 @@ from flatbuffers.compat import import_numpy np = import_numpy() -class PyTreeDefKind(object): +class PyTreeDefKind: leaf = 0 none = 1 tuple = 2 @@ -29,12 +29,12 @@ class PyTreeDefKind(object): dict = 4 -class AbstractValueKind(object): +class AbstractValueKind: shapedArray = 0 abstractToken = 1 -class DType(object): +class DType: bool = 0 i8 = 1 i16 = 2 @@ -60,18 +60,18 @@ class DType(object): f0 = 22 -class ShardingKind(object): +class ShardingKind: unspecified = 0 hlo_sharding = 1 -class DisabledSafetyCheckKind(object): +class DisabledSafetyCheckKind: platform = 0 custom_call = 1 shape_assertions = 2 -class PyTreeDef(object): +class PyTreeDef: __slots__ = ['_tab'] @classmethod @@ -163,7 +163,7 @@ def PyTreeDefEnd(builder): -class AbstractValue(object): +class AbstractValue: __slots__ = ['_tab'] @classmethod @@ -235,7 +235,7 @@ def AbstractValueEnd(builder): -class Sharding(object): +class Sharding: __slots__ = ['_tab'] @classmethod @@ -304,7 +304,7 @@ def ShardingEnd(builder): -class Effect(object): +class Effect: __slots__ = ['_tab'] @classmethod @@ -340,7 +340,7 @@ def EffectEnd(builder): -class DisabledSafetyCheck(object): +class DisabledSafetyCheck: __slots__ = ['_tab'] @classmethod @@ -386,7 +386,7 @@ def DisabledSafetyCheckEnd(builder): -class Exported(object): +class Exported: __slots__ = ['_tab'] @classmethod diff --git a/jax/_src/export/shape_poly.py b/jax/_src/export/shape_poly.py index 43a827e8a269..d380bc5a2476 100644 --- a/jax/_src/export/shape_poly.py +++ b/jax/_src/export/shape_poly.py @@ -19,7 +19,7 @@ from __future__ import annotations import enum -from collections.abc import Sequence +from collections.abc import Callable, Sequence import dataclasses from enum import Enum import functools @@ -28,7 +28,7 @@ import copy import operator as op import tokenize -from typing import Any, Callable, Union, overload +from typing import Any, Union, overload import warnings import numpy as np diff --git a/jax/_src/extend/random.py b/jax/_src/extend/random.py index ffc7f63072b9..df927486dd2f 100644 --- a/jax/_src/extend/random.py +++ b/jax/_src/extend/random.py @@ -12,8 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable -from collections.abc import Hashable +from collections.abc import Callable, Hashable from jax import Array diff --git a/jax/_src/image/scale.py b/jax/_src/image/scale.py index 1ef1db916f73..aa9910555130 100644 --- a/jax/_src/image/scale.py +++ b/jax/_src/image/scale.py @@ -14,10 +14,9 @@ from __future__ import annotations -from collections.abc import Sequence +from collections.abc import Callable, Sequence from functools import partial import enum -from typing import Callable import numpy as np diff --git a/jax/_src/internal_test_util/export_back_compat_test_util.py b/jax/_src/internal_test_util/export_back_compat_test_util.py index 001a3cbac3be..5a975e3c5a61 100644 --- a/jax/_src/internal_test_util/export_back_compat_test_util.py +++ b/jax/_src/internal_test_util/export_back_compat_test_util.py @@ -70,13 +70,13 @@ def func(...): ... from __future__ import annotations -from collections.abc import Iterable, Sequence +from collections.abc import Callable, Iterable, Sequence import dataclasses import datetime import os import re import sys -from typing import Any, Callable +from typing import Any from absl import logging diff --git a/jax/_src/internal_test_util/test_harnesses.py b/jax/_src/internal_test_util/test_harnesses.py index 44b678d7ebea..2b22944c17b8 100644 --- a/jax/_src/internal_test_util/test_harnesses.py +++ b/jax/_src/internal_test_util/test_harnesses.py @@ -38,11 +38,11 @@ from __future__ import annotations -from collections.abc import Iterable, Sequence +from collections.abc import Callable, Iterable, Sequence import operator import os from functools import partial -from typing import Any, Callable, NamedTuple, Union +from typing import Any, NamedTuple, Union from absl import testing import numpy as np diff --git a/jax/_src/interpreters/ad.py b/jax/_src/interpreters/ad.py index 0106a5b77181..a527acb8db90 100644 --- a/jax/_src/interpreters/ad.py +++ b/jax/_src/interpreters/ad.py @@ -14,12 +14,12 @@ from __future__ import annotations -from collections.abc import Sequence +from collections.abc import Callable, Sequence import contextlib import functools import itertools as it from functools import partial -from typing import Any, Callable +from typing import Any import jax from jax._src import config diff --git a/jax/_src/interpreters/batching.py b/jax/_src/interpreters/batching.py index 87add6b74567..3a87fffa5116 100644 --- a/jax/_src/interpreters/batching.py +++ b/jax/_src/interpreters/batching.py @@ -14,10 +14,10 @@ from __future__ import annotations import collections -from collections.abc import Iterable, Sequence +from collections.abc import Callable, Iterable, Sequence import dataclasses from functools import partial -from typing import Any, Callable, Union +from typing import Any, Union import numpy as np diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index 251888966540..7eb826c95a67 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -16,7 +16,7 @@ from __future__ import annotations import collections -from collections.abc import Iterator, Sequence +from collections.abc import Callable, Iterator, Sequence import dataclasses import functools from functools import partial @@ -27,7 +27,7 @@ import re import types import typing -from typing import Any, Callable, NamedTuple, Protocol, Union, cast as type_cast +from typing import Any, NamedTuple, Protocol, Union, cast as type_cast import warnings import numpy as np diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py index d4b51731dc8d..497c9ea129a8 100644 --- a/jax/_src/interpreters/partial_eval.py +++ b/jax/_src/interpreters/partial_eval.py @@ -14,13 +14,13 @@ from __future__ import annotations from collections import namedtuple -from collections.abc import Sequence, Hashable +from collections.abc import Callable, Sequence, Hashable from contextlib import contextmanager, AbstractContextManager from functools import partial import inspect import itertools as it import operator as op -from typing import Any, Callable, NamedTuple, Union +from typing import Any, NamedTuple, Union from weakref import ref import numpy as np diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 2a1336b4cfcd..69d7c619b0a6 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -19,15 +19,14 @@ from contextlib import contextmanager import collections from collections import namedtuple -from collections.abc import Sequence, Iterable +from collections.abc import Callable, Sequence, Iterable, Iterator import dataclasses from functools import partial, lru_cache, cached_property import itertools as it import logging import math import threading -from typing import Any, Callable, NamedTuple, TypeVar, Union, cast -from collections.abc import Iterator +from typing import Any, NamedTuple, TypeVar, Union, cast import warnings import numpy as np diff --git a/jax/_src/interpreters/xla.py b/jax/_src/interpreters/xla.py index 314fdccfb975..2db877d3f970 100644 --- a/jax/_src/interpreters/xla.py +++ b/jax/_src/interpreters/xla.py @@ -17,12 +17,12 @@ from __future__ import annotations from collections import defaultdict -from collections.abc import Sequence +from collections.abc import Callable, Sequence import dataclasses import functools from functools import partial import itertools as it -from typing import Any, Callable, Protocol, Union +from typing import Any, Protocol, Union import numpy as np diff --git a/jax/_src/jaxpr_util.py b/jax/_src/jaxpr_util.py index 1358980d04ac..3f3f677b069d 100644 --- a/jax/_src/jaxpr_util.py +++ b/jax/_src/jaxpr_util.py @@ -17,11 +17,12 @@ from __future__ import annotations from collections import Counter, defaultdict +from collections.abc import Callable import gzip import itertools import json import types -from typing import Any, Callable, Union +from typing import Any, Union from jax._src import core from jax._src import util diff --git a/jax/_src/lax/control_flow/common.py b/jax/_src/lax/control_flow/common.py index f75175b79bb6..b613193876b6 100644 --- a/jax/_src/lax/control_flow/common.py +++ b/jax/_src/lax/control_flow/common.py @@ -15,10 +15,10 @@ from __future__ import annotations -from collections.abc import Sequence +from collections.abc import Callable, Sequence import os from functools import partial -from typing import Any, Callable +from typing import Any from jax._src import core from jax._src import linear_util as lu diff --git a/jax/_src/lax/control_flow/conditionals.py b/jax/_src/lax/control_flow/conditionals.py index 3150f972ce0d..ffe5d086a98d 100644 --- a/jax/_src/lax/control_flow/conditionals.py +++ b/jax/_src/lax/control_flow/conditionals.py @@ -15,13 +15,13 @@ from __future__ import annotations import collections -from collections.abc import Sequence +from collections.abc import Callable, Sequence import functools from functools import partial import inspect import itertools import operator -from typing import Any, Callable, TypeVar +from typing import Any, TypeVar import jax from jax.tree_util import tree_flatten, tree_unflatten diff --git a/jax/_src/lax/control_flow/for_loop.py b/jax/_src/lax/control_flow/for_loop.py index c20e22a2385b..936656b0e7df 100644 --- a/jax/_src/lax/control_flow/for_loop.py +++ b/jax/_src/lax/control_flow/for_loop.py @@ -15,10 +15,10 @@ from __future__ import annotations -from collections.abc import Sequence +from collections.abc import Callable, Sequence import functools import operator -from typing import Any, Callable, Generic, TypeVar +from typing import Any, Generic, TypeVar import jax.numpy as jnp from jax import lax diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py index a4ab60db00e9..0c704b84475b 100644 --- a/jax/_src/lax/control_flow/loops.py +++ b/jax/_src/lax/control_flow/loops.py @@ -14,12 +14,12 @@ """Module for the loop primitives.""" from __future__ import annotations -from collections.abc import Sequence +from collections.abc import Callable, Sequence from functools import partial import inspect import itertools import operator -from typing import Any, Callable, TypeVar +from typing import Any, TypeVar import weakref import jax diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 11af8366b44c..4d4abc1b18f7 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -15,14 +15,14 @@ from __future__ import annotations import builtins -from collections.abc import Sequence +from collections.abc import Callable, Sequence import enum import functools from functools import partial import itertools import math import operator -from typing import Any, Callable, ClassVar, TypeVar, Union, cast as type_cast, overload, TYPE_CHECKING +from typing import Any, ClassVar, TypeVar, Union, cast as type_cast, overload, TYPE_CHECKING import warnings import numpy as np @@ -2986,10 +2986,10 @@ def _ragged_dot_shape_rule(lhs: Array, rhs: Array, group_sizes: Array, **_) -> S m, k = lhs.shape group_count, rk, n = rhs.shape if k != rk: - raise TypeError("ragged_dot requires that lhs.shape[1] == rhs.shape[1]: got {} and {}.".format(k, rk)) + raise TypeError(f"ragged_dot requires that lhs.shape[1] == rhs.shape[1]: got {k} and {rk}.") num_groups = group_sizes.shape[0] if group_count != num_groups: - raise TypeError("ragged_dot requires that rhs.shape[0] == group_sizes.shape[0]: got {} and {}.".format(group_count, num_groups)) + raise TypeError(f"ragged_dot requires that rhs.shape[0] == group_sizes.shape[0]: got {group_count} and {num_groups}.") return (m, n) # DotDimensionNumbers used in the dot_general call for ragged_dot(). diff --git a/jax/_src/lax/linalg.py b/jax/_src/lax/linalg.py index c7a72b5d8bdc..d31bba99171c 100644 --- a/jax/_src/lax/linalg.py +++ b/jax/_src/lax/linalg.py @@ -14,10 +14,11 @@ from __future__ import annotations +from collections.abc import Callable import functools from functools import partial import math -from typing import Any, Callable, Literal, TypeVar, overload +from typing import Any, Literal, TypeVar, overload import numpy as np diff --git a/jax/_src/lax/slicing.py b/jax/_src/lax/slicing.py index 9676a9b07c77..b2bd30b3d364 100644 --- a/jax/_src/lax/slicing.py +++ b/jax/_src/lax/slicing.py @@ -14,12 +14,12 @@ from __future__ import annotations -from collections.abc import Sequence +from collections.abc import Callable, Sequence import enum import operator from functools import partial import math -from typing import Callable, NamedTuple +from typing import NamedTuple import weakref import numpy as np diff --git a/jax/_src/lax/windowed_reductions.py b/jax/_src/lax/windowed_reductions.py index 8a3fbf2c37bb..096fce7deb3a 100644 --- a/jax/_src/lax/windowed_reductions.py +++ b/jax/_src/lax/windowed_reductions.py @@ -14,9 +14,8 @@ from __future__ import annotations -from collections.abc import Sequence +from collections.abc import Callable, Sequence from functools import partial -from typing import Callable import warnings from jax import tree_util diff --git a/jax/_src/lazy_loader.py b/jax/_src/lazy_loader.py index 6041c77c65c0..cf6e68e49c81 100644 --- a/jax/_src/lazy_loader.py +++ b/jax/_src/lazy_loader.py @@ -14,9 +14,9 @@ """A LazyLoader class.""" -from collections.abc import Sequence +from collections.abc import Callable, Sequence import importlib -from typing import Any, Callable +from typing import Any def attach(package_name: str, submodules: Sequence[str]) -> tuple[ diff --git a/jax/_src/linear_util.py b/jax/_src/linear_util.py index 93bd3b9cfe2a..bc4cc242f055 100644 --- a/jax/_src/linear_util.py +++ b/jax/_src/linear_util.py @@ -63,8 +63,9 @@ def trans1(static_arg, *dynamic_args, **kwargs): """ from __future__ import annotations +from collections.abc import Callable from functools import partial -from typing import Any, Callable, NamedTuple +from typing import Any, NamedTuple import weakref from jax._src import config diff --git a/jax/_src/maps.py b/jax/_src/maps.py index d8c074ff705b..20fc54d8fe37 100644 --- a/jax/_src/maps.py +++ b/jax/_src/maps.py @@ -15,12 +15,12 @@ from __future__ import annotations from collections import OrderedDict, abc -from collections.abc import Iterable, Sequence, Mapping +from collections.abc import Callable, Iterable, Sequence, Mapping import contextlib from functools import wraps, partial, partialmethod, lru_cache import itertools as it import math -from typing import Callable, Any, NamedTuple, Union, cast as type_cast +from typing import Any, NamedTuple, Union, cast as type_cast import numpy as np diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 247435ff8f43..80eaf01dbd95 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -27,13 +27,13 @@ import builtins import collections -from collections.abc import Sequence +from collections.abc import Callable, Sequence from functools import partial import importlib import math import operator import types -from typing import (cast, overload, Any, Callable, Literal, NamedTuple, +from typing import (cast, overload, Any, Literal, NamedTuple, Protocol, TypeVar, Union) from textwrap import dedent as _dedent import warnings diff --git a/jax/_src/numpy/reductions.py b/jax/_src/numpy/reductions.py index 7fa667836e31..6abdf884b7e3 100644 --- a/jax/_src/numpy/reductions.py +++ b/jax/_src/numpy/reductions.py @@ -15,11 +15,11 @@ from __future__ import annotations import builtins -from collections.abc import Sequence +from collections.abc import Callable, Sequence from functools import partial import math import operator -from typing import overload, Any, Callable, Literal, Protocol, Union +from typing import overload, Any, Literal, Protocol, Union import warnings import numpy as np diff --git a/jax/_src/numpy/ufunc_api.py b/jax/_src/numpy/ufunc_api.py index 7d0769a193f4..2e114193af13 100644 --- a/jax/_src/numpy/ufunc_api.py +++ b/jax/_src/numpy/ufunc_api.py @@ -16,10 +16,11 @@ from __future__ import annotations +from collections.abc import Callable from functools import partial import math import operator -from typing import Any, Callable +from typing import Any import jax from jax._src.typing import Array, ArrayLike, DTypeLike diff --git a/jax/_src/numpy/ufuncs.py b/jax/_src/numpy/ufuncs.py index 673ff2c4d11d..1a75e413379a 100644 --- a/jax/_src/numpy/ufuncs.py +++ b/jax/_src/numpy/ufuncs.py @@ -18,9 +18,9 @@ from __future__ import annotations +from collections.abc import Callable from functools import partial import operator -from typing import Callable import warnings diff --git a/jax/_src/numpy/util.py b/jax/_src/numpy/util.py index fb3b7e4e9dc9..21b96deea3c6 100644 --- a/jax/_src/numpy/util.py +++ b/jax/_src/numpy/util.py @@ -13,11 +13,11 @@ # limitations under the License. from __future__ import annotations -from collections.abc import Sequence +from collections.abc import Callable, Sequence from functools import partial import re import textwrap -from typing import Any, Callable, NamedTuple, TypeVar +from typing import Any, NamedTuple, TypeVar import warnings diff --git a/jax/_src/numpy/vectorize.py b/jax/_src/numpy/vectorize.py index 3fe99131e6da..2c517467e287 100644 --- a/jax/_src/numpy/vectorize.py +++ b/jax/_src/numpy/vectorize.py @@ -13,10 +13,10 @@ # limitations under the License. from __future__ import annotations -from collections.abc import Collection, Sequence +from collections.abc import Callable, Collection, Sequence import functools import re -from typing import Any, Callable +from typing import Any from jax._src import api from jax import lax diff --git a/jax/_src/ops/scatter.py b/jax/_src/ops/scatter.py index b8efde364fc8..2bcfe96ad2f0 100644 --- a/jax/_src/ops/scatter.py +++ b/jax/_src/ops/scatter.py @@ -16,8 +16,8 @@ from __future__ import annotations -from collections.abc import Sequence -from typing import Callable, Union +from collections.abc import Callable, Sequence +from typing import Union import warnings import numpy as np diff --git a/jax/_src/pallas/core.py b/jax/_src/pallas/core.py index 97cfe0bf9656..38866b082da6 100644 --- a/jax/_src/pallas/core.py +++ b/jax/_src/pallas/core.py @@ -15,13 +15,13 @@ """Module for pallas-core functionality.""" from __future__ import annotations -from collections.abc import Iterator, Sequence +from collections.abc import Callable, Iterator, Sequence import copy import contextlib import dataclasses import functools import threading -from typing import Any, Callable, Union +from typing import Any, Union import jax from jax._src import api_util diff --git a/jax/_src/pallas/mosaic/core.py b/jax/_src/pallas/mosaic/core.py index cf8c5c47d08e..f4a794792253 100644 --- a/jax/_src/pallas/mosaic/core.py +++ b/jax/_src/pallas/mosaic/core.py @@ -78,7 +78,7 @@ def __eq__(self, other): return self.__class__ == other.__class__ def __hash__(self) -> int: - return hash((self.__class__)) + return hash(self.__class__) # TODO(sharadmv): implement dtype rules for AbstractSemaphoreTy @@ -109,7 +109,7 @@ def __call__(self, shape: tuple[int, ...]): dtype = SemaphoreTy() return MemoryRef(shape, dtype, TPUMemorySpace.SEMAPHORE) - def get_aval(self) -> "AbstractMemoryRef": + def get_aval(self) -> AbstractMemoryRef: return self(()).get_aval() @dataclasses.dataclass(frozen=True) diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index 853a97a8666a..3ebbfdb51b5b 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -15,11 +15,11 @@ """Module for lowering JAX to Mosaic-compatible MLIR dialects.""" from __future__ import annotations -from collections.abc import Sequence +from collections.abc import Callable, Sequence import dataclasses import functools import string -from typing import Any, Callable +from typing import Any import jax from jax import core as jax_core diff --git a/jax/_src/pallas/mosaic/pipeline.py b/jax/_src/pallas/mosaic/pipeline.py index d0d881281c85..0d778a60c711 100644 --- a/jax/_src/pallas/mosaic/pipeline.py +++ b/jax/_src/pallas/mosaic/pipeline.py @@ -15,12 +15,13 @@ """Module for emitting custom TPU pipelines within a Pallas call.""" from __future__ import annotations +from collections.abc import Sequence import dataclasses import enum import functools import itertools import operator -from typing import Optional, Union, Any, Sequence +from typing import Union, Any import jax from jax import lax @@ -201,12 +202,12 @@ class BufferedRef: spec: pl.BlockSpec # static metadata dtype: Any # static metadata buffer_type: BufferType # static metadata - vmem_ref: Optional[REF] - accum_ref: Optional[REF] - current_slot: Optional[ArrayRef] - next_slot: Optional[ArrayRef] - sem_recv: Optional[SemaphoreType] - sem_send: Optional[SemaphoreType] + vmem_ref: REF | None + accum_ref: REF | None + current_slot: ArrayRef | None + next_slot: ArrayRef | None + sem_recv: SemaphoreType | None + sem_send: SemaphoreType | None def tree_flatten(self): return ((self.vmem_ref, self.accum_ref, self.current_slot, @@ -218,7 +219,7 @@ def tree_unflatten(cls, meta, data): return cls(*meta, *data) @classmethod - def create(cls, spec, dtype, buffer_type) -> 'BufferedRef': + def create(cls, spec, dtype, buffer_type) -> BufferedRef: """Create a BufferedRef. Args: @@ -810,9 +811,9 @@ def _partition_grid( if isinstance(grid[i], int) and grid[i] % num_cores == 0 } if divisible_dimensions: - first_divisible_dimension, *_ = [ + first_divisible_dimension, *_ = ( i for i in range(len(dimension_semantics)) if i in divisible_dimensions - ] + ) partitioned_dim_size = grid[first_divisible_dimension] // num_cores partitioned_dim_offset = pl.program_id(core_axis) * partitioned_dim_size new_grid = jax_util.tuple_update( @@ -828,11 +829,11 @@ def _partition_grid( # potentially divide it more evenly largest_parallel_dimension = max(grid[i] for i in parallel_dimensions if isinstance(grid[i], int)) # type: ignore - partition_dimension, *_ = [ + partition_dimension, *_ = ( i for i, d in enumerate(grid) if isinstance(d, int) and d == largest_parallel_dimension - ] + ) base_num_iters, rem = divmod(grid[partition_dimension], num_cores) assert rem > 0, rem # We have some remainder iterations that we need to assign somewhere. We diff --git a/jax/_src/pallas/mosaic/primitives.py b/jax/_src/pallas/mosaic/primitives.py index 0adf75ad7274..f4c24e4e5e16 100644 --- a/jax/_src/pallas/mosaic/primitives.py +++ b/jax/_src/pallas/mosaic/primitives.py @@ -15,9 +15,10 @@ """Module for Pallas:TPU-specific JAX primitives and functions.""" from __future__ import annotations +from collections.abc import Callable import dataclasses import enum -from typing import Any, Callable +from typing import Any import jax from jax._src import api_util diff --git a/jax/_src/pallas/mosaic/random.py b/jax/_src/pallas/mosaic/random.py index 295196bc9c37..cc864c56f4e3 100644 --- a/jax/_src/pallas/mosaic/random.py +++ b/jax/_src/pallas/mosaic/random.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, Optional + +from collections.abc import Callable import jax import numpy as np @@ -172,7 +173,7 @@ def sample_block(sampler_fn: SampleFnType, block_size: Shape, tile_size: Shape, total_size: Shape, - block_index: Optional[tuple[typing.ArrayLike, ...]] = None, + block_index: tuple[typing.ArrayLike, ...] | None = None, **kwargs) -> jax.Array: """Samples a block of random values with invariance guarantees. diff --git a/jax/_src/pallas/pallas_call.py b/jax/_src/pallas/pallas_call.py index 4566a2818829..0748e78a2db9 100644 --- a/jax/_src/pallas/pallas_call.py +++ b/jax/_src/pallas/pallas_call.py @@ -15,10 +15,10 @@ """Module for calling pallas functions from JAX.""" from __future__ import annotations -from collections.abc import Sequence +from collections.abc import Callable, Sequence from functools import partial, reduce import itertools -from typing import Any, Callable +from typing import Any import jax from jax import api_util diff --git a/jax/_src/pallas/triton/lowering.py b/jax/_src/pallas/triton/lowering.py index 0e84a17db8a0..c270e8084f42 100644 --- a/jax/_src/pallas/triton/lowering.py +++ b/jax/_src/pallas/triton/lowering.py @@ -16,12 +16,12 @@ from __future__ import annotations -from collections.abc import Sequence +from collections.abc import Callable, Sequence import dataclasses import functools import math import operator -from typing import Any, Callable, TypeVar +from typing import Any, TypeVar import jax from jax import lax diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index cd472b7884bf..204c288d6993 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -15,7 +15,7 @@ from __future__ import annotations from collections import defaultdict -from collections.abc import Sequence, Iterable +from collections.abc import Callable, Sequence, Iterable import dataclasses from functools import partial import inspect @@ -23,7 +23,7 @@ import logging import operator as op import weakref -from typing import Callable, NamedTuple, Any, Union, Optional, cast +from typing import NamedTuple, Any, Union, cast import threading import warnings @@ -245,7 +245,7 @@ def _need_to_rebuild_with_fdo(pgle_profiler): def _get_fastpath_data( executable, out_tree, args_flat, out_flat, attrs_tracked, effects, consts, abstracted_axes, pgle_profiler -) -> Optional[pxla.MeshExecutableFastpathData]: +) -> pxla.MeshExecutableFastpathData | None: out_reflattened, out_tree = pxla.reflatten_outputs_for_dispatch(out_tree, out_flat) use_fastpath = ( @@ -608,7 +608,7 @@ def _infer_params_impl( assert None not in in_shardings_leaves assert None not in out_shardings_leaves - in_type: Union[core.InputType, tuple[core.AbstractValue, ...]] + in_type: core.InputType | tuple[core.AbstractValue, ...] if config.dynamic_shapes.value: in_type = pe.infer_lambda_input_type(axes_specs, explicit_args) in_avals = tuple(a for a, e in in_type if e) diff --git a/jax/_src/prng.py b/jax/_src/prng.py index bcbbe1790f70..d585b312fafc 100644 --- a/jax/_src/prng.py +++ b/jax/_src/prng.py @@ -13,11 +13,11 @@ # limitations under the License. from __future__ import annotations -from collections.abc import Iterator, Sequence +from collections.abc import Callable, Iterator, Sequence from functools import partial, reduce import math import operator as op -from typing import Any, Callable, NamedTuple +from typing import Any, NamedTuple import numpy as np diff --git a/jax/_src/profiler.py b/jax/_src/profiler.py index d761d50e6fa8..cad4826ba801 100644 --- a/jax/_src/profiler.py +++ b/jax/_src/profiler.py @@ -14,6 +14,7 @@ from __future__ import annotations +from collections.abc import Callable from contextlib import contextmanager from functools import wraps import glob @@ -24,7 +25,7 @@ import os import socketserver import threading -from typing import Callable, List, Optional, Union, Any +from typing import Any from jax._src import traceback_util traceback_util.register_exclusion(__file__) @@ -210,7 +211,7 @@ def stop_trace(): _profile_state.reset() -def stop_and_get_fdo_profile() -> Union[bytes, str]: +def stop_and_get_fdo_profile() -> bytes | str: """Stops the currently-running profiler trace and export fdo_profile. Currently, this is only supported for GPU. @@ -391,10 +392,10 @@ def __init__(self, retries: int, percentile: int): self.percentile: int = percentile self.collected_fdo: str | None = None self.called_times: int = 0 - self.fdo_profiles: List[Any] = [] + self.fdo_profiles: list[Any] = [] self.current_session: xla_client.profiler.ProfilerSession | None = None - def consume_fdo_profile(self) -> Optional[str]: + def consume_fdo_profile(self) -> str | None: if self.collected_fdo is not None: return self.collected_fdo diff --git a/jax/_src/scipy/ndimage.py b/jax/_src/scipy/ndimage.py index e64ca7d9152a..d81008308b94 100644 --- a/jax/_src/scipy/ndimage.py +++ b/jax/_src/scipy/ndimage.py @@ -12,11 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from collections.abc import Sequence +from collections.abc import Callable, Sequence import functools import itertools import operator -from typing import Callable from jax._src import api from jax._src import util diff --git a/jax/_src/scipy/optimize/_lbfgs.py b/jax/_src/scipy/optimize/_lbfgs.py index 4df9647debe5..aa82ab4fd0c8 100644 --- a/jax/_src/scipy/optimize/_lbfgs.py +++ b/jax/_src/scipy/optimize/_lbfgs.py @@ -15,8 +15,9 @@ from __future__ import annotations -from typing import Callable, NamedTuple +from collections.abc import Callable from functools import partial +from typing import NamedTuple import jax import jax.numpy as jnp diff --git a/jax/_src/scipy/optimize/bfgs.py b/jax/_src/scipy/optimize/bfgs.py index b6fd9f9dda17..657b7610e6e1 100644 --- a/jax/_src/scipy/optimize/bfgs.py +++ b/jax/_src/scipy/optimize/bfgs.py @@ -15,8 +15,9 @@ from __future__ import annotations +from collections.abc import Callable from functools import partial -from typing import Callable, NamedTuple +from typing import NamedTuple import jax import jax.numpy as jnp diff --git a/jax/_src/scipy/optimize/minimize.py b/jax/_src/scipy/optimize/minimize.py index 830f1228424a..4fc006be6df0 100644 --- a/jax/_src/scipy/optimize/minimize.py +++ b/jax/_src/scipy/optimize/minimize.py @@ -14,8 +14,8 @@ from __future__ import annotations -from collections.abc import Mapping -from typing import Any, Callable +from collections.abc import Callable, Mapping +from typing import Any import jax from jax._src.scipy.optimize.bfgs import minimize_bfgs diff --git a/jax/_src/scipy/signal.py b/jax/_src/scipy/signal.py index 737378faef85..1282650ae1e5 100644 --- a/jax/_src/scipy/signal.py +++ b/jax/_src/scipy/signal.py @@ -14,11 +14,10 @@ from __future__ import annotations -from collections.abc import Sequence +from collections.abc import Callable, Sequence from functools import partial import math import operator -from typing import Callable import warnings import numpy as np diff --git a/jax/_src/sourcemap.py b/jax/_src/sourcemap.py index 276a39e26444..b54f2193ff26 100644 --- a/jax/_src/sourcemap.py +++ b/jax/_src/sourcemap.py @@ -18,9 +18,10 @@ from __future__ import annotations +from collections.abc import Iterable, Sequence from dataclasses import dataclass import json -from typing import Iterable, Sequence, Union +from typing import Union # A Segment encodes how parts in the generated source relate to the original source. # Each segment is made up of 1, 4 or 5 variable-length fields. For their semantics see diff --git a/jax/_src/stages.py b/jax/_src/stages.py index 93e30be453d7..874ef8834557 100644 --- a/jax/_src/stages.py +++ b/jax/_src/stages.py @@ -315,7 +315,7 @@ class XlaLowering(Lowering): def hlo(self) -> xc.XlaComputation: """Return an HLO representation of this computation.""" hlo = self.stablehlo() - m: Union[str, bytes] + m: str | bytes m = mlir.module_to_bytecode(hlo) return xla_extension.mlir.mlir_module_to_xla_computation( m, use_tuple_args=self.compile_args["tuple_args"]) diff --git a/jax/_src/state/discharge.py b/jax/_src/state/discharge.py index 09488c8bb165..f3a3e61a2ace 100644 --- a/jax/_src/state/discharge.py +++ b/jax/_src/state/discharge.py @@ -14,11 +14,11 @@ """Module for discharging state primitives.""" from __future__ import annotations -from collections.abc import Sequence +from collections.abc import Callable, Sequence import dataclasses from functools import partial import operator -from typing import Any, Callable, Protocol +from typing import Any, Protocol import numpy as np diff --git a/jax/_src/state/indexing.py b/jax/_src/state/indexing.py index df340c03247c..acf1c7216240 100644 --- a/jax/_src/state/indexing.py +++ b/jax/_src/state/indexing.py @@ -65,9 +65,9 @@ def tree_flatten(self): @classmethod def tree_unflatten(cls, aux_data, children) -> Slice: - start, size = [ + start, size = ( a if a is not None else b for a, b in zip(children, aux_data[:2]) - ] + ) return cls(start, size, aux_data[2]) @classmethod diff --git a/jax/_src/test_util.py b/jax/_src/test_util.py index aca5ce6e67db..edd769aff5c6 100644 --- a/jax/_src/test_util.py +++ b/jax/_src/test_util.py @@ -16,7 +16,7 @@ from __future__ import annotations import collections -from collections.abc import Generator, Iterable, Sequence +from collections.abc import Callable, Generator, Iterable, Sequence from contextlib import ExitStack, contextmanager import datetime import functools @@ -28,7 +28,7 @@ import sys import tempfile import textwrap -from typing import Any, Callable +from typing import Any import unittest import warnings import zlib diff --git a/jax/_src/third_party/scipy/linalg.py b/jax/_src/third_party/scipy/linalg.py index 9eb6c3caba5a..dce4df1fb817 100644 --- a/jax/_src/third_party/scipy/linalg.py +++ b/jax/_src/third_party/scipy/linalg.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Callable +from collections.abc import Callable from jax import jit, lax import jax.numpy as jnp diff --git a/jax/_src/tpu_custom_call.py b/jax/_src/tpu_custom_call.py index c53042bce3f5..14721cea7682 100644 --- a/jax/_src/tpu_custom_call.py +++ b/jax/_src/tpu_custom_call.py @@ -19,13 +19,13 @@ import base64 import collections.abc -from collections.abc import Sequence +from collections.abc import Callable, Sequence import dataclasses import functools import io import os import time -from typing import Any, Callable +from typing import Any import jax from jax import core diff --git a/jax/_src/traceback_util.py b/jax/_src/traceback_util.py index 0a6ef8da4263..d66cbb912a99 100644 --- a/jax/_src/traceback_util.py +++ b/jax/_src/traceback_util.py @@ -14,12 +14,13 @@ from __future__ import annotations +from collections.abc import Callable import functools import os import sys import traceback import types -from typing import Any, Callable, TypeVar, cast +from typing import Any, TypeVar, cast from jax._src import config from jax._src import util diff --git a/jax/_src/tree.py b/jax/_src/tree.py index 1c3fae44a2b8..49faaa774ef2 100644 --- a/jax/_src/tree.py +++ b/jax/_src/tree.py @@ -13,7 +13,8 @@ # limitations under the License. from __future__ import annotations -from typing import Any, Callable, Iterable, TypeVar, overload +from collections.abc import Callable, Iterable +from typing import Any, TypeVar, overload from jax._src import tree_util diff --git a/jax/_src/tree_util.py b/jax/_src/tree_util.py index 925d7ce2ec30..32f59b1df36e 100644 --- a/jax/_src/tree_util.py +++ b/jax/_src/tree_util.py @@ -14,14 +14,14 @@ from __future__ import annotations import collections -from collections.abc import Hashable, Iterable +from collections.abc import Callable, Hashable, Iterable, Sequence from dataclasses import dataclass import difflib import functools from functools import partial import operator as op import textwrap -from typing import Any, Callable, NamedTuple, Sequence, TypeVar, Union, overload +from typing import Any, NamedTuple, TypeVar, Union, overload from jax._src import traceback_util from jax._src.lib import pytree diff --git a/jax/_src/util.py b/jax/_src/util.py index 3bdcd298e7a7..7aab80b2def3 100644 --- a/jax/_src/util.py +++ b/jax/_src/util.py @@ -15,14 +15,14 @@ from __future__ import annotations import abc -from collections.abc import Iterable, Iterator, Sequence +from collections.abc import Callable, Iterable, Iterator, Sequence import dataclasses import functools from functools import partial import itertools as it import logging import operator -from typing import (Any, Callable, Generic, TypeVar, overload, TYPE_CHECKING, cast) +from typing import (Any, Generic, TypeVar, overload, TYPE_CHECKING, cast) import weakref import numpy as np diff --git a/jax/_src/xla_bridge.py b/jax/_src/xla_bridge.py index 87925c142e50..41fd2c586593 100644 --- a/jax/_src/xla_bridge.py +++ b/jax/_src/xla_bridge.py @@ -21,7 +21,7 @@ from __future__ import annotations import atexit -from collections.abc import Mapping +from collections.abc import Callable, Mapping import dataclasses from functools import lru_cache, partial import importlib @@ -32,7 +32,7 @@ import platform as py_platform import threading import traceback -from typing import Any, Callable, Union +from typing import Any, Union import warnings from jax._src import config diff --git a/jax/example_libraries/optimizers.py b/jax/example_libraries/optimizers.py index 8bf57792ffd8..71680ca61b96 100644 --- a/jax/example_libraries/optimizers.py +++ b/jax/example_libraries/optimizers.py @@ -91,7 +91,8 @@ def step(step, opt_state): from __future__ import annotations -from typing import Any, Callable, NamedTuple +from collections.abc import Callable +from typing import Any, NamedTuple from collections import namedtuple import functools diff --git a/jax/experimental/array_api/_utility_functions.py b/jax/experimental/array_api/_utility_functions.py index c5dac25fd8c6..f75b2e2e29af 100644 --- a/jax/experimental/array_api/_utility_functions.py +++ b/jax/experimental/array_api/_utility_functions.py @@ -15,7 +15,6 @@ from __future__ import annotations import jax -from typing import Tuple from jax._src.sharding import Sharding from jax._src.lib import xla_client as xc from jax._src import dtypes as _dtypes, config @@ -71,7 +70,7 @@ def default_dtypes(self, *, device: xc.Device | Sharding | None = None): def dtypes( self, *, device: xc.Device | Sharding | None = None, - kind: str | Tuple[str, ...] | None = None): + kind: str | tuple[str, ...] | None = None): # Array API supported dtypes are device-independent in JAX del device data_types = self._build_dtype_dict() diff --git a/jax/experimental/array_serialization/serialization.py b/jax/experimental/array_serialization/serialization.py index f9e8e2fafe4f..c7aa8b590412 100644 --- a/jax/experimental/array_serialization/serialization.py +++ b/jax/experimental/array_serialization/serialization.py @@ -17,16 +17,15 @@ import abc import asyncio -from collections.abc import Awaitable, Sequence +from collections.abc import Awaitable, Callable, Sequence from functools import partial import itertools import logging import os import re -import sys import threading import time -from typing import Any, Callable, Optional, Union +from typing import Any import jax from jax._src import array @@ -130,7 +129,7 @@ def get_tensorstore_spec(ckpt_path: str, ocdbt: bool = False): return spec -def is_remote_storage(tspec: Union[dict[str, Any], str]) -> bool: +def is_remote_storage(tspec: dict[str, Any] | str) -> bool: """Detect if user is using cloud storages. This can detect common defines and unable to detect some corner cases such as @@ -190,7 +189,7 @@ async def async_serialize( tensorstore_spec, commit_future=None, context=TS_CONTEXT, - primary_host: Optional[int] = 0, + primary_host: int | None = 0, replica_id: int = 0, ): """Serialize an array using TensorStore. diff --git a/jax/experimental/host_callback.py b/jax/experimental/host_callback.py index a38261d57086..deaac9c72c8b 100644 --- a/jax/experimental/host_callback.py +++ b/jax/experimental/host_callback.py @@ -503,14 +503,14 @@ def power3_with_cotangents(x): import atexit import enum -from collections.abc import Sequence +from collections.abc import Callable, Sequence import functools import itertools import logging import math import threading import traceback -from typing import Any, Callable, cast +from typing import Any, cast import jax from jax._src import api diff --git a/jax/experimental/jax2tf/call_tf.py b/jax/experimental/jax2tf/call_tf.py index 0ce5bcb170f2..adf43b6b94c0 100644 --- a/jax/experimental/jax2tf/call_tf.py +++ b/jax/experimental/jax2tf/call_tf.py @@ -25,10 +25,10 @@ from __future__ import annotations -from collections.abc import Sequence +from collections.abc import Callable, Sequence import dataclasses import functools -from typing import Any, Callable, Optional +from typing import Any from absl import logging import jax diff --git a/jax/experimental/jax2tf/examples/mnist_lib.py b/jax/experimental/jax2tf/examples/mnist_lib.py index 1dae2752ffd3..41173c79a5b9 100644 --- a/jax/experimental/jax2tf/examples/mnist_lib.py +++ b/jax/experimental/jax2tf/examples/mnist_lib.py @@ -21,12 +21,12 @@ from __future__ import annotations -from collections.abc import Sequence +from collections.abc import Callable, Sequence import functools import logging import re import time -from typing import Any, Callable, Optional +from typing import Any from absl import flags import flax diff --git a/jax/experimental/jax2tf/examples/saved_model_lib.py b/jax/experimental/jax2tf/examples/saved_model_lib.py index f0fa145728fe..8f2f0982fd3d 100644 --- a/jax/experimental/jax2tf/examples/saved_model_lib.py +++ b/jax/experimental/jax2tf/examples/saved_model_lib.py @@ -26,8 +26,8 @@ from __future__ import annotations -from collections.abc import Sequence -from typing import Any, Callable, Optional, Union +from collections.abc import Callable, Sequence +from typing import Any from jax.experimental import jax2tf import tensorflow as tf diff --git a/jax/experimental/jax2tf/impl_no_xla.py b/jax/experimental/jax2tf/impl_no_xla.py index 3305654f2243..5ecde602cdaa 100644 --- a/jax/experimental/jax2tf/impl_no_xla.py +++ b/jax/experimental/jax2tf/impl_no_xla.py @@ -16,12 +16,12 @@ from __future__ import annotations import builtins -from collections.abc import Sequence +from collections.abc import Callable, Sequence import dataclasses from functools import partial, wraps import math import string -from typing import Any, Callable, Optional +from typing import Any from jax._src import core from jax import lax diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index b53ffda9c5d1..5f3230599a25 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -15,7 +15,7 @@ from __future__ import annotations -from collections.abc import Iterable, Sequence +from collections.abc import Callable, Iterable, Sequence from functools import partial import contextlib import math @@ -23,7 +23,7 @@ import os import re import threading -from typing import Any, Callable, Union +from typing import Any, Union import warnings from absl import logging diff --git a/jax/experimental/jax2tf/tests/back_compat_tf_test.py b/jax/experimental/jax2tf/tests/back_compat_tf_test.py index 56091b2c7eae..7f903b70d987 100644 --- a/jax/experimental/jax2tf/tests/back_compat_tf_test.py +++ b/jax/experimental/jax2tf/tests/back_compat_tf_test.py @@ -20,11 +20,10 @@ from __future__ import annotations import base64 -from collections.abc import Sequence +from collections.abc import Callable, Sequence import io import os import tarfile -from typing import Callable, Optional from absl.testing import absltest import jax diff --git a/jax/experimental/jax2tf/tests/call_tf_test.py b/jax/experimental/jax2tf/tests/call_tf_test.py index cd66cb9ff41b..5740b76038d8 100644 --- a/jax/experimental/jax2tf/tests/call_tf_test.py +++ b/jax/experimental/jax2tf/tests/call_tf_test.py @@ -13,10 +13,10 @@ # limitations under the License. """Tests for call_tf.""" +from collections.abc import Callable import contextlib from functools import partial import os -from typing import Callable import unittest from absl import logging diff --git a/jax/experimental/jax2tf/tests/converters.py b/jax/experimental/jax2tf/tests/converters.py index f0a293ca52d5..1ed017fc0819 100644 --- a/jax/experimental/jax2tf/tests/converters.py +++ b/jax/experimental/jax2tf/tests/converters.py @@ -12,10 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. """Converters for jax2tf.""" + +from collections.abc import Callable import dataclasses import functools import tempfile -from typing import Any, Callable +from typing import Any + from jax.experimental import jax2tf import tensorflow as tf import tensorflowjs as tfjs diff --git a/jax/experimental/jax2tf/tests/cross_compilation_check.py b/jax/experimental/jax2tf/tests/cross_compilation_check.py index 0a4bf61f8847..cc34d78e88d4 100644 --- a/jax/experimental/jax2tf/tests/cross_compilation_check.py +++ b/jax/experimental/jax2tf/tests/cross_compilation_check.py @@ -26,12 +26,11 @@ from __future__ import annotations -from collections.abc import Sequence +from collections.abc import Callable, Sequence import contextlib import dataclasses import os import re -from typing import Callable, Optional import zlib from absl import app diff --git a/jax/experimental/jax2tf/tests/flax_models/bilstm_classifier.py b/jax/experimental/jax2tf/tests/flax_models/bilstm_classifier.py index 9bb7466125c5..5b1169224ed9 100644 --- a/jax/experimental/jax2tf/tests/flax_models/bilstm_classifier.py +++ b/jax/experimental/jax2tf/tests/flax_models/bilstm_classifier.py @@ -18,8 +18,9 @@ from __future__ import annotations +from collections.abc import Callable import functools -from typing import Any, Callable, Optional +from typing import Any from flax import linen as nn import jax diff --git a/jax/experimental/jax2tf/tests/flax_models/gnn.py b/jax/experimental/jax2tf/tests/flax_models/gnn.py index 6746da7a2700..4a74be446ba1 100644 --- a/jax/experimental/jax2tf/tests/flax_models/gnn.py +++ b/jax/experimental/jax2tf/tests/flax_models/gnn.py @@ -16,8 +16,7 @@ https://github.com/google/flax/tree/main/examples/ogbg_molpcba """ -from collections.abc import Sequence -from typing import Callable +from collections.abc import Callable, Sequence from flax import linen as nn diff --git a/jax/experimental/jax2tf/tests/flax_models/resnet.py b/jax/experimental/jax2tf/tests/flax_models/resnet.py index bb6e519deceb..48829127b304 100644 --- a/jax/experimental/jax2tf/tests/flax_models/resnet.py +++ b/jax/experimental/jax2tf/tests/flax_models/resnet.py @@ -19,9 +19,9 @@ # See issue #620. # pytype: disable=wrong-arg-count -from collections.abc import Sequence +from collections.abc import Callable, Sequence from functools import partial -from typing import Any, Callable +from typing import Any from flax import linen as nn import jax.numpy as jnp diff --git a/jax/experimental/jax2tf/tests/flax_models/transformer_lm1b.py b/jax/experimental/jax2tf/tests/flax_models/transformer_lm1b.py index 334248219962..27535c784e89 100644 --- a/jax/experimental/jax2tf/tests/flax_models/transformer_lm1b.py +++ b/jax/experimental/jax2tf/tests/flax_models/transformer_lm1b.py @@ -24,7 +24,8 @@ from __future__ import annotations -from typing import Callable, Any, Optional +from collections.abc import Callable +from typing import Any from flax import linen as nn from flax import struct diff --git a/jax/experimental/jax2tf/tests/flax_models/transformer_nlp_seq.py b/jax/experimental/jax2tf/tests/flax_models/transformer_nlp_seq.py index 04111e6c4d5b..cc78b5a41496 100644 --- a/jax/experimental/jax2tf/tests/flax_models/transformer_nlp_seq.py +++ b/jax/experimental/jax2tf/tests/flax_models/transformer_nlp_seq.py @@ -18,7 +18,8 @@ from __future__ import annotations -from typing import Callable, Any, Optional +from collections.abc import Callable +from typing import Any from flax import linen as nn from flax import struct diff --git a/jax/experimental/jax2tf/tests/flax_models/transformer_wmt.py b/jax/experimental/jax2tf/tests/flax_models/transformer_wmt.py index 58e50dacd914..1cdeffeb6ea9 100644 --- a/jax/experimental/jax2tf/tests/flax_models/transformer_wmt.py +++ b/jax/experimental/jax2tf/tests/flax_models/transformer_wmt.py @@ -24,7 +24,8 @@ from __future__ import annotations -from typing import Callable, Any, Optional +from collections.abc import Callable +from typing import Any from flax import linen as nn from flax import struct diff --git a/jax/experimental/jax2tf/tests/jax2tf_limitations.py b/jax/experimental/jax2tf/tests/jax2tf_limitations.py index d184cb4e5dcc..03e6086a4924 100644 --- a/jax/experimental/jax2tf/tests/jax2tf_limitations.py +++ b/jax/experimental/jax2tf/tests/jax2tf_limitations.py @@ -15,9 +15,9 @@ from __future__ import annotations -from collections.abc import Sequence +from collections.abc import Callable, Sequence import itertools -from typing import Any, Callable, Optional, Union +from typing import Any import jax from jax import lax diff --git a/jax/experimental/jax2tf/tests/model_harness.py b/jax/experimental/jax2tf/tests/model_harness.py index 9af7229c0530..91aacf2f596f 100644 --- a/jax/experimental/jax2tf/tests/model_harness.py +++ b/jax/experimental/jax2tf/tests/model_harness.py @@ -15,10 +15,10 @@ from __future__ import annotations -from collections.abc import Sequence +from collections.abc import Callable, Sequence import dataclasses import functools -from typing import Any, Callable, Optional, Union +from typing import Any import re import numpy as np diff --git a/jax/experimental/jax2tf/tests/shape_poly_test.py b/jax/experimental/jax2tf/tests/shape_poly_test.py index e1d2b1fcefd0..83aac43f2d9d 100644 --- a/jax/experimental/jax2tf/tests/shape_poly_test.py +++ b/jax/experimental/jax2tf/tests/shape_poly_test.py @@ -15,10 +15,10 @@ from __future__ import annotations -from collections.abc import Sequence +from collections.abc import Callable, Sequence import contextlib import math -from typing import Any, Callable +from typing import Any import unittest from absl import logging diff --git a/jax/experimental/jax2tf/tests/tf_test_util.py b/jax/experimental/jax2tf/tests/tf_test_util.py index df028a700f6f..32f89e533daf 100644 --- a/jax/experimental/jax2tf/tests/tf_test_util.py +++ b/jax/experimental/jax2tf/tests/tf_test_util.py @@ -14,12 +14,12 @@ from __future__ import annotations -from collections.abc import Sequence +from collections.abc import Callable, Sequence import contextlib import dataclasses import re import os -from typing import Any, Callable, Optional +from typing import Any from absl.testing import absltest from absl import logging diff --git a/jax/experimental/jet.py b/jax/experimental/jet.py index ac23debd6d27..1ed6183b1229 100644 --- a/jax/experimental/jet.py +++ b/jax/experimental/jet.py @@ -52,7 +52,8 @@ `outstanding primitive rules `__. """ -from typing import Any, Callable +from collections.abc import Callable +from typing import Any from functools import partial diff --git a/jax/experimental/key_reuse/_core.py b/jax/experimental/key_reuse/_core.py index 4250ba9b2677..b4989e151a53 100644 --- a/jax/experimental/key_reuse/_core.py +++ b/jax/experimental/key_reuse/_core.py @@ -15,8 +15,9 @@ from __future__ import annotations from collections import defaultdict +from collections.abc import Callable, Iterator from functools import partial, reduce, total_ordering, wraps -from typing import Any, Callable, Iterator, NamedTuple +from typing import Any, NamedTuple import jax from jax import lax diff --git a/jax/experimental/mesh_utils.py b/jax/experimental/mesh_utils.py index 7157be1e2c9d..dd112db3b269 100644 --- a/jax/experimental/mesh_utils.py +++ b/jax/experimental/mesh_utils.py @@ -17,11 +17,11 @@ from __future__ import annotations import collections -from collections.abc import Sequence +from collections.abc import Callable, Generator, MutableMapping, Sequence import itertools import logging import math -from typing import Any, Callable, Generator, MutableMapping +from typing import Any from jax._src import xla_bridge as xb import numpy as np diff --git a/jax/experimental/mosaic/gpu/__init__.py b/jax/experimental/mosaic/gpu/__init__.py index 7df477ab6ffa..2cf9e9d4ff39 100644 --- a/jax/experimental/mosaic/gpu/__init__.py +++ b/jax/experimental/mosaic/gpu/__init__.py @@ -13,7 +13,7 @@ # limitations under the License. # ============================================================================== -from collections.abc import Callable +from collections.abc import Callable, Sequence import contextlib import ctypes import dataclasses @@ -24,7 +24,7 @@ import subprocess import tempfile import time -from typing import Any, Generic, Sequence, TypeVar +from typing import Any, Generic, TypeVar import jax from jax._src import config diff --git a/jax/experimental/mosaic/gpu/utils.py b/jax/experimental/mosaic/gpu/utils.py index a620eb1f1302..c9de1eb29985 100644 --- a/jax/experimental/mosaic/gpu/utils.py +++ b/jax/experimental/mosaic/gpu/utils.py @@ -14,12 +14,12 @@ # ============================================================================== """Utilities for code generator.""" -from collections.abc import Iterator +from collections.abc import Iterator, Sequence import contextlib import dataclasses import enum import functools -from typing import Any, Literal, Sequence +from typing import Any, Literal import jax from jaxlib.mlir import ir diff --git a/jax/experimental/pallas/ops/tpu/megablox/gmm.py b/jax/experimental/pallas/ops/tpu/megablox/gmm.py index e9072bbe1146..ba8ca6c1b617 100644 --- a/jax/experimental/pallas/ops/tpu/megablox/gmm.py +++ b/jax/experimental/pallas/ops/tpu/megablox/gmm.py @@ -14,8 +14,9 @@ """Grouped matrix multiplication kernels for TPU written in Pallas.""" +from collections.abc import Callable import functools -from typing import Any, Callable, Optional, Union +from typing import Any, Optional import jax from jax import lax @@ -315,9 +316,9 @@ def gmm( rhs: jnp.ndarray, group_sizes: jnp.ndarray, preferred_element_type: jnp.dtype = jnp.float32, - tiling: Optional[Union[tuple[int, int, int], LutFn]] = (128, 128, 128), - group_offset: Optional[jnp.ndarray] = None, - existing_out: Optional[jnp.ndarray] = None, + tiling: tuple[int, int, int] | LutFn | None = (128, 128, 128), + group_offset: jnp.ndarray | None = None, + existing_out: jnp.ndarray | None = None, transpose_rhs: bool = False, interpret: bool = False, ) -> jnp.ndarray: @@ -577,10 +578,10 @@ def tgmm( rhs: jnp.ndarray, group_sizes: jnp.ndarray, preferred_element_type: jnp.dtype = jnp.float32, - tiling: Optional[Union[tuple[int, int, int], LutFn]] = (128, 128, 128), - group_offset: Optional[jnp.ndarray] = None, - num_actual_groups: Optional[int] = None, - existing_out: Optional[jnp.ndarray] = None, + tiling: tuple[int, int, int] | LutFn | None = (128, 128, 128), + group_offset: jnp.ndarray | None = None, + num_actual_groups: int | None = None, + existing_out: jnp.ndarray | None = None, interpret: bool = False, ) -> jnp.ndarray: """Compute lhs[:, sizes[i-1]:sizes[i]] @ rhs[sizes[i-1]:sizes[i], :]. diff --git a/jax/experimental/pallas/ops/tpu/megablox/ops.py b/jax/experimental/pallas/ops/tpu/megablox/ops.py index 874951db0452..015c6b3ade67 100644 --- a/jax/experimental/pallas/ops/tpu/megablox/ops.py +++ b/jax/experimental/pallas/ops/tpu/megablox/ops.py @@ -14,8 +14,6 @@ """Grouped matrix multiplication operations with custom VJPs.""" -from typing import Optional - import jax from jax.experimental.pallas.ops.tpu.megablox import gmm as backend import jax.numpy as jnp @@ -33,8 +31,8 @@ def _gmm_fwd( group_sizes: jnp.ndarray, preferred_element_type: jnp.dtype = jnp.float32, tiling: tuple[int, int, int] = (128, 128, 128), - group_offset: Optional[jnp.ndarray] = None, - existing_out: Optional[jnp.ndarray] = None, + group_offset: jnp.ndarray | None = None, + existing_out: jnp.ndarray | None = None, transpose_rhs: bool = False, interpret: bool = False, ) -> tuple[ @@ -43,7 +41,7 @@ def _gmm_fwd( jnp.ndarray, jnp.ndarray, jnp.ndarray, - Optional[jnp.ndarray], + jnp.ndarray | None, int, ], ]: @@ -71,7 +69,7 @@ def _gmm_bwd( jnp.ndarray, jnp.ndarray, jnp.ndarray, - Optional[jnp.ndarray], + jnp.ndarray | None, int, ], grad: jnp.ndarray, diff --git a/jax/experimental/pallas/ops/tpu/paged_attention/paged_attention_kernel.py b/jax/experimental/pallas/ops/tpu/paged_attention/paged_attention_kernel.py index f4ea3f9a0a5e..7d47cc3d0efa 100644 --- a/jax/experimental/pallas/ops/tpu/paged_attention/paged_attention_kernel.py +++ b/jax/experimental/pallas/ops/tpu/paged_attention/paged_attention_kernel.py @@ -15,7 +15,6 @@ """PagedAttention TPU kernel.""" import functools -from typing import Optional, Union import jax from jax import lax @@ -364,14 +363,14 @@ def body(i, _): ) def paged_attention( q: jax.Array, - k_pages: Union[jax.Array, quantization_utils.QuantizedTensor], - v_pages: Union[jax.Array, quantization_utils.QuantizedTensor], + k_pages: jax.Array | quantization_utils.QuantizedTensor, + v_pages: jax.Array | quantization_utils.QuantizedTensor, lengths: jax.Array, page_indices: jax.Array, *, mask_value: float = DEFAULT_MASK_VALUE, pages_per_compute_block: int, - megacore_mode: Optional[str] = None, + megacore_mode: str | None = None, inline_seq_dim: bool = True, ) -> jax.Array: """Paged grouped query attention. diff --git a/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py b/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py index 75326c65c818..a6c0715e6043 100644 --- a/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py +++ b/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py @@ -16,11 +16,11 @@ from __future__ import annotations -from collections.abc import Mapping +from collections.abc import Callable, Mapping import dataclasses import enum import functools -from typing import Any, Callable, Literal, NamedTuple, Optional, Union, overload +from typing import Any, Literal, NamedTuple, Optional, Union, overload import jax from jax import ad_checkpoint diff --git a/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_mask.py b/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_mask.py index e65d9b073a18..eab2a695dc02 100644 --- a/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_mask.py +++ b/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_mask.py @@ -16,8 +16,9 @@ from __future__ import annotations +from collections.abc import Callable, Sequence import dataclasses -from typing import Any, Callable, Sequence, Tuple +from typing import Any import numpy as np # mypy: ignore-errors @@ -26,7 +27,7 @@ class Mask: """A base class for splash attention masks.""" @property - def shape(self) -> Tuple[int, ...]: + def shape(self) -> tuple[int, ...]: raise NotImplementedError def __getitem__(self, idx) -> np.ndarray: @@ -38,14 +39,14 @@ def __bool__(self) -> bool: ' instead of bitwise operations on masks.' ) - def __or__(self, other: 'Mask') -> 'Mask': + def __or__(self, other: Mask) -> Mask: if self.shape != other.shape: raise ValueError( f'Invalid shape for other: {other.shape}, expected: {self.shape}' ) return LogicalOr(self, other) - def __and__(self, other: 'Mask') -> 'Mask': + def __and__(self, other: Mask) -> Mask: if self.shape != other.shape: raise ValueError( f'Invalid shape for other: {other.shape}, expected: {self.shape}' @@ -53,7 +54,7 @@ def __and__(self, other: 'Mask') -> 'Mask': return LogicalAnd(self, other) -def make_causal_mask(shape: Tuple[int, int], offset: int = 0) -> np.ndarray: +def make_causal_mask(shape: tuple[int, int], offset: int = 0) -> np.ndarray: """Makes a causal attention mask. Args: @@ -73,8 +74,8 @@ def make_causal_mask(shape: Tuple[int, int], offset: int = 0) -> np.ndarray: def make_local_attention_mask( - shape: Tuple[int, int], - window_size: Tuple[int | None, int | None], + shape: tuple[int, int], + window_size: tuple[int | None, int | None], *, offset: int = 0, ) -> np.ndarray: @@ -92,7 +93,7 @@ def make_local_attention_mask( def make_random_mask( - shape: Tuple[int, int], sparsity: float, seed: int + shape: tuple[int, int], sparsity: float, seed: int ) -> np.ndarray: """Makes a random attention mask.""" np.random.seed(seed) @@ -111,7 +112,7 @@ def __init__(self, left: Mask, right: Mask): self.right = right @property - def shape(self) -> Tuple[int, ...]: + def shape(self) -> tuple[int, ...]: return self.left.shape def __getitem__(self, idx) -> np.ndarray: @@ -133,7 +134,7 @@ def __init__(self, left: Mask, right: Mask): self.right = right @property - def shape(self) -> Tuple[int, ...]: + def shape(self) -> tuple[int, ...]: return self.left.shape def __getitem__(self, idx) -> np.ndarray: @@ -167,7 +168,7 @@ def __post_init__(self): raise ValueError('Nesting MultiHeadMasks is not supported') @property - def shape(self) -> Tuple[int, ...]: + def shape(self) -> tuple[int, ...]: return (len(self.masks),) + self.masks[0].shape def __getitem__(self, idx) -> np.ndarray: @@ -208,13 +209,13 @@ class _ComputableMask(Mask): mask rather than loading it. """ - _shape: Tuple[int, int] + _shape: tuple[int, int] q_sequence: np.ndarray mask_function: Callable[..., Any] def __init__( self, - shape: Tuple[int, int], + shape: tuple[int, int], mask_function: Callable[..., Any], shard_count: int = 1, ): @@ -231,7 +232,7 @@ def __init__( self.q_sequence = np.arange(q_seq_len, dtype=np.int32) @property - def shape(self) -> Tuple[int, ...]: + def shape(self) -> tuple[int, ...]: return self._shape def __getitem__(self, idx) -> np.ndarray: @@ -271,7 +272,7 @@ class CausalMask(_ComputableMask): def __init__( self, - shape: Tuple[int, int], + shape: tuple[int, int], offset: int = 0, shard_count: int = 1, ): @@ -329,15 +330,15 @@ class LocalMask(Mask): # TODO(amagni): Transform LocalMask into a _ComputableMask. - _shape: Tuple[int, int] - window_size: Tuple[int | None, int | None] + _shape: tuple[int, int] + window_size: tuple[int | None, int | None] offset: int _q_sequence: np.ndarray | None = None def __init__( self, - shape: Tuple[int, int], - window_size: Tuple[int | None, int | None], + shape: tuple[int, int], + window_size: tuple[int | None, int | None], offset: int, shard_count: int = 1, ): @@ -352,7 +353,7 @@ def __init__( ) @property - def shape(self) -> Tuple[int, int]: + def shape(self) -> tuple[int, int]: return self._shape def __getitem__(self, idx) -> np.ndarray: @@ -429,7 +430,7 @@ def __post_init__(self): raise ValueError('Mask must be a boolean array') @property - def shape(self) -> Tuple[int, ...]: + def shape(self) -> tuple[int, ...]: return self.array.shape def __getitem__(self, idx) -> np.ndarray: @@ -467,7 +468,7 @@ def __post_init__(self): raise ValueError(f'Unsupported shape type: {type(self.shape)}') @property - def shape(self) -> Tuple[int, ...]: + def shape(self) -> tuple[int, ...]: return self._shape def __getitem__(self, idx) -> np.ndarray: diff --git a/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_mask_info.py b/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_mask_info.py index af046688067f..3c672b8dbe88 100644 --- a/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_mask_info.py +++ b/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_mask_info.py @@ -16,8 +16,9 @@ from __future__ import annotations import collections +from collections.abc import Callable import functools -from typing import Callable, Dict, List, NamedTuple, Set, Tuple +from typing import Dict, List, NamedTuple, Set, Tuple from jax import util as jax_util from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_mask as mask_lib import numpy as np @@ -161,11 +162,11 @@ def __eq__(self, other: object) -> bool: def _get_mask_info_for_shard( - output_shape: Tuple[int, int, int], + output_shape: tuple[int, int, int], has_mask_next: bool, mask: mask_lib.MultiHeadMask, - block_shape: Tuple[int, int], - coords_to_partial_mask_block_index: Dict[Tuple[int, int, int], int], + block_shape: tuple[int, int], + coords_to_partial_mask_block_index: dict[tuple[int, int, int], int], masks_per_head_shard: int, head_start: int, num_heads: int, @@ -173,7 +174,7 @@ def _get_mask_info_for_shard( q_seq_shard_size: int, blocked_q_seq_start: int, is_dkv: bool, -) -> Tuple[np.ndarray, np.ndarray | None]: +) -> tuple[np.ndarray, np.ndarray | None]: """Process a slice of the mask to compute data_next and mask_next. Args: @@ -310,7 +311,7 @@ def _get_mask_info_for_shard( @functools.lru_cache(maxsize=12) def _process_mask( mask: mask_lib.MultiHeadMask, # [num_heads, q_seq_len, kv_seq_len] - block_shape: Tuple[int, int], + block_shape: tuple[int, int], is_dkv: bool, *, downcast_smem_data: bool = True, @@ -394,18 +395,18 @@ def assign_unique_ids(objects): id_map = collections.defaultdict(lambda: len(id_map)) return {obj: id_map[obj] for obj in objects} - unique_masks_dict: Dict[mask_lib.Mask, int] = assign_unique_ids( + unique_masks_dict: dict[mask_lib.Mask, int] = assign_unique_ids( head_mask for head_mask in mask.masks ) # Build a mapping of heads to unique masks and masks to unique masks. - head_to_mask_id: List[int] = [0] * head_count - head_shard_to_mask_ids: List[Set[int]] = [set() for _ in range(head_shards)] - mask_id_to_heads: List[List[int]] = [ + head_to_mask_id: list[int] = [0] * head_count + head_shard_to_mask_ids: list[set[int]] = [set() for _ in range(head_shards)] + mask_id_to_heads: list[list[int]] = [ [] for _ in range(len(unique_masks_dict)) ] - mask_id_to_head_shards: List[Set[int]] = [ + mask_id_to_head_shards: list[set[int]] = [ set() for _ in range(len(unique_masks_dict)) ] @@ -436,10 +437,10 @@ def assign_unique_ids(objects): # TODO(amagni): checking the validity of the masks is slow for large masks. # Disable it for now, reevalute in the future. - partial_mask_block_ids: Dict[_HashableNDArray, int] = collections.defaultdict( + partial_mask_block_ids: dict[_HashableNDArray, int] = collections.defaultdict( lambda: len(partial_mask_block_ids) ) - block_id_to_block_coords: Dict[int, List[Tuple[int, ...]]] = ( + block_id_to_block_coords: dict[int, list[tuple[int, ...]]] = ( collections.defaultdict(list) ) diff --git a/jax/experimental/shard_map.py b/jax/experimental/shard_map.py index 23a80471279f..4957df4866f0 100644 --- a/jax/experimental/shard_map.py +++ b/jax/experimental/shard_map.py @@ -13,14 +13,14 @@ # limitations under the License. from __future__ import annotations -from collections.abc import Hashable, Sequence +from collections.abc import Callable, Hashable, Sequence import enum from functools import partial import inspect import itertools as it from math import prod import operator as op -from typing import Any, Callable, TypeVar, Union +from typing import Any, TypeVar, Union import numpy as np diff --git a/jax/experimental/slab/djax.py b/jax/experimental/slab/djax.py index cffe4d07f3f2..c989ede8663e 100644 --- a/jax/experimental/slab/djax.py +++ b/jax/experimental/slab/djax.py @@ -14,9 +14,9 @@ from __future__ import annotations -from functools import partial -from typing import Callable import collections +from collections.abc import Callable +from functools import partial import sys import numpy as np diff --git a/jax/experimental/slab/slab.py b/jax/experimental/slab/slab.py index 5300223787c5..af7b079eeb7f 100644 --- a/jax/experimental/slab/slab.py +++ b/jax/experimental/slab/slab.py @@ -14,10 +14,11 @@ from __future__ import annotations +from collections.abc import Iterable, Sequence from functools import partial, reduce -from typing import Iterable, NamedTuple, Sequence, Union import sys import typing +from typing import NamedTuple, Union import numpy as np diff --git a/jax/experimental/sparse/ad.py b/jax/experimental/sparse/ad.py index 489a3f748f03..2c235c9320d5 100644 --- a/jax/experimental/sparse/ad.py +++ b/jax/experimental/sparse/ad.py @@ -14,9 +14,9 @@ from __future__ import annotations -from collections.abc import Sequence +from collections.abc import Callable, Sequence import itertools -from typing import Any, Callable, Union +from typing import Any import jax from jax._src import core diff --git a/jax/experimental/sparse/linalg.py b/jax/experimental/sparse/linalg.py index c2990d3fed57..b0ac1fa5d380 100644 --- a/jax/experimental/sparse/linalg.py +++ b/jax/experimental/sparse/linalg.py @@ -16,7 +16,7 @@ from __future__ import annotations -from typing import Union, Callable +from collections.abc import Callable import functools import jax diff --git a/jax/experimental/sparse/test_util.py b/jax/experimental/sparse/test_util.py index 38e6785f2ded..365c436521b8 100644 --- a/jax/experimental/sparse/test_util.py +++ b/jax/experimental/sparse/test_util.py @@ -15,12 +15,11 @@ from __future__ import annotations -from collections.abc import Iterable, Iterator, Sequence +from collections.abc import Callable, Iterable, Iterator, Sequence import functools import itertools import math -from typing import Any, Callable, Union -from typing import NamedTuple +from typing import Any, NamedTuple import jax from jax import lax diff --git a/jax/experimental/sparse/transform.py b/jax/experimental/sparse/transform.py index 453083c57f2e..86eb8a9aefe8 100644 --- a/jax/experimental/sparse/transform.py +++ b/jax/experimental/sparse/transform.py @@ -47,9 +47,9 @@ -0.15574613], dtype=float32) """ -from collections.abc import Sequence +from collections.abc import Callable, Sequence import functools -from typing import Any, Callable, NamedTuple, Optional +from typing import Any, NamedTuple import numpy as np diff --git a/jaxlib/hlo_helpers.py b/jaxlib/hlo_helpers.py index 4ec995172585..0d57a04f1aa7 100644 --- a/jaxlib/hlo_helpers.py +++ b/jaxlib/hlo_helpers.py @@ -16,9 +16,9 @@ from __future__ import annotations -from collections.abc import Sequence +from collections.abc import Callable, Sequence from functools import partial -from typing import Callable, Union +from typing import Union import jaxlib.mlir.ir as ir import jaxlib.mlir.dialects.stablehlo as hlo diff --git a/tests/api_test.py b/tests/api_test.py index fb1d6f4cd0c8..3929a29f9c30 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -16,6 +16,7 @@ import collections import collections.abc +from collections.abc import Callable import concurrent.futures from contextlib import contextmanager import copy @@ -33,7 +34,7 @@ import subprocess import sys import types -from typing import Callable, NamedTuple +from typing import NamedTuple import unittest import weakref diff --git a/tests/batching_test.py b/tests/batching_test.py index 36e686443ac7..4d912bfca206 100644 --- a/tests/batching_test.py +++ b/tests/batching_test.py @@ -14,10 +14,11 @@ from __future__ import annotations +from collections.abc import Callable from contextlib import contextmanager from functools import partial import itertools as it -from typing import Any, Callable, TypeVar, Union +from typing import Any, TypeVar, Union import numpy as np from absl.testing import absltest diff --git a/tests/export_harnesses_multi_platform_test.py b/tests/export_harnesses_multi_platform_test.py index 5d737632a72b..035905d3f9e5 100644 --- a/tests/export_harnesses_multi_platform_test.py +++ b/tests/export_harnesses_multi_platform_test.py @@ -21,9 +21,9 @@ from __future__ import annotations +from collections.abc import Callable import math import re -from typing import Callable from absl import logging from absl.testing import absltest diff --git a/tests/export_test.py b/tests/export_test.py index 7875f82b099b..8ccf3bb35849 100644 --- a/tests/export_test.py +++ b/tests/export_test.py @@ -13,7 +13,7 @@ # limitations under the License. from __future__ import annotations -from collections.abc import Sequence +from collections.abc import Callable, Sequence import contextlib import dataclasses import functools @@ -21,7 +21,6 @@ import math import re import unittest -from typing import Callable from absl.testing import absltest import jax diff --git a/tests/fused_attention_stablehlo_test.py b/tests/fused_attention_stablehlo_test.py index e6ac2480d3e4..9cf7ba80dff5 100644 --- a/tests/fused_attention_stablehlo_test.py +++ b/tests/fused_attention_stablehlo_test.py @@ -14,7 +14,6 @@ from functools import partial from absl.testing import absltest -from typing import Optional import os os.environ["XLA_FLAGS"] = \ @@ -43,8 +42,8 @@ def sdpa_train(query: Array, key: Array, value: Array, grad: Array, - bias: Optional[Array] = None, - mask: Optional[Array] = None, + bias: Array | None = None, + mask: Array | None = None, scale: float = 0.5, mask_type: MaskType = MaskType.NO_MASK, is_bnth: bool = False, @@ -74,8 +73,8 @@ def sdpa_train(query: Array, def sdpa_ref(query: Array, key: Array, value: Array, - bias: Optional[Array] = None, - mask: Optional[Array] = None, + bias: Array | None = None, + mask: Array | None = None, scale: float = 0.5, mask_type: MaskType = MaskType.NO_MASK, dropout_rate: float = 0.1) -> Array: @@ -150,8 +149,8 @@ def sdpa_train_ref(query: Array, key: Array, value: Array, grad: Array, - bias: Optional[Array] = None, - mask: Optional[Array] = None, + bias: Array | None = None, + mask: Array | None = None, scale: float = 0.5, mask_type: MaskType = MaskType.NO_MASK, dropout_rate: float = 0.1) -> Array: diff --git a/tests/host_callback_test.py b/tests/host_callback_test.py index 27b0b6aed9d6..1ad59103c3e7 100644 --- a/tests/host_callback_test.py +++ b/tests/host_callback_test.py @@ -15,7 +15,7 @@ from __future__ import annotations import contextlib -from collections.abc import Sequence +from collections.abc import Callable, Sequence from functools import partial import itertools import logging @@ -23,7 +23,6 @@ import re import threading import time -from typing import Callable import unittest from unittest import skip, SkipTest diff --git a/tests/host_callback_to_tf_test.py b/tests/host_callback_to_tf_test.py index 05a25bd33b50..fe80c90ace68 100644 --- a/tests/host_callback_to_tf_test.py +++ b/tests/host_callback_to_tf_test.py @@ -18,7 +18,7 @@ This is separate from host_callback_test because it needs a TF dependency. """ -from typing import Callable +from collections.abc import Callable import unittest from absl.testing import absltest diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index 6835c644bed7..0f60a69a3702 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -16,7 +16,6 @@ from functools import partial import operator -from typing import Optional from absl.testing import absltest, parameterized import jax @@ -65,7 +64,7 @@ def mlir_sum(elems): return total -def copy(src: ir.Value, dst: ir.Value, swizzle: Optional[int] = None): +def copy(src: ir.Value, dst: ir.Value, swizzle: int | None = None): index = ir.IndexType.get() thread_id = gpu.thread_id(gpu.Dimension.x) stride = gpu.block_dim(gpu.Dimension.x) diff --git a/tests/pallas/gmm_test.py b/tests/pallas/gmm_test.py index f2a74a9275e4..be830a6a4473 100644 --- a/tests/pallas/gmm_test.py +++ b/tests/pallas/gmm_test.py @@ -14,7 +14,7 @@ import functools import itertools -from typing import Any, Union +from typing import Any from absl.testing import absltest from absl.testing import parameterized @@ -114,7 +114,7 @@ def random_dense( shape: tuple[int, ...], key: jax.Array, dtype: jnp.dtype, - limit: Union[int, None] = None, + limit: int | None = None, ) -> jnp.ndarray: if limit is None: limit = 1 / np.prod(shape) diff --git a/tests/pallas/splash_attention_kernel_test.py b/tests/pallas/splash_attention_kernel_test.py index 8173c858ed54..e6132a1966a3 100644 --- a/tests/pallas/splash_attention_kernel_test.py +++ b/tests/pallas/splash_attention_kernel_test.py @@ -15,9 +15,10 @@ """Tests for splash_attention.""" from __future__ import annotations +from collections.abc import Callable import dataclasses import functools -from typing import Any, Callable, TypeVar +from typing import Any, TypeVar import unittest from absl.testing import absltest @@ -360,7 +361,7 @@ def test_splash_attention(self, is_mqa, is_segmented, data): attn_logits_soft_cap = data.draw(attn_logits_soft_cap_strategy()) masks = data.draw(mha_mask_strategy(q_seq_len, kv_seq_len, num_q_heads)) - mask = mask_lib.MultiHeadMask(tuple((m.get_mask() for m in masks))) + mask = mask_lib.MultiHeadMask(tuple(m.get_mask() for m in masks)) block_sizes = data.draw(block_sizes_strategy(q_seq_len, kv_seq_len)) if is_mqa: @@ -421,7 +422,7 @@ def test_splash_attention_fwd( segment_ids = data.draw(segment_ids_strategy(q_seq_len)) attn_logits_soft_cap = data.draw(attn_logits_soft_cap_strategy()) masks = data.draw(mha_mask_strategy(q_seq_len, kv_seq_len, num_q_heads)) - mask = mask_lib.MultiHeadMask(tuple((m.get_mask() for m in masks))) + mask = mask_lib.MultiHeadMask(tuple(m.get_mask() for m in masks)) block_sizes = data.draw(block_sizes_strategy(q_seq_len, kv_seq_len)) if is_mqa: attn_ref = splash.make_masked_mqa_reference(mask) diff --git a/tests/pallas/splash_attention_mask_test.py b/tests/pallas/splash_attention_mask_test.py index a408872100c4..ce7d8fd09182 100644 --- a/tests/pallas/splash_attention_mask_test.py +++ b/tests/pallas/splash_attention_mask_test.py @@ -15,7 +15,6 @@ """Tests for splash_attention_masks.""" from __future__ import annotations -from typing import List from absl.testing import absltest from absl.testing import parameterized import jax @@ -733,7 +732,7 @@ def _expected_local_mask_next(self, mask_base_index: int): _expected_local_mask_next_dkv = _expected_local_mask_next - def _stack(self, arrays: List[np.ndarray]) -> np.ndarray: + def _stack(self, arrays: list[np.ndarray]) -> np.ndarray: return np.stack(arrays, axis=0) # For each test, check both the lazy and the dense versions of the mask. diff --git a/tests/shape_poly_test.py b/tests/shape_poly_test.py index a6d33a258318..f079d6753edd 100644 --- a/tests/shape_poly_test.py +++ b/tests/shape_poly_test.py @@ -16,13 +16,13 @@ from __future__ import annotations import enum -from collections.abc import Sequence +from collections.abc import Callable, Sequence import cProfile import itertools import math import os from pstats import Stats -from typing import Any, Callable +from typing import Any import unittest from absl import logging diff --git a/tests/shard_map_test.py b/tests/shard_map_test.py index e38801afb5b8..ca9d813e2571 100644 --- a/tests/shard_map_test.py +++ b/tests/shard_map_test.py @@ -14,14 +14,14 @@ from __future__ import annotations -from collections.abc import Sequence, Iterable, Iterator, Generator +from collections.abc import Callable, Generator, Iterable, Iterator, Sequence import contextlib from functools import partial import itertools as it import math import operator as op from types import SimpleNamespace -from typing import Any, NamedTuple, Callable, TypeVar +from typing import Any, NamedTuple, TypeVar import unittest from absl.testing import absltest diff --git a/tests/state_test.py b/tests/state_test.py index 5573049e76f0..b6dbb490b794 100644 --- a/tests/state_test.py +++ b/tests/state_test.py @@ -14,10 +14,10 @@ from __future__ import annotations -from collections.abc import Sequence +from collections.abc import Callable, Sequence from functools import partial import itertools as it -from typing import Any, Callable, NamedTuple, Union +from typing import Any, NamedTuple, Union from absl.testing import absltest from absl.testing import parameterized