From c3bc88d5e4e3b25d99b7aa9ec2b1c93a854d976c Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Thu, 16 May 2024 15:10:01 +0100 Subject: [PATCH] Bumped mypy to 1.10.0 and ruff to 0.4.4 --- .pre-commit-config.yaml | 6 +-- jax/_src/api_util.py | 2 +- jax/_src/cache_key.py | 3 +- jax/_src/compiler.py | 8 +-- jax/_src/core.py | 2 +- jax/_src/debugging.py | 6 +-- jax/_src/interpreters/mlir.py | 51 ++++++++++--------- jax/_src/interpreters/pxla.py | 6 +-- jax/_src/lax/control_flow/conditionals.py | 4 +- jax/_src/lax/lax.py | 2 +- jax/_src/lax/parallel.py | 2 +- jax/_src/lax/windowed_reductions.py | 4 +- jax/_src/maps.py | 4 +- jax/_src/numpy/lax_numpy.py | 2 +- jax/_src/pallas/mosaic/lowering.py | 6 ++- jax/_src/prng.py | 6 +-- jax/_src/scipy/ndimage.py | 2 +- jax/_src/scipy/signal.py | 2 +- jax/experimental/custom_partitioning.py | 2 +- jax/experimental/export/_export.py | 6 +-- .../splash_attention_mask_info.py | 8 +-- jax/experimental/rnn.py | 2 +- tests/shard_map_test.py | 4 +- 23 files changed, 75 insertions(+), 65 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 1903bc8fe37d..3405c5e9d76b 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -26,17 +26,17 @@ repos: files: \.py$ - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.3.5 + rev: v0.4.4 hooks: - id: ruff - repo: https://github.com/pre-commit/mirrors-mypy - rev: 'v1.9.0' + rev: 'v1.10.0' hooks: - id: mypy files: (jax/|tests/typing_test\.py) exclude: jax/_src/basearray.py|jax/numpy/__init__.py # Use pyi instead - additional_dependencies: [types-requests==2.31.0, jaxlib==0.4.23, ml_dtypes==0.3.2, numpy==1.26.3, scipy==1.11.4] + additional_dependencies: [types-requests==2.31.0, jaxlib==0.4.27, ml_dtypes==0.3.2, numpy==1.26.3, scipy==1.11.4] args: [--config=pyproject.toml] - repo: https://github.com/mwouts/jupytext diff --git a/jax/_src/api_util.py b/jax/_src/api_util.py index 657f1e11357a..e9cfc4f3bce2 100644 --- a/jax/_src/api_util.py +++ b/jax/_src/api_util.py @@ -672,7 +672,7 @@ def result_paths(*args, **kwargs): yield ans, [keystr(path) for path, _ in generate_key_paths(ans)] def jaxpr_debug_info(jaxpr: core.Jaxpr, trace_debug: TracingDebugInfo | None, - result_paths: tuple[str | None, ...] | None = None, + result_paths: tuple[str, ...] | None = None, ) -> core.Jaxpr: """Add debug info to jaxpr, given trace-time debug info and result paths.""" if trace_debug is None: diff --git a/jax/_src/cache_key.py b/jax/_src/cache_key.py index 292fa24159de..6fdf0c600b7d 100644 --- a/jax/_src/cache_key.py +++ b/jax/_src/cache_key.py @@ -18,6 +18,7 @@ import logging import os import sys +from typing import cast as type_cast from jax._src import config from jax._src.lib import version_str as jaxlib_version_str @@ -136,7 +137,7 @@ def _serialize_ir(m: ir.Module) -> bytes: def _canonicalize_ir(m_original: ir.Module) -> bytes: with m_original.context: - m = m_original.operation.clone() + m = type_cast(ir.Module, m_original.operation.clone()) passes = pm.PassManager.parse( "builtin.module(strip-debuginfo)" ) diff --git a/jax/_src/compiler.py b/jax/_src/compiler.py index 6c6f983e7cf9..246a87ddf6fb 100644 --- a/jax/_src/compiler.py +++ b/jax/_src/compiler.py @@ -217,9 +217,11 @@ def backend_compile( options: xc.CompileOptions, host_callbacks: Sequence[Any], ) -> xc.LoadedExecutable: - # Convert ir.Module to a string representation, unless the - # back-end expliclity flags the ability to handle a module directly - # (avoiding the overhead of back and forth conversions) + # Convert ir.Module to a string representation, unless the backend + # explicitly flags the ability to handle a module directly (avoiding the + # overhead of back and forth conversions). + # TODO(slebedev): Change the backend.compile() to accept ir.Module. + built_c: Any if getattr(backend, "needs_str_ir", True): built_c = mlir.module_to_bytecode(module) else: diff --git a/jax/_src/core.py b/jax/_src/core.py index e34ddb8cf3e6..dbc1dbae8c57 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -79,7 +79,7 @@ class JaxprDebugInfo(NamedTuple): traced_for: str # e.g. 'jit', 'scan', etc func_src_info: str # e.g. f'{fun.__name__} at {filename}:{lineno}' arg_names: tuple[str | None, ...] # e.g. ('args[0]', ... ) - result_paths: tuple[str | None, ...] # e.g. ('[0]', '[1]', ...) + result_paths: tuple[str, ...] # e.g. ('[0]', '[1]', ...) class Jaxpr: __slots__ = ['__weakref__', '_constvars', '_invars', '_outvars', '_eqns', diff --git a/jax/_src/debugging.py b/jax/_src/debugging.py index 0515ab37786d..7c5270efbacf 100644 --- a/jax/_src/debugging.py +++ b/jax/_src/debugging.py @@ -384,7 +384,7 @@ def _hlo_sharding_callback(hlo_sharding: xc.HloSharding): has_side_effect=ir.BoolAttr.get(True), api_version=mlir.i32_attr(1), called_computations=ir.ArrayAttr.get([]), - backend_config=ir.StringAttr.get(key), + backend_config=ir.StringAttr.get(key), # type: ignore[arg-type] operand_layouts=None, result_layouts=None) return [] @@ -504,11 +504,11 @@ def visualize_sharding(shape: Sequence[int], sharding: Sharding, *, heights[chunk_idxs] = None widths[chunk_idxs] = horiz_size / shape[0] slices.setdefault(chunk_idxs, set()).add(dev.id) - num_rows = max([a[0] for a in slices.keys()]) + 1 + num_rows = max(a[0] for a in slices.keys()) + 1 if len(list(slices.keys())[0]) == 1: num_cols = 1 else: - num_cols = max([a[1] for a in slices.keys()]) + 1 + num_cols = max(a[1] for a in slices.keys()) + 1 color_iter = make_color_iter(color_map, num_rows, num_cols) table = rich.table.Table(show_header=False, show_lines=not use_color, diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index 2be855c6fae1..340e73bd303d 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -27,7 +27,7 @@ import re import types import typing -from typing import Any, Callable, NamedTuple, Protocol, Union +from typing import Any, Callable, NamedTuple, Protocol, Union, cast as type_cast import warnings import numpy as np @@ -87,19 +87,20 @@ # IR Helpers def dense_int_elements(xs) -> ir.DenseIntElementsAttr: - return ir.DenseIntElementsAttr.get(np.asarray(xs, np.int64)) + return type_cast(ir.DenseIntElementsAttr, + ir.DenseIntElementsAttr.get(np.asarray(xs, np.int64))) -def dense_int_array(xs) -> ir.DenseIntElementsAttr | ir.DenseI64ArrayAttr: +def dense_int_array(xs) -> ir.DenseElementsAttr | ir.DenseI64ArrayAttr: # TODO: b/321794305 - remove this check when jaxlib is on StableHLO API v5 or higher if hlo.get_api_version() < 5: return dense_int_elements(xs) - return ir.DenseI64ArrayAttr.get(np.asarray(xs, np.int64)) + return ir.DenseI64ArrayAttr.get(np.asarray(xs, np.int64)) # type: ignore # TODO: b/321794305 - delete this when jaxlib is on StableHLO API v6 or higher def dense_int_array_v6(xs) -> ir.DenseIntElementsAttr | ir.DenseI64ArrayAttr: if hlo.get_api_version() < 6: return dense_int_elements(xs) - return ir.DenseI64ArrayAttr.get(np.asarray(xs, np.int64)) + return ir.DenseI64ArrayAttr.get(np.asarray(xs, np.int64)) # type: ignore def dense_bool_elements(xs: Sequence[bool]) -> ir.DenseElementsAttr: a = np.packbits(np.array(xs, np.bool_), bitorder='little') @@ -114,7 +115,7 @@ def dense_bool_array(xs: Sequence[bool]) -> ir.DenseElementsAttr | ir.DenseBoolA # TODO: b/321794305 - remove this check when jaxlib is on StableHLO API v6 or higher if hlo.get_api_version() < 6: return dense_bool_elements(xs) - return ir.DenseBoolArrayAttr.get(xs) + return ir.DenseBoolArrayAttr.get(xs) # type: ignore def i32_attr(i): return ir.IntegerAttr.get(ir.IntegerType.get_signless(32), i) def i64_attr(i): return ir.IntegerAttr.get(ir.IntegerType.get_signless(64), i) @@ -132,7 +133,7 @@ def lower_dim(d): return hlo.reshape(int1d, d) ds = map(lower_dim, sizes) if not ds: - return ir_constant(np.array([], np.int32)) + return type_cast(ir.RankedTensorType, ir_constant(np.array([], np.int32))) elif len(ds) == 1: return ds[0] else: @@ -195,7 +196,7 @@ def _array_ir_types(aval: core.ShapedArray | core.DShapedArray aval = core.physical_aval(aval) # type: ignore if not core.is_constant_shape(aval.shape): return _dynamic_array_ir_types(aval) # type: ignore - return (ir.RankedTensorType.get(aval.shape, dtype_to_ir_type(aval.dtype)),) + return (ir.RankedTensorType.get(aval.shape, dtype_to_ir_type(aval.dtype)),) # type: ignore def _dynamic_array_ir_types(aval: core.ShapedArray) -> Sequence[ir.Type]: dyn_size = ir.ShapedType.get_dynamic_size() @@ -282,7 +283,7 @@ def _numpy_array_constant(x: np.ndarray | np.generic) -> Sequence[ir.Value]: if x.dtype == np.bool_: x = np.packbits(x, bitorder='little') # type: ignore x = np.ascontiguousarray(x) - attr = ir.DenseElementsAttr.get(x, type=element_type, shape=shape) + attr = ir.DenseElementsAttr.get(x, type=element_type, shape=shape) # type: ignore return (hlo.constant(attr),) @@ -314,11 +315,11 @@ def _ndarray_constant_handler(val: np.ndarray | np.generic) -> Sequence[ir.Value elif np.any(np.equal(0, val.strides)) and val.size > 0: zero_stride_axes, = np.where(np.equal(0, val.strides)) other_axes, = np.where(np.not_equal(0, val.strides)) - collapsed_val = val[tuple(0 if ax in zero_stride_axes else slice(None) # type: ignore - for ax in range(val.ndim))] # type: ignore + collapsed_val = val[tuple(0 if ax in zero_stride_axes else slice(None) # type: ignore + for ax in range(val.ndim))] out = hlo.broadcast_in_dim( ir.RankedTensorType.get( - val.shape, dtype_to_ir_type(collapsed_val.dtype)), + val.shape, dtype_to_ir_type(collapsed_val.dtype)), # type: ignore _numpy_array_constant(collapsed_val)[0], dense_int_array_v6(other_axes)) return (out,) @@ -738,7 +739,7 @@ def wrap_singleton_ir_values(x: ir.Value | Sequence[ir.Value] def flatten_lowering_ir_args( xs: Sequence[ir.Value | Sequence[ir.Value]] -) -> Sequence[Sequence[ir.Value]]: +) -> Sequence[ir.Value]: return util.flatten(map(wrap_singleton_ir_values, xs)) _module_name_regex = re.compile(r"[^\w.-]") @@ -863,7 +864,7 @@ def lower_jaxpr_to_module( in_layouts: Sequence[DeviceLocalLayout | None | AutoLayout] | None = None, out_layouts: Sequence[DeviceLocalLayout | None | AutoLayout] | None = None, arg_names: Sequence[str | None] | None = None, - result_names: Sequence[str | None] | None = None, + result_names: Sequence[str] | None = None, num_replicas: int = 1, num_partitions: int = 1, all_default_mem_kind: bool = True, @@ -1106,7 +1107,7 @@ def lower_jaxpr_to_fun( xla_donated_args: Sequence[bool] | None = None, api_name: str = "jit", arg_names: Sequence[str | None] | None = None, - result_names: Sequence[str | None] | None = None, + result_names: Sequence[str] | None = None, arg_memory_kinds: Sequence[str | None] | None = None, result_memory_kinds: Sequence[str | None] | None = None, arg_layouts: Sequence[DeviceLocalLayout | None | AutoLayout] | None = None, @@ -1618,7 +1619,7 @@ def lower_per_platform(ctx: LoweringRuleContext, default_rule: LoweringRule | None, effects: effects_lib.Effects, *rule_args: ir.Value, - **rule_kwargs) -> ir.Value: + **rule_kwargs) -> Sequence[ir.Value]: """Emits code for a primitive for the current lowering platform(s). For example, given @@ -2039,9 +2040,8 @@ def compare_hlo(x, y, direction: str, comparison_type: str | None = None): """Creates CompareOp.""" if comparison_type is None: elem_type = ir.RankedTensorType(x.type).element_type - if ir.IntegerType.isinstance(elem_type): - comparison_type = ("UNSIGNED" if ir.IntegerType.is_unsigned(elem_type) - else "SIGNED") + if isinstance(elem_type, ir.IntegerType): + comparison_type = "UNSIGNED" if elem_type.is_unsigned else "SIGNED" else: comparison_type = "FLOAT" @@ -2129,7 +2129,7 @@ def get_sharding_attr(sharding_proto: xc.OpSharding): # The MHLO to HLO conversion supports both, and the proto representation is # more compact. if len(sharding_proto.tile_assignment_devices) > 100: - return ir.StringAttr.get(sharding_proto.SerializeToString()) + return ir.StringAttr.get(sharding_proto.SerializeToString()) # type: ignore else: return ir.StringAttr.get(repr(xc.HloSharding.from_proto(sharding_proto))) @@ -2315,7 +2315,8 @@ def send_to_host(channel: int, token: hlo.TokenType, operand: Any, def receive_from_host(channel: int, token: hlo.TokenType, out_aval: core.ShapedArray, name: str, *, - sharding: xc.OpSharding | None = None) -> ir.Value: + sharding: xc.OpSharding | None = None, +) -> tuple[ir.Value, ir.Value]: channel_handle = hlo.ChannelHandle.get(channel, RECV_FROM_HOST_TYPE) recv_op = hlo.RecvOp([aval_to_ir_type(out_aval), hlo.TokenType.get()], token, channel_handle, @@ -2592,7 +2593,7 @@ def custom_call( if backend_config is None: backend_config_attr = ir.StringAttr.get("") elif isinstance(backend_config, (str, bytes)): - backend_config_attr = ir.StringAttr.get(backend_config) + backend_config_attr = ir.StringAttr.get(backend_config) # type: ignore elif isinstance(backend_config, dict): # TODO(necula): it seems that the CustomCallOp constructor requires that # backend_config_attr be a string attribute, even though in some cases we @@ -2661,8 +2662,8 @@ def custom_call( op = hlo.CustomCallOp.build_generic(results=result_types, operands=operands, attributes=attributes) if isinstance(backend_config, dict): - backend_config_attr = ir.DictAttr.get(backend_config) - op.operation.attributes["mhlo.backend_config"] = backend_config_attr + op.operation.attributes["mhlo.backend_config"] = ir.DictAttr.get( + backend_config) return op @@ -2721,7 +2722,7 @@ def prep_one_pad(pad_lo_hi: tuple[core.DimSize, core.DimSize]): base_dilations=dense_int_array_v6(base_dilation), window_dilations=dense_int_array_v6(window_dilation), padding=ir.DenseIntElementsAttr.get(np.asarray(padding, np.int64), - shape=(len(padding), 2))) + shape=[len(padding), 2])) reducer = rw.regions[0].blocks.append(*(scalar_types + scalar_types)) with ir.InsertionPoint(reducer): hlo.return_(reducer_body(reducer)) diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index e72cd07a8323..6e49ab3d61bd 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -2000,7 +2000,7 @@ def are_all_shardings_default_mem_kind(da_object, shardings): return False return True -memory_kind_propagate_rule = {} # type: ignore +memory_kind_propagate_rule: dict[Any, Any] = {} @weakref_lru_cache def get_out_memory_kinds_via_propagation(closed_jaxpr: core.ClosedJaxpr @@ -2386,10 +2386,10 @@ def lower_mesh_computation( all_args_info=None) class MeshComputation(stages.XlaLowering): - _hlo: ir.Module | None + _hlo: ir.Module _executable: MeshExecutable | None - def __init__(self, name: str, hlo: ir.Module | None, + def __init__(self, name: str, hlo: ir.Module, donated_invars: Sequence[bool], **compile_args): self._name = name self._hlo = hlo diff --git a/jax/_src/lax/control_flow/conditionals.py b/jax/_src/lax/control_flow/conditionals.py index 9ee151bafab2..ba957417db20 100644 --- a/jax/_src/lax/control_flow/conditionals.py +++ b/jax/_src/lax/control_flow/conditionals.py @@ -975,7 +975,9 @@ def _platform_index_lowering(ctx: mlir.LoweringRuleContext, *, platforms: Sequence[Sequence[str]], has_default: bool): - def lower_constant(ctx: mlir.LoweringRuleContext, *, i: int) -> mlir.ir.Value: + def lower_constant( + ctx: mlir.LoweringRuleContext, *, i: int + ) -> Sequence[ir.Value]: return mlir.ir_constants(np.int32(i)) platform_rules: dict[str, mlir.LoweringRule] = {} for i, ps in enumerate(platforms): diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index dd70621ace34..889de5f92b92 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -1774,7 +1774,7 @@ def broadcast_hlo( return out def _nary_lower_hlo(op: Callable, ctx, - *args: ir.Value | Sequence[ir.Value], + *args: ir.Value, explicit_type=False, **params) -> Sequence[ir.Value]: """Lowers an elementwise operator to its MLIR equivalent. diff --git a/jax/_src/lax/parallel.py b/jax/_src/lax/parallel.py index b9ea7345eea4..0c0a0e1c70ff 100644 --- a/jax/_src/lax/parallel.py +++ b/jax/_src/lax/parallel.py @@ -727,7 +727,7 @@ def _replica_groups(axis_env, axis_name, axis_index_groups): return replica_groups def _replica_groups_hlo(replica_groups: Sequence[Sequence[int]] - ) -> ir.DenseIntElementsAttr: + ) -> ir.DenseElementsAttr: # Uneven replica groups are padded with -1. groups = np.array(list(itertools.zip_longest(*replica_groups, fillvalue=-1)), dtype=np.int64).T diff --git a/jax/_src/lax/windowed_reductions.py b/jax/_src/lax/windowed_reductions.py index 3c1473b3f595..18db9c764903 100644 --- a/jax/_src/lax/windowed_reductions.py +++ b/jax/_src/lax/windowed_reductions.py @@ -902,7 +902,9 @@ def snd(t, t_aval): double_word_out_aval = out_aval.update(dtype=double_word_dtype) def reducer_body(reducer: ir.Block) -> Sequence[ir.Value]: - x, y = reducer.arguments + x: ir.Value + y: ir.Value + x, y = reducer.arguments # type: ignore assert select_prim is lax.ge_p or select_prim is lax.le_p cmp_op = "GE" if select_prim is lax.ge_p else "LE" out = hlo.SelectOp(mlir.compare_hlo(fst(x), fst(y), cmp_op), x, y) diff --git a/jax/_src/maps.py b/jax/_src/maps.py index c7224019032a..a541cf2af7ab 100644 --- a/jax/_src/maps.py +++ b/jax/_src/maps.py @@ -20,7 +20,7 @@ from functools import wraps, partial, partialmethod, lru_cache import itertools as it import math -from typing import Callable, Any, NamedTuple, Union +from typing import Callable, Any, NamedTuple, Union, cast as type_cast import numpy as np @@ -631,7 +631,7 @@ def lower(*args, **kwargs): no_kwargs=True) fun_mapped.lower = lower - return fun_mapped + return type_cast(stages.Wrapped, fun_mapped) def xmap_impl(fun: lu.WrappedFun, *args, name, in_axes, out_axes_thunk, donated_invars, global_axis_sizes, axis_resources, resource_env, backend, diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 1924128a9a25..6aa5bbdb9247 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -868,7 +868,7 @@ def gradient_along_axis(a, h, axis): if len(axis_tuple) == 0: return [] - if min([s for i, s in enumerate(a.shape) if i in axis_tuple]) < 2: + if min(s for i, s in enumerate(a.shape) if i in axis_tuple) < 2: raise ValueError("Shape of array too small to calculate " "a numerical gradient, " "at least 2 elements are required.") diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index e55dfb1594dd..aaaa7d9908a0 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -20,6 +20,8 @@ from typing import Any, Callable from collections.abc import Sequence +from jaxlib.mlir.ir import Module + import jax from jax import core as jax_core from jax import lax @@ -341,7 +343,7 @@ def lower_jaxpr_to_module( jaxpr: jax_core.Jaxpr, dimension_semantics: tuple[str | None, ...] | None, mesh: mesh_lib.Mesh | None = None -) -> ir.Module: +) -> tuple[Module, tuple[Any, ...]]: mosaic_grid_mapping = MosaicGridMapping( jaxpr, grid_mapping, dimension_semantics, mesh) mosaic_grid_mapping.maybe_compress_grid() @@ -2199,7 +2201,7 @@ def _device_id_to_logical( device_ids = tree_util.tree_leaves(device_id) mesh_strides = ctx.lowering_context.mesh_context.mesh_strides def _linearize_mesh_indices(*indices): - return sum([a * b for a, b in zip(indices, mesh_strides)]) + return sum(a * b for a, b in zip(indices, mesh_strides)) lower_ctx = LoweringRuleContext( lowering_context=ctx.lowering_context, avals_in=[jax_core.ShapedArray((), jnp.int32)] * len(device_ids), diff --git a/jax/_src/prng.py b/jax/_src/prng.py index 1d26342c9839..547806ab2609 100644 --- a/jax/_src/prng.py +++ b/jax/_src/prng.py @@ -1140,11 +1140,11 @@ def _mul(x: core.DimSize, y: ir.Value) -> ir.Value: if core.is_constant_dim(x): x_const = mlir.ir_constant(np.array(x, np.dtype('uint64'))) else: - x_const, = mlir.eval_dynamic_shape(ctx, (x,)) + x_shape, = mlir.eval_dynamic_shape(ctx, (x,)) x_const = hlo.convert( ir.RankedTensorType.get( - (), - mlir.dtype_to_ir_type(np.dtype('uint64'))), x_const) + [], + mlir.dtype_to_ir_type(np.dtype('uint64'))), x_shape) x_bcast = mlir.broadcast_in_dim(ctx, x_const, aval_u64, broadcast_dimensions=[]) return mlir.hlo.multiply(x_bcast, y) diff --git a/jax/_src/scipy/ndimage.py b/jax/_src/scipy/ndimage.py index 0b315097354b..4445a6130d06 100644 --- a/jax/_src/scipy/ndimage.py +++ b/jax/_src/scipy/ndimage.py @@ -116,7 +116,7 @@ def _map_coordinates(input: ArrayLike, coordinates: Sequence[ArrayLike], else: all_valid = functools.reduce(operator.and_, validities) contribution = jnp.where(all_valid, input_arr[indices], cval) - outputs.append(_nonempty_prod(weights) * contribution) + outputs.append(_nonempty_prod(weights) * contribution) # type: ignore result = _nonempty_sum(outputs) if jnp.issubdtype(input_arr.dtype, jnp.integer): result = _round_half_away_from_zero(result) diff --git a/jax/_src/scipy/signal.py b/jax/_src/scipy/signal.py index 27b30164ec40..8d4e8779c696 100644 --- a/jax/_src/scipy/signal.py +++ b/jax/_src/scipy/signal.py @@ -627,7 +627,7 @@ def pad(x, n, axis=-1): return jnp.zeros(x.shape, freq_dtype), jnp.zeros(x.shape, freq_dtype), jnp.zeros(x.shape, result_dtype) else: if x.size == 0 or y_arr.size == 0: - shape = tuple_insert(outershape, min([x.shape[axis], y_arr.shape[axis]]), axis) + shape = tuple_insert(outershape, min(x.shape[axis], y_arr.shape[axis]), axis) return jnp.zeros(shape, freq_dtype), jnp.zeros(shape, freq_dtype), jnp.zeros(shape, result_dtype) # Move time-axis to the end diff --git a/jax/experimental/custom_partitioning.py b/jax/experimental/custom_partitioning.py index e3db0f3112a1..f756f61651b4 100644 --- a/jax/experimental/custom_partitioning.py +++ b/jax/experimental/custom_partitioning.py @@ -536,7 +536,7 @@ def to_mesh_pspec_sharding(hlo_sharding: xc.HloSharding | None, ndim): mlir.register_lowering(custom_partitioning_p, _custom_partitioning_lowering_rule) -xc.register_custom_call_partitioner( # pytype: disable=module-attr +xc.register_custom_call_partitioner( # type: ignore # pytype: disable=module-attr _CUSTOM_PARTITIONING_CALL_NAME, _custom_partitioning_propagate_user_sharding, _custom_partitioning_partition, diff --git a/jax/experimental/export/_export.py b/jax/experimental/export/_export.py index 74fe69d6975f..be496aeb5a22 100644 --- a/jax/experimental/export/_export.py +++ b/jax/experimental/export/_export.py @@ -570,7 +570,7 @@ def _wrap_main_func( context = mlir.make_ir_context() with context, ir.Location.unknown(context): # Make a copy, do not mutate because it may be cached - wrapped_module = ir.Module.parse(mlir.module_to_bytecode(module)) + wrapped_module = ir.Module.parse(mlir.module_to_bytecode(module)) # type: ignore symbol_table = ir.SymbolTable(wrapped_module.operation) orig_main = symbol_table["main"] orig_main.attributes["sym_visibility"] = ir.StringAttr.get("private") @@ -1107,7 +1107,7 @@ def _call_exported_lowering(ctx: mlir.LoweringRuleContext, *args, args = tuple( wrap_with_sharding(ctx, x, x_aval, x_sharding) for x, x_aval, x_sharding in zip(args, ctx.avals_in, exported.in_shardings)) - submodule = ir.Module.parse(exported.mlir_module()) + submodule = ir.Module.parse(exported.mlir_module()) # type: ignore symtab = ir.SymbolTable(submodule.operation) # The called function may have been exported with polymorphic shapes and called # now with more refined shapes. We insert hlo.ConvertOp to ensure the module @@ -1126,7 +1126,7 @@ def convert_shape(x: ir.Value, x_aval: core.AbstractValue, new_aval: core.Abstra submodule, dst_symtab=ctx.module_context.symbol_table) - submodule_args = [] + submodule_args: list[ir.Value] = [] # All the platforms for the current lowering must be among the platforms # for which the callee was lowered. lowering_platforms = ctx.module_context.platforms diff --git a/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_mask_info.py b/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_mask_info.py index 051c367b2e2c..af046688067f 100644 --- a/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_mask_info.py +++ b/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_mask_info.py @@ -426,7 +426,7 @@ def assign_unique_ids(objects): # MaskInfo class and runtime overhead to perform an indirect lookup. Since # having multiple masks per head-shard is not a common case we leave this for # future work. - max_masks_per_head_shard = max([len(x) for x in head_shard_to_mask_ids]) + max_masks_per_head_shard = max(len(x) for x in head_shard_to_mask_ids) masks_per_head_shard = 1 if max_masks_per_head_shard == 1 else heads_per_shard unique_masks = [ @@ -697,7 +697,7 @@ def set_block_mask(mask_id: int, q_index: int, kv_index: int, value: int): # maintain the SPMD paradigm. padding_axis = 1 if is_dkv else 2 - max_size = max([x.shape[padding_axis] for x in block_mask_shards]) + max_size = max(x.shape[padding_axis] for x in block_mask_shards) padded_block_mask_shards = [] padded_data_next_shards = [] padded_mask_next_shards = [] @@ -791,7 +791,7 @@ def _shrink_mask_info( # Pad each row in the non-zero indices to match the width of the longest # row. This avoids having jagged rows. - max_non_zero_cols = max([len(x) for x in grouped_non_zero_cols]) + max_non_zero_cols = max(len(x) for x in grouped_non_zero_cols) padded_non_zero_cols = [] padding = -1 for row in grouped_non_zero_cols: @@ -856,7 +856,7 @@ def _shrink_mask_info_dkv( # Pad each col in the non-zero indices to match the height of the longest # col. This avoids having jagged cols. - max_non_zero_rows = max([len(x) for x in grouped_non_zero_rows]) + max_non_zero_rows = max(len(x) for x in grouped_non_zero_rows) padded_non_zero_rows = [] padding = -1 for col in grouped_non_zero_rows: diff --git a/jax/experimental/rnn.py b/jax/experimental/rnn.py index 961bee9b1ef1..6412010f4de8 100644 --- a/jax/experimental/rnn.py +++ b/jax/experimental/rnn.py @@ -162,7 +162,7 @@ def get_num_params_in_lstm(input_size: int, hidden_size: int, num_layers: int, """Get param count in LSTM.""" layer_shapes = _get_params_shapes_in_lstm(input_size, hidden_size, num_layers, bidirectional) - param_count = sum([math.prod(shape) for shape in layer_shapes]) + param_count = sum(math.prod(shape) for shape in layer_shapes) return param_count diff --git a/tests/shard_map_test.py b/tests/shard_map_test.py index 223a6d000aae..5de80d7c84bb 100644 --- a/tests/shard_map_test.py +++ b/tests/shard_map_test.py @@ -1483,7 +1483,7 @@ def f(x): e1, _, e2 = jaxpr.eqns self.assertLen(e1.outvars, 1) # only primal output self.assertLen(e2.invars, 2) # res and cotangent inputs - self.assertEqual(sum([e1.outvars[0] is v for v in e2.invars]), 1) + self.assertEqual(sum(e1.outvars[0] is v for v in e2.invars), 1) @parameterized.parameters(it.product([True, False], repeat=2)) def test_res_forwarding_optimization_complex(self, jit, remat): @@ -1506,7 +1506,7 @@ def f(x): e1, _, e2 = jaxpr.eqns self.assertLen(e1.outvars, 2) # one primal and one res output self.assertLen(e2.invars, 4) # two res and two cotangent inputs - self.assertEqual(sum([e1.outvars[-1] is v for v in e2.invars]), 1) + self.assertEqual(sum(e1.outvars[-1] is v for v in e2.invars), 1) @parameterized.parameters([True, False]) def test_check_rep_failure_inside_rule(self, jit):