diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py index 68d10eace7d1..3d85a063b650 100644 --- a/jax/_src/lax/control_flow/loops.py +++ b/jax/_src/lax/control_flow/loops.py @@ -2449,14 +2449,14 @@ def register_lowering(fn, platform=None): if xla_extension_version >= 263: # In XLA, there's a rewriter for an O(N^2) reduce-window implementation. + # TODO(https://github.com/llvm/llvm-project/issues/91883): enable rewrite + # for CPU once the vectorizer crash is fixed.. for platform in ['cuda', 'rocm', 'tpu']: register_lowering( partial(cumred_reduce_window_impl, reduce_window_fn), platform ) - # TODO(https://github.com/llvm/llvm-project/issues/91883) Re-enable rewrite - # for CPU once the vectorizer crash is fixed.. - register_lowering(partial(associative_scan, reduce_fn), 'cpu') + register_lowering(partial(associative_scan, reduce_fn)) else: # Older XLA versions only have this rewrite for TPU. register_lowering(