diff --git a/jax/_src/abstract_arrays.py b/jax/_src/abstract_arrays.py index 423dbaf9ce58..67e04e99b9a7 100644 --- a/jax/_src/abstract_arrays.py +++ b/jax/_src/abstract_arrays.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import annotations - from functools import partial import numpy as np diff --git a/jax/_src/ad_checkpoint.py b/jax/_src/ad_checkpoint.py index 4b52292f0e2d..75df304b8dda 100644 --- a/jax/_src/ad_checkpoint.py +++ b/jax/_src/ad_checkpoint.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import annotations - from collections.abc import Callable, Sequence import functools from functools import partial diff --git a/jax/_src/ad_util.py b/jax/_src/ad_util.py index 90ae6c1413ec..7a9da13ae296 100644 --- a/jax/_src/ad_util.py +++ b/jax/_src/ad_util.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import annotations from collections.abc import Callable import types diff --git a/jax/_src/api.py b/jax/_src/api.py index 4a42693c2e8f..1bc1dfc51c8e 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -20,7 +20,6 @@ tree_util.py), which include nested tuples/lists/dicts, where the leaves are arrays. """ -from __future__ import annotations import collections from collections.abc import Callable, Generator, Hashable, Iterable, Sequence diff --git a/jax/_src/api_util.py b/jax/_src/api_util.py index 16a29e699bbc..e9899f075f71 100644 --- a/jax/_src/api_util.py +++ b/jax/_src/api_util.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import annotations - from collections.abc import Callable, Iterable, Sequence import inspect import operator diff --git a/jax/_src/array.py b/jax/_src/array.py index 6e3f0a76f512..d8410c98450d 100644 --- a/jax/_src/array.py +++ b/jax/_src/array.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import annotations - from collections import defaultdict from collections.abc import Callable, Sequence import enum diff --git a/jax/_src/basearray.py b/jax/_src/basearray.py index 5809b9649f26..fa194dea02a7 100644 --- a/jax/_src/basearray.py +++ b/jax/_src/basearray.py @@ -14,8 +14,6 @@ # Note that type annotations for this file are defined in basearray.pyi -from __future__ import annotations - import abc import numpy as np from typing import Any, Union diff --git a/jax/_src/callback.py b/jax/_src/callback.py index 054804379043..3b6cc0b0351c 100644 --- a/jax/_src/callback.py +++ b/jax/_src/callback.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. """Module for JAX callbacks.""" -from __future__ import annotations from collections.abc import Callable, Sequence import dataclasses diff --git a/jax/_src/checkify.py b/jax/_src/checkify.py index 9bbaa1296c93..a10e9ef0200d 100644 --- a/jax/_src/checkify.py +++ b/jax/_src/checkify.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import annotations from collections.abc import Callable, Sequence import dataclasses diff --git a/jax/_src/clusters/cloud_tpu_cluster.py b/jax/_src/clusters/cloud_tpu_cluster.py index 4a12445856f6..935534ffa80c 100644 --- a/jax/_src/clusters/cloud_tpu_cluster.py +++ b/jax/_src/clusters/cloud_tpu_cluster.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import annotations - import logging import os import re diff --git a/jax/_src/clusters/cluster.py b/jax/_src/clusters/cluster.py index 7d4fcfd43a3a..7620b31f1778 100644 --- a/jax/_src/clusters/cluster.py +++ b/jax/_src/clusters/cluster.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import annotations - from collections.abc import Sequence import logging from jax._src.cloud_tpu_init import running_in_cloud_tpu_vm diff --git a/jax/_src/clusters/ompi_cluster.py b/jax/_src/clusters/ompi_cluster.py index 9249958e2885..8b590718fa0d 100644 --- a/jax/_src/clusters/ompi_cluster.py +++ b/jax/_src/clusters/ompi_cluster.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import annotations - import os import re from jax._src import clusters diff --git a/jax/_src/clusters/slurm_cluster.py b/jax/_src/clusters/slurm_cluster.py index 933036a64eab..ce1e4b9d06a9 100644 --- a/jax/_src/clusters/slurm_cluster.py +++ b/jax/_src/clusters/slurm_cluster.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import annotations - import os from jax._src import clusters diff --git a/jax/_src/compilation_cache.py b/jax/_src/compilation_cache.py index 9c276151741d..06d407ccfd4b 100644 --- a/jax/_src/compilation_cache.py +++ b/jax/_src/compilation_cache.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import annotations - import logging import threading import warnings diff --git a/jax/_src/compiler.py b/jax/_src/compiler.py index 438f1f9e5183..fda8df7713ee 100644 --- a/jax/_src/compiler.py +++ b/jax/_src/compiler.py @@ -14,8 +14,6 @@ # Interface to the compiler -from __future__ import annotations - from collections.abc import Sequence import logging import os diff --git a/jax/_src/compute_on.py b/jax/_src/compute_on.py index 25b2be78d287..5cbde8e52472 100644 --- a/jax/_src/compute_on.py +++ b/jax/_src/compute_on.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import annotations import threading from contextlib import contextmanager diff --git a/jax/_src/config.py b/jax/_src/config.py index d9e5ee0137d0..46a62c2de137 100644 --- a/jax/_src/config.py +++ b/jax/_src/config.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import annotations - from collections.abc import Callable, Hashable, Iterator, Sequence import contextlib import functools diff --git a/jax/_src/core.py b/jax/_src/core.py index ecb801afed8d..7412e016c16c 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import annotations from collections import Counter, defaultdict, deque, namedtuple from collections.abc import (Callable, Collection, Generator, Hashable, diff --git a/jax/_src/custom_batching.py b/jax/_src/custom_batching.py index 4d41849b75d3..7a7fee416992 100644 --- a/jax/_src/custom_batching.py +++ b/jax/_src/custom_batching.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import annotations - from collections.abc import Callable import functools import operator diff --git a/jax/_src/custom_derivatives.py b/jax/_src/custom_derivatives.py index 46d9fab00455..33e624369fcd 100644 --- a/jax/_src/custom_derivatives.py +++ b/jax/_src/custom_derivatives.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import annotations - from collections.abc import Callable, Sequence import dataclasses from functools import update_wrapper, reduce, partial diff --git a/jax/_src/custom_transpose.py b/jax/_src/custom_transpose.py index a4de1b8cc46c..410527b05b40 100644 --- a/jax/_src/custom_transpose.py +++ b/jax/_src/custom_transpose.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import annotations - from collections.abc import Callable import functools from typing import Any diff --git a/jax/_src/debugger/cli_debugger.py b/jax/_src/debugger/cli_debugger.py index bf4b38765026..d268afb04faf 100644 --- a/jax/_src/debugger/cli_debugger.py +++ b/jax/_src/debugger/cli_debugger.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import annotations import cmd import pprint diff --git a/jax/_src/debugger/colab_debugger.py b/jax/_src/debugger/colab_debugger.py index 57a5be4825d6..5ee8d769666c 100644 --- a/jax/_src/debugger/colab_debugger.py +++ b/jax/_src/debugger/colab_debugger.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. """Module for Colab-specific debugger.""" -from __future__ import annotations import html import inspect diff --git a/jax/_src/debugger/colab_lib.py b/jax/_src/debugger/colab_lib.py index 69b6adb4bbf3..855ba6daa2f8 100644 --- a/jax/_src/debugger/colab_lib.py +++ b/jax/_src/debugger/colab_lib.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. """Module for building interfaces in Colab.""" -from __future__ import annotations import abc import dataclasses diff --git a/jax/_src/debugger/core.py b/jax/_src/debugger/core.py index f6b0a81baf92..7de97f1a8cef 100644 --- a/jax/_src/debugger/core.py +++ b/jax/_src/debugger/core.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import annotations from collections.abc import Hashable import dataclasses diff --git a/jax/_src/debugger/web_debugger.py b/jax/_src/debugger/web_debugger.py index 443bfa676715..c09a19a44101 100644 --- a/jax/_src/debugger/web_debugger.py +++ b/jax/_src/debugger/web_debugger.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import annotations import atexit import functools diff --git a/jax/_src/debugging.py b/jax/_src/debugging.py index 7d8b3a914b6d..56c201c47661 100644 --- a/jax/_src/debugging.py +++ b/jax/_src/debugging.py @@ -13,8 +13,6 @@ # limitations under the License. """Module for JAX debugging primitives and related functionality.""" -from __future__ import annotations - import importlib.util from collections.abc import Callable, Sequence import functools diff --git a/jax/_src/dispatch.py b/jax/_src/dispatch.py index 9ae1f7a6c2a3..3ced84373f30 100644 --- a/jax/_src/dispatch.py +++ b/jax/_src/dispatch.py @@ -13,7 +13,6 @@ # limitations under the License. # Primitive dispatch and jit dispatch. -from __future__ import annotations import atexit from collections.abc import Callable, Iterator, Sequence diff --git a/jax/_src/distributed.py b/jax/_src/distributed.py index c1a8ec7fe948..2205f3143114 100644 --- a/jax/_src/distributed.py +++ b/jax/_src/distributed.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import annotations - import atexit from collections.abc import Sequence import logging diff --git a/jax/_src/dlpack.py b/jax/_src/dlpack.py index 386123ae61f0..fc64fe1a89f5 100644 --- a/jax/_src/dlpack.py +++ b/jax/_src/dlpack.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import annotations - from typing import Any from jax._src.api import device_put diff --git a/jax/_src/dtypes.py b/jax/_src/dtypes.py index 1d50c5be74b6..f43ce9e74c9c 100644 --- a/jax/_src/dtypes.py +++ b/jax/_src/dtypes.py @@ -19,8 +19,6 @@ # b) the set of supported types (e.g., bfloat16), # so we need our own implementation that deviates from NumPy in places. -from __future__ import annotations - import abc import builtins import functools diff --git a/jax/_src/earray.py b/jax/_src/earray.py index f4b5e232bc33..da470dea2cbb 100644 --- a/jax/_src/earray.py +++ b/jax/_src/earray.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import annotations - import math from jax._src import api_util diff --git a/jax/_src/effects.py b/jax/_src/effects.py index 36528c5feae5..f0035ba222a3 100644 --- a/jax/_src/effects.py +++ b/jax/_src/effects.py @@ -50,8 +50,6 @@ https://jax.readthedocs.io/en/latest/jep/10657-sequencing-effects.html. """ -from __future__ import annotations - from collections.abc import Iterable, Set from typing import Any diff --git a/jax/_src/environment_info.py b/jax/_src/environment_info.py index 32d34db254fd..be340724e9dd 100644 --- a/jax/_src/environment_info.py +++ b/jax/_src/environment_info.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import annotations - import platform import subprocess import sys diff --git a/jax/_src/errors.py b/jax/_src/errors.py index 590f68ac0b3b..cfca4a1c50ee 100644 --- a/jax/_src/errors.py +++ b/jax/_src/errors.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import annotations from jax._src import core from jax._src.util import set_module diff --git a/jax/_src/export/_export.py b/jax/_src/export/_export.py index a228eaa8b285..12a2a7cc8159 100644 --- a/jax/_src/export/_export.py +++ b/jax/_src/export/_export.py @@ -15,8 +15,6 @@ """ -from __future__ import annotations - from collections.abc import Callable, Sequence import copy import dataclasses diff --git a/jax/_src/export/serialization.py b/jax/_src/export/serialization.py index a47b095e4450..bc319f07996e 100644 --- a/jax/_src/export/serialization.py +++ b/jax/_src/export/serialization.py @@ -14,8 +14,6 @@ # Serialization and deserialization of _export.Exported -from __future__ import annotations - from collections.abc import Callable, Sequence from functools import partial from typing import TypeVar diff --git a/jax/_src/export/shape_poly.py b/jax/_src/export/shape_poly.py index d380bc5a2476..593166351dea 100644 --- a/jax/_src/export/shape_poly.py +++ b/jax/_src/export/shape_poly.py @@ -16,8 +16,6 @@ See documentation at https://jax.readthedocs.io/en/latest/export/shape_poly.html. """ -from __future__ import annotations - import enum from collections.abc import Callable, Sequence import dataclasses diff --git a/jax/_src/export/shape_poly_decision.py b/jax/_src/export/shape_poly_decision.py index e325722b0c26..39e14da4803f 100644 --- a/jax/_src/export/shape_poly_decision.py +++ b/jax/_src/export/shape_poly_decision.py @@ -15,8 +15,6 @@ """ -from __future__ import annotations - from collections.abc import Sequence import itertools import math diff --git a/jax/_src/extend/ffi.py b/jax/_src/extend/ffi.py index aec124549e1e..ff321d603fce 100644 --- a/jax/_src/extend/ffi.py +++ b/jax/_src/extend/ffi.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import annotations - import os import ctypes from collections.abc import Iterable, Mapping, Sequence diff --git a/jax/_src/image/scale.py b/jax/_src/image/scale.py index aa9910555130..f6bef4bf276d 100644 --- a/jax/_src/image/scale.py +++ b/jax/_src/image/scale.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import annotations - from collections.abc import Callable, Sequence from functools import partial import enum diff --git a/jax/_src/internal_test_util/export_back_compat_test_util.py b/jax/_src/internal_test_util/export_back_compat_test_util.py index 5a975e3c5a61..a27320ab86ee 100644 --- a/jax/_src/internal_test_util/export_back_compat_test_util.py +++ b/jax/_src/internal_test_util/export_back_compat_test_util.py @@ -68,8 +68,6 @@ def func(...): ... """ -from __future__ import annotations - from collections.abc import Callable, Iterable, Sequence import dataclasses import datetime diff --git a/jax/_src/internal_test_util/lax_test_util.py b/jax/_src/internal_test_util/lax_test_util.py index b57b7d0852a9..0460f0fbd3c0 100644 --- a/jax/_src/internal_test_util/lax_test_util.py +++ b/jax/_src/internal_test_util/lax_test_util.py @@ -17,8 +17,6 @@ # only, and may be changed or removed at any time and without any deprecation # cycle. -from __future__ import annotations - import collections import itertools from typing import Union, cast diff --git a/jax/_src/internal_test_util/test_harnesses.py b/jax/_src/internal_test_util/test_harnesses.py index 2b22944c17b8..3edd1f7880a1 100644 --- a/jax/_src/internal_test_util/test_harnesses.py +++ b/jax/_src/internal_test_util/test_harnesses.py @@ -36,8 +36,6 @@ to fail. A Limitation is specific to a harness. """ -from __future__ import annotations - from collections.abc import Callable, Iterable, Sequence import operator import os diff --git a/jax/_src/interpreters/ad.py b/jax/_src/interpreters/ad.py index a527acb8db90..d08f330c85fe 100644 --- a/jax/_src/interpreters/ad.py +++ b/jax/_src/interpreters/ad.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import annotations - from collections.abc import Callable, Sequence import contextlib import functools diff --git a/jax/_src/interpreters/batching.py b/jax/_src/interpreters/batching.py index 3a87fffa5116..a369ba425ab0 100644 --- a/jax/_src/interpreters/batching.py +++ b/jax/_src/interpreters/batching.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import annotations import collections from collections.abc import Callable, Iterable, Sequence diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index 7eb826c95a67..c8b0598de9dd 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -13,7 +13,6 @@ # limitations under the License. # Lowering and execution path that converts jaxprs into MLIR. -from __future__ import annotations import collections from collections.abc import Callable, Iterator, Sequence diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py index 497c9ea129a8..a3df4b64a0dc 100644 --- a/jax/_src/interpreters/partial_eval.py +++ b/jax/_src/interpreters/partial_eval.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import annotations from collections import namedtuple from collections.abc import Callable, Sequence, Hashable diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 69d7c619b0a6..019c1c75afa4 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -13,8 +13,6 @@ # limitations under the License. """Implementation of pmap and related functionality.""" -from __future__ import annotations - import enum from contextlib import contextmanager import collections diff --git a/jax/_src/interpreters/xla.py b/jax/_src/interpreters/xla.py index 2db877d3f970..f6c3639cf6d3 100644 --- a/jax/_src/interpreters/xla.py +++ b/jax/_src/interpreters/xla.py @@ -14,8 +14,6 @@ # Lowering of jaxprs into XLA (HLO) computations. -from __future__ import annotations - from collections import defaultdict from collections.abc import Callable, Sequence import dataclasses diff --git a/jax/_src/jaxpr_util.py b/jax/_src/jaxpr_util.py index 3f3f677b069d..4e3d098a8a9e 100644 --- a/jax/_src/jaxpr_util.py +++ b/jax/_src/jaxpr_util.py @@ -14,8 +14,6 @@ """Utilities for the Jaxpr IR.""" -from __future__ import annotations - from collections import Counter, defaultdict from collections.abc import Callable import gzip diff --git a/jax/_src/lax/control_flow/common.py b/jax/_src/lax/control_flow/common.py index b613193876b6..0e351113d831 100644 --- a/jax/_src/lax/control_flow/common.py +++ b/jax/_src/lax/control_flow/common.py @@ -13,8 +13,6 @@ # limitations under the License. """Module for the common control flow utilities.""" -from __future__ import annotations - from collections.abc import Callable, Sequence import os from functools import partial diff --git a/jax/_src/lax/control_flow/conditionals.py b/jax/_src/lax/control_flow/conditionals.py index ffe5d086a98d..98d63d63f4bb 100644 --- a/jax/_src/lax/control_flow/conditionals.py +++ b/jax/_src/lax/control_flow/conditionals.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. """Module for conditional control flow primitives.""" -from __future__ import annotations import collections from collections.abc import Callable, Sequence diff --git a/jax/_src/lax/control_flow/for_loop.py b/jax/_src/lax/control_flow/for_loop.py index 936656b0e7df..01d02bf1cc7b 100644 --- a/jax/_src/lax/control_flow/for_loop.py +++ b/jax/_src/lax/control_flow/for_loop.py @@ -13,8 +13,6 @@ # limitations under the License. """Module for the `for_loop` primitive.""" -from __future__ import annotations - from collections.abc import Callable, Sequence import functools import operator diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py index 0c704b84475b..082a0e97f7bd 100644 --- a/jax/_src/lax/control_flow/loops.py +++ b/jax/_src/lax/control_flow/loops.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. """Module for the loop primitives.""" -from __future__ import annotations from collections.abc import Callable, Sequence from functools import partial diff --git a/jax/_src/lax/convolution.py b/jax/_src/lax/convolution.py index 2b2ad5bbb515..953d7813a3fb 100644 --- a/jax/_src/lax/convolution.py +++ b/jax/_src/lax/convolution.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import annotations - from collections.abc import Sequence from functools import partial import operator diff --git a/jax/_src/lax/eigh.py b/jax/_src/lax/eigh.py index fc66b0f2e7ee..919179f19251 100644 --- a/jax/_src/lax/eigh.py +++ b/jax/_src/lax/eigh.py @@ -25,8 +25,6 @@ CPU and GPU also. """ -from __future__ import annotations - from functools import partial from typing import NamedTuple diff --git a/jax/_src/lax/fft.py b/jax/_src/lax/fft.py index a1cce3500df1..ba1d20e649dd 100644 --- a/jax/_src/lax/fft.py +++ b/jax/_src/lax/fft.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import annotations - from collections.abc import Sequence from functools import partial import math diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 4d4abc1b18f7..23620bd14af4 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import annotations - import builtins from collections.abc import Callable, Sequence import enum diff --git a/jax/_src/lax/linalg.py b/jax/_src/lax/linalg.py index d31bba99171c..90e04ca1fd35 100644 --- a/jax/_src/lax/linalg.py +++ b/jax/_src/lax/linalg.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import annotations - from collections.abc import Callable import functools from functools import partial diff --git a/jax/_src/lax/other.py b/jax/_src/lax/other.py index c84536495774..84eb131ac7b3 100644 --- a/jax/_src/lax/other.py +++ b/jax/_src/lax/other.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import annotations - from collections.abc import Sequence import math from typing import Any, Union, cast as type_cast diff --git a/jax/_src/lax/parallel.py b/jax/_src/lax/parallel.py index 47386cb4a5f0..e6766b3eda83 100644 --- a/jax/_src/lax/parallel.py +++ b/jax/_src/lax/parallel.py @@ -15,8 +15,6 @@ Parallelization primitives. """ -from __future__ import annotations - from collections.abc import Sequence from functools import partial import itertools diff --git a/jax/_src/lax/qdwh.py b/jax/_src/lax/qdwh.py index bac3ea957955..3bc3c60570fe 100644 --- a/jax/_src/lax/qdwh.py +++ b/jax/_src/lax/qdwh.py @@ -24,8 +24,6 @@ https://epubs.siam.org/doi/abs/10.1137/090774999 """ -from __future__ import annotations - import functools import jax diff --git a/jax/_src/lax/slicing.py b/jax/_src/lax/slicing.py index b2bd30b3d364..b5731c7d066d 100644 --- a/jax/_src/lax/slicing.py +++ b/jax/_src/lax/slicing.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import annotations - from collections.abc import Callable, Sequence import enum import operator diff --git a/jax/_src/lax/stack.py b/jax/_src/lax/stack.py index 882195f17d51..29d74cbb3f37 100644 --- a/jax/_src/lax/stack.py +++ b/jax/_src/lax/stack.py @@ -18,8 +18,6 @@ Eigendecomposition on TPU. """ -from __future__ import annotations - from typing import Any import jax diff --git a/jax/_src/lax/svd.py b/jax/_src/lax/svd.py index 77ff4297e137..44f745beb533 100644 --- a/jax/_src/lax/svd.py +++ b/jax/_src/lax/svd.py @@ -33,8 +33,6 @@ https://epubs.siam.org/doi/abs/10.1137/090774999 """ -from __future__ import annotations - from collections.abc import Sequence import functools import operator diff --git a/jax/_src/lax/windowed_reductions.py b/jax/_src/lax/windowed_reductions.py index 096fce7deb3a..25a0b717c701 100644 --- a/jax/_src/lax/windowed_reductions.py +++ b/jax/_src/lax/windowed_reductions.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import annotations - from collections.abc import Callable, Sequence from functools import partial import warnings diff --git a/jax/_src/layout.py b/jax/_src/layout.py index 2071794a09fb..7685d8f836d6 100644 --- a/jax/_src/layout.py +++ b/jax/_src/layout.py @@ -12,8 +12,6 @@ # See the License for the ific language governing permissions and # limitations under the License. -from __future__ import annotations - from typing import Union from jax._src.sharding import Sharding diff --git a/jax/_src/lib/__init__.py b/jax/_src/lib/__init__.py index 3d426ff370ee..459475b1f49a 100644 --- a/jax/_src/lib/__init__.py +++ b/jax/_src/lib/__init__.py @@ -15,8 +15,6 @@ # This module is largely a wrapper around `jaxlib` that performs version # checking on import. -from __future__ import annotations - import gc import pathlib import re diff --git a/jax/_src/linear_util.py b/jax/_src/linear_util.py index bc4cc242f055..a98e45574aeb 100644 --- a/jax/_src/linear_util.py +++ b/jax/_src/linear_util.py @@ -61,7 +61,6 @@ def trans1(static_arg, *dynamic_args, **kwargs): dynamic positional arguments for the generators, and also the auxiliary output data must be immutable, because it will be stored in function memoization tables. """ -from __future__ import annotations from collections.abc import Callable from functools import partial diff --git a/jax/_src/lru_cache.py b/jax/_src/lru_cache.py index 3b1f9df07210..4ed8e4e4e92a 100644 --- a/jax/_src/lru_cache.py +++ b/jax/_src/lru_cache.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import annotations - import heapq import logging import pathlib diff --git a/jax/_src/maps.py b/jax/_src/maps.py index 20fc54d8fe37..f35b8ff81735 100644 --- a/jax/_src/maps.py +++ b/jax/_src/maps.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import annotations - from collections import OrderedDict, abc from collections.abc import Callable, Iterable, Sequence, Mapping import contextlib diff --git a/jax/_src/mesh.py b/jax/_src/mesh.py index 32138678561f..b6c1e06b2b19 100644 --- a/jax/_src/mesh.py +++ b/jax/_src/mesh.py @@ -13,8 +13,6 @@ # limitations under the License. """Definitions of Mesh and ResourceEnv.""" -from __future__ import annotations - import collections from collections.abc import Hashable, Sequence import contextlib diff --git a/jax/_src/monitoring.py b/jax/_src/monitoring.py index 3b291de0061a..af197f72906a 100644 --- a/jax/_src/monitoring.py +++ b/jax/_src/monitoring.py @@ -21,8 +21,6 @@ aggregation/exporting. """ -from __future__ import annotations - from typing import Protocol diff --git a/jax/_src/nn/functions.py b/jax/_src/nn/functions.py index 822fb548ed90..13c1e3224aab 100644 --- a/jax/_src/nn/functions.py +++ b/jax/_src/nn/functions.py @@ -14,8 +14,6 @@ """Shared neural network activations and other functions.""" -from __future__ import annotations - from functools import partial import operator import numpy as np diff --git a/jax/_src/nn/initializers.py b/jax/_src/nn/initializers.py index cf245f7927be..073a06069a23 100644 --- a/jax/_src/nn/initializers.py +++ b/jax/_src/nn/initializers.py @@ -17,8 +17,6 @@ used in Keras and Sonnet. """ -from __future__ import annotations - from collections.abc import Sequence import math import typing diff --git a/jax/_src/numpy/array_methods.py b/jax/_src/numpy/array_methods.py index 1d27c4b3aa28..6edd60f91f17 100644 --- a/jax/_src/numpy/array_methods.py +++ b/jax/_src/numpy/array_methods.py @@ -19,8 +19,6 @@ This is done dynamically in order to avoid circular imports. """ -from __future__ import annotations - __all__ = ['register_jax_array_methods'] import abc diff --git a/jax/_src/numpy/fft.py b/jax/_src/numpy/fft.py index 03e468fa99a9..92c12eb78c53 100644 --- a/jax/_src/numpy/fft.py +++ b/jax/_src/numpy/fft.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import annotations - from collections.abc import Sequence import operator import numpy as np diff --git a/jax/_src/numpy/index_tricks.py b/jax/_src/numpy/index_tricks.py index 90a17000cf16..d346234e6fd5 100644 --- a/jax/_src/numpy/index_tricks.py +++ b/jax/_src/numpy/index_tricks.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import annotations - from collections.abc import Iterable from typing import Any, Union diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 80eaf01dbd95..199f6f73f2f3 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -23,7 +23,6 @@ transformations for NumPy primitives can be derived from the transformation rules for the underlying :code:`lax` primitives. """ -from __future__ import annotations import builtins import collections diff --git a/jax/_src/numpy/linalg.py b/jax/_src/numpy/linalg.py index 63aca76e6098..364d137c1175 100644 --- a/jax/_src/numpy/linalg.py +++ b/jax/_src/numpy/linalg.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import annotations - from collections.abc import Sequence from functools import partial import itertools diff --git a/jax/_src/numpy/polynomial.py b/jax/_src/numpy/polynomial.py index 45595c4387a2..aad1a587ae68 100644 --- a/jax/_src/numpy/polynomial.py +++ b/jax/_src/numpy/polynomial.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import annotations - from functools import partial import operator diff --git a/jax/_src/numpy/reductions.py b/jax/_src/numpy/reductions.py index 6abdf884b7e3..00b8876d0dd2 100644 --- a/jax/_src/numpy/reductions.py +++ b/jax/_src/numpy/reductions.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import annotations - import builtins from collections.abc import Callable, Sequence from functools import partial diff --git a/jax/_src/numpy/setops.py b/jax/_src/numpy/setops.py index 34968b2c7599..c10d35ddeb26 100644 --- a/jax/_src/numpy/setops.py +++ b/jax/_src/numpy/setops.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import annotations - from functools import partial import math import operator diff --git a/jax/_src/numpy/ufunc_api.py b/jax/_src/numpy/ufunc_api.py index 2e114193af13..8bc492fd9dcc 100644 --- a/jax/_src/numpy/ufunc_api.py +++ b/jax/_src/numpy/ufunc_api.py @@ -14,8 +14,6 @@ """Tools to create numpy-style ufuncs.""" -from __future__ import annotations - from collections.abc import Callable from functools import partial import math diff --git a/jax/_src/numpy/ufuncs.py b/jax/_src/numpy/ufuncs.py index 1a75e413379a..aab0928aa1bb 100644 --- a/jax/_src/numpy/ufuncs.py +++ b/jax/_src/numpy/ufuncs.py @@ -16,8 +16,6 @@ Implements ufuncs for jax.numpy. """ -from __future__ import annotations - from collections.abc import Callable from functools import partial import operator diff --git a/jax/_src/numpy/util.py b/jax/_src/numpy/util.py index 21b96deea3c6..d9b4a6bbae3c 100644 --- a/jax/_src/numpy/util.py +++ b/jax/_src/numpy/util.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import annotations from collections.abc import Callable, Sequence from functools import partial diff --git a/jax/_src/numpy/vectorize.py b/jax/_src/numpy/vectorize.py index 2c517467e287..b2b273c7d818 100644 --- a/jax/_src/numpy/vectorize.py +++ b/jax/_src/numpy/vectorize.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import annotations from collections.abc import Callable, Collection, Sequence import functools diff --git a/jax/_src/op_shardings.py b/jax/_src/op_shardings.py index 6b248736ce45..93b64c77fc76 100644 --- a/jax/_src/op_shardings.py +++ b/jax/_src/op_shardings.py @@ -13,8 +13,6 @@ # limitations under the License. """Sharding utilities""" -from __future__ import annotations - from collections.abc import Sequence import itertools from typing import Union diff --git a/jax/_src/ops/scatter.py b/jax/_src/ops/scatter.py index 2bcfe96ad2f0..bea4ff357eb0 100644 --- a/jax/_src/ops/scatter.py +++ b/jax/_src/ops/scatter.py @@ -14,8 +14,6 @@ # Helpers for indexed updates. -from __future__ import annotations - from collections.abc import Callable, Sequence from typing import Union import warnings diff --git a/jax/_src/ops/special.py b/jax/_src/ops/special.py index 59ad594ef2bc..707bd2f9854a 100644 --- a/jax/_src/ops/special.py +++ b/jax/_src/ops/special.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import annotations - from typing import overload, Literal import jax diff --git a/jax/_src/pallas/core.py b/jax/_src/pallas/core.py index 38866b082da6..ca4a6459bae2 100644 --- a/jax/_src/pallas/core.py +++ b/jax/_src/pallas/core.py @@ -13,7 +13,6 @@ # limitations under the License. """Module for pallas-core functionality.""" -from __future__ import annotations from collections.abc import Callable, Iterator, Sequence import copy diff --git a/jax/_src/pallas/mosaic/core.py b/jax/_src/pallas/mosaic/core.py index f4a794792253..33a6402a244b 100644 --- a/jax/_src/pallas/mosaic/core.py +++ b/jax/_src/pallas/mosaic/core.py @@ -13,7 +13,6 @@ # limitations under the License. """Contains TPU-specific Pallas abstractions.""" -from __future__ import annotations from collections.abc import Sequence import dataclasses diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index 3ebbfdb51b5b..f7affa69c533 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -13,7 +13,6 @@ # limitations under the License. """Module for lowering JAX to Mosaic-compatible MLIR dialects.""" -from __future__ import annotations from collections.abc import Callable, Sequence import dataclasses diff --git a/jax/_src/pallas/mosaic/pallas_call_registration.py b/jax/_src/pallas/mosaic/pallas_call_registration.py index 8f307e560bf0..291423bc1b78 100644 --- a/jax/_src/pallas/mosaic/pallas_call_registration.py +++ b/jax/_src/pallas/mosaic/pallas_call_registration.py @@ -14,8 +14,6 @@ """Contains registrations for pallas_call on TPU.""" -from __future__ import annotations - from typing import Any import warnings diff --git a/jax/_src/pallas/mosaic/pipeline.py b/jax/_src/pallas/mosaic/pipeline.py index 0d778a60c711..35e8b87f05bc 100644 --- a/jax/_src/pallas/mosaic/pipeline.py +++ b/jax/_src/pallas/mosaic/pipeline.py @@ -13,7 +13,6 @@ # limitations under the License. """Module for emitting custom TPU pipelines within a Pallas call.""" -from __future__ import annotations from collections.abc import Sequence import dataclasses diff --git a/jax/_src/pallas/mosaic/primitives.py b/jax/_src/pallas/mosaic/primitives.py index f4c24e4e5e16..4e71acce8b81 100644 --- a/jax/_src/pallas/mosaic/primitives.py +++ b/jax/_src/pallas/mosaic/primitives.py @@ -13,7 +13,6 @@ # limitations under the License. """Module for Pallas:TPU-specific JAX primitives and functions.""" -from __future__ import annotations from collections.abc import Callable import dataclasses diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index 5b4db68f2552..cfbaa5cf947b 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -14,8 +14,6 @@ """Module for lowering JAX primitives to Mosaic GPU.""" -from __future__ import annotations - from collections.abc import Sequence import dataclasses import functools diff --git a/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py b/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py index 740f0c31ebb7..dad1e0dcd691 100644 --- a/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py +++ b/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py @@ -15,8 +15,6 @@ """Module registering a lowering rule for pallas_call on GPU.""" -from __future__ import annotations - from typing import Any import jax diff --git a/jax/_src/pallas/pallas_call.py b/jax/_src/pallas/pallas_call.py index 0748e78a2db9..340b5131a2b5 100644 --- a/jax/_src/pallas/pallas_call.py +++ b/jax/_src/pallas/pallas_call.py @@ -13,7 +13,6 @@ # limitations under the License. """Module for calling pallas functions from JAX.""" -from __future__ import annotations from collections.abc import Callable, Sequence from functools import partial, reduce diff --git a/jax/_src/pallas/primitives.py b/jax/_src/pallas/primitives.py index 4abc5ced1af0..21b5db48d917 100644 --- a/jax/_src/pallas/primitives.py +++ b/jax/_src/pallas/primitives.py @@ -14,8 +14,6 @@ """Pallas-specific JAX primitives.""" -from __future__ import annotations - import enum import functools import string diff --git a/jax/_src/pallas/triton/lowering.py b/jax/_src/pallas/triton/lowering.py index c270e8084f42..36a2dd29c4e4 100644 --- a/jax/_src/pallas/triton/lowering.py +++ b/jax/_src/pallas/triton/lowering.py @@ -14,8 +14,6 @@ """Module for lowering JAX primitives to Triton IR.""" -from __future__ import annotations - from collections.abc import Callable, Sequence import dataclasses import functools diff --git a/jax/_src/pallas/triton/pallas_call_registration.py b/jax/_src/pallas/triton/pallas_call_registration.py index e6d521692ec2..041bc77f4618 100644 --- a/jax/_src/pallas/triton/pallas_call_registration.py +++ b/jax/_src/pallas/triton/pallas_call_registration.py @@ -14,8 +14,6 @@ """Module registering a lowering rule for pallas_call on GPU.""" -from __future__ import annotations - import io from typing import Any diff --git a/jax/_src/pallas/triton/primitives.py b/jax/_src/pallas/triton/primitives.py index 8518a94ed9cf..e611163d9240 100644 --- a/jax/_src/pallas/triton/primitives.py +++ b/jax/_src/pallas/triton/primitives.py @@ -14,8 +14,6 @@ """Module for GPU-specific JAX primitives.""" -from __future__ import annotations - from collections.abc import Sequence import jax diff --git a/jax/_src/pallas/utils.py b/jax/_src/pallas/utils.py index 41466be0822d..bbf2d4c05ad1 100644 --- a/jax/_src/pallas/utils.py +++ b/jax/_src/pallas/utils.py @@ -14,7 +14,6 @@ """Pallas utility functions.""" -from __future__ import annotations from typing import overload import jax diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 204c288d6993..7e49216e5715 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import annotations - from collections import defaultdict from collections.abc import Callable, Sequence, Iterable import dataclasses diff --git a/jax/_src/pretty_printer.py b/jax/_src/pretty_printer.py index 5c1e7e1198e8..9262a3bfa11e 100644 --- a/jax/_src/pretty_printer.py +++ b/jax/_src/pretty_printer.py @@ -25,8 +25,6 @@ # Annotations. https://ayazhafiz.com/articles/21/strictly-annotated # -from __future__ import annotations - from collections.abc import Sequence import enum from functools import partial diff --git a/jax/_src/prng.py b/jax/_src/prng.py index d585b312fafc..d27c6c52a5fc 100644 --- a/jax/_src/prng.py +++ b/jax/_src/prng.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import annotations from collections.abc import Callable, Iterator, Sequence from functools import partial, reduce diff --git a/jax/_src/profiler.py b/jax/_src/profiler.py index cad4826ba801..c84a66860c9f 100644 --- a/jax/_src/profiler.py +++ b/jax/_src/profiler.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import annotations - from collections.abc import Callable from contextlib import contextmanager from functools import wraps diff --git a/jax/_src/random.py b/jax/_src/random.py index 6a0a3c0f9932..7b930d56f896 100644 --- a/jax/_src/random.py +++ b/jax/_src/random.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import annotations - from collections.abc import Hashable, Sequence from functools import partial import math diff --git a/jax/_src/scipy/cluster/vq.py b/jax/_src/scipy/cluster/vq.py index a82c8928644d..434136de87fe 100644 --- a/jax/_src/scipy/cluster/vq.py +++ b/jax/_src/scipy/cluster/vq.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import annotations import operator diff --git a/jax/_src/scipy/fft.py b/jax/_src/scipy/fft.py index f1d907cf3f3b..cfbbb247011c 100644 --- a/jax/_src/scipy/fft.py +++ b/jax/_src/scipy/fft.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import annotations - from collections.abc import Sequence from functools import partial import math diff --git a/jax/_src/scipy/integrate.py b/jax/_src/scipy/integrate.py index b61cdb163b8d..2d9c26c93138 100644 --- a/jax/_src/scipy/integrate.py +++ b/jax/_src/scipy/integrate.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import annotations - from functools import partial from jax import jit diff --git a/jax/_src/scipy/linalg.py b/jax/_src/scipy/linalg.py index 72f07a3441eb..d5b950dc37ca 100644 --- a/jax/_src/scipy/linalg.py +++ b/jax/_src/scipy/linalg.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import annotations - from functools import partial import numpy as np diff --git a/jax/_src/scipy/optimize/_lbfgs.py b/jax/_src/scipy/optimize/_lbfgs.py index aa82ab4fd0c8..0d1e335d7dc2 100644 --- a/jax/_src/scipy/optimize/_lbfgs.py +++ b/jax/_src/scipy/optimize/_lbfgs.py @@ -13,8 +13,6 @@ # limitations under the License. """The Limited-Memory Broyden-Fletcher-Goldfarb-Shanno minimization algorithm.""" -from __future__ import annotations - from collections.abc import Callable from functools import partial from typing import NamedTuple diff --git a/jax/_src/scipy/optimize/bfgs.py b/jax/_src/scipy/optimize/bfgs.py index 657b7610e6e1..891e9d00de82 100644 --- a/jax/_src/scipy/optimize/bfgs.py +++ b/jax/_src/scipy/optimize/bfgs.py @@ -13,8 +13,6 @@ # limitations under the License. """The Broyden-Fletcher-Goldfarb-Shanno minimization algorithm.""" -from __future__ import annotations - from collections.abc import Callable from functools import partial from typing import NamedTuple diff --git a/jax/_src/scipy/optimize/line_search.py b/jax/_src/scipy/optimize/line_search.py index 189009693cdd..3e21d6b959cb 100644 --- a/jax/_src/scipy/optimize/line_search.py +++ b/jax/_src/scipy/optimize/line_search.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import annotations - from typing import NamedTuple from functools import partial diff --git a/jax/_src/scipy/optimize/minimize.py b/jax/_src/scipy/optimize/minimize.py index 4fc006be6df0..14f0490664e0 100644 --- a/jax/_src/scipy/optimize/minimize.py +++ b/jax/_src/scipy/optimize/minimize.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import annotations - from collections.abc import Callable, Mapping from typing import Any diff --git a/jax/_src/scipy/signal.py b/jax/_src/scipy/signal.py index 1282650ae1e5..ca5981ffdb01 100644 --- a/jax/_src/scipy/signal.py +++ b/jax/_src/scipy/signal.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import annotations - from collections.abc import Callable, Sequence from functools import partial import math diff --git a/jax/_src/scipy/spatial/transform.py b/jax/_src/scipy/spatial/transform.py index ec7165e32ffd..27133aee19a8 100644 --- a/jax/_src/scipy/spatial/transform.py +++ b/jax/_src/scipy/spatial/transform.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import annotations - import functools import re import typing diff --git a/jax/_src/scipy/special.py b/jax/_src/scipy/special.py index 70f3ccd2ef80..6571aec6da7e 100644 --- a/jax/_src/scipy/special.py +++ b/jax/_src/scipy/special.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import annotations - from functools import partial import operator from typing import cast, overload, Any diff --git a/jax/_src/scipy/stats/_core.py b/jax/_src/scipy/stats/_core.py index 08d1c0b6b538..15dea8d4b6ff 100644 --- a/jax/_src/scipy/stats/_core.py +++ b/jax/_src/scipy/stats/_core.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import annotations - from collections import namedtuple from functools import partial import math diff --git a/jax/_src/sharding.py b/jax/_src/sharding.py index 14dbbba6e975..70c05c23be50 100644 --- a/jax/_src/sharding.py +++ b/jax/_src/sharding.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import annotations - from collections.abc import Mapping, Sequence import functools diff --git a/jax/_src/sharding_impls.py b/jax/_src/sharding_impls.py index acceeef86a4a..6430f7435b34 100644 --- a/jax/_src/sharding_impls.py +++ b/jax/_src/sharding_impls.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import annotations - import collections from collections import OrderedDict from collections.abc import Mapping, Sequence diff --git a/jax/_src/sharding_specs.py b/jax/_src/sharding_specs.py index 7092b51ab894..d91ffcd54614 100644 --- a/jax/_src/sharding_specs.py +++ b/jax/_src/sharding_specs.py @@ -27,8 +27,6 @@ # This encoding is assumed by various parts of the system, e.g. generating # replica groups for collective operations. -from __future__ import annotations - from collections.abc import Sequence import itertools import math diff --git a/jax/_src/source_info_util.py b/jax/_src/source_info_util.py index 2f0ff74ab9fe..515764dc8979 100644 --- a/jax/_src/source_info_util.py +++ b/jax/_src/source_info_util.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import annotations - from collections.abc import Iterator import contextlib import dataclasses diff --git a/jax/_src/sourcemap.py b/jax/_src/sourcemap.py index b54f2193ff26..826614cd6105 100644 --- a/jax/_src/sourcemap.py +++ b/jax/_src/sourcemap.py @@ -16,8 +16,6 @@ An implementation of sourcemaps following `TC39 `_. """ -from __future__ import annotations - from collections.abc import Iterable, Sequence from dataclasses import dataclass import json diff --git a/jax/_src/stages.py b/jax/_src/stages.py index 874ef8834557..9827131a5414 100644 --- a/jax/_src/stages.py +++ b/jax/_src/stages.py @@ -28,7 +28,6 @@ various internal XLA-backed lowerings and executables into the lowering and executable protocols described above. """ -from __future__ import annotations import functools from collections.abc import Sequence diff --git a/jax/_src/state/discharge.py b/jax/_src/state/discharge.py index f3a3e61a2ace..06ed7a419b2e 100644 --- a/jax/_src/state/discharge.py +++ b/jax/_src/state/discharge.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. """Module for discharging state primitives.""" -from __future__ import annotations from collections.abc import Callable, Sequence import dataclasses diff --git a/jax/_src/state/indexing.py b/jax/_src/state/indexing.py index acf1c7216240..25b79aa3ebb1 100644 --- a/jax/_src/state/indexing.py +++ b/jax/_src/state/indexing.py @@ -14,8 +14,6 @@ """Contains shared logic and abstractions for Pallas indexing ops.""" -from __future__ import annotations - import dataclasses from typing import Any, Union diff --git a/jax/_src/state/primitives.py b/jax/_src/state/primitives.py index 224c2f351ae3..8817c0486967 100644 --- a/jax/_src/state/primitives.py +++ b/jax/_src/state/primitives.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. """Module for state primitives.""" -from __future__ import annotations from functools import partial from typing import Any, Union diff --git a/jax/_src/state/types.py b/jax/_src/state/types.py index 303e4da0b5bf..bf10c4d9ebc1 100644 --- a/jax/_src/state/types.py +++ b/jax/_src/state/types.py @@ -13,8 +13,6 @@ # limitations under the License. """Module for state types.""" -from __future__ import annotations - from collections.abc import Sequence import dataclasses import math diff --git a/jax/_src/test_util.py b/jax/_src/test_util.py index edd769aff5c6..de035b36999e 100644 --- a/jax/_src/test_util.py +++ b/jax/_src/test_util.py @@ -13,7 +13,6 @@ # limitations under the License. # pyformat: disable -from __future__ import annotations import collections from collections.abc import Callable, Generator, Iterable, Sequence diff --git a/jax/_src/third_party/scipy/linalg.py b/jax/_src/third_party/scipy/linalg.py index dce4df1fb817..02da16ce1e7b 100644 --- a/jax/_src/third_party/scipy/linalg.py +++ b/jax/_src/third_party/scipy/linalg.py @@ -1,4 +1,3 @@ -from __future__ import annotations from collections.abc import Callable diff --git a/jax/_src/third_party/scipy/signal_helper.py b/jax/_src/third_party/scipy/signal_helper.py index 4a021675804d..a76002f06cc7 100644 --- a/jax/_src/third_party/scipy/signal_helper.py +++ b/jax/_src/third_party/scipy/signal_helper.py @@ -1,7 +1,5 @@ """Utility functions adopted from scipy.signal.""" -from __future__ import annotations - from typing import Any import warnings diff --git a/jax/_src/tpu_custom_call.py b/jax/_src/tpu_custom_call.py index 14721cea7682..8ebf88e958e3 100644 --- a/jax/_src/tpu_custom_call.py +++ b/jax/_src/tpu_custom_call.py @@ -15,7 +15,6 @@ """JAX bindings for Mosaic.""" # mypy: ignore-errors -from __future__ import annotations import base64 import collections.abc diff --git a/jax/_src/traceback_util.py b/jax/_src/traceback_util.py index d66cbb912a99..7756ec6fe094 100644 --- a/jax/_src/traceback_util.py +++ b/jax/_src/traceback_util.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import annotations - from collections.abc import Callable import functools import os diff --git a/jax/_src/tree.py b/jax/_src/tree.py index 49faaa774ef2..b06d50c1d868 100644 --- a/jax/_src/tree.py +++ b/jax/_src/tree.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import annotations from collections.abc import Callable, Iterable from typing import Any, TypeVar, overload diff --git a/jax/_src/tree_util.py b/jax/_src/tree_util.py index 32f59b1df36e..53a2e8f528f7 100644 --- a/jax/_src/tree_util.py +++ b/jax/_src/tree_util.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import annotations import collections from collections.abc import Callable, Hashable, Iterable, Sequence diff --git a/jax/_src/typing.py b/jax/_src/typing.py index 353f63f2a86d..ddc0e86d59ed 100644 --- a/jax/_src/typing.py +++ b/jax/_src/typing.py @@ -24,8 +24,6 @@ https://github.com/google/jax/pull/11859/. """ -from __future__ import annotations - from collections.abc import Sequence from typing import Any, Protocol, Union import numpy as np diff --git a/jax/_src/util.py b/jax/_src/util.py index 7aab80b2def3..dad3aaabda02 100644 --- a/jax/_src/util.py +++ b/jax/_src/util.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import annotations - import abc from collections.abc import Callable, Iterable, Iterator, Sequence import dataclasses diff --git a/jax/_src/xla_bridge.py b/jax/_src/xla_bridge.py index 41fd2c586593..3a4481a9d406 100644 --- a/jax/_src/xla_bridge.py +++ b/jax/_src/xla_bridge.py @@ -18,7 +18,6 @@ and provide some automatic type mapping logic for converting between Numpy and XLA. There are also a handful of related casting utilities. """ -from __future__ import annotations import atexit from collections.abc import Callable, Mapping diff --git a/jax/collect_profile.py b/jax/collect_profile.py index a7777085ce90..22ec831f8d98 100644 --- a/jax/collect_profile.py +++ b/jax/collect_profile.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import annotations - import argparse import gzip import os diff --git a/jax/example_libraries/optimizers.py b/jax/example_libraries/optimizers.py index 71680ca61b96..674419ad66d4 100644 --- a/jax/example_libraries/optimizers.py +++ b/jax/example_libraries/optimizers.py @@ -89,8 +89,6 @@ def step(step, opt_state): .. _Optax: https://github.com/deepmind/optax """ -from __future__ import annotations - from collections.abc import Callable from typing import Any, NamedTuple diff --git a/jax/experimental/array_api/__init__.py b/jax/experimental/array_api/__init__.py index 4240b9d5100c..0d1ce6f7ca26 100644 --- a/jax/experimental/array_api/__init__.py +++ b/jax/experimental/array_api/__init__.py @@ -34,8 +34,6 @@ .. _Python array API standard: https://data-apis.org/array-api/latest/ """ -from __future__ import annotations - from jax.experimental.array_api._version import __array_api_version__ as __array_api_version__ from jax.experimental.array_api import fft as fft diff --git a/jax/experimental/array_api/_array_methods.py b/jax/experimental/array_api/_array_methods.py index 2b071db573a8..bbdcc78f2fcc 100644 --- a/jax/experimental/array_api/_array_methods.py +++ b/jax/experimental/array_api/_array_methods.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import annotations - from typing import Any import jax diff --git a/jax/experimental/array_api/_creation_functions.py b/jax/experimental/array_api/_creation_functions.py index 99b8e3ed4465..68ec122a9f1f 100644 --- a/jax/experimental/array_api/_creation_functions.py +++ b/jax/experimental/array_api/_creation_functions.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import annotations - import jax import jax.numpy as jnp diff --git a/jax/experimental/array_api/_data_type_functions.py b/jax/experimental/array_api/_data_type_functions.py index 248c1c6dd0fe..881b1ca005b7 100644 --- a/jax/experimental/array_api/_data_type_functions.py +++ b/jax/experimental/array_api/_data_type_functions.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import annotations - import builtins from typing import NamedTuple import numpy as np diff --git a/jax/experimental/array_api/_manipulation_functions.py b/jax/experimental/array_api/_manipulation_functions.py index c364b9f5b79c..8ea8416448c2 100644 --- a/jax/experimental/array_api/_manipulation_functions.py +++ b/jax/experimental/array_api/_manipulation_functions.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import annotations - import jax from jax import Array diff --git a/jax/experimental/array_api/_utility_functions.py b/jax/experimental/array_api/_utility_functions.py index f75b2e2e29af..4276b62a611c 100644 --- a/jax/experimental/array_api/_utility_functions.py +++ b/jax/experimental/array_api/_utility_functions.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import annotations - import jax from jax._src.sharding import Sharding from jax._src.lib import xla_client as xc diff --git a/jax/experimental/array_serialization/serialization.py b/jax/experimental/array_serialization/serialization.py index c7aa8b590412..cb4e66fce5ee 100644 --- a/jax/experimental/array_serialization/serialization.py +++ b/jax/experimental/array_serialization/serialization.py @@ -13,8 +13,6 @@ # limitations under the License. """Array serialization and deserialization.""" -from __future__ import annotations - import abc import asyncio from collections.abc import Awaitable, Callable, Sequence diff --git a/jax/experimental/attrs.py b/jax/experimental/attrs.py index 8176465c1470..af36561c0c5d 100644 --- a/jax/experimental/attrs.py +++ b/jax/experimental/attrs.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import annotations - from contextlib import contextmanager from typing import Any diff --git a/jax/experimental/custom_partitioning.py b/jax/experimental/custom_partitioning.py index aa138fe88993..779f92f49c1f 100644 --- a/jax/experimental/custom_partitioning.py +++ b/jax/experimental/custom_partitioning.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import annotations - from functools import partial import inspect from typing import Optional diff --git a/jax/experimental/host_callback.py b/jax/experimental/host_callback.py index deaac9c72c8b..c6fad104f2fe 100644 --- a/jax/experimental/host_callback.py +++ b/jax/experimental/host_callback.py @@ -499,8 +499,6 @@ def power3_with_cotangents(x): """ -from __future__ import annotations - import atexit import enum from collections.abc import Callable, Sequence diff --git a/jax/experimental/jax2tf/call_tf.py b/jax/experimental/jax2tf/call_tf.py index adf43b6b94c0..588e0df9da9b 100644 --- a/jax/experimental/jax2tf/call_tf.py +++ b/jax/experimental/jax2tf/call_tf.py @@ -23,8 +23,6 @@ """ -from __future__ import annotations - from collections.abc import Callable, Sequence import dataclasses import functools diff --git a/jax/experimental/jax2tf/examples/mnist_lib.py b/jax/experimental/jax2tf/examples/mnist_lib.py index 41173c79a5b9..6bb8c161ee2b 100644 --- a/jax/experimental/jax2tf/examples/mnist_lib.py +++ b/jax/experimental/jax2tf/examples/mnist_lib.py @@ -19,8 +19,6 @@ See README.md for how these are used. """ -from __future__ import annotations - from collections.abc import Callable, Sequence import functools import logging diff --git a/jax/experimental/jax2tf/examples/saved_model_lib.py b/jax/experimental/jax2tf/examples/saved_model_lib.py index 8f2f0982fd3d..665d054446af 100644 --- a/jax/experimental/jax2tf/examples/saved_model_lib.py +++ b/jax/experimental/jax2tf/examples/saved_model_lib.py @@ -24,8 +24,6 @@ customize this function as needed. """ -from __future__ import annotations - from collections.abc import Callable, Sequence from typing import Any diff --git a/jax/experimental/jax2tf/impl_no_xla.py b/jax/experimental/jax2tf/impl_no_xla.py index 5ecde602cdaa..426bd12a26b3 100644 --- a/jax/experimental/jax2tf/impl_no_xla.py +++ b/jax/experimental/jax2tf/impl_no_xla.py @@ -13,8 +13,6 @@ # limitations under the License. """Workarounds for jax2tf transforms when XLA is not linked in.""" -from __future__ import annotations - import builtins from collections.abc import Callable, Sequence import dataclasses diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index 5f3230599a25..4a4114d87fbe 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -13,8 +13,6 @@ # limitations under the License. """Provides JAX and TensorFlow interoperation APIs.""" -from __future__ import annotations - from collections.abc import Callable, Iterable, Sequence from functools import partial import contextlib diff --git a/jax/experimental/jax2tf/tests/back_compat_tf_test.py b/jax/experimental/jax2tf/tests/back_compat_tf_test.py index 7f903b70d987..128247273db5 100644 --- a/jax/experimental/jax2tf/tests/back_compat_tf_test.py +++ b/jax/experimental/jax2tf/tests/back_compat_tf_test.py @@ -17,8 +17,6 @@ these tests. """ -from __future__ import annotations - import base64 from collections.abc import Callable, Sequence import io diff --git a/jax/experimental/jax2tf/tests/cross_compilation_check.py b/jax/experimental/jax2tf/tests/cross_compilation_check.py index cc34d78e88d4..460e8638a534 100644 --- a/jax/experimental/jax2tf/tests/cross_compilation_check.py +++ b/jax/experimental/jax2tf/tests/cross_compilation_check.py @@ -24,8 +24,6 @@ """ -from __future__ import annotations - from collections.abc import Callable, Sequence import contextlib import dataclasses diff --git a/jax/experimental/jax2tf/tests/flax_models/bilstm_classifier.py b/jax/experimental/jax2tf/tests/flax_models/bilstm_classifier.py index 5b1169224ed9..c4d552ae5e95 100644 --- a/jax/experimental/jax2tf/tests/flax_models/bilstm_classifier.py +++ b/jax/experimental/jax2tf/tests/flax_models/bilstm_classifier.py @@ -16,8 +16,6 @@ https://github.com/google/flax/tree/main/examples/sst2 """ -from __future__ import annotations - from collections.abc import Callable import functools from typing import Any diff --git a/jax/experimental/jax2tf/tests/flax_models/transformer_lm1b.py b/jax/experimental/jax2tf/tests/flax_models/transformer_lm1b.py index 27535c784e89..e933730ed3b6 100644 --- a/jax/experimental/jax2tf/tests/flax_models/transformer_lm1b.py +++ b/jax/experimental/jax2tf/tests/flax_models/transformer_lm1b.py @@ -22,8 +22,6 @@ # pytype: disable=wrong-keyword-args # pytype: disable=attribute-error -from __future__ import annotations - from collections.abc import Callable from typing import Any diff --git a/jax/experimental/jax2tf/tests/flax_models/transformer_nlp_seq.py b/jax/experimental/jax2tf/tests/flax_models/transformer_nlp_seq.py index cc78b5a41496..7ec32832eaa3 100644 --- a/jax/experimental/jax2tf/tests/flax_models/transformer_nlp_seq.py +++ b/jax/experimental/jax2tf/tests/flax_models/transformer_nlp_seq.py @@ -16,8 +16,6 @@ https://github.com/google/flax/tree/main/examples/lm1b """ -from __future__ import annotations - from collections.abc import Callable from typing import Any diff --git a/jax/experimental/jax2tf/tests/flax_models/transformer_wmt.py b/jax/experimental/jax2tf/tests/flax_models/transformer_wmt.py index 1cdeffeb6ea9..21e3378badb5 100644 --- a/jax/experimental/jax2tf/tests/flax_models/transformer_wmt.py +++ b/jax/experimental/jax2tf/tests/flax_models/transformer_wmt.py @@ -22,8 +22,6 @@ # pytype: disable=wrong-keyword-args # pytype: disable=attribute-error -from __future__ import annotations - from collections.abc import Callable from typing import Any diff --git a/jax/experimental/jax2tf/tests/jax2tf_limitations.py b/jax/experimental/jax2tf/tests/jax2tf_limitations.py index 03e6086a4924..e9c2c86531e2 100644 --- a/jax/experimental/jax2tf/tests/jax2tf_limitations.py +++ b/jax/experimental/jax2tf/tests/jax2tf_limitations.py @@ -13,8 +13,6 @@ # limitations under the License. """See primitives_test docstring for how the Jax2TfLimitations are used.""" -from __future__ import annotations - from collections.abc import Callable, Sequence import itertools from typing import Any diff --git a/jax/experimental/jax2tf/tests/model_harness.py b/jax/experimental/jax2tf/tests/model_harness.py index 91aacf2f596f..88ca99ed1352 100644 --- a/jax/experimental/jax2tf/tests/model_harness.py +++ b/jax/experimental/jax2tf/tests/model_harness.py @@ -13,8 +13,6 @@ # limitations under the License. """All the models to convert.""" -from __future__ import annotations - from collections.abc import Callable, Sequence import dataclasses import functools diff --git a/jax/experimental/jax2tf/tests/shape_poly_test.py b/jax/experimental/jax2tf/tests/shape_poly_test.py index 83aac43f2d9d..0db94889afdf 100644 --- a/jax/experimental/jax2tf/tests/shape_poly_test.py +++ b/jax/experimental/jax2tf/tests/shape_poly_test.py @@ -13,8 +13,6 @@ # limitations under the License. """Tests for the shape-polymorphic jax2tf conversion.""" -from __future__ import annotations - from collections.abc import Callable, Sequence import contextlib import math diff --git a/jax/experimental/jax2tf/tests/tf_test_util.py b/jax/experimental/jax2tf/tests/tf_test_util.py index 32f89e533daf..239cd920dcc4 100644 --- a/jax/experimental/jax2tf/tests/tf_test_util.py +++ b/jax/experimental/jax2tf/tests/tf_test_util.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import annotations - from collections.abc import Callable, Sequence import contextlib import dataclasses diff --git a/jax/experimental/key_reuse/_core.py b/jax/experimental/key_reuse/_core.py index b4989e151a53..eaef7686771e 100644 --- a/jax/experimental/key_reuse/_core.py +++ b/jax/experimental/key_reuse/_core.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import annotations - from collections import defaultdict from collections.abc import Callable, Iterator from functools import partial, reduce, total_ordering, wraps diff --git a/jax/experimental/mesh_utils.py b/jax/experimental/mesh_utils.py index dd112db3b269..d95452e197de 100644 --- a/jax/experimental/mesh_utils.py +++ b/jax/experimental/mesh_utils.py @@ -14,8 +14,6 @@ # ============================================================================== """Utils for building a device mesh.""" -from __future__ import annotations - import collections from collections.abc import Callable, Generator, MutableMapping, Sequence import itertools diff --git a/jax/experimental/multihost_utils.py b/jax/experimental/multihost_utils.py index b3665b5845b7..9920ea6abd9f 100644 --- a/jax/experimental/multihost_utils.py +++ b/jax/experimental/multihost_utils.py @@ -13,8 +13,6 @@ # limitations under the License. """Utilities for synchronizing and communication across multiple hosts.""" -from __future__ import annotations - from functools import partial, lru_cache from typing import Optional import zlib diff --git a/jax/experimental/pallas/ops/gpu/attention.py b/jax/experimental/pallas/ops/gpu/attention.py index 6df8f94af2eb..5a7579149549 100644 --- a/jax/experimental/pallas/ops/gpu/attention.py +++ b/jax/experimental/pallas/ops/gpu/attention.py @@ -13,7 +13,6 @@ # limitations under the License. """Module containing fused attention forward and backward pass.""" -from __future__ import annotations import functools from typing import Any, Optional diff --git a/jax/experimental/pallas/ops/gpu/decode_attention.py b/jax/experimental/pallas/ops/gpu/decode_attention.py index 34ef9872977b..85437243f0e8 100644 --- a/jax/experimental/pallas/ops/gpu/decode_attention.py +++ b/jax/experimental/pallas/ops/gpu/decode_attention.py @@ -13,7 +13,6 @@ # limitations under the License. """Module containing decode attention.""" -from __future__ import annotations import functools from typing import Any diff --git a/jax/experimental/pallas/ops/gpu/layer_norm.py b/jax/experimental/pallas/ops/gpu/layer_norm.py index 269f29dc71b7..7193755cfafd 100644 --- a/jax/experimental/pallas/ops/gpu/layer_norm.py +++ b/jax/experimental/pallas/ops/gpu/layer_norm.py @@ -14,8 +14,6 @@ """Module containing fused layer norm forward and backward pass.""" -from __future__ import annotations - import functools from typing import Optional diff --git a/jax/experimental/pallas/ops/gpu/rms_norm.py b/jax/experimental/pallas/ops/gpu/rms_norm.py index e1dfa3c5b9b7..d60522d7b472 100644 --- a/jax/experimental/pallas/ops/gpu/rms_norm.py +++ b/jax/experimental/pallas/ops/gpu/rms_norm.py @@ -14,8 +14,6 @@ """Module containing rms forward and backward pass.""" -from __future__ import annotations - import functools from typing import Optional diff --git a/jax/experimental/pallas/ops/tpu/all_gather.py b/jax/experimental/pallas/ops/tpu/all_gather.py index e121db894122..88921afee4d5 100644 --- a/jax/experimental/pallas/ops/tpu/all_gather.py +++ b/jax/experimental/pallas/ops/tpu/all_gather.py @@ -22,7 +22,7 @@ chunk, splits it in two, and sends each of the half-chunks in each direction (left and right) until every device has received the half chunks. """ -from __future__ import annotations + import functools from collections.abc import Sequence diff --git a/jax/experimental/pallas/ops/tpu/flash_attention.py b/jax/experimental/pallas/ops/tpu/flash_attention.py index f3b09c96486b..1aaf597af908 100644 --- a/jax/experimental/pallas/ops/tpu/flash_attention.py +++ b/jax/experimental/pallas/ops/tpu/flash_attention.py @@ -13,7 +13,6 @@ # limitations under the License. """Flash Attention TPU kernel.""" -from __future__ import annotations import dataclasses import functools diff --git a/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py b/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py index a6c0715e6043..656d8aeaa896 100644 --- a/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py +++ b/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py @@ -14,8 +14,6 @@ """Implementation of Sparse Flash Attention, a.k.a. "Splash" attention.""" -from __future__ import annotations - from collections.abc import Callable, Mapping import dataclasses import enum diff --git a/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_mask.py b/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_mask.py index eab2a695dc02..f89d9ce2546d 100644 --- a/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_mask.py +++ b/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_mask.py @@ -14,8 +14,6 @@ """Mini-mask creation library.""" -from __future__ import annotations - from collections.abc import Callable, Sequence import dataclasses from typing import Any diff --git a/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_mask_info.py b/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_mask_info.py index 3c672b8dbe88..9b9ec45e1f44 100644 --- a/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_mask_info.py +++ b/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_mask_info.py @@ -13,7 +13,6 @@ # limitations under the License. """Mini-mask creation library.""" -from __future__ import annotations import collections from collections.abc import Callable diff --git a/jax/experimental/serialize_executable.py b/jax/experimental/serialize_executable.py index 1c0614da516f..56749e3a00ef 100644 --- a/jax/experimental/serialize_executable.py +++ b/jax/experimental/serialize_executable.py @@ -13,8 +13,6 @@ # limitations under the License. """Pickling support for precompiled binaries.""" -from __future__ import annotations - import pickle import io from typing import Optional, Union diff --git a/jax/experimental/shard_map.py b/jax/experimental/shard_map.py index 4957df4866f0..8d64026da36e 100644 --- a/jax/experimental/shard_map.py +++ b/jax/experimental/shard_map.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import annotations from collections.abc import Callable, Hashable, Sequence import enum diff --git a/jax/experimental/slab/djax.py b/jax/experimental/slab/djax.py index c989ede8663e..1078ff635b3d 100644 --- a/jax/experimental/slab/djax.py +++ b/jax/experimental/slab/djax.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import annotations - import collections from collections.abc import Callable from functools import partial diff --git a/jax/experimental/slab/slab.py b/jax/experimental/slab/slab.py index af7b079eeb7f..e796b4f972fd 100644 --- a/jax/experimental/slab/slab.py +++ b/jax/experimental/slab/slab.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import annotations - from collections.abc import Iterable, Sequence from functools import partial, reduce import sys diff --git a/jax/experimental/sparse/ad.py b/jax/experimental/sparse/ad.py index 2c235c9320d5..d48e5fc5c4ad 100644 --- a/jax/experimental/sparse/ad.py +++ b/jax/experimental/sparse/ad.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import annotations - from collections.abc import Callable, Sequence import itertools from typing import Any diff --git a/jax/experimental/sparse/api.py b/jax/experimental/sparse/api.py index 5e64e1e14910..9f4d134a0fea 100644 --- a/jax/experimental/sparse/api.py +++ b/jax/experimental/sparse/api.py @@ -30,8 +30,6 @@ (API should be considered unstable and subject to change). """ -from __future__ import annotations - from functools import partial import operator from typing import Optional, Union diff --git a/jax/experimental/sparse/bcoo.py b/jax/experimental/sparse/bcoo.py index 4cbe52383751..39e24c2d3a33 100644 --- a/jax/experimental/sparse/bcoo.py +++ b/jax/experimental/sparse/bcoo.py @@ -13,7 +13,6 @@ # limitations under the License. """BCOO (Bached coordinate format) matrix object and associated primitives.""" -from __future__ import annotations from collections.abc import Sequence import functools diff --git a/jax/experimental/sparse/bcsr.py b/jax/experimental/sparse/bcsr.py index b831163e1497..ea7ccb81c240 100644 --- a/jax/experimental/sparse/bcsr.py +++ b/jax/experimental/sparse/bcsr.py @@ -13,7 +13,6 @@ # limitations under the License. """BCSR (Bached compressed row) matrix object and associated primitives.""" -from __future__ import annotations from collections.abc import Sequence from functools import partial diff --git a/jax/experimental/sparse/coo.py b/jax/experimental/sparse/coo.py index 8863478df4d3..0f4b312fbaf7 100644 --- a/jax/experimental/sparse/coo.py +++ b/jax/experimental/sparse/coo.py @@ -13,7 +13,6 @@ # limitations under the License. """COO (coordinate format) matrix object and associated primitives.""" -from __future__ import annotations from collections.abc import Sequence from functools import partial diff --git a/jax/experimental/sparse/csr.py b/jax/experimental/sparse/csr.py index c1178943c02a..2d907baf9075 100644 --- a/jax/experimental/sparse/csr.py +++ b/jax/experimental/sparse/csr.py @@ -13,7 +13,6 @@ # limitations under the License. """CSR (compressed sparse row) matrix object and associated primitives.""" -from __future__ import annotations from functools import partial import operator diff --git a/jax/experimental/sparse/linalg.py b/jax/experimental/sparse/linalg.py index b0ac1fa5d380..0454eba54389 100644 --- a/jax/experimental/sparse/linalg.py +++ b/jax/experimental/sparse/linalg.py @@ -14,8 +14,6 @@ """Sparse linear algebra routines.""" -from __future__ import annotations - from collections.abc import Callable import functools diff --git a/jax/experimental/sparse/test_util.py b/jax/experimental/sparse/test_util.py index 365c436521b8..d95ab1c79b01 100644 --- a/jax/experimental/sparse/test_util.py +++ b/jax/experimental/sparse/test_util.py @@ -13,8 +13,6 @@ # limitations under the License. """Sparse test utilities.""" -from __future__ import annotations - from collections.abc import Callable, Iterable, Iterator, Sequence import functools import itertools diff --git a/jax/experimental/sparse/transform.py b/jax/experimental/sparse/transform.py index 86eb8a9aefe8..46d27ec55d44 100644 --- a/jax/experimental/sparse/transform.py +++ b/jax/experimental/sparse/transform.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import annotations """ Sparsify transform diff --git a/jax/experimental/topologies.py b/jax/experimental/topologies.py index 7866564e9c01..4e27033d95ae 100644 --- a/jax/experimental/topologies.py +++ b/jax/experimental/topologies.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import annotations - import abc from collections.abc import Sequence from typing import Optional diff --git a/jax/interpreters/ad.py b/jax/interpreters/ad.py index 6663df3ac473..2ad2b7e3dc7e 100644 --- a/jax/interpreters/ad.py +++ b/jax/interpreters/ad.py @@ -15,8 +15,6 @@ # Note: import as is required for names to be exported. # See PEP 484 & https://github.com/google/jax/issues/7570 -from __future__ import annotations - from jax._src.interpreters.ad import ( CustomJVPException as CustomJVPException, CustomVJPException as CustomVJPException, diff --git a/jax/tools/build_utils.py b/jax/tools/build_utils.py index 84cc697d1894..e0256754cdd0 100644 --- a/jax/tools/build_utils.py +++ b/jax/tools/build_utils.py @@ -14,8 +14,6 @@ """Utilities for the building JAX related python packages.""" -from __future__ import annotations - import os import pathlib import platform diff --git a/jax/version.py b/jax/version.py index f3d007eec9b1..c9115c54bd13 100644 --- a/jax/version.py +++ b/jax/version.py @@ -14,7 +14,6 @@ # This file is included as part of both jax and jaxlib. It is also # eval()-ed by setup.py, so it should not have any dependencies. -from __future__ import annotations import datetime import os diff --git a/jaxlib/gpu_prng.py b/jaxlib/gpu_prng.py index 12dcacebfa46..38b0c02b1fae 100644 --- a/jaxlib/gpu_prng.py +++ b/jaxlib/gpu_prng.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import annotations - import functools from functools import partial import importlib diff --git a/jaxlib/hlo_helpers.py b/jaxlib/hlo_helpers.py index 0d57a04f1aa7..4d60bf07b344 100644 --- a/jaxlib/hlo_helpers.py +++ b/jaxlib/hlo_helpers.py @@ -14,8 +14,6 @@ """A small library of helpers for use in jaxlib to build MLIR operations.""" -from __future__ import annotations - from collections.abc import Callable, Sequence from functools import partial from typing import Union diff --git a/jaxlib/triton/dialect.py b/jaxlib/triton/dialect.py index 1bbb565b69b2..ee2a0ad71b52 100644 --- a/jaxlib/triton/dialect.py +++ b/jaxlib/triton/dialect.py @@ -16,8 +16,6 @@ """Python bindings for the MLIR Triton dialect.""" -from __future__ import annotations - from collections.abc import Sequence from jaxlib.mlir._mlir_libs._triton_ext import ( diff --git a/tests/api_test.py b/tests/api_test.py index 3929a29f9c30..00b7621626b5 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import annotations - import collections import collections.abc from collections.abc import Callable diff --git a/tests/array_api_test.py b/tests/array_api_test.py index dcb33b9bc57f..7f83d2694ac1 100644 --- a/tests/array_api_test.py +++ b/tests/array_api_test.py @@ -17,7 +17,6 @@ The full test suite for the array API is run via the array-api-tests CI; this is just a minimal smoke test to catch issues early. """ -from __future__ import annotations from types import ModuleType diff --git a/tests/attrs_test.py b/tests/attrs_test.py index 5c834f314270..33a8edf9783b 100644 --- a/tests/attrs_test.py +++ b/tests/attrs_test.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import annotations - from dataclasses import dataclass from absl.testing import absltest diff --git a/tests/batching_test.py b/tests/batching_test.py index 4d912bfca206..2999ca2a399c 100644 --- a/tests/batching_test.py +++ b/tests/batching_test.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import annotations - from collections.abc import Callable from contextlib import contextmanager from functools import partial diff --git a/tests/export_harnesses_multi_platform_test.py b/tests/export_harnesses_multi_platform_test.py index 035905d3f9e5..d06a86486acb 100644 --- a/tests/export_harnesses_multi_platform_test.py +++ b/tests/export_harnesses_multi_platform_test.py @@ -19,8 +19,6 @@ cross-platform lowering is tested in export_test.py. """ -from __future__ import annotations - from collections.abc import Callable import math import re diff --git a/tests/export_test.py b/tests/export_test.py index 8ccf3bb35849..785588f32a05 100644 --- a/tests/export_test.py +++ b/tests/export_test.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import annotations from collections.abc import Callable, Sequence import contextlib diff --git a/tests/host_callback_test.py b/tests/host_callback_test.py index 1ad59103c3e7..55e4f099af0a 100644 --- a/tests/host_callback_test.py +++ b/tests/host_callback_test.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import annotations - import contextlib from collections.abc import Callable, Sequence from functools import partial diff --git a/tests/lax_metal_test.py b/tests/lax_metal_test.py index dab26d86c0a2..66bb1fc99b9b 100644 --- a/tests/lax_metal_test.py +++ b/tests/lax_metal_test.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import annotations - from array import array as make_python_array import collections import copy diff --git a/tests/lax_numpy_indexing_test.py b/tests/lax_numpy_indexing_test.py index d18244062da6..99c90f951181 100644 --- a/tests/lax_numpy_indexing_test.py +++ b/tests/lax_numpy_indexing_test.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import annotations - import enum from functools import partial import itertools diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index 76c9d3abf278..581dc7a6776b 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import annotations - from array import array as make_python_array import collections from collections.abc import Iterator diff --git a/tests/lax_test.py b/tests/lax_test.py index ce1a2d4ff897..d848d8d14ff0 100644 --- a/tests/lax_test.py +++ b/tests/lax_test.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import annotations from functools import partial import itertools diff --git a/tests/lax_vmap_test.py b/tests/lax_vmap_test.py index 37d51c04f8de..afcfccefdfd8 100644 --- a/tests/lax_vmap_test.py +++ b/tests/lax_vmap_test.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import annotations - from functools import partial import itertools import math diff --git a/tests/lru_cache_test.py b/tests/lru_cache_test.py index fb999cbef0cf..4dce1490fd2d 100644 --- a/tests/lru_cache_test.py +++ b/tests/lru_cache_test.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import annotations - import importlib.util import tempfile import time diff --git a/tests/mutable_array_test.py b/tests/mutable_array_test.py index fe7ddd618ee1..661f72bd5885 100644 --- a/tests/mutable_array_test.py +++ b/tests/mutable_array_test.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import annotations - from absl.testing import absltest from absl.testing import parameterized import numpy as np diff --git a/tests/pallas/all_gather_test.py b/tests/pallas/all_gather_test.py index 98b3e5b40135..7465887fe786 100644 --- a/tests/pallas/all_gather_test.py +++ b/tests/pallas/all_gather_test.py @@ -13,7 +13,6 @@ # limitations under the License. """Tests the simple all_gather kernel.""" -from __future__ import annotations from absl.testing import absltest import jax diff --git a/tests/pallas/indexing_test.py b/tests/pallas/indexing_test.py index c11fca350d0d..a6b8ad95e778 100644 --- a/tests/pallas/indexing_test.py +++ b/tests/pallas/indexing_test.py @@ -14,8 +14,6 @@ """Tests for Pallas indexing logic and abstractions.""" -from __future__ import annotations - import unittest from absl.testing import absltest diff --git a/tests/pallas/splash_attention_kernel_test.py b/tests/pallas/splash_attention_kernel_test.py index e6132a1966a3..d21e702bd11e 100644 --- a/tests/pallas/splash_attention_kernel_test.py +++ b/tests/pallas/splash_attention_kernel_test.py @@ -13,7 +13,6 @@ # limitations under the License. """Tests for splash_attention.""" -from __future__ import annotations from collections.abc import Callable import dataclasses diff --git a/tests/pallas/splash_attention_mask_test.py b/tests/pallas/splash_attention_mask_test.py index ce7d8fd09182..2285c3463353 100644 --- a/tests/pallas/splash_attention_mask_test.py +++ b/tests/pallas/splash_attention_mask_test.py @@ -13,7 +13,6 @@ # limitations under the License. """Tests for splash_attention_masks.""" -from __future__ import annotations from absl.testing import absltest from absl.testing import parameterized diff --git a/tests/pmap_test.py b/tests/pmap_test.py index 5e279a5e6daa..44e439d02ffa 100644 --- a/tests/pmap_test.py +++ b/tests/pmap_test.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import annotations - from concurrent.futures import ThreadPoolExecutor import contextlib from functools import partial diff --git a/tests/random_test.py b/tests/random_test.py index 2c45d60cc64d..71be90847684 100644 --- a/tests/random_test.py +++ b/tests/random_test.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import annotations - import copy import enum from functools import partial diff --git a/tests/shape_poly_test.py b/tests/shape_poly_test.py index f079d6753edd..20b7c487cc64 100644 --- a/tests/shape_poly_test.py +++ b/tests/shape_poly_test.py @@ -13,8 +13,6 @@ # limitations under the License. """Tests for the shape-polymorphic export.""" -from __future__ import annotations - import enum from collections.abc import Callable, Sequence import cProfile diff --git a/tests/shard_map_test.py b/tests/shard_map_test.py index ca9d813e2571..bc166b59e60a 100644 --- a/tests/shard_map_test.py +++ b/tests/shard_map_test.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import annotations - from collections.abc import Callable, Generator, Iterable, Iterator, Sequence import contextlib from functools import partial diff --git a/tests/state_test.py b/tests/state_test.py index b6dbb490b794..18a031379727 100644 --- a/tests/state_test.py +++ b/tests/state_test.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import annotations - from collections.abc import Callable, Sequence from functools import partial import itertools as it diff --git a/tests/typing_test.py b/tests/typing_test.py index 562c6c56d2d9..7df2ad2546fa 100644 --- a/tests/typing_test.py +++ b/tests/typing_test.py @@ -18,8 +18,6 @@ so it should be checked with pytype/mypy as well as being run with pytest. """ -from __future__ import annotations - from typing import Any, TYPE_CHECKING import jax diff --git a/tests/xmap_test.py b/tests/xmap_test.py index 428c7fc66801..f7dc10107bf3 100644 --- a/tests/xmap_test.py +++ b/tests/xmap_test.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import annotations - from collections.abc import Generator, Iterator import contextlib import functools