From 569f92c9272924757acf590478adada79b478f85 Mon Sep 17 00:00:00 2001 From: Corey Adams Date: Thu, 7 Mar 2024 12:27:43 -0600 Subject: [PATCH] Enable automatic detection of distrbuted variables with any configuration of MPI, as long as mpi4py is available --- .bazelrc | 129 +- .bazelversion | 2 +- .github/workflows/ci-build.yaml | 85 +- .github/workflows/cloud-tpu-ci-nightly.yml | 14 +- .github/workflows/jax-array-api.yml | 6 +- .github/workflows/metal_plugin_ci.yml | 50 + .github/workflows/upstream-nightly.yml | 173 + .github/workflows/wheel_win_x64.yml | 19 +- .github/workflows/windows_ci.yml | 28 +- .pre-commit-config.yaml | 10 +- .readthedocs.yml | 4 +- CHANGELOG.md | 315 +- README.md | 21 +- WORKSPACE | 42 +- benchmarks/api_benchmark.py | 63 +- benchmarks/mosaic/BUILD | 56 + benchmarks/mosaic/matmul_bench.py | 110 + benchmarks/shape_poly_benchmark.py | 6 +- build/BUILD.bazel | 67 + build/build.py | 245 +- build/requirements.in | 24 + build/requirements_lock_3_10.txt | 624 ++ build/requirements_lock_3_11.txt | 613 ++ build/requirements_lock_3_12.txt | 613 ++ build/requirements_lock_3_13.txt | 106 + build/rocm/build_rocm.sh | 2 +- build/rocm/run_single_gpu.py | 6 +- build/test-requirements.txt | 2 + cloud_tpu_colabs/Pmap_Cookbook.ipynb | 2 +- cloud_tpu_colabs/README.md | 10 - docs/Custom_Operation_for_GPUs.md | 1876 +----- docs/Custom_Operation_for_GPUs.py | 527 ++ docs/_static/distributed_data_loading/1.svg | 1 + docs/_static/distributed_data_loading/10.svg | 1 + docs/_static/distributed_data_loading/11.svg | 1 + docs/_static/distributed_data_loading/12.svg | 1 + docs/_static/distributed_data_loading/13.svg | 1 + docs/_static/distributed_data_loading/14.svg | 1 + docs/_static/distributed_data_loading/15.svg | 1 + docs/_static/distributed_data_loading/16.svg | 1 + docs/_static/distributed_data_loading/17.svg | 1 + docs/_static/distributed_data_loading/18.svg | 1 + docs/_static/distributed_data_loading/19.svg | 1 + docs/_static/distributed_data_loading/2.svg | 1 + docs/_static/distributed_data_loading/20.svg | 1 + docs/_static/distributed_data_loading/21.svg | 1 + docs/_static/distributed_data_loading/22.svg | 1 + docs/_static/distributed_data_loading/3.svg | 1 + docs/_static/distributed_data_loading/4.svg | 1 + docs/_static/distributed_data_loading/5.svg | 1 + docs/_static/distributed_data_loading/6.svg | 1 + docs/_static/distributed_data_loading/7.svg | 1 + docs/_static/distributed_data_loading/8.svg | 1 + docs/_static/distributed_data_loading/9.svg | 1 + docs/_static/style.css | 14 + .../advanced-autodiff.md | 10 +- .../advanced-compilation.md | 4 +- docs/_tutorials/advanced-debugging.md | 25 + .../external-callbacks.md | 19 +- .../gradient-checkpointing.md | 4 +- docs/_tutorials/index.rst | 57 + .../jax-primitives.md | 8 +- docs/{tutorials => _tutorials}/jaxpr.md | 10 +- docs/{tutorials => _tutorials}/parallelism.md | 4 +- .../profiling-and-performance.md | 4 +- .../simple-neural-network.md | 4 +- docs/advanced_guide.rst | 1 + docs/aot.md | 73 +- docs/api_compatibility.md | 10 +- docs/async_dispatch.rst | 2 +- docs/autodidax.ipynb | 28 +- docs/autodidax.md | 30 +- docs/autodidax.py | 30 +- .../automatic-differentiation.md | 50 +- .../automatic-vectorization.md | 15 +- docs/beginner_guide.rst | 6 +- docs/build_custom_gpu.sh | 13 + docs/building_on_jax.md | 4 +- docs/conf.py | 85 +- docs/contributing.md | 4 +- docs/cuda_custom_call/BUILD | 63 + docs/cuda_custom_call/Makefile | 35 + .../cuda_custom_call/cuda_custom_call_test.py | 216 + docs/cuda_custom_call/foo.cu.cc | 136 + docs/debugging.md | 205 + docs/debugging/checkify_guide.md | 2 + docs/debugging/flags.md | 18 +- docs/debugging/index.md | 6 +- docs/debugging/print_breakpoint.md | 4 +- docs/deprecation.md | 25 +- docs/developer.md | 264 +- docs/device_memory_profiling.md | 7 +- docs/distributed_data_loading.md | 466 ++ docs/errors.rst | 1 + docs/export/export.md | 799 +++ docs/export/index.rst | 13 + docs/export/jax2tf.md | 5 + docs/export/shape_poly.md | 663 ++ docs/faq.rst | 36 +- docs/glossary.rst | 20 +- docs/gpu_ops/gpu_ops.cpp | 45 + docs/gpu_ops/kernel_helpers.h | 64 + docs/gpu_ops/kernels.h | 44 + docs/gpu_ops/pybind11_kernel_helpers.h | 41 + docs/gpu_ops/rms_norm_kernels.cu | 970 +++ docs/gpu_performance_tips.md | 23 +- docs/index.rst | 10 +- docs/installation.md | 291 +- docs/investigating_a_regression.md | 2 + docs/jax-101/01-jax-basics.ipynb | 833 --- docs/jax-101/01-jax-basics.md | 383 -- docs/jax-101/02-jitting.ipynb | 673 -- docs/jax-101/02-jitting.md | 338 - docs/jax-101/03-vectorization.ipynb | 369 -- docs/jax-101/03-vectorization.md | 161 - docs/jax-101/04-advanced-autodiff.ipynb | 738 --- docs/jax-101/04-advanced-autodiff.md | 374 -- docs/jax-101/05-random-numbers.ipynb | 509 -- docs/jax-101/05-random-numbers.md | 254 - docs/jax-101/05.1-pytrees.ipynb | 1019 --- docs/jax-101/05.1-pytrees.md | 536 -- docs/jax-101/06-parallelism.ipynb | 912 --- docs/jax-101/06-parallelism.md | 414 -- docs/jax-101/07-state.ipynb | 420 -- docs/jax-101/08-pjit.rst | 9 - docs/jax-101/index.rst | 24 - docs/jax.experimental.key_reuse.rst | 9 - docs/jax.experimental.pallas.rst | 23 + docs/jax.experimental.rst | 3 + .../jax.experimental.serialize_executable.rst | 13 + docs/jax.experimental.shard_map.rst | 12 + docs/jax.export.rst | 52 + docs/jax.extend.ffi.rst | 11 + docs/jax.extend.rst | 1 + docs/jax.lax.rst | 1 - docs/jax.lib.rst | 4 +- docs/jax.nn.initializers.rst | 4 +- docs/jax.nn.rst | 3 + docs/jax.numpy.rst | 4 + docs/jax.random.rst | 1 + docs/jax.rst | 2 + docs/jax.scipy.rst | 31 +- docs/jax.sharding.rst | 3 - docs/jax.stages.rst | 2 +- docs/jax.tree_util.rst | 24 +- docs/jax_array_migration.md | 2 + docs/jaxpr.rst | 33 +- docs/jep/10657-sequencing-effects.md | 4 +- docs/jep/18137-numpy-scipy-scope.md | 4 +- docs/jep/263-prng.md | 1 + docs/jep/9263-typed-keys.md | 20 +- docs/jep/9407-type-promotion.ipynb | 8 +- docs/jep/9407-type-promotion.md | 10 +- docs/{tutorials => }/jit-compilation.md | 73 +- docs/key-concepts.md | 191 + docs/multi_process.md | 48 +- docs/notebooks/Common_Gotchas_in_JAX.ipynb | 33 +- docs/notebooks/Common_Gotchas_in_JAX.md | 35 +- ...tom_derivative_rules_for_Python_code.ipynb | 2 + ...Custom_derivative_rules_for_Python_code.md | 4 +- ...arrays_and_automatic_parallelization.ipynb | 23 +- ...ed_arrays_and_automatic_parallelization.md | 20 +- docs/notebooks/How_JAX_primitives_work.ipynb | 2 + docs/notebooks/How_JAX_primitives_work.md | 4 +- .../Neural_Network_and_Data_Loading.ipynb | 10 +- .../Neural_Network_and_Data_Loading.md | 12 +- .../Writing_custom_interpreters_in_Jax.ipynb | 4 +- .../Writing_custom_interpreters_in_Jax.md | 6 +- docs/notebooks/autodiff_cookbook.ipynb | 10 +- docs/notebooks/autodiff_cookbook.md | 12 +- docs/notebooks/autodiff_remat.ipynb | 4 +- docs/notebooks/autodiff_remat.md | 4 +- docs/notebooks/convolutions.ipynb | 6 +- docs/notebooks/convolutions.md | 8 +- docs/notebooks/external_callbacks.ipynb | 4 +- docs/notebooks/external_callbacks.md | 4 +- .../neural_network_with_tfds_data.ipynb | 10 +- .../neural_network_with_tfds_data.md | 12 +- docs/notebooks/quickstart.ipynb | 609 -- docs/notebooks/quickstart.md | 293 - docs/notebooks/shard_map.ipynb | 8 +- docs/notebooks/shard_map.md | 10 +- docs/notebooks/thinking_in_jax.ipynb | 2 + docs/notebooks/thinking_in_jax.md | 4 +- docs/notebooks/vmapped_log_probs.ipynb | 4 +- docs/notebooks/vmapped_log_probs.md | 6 +- docs/notebooks/xmap_tutorial.ipynb | 4 +- docs/notebooks/xmap_tutorial.md | 6 +- docs/pallas/design.md | 435 +- docs/pallas/grid_blockspec.md | 215 + docs/pallas/index.rst | 4 +- docs/pallas/quickstart.ipynb | 177 +- docs/pallas/quickstart.md | 179 +- docs/pallas/tpu/details.rst | 15 +- docs/pallas/tpu/pipelining.ipynb | 306 +- docs/pallas/tpu/pipelining.md | 305 +- docs/persistent_compilation_cache.md | 76 + docs/profiling.md | 20 +- docs/pytrees.md | 4 +- docs/{tutorials => }/quickstart.md | 90 +- docs/{tutorials => }/random-numbers.md | 21 +- docs/rank_promotion_warning.rst | 4 +- docs/requirements.txt | 4 +- docs/sharded-computation.ipynb | 779 +++ docs/sharded-computation.md | 318 + .../07-state.md => stateful-computations.md} | 100 +- docs/tutorials.rst | 18 + docs/tutorials/advanced-debugging.md | 16 - docs/tutorials/debugging.md | 162 - docs/tutorials/index.rst | 57 - docs/tutorials/installation.md | 318 - docs/tutorials/jax-as-accelerated-numpy.md | 8 - docs/tutorials/single-host-sharding.md | 5 - docs/tutorials/stateful-computations.md | 8 - docs/tutorials/thinking-in-jax.md | 417 -- docs/type_promotion.rst | 43 +- docs/user_guides.rst | 4 +- docs/{tutorials => }/working-with-pytrees.md | 154 +- examples/advi.py | 4 +- examples/differentially_private_sgd.py | 2 +- examples/examples_test.py | 13 +- examples/gaussian_process_regression.py | 7 +- examples/jax_cpp/BUILD | 7 +- examples/jax_cpp/main.cc | 13 +- examples/mnist_classifier.py | 2 +- examples/mnist_vae.py | 6 +- jax/BUILD | 110 +- jax/__init__.py | 30 +- jax/_src/abstract_arrays.py | 2 +- jax/_src/ad_checkpoint.py | 54 +- jax/_src/ad_util.py | 3 +- jax/_src/api.py | 486 +- jax/_src/api_util.py | 134 +- jax/_src/array.py | 641 +- jax/_src/basearray.py | 11 +- jax/_src/basearray.pyi | 93 +- jax/_src/blocked_sampler.py | 165 + jax/_src/cache_key.py | 18 +- jax/_src/callback.py | 234 +- jax/_src/checkify.py | 33 +- jax/_src/cloud_tpu_init.py | 35 +- jax/_src/clusters/__init__.py | 4 +- jax/_src/clusters/cloud_tpu_cluster.py | 186 +- jax/_src/clusters/cluster.py | 30 +- jax/_src/clusters/mpi4py_cluster.py | 93 + jax/_src/clusters/ompi_cluster.py | 5 +- jax/_src/clusters/slurm_cluster.py | 5 +- jax/_src/compilation_cache.py | 38 +- jax/_src/compiler.py | 200 +- jax/_src/compute_on.py | 55 + jax/_src/config.py | 422 +- jax/_src/core.py | 459 +- jax/_src/cudnn/fused_attention_stablehlo.py | 1088 ++-- jax/_src/custom_batching.py | 4 +- jax/_src/custom_derivatives.py | 82 +- jax/_src/custom_transpose.py | 7 +- jax/_src/debugger/core.py | 2 +- jax/_src/debugger/web_debugger.py | 41 +- jax/_src/debugging.py | 80 +- jax/_src/deprecations.py | 56 +- jax/_src/dispatch.py | 303 +- jax/_src/distributed.py | 79 +- jax/_src/dlpack.py | 280 +- jax/_src/dtypes.py | 74 +- jax/_src/earray.py | 112 + jax/_src/errors.py | 30 +- .../_src/export/__init__.py | 5 +- jax/{experimental => _src}/export/_export.py | 1044 +-- .../export/serialization.fbs | 10 +- .../export/serialization.py} | 122 +- .../export/serialization_generated.py | 47 +- .../export/shape_poly.py} | 214 +- .../export/shape_poly_decision.py} | 10 +- jax/_src/extend/ffi.py | 257 + jax/_src/extend/random.py | 3 +- jax/_src/gfile_cache.py | 4 + jax/_src/image/scale.py | 3 +- .../cpu_ducc_fft.py | 114 - .../pallas/cuda_add_one.py | 47 + .../stablehlo_dynamic_approx_top_k.py | 84 + .../export_back_compat_test_util.py | 35 +- jax/_src/internal_test_util/lax_test_util.py | 4 +- jax/_src/internal_test_util/test_harnesses.py | 61 +- jax/_src/interpreters/ad.py | 7 +- jax/_src/interpreters/batching.py | 12 +- jax/_src/interpreters/mlir.py | 567 +- jax/_src/interpreters/partial_eval.py | 237 +- jax/_src/interpreters/pxla.py | 1122 ++-- jax/_src/interpreters/xla.py | 16 +- jax/_src/jaxpr_util.py | 5 +- jax/_src/lax/ann.py | 68 +- jax/_src/lax/control_flow/common.py | 4 +- jax/_src/lax/control_flow/conditionals.py | 125 +- jax/_src/lax/control_flow/for_loop.py | 8 +- jax/_src/lax/control_flow/loops.py | 611 +- jax/_src/lax/control_flow/solves.py | 2 +- jax/_src/lax/convolution.py | 20 +- jax/_src/lax/eigh.py | 1 + jax/_src/lax/fft.py | 77 - jax/_src/lax/lax.py | 291 +- jax/_src/lax/linalg.py | 151 +- jax/_src/lax/parallel.py | 83 +- jax/_src/lax/qdwh.py | 191 +- jax/_src/lax/slicing.py | 47 +- jax/_src/lax/svd.py | 240 +- jax/_src/lax/windowed_reductions.py | 259 +- jax/_src/lax_reference.py | 29 + jax/_src/layout.py | 143 +- jax/_src/lazy_loader.py | 4 +- jax/_src/lib/BUILD | 1 + jax/_src/lib/__init__.py | 14 +- jax/_src/lib/mlir/__init__.py | 7 +- jax/_src/lib/mlir/dialects/__init__.py | 9 + jax/_src/lib/mosaic_gpu.py | 23 + jax/_src/lib/triton.py | 24 +- jax/_src/linear_util.py | 18 +- jax/_src/lru_cache.py | 184 + jax/_src/maps.py | 51 +- jax/_src/mesh.py | 4 +- jax/_src/nn/functions.py | 134 +- jax/_src/nn/initializers.py | 42 +- jax/_src/numpy/array_methods.py | 17 +- jax/_src/numpy/lax_numpy.py | 3543 ++++++++-- jax/_src/numpy/linalg.py | 1495 ++++- jax/_src/numpy/polynomial.py | 497 +- jax/_src/numpy/reductions.py | 773 ++- jax/_src/numpy/setops.py | 706 +- jax/_src/numpy/ufunc_api.py | 7 +- jax/_src/numpy/ufuncs.py | 752 ++- jax/_src/numpy/util.py | 6 +- jax/_src/numpy/vectorize.py | 6 +- jax/_src/op_shardings.py | 4 +- jax/_src/ops/scatter.py | 21 +- jax/_src/ops/special.py | 15 +- jax/_src/pallas/BUILD | 34 +- jax/_src/pallas/core.py | 230 +- jax/_src/pallas/mosaic/BUILD | 40 +- jax/_src/pallas/mosaic/__init__.py | 39 - jax/_src/pallas/mosaic/core.py | 56 +- .../pallas/mosaic/kernel_regeneration_util.py | 60 - jax/_src/pallas/mosaic/lowering.py | 777 ++- .../pallas/mosaic/pallas_call_registration.py | 59 +- jax/_src/pallas/mosaic/pipeline.py | 2132 +++--- jax/_src/pallas/mosaic/primitives.py | 327 +- jax/_src/pallas/mosaic/random.py | 211 + jax/_src/pallas/mosaic_gpu/BUILD | 63 + jax/_src/pallas/mosaic_gpu/__init__.py | 13 + jax/_src/pallas/mosaic_gpu/lowering.py | 419 ++ .../mosaic_gpu/pallas_call_registration.py | 89 + jax/_src/pallas/pallas_call.py | 786 ++- jax/_src/pallas/primitives.py | 243 +- jax/_src/pallas/triton/BUILD | 37 +- jax/_src/pallas/triton/__init__.py | 20 - jax/_src/pallas/triton/lowering.py | 822 ++- .../pallas/triton/pallas_call_registration.py | 268 +- jax/_src/pallas/triton/primitives.py | 122 + jax/_src/pallas/utils.py | 64 +- jax/_src/partition_spec.py | 2 +- jax/_src/pickle_util.py | 2 +- jax/_src/pjit.py | 1513 +++-- jax/_src/pretty_printer.py | 94 +- jax/_src/prng.py | 238 +- jax/_src/profiler.py | 69 +- jax/_src/random.py | 93 +- jax/_src/scipy/cluster/vq.py | 84 +- jax/_src/scipy/fft.py | 252 +- jax/_src/scipy/integrate.py | 63 +- jax/_src/scipy/linalg.py | 1245 +++- jax/_src/scipy/ndimage.py | 70 +- jax/_src/scipy/optimize/_lbfgs.py | 9 +- jax/_src/scipy/optimize/bfgs.py | 3 +- jax/_src/scipy/optimize/line_search.py | 2 +- jax/_src/scipy/optimize/minimize.py | 4 +- jax/_src/scipy/signal.py | 542 +- jax/_src/scipy/spatial/transform.py | 65 +- jax/_src/scipy/special.py | 974 ++- jax/_src/scipy/stats/_core.py | 173 +- jax/_src/scipy/stats/bernoulli.py | 112 +- jax/_src/scipy/stats/beta.py | 198 +- jax/_src/scipy/stats/betabinom.py | 61 +- jax/_src/scipy/stats/binom.py | 81 +- jax/_src/scipy/stats/cauchy.py | 234 +- jax/_src/scipy/stats/chi2.py | 198 +- jax/_src/scipy/stats/dirichlet.py | 53 +- jax/_src/scipy/stats/expon.py | 55 +- jax/_src/scipy/stats/gamma.py | 185 +- jax/_src/scipy/stats/gennorm.py | 96 +- jax/_src/scipy/stats/geom.py | 64 +- jax/_src/scipy/stats/kde.py | 28 +- jax/_src/scipy/stats/laplace.py | 76 +- jax/_src/scipy/stats/logistic.py | 160 +- jax/_src/scipy/stats/multinomial.py | 53 +- jax/_src/scipy/stats/multivariate_normal.py | 58 +- jax/_src/scipy/stats/nbinom.py | 81 +- jax/_src/scipy/stats/norm.py | 233 +- jax/_src/scipy/stats/pareto.py | 62 +- jax/_src/scipy/stats/poisson.py | 86 +- jax/_src/scipy/stats/t.py | 55 +- jax/_src/scipy/stats/truncnorm.py | 199 +- jax/_src/scipy/stats/uniform.py | 110 +- jax/_src/scipy/stats/vonmises.py | 56 +- jax/_src/scipy/stats/wrapcauchy.py | 50 +- jax/_src/shard_alike.py | 3 - jax/_src/sharding.py | 127 +- jax/_src/sharding_impls.py | 584 +- jax/_src/sharding_specs.py | 17 - jax/_src/source_info_util.py | 11 +- jax/_src/sourcemap.py | 236 + jax/_src/stages.py | 139 +- jax/_src/state/discharge.py | 100 +- jax/_src/state/indexing.py | 159 +- jax/_src/state/primitives.py | 39 +- jax/_src/state/types.py | 14 +- jax/_src/test_util.py | 630 +- jax/_src/third_party/numpy/LICENSE | 30 - jax/_src/third_party/numpy/__init__.py | 0 jax/_src/third_party/numpy/linalg.py | 211 - jax/_src/third_party/scipy/interpolate.py | 39 +- jax/_src/third_party/scipy/linalg.py | 58 +- jax/_src/third_party/scipy/signal_helper.py | 17 +- jax/_src/tpu_custom_call.py | 294 +- jax/_src/traceback_util.py | 5 +- jax/_src/tree.py | 236 +- jax/_src/tree_util.py | 764 ++- jax/_src/typing.py | 14 + jax/_src/util.py | 50 +- jax/_src/xla_bridge.py | 414 +- jax/config.py | 26 - jax/core.py | 63 +- jax/errors.py | 1 + jax/example_libraries/README.md | 2 +- jax/example_libraries/optimizers.py | 5 +- jax/example_libraries/stax.py | 2 +- jax/experimental/__init__.py | 3 + jax/experimental/array_api/__init__.py | 191 +- jax/experimental/array_api/_array_methods.py | 10 +- .../array_api/_creation_functions.py | 44 +- .../array_api/_data_type_functions.py | 213 +- .../array_api/_elementwise_functions.py | 401 +- jax/experimental/array_api/_fft_functions.py | 51 +- .../array_api/_linear_algebra_functions.py | 130 +- .../array_api/_manipulation_functions.py | 51 +- .../array_api/_searching_functions.py | 39 - jax/experimental/array_api/_set_functions.py | 35 - .../array_api/_sorting_functions.py | 28 - .../array_api/_statistical_functions.py | 34 +- .../array_api/_utility_functions.py | 72 +- jax/experimental/array_api/_version.py | 2 +- jax/experimental/array_api/fft.py | 9 +- jax/experimental/array_api/linalg.py | 12 +- jax/experimental/array_api/skips.txt | 33 +- .../array_serialization/serialization.py | 95 +- .../array_serialization/serialization_test.py | 141 +- jax/experimental/attrs.py | 71 +- jax/experimental/compute_on.py | 17 + jax/experimental/custom_partitioning.py | 46 +- jax/experimental/export/BUILD | 7 +- jax/experimental/export/__init__.py | 79 +- jax/experimental/host_callback.py | 212 +- jax/experimental/jax2tf/BUILD | 1 - jax/experimental/jax2tf/README.md | 1 + jax/experimental/jax2tf/call_tf.py | 17 +- .../jax2tf/examples/keras_reuse_main.py | 6 +- .../jax2tf/examples/keras_reuse_main_test.py | 4 +- jax/experimental/jax2tf/examples/mnist_lib.py | 12 +- .../jax2tf/examples/saved_model_lib.py | 8 +- .../jax2tf/examples/saved_model_main.py | 6 +- .../examples/serving/model_server_request.py | 12 +- .../tf_js/quickdraw/input_pipeline.py | 6 +- .../examples/tf_js/quickdraw/quickdraw.py | 3 +- .../zaidalyafeai.github.io/LICENSE | 2 +- .../jax2tf/examples/tflite/mnist/mnist.py | 4 +- jax/experimental/jax2tf/impl_no_xla.py | 6 +- jax/experimental/jax2tf/jax2tf.py | 175 +- .../jax2tf/tests/back_compat_tf_test.py | 9 +- jax/experimental/jax2tf/tests/call_tf_test.py | 226 +- .../jax2tf/tests/control_flow_ops_test.py | 3 +- jax/experimental/jax2tf/tests/converters.py | 5 +- .../jax2tf/tests/cross_compilation_check.py | 8 +- .../tests/flax_models/bilstm_classifier.py | 3 +- .../jax2tf/tests/flax_models/gnn.py | 5 +- .../jax2tf/tests/flax_models/resnet.py | 4 +- .../tests/flax_models/transformer_lm1b.py | 3 +- .../tests/flax_models/transformer_nlp_seq.py | 3 +- .../tests/flax_models/transformer_wmt.py | 3 +- .../jax2tf/tests/jax2tf_limitations.py | 43 +- jax/experimental/jax2tf/tests/jax2tf_test.py | 45 +- .../tests/jax_primitives_coverage_test.py | 2 +- .../jax2tf/tests/model_harness.py | 4 +- .../jax2tf/tests/primitives_test.py | 9 +- .../jax2tf/tests/savedmodel_test.py | 5 +- .../jax2tf/tests/shape_poly_test.py | 41 +- .../jax2tf/tests/sharding_test.py | 219 +- jax/experimental/jax2tf/tests/tf_test_util.py | 26 +- jax/experimental/jet.py | 31 +- jax/experimental/key_reuse/__init__.py | 38 +- jax/experimental/key_reuse/_core.py | 384 +- jax/experimental/layout.py | 4 +- jax/experimental/maps.py | 24 +- jax/experimental/mesh_utils.py | 491 +- jax/experimental/mosaic/gpu/__init__.py | 704 ++ jax/experimental/mosaic/gpu/dsl.py | 50 + jax/experimental/mosaic/gpu/examples/BUILD | 65 + .../mosaic/gpu/examples/flash_attention.py | 650 ++ .../mosaic/gpu/examples/matmul.py | 530 ++ .../mosaic/gpu/fragmented_array.py | 661 ++ jax/experimental/mosaic/gpu/profiler.py | 300 + jax/experimental/mosaic/gpu/utils.py | 784 +++ jax/experimental/mosaic/gpu/wgmma.py | 502 ++ jax/experimental/multihost_utils.py | 30 +- jax/experimental/ode.py | 4 +- jax/experimental/pallas/__init__.py | 20 +- jax/experimental/pallas/gpu.py | 12 +- jax/experimental/pallas/ops/__init__.py | 7 - jax/experimental/pallas/ops/gpu/__init__.py | 13 + .../pallas/ops/{ => gpu}/attention.py | 64 +- .../pallas/ops/gpu/decode_attention.py | 31 +- .../pallas/ops/{ => gpu}/layer_norm.py | 2 - .../pallas/ops/{ => gpu}/rms_norm.py | 0 .../pallas/ops/{ => gpu}/softmax.py | 0 jax/experimental/pallas/ops/tpu/all_gather.py | 18 +- .../pallas/ops/tpu/flash_attention.py | 12 +- .../ops/tpu/megablox/__init__.py} | 10 +- .../pallas/ops/tpu/megablox/common.py | 63 + .../pallas/ops/tpu/megablox/gmm.py | 799 +++ .../pallas/ops/tpu/megablox/ops.py | 109 + .../ops/tpu/paged_attention/__init__.py} | 5 +- .../paged_attention/paged_attention_kernel.py | 649 ++ .../tpu/paged_attention/quantization_utils.py | 107 + .../splash_attention_kernel.py | 212 +- .../splash_attention/splash_attention_mask.py | 45 +- .../splash_attention_mask_info.py | 35 +- jax/experimental/pallas/tpu.py | 69 +- jax/experimental/rnn.py | 4 +- jax/experimental/shard_map.py | 253 +- jax/experimental/slab/djax.py | 187 + jax/experimental/slab/slab.py | 365 ++ jax/experimental/sparse/ad.py | 8 +- jax/experimental/sparse/bcoo.py | 35 +- jax/experimental/sparse/linalg.py | 2 +- jax/experimental/sparse/nm.py | 244 + jax/experimental/sparse/random.py | 2 +- jax/experimental/sparse/test_util.py | 9 +- jax/experimental/sparse/transform.py | 26 +- jax/export.py | 36 + jax/extend/BUILD | 27 +- jax/extend/__init__.py | 2 + jax/extend/{core.py => backend.py} | 6 +- jax/extend/core/__init__.py | 33 + jax/extend/core/primitives.py | 229 + jax/extend/ffi.py | 23 + jax/interpreters/ad.py | 20 +- jax/interpreters/mlir.py | 1 - jax/interpreters/pxla.py | 1 - jax/interpreters/xla.py | 59 +- jax/lax/__init__.py | 18 +- jax/nn/__init__.py | 18 +- jax/numpy/__init__.py | 19 +- jax/numpy/__init__.pyi | 528 +- jax/numpy/linalg.py | 11 +- jax/random.py | 74 +- jax/scipy/linalg.py | 1 + jax/scipy/special.py | 21 +- jax/sharding.py | 22 +- jax/stages.py | 2 + jax/tools/BUILD | 2 +- jax/tools/build_utils.py | 21 +- jax/tools/jax_to_ir.py | 18 +- jax/tools/pgo_nsys_converter.py | 62 + .../_dtypes.py => tools/toolchains/BUILD} | 31 +- jax/tree_util.py | 59 +- jax/util.py | 1 + jax/version.py | 20 +- jax_plugins/BUILD.bazel | 5 +- jax_plugins/cuda/__init__.py | 2 +- jax_plugins/cuda/plugin_setup.py | 24 +- jax_plugins/rocm/BUILD.bazel | 55 + jax_plugins/rocm/__init__.py | 91 + jax_plugins/rocm/plugin_pyproject.toml | 3 + jax_plugins/rocm/plugin_setup.py | 70 + jax_plugins/rocm/pyproject.toml | 3 + jax_plugins/rocm/setup.py | 66 + jaxlib/BUILD | 57 +- jaxlib/cpu/BUILD | 56 +- jaxlib/cpu/cpu_kernels.cc | 29 +- jaxlib/cpu/ducc_fft.cc | 66 - jaxlib/cpu/ducc_fft_kernels.cc | 142 - jaxlib/cpu/lapack.cc | 184 +- jaxlib/cpu/lapack.h | 61 + jaxlib/cpu/lapack_kernels.cc | 247 +- jaxlib/cpu/lapack_kernels.h | 132 +- jaxlib/cpu/lapack_kernels_using_lapack.cc | 164 +- jaxlib/cpu_feature_guard.c | 12 +- jaxlib/cuda/BUILD | 72 +- jaxlib/cuda/versions_helpers.cc | 19 +- jaxlib/cuda_plugin_extension.cc | 87 +- jaxlib/gpu/BUILD | 3 + jaxlib/gpu/blas.cc | 2 +- jaxlib/gpu/cholesky_update_kernel.cc | 53 + jaxlib/gpu/cholesky_update_kernel.cu.cc | 136 + jaxlib/gpu/cholesky_update_kernel.h | 50 + jaxlib/gpu/gpu_kernels.cc | 15 + jaxlib/gpu/linalg.cc | 39 +- jaxlib/gpu/lu_pivot_kernels.cc | 78 +- jaxlib/gpu/lu_pivot_kernels.cu.cc | 20 +- jaxlib/gpu/lu_pivot_kernels.h | 37 +- jaxlib/gpu/prng.cc | 4 + jaxlib/gpu/prng_kernels.cc | 38 + jaxlib/gpu/prng_kernels.cu.cc | 17 + jaxlib/gpu/prng_kernels.h | 12 + jaxlib/gpu/solver.cc | 3 +- jaxlib/gpu/sparse.cc | 2 +- jaxlib/gpu/triton_kernels.cc | 18 +- jaxlib/gpu/vendor.h | 36 +- jaxlib/gpu_linalg.py | 70 +- jaxlib/gpu_prng.py | 68 +- jaxlib/gpu_rnn.py | 2 +- jaxlib/gpu_solver.py | 42 +- jaxlib/gpu_sparse.py | 18 +- jaxlib/gpu_triton.py | 17 +- jaxlib/hlo_helpers.py | 15 +- jaxlib/jax.bzl | 55 +- jaxlib/kernel_nanobind_helpers.h | 2 +- jaxlib/mlir/BUILD.bazel | 91 +- jaxlib/mlir/_mlir_libs/BUILD.bazel | 107 +- .../mlir/_mlir_libs/register_jax_dialects.cc | 9 + jaxlib/mlir/_mlir_libs/tpu_ext.cc | 94 +- jaxlib/mosaic/BUILD | 6 + .../dialect/tpu/integrations/c/tpu_dialect.cc | 8 + .../dialect/tpu/integrations/c/tpu_dialect.h | 3 + jaxlib/mosaic/dialect/tpu/layout.cc | 142 +- jaxlib/mosaic/dialect/tpu/layout.h | 73 +- jaxlib/mosaic/dialect/tpu/tpu.td | 156 +- jaxlib/mosaic/dialect/tpu/tpu_dialect.cc | 3 + jaxlib/mosaic/dialect/tpu/tpu_dialect.h | 7 +- jaxlib/mosaic/dialect/tpu/tpu_ops.cc | 101 +- .../tpu/transforms/apply_vector_layout.cc | 2009 ++++-- .../tpu/transforms/apply_vector_layout.h | 8 +- .../tpu/transforms/debug_assert_insertion.cc | 33 +- .../tpu/transforms/infer_memref_layout.cc | 66 +- .../tpu/transforms/infer_memref_layout.h | 7 +- .../tpu/transforms/infer_vector_layout.cc | 912 ++- .../tpu/transforms/linalg_vectorization.cc | 462 +- jaxlib/mosaic/dialect/tpu/transforms/serde.cc | 70 +- jaxlib/mosaic/dialect/tpu/util.h | 7 + jaxlib/mosaic/gpu/BUILD | 198 + jaxlib/mosaic/gpu/custom_call.cc | 446 ++ .../gpu/integrations/c/passes.cc} | 34 +- .../gpu/integrations/c/passes.h} | 24 +- jaxlib/mosaic/gpu/launch_lowering.cc | 329 + jaxlib/mosaic/gpu/launch_lowering.h | 27 + jaxlib/mosaic/gpu/mosaic_gpu_ext.cc | 63 + jaxlib/mosaic/gpu/pass_boilerplate.h | 64 + jaxlib/mosaic/gpu/passes.cc | 77 + jaxlib/mosaic/gpu/passes.h | 27 + jaxlib/mosaic/gpu/runtime.cc | 145 + jaxlib/mosaic/python/BUILD | 2 +- jaxlib/rocm/BUILD.bazel | 69 +- jaxlib/rocm_plugin_extension.cc | 152 + jaxlib/setup.py | 34 +- jaxlib/tools/BUILD.bazel | 56 +- jaxlib/tools/LICENSE.txt | 28 - ...ls_wheel.py => build_gpu_kernels_wheel.py} | 75 +- jaxlib/tools/build_gpu_plugin_wheel.py | 68 +- jaxlib/tools/build_wheel.py | 66 +- jaxlib/triton/dialect.py | 3 +- platform_mappings | 11 + pyproject.toml | 59 +- setup.py | 127 +- tests/BUILD | 137 +- tests/ann_test.py | 4 +- tests/aot_test.py | 3 +- tests/api_test.py | 435 +- tests/api_util_test.py | 7 +- tests/array_api_test.py | 126 +- tests/array_interoperability_test.py | 85 +- tests/array_test.py | 224 +- tests/attrs_test.py | 95 +- tests/batching_test.py | 8 +- tests/blocked_sampler_test.py | 90 + tests/cache_key_test.py | 17 + tests/cholesky_update_test.py | 72 + tests/clear_backends_test.py | 7 +- tests/compilation_cache_test.py | 469 +- tests/config_test.py | 73 + tests/core_test.py | 10 +- tests/custom_linear_solve_test.py | 3 +- tests/custom_object_test.py | 4 +- tests/custom_root_test.py | 3 +- tests/debug_nans_test.py | 30 +- tests/debugger_test.py | 40 +- tests/debugging_primitives_test.py | 26 +- tests/deprecation_test.py | 21 +- tests/dtypes_test.py | 89 +- tests/dynamic_api_test.py | 54 +- tests/export_back_compat_test.py | 92 +- tests/export_harnesses_multi_platform_test.py | 48 +- tests/export_test.py | 1075 ++- tests/extend_test.py | 95 +- tests/for_loop_test.py | 7 +- tests/fused_attention_stablehlo_test.py | 362 +- tests/generated_fun_test.py | 4 +- tests/gpu_memory_flags_test.py | 8 +- tests/heap_profiler_test.py | 3 +- tests/host_callback_test.py | 721 +- tests/host_callback_to_tf_test.py | 39 +- tests/image_test.py | 4 +- tests/infeed_test.py | 7 +- tests/jax_to_ir_test.py | 2 + tests/jaxpr_effects_test.py | 120 +- tests/jaxpr_util_test.py | 6 +- tests/jet_test.py | 7 +- tests/key_reuse_test.py | 148 +- tests/lax_autodiff_test.py | 3 +- tests/lax_control_flow_test.py | 139 +- tests/lax_metal_test.py | 5773 +++++++++++++++++ tests/lax_numpy_einsum_test.py | 43 +- tests/lax_numpy_indexing_test.py | 34 + tests/lax_numpy_operators_test.py | 29 +- tests/lax_numpy_reducers_test.py | 115 +- tests/lax_numpy_test.py | 448 +- tests/lax_numpy_ufuncs_test.py | 3 +- tests/lax_numpy_vectorize_test.py | 3 +- tests/lax_scipy_special_functions_test.py | 58 +- tests/lax_scipy_spectral_dac_test.py | 4 +- tests/lax_scipy_test.py | 24 +- tests/lax_test.py | 937 ++- tests/lax_vmap_op_test.py | 3 +- tests/lax_vmap_test.py | 3 +- tests/layout_test.py | 355 +- tests/linalg_test.py | 174 +- tests/lobpcg_test.py | 14 +- tests/logging_test.py | 65 +- tests/lru_cache_test.py | 155 + tests/memories_test.py | 1643 ++--- tests/mesh_utils_test.py | 342 +- tests/metadata_test.py | 4 +- tests/mock_gpu_test.py | 18 +- tests/mosaic/BUILD | 89 + tests/mosaic/flash_attention_test.py | 78 + tests/mosaic/gpu_test.py | 976 +++ tests/mosaic/matmul_test.py | 149 + tests/mosaic_test.py | 4 +- tests/multi_device_test.py | 37 +- tests/multibackend_test.py | 17 +- tests/multiprocess_gpu_test.py | 58 +- tests/mutable_array_test.py | 227 + tests/name_stack_test.py | 18 +- tests/nn_test.py | 70 +- tests/ode_test.py | 3 +- tests/optimizers_test.py | 3 +- tests/package_structure_test.py | 4 +- tests/pallas/BUILD | 246 +- tests/pallas/all_gather_test.py | 3 +- .../pallas/export_back_compat_pallas_test.py | 61 + tests/pallas/export_pallas_test.py | 64 + tests/pallas/gmm_test.py | 390 ++ tests/pallas/gpu_attention_test.py | 31 +- tests/pallas/indexing_test.py | 147 +- tests/pallas/mosaic_gpu_test.py | 118 + tests/pallas/ops_test.py | 82 + tests/pallas/paged_attention_kernel_test.py | 182 + tests/pallas/pallas_call_tpu_test.py | 1656 +++-- tests/pallas/pallas_pipeline_tpu_test.py | 1519 +++++ tests/pallas/pallas_shape_poly_test.py | 220 + tests/pallas/pallas_test.py | 2037 ++++-- tests/pallas/splash_attention_kernel_test.py | 45 +- tests/pallas/splash_attention_mask_test.py | 3 +- tests/pallas/tpu/BUILD | 47 + tests/pallas/tpu/pallas_random_test.py | 211 + tests/pgle_test.py | 233 +- tests/pickle_test.py | 3 +- tests/pjit_test.py | 587 +- tests/pmap_test.py | 112 +- tests/polynomial_test.py | 4 +- tests/pretty_printer_test.py | 36 + tests/profiler_test.py | 3 +- tests/python_callback_test.py | 216 +- tests/pytorch_interoperability_test.py | 6 - tests/qdwh_test.py | 273 +- tests/random_lax_test.py | 111 +- tests/random_test.py | 57 +- tests/scipy_fft_test.py | 5 +- tests/scipy_interpolate_test.py | 4 +- tests/scipy_ndimage_test.py | 8 +- tests/scipy_optimize_test.py | 4 +- tests/scipy_signal_test.py | 4 +- tests/scipy_spatial_test.py | 3 +- tests/scipy_stats_test.py | 41 +- tests/shape_poly_test.py | 91 +- tests/shard_alike_test.py | 29 +- tests/shard_map_test.py | 360 +- tests/source_info_test.py | 3 +- tests/sourcemap_test.py | 89 + tests/sparse_bcoo_bcsr_test.py | 29 +- tests/sparse_nm_test.py | 209 + tests/sparse_test.py | 27 +- tests/sparsify_test.py | 4 +- tests/stack_test.py | 4 +- tests/state_test.py | 62 +- tests/stax_test.py | 4 +- tests/svd_test.py | 20 + tests/third_party/scipy/line_search_test.py | 4 +- tests/transfer_guard_test.py | 18 +- tests/tree_util_test.py | 249 +- tests/util_test.py | 4 +- tests/x64_context_test.py | 9 +- tests/xla_bridge_test.py | 43 +- tests/xmap_test.py | 46 +- third_party/nanobind/BUILD.bazel | 22 - third_party/nanobind/workspace.bzl | 26 - third_party/robin_map/BUILD.bazel | 17 - third_party/robin_map/workspace.bzl | 26 - third_party/xla/workspace.bzl | 5 +- 814 files changed, 80657 insertions(+), 33340 deletions(-) create mode 100644 .github/workflows/metal_plugin_ci.yml create mode 100644 .github/workflows/upstream-nightly.yml create mode 100644 benchmarks/mosaic/BUILD create mode 100644 benchmarks/mosaic/matmul_bench.py create mode 100644 build/BUILD.bazel create mode 100644 build/requirements.in create mode 100644 build/requirements_lock_3_10.txt create mode 100644 build/requirements_lock_3_11.txt create mode 100644 build/requirements_lock_3_12.txt create mode 100644 build/requirements_lock_3_13.txt create mode 100644 docs/Custom_Operation_for_GPUs.py create mode 100644 docs/_static/distributed_data_loading/1.svg create mode 100644 docs/_static/distributed_data_loading/10.svg create mode 100644 docs/_static/distributed_data_loading/11.svg create mode 100644 docs/_static/distributed_data_loading/12.svg create mode 100644 docs/_static/distributed_data_loading/13.svg create mode 100644 docs/_static/distributed_data_loading/14.svg create mode 100644 docs/_static/distributed_data_loading/15.svg create mode 100644 docs/_static/distributed_data_loading/16.svg create mode 100644 docs/_static/distributed_data_loading/17.svg create mode 100644 docs/_static/distributed_data_loading/18.svg create mode 100644 docs/_static/distributed_data_loading/19.svg create mode 100644 docs/_static/distributed_data_loading/2.svg create mode 100644 docs/_static/distributed_data_loading/20.svg create mode 100644 docs/_static/distributed_data_loading/21.svg create mode 100644 docs/_static/distributed_data_loading/22.svg create mode 100644 docs/_static/distributed_data_loading/3.svg create mode 100644 docs/_static/distributed_data_loading/4.svg create mode 100644 docs/_static/distributed_data_loading/5.svg create mode 100644 docs/_static/distributed_data_loading/6.svg create mode 100644 docs/_static/distributed_data_loading/7.svg create mode 100644 docs/_static/distributed_data_loading/8.svg create mode 100644 docs/_static/distributed_data_loading/9.svg rename docs/{tutorials => _tutorials}/advanced-autodiff.md (99%) rename docs/{tutorials => _tutorials}/advanced-compilation.md (81%) create mode 100644 docs/_tutorials/advanced-debugging.md rename docs/{tutorials => _tutorials}/external-callbacks.md (97%) rename docs/{tutorials => _tutorials}/gradient-checkpointing.md (99%) create mode 100644 docs/_tutorials/index.rst rename docs/{tutorials => _tutorials}/jax-primitives.md (98%) rename docs/{tutorials => _tutorials}/jaxpr.md (97%) rename docs/{tutorials => _tutorials}/parallelism.md (82%) rename docs/{tutorials => _tutorials}/profiling-and-performance.md (81%) rename docs/{tutorials => _tutorials}/simple-neural-network.md (66%) rename docs/{tutorials => }/automatic-differentiation.md (83%) rename docs/{tutorials => }/automatic-vectorization.md (87%) create mode 100644 docs/build_custom_gpu.sh create mode 100644 docs/cuda_custom_call/BUILD create mode 100644 docs/cuda_custom_call/Makefile create mode 100644 docs/cuda_custom_call/cuda_custom_call_test.py create mode 100644 docs/cuda_custom_call/foo.cu.cc create mode 100644 docs/debugging.md create mode 100644 docs/distributed_data_loading.md create mode 100644 docs/export/export.md create mode 100644 docs/export/index.rst create mode 100644 docs/export/jax2tf.md create mode 100644 docs/export/shape_poly.md create mode 100644 docs/gpu_ops/gpu_ops.cpp create mode 100644 docs/gpu_ops/kernel_helpers.h create mode 100644 docs/gpu_ops/kernels.h create mode 100644 docs/gpu_ops/pybind11_kernel_helpers.h create mode 100644 docs/gpu_ops/rms_norm_kernels.cu delete mode 100644 docs/jax-101/01-jax-basics.ipynb delete mode 100644 docs/jax-101/01-jax-basics.md delete mode 100644 docs/jax-101/02-jitting.ipynb delete mode 100644 docs/jax-101/02-jitting.md delete mode 100644 docs/jax-101/03-vectorization.ipynb delete mode 100644 docs/jax-101/03-vectorization.md delete mode 100644 docs/jax-101/04-advanced-autodiff.ipynb delete mode 100644 docs/jax-101/04-advanced-autodiff.md delete mode 100644 docs/jax-101/05-random-numbers.ipynb delete mode 100644 docs/jax-101/05-random-numbers.md delete mode 100644 docs/jax-101/05.1-pytrees.ipynb delete mode 100644 docs/jax-101/05.1-pytrees.md delete mode 100644 docs/jax-101/06-parallelism.ipynb delete mode 100644 docs/jax-101/06-parallelism.md delete mode 100644 docs/jax-101/07-state.ipynb delete mode 100644 docs/jax-101/08-pjit.rst delete mode 100644 docs/jax-101/index.rst create mode 100644 docs/jax.experimental.pallas.rst create mode 100644 docs/jax.experimental.serialize_executable.rst create mode 100644 docs/jax.experimental.shard_map.rst create mode 100644 docs/jax.export.rst create mode 100644 docs/jax.extend.ffi.rst rename docs/{tutorials => }/jit-compilation.md (65%) create mode 100644 docs/key-concepts.md delete mode 100644 docs/notebooks/quickstart.ipynb delete mode 100644 docs/notebooks/quickstart.md create mode 100644 docs/pallas/grid_blockspec.md create mode 100644 docs/persistent_compilation_cache.md rename docs/{tutorials => }/quickstart.md (59%) rename docs/{tutorials => }/random-numbers.md (92%) create mode 100644 docs/sharded-computation.ipynb create mode 100644 docs/sharded-computation.md rename docs/{jax-101/07-state.md => stateful-computations.md} (60%) create mode 100644 docs/tutorials.rst delete mode 100644 docs/tutorials/advanced-debugging.md delete mode 100644 docs/tutorials/debugging.md delete mode 100644 docs/tutorials/index.rst delete mode 100644 docs/tutorials/installation.md delete mode 100644 docs/tutorials/jax-as-accelerated-numpy.md delete mode 100644 docs/tutorials/single-host-sharding.md delete mode 100644 docs/tutorials/stateful-computations.md delete mode 100644 docs/tutorials/thinking-in-jax.md rename docs/{tutorials => }/working-with-pytrees.md (77%) create mode 100644 jax/_src/blocked_sampler.py create mode 100644 jax/_src/clusters/mpi4py_cluster.py create mode 100644 jax/_src/compute_on.py create mode 100644 jax/_src/earray.py rename jaxlib/cpu/_ducc_fft.pyi => jax/_src/export/__init__.py (74%) rename jax/{experimental => _src}/export/_export.py (56%) rename jax/{experimental => _src}/export/serialization.fbs (93%) rename jax/{experimental/export/_serialization.py => _src/export/serialization.py} (81%) rename jax/{experimental => _src}/export/serialization_generated.py (96%) rename jax/{experimental/export/_shape_poly.py => _src/export/shape_poly.py} (92%) rename jax/{experimental/export/_shape_poly_decision.py => _src/export/shape_poly_decision.py} (98%) create mode 100644 jax/_src/extend/ffi.py delete mode 100644 jax/_src/internal_test_util/export_back_compat_test_data/cpu_ducc_fft.py create mode 100644 jax/_src/internal_test_util/export_back_compat_test_data/pallas/cuda_add_one.py create mode 100644 jax/_src/internal_test_util/export_back_compat_test_data/stablehlo_dynamic_approx_top_k.py create mode 100644 jax/_src/lib/mosaic_gpu.py create mode 100644 jax/_src/lru_cache.py delete mode 100644 jax/_src/pallas/mosaic/kernel_regeneration_util.py create mode 100644 jax/_src/pallas/mosaic/random.py create mode 100644 jax/_src/pallas/mosaic_gpu/BUILD create mode 100644 jax/_src/pallas/mosaic_gpu/__init__.py create mode 100644 jax/_src/pallas/mosaic_gpu/lowering.py create mode 100644 jax/_src/pallas/mosaic_gpu/pallas_call_registration.py create mode 100644 jax/_src/pallas/triton/primitives.py create mode 100644 jax/_src/sourcemap.py delete mode 100644 jax/_src/third_party/numpy/LICENSE delete mode 100644 jax/_src/third_party/numpy/__init__.py delete mode 100644 jax/_src/third_party/numpy/linalg.py delete mode 100644 jax/config.py delete mode 100644 jax/experimental/array_api/_searching_functions.py delete mode 100644 jax/experimental/array_api/_set_functions.py delete mode 100644 jax/experimental/array_api/_sorting_functions.py create mode 100644 jax/experimental/compute_on.py create mode 100644 jax/experimental/mosaic/gpu/__init__.py create mode 100644 jax/experimental/mosaic/gpu/dsl.py create mode 100644 jax/experimental/mosaic/gpu/examples/BUILD create mode 100644 jax/experimental/mosaic/gpu/examples/flash_attention.py create mode 100644 jax/experimental/mosaic/gpu/examples/matmul.py create mode 100644 jax/experimental/mosaic/gpu/fragmented_array.py create mode 100644 jax/experimental/mosaic/gpu/profiler.py create mode 100644 jax/experimental/mosaic/gpu/utils.py create mode 100644 jax/experimental/mosaic/gpu/wgmma.py create mode 100644 jax/experimental/pallas/ops/gpu/__init__.py rename jax/experimental/pallas/ops/{ => gpu}/attention.py (88%) rename jax/experimental/pallas/ops/{ => gpu}/layer_norm.py (99%) rename jax/experimental/pallas/ops/{ => gpu}/rms_norm.py (100%) rename jax/experimental/pallas/ops/{ => gpu}/softmax.py (100%) rename jax/experimental/{array_api/_constants.py => pallas/ops/tpu/megablox/__init__.py} (81%) create mode 100644 jax/experimental/pallas/ops/tpu/megablox/common.py create mode 100644 jax/experimental/pallas/ops/tpu/megablox/gmm.py create mode 100644 jax/experimental/pallas/ops/tpu/megablox/ops.py rename jax/experimental/{array_api/_indexing_functions.py => pallas/ops/tpu/paged_attention/__init__.py} (85%) create mode 100644 jax/experimental/pallas/ops/tpu/paged_attention/paged_attention_kernel.py create mode 100644 jax/experimental/pallas/ops/tpu/paged_attention/quantization_utils.py create mode 100644 jax/experimental/slab/djax.py create mode 100644 jax/experimental/slab/slab.py create mode 100644 jax/experimental/sparse/nm.py create mode 100644 jax/export.py rename jax/extend/{core.py => backend.py} (86%) create mode 100644 jax/extend/core/__init__.py create mode 100644 jax/extend/core/primitives.py create mode 100644 jax/extend/ffi.py create mode 100644 jax/tools/pgo_nsys_converter.py rename jax/{experimental/array_api/_dtypes.py => tools/toolchains/BUILD} (56%) create mode 100644 jax_plugins/rocm/BUILD.bazel create mode 100644 jax_plugins/rocm/__init__.py create mode 100644 jax_plugins/rocm/plugin_pyproject.toml create mode 100644 jax_plugins/rocm/plugin_setup.py create mode 100644 jax_plugins/rocm/pyproject.toml create mode 100644 jax_plugins/rocm/setup.py delete mode 100644 jaxlib/cpu/ducc_fft.cc delete mode 100644 jaxlib/cpu/ducc_fft_kernels.cc create mode 100644 jaxlib/cpu/lapack.h create mode 100644 jaxlib/gpu/cholesky_update_kernel.cc create mode 100644 jaxlib/gpu/cholesky_update_kernel.cu.cc create mode 100644 jaxlib/gpu/cholesky_update_kernel.h create mode 100644 jaxlib/mosaic/gpu/BUILD create mode 100644 jaxlib/mosaic/gpu/custom_call.cc rename jaxlib/{cpu/ducc_fft.fbs => mosaic/gpu/integrations/c/passes.cc} (54%) rename jaxlib/{cpu/ducc_fft_kernels.h => mosaic/gpu/integrations/c/passes.h} (56%) create mode 100644 jaxlib/mosaic/gpu/launch_lowering.cc create mode 100644 jaxlib/mosaic/gpu/launch_lowering.h create mode 100644 jaxlib/mosaic/gpu/mosaic_gpu_ext.cc create mode 100644 jaxlib/mosaic/gpu/pass_boilerplate.h create mode 100644 jaxlib/mosaic/gpu/passes.cc create mode 100644 jaxlib/mosaic/gpu/passes.h create mode 100644 jaxlib/mosaic/gpu/runtime.cc create mode 100644 jaxlib/rocm_plugin_extension.cc rename jaxlib/tools/{build_cuda_kernels_wheel.py => build_gpu_kernels_wheel.py} (60%) create mode 100644 platform_mappings create mode 100644 tests/blocked_sampler_test.py create mode 100644 tests/cholesky_update_test.py create mode 100644 tests/config_test.py create mode 100644 tests/lax_metal_test.py create mode 100644 tests/lru_cache_test.py create mode 100644 tests/mosaic/BUILD create mode 100644 tests/mosaic/flash_attention_test.py create mode 100644 tests/mosaic/gpu_test.py create mode 100644 tests/mosaic/matmul_test.py create mode 100644 tests/mutable_array_test.py create mode 100644 tests/pallas/export_back_compat_pallas_test.py create mode 100644 tests/pallas/export_pallas_test.py create mode 100644 tests/pallas/gmm_test.py create mode 100644 tests/pallas/mosaic_gpu_test.py create mode 100644 tests/pallas/ops_test.py create mode 100644 tests/pallas/paged_attention_kernel_test.py create mode 100644 tests/pallas/pallas_pipeline_tpu_test.py create mode 100644 tests/pallas/pallas_shape_poly_test.py create mode 100644 tests/pallas/tpu/BUILD create mode 100644 tests/pallas/tpu/pallas_random_test.py create mode 100644 tests/pretty_printer_test.py create mode 100644 tests/sourcemap_test.py create mode 100644 tests/sparse_nm_test.py delete mode 100644 third_party/nanobind/BUILD.bazel delete mode 100644 third_party/nanobind/workspace.bzl delete mode 100644 third_party/robin_map/BUILD.bazel delete mode 100644 third_party/robin_map/workspace.bzl diff --git a/.bazelrc b/.bazelrc index 6fae63d6081d..a08df1017f91 100644 --- a/.bazelrc +++ b/.bazelrc @@ -79,12 +79,20 @@ build:nvcc_clang --action_env=TF_CUDA_CLANG="1" build:nvcc_clang --action_env=TF_NVCC_CLANG="1" build:nvcc_clang --@local_config_cuda//:cuda_compiler=nvcc +# Requires MSVC and LLVM to be installed +build:win_clang --extra_toolchains=@local_config_cc//:cc-toolchain-x64_windows-clang-cl +build:win_clang --extra_execution_platforms=//jax/tools/toolchains:x64_windows-clang-cl +build:win_clang --compiler=clang-cl + # Later Bazel flag values override earlier values. # TODO(jieying): remove enable_gpu and xla_python_enable_gpu from build:cuda # after the pluin is released. build:cuda_plugin --@xla//xla/python:enable_gpu=false build:cuda_plugin --define=xla_python_enable_gpu=false +build:rocm_plugin --@xla//xla/python:enable_gpu=false +build:rocm_plugin --define=xla_python_enable_gpu=false + # Force the linker to set RPATH, not RUNPATH. When resolving dynamic libraries, # ld.so prefers in order: RPATH, LD_LIBRARY_PATH, RUNPATH. JAX sets RPATH to # point to the $ORIGIN-relative location of the pip-installed NVIDIA CUDA @@ -208,75 +216,47 @@ build:rbe_linux --host_linkopt=-lm build:rbe_cpu_linux_base --config=rbe_linux build:rbe_cpu_linux_base --config=cuda_clang build:rbe_cpu_linux_base --action_env=TF_NVCC_CLANG="1" -build:rbe_cpu_linux_base --host_crosstool_top="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn8.9_config_cuda//crosstool:toolchain" -build:rbe_cpu_linux_base --crosstool_top="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn8.9_config_cuda//crosstool:toolchain" -build:rbe_cpu_linux_base --extra_toolchains="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn8.9_config_cuda//crosstool:toolchain-linux-x86_64" -build:rbe_cpu_linux_base --extra_execution_platforms="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn8.9_config_platform//:platform" -build:rbe_cpu_linux_base --host_platform="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn8.9_config_platform//:platform" -build:rbe_cpu_linux_base --platforms="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn8.9_config_platform//:platform" - -build:rbe_cpu_linux_py3.9 --config=rbe_cpu_linux_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn8.9_config_python3.9" -build:rbe_cpu_linux_py3.9 --python_path="/usr/local/bin/python3.9" -build:rbe_cpu_linux_py3.10 --config=rbe_cpu_linux_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn8.9_config_python3.10" -build:rbe_cpu_linux_py3.10 --python_path="/usr/local/bin/python3.10" -build:rbe_cpu_linux_py3.11 --config=rbe_cpu_linux_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn8.9_config_python3.11" -build:rbe_cpu_linux_py3.11 --python_path="/usr/local/bin/python3.11" -build:rbe_cpu_linux_py3.12 --config=rbe_cpu_linux_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn8.9_config_python3.12" -build:rbe_cpu_linux_py3.12 --python_path="/usr/local/bin/python3.12" +build:rbe_cpu_linux_base --host_crosstool_top="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_cuda//crosstool:toolchain" +build:rbe_cpu_linux_base --crosstool_top="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_cuda//crosstool:toolchain" +build:rbe_cpu_linux_base --extra_toolchains="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_cuda//crosstool:toolchain-linux-x86_64" +build:rbe_cpu_linux_base --extra_execution_platforms="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_platform//:platform" +build:rbe_cpu_linux_base --host_platform="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_platform//:platform" +build:rbe_cpu_linux_base --platforms="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_platform//:platform" + +build:rbe_cpu_linux_py3.10 --config=rbe_cpu_linux_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_python3.10" +build:rbe_cpu_linux_py3.10 --repo_env HERMETIC_PYTHON_VERSION="3.10" +build:rbe_cpu_linux_py3.11 --config=rbe_cpu_linux_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_python3.11" +build:rbe_cpu_linux_py3.11 --repo_env HERMETIC_PYTHON_VERSION="3.11" +build:rbe_cpu_linux_py3.12 --config=rbe_cpu_linux_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_python3.12" +build:rbe_cpu_linux_py3.12 --repo_env HERMETIC_PYTHON_VERSION="3.12" build:rbe_linux_cuda_base --config=rbe_linux build:rbe_linux_cuda_base --config=cuda build:rbe_linux_cuda_base --repo_env=REMOTE_GPU_TESTING=1 -build:rbe_linux_cuda11.8_nvcc_base --config=rbe_linux_cuda_base -build:rbe_linux_cuda11.8_nvcc_base --config=cuda_clang -build:rbe_linux_cuda11.8_nvcc_base --action_env=TF_NVCC_CLANG="1" -build:rbe_linux_cuda11.8_nvcc_base --action_env=TF_CUDA_VERSION=11 -build:rbe_linux_cuda11.8_nvcc_base --action_env=TF_CUDNN_VERSION=8 -build:rbe_linux_cuda11.8_nvcc_base --action_env=CUDA_TOOLKIT_PATH="/usr/local/cuda-11.8" -build:rbe_linux_cuda11.8_nvcc_base --action_env=LD_LIBRARY_PATH="/usr/local/cuda:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:/usr/local/tensorrt/lib" -build:rbe_linux_cuda11.8_nvcc_base --host_crosstool_top="@ubuntu20.04-clang_manylinux2014-cuda11.8-cudnn8.6-tensorrt8.4_config_cuda//crosstool:toolchain" -build:rbe_linux_cuda11.8_nvcc_base --crosstool_top="@ubuntu20.04-clang_manylinux2014-cuda11.8-cudnn8.6-tensorrt8.4_config_cuda//crosstool:toolchain" -build:rbe_linux_cuda11.8_nvcc_base --extra_toolchains="@ubuntu20.04-clang_manylinux2014-cuda11.8-cudnn8.6-tensorrt8.4_config_cuda//crosstool:toolchain-linux-x86_64" -build:rbe_linux_cuda11.8_nvcc_base --extra_execution_platforms="@ubuntu20.04-clang_manylinux2014-cuda11.8-cudnn8.6-tensorrt8.4_config_platform//:platform" -build:rbe_linux_cuda11.8_nvcc_base --host_platform="@ubuntu20.04-clang_manylinux2014-cuda11.8-cudnn8.6-tensorrt8.4_config_platform//:platform" -build:rbe_linux_cuda11.8_nvcc_base --platforms="@ubuntu20.04-clang_manylinux2014-cuda11.8-cudnn8.6-tensorrt8.4_config_platform//:platform" -build:rbe_linux_cuda11.8_nvcc_base --repo_env=TF_CUDA_CONFIG_REPO="@ubuntu20.04-clang_manylinux2014-cuda11.8-cudnn8.6-tensorrt8.4_config_cuda" -build:rbe_linux_cuda11.8_nvcc_base --repo_env=TF_NCCL_CONFIG_REPO="@ubuntu20.04-clang_manylinux2014-cuda11.8-cudnn8.6-tensorrt8.4_config_nccl" -build:rbe_linux_cuda11.8_nvcc_py3.9 --config=rbe_linux_cuda11.8_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-clang_manylinux2014-cuda11.8-cudnn8.6-tensorrt8.4_config_python3.9" -build:rbe_linux_cuda11.8_nvcc_py3.9 --python_path="/usr/local/bin/python3.9" -build:rbe_linux_cuda11.8_nvcc_py3.10 --config=rbe_linux_cuda11.8_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-clang_manylinux2014-cuda11.8-cudnn8.6-tensorrt8.4_config_python3.10" -build:rbe_linux_cuda11.8_nvcc_py3.10 --python_path="/usr/local/bin/python3.10" -build:rbe_linux_cuda11.8_nvcc_py3.11 --config=rbe_linux_cuda11.8_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-clang_manylinux2014-cuda11.8-cudnn8.6-tensorrt8.4_config_python3.11" -build:rbe_linux_cuda11.8_nvcc_py3.11 --python_path="/usr/local/bin/python3.11" -build:rbe_linux_cuda11.8_nvcc_py3.12 --config=rbe_linux_cuda11.8_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-clang_manylinux2014-cuda11.8-cudnn8.6-tensorrt8.4_config_python3.12" -build:rbe_linux_cuda11.8_nvcc_py3.12 --python_path="/usr/local/bin/python3.12" - build:rbe_linux_cuda12.3_nvcc_base --config=rbe_linux_cuda_base build:rbe_linux_cuda12.3_nvcc_base --config=cuda_clang build:rbe_linux_cuda12.3_nvcc_base --action_env=TF_NVCC_CLANG="1" build:rbe_linux_cuda12.3_nvcc_base --action_env=TF_CUDA_VERSION=12 -build:rbe_linux_cuda12.3_nvcc_base --action_env=TF_CUDNN_VERSION=8 +build:rbe_linux_cuda12.3_nvcc_base --action_env=TF_CUDNN_VERSION=9 build:rbe_linux_cuda12.3_nvcc_base --action_env=CUDA_TOOLKIT_PATH="/usr/local/cuda-12" build:rbe_linux_cuda12.3_nvcc_base --action_env=LD_LIBRARY_PATH="/usr/local/cuda:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:/usr/local/tensorrt/lib" -build:rbe_linux_cuda12.3_nvcc_base --host_crosstool_top="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn8.9_config_cuda//crosstool:toolchain" -build:rbe_linux_cuda12.3_nvcc_base --crosstool_top="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn8.9_config_cuda//crosstool:toolchain" -build:rbe_linux_cuda12.3_nvcc_base --extra_toolchains="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn8.9_config_cuda//crosstool:toolchain-linux-x86_64" -build:rbe_linux_cuda12.3_nvcc_base --extra_execution_platforms="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn8.9_config_platform//:platform" -build:rbe_linux_cuda12.3_nvcc_base --host_platform="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn8.9_config_platform//:platform" -build:rbe_linux_cuda12.3_nvcc_base --platforms="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn8.9_config_platform//:platform" -build:rbe_linux_cuda12.3_nvcc_base --repo_env=TF_CUDA_CONFIG_REPO="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn8.9_config_cuda" -build:rbe_linux_cuda12.3_nvcc_base --repo_env=TF_NCCL_CONFIG_REPO="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn8.9_config_nccl" +build:rbe_linux_cuda12.3_nvcc_base --host_crosstool_top="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_cuda//crosstool:toolchain" +build:rbe_linux_cuda12.3_nvcc_base --crosstool_top="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_cuda//crosstool:toolchain" +build:rbe_linux_cuda12.3_nvcc_base --extra_toolchains="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_cuda//crosstool:toolchain-linux-x86_64" +build:rbe_linux_cuda12.3_nvcc_base --extra_execution_platforms="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_platform//:platform" +build:rbe_linux_cuda12.3_nvcc_base --host_platform="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_platform//:platform" +build:rbe_linux_cuda12.3_nvcc_base --platforms="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_platform//:platform" +build:rbe_linux_cuda12.3_nvcc_base --repo_env=TF_CUDA_CONFIG_REPO="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_cuda" +build:rbe_linux_cuda12.3_nvcc_base --repo_env=TF_NCCL_CONFIG_REPO="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_nccl" # RBE machines have an older CUDA driver version, so we have to enable driver forward compatibility build:rbe_linux_cuda12.3_nvcc_base --test_env=LD_LIBRARY_PATH=/usr/local/cuda/compat -build:rbe_linux_cuda12.3_nvcc_py3.9 --config=rbe_linux_cuda12.3_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn8.9_config_python3.9" -build:rbe_linux_cuda12.3_nvcc_py3.9 --python_path="/usr/local/bin/python3.9" -build:rbe_linux_cuda12.3_nvcc_py3.10 --config=rbe_linux_cuda12.3_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn8.9_config_python3.10" -build:rbe_linux_cuda12.3_nvcc_py3.10 --python_path="/usr/local/bin/python3.10" -build:rbe_linux_cuda12.3_nvcc_py3.11 --config=rbe_linux_cuda12.3_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn8.9_config_python3.11" -build:rbe_linux_cuda12.3_nvcc_py3.11 --python_path="/usr/local/bin/python3.11" -build:rbe_linux_cuda12.3_nvcc_py3.12 --config=rbe_linux_cuda12.3_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn8.9_config_python3.12" -build:rbe_linux_cuda12.3_nvcc_py3.12 --python_path="/usr/local/bin/python3.12" +build:rbe_linux_cuda12.3_nvcc_py3.10 --config=rbe_linux_cuda12.3_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_python3.10" +build:rbe_linux_cuda12.3_nvcc_py3.10 --repo_env HERMETIC_PYTHON_VERSION="3.10" +build:rbe_linux_cuda12.3_nvcc_py3.11 --config=rbe_linux_cuda12.3_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_python3.11" +build:rbe_linux_cuda12.3_nvcc_py3.11 --repo_env HERMETIC_PYTHON_VERSION="3.11" +build:rbe_linux_cuda12.3_nvcc_py3.12 --config=rbe_linux_cuda12.3_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_python3.12" +build:rbe_linux_cuda12.3_nvcc_py3.12 --repo_env HERMETIC_PYTHON_VERSION="3.12" # These you may need to change for your own GCP project. build:tensorflow_testing_rbe --project_id=tensorflow-testing @@ -303,23 +283,36 @@ build:cross_compile_linux_arm64 --cpu=aarch64 build:cross_compile_linux_arm64 --crosstool_top=@xla//tools/toolchains/cross_compile/cc:cross_compile_toolchain_suite build:rbe_cross_compile_base --config=rbe -# JAX depends on some local Python headers that are configured as Genrule. They -# are present on the local host machine but not on the remote execution machine, -# leading to build failures. To resolve the issue, the following line is added -# to make sure all Genrule targets are excuted locally. -build:rbe_cross_compile_base --strategy=Genrule=standalone -# Due to the above strategy, all Genrule commands are executed locally, but the -# following actions invoke tools (E.g `flatc`, `llvm-tblgen`, etc.) that are -# only executabe on the RBE (x86) machine, so the strategy_regexp options are -# added to override and run the actions using remote strategy. -build:rbe_cross_compile_base --strategy_regexp='Generating code from table.*=remote' -build:rbe_cross_compile_base --strategy_regexp='Generating flatbuffer files.*=remote' -build:rbe_cross_compile_base --strategy_regexp='Executing genrule @llvm-project.*=remote' # RBE cross-compile configs for Linux Aarch64 build:rbe_cross_compile_linux_arm64 --config=cross_compile_linux_arm64 build:rbe_cross_compile_linux_arm64 --config=rbe_cross_compile_base # END LINUX AARCH64 CROSS-COMPILE CONFIGS + +# START MACOS CROSS-COMPILE CONFIGS +build:cross_compile_macos_x86 --config=cross_compile_base +build:cross_compile_macos_x86 --config=nonccl +# Target Catalina (10.15) as the minimum supported OS +build:cross_compile_macos_x86 --action_env MACOSX_DEPLOYMENT_TARGET=10.15 + +# Set the target CPU to Darwin x86 +build:cross_compile_macos_x86 --platforms=@xla//tools/toolchains/cross_compile/config:darwin_x86_64 +build:cross_compile_macos_x86 --cpu=darwin +build:cross_compile_macos_x86 --crosstool_top=@xla//tools/toolchains/cross_compile/cc:cross_compile_toolchain_suite +# When RBE cross-compiling for macOS, we need to explicitly register the +# toolchain. Otherwise, oddly, RBE complains that a "docker container must be +# specified". +build:cross_compile_macos_x86 --extra_toolchains=@xla//tools/toolchains/cross_compile/config:macos-x86-cross-compile-cc-toolchain +# Map --platforms=darwin_x86_64 to --cpu=darwin and vice-versa to make selects() +# and transistions that use these flags work. The flag --platform_mappings needs +# to be set to a file that exists relative to the package path roots. +build:cross_compile_macos_x86 --platform_mappings=platform_mappings + +# RBE cross-compile configs for Darwin x86 +build:rbe_cross_compile_macos_x86 --config=cross_compile_macos_x86 +build:rbe_cross_compile_macos_x86 --config=rbe_cross_compile_base +# END MACOS CROSS-COMPILE CONFIGS + # END CROSS-COMPILE CONFIGS ############################################################################# diff --git a/.bazelversion b/.bazelversion index 5e3254243a3b..f22d756da39d 100644 --- a/.bazelversion +++ b/.bazelversion @@ -1 +1 @@ -6.1.2 +6.5.0 diff --git a/.github/workflows/ci-build.yaml b/.github/workflows/ci-build.yaml index c465d42c4db2..cc05acbfd7a3 100644 --- a/.github/workflows/ci-build.yaml +++ b/.github/workflows/ci-build.yaml @@ -1,8 +1,8 @@ name: CI # We test all supported Python versions as follows: -# - 3.9 : Documentation build -# - 3.9 : Part of Matrix with NumPy dispatch +# - 3.10 : Documentation build +# - 3.10 : Part of Matrix with NumPy dispatch # - 3.10 : Part of Matrix # - 3.11 : Part of Matrix @@ -32,7 +32,7 @@ jobs: if: ${{github.ref != 'refs/heads/main'}} - uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # ratchet:actions/checkout@v4 - name: Set up Python 3.11 - uses: actions/setup-python@0a5c61591373683505ea898e09a3ea4f39ef2b9c # ratchet:actions/setup-python@v5 + uses: actions/setup-python@82c7e631bb3cdc910f68e0081d67478d79c6982d # ratchet:actions/setup-python@v5 with: python-version: 3.11 - uses: pre-commit/action@2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd # ratchet: pre-commit/action@v3.0.1 @@ -45,8 +45,8 @@ jobs: matrix: # Test the oldest and newest supported Python versions here. include: - - name-prefix: "with 3.9" - python-version: "3.9" + - name-prefix: "with 3.10" + python-version: "3.10" os: ubuntu-20.04-16core enable-x64: 1 prng-upgrade: 1 @@ -65,7 +65,7 @@ jobs: if: ${{github.ref != 'refs/heads/main'}} - uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # ratchet:actions/checkout@v4 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@0a5c61591373683505ea898e09a3ea4f39ef2b9c # ratchet:actions/setup-python@v5 + uses: actions/setup-python@82c7e631bb3cdc910f68e0081d67478d79c6982d # ratchet:actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - name: Get pip cache dir @@ -74,7 +74,7 @@ jobs: python -m pip install --upgrade pip wheel echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT - name: pip cache - uses: actions/cache@ab5e6d0c87105b4c9c2047343972218f562e4319 # ratchet: actions/cache@v4 + uses: actions/cache@0c45773b623bea8c8e75f6c82b208c3cf94ea4f9 # ratchet: actions/cache@v4 with: path: ${{ steps.pip-cache.outputs.dir }} key: ${{ runner.os }}-py${{ matrix.python-version }}-pip-${{ hashFiles('**/setup.py', '**/requirements.txt', '**/test-requirements.txt') }} @@ -108,7 +108,7 @@ jobs: timeout-minutes: 10 strategy: matrix: - python-version: [3.9] + python-version: ['3.10'] steps: - name: Cancel previous uses: styfle/cancel-workflow-action@85880fa0301c86cca9da44039ee3bb12d3bedbfa # ratchet: styfle/cancel-workflow-action@0.12.1 @@ -117,7 +117,7 @@ jobs: if: ${{github.ref != 'refs/heads/main'}} - uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # ratchet:actions/checkout@v4 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@0a5c61591373683505ea898e09a3ea4f39ef2b9c # ratchet:actions/setup-python@v5 + uses: actions/setup-python@82c7e631bb3cdc910f68e0081d67478d79c6982d # ratchet:actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - name: Get pip cache dir @@ -126,7 +126,7 @@ jobs: python -m pip install --upgrade pip wheel echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT - name: pip cache - uses: actions/cache@ab5e6d0c87105b4c9c2047343972218f562e4319 # ratchet: actions/cache@v4 + uses: actions/cache@0c45773b623bea8c8e75f6c82b208c3cf94ea4f9 # ratchet: actions/cache@v4 with: path: ${{ steps.pip-cache.outputs.dir }} key: ${{ runner.os }}-pip-docs-${{ hashFiles('**/setup.py', '**/requirements.txt', '**/test-requirements.txt') }} @@ -136,11 +136,12 @@ jobs: - name: Test documentation env: XLA_FLAGS: "--xla_force_host_platform_device_count=8" + JAX_TRACEBACK_FILTERING: "off" JAX_ARRAY: 1 PY_COLORS: 1 run: | - pytest -n auto --tb=short docs - pytest -n auto --tb=short --doctest-modules jax --ignore=jax/config.py --ignore=jax/experimental/jax2tf --ignore=jax/_src/lib/mlir --ignore=jax/_src/lib/triton.py --ignore=jax/interpreters/mlir.py --ignore=jax/_src/iree.py --ignore=jax/experimental/array_serialization --ignore=jax/collect_profile.py --ignore=jax/_src/tpu_custom_call.py --ignore=jax/experimental/mosaic --ignore=jax/experimental/pallas --ignore=jax/_src/pallas + pytest -n auto --tb=short --doctest-glob='*.md' --doctest-glob='*.rst' docs --doctest-continue-on-failure --ignore=docs/multi_process.md + pytest -n auto --tb=short --doctest-modules jax --ignore=jax/config.py --ignore=jax/experimental/jax2tf --ignore=jax/_src/lib/mlir --ignore=jax/_src/lib/triton.py --ignore=jax/_src/lib/mosaic_gpu.py --ignore=jax/interpreters/mlir.py --ignore=jax/experimental/array_serialization --ignore=jax/collect_profile.py --ignore=jax/_src/tpu_custom_call.py --ignore=jax/experimental/mosaic --ignore=jax/experimental/pallas --ignore=jax/_src/pallas --ignore=jax/experimental/maps.py documentation_render: @@ -149,7 +150,7 @@ jobs: timeout-minutes: 10 strategy: matrix: - python-version: [3.9] + python-version: ['3.10'] steps: - name: Cancel previous uses: styfle/cancel-workflow-action@85880fa0301c86cca9da44039ee3bb12d3bedbfa # ratchet: styfle/cancel-workflow-action@0.12.1 @@ -158,7 +159,7 @@ jobs: if: ${{github.ref != 'refs/heads/main'}} - uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # ratchet:actions/checkout@v4 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@0a5c61591373683505ea898e09a3ea4f39ef2b9c # ratchet:actions/setup-python@v5 + uses: actions/setup-python@82c7e631bb3cdc910f68e0081d67478d79c6982d # ratchet:actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - name: Get pip cache dir @@ -167,7 +168,7 @@ jobs: python -m pip install --upgrade pip wheel echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT - name: pip cache - uses: actions/cache@ab5e6d0c87105b4c9c2047343972218f562e4319 # ratchet: actions/cache@v4 + uses: actions/cache@0c45773b623bea8c8e75f6c82b208c3cf94ea4f9 # ratchet: actions/cache@v4 with: path: ${{ steps.pip-cache.outputs.dir }} key: ${{ runner.os }}-pip-docs-${{ hashFiles('**/setup.py', '**/requirements.txt', '**/test-requirements.txt') }} @@ -177,3 +178,57 @@ jobs: - name: Render documentation run: | sphinx-build --color -W --keep-going -b html -D nb_execution_mode=off docs docs/build/html + + + jax2tf_test: + name: "jax2tf_test (py ${{ matrix.python-version }} on ${{ matrix.os }}, x64=${{ matrix.enable-x64}})" + runs-on: ${{ matrix.os }} + timeout-minutes: 30 + strategy: + matrix: + # Test the oldest supported Python version here. + include: + - python-version: "3.10" + os: ubuntu-latest + enable-x64: 0 + num_generated_cases: 10 + steps: + - name: Cancel previous + uses: styfle/cancel-workflow-action@85880fa0301c86cca9da44039ee3bb12d3bedbfa # ratchet: styfle/cancel-workflow-action@0.12.1 + with: + access_token: ${{ github.token }} + if: ${{github.ref != 'refs/heads/main'}} + - uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # ratchet:actions/checkout@v4 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@82c7e631bb3cdc910f68e0081d67478d79c6982d # ratchet:actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + - name: Get pip cache dir + id: pip-cache + run: | + python -m pip install --upgrade pip wheel + echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT + - name: pip cache + uses: actions/cache@0c45773b623bea8c8e75f6c82b208c3cf94ea4f9 # ratchet: actions/cache@v4 + with: + path: ${{ steps.pip-cache.outputs.dir }} + key: ${{ runner.os }}-py${{ matrix.python-version }}-pip-${{ hashFiles('**/setup.py', '**/requirements.txt', '**/test-requirements.txt') }} + - name: Install dependencies + run: | + pip install .[minimum-jaxlib] tensorflow -r build/test-requirements.txt + + - name: Run tests + env: + JAX_NUM_GENERATED_CASES: ${{ matrix.num_generated_cases }} + JAX_ENABLE_X64: ${{ matrix.enable-x64 }} + JAX_ENABLE_CHECKS: true + JAX_SKIP_SLOW_TESTS: true + PY_COLORS: 1 + run: | + pip install -e . + echo "JAX_NUM_GENERATED_CASES=$JAX_NUM_GENERATED_CASES" + echo "JAX_ENABLE_X64=$JAX_ENABLE_X64" + echo "JAX_ENABLE_CHECKS=$JAX_ENABLE_CHECKS" + echo "JAX_SKIP_SLOW_TESTS=$JAX_SKIP_SLOW_TESTS" + pytest -n auto --tb=short --maxfail=20 jax/experimental/jax2tf/tests/jax2tf_test.py + \ No newline at end of file diff --git a/.github/workflows/cloud-tpu-ci-nightly.yml b/.github/workflows/cloud-tpu-ci-nightly.yml index b4356ba64d7c..c7d3afd4c4a9 100644 --- a/.github/workflows/cloud-tpu-ci-nightly.yml +++ b/.github/workflows/cloud-tpu-ci-nightly.yml @@ -25,12 +25,16 @@ jobs: fail-fast: false # don't cancel all jobs on failure matrix: jaxlib-version: ["pypi_latest", "nightly", "nightly+oldest_supported_libtpu"] - tpu-type: ["v3-8", "v4-8", "v5e-4"] - name: "TPU test (jaxlib=${{ matrix.jaxlib-version }}, ${{ matrix.tpu-type }})" + tpu: [ + {type: "v3-8", cores: "4"}, + {type: "v4-8", cores: "4"}, + {type: "v5e-8", cores: "8"} + ] + name: "TPU test (jaxlib=${{ matrix.jaxlib-version }}, ${{ matrix.tpu.type }})" env: LIBTPU_OLDEST_VERSION_DATE: 20240228 ENABLE_PJRT_COMPATIBILITY: ${{ matrix.jaxlib-version == 'nightly+oldest_supported_libtpu' }} - runs-on: ["self-hosted", "tpu", "${{ matrix.tpu-type }}"] + runs-on: ["self-hosted", "tpu", "${{ matrix.tpu.type }}"] timeout-minutes: 120 defaults: run: @@ -84,7 +88,7 @@ jobs: PY_COLORS: 1 run: | # Run single-accelerator tests in parallel - JAX_ENABLE_TPU_XDIST=true python3 -m pytest -n=4 --tb=short \ + JAX_ENABLE_TPU_XDIST=true python3 -m pytest -n=${{ matrix.tpu.cores }} --tb=short \ --maxfail=20 -m "not multiaccelerator" tests examples # Run multi-accelerator across all chips python3 -m pytest --tb=short --maxfail=20 -m "multiaccelerator" tests @@ -95,5 +99,5 @@ jobs: curl --location --request POST '${{ secrets.BUILD_CHAT_WEBHOOK }}' \ --header 'Content-Type: application/json' \ --data-raw "{ - 'text': '\"$GITHUB_WORKFLOW\", jaxlib/libtpu version \"${{ matrix.jaxlib-version }}\", TPU type ${{ matrix.tpu-type }} job failed, timed out, or was cancelled: $GITHUB_SERVER_URL/$GITHUB_REPOSITORY/actions/runs/$GITHUB_RUN_ID' + 'text': '\"$GITHUB_WORKFLOW\", jaxlib/libtpu version \"${{ matrix.jaxlib-version }}\", TPU type ${{ matrix.tpu.type }} job failed, timed out, or was cancelled: $GITHUB_SERVER_URL/$GITHUB_REPOSITORY/actions/runs/$GITHUB_RUN_ID' }" diff --git a/.github/workflows/jax-array-api.yml b/.github/workflows/jax-array-api.yml index 59367832ad17..3a1dd863c2b0 100644 --- a/.github/workflows/jax-array-api.yml +++ b/.github/workflows/jax-array-api.yml @@ -25,11 +25,11 @@ jobs: with: repository: data-apis/array-api-tests # TODO(jakevdp) update this to a stable release/tag when available. - ref: '4f83bb3ee9146c30cf997f1527588b7a8b8ee6db' # Latest commit as of 2024-02-16 + ref: '33f2d2ea2f3dd2b3ceeeb4519d55e08096184149' # Latest commit as of 2024-05-28 submodules: 'true' path: 'array-api-tests' - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@0a5c61591373683505ea898e09a3ea4f39ef2b9c # ratchet:actions/setup-python@v5 + uses: actions/setup-python@82c7e631bb3cdc910f68e0081d67478d79c6982d # ratchet:actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - name: Install dependencies @@ -42,4 +42,4 @@ jobs: JAX_ENABLE_X64: 'true' run: | cd ${GITHUB_WORKSPACE}/array-api-tests - pytest --ci array_api_tests --max-examples=5 --derandomize --disable-deadline --skips-file ${GITHUB_WORKSPACE}/jax/experimental/array_api/skips.txt + pytest array_api_tests --max-examples=5 --derandomize --disable-deadline --skips-file ${GITHUB_WORKSPACE}/jax/experimental/array_api/skips.txt diff --git a/.github/workflows/metal_plugin_ci.yml b/.github/workflows/metal_plugin_ci.yml new file mode 100644 index 000000000000..6e67841460e9 --- /dev/null +++ b/.github/workflows/metal_plugin_ci.yml @@ -0,0 +1,50 @@ +# JAX-Metal plugin CI + +name: Jax-Metal CI +on: + schedule: + - cron: "0 12 * * *" # Daily at 12:00 UTC + workflow_dispatch: # allows triggering the workflow run manually + pull_request: # Automatically trigger on pull requests affecting this file + branches: + - main + paths: + - '**workflows/metal_plugin_ci.yml' + +jobs: + jax-metal-plugin-test: + + strategy: + fail-fast: false # don't cancel all jobs on failure + matrix: + jaxlib-version: ["pypi_latest", "nightly"] + name: "Jax-Metal plugin test (jaxlib=${{ matrix.jaxlib-version }})" + runs-on: [self-hosted, macOS, ARM64] + + steps: + - name: Get repo + uses: actions/checkout@v4 + with: + path: jax + - name: Setup build and test enviroment + run: | + rm -rf ${GITHUB_WORKSPACE}/jax-metal-venv + python3 -m venv ${GITHUB_WORKSPACE}/jax-metal-venv + source ${GITHUB_WORKSPACE}/jax-metal-venv/bin/activate + pip install -U pip numpy wheel + pip install absl-py pytest + if [[ "${{ matrix.jaxlib-version }}" == "nightly" ]]; then + pip install --pre jaxlib \ + -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html + fi; + cd jax + pip install . + pip install jax-metal + - name: Run test + run: | + source ${GITHUB_WORKSPACE}/jax-metal-venv/bin/activate + export ENABLE_PJRT_COMPATIBILITY=1 + cd jax + pytest tests/lax_metal_test.py + + diff --git a/.github/workflows/upstream-nightly.yml b/.github/workflows/upstream-nightly.yml new file mode 100644 index 000000000000..60b31f40bc98 --- /dev/null +++ b/.github/workflows/upstream-nightly.yml @@ -0,0 +1,173 @@ +name: CI - with Numpy/Scipy nightly wheels (nightly) +# This configures a github action that runs the JAX test suite against nightly development builds +# of numpy and scipy, in order to catch issues with new package versions prior to their release. +# Unlike our other CI, this is one that we expect to fail frequently, and so we don't run it against +# every commit and PR in the repository. Rather, we run it on a schedule, and failures lead to an +# issue being created or updated. +# Portions of this adapted from https://github.com/pydata/xarray/blob/main/.github/workflows/upstream-dev-ci.yaml + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +on: + schedule: + - cron: "0 12 * * *" # Daily at 12:00 UTC + workflow_dispatch: # allows triggering the workflow run manually + pull_request: # Automatically trigger on pull requests affecting this file + branches: + - main + paths: + - '**workflows/upstream-nightly.yml' + +jobs: + upstream-dev: + runs-on: ubuntu-20.04-16core + permissions: + contents: read + checks: write # for upload-artifact + defaults: + run: + shell: bash -l {0} + strategy: + fail-fast: false + matrix: + python-version: ["3.12"] + outputs: + artifacts_availability: ${{ steps.status.outputs.ARTIFACTS_AVAILABLE }} + steps: + - uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # ratchet:actions/checkout@v4 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@82c7e631bb3cdc910f68e0081d67478d79c6982d # ratchet:actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + - name: Install JAX test requirements + run: | + pip install -r build/test-requirements.txt + pip install pytest-reportlog + - name: Install numpy & scipy development versions + run: | + pip install \ + -i https://pypi.anaconda.org/scientific-python-nightly-wheels/simple \ + --no-deps \ + --pre \ + --upgrade \ + numpy \ + scipy + python -c "import numpy; print(f'{numpy.__version__=}')" + python -c "import scipy; print(f'{scipy.__version__=}')" + - name: Install JAX + run: | + pip install .[ci] + - name: Run tests + if: success() + id: status + env: + JAX_NUM_GENERATED_CASES: 1 + JAX_ENABLE_X64: true + JAX_ENABLE_CHECKS: true + JAX_SKIP_SLOW_TESTS: true + PY_COLORS: 1 + run: | + echo "JAX_NUM_GENERATED_CASES=$JAX_NUM_GENERATED_CASES" + echo "JAX_ENABLE_X64=$JAX_ENABLE_X64" + echo "JAX_ENABLE_CHECKS=$JAX_ENABLE_CHECKS" + echo "JAX_SKIP_SLOW_TESTS=$JAX_SKIP_SLOW_TESTS" + pytest -n auto --tb=short -rf --maxfail=20 \ + --report-log output-${{ matrix.python-version }}-log.jsonl \ + tests \ + || ( + echo 'ARTIFACTS_AVAILABLE=true' >> $GITHUB_OUTPUT && false + ) + - name: Upload artifacts + if: | + failure() + && steps.status.outcome == 'failure' + && github.event_name == 'schedule' + && github.repository == 'google/jax' + uses: actions/upload-artifact@65462800fd760344b1a7b4382951275a0abb4808 # ratchet: actions/upload-artifact@v4 + with: + name: output-${{ matrix.python-version }}-log.jsonl + path: output-${{ matrix.python-version }}-log.jsonl + retention-days: 5 + + report: + name: report + needs: upstream-dev + permissions: + contents: read + issues: write + if: | + failure() + && github.event_name == 'schedule' + && needs.upstream-dev.outputs.artifacts_availability == 'true' + runs-on: ubuntu-latest + defaults: + run: + shell: bash + steps: + - uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # ratchet:actions/checkout@v4 + - uses: actions/setup-python@82c7e631bb3cdc910f68e0081d67478d79c6982d # ratchet:actions/setup-python@v5 + with: + python-version: "3.x" + - uses: actions/download-artifact@65a9edc5881444af0b9093a5e628f2fe47ea3b2e # ratchet:actions/download-artifact@v4 + with: + path: /tmp/workspace/logs + - name: install requirements + run: | + python -m pip install pytest + - name: Move all log files into a single directory + run: | + rsync -a /tmp/workspace/logs/output-*/ ./logs + ls -R ./logs + cat logs/*.jsonl > pytest-logs.txt + python .github/workflows/parse_logs.py pytest-logs.txt --outfile=parsed-logs.txt + - name: Report failures + uses: actions/github-script@60a0d83039c74a4aee543508d2ffcb1c3799cdea # ratchet:actions/github-script@v7 + with: + github-token: ${{ secrets.GITHUB_TOKEN }} + script: | + const fs = require('fs'); + const parsed_logs = fs.readFileSync('parsed-logs.txt', 'utf8'); + const title = "⚠️ Nightly upstream-dev CI failed ⚠️" + const workflow_url = `https://github.com/${process.env.GITHUB_REPOSITORY}/actions/runs/${process.env.GITHUB_RUN_ID}` + const issue_body = `[Workflow Run URL](${workflow_url})\n${parsed_logs}` + // Run GraphQL query against GitHub API to find the most recent open issue used for reporting failures + const query = `query($owner:String!, $name:String!, $creator:String!, $label:String!){ + repository(owner: $owner, name: $name) { + issues(first: 1, states: OPEN, filterBy: {createdBy: $creator, labels: [$label]}, orderBy: {field: CREATED_AT, direction: DESC}) { + edges { + node { + body + id + number + } + } + } + } + }`; + const variables = { + owner: context.repo.owner, + name: context.repo.repo, + label: 'CI', + creator: "github-actions[bot]" + } + const result = await github.graphql(query, variables) + // If no issue is open, create a new issue, + // else update the body of the existing issue. + if (result.repository.issues.edges.length === 0) { + github.rest.issues.create({ + owner: variables.owner, + repo: variables.name, + body: issue_body, + title: title, + labels: [variables.label] + }) + } else { + github.rest.issues.update({ + owner: variables.owner, + repo: variables.name, + issue_number: result.repository.issues.edges[0].node.number, + body: issue_body + }) + } diff --git a/.github/workflows/wheel_win_x64.yml b/.github/workflows/wheel_win_x64.yml index 96e07ea3f3f9..6433fb66039e 100644 --- a/.github/workflows/wheel_win_x64.yml +++ b/.github/workflows/wheel_win_x64.yml @@ -11,9 +11,9 @@ jobs: strategy: fail-fast: false # Don't stop all wheel builds if one has a test failure. matrix: - os: [win-2019-16core] + os: [windows-2019-32core] arch: [AMD64] - pyver: ['3.9', '3.10', '3.11', '3.12'] + pyver: ['3.10', '3.11', '3.12'] name: ${{ matrix.os }} ${{ matrix.pyver }} jaxlib wheel build runs-on: ${{ matrix.os }} @@ -23,9 +23,12 @@ jobs: with: access_token: ${{ github.token }} + - name: Install LLVM/Clang + run: choco install llvm --version=18.1.4 --yes --no-progress --allow-downgrade + - uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # ratchet:actions/checkout@v4 - - uses: actions/setup-python@0a5c61591373683505ea898e09a3ea4f39ef2b9c # ratchet:actions/setup-python@v5 + - uses: actions/setup-python@82c7e631bb3cdc910f68e0081d67478d79c6982d # ratchet:actions/setup-python@v5 with: python-version: ${{ matrix.pyver }} cache: 'pip' @@ -36,10 +39,14 @@ jobs: JAXLIB_RELEASE: true run: | python -m pip install -r build/test-requirements.txt + python -m pip install --upgrade numpy==2.0.0 scipy==1.13.1 "C:\\msys64\\;C:\\msys64\\usr\\bin\\;" >> $env:GITHUB_PATH - python.exe build\build.py --bazel_options=--color=yes --verbose + python.exe build\build.py ` + --bazel_options=--color=yes ` + --bazel_options=--config=win_clang ` + --verbose - - uses: actions/upload-artifact@5d5d22a31266ced268874388b861e4b58bb5c2f3 # ratchet: actions/upload-artifact@v4 + - uses: actions/upload-artifact@65462800fd760344b1a7b4382951275a0abb4808 # ratchet: actions/upload-artifact@v4 with: name: wheels-${{ matrix.os }}-${{ matrix.pyver }} path: ${{ github.workspace }}\dist\*.whl @@ -54,4 +61,4 @@ jobs: python -m pip install -e ${{ github.workspace }} python -m pip install --no-index --find-links ${{ github.workspace }}\dist jaxlib echo "JAX_ENABLE_CHECKS=$JAX_ENABLE_CHECKS" - pytest -n auto --tb=short tests examples \ No newline at end of file + pytest -n auto --tb=short tests examples diff --git a/.github/workflows/windows_ci.yml b/.github/workflows/windows_ci.yml index e80004bbf3ec..92f9355ae200 100644 --- a/.github/workflows/windows_ci.yml +++ b/.github/workflows/windows_ci.yml @@ -3,6 +3,8 @@ on: schedule: - cron: "0 12 * * *" # Daily at 12:00 UTC workflow_dispatch: # allows triggering the workflow run manually + pull_request: + types: [ labeled ] # allow force-windows-run label env: DISTUTILS_USE_SDK: 1 @@ -10,13 +12,14 @@ env: jobs: win-wheels: + if: ${{ (github.event.action != 'labeled') || (github.event.label.name == 'windows:force-run')}} strategy: fail-fast: true matrix: - os: [win-2019-16core] + os: [windows-2019-32core] arch: [AMD64] - pyver: ['3.9'] - name: ${{ matrix.os }} CI build + pyver: ['3.10'] + name: Windows CI build runs-on: ${{ matrix.os }} steps: @@ -25,16 +28,14 @@ jobs: with: access_token: ${{ github.token }} - - uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # ratchet:actions/checkout@v4 - with: - path: jax + - name: Install LLVM/Clang + run: choco install llvm --version=18.1.4 --yes --no-progress --allow-downgrade - uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # ratchet:actions/checkout@v4 with: - repository: openxla/xla - path: xla + path: jax - - uses: actions/setup-python@0a5c61591373683505ea898e09a3ea4f39ef2b9c # ratchet:actions/setup-python@v5 + - uses: actions/setup-python@82c7e631bb3cdc910f68e0081d67478d79c6982d # ratchet:actions/setup-python@v5 with: python-version: ${{ matrix.pyver }} cache: 'pip' @@ -46,10 +47,13 @@ jobs: run: | cd jax python -m pip install -r build/test-requirements.txt + python -m pip install --upgrade numpy==2.0.0 scipy==1.13.1 "C:\\msys64\\;C:\\msys64\\usr\\bin\\;" >> $env:GITHUB_PATH - python.exe build\build.py ('--bazel_options=--override_repository=xla=${{ github.workspace }}\xla' -replace '\\','\\') --bazel_options=--color=yes + python.exe build\build.py ` + --bazel_options=--color=yes ` + --bazel_options=--config=win_clang - - uses: actions/upload-artifact@5d5d22a31266ced268874388b861e4b58bb5c2f3 # ratchet: actions/upload-artifact@v4 + - uses: actions/upload-artifact@65462800fd760344b1a7b4382951275a0abb4808 # ratchet: actions/upload-artifact@v4 with: name: wheels path: ${{ github.workspace }}\jax\dist\*.whl @@ -65,4 +69,4 @@ jobs: python -m pip install -e ${{ github.workspace }}\jax python -m pip install --no-index --find-links ${{ github.workspace }}\jax\dist jaxlib echo "JAX_ENABLE_CHECKS=$JAX_ENABLE_CHECKS" - pytest -n auto --tb=short tests examples \ No newline at end of file + pytest -n auto --tb=short tests examples diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 0a82aaed45ed..3405c5e9d76b 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -9,7 +9,7 @@ repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.5.0 + rev: v4.6.0 hooks: - id: check-ast - id: check-merge-conflict @@ -26,21 +26,21 @@ repos: files: \.py$ - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.1.5 + rev: v0.4.4 hooks: - id: ruff - repo: https://github.com/pre-commit/mirrors-mypy - rev: 'v1.8.0' + rev: 'v1.10.0' hooks: - id: mypy files: (jax/|tests/typing_test\.py) exclude: jax/_src/basearray.py|jax/numpy/__init__.py # Use pyi instead - additional_dependencies: [types-requests==2.31.0, jaxlib==0.4.23, ml_dtypes==0.3.2, numpy==1.26.3, scipy==1.11.4] + additional_dependencies: [types-requests==2.31.0, jaxlib==0.4.27, ml_dtypes==0.3.2, numpy==1.26.3, scipy==1.11.4] args: [--config=pyproject.toml] - repo: https://github.com/mwouts/jupytext - rev: v1.16.0 + rev: v1.16.1 hooks: - id: jupytext args: [--sync] diff --git a/.readthedocs.yml b/.readthedocs.yml index e878805a1377..6f807aa82377 100644 --- a/.readthedocs.yml +++ b/.readthedocs.yml @@ -6,9 +6,9 @@ version: 2 build: - os: "ubuntu-20.04" + os: "ubuntu-22.04" tools: - python: "3.9" + python: "3.10" # Build documentation in the docs/ directory with Sphinx sphinx: diff --git a/CHANGELOG.md b/CHANGELOG.md index a4fe3af345c8..6c5cd34d1828 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,15 +6,290 @@ Best viewed [here](https://jax.readthedocs.io/en/latest/changelog.html). Remember to align the itemized text with the first line of an item within a list. --> -## jax 0.4.26 +## jax 0.4.31 + +* Changes + * The minimum Python version is now 3.10. 3.10 will remain the minimum + supported version until July 2025. + * The minimum NumPy version is now 1.24. NumPy 1.24 will remain the minimum + supported version until December 2024. + * {func}`jax.numpy.ceil`, {func}`jax.numpy.floor` and {func}`jax.numpy.trunc` now return the output + of the same dtype as the input, i.e. no longer upcast integer or boolean inputs to floating point. + * `libdevice.10.bc` is no longer bundled with CUDA wheels. It must be + installed either as a part of local CUDA installation, or via NVIDIA's CUDA + pip wheels. + * {class}`jax.experimental.pallas.BlockSpec` now expects `block_shape` to + be passed *before* `index_map`. The old argument order is deprecated and + will be removed in a future release. +* Deprecations + * Removed a number of previously-deprecated internal APIs related to + polymorphic shapes. From {mod}`jax.core`: removed `canonicalize_shape`, + `dimension_as_value`, `definitely_equal`, and `symbolic_equal_dim`. + +## jaxlib 0.4.31 + +* Bug fixes + * Fixed a bug that meant that negative static_argnums to a jit were mishandled + by the jit dispatch fast path. + +## jax 0.4.30 (June 18, 2024) + +* Changes + * JAX supports ml_dtypes >= 0.2. In 0.4.29 release, the ml_dtypes version was + bumped to 0.4.0 but this has been rolled back in this release to give users + of both TensorFlow and JAX more time to migrate to a newer TensorFlow + release. + * `jax.experimental.mesh_utils` can now create an efficient mesh for TPU v5e. + * jax now depends on jaxlib directly. This change was enabled by the CUDA + plugin switch: there are no longer multiple jaxlib variants. You can install + a CPU-only jax with `pip install jax`, no extras required. + * Added an API for exporting and serializing JAX functions. This used + to exist in `jax.experimental.export` (which is being deprecated), + and will now live in `jax.export`. + See the [documentation](https://jax.readthedocs.io/en/latest/export/index.html). + +* Deprecations + * Internal pretty-printing tools `jax.core.pp_*` are deprecated, and will be removed + in a future release. + * Hashing of tracers is deprecated, and will lead to a `TypeError` in a future JAX + release. This previously was the case, but there was an inadvertent regression in + the last several JAX releases. + * `jax.experimental.export` is deprecated. Use {mod}`jax.export` instead. + See the [migration guide](https://jax.readthedocs.io/en/latest/export/export.html#migration-guide-from-jax-experimental-export). + * Passing an array in place of a dtype is now deprecated in most cases; e.g. for arrays + `x` and `y`, `x.astype(y)` will raise a warning. To silence it use `x.astype(y.dtype)`. + * `jax.xla_computation` is deprecated and will be removed in a future release. + Please use the AOT APIs to get the same functionality as `jax.xla_computation`. + * `jax.xla_computation(fn)(*args, **kwargs)` can be replaced with + `jax.jit(fn).lower(*args, **kwargs).compiler_ir('hlo')`. + * You can also use `.out_info` property of `jax.stages.Lowered` to get the + output information (like tree structure, shape and dtype). + * For cross-backend lowering, you can replace + `jax.xla_computation(fn, backend='tpu')(*args, **kwargs)` with + `jax.jit(fn).trace(*args, **kwargs).lower(lowering_platforms=('tpu',)).compiler_ir('hlo')`. + + +## jaxlib 0.4.30 (June 18, 2024) + + * Support for monolithic CUDA jaxlibs has been dropped. You must use the + plugin-based installation (`pip install jax[cuda12]` or + `pip install jax[cuda12_local]`). + +## jax 0.4.29 (June 10, 2024) + +* Changes + * We anticipate that this will be the last release of JAX and jaxlib + supporting a monolithic CUDA jaxlib. Future releases will use the CUDA + plugin jaxlib (e.g. `pip install jax[cuda12]`). + * JAX now requires ml_dtypes version 0.4.0 or newer. + * Removed backwards-compatibility support for old usage of the + `jax.experimental.export` API. It is not possible anymore to use + `from jax.experimental.export import export`, and instead you should use + `from jax.experimental import export`. + The removed functionality has been deprecated since 0.4.24. + * Added `is_leaf` argument to {func}`jax.tree.all` & {func}`jax.tree_util.tree_all`. + +* Deprecations + * `jax.sharding.XLACompatibleSharding` is deprecated. Please use + `jax.sharding.Sharding`. + * `jax.experimental.Exported.in_shardings` has been renamed as + `jax.experimental.Exported.in_shardings_hlo`. Same for `out_shardings`. + The old names will be removed after 3 months. + * Removed a number of previously-deprecated APIs: + * from {mod}`jax.core`: `non_negative_dim`, `DimSize`, `Shape` + * from {mod}`jax.lax`: `tie_in` + * from {mod}`jax.nn`: `normalize` + * from {mod}`jax.interpreters.xla`: `backend_specific_translations`, + `translations`, `register_translation`, `xla_destructure`, + `TranslationRule`, `TranslationContext`, `XlaOp`. + * The ``tol`` argument of {func}`jax.numpy.linalg.matrix_rank` is being + deprecated and will soon be removed. Use `rtol` instead. + * The ``rcond`` argument of {func}`jax.numpy.linalg.pinv` is being + deprecated and will soon be removed. Use `rtol` instead. + * The deprecated `jax.config` submodule has been removed. To configure JAX + use `import jax` and then reference the config object via `jax.config`. + * {mod}`jax.random` APIs no longer accept batched keys, where previously + some did unintentionally. Going forward, we recommend explicit use of + {func}`jax.vmap` in such cases. + * In {func}`jax.scipy.special.beta`, the `x` and `y` parameters have been + renamed to `a` and `b` for consistency with other `beta` APIs. + +* New Functionality + * Added {func}`jax.experimental.Exported.in_shardings_jax` to construct + shardings that can be used with the JAX APIs from the HloShardings + that are stored in the `Exported` objects. + +## jaxlib 0.4.29 (June 10, 2024) + +* Bug fixes + * Fixed a bug where XLA sharded some concatenation operations incorrectly, + which manifested as an incorrect output for cumulative reductions (#21403). + * Fixed a bug where XLA:CPU miscompiled certain matmul fusions + (https://github.com/openxla/xla/pull/13301). + * Fixes a compiler crash on GPU (https://github.com/google/jax/issues/21396). * Deprecations + * `jax.tree.map(f, None, non-None)` now emits a `DeprecationWarning`, and will + raise an error in a future version of jax. `None` is only a tree-prefix of + itself. To preserve the current behavior, you can ask `jax.tree.map` to + treat `None` as a leaf value by writing: + `jax.tree.map(lambda x, y: None if x is None else f(x, y), a, b, is_leaf=lambda x: x is None)`. + +## jax 0.4.28 (May 9, 2024) + +* Bug fixes + * Reverted a change to `make_jaxpr` that was breaking Equinox (#21116). + +* Deprecations & removals + * The ``kind`` argument to {func}`jax.numpy.sort` and {func}`jax.numpy.argsort` + is now removed. Use `stable=True` or `stable=False` instead. + * Removed ``get_compute_capability`` from the ``jax.experimental.pallas.gpu`` + module. Use the ``compute_capability`` attribute of a GPU device, returned + by {func}`jax.devices` or {func}`jax.local_devices`, instead. + * The ``newshape`` argument to {func}`jax.numpy.reshape`is being deprecated + and will soon be removed. Use `shape` instead. + +* Changes + * The minimum jaxlib version of this release is 0.4.27. + +## jaxlib 0.4.28 (May 9, 2024) + +* Bug fixes + * Fixes a memory corruption bug in the type name of Array and JIT Python + objects in Python 3.10 or earlier. + * Fixed a warning `'+ptx84' is not a recognized feature for this target` + under CUDA 12.4. + * Fixed a slow compilation problem on CPU. + +* Changes + * The Windows build is now built with Clang instead of MSVC. + + +## jax 0.4.27 (May 7, 2024) + +* New Functionality + * Added {func}`jax.numpy.unstack` and {func}`jax.numpy.cumulative_sum`, + following their addition in the array API 2023 standard, soon to be + adopted by NumPy. + * Added a new config option `jax_cpu_collectives_implementation` to select the + implementation of cross-process collective operations used by the CPU backend. + Choices available are `'none'`(default), `'gloo'` and `'mpi'` (requires jaxlib 0.4.26). + If set to `'none'`, cross-process collective operations are disabled. + +* Changes + * {func}`jax.pure_callback`, {func}`jax.experimental.io_callback` + and {func}`jax.debug.callback` now use {class}`jax.Array` instead + of {class}`np.ndarray`. You can recover the old behavior by transforming + the arguments via `jax.tree.map(np.asarray, args)` before passing them + to the callback. + * `complex_arr.astype(bool)` now follows the same semantics as NumPy, returning + False where `complex_arr` is equal to `0 + 0j`, and True otherwise. + * `core.Token` now is a non-trivial class which wraps a `jax.Array`. It could + be created and threaded in and out of computations to build up dependency. + The singleton object `core.token` has been removed, users now should create + and use fresh `core.Token` objects instead. + * On GPU, the Threefry PRNG implementation no longer lowers to a kernel call + by default. This choice can improve runtime memory usage at a compile-time + cost. Prior behavior, which produces a kernel call, can be recovered with + `jax.config.update('jax_threefry_gpu_kernel_lowering', True)`. If the new + default causes issues, please file a bug. Otherwise, we intend to remove + this flag in a future release. + +* Deprecations & Removals + * Pallas now exclusively uses XLA for compiling kernels on GPU. The old + lowering pass via Triton Python APIs has been removed and the + `JAX_TRITON_COMPILE_VIA_XLA` environment variable no longer has any effect. + * {func}`jax.numpy.clip` has a new argument signature: `a`, `a_min`, and + `a_max` are deprecated in favor of `x` (positional only), `min`, and + `max` ({jax-issue}`20550`). + * The `device()` method of JAX arrays has been removed, after being deprecated + since JAX v0.4.21. Use `arr.devices()` instead. + * The `initial` argument to {func}`jax.nn.softmax` and {func}`jax.nn.log_softmax` + is deprecated; empty inputs to softmax are now supported without setting this. + * In {func}`jax.jit`, passing invalid `static_argnums` or `static_argnames` + now leads to an error rather than a warning. + * The minimum jaxlib version is now 0.4.23. + * The {func}`jax.numpy.hypot` function now issues a deprecation warning when + passing complex-valued inputs to it. This will raise an error when the + deprecation is completed. + * Scalar arguments to {func}`jax.numpy.nonzero`, {func}`jax.numpy.where`, and + related functions now raise an error, following a similar change in NumPy. + * The config option `jax_cpu_enable_gloo_collectives` is deprecated. + Use `jax.config.update('jax_cpu_collectives_implementation', 'gloo')` instead. + * The `jax.Array.device_buffer` and `jax.Array.device_buffers` methods have + been removed after being deprecated in JAX v0.4.22. Instead use + {attr}`jax.Array.addressable_shards` and {meth}`jax.Array.addressable_data`. + * The `condition`, `x`, and `y` parameters of `jax.numpy.where` are now + positional-only, following deprecation of the keywords in JAX v0.4.21. + * Non-array arguments to functions in {mod}`jax.lax.linalg` now must be + specified by keyword. Previously, this raised a DeprecationWarning. + * Array-like arguments are now required in several :func:`jax.numpy` APIs, + including {func}`~jax.numpy.apply_along_axis`, + {func}`~jax.numpy.apply_over_axes`, {func}`~jax.numpy.inner`, + {func}`~jax.numpy.outer`, {func}`~jax.numpy.cross`, + {func}`~jax.numpy.kron`, and {func}`~jax.numpy.lexsort`. + +* Bug fixes + * {func}`jax.numpy.astype` will now always return a copy when `copy=True`. + Previously, no copy would be made when the output array would have the same + dtype as the input array. This may result in some increased memory usage. + The default value is set to `copy=False` to preserve backwards compatibility. + +## jaxlib 0.4.27 (May 7, 2024) + +## jax 0.4.26 (April 3, 2024) + +* New Functionality + * Added {func}`jax.numpy.trapezoid`, following the addition of this function in + NumPy 2.0. + +* Changes + * Complex-valued {func}`jax.numpy.geomspace` now chooses the logarithmic spiral + branch consistent with that of NumPy 2.0. + * The behavior of `lax.rng_bit_generator`, and in turn the `'rbg'` + and `'unsafe_rbg'` PRNG implementations, under `jax.vmap` [has + changed](https://github.com/google/jax/issues/19085) so that + mapping over keys results in random generation only from the first + key in the batch. + * Docs now use `jax.random.key` for construction of PRNG key arrays + rather than `jax.random.PRNGKey`. + +* Deprecations & Removals * {func}`jax.tree_map` is deprecated; use `jax.tree.map` instead, or for backward compatibility with older JAX versions, use {func}`jax.tree_util.tree_map`. + * {func}`jax.clear_backends` is deprecated as it does not necessarily do what + its name suggests and can lead to unexpected consequences, e.g., it will not + destroy existing backends and release corresponding owned resources. Use + {func}`jax.clear_caches` if you only want to clean up compilation caches. + For backward compatibility or you really need to switch/reinitialize the + default backend, use {func}`jax.extend.backend.clear_backends`. + * The `jax.experimental.maps` module and `jax.experimental.maps.xmap` are + deprecated. Use `jax.experimental.shard_map` or `jax.vmap` with the + `spmd_axis_name` argument for expressing SPMD device-parallel computations. + * The `jax.experimental.host_callback` module is deprecated. + Use instead the [new JAX external callbacks](https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html). + Added `JAX_HOST_CALLBACK_LEGACY` flag to assist in the transition to the + new callbacks. See {jax-issue}`#20385` for a discussion. * Passing arguments to {func}`jax.numpy.array_equal` and {func}`jax.numpy.array_equiv` that cannot be converted to a JAX array now results in an exception. + * The deprecated flag `jax_parallel_functions_output_gda` has been removed. + This flag was long deprecated and did nothing; its use was a no-op. + * The previously-deprecated imports `jax.interpreters.ad.config` and + `jax.interpreters.ad.source_info_util` have now been removed. Use `jax.config` + and `jax.extend.source_info_util` instead. + * JAX export does not support older serialization versions anymore. Version 9 + has been supported since October 27th, 2023 and has become the default + since February 1, 2024. + See [a description of the versions](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#native-serialization-versions). + This change could break clients that set a specific + JAX serialization version lower than 9. + +## jaxlib 0.4.26 (April 3, 2024) -## jaxlib 0.4.26 +* Changes + * JAX now supports CUDA 12.1 or newer only. Support for CUDA 11.8 has been + dropped. + * JAX now supports NumPy 2.0. ## jax 0.4.25 (Feb 26, 2024) @@ -63,7 +338,7 @@ Remember to align the itemized text with the first line of an item within a list * Changes * JAX lowering to StableHLO does not depend on physical devices anymore. - If your primitive wraps custom_paritioning or JAX callbacks in the lowering + If your primitive wraps custom_partitioning or JAX callbacks in the lowering rule i.e. function passed to `rule` parameter of `mlir.register_lowering` then add your primitive to `jax._src.dispatch.prim_requires_devices_during_lowering` set. This is needed because custom_partitioning and JAX callbacks need physical @@ -85,8 +360,8 @@ Remember to align the itemized text with the first line of an item within a list cannot interact, e.g., in arithmetic operations. Scopes are introduced by {func}`jax.experimental.jax2tf.convert`, {func}`jax.experimental.export.symbolic_shape`, {func}`jax.experimental.export.symbolic_args_specs`. - The scope of a symbolic expression `e` can be read with `e.scope` and passed in - to the above functions to direct them to construct sybolic expressions in + The scope of a symbolic expression `e` can be read with `e.scope` and passed + into the above functions to direct them to construct symbolic expressions in a given scope. See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#user-specified-symbolic-constraints. * simplified and faster equality comparisons, where we consider two symbolic dimensions @@ -217,7 +492,7 @@ Remember to align the itemized text with the first line of an item within a list that cannot be converted to a JAX array is deprecated and now raises a {obj}`DeprecationWaning`. Currently the functions return False, in the future this will raise an exception. - * The `device()` method of JAX arrays deprecated. Depending on the context, it may + * The `device()` method of JAX arrays is deprecated. Depending on the context, it may be replaced with one of the following: - {meth}`jax.Array.devices` returns the set of all devices used by the array. - {attr}`jax.Array.sharding` gives the sharding configuration used by the array. @@ -267,7 +542,7 @@ Remember to align the itemized text with the first line of an item within a list * Bug fixes * Only process 0 in a multicontroller distributed JAX program will write persistent compilation cache entries. This fixes write contention if the - cache is placed on a network filesystem such as GCS. + cache is placed on a network file system such as GCS. * The version check for cusolver and cufft no longer considers the patch versions when determining if the installed version of these libraries is at least as new as the versions against which JAX was built. @@ -308,7 +583,7 @@ Remember to align the itemized text with the first line of an item within a list ## jax 0.4.17 (Oct 3, 2023) * New features - * Added new {func}`jax.numpy.bitwise_count` function, matching the API of the simlar + * Added new {func}`jax.numpy.bitwise_count` function, matching the API of the similar function recently added to NumPy. * Deprecations * Removed the deprecated module `jax.abstract_arrays` and all its contents. @@ -735,8 +1010,8 @@ Changes: custom pytree node. This includes: * `tree_flatten_with_path` that flattens a tree and return not only each leaf but also their key paths. - * `tree_map_with_paths` that can map a function that takes the key path as argument. - * `register_pytree_with_keys`` to register how the key path and leaves should looks + * `tree_map_with_path` that can map a function that takes the key path as an argument. + * `register_pytree_with_keys` to register how the key path and leaves should looks like in a custom pytree node. * `keystr` that pretty-prints a key path. @@ -799,7 +1074,7 @@ Changes: * Breaking Changes * the `initial` argument to reduction functions like :func:`jax.numpy.sum` is now required to be a scalar, consistent with the corresponding NumPy API. - The previous behavior of broadcating the output against non-scalar `initial` + The previous behavior of broadcasting the output against non-scalar `initial` values was an unintentional implementation detail ({jax-issue}`#14446`). ## jaxlib 0.4.4 (Feb 16, 2023) @@ -1011,7 +1286,7 @@ Changes: changes to how `isinstance` works for {class}`jax.numpy.ndarray` for jax-internal objects, as {class}`jax.numpy.ndarray` is now a simple alias of {class}`jax.Array`. * Breaking changes - * `jax._src` is no longer imported into the from the public `jax` namespace. + * `jax._src` is no longer imported into the public `jax` namespace. This may break users that were using JAX internals. * `jax.soft_pmap` has been deleted. Please use `pjit` or `xmap` instead. `jax.soft_pmap` is undocumented. If it were documented, a deprecation period @@ -1120,7 +1395,7 @@ Changes: * Added {func}`jax.random.ball`. * Added {func}`jax.default_device`. * Added a `python -m jax.collect_profile` script to manually capture program - traces as an alternative to the Tensorboard UI. + traces as an alternative to the TensorBoard UI. * Added a `jax.named_scope` context manager that adds profiler metadata to Python programs (similar to `jax.named_call`). * In scatter-update operations (i.e. :attr:`jax.numpy.ndarray.at`), unsafe implicit @@ -1230,7 +1505,7 @@ Changes: `format_shape_dtype_string`, `rand_uniform`, `skip_on_devices`, `with_config`, `xla_bridge`, and `_default_tolerance` ({jax-issue}`#10389`). These, along with previously-deprecated `JaxTestCase`, `JaxTestLoader`, and `BufferDonationTestCase`, will be removed in a future JAX release. - Most of these utilites can be replaced by calls to standard python & numpy testing utilities found + Most of these utilities can be replaced by calls to standard python & numpy testing utilities found in e.g. {mod}`unittest`, {mod}`absl.testing`, {mod}`numpy.testing`, etc. JAX-specific functionality such as device checking can be replaced through the use of public APIs such as {func}`jax.devices`. Many of the deprecated utilities will still exist in {mod}`jax._src.test_util`, but these are not @@ -1395,7 +1670,7 @@ Changes: special autodiff handling for hcb.id_tap and id_print. From now on, only the primals are tapped. The old behavior can be obtained (for a limited time) by setting the ``JAX_HOST_CALLBACK_AD_TRANSFORMS`` - environment variable, or the ```--flax_host_callback_ad_transforms``` flag. + environment variable, or the ```--jax_host_callback_ad_transforms``` flag. Additionally, added documentation for how to implement the old behavior using JAX custom AD APIs ({jax-issue}`#8678`). * Sorting now matches the behavior of NumPy for ``0.0`` and ``NaN`` regardless of the @@ -1434,7 +1709,7 @@ Changes: * Bug fixes: * Out-of-bounds indices to `jax.ops.segment_sum` will now be handled with - `FILL_OR_DROP` semantics, as documented. This primarily afects the + `FILL_OR_DROP` semantics, as documented. This primarily affects the reverse-mode derivative, where gradients corresponding to out-of-bounds indices will now be returned as 0. (#8634). * jax2tf will force the converted code to use XLA for the code fragments @@ -1876,7 +2151,7 @@ Changes: ## jaxlib 0.1.61 (February 12 2021) -## jaxlib 0.1.60 (Febuary 3 2021) +## jaxlib 0.1.60 (February 3 2021) * Bug fixes: * Fixed a memory leak when converting CPU DeviceArrays to NumPy arrays. The @@ -1930,7 +2205,7 @@ Changes: * `host_callback.outfeed_receiver` has been removed (it is not necessary, and was deprecated a few months ago). * New features: - * New flag for debugging `inf`, analagous to that for `NaN` ({jax-issue}`#5224`). + * New flag for debugging `inf`, analogous to that for `NaN` ({jax-issue}`#5224`). ## jax 0.2.7 (Dec 4 2020) @@ -2169,7 +2444,7 @@ Changes: * Adds preliminary support for on-device heap profiling. * Implements `np.nextafter` for `bfloat16` types. * Complex128 support for FFTs on CPU and GPU. -* Bugfixes: +* Bug fixes: * Improved float64 `tanh` accuracy on GPU. * float64 scatters on GPU are much faster. * Complex matrix multiplication on CPU should be much faster. @@ -2286,7 +2561,7 @@ Changes: * Added several new rules for `jax.experimental.jet` {jax-issue}`#2537`. * Fixed `jax.experimental.stax.BatchNorm` when `scale`/`center` isn't provided. * Fix some missing cases of broadcasting in `jax.numpy.einsum` {jax-issue}`#2512`. -* Implement `jax.numpy.cumsum` and `jax.numpy.cumprod` in terms of a parallel prefix scan {jax-issue}`#2596` and make `reduce_prod` differentiable to arbitray order {jax-issue}`#2597`. +* Implement `jax.numpy.cumsum` and `jax.numpy.cumprod` in terms of a parallel prefix scan {jax-issue}`#2596` and make `reduce_prod` differentiable to arbitrary order {jax-issue}`#2597`. * Add `batch_group_count` to `conv_general_dilated` {jax-issue}`#2635`. * Add docstring for `test_util.check_grads` {jax-issue}`#2656`. * Add `callback_transform` {jax-issue}`#2665`. diff --git a/README.md b/README.md index 522d20cea734..b19d7b9ff128 100644 --- a/README.md +++ b/README.md @@ -17,9 +17,8 @@ ## What is JAX? -JAX is [Autograd](https://github.com/hips/autograd) and [XLA](https://www.tensorflow.org/xla), -brought together for high-performance numerical computing, including -large-scale machine learning research. +JAX is a Python library for accelerator-oriented array computation and program transformation, +designed for high-performance numerical computing and large-scale machine learning. With its updated version of [Autograd](https://github.com/hips/autograd), JAX can automatically differentiate native @@ -84,7 +83,7 @@ perex_grads = jit(vmap(grad_loss, in_axes=(None, 0, 0))) # fast per-example gra ## Quickstart: Colab in the Cloud Jump right in using a notebook in your browser, connected to a Google Cloud GPU. Here are some starter notebooks: -- [The basics: NumPy on accelerators, `grad` for differentiation, `jit` for compilation, and `vmap` for vectorization](https://jax.readthedocs.io/en/latest/notebooks/quickstart.html) +- [The basics: NumPy on accelerators, `grad` for differentiation, `jit` for compilation, and `vmap` for vectorization](https://jax.readthedocs.io/en/latest/quickstart.html) - [Training a Simple Neural Network, with TensorFlow Dataset Data Loading](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/neural_network_with_tfds_data.ipynb) **JAX now runs on Cloud TPUs.** To try out the preview, see the [Cloud TPU @@ -361,8 +360,11 @@ Some standouts: startup (or set the environment variable `JAX_ENABLE_X64=True`). On TPU, JAX uses 32-bit values by default for everything _except_ internal temporary variables in 'matmul-like' operations, such as `jax.numpy.dot` and `lax.conv`. - Those ops have a `precision` parameter which can be used to simulate - true 32-bit, with a cost of possibly slower runtime. + Those ops have a `precision` parameter which can be used to approximate 32-bit operations + via three bfloat16 passes, with a cost of possibly slower runtime. + Non-matmul operations on TPU lower to implementations that often emphasize speed over + accuracy, so in practice computations on TPU will be less precise than similar + computations on other backends. 1. Some of NumPy's dtype promotion semantics involving a mix of Python scalars and NumPy types aren't preserved, namely `np.add(1, np.array([2], np.float32)).dtype` is `float64` rather than `float32`. @@ -394,8 +396,8 @@ Some standouts: | Hardware | Instructions | |------------|-----------------------------------------------------------------------------------------------------------------| -| CPU | `pip install -U "jax[cpu]"` | -| NVIDIA GPU on x86_64 | `pip install -U "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html` | +| CPU | `pip install -U jax` | +| NVIDIA GPU | `pip install -U "jax[cuda12]"` | | Google TPU | `pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html` | | AMD GPU | Use [Docker](https://hub.docker.com/r/rocm/jax) or [build from source](https://jax.readthedocs.io/en/latest/developer.html#additional-notes-for-building-a-rocm-jaxlib-for-amd-gpus). | | Apple GPU | Follow [Apple's instructions](https://developer.apple.com/metal/jax/). | @@ -412,7 +414,8 @@ community-supported conda build, and answers to some frequently-asked questions. Multiple Google research groups develop and share libraries for training neural networks in JAX. If you want a fully featured library for neural network training with examples and how-to guides, try -[Flax](https://github.com/google/flax). +[Flax](https://github.com/google/flax). Check out the new [NNX](https://flax.readthedocs.io/en/latest/nnx/index.html) API for a +simplified development experience. Google X maintains the neural network library [Equinox](https://github.com/patrick-kidger/equinox). This is used as the diff --git a/WORKSPACE b/WORKSPACE index 51f3df35d6b0..e574bd9f9611 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -2,6 +2,41 @@ load("//third_party/xla:workspace.bzl", jax_xla_workspace = "repo") jax_xla_workspace() +# Initialize hermetic Python +load("@xla//third_party/py:python_init_rules.bzl", "python_init_rules") +python_init_rules() + +load("@xla//third_party/py:python_init_repositories.bzl", "python_init_repositories") +python_init_repositories( + requirements = { + "3.10": "//build:requirements_lock_3_10.txt", + "3.11": "//build:requirements_lock_3_11.txt", + "3.12": "//build:requirements_lock_3_12.txt", + "3.13": "//build:requirements_lock_3_13.txt", + }, + local_wheel_workspaces = ["//jaxlib:jax.bzl"], + local_wheel_dist_folder = "../dist", + default_python_version = "system", +) + +load("@xla//third_party/py:python_init_toolchains.bzl", "python_init_toolchains") +python_init_toolchains() + +load("@xla//third_party/py:python_init_pip.bzl", "python_init_pip") +python_init_pip() + +load("@pypi//:requirements.bzl", "install_deps") +install_deps() + +# Optional, to facilitate testing against newest versions of Python +load("@xla//third_party/py:python_repo.bzl", "custom_python_interpreter") +custom_python_interpreter( + name = "python_dev", + urls = ["https://www.python.org/ftp/python/3.13.0/Python-{version}.tgz"], + strip_prefix = "Python-{version}", + version = "3.13.0a6", +) + load("@xla//:workspace4.bzl", "xla_workspace4") xla_workspace4() @@ -19,10 +54,3 @@ xla_workspace0() load("//third_party/flatbuffers:workspace.bzl", flatbuffers = "repo") flatbuffers() - -load("//third_party/robin_map:workspace.bzl", robin_map = "repo") -robin_map() - -load("//third_party/nanobind:workspace.bzl", nanobind = "repo") -nanobind() - diff --git a/benchmarks/api_benchmark.py b/benchmarks/api_benchmark.py index 30fb04ace8e3..c68dab85dc8e 100644 --- a/benchmarks/api_benchmark.py +++ b/benchmarks/api_benchmark.py @@ -33,9 +33,7 @@ import jax.numpy as jnp import numpy as np -from jax import config - -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() partial = functools.partial @@ -679,11 +677,31 @@ def host_local_array_to_global_array(state): multihost_utils.host_local_array_to_global_array( (input_data, input_data), global_mesh, (in_pspec, in_pspec)) + @google_benchmark.register -def device_put(state): - x = np.array(1, np.int32) +@google_benchmark.option.arg_names(['num_args']) +@google_benchmark.option.args([1]) +@google_benchmark.option.args([10]) +@google_benchmark.option.args([100]) +@google_benchmark.option.args([1000]) +def device_put_from_numpy_array(state): + x = [np.array(1, np.int32)] * state.range(0) while state: - _ = jax.device_put(x).block_until_ready() + _ = jax.block_until_ready(jax.device_put(x)) + + +@google_benchmark.register +@google_benchmark.option.arg_names(['num_args']) +@google_benchmark.option.args([1]) +@google_benchmark.option.args([10]) +@google_benchmark.option.args([100]) +@google_benchmark.option.args([1000]) +def device_put_from_jax_array(state): + x = [np.array(1, np.int32)] * state.range(0) + x = jax.block_until_ready(jax.device_put(x, device=jax.devices()[0])) + d = jax.devices()[1] + while state: + _ = jax.block_until_ready(jax.device_put(x, device=d)) @google_benchmark.register @@ -856,6 +874,21 @@ def bench_make_array_from_callback_fully_replicated_sharding(state): while state: jax.make_array_from_callback(shape, s, np_arr.__getitem__) + +@google_benchmark.register +@google_benchmark.option.unit(google_benchmark.kMillisecond) +def bench_make_array_from_callback_sharded(state): + global_mesh = create_mesh((4, 2), ('x', 'y'), state) + input_shape = (8, 2) + input_data = np.arange(math.prod(input_shape)).reshape(input_shape) + + def callback(index): + return input_data[index] + + s = jax.NamedSharding(global_mesh, jax.sharding.PartitionSpec('x', 'y')) + while state: + jax.make_array_from_callback((8, 2), s, callback) + @google_benchmark.register @google_benchmark.option.unit(google_benchmark.kMillisecond) def benchmark_lorentz63_cache_hits(state): @@ -888,5 +921,23 @@ def loss(x0): jax.make_jaxpr(lambda x: training_step(x, 100, unroll=True))(x) +@google_benchmark.register +def jit_add_chain(state): + SIZE = 100 + + @jax.jit + def g(x, y): + return lax.add(x, y) + + x = jax.random.normal(jax.random.PRNGKey(0), (2, 2)) + while state: + @jax.jit + def f(x): + for i in range(SIZE): + x = g(x, x) + return x + f(x).block_until_ready() + + if __name__ == "__main__": google_benchmark.main() diff --git a/benchmarks/mosaic/BUILD b/benchmarks/mosaic/BUILD new file mode 100644 index 000000000000..027da12ce6d3 --- /dev/null +++ b/benchmarks/mosaic/BUILD @@ -0,0 +1,56 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load( + "//jaxlib:jax.bzl", + "jax_generate_backend_suites", + "jax_test", + "py_deps", +) + +licenses(["notice"]) + +package( + default_applicable_licenses = [], + default_visibility = ["//visibility:private"], +) + +jax_generate_backend_suites() + +DISABLED_BACKENDS = [ + "cpu", + "tpu", +] + +DISABLED_CONFIGS = [ + "gpu", + "gpu_a100", + "gpu_p100", + "gpu_p100_x32", + "gpu_x32", + "gpu_pjrt_c_api", +] + +jax_test( + name = "matmul_bench", + srcs = ["matmul_bench.py"], + disable_backends = DISABLED_BACKENDS, + disable_configs = DISABLED_CONFIGS, + tags = ["notap"], + deps = [ + "//third_party/py/google_benchmark", + "//third_party/py/jax:mosaic_gpu", + "//third_party/py/jax/experimental/mosaic/gpu/examples:matmul", + ] + py_deps("absl/testing") + py_deps("numpy"), +) diff --git a/benchmarks/mosaic/matmul_bench.py b/benchmarks/mosaic/matmul_bench.py new file mode 100644 index 000000000000..32c147916407 --- /dev/null +++ b/benchmarks/mosaic/matmul_bench.py @@ -0,0 +1,110 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Microbenchmarks for mosaic gpu matrix mutliplication.""" + +import functools +import sys + +from absl import app +import google_benchmark as benchmark +from jax._src import config +from jax.experimental.mosaic.gpu.examples import matmul +from jax._src import test_util as jtu +import jax.numpy as jnp + +config.update("jax_traceback_filtering", "off") +config.parse_flags_with_absl() + +def _params_name(params): + return ",".join(f"{k}={v}" for k, v in params.items()) + +def matmul_benchmark(*args): + def decorator(get_runtimes): + for test_case in args: + + @benchmark.register(name=f"{get_runtimes.__name__}_{_params_name(test_case)}") + @benchmark.option.unit(benchmark.kMillisecond) + @benchmark.option.use_manual_time() + @benchmark.option.iterations(1) + @functools.wraps(get_runtimes) + def wrapper(state, test_case=test_case): + m, n, k = test_case["m"], test_case["n"], test_case["k"] + runtime, ref_runtime = get_runtimes(**test_case) + state.counters["TFlops"] = ( + float(2 * k * m * n) / (runtime / 1e3) / 1e12 + ) + state.counters["jax_TFlops"] = ( + float(2 * k * m * n) / (ref_runtime / 1e3) / 1e12 + ) + state.counters["speedup"] = ref_runtime / runtime + state.set_iteration_time(runtime / 1e3) + + return decorator + + +@matmul_benchmark( + dict(m=55 * 128, n=95 * 128, k=48 * 128, stages=4, tile_m=128), + dict(m=55 * 128, n=45 * 128, k=48 * 128, stages=4, tile_m=128), + dict(m=64, n=95 * 128, k=48 * 128, stages=4, tile_m=64), + dict(m=64, n=45 * 128, k=48 * 128, stages=4, tile_m=64), +) +def bf16_i8_matmul(m, k, n, stages, tile_m): + # RHS.element_size==1b so k_tile=128 + if stages * 128 > k: + raise ValueError(f"Too many stages {(stages, k)=}.") + + return matmul.verify( + m, + k, + n, + stages, + tile_m=tile_m, + rhs_transpose=False, + lhs_dtype=jnp.bfloat16, + rhs_dtype=jnp.int8, + ) + +@matmul_benchmark( + dict(m=1024, n=1024, k=1024, stages=4, tile_m=128, tile_n=256), + dict(m=1024, n=1024, k=1024, stages=4, tile_m=128, tile_n=128), + dict(m=1024, n=1024, k=1024, stages=4, tile_m=64, tile_n=128), +) +def f32_matmul(m, n, k, stages, tile_m, tile_n): + if stages * 32 > k: + raise ValueError(f"Too many stages {(stages, k)=}.") + + return matmul.verify( + m=m, + k=k, + n=n, + stages=stages, + tile_m=tile_m, + tile_n=tile_n, + rhs_transpose=True, + lhs_dtype=jnp.float32, + rhs_dtype=jnp.float32, + ) + + +def main(_): + device = jtu.device_under_test() + if device != "gpu": + raise ValueError(f"Mosaic only work with gpu (got {device})") + + benchmark.run_benchmarks() + + +if __name__ == "__main__": + sys.argv = benchmark.initialize(sys.argv) + app.run(main) diff --git a/benchmarks/shape_poly_benchmark.py b/benchmarks/shape_poly_benchmark.py index bd8dd42d1052..d26801d8dfe5 100644 --- a/benchmarks/shape_poly_benchmark.py +++ b/benchmarks/shape_poly_benchmark.py @@ -15,12 +15,12 @@ import google_benchmark as benchmark -from jax import config +import jax from jax import core from jax._src.numpy import lax_numpy -from jax.experimental import export +from jax import export -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() @benchmark.register diff --git a/build/BUILD.bazel b/build/BUILD.bazel new file mode 100644 index 000000000000..cf43fdab09a6 --- /dev/null +++ b/build/BUILD.bazel @@ -0,0 +1,67 @@ +# Copyright 2024 The Jax Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +licenses(["notice"]) + +load("@python//:defs.bzl", "compile_pip_requirements") +load("@python_version_repo//:py_version.bzl", "REQUIREMENTS") +load("//jaxlib:jax.bzl", "all_py_deps") + +compile_pip_requirements( + name = "requirements", + extra_args = [ + "--allow-unsafe", + "--build-isolation", + "--rebuild", + ], + requirements_in = "requirements.in", + requirements_txt = REQUIREMENTS, + generate_hashes = True, + data = ["test-requirements.txt"] +) + +compile_pip_requirements( + name = "requirements_nightly", + extra_args = [ + "--allow-unsafe", + "--build-isolation", + "--extra-index-url=https://pypi.anaconda.org/scientific-python-nightly-wheels/simple", + "--pre", + "--upgrade" + ], + requirements_in = "requirements.in", + requirements_txt = REQUIREMENTS, + generate_hashes = False, + data = ["test-requirements.txt"] +) + +compile_pip_requirements( + name = "requirements_dev", + extra_args = [ + "--allow-unsafe", + "--build-isolation", + "--upgrade", + "--rebuild", + ], + requirements_in = "requirements.in", + requirements_txt = REQUIREMENTS, + generate_hashes = False, + data = ["test-requirements.txt"] +) + +py_library( + name = "all_py_deps", + deps = all_py_deps(["zstandard"]), +) \ No newline at end of file diff --git a/build/build.py b/build/build.py index dfdf33ec8599..2f68222816ac 100755 --- a/build/build.py +++ b/build/build.py @@ -70,29 +70,10 @@ def get_python_version(python_bin_path): return major, minor def check_python_version(python_version): - if python_version < (3, 9): - print("ERROR: JAX requires Python 3.9 or newer, found ", python_version) + if python_version < (3, 10): + print("ERROR: JAX requires Python 3.10 or newer, found ", python_version) sys.exit(-1) -def check_package_is_installed(python_bin_path, python_version, package): - args = [python_bin_path] - if python_version >= (3, 11): - args.append("-P") # Don't include the current directory. - args += ["-c", f"import {package}"] - try: - shell(args) - except: - print(f"ERROR: jaxlib build requires package '{package}' to be installed.") - sys.exit(-1) - -def check_numpy_version(python_bin_path): - version = shell( - [python_bin_path, "-c", "import numpy as np; print(np.__version__)"]) - numpy_version = tuple(map(int, version.split(".")[:2])) - if numpy_version < (1, 22): - print("ERROR: JAX requires NumPy 1.22 or newer, found " + version + ".") - sys.exit(-1) - return version def get_githash(): try: @@ -105,45 +86,45 @@ def get_githash(): # Bazel -BAZEL_BASE_URI = "https://github.com/bazelbuild/bazel/releases/download/6.1.2/" +BAZEL_BASE_URI = "https://github.com/bazelbuild/bazel/releases/download/6.5.0/" BazelPackage = collections.namedtuple("BazelPackage", ["base_uri", "file", "sha256"]) bazel_packages = { ("Linux", "x86_64"): BazelPackage( base_uri=None, - file="bazel-6.1.2-linux-x86_64", + file="bazel-6.5.0-linux-x86_64", sha256= - "e89747d63443e225b140d7d37ded952dacea73aaed896bca01ccd745827c6289"), + "a40ac69263440761199fcb8da47ad4e3f328cbe79ffbf4ecc14e5ba252857307"), ("Linux", "aarch64"): BazelPackage( base_uri=None, - file="bazel-6.1.2-linux-arm64", + file="bazel-6.5.0-linux-arm64", sha256= - "1c9b249e315601c3703c41668a1204a8fdf0eba7f0f2b7fc38253bad1d1969c7"), + "5afe973cadc036496cac66f1414ca9be36881423f576db363d83afc9084c0c2f"), ("Darwin", "x86_64"): BazelPackage( base_uri=None, - file="bazel-6.1.2-darwin-x86_64", + file="bazel-6.5.0-darwin-x86_64", sha256= - "22d4b605ce6a7aad92d4f387458cc68de9907a2efa08f9b8bda244c2b6010561"), + "bbf9c2c03bac48e0514f46db0295027935535d91f6d8dcd960c53393559eab29"), ("Darwin", "arm64"): BazelPackage( base_uri=None, - file="bazel-6.1.2-darwin-arm64", + file="bazel-6.5.0-darwin-arm64", sha256= - "30cdf85af055ca8fdab7de592b1bd64f940955e3f63ed5c503c4e93d0112bd9d"), + "c6b6dc17efcdf13fba484c6fe0b6c3361b888ae7b9573bc25a2dbe8c502448eb"), ("Windows", "AMD64"): BazelPackage( base_uri=None, - file="bazel-6.1.2-windows-x86_64.exe", + file="bazel-6.5.0-windows-x86_64.exe", sha256= - "47e7f65a3bfa882910f76e2107b4298b28ace33681bd0279e25a8f91551913c0"), + "6eae8e7f28e1b68b833503d1a58caf139c11e52de19df0d787d974653a0ea4c6"), } def download_and_verify_bazel(): - """Downloads a bazel binary from Github, verifying its SHA256 hash.""" + """Downloads a bazel binary from GitHub, verifying its SHA256 hash.""" package = bazel_packages.get((platform.system(), platform.machine())) if package is None: return None @@ -201,7 +182,7 @@ def get_bazel_paths(bazel_path_flag): def get_bazel_path(bazel_path_flag): """Returns the path to a Bazel binary, downloading Bazel if not found. Also, - checks Bazel's version is at least newer than 5.1.1 + checks Bazel's version is at least newer than 6.5.0 A manual version check is needed only for really old bazel versions. Newer bazel releases perform their own version check against .bazelversion @@ -210,11 +191,11 @@ def get_bazel_path(bazel_path_flag): """ for path in filter(None, get_bazel_paths(bazel_path_flag)): version = get_bazel_version(path) - if version is not None and version >= (5, 1, 1): + if version is not None and version >= (6, 5, 0): return path, ".".join(map(str, version)) print("Cannot find or download a suitable version of bazel." - "Please install bazel >= 5.1.1.") + "Please install bazel >= 6.5.0.") sys.exit(-1) @@ -256,34 +237,31 @@ def get_clang_major_version(clang_path): -def write_bazelrc(*, python_bin_path, remote_build, +def write_bazelrc(*, remote_build, cuda_toolkit_path, cudnn_install_path, cuda_version, cudnn_version, rocm_toolkit_path, cpu, cuda_compute_capabilities, - rocm_amdgpu_targets, bazel_options, target_cpu_features, + rocm_amdgpu_targets, target_cpu_features, wheel_cpu, enable_mkl_dnn, use_clang, clang_path, clang_major_version, enable_cuda, enable_nccl, enable_rocm, - build_gpu_plugin): + build_gpu_plugin, python_version): tf_cuda_paths = [] with open("../.jax_configure.bazelrc", "w") as f: - if not remote_build and python_bin_path: + if not remote_build: f.write(textwrap.dedent("""\ build --strategy=Genrule=standalone - build --repo_env PYTHON_BIN_PATH="{python_bin_path}" - build --action_env=PYENV_ROOT - build --python_path="{python_bin_path}" - """).format(python_bin_path=python_bin_path)) + """)) if use_clang: f.write(f'build --action_env CLANG_COMPILER_PATH="{clang_path}"\n') f.write(f'build --repo_env CC="{clang_path}"\n') f.write(f'build --repo_env BAZEL_COMPILER="{clang_path}"\n') - bazel_options.append("--copt=-Wno-error=unused-command-line-argument\n") - if clang_major_version in (16, 17): + f.write('build --copt=-Wno-error=unused-command-line-argument\n') + if clang_major_version in (16, 17, 18): # Necessary due to XLA's old version of upb. See: # https://github.com/openxla/xla/blob/c4277a076e249f5b97c8e45c8cb9d1f554089d76/.bazelrc#L505 - bazel_options.append("--copt=-Wno-gnu-offsetof-extensions\n") + f.write("build --copt=-Wno-gnu-offsetof-extensions\n") if cuda_toolkit_path: tf_cuda_paths.append(cuda_toolkit_path) @@ -316,8 +294,6 @@ def write_bazelrc(*, python_bin_path, remote_build, if cpu is not None: f.write(f"build --cpu={cpu}\n") - for o in bazel_options: - f.write(f"build {o}\n") if target_cpu_features == "release": if wheel_cpu == "x86_64": f.write("build --config=avx_windows\n" if is_windows() @@ -342,9 +318,14 @@ def write_bazelrc(*, python_bin_path, remote_build, if not enable_nccl: f.write("build --config=nonccl\n") if build_gpu_plugin: - f.write("build --config=cuda_plugin\n") - - + if enable_cuda: + f.write("build --config=cuda_plugin\n") + elif enable_rocm: + f.write("build --config=rocm_plugin\n") + if python_version: + f.write( + "build --repo_env HERMETIC_PYTHON_VERSION=\"{python_version}\"".format( + python_version=python_version)) BANNER = r""" _ _ __ __ | | / \ \ \/ / @@ -404,8 +385,9 @@ def main(): "GitHub.") parser.add_argument( "--python_bin_path", - help="Path to Python binary to use. The default is the Python " - "interpreter used to run the build script.") + help="Path to Python binary whose version to match while building with " + "hermetic python. The default is the Python interpreter used to run the " + "build script. DEPRECATED: use --python_version instead.") parser.add_argument( "--target_cpu_features", choices=["release", "native", "default"], @@ -452,21 +434,21 @@ def main(): "plugin is still experimental and is not ready for use yet." ), ) - add_boolean_argument( - parser, - "build_cuda_kernel_plugin", - default=False, - help_str=( - "Are we building the cuda kernel plugin? jaxlib will not be built " - "when this flag is True." + parser.add_argument( + "--build_gpu_kernel_plugin", + choices=["cuda", "rocm"], + default="", + help=( + "Specify 'cuda' or 'rocm' to build the respective kernel plugin." + " When this flag is set, jaxlib will not be built." ), ) add_boolean_argument( parser, - "build_cuda_pjrt_plugin", + "build_gpu_pjrt_plugin", default=False, help_str=( - "Are we building the cuda pjrt plugin? jaxlib will not be built " + "Are we building the cuda/rocm pjrt plugin? jaxlib will not be built " "when this flag is True." ), ) @@ -475,6 +457,11 @@ def main(): choices=["11", "12"], default="12", help="Which CUDA major version the gpu plugin is for.") + parser.add_argument( + "--gpu_plugin_rocm_version", + choices=["60"], + default="60", + help="Which ROCM major version the gpu plugin is for.") add_boolean_argument( parser, "enable_rocm", @@ -527,7 +514,8 @@ def main(): parser.add_argument( "--bazel_options", action="append", default=[], - help="Additional options to pass to bazel.") + help="Additional options to pass to the main Bazel command to be " + "executed, e.g. `run`.") parser.add_argument( "--output_path", default=os.path.join(cwd, "dist"), @@ -541,11 +529,29 @@ def main(): "--editable", action="store_true", help="Create an 'editable' jaxlib build instead of a wheel.") + parser.add_argument( + "--python_version", + default=None, + help="hermetic python version, e.g., 3.10") add_boolean_argument( parser, "configure_only", default=False, help_str="If true, writes a .bazelrc file but does not build jaxlib.") + add_boolean_argument( + parser, + "requirements_update", + default=False, + help_str="If true, writes a .bazelrc and updates requirements_lock.txt " + "for a corresponding version of Python but does not build " + "jaxlib.") + add_boolean_argument( + parser, + "requirements_nightly_update", + default=False, + help_str="Same as update_requirements, but will consider dev, nightly " + "and pre-release versions of packages.") + args = parser.parse_args() logging.basicConfig() @@ -582,17 +588,15 @@ def main(): print(f"Bazel binary path: {bazel_path}") print(f"Bazel version: {bazel_version}") - python_bin_path = get_python_bin_path(args.python_bin_path) - print(f"Python binary path: {python_bin_path}") - python_version = get_python_version(python_bin_path) - print("Python version: {}".format(".".join(map(str, python_version)))) - check_python_version(python_version) - - numpy_version = check_numpy_version(python_bin_path) - print(f"NumPy version: {numpy_version}") - check_package_is_installed(python_bin_path, python_version, "wheel") - check_package_is_installed(python_bin_path, python_version, "build") - check_package_is_installed(python_bin_path, python_version, "setuptools") + if args.python_version: + python_version = args.python_version + else: + python_bin_path = get_python_bin_path(args.python_bin_path) + print(f"Python binary path: {python_bin_path}") + python_version = get_python_version(python_bin_path) + print("Python version: {}".format(".".join(map(str, python_version)))) + check_python_version(python_version) + python_version = ".".join(map(str, python_version)) print("Use clang: {}".format("yes" if args.use_clang else "no")) clang_path = args.clang_path @@ -631,7 +635,6 @@ def main(): print(f"ROCm amdgpu targets: {args.rocm_amdgpu_targets}") write_bazelrc( - python_bin_path=python_bin_path, remote_build=args.remote_build, cuda_toolkit_path=cuda_toolkit_path, cudnn_install_path=cudnn_install_path, @@ -641,7 +644,6 @@ def main(): cpu=args.target_cpu, cuda_compute_capabilities=args.cuda_compute_capabilities, rocm_amdgpu_targets=args.rocm_amdgpu_targets, - bazel_options=args.bazel_options, target_cpu_features=args.target_cpu_features, wheel_cpu=wheel_cpu, enable_mkl_dnn=args.enable_mkl_dnn, @@ -652,48 +654,89 @@ def main(): enable_nccl=args.enable_nccl, enable_rocm=args.enable_rocm, build_gpu_plugin=args.build_gpu_plugin, + python_version=python_version, ) + if args.requirements_update: + update_command = ([bazel_path] + args.bazel_startup_options + + ["run", "--verbose_failures=true", "//build:requirements.update"]) + print(" ".join(update_command)) + shell(update_command) + return + + if args.requirements_nightly_update: + update_nightly_command = ([bazel_path] + args.bazel_startup_options + + ["run", "--verbose_failures=true", "//build:requirements_nightly.update"]) + print(" ".join(update_nightly_command)) + shell(update_nightly_command) + return + if args.configure_only: return print("\nBuilding XLA and installing it in the jaxlib source tree...") - if not args.build_cuda_kernel_plugin and not args.build_cuda_pjrt_plugin: - command = ([bazel_path] + args.bazel_startup_options + - ["run", "--verbose_failures=true"] + - ["//jaxlib/tools:build_wheel", "--", + command_base = ( + bazel_path, + *args.bazel_startup_options, + "run", + "--verbose_failures=true", + *args.bazel_options, + ) + + if args.build_gpu_kernel_plugin == "" and not args.build_gpu_pjrt_plugin: + build_cpu_wheel_command = [ + *command_base, + "//jaxlib/tools:build_wheel", "--", f"--output_path={output_path}", f"--jaxlib_git_hash={get_githash()}", - f"--cpu={wheel_cpu}"]) + f"--cpu={wheel_cpu}" + ] if args.build_gpu_plugin: - command.append("--include_gpu_plugin_extension") + build_cpu_wheel_command.append("--skip_gpu_kernels") if args.editable: - command += ["--editable"] - print(" ".join(command)) - shell(command) - - if args.build_gpu_plugin or args.build_cuda_kernel_plugin: - build_cuda_kernels_command = ([bazel_path] + args.bazel_startup_options + - ["run", "--verbose_failures=true"] + - ["//jaxlib/tools:build_cuda_kernels_wheel", "--", + build_cpu_wheel_command.append("--editable") + print(" ".join(build_cpu_wheel_command)) + shell(build_cpu_wheel_command) + + if args.build_gpu_plugin or (args.build_gpu_kernel_plugin == "cuda") or \ + (args.build_gpu_kernel_plugin == "rocm"): + build_gpu_kernels_command = [ + *command_base, + "//jaxlib/tools:build_gpu_kernels_wheel", "--", f"--output_path={output_path}", f"--jaxlib_git_hash={get_githash()}", f"--cpu={wheel_cpu}", - f"--cuda_version={args.gpu_plugin_cuda_version}"]) + ] + if args.enable_cuda: + build_gpu_kernels_command.append(f"--enable-cuda={args.enable_cuda}") + build_gpu_kernels_command.append(f"--platform_version={args.gpu_plugin_cuda_version}") + elif args.enable_rocm: + build_gpu_kernels_command.append(f"--enable-rocm={args.enable_rocm}") + build_gpu_kernels_command.append(f"--platform_version={args.gpu_plugin_rocm_version}") + else: + raise ValueError("Unsupported GPU plugin backend. Choose either 'cuda' or 'rocm'.") if args.editable: - build_cuda_kernels_command.append("--editable") - print(" ".join(build_cuda_kernels_command)) - shell(build_cuda_kernels_command) - - if args.build_gpu_plugin or args.build_cuda_pjrt_plugin: - build_pjrt_plugin_command = ([bazel_path] + args.bazel_startup_options + - ["run", "--verbose_failures=true"] + - ["//jaxlib/tools:build_gpu_plugin_wheel", "--", + build_gpu_kernels_command.append("--editable") + print(" ".join(build_gpu_kernels_command)) + shell(build_gpu_kernels_command) + + if args.build_gpu_plugin or args.build_gpu_pjrt_plugin: + build_pjrt_plugin_command = [ + *command_base, + "//jaxlib/tools:build_gpu_plugin_wheel", "--", f"--output_path={output_path}", f"--jaxlib_git_hash={get_githash()}", f"--cpu={wheel_cpu}", - f"--cuda_version={args.gpu_plugin_cuda_version}"]) + ] + if args.enable_cuda: + build_pjrt_plugin_command.append(f"--enable-cuda={args.enable_cuda}") + build_pjrt_plugin_command.append(f"--platform_version={args.gpu_plugin_cuda_version}") + elif args.enable_rocm: + build_pjrt_plugin_command.append(f"--enable-rocm={args.enable_rocm}") + build_pjrt_plugin_command.append(f"--platform_version={args.gpu_plugin_rocm_version}") + else: + raise ValueError("Unsupported GPU plugin backend. Choose either 'cuda' or 'rocm'.") if args.editable: build_pjrt_plugin_command.append("--editable") print(" ".join(build_pjrt_plugin_command)) diff --git a/build/requirements.in b/build/requirements.in new file mode 100644 index 000000000000..add6b8577350 --- /dev/null +++ b/build/requirements.in @@ -0,0 +1,24 @@ +# +# test deps +# +-r test-requirements.txt + +# matplotlib 3.9.0 pins NumPy 1.23, which is incompatible with the requirement +# below. +matplotlib~=3.8.4; python_version<="3.10" +matplotlib; python_version>="3.11" + +# +# build deps +# +numpy~=2.0.0 + +# +# runtime deps +# +scipy~=1.13.1 + +ml_dtypes>=0.4.0 +opt_einsum +zstandard +etils[epath] diff --git a/build/requirements_lock_3_10.txt b/build/requirements_lock_3_10.txt new file mode 100644 index 000000000000..adabb0dd2e70 --- /dev/null +++ b/build/requirements_lock_3_10.txt @@ -0,0 +1,624 @@ +# +# This file is autogenerated by pip-compile with Python 3.10 +# by the following command: +# +# bazel run //build:requirements.update +# +absl-py==2.1.0 \ + --hash=sha256:526a04eadab8b4ee719ce68f204172ead1027549089702d99b9059f129ff1308 \ + --hash=sha256:7820790efbb316739cde8b4e19357243fc3608a152024288513dd968d7d959ff + # via -r build/test-requirements.txt +attrs==23.2.0 \ + --hash=sha256:935dc3b529c262f6cf76e50877d35a4bd3c1de194fd41f47a2b7ae8f19971f30 \ + --hash=sha256:99b87a485a5820b23b879f04c2305b44b951b502fd64be915879d77a7e8fc6f1 + # via hypothesis +build==1.2.1 \ + --hash=sha256:526263f4870c26f26c433545579475377b2b7588b6f1eac76a001e873ae3e19d \ + --hash=sha256:75e10f767a433d9a86e50d83f418e83efc18ede923ee5ff7df93b6cb0306c5d4 + # via -r build/test-requirements.txt +cloudpickle==3.0.0 \ + --hash=sha256:246ee7d0c295602a036e86369c77fecda4ab17b506496730f2f576d9016fd9c7 \ + --hash=sha256:996d9a482c6fb4f33c1a35335cf8afd065d2a56e973270364840712d9131a882 + # via -r build/test-requirements.txt +colorama==0.4.6 \ + --hash=sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44 \ + --hash=sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6 + # via -r build/test-requirements.txt +contourpy==1.2.1 \ + --hash=sha256:00e5388f71c1a0610e6fe56b5c44ab7ba14165cdd6d695429c5cd94021e390b2 \ + --hash=sha256:10a37ae557aabf2509c79715cd20b62e4c7c28b8cd62dd7d99e5ed3ce28c3fd9 \ + --hash=sha256:11959f0ce4a6f7b76ec578576a0b61a28bdc0696194b6347ba3f1c53827178b9 \ + --hash=sha256:187fa1d4c6acc06adb0fae5544c59898ad781409e61a926ac7e84b8f276dcef4 \ + --hash=sha256:1a07fc092a4088ee952ddae19a2b2a85757b923217b7eed584fdf25f53a6e7ce \ + --hash=sha256:1cac0a8f71a041aa587410424ad46dfa6a11f6149ceb219ce7dd48f6b02b87a7 \ + --hash=sha256:1d59e739ab0e3520e62a26c60707cc3ab0365d2f8fecea74bfe4de72dc56388f \ + --hash=sha256:2855c8b0b55958265e8b5888d6a615ba02883b225f2227461aa9127c578a4922 \ + --hash=sha256:2e785e0f2ef0d567099b9ff92cbfb958d71c2d5b9259981cd9bee81bd194c9a4 \ + --hash=sha256:309be79c0a354afff9ff7da4aaed7c3257e77edf6c1b448a779329431ee79d7e \ + --hash=sha256:39f3ecaf76cd98e802f094e0d4fbc6dc9c45a8d0c4d185f0f6c2234e14e5f75b \ + --hash=sha256:457499c79fa84593f22454bbd27670227874cd2ff5d6c84e60575c8b50a69619 \ + --hash=sha256:49e70d111fee47284d9dd867c9bb9a7058a3c617274900780c43e38d90fe1205 \ + --hash=sha256:4c75507d0a55378240f781599c30e7776674dbaf883a46d1c90f37e563453480 \ + --hash=sha256:4c863140fafc615c14a4bf4efd0f4425c02230eb8ef02784c9a156461e62c965 \ + --hash=sha256:4d8908b3bee1c889e547867ca4cdc54e5ab6be6d3e078556814a22457f49423c \ + --hash=sha256:5b9eb0ca724a241683c9685a484da9d35c872fd42756574a7cfbf58af26677fd \ + --hash=sha256:6022cecf8f44e36af10bd9118ca71f371078b4c168b6e0fab43d4a889985dbb5 \ + --hash=sha256:6150ffa5c767bc6332df27157d95442c379b7dce3a38dff89c0f39b63275696f \ + --hash=sha256:62828cada4a2b850dbef89c81f5a33741898b305db244904de418cc957ff05dc \ + --hash=sha256:7b4182299f251060996af5249c286bae9361fa8c6a9cda5efc29fe8bfd6062ec \ + --hash=sha256:94b34f32646ca0414237168d68a9157cb3889f06b096612afdd296003fdd32fd \ + --hash=sha256:9ce6889abac9a42afd07a562c2d6d4b2b7134f83f18571d859b25624a331c90b \ + --hash=sha256:9cffe0f850e89d7c0012a1fb8730f75edd4320a0a731ed0c183904fe6ecfc3a9 \ + --hash=sha256:a12a813949e5066148712a0626895c26b2578874e4cc63160bb007e6df3436fe \ + --hash=sha256:a1eea9aecf761c661d096d39ed9026574de8adb2ae1c5bd7b33558af884fb2ce \ + --hash=sha256:a31f94983fecbac95e58388210427d68cd30fe8a36927980fab9c20062645609 \ + --hash=sha256:ac58bdee53cbeba2ecad824fa8159493f0bf3b8ea4e93feb06c9a465d6c87da8 \ + --hash=sha256:af3f4485884750dddd9c25cb7e3915d83c2db92488b38ccb77dd594eac84c4a0 \ + --hash=sha256:b33d2bc4f69caedcd0a275329eb2198f560b325605810895627be5d4b876bf7f \ + --hash=sha256:b59c0ffceff8d4d3996a45f2bb6f4c207f94684a96bf3d9728dbb77428dd8cb8 \ + --hash=sha256:bb6834cbd983b19f06908b45bfc2dad6ac9479ae04abe923a275b5f48f1a186b \ + --hash=sha256:bd3db01f59fdcbce5b22afad19e390260d6d0222f35a1023d9adc5690a889364 \ + --hash=sha256:bd7c23df857d488f418439686d3b10ae2fbf9bc256cd045b37a8c16575ea1040 \ + --hash=sha256:c2528d60e398c7c4c799d56f907664673a807635b857df18f7ae64d3e6ce2d9f \ + --hash=sha256:d31a63bc6e6d87f77d71e1abbd7387ab817a66733734883d1fc0021ed9bfa083 \ + --hash=sha256:d4492d82b3bc7fbb7e3610747b159869468079fe149ec5c4d771fa1f614a14df \ + --hash=sha256:ddcb8581510311e13421b1f544403c16e901c4e8f09083c881fab2be80ee31ba \ + --hash=sha256:e1d59258c3c67c865435d8fbeb35f8c59b8bef3d6f46c1f29f6123556af28445 \ + --hash=sha256:eb3315a8a236ee19b6df481fc5f997436e8ade24a9f03dfdc6bd490fea20c6da \ + --hash=sha256:ef2b055471c0eb466033760a521efb9d8a32b99ab907fc8358481a1dd29e3bd3 \ + --hash=sha256:ef5adb9a3b1d0c645ff694f9bca7702ec2c70f4d734f9922ea34de02294fdf72 \ + --hash=sha256:f32c38afb74bd98ce26de7cc74a67b40afb7b05aae7b42924ea990d51e4dac02 \ + --hash=sha256:fe0ccca550bb8e5abc22f530ec0466136379c01321fd94f30a22231e8a48d985 + # via matplotlib +cycler==0.12.1 \ + --hash=sha256:85cef7cff222d8644161529808465972e51340599459b8ac3ccbac5a854e0d30 \ + --hash=sha256:88bb128f02ba341da8ef447245a9e138fae777f6a23943da4540077d3601eb1c + # via matplotlib +etils[epath,epy]==1.7.0 \ + --hash=sha256:61af8f7c242171de15e22e5da02d527cb9e677d11f8bcafe18fcc3548eee3e60 \ + --hash=sha256:97b68fd25e185683215286ef3a54e38199b6245f5fe8be6bedc1189be4256350 + # via -r build/requirements.in +exceptiongroup==1.2.1 \ + --hash=sha256:5258b9ed329c5bbdd31a309f53cbfb0b155341807f6ff7606a1e801a891b29ad \ + --hash=sha256:a4785e48b045528f5bfe627b6ad554ff32def154f42372786903b7abcfe1aa16 + # via + # hypothesis + # pytest +execnet==2.1.1 \ + --hash=sha256:26dee51f1b80cebd6d0ca8e74dd8745419761d3bef34163928cbebbdc4749fdc \ + --hash=sha256:5189b52c6121c24feae288166ab41b32549c7e2348652736540b9e6e7d4e72e3 + # via pytest-xdist +filelock==3.14.0 \ + --hash=sha256:43339835842f110ca7ae60f1e1c160714c5a6afd15a2873419ab185334975c0f \ + --hash=sha256:6ea72da3be9b8c82afd3edcf99f2fffbb5076335a5ae4d03248bb5b6c3eae78a + # via -r build/test-requirements.txt +flatbuffers==24.3.25 \ + --hash=sha256:8dbdec58f935f3765e4f7f3cf635ac3a77f83568138d6a2311f524ec96364812 \ + --hash=sha256:de2ec5b203f21441716617f38443e0a8ebf3d25bf0d9c0bb0ce68fa00ad546a4 + # via -r build/test-requirements.txt +fonttools==4.51.0 \ + --hash=sha256:0118ef998a0699a96c7b28457f15546815015a2710a1b23a7bf6c1be60c01636 \ + --hash=sha256:0d145976194a5242fdd22df18a1b451481a88071feadf251221af110ca8f00ce \ + --hash=sha256:0e19bd9e9964a09cd2433a4b100ca7f34e34731e0758e13ba9a1ed6e5468cc0f \ + --hash=sha256:0f08c901d3866a8905363619e3741c33f0a83a680d92a9f0e575985c2634fcc1 \ + --hash=sha256:1250e818b5f8a679ad79660855528120a8f0288f8f30ec88b83db51515411fcc \ + --hash=sha256:15c94eeef6b095831067f72c825eb0e2d48bb4cea0647c1b05c981ecba2bf39f \ + --hash=sha256:1621ee57da887c17312acc4b0e7ac30d3a4fb0fec6174b2e3754a74c26bbed1e \ + --hash=sha256:180194c7fe60c989bb627d7ed5011f2bef1c4d36ecf3ec64daec8302f1ae0716 \ + --hash=sha256:278e50f6b003c6aed19bae2242b364e575bcb16304b53f2b64f6551b9c000e15 \ + --hash=sha256:32b17504696f605e9e960647c5f64b35704782a502cc26a37b800b4d69ff3c77 \ + --hash=sha256:3bee3f3bd9fa1d5ee616ccfd13b27ca605c2b4270e45715bd2883e9504735034 \ + --hash=sha256:4060acc2bfa2d8e98117828a238889f13b6f69d59f4f2d5857eece5277b829ba \ + --hash=sha256:54dcf21a2f2d06ded676e3c3f9f74b2bafded3a8ff12f0983160b13e9f2fb4a7 \ + --hash=sha256:56fc244f2585d6c00b9bcc59e6593e646cf095a96fe68d62cd4da53dd1287b55 \ + --hash=sha256:599bdb75e220241cedc6faebfafedd7670335d2e29620d207dd0378a4e9ccc5a \ + --hash=sha256:5f6bc991d1610f5c3bbe997b0233cbc234b8e82fa99fc0b2932dc1ca5e5afec0 \ + --hash=sha256:60a3409c9112aec02d5fb546f557bca6efa773dcb32ac147c6baf5f742e6258b \ + --hash=sha256:68b3fb7775a923be73e739f92f7e8a72725fd333eab24834041365d2278c3671 \ + --hash=sha256:76f1777d8b3386479ffb4a282e74318e730014d86ce60f016908d9801af9ca2a \ + --hash=sha256:806e7912c32a657fa39d2d6eb1d3012d35f841387c8fc6cf349ed70b7c340039 \ + --hash=sha256:84d7751f4468dd8cdd03ddada18b8b0857a5beec80bce9f435742abc9a851a74 \ + --hash=sha256:865a58b6e60b0938874af0968cd0553bcd88e0b2cb6e588727117bd099eef836 \ + --hash=sha256:8ac27f436e8af7779f0bb4d5425aa3535270494d3bc5459ed27de3f03151e4c2 \ + --hash=sha256:8b4850fa2ef2cfbc1d1f689bc159ef0f45d8d83298c1425838095bf53ef46308 \ + --hash=sha256:8b5ad456813d93b9c4b7ee55302208db2b45324315129d85275c01f5cb7e61a2 \ + --hash=sha256:8e2f1a4499e3b5ee82c19b5ee57f0294673125c65b0a1ff3764ea1f9db2f9ef5 \ + --hash=sha256:9696fe9f3f0c32e9a321d5268208a7cc9205a52f99b89479d1b035ed54c923f1 \ + --hash=sha256:96a48e137c36be55e68845fc4284533bda2980f8d6f835e26bca79d7e2006438 \ + --hash=sha256:a8feca65bab31479d795b0d16c9a9852902e3a3c0630678efb0b2b7941ea9c74 \ + --hash=sha256:aefa011207ed36cd280babfaa8510b8176f1a77261833e895a9d96e57e44802f \ + --hash=sha256:b2b92381f37b39ba2fc98c3a45a9d6383bfc9916a87d66ccb6553f7bdd129097 \ + --hash=sha256:b3c61423f22165541b9403ee39874dcae84cd57a9078b82e1dce8cb06b07fa2e \ + --hash=sha256:b5b48a1121117047d82695d276c2af2ee3a24ffe0f502ed581acc2673ecf1037 \ + --hash=sha256:c18b49adc721a7d0b8dfe7c3130c89b8704baf599fb396396d07d4aa69b824a1 \ + --hash=sha256:c5b8cab0c137ca229433570151b5c1fc6af212680b58b15abd797dcdd9dd5051 \ + --hash=sha256:c7e91abdfae1b5c9e3a543f48ce96013f9a08c6c9668f1e6be0beabf0a569c1b \ + --hash=sha256:cadf4e12a608ef1d13e039864f484c8a968840afa0258b0b843a0556497ea9ed \ + --hash=sha256:dc0673361331566d7a663d7ce0f6fdcbfbdc1f59c6e3ed1165ad7202ca183c68 \ + --hash=sha256:de7c29bdbdd35811f14493ffd2534b88f0ce1b9065316433b22d63ca1cd21f14 \ + --hash=sha256:e9d9298be7a05bb4801f558522adbe2feea1b0b103d5294ebf24a92dd49b78e5 \ + --hash=sha256:ee1af4be1c5afe4c96ca23badd368d8dc75f611887fb0c0dac9f71ee5d6f110e \ + --hash=sha256:f7e89853d8bea103c8e3514b9f9dc86b5b4120afb4583b57eb10dfa5afbe0936 + # via matplotlib +fsspec==2024.5.0 \ + --hash=sha256:1d021b0b0f933e3b3029ed808eb400c08ba101ca2de4b3483fbc9ca23fcee94a \ + --hash=sha256:e0fdbc446d67e182f49a70b82cf7889028a63588fde6b222521f10937b2b670c + # via etils +hypothesis==6.102.4 \ + --hash=sha256:013df31b04a4daede13756f497e60e451963d86f426395a79f99c5d692919bbd \ + --hash=sha256:59b4d144346d5cffb482cc1bafbd21b13ff31608e8c4b3e4630339aee3e87763 + # via -r build/test-requirements.txt +importlib-resources==6.4.0 \ + --hash=sha256:50d10f043df931902d4194ea07ec57960f66a80449ff867bfe782b4c486ba78c \ + --hash=sha256:cdb2b453b8046ca4e3798eb1d84f3cce1446a0e8e7b5ef4efb600f19fc398145 + # via etils +iniconfig==2.0.0 \ + --hash=sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3 \ + --hash=sha256:b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374 + # via pytest +kiwisolver==1.4.5 \ + --hash=sha256:00bd361b903dc4bbf4eb165f24d1acbee754fce22ded24c3d56eec268658a5cf \ + --hash=sha256:040c1aebeda72197ef477a906782b5ab0d387642e93bda547336b8957c61022e \ + --hash=sha256:05703cf211d585109fcd72207a31bb170a0f22144d68298dc5e61b3c946518af \ + --hash=sha256:06f54715b7737c2fecdbf140d1afb11a33d59508a47bf11bb38ecf21dc9ab79f \ + --hash=sha256:0dc9db8e79f0036e8173c466d21ef18e1befc02de8bf8aa8dc0813a6dc8a7046 \ + --hash=sha256:0f114aa76dc1b8f636d077979c0ac22e7cd8f3493abbab152f20eb8d3cda71f3 \ + --hash=sha256:11863aa14a51fd6ec28688d76f1735f8f69ab1fabf388851a595d0721af042f5 \ + --hash=sha256:11c7de8f692fc99816e8ac50d1d1aef4f75126eefc33ac79aac02c099fd3db71 \ + --hash=sha256:11d011a7574eb3b82bcc9c1a1d35c1d7075677fdd15de527d91b46bd35e935ee \ + --hash=sha256:146d14bebb7f1dc4d5fbf74f8a6cb15ac42baadee8912eb84ac0b3b2a3dc6ac3 \ + --hash=sha256:15568384086b6df3c65353820a4473575dbad192e35010f622c6ce3eebd57af9 \ + --hash=sha256:19df6e621f6d8b4b9c4d45f40a66839294ff2bb235e64d2178f7522d9170ac5b \ + --hash=sha256:1b04139c4236a0f3aff534479b58f6f849a8b351e1314826c2d230849ed48985 \ + --hash=sha256:210ef2c3a1f03272649aff1ef992df2e724748918c4bc2d5a90352849eb40bea \ + --hash=sha256:2270953c0d8cdab5d422bee7d2007f043473f9d2999631c86a223c9db56cbd16 \ + --hash=sha256:2400873bccc260b6ae184b2b8a4fec0e4082d30648eadb7c3d9a13405d861e89 \ + --hash=sha256:2a40773c71d7ccdd3798f6489aaac9eee213d566850a9533f8d26332d626b82c \ + --hash=sha256:2c5674c4e74d939b9d91dda0fae10597ac7521768fec9e399c70a1f27e2ea2d9 \ + --hash=sha256:3195782b26fc03aa9c6913d5bad5aeb864bdc372924c093b0f1cebad603dd712 \ + --hash=sha256:31a82d498054cac9f6d0b53d02bb85811185bcb477d4b60144f915f3b3126342 \ + --hash=sha256:32d5cf40c4f7c7b3ca500f8985eb3fb3a7dfc023215e876f207956b5ea26632a \ + --hash=sha256:346f5343b9e3f00b8db8ba359350eb124b98c99efd0b408728ac6ebf38173958 \ + --hash=sha256:378a214a1e3bbf5ac4a8708304318b4f890da88c9e6a07699c4ae7174c09a68d \ + --hash=sha256:39b42c68602539407884cf70d6a480a469b93b81b7701378ba5e2328660c847a \ + --hash=sha256:3a2b053a0ab7a3960c98725cfb0bf5b48ba82f64ec95fe06f1d06c99b552e130 \ + --hash=sha256:3aba7311af82e335dd1e36ffff68aaca609ca6290c2cb6d821a39aa075d8e3ff \ + --hash=sha256:3cd32d6c13807e5c66a7cbb79f90b553642f296ae4518a60d8d76243b0ad2898 \ + --hash=sha256:3edd2fa14e68c9be82c5b16689e8d63d89fe927e56debd6e1dbce7a26a17f81b \ + --hash=sha256:4c380469bd3f970ef677bf2bcba2b6b0b4d5c75e7a020fb863ef75084efad66f \ + --hash=sha256:4e66e81a5779b65ac21764c295087de82235597a2293d18d943f8e9e32746265 \ + --hash=sha256:53abb58632235cd154176ced1ae8f0d29a6657aa1aa9decf50b899b755bc2b93 \ + --hash=sha256:5794cf59533bc3f1b1c821f7206a3617999db9fbefc345360aafe2e067514929 \ + --hash=sha256:59415f46a37f7f2efeec758353dd2eae1b07640d8ca0f0c42548ec4125492635 \ + --hash=sha256:59ec7b7c7e1a61061850d53aaf8e93db63dce0c936db1fda2658b70e4a1be709 \ + --hash=sha256:59edc41b24031bc25108e210c0def6f6c2191210492a972d585a06ff246bb79b \ + --hash=sha256:5a580c91d686376f0f7c295357595c5a026e6cbc3d77b7c36e290201e7c11ecb \ + --hash=sha256:5b94529f9b2591b7af5f3e0e730a4e0a41ea174af35a4fd067775f9bdfeee01a \ + --hash=sha256:5c7b3b3a728dc6faf3fc372ef24f21d1e3cee2ac3e9596691d746e5a536de920 \ + --hash=sha256:5c90ae8c8d32e472be041e76f9d2f2dbff4d0b0be8bd4041770eddb18cf49a4e \ + --hash=sha256:5e7139af55d1688f8b960ee9ad5adafc4ac17c1c473fe07133ac092310d76544 \ + --hash=sha256:5ff5cf3571589b6d13bfbfd6bcd7a3f659e42f96b5fd1c4830c4cf21d4f5ef45 \ + --hash=sha256:620ced262a86244e2be10a676b646f29c34537d0d9cc8eb26c08f53d98013390 \ + --hash=sha256:6512cb89e334e4700febbffaaa52761b65b4f5a3cf33f960213d5656cea36a77 \ + --hash=sha256:6c08e1312a9cf1074d17b17728d3dfce2a5125b2d791527f33ffbe805200a355 \ + --hash=sha256:6c3bd3cde54cafb87d74d8db50b909705c62b17c2099b8f2e25b461882e544ff \ + --hash=sha256:6ef7afcd2d281494c0a9101d5c571970708ad911d028137cd558f02b851c08b4 \ + --hash=sha256:7269d9e5f1084a653d575c7ec012ff57f0c042258bf5db0954bf551c158466e7 \ + --hash=sha256:72d40b33e834371fd330fb1472ca19d9b8327acb79a5821d4008391db8e29f20 \ + --hash=sha256:74d1b44c6cfc897df648cc9fdaa09bc3e7679926e6f96df05775d4fb3946571c \ + --hash=sha256:74db36e14a7d1ce0986fa104f7d5637aea5c82ca6326ed0ec5694280942d1162 \ + --hash=sha256:763773d53f07244148ccac5b084da5adb90bfaee39c197554f01b286cf869228 \ + --hash=sha256:76c6a5964640638cdeaa0c359382e5703e9293030fe730018ca06bc2010c4437 \ + --hash=sha256:76d9289ed3f7501012e05abb8358bbb129149dbd173f1f57a1bf1c22d19ab7cc \ + --hash=sha256:7931d8f1f67c4be9ba1dd9c451fb0eeca1a25b89e4d3f89e828fe12a519b782a \ + --hash=sha256:7b8b454bac16428b22560d0a1cf0a09875339cab69df61d7805bf48919415901 \ + --hash=sha256:7e5bab140c309cb3a6ce373a9e71eb7e4873c70c2dda01df6820474f9889d6d4 \ + --hash=sha256:83d78376d0d4fd884e2c114d0621624b73d2aba4e2788182d286309ebdeed770 \ + --hash=sha256:852542f9481f4a62dbb5dd99e8ab7aedfeb8fb6342349a181d4036877410f525 \ + --hash=sha256:85267bd1aa8880a9c88a8cb71e18d3d64d2751a790e6ca6c27b8ccc724bcd5ad \ + --hash=sha256:88a2df29d4724b9237fc0c6eaf2a1adae0cdc0b3e9f4d8e7dc54b16812d2d81a \ + --hash=sha256:88b9f257ca61b838b6f8094a62418421f87ac2a1069f7e896c36a7d86b5d4c29 \ + --hash=sha256:8ab3919a9997ab7ef2fbbed0cc99bb28d3c13e6d4b1ad36e97e482558a91be90 \ + --hash=sha256:92dea1ffe3714fa8eb6a314d2b3c773208d865a0e0d35e713ec54eea08a66250 \ + --hash=sha256:9407b6a5f0d675e8a827ad8742e1d6b49d9c1a1da5d952a67d50ef5f4170b18d \ + --hash=sha256:9408acf3270c4b6baad483865191e3e582b638b1654a007c62e3efe96f09a9a3 \ + --hash=sha256:955e8513d07a283056b1396e9a57ceddbd272d9252c14f154d450d227606eb54 \ + --hash=sha256:9db8ea4c388fdb0f780fe91346fd438657ea602d58348753d9fb265ce1bca67f \ + --hash=sha256:9eaa8b117dc8337728e834b9c6e2611f10c79e38f65157c4c38e9400286f5cb1 \ + --hash=sha256:a51a263952b1429e429ff236d2f5a21c5125437861baeed77f5e1cc2d2c7c6da \ + --hash=sha256:a6aa6315319a052b4ee378aa171959c898a6183f15c1e541821c5c59beaa0238 \ + --hash=sha256:aa12042de0171fad672b6c59df69106d20d5596e4f87b5e8f76df757a7c399aa \ + --hash=sha256:aaf7be1207676ac608a50cd08f102f6742dbfc70e8d60c4db1c6897f62f71523 \ + --hash=sha256:b0157420efcb803e71d1b28e2c287518b8808b7cf1ab8af36718fd0a2c453eb0 \ + --hash=sha256:b3f7e75f3015df442238cca659f8baa5f42ce2a8582727981cbfa15fee0ee205 \ + --hash=sha256:b9098e0049e88c6a24ff64545cdfc50807818ba6c1b739cae221bbbcbc58aad3 \ + --hash=sha256:ba55dce0a9b8ff59495ddd050a0225d58bd0983d09f87cfe2b6aec4f2c1234e4 \ + --hash=sha256:bb86433b1cfe686da83ce32a9d3a8dd308e85c76b60896d58f082136f10bffac \ + --hash=sha256:bbea0db94288e29afcc4c28afbf3a7ccaf2d7e027489c449cf7e8f83c6346eb9 \ + --hash=sha256:bbf1d63eef84b2e8c89011b7f2235b1e0bf7dacc11cac9431fc6468e99ac77fb \ + --hash=sha256:c7940c1dc63eb37a67721b10d703247552416f719c4188c54e04334321351ced \ + --hash=sha256:c9bf3325c47b11b2e51bca0824ea217c7cd84491d8ac4eefd1e409705ef092bd \ + --hash=sha256:cdc8a402aaee9a798b50d8b827d7ecf75edc5fb35ea0f91f213ff927c15f4ff0 \ + --hash=sha256:ceec1a6bc6cab1d6ff5d06592a91a692f90ec7505d6463a88a52cc0eb58545da \ + --hash=sha256:cfe6ab8da05c01ba6fbea630377b5da2cd9bcbc6338510116b01c1bc939a2c18 \ + --hash=sha256:d099e745a512f7e3bbe7249ca835f4d357c586d78d79ae8f1dcd4d8adeb9bda9 \ + --hash=sha256:d0ef46024e6a3d79c01ff13801cb19d0cad7fd859b15037aec74315540acc276 \ + --hash=sha256:d2e5a98f0ec99beb3c10e13b387f8db39106d53993f498b295f0c914328b1333 \ + --hash=sha256:da4cfb373035def307905d05041c1d06d8936452fe89d464743ae7fb8371078b \ + --hash=sha256:da802a19d6e15dffe4b0c24b38b3af68e6c1a68e6e1d8f30148c83864f3881db \ + --hash=sha256:dced8146011d2bc2e883f9bd68618b8247387f4bbec46d7392b3c3b032640126 \ + --hash=sha256:dfdd7c0b105af050eb3d64997809dc21da247cf44e63dc73ff0fd20b96be55a9 \ + --hash=sha256:e368f200bbc2e4f905b8e71eb38b3c04333bddaa6a2464a6355487b02bb7fb09 \ + --hash=sha256:e391b1f0a8a5a10ab3b9bb6afcfd74f2175f24f8975fb87ecae700d1503cdee0 \ + --hash=sha256:e57e563a57fb22a142da34f38acc2fc1a5c864bc29ca1517a88abc963e60d6ec \ + --hash=sha256:e5d706eba36b4c4d5bc6c6377bb6568098765e990cfc21ee16d13963fab7b3e7 \ + --hash=sha256:ec20916e7b4cbfb1f12380e46486ec4bcbaa91a9c448b97023fde0d5bbf9e4ff \ + --hash=sha256:f1d072c2eb0ad60d4c183f3fb44ac6f73fb7a8f16a2694a91f988275cbf352f9 \ + --hash=sha256:f846c260f483d1fd217fe5ed7c173fb109efa6b1fc8381c8b7552c5781756192 \ + --hash=sha256:f91de7223d4c7b793867797bacd1ee53bfe7359bd70d27b7b58a04efbb9436c8 \ + --hash=sha256:faae4860798c31530dd184046a900e652c95513796ef51a12bc086710c2eec4d \ + --hash=sha256:fc579bf0f502e54926519451b920e875f433aceb4624a3646b3252b5caa9e0b6 \ + --hash=sha256:fcc700eadbbccbf6bc1bcb9dbe0786b4b1cb91ca0dcda336eef5c2beed37b797 \ + --hash=sha256:fd32ea360bcbb92d28933fc05ed09bffcb1704ba3fc7942e81db0fd4f81a7892 \ + --hash=sha256:fdb7adb641a0d13bdcd4ef48e062363d8a9ad4a182ac7647ec88f695e719ae9f + # via matplotlib +markdown-it-py==3.0.0 \ + --hash=sha256:355216845c60bd96232cd8d8c40e8f9765cc86f46880e43a8fd22dc1a1a8cab1 \ + --hash=sha256:e3f60a94fa066dc52ec76661e37c851cb232d92f9886b15cb560aaada2df8feb + # via rich +matplotlib==3.8.4 ; python_version <= "3.10" \ + --hash=sha256:1c13f041a7178f9780fb61cc3a2b10423d5e125480e4be51beaf62b172413b67 \ + --hash=sha256:232ce322bfd020a434caaffbd9a95333f7c2491e59cfc014041d95e38ab90d1c \ + --hash=sha256:493e9f6aa5819156b58fce42b296ea31969f2aab71c5b680b4ea7a3cb5c07d94 \ + --hash=sha256:50bac6e4d77e4262c4340d7a985c30912054745ec99756ce213bfbc3cb3808eb \ + --hash=sha256:606e3b90897554c989b1e38a258c626d46c873523de432b1462f295db13de6f9 \ + --hash=sha256:6209e5c9aaccc056e63b547a8152661324404dd92340a6e479b3a7f24b42a5d0 \ + --hash=sha256:6485ac1f2e84676cff22e693eaa4fbed50ef5dc37173ce1f023daef4687df616 \ + --hash=sha256:6addbd5b488aedb7f9bc19f91cd87ea476206f45d7116fcfe3d31416702a82fa \ + --hash=sha256:72f9322712e4562e792b2961971891b9fbbb0e525011e09ea0d1f416c4645661 \ + --hash=sha256:7a6769f58ce51791b4cb8b4d7642489df347697cd3e23d88266aaaee93b41d9a \ + --hash=sha256:8080d5081a86e690d7688ffa542532e87f224c38a6ed71f8fbed34dd1d9fedae \ + --hash=sha256:843cbde2f0946dadd8c5c11c6d91847abd18ec76859dc319362a0964493f0ba6 \ + --hash=sha256:8aac397d5e9ec158960e31c381c5ffc52ddd52bd9a47717e2a694038167dffea \ + --hash=sha256:8f65c9f002d281a6e904976007b2d46a1ee2bcea3a68a8c12dda24709ddc9106 \ + --hash=sha256:90df07db7b599fe7035d2f74ab7e438b656528c68ba6bb59b7dc46af39ee48ef \ + --hash=sha256:9bb0189011785ea794ee827b68777db3ca3f93f3e339ea4d920315a0e5a78d54 \ + --hash=sha256:a0e47eda4eb2614300fc7bb4657fced3e83d6334d03da2173b09e447418d499f \ + --hash=sha256:abc9d838f93583650c35eca41cfcec65b2e7cb50fd486da6f0c49b5e1ed23014 \ + --hash=sha256:ac24233e8f2939ac4fd2919eed1e9c0871eac8057666070e94cbf0b33dd9c338 \ + --hash=sha256:b12ba985837e4899b762b81f5b2845bd1a28f4fdd1a126d9ace64e9c4eb2fb25 \ + --hash=sha256:b7a2a253d3b36d90c8993b4620183b55665a429da8357a4f621e78cd48b2b30b \ + --hash=sha256:c7064120a59ce6f64103c9cefba8ffe6fba87f2c61d67c401186423c9a20fd35 \ + --hash=sha256:c89ee9314ef48c72fe92ce55c4e95f2f39d70208f9f1d9db4e64079420d8d732 \ + --hash=sha256:cc4ccdc64e3039fc303defd119658148f2349239871db72cd74e2eeaa9b80b71 \ + --hash=sha256:ce1edd9f5383b504dbc26eeea404ed0a00656c526638129028b758fd43fc5f10 \ + --hash=sha256:ecd79298550cba13a43c340581a3ec9c707bd895a6a061a78fa2524660482fc0 \ + --hash=sha256:f51c4c869d4b60d769f7b4406eec39596648d9d70246428745a681c327a8ad30 \ + --hash=sha256:fb44f53af0a62dc80bba4443d9b27f2fde6acfdac281d95bc872dc148a6509cc + # via -r build/requirements.in +mdurl==0.1.2 \ + --hash=sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8 \ + --hash=sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba + # via markdown-it-py +ml-dtypes==0.4.0 \ + --hash=sha256:03e7cda6ef164eed0abb31df69d2c00c3a5ab3e2610b6d4c42183a43329c72a5 \ + --hash=sha256:2bb83fd064db43e67e67d021e547698af4c8d5c6190f2e9b1c53c09f6ff5531d \ + --hash=sha256:3b67ec73a697c88c1122038e0de46520e48dc2ec876d42cf61bc5efe3c0b7675 \ + --hash=sha256:41affb38fdfe146e3db226cf2953021184d6f0c4ffab52136613e9601706e368 \ + --hash=sha256:43cf4356a0fe2eeac6d289018d0734e17a403bdf1fd911953c125dd0358edcc0 \ + --hash=sha256:723af6346447268a3cf0b7356e963d80ecb5732b5279b2aa3fa4b9fc8297c85e \ + --hash=sha256:75b4faf99d0711b81f393db36d210b4255fd419f6f790bc6c1b461f95ffb7a9e \ + --hash=sha256:93afe37f3a879d652ec9ef1fc47612388890660a2657fbb5747256c3b818fd81 \ + --hash=sha256:a15d96d090aebb55ee85173d1775ae325a001aab607a76c8ea0b964ccd6b5364 \ + --hash=sha256:ad6849a2db386b38e4d54fe13eb3293464561780531a918f8ef4c8169170dd49 \ + --hash=sha256:bdf689be7351cc3c95110c910c1b864002f113e682e44508910c849e144f3df1 \ + --hash=sha256:c83e4d443962d891d51669ff241d5aaad10a8d3d37a81c5532a45419885d591c \ + --hash=sha256:e1e2f4237b459a63c97c2c9f449baa637d7e4c20addff6a9bac486f22432f3b6 \ + --hash=sha256:eaa32979ebfde3a0d7c947cafbf79edc1ec77ac05ad0780ee86c1d8df70f2259 \ + --hash=sha256:eaf197e72f4f7176a19fe3cb8b61846b38c6757607e7bf9cd4b1d84cd3e74deb \ + --hash=sha256:ee9f91d4c4f9959a7e1051c141dc565f39e54435618152219769e24f5e9a4d06 \ + --hash=sha256:f1724ddcdf5edbaf615a62110af47407f1719b8d02e68ccee60683acb5f74da1 + # via -r build/requirements.in +mpmath==1.4.0a1 \ + --hash=sha256:78884400f439f500fa76be0121a8f9598313d87664863a192e1185ddbd7ae97f \ + --hash=sha256:f8b7b5a3a1726ab6e8c898eb2157426b82c482ab1ab8ffed9f88bb9e07c6e9c1 + # via -r build/test-requirements.txt +numpy==2.0.0 \ + --hash=sha256:04494f6ec467ccb5369d1808570ae55f6ed9b5809d7f035059000a37b8d7e86f \ + --hash=sha256:0a43f0974d501842866cc83471bdb0116ba0dffdbaac33ec05e6afed5b615238 \ + --hash=sha256:0e50842b2295ba8414c8c1d9d957083d5dfe9e16828b37de883f51fc53c4016f \ + --hash=sha256:0ec84b9ba0654f3b962802edc91424331f423dcf5d5f926676e0150789cb3d95 \ + --hash=sha256:17067d097ed036636fa79f6a869ac26df7db1ba22039d962422506640314933a \ + --hash=sha256:1cde1753efe513705a0c6d28f5884e22bdc30438bf0085c5c486cdaff40cd67a \ + --hash=sha256:1e72728e7501a450288fc8e1f9ebc73d90cfd4671ebbd631f3e7857c39bd16f2 \ + --hash=sha256:2635dbd200c2d6faf2ef9a0d04f0ecc6b13b3cad54f7c67c61155138835515d2 \ + --hash=sha256:2ce46fd0b8a0c947ae047d222f7136fc4d55538741373107574271bc00e20e8f \ + --hash=sha256:34f003cb88b1ba38cb9a9a4a3161c1604973d7f9d5552c38bc2f04f829536609 \ + --hash=sha256:354f373279768fa5a584bac997de6a6c9bc535c482592d7a813bb0c09be6c76f \ + --hash=sha256:38ecb5b0582cd125f67a629072fed6f83562d9dd04d7e03256c9829bdec027ad \ + --hash=sha256:3e8e01233d57639b2e30966c63d36fcea099d17c53bf424d77f088b0f4babd86 \ + --hash=sha256:3f6bed7f840d44c08ebdb73b1825282b801799e325bcbdfa6bc5c370e5aecc65 \ + --hash=sha256:4554eb96f0fd263041baf16cf0881b3f5dafae7a59b1049acb9540c4d57bc8cb \ + --hash=sha256:46e161722e0f619749d1cd892167039015b2c2817296104487cd03ed4a955995 \ + --hash=sha256:49d9f7d256fbc804391a7f72d4a617302b1afac1112fac19b6c6cec63fe7fe8a \ + --hash=sha256:4d2f62e55a4cd9c58c1d9a1c9edaedcd857a73cb6fda875bf79093f9d9086f85 \ + --hash=sha256:5f64641b42b2429f56ee08b4f427a4d2daf916ec59686061de751a55aafa22e4 \ + --hash=sha256:63b92c512d9dbcc37f9d81b123dec99fdb318ba38c8059afc78086fe73820275 \ + --hash=sha256:6d7696c615765091cc5093f76fd1fa069870304beaccfd58b5dcc69e55ef49c1 \ + --hash=sha256:79e843d186c8fb1b102bef3e2bc35ef81160ffef3194646a7fdd6a73c6b97196 \ + --hash=sha256:821eedb7165ead9eebdb569986968b541f9908979c2da8a4967ecac4439bae3d \ + --hash=sha256:84554fc53daa8f6abf8e8a66e076aff6ece62de68523d9f665f32d2fc50fd66e \ + --hash=sha256:8d83bb187fb647643bd56e1ae43f273c7f4dbcdf94550d7938cfc32566756514 \ + --hash=sha256:903703372d46bce88b6920a0cd86c3ad82dae2dbef157b5fc01b70ea1cfc430f \ + --hash=sha256:9416a5c2e92ace094e9f0082c5fd473502c91651fb896bc17690d6fc475128d6 \ + --hash=sha256:9a1712c015831da583b21c5bfe15e8684137097969c6d22e8316ba66b5baabe4 \ + --hash=sha256:9c27f0946a3536403efb0e1c28def1ae6730a72cd0d5878db38824855e3afc44 \ + --hash=sha256:a356364941fb0593bb899a1076b92dfa2029f6f5b8ba88a14fd0984aaf76d0df \ + --hash=sha256:a7039a136017eaa92c1848152827e1424701532ca8e8967fe480fe1569dae581 \ + --hash=sha256:acd3a644e4807e73b4e1867b769fbf1ce8c5d80e7caaef0d90dcdc640dfc9787 \ + --hash=sha256:ad0c86f3455fbd0de6c31a3056eb822fc939f81b1618f10ff3406971893b62a5 \ + --hash=sha256:b4c76e3d4c56f145d41b7b6751255feefae92edbc9a61e1758a98204200f30fc \ + --hash=sha256:b6f6a8f45d0313db07d6d1d37bd0b112f887e1369758a5419c0370ba915b3871 \ + --hash=sha256:c5a59996dc61835133b56a32ebe4ef3740ea5bc19b3983ac60cc32be5a665d54 \ + --hash=sha256:c73aafd1afca80afecb22718f8700b40ac7cab927b8abab3c3e337d70e10e5a2 \ + --hash=sha256:cee6cc0584f71adefe2c908856ccc98702baf95ff80092e4ca46061538a2ba98 \ + --hash=sha256:cef04d068f5fb0518a77857953193b6bb94809a806bd0a14983a8f12ada060c9 \ + --hash=sha256:cf5d1c9e6837f8af9f92b6bd3e86d513cdc11f60fd62185cc49ec7d1aba34864 \ + --hash=sha256:e61155fae27570692ad1d327e81c6cf27d535a5d7ef97648a17d922224b216de \ + --hash=sha256:e7f387600d424f91576af20518334df3d97bc76a300a755f9a8d6e4f5cadd289 \ + --hash=sha256:ed08d2703b5972ec736451b818c2eb9da80d66c3e84aed1deeb0c345fefe461b \ + --hash=sha256:fbd6acc766814ea6443628f4e6751d0da6593dae29c08c0b2606164db026970c \ + --hash=sha256:feff59f27338135776f6d4e2ec7aeeac5d5f7a08a83e80869121ef8164b74af9 + # via + # -r build/requirements.in + # -r build/test-requirements.txt + # contourpy + # matplotlib + # ml-dtypes + # opt-einsum + # scipy +opt-einsum==3.3.0 \ + --hash=sha256:2455e59e3947d3c275477df7f5205b30635e266fe6dc300e3d9f9646bfcea147 \ + --hash=sha256:59f6475f77bbc37dcf7cd748519c0ec60722e91e63ca114e68821c0c54a46549 + # via -r build/requirements.in +packaging==24.0 \ + --hash=sha256:2ddfb553fdf02fb784c234c7ba6ccc288296ceabec964ad2eae3777778130bc5 \ + --hash=sha256:eb82c5e3e56209074766e6885bb04b8c38a0c015d0a30036ebe7ece34c9989e9 + # via + # build + # matplotlib + # pytest +pillow==10.3.0 \ + --hash=sha256:048ad577748b9fa4a99a0548c64f2cb8d672d5bf2e643a739ac8faff1164238c \ + --hash=sha256:048eeade4c33fdf7e08da40ef402e748df113fd0b4584e32c4af74fe78baaeb2 \ + --hash=sha256:0ba26351b137ca4e0db0342d5d00d2e355eb29372c05afd544ebf47c0956ffeb \ + --hash=sha256:0ea2a783a2bdf2a561808fe4a7a12e9aa3799b701ba305de596bc48b8bdfce9d \ + --hash=sha256:1530e8f3a4b965eb6a7785cf17a426c779333eb62c9a7d1bbcf3ffd5bf77a4aa \ + --hash=sha256:16563993329b79513f59142a6b02055e10514c1a8e86dca8b48a893e33cf91e3 \ + --hash=sha256:19aeb96d43902f0a783946a0a87dbdad5c84c936025b8419da0a0cd7724356b1 \ + --hash=sha256:1a1d1915db1a4fdb2754b9de292642a39a7fb28f1736699527bb649484fb966a \ + --hash=sha256:1b87bd9d81d179bd8ab871603bd80d8645729939f90b71e62914e816a76fc6bd \ + --hash=sha256:1dfc94946bc60ea375cc39cff0b8da6c7e5f8fcdc1d946beb8da5c216156ddd8 \ + --hash=sha256:2034f6759a722da3a3dbd91a81148cf884e91d1b747992ca288ab88c1de15999 \ + --hash=sha256:261ddb7ca91fcf71757979534fb4c128448b5b4c55cb6152d280312062f69599 \ + --hash=sha256:2ed854e716a89b1afcedea551cd85f2eb2a807613752ab997b9974aaa0d56936 \ + --hash=sha256:3102045a10945173d38336f6e71a8dc71bcaeed55c3123ad4af82c52807b9375 \ + --hash=sha256:339894035d0ede518b16073bdc2feef4c991ee991a29774b33e515f1d308e08d \ + --hash=sha256:412444afb8c4c7a6cc11a47dade32982439925537e483be7c0ae0cf96c4f6a0b \ + --hash=sha256:4203efca580f0dd6f882ca211f923168548f7ba334c189e9eab1178ab840bf60 \ + --hash=sha256:45ebc7b45406febf07fef35d856f0293a92e7417ae7933207e90bf9090b70572 \ + --hash=sha256:4b5ec25d8b17217d635f8935dbc1b9aa5907962fae29dff220f2659487891cd3 \ + --hash=sha256:4c8e73e99da7db1b4cad7f8d682cf6abad7844da39834c288fbfa394a47bbced \ + --hash=sha256:4e6f7d1c414191c1199f8996d3f2282b9ebea0945693fb67392c75a3a320941f \ + --hash=sha256:4eaa22f0d22b1a7e93ff0a596d57fdede2e550aecffb5a1ef1106aaece48e96b \ + --hash=sha256:50b8eae8f7334ec826d6eeffaeeb00e36b5e24aa0b9df322c247539714c6df19 \ + --hash=sha256:50fd3f6b26e3441ae07b7c979309638b72abc1a25da31a81a7fbd9495713ef4f \ + --hash=sha256:51243f1ed5161b9945011a7360e997729776f6e5d7005ba0c6879267d4c5139d \ + --hash=sha256:5d512aafa1d32efa014fa041d38868fda85028e3f930a96f85d49c7d8ddc0383 \ + --hash=sha256:5f77cf66e96ae734717d341c145c5949c63180842a545c47a0ce7ae52ca83795 \ + --hash=sha256:6b02471b72526ab8a18c39cb7967b72d194ec53c1fd0a70b050565a0f366d355 \ + --hash=sha256:6fb1b30043271ec92dc65f6d9f0b7a830c210b8a96423074b15c7bc999975f57 \ + --hash=sha256:7161ec49ef0800947dc5570f86568a7bb36fa97dd09e9827dc02b718c5643f09 \ + --hash=sha256:72d622d262e463dfb7595202d229f5f3ab4b852289a1cd09650362db23b9eb0b \ + --hash=sha256:74d28c17412d9caa1066f7a31df8403ec23d5268ba46cd0ad2c50fb82ae40462 \ + --hash=sha256:78618cdbccaa74d3f88d0ad6cb8ac3007f1a6fa5c6f19af64b55ca170bfa1edf \ + --hash=sha256:793b4e24db2e8742ca6423d3fde8396db336698c55cd34b660663ee9e45ed37f \ + --hash=sha256:798232c92e7665fe82ac085f9d8e8ca98826f8e27859d9a96b41d519ecd2e49a \ + --hash=sha256:81d09caa7b27ef4e61cb7d8fbf1714f5aec1c6b6c5270ee53504981e6e9121ad \ + --hash=sha256:8ab74c06ffdab957d7670c2a5a6e1a70181cd10b727cd788c4dd9005b6a8acd9 \ + --hash=sha256:8eb0908e954d093b02a543dc963984d6e99ad2b5e36503d8a0aaf040505f747d \ + --hash=sha256:90b9e29824800e90c84e4022dd5cc16eb2d9605ee13f05d47641eb183cd73d45 \ + --hash=sha256:9797a6c8fe16f25749b371c02e2ade0efb51155e767a971c61734b1bf6293994 \ + --hash=sha256:9d2455fbf44c914840c793e89aa82d0e1763a14253a000743719ae5946814b2d \ + --hash=sha256:9d3bea1c75f8c53ee4d505c3e67d8c158ad4df0d83170605b50b64025917f338 \ + --hash=sha256:9e2ec1e921fd07c7cda7962bad283acc2f2a9ccc1b971ee4b216b75fad6f0463 \ + --hash=sha256:9e91179a242bbc99be65e139e30690e081fe6cb91a8e77faf4c409653de39451 \ + --hash=sha256:a0eaa93d054751ee9964afa21c06247779b90440ca41d184aeb5d410f20ff591 \ + --hash=sha256:a2c405445c79c3f5a124573a051062300936b0281fee57637e706453e452746c \ + --hash=sha256:aa7e402ce11f0885305bfb6afb3434b3cd8f53b563ac065452d9d5654c7b86fd \ + --hash=sha256:aff76a55a8aa8364d25400a210a65ff59d0168e0b4285ba6bf2bd83cf675ba32 \ + --hash=sha256:b09b86b27a064c9624d0a6c54da01c1beaf5b6cadfa609cf63789b1d08a797b9 \ + --hash=sha256:b14f16f94cbc61215115b9b1236f9c18403c15dd3c52cf629072afa9d54c1cbf \ + --hash=sha256:b50811d664d392f02f7761621303eba9d1b056fb1868c8cdf4231279645c25f5 \ + --hash=sha256:b7bc2176354defba3edc2b9a777744462da2f8e921fbaf61e52acb95bafa9828 \ + --hash=sha256:c78e1b00a87ce43bb37642c0812315b411e856a905d58d597750eb79802aaaa3 \ + --hash=sha256:c83341b89884e2b2e55886e8fbbf37c3fa5efd6c8907124aeb72f285ae5696e5 \ + --hash=sha256:ca2870d5d10d8726a27396d3ca4cf7976cec0f3cb706debe88e3a5bd4610f7d2 \ + --hash=sha256:ccce24b7ad89adb5a1e34a6ba96ac2530046763912806ad4c247356a8f33a67b \ + --hash=sha256:cd5e14fbf22a87321b24c88669aad3a51ec052eb145315b3da3b7e3cc105b9a2 \ + --hash=sha256:ce49c67f4ea0609933d01c0731b34b8695a7a748d6c8d186f95e7d085d2fe475 \ + --hash=sha256:d33891be6df59d93df4d846640f0e46f1a807339f09e79a8040bc887bdcd7ed3 \ + --hash=sha256:d3b2348a78bc939b4fed6552abfd2e7988e0f81443ef3911a4b8498ca084f6eb \ + --hash=sha256:d886f5d353333b4771d21267c7ecc75b710f1a73d72d03ca06df49b09015a9ef \ + --hash=sha256:d93480005693d247f8346bc8ee28c72a2191bdf1f6b5db469c096c0c867ac015 \ + --hash=sha256:dc1a390a82755a8c26c9964d457d4c9cbec5405896cba94cf51f36ea0d855002 \ + --hash=sha256:dd78700f5788ae180b5ee8902c6aea5a5726bac7c364b202b4b3e3ba2d293170 \ + --hash=sha256:e46f38133e5a060d46bd630faa4d9fa0202377495df1f068a8299fd78c84de84 \ + --hash=sha256:e4b878386c4bf293578b48fc570b84ecfe477d3b77ba39a6e87150af77f40c57 \ + --hash=sha256:f0d0591a0aeaefdaf9a5e545e7485f89910c977087e7de2b6c388aec32011e9f \ + --hash=sha256:fdcbb4068117dfd9ce0138d068ac512843c52295ed996ae6dd1faf537b6dbc27 \ + --hash=sha256:ff61bfd9253c3915e6d41c651d5f962da23eda633cf02262990094a18a55371a + # via + # -r build/test-requirements.txt + # matplotlib +pluggy==1.5.0 \ + --hash=sha256:2cffa88e94fdc978c4c574f15f9e59b7f4201d439195c3715ca9e2486f1d0cf1 \ + --hash=sha256:44e1ad92c8ca002de6377e165f3e0f1be63266ab4d554740532335b9d75ea669 + # via pytest +portpicker==1.6.0 \ + --hash=sha256:b2787a41404cf7edbe29b07b9e0ed863b09f2665dcc01c1eb0c2261c1e7d0755 \ + --hash=sha256:bd507fd6f96f65ee02781f2e674e9dc6c99bbfa6e3c39992e3916204c9d431fa + # via -r build/test-requirements.txt +psutil==5.9.8 \ + --hash=sha256:02615ed8c5ea222323408ceba16c60e99c3f91639b07da6373fb7e6539abc56d \ + --hash=sha256:05806de88103b25903dff19bb6692bd2e714ccf9e668d050d144012055cbca73 \ + --hash=sha256:26bd09967ae00920df88e0352a91cff1a78f8d69b3ecabbfe733610c0af486c8 \ + --hash=sha256:27cc40c3493bb10de1be4b3f07cae4c010ce715290a5be22b98493509c6299e2 \ + --hash=sha256:36f435891adb138ed3c9e58c6af3e2e6ca9ac2f365efe1f9cfef2794e6c93b4e \ + --hash=sha256:50187900d73c1381ba1454cf40308c2bf6f34268518b3f36a9b663ca87e65e36 \ + --hash=sha256:611052c4bc70432ec770d5d54f64206aa7203a101ec273a0cd82418c86503bb7 \ + --hash=sha256:6be126e3225486dff286a8fb9a06246a5253f4c7c53b475ea5f5ac934e64194c \ + --hash=sha256:7d79560ad97af658a0f6adfef8b834b53f64746d45b403f225b85c5c2c140eee \ + --hash=sha256:8cb6403ce6d8e047495a701dc7c5bd788add903f8986d523e3e20b98b733e421 \ + --hash=sha256:8db4c1b57507eef143a15a6884ca10f7c73876cdf5d51e713151c1236a0e68cf \ + --hash=sha256:aee678c8720623dc456fa20659af736241f575d79429a0e5e9cf88ae0605cc81 \ + --hash=sha256:bc56c2a1b0d15aa3eaa5a60c9f3f8e3e565303b465dbf57a1b730e7a2b9844e0 \ + --hash=sha256:bd1184ceb3f87651a67b2708d4c3338e9b10c5df903f2e3776b62303b26cb631 \ + --hash=sha256:d06016f7f8625a1825ba3732081d77c94589dca78b7a3fc072194851e88461a4 \ + --hash=sha256:d16bbddf0693323b8c6123dd804100241da461e41d6e332fb0ba6058f630f8c8 + # via portpicker +pygments==2.18.0 \ + --hash=sha256:786ff802f32e91311bff3889f6e9a86e81505fe99f2735bb6d60ae0c5004f199 \ + --hash=sha256:b8e6aca0523f3ab76fee51799c488e38782ac06eafcf95e7ba832985c8e7b13a + # via rich +pyparsing==3.1.2 \ + --hash=sha256:a1bac0ce561155ecc3ed78ca94d3c9378656ad4c94c1270de543f621420f94ad \ + --hash=sha256:f9db75911801ed778fe61bb643079ff86601aca99fcae6345aa67292038fb742 + # via matplotlib +pyproject-hooks==1.1.0 \ + --hash=sha256:4b37730834edbd6bd37f26ece6b44802fb1c1ee2ece0e54ddff8bfc06db86965 \ + --hash=sha256:7ceeefe9aec63a1064c18d939bdc3adf2d8aa1988a510afec15151578b232aa2 + # via build +pytest==8.2.0 \ + --hash=sha256:1733f0620f6cda4095bbf0d9ff8022486e91892245bb9e7d5542c018f612f233 \ + --hash=sha256:d507d4482197eac0ba2bae2e9babf0672eb333017bcedaa5fb1a3d42c1174b3f + # via pytest-xdist +pytest-xdist==3.6.1 \ + --hash=sha256:9ed4adfb68a016610848639bb7e02c9352d5d9f03d04809919e2dafc3be4cca7 \ + --hash=sha256:ead156a4db231eec769737f57668ef58a2084a34b2e55c4a8fa20d861107300d + # via -r build/test-requirements.txt +python-dateutil==2.9.0.post0 \ + --hash=sha256:37dd54208da7e1cd875388217d5e00ebd4179249f90fb72437e91a35459a0ad3 \ + --hash=sha256:a8b2bc7bffae282281c8140a97d3aa9c14da0b136dfe83f850eea9a5f7470427 + # via matplotlib +rich==13.7.1 \ + --hash=sha256:4edbae314f59eb482f54e9e30bf00d33350aaa94f4bfcd4e9e3110e64d0d7222 \ + --hash=sha256:9be308cb1fe2f1f57d67ce99e95af38a1e2bc71ad9813b0e247cf7ffbcc3a432 + # via -r build/test-requirements.txt +scipy==1.13.1 \ + --hash=sha256:017367484ce5498445aade74b1d5ab377acdc65e27095155e448c88497755a5d \ + --hash=sha256:095a87a0312b08dfd6a6155cbbd310a8c51800fc931b8c0b84003014b874ed3c \ + --hash=sha256:20335853b85e9a49ff7572ab453794298bcf0354d8068c5f6775a0eabf350aca \ + --hash=sha256:27e52b09c0d3a1d5b63e1105f24177e544a222b43611aaf5bc44d4a0979e32f9 \ + --hash=sha256:2831f0dc9c5ea9edd6e51e6e769b655f08ec6db6e2e10f86ef39bd32eb11da54 \ + --hash=sha256:2ac65fb503dad64218c228e2dc2d0a0193f7904747db43014645ae139c8fad16 \ + --hash=sha256:392e4ec766654852c25ebad4f64e4e584cf19820b980bc04960bca0b0cd6eaa2 \ + --hash=sha256:436bbb42a94a8aeef855d755ce5a465479c721e9d684de76bf61a62e7c2b81d5 \ + --hash=sha256:45484bee6d65633752c490404513b9ef02475b4284c4cfab0ef946def50b3f59 \ + --hash=sha256:54f430b00f0133e2224c3ba42b805bfd0086fe488835effa33fa291561932326 \ + --hash=sha256:5713f62f781eebd8d597eb3f88b8bf9274e79eeabf63afb4a737abc6c84ad37b \ + --hash=sha256:5d72782f39716b2b3509cd7c33cdc08c96f2f4d2b06d51e52fb45a19ca0c86a1 \ + --hash=sha256:637e98dcf185ba7f8e663e122ebf908c4702420477ae52a04f9908707456ba4d \ + --hash=sha256:8335549ebbca860c52bf3d02f80784e91a004b71b059e3eea9678ba994796a24 \ + --hash=sha256:949ae67db5fa78a86e8fa644b9a6b07252f449dcf74247108c50e1d20d2b4627 \ + --hash=sha256:a014c2b3697bde71724244f63de2476925596c24285c7a637364761f8710891c \ + --hash=sha256:a78b4b3345f1b6f68a763c6e25c0c9a23a9fd0f39f5f3d200efe8feda560a5fa \ + --hash=sha256:cdd7dacfb95fea358916410ec61bbc20440f7860333aee6d882bb8046264e949 \ + --hash=sha256:cfa31f1def5c819b19ecc3a8b52d28ffdcc7ed52bb20c9a7589669dd3c250989 \ + --hash=sha256:d533654b7d221a6a97304ab63c41c96473ff04459e404b83275b60aa8f4b7004 \ + --hash=sha256:d605e9c23906d1994f55ace80e0125c587f96c020037ea6aa98d01b4bd2e222f \ + --hash=sha256:de3ade0e53bc1f21358aa74ff4830235d716211d7d077e340c7349bc3542e884 \ + --hash=sha256:e89369d27f9e7b0884ae559a3a956e77c02114cc60a6058b4e5011572eea9299 \ + --hash=sha256:eccfa1906eacc02de42d70ef4aecea45415f5be17e72b61bafcfd329bdc52e94 \ + --hash=sha256:f26264b282b9da0952a024ae34710c2aff7d27480ee91a2e82b7b7073c24722f + # via -r build/requirements.in +six==1.16.0 \ + --hash=sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926 \ + --hash=sha256:8abb2f1d86890a2dfb989f9a77cfcfd3e47c2a354b01111771326f8aa26e0254 + # via python-dateutil +sortedcontainers==2.4.0 \ + --hash=sha256:25caa5a06cc30b6b83d11423433f65d1f9d76c4c6a0c90e3379eaa43b9bfdb88 \ + --hash=sha256:a163dcaede0f1c021485e957a39245190e74249897e2ae4b2aa38595db237ee0 + # via hypothesis +tomli==2.0.1 \ + --hash=sha256:939de3e7a6161af0c887ef91b7d41a53e7c5a1ca976325f429cb46ea9bc30ecc \ + --hash=sha256:de526c12914f0c550d15924c62d72abc48d6fe7364aa87328337a31007fe8a4f + # via + # build + # pytest +typing-extensions==4.12.0rc1 \ + --hash=sha256:be199d06d8f09ca2c9425e3aa04a9afba33e892fe079dea959e72df7f8442343 \ + --hash=sha256:f933a7b288a919ca97adbff656e52ff81f7ff25d98a2aabb9355ca4090f772fe + # via etils +wheel==0.43.0 \ + --hash=sha256:465ef92c69fa5c5da2d1cf8ac40559a8c940886afcef87dcf14b9470862f1d85 \ + --hash=sha256:55c570405f142630c6b9f72fe09d9b67cf1477fcf543ae5b8dcb1f5b7377da81 + # via -r build/test-requirements.txt +zipp==3.18.2 \ + --hash=sha256:6278d9ddbcfb1f1089a88fde84481528b07b0e10474e09dcfe53dad4069fa059 \ + --hash=sha256:dce197b859eb796242b0622af1b8beb0a722d52aa2f57133ead08edd5bf5374e + # via etils +zstandard==0.22.0 \ + --hash=sha256:11f0d1aab9516a497137b41e3d3ed4bbf7b2ee2abc79e5c8b010ad286d7464bd \ + --hash=sha256:1958100b8a1cc3f27fa21071a55cb2ed32e9e5df4c3c6e661c193437f171cba2 \ + --hash=sha256:1a90ba9a4c9c884bb876a14be2b1d216609385efb180393df40e5172e7ecf356 \ + --hash=sha256:1d43501f5f31e22baf822720d82b5547f8a08f5386a883b32584a185675c8fbf \ + --hash=sha256:23d2b3c2b8e7e5a6cb7922f7c27d73a9a615f0a5ab5d0e03dd533c477de23004 \ + --hash=sha256:2612e9bb4977381184bb2463150336d0f7e014d6bb5d4a370f9a372d21916f69 \ + --hash=sha256:275df437ab03f8c033b8a2c181e51716c32d831082d93ce48002a5227ec93019 \ + --hash=sha256:2ac9957bc6d2403c4772c890916bf181b2653640da98f32e04b96e4d6fb3252a \ + --hash=sha256:2b11ea433db22e720758cba584c9d661077121fcf60ab43351950ded20283440 \ + --hash=sha256:2fdd53b806786bd6112d97c1f1e7841e5e4daa06810ab4b284026a1a0e484c0b \ + --hash=sha256:33591d59f4956c9812f8063eff2e2c0065bc02050837f152574069f5f9f17775 \ + --hash=sha256:36a47636c3de227cd765e25a21dc5dace00539b82ddd99ee36abae38178eff9e \ + --hash=sha256:39b2853efc9403927f9065cc48c9980649462acbdf81cd4f0cb773af2fd734bc \ + --hash=sha256:3db41c5e49ef73641d5111554e1d1d3af106410a6c1fb52cf68912ba7a343a0d \ + --hash=sha256:445b47bc32de69d990ad0f34da0e20f535914623d1e506e74d6bc5c9dc40bb09 \ + --hash=sha256:466e6ad8caefb589ed281c076deb6f0cd330e8bc13c5035854ffb9c2014b118c \ + --hash=sha256:48f260e4c7294ef275744210a4010f116048e0c95857befb7462e033f09442fe \ + --hash=sha256:4ac59d5d6910b220141c1737b79d4a5aa9e57466e7469a012ed42ce2d3995e88 \ + --hash=sha256:53866a9d8ab363271c9e80c7c2e9441814961d47f88c9bc3b248142c32141d94 \ + --hash=sha256:589402548251056878d2e7c8859286eb91bd841af117dbe4ab000e6450987e08 \ + --hash=sha256:68953dc84b244b053c0d5f137a21ae8287ecf51b20872eccf8eaac0302d3e3b0 \ + --hash=sha256:6c25b8eb733d4e741246151d895dd0308137532737f337411160ff69ca24f93a \ + --hash=sha256:7034d381789f45576ec3f1fa0e15d741828146439228dc3f7c59856c5bcd3292 \ + --hash=sha256:73a1d6bd01961e9fd447162e137ed949c01bdb830dfca487c4a14e9742dccc93 \ + --hash=sha256:8226a33c542bcb54cd6bd0a366067b610b41713b64c9abec1bc4533d69f51e70 \ + --hash=sha256:888196c9c8893a1e8ff5e89b8f894e7f4f0e64a5af4d8f3c410f0319128bb2f8 \ + --hash=sha256:88c5b4b47a8a138338a07fc94e2ba3b1535f69247670abfe422de4e0b344aae2 \ + --hash=sha256:8a1b2effa96a5f019e72874969394edd393e2fbd6414a8208fea363a22803b45 \ + --hash=sha256:93e1856c8313bc688d5df069e106a4bc962eef3d13372020cc6e3ebf5e045202 \ + --hash=sha256:9501f36fac6b875c124243a379267d879262480bf85b1dbda61f5ad4d01b75a3 \ + --hash=sha256:959665072bd60f45c5b6b5d711f15bdefc9849dd5da9fb6c873e35f5d34d8cfb \ + --hash=sha256:a1d67d0d53d2a138f9e29d8acdabe11310c185e36f0a848efa104d4e40b808e4 \ + --hash=sha256:a493d470183ee620a3df1e6e55b3e4de8143c0ba1b16f3ded83208ea8ddfd91d \ + --hash=sha256:a7ccf5825fd71d4542c8ab28d4d482aace885f5ebe4b40faaa290eed8e095a4c \ + --hash=sha256:a88b7df61a292603e7cd662d92565d915796b094ffb3d206579aaebac6b85d5f \ + --hash=sha256:a97079b955b00b732c6f280d5023e0eefe359045e8b83b08cf0333af9ec78f26 \ + --hash=sha256:d22fdef58976457c65e2796e6730a3ea4a254f3ba83777ecfc8592ff8d77d303 \ + --hash=sha256:d75f693bb4e92c335e0645e8845e553cd09dc91616412d1d4650da835b5449df \ + --hash=sha256:d8593f8464fb64d58e8cb0b905b272d40184eac9a18d83cf8c10749c3eafcd7e \ + --hash=sha256:d8fff0f0c1d8bc5d866762ae95bd99d53282337af1be9dc0d88506b340e74b73 \ + --hash=sha256:de20a212ef3d00d609d0b22eb7cc798d5a69035e81839f549b538eff4105d01c \ + --hash=sha256:e9e9d4e2e336c529d4c435baad846a181e39a982f823f7e4495ec0b0ec8538d2 \ + --hash=sha256:f058a77ef0ece4e210bb0450e68408d4223f728b109764676e1a13537d056bb0 \ + --hash=sha256:f1a4b358947a65b94e2501ce3e078bbc929b039ede4679ddb0460829b12f7375 \ + --hash=sha256:f9b2cde1cd1b2a10246dbc143ba49d942d14fb3d2b4bccf4618d475c65464912 \ + --hash=sha256:fe3390c538f12437b859d815040763abc728955a52ca6ff9c5d4ac707c4ad98e + # via -r build/requirements.in + +# The following packages are considered to be unsafe in a requirements file: +setuptools==69.5.1 \ + --hash=sha256:6c1fccdac05a97e598fb0ae3bbed5904ccb317337a51139dcd51453611bbb987 \ + --hash=sha256:c636ac361bc47580504644275c9ad802c50415c7522212252c033bd15f301f32 + # via -r build/test-requirements.txt diff --git a/build/requirements_lock_3_11.txt b/build/requirements_lock_3_11.txt new file mode 100644 index 000000000000..053e996cefad --- /dev/null +++ b/build/requirements_lock_3_11.txt @@ -0,0 +1,613 @@ +# +# This file is autogenerated by pip-compile with Python 3.11 +# by the following command: +# +# bazel run //build:requirements.update +# +absl-py==2.1.0 \ + --hash=sha256:526a04eadab8b4ee719ce68f204172ead1027549089702d99b9059f129ff1308 \ + --hash=sha256:7820790efbb316739cde8b4e19357243fc3608a152024288513dd968d7d959ff + # via -r build/test-requirements.txt +attrs==23.2.0 \ + --hash=sha256:935dc3b529c262f6cf76e50877d35a4bd3c1de194fd41f47a2b7ae8f19971f30 \ + --hash=sha256:99b87a485a5820b23b879f04c2305b44b951b502fd64be915879d77a7e8fc6f1 + # via hypothesis +build==1.2.1 \ + --hash=sha256:526263f4870c26f26c433545579475377b2b7588b6f1eac76a001e873ae3e19d \ + --hash=sha256:75e10f767a433d9a86e50d83f418e83efc18ede923ee5ff7df93b6cb0306c5d4 + # via -r build/test-requirements.txt +cloudpickle==3.0.0 \ + --hash=sha256:246ee7d0c295602a036e86369c77fecda4ab17b506496730f2f576d9016fd9c7 \ + --hash=sha256:996d9a482c6fb4f33c1a35335cf8afd065d2a56e973270364840712d9131a882 + # via -r build/test-requirements.txt +colorama==0.4.6 \ + --hash=sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44 \ + --hash=sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6 + # via -r build/test-requirements.txt +contourpy==1.2.1 \ + --hash=sha256:00e5388f71c1a0610e6fe56b5c44ab7ba14165cdd6d695429c5cd94021e390b2 \ + --hash=sha256:10a37ae557aabf2509c79715cd20b62e4c7c28b8cd62dd7d99e5ed3ce28c3fd9 \ + --hash=sha256:11959f0ce4a6f7b76ec578576a0b61a28bdc0696194b6347ba3f1c53827178b9 \ + --hash=sha256:187fa1d4c6acc06adb0fae5544c59898ad781409e61a926ac7e84b8f276dcef4 \ + --hash=sha256:1a07fc092a4088ee952ddae19a2b2a85757b923217b7eed584fdf25f53a6e7ce \ + --hash=sha256:1cac0a8f71a041aa587410424ad46dfa6a11f6149ceb219ce7dd48f6b02b87a7 \ + --hash=sha256:1d59e739ab0e3520e62a26c60707cc3ab0365d2f8fecea74bfe4de72dc56388f \ + --hash=sha256:2855c8b0b55958265e8b5888d6a615ba02883b225f2227461aa9127c578a4922 \ + --hash=sha256:2e785e0f2ef0d567099b9ff92cbfb958d71c2d5b9259981cd9bee81bd194c9a4 \ + --hash=sha256:309be79c0a354afff9ff7da4aaed7c3257e77edf6c1b448a779329431ee79d7e \ + --hash=sha256:39f3ecaf76cd98e802f094e0d4fbc6dc9c45a8d0c4d185f0f6c2234e14e5f75b \ + --hash=sha256:457499c79fa84593f22454bbd27670227874cd2ff5d6c84e60575c8b50a69619 \ + --hash=sha256:49e70d111fee47284d9dd867c9bb9a7058a3c617274900780c43e38d90fe1205 \ + --hash=sha256:4c75507d0a55378240f781599c30e7776674dbaf883a46d1c90f37e563453480 \ + --hash=sha256:4c863140fafc615c14a4bf4efd0f4425c02230eb8ef02784c9a156461e62c965 \ + --hash=sha256:4d8908b3bee1c889e547867ca4cdc54e5ab6be6d3e078556814a22457f49423c \ + --hash=sha256:5b9eb0ca724a241683c9685a484da9d35c872fd42756574a7cfbf58af26677fd \ + --hash=sha256:6022cecf8f44e36af10bd9118ca71f371078b4c168b6e0fab43d4a889985dbb5 \ + --hash=sha256:6150ffa5c767bc6332df27157d95442c379b7dce3a38dff89c0f39b63275696f \ + --hash=sha256:62828cada4a2b850dbef89c81f5a33741898b305db244904de418cc957ff05dc \ + --hash=sha256:7b4182299f251060996af5249c286bae9361fa8c6a9cda5efc29fe8bfd6062ec \ + --hash=sha256:94b34f32646ca0414237168d68a9157cb3889f06b096612afdd296003fdd32fd \ + --hash=sha256:9ce6889abac9a42afd07a562c2d6d4b2b7134f83f18571d859b25624a331c90b \ + --hash=sha256:9cffe0f850e89d7c0012a1fb8730f75edd4320a0a731ed0c183904fe6ecfc3a9 \ + --hash=sha256:a12a813949e5066148712a0626895c26b2578874e4cc63160bb007e6df3436fe \ + --hash=sha256:a1eea9aecf761c661d096d39ed9026574de8adb2ae1c5bd7b33558af884fb2ce \ + --hash=sha256:a31f94983fecbac95e58388210427d68cd30fe8a36927980fab9c20062645609 \ + --hash=sha256:ac58bdee53cbeba2ecad824fa8159493f0bf3b8ea4e93feb06c9a465d6c87da8 \ + --hash=sha256:af3f4485884750dddd9c25cb7e3915d83c2db92488b38ccb77dd594eac84c4a0 \ + --hash=sha256:b33d2bc4f69caedcd0a275329eb2198f560b325605810895627be5d4b876bf7f \ + --hash=sha256:b59c0ffceff8d4d3996a45f2bb6f4c207f94684a96bf3d9728dbb77428dd8cb8 \ + --hash=sha256:bb6834cbd983b19f06908b45bfc2dad6ac9479ae04abe923a275b5f48f1a186b \ + --hash=sha256:bd3db01f59fdcbce5b22afad19e390260d6d0222f35a1023d9adc5690a889364 \ + --hash=sha256:bd7c23df857d488f418439686d3b10ae2fbf9bc256cd045b37a8c16575ea1040 \ + --hash=sha256:c2528d60e398c7c4c799d56f907664673a807635b857df18f7ae64d3e6ce2d9f \ + --hash=sha256:d31a63bc6e6d87f77d71e1abbd7387ab817a66733734883d1fc0021ed9bfa083 \ + --hash=sha256:d4492d82b3bc7fbb7e3610747b159869468079fe149ec5c4d771fa1f614a14df \ + --hash=sha256:ddcb8581510311e13421b1f544403c16e901c4e8f09083c881fab2be80ee31ba \ + --hash=sha256:e1d59258c3c67c865435d8fbeb35f8c59b8bef3d6f46c1f29f6123556af28445 \ + --hash=sha256:eb3315a8a236ee19b6df481fc5f997436e8ade24a9f03dfdc6bd490fea20c6da \ + --hash=sha256:ef2b055471c0eb466033760a521efb9d8a32b99ab907fc8358481a1dd29e3bd3 \ + --hash=sha256:ef5adb9a3b1d0c645ff694f9bca7702ec2c70f4d734f9922ea34de02294fdf72 \ + --hash=sha256:f32c38afb74bd98ce26de7cc74a67b40afb7b05aae7b42924ea990d51e4dac02 \ + --hash=sha256:fe0ccca550bb8e5abc22f530ec0466136379c01321fd94f30a22231e8a48d985 + # via matplotlib +cycler==0.12.1 \ + --hash=sha256:85cef7cff222d8644161529808465972e51340599459b8ac3ccbac5a854e0d30 \ + --hash=sha256:88bb128f02ba341da8ef447245a9e138fae777f6a23943da4540077d3601eb1c + # via matplotlib +etils[epath,epy]==1.8.0 \ + --hash=sha256:f31d7f27a889457eaa44eab18ce836d24fd6d40dbbb167d38879b7296f6456ea \ + --hash=sha256:fb478f57fec202e260e54c9192b317692fd63db2d11d993e70bcdffa29cccd58 + # via -r build/requirements.in +execnet==2.1.1 \ + --hash=sha256:26dee51f1b80cebd6d0ca8e74dd8745419761d3bef34163928cbebbdc4749fdc \ + --hash=sha256:5189b52c6121c24feae288166ab41b32549c7e2348652736540b9e6e7d4e72e3 + # via pytest-xdist +filelock==3.14.0 \ + --hash=sha256:43339835842f110ca7ae60f1e1c160714c5a6afd15a2873419ab185334975c0f \ + --hash=sha256:6ea72da3be9b8c82afd3edcf99f2fffbb5076335a5ae4d03248bb5b6c3eae78a + # via -r build/test-requirements.txt +flatbuffers==24.3.25 \ + --hash=sha256:8dbdec58f935f3765e4f7f3cf635ac3a77f83568138d6a2311f524ec96364812 \ + --hash=sha256:de2ec5b203f21441716617f38443e0a8ebf3d25bf0d9c0bb0ce68fa00ad546a4 + # via -r build/test-requirements.txt +fonttools==4.51.0 \ + --hash=sha256:0118ef998a0699a96c7b28457f15546815015a2710a1b23a7bf6c1be60c01636 \ + --hash=sha256:0d145976194a5242fdd22df18a1b451481a88071feadf251221af110ca8f00ce \ + --hash=sha256:0e19bd9e9964a09cd2433a4b100ca7f34e34731e0758e13ba9a1ed6e5468cc0f \ + --hash=sha256:0f08c901d3866a8905363619e3741c33f0a83a680d92a9f0e575985c2634fcc1 \ + --hash=sha256:1250e818b5f8a679ad79660855528120a8f0288f8f30ec88b83db51515411fcc \ + --hash=sha256:15c94eeef6b095831067f72c825eb0e2d48bb4cea0647c1b05c981ecba2bf39f \ + --hash=sha256:1621ee57da887c17312acc4b0e7ac30d3a4fb0fec6174b2e3754a74c26bbed1e \ + --hash=sha256:180194c7fe60c989bb627d7ed5011f2bef1c4d36ecf3ec64daec8302f1ae0716 \ + --hash=sha256:278e50f6b003c6aed19bae2242b364e575bcb16304b53f2b64f6551b9c000e15 \ + --hash=sha256:32b17504696f605e9e960647c5f64b35704782a502cc26a37b800b4d69ff3c77 \ + --hash=sha256:3bee3f3bd9fa1d5ee616ccfd13b27ca605c2b4270e45715bd2883e9504735034 \ + --hash=sha256:4060acc2bfa2d8e98117828a238889f13b6f69d59f4f2d5857eece5277b829ba \ + --hash=sha256:54dcf21a2f2d06ded676e3c3f9f74b2bafded3a8ff12f0983160b13e9f2fb4a7 \ + --hash=sha256:56fc244f2585d6c00b9bcc59e6593e646cf095a96fe68d62cd4da53dd1287b55 \ + --hash=sha256:599bdb75e220241cedc6faebfafedd7670335d2e29620d207dd0378a4e9ccc5a \ + --hash=sha256:5f6bc991d1610f5c3bbe997b0233cbc234b8e82fa99fc0b2932dc1ca5e5afec0 \ + --hash=sha256:60a3409c9112aec02d5fb546f557bca6efa773dcb32ac147c6baf5f742e6258b \ + --hash=sha256:68b3fb7775a923be73e739f92f7e8a72725fd333eab24834041365d2278c3671 \ + --hash=sha256:76f1777d8b3386479ffb4a282e74318e730014d86ce60f016908d9801af9ca2a \ + --hash=sha256:806e7912c32a657fa39d2d6eb1d3012d35f841387c8fc6cf349ed70b7c340039 \ + --hash=sha256:84d7751f4468dd8cdd03ddada18b8b0857a5beec80bce9f435742abc9a851a74 \ + --hash=sha256:865a58b6e60b0938874af0968cd0553bcd88e0b2cb6e588727117bd099eef836 \ + --hash=sha256:8ac27f436e8af7779f0bb4d5425aa3535270494d3bc5459ed27de3f03151e4c2 \ + --hash=sha256:8b4850fa2ef2cfbc1d1f689bc159ef0f45d8d83298c1425838095bf53ef46308 \ + --hash=sha256:8b5ad456813d93b9c4b7ee55302208db2b45324315129d85275c01f5cb7e61a2 \ + --hash=sha256:8e2f1a4499e3b5ee82c19b5ee57f0294673125c65b0a1ff3764ea1f9db2f9ef5 \ + --hash=sha256:9696fe9f3f0c32e9a321d5268208a7cc9205a52f99b89479d1b035ed54c923f1 \ + --hash=sha256:96a48e137c36be55e68845fc4284533bda2980f8d6f835e26bca79d7e2006438 \ + --hash=sha256:a8feca65bab31479d795b0d16c9a9852902e3a3c0630678efb0b2b7941ea9c74 \ + --hash=sha256:aefa011207ed36cd280babfaa8510b8176f1a77261833e895a9d96e57e44802f \ + --hash=sha256:b2b92381f37b39ba2fc98c3a45a9d6383bfc9916a87d66ccb6553f7bdd129097 \ + --hash=sha256:b3c61423f22165541b9403ee39874dcae84cd57a9078b82e1dce8cb06b07fa2e \ + --hash=sha256:b5b48a1121117047d82695d276c2af2ee3a24ffe0f502ed581acc2673ecf1037 \ + --hash=sha256:c18b49adc721a7d0b8dfe7c3130c89b8704baf599fb396396d07d4aa69b824a1 \ + --hash=sha256:c5b8cab0c137ca229433570151b5c1fc6af212680b58b15abd797dcdd9dd5051 \ + --hash=sha256:c7e91abdfae1b5c9e3a543f48ce96013f9a08c6c9668f1e6be0beabf0a569c1b \ + --hash=sha256:cadf4e12a608ef1d13e039864f484c8a968840afa0258b0b843a0556497ea9ed \ + --hash=sha256:dc0673361331566d7a663d7ce0f6fdcbfbdc1f59c6e3ed1165ad7202ca183c68 \ + --hash=sha256:de7c29bdbdd35811f14493ffd2534b88f0ce1b9065316433b22d63ca1cd21f14 \ + --hash=sha256:e9d9298be7a05bb4801f558522adbe2feea1b0b103d5294ebf24a92dd49b78e5 \ + --hash=sha256:ee1af4be1c5afe4c96ca23badd368d8dc75f611887fb0c0dac9f71ee5d6f110e \ + --hash=sha256:f7e89853d8bea103c8e3514b9f9dc86b5b4120afb4583b57eb10dfa5afbe0936 + # via matplotlib +fsspec==2024.5.0 \ + --hash=sha256:1d021b0b0f933e3b3029ed808eb400c08ba101ca2de4b3483fbc9ca23fcee94a \ + --hash=sha256:e0fdbc446d67e182f49a70b82cf7889028a63588fde6b222521f10937b2b670c + # via etils +hypothesis==6.102.4 \ + --hash=sha256:013df31b04a4daede13756f497e60e451963d86f426395a79f99c5d692919bbd \ + --hash=sha256:59b4d144346d5cffb482cc1bafbd21b13ff31608e8c4b3e4630339aee3e87763 + # via -r build/test-requirements.txt +importlib-resources==6.4.0 \ + --hash=sha256:50d10f043df931902d4194ea07ec57960f66a80449ff867bfe782b4c486ba78c \ + --hash=sha256:cdb2b453b8046ca4e3798eb1d84f3cce1446a0e8e7b5ef4efb600f19fc398145 + # via etils +iniconfig==2.0.0 \ + --hash=sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3 \ + --hash=sha256:b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374 + # via pytest +kiwisolver==1.4.5 \ + --hash=sha256:00bd361b903dc4bbf4eb165f24d1acbee754fce22ded24c3d56eec268658a5cf \ + --hash=sha256:040c1aebeda72197ef477a906782b5ab0d387642e93bda547336b8957c61022e \ + --hash=sha256:05703cf211d585109fcd72207a31bb170a0f22144d68298dc5e61b3c946518af \ + --hash=sha256:06f54715b7737c2fecdbf140d1afb11a33d59508a47bf11bb38ecf21dc9ab79f \ + --hash=sha256:0dc9db8e79f0036e8173c466d21ef18e1befc02de8bf8aa8dc0813a6dc8a7046 \ + --hash=sha256:0f114aa76dc1b8f636d077979c0ac22e7cd8f3493abbab152f20eb8d3cda71f3 \ + --hash=sha256:11863aa14a51fd6ec28688d76f1735f8f69ab1fabf388851a595d0721af042f5 \ + --hash=sha256:11c7de8f692fc99816e8ac50d1d1aef4f75126eefc33ac79aac02c099fd3db71 \ + --hash=sha256:11d011a7574eb3b82bcc9c1a1d35c1d7075677fdd15de527d91b46bd35e935ee \ + --hash=sha256:146d14bebb7f1dc4d5fbf74f8a6cb15ac42baadee8912eb84ac0b3b2a3dc6ac3 \ + --hash=sha256:15568384086b6df3c65353820a4473575dbad192e35010f622c6ce3eebd57af9 \ + --hash=sha256:19df6e621f6d8b4b9c4d45f40a66839294ff2bb235e64d2178f7522d9170ac5b \ + --hash=sha256:1b04139c4236a0f3aff534479b58f6f849a8b351e1314826c2d230849ed48985 \ + --hash=sha256:210ef2c3a1f03272649aff1ef992df2e724748918c4bc2d5a90352849eb40bea \ + --hash=sha256:2270953c0d8cdab5d422bee7d2007f043473f9d2999631c86a223c9db56cbd16 \ + --hash=sha256:2400873bccc260b6ae184b2b8a4fec0e4082d30648eadb7c3d9a13405d861e89 \ + --hash=sha256:2a40773c71d7ccdd3798f6489aaac9eee213d566850a9533f8d26332d626b82c \ + --hash=sha256:2c5674c4e74d939b9d91dda0fae10597ac7521768fec9e399c70a1f27e2ea2d9 \ + --hash=sha256:3195782b26fc03aa9c6913d5bad5aeb864bdc372924c093b0f1cebad603dd712 \ + --hash=sha256:31a82d498054cac9f6d0b53d02bb85811185bcb477d4b60144f915f3b3126342 \ + --hash=sha256:32d5cf40c4f7c7b3ca500f8985eb3fb3a7dfc023215e876f207956b5ea26632a \ + --hash=sha256:346f5343b9e3f00b8db8ba359350eb124b98c99efd0b408728ac6ebf38173958 \ + --hash=sha256:378a214a1e3bbf5ac4a8708304318b4f890da88c9e6a07699c4ae7174c09a68d \ + --hash=sha256:39b42c68602539407884cf70d6a480a469b93b81b7701378ba5e2328660c847a \ + --hash=sha256:3a2b053a0ab7a3960c98725cfb0bf5b48ba82f64ec95fe06f1d06c99b552e130 \ + --hash=sha256:3aba7311af82e335dd1e36ffff68aaca609ca6290c2cb6d821a39aa075d8e3ff \ + --hash=sha256:3cd32d6c13807e5c66a7cbb79f90b553642f296ae4518a60d8d76243b0ad2898 \ + --hash=sha256:3edd2fa14e68c9be82c5b16689e8d63d89fe927e56debd6e1dbce7a26a17f81b \ + --hash=sha256:4c380469bd3f970ef677bf2bcba2b6b0b4d5c75e7a020fb863ef75084efad66f \ + --hash=sha256:4e66e81a5779b65ac21764c295087de82235597a2293d18d943f8e9e32746265 \ + --hash=sha256:53abb58632235cd154176ced1ae8f0d29a6657aa1aa9decf50b899b755bc2b93 \ + --hash=sha256:5794cf59533bc3f1b1c821f7206a3617999db9fbefc345360aafe2e067514929 \ + --hash=sha256:59415f46a37f7f2efeec758353dd2eae1b07640d8ca0f0c42548ec4125492635 \ + --hash=sha256:59ec7b7c7e1a61061850d53aaf8e93db63dce0c936db1fda2658b70e4a1be709 \ + --hash=sha256:59edc41b24031bc25108e210c0def6f6c2191210492a972d585a06ff246bb79b \ + --hash=sha256:5a580c91d686376f0f7c295357595c5a026e6cbc3d77b7c36e290201e7c11ecb \ + --hash=sha256:5b94529f9b2591b7af5f3e0e730a4e0a41ea174af35a4fd067775f9bdfeee01a \ + --hash=sha256:5c7b3b3a728dc6faf3fc372ef24f21d1e3cee2ac3e9596691d746e5a536de920 \ + --hash=sha256:5c90ae8c8d32e472be041e76f9d2f2dbff4d0b0be8bd4041770eddb18cf49a4e \ + --hash=sha256:5e7139af55d1688f8b960ee9ad5adafc4ac17c1c473fe07133ac092310d76544 \ + --hash=sha256:5ff5cf3571589b6d13bfbfd6bcd7a3f659e42f96b5fd1c4830c4cf21d4f5ef45 \ + --hash=sha256:620ced262a86244e2be10a676b646f29c34537d0d9cc8eb26c08f53d98013390 \ + --hash=sha256:6512cb89e334e4700febbffaaa52761b65b4f5a3cf33f960213d5656cea36a77 \ + --hash=sha256:6c08e1312a9cf1074d17b17728d3dfce2a5125b2d791527f33ffbe805200a355 \ + --hash=sha256:6c3bd3cde54cafb87d74d8db50b909705c62b17c2099b8f2e25b461882e544ff \ + --hash=sha256:6ef7afcd2d281494c0a9101d5c571970708ad911d028137cd558f02b851c08b4 \ + --hash=sha256:7269d9e5f1084a653d575c7ec012ff57f0c042258bf5db0954bf551c158466e7 \ + --hash=sha256:72d40b33e834371fd330fb1472ca19d9b8327acb79a5821d4008391db8e29f20 \ + --hash=sha256:74d1b44c6cfc897df648cc9fdaa09bc3e7679926e6f96df05775d4fb3946571c \ + --hash=sha256:74db36e14a7d1ce0986fa104f7d5637aea5c82ca6326ed0ec5694280942d1162 \ + --hash=sha256:763773d53f07244148ccac5b084da5adb90bfaee39c197554f01b286cf869228 \ + --hash=sha256:76c6a5964640638cdeaa0c359382e5703e9293030fe730018ca06bc2010c4437 \ + --hash=sha256:76d9289ed3f7501012e05abb8358bbb129149dbd173f1f57a1bf1c22d19ab7cc \ + --hash=sha256:7931d8f1f67c4be9ba1dd9c451fb0eeca1a25b89e4d3f89e828fe12a519b782a \ + --hash=sha256:7b8b454bac16428b22560d0a1cf0a09875339cab69df61d7805bf48919415901 \ + --hash=sha256:7e5bab140c309cb3a6ce373a9e71eb7e4873c70c2dda01df6820474f9889d6d4 \ + --hash=sha256:83d78376d0d4fd884e2c114d0621624b73d2aba4e2788182d286309ebdeed770 \ + --hash=sha256:852542f9481f4a62dbb5dd99e8ab7aedfeb8fb6342349a181d4036877410f525 \ + --hash=sha256:85267bd1aa8880a9c88a8cb71e18d3d64d2751a790e6ca6c27b8ccc724bcd5ad \ + --hash=sha256:88a2df29d4724b9237fc0c6eaf2a1adae0cdc0b3e9f4d8e7dc54b16812d2d81a \ + --hash=sha256:88b9f257ca61b838b6f8094a62418421f87ac2a1069f7e896c36a7d86b5d4c29 \ + --hash=sha256:8ab3919a9997ab7ef2fbbed0cc99bb28d3c13e6d4b1ad36e97e482558a91be90 \ + --hash=sha256:92dea1ffe3714fa8eb6a314d2b3c773208d865a0e0d35e713ec54eea08a66250 \ + --hash=sha256:9407b6a5f0d675e8a827ad8742e1d6b49d9c1a1da5d952a67d50ef5f4170b18d \ + --hash=sha256:9408acf3270c4b6baad483865191e3e582b638b1654a007c62e3efe96f09a9a3 \ + --hash=sha256:955e8513d07a283056b1396e9a57ceddbd272d9252c14f154d450d227606eb54 \ + --hash=sha256:9db8ea4c388fdb0f780fe91346fd438657ea602d58348753d9fb265ce1bca67f \ + --hash=sha256:9eaa8b117dc8337728e834b9c6e2611f10c79e38f65157c4c38e9400286f5cb1 \ + --hash=sha256:a51a263952b1429e429ff236d2f5a21c5125437861baeed77f5e1cc2d2c7c6da \ + --hash=sha256:a6aa6315319a052b4ee378aa171959c898a6183f15c1e541821c5c59beaa0238 \ + --hash=sha256:aa12042de0171fad672b6c59df69106d20d5596e4f87b5e8f76df757a7c399aa \ + --hash=sha256:aaf7be1207676ac608a50cd08f102f6742dbfc70e8d60c4db1c6897f62f71523 \ + --hash=sha256:b0157420efcb803e71d1b28e2c287518b8808b7cf1ab8af36718fd0a2c453eb0 \ + --hash=sha256:b3f7e75f3015df442238cca659f8baa5f42ce2a8582727981cbfa15fee0ee205 \ + --hash=sha256:b9098e0049e88c6a24ff64545cdfc50807818ba6c1b739cae221bbbcbc58aad3 \ + --hash=sha256:ba55dce0a9b8ff59495ddd050a0225d58bd0983d09f87cfe2b6aec4f2c1234e4 \ + --hash=sha256:bb86433b1cfe686da83ce32a9d3a8dd308e85c76b60896d58f082136f10bffac \ + --hash=sha256:bbea0db94288e29afcc4c28afbf3a7ccaf2d7e027489c449cf7e8f83c6346eb9 \ + --hash=sha256:bbf1d63eef84b2e8c89011b7f2235b1e0bf7dacc11cac9431fc6468e99ac77fb \ + --hash=sha256:c7940c1dc63eb37a67721b10d703247552416f719c4188c54e04334321351ced \ + --hash=sha256:c9bf3325c47b11b2e51bca0824ea217c7cd84491d8ac4eefd1e409705ef092bd \ + --hash=sha256:cdc8a402aaee9a798b50d8b827d7ecf75edc5fb35ea0f91f213ff927c15f4ff0 \ + --hash=sha256:ceec1a6bc6cab1d6ff5d06592a91a692f90ec7505d6463a88a52cc0eb58545da \ + --hash=sha256:cfe6ab8da05c01ba6fbea630377b5da2cd9bcbc6338510116b01c1bc939a2c18 \ + --hash=sha256:d099e745a512f7e3bbe7249ca835f4d357c586d78d79ae8f1dcd4d8adeb9bda9 \ + --hash=sha256:d0ef46024e6a3d79c01ff13801cb19d0cad7fd859b15037aec74315540acc276 \ + --hash=sha256:d2e5a98f0ec99beb3c10e13b387f8db39106d53993f498b295f0c914328b1333 \ + --hash=sha256:da4cfb373035def307905d05041c1d06d8936452fe89d464743ae7fb8371078b \ + --hash=sha256:da802a19d6e15dffe4b0c24b38b3af68e6c1a68e6e1d8f30148c83864f3881db \ + --hash=sha256:dced8146011d2bc2e883f9bd68618b8247387f4bbec46d7392b3c3b032640126 \ + --hash=sha256:dfdd7c0b105af050eb3d64997809dc21da247cf44e63dc73ff0fd20b96be55a9 \ + --hash=sha256:e368f200bbc2e4f905b8e71eb38b3c04333bddaa6a2464a6355487b02bb7fb09 \ + --hash=sha256:e391b1f0a8a5a10ab3b9bb6afcfd74f2175f24f8975fb87ecae700d1503cdee0 \ + --hash=sha256:e57e563a57fb22a142da34f38acc2fc1a5c864bc29ca1517a88abc963e60d6ec \ + --hash=sha256:e5d706eba36b4c4d5bc6c6377bb6568098765e990cfc21ee16d13963fab7b3e7 \ + --hash=sha256:ec20916e7b4cbfb1f12380e46486ec4bcbaa91a9c448b97023fde0d5bbf9e4ff \ + --hash=sha256:f1d072c2eb0ad60d4c183f3fb44ac6f73fb7a8f16a2694a91f988275cbf352f9 \ + --hash=sha256:f846c260f483d1fd217fe5ed7c173fb109efa6b1fc8381c8b7552c5781756192 \ + --hash=sha256:f91de7223d4c7b793867797bacd1ee53bfe7359bd70d27b7b58a04efbb9436c8 \ + --hash=sha256:faae4860798c31530dd184046a900e652c95513796ef51a12bc086710c2eec4d \ + --hash=sha256:fc579bf0f502e54926519451b920e875f433aceb4624a3646b3252b5caa9e0b6 \ + --hash=sha256:fcc700eadbbccbf6bc1bcb9dbe0786b4b1cb91ca0dcda336eef5c2beed37b797 \ + --hash=sha256:fd32ea360bcbb92d28933fc05ed09bffcb1704ba3fc7942e81db0fd4f81a7892 \ + --hash=sha256:fdb7adb641a0d13bdcd4ef48e062363d8a9ad4a182ac7647ec88f695e719ae9f + # via matplotlib +markdown-it-py==3.0.0 \ + --hash=sha256:355216845c60bd96232cd8d8c40e8f9765cc86f46880e43a8fd22dc1a1a8cab1 \ + --hash=sha256:e3f60a94fa066dc52ec76661e37c851cb232d92f9886b15cb560aaada2df8feb + # via rich +matplotlib==3.9.0 ; python_version >= "3.11" \ + --hash=sha256:063af8587fceeac13b0936c42a2b6c732c2ab1c98d38abc3337e430e1ff75e38 \ + --hash=sha256:06a478f0d67636554fa78558cfbcd7b9dba85b51f5c3b5a0c9be49010cf5f321 \ + --hash=sha256:0a490715b3b9984fa609116481b22178348c1a220a4499cda79132000a79b4db \ + --hash=sha256:0fc51eaa5262553868461c083d9adadb11a6017315f3a757fc45ec6ec5f02888 \ + --hash=sha256:13beb4840317d45ffd4183a778685e215939be7b08616f431c7795276e067463 \ + --hash=sha256:290d304e59be2b33ef5c2d768d0237f5bd132986bdcc66f80bc9bcc300066a03 \ + --hash=sha256:2bcee1dffaf60fe7656183ac2190bd630842ff87b3153afb3e384d966b57fe56 \ + --hash=sha256:2e7f03e5cbbfacdd48c8ea394d365d91ee8f3cae7e6ec611409927b5ed997ee4 \ + --hash=sha256:3f988bafb0fa39d1074ddd5bacd958c853e11def40800c5824556eb630f94d3b \ + --hash=sha256:52146fc3bd7813cc784562cb93a15788be0b2875c4655e2cc6ea646bfa30344b \ + --hash=sha256:550cdda3adbd596078cca7d13ed50b77879104e2e46392dcd7c75259d8f00e85 \ + --hash=sha256:616fabf4981a3b3c5a15cd95eba359c8489c4e20e03717aea42866d8d0465956 \ + --hash=sha256:76cce0f31b351e3551d1f3779420cf8f6ec0d4a8cf9c0237a3b549fd28eb4abb \ + --hash=sha256:7ff2e239c26be4f24bfa45860c20ffccd118d270c5b5d081fa4ea409b5469fcd \ + --hash=sha256:8146ce83cbc5dc71c223a74a1996d446cd35cfb6a04b683e1446b7e6c73603b7 \ + --hash=sha256:81c40af649d19c85f8073e25e5806926986806fa6d54be506fbf02aef47d5a89 \ + --hash=sha256:9a2fa6d899e17ddca6d6526cf6e7ba677738bf2a6a9590d702c277204a7c6152 \ + --hash=sha256:a5be985db2596d761cdf0c2eaf52396f26e6a64ab46bd8cd810c48972349d1be \ + --hash=sha256:af4001b7cae70f7eaacfb063db605280058246de590fa7874f00f62259f2df7e \ + --hash=sha256:bd4f2831168afac55b881db82a7730992aa41c4f007f1913465fb182d6fb20c0 \ + --hash=sha256:bdd1ecbe268eb3e7653e04f451635f0fb0f77f07fd070242b44c076c9106da84 \ + --hash=sha256:c53aeb514ccbbcbab55a27f912d79ea30ab21ee0531ee2c09f13800efb272674 \ + --hash=sha256:c79f3a585f1368da6049318bdf1f85568d8d04b2e89fc24b7e02cc9b62017382 \ + --hash=sha256:cd53c79fd02f1c1808d2cfc87dd3cf4dbc63c5244a58ee7944497107469c8d8a \ + --hash=sha256:d38e85a1a6d732f645f1403ce5e6727fd9418cd4574521d5803d3d94911038e5 \ + --hash=sha256:d91a4ffc587bacf5c4ce4ecfe4bcd23a4b675e76315f2866e588686cc97fccdf \ + --hash=sha256:e6d29ea6c19e34b30fb7d88b7081f869a03014f66fe06d62cc77d5a6ea88ed7a \ + --hash=sha256:eaf3978060a106fab40c328778b148f590e27f6fa3cd15a19d6892575bce387d \ + --hash=sha256:fe428e191ea016bb278758c8ee82a8129c51d81d8c4bc0846c09e7e8e9057241 + # via -r build/requirements.in +mdurl==0.1.2 \ + --hash=sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8 \ + --hash=sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba + # via markdown-it-py +ml-dtypes==0.4.0 \ + --hash=sha256:03e7cda6ef164eed0abb31df69d2c00c3a5ab3e2610b6d4c42183a43329c72a5 \ + --hash=sha256:2bb83fd064db43e67e67d021e547698af4c8d5c6190f2e9b1c53c09f6ff5531d \ + --hash=sha256:3b67ec73a697c88c1122038e0de46520e48dc2ec876d42cf61bc5efe3c0b7675 \ + --hash=sha256:41affb38fdfe146e3db226cf2953021184d6f0c4ffab52136613e9601706e368 \ + --hash=sha256:43cf4356a0fe2eeac6d289018d0734e17a403bdf1fd911953c125dd0358edcc0 \ + --hash=sha256:723af6346447268a3cf0b7356e963d80ecb5732b5279b2aa3fa4b9fc8297c85e \ + --hash=sha256:75b4faf99d0711b81f393db36d210b4255fd419f6f790bc6c1b461f95ffb7a9e \ + --hash=sha256:93afe37f3a879d652ec9ef1fc47612388890660a2657fbb5747256c3b818fd81 \ + --hash=sha256:a15d96d090aebb55ee85173d1775ae325a001aab607a76c8ea0b964ccd6b5364 \ + --hash=sha256:ad6849a2db386b38e4d54fe13eb3293464561780531a918f8ef4c8169170dd49 \ + --hash=sha256:bdf689be7351cc3c95110c910c1b864002f113e682e44508910c849e144f3df1 \ + --hash=sha256:c83e4d443962d891d51669ff241d5aaad10a8d3d37a81c5532a45419885d591c \ + --hash=sha256:e1e2f4237b459a63c97c2c9f449baa637d7e4c20addff6a9bac486f22432f3b6 \ + --hash=sha256:eaa32979ebfde3a0d7c947cafbf79edc1ec77ac05ad0780ee86c1d8df70f2259 \ + --hash=sha256:eaf197e72f4f7176a19fe3cb8b61846b38c6757607e7bf9cd4b1d84cd3e74deb \ + --hash=sha256:ee9f91d4c4f9959a7e1051c141dc565f39e54435618152219769e24f5e9a4d06 \ + --hash=sha256:f1724ddcdf5edbaf615a62110af47407f1719b8d02e68ccee60683acb5f74da1 + # via -r build/requirements.in +mpmath==1.4.0a1 \ + --hash=sha256:78884400f439f500fa76be0121a8f9598313d87664863a192e1185ddbd7ae97f \ + --hash=sha256:f8b7b5a3a1726ab6e8c898eb2157426b82c482ab1ab8ffed9f88bb9e07c6e9c1 + # via -r build/test-requirements.txt +numpy==2.0.0 \ + --hash=sha256:04494f6ec467ccb5369d1808570ae55f6ed9b5809d7f035059000a37b8d7e86f \ + --hash=sha256:0a43f0974d501842866cc83471bdb0116ba0dffdbaac33ec05e6afed5b615238 \ + --hash=sha256:0e50842b2295ba8414c8c1d9d957083d5dfe9e16828b37de883f51fc53c4016f \ + --hash=sha256:0ec84b9ba0654f3b962802edc91424331f423dcf5d5f926676e0150789cb3d95 \ + --hash=sha256:17067d097ed036636fa79f6a869ac26df7db1ba22039d962422506640314933a \ + --hash=sha256:1cde1753efe513705a0c6d28f5884e22bdc30438bf0085c5c486cdaff40cd67a \ + --hash=sha256:1e72728e7501a450288fc8e1f9ebc73d90cfd4671ebbd631f3e7857c39bd16f2 \ + --hash=sha256:2635dbd200c2d6faf2ef9a0d04f0ecc6b13b3cad54f7c67c61155138835515d2 \ + --hash=sha256:2ce46fd0b8a0c947ae047d222f7136fc4d55538741373107574271bc00e20e8f \ + --hash=sha256:34f003cb88b1ba38cb9a9a4a3161c1604973d7f9d5552c38bc2f04f829536609 \ + --hash=sha256:354f373279768fa5a584bac997de6a6c9bc535c482592d7a813bb0c09be6c76f \ + --hash=sha256:38ecb5b0582cd125f67a629072fed6f83562d9dd04d7e03256c9829bdec027ad \ + --hash=sha256:3e8e01233d57639b2e30966c63d36fcea099d17c53bf424d77f088b0f4babd86 \ + --hash=sha256:3f6bed7f840d44c08ebdb73b1825282b801799e325bcbdfa6bc5c370e5aecc65 \ + --hash=sha256:4554eb96f0fd263041baf16cf0881b3f5dafae7a59b1049acb9540c4d57bc8cb \ + --hash=sha256:46e161722e0f619749d1cd892167039015b2c2817296104487cd03ed4a955995 \ + --hash=sha256:49d9f7d256fbc804391a7f72d4a617302b1afac1112fac19b6c6cec63fe7fe8a \ + --hash=sha256:4d2f62e55a4cd9c58c1d9a1c9edaedcd857a73cb6fda875bf79093f9d9086f85 \ + --hash=sha256:5f64641b42b2429f56ee08b4f427a4d2daf916ec59686061de751a55aafa22e4 \ + --hash=sha256:63b92c512d9dbcc37f9d81b123dec99fdb318ba38c8059afc78086fe73820275 \ + --hash=sha256:6d7696c615765091cc5093f76fd1fa069870304beaccfd58b5dcc69e55ef49c1 \ + --hash=sha256:79e843d186c8fb1b102bef3e2bc35ef81160ffef3194646a7fdd6a73c6b97196 \ + --hash=sha256:821eedb7165ead9eebdb569986968b541f9908979c2da8a4967ecac4439bae3d \ + --hash=sha256:84554fc53daa8f6abf8e8a66e076aff6ece62de68523d9f665f32d2fc50fd66e \ + --hash=sha256:8d83bb187fb647643bd56e1ae43f273c7f4dbcdf94550d7938cfc32566756514 \ + --hash=sha256:903703372d46bce88b6920a0cd86c3ad82dae2dbef157b5fc01b70ea1cfc430f \ + --hash=sha256:9416a5c2e92ace094e9f0082c5fd473502c91651fb896bc17690d6fc475128d6 \ + --hash=sha256:9a1712c015831da583b21c5bfe15e8684137097969c6d22e8316ba66b5baabe4 \ + --hash=sha256:9c27f0946a3536403efb0e1c28def1ae6730a72cd0d5878db38824855e3afc44 \ + --hash=sha256:a356364941fb0593bb899a1076b92dfa2029f6f5b8ba88a14fd0984aaf76d0df \ + --hash=sha256:a7039a136017eaa92c1848152827e1424701532ca8e8967fe480fe1569dae581 \ + --hash=sha256:acd3a644e4807e73b4e1867b769fbf1ce8c5d80e7caaef0d90dcdc640dfc9787 \ + --hash=sha256:ad0c86f3455fbd0de6c31a3056eb822fc939f81b1618f10ff3406971893b62a5 \ + --hash=sha256:b4c76e3d4c56f145d41b7b6751255feefae92edbc9a61e1758a98204200f30fc \ + --hash=sha256:b6f6a8f45d0313db07d6d1d37bd0b112f887e1369758a5419c0370ba915b3871 \ + --hash=sha256:c5a59996dc61835133b56a32ebe4ef3740ea5bc19b3983ac60cc32be5a665d54 \ + --hash=sha256:c73aafd1afca80afecb22718f8700b40ac7cab927b8abab3c3e337d70e10e5a2 \ + --hash=sha256:cee6cc0584f71adefe2c908856ccc98702baf95ff80092e4ca46061538a2ba98 \ + --hash=sha256:cef04d068f5fb0518a77857953193b6bb94809a806bd0a14983a8f12ada060c9 \ + --hash=sha256:cf5d1c9e6837f8af9f92b6bd3e86d513cdc11f60fd62185cc49ec7d1aba34864 \ + --hash=sha256:e61155fae27570692ad1d327e81c6cf27d535a5d7ef97648a17d922224b216de \ + --hash=sha256:e7f387600d424f91576af20518334df3d97bc76a300a755f9a8d6e4f5cadd289 \ + --hash=sha256:ed08d2703b5972ec736451b818c2eb9da80d66c3e84aed1deeb0c345fefe461b \ + --hash=sha256:fbd6acc766814ea6443628f4e6751d0da6593dae29c08c0b2606164db026970c \ + --hash=sha256:feff59f27338135776f6d4e2ec7aeeac5d5f7a08a83e80869121ef8164b74af9 + # via + # -r build/requirements.in + # -r build/test-requirements.txt + # contourpy + # matplotlib + # ml-dtypes + # opt-einsum + # scipy +opt-einsum==3.3.0 \ + --hash=sha256:2455e59e3947d3c275477df7f5205b30635e266fe6dc300e3d9f9646bfcea147 \ + --hash=sha256:59f6475f77bbc37dcf7cd748519c0ec60722e91e63ca114e68821c0c54a46549 + # via -r build/requirements.in +packaging==24.0 \ + --hash=sha256:2ddfb553fdf02fb784c234c7ba6ccc288296ceabec964ad2eae3777778130bc5 \ + --hash=sha256:eb82c5e3e56209074766e6885bb04b8c38a0c015d0a30036ebe7ece34c9989e9 + # via + # build + # matplotlib + # pytest +pillow==10.3.0 \ + --hash=sha256:048ad577748b9fa4a99a0548c64f2cb8d672d5bf2e643a739ac8faff1164238c \ + --hash=sha256:048eeade4c33fdf7e08da40ef402e748df113fd0b4584e32c4af74fe78baaeb2 \ + --hash=sha256:0ba26351b137ca4e0db0342d5d00d2e355eb29372c05afd544ebf47c0956ffeb \ + --hash=sha256:0ea2a783a2bdf2a561808fe4a7a12e9aa3799b701ba305de596bc48b8bdfce9d \ + --hash=sha256:1530e8f3a4b965eb6a7785cf17a426c779333eb62c9a7d1bbcf3ffd5bf77a4aa \ + --hash=sha256:16563993329b79513f59142a6b02055e10514c1a8e86dca8b48a893e33cf91e3 \ + --hash=sha256:19aeb96d43902f0a783946a0a87dbdad5c84c936025b8419da0a0cd7724356b1 \ + --hash=sha256:1a1d1915db1a4fdb2754b9de292642a39a7fb28f1736699527bb649484fb966a \ + --hash=sha256:1b87bd9d81d179bd8ab871603bd80d8645729939f90b71e62914e816a76fc6bd \ + --hash=sha256:1dfc94946bc60ea375cc39cff0b8da6c7e5f8fcdc1d946beb8da5c216156ddd8 \ + --hash=sha256:2034f6759a722da3a3dbd91a81148cf884e91d1b747992ca288ab88c1de15999 \ + --hash=sha256:261ddb7ca91fcf71757979534fb4c128448b5b4c55cb6152d280312062f69599 \ + --hash=sha256:2ed854e716a89b1afcedea551cd85f2eb2a807613752ab997b9974aaa0d56936 \ + --hash=sha256:3102045a10945173d38336f6e71a8dc71bcaeed55c3123ad4af82c52807b9375 \ + --hash=sha256:339894035d0ede518b16073bdc2feef4c991ee991a29774b33e515f1d308e08d \ + --hash=sha256:412444afb8c4c7a6cc11a47dade32982439925537e483be7c0ae0cf96c4f6a0b \ + --hash=sha256:4203efca580f0dd6f882ca211f923168548f7ba334c189e9eab1178ab840bf60 \ + --hash=sha256:45ebc7b45406febf07fef35d856f0293a92e7417ae7933207e90bf9090b70572 \ + --hash=sha256:4b5ec25d8b17217d635f8935dbc1b9aa5907962fae29dff220f2659487891cd3 \ + --hash=sha256:4c8e73e99da7db1b4cad7f8d682cf6abad7844da39834c288fbfa394a47bbced \ + --hash=sha256:4e6f7d1c414191c1199f8996d3f2282b9ebea0945693fb67392c75a3a320941f \ + --hash=sha256:4eaa22f0d22b1a7e93ff0a596d57fdede2e550aecffb5a1ef1106aaece48e96b \ + --hash=sha256:50b8eae8f7334ec826d6eeffaeeb00e36b5e24aa0b9df322c247539714c6df19 \ + --hash=sha256:50fd3f6b26e3441ae07b7c979309638b72abc1a25da31a81a7fbd9495713ef4f \ + --hash=sha256:51243f1ed5161b9945011a7360e997729776f6e5d7005ba0c6879267d4c5139d \ + --hash=sha256:5d512aafa1d32efa014fa041d38868fda85028e3f930a96f85d49c7d8ddc0383 \ + --hash=sha256:5f77cf66e96ae734717d341c145c5949c63180842a545c47a0ce7ae52ca83795 \ + --hash=sha256:6b02471b72526ab8a18c39cb7967b72d194ec53c1fd0a70b050565a0f366d355 \ + --hash=sha256:6fb1b30043271ec92dc65f6d9f0b7a830c210b8a96423074b15c7bc999975f57 \ + --hash=sha256:7161ec49ef0800947dc5570f86568a7bb36fa97dd09e9827dc02b718c5643f09 \ + --hash=sha256:72d622d262e463dfb7595202d229f5f3ab4b852289a1cd09650362db23b9eb0b \ + --hash=sha256:74d28c17412d9caa1066f7a31df8403ec23d5268ba46cd0ad2c50fb82ae40462 \ + --hash=sha256:78618cdbccaa74d3f88d0ad6cb8ac3007f1a6fa5c6f19af64b55ca170bfa1edf \ + --hash=sha256:793b4e24db2e8742ca6423d3fde8396db336698c55cd34b660663ee9e45ed37f \ + --hash=sha256:798232c92e7665fe82ac085f9d8e8ca98826f8e27859d9a96b41d519ecd2e49a \ + --hash=sha256:81d09caa7b27ef4e61cb7d8fbf1714f5aec1c6b6c5270ee53504981e6e9121ad \ + --hash=sha256:8ab74c06ffdab957d7670c2a5a6e1a70181cd10b727cd788c4dd9005b6a8acd9 \ + --hash=sha256:8eb0908e954d093b02a543dc963984d6e99ad2b5e36503d8a0aaf040505f747d \ + --hash=sha256:90b9e29824800e90c84e4022dd5cc16eb2d9605ee13f05d47641eb183cd73d45 \ + --hash=sha256:9797a6c8fe16f25749b371c02e2ade0efb51155e767a971c61734b1bf6293994 \ + --hash=sha256:9d2455fbf44c914840c793e89aa82d0e1763a14253a000743719ae5946814b2d \ + --hash=sha256:9d3bea1c75f8c53ee4d505c3e67d8c158ad4df0d83170605b50b64025917f338 \ + --hash=sha256:9e2ec1e921fd07c7cda7962bad283acc2f2a9ccc1b971ee4b216b75fad6f0463 \ + --hash=sha256:9e91179a242bbc99be65e139e30690e081fe6cb91a8e77faf4c409653de39451 \ + --hash=sha256:a0eaa93d054751ee9964afa21c06247779b90440ca41d184aeb5d410f20ff591 \ + --hash=sha256:a2c405445c79c3f5a124573a051062300936b0281fee57637e706453e452746c \ + --hash=sha256:aa7e402ce11f0885305bfb6afb3434b3cd8f53b563ac065452d9d5654c7b86fd \ + --hash=sha256:aff76a55a8aa8364d25400a210a65ff59d0168e0b4285ba6bf2bd83cf675ba32 \ + --hash=sha256:b09b86b27a064c9624d0a6c54da01c1beaf5b6cadfa609cf63789b1d08a797b9 \ + --hash=sha256:b14f16f94cbc61215115b9b1236f9c18403c15dd3c52cf629072afa9d54c1cbf \ + --hash=sha256:b50811d664d392f02f7761621303eba9d1b056fb1868c8cdf4231279645c25f5 \ + --hash=sha256:b7bc2176354defba3edc2b9a777744462da2f8e921fbaf61e52acb95bafa9828 \ + --hash=sha256:c78e1b00a87ce43bb37642c0812315b411e856a905d58d597750eb79802aaaa3 \ + --hash=sha256:c83341b89884e2b2e55886e8fbbf37c3fa5efd6c8907124aeb72f285ae5696e5 \ + --hash=sha256:ca2870d5d10d8726a27396d3ca4cf7976cec0f3cb706debe88e3a5bd4610f7d2 \ + --hash=sha256:ccce24b7ad89adb5a1e34a6ba96ac2530046763912806ad4c247356a8f33a67b \ + --hash=sha256:cd5e14fbf22a87321b24c88669aad3a51ec052eb145315b3da3b7e3cc105b9a2 \ + --hash=sha256:ce49c67f4ea0609933d01c0731b34b8695a7a748d6c8d186f95e7d085d2fe475 \ + --hash=sha256:d33891be6df59d93df4d846640f0e46f1a807339f09e79a8040bc887bdcd7ed3 \ + --hash=sha256:d3b2348a78bc939b4fed6552abfd2e7988e0f81443ef3911a4b8498ca084f6eb \ + --hash=sha256:d886f5d353333b4771d21267c7ecc75b710f1a73d72d03ca06df49b09015a9ef \ + --hash=sha256:d93480005693d247f8346bc8ee28c72a2191bdf1f6b5db469c096c0c867ac015 \ + --hash=sha256:dc1a390a82755a8c26c9964d457d4c9cbec5405896cba94cf51f36ea0d855002 \ + --hash=sha256:dd78700f5788ae180b5ee8902c6aea5a5726bac7c364b202b4b3e3ba2d293170 \ + --hash=sha256:e46f38133e5a060d46bd630faa4d9fa0202377495df1f068a8299fd78c84de84 \ + --hash=sha256:e4b878386c4bf293578b48fc570b84ecfe477d3b77ba39a6e87150af77f40c57 \ + --hash=sha256:f0d0591a0aeaefdaf9a5e545e7485f89910c977087e7de2b6c388aec32011e9f \ + --hash=sha256:fdcbb4068117dfd9ce0138d068ac512843c52295ed996ae6dd1faf537b6dbc27 \ + --hash=sha256:ff61bfd9253c3915e6d41c651d5f962da23eda633cf02262990094a18a55371a + # via + # -r build/test-requirements.txt + # matplotlib +pluggy==1.5.0 \ + --hash=sha256:2cffa88e94fdc978c4c574f15f9e59b7f4201d439195c3715ca9e2486f1d0cf1 \ + --hash=sha256:44e1ad92c8ca002de6377e165f3e0f1be63266ab4d554740532335b9d75ea669 + # via pytest +portpicker==1.6.0 \ + --hash=sha256:b2787a41404cf7edbe29b07b9e0ed863b09f2665dcc01c1eb0c2261c1e7d0755 \ + --hash=sha256:bd507fd6f96f65ee02781f2e674e9dc6c99bbfa6e3c39992e3916204c9d431fa + # via -r build/test-requirements.txt +psutil==5.9.8 \ + --hash=sha256:02615ed8c5ea222323408ceba16c60e99c3f91639b07da6373fb7e6539abc56d \ + --hash=sha256:05806de88103b25903dff19bb6692bd2e714ccf9e668d050d144012055cbca73 \ + --hash=sha256:26bd09967ae00920df88e0352a91cff1a78f8d69b3ecabbfe733610c0af486c8 \ + --hash=sha256:27cc40c3493bb10de1be4b3f07cae4c010ce715290a5be22b98493509c6299e2 \ + --hash=sha256:36f435891adb138ed3c9e58c6af3e2e6ca9ac2f365efe1f9cfef2794e6c93b4e \ + --hash=sha256:50187900d73c1381ba1454cf40308c2bf6f34268518b3f36a9b663ca87e65e36 \ + --hash=sha256:611052c4bc70432ec770d5d54f64206aa7203a101ec273a0cd82418c86503bb7 \ + --hash=sha256:6be126e3225486dff286a8fb9a06246a5253f4c7c53b475ea5f5ac934e64194c \ + --hash=sha256:7d79560ad97af658a0f6adfef8b834b53f64746d45b403f225b85c5c2c140eee \ + --hash=sha256:8cb6403ce6d8e047495a701dc7c5bd788add903f8986d523e3e20b98b733e421 \ + --hash=sha256:8db4c1b57507eef143a15a6884ca10f7c73876cdf5d51e713151c1236a0e68cf \ + --hash=sha256:aee678c8720623dc456fa20659af736241f575d79429a0e5e9cf88ae0605cc81 \ + --hash=sha256:bc56c2a1b0d15aa3eaa5a60c9f3f8e3e565303b465dbf57a1b730e7a2b9844e0 \ + --hash=sha256:bd1184ceb3f87651a67b2708d4c3338e9b10c5df903f2e3776b62303b26cb631 \ + --hash=sha256:d06016f7f8625a1825ba3732081d77c94589dca78b7a3fc072194851e88461a4 \ + --hash=sha256:d16bbddf0693323b8c6123dd804100241da461e41d6e332fb0ba6058f630f8c8 + # via portpicker +pygments==2.18.0 \ + --hash=sha256:786ff802f32e91311bff3889f6e9a86e81505fe99f2735bb6d60ae0c5004f199 \ + --hash=sha256:b8e6aca0523f3ab76fee51799c488e38782ac06eafcf95e7ba832985c8e7b13a + # via rich +pyparsing==3.1.2 \ + --hash=sha256:a1bac0ce561155ecc3ed78ca94d3c9378656ad4c94c1270de543f621420f94ad \ + --hash=sha256:f9db75911801ed778fe61bb643079ff86601aca99fcae6345aa67292038fb742 + # via matplotlib +pyproject-hooks==1.1.0 \ + --hash=sha256:4b37730834edbd6bd37f26ece6b44802fb1c1ee2ece0e54ddff8bfc06db86965 \ + --hash=sha256:7ceeefe9aec63a1064c18d939bdc3adf2d8aa1988a510afec15151578b232aa2 + # via build +pytest==8.2.0 \ + --hash=sha256:1733f0620f6cda4095bbf0d9ff8022486e91892245bb9e7d5542c018f612f233 \ + --hash=sha256:d507d4482197eac0ba2bae2e9babf0672eb333017bcedaa5fb1a3d42c1174b3f + # via pytest-xdist +pytest-xdist==3.6.1 \ + --hash=sha256:9ed4adfb68a016610848639bb7e02c9352d5d9f03d04809919e2dafc3be4cca7 \ + --hash=sha256:ead156a4db231eec769737f57668ef58a2084a34b2e55c4a8fa20d861107300d + # via -r build/test-requirements.txt +python-dateutil==2.9.0.post0 \ + --hash=sha256:37dd54208da7e1cd875388217d5e00ebd4179249f90fb72437e91a35459a0ad3 \ + --hash=sha256:a8b2bc7bffae282281c8140a97d3aa9c14da0b136dfe83f850eea9a5f7470427 + # via matplotlib +rich==13.7.1 \ + --hash=sha256:4edbae314f59eb482f54e9e30bf00d33350aaa94f4bfcd4e9e3110e64d0d7222 \ + --hash=sha256:9be308cb1fe2f1f57d67ce99e95af38a1e2bc71ad9813b0e247cf7ffbcc3a432 + # via -r build/test-requirements.txt +scipy==1.13.1 \ + --hash=sha256:017367484ce5498445aade74b1d5ab377acdc65e27095155e448c88497755a5d \ + --hash=sha256:095a87a0312b08dfd6a6155cbbd310a8c51800fc931b8c0b84003014b874ed3c \ + --hash=sha256:20335853b85e9a49ff7572ab453794298bcf0354d8068c5f6775a0eabf350aca \ + --hash=sha256:27e52b09c0d3a1d5b63e1105f24177e544a222b43611aaf5bc44d4a0979e32f9 \ + --hash=sha256:2831f0dc9c5ea9edd6e51e6e769b655f08ec6db6e2e10f86ef39bd32eb11da54 \ + --hash=sha256:2ac65fb503dad64218c228e2dc2d0a0193f7904747db43014645ae139c8fad16 \ + --hash=sha256:392e4ec766654852c25ebad4f64e4e584cf19820b980bc04960bca0b0cd6eaa2 \ + --hash=sha256:436bbb42a94a8aeef855d755ce5a465479c721e9d684de76bf61a62e7c2b81d5 \ + --hash=sha256:45484bee6d65633752c490404513b9ef02475b4284c4cfab0ef946def50b3f59 \ + --hash=sha256:54f430b00f0133e2224c3ba42b805bfd0086fe488835effa33fa291561932326 \ + --hash=sha256:5713f62f781eebd8d597eb3f88b8bf9274e79eeabf63afb4a737abc6c84ad37b \ + --hash=sha256:5d72782f39716b2b3509cd7c33cdc08c96f2f4d2b06d51e52fb45a19ca0c86a1 \ + --hash=sha256:637e98dcf185ba7f8e663e122ebf908c4702420477ae52a04f9908707456ba4d \ + --hash=sha256:8335549ebbca860c52bf3d02f80784e91a004b71b059e3eea9678ba994796a24 \ + --hash=sha256:949ae67db5fa78a86e8fa644b9a6b07252f449dcf74247108c50e1d20d2b4627 \ + --hash=sha256:a014c2b3697bde71724244f63de2476925596c24285c7a637364761f8710891c \ + --hash=sha256:a78b4b3345f1b6f68a763c6e25c0c9a23a9fd0f39f5f3d200efe8feda560a5fa \ + --hash=sha256:cdd7dacfb95fea358916410ec61bbc20440f7860333aee6d882bb8046264e949 \ + --hash=sha256:cfa31f1def5c819b19ecc3a8b52d28ffdcc7ed52bb20c9a7589669dd3c250989 \ + --hash=sha256:d533654b7d221a6a97304ab63c41c96473ff04459e404b83275b60aa8f4b7004 \ + --hash=sha256:d605e9c23906d1994f55ace80e0125c587f96c020037ea6aa98d01b4bd2e222f \ + --hash=sha256:de3ade0e53bc1f21358aa74ff4830235d716211d7d077e340c7349bc3542e884 \ + --hash=sha256:e89369d27f9e7b0884ae559a3a956e77c02114cc60a6058b4e5011572eea9299 \ + --hash=sha256:eccfa1906eacc02de42d70ef4aecea45415f5be17e72b61bafcfd329bdc52e94 \ + --hash=sha256:f26264b282b9da0952a024ae34710c2aff7d27480ee91a2e82b7b7073c24722f + # via -r build/requirements.in +six==1.16.0 \ + --hash=sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926 \ + --hash=sha256:8abb2f1d86890a2dfb989f9a77cfcfd3e47c2a354b01111771326f8aa26e0254 + # via python-dateutil +sortedcontainers==2.4.0 \ + --hash=sha256:25caa5a06cc30b6b83d11423433f65d1f9d76c4c6a0c90e3379eaa43b9bfdb88 \ + --hash=sha256:a163dcaede0f1c021485e957a39245190e74249897e2ae4b2aa38595db237ee0 + # via hypothesis +typing-extensions==4.12.0rc1 \ + --hash=sha256:be199d06d8f09ca2c9425e3aa04a9afba33e892fe079dea959e72df7f8442343 \ + --hash=sha256:f933a7b288a919ca97adbff656e52ff81f7ff25d98a2aabb9355ca4090f772fe + # via etils +wheel==0.43.0 \ + --hash=sha256:465ef92c69fa5c5da2d1cf8ac40559a8c940886afcef87dcf14b9470862f1d85 \ + --hash=sha256:55c570405f142630c6b9f72fe09d9b67cf1477fcf543ae5b8dcb1f5b7377da81 + # via -r build/test-requirements.txt +zipp==3.18.2 \ + --hash=sha256:6278d9ddbcfb1f1089a88fde84481528b07b0e10474e09dcfe53dad4069fa059 \ + --hash=sha256:dce197b859eb796242b0622af1b8beb0a722d52aa2f57133ead08edd5bf5374e + # via etils +zstandard==0.22.0 \ + --hash=sha256:11f0d1aab9516a497137b41e3d3ed4bbf7b2ee2abc79e5c8b010ad286d7464bd \ + --hash=sha256:1958100b8a1cc3f27fa21071a55cb2ed32e9e5df4c3c6e661c193437f171cba2 \ + --hash=sha256:1a90ba9a4c9c884bb876a14be2b1d216609385efb180393df40e5172e7ecf356 \ + --hash=sha256:1d43501f5f31e22baf822720d82b5547f8a08f5386a883b32584a185675c8fbf \ + --hash=sha256:23d2b3c2b8e7e5a6cb7922f7c27d73a9a615f0a5ab5d0e03dd533c477de23004 \ + --hash=sha256:2612e9bb4977381184bb2463150336d0f7e014d6bb5d4a370f9a372d21916f69 \ + --hash=sha256:275df437ab03f8c033b8a2c181e51716c32d831082d93ce48002a5227ec93019 \ + --hash=sha256:2ac9957bc6d2403c4772c890916bf181b2653640da98f32e04b96e4d6fb3252a \ + --hash=sha256:2b11ea433db22e720758cba584c9d661077121fcf60ab43351950ded20283440 \ + --hash=sha256:2fdd53b806786bd6112d97c1f1e7841e5e4daa06810ab4b284026a1a0e484c0b \ + --hash=sha256:33591d59f4956c9812f8063eff2e2c0065bc02050837f152574069f5f9f17775 \ + --hash=sha256:36a47636c3de227cd765e25a21dc5dace00539b82ddd99ee36abae38178eff9e \ + --hash=sha256:39b2853efc9403927f9065cc48c9980649462acbdf81cd4f0cb773af2fd734bc \ + --hash=sha256:3db41c5e49ef73641d5111554e1d1d3af106410a6c1fb52cf68912ba7a343a0d \ + --hash=sha256:445b47bc32de69d990ad0f34da0e20f535914623d1e506e74d6bc5c9dc40bb09 \ + --hash=sha256:466e6ad8caefb589ed281c076deb6f0cd330e8bc13c5035854ffb9c2014b118c \ + --hash=sha256:48f260e4c7294ef275744210a4010f116048e0c95857befb7462e033f09442fe \ + --hash=sha256:4ac59d5d6910b220141c1737b79d4a5aa9e57466e7469a012ed42ce2d3995e88 \ + --hash=sha256:53866a9d8ab363271c9e80c7c2e9441814961d47f88c9bc3b248142c32141d94 \ + --hash=sha256:589402548251056878d2e7c8859286eb91bd841af117dbe4ab000e6450987e08 \ + --hash=sha256:68953dc84b244b053c0d5f137a21ae8287ecf51b20872eccf8eaac0302d3e3b0 \ + --hash=sha256:6c25b8eb733d4e741246151d895dd0308137532737f337411160ff69ca24f93a \ + --hash=sha256:7034d381789f45576ec3f1fa0e15d741828146439228dc3f7c59856c5bcd3292 \ + --hash=sha256:73a1d6bd01961e9fd447162e137ed949c01bdb830dfca487c4a14e9742dccc93 \ + --hash=sha256:8226a33c542bcb54cd6bd0a366067b610b41713b64c9abec1bc4533d69f51e70 \ + --hash=sha256:888196c9c8893a1e8ff5e89b8f894e7f4f0e64a5af4d8f3c410f0319128bb2f8 \ + --hash=sha256:88c5b4b47a8a138338a07fc94e2ba3b1535f69247670abfe422de4e0b344aae2 \ + --hash=sha256:8a1b2effa96a5f019e72874969394edd393e2fbd6414a8208fea363a22803b45 \ + --hash=sha256:93e1856c8313bc688d5df069e106a4bc962eef3d13372020cc6e3ebf5e045202 \ + --hash=sha256:9501f36fac6b875c124243a379267d879262480bf85b1dbda61f5ad4d01b75a3 \ + --hash=sha256:959665072bd60f45c5b6b5d711f15bdefc9849dd5da9fb6c873e35f5d34d8cfb \ + --hash=sha256:a1d67d0d53d2a138f9e29d8acdabe11310c185e36f0a848efa104d4e40b808e4 \ + --hash=sha256:a493d470183ee620a3df1e6e55b3e4de8143c0ba1b16f3ded83208ea8ddfd91d \ + --hash=sha256:a7ccf5825fd71d4542c8ab28d4d482aace885f5ebe4b40faaa290eed8e095a4c \ + --hash=sha256:a88b7df61a292603e7cd662d92565d915796b094ffb3d206579aaebac6b85d5f \ + --hash=sha256:a97079b955b00b732c6f280d5023e0eefe359045e8b83b08cf0333af9ec78f26 \ + --hash=sha256:d22fdef58976457c65e2796e6730a3ea4a254f3ba83777ecfc8592ff8d77d303 \ + --hash=sha256:d75f693bb4e92c335e0645e8845e553cd09dc91616412d1d4650da835b5449df \ + --hash=sha256:d8593f8464fb64d58e8cb0b905b272d40184eac9a18d83cf8c10749c3eafcd7e \ + --hash=sha256:d8fff0f0c1d8bc5d866762ae95bd99d53282337af1be9dc0d88506b340e74b73 \ + --hash=sha256:de20a212ef3d00d609d0b22eb7cc798d5a69035e81839f549b538eff4105d01c \ + --hash=sha256:e9e9d4e2e336c529d4c435baad846a181e39a982f823f7e4495ec0b0ec8538d2 \ + --hash=sha256:f058a77ef0ece4e210bb0450e68408d4223f728b109764676e1a13537d056bb0 \ + --hash=sha256:f1a4b358947a65b94e2501ce3e078bbc929b039ede4679ddb0460829b12f7375 \ + --hash=sha256:f9b2cde1cd1b2a10246dbc143ba49d942d14fb3d2b4bccf4618d475c65464912 \ + --hash=sha256:fe3390c538f12437b859d815040763abc728955a52ca6ff9c5d4ac707c4ad98e + # via -r build/requirements.in + +# The following packages are considered to be unsafe in a requirements file: +setuptools==69.5.1 \ + --hash=sha256:6c1fccdac05a97e598fb0ae3bbed5904ccb317337a51139dcd51453611bbb987 \ + --hash=sha256:c636ac361bc47580504644275c9ad802c50415c7522212252c033bd15f301f32 + # via -r build/test-requirements.txt diff --git a/build/requirements_lock_3_12.txt b/build/requirements_lock_3_12.txt new file mode 100644 index 000000000000..1468e64c29cd --- /dev/null +++ b/build/requirements_lock_3_12.txt @@ -0,0 +1,613 @@ +# +# This file is autogenerated by pip-compile with Python 3.12 +# by the following command: +# +# bazel run //build:requirements.update +# +absl-py==2.1.0 \ + --hash=sha256:526a04eadab8b4ee719ce68f204172ead1027549089702d99b9059f129ff1308 \ + --hash=sha256:7820790efbb316739cde8b4e19357243fc3608a152024288513dd968d7d959ff + # via -r build/test-requirements.txt +attrs==23.2.0 \ + --hash=sha256:935dc3b529c262f6cf76e50877d35a4bd3c1de194fd41f47a2b7ae8f19971f30 \ + --hash=sha256:99b87a485a5820b23b879f04c2305b44b951b502fd64be915879d77a7e8fc6f1 + # via hypothesis +build==1.2.1 \ + --hash=sha256:526263f4870c26f26c433545579475377b2b7588b6f1eac76a001e873ae3e19d \ + --hash=sha256:75e10f767a433d9a86e50d83f418e83efc18ede923ee5ff7df93b6cb0306c5d4 + # via -r build/test-requirements.txt +cloudpickle==3.0.0 \ + --hash=sha256:246ee7d0c295602a036e86369c77fecda4ab17b506496730f2f576d9016fd9c7 \ + --hash=sha256:996d9a482c6fb4f33c1a35335cf8afd065d2a56e973270364840712d9131a882 + # via -r build/test-requirements.txt +colorama==0.4.6 \ + --hash=sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44 \ + --hash=sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6 + # via -r build/test-requirements.txt +contourpy==1.2.1 \ + --hash=sha256:00e5388f71c1a0610e6fe56b5c44ab7ba14165cdd6d695429c5cd94021e390b2 \ + --hash=sha256:10a37ae557aabf2509c79715cd20b62e4c7c28b8cd62dd7d99e5ed3ce28c3fd9 \ + --hash=sha256:11959f0ce4a6f7b76ec578576a0b61a28bdc0696194b6347ba3f1c53827178b9 \ + --hash=sha256:187fa1d4c6acc06adb0fae5544c59898ad781409e61a926ac7e84b8f276dcef4 \ + --hash=sha256:1a07fc092a4088ee952ddae19a2b2a85757b923217b7eed584fdf25f53a6e7ce \ + --hash=sha256:1cac0a8f71a041aa587410424ad46dfa6a11f6149ceb219ce7dd48f6b02b87a7 \ + --hash=sha256:1d59e739ab0e3520e62a26c60707cc3ab0365d2f8fecea74bfe4de72dc56388f \ + --hash=sha256:2855c8b0b55958265e8b5888d6a615ba02883b225f2227461aa9127c578a4922 \ + --hash=sha256:2e785e0f2ef0d567099b9ff92cbfb958d71c2d5b9259981cd9bee81bd194c9a4 \ + --hash=sha256:309be79c0a354afff9ff7da4aaed7c3257e77edf6c1b448a779329431ee79d7e \ + --hash=sha256:39f3ecaf76cd98e802f094e0d4fbc6dc9c45a8d0c4d185f0f6c2234e14e5f75b \ + --hash=sha256:457499c79fa84593f22454bbd27670227874cd2ff5d6c84e60575c8b50a69619 \ + --hash=sha256:49e70d111fee47284d9dd867c9bb9a7058a3c617274900780c43e38d90fe1205 \ + --hash=sha256:4c75507d0a55378240f781599c30e7776674dbaf883a46d1c90f37e563453480 \ + --hash=sha256:4c863140fafc615c14a4bf4efd0f4425c02230eb8ef02784c9a156461e62c965 \ + --hash=sha256:4d8908b3bee1c889e547867ca4cdc54e5ab6be6d3e078556814a22457f49423c \ + --hash=sha256:5b9eb0ca724a241683c9685a484da9d35c872fd42756574a7cfbf58af26677fd \ + --hash=sha256:6022cecf8f44e36af10bd9118ca71f371078b4c168b6e0fab43d4a889985dbb5 \ + --hash=sha256:6150ffa5c767bc6332df27157d95442c379b7dce3a38dff89c0f39b63275696f \ + --hash=sha256:62828cada4a2b850dbef89c81f5a33741898b305db244904de418cc957ff05dc \ + --hash=sha256:7b4182299f251060996af5249c286bae9361fa8c6a9cda5efc29fe8bfd6062ec \ + --hash=sha256:94b34f32646ca0414237168d68a9157cb3889f06b096612afdd296003fdd32fd \ + --hash=sha256:9ce6889abac9a42afd07a562c2d6d4b2b7134f83f18571d859b25624a331c90b \ + --hash=sha256:9cffe0f850e89d7c0012a1fb8730f75edd4320a0a731ed0c183904fe6ecfc3a9 \ + --hash=sha256:a12a813949e5066148712a0626895c26b2578874e4cc63160bb007e6df3436fe \ + --hash=sha256:a1eea9aecf761c661d096d39ed9026574de8adb2ae1c5bd7b33558af884fb2ce \ + --hash=sha256:a31f94983fecbac95e58388210427d68cd30fe8a36927980fab9c20062645609 \ + --hash=sha256:ac58bdee53cbeba2ecad824fa8159493f0bf3b8ea4e93feb06c9a465d6c87da8 \ + --hash=sha256:af3f4485884750dddd9c25cb7e3915d83c2db92488b38ccb77dd594eac84c4a0 \ + --hash=sha256:b33d2bc4f69caedcd0a275329eb2198f560b325605810895627be5d4b876bf7f \ + --hash=sha256:b59c0ffceff8d4d3996a45f2bb6f4c207f94684a96bf3d9728dbb77428dd8cb8 \ + --hash=sha256:bb6834cbd983b19f06908b45bfc2dad6ac9479ae04abe923a275b5f48f1a186b \ + --hash=sha256:bd3db01f59fdcbce5b22afad19e390260d6d0222f35a1023d9adc5690a889364 \ + --hash=sha256:bd7c23df857d488f418439686d3b10ae2fbf9bc256cd045b37a8c16575ea1040 \ + --hash=sha256:c2528d60e398c7c4c799d56f907664673a807635b857df18f7ae64d3e6ce2d9f \ + --hash=sha256:d31a63bc6e6d87f77d71e1abbd7387ab817a66733734883d1fc0021ed9bfa083 \ + --hash=sha256:d4492d82b3bc7fbb7e3610747b159869468079fe149ec5c4d771fa1f614a14df \ + --hash=sha256:ddcb8581510311e13421b1f544403c16e901c4e8f09083c881fab2be80ee31ba \ + --hash=sha256:e1d59258c3c67c865435d8fbeb35f8c59b8bef3d6f46c1f29f6123556af28445 \ + --hash=sha256:eb3315a8a236ee19b6df481fc5f997436e8ade24a9f03dfdc6bd490fea20c6da \ + --hash=sha256:ef2b055471c0eb466033760a521efb9d8a32b99ab907fc8358481a1dd29e3bd3 \ + --hash=sha256:ef5adb9a3b1d0c645ff694f9bca7702ec2c70f4d734f9922ea34de02294fdf72 \ + --hash=sha256:f32c38afb74bd98ce26de7cc74a67b40afb7b05aae7b42924ea990d51e4dac02 \ + --hash=sha256:fe0ccca550bb8e5abc22f530ec0466136379c01321fd94f30a22231e8a48d985 + # via matplotlib +cycler==0.12.1 \ + --hash=sha256:85cef7cff222d8644161529808465972e51340599459b8ac3ccbac5a854e0d30 \ + --hash=sha256:88bb128f02ba341da8ef447245a9e138fae777f6a23943da4540077d3601eb1c + # via matplotlib +etils[epath,epy]==1.8.0 \ + --hash=sha256:f31d7f27a889457eaa44eab18ce836d24fd6d40dbbb167d38879b7296f6456ea \ + --hash=sha256:fb478f57fec202e260e54c9192b317692fd63db2d11d993e70bcdffa29cccd58 + # via -r build/requirements.in +execnet==2.1.1 \ + --hash=sha256:26dee51f1b80cebd6d0ca8e74dd8745419761d3bef34163928cbebbdc4749fdc \ + --hash=sha256:5189b52c6121c24feae288166ab41b32549c7e2348652736540b9e6e7d4e72e3 + # via pytest-xdist +filelock==3.14.0 \ + --hash=sha256:43339835842f110ca7ae60f1e1c160714c5a6afd15a2873419ab185334975c0f \ + --hash=sha256:6ea72da3be9b8c82afd3edcf99f2fffbb5076335a5ae4d03248bb5b6c3eae78a + # via -r build/test-requirements.txt +flatbuffers==24.3.25 \ + --hash=sha256:8dbdec58f935f3765e4f7f3cf635ac3a77f83568138d6a2311f524ec96364812 \ + --hash=sha256:de2ec5b203f21441716617f38443e0a8ebf3d25bf0d9c0bb0ce68fa00ad546a4 + # via -r build/test-requirements.txt +fonttools==4.51.0 \ + --hash=sha256:0118ef998a0699a96c7b28457f15546815015a2710a1b23a7bf6c1be60c01636 \ + --hash=sha256:0d145976194a5242fdd22df18a1b451481a88071feadf251221af110ca8f00ce \ + --hash=sha256:0e19bd9e9964a09cd2433a4b100ca7f34e34731e0758e13ba9a1ed6e5468cc0f \ + --hash=sha256:0f08c901d3866a8905363619e3741c33f0a83a680d92a9f0e575985c2634fcc1 \ + --hash=sha256:1250e818b5f8a679ad79660855528120a8f0288f8f30ec88b83db51515411fcc \ + --hash=sha256:15c94eeef6b095831067f72c825eb0e2d48bb4cea0647c1b05c981ecba2bf39f \ + --hash=sha256:1621ee57da887c17312acc4b0e7ac30d3a4fb0fec6174b2e3754a74c26bbed1e \ + --hash=sha256:180194c7fe60c989bb627d7ed5011f2bef1c4d36ecf3ec64daec8302f1ae0716 \ + --hash=sha256:278e50f6b003c6aed19bae2242b364e575bcb16304b53f2b64f6551b9c000e15 \ + --hash=sha256:32b17504696f605e9e960647c5f64b35704782a502cc26a37b800b4d69ff3c77 \ + --hash=sha256:3bee3f3bd9fa1d5ee616ccfd13b27ca605c2b4270e45715bd2883e9504735034 \ + --hash=sha256:4060acc2bfa2d8e98117828a238889f13b6f69d59f4f2d5857eece5277b829ba \ + --hash=sha256:54dcf21a2f2d06ded676e3c3f9f74b2bafded3a8ff12f0983160b13e9f2fb4a7 \ + --hash=sha256:56fc244f2585d6c00b9bcc59e6593e646cf095a96fe68d62cd4da53dd1287b55 \ + --hash=sha256:599bdb75e220241cedc6faebfafedd7670335d2e29620d207dd0378a4e9ccc5a \ + --hash=sha256:5f6bc991d1610f5c3bbe997b0233cbc234b8e82fa99fc0b2932dc1ca5e5afec0 \ + --hash=sha256:60a3409c9112aec02d5fb546f557bca6efa773dcb32ac147c6baf5f742e6258b \ + --hash=sha256:68b3fb7775a923be73e739f92f7e8a72725fd333eab24834041365d2278c3671 \ + --hash=sha256:76f1777d8b3386479ffb4a282e74318e730014d86ce60f016908d9801af9ca2a \ + --hash=sha256:806e7912c32a657fa39d2d6eb1d3012d35f841387c8fc6cf349ed70b7c340039 \ + --hash=sha256:84d7751f4468dd8cdd03ddada18b8b0857a5beec80bce9f435742abc9a851a74 \ + --hash=sha256:865a58b6e60b0938874af0968cd0553bcd88e0b2cb6e588727117bd099eef836 \ + --hash=sha256:8ac27f436e8af7779f0bb4d5425aa3535270494d3bc5459ed27de3f03151e4c2 \ + --hash=sha256:8b4850fa2ef2cfbc1d1f689bc159ef0f45d8d83298c1425838095bf53ef46308 \ + --hash=sha256:8b5ad456813d93b9c4b7ee55302208db2b45324315129d85275c01f5cb7e61a2 \ + --hash=sha256:8e2f1a4499e3b5ee82c19b5ee57f0294673125c65b0a1ff3764ea1f9db2f9ef5 \ + --hash=sha256:9696fe9f3f0c32e9a321d5268208a7cc9205a52f99b89479d1b035ed54c923f1 \ + --hash=sha256:96a48e137c36be55e68845fc4284533bda2980f8d6f835e26bca79d7e2006438 \ + --hash=sha256:a8feca65bab31479d795b0d16c9a9852902e3a3c0630678efb0b2b7941ea9c74 \ + --hash=sha256:aefa011207ed36cd280babfaa8510b8176f1a77261833e895a9d96e57e44802f \ + --hash=sha256:b2b92381f37b39ba2fc98c3a45a9d6383bfc9916a87d66ccb6553f7bdd129097 \ + --hash=sha256:b3c61423f22165541b9403ee39874dcae84cd57a9078b82e1dce8cb06b07fa2e \ + --hash=sha256:b5b48a1121117047d82695d276c2af2ee3a24ffe0f502ed581acc2673ecf1037 \ + --hash=sha256:c18b49adc721a7d0b8dfe7c3130c89b8704baf599fb396396d07d4aa69b824a1 \ + --hash=sha256:c5b8cab0c137ca229433570151b5c1fc6af212680b58b15abd797dcdd9dd5051 \ + --hash=sha256:c7e91abdfae1b5c9e3a543f48ce96013f9a08c6c9668f1e6be0beabf0a569c1b \ + --hash=sha256:cadf4e12a608ef1d13e039864f484c8a968840afa0258b0b843a0556497ea9ed \ + --hash=sha256:dc0673361331566d7a663d7ce0f6fdcbfbdc1f59c6e3ed1165ad7202ca183c68 \ + --hash=sha256:de7c29bdbdd35811f14493ffd2534b88f0ce1b9065316433b22d63ca1cd21f14 \ + --hash=sha256:e9d9298be7a05bb4801f558522adbe2feea1b0b103d5294ebf24a92dd49b78e5 \ + --hash=sha256:ee1af4be1c5afe4c96ca23badd368d8dc75f611887fb0c0dac9f71ee5d6f110e \ + --hash=sha256:f7e89853d8bea103c8e3514b9f9dc86b5b4120afb4583b57eb10dfa5afbe0936 + # via matplotlib +fsspec==2024.5.0 \ + --hash=sha256:1d021b0b0f933e3b3029ed808eb400c08ba101ca2de4b3483fbc9ca23fcee94a \ + --hash=sha256:e0fdbc446d67e182f49a70b82cf7889028a63588fde6b222521f10937b2b670c + # via etils +hypothesis==6.102.4 \ + --hash=sha256:013df31b04a4daede13756f497e60e451963d86f426395a79f99c5d692919bbd \ + --hash=sha256:59b4d144346d5cffb482cc1bafbd21b13ff31608e8c4b3e4630339aee3e87763 + # via -r build/test-requirements.txt +importlib-resources==6.4.0 \ + --hash=sha256:50d10f043df931902d4194ea07ec57960f66a80449ff867bfe782b4c486ba78c \ + --hash=sha256:cdb2b453b8046ca4e3798eb1d84f3cce1446a0e8e7b5ef4efb600f19fc398145 + # via etils +iniconfig==2.0.0 \ + --hash=sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3 \ + --hash=sha256:b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374 + # via pytest +kiwisolver==1.4.5 \ + --hash=sha256:00bd361b903dc4bbf4eb165f24d1acbee754fce22ded24c3d56eec268658a5cf \ + --hash=sha256:040c1aebeda72197ef477a906782b5ab0d387642e93bda547336b8957c61022e \ + --hash=sha256:05703cf211d585109fcd72207a31bb170a0f22144d68298dc5e61b3c946518af \ + --hash=sha256:06f54715b7737c2fecdbf140d1afb11a33d59508a47bf11bb38ecf21dc9ab79f \ + --hash=sha256:0dc9db8e79f0036e8173c466d21ef18e1befc02de8bf8aa8dc0813a6dc8a7046 \ + --hash=sha256:0f114aa76dc1b8f636d077979c0ac22e7cd8f3493abbab152f20eb8d3cda71f3 \ + --hash=sha256:11863aa14a51fd6ec28688d76f1735f8f69ab1fabf388851a595d0721af042f5 \ + --hash=sha256:11c7de8f692fc99816e8ac50d1d1aef4f75126eefc33ac79aac02c099fd3db71 \ + --hash=sha256:11d011a7574eb3b82bcc9c1a1d35c1d7075677fdd15de527d91b46bd35e935ee \ + --hash=sha256:146d14bebb7f1dc4d5fbf74f8a6cb15ac42baadee8912eb84ac0b3b2a3dc6ac3 \ + --hash=sha256:15568384086b6df3c65353820a4473575dbad192e35010f622c6ce3eebd57af9 \ + --hash=sha256:19df6e621f6d8b4b9c4d45f40a66839294ff2bb235e64d2178f7522d9170ac5b \ + --hash=sha256:1b04139c4236a0f3aff534479b58f6f849a8b351e1314826c2d230849ed48985 \ + --hash=sha256:210ef2c3a1f03272649aff1ef992df2e724748918c4bc2d5a90352849eb40bea \ + --hash=sha256:2270953c0d8cdab5d422bee7d2007f043473f9d2999631c86a223c9db56cbd16 \ + --hash=sha256:2400873bccc260b6ae184b2b8a4fec0e4082d30648eadb7c3d9a13405d861e89 \ + --hash=sha256:2a40773c71d7ccdd3798f6489aaac9eee213d566850a9533f8d26332d626b82c \ + --hash=sha256:2c5674c4e74d939b9d91dda0fae10597ac7521768fec9e399c70a1f27e2ea2d9 \ + --hash=sha256:3195782b26fc03aa9c6913d5bad5aeb864bdc372924c093b0f1cebad603dd712 \ + --hash=sha256:31a82d498054cac9f6d0b53d02bb85811185bcb477d4b60144f915f3b3126342 \ + --hash=sha256:32d5cf40c4f7c7b3ca500f8985eb3fb3a7dfc023215e876f207956b5ea26632a \ + --hash=sha256:346f5343b9e3f00b8db8ba359350eb124b98c99efd0b408728ac6ebf38173958 \ + --hash=sha256:378a214a1e3bbf5ac4a8708304318b4f890da88c9e6a07699c4ae7174c09a68d \ + --hash=sha256:39b42c68602539407884cf70d6a480a469b93b81b7701378ba5e2328660c847a \ + --hash=sha256:3a2b053a0ab7a3960c98725cfb0bf5b48ba82f64ec95fe06f1d06c99b552e130 \ + --hash=sha256:3aba7311af82e335dd1e36ffff68aaca609ca6290c2cb6d821a39aa075d8e3ff \ + --hash=sha256:3cd32d6c13807e5c66a7cbb79f90b553642f296ae4518a60d8d76243b0ad2898 \ + --hash=sha256:3edd2fa14e68c9be82c5b16689e8d63d89fe927e56debd6e1dbce7a26a17f81b \ + --hash=sha256:4c380469bd3f970ef677bf2bcba2b6b0b4d5c75e7a020fb863ef75084efad66f \ + --hash=sha256:4e66e81a5779b65ac21764c295087de82235597a2293d18d943f8e9e32746265 \ + --hash=sha256:53abb58632235cd154176ced1ae8f0d29a6657aa1aa9decf50b899b755bc2b93 \ + --hash=sha256:5794cf59533bc3f1b1c821f7206a3617999db9fbefc345360aafe2e067514929 \ + --hash=sha256:59415f46a37f7f2efeec758353dd2eae1b07640d8ca0f0c42548ec4125492635 \ + --hash=sha256:59ec7b7c7e1a61061850d53aaf8e93db63dce0c936db1fda2658b70e4a1be709 \ + --hash=sha256:59edc41b24031bc25108e210c0def6f6c2191210492a972d585a06ff246bb79b \ + --hash=sha256:5a580c91d686376f0f7c295357595c5a026e6cbc3d77b7c36e290201e7c11ecb \ + --hash=sha256:5b94529f9b2591b7af5f3e0e730a4e0a41ea174af35a4fd067775f9bdfeee01a \ + --hash=sha256:5c7b3b3a728dc6faf3fc372ef24f21d1e3cee2ac3e9596691d746e5a536de920 \ + --hash=sha256:5c90ae8c8d32e472be041e76f9d2f2dbff4d0b0be8bd4041770eddb18cf49a4e \ + --hash=sha256:5e7139af55d1688f8b960ee9ad5adafc4ac17c1c473fe07133ac092310d76544 \ + --hash=sha256:5ff5cf3571589b6d13bfbfd6bcd7a3f659e42f96b5fd1c4830c4cf21d4f5ef45 \ + --hash=sha256:620ced262a86244e2be10a676b646f29c34537d0d9cc8eb26c08f53d98013390 \ + --hash=sha256:6512cb89e334e4700febbffaaa52761b65b4f5a3cf33f960213d5656cea36a77 \ + --hash=sha256:6c08e1312a9cf1074d17b17728d3dfce2a5125b2d791527f33ffbe805200a355 \ + --hash=sha256:6c3bd3cde54cafb87d74d8db50b909705c62b17c2099b8f2e25b461882e544ff \ + --hash=sha256:6ef7afcd2d281494c0a9101d5c571970708ad911d028137cd558f02b851c08b4 \ + --hash=sha256:7269d9e5f1084a653d575c7ec012ff57f0c042258bf5db0954bf551c158466e7 \ + --hash=sha256:72d40b33e834371fd330fb1472ca19d9b8327acb79a5821d4008391db8e29f20 \ + --hash=sha256:74d1b44c6cfc897df648cc9fdaa09bc3e7679926e6f96df05775d4fb3946571c \ + --hash=sha256:74db36e14a7d1ce0986fa104f7d5637aea5c82ca6326ed0ec5694280942d1162 \ + --hash=sha256:763773d53f07244148ccac5b084da5adb90bfaee39c197554f01b286cf869228 \ + --hash=sha256:76c6a5964640638cdeaa0c359382e5703e9293030fe730018ca06bc2010c4437 \ + --hash=sha256:76d9289ed3f7501012e05abb8358bbb129149dbd173f1f57a1bf1c22d19ab7cc \ + --hash=sha256:7931d8f1f67c4be9ba1dd9c451fb0eeca1a25b89e4d3f89e828fe12a519b782a \ + --hash=sha256:7b8b454bac16428b22560d0a1cf0a09875339cab69df61d7805bf48919415901 \ + --hash=sha256:7e5bab140c309cb3a6ce373a9e71eb7e4873c70c2dda01df6820474f9889d6d4 \ + --hash=sha256:83d78376d0d4fd884e2c114d0621624b73d2aba4e2788182d286309ebdeed770 \ + --hash=sha256:852542f9481f4a62dbb5dd99e8ab7aedfeb8fb6342349a181d4036877410f525 \ + --hash=sha256:85267bd1aa8880a9c88a8cb71e18d3d64d2751a790e6ca6c27b8ccc724bcd5ad \ + --hash=sha256:88a2df29d4724b9237fc0c6eaf2a1adae0cdc0b3e9f4d8e7dc54b16812d2d81a \ + --hash=sha256:88b9f257ca61b838b6f8094a62418421f87ac2a1069f7e896c36a7d86b5d4c29 \ + --hash=sha256:8ab3919a9997ab7ef2fbbed0cc99bb28d3c13e6d4b1ad36e97e482558a91be90 \ + --hash=sha256:92dea1ffe3714fa8eb6a314d2b3c773208d865a0e0d35e713ec54eea08a66250 \ + --hash=sha256:9407b6a5f0d675e8a827ad8742e1d6b49d9c1a1da5d952a67d50ef5f4170b18d \ + --hash=sha256:9408acf3270c4b6baad483865191e3e582b638b1654a007c62e3efe96f09a9a3 \ + --hash=sha256:955e8513d07a283056b1396e9a57ceddbd272d9252c14f154d450d227606eb54 \ + --hash=sha256:9db8ea4c388fdb0f780fe91346fd438657ea602d58348753d9fb265ce1bca67f \ + --hash=sha256:9eaa8b117dc8337728e834b9c6e2611f10c79e38f65157c4c38e9400286f5cb1 \ + --hash=sha256:a51a263952b1429e429ff236d2f5a21c5125437861baeed77f5e1cc2d2c7c6da \ + --hash=sha256:a6aa6315319a052b4ee378aa171959c898a6183f15c1e541821c5c59beaa0238 \ + --hash=sha256:aa12042de0171fad672b6c59df69106d20d5596e4f87b5e8f76df757a7c399aa \ + --hash=sha256:aaf7be1207676ac608a50cd08f102f6742dbfc70e8d60c4db1c6897f62f71523 \ + --hash=sha256:b0157420efcb803e71d1b28e2c287518b8808b7cf1ab8af36718fd0a2c453eb0 \ + --hash=sha256:b3f7e75f3015df442238cca659f8baa5f42ce2a8582727981cbfa15fee0ee205 \ + --hash=sha256:b9098e0049e88c6a24ff64545cdfc50807818ba6c1b739cae221bbbcbc58aad3 \ + --hash=sha256:ba55dce0a9b8ff59495ddd050a0225d58bd0983d09f87cfe2b6aec4f2c1234e4 \ + --hash=sha256:bb86433b1cfe686da83ce32a9d3a8dd308e85c76b60896d58f082136f10bffac \ + --hash=sha256:bbea0db94288e29afcc4c28afbf3a7ccaf2d7e027489c449cf7e8f83c6346eb9 \ + --hash=sha256:bbf1d63eef84b2e8c89011b7f2235b1e0bf7dacc11cac9431fc6468e99ac77fb \ + --hash=sha256:c7940c1dc63eb37a67721b10d703247552416f719c4188c54e04334321351ced \ + --hash=sha256:c9bf3325c47b11b2e51bca0824ea217c7cd84491d8ac4eefd1e409705ef092bd \ + --hash=sha256:cdc8a402aaee9a798b50d8b827d7ecf75edc5fb35ea0f91f213ff927c15f4ff0 \ + --hash=sha256:ceec1a6bc6cab1d6ff5d06592a91a692f90ec7505d6463a88a52cc0eb58545da \ + --hash=sha256:cfe6ab8da05c01ba6fbea630377b5da2cd9bcbc6338510116b01c1bc939a2c18 \ + --hash=sha256:d099e745a512f7e3bbe7249ca835f4d357c586d78d79ae8f1dcd4d8adeb9bda9 \ + --hash=sha256:d0ef46024e6a3d79c01ff13801cb19d0cad7fd859b15037aec74315540acc276 \ + --hash=sha256:d2e5a98f0ec99beb3c10e13b387f8db39106d53993f498b295f0c914328b1333 \ + --hash=sha256:da4cfb373035def307905d05041c1d06d8936452fe89d464743ae7fb8371078b \ + --hash=sha256:da802a19d6e15dffe4b0c24b38b3af68e6c1a68e6e1d8f30148c83864f3881db \ + --hash=sha256:dced8146011d2bc2e883f9bd68618b8247387f4bbec46d7392b3c3b032640126 \ + --hash=sha256:dfdd7c0b105af050eb3d64997809dc21da247cf44e63dc73ff0fd20b96be55a9 \ + --hash=sha256:e368f200bbc2e4f905b8e71eb38b3c04333bddaa6a2464a6355487b02bb7fb09 \ + --hash=sha256:e391b1f0a8a5a10ab3b9bb6afcfd74f2175f24f8975fb87ecae700d1503cdee0 \ + --hash=sha256:e57e563a57fb22a142da34f38acc2fc1a5c864bc29ca1517a88abc963e60d6ec \ + --hash=sha256:e5d706eba36b4c4d5bc6c6377bb6568098765e990cfc21ee16d13963fab7b3e7 \ + --hash=sha256:ec20916e7b4cbfb1f12380e46486ec4bcbaa91a9c448b97023fde0d5bbf9e4ff \ + --hash=sha256:f1d072c2eb0ad60d4c183f3fb44ac6f73fb7a8f16a2694a91f988275cbf352f9 \ + --hash=sha256:f846c260f483d1fd217fe5ed7c173fb109efa6b1fc8381c8b7552c5781756192 \ + --hash=sha256:f91de7223d4c7b793867797bacd1ee53bfe7359bd70d27b7b58a04efbb9436c8 \ + --hash=sha256:faae4860798c31530dd184046a900e652c95513796ef51a12bc086710c2eec4d \ + --hash=sha256:fc579bf0f502e54926519451b920e875f433aceb4624a3646b3252b5caa9e0b6 \ + --hash=sha256:fcc700eadbbccbf6bc1bcb9dbe0786b4b1cb91ca0dcda336eef5c2beed37b797 \ + --hash=sha256:fd32ea360bcbb92d28933fc05ed09bffcb1704ba3fc7942e81db0fd4f81a7892 \ + --hash=sha256:fdb7adb641a0d13bdcd4ef48e062363d8a9ad4a182ac7647ec88f695e719ae9f + # via matplotlib +markdown-it-py==3.0.0 \ + --hash=sha256:355216845c60bd96232cd8d8c40e8f9765cc86f46880e43a8fd22dc1a1a8cab1 \ + --hash=sha256:e3f60a94fa066dc52ec76661e37c851cb232d92f9886b15cb560aaada2df8feb + # via rich +matplotlib==3.9.0 ; python_version >= "3.11" \ + --hash=sha256:063af8587fceeac13b0936c42a2b6c732c2ab1c98d38abc3337e430e1ff75e38 \ + --hash=sha256:06a478f0d67636554fa78558cfbcd7b9dba85b51f5c3b5a0c9be49010cf5f321 \ + --hash=sha256:0a490715b3b9984fa609116481b22178348c1a220a4499cda79132000a79b4db \ + --hash=sha256:0fc51eaa5262553868461c083d9adadb11a6017315f3a757fc45ec6ec5f02888 \ + --hash=sha256:13beb4840317d45ffd4183a778685e215939be7b08616f431c7795276e067463 \ + --hash=sha256:290d304e59be2b33ef5c2d768d0237f5bd132986bdcc66f80bc9bcc300066a03 \ + --hash=sha256:2bcee1dffaf60fe7656183ac2190bd630842ff87b3153afb3e384d966b57fe56 \ + --hash=sha256:2e7f03e5cbbfacdd48c8ea394d365d91ee8f3cae7e6ec611409927b5ed997ee4 \ + --hash=sha256:3f988bafb0fa39d1074ddd5bacd958c853e11def40800c5824556eb630f94d3b \ + --hash=sha256:52146fc3bd7813cc784562cb93a15788be0b2875c4655e2cc6ea646bfa30344b \ + --hash=sha256:550cdda3adbd596078cca7d13ed50b77879104e2e46392dcd7c75259d8f00e85 \ + --hash=sha256:616fabf4981a3b3c5a15cd95eba359c8489c4e20e03717aea42866d8d0465956 \ + --hash=sha256:76cce0f31b351e3551d1f3779420cf8f6ec0d4a8cf9c0237a3b549fd28eb4abb \ + --hash=sha256:7ff2e239c26be4f24bfa45860c20ffccd118d270c5b5d081fa4ea409b5469fcd \ + --hash=sha256:8146ce83cbc5dc71c223a74a1996d446cd35cfb6a04b683e1446b7e6c73603b7 \ + --hash=sha256:81c40af649d19c85f8073e25e5806926986806fa6d54be506fbf02aef47d5a89 \ + --hash=sha256:9a2fa6d899e17ddca6d6526cf6e7ba677738bf2a6a9590d702c277204a7c6152 \ + --hash=sha256:a5be985db2596d761cdf0c2eaf52396f26e6a64ab46bd8cd810c48972349d1be \ + --hash=sha256:af4001b7cae70f7eaacfb063db605280058246de590fa7874f00f62259f2df7e \ + --hash=sha256:bd4f2831168afac55b881db82a7730992aa41c4f007f1913465fb182d6fb20c0 \ + --hash=sha256:bdd1ecbe268eb3e7653e04f451635f0fb0f77f07fd070242b44c076c9106da84 \ + --hash=sha256:c53aeb514ccbbcbab55a27f912d79ea30ab21ee0531ee2c09f13800efb272674 \ + --hash=sha256:c79f3a585f1368da6049318bdf1f85568d8d04b2e89fc24b7e02cc9b62017382 \ + --hash=sha256:cd53c79fd02f1c1808d2cfc87dd3cf4dbc63c5244a58ee7944497107469c8d8a \ + --hash=sha256:d38e85a1a6d732f645f1403ce5e6727fd9418cd4574521d5803d3d94911038e5 \ + --hash=sha256:d91a4ffc587bacf5c4ce4ecfe4bcd23a4b675e76315f2866e588686cc97fccdf \ + --hash=sha256:e6d29ea6c19e34b30fb7d88b7081f869a03014f66fe06d62cc77d5a6ea88ed7a \ + --hash=sha256:eaf3978060a106fab40c328778b148f590e27f6fa3cd15a19d6892575bce387d \ + --hash=sha256:fe428e191ea016bb278758c8ee82a8129c51d81d8c4bc0846c09e7e8e9057241 + # via -r build/requirements.in +mdurl==0.1.2 \ + --hash=sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8 \ + --hash=sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba + # via markdown-it-py +ml-dtypes==0.4.0 \ + --hash=sha256:03e7cda6ef164eed0abb31df69d2c00c3a5ab3e2610b6d4c42183a43329c72a5 \ + --hash=sha256:2bb83fd064db43e67e67d021e547698af4c8d5c6190f2e9b1c53c09f6ff5531d \ + --hash=sha256:3b67ec73a697c88c1122038e0de46520e48dc2ec876d42cf61bc5efe3c0b7675 \ + --hash=sha256:41affb38fdfe146e3db226cf2953021184d6f0c4ffab52136613e9601706e368 \ + --hash=sha256:43cf4356a0fe2eeac6d289018d0734e17a403bdf1fd911953c125dd0358edcc0 \ + --hash=sha256:723af6346447268a3cf0b7356e963d80ecb5732b5279b2aa3fa4b9fc8297c85e \ + --hash=sha256:75b4faf99d0711b81f393db36d210b4255fd419f6f790bc6c1b461f95ffb7a9e \ + --hash=sha256:93afe37f3a879d652ec9ef1fc47612388890660a2657fbb5747256c3b818fd81 \ + --hash=sha256:a15d96d090aebb55ee85173d1775ae325a001aab607a76c8ea0b964ccd6b5364 \ + --hash=sha256:ad6849a2db386b38e4d54fe13eb3293464561780531a918f8ef4c8169170dd49 \ + --hash=sha256:bdf689be7351cc3c95110c910c1b864002f113e682e44508910c849e144f3df1 \ + --hash=sha256:c83e4d443962d891d51669ff241d5aaad10a8d3d37a81c5532a45419885d591c \ + --hash=sha256:e1e2f4237b459a63c97c2c9f449baa637d7e4c20addff6a9bac486f22432f3b6 \ + --hash=sha256:eaa32979ebfde3a0d7c947cafbf79edc1ec77ac05ad0780ee86c1d8df70f2259 \ + --hash=sha256:eaf197e72f4f7176a19fe3cb8b61846b38c6757607e7bf9cd4b1d84cd3e74deb \ + --hash=sha256:ee9f91d4c4f9959a7e1051c141dc565f39e54435618152219769e24f5e9a4d06 \ + --hash=sha256:f1724ddcdf5edbaf615a62110af47407f1719b8d02e68ccee60683acb5f74da1 + # via -r build/requirements.in +mpmath==1.4.0a1 \ + --hash=sha256:78884400f439f500fa76be0121a8f9598313d87664863a192e1185ddbd7ae97f \ + --hash=sha256:f8b7b5a3a1726ab6e8c898eb2157426b82c482ab1ab8ffed9f88bb9e07c6e9c1 + # via -r build/test-requirements.txt +numpy==2.0.0 \ + --hash=sha256:04494f6ec467ccb5369d1808570ae55f6ed9b5809d7f035059000a37b8d7e86f \ + --hash=sha256:0a43f0974d501842866cc83471bdb0116ba0dffdbaac33ec05e6afed5b615238 \ + --hash=sha256:0e50842b2295ba8414c8c1d9d957083d5dfe9e16828b37de883f51fc53c4016f \ + --hash=sha256:0ec84b9ba0654f3b962802edc91424331f423dcf5d5f926676e0150789cb3d95 \ + --hash=sha256:17067d097ed036636fa79f6a869ac26df7db1ba22039d962422506640314933a \ + --hash=sha256:1cde1753efe513705a0c6d28f5884e22bdc30438bf0085c5c486cdaff40cd67a \ + --hash=sha256:1e72728e7501a450288fc8e1f9ebc73d90cfd4671ebbd631f3e7857c39bd16f2 \ + --hash=sha256:2635dbd200c2d6faf2ef9a0d04f0ecc6b13b3cad54f7c67c61155138835515d2 \ + --hash=sha256:2ce46fd0b8a0c947ae047d222f7136fc4d55538741373107574271bc00e20e8f \ + --hash=sha256:34f003cb88b1ba38cb9a9a4a3161c1604973d7f9d5552c38bc2f04f829536609 \ + --hash=sha256:354f373279768fa5a584bac997de6a6c9bc535c482592d7a813bb0c09be6c76f \ + --hash=sha256:38ecb5b0582cd125f67a629072fed6f83562d9dd04d7e03256c9829bdec027ad \ + --hash=sha256:3e8e01233d57639b2e30966c63d36fcea099d17c53bf424d77f088b0f4babd86 \ + --hash=sha256:3f6bed7f840d44c08ebdb73b1825282b801799e325bcbdfa6bc5c370e5aecc65 \ + --hash=sha256:4554eb96f0fd263041baf16cf0881b3f5dafae7a59b1049acb9540c4d57bc8cb \ + --hash=sha256:46e161722e0f619749d1cd892167039015b2c2817296104487cd03ed4a955995 \ + --hash=sha256:49d9f7d256fbc804391a7f72d4a617302b1afac1112fac19b6c6cec63fe7fe8a \ + --hash=sha256:4d2f62e55a4cd9c58c1d9a1c9edaedcd857a73cb6fda875bf79093f9d9086f85 \ + --hash=sha256:5f64641b42b2429f56ee08b4f427a4d2daf916ec59686061de751a55aafa22e4 \ + --hash=sha256:63b92c512d9dbcc37f9d81b123dec99fdb318ba38c8059afc78086fe73820275 \ + --hash=sha256:6d7696c615765091cc5093f76fd1fa069870304beaccfd58b5dcc69e55ef49c1 \ + --hash=sha256:79e843d186c8fb1b102bef3e2bc35ef81160ffef3194646a7fdd6a73c6b97196 \ + --hash=sha256:821eedb7165ead9eebdb569986968b541f9908979c2da8a4967ecac4439bae3d \ + --hash=sha256:84554fc53daa8f6abf8e8a66e076aff6ece62de68523d9f665f32d2fc50fd66e \ + --hash=sha256:8d83bb187fb647643bd56e1ae43f273c7f4dbcdf94550d7938cfc32566756514 \ + --hash=sha256:903703372d46bce88b6920a0cd86c3ad82dae2dbef157b5fc01b70ea1cfc430f \ + --hash=sha256:9416a5c2e92ace094e9f0082c5fd473502c91651fb896bc17690d6fc475128d6 \ + --hash=sha256:9a1712c015831da583b21c5bfe15e8684137097969c6d22e8316ba66b5baabe4 \ + --hash=sha256:9c27f0946a3536403efb0e1c28def1ae6730a72cd0d5878db38824855e3afc44 \ + --hash=sha256:a356364941fb0593bb899a1076b92dfa2029f6f5b8ba88a14fd0984aaf76d0df \ + --hash=sha256:a7039a136017eaa92c1848152827e1424701532ca8e8967fe480fe1569dae581 \ + --hash=sha256:acd3a644e4807e73b4e1867b769fbf1ce8c5d80e7caaef0d90dcdc640dfc9787 \ + --hash=sha256:ad0c86f3455fbd0de6c31a3056eb822fc939f81b1618f10ff3406971893b62a5 \ + --hash=sha256:b4c76e3d4c56f145d41b7b6751255feefae92edbc9a61e1758a98204200f30fc \ + --hash=sha256:b6f6a8f45d0313db07d6d1d37bd0b112f887e1369758a5419c0370ba915b3871 \ + --hash=sha256:c5a59996dc61835133b56a32ebe4ef3740ea5bc19b3983ac60cc32be5a665d54 \ + --hash=sha256:c73aafd1afca80afecb22718f8700b40ac7cab927b8abab3c3e337d70e10e5a2 \ + --hash=sha256:cee6cc0584f71adefe2c908856ccc98702baf95ff80092e4ca46061538a2ba98 \ + --hash=sha256:cef04d068f5fb0518a77857953193b6bb94809a806bd0a14983a8f12ada060c9 \ + --hash=sha256:cf5d1c9e6837f8af9f92b6bd3e86d513cdc11f60fd62185cc49ec7d1aba34864 \ + --hash=sha256:e61155fae27570692ad1d327e81c6cf27d535a5d7ef97648a17d922224b216de \ + --hash=sha256:e7f387600d424f91576af20518334df3d97bc76a300a755f9a8d6e4f5cadd289 \ + --hash=sha256:ed08d2703b5972ec736451b818c2eb9da80d66c3e84aed1deeb0c345fefe461b \ + --hash=sha256:fbd6acc766814ea6443628f4e6751d0da6593dae29c08c0b2606164db026970c \ + --hash=sha256:feff59f27338135776f6d4e2ec7aeeac5d5f7a08a83e80869121ef8164b74af9 + # via + # -r build/requirements.in + # -r build/test-requirements.txt + # contourpy + # matplotlib + # ml-dtypes + # opt-einsum + # scipy +opt-einsum==3.3.0 \ + --hash=sha256:2455e59e3947d3c275477df7f5205b30635e266fe6dc300e3d9f9646bfcea147 \ + --hash=sha256:59f6475f77bbc37dcf7cd748519c0ec60722e91e63ca114e68821c0c54a46549 + # via -r build/requirements.in +packaging==24.0 \ + --hash=sha256:2ddfb553fdf02fb784c234c7ba6ccc288296ceabec964ad2eae3777778130bc5 \ + --hash=sha256:eb82c5e3e56209074766e6885bb04b8c38a0c015d0a30036ebe7ece34c9989e9 + # via + # build + # matplotlib + # pytest +pillow==10.3.0 \ + --hash=sha256:048ad577748b9fa4a99a0548c64f2cb8d672d5bf2e643a739ac8faff1164238c \ + --hash=sha256:048eeade4c33fdf7e08da40ef402e748df113fd0b4584e32c4af74fe78baaeb2 \ + --hash=sha256:0ba26351b137ca4e0db0342d5d00d2e355eb29372c05afd544ebf47c0956ffeb \ + --hash=sha256:0ea2a783a2bdf2a561808fe4a7a12e9aa3799b701ba305de596bc48b8bdfce9d \ + --hash=sha256:1530e8f3a4b965eb6a7785cf17a426c779333eb62c9a7d1bbcf3ffd5bf77a4aa \ + --hash=sha256:16563993329b79513f59142a6b02055e10514c1a8e86dca8b48a893e33cf91e3 \ + --hash=sha256:19aeb96d43902f0a783946a0a87dbdad5c84c936025b8419da0a0cd7724356b1 \ + --hash=sha256:1a1d1915db1a4fdb2754b9de292642a39a7fb28f1736699527bb649484fb966a \ + --hash=sha256:1b87bd9d81d179bd8ab871603bd80d8645729939f90b71e62914e816a76fc6bd \ + --hash=sha256:1dfc94946bc60ea375cc39cff0b8da6c7e5f8fcdc1d946beb8da5c216156ddd8 \ + --hash=sha256:2034f6759a722da3a3dbd91a81148cf884e91d1b747992ca288ab88c1de15999 \ + --hash=sha256:261ddb7ca91fcf71757979534fb4c128448b5b4c55cb6152d280312062f69599 \ + --hash=sha256:2ed854e716a89b1afcedea551cd85f2eb2a807613752ab997b9974aaa0d56936 \ + --hash=sha256:3102045a10945173d38336f6e71a8dc71bcaeed55c3123ad4af82c52807b9375 \ + --hash=sha256:339894035d0ede518b16073bdc2feef4c991ee991a29774b33e515f1d308e08d \ + --hash=sha256:412444afb8c4c7a6cc11a47dade32982439925537e483be7c0ae0cf96c4f6a0b \ + --hash=sha256:4203efca580f0dd6f882ca211f923168548f7ba334c189e9eab1178ab840bf60 \ + --hash=sha256:45ebc7b45406febf07fef35d856f0293a92e7417ae7933207e90bf9090b70572 \ + --hash=sha256:4b5ec25d8b17217d635f8935dbc1b9aa5907962fae29dff220f2659487891cd3 \ + --hash=sha256:4c8e73e99da7db1b4cad7f8d682cf6abad7844da39834c288fbfa394a47bbced \ + --hash=sha256:4e6f7d1c414191c1199f8996d3f2282b9ebea0945693fb67392c75a3a320941f \ + --hash=sha256:4eaa22f0d22b1a7e93ff0a596d57fdede2e550aecffb5a1ef1106aaece48e96b \ + --hash=sha256:50b8eae8f7334ec826d6eeffaeeb00e36b5e24aa0b9df322c247539714c6df19 \ + --hash=sha256:50fd3f6b26e3441ae07b7c979309638b72abc1a25da31a81a7fbd9495713ef4f \ + --hash=sha256:51243f1ed5161b9945011a7360e997729776f6e5d7005ba0c6879267d4c5139d \ + --hash=sha256:5d512aafa1d32efa014fa041d38868fda85028e3f930a96f85d49c7d8ddc0383 \ + --hash=sha256:5f77cf66e96ae734717d341c145c5949c63180842a545c47a0ce7ae52ca83795 \ + --hash=sha256:6b02471b72526ab8a18c39cb7967b72d194ec53c1fd0a70b050565a0f366d355 \ + --hash=sha256:6fb1b30043271ec92dc65f6d9f0b7a830c210b8a96423074b15c7bc999975f57 \ + --hash=sha256:7161ec49ef0800947dc5570f86568a7bb36fa97dd09e9827dc02b718c5643f09 \ + --hash=sha256:72d622d262e463dfb7595202d229f5f3ab4b852289a1cd09650362db23b9eb0b \ + --hash=sha256:74d28c17412d9caa1066f7a31df8403ec23d5268ba46cd0ad2c50fb82ae40462 \ + --hash=sha256:78618cdbccaa74d3f88d0ad6cb8ac3007f1a6fa5c6f19af64b55ca170bfa1edf \ + --hash=sha256:793b4e24db2e8742ca6423d3fde8396db336698c55cd34b660663ee9e45ed37f \ + --hash=sha256:798232c92e7665fe82ac085f9d8e8ca98826f8e27859d9a96b41d519ecd2e49a \ + --hash=sha256:81d09caa7b27ef4e61cb7d8fbf1714f5aec1c6b6c5270ee53504981e6e9121ad \ + --hash=sha256:8ab74c06ffdab957d7670c2a5a6e1a70181cd10b727cd788c4dd9005b6a8acd9 \ + --hash=sha256:8eb0908e954d093b02a543dc963984d6e99ad2b5e36503d8a0aaf040505f747d \ + --hash=sha256:90b9e29824800e90c84e4022dd5cc16eb2d9605ee13f05d47641eb183cd73d45 \ + --hash=sha256:9797a6c8fe16f25749b371c02e2ade0efb51155e767a971c61734b1bf6293994 \ + --hash=sha256:9d2455fbf44c914840c793e89aa82d0e1763a14253a000743719ae5946814b2d \ + --hash=sha256:9d3bea1c75f8c53ee4d505c3e67d8c158ad4df0d83170605b50b64025917f338 \ + --hash=sha256:9e2ec1e921fd07c7cda7962bad283acc2f2a9ccc1b971ee4b216b75fad6f0463 \ + --hash=sha256:9e91179a242bbc99be65e139e30690e081fe6cb91a8e77faf4c409653de39451 \ + --hash=sha256:a0eaa93d054751ee9964afa21c06247779b90440ca41d184aeb5d410f20ff591 \ + --hash=sha256:a2c405445c79c3f5a124573a051062300936b0281fee57637e706453e452746c \ + --hash=sha256:aa7e402ce11f0885305bfb6afb3434b3cd8f53b563ac065452d9d5654c7b86fd \ + --hash=sha256:aff76a55a8aa8364d25400a210a65ff59d0168e0b4285ba6bf2bd83cf675ba32 \ + --hash=sha256:b09b86b27a064c9624d0a6c54da01c1beaf5b6cadfa609cf63789b1d08a797b9 \ + --hash=sha256:b14f16f94cbc61215115b9b1236f9c18403c15dd3c52cf629072afa9d54c1cbf \ + --hash=sha256:b50811d664d392f02f7761621303eba9d1b056fb1868c8cdf4231279645c25f5 \ + --hash=sha256:b7bc2176354defba3edc2b9a777744462da2f8e921fbaf61e52acb95bafa9828 \ + --hash=sha256:c78e1b00a87ce43bb37642c0812315b411e856a905d58d597750eb79802aaaa3 \ + --hash=sha256:c83341b89884e2b2e55886e8fbbf37c3fa5efd6c8907124aeb72f285ae5696e5 \ + --hash=sha256:ca2870d5d10d8726a27396d3ca4cf7976cec0f3cb706debe88e3a5bd4610f7d2 \ + --hash=sha256:ccce24b7ad89adb5a1e34a6ba96ac2530046763912806ad4c247356a8f33a67b \ + --hash=sha256:cd5e14fbf22a87321b24c88669aad3a51ec052eb145315b3da3b7e3cc105b9a2 \ + --hash=sha256:ce49c67f4ea0609933d01c0731b34b8695a7a748d6c8d186f95e7d085d2fe475 \ + --hash=sha256:d33891be6df59d93df4d846640f0e46f1a807339f09e79a8040bc887bdcd7ed3 \ + --hash=sha256:d3b2348a78bc939b4fed6552abfd2e7988e0f81443ef3911a4b8498ca084f6eb \ + --hash=sha256:d886f5d353333b4771d21267c7ecc75b710f1a73d72d03ca06df49b09015a9ef \ + --hash=sha256:d93480005693d247f8346bc8ee28c72a2191bdf1f6b5db469c096c0c867ac015 \ + --hash=sha256:dc1a390a82755a8c26c9964d457d4c9cbec5405896cba94cf51f36ea0d855002 \ + --hash=sha256:dd78700f5788ae180b5ee8902c6aea5a5726bac7c364b202b4b3e3ba2d293170 \ + --hash=sha256:e46f38133e5a060d46bd630faa4d9fa0202377495df1f068a8299fd78c84de84 \ + --hash=sha256:e4b878386c4bf293578b48fc570b84ecfe477d3b77ba39a6e87150af77f40c57 \ + --hash=sha256:f0d0591a0aeaefdaf9a5e545e7485f89910c977087e7de2b6c388aec32011e9f \ + --hash=sha256:fdcbb4068117dfd9ce0138d068ac512843c52295ed996ae6dd1faf537b6dbc27 \ + --hash=sha256:ff61bfd9253c3915e6d41c651d5f962da23eda633cf02262990094a18a55371a + # via + # -r build/test-requirements.txt + # matplotlib +pluggy==1.5.0 \ + --hash=sha256:2cffa88e94fdc978c4c574f15f9e59b7f4201d439195c3715ca9e2486f1d0cf1 \ + --hash=sha256:44e1ad92c8ca002de6377e165f3e0f1be63266ab4d554740532335b9d75ea669 + # via pytest +portpicker==1.6.0 \ + --hash=sha256:b2787a41404cf7edbe29b07b9e0ed863b09f2665dcc01c1eb0c2261c1e7d0755 \ + --hash=sha256:bd507fd6f96f65ee02781f2e674e9dc6c99bbfa6e3c39992e3916204c9d431fa + # via -r build/test-requirements.txt +psutil==5.9.8 \ + --hash=sha256:02615ed8c5ea222323408ceba16c60e99c3f91639b07da6373fb7e6539abc56d \ + --hash=sha256:05806de88103b25903dff19bb6692bd2e714ccf9e668d050d144012055cbca73 \ + --hash=sha256:26bd09967ae00920df88e0352a91cff1a78f8d69b3ecabbfe733610c0af486c8 \ + --hash=sha256:27cc40c3493bb10de1be4b3f07cae4c010ce715290a5be22b98493509c6299e2 \ + --hash=sha256:36f435891adb138ed3c9e58c6af3e2e6ca9ac2f365efe1f9cfef2794e6c93b4e \ + --hash=sha256:50187900d73c1381ba1454cf40308c2bf6f34268518b3f36a9b663ca87e65e36 \ + --hash=sha256:611052c4bc70432ec770d5d54f64206aa7203a101ec273a0cd82418c86503bb7 \ + --hash=sha256:6be126e3225486dff286a8fb9a06246a5253f4c7c53b475ea5f5ac934e64194c \ + --hash=sha256:7d79560ad97af658a0f6adfef8b834b53f64746d45b403f225b85c5c2c140eee \ + --hash=sha256:8cb6403ce6d8e047495a701dc7c5bd788add903f8986d523e3e20b98b733e421 \ + --hash=sha256:8db4c1b57507eef143a15a6884ca10f7c73876cdf5d51e713151c1236a0e68cf \ + --hash=sha256:aee678c8720623dc456fa20659af736241f575d79429a0e5e9cf88ae0605cc81 \ + --hash=sha256:bc56c2a1b0d15aa3eaa5a60c9f3f8e3e565303b465dbf57a1b730e7a2b9844e0 \ + --hash=sha256:bd1184ceb3f87651a67b2708d4c3338e9b10c5df903f2e3776b62303b26cb631 \ + --hash=sha256:d06016f7f8625a1825ba3732081d77c94589dca78b7a3fc072194851e88461a4 \ + --hash=sha256:d16bbddf0693323b8c6123dd804100241da461e41d6e332fb0ba6058f630f8c8 + # via portpicker +pygments==2.18.0 \ + --hash=sha256:786ff802f32e91311bff3889f6e9a86e81505fe99f2735bb6d60ae0c5004f199 \ + --hash=sha256:b8e6aca0523f3ab76fee51799c488e38782ac06eafcf95e7ba832985c8e7b13a + # via rich +pyparsing==3.1.2 \ + --hash=sha256:a1bac0ce561155ecc3ed78ca94d3c9378656ad4c94c1270de543f621420f94ad \ + --hash=sha256:f9db75911801ed778fe61bb643079ff86601aca99fcae6345aa67292038fb742 + # via matplotlib +pyproject-hooks==1.1.0 \ + --hash=sha256:4b37730834edbd6bd37f26ece6b44802fb1c1ee2ece0e54ddff8bfc06db86965 \ + --hash=sha256:7ceeefe9aec63a1064c18d939bdc3adf2d8aa1988a510afec15151578b232aa2 + # via build +pytest==8.2.0 \ + --hash=sha256:1733f0620f6cda4095bbf0d9ff8022486e91892245bb9e7d5542c018f612f233 \ + --hash=sha256:d507d4482197eac0ba2bae2e9babf0672eb333017bcedaa5fb1a3d42c1174b3f + # via pytest-xdist +pytest-xdist==3.6.1 \ + --hash=sha256:9ed4adfb68a016610848639bb7e02c9352d5d9f03d04809919e2dafc3be4cca7 \ + --hash=sha256:ead156a4db231eec769737f57668ef58a2084a34b2e55c4a8fa20d861107300d + # via -r build/test-requirements.txt +python-dateutil==2.9.0.post0 \ + --hash=sha256:37dd54208da7e1cd875388217d5e00ebd4179249f90fb72437e91a35459a0ad3 \ + --hash=sha256:a8b2bc7bffae282281c8140a97d3aa9c14da0b136dfe83f850eea9a5f7470427 + # via matplotlib +rich==13.7.1 \ + --hash=sha256:4edbae314f59eb482f54e9e30bf00d33350aaa94f4bfcd4e9e3110e64d0d7222 \ + --hash=sha256:9be308cb1fe2f1f57d67ce99e95af38a1e2bc71ad9813b0e247cf7ffbcc3a432 + # via -r build/test-requirements.txt +scipy==1.13.1 \ + --hash=sha256:017367484ce5498445aade74b1d5ab377acdc65e27095155e448c88497755a5d \ + --hash=sha256:095a87a0312b08dfd6a6155cbbd310a8c51800fc931b8c0b84003014b874ed3c \ + --hash=sha256:20335853b85e9a49ff7572ab453794298bcf0354d8068c5f6775a0eabf350aca \ + --hash=sha256:27e52b09c0d3a1d5b63e1105f24177e544a222b43611aaf5bc44d4a0979e32f9 \ + --hash=sha256:2831f0dc9c5ea9edd6e51e6e769b655f08ec6db6e2e10f86ef39bd32eb11da54 \ + --hash=sha256:2ac65fb503dad64218c228e2dc2d0a0193f7904747db43014645ae139c8fad16 \ + --hash=sha256:392e4ec766654852c25ebad4f64e4e584cf19820b980bc04960bca0b0cd6eaa2 \ + --hash=sha256:436bbb42a94a8aeef855d755ce5a465479c721e9d684de76bf61a62e7c2b81d5 \ + --hash=sha256:45484bee6d65633752c490404513b9ef02475b4284c4cfab0ef946def50b3f59 \ + --hash=sha256:54f430b00f0133e2224c3ba42b805bfd0086fe488835effa33fa291561932326 \ + --hash=sha256:5713f62f781eebd8d597eb3f88b8bf9274e79eeabf63afb4a737abc6c84ad37b \ + --hash=sha256:5d72782f39716b2b3509cd7c33cdc08c96f2f4d2b06d51e52fb45a19ca0c86a1 \ + --hash=sha256:637e98dcf185ba7f8e663e122ebf908c4702420477ae52a04f9908707456ba4d \ + --hash=sha256:8335549ebbca860c52bf3d02f80784e91a004b71b059e3eea9678ba994796a24 \ + --hash=sha256:949ae67db5fa78a86e8fa644b9a6b07252f449dcf74247108c50e1d20d2b4627 \ + --hash=sha256:a014c2b3697bde71724244f63de2476925596c24285c7a637364761f8710891c \ + --hash=sha256:a78b4b3345f1b6f68a763c6e25c0c9a23a9fd0f39f5f3d200efe8feda560a5fa \ + --hash=sha256:cdd7dacfb95fea358916410ec61bbc20440f7860333aee6d882bb8046264e949 \ + --hash=sha256:cfa31f1def5c819b19ecc3a8b52d28ffdcc7ed52bb20c9a7589669dd3c250989 \ + --hash=sha256:d533654b7d221a6a97304ab63c41c96473ff04459e404b83275b60aa8f4b7004 \ + --hash=sha256:d605e9c23906d1994f55ace80e0125c587f96c020037ea6aa98d01b4bd2e222f \ + --hash=sha256:de3ade0e53bc1f21358aa74ff4830235d716211d7d077e340c7349bc3542e884 \ + --hash=sha256:e89369d27f9e7b0884ae559a3a956e77c02114cc60a6058b4e5011572eea9299 \ + --hash=sha256:eccfa1906eacc02de42d70ef4aecea45415f5be17e72b61bafcfd329bdc52e94 \ + --hash=sha256:f26264b282b9da0952a024ae34710c2aff7d27480ee91a2e82b7b7073c24722f + # via -r build/requirements.in +six==1.16.0 \ + --hash=sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926 \ + --hash=sha256:8abb2f1d86890a2dfb989f9a77cfcfd3e47c2a354b01111771326f8aa26e0254 + # via python-dateutil +sortedcontainers==2.4.0 \ + --hash=sha256:25caa5a06cc30b6b83d11423433f65d1f9d76c4c6a0c90e3379eaa43b9bfdb88 \ + --hash=sha256:a163dcaede0f1c021485e957a39245190e74249897e2ae4b2aa38595db237ee0 + # via hypothesis +typing-extensions==4.12.0rc1 \ + --hash=sha256:be199d06d8f09ca2c9425e3aa04a9afba33e892fe079dea959e72df7f8442343 \ + --hash=sha256:f933a7b288a919ca97adbff656e52ff81f7ff25d98a2aabb9355ca4090f772fe + # via etils +wheel==0.43.0 \ + --hash=sha256:465ef92c69fa5c5da2d1cf8ac40559a8c940886afcef87dcf14b9470862f1d85 \ + --hash=sha256:55c570405f142630c6b9f72fe09d9b67cf1477fcf543ae5b8dcb1f5b7377da81 + # via -r build/test-requirements.txt +zipp==3.18.2 \ + --hash=sha256:6278d9ddbcfb1f1089a88fde84481528b07b0e10474e09dcfe53dad4069fa059 \ + --hash=sha256:dce197b859eb796242b0622af1b8beb0a722d52aa2f57133ead08edd5bf5374e + # via etils +zstandard==0.22.0 \ + --hash=sha256:11f0d1aab9516a497137b41e3d3ed4bbf7b2ee2abc79e5c8b010ad286d7464bd \ + --hash=sha256:1958100b8a1cc3f27fa21071a55cb2ed32e9e5df4c3c6e661c193437f171cba2 \ + --hash=sha256:1a90ba9a4c9c884bb876a14be2b1d216609385efb180393df40e5172e7ecf356 \ + --hash=sha256:1d43501f5f31e22baf822720d82b5547f8a08f5386a883b32584a185675c8fbf \ + --hash=sha256:23d2b3c2b8e7e5a6cb7922f7c27d73a9a615f0a5ab5d0e03dd533c477de23004 \ + --hash=sha256:2612e9bb4977381184bb2463150336d0f7e014d6bb5d4a370f9a372d21916f69 \ + --hash=sha256:275df437ab03f8c033b8a2c181e51716c32d831082d93ce48002a5227ec93019 \ + --hash=sha256:2ac9957bc6d2403c4772c890916bf181b2653640da98f32e04b96e4d6fb3252a \ + --hash=sha256:2b11ea433db22e720758cba584c9d661077121fcf60ab43351950ded20283440 \ + --hash=sha256:2fdd53b806786bd6112d97c1f1e7841e5e4daa06810ab4b284026a1a0e484c0b \ + --hash=sha256:33591d59f4956c9812f8063eff2e2c0065bc02050837f152574069f5f9f17775 \ + --hash=sha256:36a47636c3de227cd765e25a21dc5dace00539b82ddd99ee36abae38178eff9e \ + --hash=sha256:39b2853efc9403927f9065cc48c9980649462acbdf81cd4f0cb773af2fd734bc \ + --hash=sha256:3db41c5e49ef73641d5111554e1d1d3af106410a6c1fb52cf68912ba7a343a0d \ + --hash=sha256:445b47bc32de69d990ad0f34da0e20f535914623d1e506e74d6bc5c9dc40bb09 \ + --hash=sha256:466e6ad8caefb589ed281c076deb6f0cd330e8bc13c5035854ffb9c2014b118c \ + --hash=sha256:48f260e4c7294ef275744210a4010f116048e0c95857befb7462e033f09442fe \ + --hash=sha256:4ac59d5d6910b220141c1737b79d4a5aa9e57466e7469a012ed42ce2d3995e88 \ + --hash=sha256:53866a9d8ab363271c9e80c7c2e9441814961d47f88c9bc3b248142c32141d94 \ + --hash=sha256:589402548251056878d2e7c8859286eb91bd841af117dbe4ab000e6450987e08 \ + --hash=sha256:68953dc84b244b053c0d5f137a21ae8287ecf51b20872eccf8eaac0302d3e3b0 \ + --hash=sha256:6c25b8eb733d4e741246151d895dd0308137532737f337411160ff69ca24f93a \ + --hash=sha256:7034d381789f45576ec3f1fa0e15d741828146439228dc3f7c59856c5bcd3292 \ + --hash=sha256:73a1d6bd01961e9fd447162e137ed949c01bdb830dfca487c4a14e9742dccc93 \ + --hash=sha256:8226a33c542bcb54cd6bd0a366067b610b41713b64c9abec1bc4533d69f51e70 \ + --hash=sha256:888196c9c8893a1e8ff5e89b8f894e7f4f0e64a5af4d8f3c410f0319128bb2f8 \ + --hash=sha256:88c5b4b47a8a138338a07fc94e2ba3b1535f69247670abfe422de4e0b344aae2 \ + --hash=sha256:8a1b2effa96a5f019e72874969394edd393e2fbd6414a8208fea363a22803b45 \ + --hash=sha256:93e1856c8313bc688d5df069e106a4bc962eef3d13372020cc6e3ebf5e045202 \ + --hash=sha256:9501f36fac6b875c124243a379267d879262480bf85b1dbda61f5ad4d01b75a3 \ + --hash=sha256:959665072bd60f45c5b6b5d711f15bdefc9849dd5da9fb6c873e35f5d34d8cfb \ + --hash=sha256:a1d67d0d53d2a138f9e29d8acdabe11310c185e36f0a848efa104d4e40b808e4 \ + --hash=sha256:a493d470183ee620a3df1e6e55b3e4de8143c0ba1b16f3ded83208ea8ddfd91d \ + --hash=sha256:a7ccf5825fd71d4542c8ab28d4d482aace885f5ebe4b40faaa290eed8e095a4c \ + --hash=sha256:a88b7df61a292603e7cd662d92565d915796b094ffb3d206579aaebac6b85d5f \ + --hash=sha256:a97079b955b00b732c6f280d5023e0eefe359045e8b83b08cf0333af9ec78f26 \ + --hash=sha256:d22fdef58976457c65e2796e6730a3ea4a254f3ba83777ecfc8592ff8d77d303 \ + --hash=sha256:d75f693bb4e92c335e0645e8845e553cd09dc91616412d1d4650da835b5449df \ + --hash=sha256:d8593f8464fb64d58e8cb0b905b272d40184eac9a18d83cf8c10749c3eafcd7e \ + --hash=sha256:d8fff0f0c1d8bc5d866762ae95bd99d53282337af1be9dc0d88506b340e74b73 \ + --hash=sha256:de20a212ef3d00d609d0b22eb7cc798d5a69035e81839f549b538eff4105d01c \ + --hash=sha256:e9e9d4e2e336c529d4c435baad846a181e39a982f823f7e4495ec0b0ec8538d2 \ + --hash=sha256:f058a77ef0ece4e210bb0450e68408d4223f728b109764676e1a13537d056bb0 \ + --hash=sha256:f1a4b358947a65b94e2501ce3e078bbc929b039ede4679ddb0460829b12f7375 \ + --hash=sha256:f9b2cde1cd1b2a10246dbc143ba49d942d14fb3d2b4bccf4618d475c65464912 \ + --hash=sha256:fe3390c538f12437b859d815040763abc728955a52ca6ff9c5d4ac707c4ad98e + # via -r build/requirements.in + +# The following packages are considered to be unsafe in a requirements file: +setuptools==69.5.1 \ + --hash=sha256:6c1fccdac05a97e598fb0ae3bbed5904ccb317337a51139dcd51453611bbb987 \ + --hash=sha256:c636ac361bc47580504644275c9ad802c50415c7522212252c033bd15f301f32 + # via -r build/test-requirements.txt diff --git a/build/requirements_lock_3_13.txt b/build/requirements_lock_3_13.txt new file mode 100644 index 000000000000..62b5e14e65b4 --- /dev/null +++ b/build/requirements_lock_3_13.txt @@ -0,0 +1,106 @@ +# +# This file is autogenerated by pip-compile with Python 3.12 +# by the following command: +# +# bazel run //build:requirements_dev.update +# +absl-py==2.1.0 + # via -r build/test-requirements.txt +attrs==23.2.0 + # via hypothesis +build==1.2.1 + # via -r build/test-requirements.txt +cloudpickle==3.0.0 + # via -r build/test-requirements.txt +colorama==0.4.6 + # via -r build/test-requirements.txt +contourpy==1.2.1 + # via matplotlib +cycler==0.12.1 + # via matplotlib +etils[epath,epy]==1.8.0 + # via -r build/requirements.in +execnet==2.1.1 + # via pytest-xdist +flatbuffers==24.3.25 + # via -r build/test-requirements.txt +fonttools==4.51.0 + # via matplotlib +fsspec==2024.3.1 + # via etils +hypothesis==6.100.1 + # via -r build/test-requirements.txt +importlib-resources==6.4.0 + # via etils +iniconfig==2.0.0 + # via pytest +kiwisolver==1.4.5 + # via matplotlib +markdown-it-py==3.0.0 + # via rich +matplotlib==3.8.3 + # via -r build/requirements.in +mdurl==0.1.2 + # via markdown-it-py +ml-dtypes==0.4.0 + # via -r build/requirements.in +mpmath==1.3.0 + # via -r build/test-requirements.txt +numpy==1.26.4 + # via + # -r build/requirements.in + # -r build/test-requirements.txt + # contourpy + # matplotlib + # ml-dtypes + # opt-einsum + # scipy +opt-einsum==3.3.0 + # via -r build/requirements.in +packaging==24.0 + # via + # build + # matplotlib + # pytest +pillow==10.3.0 + # via + # -r build/test-requirements.txt + # matplotlib +pluggy==1.4.0 + # via pytest +portpicker==1.6.0 + # via -r build/test-requirements.txt +psutil==5.9.8 + # via portpicker +pygments==2.17.2 + # via rich +pyparsing==3.1.2 + # via matplotlib +pyproject-hooks==1.0.0 + # via build +pytest==8.1.1 + # via pytest-xdist +pytest-xdist==3.5.0 + # via -r build/test-requirements.txt +python-dateutil==2.9.0.post0 + # via matplotlib +rich==13.7.1 + # via -r build/test-requirements.txt +scipy==1.13.1 + # via -r build/requirements.in +six==1.16.0 + # via python-dateutil +sortedcontainers==2.4.0 + # via hypothesis +typing-extensions==4.11.0 + # via etils +wheel==0.43.0 + # via -r build/test-requirements.txt +zipp==3.18.1 + # via etils +zstandard==0.22.0 + # via -r build/requirements.in + +# The following packages are considered to be unsafe in a requirements file: +setuptools==69.2.0 + # via -r build/test-requirements.txt diff --git a/build/rocm/build_rocm.sh b/build/rocm/build_rocm.sh index 64e166239fc8..6374a2a18929 100755 --- a/build/rocm/build_rocm.sh +++ b/build/rocm/build_rocm.sh @@ -57,7 +57,7 @@ rocm_version=$(cat /opt/rocm/.info/version | cut -d "-" -f 1) export JAX_ROCM_VERSION=${rocm_version//./} #Build and install wheel -python3 ./build/build.py --enable_rocm --rocm_path=${ROCM_PATH} --bazel_options=--override_repository=xla=${XLA_CLONE_DIR} +python3 ./build/build.py --enable_rocm --build_gpu_plugin --gpu_plugin_rocm_version=60 --rocm_path=${ROCM_PATH} --bazel_options=--override_repository=xla=${XLA_CLONE_DIR} JAX_RELEASE=1 python -m build pip3 install --force-reinstall dist/*.whl # installs jaxlib (includes XLA) diff --git a/build/rocm/run_single_gpu.py b/build/rocm/run_single_gpu.py index bf37a49ee61e..add7ee3d86b5 100755 --- a/build/rocm/run_single_gpu.py +++ b/build/rocm/run_single_gpu.py @@ -22,7 +22,7 @@ GPU_LOCK = threading.Lock() LAST_CODE = 0 -base_dir="./logs" +base_dir = "./logs" def extract_filename(path): base_name = os.path.basename(path) @@ -32,7 +32,7 @@ def extract_filename(path): def generate_final_report(shell=False, env_vars={}): env = os.environ env = {**env, **env_vars} - cmd = ["pytest_html_merger", "-i", '{}'.format(base_dir), "-o", '{}/final_compiled_report.html'.format(base_dir)] + cmd = ["pytest_html_merger", "-i", f'{base_dir}', "-o", f'{base_dir}/final_compiled_report.html'] result = subprocess.run(cmd, shell=shell, capture_output=True, @@ -90,7 +90,7 @@ def run_test(testmodule, gpu_tokens): "XLA_PYTHON_CLIENT_ALLOCATOR": "default", } testfile = extract_filename(testmodule) - cmd = ["python3", "-m", "pytest", '--html={}/{}_log.html'.format(base_dir, testfile), "--reruns", "3", "-x", testmodule] + cmd = ["python3", "-m", "pytest", f'--html={base_dir}/{testfile}_log.html', "--reruns", "3", "-x", testmodule] return_code, stderr, stdout = run_shell_command(cmd, env_vars=env_vars) with GPU_LOCK: gpu_tokens.append(target_gpu) diff --git a/build/test-requirements.txt b/build/test-requirements.txt index 0744b2ac312e..4f9d19e76ba2 100644 --- a/build/test-requirements.txt +++ b/build/test-requirements.txt @@ -2,8 +2,10 @@ absl-py build cloudpickle colorama>=0.4.4 +filelock flatbuffers hypothesis +mpmath>=1.3 numpy>=1.22 pillow>=9.1.0 portpicker diff --git a/cloud_tpu_colabs/Pmap_Cookbook.ipynb b/cloud_tpu_colabs/Pmap_Cookbook.ipynb index cb2bd79ef53d..4f4ba8c165a3 100644 --- a/cloud_tpu_colabs/Pmap_Cookbook.ipynb +++ b/cloud_tpu_colabs/Pmap_Cookbook.ipynb @@ -19,7 +19,7 @@ "\n", "This notebook is an introduction to writing single-program multiple-data (SPMD) programs in JAX, and executing them synchronously in parallel on multiple devices, such as multiple GPUs or multiple TPU cores. The SPMD model is useful for computations like training neural networks with synchronous gradient descent algorithms, and can be used for data-parallel as well as model-parallel computations.\n", "\n", - "**Note:** To run this notebook with any parallelism, you'll need multiple XLA devices available, e.g. with a multi-GPU machine, a Google Cloud TPU or a Kaggle TPU VM. The required features are not supported by the Google Colab TPU runtime at this time.\n", + "**Note:** To run this notebook with any parallelism, you'll need multiple XLA devices available, e.g. with a multi-GPU machine, a Colab TPU, a Google Cloud TPU or a Kaggle TPU VM.\n", "\n", "The code in this notebook is simple. For an example of how to use these tools to do data-parallel neural network training, check out [the SPMD MNIST example](https://github.com/google/jax/blob/main/examples/spmd_mnist_classifier_fromscratch.py) or the much more capable [Trax library](https://github.com/google/trax/)." ] diff --git a/cloud_tpu_colabs/README.md b/cloud_tpu_colabs/README.md index 106485996edc..4a795f718c84 100644 --- a/cloud_tpu_colabs/README.md +++ b/cloud_tpu_colabs/README.md @@ -9,16 +9,6 @@ computation across multiple TPU cores from Colab. You can also run the same code directly on a [Cloud TPU VM](https://cloud.google.com/tpu/docs/jax-quickstart-tpu-vm). -## Update (June 2021): introducing Cloud TPU VMs - -A new Cloud TPU architecture was recently -[announced](https://cloud.google.com/blog/products/compute/introducing-cloud-tpu-vms) -that gives you direct access to a VM with TPUs attached, enabling significant -performance and usability improvements when using JAX on Cloud TPU. As of -writing, Colab still uses the previous architecture, but the same JAX code -generally will run on either architecture (there are a few features that are -only available with the new architecture, such as complex number support). - ## Example Cloud TPU notebooks The following notebooks showcase how to use and what you can do with Cloud TPUs on Colab: diff --git a/docs/Custom_Operation_for_GPUs.md b/docs/Custom_Operation_for_GPUs.md index 44cdaf1f15e8..2d9c43831e4d 100644 --- a/docs/Custom_Operation_for_GPUs.md +++ b/docs/Custom_Operation_for_GPUs.md @@ -1,5 +1,7 @@ # Custom operations for GPUs with C++ and CUDA + + JAX ships with a large number of built-in operations, but users occasionally run into a situation where they need a new operation that is not supported by JAX. To accommodate such scenarios, JAX allows users to define custom operations and this tutorial is to explain how we can define one for GPUs and use it in single-GPU and multi-GPU environments. @@ -34,7 +36,7 @@ You need to follow these steps in Python: * Define its abstract evaluation. * Define its lowering to MLIR. * [Optional] Define the gradient. -* [Optional] Use [xmap](https://jax.readthedocs.io/en/latest/notebooks/xmap_tutorial.html) (or one of the experimental [custom_partitioning](https://jax.readthedocs.io/en/latest/jax.experimental.custom_partitioning.html) or [shard_map](https://jax.readthedocs.io/en/latest/jep/14273-shard-map.html) functions) for fast multi-GPU. +* [Optional] Use [custom_partitioning](https://jax.readthedocs.io/en/latest/jax.experimental.custom_partitioning.html) or [shard_map](https://jax.readthedocs.io/en/latest/jep/14273-shard-map.html) functions for fast multi-GPU. ## C code @@ -525,7 +527,7 @@ with mesh: print(pjitted.lower(x, weight).compile().runtime_executable().hlo_modules()[0].to_string()) out = pjitted(x, weight) -jnp.allclose(ref, out, atol=1e-2, rtol=1e-2) +jnp.allclose(ref, out, atol=1e-5, rtol=1e-5) ``` ```python HloModule pjit_rms_norm, entry_computation_layout={(bf16[4,512,512]{2,1,0},bf16[512,512]{1,0})->bf16[4,512,512]{2,1,0}} @@ -554,1661 +556,319 @@ ENTRY %main.7_spmd (param: bf16[4,512,512], param.1: bf16[512,512]) -> bf16[4,51 True ``` -The values have been computed correctly for forward operation, however, the generated HLO modules shows an `all-gather` operation to replicate `x` on all devices, incurring large communication overhead. +The values have been computed correctly for forward operation, however, the generated HLO modules show an `all-gather` operation to replicate `x` on all devices, incurring large communication overhead. As XLA does not have enough knowledge about the custom functions to shard input tensors, it decides to replicate them to produce correct values before making the custom call. -To avoid this overhead, we need to use the xmap manual sharding with the following configuration updates - -```python -jax.config.update("experimental_xmap_spmd_lowering", True) -jax.config.update("experimental_xmap_spmd_lowering_manual", True) -``` +To avoid this duplication, we can: +- [custom_partitioning](https://jax.readthedocs.io/en/latest/jax.experimental.custom_partitioning.html): to make it behave like all native JAX operations (but more complicated) +- Use manual sharding + - [shard_map](https://jax.readthedocs.io/en/latest/jep/14273-shard-map.html): the new replacement for xmap + - [xmap](https://jax.readthedocs.io/en/latest/notebooks/xmap_tutorial.html) (now deprecated) -We need to modify the test code to use the xmap manual sharding with the custom operation. +This example demonstrates the use of custom_partitioning. -We first define a function that wraps `rms_norm` with `xmap`. As the size of the data axis that is being sharded must match the size of the corresponding mesh axis in the xmap manual sharding mode, we reshape `x` with the new shape `(device_count, x.shape[0] // device_count, *x.shape[1:])`, and `device_count` represents the size of the corresponding mesh axis. +### Shard the forward function with custom_partitioning -After running `rms_norm` through `xmap`, we reshape the output to match the shape of `x` to match the expectation from clients. +We first create a helper function to help with all the JAX/XLA callback registration required. ```python -from jax.experimental.maps import xmap - - -def xmap_rms_norm(x, weight, *, device_count): - reshaped = x.reshape(device_count, x.shape[0] // device_count, *x.shape[1:]) - xmapped = xmap( - rms_norm, - in_axes=(("x", None, None, None), (None, None)), - out_axes=("x", None, None, None), - axis_resources={"x": "x"}, - ) - reshaped_out = xmapped(reshaped, weight) - return reshaped_out.reshape(x.shape) -``` - -Now we need to run `xmap_rms_norm`, not `rms_norm` through `pjit`. - -```python -with mesh: - - pjitted = pjit( - partial(xmap_rms_norm, device_count=jax.local_device_count()), - # Shard x by batch dimension and replicate weight on all devices. - in_shardings=( - PartitionSpec("x", None, None), - PartitionSpec(None, None), - ), - # Shard the output by batch dimension. - out_shardings=PartitionSpec("x", None, None), - ) - print(pjitted.lower(x, weight).compile().runtime_executable().hlo_modules()[0].to_string()) - out = pjitted(x, weight) - -jnp.allclose(ref, out, atol=1e-2, rtol=1e-2) -``` -```python -HloModule pjit__unnamed_wrapped_function_, entry_computation_layout={(bf16[4,512,512]{2,1,0},bf16[512,512]{1,0})->bf16[4,512,512]{2,1,0}} - -ENTRY %main.17_spmd (param: bf16[4,512,512], param.1: bf16[512,512]) -> bf16[4,512,512] { - %param = bf16[4,512,512]{2,1,0} parameter(0), sharding={devices=[8,1,1]0,1,2,3,4,5,6,7}, metadata={op_name="pjit()/jit(main)/xmap(rms_norm)/squeeze[dimensions=(0,)]" source_file="/tmp/ipykernel_25235/3123505662.py" source_line=13} - %param.1 = bf16[512,512]{1,0} parameter(1), sharding={replicated}, metadata={op_name="pjit()/jit(main)/xmap(rms_norm)/full_to_shard[axes=OrderedDict() mesh=Mesh(device_ids=array([0, 1, 2, 3, 4, 5, 6, 7]), axis_names=(\'x\',)) manual_axes=(\'x\',)]" source_file="/tmp/ipykernel_25235/3123505662.py" source_line=13} - %custom-call.0 = (bf16[4,512,512]{2,1,0}, f32[4]{0}) custom-call(bf16[4,512,512]{2,1,0} %param, bf16[512,512]{1,0} %param.1), custom_call_target="rms_forward_affine_mixed_dtype", operand_layout_constraints={bf16[4,512,512]{2,1,0}, bf16[512,512]{1,0}}, api_version=API_VERSION_STATUS_RETURNING, metadata={op_name="pjit()/jit(main)/xmap(rms_norm)/rms_norm_fwd[eps=1e-05]" source_file="/tmp/ipykernel_25235/3343076723.py" source_line=8}, backend_config="\004\000\000\000\000\000\004\000\361h\343\210\265\370\344>\000\000\000\000\000\000\000\000\000\000\000\000\027\177\000\000" - ROOT %get-tuple-element = bf16[4,512,512]{2,1,0} get-tuple-element((bf16[4,512,512]{2,1,0}, f32[4]{0}) %custom-call.0), index=0, metadata={op_name="pjit()/jit(main)/xmap(rms_norm)/rms_norm_fwd[eps=1e-05]" source_file="/tmp/ipykernel_25235/3343076723.py" source_line=8} -} -``` -```python -True +def register_primitive(cls): + """ + register jax primitive + + The order of calls. Each operation is composed of two primitives: Inner and Outer. + + Inner, only the basic to wrap the custom_call itself. + - impl to XLA custom_call in C. + - abstract to know the static shapes + - lower to StableHLO XLA custom_call. + Outer, mostly all the rest: + - impl: Bind to the inner primitive. Not used for real computation, but only for tracing. So we only need to bind. + - abstract: same + - lower to StableHLO custom_p. (XLA will call the python callback from it) + - custom_p + - vmap: could be added here. + VJP is based on Outer, but not handled in this function. + """ + + def name_of_wrapper_p(): + return cls.name + "_wrapper" + + inner_p = core.Primitive(cls.name) + dispatch.prim_requires_devices_during_lowering.add(inner_p) + inner_p.multiple_results = cls.multiple_results + inner_p.def_impl(partial(xla.apply_primitive, inner_p)) + inner_p.def_abstract_eval(cls.abstract) + mlir.register_lowering(inner_p, cls.lowering, platform='cuda') + cls.inner_primitive = inner_p + + outer_p = core.Primitive(name_of_wrapper_p()) + dispatch.prim_requires_devices_during_lowering.add(outer_p) + outer_p.multiple_results = cls.multiple_results + outer_p.def_impl(cls.impl) + outer_p.def_abstract_eval(cls.abstract) + batching.primitive_batchers[outer_p] = cls.batcher + outer_p_lower = custom_partitioning(cls.impl, static_argnums=cls.impl_static_args) + outer_p_lower.def_partition(infer_sharding_from_operands=cls.infer_sharding_from_operands, + partition=cls.partition) + mlir.register_lowering(outer_p, + mlir.lower_fun(outer_p_lower, multiple_results=cls.multiple_results)) + cls.outer_primitive = outer_p +... ``` -With this modification, the `all-gather` operation is eliminated and the custom call is made on each shard of `x`. - -### Test the backward function - -We are moving onto the backward operation using `jax.grad` on multiple devices. - -Similarly to the forward operation test, we are creating a simple 1D mesh and sharding `x` on all devices. +We define 2 JAX primitives, one inner primitive that map to the +real kernel we want to warp in JAX. And an outer primitive that will +be used with the custom_partitioning registration and for the +gradient. (And if you implement the interface to support vmat, it will +also be on the outer primitive). -We also define the `loss` function with `xmap_rms_norm` instead of `rms_norm` - -```python -def loss_ref(x, weight): - predictions = rms_norm(x, weight) - return -jnp.mean(predictions**2) +JAX custom_partitioning implementation are callbacks from XLA to Python during XLA sharding logic. +XLA sharding goes in two phases: a sharding propagation phase and a partition phase. +The propagation phase is when XLA plan the sharding to be created. It is the partition phase that create the sharded graph. +For XLA to be able to shard our custom operations, it needs us to define 2 extra functions: +infer_sharding_from_operands() and partition(). They are used in the first and second phase respectively. +The infer_sharding_from_operands() function must do what its name say: infer the output sharding from the input sharding. -ref = jax.grad(loss_ref, argnums=(0, 1))(x, weight) +The partition() function will do a few things: +- tell which input sharding will be expected. XLA will reshad if needed. +- tell the final version of the output sharding. +- give a function that will create the new instruction from the sharded inputs. +See the code comments for more explanation: -# Re-define loss to use xmap_rms_norm instead of rms_norm -def loss(x, weight, *, device_count): - predictions = xmap_rms_norm(x, weight, device_count=device_count) - return -jnp.mean(predictions**2) - - -with mesh: - - pjitted = pjit( - jax.grad(partial(loss, device_count=jax.local_device_count()), argnums=(0, 1)), - # Shard x by batch dimension and replicate weight on all devices. - in_shardings=( - PartitionSpec("x", None, None), - PartitionSpec(None, None), - ), - # Shard the output by batch dimension and replicate weight grad on all devices. - out_shardings=( - PartitionSpec("x", None, None), - PartitionSpec(None, None), - ), - ) - out = pjitted(x, weight) - -for r, o in zip(ref, out): - print(jnp.allclose(r, o, atol=1e-2, rtol=1e-2)) -``` ```python -True -True -``` - -We can inspect the generated jaxpr, which is the JAX internal representation, to make sure `jax.grad` inserts a `psum` for the gradient accumulation across the devices when needed. +class RmsNormFwdClass: + name = "rms_forward_affine_mixed_dtype" + multiple_results = True + impl_static_args = (2,) # eps + inner_primitive = None + outer_primitive = None + + @staticmethod + def infer_sharding_from_operands(eps : float, mesh : jax.sharding.Mesh, + arg_infos : Tuple[jax._src.api.ShapeDtypeStruct], + result_infos : Tuple[jax._src.core.ShapedArray]): + del eps, result_infos # Not needed for this example. + x_info, weight_info = arg_infos + assert len(x_info.shape) == 3 + assert len(weight_info.shape) == 2 + # partition() will force all dims of all inputs to be replicated except the + # first dim of x that will be kept as is. + # This is because the implementaion can only be sharded on the batch dimensions. + + x_spec = arg_infos[0].sharding.spec + # None mean that we replicate on that dimension. + output_sharding = NamedSharding(mesh, PartitionSpec(x_spec[0], None, None)) + invvar_sharding = NamedSharding(mesh, PartitionSpec(x_spec[0])) + return (output_sharding, invvar_sharding) + + @staticmethod + def partition(eps : float, mesh : jax.sharding.Mesh, + arg_infos : Tuple[jax._src.api.ShapeDtypeStruct], + result_infos : Tuple[jax._src.api.ShapeDtypeStruct]): + del result_infos # Not needed for this example. + x_info, weight_info = arg_infos + assert len(x_info.shape) == 3 + assert len(weight_info.shape) == 2 + x_spec = arg_infos[0].sharding.spec + # We only support sharding on the batch dimensions. + # Force sharding on all others dimensions with None. + arg_shardings = (NamedSharding(mesh, PartitionSpec(x_spec[0], None, None)), + NamedSharding(mesh, PartitionSpec(None, None))) + invvar_sharding = NamedSharding(mesh, PartitionSpec(x_spec[0])) + output_shardings = (arg_shardings[0], invvar_sharding) + # Sharded_impl only accepts positional arugments + # And they should be Jax traceable variables + impl = partial(RmsNormFwdClass.impl, eps=eps) + + return mesh, impl, output_shardings, arg_shardings +register_primitive(RmsNormFwdClass) +``` +Next we define the primitive for the backward pass of RMSNorm + +### Shard the backward function with custom_partitioning ```python -with mesh: - - print(jax.make_jaxpr(pjitted)(x, weight)) -``` -```python -{ lambda ; a:bf16[32,512,512] b:bf16[512,512]. let - c:bf16[32,512,512] d:bf16[512,512] = pjit[ - donated_invars=(False, False) - in_positional_semantics=(<_PositionalSemantics.GLOBAL: 1>, <_PositionalSemantics.GLOBAL: 1>) - in_shardings=(GSPMDSharding({devices=[8,1,1]0,1,2,3,4,5,6,7}), GSPMDSharding({replicated})) - jaxpr={ lambda ; e:bf16[32,512,512] f:bf16[512,512]. let - g:bf16[8,4,512,512] = reshape[ - dimensions=None - new_sizes=(8, 4, 512, 512) - ] e - h:bf16[8,4,512,512] i:f32[8,4] j:bf16[8,4,512,512] k:bf16[512,512] = xmap[ - axis_resources=FrozenDict({'x': ('x',)}) - backend=None - call_jaxpr={ lambda ; l:bf16[4,512,512;x:8] m:bf16[512,512]. let - n:bf16[4,512,512;x:8] o:f32[4;x:8] = rms_norm_fwd[eps=1e-05] l m - in (n, o, l, m) } - donated_invars=(False, False) - global_axis_sizes=FrozenDict({'x': 8}) - in_axes=(FrozenDict({'x': 0}), FrozenDict({})) - in_positional_semantics=(<_PositionalSemantics.GLOBAL: 1>, <_PositionalSemantics.GLOBAL: 1>) - name=rms_norm - out_axes=(FrozenDict({'x': 0}), FrozenDict({'x': 0}), FrozenDict({'x': 0}), FrozenDict({})) - out_positional_semantics=_PositionalSemantics.GLOBAL - resource_env=ResourceEnv(Mesh(device_ids=array([0, 1, 2, 3, 4, 5, 6, 7]), axis_names=('x',)), ()) - spmd_in_axes=None - spmd_out_axes=None - ] g f - p:bf16[32,512,512] = reshape[dimensions=None new_sizes=(32, 512, 512)] h - q:bf16[32,512,512] = integer_pow[y=2] p - r:bf16[32,512,512] = integer_pow[y=1] p - s:bf16[32,512,512] = mul 2 r - t:f32[32,512,512] = convert_element_type[ - new_dtype=float32 - weak_type=False - ] q - u:f32[] = reduce_sum[axes=(0, 1, 2)] t - v:bf16[] = convert_element_type[new_dtype=bfloat16 weak_type=False] u - w:bf16[] = div v 8.38861e+06 - _:bf16[] = neg w - x:bf16[] = neg 1 - y:bf16[] = div x 8.38861e+06 - z:f32[] = convert_element_type[new_dtype=float32 weak_type=False] y - ba:f32[32,512,512] = broadcast_in_dim[ - broadcast_dimensions=() - shape=(32, 512, 512) - ] z - bb:bf16[32,512,512] = convert_element_type[ - new_dtype=bfloat16 - weak_type=False - ] ba - bc:bf16[32,512,512] = mul bb s - bd:bf16[8,4,512,512] = reshape[ - dimensions=None - new_sizes=(8, 4, 512, 512) - ] bc - be:bf16[8,4,512,512] bf:bf16[512,512] = xmap[ - axis_resources=FrozenDict({'x': ('x',)}) - backend=None - call_jaxpr={ lambda ; bg:f32[4;x:8] bh:bf16[4,512,512;x:8] bi:bf16[512,512] - bj:bf16[4,512,512;x:8]. let - bk:bf16[4,512,512;x:8] bl:bf16[512,512;x:8] _:f32[16,262144;x:8] = rms_norm_bwd[ - eps=1e-05 - ] bj bg bh bi - bm:bf16[512,512] = psum[axes=('x',) axis_index_groups=None] bl - in (bk, bm) } - donated_invars=(False, False, False, False) - global_axis_sizes=FrozenDict({'x': 8}) - in_axes=(FrozenDict({'x': 0}), FrozenDict({'x': 0}), FrozenDict({}), FrozenDict({'x': 0})) - in_positional_semantics=(<_PositionalSemantics.GLOBAL: 1>, <_PositionalSemantics.GLOBAL: 1>) - name=transpose(rms_norm) - out_axes=(FrozenDict({'x': 0}), FrozenDict({})) - out_positional_semantics=_PositionalSemantics.GLOBAL - resource_env=ResourceEnv(Mesh(device_ids=array([0, 1, 2, 3, 4, 5, 6, 7]), axis_names=('x',)), ()) - spmd_in_axes=None - spmd_out_axes=None - ] i j k bd - bn:bf16[32,512,512] = reshape[ - dimensions=None - new_sizes=(32, 512, 512) - ] be - in (bn, bf) } - name= - out_positional_semantics=_PositionalSemantics.GLOBAL - out_shardings=(GSPMDSharding({devices=[8,1,1]0,1,2,3,4,5,6,7}), GSPMDSharding({replicated})) - resource_env=ResourceEnv(Mesh(device_ids=array([0, 1, 2, 3, 4, 5, 6, 7]), axis_names=('x',)), ()) - ] a b - in (c, d) } -``` - -We see that `bm:bf16[512,512] = psum[axes=('x',) axis_index_groups=None] bl` has been added after the call to `rms_norm_bwd` to reduce `grad_weight` across the devices on the axis `"x"`, but there is no `psum` for `grad_input`. - -This is controlled by `named_shape` passed to the `ShapedArray` construction in abstract evaluation and the axes given to `xmap`. - -The following code snippet from `_rms_norm_bwd_abstract` shows that `grad_input` has the exact same shape, type, and named shape as `x` does, which means `grad_input` is sharded the same way as `x`, hence no need for a `psum` for `grad_input`. -In contrast, `grad_weight` has the same shape and type as `weight` does, but, when `weight.named_shape` is empty, `x.named_shape` is used for `grad_weight`. In `in_axes` of our `xmap` call, `weight` has no named axis and `weight.named_shape` is empty, but `grad_weight` now has a named axis `"x"` in `grad_weight.named_shape`. -This makes `jax.grad` insert `psum` on the axis `"x"` for `grad_weight`. - -``` -weight_named_shape = ( - weight_named_shape if weight.named_shape else x.named_shape -) -... -return ( - ShapedArray( - x.shape, x_dtype, named_shape=x.named_shape - ), # grad input - ShapedArray( - weight.shape, w_dtype, named_shape=weight_named_shape - ), # grad weight - .... -) -``` - -## Let's put it together - -Here is the complete code. +class RmsNormBwdClass: + name = "rms_norm_bwd" + multiple_results = True + impl_static_args = (4,) # eps + inner_primitive = None + outer_primitive = None + + @staticmethod + def infer_sharding_from_operands(eps : float, mesh : jax.sharding.Mesh, + arg_infos : Tuple[jax._src.api.ShapeDtypeStruct], + result_infos : Tuple[jax._src.core.ShapedArray]): + del eps, result_infos # Not needed for this example. + g_info, invvar_info, x_info, weight_info = arg_infos + assert len(g_info.shape) == 3 + assert len(invvar_info.shape) == 1 + assert len(x_info.shape) == 3 + assert len(weight_info.shape) == 2 + # partition() will force all dims to be replicated except the batch dimension. + x_spec = x_info.sharding.spec + output_sharding = NamedSharding(mesh, PartitionSpec(x_spec[0], None, None)) + invvar_sharding = NamedSharding(mesh, PartitionSpec(None, None)) + return (output_sharding, invvar_sharding, output_sharding, ) + + @staticmethod + def partition(eps : float, mesh : jax.sharding.Mesh, + arg_infos : Tuple[jax._src.api.ShapeDtypeStruct], + result_infos : Tuple[jax._src.api.ShapeDtypeStruct]): + del result_infos # Not needed for this example. + g_info, invvar_info, x_info, weight_info = arg_infos + assert len(g_info.shape) == 3 + assert len(invvar_info.shape) == 1 + assert len(x_info.shape) == 3 + assert len(weight_info.shape) == 2 + + # We only support sharding on the batch dimensions. + # Force sharding on all others dimensions with None. + # Also force gx, x and invvar to have the same batch sharding/replication. + x_spec = x_info.sharding.spec + arg_shardings = (NamedSharding(mesh, PartitionSpec(x_spec[0], None, None)), + NamedSharding(mesh, PartitionSpec(x_spec[0],)), + NamedSharding(mesh, PartitionSpec(x_spec[0], None, None)), + NamedSharding(mesh, PartitionSpec(None, None))) + + output_sharding = NamedSharding(mesh, PartitionSpec(x_spec[0], None, None)) + invvar_sharding = NamedSharding(mesh, PartitionSpec(None, None)) + output_shardings = (output_sharding, invvar_sharding, invvar_sharding) + + + # Sharded_impl only accepts positional arugments + # And they should be Jax traceable variables + def impl(g, invvar, x, weight): + grad_input, grad_weight, part_grad = _rms_norm_bwd_p.bind( + g, invvar, x, weight, eps=eps + ) + # We need to sum the weight gradient from all partition. + global_weight = grad_weight + if x_spec[0]: + global_weight = jax.lax.psum(grad_weight, x_spec[0]) + return grad_input, global_weight, part_grad + return mesh, impl, output_shardings, arg_shardings +register_primitive(RmsNormBwdClass) +``` +Plumbing to establish the forward and backward primitives with a custom_vjp rule as before: ```python -from functools import partial, reduce -from operator import mul - -import jax -import jax.numpy as jnp -from build import gpu_ops -from jax import core, dtypes -from jax.core import ShapedArray -from jax.experimental.maps import xmap -from jax.experimental.pjit import pjit -from jax.interpreters import mlir, xla -from jax.interpreters.mlir import ir -from jax.lib import xla_client -from jax.sharding import Mesh, PartitionSpec -from jaxlib.hlo_helpers import custom_call - - -# Create _rms_norm_fwd_p for forward operation. -_rms_norm_fwd_p = core.Primitive("rms_norm_fwd") -_rms_norm_fwd_p.multiple_results = True -_rms_norm_fwd_p.def_impl(partial(xla.apply_primitive, _rms_norm_fwd_p)) - - -def rms_norm_fwd(x, weight, eps=1e-05): - output, invvar = _rms_norm_fwd_p.bind(x, weight, eps=eps) +@partial(jax.custom_vjp, nondiff_argnums=(2,)) +def custom_p_rms_norm(x, weight, eps=1e-05): + output, _ = custom_p_rms_norm_fwd(x, weight, eps=eps) + return output + +def custom_p_rms_norm_fwd(x, weight, eps=1e-05): + output, invvar = RmsNormFwdClass.outer_primitive.bind(x, weight, eps=eps) return output, (invvar, x, weight) - -# Create _rms_norm_bwd_p for backward operation. -_rms_norm_bwd_p = core.Primitive("rms_norm_bwd") -_rms_norm_bwd_p.multiple_results = True -_rms_norm_bwd_p.def_impl(partial(xla.apply_primitive, _rms_norm_bwd_p)) - - -def rms_norm_bwd(eps, res, g): +def custom_p_rms_norm_bwd(eps, res, g): invvar, x, weight = res - grad_input, grad_weight, part_grad = _rms_norm_bwd_p.bind( - g, invvar, x, weight, eps=eps - ) + grad_input, grad_weight, part_grad = RmsNormBwdClass.outer_primitive.bind( + g, invvar, x, weight, eps=eps) return grad_input, grad_weight +custom_p_rms_norm.defvjp(custom_p_rms_norm_fwd, custom_p_rms_norm_bwd) +``` -#################### -# Lowering to MLIR # -#################### - - -# Register functions defined in gpu_ops as custom call target for GPUs -for _name, _value in gpu_ops.get_rms_norm_registrations().items(): - xla_client.register_custom_call_target(_name, _value, platform="gpu") - - -def element_type_to_descriptor_type_mapping(element_type): - _element_type_to_descriptor_type_mapping = { - ir.BF16Type.get(): gpu_ops.ElementType.BF16, - ir.F16Type.get(): gpu_ops.ElementType.F16, - ir.F32Type.get(): gpu_ops.ElementType.F32, - ir.F64Type.get(): gpu_ops.ElementType.F64, - } - return _element_type_to_descriptor_type_mapping.get(element_type) - - -def default_layouts(*shapes): - return [range(len(shape) - 1, -1, -1) for shape in shapes] - - -def _rms_norm_fwd_cuda_lowering(ctx, x, weight, eps): - x_type = ir.RankedTensorType(x.type) - x_shape = x_type.shape - w_type = ir.RankedTensorType(weight.type) - w_shape = w_type.shape - iv_element_type = ( - ir.F32Type.get() - if x_type.element_type in [ir.F16Type.get(), ir.BF16Type.get()] - else x_type.element_type - ) - - n2 = reduce(lambda x, y: x * y, w_shape) - n1 = reduce(lambda x, y: x * y, x_shape) // n2 - - opaque = gpu_ops.create_rms_norm_descriptor( - n1, - n2, - eps, - element_type_to_descriptor_type_mapping(x_type.element_type), - element_type_to_descriptor_type_mapping(w_type.element_type), - 0, # unused - ) - out = custom_call( - b"rms_forward_affine_mixed_dtype", - result_types=[ - ir.RankedTensorType.get(x_shape, w_type.element_type), - ir.RankedTensorType.get((n1,), iv_element_type), - ], - operands=[x, weight], - backend_config=opaque, - operand_layouts=default_layouts(x_shape, w_shape), - result_layouts=default_layouts(x_shape, (n1,)), - ).results - return out - - -mlir.register_lowering( - _rms_norm_fwd_p, - _rms_norm_fwd_cuda_lowering, - platform="gpu", -) - - -def _rms_norm_bwd_cuda_lowering(ctx, grad_output, invvar, x, weight, eps): - x_type = ir.RankedTensorType(x.type) - x_shape = x_type.shape - w_type = ir.RankedTensorType(weight.type) - w_shape = w_type.shape - iv_type = ir.RankedTensorType(invvar.type) - - n2 = reduce(lambda x, y: x * y, w_shape) - n1 = reduce(lambda x, y: x * y, x_shape) // n2 - - part_grad_shape = ctx.avals_out[-1].shape - - opaque = gpu_ops.create_rms_norm_descriptor( - n1, - n2, - eps, - element_type_to_descriptor_type_mapping(x_type.element_type), - element_type_to_descriptor_type_mapping(w_type.element_type), - part_grad_shape[0], - ) - out = custom_call( - b"rms_backward_affine", - result_types=[ - ir.RankedTensorType.get(x_shape, x_type.element_type), - ir.RankedTensorType.get(w_shape, w_type.element_type), - ir.RankedTensorType.get(part_grad_shape, iv_type.element_type), - ], - operands=[grad_output, invvar, x, weight], - backend_config=opaque, - operand_layouts=default_layouts(x_shape, (n1,), x_shape, w_shape), - result_layouts=default_layouts(x_shape, w_shape, part_grad_shape), - ).results - return out - - -mlir.register_lowering( - _rms_norm_bwd_p, - _rms_norm_bwd_cuda_lowering, - platform="gpu", -) - - -####################### -# Abstract evaluation # -####################### - - -def _rms_norm_fwd_abstract(x, weight, eps): - w_dtype = dtypes.canonicalize_dtype(weight.dtype) - iv_dtype = dtypes.canonicalize_dtype(x.dtype) - if iv_dtype in [jnp.float16, jnp.bfloat16]: - iv_dtype = jnp.float32 - n2 = reduce(mul, weight.shape) - n1 = reduce(mul, x.shape) // n2 - return ( - ShapedArray(x.shape, w_dtype, named_shape=x.named_shape), # output - ShapedArray((n1,), iv_dtype, named_shape=x.named_shape), # invvar - ) - - -_rms_norm_fwd_p.def_abstract_eval(_rms_norm_fwd_abstract) - - -def _rms_norm_bwd_abstract(grad_output, invvar, x, weight, eps): - iv_dtype = dtypes.canonicalize_dtype(invvar.dtype) - w_dtype = dtypes.canonicalize_dtype(weight.dtype) - x_dtype = dtypes.canonicalize_dtype(x.dtype) - n2 = reduce(lambda x, y: x * y, weight.shape) - n1 = reduce(lambda x, y: x * y, x.shape) // n2 - part_grad_shape = (16, n2) - assert dtypes.canonicalize_dtype(grad_output.dtype) == w_dtype - assert grad_output.shape == x.shape - assert invvar.shape == (n1,) - assert ( - iv_dtype == jnp.float32 if x_dtype in [jnp.float16, jnp.bfloat16] else x_dtype - ) - assert grad_output.named_shape == x.named_shape - weight_named_shape = ( - weight_named_shape if weight.named_shape else grad_output.named_shape - ) - return ( - ShapedArray( - x.shape, x_dtype, named_shape=x.named_shape - ), # grad input - ShapedArray( - weight.shape, w_dtype, named_shape=weight_named_shape - ), # grad weight - ShapedArray( - part_grad_shape, iv_dtype, named_shape=weight_named_shape - ), # part grad - ) - - -_rms_norm_bwd_p.def_abstract_eval(_rms_norm_bwd_abstract) - - -####################################### -# Top-level interface with custom vjp # -####################################### - - -@partial(jax.custom_vjp, nondiff_argnums=(2,)) -def rms_norm(x, weight, eps=1e-05): - output, _ = rms_norm_fwd(x, weight, eps=eps) - return output - - -rms_norm.defvjp(rms_norm_fwd, rms_norm_bwd) - - -###################### -# RMS norm with xmap # -###################### - - -jax.config.update("experimental_xmap_spmd_lowering", True) -jax.config.update("experimental_xmap_spmd_lowering_manual", True) - - -def xmap_rms_norm(x, weight, *, device_count): - reshaped = x.reshape(device_count, x.shape[0] // device_count, *x.shape[1:]) - xmapped = xmap( - rms_norm, - in_axes=(("x", None, None, None), (None, None)), - out_axes=("x", None, None, None), - axis_resources={"x": "x"}, - ) - reshaped_out = xmapped(reshaped, weight) - return reshaped_out.reshape(x.shape) - - -######## -# Test # -######## - - -import jax - - -per_core_batch_size=4 -seq_len=512 -emb_dim=512 -x = jax.random.normal( - jax.random.PRNGKey(0), - shape=(jax.local_device_count() * per_core_batch_size, seq_len, emb_dim), - dtype=jnp.bfloat16, -) -norm_shape = x.shape[-2:] -weight = jnp.ones(norm_shape, dtype=jnp.bfloat16) - +With that we have completely defined our custom RMS norm primitive with custom_partitioning. To check for correctness we define the following loss functions: ref_loss is the reference value to compare against, while custom_p_loss uses our new primitive that implements custom_partitioning. -def loss_ref(x, weight): +```python +def ref_loss(x, weight): predictions = rms_norm(x, weight) return -jnp.mean(predictions**2) -ref = jax.grad(loss_ref, argnums=(0, 1))(x, weight) - +ref = jax.grad(ref_loss, argnums=(0, 1))(x, weight) -def loss(x, weight, *, device_count): - predictions = xmap_rms_norm(x, weight, device_count=device_count) +def custom_p_loss(x, weight): + predictions = custom_p_rms_norm(x, weight) return -jnp.mean(predictions**2) +``` +# Check for correctness +```python with Mesh(jax.local_devices(), ("x",)): - - pjitted = pjit( - jax.grad(partial(loss, device_count=jax.local_device_count()), argnums=(0, 1)), - # Shard x by batch dimension and replicate weight on all devices. - in_shardings=( - PartitionSpec("x", None, None), - PartitionSpec(None, None), - ), - # Shard the output by batch dimension and replicate weight grad on all devices. - out_shardings=( - PartitionSpec("x", None, None), - PartitionSpec(None, None), - ), - ) - out = pjitted(x, weight) - -for r, o in zip(ref, out): - print(jnp.allclose(r, o, atol=1e-2, rtol=1e-2)) + def run_and_verify(loss): + pjitted = pjit( + jax.grad(loss, argnums=(0, 1)), + # Shard x by batch dimension and replicate weight on all devices. + in_shardings=( + PartitionSpec("x", None, None), + PartitionSpec(None, None), + ), + # Shard the output by batch dimension and replicate weight grad on all devices. + out_shardings=( + PartitionSpec("x", None, None), + PartitionSpec(None, None), + ), + ) + hlo = pjitted.lower(x, weight).compile().as_text() + out = pjitted(x, weight) + print(hlo) + assert "all-reduce-done" in hlo, "The gradient will produce wrong value!" + if "all-gather-start" in hlo: + print("NOT OPTIMIZED, ALL_GATHER in the graph!") + return out + + custom_p_out = run_and_verify(custom_p_loss) + + +for r, o in zip(ref_out, custom_p_out): + print(jnp.allclose(r, o, atol=1e-6, rtol=1e-6)) ``` ```python -True -True -``` - -## Appendix - -### `gpu_ops` code listing - -#### `gpu_ops/kernel_helpers.h` +HloModule pjit_custom_p_loss, is_scheduled=true, entry_computation_layout={(f16[4,512,512]{2,1,0}, f16[512,512]{1,0})->(f16[4,512,512]{2,1,0}, f16[512,512]{1,0})}, allow_spmd_sharding_propagation_to_parameters={false,false}, allow_spmd_sharding_propagation_to_output={false,false}, num_partitions=4, frontend_attributes={fingerprint_before_lhs="d7b9bc40de002332dd665ff2ab537b76"} -```cpp -// This header is not specific to our application and you'll probably want -// something like this for any extension you're building. This includes the -// infrastructure needed to serialize descriptors that are used with the -// "opaque" parameter of the GPU custom call. In our example we'll use this -// parameter to pass the size of our problem. - -#ifndef _GPU_OPS_KERNEL_HELPERS_H_ -#define _GPU_OPS_KERNEL_HELPERS_H_ - -#include -#include -#include -#include - -#define JAX_APEX_WARP_SIZE 32 - -namespace gpu_ops { - -// https://en.cppreference.com/w/cpp/numeric/bit_cast -template -typename std::enable_if::value && - std::is_trivially_copyable::value, - To>::type -bit_cast(const From &src) noexcept { - static_assert(std::is_trivially_constructible::value, - "This implementation additionally requires destination type to " - "be trivially constructible"); - - To dst; - memcpy(&dst, &src, sizeof(To)); - return dst; +%fused_multiply (param_0: f16[4,512,512]) -> f16[4,512,512] { + %param_0 = f16[4,512,512]{2,1,0} parameter(0) + %constant_4_1 = f16[] constant(-4.7684e-07) + %broadcast.8.1 = f16[4,512,512]{2,1,0} broadcast(f16[] %constant_4_1), dimensions={}, metadata={op_name="pjit(custom_p_loss)/jit(main)/mul" source_file="/opt/jax/docs/Custom_Operation_for_GPUs.py" source_line=484} + ROOT %multiply.5.1 = f16[4,512,512]{2,1,0} multiply(f16[4,512,512]{2,1,0} %param_0, f16[4,512,512]{2,1,0} %broadcast.8.1), metadata={op_name="pjit(custom_p_loss)/jit(main)/mul" source_file="/opt/jax/docs/Custom_Operation_for_GPUs.py" source_line=484} } -template std::string PackDescriptorAsString(const T &descriptor) { - return std::string(bit_cast(&descriptor), sizeof(T)); -} - -template -const T *UnpackDescriptor(const char *opaque, std::size_t opaque_len) { - if (opaque_len != sizeof(T)) { - throw std::runtime_error("Invalid opaque object size"); - } - return bit_cast(opaque); -} - -} // namespace gpu_ops - -#endif -``` - -#### `gpu_ops/kernels.h` - -```cpp -#ifndef _GPU_OPS_KERNELS_H_ -#define _GPU_OPS_KERNELS_H_ - -#include - -#include -#include - -namespace gpu_ops { - -enum ElementType { BF16, F16, F32, F64 }; - -struct RMSNormDescriptor { - int n1; - int n2; - double eps; - ElementType x_type; - ElementType w_type; - int part_grad_size; -}; - -void rms_forward_affine_mixed_dtypes(cudaStream_t stream, void **buffers, - const char *opaque, - std::size_t opaque_len); -void rms_backward_affine(cudaStream_t stream, void **buffers, - const char *opaque, std::size_t opaque_len); -} // namespace gpu_ops - -#endif -``` - -#### `gpu_ops/pybind11_kernel_helpers.h` - -```cpp -// This header extends kernel_helpers.h with the pybind11 specific interface to -// serializing descriptors. It also adds a pybind11 function for wrapping our -// custom calls in a Python capsule. This is separate from kernel_helpers so -// that the CUDA code itself doesn't include pybind11. I don't think that this -// is strictly necessary, but they do it in jaxlib, so let's do it here too. - -#ifndef _GPU_OPS_PYBIND11_KERNEL_HELPERS_H_ -#define _GPU_OPS_PYBIND11_KERNEL_HELPERS_H_ - -#include - -#include "kernel_helpers.h" - -namespace gpu_ops { - -template pybind11::bytes PackDescriptor(const T &descriptor) { - return pybind11::bytes(PackDescriptorAsString(descriptor)); +%region_0.9._custom_call_lowering_rule (Arg_0.10.0: f16[], Arg_1.11.0: f16[]) -> f16[] { + %Arg_1.11.0 = f16[] parameter(1) + %Arg_0.10.0 = f16[] parameter(0) + ROOT %add.2.0 = f16[] add(f16[] %Arg_0.10.0, f16[] %Arg_1.11.0), metadata={op_name="jit(main)/add" source_file="/opt/jax/docs/Custom_Operation_for_GPUs.py" source_line=433} } -template pybind11::capsule EncapsulateFunction(T *fn) { - return pybind11::capsule(bit_cast(fn), "xla._CUSTOM_CALL_TARGET"); +ENTRY %main.23_spmd (param.2: f16[4,512,512], param.1.0: f16[512,512]) -> (f16[4,512,512], f16[512,512]) { + %param.1.0 = f16[512,512]{1,0} parameter(1), sharding={replicated} + %param.2 = f16[4,512,512]{2,1,0} parameter(0), sharding={devices=[4,1,1]<=[4]} + %custom-call.3.0 = (f16[4,512,512]{2,1,0}, f32[4]{0}) custom-call(f16[4,512,512]{2,1,0} %param.2, f16[512,512]{1,0} %param.1.0), custom_call_target="rms_forward_affine_mixed_dtype", operand_layout_constraints={f16[4,512,512]{2,1,0}, f16[512,512]{1,0}}, api_version=API_VERSION_STATUS_RETURNING, metadata={op_name="pjit(custom_p_loss)/jit(main)/custom_partitioning[partition= propagate_user_sharding=None infer_sharding_from_operands= decode_shardings=True in_tree=PyTreeDef((*, *)) out_tree=PyTreeDef((*, *)) static_args=[1e-05]]" source_file="/opt/jax/docs/Custom_Operation_for_GPUs.py" source_line=440}, backend_config="\004\000\000\000\000\000\004\000\361h\343\210\265\370\344>\001\000\000\000\001\000\000\000\000\000\000\000$V\000\000" + %get-tuple-element.14 = f16[4,512,512]{2,1,0} get-tuple-element((f16[4,512,512]{2,1,0}, f32[4]{0}) %custom-call.3.0), index=0, metadata={op_name="pjit(custom_p_loss)/jit(main)/custom_partitioning[partition= propagate_user_sharding=None infer_sharding_from_operands= decode_shardings=True in_tree=PyTreeDef((*, *)) out_tree=PyTreeDef((*, *)) static_args=[1e-05]]" source_file="/opt/jax/docs/Custom_Operation_for_GPUs.py" source_line=440} + %loop_multiply_fusion = f16[4,512,512]{2,1,0} fusion(f16[4,512,512]{2,1,0} %get-tuple-element.14), kind=kLoop, calls=%fused_multiply, metadata={op_name="pjit(custom_p_loss)/jit(main)/mul" source_file="/opt/jax/docs/Custom_Operation_for_GPUs.py" source_line=484} + %get-tuple-element.1.0 = f32[4]{0} get-tuple-element((f16[4,512,512]{2,1,0}, f32[4]{0}) %custom-call.3.0), index=1, metadata={op_name="pjit(custom_p_loss)/jit(main)/custom_partitioning[partition= propagate_user_sharding=None infer_sharding_from_operands= decode_shardings=True in_tree=PyTreeDef((*, *)) out_tree=PyTreeDef((*, *)) static_args=[1e-05]]" source_file="/opt/jax/docs/Custom_Operation_for_GPUs.py" source_line=440} + %custom-call.5.0 = (f16[4,512,512]{2,1,0}, f16[512,512]{1,0}, f32[16,262144]{1,0}) custom-call(f16[4,512,512]{2,1,0} %loop_multiply_fusion, f32[4]{0} %get-tuple-element.1.0, f16[4,512,512]{2,1,0} %param.2, f16[512,512]{1,0} %param.1.0), custom_call_target="rms_backward_affine", operand_layout_constraints={f16[4,512,512]{2,1,0}, f32[4]{0}, f16[4,512,512]{2,1,0}, f16[512,512]{1,0}}, api_version=API_VERSION_STATUS_RETURNING, metadata={op_name="pjit(custom_p_loss)/jit(main)/custom_partitioning[partition= propagate_user_sharding=None infer_sharding_from_operands= decode_shardings=True in_tree=PyTreeDef((*, *, *, *)) out_tree=PyTreeDef((*, *, *)) static_args=[1e-05]]" source_file="/opt/jax/docs/Custom_Operation_for_GPUs.py" source_line=483}, backend_config="\004\000\000\000\000\000\004\000\361h\343\210\265\370\344>\001\000\000\000\001\000\000\000\020\000\000\000$V\000\000" + %get-tuple-element.7.0 = f16[512,512]{1,0} get-tuple-element((f16[4,512,512]{2,1,0}, f16[512,512]{1,0}, f32[16,262144]{1,0}) %custom-call.5.0), index=1, metadata={op_name="pjit(custom_p_loss)/jit(main)/custom_partitioning[partition= propagate_user_sharding=None infer_sharding_from_operands= decode_shardings=True in_tree=PyTreeDef((*, *, *, *)) out_tree=PyTreeDef((*, *, *)) static_args=[1e-05]]" source_file="/opt/jax/docs/Custom_Operation_for_GPUs.py" source_line=483} + %all-reduce-start = f16[512,512]{1,0} all-reduce-start(f16[512,512]{1,0} %get-tuple-element.7.0), channel_id=1, replica_groups={{0,1,2,3}}, use_global_device_ids=true, to_apply=%region_0.9._custom_call_lowering_rule, metadata={op_name="pjit(custom_p_loss)/jit(main)/custom_partitioning[partition= propagate_user_sharding=None infer_sharding_from_operands= decode_shardings=True in_tree=PyTreeDef((*, *, *, *)) out_tree=PyTreeDef((*, *, *)) static_args=[1e-05]]" source_file="/opt/jax/docs/Custom_Operation_for_GPUs.py" source_line=483}, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"collective_backend_config":{"is_sync":true,"no_parallel_custom_call":false}} + %all-reduce-done = f16[512,512]{1,0} all-reduce-done(f16[512,512]{1,0} %all-reduce-start), metadata={op_name="pjit(custom_p_loss)/jit(main)/custom_partitioning[partition= propagate_user_sharding=None infer_sharding_from_operands= decode_shardings=True in_tree=PyTreeDef((*, *, *, *)) out_tree=PyTreeDef((*, *, *)) static_args=[1e-05]]" source_file="/opt/jax/docs/Custom_Operation_for_GPUs.py" source_line=483} + %get-tuple-element.12.0 = f16[4,512,512]{2,1,0} get-tuple-element((f16[4,512,512]{2,1,0}, f16[512,512]{1,0}, f32[16,262144]{1,0}) %custom-call.5.0), index=0, metadata={op_name="pjit(custom_p_loss)/jit(main)/custom_partitioning[partition= propagate_user_sharding=None infer_sharding_from_operands= decode_shardings=True in_tree=PyTreeDef((*, *, *, *)) out_tree=PyTreeDef((*, *, *)) static_args=[1e-05]]" source_file="/opt/jax/docs/Custom_Operation_for_GPUs.py" source_line=483} + ROOT %tuple.1.0 = (f16[4,512,512]{2,1,0}, f16[512,512]{1,0}) tuple(f16[4,512,512]{2,1,0} %get-tuple-element.12.0, f16[512,512]{1,0} %all-reduce-done) } - -} // namespace gpu_ops - -#endif ``` - -#### `gpu_ops/gpu_ops.cpp` - -```cpp -#include "kernels.h" -#include "pybind11_kernel_helpers.h" - -namespace { -pybind11::dict RMSNormRegistrations() { - pybind11::dict dict; - dict["rms_forward_affine_mixed_dtype"] = - gpu_ops::EncapsulateFunction(gpu_ops::rms_forward_affine_mixed_dtypes); - dict["rms_backward_affine"] = - gpu_ops::EncapsulateFunction(gpu_ops::rms_backward_affine); - return dict; -} - -PYBIND11_MODULE(gpu_ops, m) { - m.def("get_rms_norm_registrations", &RMSNormRegistrations); - m.def("create_rms_norm_descriptor", - [](int n1, int n2, double eps, gpu_ops::ElementType x_type, - gpu_ops::ElementType w_type, int part_grad_size) { - return gpu_ops::PackDescriptor(gpu_ops::RMSNormDescriptor{ - n1, n2, eps, x_type, w_type, part_grad_size}); - }); - - pybind11::enum_(m, "ElementType") - .value("BF16", gpu_ops::ElementType::BF16) - .value("F16", gpu_ops::ElementType::F16) - .value("F32", gpu_ops::ElementType::F32) - .value("F64", gpu_ops::ElementType::F64); - -} -} // namespace +```python +True +True ``` -#### `gpu_ops/rms_norm_kernels.cu` +Now there are no all-gathers in the HLO, sharding is respected and only gradients are accumulated via an all-reduce. -```cpp -#include "kernel_helpers.h" -#include "kernels.h" -#include "stdio.h" -#include -#include -#include - -namespace { - -#define DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, \ - NAME, ...) \ - switch (TYPEIN) { \ - case gpu_ops::ElementType::F64: { \ - using scalar_t_in = double; \ - using accscalar_t = double; \ - switch (TYPEOUT) { \ - case gpu_ops::ElementType::F64: { \ - using scalar_t_out = double; \ - __VA_ARGS__; \ - break; \ - } \ - case gpu_ops::ElementType::F32: { \ - using scalar_t_out = float; \ - __VA_ARGS__; \ - break; \ - } \ - case gpu_ops::ElementType::F16: { \ - using scalar_t_out = __half; \ - __VA_ARGS__; \ - break; \ - } \ - case gpu_ops::ElementType::BF16: { \ - using scalar_t_out = __nv_bfloat16; \ - __VA_ARGS__; \ - break; \ - } \ - default: \ - break; \ - } \ - break; \ - } \ - case gpu_ops::ElementType::F32: { \ - using scalar_t_in = float; \ - using accscalar_t = float; \ - switch (TYPEOUT) { \ - case gpu_ops::ElementType::F64: { \ - using scalar_t_out = double; \ - __VA_ARGS__; \ - break; \ - } \ - case gpu_ops::ElementType::F32: { \ - using scalar_t_out = float; \ - __VA_ARGS__; \ - break; \ - } \ - case gpu_ops::ElementType::F16: { \ - using scalar_t_out = __half; \ - __VA_ARGS__; \ - break; \ - } \ - case gpu_ops::ElementType::BF16: { \ - using scalar_t_out = __nv_bfloat16; \ - __VA_ARGS__; \ - break; \ - } \ - default: \ - break; \ - } \ - break; \ - } \ - case gpu_ops::ElementType::F16: { \ - using scalar_t_in = __half; \ - using accscalar_t = float; \ - switch (TYPEOUT) { \ - case gpu_ops::ElementType::F64: { \ - using scalar_t_out = double; \ - __VA_ARGS__; \ - break; \ - } \ - case gpu_ops::ElementType::F32: { \ - using scalar_t_out = float; \ - __VA_ARGS__; \ - break; \ - } \ - case gpu_ops::ElementType::F16: { \ - using scalar_t_out = __half; \ - __VA_ARGS__; \ - break; \ - } \ - case gpu_ops::ElementType::BF16: { \ - using scalar_t_out = __nv_bfloat16; \ - __VA_ARGS__; \ - break; \ - } \ - default: \ - break; \ - } \ - break; \ - } \ - case gpu_ops::ElementType::BF16: { \ - using scalar_t_in = __nv_bfloat16; \ - using accscalar_t = float; \ - switch (TYPEOUT) { \ - case gpu_ops::ElementType::F64: { \ - using scalar_t_out = double; \ - __VA_ARGS__; \ - break; \ - } \ - case gpu_ops::ElementType::F32: { \ - using scalar_t_out = float; \ - __VA_ARGS__; \ - break; \ - } \ - case gpu_ops::ElementType::F16: { \ - using scalar_t_out = __half; \ - __VA_ARGS__; \ - break; \ - } \ - case gpu_ops::ElementType::BF16: { \ - using scalar_t_out = __nv_bfloat16; \ - __VA_ARGS__; \ - break; \ - } \ - default: \ - break; \ - } \ - break; \ - } \ - default: \ - break; \ - } - -template -__device__ void cuWelfordOnlineSum(const U curr, U &mu, U &sigma2, U &count) { - count = count + U(1); - U delta = curr - mu; - U lmean = mu + delta / count; - mu = lmean; - U delta2 = curr - lmean; - sigma2 = sigma2 + delta * delta2; -} - -template -__device__ void cuChanOnlineSum(const U muB, const U sigma2B, const U countB, - U &mu, U &sigma2, U &count) { - U delta = muB - mu; - U nA = count; - U nB = countB; - count = count + countB; - U nX = count; - if (nX > U(0)) { - nA = nA / nX; - nB = nB / nX; - mu = nA * mu + nB * muB; - sigma2 = sigma2 + sigma2B + delta * delta * nA * nB * nX; - } else { - mu = U(0); - sigma2 = U(0); - } -} - -template __device__ void cuRMSOnlineSum(const U curr, U &sigma2) { - sigma2 = sigma2 + curr * curr; -} - -template -__device__ void cuChanRMSOnlineSum(const U sigma2B, U &sigma2) { - sigma2 = sigma2 + sigma2B; -} -template -__device__ void cuWelfordMuSigma2(const T *__restrict__ vals, const int n1, - const int n2, const int i1, U &mu, U &sigma2, - U *buf, bool rms_only) { - // Assumptions: - // 1) blockDim.x == warpSize - // 2) Tensor is contiguous - // 3) 2*blockDim.y*sizeof(U)+blockDim.y*sizeof(int) shared memory available. - // - // compute variance and mean over n2 - U count = U(0); - mu = U(0); - sigma2 = U(0); - if (i1 < n1) { - // one warp normalizes one n1 index, - // synchronization is implicit - // initialize with standard Welford algorithm - const int numx = blockDim.x * blockDim.y; - const int thrx = threadIdx.x + threadIdx.y * blockDim.x; - const T *lvals = vals + i1 * n2; - int l = 4 * thrx; - for (; l + 3 < n2; l += 4 * numx) { - for (int k = 0; k < 4; ++k) { - U curr = static_cast(lvals[l + k]); - if (!rms_only) { - cuWelfordOnlineSum(curr, mu, sigma2, count); - } else { - cuRMSOnlineSum(curr, sigma2); - } - } - } - for (; l < n2; ++l) { - U curr = static_cast(lvals[l]); - if (!rms_only) { - cuWelfordOnlineSum(curr, mu, sigma2, count); - } else { - cuRMSOnlineSum(curr, sigma2); - } - } - // intra-warp reductions - for (int l = 0; l <= 4; ++l) { - int srcLaneB = (threadIdx.x + (1 << l)) & 31; - U sigma2B = __shfl_sync(0xffffffff, sigma2, srcLaneB, warpSize); - if (!rms_only) { - U muB = __shfl_sync(0xffffffff, mu, srcLaneB, warpSize); - U countB = __shfl_sync(0xffffffff, count, srcLaneB, warpSize); - cuChanOnlineSum(muB, sigma2B, countB, mu, sigma2, count); - } else { - cuChanRMSOnlineSum(sigma2B, sigma2); - } - } - // threadIdx.x == 0 has correct values for each warp - // inter-warp reductions - if (blockDim.y > 1) { - U *ubuf = (U *)buf; - U *ibuf = (U *)(ubuf + blockDim.y); - for (int offset = blockDim.y / 2; offset > 0; offset /= 2) { - // upper half of warps write to shared - if (threadIdx.x == 0 && threadIdx.y >= offset && - threadIdx.y < 2 * offset) { - const int wrt_y = threadIdx.y - offset; - if (!rms_only) { - ubuf[2 * wrt_y] = mu; - ibuf[wrt_y] = count; - } - ubuf[2 * wrt_y + 1] = sigma2; - } - __syncthreads(); - // lower half merges - if (threadIdx.x == 0 && threadIdx.y < offset) { - U sigma2B = ubuf[2 * threadIdx.y + 1]; - if (!rms_only) { - U muB = ubuf[2 * threadIdx.y]; - U countB = ibuf[threadIdx.y]; - cuChanOnlineSum(muB, sigma2B, countB, mu, sigma2, count); - } else { - cuChanRMSOnlineSum(sigma2B, sigma2); - } - } - __syncthreads(); - } - // threadIdx.x = 0 && threadIdx.y == 0 only thread that has correct values - if (threadIdx.x == 0 && threadIdx.y == 0) { - if (!rms_only) { - ubuf[0] = mu; - } - ubuf[1] = sigma2; - } - __syncthreads(); - if (!rms_only) { - mu = ubuf[0]; - } - sigma2 = ubuf[1] / U(n2); - // don't care about final value of count, we know count == n2 - } else { - if (!rms_only) { - mu = __shfl_sync(0xffffffff, mu, 0, warpSize); - } - sigma2 = __shfl_sync(0xffffffff, sigma2 / U(n2), 0, warpSize); - } - } -} - -template <> -__device__ void cuWelfordMuSigma2(const __half *__restrict__ vals, const int n1, - const int n2, const int i1, float &mu, - float &sigma2, float *buf, bool rms_only) { - // Assumptions: - // 1) blockDim.x == warpSize - // 2) Tensor is contiguous - // 3) 2*blockDim.y*sizeof(U)+blockDim.y*sizeof(int) shared memory available. - // - // compute variance and mean over n2 - float count = 0.0f; - mu = float(0); - sigma2 = float(0); - if (i1 < n1) { - // one warp normalizes one n1 index, - // synchronization is implicit - // initialize with standard Welford algorithm - const int numx = blockDim.x * blockDim.y; - const int thrx = threadIdx.x + threadIdx.y * blockDim.x; - const __half *lvals = vals + i1 * n2; - int l = 8 * thrx; - if ((((size_t)lvals) & 3) != 0) { - // 16 bit alignment - // first thread consumes first point - if (thrx == 0) { - float curr = static_cast(lvals[0]); - if (!rms_only) { - cuWelfordOnlineSum(curr, mu, sigma2, count); - } else { - cuRMSOnlineSum(curr, sigma2); - } - } - ++l; - } - // at this point, lvals[l] are 32 bit aligned for all threads. - for (; l + 7 < n2; l += 8 * numx) { - for (int k = 0; k < 8; k += 2) { - float2 curr = __half22float2(*((__half2 *)(lvals + l + k))); - if (!rms_only) { - cuWelfordOnlineSum(curr.x, mu, sigma2, count); - cuWelfordOnlineSum(curr.y, mu, sigma2, count); - } else { - cuRMSOnlineSum(curr.x, sigma2); - cuRMSOnlineSum(curr.y, sigma2); - } - } - } - for (; l < n2; ++l) { - float curr = static_cast(lvals[l]); - if (!rms_only) { - cuWelfordOnlineSum(curr, mu, sigma2, count); - } else { - cuRMSOnlineSum(curr, sigma2); - } - } - // intra-warp reductions - for (int l = 0; l <= 4; ++l) { - int srcLaneB = (threadIdx.x + (1 << l)) & 31; - float sigma2B = __shfl_sync(0xffffffff, sigma2, srcLaneB, warpSize); - if (!rms_only) { - float muB = __shfl_sync(0xffffffff, mu, srcLaneB, warpSize); - float countB = __shfl_sync(0xffffffff, count, srcLaneB, warpSize); - cuChanOnlineSum(muB, sigma2B, countB, mu, sigma2, count); - } else { - cuChanRMSOnlineSum(sigma2B, sigma2); - } - } - // threadIdx.x == 0 has correct values for each warp - // inter-warp reductions - if (blockDim.y > 1) { - float *ubuf = (float *)buf; - float *ibuf = (float *)(ubuf + blockDim.y); - for (int offset = blockDim.y / 2; offset > 0; offset /= 2) { - // upper half of warps write to shared - if (threadIdx.x == 0 && threadIdx.y >= offset && - threadIdx.y < 2 * offset) { - const int wrt_y = threadIdx.y - offset; - ubuf[2 * wrt_y + 1] = sigma2; - if (!rms_only) { - ubuf[2 * wrt_y] = mu; - ibuf[wrt_y] = count; - } - } - __syncthreads(); - // lower half merges - if (threadIdx.x == 0 && threadIdx.y < offset) { - float sigma2B = ubuf[2 * threadIdx.y + 1]; - if (!rms_only) { - float muB = ubuf[2 * threadIdx.y]; - float countB = ibuf[threadIdx.y]; - cuChanOnlineSum(muB, sigma2B, countB, mu, sigma2, count); - } else { - cuChanRMSOnlineSum(sigma2B, sigma2); - } - } - __syncthreads(); - } - // threadIdx.x = 0 && threadIdx.y == 0 only thread that has correct values - if (threadIdx.x == 0 && threadIdx.y == 0) { - if (!rms_only) { - ubuf[0] = mu; - } - ubuf[1] = sigma2; - } - __syncthreads(); - if (!rms_only) { - mu = ubuf[0]; - } - sigma2 = ubuf[1] / float(n2); - // don't care about final value of count, we know count == n2 - } else { - if (!rms_only) { - mu = __shfl_sync(0xffffffff, mu, 0, warpSize); - } - sigma2 = __shfl_sync(0xffffffff, sigma2 / float(n2), 0, warpSize); - } - } -} - -// This is the un-specialized struct. Note that we prevent instantiation of -// this struct by putting an undefined symbol in the function body so it won't -// compile. -// template -// struct SharedMemory -// { -// // Ensure that we won't compile any un-specialized types -// __device__ T *getPointer() -// { -// extern __device__ void error(void); -// error(); -// return NULL; -// } -// }; -// https://github.com/NVIDIA/apex/issues/246 -template struct SharedMemory; - -template <> struct SharedMemory { - __device__ float *getPointer() { - extern __shared__ float s_float[]; - return s_float; - } -}; - -template <> struct SharedMemory { - __device__ double *getPointer() { - extern __shared__ double s_double[]; - return s_double; - } -}; - -template -__device__ void cuApplyLayerNorm_(V *__restrict__ output_vals, - U *__restrict__ mean, U *__restrict__ invvar, - const T *__restrict__ vals, const int n1, - const int n2, const U epsilon, - const V *__restrict__ gamma, - const V *__restrict__ beta, bool rms_only) { - // Assumptions: - // 1) blockDim.x == warpSize - // 2) Tensors are contiguous - // - for (auto i1 = blockIdx.y; i1 < n1; i1 += gridDim.y) { - SharedMemory shared; - U *buf = shared.getPointer(); - U mu, sigma2; - cuWelfordMuSigma2(vals, n1, n2, i1, mu, sigma2, buf, rms_only); - - const T *lvals = vals + i1 * n2; - V *ovals = output_vals + i1 * n2; - U c_invvar = rsqrt(sigma2 + epsilon); - const int numx = blockDim.x * blockDim.y; - const int thrx = threadIdx.x + threadIdx.y * blockDim.x; - if (gamma != NULL && (beta != NULL || rms_only)) { - for (int i = thrx; i < n2; i += numx) { - U curr = static_cast(lvals[i]); - if (!rms_only) { - ovals[i] = - gamma[i] * static_cast(c_invvar * (curr - mu)) + beta[i]; - } else { - ovals[i] = gamma[i] * static_cast(c_invvar * curr); - } - } - } else { - for (int i = thrx; i < n2; i += numx) { - U curr = static_cast(lvals[i]); - if (!rms_only) { - ovals[i] = static_cast(c_invvar * (curr - mu)); - } else { - ovals[i] = static_cast(c_invvar * curr); - } - } - } - if (threadIdx.x == 0 && threadIdx.y == 0) { - if (!rms_only) { - mean[i1] = mu; - } - invvar[i1] = c_invvar; - } - __syncthreads(); - } -} - -template -__global__ void -cuApplyRMSNorm(V *__restrict__ output_vals, U *__restrict__ invvar, - const T *__restrict__ vals, const int n1, const int n2, - const U epsilon, const V *__restrict__ gamma) { - cuApplyLayerNorm_(output_vals, NULL, invvar, vals, n1, n2, epsilon, - gamma, NULL, true); -} - -template -void HostApplyRMSNorm(cudaStream_t stream, V *output, U *invvar, const T *input, - int n1, int n2, double epsilon, const V *gamma) { - auto getMaxGridY = []() { - int device; - int val; - cudaGetDevice(&device); - cudaDeviceGetAttribute(&val, cudaDevAttrMaxGridDimY, device); - return val; - }; - const dim3 threads(32, 4, 1); - const uint64_t maxGridY = getMaxGridY(); - const dim3 blocks(1, std::min((uint64_t)n1, maxGridY), 1); - int nshared = - threads.y > 1 ? threads.y * sizeof(U) + (threads.y / 2) * sizeof(U) : 0; - cuApplyRMSNorm<<>>( - output, invvar, input, n1, n2, U(epsilon), gamma); -} - -template -__device__ void cuLoadWriteStridedInputs( - const int i1_block, const int thr_load_row_off, const int thr_load_col_off, - const int i2_off, const int row_stride, U *warp_buf1, U *warp_buf2, - const T *input, const V *dout, const int i1_end, const int n2, - const U *__restrict__ mean, const U *__restrict__ invvar, bool rms_only) { - int i1 = i1_block + thr_load_row_off; - if (i1 < i1_end) { - U curr_mean; - if (!rms_only) { - curr_mean = mean[i1]; - } - U curr_invvar = invvar[i1]; - for (int k = 0; k < blockDim.y; ++k) { - int i2 = i2_off + k; - int load_idx = i1 * n2 + i2; - int write_idx = thr_load_row_off * row_stride + thr_load_col_off + k; - if (i2 < n2) { - U curr_input = static_cast(input[load_idx]); - U curr_dout = static_cast(dout[load_idx]); - if (!rms_only) { - warp_buf1[write_idx] = curr_dout; - warp_buf2[write_idx] = - curr_dout * (curr_input - curr_mean) * curr_invvar; - } else { - warp_buf2[write_idx] = curr_dout * (curr_input)*curr_invvar; - } - } else { - if (!rms_only) { - warp_buf1[write_idx] = U(0); - } - warp_buf2[write_idx] = U(0); - } - } - } else { - for (int k = 0; k < blockDim.y; ++k) { - int write_idx = thr_load_row_off * row_stride + thr_load_col_off + k; - if (!rms_only) { - warp_buf1[write_idx] = U(0); - } - warp_buf2[write_idx] = U(0); - } - } -} - -template -__device__ void cuLoadAddStridedInputs( - const int i1_block, const int thr_load_row_off, const int thr_load_col_off, - const int i2_off, const int row_stride, U *warp_buf1, U *warp_buf2, - const T *input, const V *dout, const int i1_end, const int n2, - const U *__restrict__ mean, const U *__restrict__ invvar, bool rms_only) { - int i1 = i1_block + thr_load_row_off; - if (i1 < i1_end) { - U curr_mean; - if (!rms_only) { - curr_mean = mean[i1]; - } - U curr_invvar = invvar[i1]; - for (int k = 0; k < blockDim.y; ++k) { - int i2 = i2_off + k; - int load_idx = i1 * n2 + i2; - int write_idx = thr_load_row_off * row_stride + thr_load_col_off + k; - if (i2 < n2) { - U curr_input = static_cast(input[load_idx]); - U curr_dout = static_cast(dout[load_idx]); - if (!rms_only) { - warp_buf1[write_idx] += curr_dout; - warp_buf2[write_idx] += - curr_dout * (curr_input - curr_mean) * curr_invvar; - } else { - warp_buf2[write_idx] += curr_dout * (curr_input)*curr_invvar; - } - } - } - } -} - -template -__global__ void cuComputePartGradGammaBeta( - const V *__restrict__ dout, const T *__restrict__ input, const int n1, - const int n2, const U *__restrict__ mean, const U *__restrict__ invvar, - U epsilon, U *part_grad_gamma, U *part_grad_beta, bool rms_only) { - const int numsegs_n1 = - (n1 + blockDim.y * blockDim.y - 1) / (blockDim.y * blockDim.y); - const int segs_per_block = (numsegs_n1 + gridDim.y - 1) / gridDim.y; - const int i1_beg = blockIdx.y * segs_per_block * blockDim.y * blockDim.y; - const int i1_beg_plus_one = - (blockIdx.y + 1) * segs_per_block * blockDim.y * blockDim.y; - const int i1_end = i1_beg_plus_one < n1 ? i1_beg_plus_one : n1; - const int row_stride = blockDim.x + 1; - const int thr_load_col_off = (threadIdx.x * blockDim.y) & (blockDim.x - 1); - const int thr_load_row_off = - (threadIdx.x * blockDim.y) / blockDim.x + threadIdx.y * blockDim.y; - const int i2_off = blockIdx.x * blockDim.x + thr_load_col_off; - SharedMemory shared; - U *buf = shared.getPointer(); // buf has at least blockDim.x * blockDim.y * - // blockDim.y + (blockDim.y - - // 1)*(blockDim.x/blockDim.y) elements - U *warp_buf1 = (U *)buf; - U *warp_buf2 = warp_buf1 + blockDim.y * blockDim.y * row_stride; - // compute partial sums from strided inputs - // do this to increase number of loads in flight - cuLoadWriteStridedInputs(i1_beg, thr_load_row_off, thr_load_col_off, i2_off, - row_stride, warp_buf1, warp_buf2, input, dout, - i1_end, n2, mean, invvar, rms_only); - for (int i1_block = i1_beg + blockDim.y * blockDim.y; i1_block < i1_end; - i1_block += blockDim.y * blockDim.y) { - cuLoadAddStridedInputs(i1_block, thr_load_row_off, thr_load_col_off, i2_off, - row_stride, warp_buf1, warp_buf2, input, dout, - i1_end, n2, mean, invvar, rms_only); - } - __syncthreads(); - // inter-warp reductions - // sum within each warp - U acc1 = U(0); - U acc2 = U(0); - for (int k = 0; k < blockDim.y; ++k) { - int row1 = threadIdx.y + k * blockDim.y; - int idx1 = row1 * row_stride + threadIdx.x; - if (!rms_only) { - acc1 += warp_buf1[idx1]; - } - acc2 += warp_buf2[idx1]; - } - if (!rms_only) { - warp_buf1[threadIdx.y * row_stride + threadIdx.x] = acc1; - } - warp_buf2[threadIdx.y * row_stride + threadIdx.x] = acc2; - __syncthreads(); - // sum all warps - for (int offset = blockDim.y / 2; offset > 1; offset /= 2) { - if (threadIdx.y < offset) { - int row1 = threadIdx.y; - int row2 = threadIdx.y + offset; - int idx1 = row1 * row_stride + threadIdx.x; - int idx2 = row2 * row_stride + threadIdx.x; - if (!rms_only) { - warp_buf1[idx1] += warp_buf1[idx2]; - } - warp_buf2[idx1] += warp_buf2[idx2]; - } - __syncthreads(); - } - int i2 = blockIdx.x * blockDim.x + threadIdx.x; - if (threadIdx.y == 0 && i2 < n2) { - int row1 = threadIdx.y; - int row2 = threadIdx.y + 1; - int idx1 = row1 * row_stride + threadIdx.x; - int idx2 = row2 * row_stride + threadIdx.x; - if (!rms_only) { - part_grad_beta[blockIdx.y * n2 + i2] = warp_buf1[idx1] + warp_buf1[idx2]; - } - part_grad_gamma[blockIdx.y * n2 + i2] = warp_buf2[idx1] + warp_buf2[idx2]; - } -} - -template -__global__ void -cuComputeGradGammaBeta(const U *part_grad_gamma, const U *part_grad_beta, - const int part_size, const int n1, const int n2, - V *grad_gamma, V *grad_beta, bool rms_only) { - // sum partial gradients for gamma and beta - SharedMemory shared; - U *buf = shared.getPointer(); - int i2 = blockIdx.x * blockDim.x + threadIdx.x; - if (i2 < n2) { - // each warp does sequential reductions until reduced part_size is num_warps - int num_warp_reductions = part_size / blockDim.y; - U sum_gamma = U(0); - U sum_beta = U(0); - const U *part_grad_gamma_ptr = - part_grad_gamma + threadIdx.y * num_warp_reductions * n2 + i2; - const U *part_grad_beta_ptr = - part_grad_beta + threadIdx.y * num_warp_reductions * n2 + i2; - for (int warp_offset = 0; warp_offset < num_warp_reductions; - ++warp_offset) { - sum_gamma += part_grad_gamma_ptr[warp_offset * n2]; - if (!rms_only) { - sum_beta += part_grad_beta_ptr[warp_offset * n2]; - } - } - // inter-warp reductions - const int nbsize3 = blockDim.x * blockDim.y / 2; - for (int offset = blockDim.y / 2; offset >= 1; offset /= 2) { - // top half write to shared memory - if (threadIdx.y >= offset && threadIdx.y < 2 * offset) { - const int write_idx = (threadIdx.y - offset) * blockDim.x + threadIdx.x; - buf[write_idx] = sum_gamma; - if (!rms_only) { - buf[write_idx + nbsize3] = sum_beta; - } - } - __syncthreads(); - // bottom half sums - if (threadIdx.y < offset) { - const int read_idx = threadIdx.y * blockDim.x + threadIdx.x; - sum_gamma += buf[read_idx]; - if (!rms_only) { - sum_beta += buf[read_idx + nbsize3]; - } - } - __syncthreads(); - } - // write out fully summed gradients - if (threadIdx.y == 0) { - grad_gamma[i2] = sum_gamma; - if (!rms_only) { - grad_beta[i2] = sum_beta; - } - } - } -} - -template -__global__ void -cuComputeGradInput(const V *__restrict__ dout, const T *__restrict__ input, - const int n1, const int n2, const U *__restrict__ mean, - const U *__restrict__ invvar, U epsilon, const V *gamma, - T *grad_input, bool rms_only) { - for (auto i1 = blockIdx.y; i1 < n1; i1 += gridDim.y) { - U sum_loss1 = U(0); - U sum_loss2 = U(0); - U c_mean; - if (!rms_only) { - c_mean = mean[i1]; - } - const U c_invvar = invvar[i1]; - const T *k_input = input + i1 * n2; - const V *k_dout = dout + i1 * n2; - const int numx = blockDim.x * blockDim.y; - const int thrx = threadIdx.x + threadIdx.y * blockDim.x; - if (gamma != NULL) { - int l = 4 * thrx; - for (; l + 3 < n2; l += 4 * numx) { - for (int k = 0; k < 4; ++k) { - const U c_h = static_cast(k_input[l + k]); - const U c_loss = static_cast(k_dout[l + k]); - if (!rms_only) { - sum_loss1 += c_loss * static_cast(gamma[l + k]); - sum_loss2 += c_loss * static_cast(gamma[l + k]) * - (c_h - c_mean) * c_invvar; - } else { - sum_loss2 += c_loss * static_cast(gamma[l + k]) * (c_h)*c_invvar; - } - } - } - for (; l < n2; ++l) { - const U c_h = static_cast(k_input[l]); - const U c_loss = static_cast(k_dout[l]); - if (!rms_only) { - sum_loss1 += c_loss * static_cast(gamma[l]); - sum_loss2 += - c_loss * static_cast(gamma[l]) * (c_h - c_mean) * c_invvar; - } else { - sum_loss2 += c_loss * static_cast(gamma[l]) * (c_h)*c_invvar; - } - } - } else { - int l = 4 * thrx; - for (; l + 3 < n2; l += 4 * numx) { - for (int k = 0; k < 4; ++k) { - const U c_h = static_cast(k_input[l + k]); - const U c_loss = static_cast(k_dout[l + k]); - if (!rms_only) { - sum_loss1 += c_loss; - sum_loss2 += c_loss * (c_h - c_mean) * c_invvar; - } else { - sum_loss2 += c_loss * (c_h)*c_invvar; - } - } - } - for (; l < n2; ++l) { - const U c_h = static_cast(k_input[l]); - const U c_loss = static_cast(k_dout[l]); - if (!rms_only) { - sum_loss1 += c_loss; - sum_loss2 += c_loss * (c_h - c_mean) * c_invvar; - } else { - sum_loss2 += c_loss * (c_h)*c_invvar; - } - } - } - // intra-warp reductions - for (int mask = blockDim.x / 2; mask > 0; mask /= 2) { - if (!rms_only) { - sum_loss1 += __shfl_xor_sync(0xffffffff, sum_loss1, mask, warpSize); - } - sum_loss2 += __shfl_xor_sync(0xffffffff, sum_loss2, mask, warpSize); - } - // inter-warp reductions - if (blockDim.y > 1) { - SharedMemory shared; - U *buf = shared.getPointer(); - for (int offset = blockDim.y / 2; offset > 0; offset /= 2) { - // upper half of warps write to shared - if (threadIdx.y >= offset && threadIdx.y < 2 * offset) { - const int wrt_i = (threadIdx.y - offset) * blockDim.x + threadIdx.x; - if (!rms_only) { - buf[2 * wrt_i] = sum_loss1; - } - buf[2 * wrt_i + 1] = sum_loss2; - } - __syncthreads(); - // lower half merges - if (threadIdx.y < offset) { - const int read_i = threadIdx.y * blockDim.x + threadIdx.x; - if (!rms_only) { - sum_loss1 += buf[2 * read_i]; - } - sum_loss2 += buf[2 * read_i + 1]; - } - __syncthreads(); - } - if (threadIdx.y == 0) { - if (!rms_only) { - buf[2 * threadIdx.x] = sum_loss1; - } - buf[2 * threadIdx.x + 1] = sum_loss2; - } - __syncthreads(); - if (threadIdx.y != 0) { - if (!rms_only) { - sum_loss1 = buf[2 * threadIdx.x]; - } - sum_loss2 = buf[2 * threadIdx.x + 1]; - } - } - // all threads now have the two sums over l - U fH = (U)n2; - U term1 = (U(1) / fH) * c_invvar; - T *k_grad_input = grad_input + i1 * n2; - if (gamma != NULL) { - for (int l = thrx; l < n2; l += numx) { - const U c_h = static_cast(k_input[l]); - const U c_loss = static_cast(k_dout[l]); - U f_grad_input = fH * c_loss * static_cast(gamma[l]); - if (!rms_only) { - f_grad_input -= sum_loss1; - f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2; - } else { - f_grad_input -= (c_h)*c_invvar * sum_loss2; - } - f_grad_input *= term1; - k_grad_input[l] = static_cast(f_grad_input); - } - } else { - for (int l = thrx; l < n2; l += numx) { - const U c_h = static_cast(k_input[l]); - const U c_loss = static_cast(k_dout[l]); - U f_grad_input = fH * c_loss; - if (!rms_only) { - f_grad_input -= sum_loss1; - f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2; - } else { - f_grad_input -= (c_h)*c_invvar * sum_loss2; - } - f_grad_input *= term1; - k_grad_input[l] = static_cast(f_grad_input); - } - } - // prevent race where buf is written again before reads are done - __syncthreads(); - } -} - -template -void HostRMSNormGradient(cudaStream_t stream, const V *dout, const U *invvar, - const T *input, int n1, int n2, const V *gamma, - double epsilon, T *grad_input, V *grad_gamma, - int part_size, U *part_grad_gamma) { - auto getMaxGridY = []() { - int device; - int val; - cudaGetDevice(&device); - cudaDeviceGetAttribute(&val, cudaDevAttrMaxGridDimY, device); - return val; - }; - const uint64_t maxGridY = getMaxGridY(); - if (gamma != NULL) { - const dim3 threads2(32, 4, 1); - const dim3 blocks2((n2 + threads2.x - 1) / threads2.x, part_size, 1); - const int nshared2_a = - 2 * sizeof(U) * threads2.y * threads2.y * (threads2.x + 1); - const int nshared2_b = threads2.x * threads2.y * sizeof(U); - const int nshared2 = nshared2_a > nshared2_b ? nshared2_a : nshared2_b; - // note (mkozuki): I can hard code part_grad_gamma's dtype as float given - // that the `cuda_layer_norm_gradient` doesn't support double. - cuComputePartGradGammaBeta<<>>( - dout, input, n1, n2, - invvar, // unused - invvar, U(epsilon), part_grad_gamma, part_grad_gamma, /* unused */ - true); - - const dim3 threads3(32, 8, 1); - const dim3 blocks3((n2 + threads2.x - 1) / threads2.x, 1, 1); - const int nshared3 = threads3.x * threads3.y * sizeof(U); - cuComputeGradGammaBeta<<>>( - part_grad_gamma, part_grad_gamma, /* unused */ - part_size, n1, n2, grad_gamma, grad_gamma, /* unused */ - true); - } - - // compute grad_input - const dim3 blocks1(1, std::min((uint64_t)n1, maxGridY), 1); - const dim3 threads1(32, 4, 1); - int nshared = threads1.y > 1 ? threads1.y * threads1.x * sizeof(U) : 0; - cuComputeGradInput<<>>( - dout, input, n1, n2, invvar, /* unused */ - invvar, U(epsilon), gamma, grad_input, true); -} - -} // namespace - -namespace gpu_ops { +## Let's put it together -void rms_forward_affine_mixed_dtypes(cudaStream_t stream, void **buffers, - const char *opaque, - std::size_t opaque_len) { - const RMSNormDescriptor &d = - *UnpackDescriptor(opaque, opaque_len); - - DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES( - d.x_type, d.w_type, "rms_norm_cuda_kernel", - HostApplyRMSNorm( - stream, static_cast(buffers[2]), - static_cast(buffers[3]), - static_cast(buffers[0]), d.n1, d.n2, d.eps, - /*gamma=*/static_cast(buffers[1]));) -} +The complete definition of the primitives using custom_partitioning can be found in [Custom_Operation_for_GPUs.py](Custom_Operation_for_GPUs.py) and the corresponding C++ code the defines python bindings in addition to the kernel implementations can be found below: -void rms_backward_affine(cudaStream_t stream, void **buffers, - const char *opaque, std::size_t opaque_len) { - const RMSNormDescriptor &d = - *UnpackDescriptor(opaque, opaque_len); - - DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES( - d.x_type, d.w_type, "cuComputeGradInputRMS", - HostRMSNormGradient( - stream, - /*dout=*/static_cast(buffers[0]), - /*invvar=*/static_cast(buffers[1]), - /*input=*/static_cast(buffers[2]), d.n1, d.n2, - // TMJ pass NULL argument for gamma, beta, grad_gamma and grad_beta - // if gamma Tensor is NULL on input. - /*gamma=*/static_cast(buffers[3]), d.eps, - /*grad_input=*/static_cast(buffers[4]), - /*grad_gamma=*/static_cast(buffers[5]), - d.part_grad_size, - /*part_grad_gamma=*/static_cast(buffers[6]));) -} +### `gpu_ops` code listing -} // namespace gpu_ops -``` +[gpu_ops/kernel_helpers.h](gpu_ops/kernel_helpers.h) \ +[gpu_ops/kernels.h](gpu_ops/kernels.h) \ +[gpu_ops/pybind11_kernel_helpers.h](gpu_ops/pybind11_kernel_helpers.h) \ +[gpu_ops/gpu_ops.cpp](gpu_ops/gpu_ops.cpp) \ +[gpu_ops/rms_norm_kernels.cu](gpu_ops/rms_norm_kernels.cu) diff --git a/docs/Custom_Operation_for_GPUs.py b/docs/Custom_Operation_for_GPUs.py new file mode 100644 index 000000000000..4c0b4b6f7b38 --- /dev/null +++ b/docs/Custom_Operation_for_GPUs.py @@ -0,0 +1,527 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import partial, reduce +import math + +import jax +import jax.numpy as jnp +from build import gpu_ops +from jax import core, dtypes +from jax.core import ShapedArray +from jax.experimental.custom_partitioning import custom_partitioning +from jax.experimental.pjit import pjit +from jax.interpreters import batching, mlir, xla +from jax.interpreters.mlir import ir +from jax.lib import xla_client +from jax.sharding import Mesh, NamedSharding, PartitionSpec +from jaxlib.hlo_helpers import custom_call +from jax._src import dispatch + + +###################################################################### +# Created Primitives for unsharded RMS norm reference implementation # +###################################################################### + +# Create _rms_norm_fwd_p for forward operation. +_rms_norm_fwd_p = core.Primitive("rms_norm_fwd") +_rms_norm_fwd_p.multiple_results = True +_rms_norm_fwd_p.def_impl(partial(xla.apply_primitive, _rms_norm_fwd_p)) + + +def rms_norm_fwd(x, weight, eps=1e-05): + output, invvar = _rms_norm_fwd_p.bind(x, weight, eps=eps) + return output, (invvar, x, weight) + + +# Create _rms_norm_bwd_p for backward operation. +_rms_norm_bwd_p = core.Primitive("rms_norm_bwd") +_rms_norm_bwd_p.multiple_results = True +_rms_norm_bwd_p.def_impl(partial(xla.apply_primitive, _rms_norm_bwd_p)) + + +def rms_norm_bwd(eps, res, g): + invvar, x, weight = res + grad_input, grad_weight, part_grad = _rms_norm_bwd_p.bind( + g, invvar, x, weight, eps=eps + ) + return grad_input, grad_weight + + +#################### +# Lowering to MLIR # +#################### + + +# Register functions defined in gpu_ops as custom call target for GPUs +for _name, _value in gpu_ops.get_rms_norm_registrations().items(): + xla_client.register_custom_call_target(_name, _value, platform="gpu") + + +def element_type_to_descriptor_type_mapping(element_type): + _element_type_to_descriptor_type_mapping = { + ir.BF16Type.get(): gpu_ops.ElementType.BF16, + ir.F16Type.get(): gpu_ops.ElementType.F16, + ir.F32Type.get(): gpu_ops.ElementType.F32, + ir.F64Type.get(): gpu_ops.ElementType.F64, + } + return _element_type_to_descriptor_type_mapping.get(element_type) + + +def default_layouts(*shapes): + return [range(len(shape) - 1, -1, -1) for shape in shapes] + + +def _rms_norm_fwd_cuda_lowering(ctx, x, weight, eps): + x_type = ir.RankedTensorType(x.type) + x_shape = x_type.shape + w_type = ir.RankedTensorType(weight.type) + w_shape = w_type.shape + iv_element_type = ( + ir.F32Type.get() + if x_type.element_type in [ir.F16Type.get(), ir.BF16Type.get()] + else x_type.element_type + ) + + n2 = math.prod(w_shape) + n1 = math.prod(x_shape) // n2 + + opaque = gpu_ops.create_rms_norm_descriptor( + n1, + n2, + eps, + element_type_to_descriptor_type_mapping(x_type.element_type), + element_type_to_descriptor_type_mapping(w_type.element_type), + 0, # unused + ) + out = custom_call( + b"rms_forward_affine_mixed_dtype", + result_types=[ + ir.RankedTensorType.get(x_shape, w_type.element_type), + ir.RankedTensorType.get((n1,), iv_element_type), + ], + operands=[x, weight], + backend_config=opaque, + operand_layouts=default_layouts(x_shape, w_shape), + result_layouts=default_layouts(x_shape, (n1,)), + ).results + return out + + +mlir.register_lowering( + _rms_norm_fwd_p, + _rms_norm_fwd_cuda_lowering, + platform="gpu", +) + + +def _rms_norm_bwd_cuda_lowering(ctx, grad_output, invvar, x, weight, eps): + x_type = ir.RankedTensorType(x.type) + x_shape = x_type.shape + w_type = ir.RankedTensorType(weight.type) + w_shape = w_type.shape + iv_type = ir.RankedTensorType(invvar.type) + + n2 = reduce(lambda x, y: x * y, w_shape) + n1 = reduce(lambda x, y: x * y, x_shape) // n2 + + part_grad_shape = ctx.avals_out[-1].shape + + opaque = gpu_ops.create_rms_norm_descriptor( + n1, + n2, + eps, + element_type_to_descriptor_type_mapping(x_type.element_type), + element_type_to_descriptor_type_mapping(w_type.element_type), + part_grad_shape[0], + ) + out = custom_call( + b"rms_backward_affine", + result_types=[ + ir.RankedTensorType.get(x_shape, x_type.element_type), + ir.RankedTensorType.get(w_shape, w_type.element_type), + ir.RankedTensorType.get(part_grad_shape, iv_type.element_type), + ], + operands=[grad_output, invvar, x, weight], + backend_config=opaque, + operand_layouts=default_layouts(x_shape, (n1,), x_shape, w_shape), + result_layouts=default_layouts(x_shape, w_shape, part_grad_shape), + ).results + return out + + +mlir.register_lowering( + _rms_norm_bwd_p, + _rms_norm_bwd_cuda_lowering, + platform="gpu", +) + + +####################### +# Abstract evaluation # +####################### + + +def _rms_norm_fwd_abstract(x, weight, eps): + w_dtype = dtypes.canonicalize_dtype(weight.dtype) + iv_dtype = dtypes.canonicalize_dtype(x.dtype) + if iv_dtype in [jnp.float16, jnp.bfloat16]: + iv_dtype = jnp.float32 + n2 = math.prod(weight.shape) + n1 = math.prod(x.shape) // n2 + return ( + ShapedArray(x.shape, w_dtype, named_shape=x.named_shape), # output + ShapedArray((n1,), iv_dtype, named_shape=x.named_shape), # invvar + ) + + +_rms_norm_fwd_p.def_abstract_eval(_rms_norm_fwd_abstract) + + +def _rms_norm_bwd_abstract(grad_output, invvar, x, weight, eps): + iv_dtype = dtypes.canonicalize_dtype(invvar.dtype) + w_dtype = dtypes.canonicalize_dtype(weight.dtype) + x_dtype = dtypes.canonicalize_dtype(x.dtype) + n2 = reduce(lambda x, y: x * y, weight.shape) + n1 = reduce(lambda x, y: x * y, x.shape) // n2 + part_grad_shape = (16, n2) + assert dtypes.canonicalize_dtype(grad_output.dtype) == w_dtype + assert grad_output.shape == x.shape + assert invvar.shape == (n1,) + assert ( + iv_dtype == jnp.float32 if x_dtype in [jnp.float16, jnp.bfloat16] else x_dtype + ) + assert grad_output.named_shape == x.named_shape + weight_named_shape = ( + weight.named_shape if weight.named_shape else grad_output.named_shape + ) + return ( + ShapedArray( + x.shape, x_dtype, named_shape=x.named_shape + ), # grad input + ShapedArray( + weight.shape, w_dtype, named_shape=weight_named_shape + ), # grad weight + ShapedArray( + part_grad_shape, iv_dtype, named_shape=weight_named_shape + ), # part grad + ) + + +_rms_norm_bwd_p.def_abstract_eval(_rms_norm_bwd_abstract) + + +####################################### +# Top-level interface with custom vjp # +####################################### + + +@partial(jax.custom_vjp, nondiff_argnums=(2,)) +def rms_norm(x, weight, eps=1e-05): + output, _ = rms_norm_fwd(x, weight, eps=eps) + return output + + +rms_norm.defvjp(rms_norm_fwd, rms_norm_bwd) + +########################################################### +# Create primitives for RMS norm with custom_partitioning # +########################################################### + +def _check_valid_batch_dims(bdims): + """ + Assert out non-supported bath dims + """ + for dim in bdims: + assert dim in [0, None], \ + "Currently only support batch_dim in [0, None], " \ + f"but got {dim=}" + +def register_primitive(cls): + """ + register jax primitive + + The order of calls. Each operation is composed of two primitives: Inner and Outer. + + Inner, only the basic to wrap the custom_call itself. + - impl to XLA custom_call in C. + - abstract to know the static shapes + - lower to StableHLO XLA custom_call. + Outer, mostly all the rest: + - impl: Bind to the inner primitive. Not used for real computation, but only for tracing. So we only need to bind. + - abstract: same + - lower to StableHLO custom_p. (XLA will call the python callback from it) + - custom_p + - vmap: could be added here. + VJP is based on Outer, but not handled in this function. + """ + + def name_of_wrapper_p(): + return cls.name + "_wrapper" + + inner_p = core.Primitive(cls.name) + dispatch.prim_requires_devices_during_lowering.add(inner_p) + inner_p.multiple_results = cls.multiple_results + inner_p.def_impl(partial(xla.apply_primitive, inner_p)) + inner_p.def_abstract_eval(cls.abstract) + mlir.register_lowering(inner_p, cls.lowering, platform='cuda') + cls.inner_primitive = inner_p + + outer_p = core.Primitive(name_of_wrapper_p()) + dispatch.prim_requires_devices_during_lowering.add(outer_p) + outer_p.multiple_results = cls.multiple_results + outer_p.def_impl(cls.impl) + outer_p.def_abstract_eval(cls.abstract) + batching.primitive_batchers[outer_p] = cls.batcher + outer_p_lower = custom_partitioning(cls.impl, static_argnums=cls.impl_static_args) + outer_p_lower.def_partition(infer_sharding_from_operands=cls.infer_sharding_from_operands, + partition=cls.partition) + mlir.register_lowering(outer_p, + mlir.lower_fun(outer_p_lower, multiple_results=cls.multiple_results)) + cls.outer_primitive = outer_p + + +class RmsNormFwdClass: + name = "rms_forward_affine_mixed_dtype" + multiple_results = True + impl_static_args = (2,) # eps + inner_primitive = None + outer_primitive = None + + @staticmethod + def abstract(x_aval, gamma_aval, **kwargs): # pylint: disable=unused-argument + return _rms_norm_fwd_abstract(x_aval, gamma_aval, **kwargs) + + @staticmethod + def lowering(ctx, x, gamma, *, eps): + return _rms_norm_fwd_cuda_lowering(ctx, x, gamma, eps) + + @staticmethod + def impl(x, gamma, eps): + assert RmsNormFwdClass.inner_primitive is not None + out, rsigma = RmsNormFwdClass.inner_primitive.bind(x, gamma, eps=eps) + return out, rsigma + + @staticmethod + def batcher(batched_args, batch_dims, *, eps): + _check_valid_batch_dims(batch_dims) + assert RmsNormFwdClass.outer_primitive is not None + x, gamma = batched_args + x_bdim, _ = batch_dims + + out_bdims = x_bdim, x_bdim + return RmsNormFwdClass.outer_primitive.bind(x, gamma, eps=eps), out_bdims + + @staticmethod + def infer_sharding_from_operands(eps: float, mesh : jax.sharding.Mesh, + arg_infos: tuple[jax._src.api.ShapeDtypeStruct, ...], + result_infos: tuple[jax._src.core.ShapedArray, ...]): + del eps, result_infos # Not needed for this example. + x_info, weight_info = arg_infos + assert len(x_info.shape) == 3 + assert len(weight_info.shape) == 2 + # partition() will force all dims to be replicated except the + # first dim of x that will be kept as is. + x_spec = arg_infos[0].sharding.spec + output_sharding = NamedSharding(mesh, PartitionSpec(x_spec[0], None, None)) + invvar_sharding = NamedSharding(mesh, PartitionSpec(x_spec[0])) + return (output_sharding, invvar_sharding) + + @staticmethod + def partition(eps: float, mesh : jax.sharding.Mesh, + arg_infos: tuple[jax._src.api.ShapeDtypeStruct, ...], + result_infos: tuple[jax._src.api.ShapeDtypeStruct, ...]): + del result_infos # Not needed for this example. + x_info, weight_info = arg_infos + assert len(x_info.shape) == 3 + assert len(weight_info.shape) == 2 + x_spec = arg_infos[0].sharding.spec + # We only support sharding on the batch dimensions. + # Force sharding on all others dimensions with None. + arg_shardings = (NamedSharding(mesh, PartitionSpec(x_spec[0], None, None)), + NamedSharding(mesh, PartitionSpec(None, None))) # TODO: TE don't force anything. + invvar_sharding = NamedSharding(mesh, PartitionSpec(x_spec[0])) + output_shardings = (arg_shardings[0], invvar_sharding) + # Sharded_impl only accepts positional arugments + # And they should be Jax traceable variables + impl = partial(RmsNormFwdClass.impl, eps=eps) + + return mesh, impl, output_shardings, arg_shardings + +register_primitive(RmsNormFwdClass) + +class RmsNormBwdClass: + name = "rms_norm_bwd" + multiple_results = True + impl_static_args = (4,) # eps + inner_primitive = None + outer_primitive = None + + @staticmethod + def abstract(grad_output, invvar, x, weight, eps): # pylint: disable=unused-argument + return _rms_norm_bwd_abstract(grad_output, invvar, x, weight, eps) + + @staticmethod + def lowering(ctx, grad_output, invvar, x, weight, eps): + return _rms_norm_bwd_cuda_lowering(ctx, grad_output, invvar, x, weight, eps) + + @staticmethod + def impl(grad_output, invvar, x, weight, eps): + assert RmsNormBwdClass.inner_primitive is not None + gx, gw, part_grad = RmsNormBwdClass.inner_primitive.bind(grad_output, invvar, x, weight, eps=eps) + return gx, gw, part_grad + + @staticmethod + def batcher(batched_args, batch_dims, *, eps): + # TODO: Add to the tutorial! + _check_valid_batch_dims(batch_dims) + assert RmsNormBwdClass.outer_primitive is not None + x, gamma = batched_args + x_bdim, _ = batch_dims + + out_bdims = x_bdim, x_bdim + return RmsNormBwdClass.outer_primitive.bind(x, gamma, eps=eps), out_bdims + + @staticmethod + def infer_sharding_from_operands(eps: float, mesh : jax.sharding.Mesh, + arg_infos: tuple[jax._src.api.ShapeDtypeStruct, ...], + result_infos: tuple[jax._src.core.ShapedArray, ...]): + del eps, result_infos # Not needed for this example. + g_info, invvar_info, x_info, weight_info = arg_infos + assert len(g_info.shape) == 3 + assert len(invvar_info.shape) == 1 + assert len(x_info.shape) == 3 + assert len(weight_info.shape) == 2 + # partition() will force all dims to be replicated except the batch dimension. + x_spec = x_info.sharding.spec + output_sharding = NamedSharding(mesh, PartitionSpec(x_spec[0], None, None)) + invvar_sharding = NamedSharding(mesh, PartitionSpec(None, None)) + return (output_sharding, invvar_sharding, output_sharding, ) + + @staticmethod + def partition(eps: float, mesh : jax.sharding.Mesh, + arg_infos: tuple[jax._src.api.ShapeDtypeStruct, ...], + result_infos: tuple[jax._src.api.ShapeDtypeStruct, ...]): + del result_infos # Not needed for this example. + g_info, invvar_info, x_info, weight_info = arg_infos + assert len(g_info.shape) == 3 + assert len(invvar_info.shape) == 1 + assert len(x_info.shape) == 3 + assert len(weight_info.shape) == 2 + + # We only support sharding on the batch dimensions. + # Force sharding on all others dimensions with None. + # Also force gx, x and invvar to have the same batch sharding/replication. + x_spec = x_info.sharding.spec + arg_shardings = (NamedSharding(mesh, PartitionSpec(x_spec[0], None, None)), + NamedSharding(mesh, PartitionSpec(x_spec[0],)), + NamedSharding(mesh, PartitionSpec(x_spec[0], None, None)), + NamedSharding(mesh, PartitionSpec(None, None))) + + output_sharding = NamedSharding(mesh, PartitionSpec(x_spec[0], None, None)) + invvar_sharding = NamedSharding(mesh, PartitionSpec(None, None)) + output_shardings = (output_sharding, invvar_sharding, invvar_sharding) + + + # Sharded_impl only accepts positional arugments + # And they should be Jax traceable variables + def sharded_impl(g, invvar, x, weight): + grad_input, grad_weight, part_grad = RmsNormBwdClass.impl( + g, invvar, x, weight, eps=eps + ) + # We need to sum the weight gradient from all partition. + # when the input is sharded and weights are replicated + global_weight = grad_weight + if x_spec[0]: + global_weight = jax.lax.psum(grad_weight, x_spec[0]) + return grad_input, global_weight, part_grad + return mesh, sharded_impl, output_shardings, arg_shardings + +register_primitive(RmsNormBwdClass) + +def custom_p_rms_norm_fwd(x, weight, eps=1e-05): + output, invvar = RmsNormFwdClass.outer_primitive.bind(x, weight, eps=eps) + return output, (invvar, x, weight) + +@partial(jax.custom_vjp, nondiff_argnums=(2,)) +def custom_p_rms_norm(x, weight, eps=1e-05): + output, _ = custom_p_rms_norm_fwd(x, weight, eps=eps) + return output + +def custom_p_rms_norm_bwd(eps, res, g): + invvar, x, weight = res + grad_input, grad_weight, part_grad = RmsNormBwdClass.outer_primitive.bind( + g, invvar, x, weight, eps=eps) + return grad_input, grad_weight + +custom_p_rms_norm.defvjp(custom_p_rms_norm_fwd, custom_p_rms_norm_bwd) + +######## +# Test # +######## + + +import jax + +per_core_batch_size = 4 +seq_len = 512 +emb_dim = 512 +assert jax.local_device_count() > 1, "Only 1 GPU, the example work, but it is this really what you want?" +x = jax.random.normal( + jax.random.PRNGKey(0), + shape=(jax.local_device_count() * per_core_batch_size, seq_len, emb_dim), + dtype=jnp.float16, +) +norm_shape = x.shape[-2:] +weight = jnp.ones(norm_shape, dtype=jnp.float16) + + +def ref_loss(x, weight): + predictions = rms_norm(x, weight) + return -jnp.mean(predictions**2) + +ref_out = jax.grad(ref_loss, argnums=(0, 1))(x, weight) + +def custom_p_loss(x, weight): + predictions = custom_p_rms_norm(x, weight) + return -jnp.mean(predictions**2) + +with Mesh(jax.local_devices(), ("x",)): + def run_and_verify(loss): + pjitted = pjit( + jax.grad(loss, argnums=(0, 1)), + # Shard x by batch dimension and replicate weight on all devices. + in_shardings=( + PartitionSpec("x", None, None), + PartitionSpec(None, None), + ), + # Shard the output by batch dimension and replicate weight grad on all devices. + out_shardings=( + PartitionSpec("x", None, None), + PartitionSpec(None, None), + ), + ) + hlo = pjitted.lower(x, weight).compile().as_text() + out = pjitted(x, weight) + print(hlo) + assert "all-reduce-done" in hlo, "The gradient will produce wrong value!" + if "all-gather-start" in hlo: + print("NOT OPTIMIZED, ALL_GATHER in the graph!") + return out + + custom_p_out = run_and_verify(custom_p_loss) + + +for r, o in zip(ref_out, custom_p_out): + print(jnp.allclose(r, o, atol=1e-6, rtol=1e-6)) diff --git a/docs/_static/distributed_data_loading/1.svg b/docs/_static/distributed_data_loading/1.svg new file mode 100644 index 000000000000..78653e459f14 --- /dev/null +++ b/docs/_static/distributed_data_loading/1.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/docs/_static/distributed_data_loading/10.svg b/docs/_static/distributed_data_loading/10.svg new file mode 100644 index 000000000000..43af436dfd63 --- /dev/null +++ b/docs/_static/distributed_data_loading/10.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/docs/_static/distributed_data_loading/11.svg b/docs/_static/distributed_data_loading/11.svg new file mode 100644 index 000000000000..0ca1d7aa84e8 --- /dev/null +++ b/docs/_static/distributed_data_loading/11.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/docs/_static/distributed_data_loading/12.svg b/docs/_static/distributed_data_loading/12.svg new file mode 100644 index 000000000000..28e393f5d2eb --- /dev/null +++ b/docs/_static/distributed_data_loading/12.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/docs/_static/distributed_data_loading/13.svg b/docs/_static/distributed_data_loading/13.svg new file mode 100644 index 000000000000..c13ee5daec44 --- /dev/null +++ b/docs/_static/distributed_data_loading/13.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/docs/_static/distributed_data_loading/14.svg b/docs/_static/distributed_data_loading/14.svg new file mode 100644 index 000000000000..ca00f661726a --- /dev/null +++ b/docs/_static/distributed_data_loading/14.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/docs/_static/distributed_data_loading/15.svg b/docs/_static/distributed_data_loading/15.svg new file mode 100644 index 000000000000..6955ef7ea0cb --- /dev/null +++ b/docs/_static/distributed_data_loading/15.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/docs/_static/distributed_data_loading/16.svg b/docs/_static/distributed_data_loading/16.svg new file mode 100644 index 000000000000..ccb458db2040 --- /dev/null +++ b/docs/_static/distributed_data_loading/16.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/docs/_static/distributed_data_loading/17.svg b/docs/_static/distributed_data_loading/17.svg new file mode 100644 index 000000000000..4232c22d15e3 --- /dev/null +++ b/docs/_static/distributed_data_loading/17.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/docs/_static/distributed_data_loading/18.svg b/docs/_static/distributed_data_loading/18.svg new file mode 100644 index 000000000000..3e5557aa5db9 --- /dev/null +++ b/docs/_static/distributed_data_loading/18.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/docs/_static/distributed_data_loading/19.svg b/docs/_static/distributed_data_loading/19.svg new file mode 100644 index 000000000000..6ce4c46399c8 --- /dev/null +++ b/docs/_static/distributed_data_loading/19.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/docs/_static/distributed_data_loading/2.svg b/docs/_static/distributed_data_loading/2.svg new file mode 100644 index 000000000000..30cb51c8b444 --- /dev/null +++ b/docs/_static/distributed_data_loading/2.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/docs/_static/distributed_data_loading/20.svg b/docs/_static/distributed_data_loading/20.svg new file mode 100644 index 000000000000..aa739d561ff3 --- /dev/null +++ b/docs/_static/distributed_data_loading/20.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/docs/_static/distributed_data_loading/21.svg b/docs/_static/distributed_data_loading/21.svg new file mode 100644 index 000000000000..e9a5c959a952 --- /dev/null +++ b/docs/_static/distributed_data_loading/21.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/docs/_static/distributed_data_loading/22.svg b/docs/_static/distributed_data_loading/22.svg new file mode 100644 index 000000000000..8b6848bcb2e2 --- /dev/null +++ b/docs/_static/distributed_data_loading/22.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/docs/_static/distributed_data_loading/3.svg b/docs/_static/distributed_data_loading/3.svg new file mode 100644 index 000000000000..6a0ba6bdaf16 --- /dev/null +++ b/docs/_static/distributed_data_loading/3.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/docs/_static/distributed_data_loading/4.svg b/docs/_static/distributed_data_loading/4.svg new file mode 100644 index 000000000000..2f3093de72fc --- /dev/null +++ b/docs/_static/distributed_data_loading/4.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/docs/_static/distributed_data_loading/5.svg b/docs/_static/distributed_data_loading/5.svg new file mode 100644 index 000000000000..f82f612ad74a --- /dev/null +++ b/docs/_static/distributed_data_loading/5.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/docs/_static/distributed_data_loading/6.svg b/docs/_static/distributed_data_loading/6.svg new file mode 100644 index 000000000000..968a1c71f288 --- /dev/null +++ b/docs/_static/distributed_data_loading/6.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/docs/_static/distributed_data_loading/7.svg b/docs/_static/distributed_data_loading/7.svg new file mode 100644 index 000000000000..8828abfdb615 --- /dev/null +++ b/docs/_static/distributed_data_loading/7.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/docs/_static/distributed_data_loading/8.svg b/docs/_static/distributed_data_loading/8.svg new file mode 100644 index 000000000000..ac35768f4a0f --- /dev/null +++ b/docs/_static/distributed_data_loading/8.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/docs/_static/distributed_data_loading/9.svg b/docs/_static/distributed_data_loading/9.svg new file mode 100644 index 000000000000..3e6a489c2a89 --- /dev/null +++ b/docs/_static/distributed_data_loading/9.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/docs/_static/style.css b/docs/_static/style.css index f2b855838064..7a5c647052f0 100644 --- a/docs/_static/style.css +++ b/docs/_static/style.css @@ -27,3 +27,17 @@ div.red-background pre { div.green-background pre { background-color: rgba(204, 244, 204, var(--block-bg-opacity)); } + +/* Python code block comments */ +html[data-theme="light"] .highlight span.c1 { + color: #fa8d59; +} + +/* Python code traceback and exception */ +html[data-theme="light"] .highlight span.gt { + color: #ff0000; +} + +html[data-theme="light"] .highlight span.gr { + color: #ff0000; +} diff --git a/docs/tutorials/advanced-autodiff.md b/docs/_tutorials/advanced-autodiff.md similarity index 99% rename from docs/tutorials/advanced-autodiff.md rename to docs/_tutorials/advanced-autodiff.md index 712640e8efec..da95f96d8b25 100644 --- a/docs/tutorials/advanced-autodiff.md +++ b/docs/_tutorials/advanced-autodiff.md @@ -5,7 +5,7 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.16.0 + jupytext_version: 1.16.1 kernelspec: display_name: Python 3 language: python @@ -13,9 +13,11 @@ kernelspec: --- (advanced-autodiff)= -# Advanced automatic differentiation 201 +# Advanced automatic differentiation -In this tutorial, you will learn about complex applications of automatic differentiation (autodiff) in JAX and gain a better understanding of how taking derivatives in JAX can be both easy and powerful.docs.g + + +In this tutorial, you will learn about complex applications of automatic differentiation (autodiff) in JAX and gain a better understanding of how taking derivatives in JAX can be both easy and powerful. Make sure to check out the {ref}`automatic-differentiation` tutorial to go over the JAX autodiff basics, if you haven't already. @@ -38,7 +40,7 @@ JAX's autodiff makes it easy to compute higher-order derivatives, because the fu The single-variable case was covered in the {ref}`automatic-differentiation` tutorial, where the example showed how to use {func}`jax.grad` to compute the the derivative of $f(x) = x^3 + 2x^2 - 3x + 1$. -In the multi-variable case, higher-order derivatives are more complicated. The second-order derivative of a function is represented by its [Hessian matrix](https://en.wikipedia.org/wiki/Hessian_matrix), defined according to: +In the multivariable case, higher-order derivatives are more complicated. The second-order derivative of a function is represented by its [Hessian matrix](https://en.wikipedia.org/wiki/Hessian_matrix), defined according to: $$(\mathbf{H}f)_{i,j} = \frac{\partial^2 f}{\partial_i\partial_j}.$$ diff --git a/docs/tutorials/advanced-compilation.md b/docs/_tutorials/advanced-compilation.md similarity index 81% rename from docs/tutorials/advanced-compilation.md rename to docs/_tutorials/advanced-compilation.md index 5be9a7163bda..09535f2fce96 100644 --- a/docs/tutorials/advanced-compilation.md +++ b/docs/_tutorials/advanced-compilation.md @@ -1,7 +1,9 @@ # Advanced compilation + + ```{note} -This is a placeholder for a section in the new {ref}`jax-tutorials`. +This is a placeholder for a section in the new {ref}`jax-tutorials-draft`. For the time being, you may find some related content in the old documentation: - {doc}`../aot` diff --git a/docs/_tutorials/advanced-debugging.md b/docs/_tutorials/advanced-debugging.md new file mode 100644 index 000000000000..56188e0958fa --- /dev/null +++ b/docs/_tutorials/advanced-debugging.md @@ -0,0 +1,25 @@ +--- +jupytext: + formats: md:myst + text_representation: + extension: .md + format_name: myst + format_version: 0.13 + jupytext_version: 1.16.1 +kernelspec: + display_name: Python 3 + language: python + name: python3 +--- + +(advanced-debugging)= +# Advanced debugging + + + +```{note} +This is a placeholder for a section in the new {ref}`jax-tutorials-draft`. + +For the time being, you may find some related content in the old documentation: +- {doc}`../debugging/index` +``` diff --git a/docs/tutorials/external-callbacks.md b/docs/_tutorials/external-callbacks.md similarity index 97% rename from docs/tutorials/external-callbacks.md rename to docs/_tutorials/external-callbacks.md index 3f6bfc2ba5e8..a46927e6a8b4 100644 --- a/docs/tutorials/external-callbacks.md +++ b/docs/_tutorials/external-callbacks.md @@ -5,16 +5,25 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.16.0 + jupytext_version: 1.16.1 kernelspec: display_name: Python 3 language: python name: python3 --- +```{code-cell} +:tags: [remove-cell] + +# This ensures that code cell tracebacks appearing below will be concise. +%xmode minimal +``` + (external-callbacks)= # External callbacks + + This tutorial outlines how you can use various callback functions, which allow JAX runtimes to execute Python code on the host. Examples of JAX callbacks are {func}`jax.pure_callback`, {func}`jax.experimental.io_callback` and {func}`jax.debug.callback`. You can use them even while running under JAX transformations, including {func}`~jax.jit`, {func}`~jax.vmap`, {func}`~jax.grad`. ## Why callbacks? @@ -35,7 +44,7 @@ def f(x): result = f(2) ``` -What is printed is not the runtime value, but the trace-time abstract value (if you're not familiar with *tracing* in JAX, a good primer can be found in {ref}`thinking-in-jax`. +What is printed is not the runtime value, but the trace-time abstract value (if you're not familiar with *tracing* in JAX, a good primer can be found in {ref}`key-concepts-tracing`. To print the value at runtime, you need a callback, for example {func}`jax.debug.print` (you can learn more about debugging in {ref}`debugging`): @@ -49,7 +58,7 @@ def f(x): result = f(2) ``` -This works by passing the runtime value represented by `y` back to the host process, where the host can print the value. +This works by passing the runtime value of `y` as a CPU {class}`jax.Array` back to the host process, where the host can print it. (external-callbacks-flavors-of-callback)= ## Flavors of callback @@ -117,10 +126,6 @@ jax.lax.scan(body_fun, None, jnp.arange(5.0))[1] However, because there is no way for JAX to introspect the content of the callback, `pure_callback` has undefined autodiff semantics: -```{code-cell} -%xmode minimal -``` - ```{code-cell} :tags: [raises-exception] diff --git a/docs/tutorials/gradient-checkpointing.md b/docs/_tutorials/gradient-checkpointing.md similarity index 99% rename from docs/tutorials/gradient-checkpointing.md rename to docs/_tutorials/gradient-checkpointing.md index 4b31a650cf2e..b768514e4bb0 100644 --- a/docs/tutorials/gradient-checkpointing.md +++ b/docs/_tutorials/gradient-checkpointing.md @@ -5,7 +5,7 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.16.0 + jupytext_version: 1.16.1 kernelspec: display_name: Python 3 language: python @@ -15,6 +15,8 @@ kernelspec: (gradient-checkpointing)= ## Gradient checkpointing with `jax.checkpoint` (`jax.remat`) + + In this tutorial, you will learn how to control JAX automatic differentiation's saved values using {func}`jax.checkpoint` (also known as {func}`jax.remat`), which can be particularly helpful in machine learning. If you are new to automatic differentiation (autodiff) or need to refresh your memory, JAX has {ref}`automatic-differentiation` and {ref}`advanced-autodiff` tutorials. diff --git a/docs/_tutorials/index.rst b/docs/_tutorials/index.rst new file mode 100644 index 000000000000..d261612a4cd4 --- /dev/null +++ b/docs/_tutorials/index.rst @@ -0,0 +1,57 @@ +:orphan: + +.. _jax-tutorials-draft: + +JAX tutorials draft +=================== + +.. note:: + + This is a + The tutorials below are a work in progress; for the time being, please refer + to the older tutorial content, including :ref:`beginner-guide`, + :ref:`user-guides`, and the now-deleted *JAX 101* tutorials. + +JAX 101 +------- +Mostly finalized at :ref:`jax-tutorials`! + +.. toctree:: + :maxdepth: 1 + + ../quickstart + ../key-concepts + ../jit-compilation + ../automatic-vectorization + ../automatic-differentiation + ../debugging + ../random-numbers + ../working-with-pytrees + ../sharded-computation + ../stateful-computations + simple-neural-network + + +JAX 201 +------- + +.. toctree:: + :maxdepth: 1 + + parallelism + advanced-autodiff + gradient-checkpointing + advanced-debugging + external-callbacks + profiling-and-performance + + +JAX 301 +------- + +.. toctree:: + :maxdepth: 1 + + jax-primitives + jaxpr + advanced-compilation diff --git a/docs/tutorials/jax-primitives.md b/docs/_tutorials/jax-primitives.md similarity index 98% rename from docs/tutorials/jax-primitives.md rename to docs/_tutorials/jax-primitives.md index 73f6ff4542ca..51abe2916693 100644 --- a/docs/tutorials/jax-primitives.md +++ b/docs/_tutorials/jax-primitives.md @@ -5,7 +5,7 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.16.0 + jupytext_version: 1.16.1 kernelspec: display_name: Python 3 language: python @@ -13,7 +13,9 @@ kernelspec: --- (jax-internals-jax-primitives)= -# JAX internals 301: JAX primitives +# JAX Internals: primitives + + ## Introduction to JAX primitives @@ -406,7 +408,7 @@ If you attempt now to use reverse differentiation, you'll notice that JAX starts When computing the reverse differentiation, JAX first performs an abstract evaluation of the forward differentiation code `multiply_add_value_and_jvp` to obtain a trace of primitives that compute the output tangent. - Observe that JAX performs this abstract evaluation with concrete values for the differentiation point, and abstract values for the tangents. -- Notice that JAX uses the special abstract tangent value `Zero` for the tangent corresponding to the 3rd argument of `ma`. This reflects the fact that you do not differentiate w.r.t. the secibd argument to `square_add_prim`, which flows to the third argument to `multiply_add_prim`. +- Notice that JAX uses the special abstract tangent value `Zero` for the tangent corresponding to the third argument of `ma`. This reflects the fact that you do not differentiate w.r.t. the second argument to `square_add_prim`, which flows to the third argument to `multiply_add_prim`. - Notice also that during the abstract evaluation of the tangent you pass the value `0.0` as the tangent for the third argument. This is because of the use of the `make_zero` function in the definition of `multiply_add_value_and_jvp`. ```{code-cell} diff --git a/docs/tutorials/jaxpr.md b/docs/_tutorials/jaxpr.md similarity index 97% rename from docs/tutorials/jaxpr.md rename to docs/_tutorials/jaxpr.md index ccbe2fde8c14..9fe990c0a8ba 100644 --- a/docs/tutorials/jaxpr.md +++ b/docs/_tutorials/jaxpr.md @@ -5,7 +5,7 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.16.0 + jupytext_version: 1.16.1 kernelspec: display_name: Python 3 language: python @@ -13,7 +13,9 @@ kernelspec: --- (jax-internals-jaxpr)= -# JAX internals 301: The jaxpr language +# JAX internals: The jaxpr language + + Jaxprs are JAX’s internal intermediate representation (IR) of programs. They are explicitly typed, functional, first-order, and in algebraic normal form (ANF). @@ -195,7 +197,7 @@ def func7(arg): print(make_jaxpr(func7)(5.)) ``` -In this case, the boolean predicate is converted to an integer index (0 or 1), and `branches` are jaxprs that correspond to the false and true branch functionals, in that order. Again, each functional takes one input variable, corresponding to `xfalse` and `xtrue` respectively. +In this case, the boolean predicate is converted to an integer index (0 or 1), and `branches` are jaxprs that correspond to the false and true branch functionals, in that order. Again, each function takes one input variable, corresponding to `xfalse` and `xtrue` respectively. The following example shows a more complicated situation when the input to the branch functionals is a tuple, and the `false` branch functional contains a constant `jnp.ones(1)` that is hoisted as a `constvar`. @@ -218,7 +220,7 @@ lax.while_loop(cond_fun: (C -> bool), body_fun: (C -> C), init: C) -> C lax.fori_loop(start: int, end: int, body: (int -> C -> C), init: C) -> C ``` -In the above signature, `C` stands for the type of a the loop “carry” value. For example, here is an example `fori_loop`: +In the above signature, `C` stands for the type of the loop “carry” value. For example, here is an example `fori_loop`: ```{code-cell} import numpy as np diff --git a/docs/tutorials/parallelism.md b/docs/_tutorials/parallelism.md similarity index 82% rename from docs/tutorials/parallelism.md rename to docs/_tutorials/parallelism.md index 47a8c646bb37..9b840357e8aa 100644 --- a/docs/tutorials/parallelism.md +++ b/docs/_tutorials/parallelism.md @@ -1,7 +1,9 @@ # Parallel computation + + ```{note} -This is a placeholder for a section in the new {ref}`jax-tutorials`. +This is a placeholder for a section in the new {ref}`jax-tutorials-draft`. For the time being, you may find some related content in the old documentation: - {doc}`../multi_process` diff --git a/docs/tutorials/profiling-and-performance.md b/docs/_tutorials/profiling-and-performance.md similarity index 81% rename from docs/tutorials/profiling-and-performance.md rename to docs/_tutorials/profiling-and-performance.md index 7928b1cfed60..d9a13b213f70 100644 --- a/docs/tutorials/profiling-and-performance.md +++ b/docs/_tutorials/profiling-and-performance.md @@ -1,7 +1,9 @@ # Profiling and performance + + ```{note} -This is a placeholder for a section in the new {ref}`jax-tutorials`. +This is a placeholder for a section in the new {ref}`jax-tutorials-draft`. For the time being, you may find some related content in the old documentation: - {doc}`../profiling` diff --git a/docs/tutorials/simple-neural-network.md b/docs/_tutorials/simple-neural-network.md similarity index 66% rename from docs/tutorials/simple-neural-network.md rename to docs/_tutorials/simple-neural-network.md index 5c8f6335f037..76e98db88d82 100644 --- a/docs/tutorials/simple-neural-network.md +++ b/docs/_tutorials/simple-neural-network.md @@ -1,5 +1,7 @@ # Example: Writing a simple neural network + + ```{note} -This is a placeholder for a section in the new {ref}`jax-tutorials`. +This is a placeholder for a section in the new {ref}`jax-tutorials-draft`. ``` diff --git a/docs/advanced_guide.rst b/docs/advanced_guide.rst index eddca16d545c..28830ebcb018 100644 --- a/docs/advanced_guide.rst +++ b/docs/advanced_guide.rst @@ -19,6 +19,7 @@ This section contains examples and tutorials on more advanced topics, such as Mu multi_process notebooks/Distributed_arrays_and_automatic_parallelization notebooks/shard_map + distributed_data_loading notebooks/xmap_tutorial .. toctree:: diff --git a/docs/aot.md b/docs/aot.md index 1a7ec0080e61..ed7f4574900b 100644 --- a/docs/aot.md +++ b/docs/aot.md @@ -1,5 +1,9 @@ +(ahead-of-time-lowering)= + # Ahead-of-time lowering and compilation + + JAX offers several transformations, such as `jax.jit` and `jax.pmap`, returning a function that is compiled and runs on accelerators or the CPU. As the JIT acronym indicates, all compilation happens _just-in-time_ for execution. @@ -33,8 +37,6 @@ way. An example: ```python >>> import jax ->>> import jax.numpy as jnp ->>> import numpy as np >>> def f(x, y): return 2 * x + y >>> x, y = 3, 4 @@ -43,12 +45,12 @@ way. An example: >>> # Print lowered HLO >>> print(lowered.as_text()) -module @jit_f.0 { - func.func public @main(%arg0: tensor, %arg1: tensor) -> tensor { - %0 = stablehlo.constant dense<2> : tensor - %1 = stablehlo.multiply %0, %arg0 : tensor - %2 = stablehlo.add %1, %arg1 : tensor - return %2 : tensor +module @jit_f attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main(%arg0: tensor {mhlo.layout_mode = "default"}, %arg1: tensor {mhlo.layout_mode = "default"}) -> (tensor {jax.result_info = "", mhlo.layout_mode = "default"}) { + %c = stablehlo.constant dense<2> : tensor + %0 = stablehlo.multiply %c, %arg0 : tensor + %1 = stablehlo.add %0, %arg1 : tensor + return %1 : tensor } } @@ -60,9 +62,14 @@ module @jit_f.0 { >>> # Execute the compiled function! >>> compiled(x, y) -DeviceArray(10, dtype=int32) +Array(10, dtype=int32, weak_type=True) + ``` +Note that the lowered objects can be used only in the same process +in which they were lowered. For exporting use cases, +see the {ref}`export` APIs. + See the {mod}`jax.stages` documentation for more details on what functionality the lowering and compiled functions provide. @@ -81,7 +88,8 @@ that have `shape` and `dtype` attributes: ```python >>> i32_scalar = jax.ShapeDtypeStruct((), jnp.dtype('int32')) >>> jax.jit(f).lower(i32_scalar, i32_scalar).compile()(x, y) -DeviceArray(10, dtype=int32) +Array(10, dtype=int32) + ``` More generally, `lower` only needs its arguments to structurally supply what JAX @@ -95,18 +103,21 @@ lowering raises an error: ```python >>> x_1d = y_1d = jnp.arange(3) ->>> jax.jit(f).lower(i32_scalar, i32_scalar).compile()(x_1d, y_1d) +>>> jax.jit(f).lower(i32_scalar, i32_scalar).compile()(x_1d, y_1d) # doctest: +IGNORE_EXCEPTION_DETAIL ... +Traceback (most recent call last): TypeError: Argument types differ from the types for which this computation was compiled. The mismatches are: Argument 'x' compiled with int32[] and called with int32[3] Argument 'y' compiled with int32[] and called with int32[3] >>> x_f = y_f = jnp.float32(72.) ->>> jax.jit(f).lower(i32_scalar, i32_scalar).compile()(x_f, y_f) +>>> jax.jit(f).lower(i32_scalar, i32_scalar).compile()(x_f, y_f) # doctest: +IGNORE_EXCEPTION_DETAIL ... +Traceback (most recent call last): TypeError: Argument types differ from the types for which this computation was compiled. The mismatches are: Argument 'x' compiled with int32[] and called with float32[] Argument 'y' compiled with int32[] and called with float32[] + ``` Relatedly, AOT-compiled functions [cannot be transformed by JAX's just-in-time @@ -125,17 +136,23 @@ to invoke the resulting compiled function. Continuing with our example above: >>> # Lowered HLO, specialized to the *value* of the first argument (7) >>> print(lowered_with_x.as_text()) -module @jit_f.1 { - func.func public @main(%arg0: tensor) -> tensor { - %0 = stablehlo.constant dense<14> : tensor - %1 = stablehlo.add %0, %arg0 : tensor - return %1 : tensor +module @jit_f attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main(%arg0: tensor {mhlo.layout_mode = "default"}) -> (tensor {jax.result_info = "", mhlo.layout_mode = "default"}) { + %c = stablehlo.constant dense<14> : tensor + %0 = stablehlo.add %c, %arg0 : tensor + return %0 : tensor } } + >>> lowered_with_x.compile()(5) -DeviceArray(19, dtype=int32) +Array(19, dtype=int32, weak_type=True) + ``` +The result of `lower` is not safe to serialize directly for use +in a different process. +See {ref}`export` for additional APIs for this purpose. + Note that `lower` here takes two arguments as usual, but the subsequent compiled function accepts only the remaining non-static second argument. The static first argument (value 7) is taken as a constant at lowering time and built into the @@ -147,11 +164,13 @@ shape/dtype structure, it is necessary that the static first argument be a concrete value. Otherwise, lowering would err: ```python ->>> jax.jit(f, static_argnums=0).lower(i32_scalar, i32_scalar) +>>> jax.jit(f, static_argnums=0).lower(i32_scalar, i32_scalar) # doctest: +SKIP +Traceback (most recent call last): TypeError: unsupported operand type(s) for *: 'int' and 'ShapeDtypeStruct' >>> jax.jit(f, static_argnums=0).lower(10, i32_scalar).compile()(5) -DeviceArray(25, dtype=int32) +Array(25, dtype=int32) + ``` ## AOT-compiled functions cannot be transformed @@ -177,13 +196,15 @@ in transformations. Example: >>> g_aot = jax.jit(g).lower(z).compile() >>> jax.vmap(g_jit)(zs) -DeviceArray([[ 1., 5., 9.], - [13., 17., 21.], - [25., 29., 33.], - [37., 41., 45.]], dtype=float32) +Array([[ 1., 5., 9.], + [13., 17., 21.], + [25., 29., 33.], + [37., 41., 45.]], dtype=float32) + +>>> jax.vmap(g_aot)(zs) # doctest: +SKIP +Traceback (most recent call last): +TypeError: Cannot apply JAX transformations to a function lowered and compiled for a particular signature. Detected argument of Tracer type ->>> jax.vmap(g_aot)(zs) -TypeError: Cannot apply JAX transformations to a function lowered and compiled for a particular signature. Detected argument of Tracer type . ``` A similar error is raised when `g_aot` is involved in autodiff diff --git a/docs/api_compatibility.md b/docs/api_compatibility.md index 39c4386f447e..deb1c690335d 100644 --- a/docs/api_compatibility.md +++ b/docs/api_compatibility.md @@ -2,6 +2,8 @@ # API compatibility + + JAX is constantly evolving, and we want to be able to make improvements to its APIs. That said, we want to minimize churn for the JAX user community, and we try to make breaking changes rarely. @@ -33,11 +35,13 @@ Only public JAX APIs are covered, which includes the following modules: * `jax.profiler` * `jax.random` (see [details below](#numerics-and-randomness)) * `jax.scipy` +* `jax.tree` * `jax.tree_util` * `jax.test_util` -Not everything in these modules is public. Over time, we are working to separate -public and private APIs. Public APIs are documented in the JAX documentation. +Not everything in these modules is intended to be public, and over time, we +are working to separate public and private APIs. Public APIs are documented +in the JAX documentation. Additionally, our goal is that all non-public APIs should have names prefixed with underscores, although we do not entirely comply with this yet. @@ -46,9 +50,7 @@ prefixed with underscores, although we do not entirely comply with this yet. * anything prefixed with an underscore. * `jax._src` * `jax.core` -* `jax.linear_util` * `jax.lib` -* `jax.prng` * `jax.interpreters` * `jax.experimental` * `jax.example_libraries` diff --git a/docs/async_dispatch.rst b/docs/async_dispatch.rst index 421a27b0b67b..00fb98a5185d 100644 --- a/docs/async_dispatch.rst +++ b/docs/async_dispatch.rst @@ -9,7 +9,7 @@ program: >>> import numpy as np >>> import jax.numpy as jnp >>> from jax import random ->>> x = random.uniform(random.PRNGKey(0), (1000, 1000)) +>>> x = random.uniform(random.key(0), (1000, 1000)) >>> # Printing the result (i.e. evaluating `repr(result)` or `str(result)`) >>> # will block until the value is ready. >>> jnp.dot(x, x) + 3. # doctest: +SKIP diff --git a/docs/autodidax.ipynb b/docs/autodidax.ipynb index 24980cf30d04..ed242ecc5710 100644 --- a/docs/autodidax.ipynb +++ b/docs/autodidax.ipynb @@ -36,6 +36,8 @@ "source": [ "# Autodidax: JAX core from scratch\n", "\n", + "\n", + "\n", "Ever want to learn how JAX works, but the implementation seemed impenetrable?\n", "Well, you're in luck! By reading this tutorial, you'll learn every big idea in\n", "JAX's core system. You'll even get clued into our weird jargon!\n", @@ -165,15 +167,15 @@ "source": [ "from collections.abc import Sequence\n", "from contextlib import contextmanager\n", - "from typing import Optional, Any\n", + "from typing import Any\n", "\n", "class MainTrace(NamedTuple):\n", " level: int\n", " trace_type: type['Trace']\n", - " global_data: Optional[Any]\n", + " global_data: Any | None\n", "\n", "trace_stack: list[MainTrace] = []\n", - "dynamic_trace: Optional[MainTrace] = None # to be employed in Part 3\n", + "dynamic_trace: MainTrace | None = None # to be employed in Part 3\n", "\n", "@contextmanager\n", "def new_main(trace_type: type['Trace'], global_data=None):\n", @@ -910,7 +912,7 @@ "source": [ "from collections.abc import Hashable, Iterable, Iterator\n", "import itertools as it\n", - "from typing import Callable\n", + "from collections.abc import Callable\n", "\n", "class NodeType(NamedTuple):\n", " name: str\n", @@ -1649,7 +1651,7 @@ "source": [ "from functools import lru_cache\n", "\n", - "@lru_cache() # ShapedArrays are hashable\n", + "@lru_cache # ShapedArrays are hashable\n", "def make_jaxpr_v1(f, *avals_in):\n", " avals_in, in_tree = tree_flatten(avals_in)\n", " f, out_tree = flatten_fun(f, in_tree)\n", @@ -1801,7 +1803,7 @@ " finally:\n", " dynamic_trace = prev_dynamic_trace\n", "\n", - "@lru_cache()\n", + "@lru_cache\n", "def make_jaxpr(f: Callable, *avals_in: ShapedArray,\n", " ) -> tuple[Jaxpr, list[Any], PyTreeDef]:\n", " avals_in, in_tree = tree_flatten(avals_in)\n", @@ -1992,7 +1994,7 @@ " return execute(*args)\n", "impl_rules[xla_call_p] = xla_call_impl\n", "\n", - "@lru_cache()\n", + "@lru_cache\n", "def xla_callable(hashable_jaxpr: IDHashable,\n", " hashable_consts: tuple[IDHashable, ...]):\n", " jaxpr: Jaxpr = hashable_jaxpr.val\n", @@ -2225,7 +2227,7 @@ " return primals_out, tangents_out\n", "jvp_rules[xla_call_p] = xla_call_jvp_rule\n", "\n", - "@lru_cache()\n", + "@lru_cache\n", "def jvp_jaxpr(jaxpr: Jaxpr) -> tuple[Jaxpr, list[Any]]:\n", " def jvp_traceable(*primals_and_tangents):\n", " n = len(primals_and_tangents) // 2\n", @@ -2251,7 +2253,7 @@ " return outs, [0] * len(outs)\n", "vmap_rules[xla_call_p] = xla_call_vmap_rule\n", "\n", - "@lru_cache()\n", + "@lru_cache\n", "def vmap_jaxpr(jaxpr: Jaxpr, axis_size: int, bdims_in: tuple[BatchAxis, ...]\n", " ) -> tuple[Jaxpr, list[Any]]:\n", " vmap_traceable = vmap(jaxpr_as_fun(jaxpr), tuple(bdims_in))\n", @@ -2636,7 +2638,7 @@ "source": [ "class PartialVal(NamedTuple):\n", " aval: ShapedArray\n", - " const: Optional[Any]\n", + " const: Any | None\n", "\n", " @classmethod\n", " def known(cls, val: Any):\n", @@ -2725,7 +2727,7 @@ "source": [ "class PartialEvalTracer(Tracer):\n", " pval: PartialVal\n", - " recipe: Optional[JaxprRecipe]\n", + " recipe: JaxprRecipe | None\n", "\n", " def __init__(self, trace, pval, recipe):\n", " self._trace = trace\n", @@ -2972,7 +2974,7 @@ "partial_eval_rules[xla_call_p] = xla_call_partial_eval\n", "\n", "def partial_eval_jaxpr(jaxpr: Jaxpr, in_unknowns: list[bool],\n", - " instantiate: Optional[list[bool]] = None,\n", + " instantiate: list[bool] | None = None,\n", " ) -> tuple[Jaxpr, Jaxpr, list[bool], int]:\n", " env: dict[Var, bool] = {}\n", " residuals: set[Var] = set()\n", @@ -3269,7 +3271,7 @@ " return [next(outs) if undef else None for undef in undef_primals]\n", "transpose_rules[xla_call_p] = xla_call_transpose_rule\n", "\n", - "@lru_cache()\n", + "@lru_cache\n", "def transpose_jaxpr(jaxpr: Jaxpr, undef_primals: tuple[bool, ...]\n", " ) -> tuple[Jaxpr, list[Any]]:\n", " avals_in, avals_out = typecheck_jaxpr(jaxpr)\n", diff --git a/docs/autodidax.md b/docs/autodidax.md index bfffecdb4241..0551b9905db3 100644 --- a/docs/autodidax.md +++ b/docs/autodidax.md @@ -6,7 +6,7 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.16.0 + jupytext_version: 1.16.1 kernelspec: display_name: Python 3 name: python3 @@ -39,6 +39,8 @@ Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab. # Autodidax: JAX core from scratch + + Ever want to learn how JAX works, but the implementation seemed impenetrable? Well, you're in luck! By reading this tutorial, you'll learn every big idea in JAX's core system. You'll even get clued into our weird jargon! @@ -146,15 +148,15 @@ more descriptive. ```{code-cell} from collections.abc import Sequence from contextlib import contextmanager -from typing import Optional, Any +from typing import Any class MainTrace(NamedTuple): level: int trace_type: type['Trace'] - global_data: Optional[Any] + global_data: Any | None trace_stack: list[MainTrace] = [] -dynamic_trace: Optional[MainTrace] = None # to be employed in Part 3 +dynamic_trace: MainTrace | None = None # to be employed in Part 3 @contextmanager def new_main(trace_type: type['Trace'], global_data=None): @@ -703,7 +705,7 @@ class Store: from collections.abc import Hashable, Iterable, Iterator import itertools as it -from typing import Callable +from collections.abc import Callable class NodeType(NamedTuple): name: str @@ -1293,7 +1295,7 @@ transformation and a pretty-printer: ```{code-cell} from functools import lru_cache -@lru_cache() # ShapedArrays are hashable +@lru_cache # ShapedArrays are hashable def make_jaxpr_v1(f, *avals_in): avals_in, in_tree = tree_flatten(avals_in) f, out_tree = flatten_fun(f, in_tree) @@ -1413,7 +1415,7 @@ def new_dynamic(main: MainTrace): finally: dynamic_trace = prev_dynamic_trace -@lru_cache() +@lru_cache def make_jaxpr(f: Callable, *avals_in: ShapedArray, ) -> tuple[Jaxpr, list[Any], PyTreeDef]: avals_in, in_tree = tree_flatten(avals_in) @@ -1562,7 +1564,7 @@ def xla_call_impl(*args, jaxpr: Jaxpr, num_consts: int): return execute(*args) impl_rules[xla_call_p] = xla_call_impl -@lru_cache() +@lru_cache def xla_callable(hashable_jaxpr: IDHashable, hashable_consts: tuple[IDHashable, ...]): jaxpr: Jaxpr = hashable_jaxpr.val @@ -1732,7 +1734,7 @@ def xla_call_jvp_rule(primals, tangents, *, jaxpr, num_consts): return primals_out, tangents_out jvp_rules[xla_call_p] = xla_call_jvp_rule -@lru_cache() +@lru_cache def jvp_jaxpr(jaxpr: Jaxpr) -> tuple[Jaxpr, list[Any]]: def jvp_traceable(*primals_and_tangents): n = len(primals_and_tangents) // 2 @@ -1753,7 +1755,7 @@ def xla_call_vmap_rule(axis_size, vals_in, dims_in, *, jaxpr, num_consts): return outs, [0] * len(outs) vmap_rules[xla_call_p] = xla_call_vmap_rule -@lru_cache() +@lru_cache def vmap_jaxpr(jaxpr: Jaxpr, axis_size: int, bdims_in: tuple[BatchAxis, ...] ) -> tuple[Jaxpr, list[Any]]: vmap_traceable = vmap(jaxpr_as_fun(jaxpr), tuple(bdims_in)) @@ -2063,7 +2065,7 @@ be either known or unknown: ```{code-cell} class PartialVal(NamedTuple): aval: ShapedArray - const: Optional[Any] + const: Any | None @classmethod def known(cls, val: Any): @@ -2127,7 +2129,7 @@ JaxprRecipe = Union[LambdaBindingRecipe, ConstRecipe, JaxprEqnRecipe] ```{code-cell} class PartialEvalTracer(Tracer): pval: PartialVal - recipe: Optional[JaxprRecipe] + recipe: JaxprRecipe | None def __init__(self, trace, pval, recipe): self._trace = trace @@ -2327,7 +2329,7 @@ def xla_call_partial_eval(trace, tracers, *, jaxpr, num_consts): partial_eval_rules[xla_call_p] = xla_call_partial_eval def partial_eval_jaxpr(jaxpr: Jaxpr, in_unknowns: list[bool], - instantiate: Optional[list[bool]] = None, + instantiate: list[bool] | None = None, ) -> tuple[Jaxpr, Jaxpr, list[bool], int]: env: dict[Var, bool] = {} residuals: set[Var] = set() @@ -2584,7 +2586,7 @@ def xla_call_transpose_rule(cts, *invals, jaxpr, num_consts): return [next(outs) if undef else None for undef in undef_primals] transpose_rules[xla_call_p] = xla_call_transpose_rule -@lru_cache() +@lru_cache def transpose_jaxpr(jaxpr: Jaxpr, undef_primals: tuple[bool, ...] ) -> tuple[Jaxpr, list[Any]]: avals_in, avals_out = typecheck_jaxpr(jaxpr) diff --git a/docs/autodidax.py b/docs/autodidax.py index fc68cfd0b2b3..b09534381c69 100644 --- a/docs/autodidax.py +++ b/docs/autodidax.py @@ -20,7 +20,7 @@ # extension: .py # format_name: light # format_version: '1.5' -# jupytext_version: 1.16.0 +# jupytext_version: 1.16.1 # kernelspec: # display_name: Python 3 # name: python3 @@ -31,6 +31,8 @@ # # Autodidax: JAX core from scratch # +# +# # Ever want to learn how JAX works, but the implementation seemed impenetrable? # Well, you're in luck! By reading this tutorial, you'll learn every big idea in # JAX's core system. You'll even get clued into our weird jargon! @@ -136,15 +138,15 @@ def bind1(prim, *args, **params): # + from collections.abc import Sequence from contextlib import contextmanager -from typing import Optional, Any +from typing import Any class MainTrace(NamedTuple): level: int trace_type: type['Trace'] - global_data: Optional[Any] + global_data: Any | None trace_stack: list[MainTrace] = [] -dynamic_trace: Optional[MainTrace] = None # to be employed in Part 3 +dynamic_trace: MainTrace | None = None # to be employed in Part 3 @contextmanager def new_main(trace_type: type['Trace'], global_data=None): @@ -695,7 +697,7 @@ def __call__(self): # + tags=["hide-input"] from collections.abc import Hashable, Iterable, Iterator import itertools as it -from typing import Callable +from collections.abc import Callable class NodeType(NamedTuple): name: str @@ -1295,7 +1297,7 @@ def broadcast_abstract_eval(x: ShapedArray, *, shape: Sequence[int], # + from functools import lru_cache -@lru_cache() # ShapedArrays are hashable +@lru_cache # ShapedArrays are hashable def make_jaxpr_v1(f, *avals_in): avals_in, in_tree = tree_flatten(avals_in) f, out_tree = flatten_fun(f, in_tree) @@ -1410,7 +1412,7 @@ def new_dynamic(main: MainTrace): finally: dynamic_trace = prev_dynamic_trace -@lru_cache() +@lru_cache def make_jaxpr(f: Callable, *avals_in: ShapedArray, ) -> tuple[Jaxpr, list[Any], PyTreeDef]: avals_in, in_tree = tree_flatten(avals_in) @@ -1554,7 +1556,7 @@ def xla_call_impl(*args, jaxpr: Jaxpr, num_consts: int): return execute(*args) impl_rules[xla_call_p] = xla_call_impl -@lru_cache() +@lru_cache def xla_callable(hashable_jaxpr: IDHashable, hashable_consts: tuple[IDHashable, ...]): jaxpr: Jaxpr = hashable_jaxpr.val @@ -1726,7 +1728,7 @@ def xla_call_jvp_rule(primals, tangents, *, jaxpr, num_consts): return primals_out, tangents_out jvp_rules[xla_call_p] = xla_call_jvp_rule -@lru_cache() +@lru_cache def jvp_jaxpr(jaxpr: Jaxpr) -> tuple[Jaxpr, list[Any]]: def jvp_traceable(*primals_and_tangents): n = len(primals_and_tangents) // 2 @@ -1747,7 +1749,7 @@ def xla_call_vmap_rule(axis_size, vals_in, dims_in, *, jaxpr, num_consts): return outs, [0] * len(outs) vmap_rules[xla_call_p] = xla_call_vmap_rule -@lru_cache() +@lru_cache def vmap_jaxpr(jaxpr: Jaxpr, axis_size: int, bdims_in: tuple[BatchAxis, ...] ) -> tuple[Jaxpr, list[Any]]: vmap_traceable = vmap(jaxpr_as_fun(jaxpr), tuple(bdims_in)) @@ -2055,7 +2057,7 @@ def vspace(aval: ShapedArray) -> ShapedArray: class PartialVal(NamedTuple): aval: ShapedArray - const: Optional[Any] + const: Any | None @classmethod def known(cls, val: Any): @@ -2119,7 +2121,7 @@ class JaxprEqnRecipe(NamedTuple): class PartialEvalTracer(Tracer): pval: PartialVal - recipe: Optional[JaxprRecipe] + recipe: JaxprRecipe | None def __init__(self, trace, pval, recipe): self._trace = trace @@ -2320,7 +2322,7 @@ def xla_call_partial_eval(trace, tracers, *, jaxpr, num_consts): partial_eval_rules[xla_call_p] = xla_call_partial_eval def partial_eval_jaxpr(jaxpr: Jaxpr, in_unknowns: list[bool], - instantiate: Optional[list[bool]] = None, + instantiate: list[bool] | None = None, ) -> tuple[Jaxpr, Jaxpr, list[bool], int]: env: dict[Var, bool] = {} residuals: set[Var] = set() @@ -2583,7 +2585,7 @@ def xla_call_transpose_rule(cts, *invals, jaxpr, num_consts): return [next(outs) if undef else None for undef in undef_primals] transpose_rules[xla_call_p] = xla_call_transpose_rule -@lru_cache() +@lru_cache def transpose_jaxpr(jaxpr: Jaxpr, undef_primals: tuple[bool, ...] ) -> tuple[Jaxpr, list[Any]]: avals_in, avals_out = typecheck_jaxpr(jaxpr) diff --git a/docs/tutorials/automatic-differentiation.md b/docs/automatic-differentiation.md similarity index 83% rename from docs/tutorials/automatic-differentiation.md rename to docs/automatic-differentiation.md index dbde5d5c3454..cc4a19aaba64 100644 --- a/docs/tutorials/automatic-differentiation.md +++ b/docs/automatic-differentiation.md @@ -5,7 +5,7 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.16.0 + jupytext_version: 1.16.1 kernelspec: display_name: Python 3 language: python @@ -13,9 +13,12 @@ kernelspec: --- (automatic-differentiation)= -# Automatic differentiation 101 +# Automatic differentiation -In this tutorial, you will learn about fundamental applications of automatic differentiation (autodiff) in JAX. JAX has a pretty general automatic differentiation (autodiff) system. Computing gradients is a critical part of modern machine learning methods, and this tutorial will walk you through a few introductory autodiff topics, such as: + + +In this section, you will learn about fundamental applications of automatic differentiation (autodiff) in JAX. JAX has a pretty general autodiff system. +Computing gradients is a critical part of modern machine learning methods, and this tutorial will walk you through a few introductory autodiff topics, such as: - {ref}`automatic-differentiation-taking-gradients` - {ref}`automatic-differentiation-linear logistic regression` @@ -28,9 +31,9 @@ Make sure to also check out the {ref}`advanced-autodiff` tutorial for more advan While understanding how automatic differentiation works "under the hood" isn't crucial for using JAX in most contexts, you are encouraged to check out this quite accessible [video](https://www.youtube.com/watch?v=wG_nF1awSSY) to get a deeper sense of what's going on. (automatic-differentiation-taking-gradients)= -## 1.Taking gradients with `jax.grad` +## 1. Taking gradients with `jax.grad` -In JAX, you can differentiate a function with the {func}`jax.grad` transformation: +In JAX, you can differentiate a scalar-valued function with the {func}`jax.grad` transformation: ```{code-cell} import jax @@ -63,7 +66,7 @@ dfdx = jax.grad(f) The higher-order derivatives of $f$ are: $$ -\begin{array}{l}s +\begin{array}{l} f'(x) = 3x^2 + 4x -3\\ f''(x) = 6x + 4\\ f'''(x) = 6\\ @@ -105,27 +108,27 @@ print(d4fdx(1.)) The next example shows how to compute gradients with {func}`jax.grad` in a linear logistic regression model. First, the setup: ```{code-cell} -key = jax.random.PRNGKey(0) +key = jax.random.key(0) def sigmoid(x): - return 0.5 * (jnp.tanh(x / 2) + 1) + return 0.5 * (jnp.tanh(x / 2) + 1) # Outputs probability of a label being true. def predict(W, b, inputs): - return sigmoid(jnp.dot(inputs, W) + b) + return sigmoid(jnp.dot(inputs, W) + b) # Build a toy dataset. inputs = jnp.array([[0.52, 1.12, 0.77], - [0.88, -1.08, 0.15], - [0.52, 0.06, -1.30], - [0.74, -2.49, 1.39]]) + [0.88, -1.08, 0.15], + [0.52, 0.06, -1.30], + [0.74, -2.49, 1.39]]) targets = jnp.array([True, True, False, True]) # Training loss is the negative log-likelihood of the training examples. def loss(W, b): - preds = predict(W, b, inputs) - label_probs = preds * targets + (1 - preds) * (1 - targets) - return -jnp.sum(jnp.log(label_probs)) + preds = predict(W, b, inputs) + label_probs = preds * targets + (1 - preds) * (1 - targets) + return -jnp.sum(jnp.log(label_probs)) # Initialize random model coefficients key, W_key, b_key = jax.random.split(key, 3) @@ -138,20 +141,20 @@ Use the {func}`jax.grad` function with its `argnums` argument to differentiate a ```{code-cell} # Differentiate `loss` with respect to the first positional argument: W_grad = grad(loss, argnums=0)(W, b) -print('W_grad', W_grad) +print(f'{W_grad=}') # Since argnums=0 is the default, this does the same thing: W_grad = grad(loss)(W, b) -print('W_grad', W_grad) +print(f'{W_grad=}') # But you can choose different values too, and drop the keyword: b_grad = grad(loss, 1)(W, b) -print('b_grad', b_grad) +print(f'{b_grad=}') # Including tuple values W_grad, b_grad = grad(loss, (0, 1))(W, b) -print('W_grad', W_grad) -print('b_grad', b_grad) +print(f'{W_grad=}') +print(f'{b_grad=}') ``` The {func}`jax.grad` API has a direct correspondence to the excellent notation in Spivak's classic *Calculus on Manifolds* (1965), also used in Sussman and Wisdom's [*Structure and Interpretation of Classical Mechanics*](https://mitpress.mit.edu/9780262028967/structure-and-interpretation-of-classical-mechanics) (2015) and their [*Functional Differential Geometry*](https://mitpress.mit.edu/9780262019347/functional-differential-geometry) (2013). Both books are open-access. See in particular the "Prologue" section of *Functional Differential Geometry* for a defense of this notation. @@ -162,7 +165,8 @@ Essentially, when using the `argnums` argument, if `f` is a Python function for (automatic-differentiation-nested-lists-tuples-and-dicts)= ## 3. Differentiating with respect to nested lists, tuples, and dicts -Differentiating with respect to standard Python containers just works, so use tuples, lists, and dicts (and arbitrary nesting) however you like. +Due to JAX's PyTree abstraction (see {ref}`working-with-pytrees`), differentiating with +respect to standard Python containers just works, so use tuples, lists, and dicts (and arbitrary nesting) however you like. Continuing the previous example: @@ -175,13 +179,13 @@ def loss2(params_dict): print(grad(loss2)({'W': W, 'b': b})) ``` -You can {ref}`pytrees-custom-pytree-nodes` to work with not just {func}`jax.grad` but other JAX transformations ({func}`jax.jit`, {func}`jax.vmap`, and so on). +You can create {ref}`pytrees-custom-pytree-nodes` to work with not just {func}`jax.grad` but other JAX transformations ({func}`jax.jit`, {func}`jax.vmap`, and so on). (automatic-differentiation-evaluating-using-jax-value_and_grad)= ## 4. Evaluating a function and its gradient using `jax.value_and_grad` -Another convenient function is {func}`jax.value_and_grad` for efficiently computing both a function's value as well as its gradient's value. +Another convenient function is {func}`jax.value_and_grad` for efficiently computing both a function's value as well as its gradient's value in one pass. Continuing the previous examples: diff --git a/docs/tutorials/automatic-vectorization.md b/docs/automatic-vectorization.md similarity index 87% rename from docs/tutorials/automatic-vectorization.md rename to docs/automatic-vectorization.md index b9c0b9cc4d80..7559155e2e9e 100644 --- a/docs/tutorials/automatic-vectorization.md +++ b/docs/automatic-vectorization.md @@ -5,7 +5,7 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.16.0 + jupytext_version: 1.16.1 kernelspec: display_name: Python 3 language: python @@ -13,11 +13,14 @@ kernelspec: --- (automatic-vectorization)= -# Automatic Vectorization in JAX +# Automatic vectorization -In the previous section we discussed JIT compilation via the `jax.jit` function. This notebook discusses another of JAX's transforms: vectorization via `jax.vmap`. + -## Manual Vectorization +In the previous section we discussed JIT compilation via the {func}`jax.jit` function. +This notebook discusses another of JAX's transforms: vectorization via {func}`jax.vmap`. + +## Manual vectorization Consider the following simple code that computes the convolution of two one-dimensional vectors: @@ -72,9 +75,9 @@ def manually_vectorized_convolve(xs, ws): manually_vectorized_convolve(xs, ws) ``` -Such re-implementation is messy and error-prone; fortunately JAX provides another way. +Such re-implementation can be messy and error-prone as the complexity of a function increases; fortunately JAX provides another way. -## Automatic Vectorization +## Automatic vectorization In JAX, the {func}`jax.vmap` transformation is designed to generate such a vectorized implementation of a function automatically: diff --git a/docs/beginner_guide.rst b/docs/beginner_guide.rst index 2ac1662456f3..204659ec2cb9 100644 --- a/docs/beginner_guide.rst +++ b/docs/beginner_guide.rst @@ -5,7 +5,7 @@ Getting Started with JAX ======================== Welcome to JAX! The JAX documentation contains a number of useful resources for getting started. -:doc:`notebooks/quickstart` is the easiest place to jump-in and get an overview of the JAX project. +:doc:`quickstart` is the easiest place to jump-in and get an overview of the JAX project. If you're accustomed to writing NumPy code and are starting to explore JAX, you might find the following resources helpful: @@ -15,12 +15,12 @@ If you're accustomed to writing NumPy code and are starting to explore JAX, you Tutorials --------- -If you're ready to explore JAX more deeply, the JAX 101 tutorial goes into much more detail: +If you're ready to explore JAX more deeply, the JAX tutorials go into much more detail: .. toctree:: :maxdepth: 2 - jax-101/index + tutorials If you prefer a video introduction here is one from JAX contributor Jake VanderPlas: diff --git a/docs/build_custom_gpu.sh b/docs/build_custom_gpu.sh new file mode 100644 index 000000000000..76fbe6a7ba6a --- /dev/null +++ b/docs/build_custom_gpu.sh @@ -0,0 +1,13 @@ +python -m pip install pybind11==2.10.1 +mkdir -p build +touch build/__init__.py +pybind_include_path=$(python -c "import pybind11; print(pybind11.get_include())") +python_executable=$(python -c 'import sys; print(sys.executable)') +#python_include_path=$(python -c 'from distutils.sysconfig import get_python_inc;print(get_python_inc())') +echo pybind_include_path=$pybind_include_path +echo python_executable=$python_executable + +nvcc --threads 4 -Xcompiler -Wall -ldl --expt-relaxed-constexpr -O3 -DNDEBUG -Xcompiler -O3 --generate-code=arch=compute_70,code=[compute_70,sm_70] --generate-code=arch=compute_75,code=[compute_75,sm_75] --generate-code=arch=compute_80,code=[compute_80,sm_80] --generate-code=arch=compute_86,code=[compute_86,sm_86] -Xcompiler=-fPIC -Xcompiler=-fvisibility=hidden -x cu -c gpu_ops/rms_norm_kernels.cu -o build/rms_norm_kernels.cu.o +c++ -I/usr/local/cuda/include -I$pybind_include_path $(${python_executable}3-config --cflags) -O3 -DNDEBUG -O3 -fPIC -fvisibility=hidden -flto -fno-fat-lto-objects -o build/gpu_ops.cpp.o -c gpu_ops/gpu_ops.cpp +c++ -fPIC -O3 -DNDEBUG -O3 -flto -shared -o build/gpu_ops$(${python_executable}3-config --extension-suffix) build/gpu_ops.cpp.o build/rms_norm_kernels.cu.o -L/usr/local/cuda/lib64 -lcudadevrt -lcudart_static -lrt -lpthread -ldl +strip build/gpu_ops$(${python_executable}3-config --extension-suffix) diff --git a/docs/building_on_jax.md b/docs/building_on_jax.md index a50a6343a0ea..e0a4404911a7 100644 --- a/docs/building_on_jax.md +++ b/docs/building_on_jax.md @@ -1,5 +1,7 @@ # Building on JAX + + A great way to learn advanced JAX usage is to see how other libraries are using JAX, both how they integrate the library into their API, what functionality it adds mathematically, @@ -43,7 +45,7 @@ Here are more specific examples of each pattern. ### Direct Usage Jax can be directly imported and utilized to build models “from scratch” as shown across this website, -for example in [JAX 101](https://jax.readthedocs.io/en/latest/jax-101/index.html) +for example in [JAX Tutorials](https://jax.readthedocs.io/en/latest/tutorials.html) or [Neural Network with JAX](https://jax.readthedocs.io/en/latest/notebooks/neural_network_with_tfds_data.html). This may be the best option if you are unable to find prebuilt code for your particular challenge, or if you're looking to reduce the number diff --git a/docs/conf.py b/docs/conf.py index 5b7bf00d7825..c562b4fb5694 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -26,16 +26,25 @@ # add these directories to sys.path here. If the directory is relative to the # documentation root, use os.path.abspath to make it absolute, like shown here. # +import inspect +import operator import os import sys sys.path.insert(0, os.path.abspath('..')) - -# Currently type aliases are expanded. We tried a workaround along the lines of: +# Workaround to avoid expanding type aliases. See: # https://github.com/sphinx-doc/sphinx/issues/6518#issuecomment-589613836 -# Unfortunately, this workaround makes Sphinx drop module-level documentation. -# See https://github.com/google/jax/issues/3452. +from typing import ForwardRef + +def _do_not_evaluate_in_jax( + self, globalns, *args, _evaluate=ForwardRef._evaluate, +): + if globalns.get('__name__', '').startswith('jax'): + return self + return _evaluate(self, globalns, *args) + +ForwardRef._evaluate = _do_not_evaluate_in_jax # -- Project information ----------------------------------------------------- @@ -63,16 +72,16 @@ 'sphinx.ext.autodoc', 'sphinx.ext.autosummary', 'sphinx.ext.intersphinx', + 'sphinx.ext.linkcode', 'sphinx.ext.mathjax', 'sphinx.ext.napoleon', - 'sphinx.ext.viewcode', 'matplotlib.sphinxext.plot_directive', - 'sphinx_autodoc_typehints', 'myst_nb', "sphinx_remove_toctrees", 'sphinx_copybutton', 'jax_extensions', - 'sphinx_design' + 'sphinx_design', + 'sphinxext.rediraffe', ] intersphinx_mapping = { @@ -124,8 +133,8 @@ 'pallas/quickstart.md', 'pallas/tpu/pipelining.md', 'jep/9407-type-promotion.md', - 'jax-101/*.md', 'autodidax.md', + 'sharded-computation.md', ] # The name of the Pygments (syntax highlighting) style to use. @@ -197,8 +206,6 @@ # List of patterns, relative to source directory, that match notebook # files that will not be executed. nb_execution_excludepatterns = [ - # Includes GPU timings that shouldn't be executed by doc build - 'notebooks/quickstart.*', # Slow notebook: long time to load tf.ds 'notebooks/neural_network_with_tfds_data.*', # Slow notebook @@ -206,13 +213,14 @@ # Has extra requirements: networkx, pandas, pytorch, tensorflow, etc. 'jep/9407-type-promotion.*', # TODO(jakevdp): enable execution on the following if possible: - 'jax-101/*', 'notebooks/xmap_tutorial.*', 'notebooks/Distributed_arrays_and_automatic_parallelization.*', 'notebooks/autodiff_remat.*', # Requires accelerators 'pallas/quickstart.*', 'pallas/tpu/pipelining.*', + 'sharded-computation.*', + 'distributed_data_loading.*' ] # -- Options for HTMLHelp output --------------------------------------------- @@ -292,17 +300,56 @@ # -- Extension configuration ------------------------------------------------- -# Tell sphinx-autodoc-typehints to generate stub parameter annotations including -# types, even if the parameters aren't explicitly documented. -always_document_param_types = True - - # Tell sphinx autodoc how to render type aliases. +autodoc_typehints = "description" +autodoc_typehints_description_target = "all" autodoc_type_aliases = { - 'ArrayLike': 'ArrayLike', - 'DTypeLike': 'DTypeLike', + 'ArrayLike': 'jax.typing.ArrayLike', + 'DTypeLike': 'jax.typing.DTypeLike', } - # Remove auto-generated API docs from sidebars. They take too long to build. remove_from_toctrees = ["_autosummary/*"] + +# Customize code links via sphinx.ext.linkcode + +def linkcode_resolve(domain, info): + import jax + + if domain != 'py': + return None + if not info['module']: + return None + if not info['fullname']: + return None + if info['module'].split(".")[0] != 'jax': + return None + try: + mod = sys.modules.get(info['module']) + obj = operator.attrgetter(info['fullname'])(mod) + if isinstance(obj, property): + obj = obj.fget + while hasattr(obj, '__wrapped__'): # decorated functions + obj = obj.__wrapped__ + filename = inspect.getsourcefile(obj) + source, linenum = inspect.getsourcelines(obj) + except: + return None + filename = os.path.relpath(filename, start=os.path.dirname(jax.__file__)) + lines = f"#L{linenum}-L{linenum + len(source)}" if linenum else "" + return f"https://github.com/google/jax/blob/main/jax/{filename}{lines}" + +# Generate redirects from deleted files to new sources +rediraffe_redirects = { + 'notebooks/quickstart.md': 'quickstart.md', + 'jax-101/01-jax-basics.md': 'key-concepts.md', + 'jax-101/02-jitting.md': 'jit-compilation.md', + 'jax-101/03-vectorization.md': 'automatic-vectorization.md', + 'jax-101/04-advanced-autodiff.md': 'automatic-differentiation.md', + 'jax-101/05-random-numbers.md': 'random-numbers.md', + 'jax-101/05.1-pytrees.md': 'working-with-pytrees.md', + 'jax-101/06-parallelism.md': 'sharded-computation.md', + 'jax-101/07-state.md': 'stateful-computations.md', + 'jax-101/08-pjit.rst': 'sharded-computation.md', + 'jax-101/index.rst': 'tutorials.rst', +} diff --git a/docs/contributing.md b/docs/contributing.md index 5040fbd9f17e..cad7cfc1ea64 100644 --- a/docs/contributing.md +++ b/docs/contributing.md @@ -1,5 +1,7 @@ # Contributing to JAX + + Everyone can contribute to JAX, and we value everyone's contributions. There are several ways to contribute, including: @@ -34,7 +36,7 @@ Follow these steps to contribute code: [repository page](http://www.github.com/google/jax). This creates a copy of the JAX repository in your own account. -3. Install Python >= 3.9 locally in order to run tests. +3. Install Python >= 3.10 locally in order to run tests. 4. `pip` installing your fork from source. This allows you to modify the code and immediately test it out: diff --git a/docs/cuda_custom_call/BUILD b/docs/cuda_custom_call/BUILD new file mode 100644 index 000000000000..93715bdac171 --- /dev/null +++ b/docs/cuda_custom_call/BUILD @@ -0,0 +1,63 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load( + "//jaxlib:jax.bzl", + "cuda_library", + "jax_generate_backend_suites", + "jax_test", +) + +licenses(["notice"]) + +package( + default_applicable_licenses = [], + default_visibility = ["//visibility:private"], +) + +jax_generate_backend_suites() + +jax_test( + name = "cuda_custom_call_test", + srcs = ["cuda_custom_call_test.py"], + data = [":foo"], + disable_backends = [ + "cpu", + "tpu", + ], + tags = ["notap"], + deps = [ + "//jax:extend", + ], +) + +# this second target is needed to properly link in CUDA runtime symbols +# such as cudaLaunchKernel, even though we are only building one library. +cc_shared_library( + name = "foo", + deps = [ + ":foo_", + "@xla//xla/tsl/cuda:cudart", + ], +) + +cuda_library( + name = "foo_", + srcs = ["foo.cu.cc"], + deps = [ + "@xla//xla/ffi/api:c_api", + "@xla//xla/ffi/api:ffi", + "@local_config_cuda//cuda:cuda_headers", + ], +) diff --git a/docs/cuda_custom_call/Makefile b/docs/cuda_custom_call/Makefile new file mode 100644 index 000000000000..ca51b63b5eaf --- /dev/null +++ b/docs/cuda_custom_call/Makefile @@ -0,0 +1,35 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# This Makefile is not used by Bazel for this test, it is intended to serve as +# documentation of build instructions for JAX users that are not using Bazel to +# build their custom call code. For that reason, this Makefile is likely subject +# to bitrot over time. Please file a JAX issue on GitHub if typing "make" in +# this directory no longer runs the test to completion. +NVCC = nvcc +NVCCFLAGS += -I$(shell python -c 'from jax.extend import ffi; print(ffi.include_dir())') +NVCCFLAGS += -arch native +# since the file extension is .cu.cc, tell NVCC explicitly to treat it as .cu +NVCCFLAGS += -x cu + +# depends on libfoo.so being in the same directory as cuda_custom_call_test.py +check: libfoo.so + python cuda_custom_call_test.py + +lib%.so: %.cu.cc + $(NVCC) $(NVCCFLAGS) --compiler-options=-shared,-fPIC -o $@ $< + +clean: + rm -rf *.so diff --git a/docs/cuda_custom_call/cuda_custom_call_test.py b/docs/cuda_custom_call/cuda_custom_call_test.py new file mode 100644 index 000000000000..563462feb472 --- /dev/null +++ b/docs/cuda_custom_call/cuda_custom_call_test.py @@ -0,0 +1,216 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# This test is intentionally structured to stay close to what a standalone JAX +# custom call integration might look like. JAX test harness is in a separate +# section towards the end of this file. The test can be run standalone by typing +# "make" in the directory containing this file. + +import os +import ctypes +import unittest + +import numpy as np + +import jax +import jax.numpy as jnp +from jax.extend import ffi +from jax.lib import xla_client +from jax.interpreters import mlir + +# start test boilerplate +from absl.testing import absltest +from jax._src import config +from jax._src import test_util as jtu + +config.parse_flags_with_absl() +# end test boilerplate + +# XLA needs uppercase, "cuda" isn't recognized +XLA_PLATFORM = "CUDA" + +# JAX needs lowercase, "CUDA" isn't recognized +JAX_PLATFORM = "cuda" + +# 0 = original ("opaque"), 1 = FFI +XLA_CUSTOM_CALL_API_VERSION = 1 + +# these strings are how we identify kernels to XLA: +# - first we register a pointer to the kernel with XLA under this name +# - then we "tell" JAX to emit StableHLO specifying this name to XLA +XLA_CUSTOM_CALL_TARGET_FWD = "foo-fwd" +XLA_CUSTOM_CALL_TARGET_BWD = "foo-bwd" + +# independently, corresponding JAX primitives must also be named, +# names can be different from XLA targets, here they are the same +JAX_PRIMITIVE_FWD = "foo-fwd" +JAX_PRIMITIVE_BWD = "foo-bwd" + +if jtu.is_running_under_pytest(): + raise unittest.SkipTest("libfoo.so hasn't been built") +SHARED_LIBRARY = os.path.join(os.path.dirname(__file__), "libfoo.so") + +library = ctypes.cdll.LoadLibrary(SHARED_LIBRARY) + +#-----------------------------------------------------------------------------# +# Forward pass # +#-----------------------------------------------------------------------------# + +# register the XLA FFI binding pointer with XLA +xla_client.register_custom_call_target(name=XLA_CUSTOM_CALL_TARGET_FWD, + fn=ffi.pycapsule(library.FooFwd), + platform=XLA_PLATFORM, + api_version=XLA_CUSTOM_CALL_API_VERSION) + + +# our forward primitive will also return the intermediate output b+1 +# so it can be reused in the backward pass computation +def _foo_fwd_abstract_eval(a, b): + assert a.shape == b.shape + assert a.dtype == b.dtype + shaped_array = jax.core.ShapedArray(a.shape, a.dtype) + return ( + shaped_array, # output c + shaped_array, # intermediate output b+1 + ) + + +def _foo_fwd_lowering(ctx, a, b): + # ffi.ffi_lowering does most of the heavy lifting building a lowering. + # Keyword arguments passed to the lowering constructed by ffi_lowering are + # turned into custom call backend_config entries, which we take advantage of + # here for the dynamically computed n. + n = np.prod(a.type.shape).astype(np.uint64) + return ffi.ffi_lowering(XLA_CUSTOM_CALL_TARGET_FWD)(ctx, a, b, n=n) + + +# construct a new JAX primitive +foo_fwd_p = jax.core.Primitive(JAX_PRIMITIVE_FWD) +# register the abstract evaluation rule for the forward primitive +foo_fwd_p.def_abstract_eval(_foo_fwd_abstract_eval) +foo_fwd_p.multiple_results = True +mlir.register_lowering(foo_fwd_p, _foo_fwd_lowering, platform=JAX_PLATFORM) + +#-----------------------------------------------------------------------------# +# Backward pass # +#-----------------------------------------------------------------------------# + +# register the XLA FFI binding pointer with XLA +xla_client.register_custom_call_target(name=XLA_CUSTOM_CALL_TARGET_BWD, + fn=ffi.pycapsule(library.FooBwd), + platform=XLA_PLATFORM, + api_version=XLA_CUSTOM_CALL_API_VERSION) + + +def _foo_bwd_abstract_eval(c_grad, a, b_plus_1): + assert c_grad.shape == a.shape + assert a.shape == b_plus_1.shape + assert c_grad.dtype == a.dtype + assert a.dtype == b_plus_1.dtype + + shaped_array = jax.core.ShapedArray(a.shape, a.dtype) + return ( + shaped_array, # a_grad + shaped_array, # b_grad + ) + + +def _foo_bwd_lowering(ctx, c_grad, a, b_plus_1): + n = np.prod(a.type.shape).astype(np.uint64) + return ffi.ffi_lowering(XLA_CUSTOM_CALL_TARGET_BWD)(ctx, + c_grad, + a, + b_plus_1, + n=n) + + +# construct a new JAX primitive +foo_bwd_p = jax.core.Primitive(JAX_PRIMITIVE_BWD) +# register the abstract evaluation rule for the backward primitive +foo_bwd_p.def_abstract_eval(_foo_bwd_abstract_eval) +foo_bwd_p.multiple_results = True +mlir.register_lowering(foo_bwd_p, _foo_bwd_lowering, platform=JAX_PLATFORM) + +#-----------------------------------------------------------------------------# +# User facing API # +#-----------------------------------------------------------------------------# + + +def foo_fwd(a, b): + c, b_plus_1 = foo_fwd_p.bind(a, b) + return c, (a, b_plus_1) + + +def foo_bwd(res, c_grad): + a, b_plus_1 = res + return foo_bwd_p.bind(c_grad, a, b_plus_1) + + +@jax.custom_vjp +def foo(a, b): + c, _ = foo_fwd(a, b) + return c + + +foo.defvjp(foo_fwd, foo_bwd) + +#-----------------------------------------------------------------------------# +# Test # +#-----------------------------------------------------------------------------# + + +class CustomCallTest(jtu.JaxTestCase): + + def test_fwd_interpretable(self): + shape = (2, 3) + a = 2. * jnp.ones(shape) + b = 3. * jnp.ones(shape) + observed = jax.jit(foo)(a, b) + expected = (2. * (3. + 1.)) + self.assertArraysEqual(observed, expected) + + def test_bwd_interpretable(self): + shape = (2, 3) + a = 2. * jnp.ones(shape) + b = 3. * jnp.ones(shape) + + def loss(a, b): + return jnp.sum(foo(a, b)) + + da_observed, db_observed = jax.jit(jax.grad(loss, argnums=(0, 1)))(a, b) + da_expected = b + 1 + db_expected = a + self.assertArraysEqual(da_observed, da_expected) + self.assertArraysEqual(db_observed, db_expected) + + def test_fwd_random(self): + shape = (2, 3) + akey, bkey = jax.random.split(jax.random.key(0)) + a = jax.random.normal(key=akey, shape=shape) + b = jax.random.normal(key=bkey, shape=shape) + observed = jax.jit(foo)(a, b) + expected = a * (b + 1) + self.assertAllClose(observed, expected) + + def test_bwd_random(self): + shape = (2, 3) + akey, bkey = jax.random.split(jax.random.key(0)) + a = jax.random.normal(key=akey, shape=shape) + b = jax.random.normal(key=bkey, shape=shape) + jtu.check_grads(f=jax.jit(foo), args=(a, b), order=1, modes=("rev",)) + + +if __name__ == "__main__": + absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/docs/cuda_custom_call/foo.cu.cc b/docs/cuda_custom_call/foo.cu.cc new file mode 100644 index 000000000000..7072a822f929 --- /dev/null +++ b/docs/cuda_custom_call/foo.cu.cc @@ -0,0 +1,136 @@ +/* Copyright 2024 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/ffi/api/c_api.h" +#include "xla/ffi/api/ffi.h" + +namespace ffi = xla::ffi; + +//----------------------------------------------------------------------------// +// Forward pass // +//----------------------------------------------------------------------------// + +// c = a * (b+1) +// This strawman operation works well for demo purposes because: +// 1. it's simple enough to be quickly understood, +// 2. it's complex enough to require intermediate outputs in grad computation, +// like many operations in practice do, and +// 3. it does not have a built-in implementation in JAX. +__global__ void FooFwdKernel(const float *a, const float *b, float *c, + float *b_plus_1, // intermediate output b+1 + size_t n) { + size_t tid = blockIdx.x * blockDim.x + threadIdx.x; + const size_t grid_stride = blockDim.x * gridDim.x; + for (size_t i = tid; i < n; i += grid_stride) { + b_plus_1[i] = b[i] + 1.0f; + c[i] = a[i] * b_plus_1[i]; + } +} + +// Host function wrapper that launches the kernel with hardcoded grid/block +// size. Note, it uses types from XLA FFI. The return type must be ffi::Error. +// Buffer type provides buffer dimensions, so the "n" argument here is not +// strictly necessary, but it allows us to demonstrate the use of attributes +// (.Attr in the FFI handler definition above). +ffi::Error FooFwdHost(cudaStream_t stream, ffi::Buffer a, + ffi::Buffer b, + ffi::Result> c, + ffi::Result> b_plus_1, + size_t n) { + const int block_dim = 128; + const int grid_dim = 1; + // Note how we access regular Buffer data vs Result Buffer data: + FooFwdKernel<<>>( + a.data, b.data, c->data, b_plus_1->data, n); + // Check for launch time errors. Note that this function may also + // return error codes from previous, asynchronous launches. This + // means that an error status returned here could have been caused + // by a different kernel previously launched by XLA. + cudaError_t last_error = cudaGetLastError(); + if (last_error != cudaSuccess) { + return ffi::Error( + XLA_FFI_Error_Code_INTERNAL, + std::string("CUDA error: ") + cudaGetErrorString(last_error)); + } + return ffi::Error::Success(); +} + +// Creates symbol FooFwd with C linkage that can be loaded using Python ctypes +XLA_FFI_DEFINE_HANDLER_SYMBOL( + FooFwd, FooFwdHost, + ffi::Ffi::Bind() + .Ctx>() // stream + .Arg>() // a + .Arg>() // b + .Ret>() // c + .Ret>() // b_plus_1 + .Attr("n")); + +//----------------------------------------------------------------------------// +// Backward pass // +//----------------------------------------------------------------------------// + +// compute da = dc * (b+1), and +// db = dc * a +__global__ void FooBwdKernel(const float *c_grad, // incoming gradient wrt c + const float *a, // original input a + const float *b_plus_1, // intermediate output b+1 + float *a_grad, // outgoing gradient wrt a + float *b_grad, // outgoing gradient wrt b + size_t n) { + size_t tid = blockIdx.x * blockDim.x + threadIdx.x; + const size_t grid_stride = blockDim.x * gridDim.x; + for (size_t i = tid; i < n; i += grid_stride) { + // In practice on GPUs b_plus_1 can be recomputed for practically free + // instead of storing it out and reusing, so the reuse here is a bit + // contrived. We do it to demonstrate residual/intermediate output passing + // between the forward and the backward pass which becomes useful when + // recomputation is more expensive than reuse. + a_grad[i] = c_grad[i] * b_plus_1[i]; + b_grad[i] = c_grad[i] * a[i]; + } +} + +ffi::Error FooBwdHost(cudaStream_t stream, + ffi::Buffer c_grad, + ffi::Buffer a, + ffi::Result> b_plus_1, + ffi::Result> a_grad, + ffi::Result> b_grad, + size_t n) { + const int block_dim = 128; + const int grid_dim = 1; + FooBwdKernel<<>>( + c_grad.data, a.data, b_plus_1->data, a_grad->data, b_grad->data, n); + cudaError_t last_error = cudaGetLastError(); + if (last_error != cudaSuccess) { + return ffi::Error( + XLA_FFI_Error_Code_INTERNAL, + std::string("CUDA error: ") + cudaGetErrorString(last_error)); + } + return ffi::Error::Success(); +} + +// Creates symbol FooBwd with C linkage that can be loaded using Python ctypes +XLA_FFI_DEFINE_HANDLER_SYMBOL( + FooBwd, FooBwdHost, + ffi::Ffi::Bind() + .Ctx>() // stream + .Arg>() // c_grad + .Arg>() // a + .Arg>() // b_plus_1 + .Ret>() // a_grad + .Ret>() // b_grad + .Attr("n")); diff --git a/docs/debugging.md b/docs/debugging.md new file mode 100644 index 000000000000..7ee36f19f5bf --- /dev/null +++ b/docs/debugging.md @@ -0,0 +1,205 @@ +--- +jupytext: + formats: md:myst + text_representation: + extension: .md + format_name: myst + format_version: 0.13 + jupytext_version: 1.16.1 +kernelspec: + display_name: Python 3 + language: python + name: python3 +--- + +(debugging)= +# Introduction to debugging + + + +This section introduces you to a set of built-in JAX debugging methods — {func}`jax.debug.print`, {func}`jax.debug.breakpoint`, and {func}`jax.debug.callback` — that you can use with various JAX transformations. + +Let's begin with {func}`jax.debug.print`. + +## JAX `debug.print` for high-level + +**TL;DR** Here is a rule of thumb: + +- Use {func}`jax.debug.print` for traced (dynamic) array values with {func}`jax.jit`, {func}`jax.vmap` and others. +- Use Python {func}`print` for static values, such as dtypes and array shapes. + +Recall from {ref}`jit-compilation` that when transforming a function with {func}`jax.jit`, +the Python code is executed with abstract tracers in place of your arrays. Because of this, +the Python {func}`print` function will only print this tracer value: + +```{code-cell} +import jax +import jax.numpy as jnp + +@jax.jit +def f(x): + print("print(x) ->", x) + y = jnp.sin(x) + print("print(y) ->", y) + return y + +result = f(2.) +``` + +Python's `print` executes at trace-time, before the runtime values exist. +If you want to print the actual runtime values, you can use {func}`jax.debug.print`: + +```{code-cell} +@jax.jit +def f(x): + jax.debug.print("jax.debug.print(x) -> {x}", x=x) + y = jnp.sin(x) + jax.debug.print("jax.debug.print(y) -> {y}", y=y) + return y + +result = f(2.) +``` + +Similarly, within {func}`jax.vmap`, using Python's `print` will only print the tracer; +to print the values being mapped over, use {func}`jax.debug.print`: + +```{code-cell} +def f(x): + jax.debug.print("jax.debug.print(x) -> {}", x) + y = jnp.sin(x) + jax.debug.print("jax.debug.print(y) -> {}", y) + return y + +xs = jnp.arange(3.) + +result = jax.vmap(f)(xs) +``` + +Here's the result with {func}`jax.lax.map`, which is a sequential map rather than a +vectorization: + +```{code-cell} +result = jax.lax.map(f, xs) +``` + +Notice the order is different, as {func}`jax.vmap` and {func}`jax.lax.map` compute the same results in different ways. When debugging, the evaluation order details are exactly what you may need to inspect. + +Below is an example with {func}`jax.grad`, where {func}`jax.debug.print` only prints the forward pass. In this case, the behavior is similar to Python's {func}`print`, but it's consistent if you apply {func}`jax.jit` during the call. + +```{code-cell} +def f(x): + jax.debug.print("jax.debug.print(x) -> {}", x) + return x ** 2 + +result = jax.grad(f)(1.) +``` + +Sometimes, when the arguments don't depend on one another, calls to {func}`jax.debug.print` may print them in a different order when staged out with a JAX transformation. If you need the original order, such as `x: ...` first and then `y: ...` second, add the `ordered=True` parameter. + +For example: + +```{code-cell} +@jax.jit +def f(x, y): + jax.debug.print("jax.debug.print(x) -> {}", x, ordered=True) + jax.debug.print("jax.debug.print(y) -> {}", y, ordered=True) + return x + y + +f(1, 2) +``` + +To learn more about {func}`jax.debug.print` and its Sharp Bits, refer to {ref}`advanced-debugging`. + + +## JAX `debug.breakpoint` for `pdb`-like debugging + +**TL;DR** Use {func}`jax.debug.breakpoint` to pause the execution of your JAX program to inspect values. + +To pause your compiled JAX program during certain points during debugging, you can use {func}`jax.debug.breakpoint`. The prompt is similar to Python `pdb`, and it allows you to inspect the values in the call stack. In fact, {func}`jax.debug.breakpoint` is an application of {func}`jax.debug.callback` that captures information about the call stack. + +To print all available commands during a `breakpoint` debugging session, use the `help` command. (Full debugger commands, the Sharp Bits, its strengths and limitations are covered in {ref}`advanced-debugging`.) + +Here is an example of what a debugger session might look like: + +```{code-cell} +:tags: [skip-execution] + +@jax.jit +def f(x): + y, z = jnp.sin(x), jnp.cos(x) + jax.debug.breakpoint() + return y * z +f(2.) # ==> Pauses during execution +``` + +![JAX debugger](_static/debugger.gif) + +For value-dependent breakpointing, you can use runtime conditionals like {func}`jax.lax.cond`: + +```{code-cell} +def breakpoint_if_nonfinite(x): + is_finite = jnp.isfinite(x).all() + def true_fn(x): + pass + def false_fn(x): + jax.debug.breakpoint() + jax.lax.cond(is_finite, true_fn, false_fn, x) + +@jax.jit +def f(x, y): + z = x / y + breakpoint_if_nonfinite(z) + return z + +f(2., 1.) # ==> No breakpoint +``` + +```{code-cell} +:tags: [skip-execution] + +f(2., 0.) # ==> Pauses during execution +``` + +## JAX `debug.callback` for more control during debugging + +Both {func}`jax.debug.print` and {func}`jax.debug.breakpoint` are implemented using +the more flexible {func}`jax.debug.callback`, which gives greater control over the +host-side logic executed via a Python callback. +It is compatible with {func}`jax.jit`, {func}`jax.vmap`, {func}`jax.grad` and other +transformations (refer to the {ref}`external-callbacks-flavors-of-callback` table in +{ref}`external-callbacks` for more information). + +For example: + +```{code-cell} +import logging + +def log_value(x): + logging.warning(f'Logged value: {x}') + +@jax.jit +def f(x): + jax.debug.callback(log_value, x) + return x + +f(1.0); +``` + +This callback is compatible with other transformations, including {func}`jax.vmap` and {func}`jax.grad`: + +```{code-cell} +x = jnp.arange(5.0) +jax.vmap(f)(x); +``` + +```{code-cell} +jax.grad(f)(1.0); +``` + +This can make {func}`jax.debug.callback` useful for general-purpose debugging. + +You can learn more about {func}`jax.debug.callback` and other kinds of JAX callbacks in {ref}`external-callbacks`. + +## Next steps + +Check out the {ref}`advanced-debugging` to learn more about debugging in JAX. diff --git a/docs/debugging/checkify_guide.md b/docs/debugging/checkify_guide.md index a804d36038a1..2dad9b863b06 100644 --- a/docs/debugging/checkify_guide.md +++ b/docs/debugging/checkify_guide.md @@ -1,5 +1,7 @@ # The `checkify` transformation + + **TL;DR** Checkify lets you add `jit`-able runtime error checking (e.g. out of bounds indexing) to your JAX code. Use the `checkify.checkify` transformation together with the assert-like `checkify.check` function to add runtime checks to JAX code: ```python diff --git a/docs/debugging/flags.md b/docs/debugging/flags.md index 8434384c4eb5..1cf1829e5152 100644 --- a/docs/debugging/flags.md +++ b/docs/debugging/flags.md @@ -1,5 +1,7 @@ # JAX debugging flags + + JAX offers flags and context managers that enable catching errors more easily. ## `jax_debug_nans` configuration option and context manager @@ -12,14 +14,14 @@ JAX offers flags and context managers that enable catching errors more easily. If you want to trace where NaNs are occurring in your functions or gradients, you can turn on the NaN-checker by: * setting the `JAX_DEBUG_NANS=True` environment variable; -* adding `from jax import config` and `config.update("jax_debug_nans", True)` near the top of your main file; -* adding from `jax.config import config` and `config.parse_flags_with_absl()` to your main file, then set the option using a command-line flag like `--jax_debug_nans=True`; +* adding `jax.config.update("jax_debug_nans", True)` near the top of your main file; +* adding `jax.config.parse_flags_with_absl()` to your main file, then set the option using a command-line flag like `--jax_debug_nans=True`; ### Example(s) ```python -from jax import config -config.update("jax_debug_nans", True) +import jax +jax.config.update("jax_debug_nans", True) def f(x, y): return x / y @@ -47,14 +49,14 @@ jax.jit(f)(0., 0.) # ==> raises FloatingPointError exception! You can disable JIT-compilation by: * setting the `JAX_DISABLE_JIT=True` environment variable; -* adding `from jax import config` and `config.update("jax_disable_jit", True)` near the top of your main file; -* adding from `jax.config import config` and `config.parse_flags_with_absl()` to your main file, then set the option using a command-line flag like `--jax_disable_jit=True`; +* adding `jax.config.update("jax_disable_jit", True)` near the top of your main file; +* adding `jax.config.parse_flags_with_absl()` to your main file, then set the option using a command-line flag like `--jax_disable_jit=True`; ### Examples ```python -from jax import config -config.update("jax_disable_jit", True) +import jax +jax.config.update("jax_disable_jit", True) def f(x): y = jnp.log(x) diff --git a/docs/debugging/index.md b/docs/debugging/index.md index 9a020b360b81..b00fcc13d0a0 100644 --- a/docs/debugging/index.md +++ b/docs/debugging/index.md @@ -1,5 +1,7 @@ # Runtime value debugging in JAX + + Do you have exploding gradients? Are NaNs making you gnash your teeth? Just want to poke around the intermediate values in your computation? Check out the following JAX debugging tools! This page has TL;DR summaries and you can click the "Read more" links at the bottom to learn more. Table of contents: @@ -82,8 +84,8 @@ Click [here](checkify_guide) to learn more! **TL;DR** Enable the `jax_debug_nans` flag to automatically detect when NaNs are produced in `jax.jit`-compiled code (but not in `jax.pmap` or `jax.pjit`-compiled code) and enable the `jax_disable_jit` flag to disable JIT-compilation, enabling use of traditional Python debugging tools like `print` and `pdb`. ```python -from jax import config -config.update("jax_debug_nans", True) +import jax +jax.config.update("jax_debug_nans", True) def f(x, y): return x / y diff --git a/docs/debugging/print_breakpoint.md b/docs/debugging/print_breakpoint.md index 4bd5d03520c9..440cc38d99f0 100644 --- a/docs/debugging/print_breakpoint.md +++ b/docs/debugging/print_breakpoint.md @@ -1,5 +1,7 @@ # `jax.debug.print` and `jax.debug.breakpoint` + + The {mod}`jax.debug` package offers some useful tools for inspecting values inside of JIT-ted functions. @@ -150,7 +152,7 @@ def callback(fun: Callable, *args: PyTree[Array], **kwargs: PyTree[Array]) -> No return None ``` -As with `jax.debug.print`, these callbacks should only be used for debugging output, like printing or plotting. Printing and plotting are pretty harmless, but if you use it for anything else its behavior might surprise you under transformations. For example, it's not safe to use `jax.debug.callback` for timing operations, since callbacks might reordered and asynchronous (see below). +As with `jax.debug.print`, these callbacks should only be used for debugging output, like printing or plotting. Printing and plotting are pretty harmless, but if you use it for anything else its behavior might surprise you under transformations. For example, it's not safe to use `jax.debug.callback` for timing operations, since callbacks might be reordered and asynchronous (see below). ### Sharp bits Like most JAX APIs, `jax.debug.print` can cut you if you're not careful. diff --git a/docs/deprecation.md b/docs/deprecation.md index 5b69b7053fcd..7a8b867b6f2e 100644 --- a/docs/deprecation.md +++ b/docs/deprecation.md @@ -1,6 +1,8 @@ (version-support-policy)= # Python and NumPy version support policy + + For NumPy and SciPy version support, JAX follows the Python scientific community's [SPEC 0](https://scientific-python.org/specs/spec-0000/). @@ -11,12 +13,25 @@ nine months longer than SPEC-0 recommends. This means we support at least: -* All minor Python releases in the 45 months prior to each JAX release. +* All minor Python releases in the 45 months prior to each JAX release. For example: + + * **Python 3.9** was released October 2020, and will be supported in new JAX releases at least until **July 2024**. + * **Python 3.10** was released October 2021, and will be supported in new JAX releases at least until **July 2025**. + * **Python 3.11** was released October 2022, and will be supported in new JAX releases at least until **July 2026**. -* All minor NumPy releases in the 24 months prior to each JAX release. +* All minor NumPy releases in the 24 months prior to each JAX release. For example: + + * **NumPy 1.22** was released December 2021, and will be supported in new JAX releases at least until **December 2023**. + * **NumPy 1.23** was released June 2022, and will be supported in new JAX releases at least until **June 2024**. + * **NumPy 1.24** was released December 2022, and will be supported in new JAX releases at least until **December 2024**. * All minor SciPy releases in the 24 months prior to each JAX release, starting - with SciPy version 1.9 + with SciPy version 1.9. For example: + + * **Scipy 1.9** was released July 2022, and will be supported in new JAX releases at least until **July 2024**. + * **Scipy 1.10** was released January 2023, and will be supported in new JAX releases at least until **January 2025**. + * **Scipy 1.11** was released June 2023, and will be supported in new JAX releases at least until **June 2025**. -JAX releases may support older versions of Python, NumPy, and SciPy, but support -for older versions may be dropped at any time. +JAX releases may support older versions of Python, NumPy, and SciPy than strictly required +by this policy, but support for older versions may be dropped at any time beyond the listed +dates. diff --git a/docs/developer.md b/docs/developer.md index 34876b3b1293..018982f4c00d 100644 --- a/docs/developer.md +++ b/docs/developer.md @@ -1,6 +1,8 @@ (building-from-source)= # Building from source + + First, obtain the JAX source code: ``` @@ -44,21 +46,31 @@ To build `jaxlib` from source, you must also install some prerequisites: See below for Windows build instructions. -- Python packages: `numpy`, `wheel`, `build`. +- there is no need to install Python dependencies locally, as your system + Python will be ignored during the build; please check + [Managing hermetic Python](#managing-hermetic-python) for details. -You can install the necessary Python dependencies using `pip`: +To build `jaxlib` for CPU or TPU, you can run: ``` -pip install numpy wheel build +python build/build.py +pip install dist/*.whl # installs jaxlib (includes XLA) ``` -To build `jaxlib` for CPU or TPU, you can run: +To build a wheel for a version of Python different from your current system +installation pass `--python_version` flag to the build command: ``` -python build/build.py -pip install dist/*.whl # installs jaxlib (includes XLA) +python build/build.py --python_version=3.12 ``` +The rest of this document assumes that you are building for Python version +matching your current system installation. If you need to build for a different +version, simply append `--python_version=` flag every time you call +`python build/build.py`. Note, the Bazel build will always use a hermetic Python +installation regardless of whether the `--python_version` parameter is passed or +not. + There are two ways to build `jaxlib` with CUDA support: (1) use `python build/build.py --enable_cuda` to generate a jaxlib wheel with cuda support, or (2) use @@ -69,8 +81,11 @@ and jax-cuda-pjrt). You can set `gpu_plugin_cuda_version` to 11 or 12. See `python build/build.py --help` for configuration options, including ways to specify the paths to CUDA and CUDNN, which you must have installed. Here `python` should be the name of your Python 3 interpreter; on some systems, you -may need to use `python3` instead. By default, the wheel is written to the -`dist/` subdirectory of the current directory. +may need to use `python3` instead. Despite calling the script with `python`, +Bazel will always use its own hermetic Python interpreter and dependencies, only +the `build/build.py` script itself will be processed by your system Python +interpreter. By default, the wheel is written to the `dist/` subdirectory of the +current directory. ### Building jaxlib from source with a modified XLA repository. @@ -172,6 +187,200 @@ python build/build.py --enable_rocm --rocm_path=/opt/rocm-5.7.0 \ --bazel_options=--override_repository=xla=/path/to/xla-rocm ``` +## Managing hermetic Python + +To make sure that JAX's build is reproducible, behaves uniformly across +supported platforms (Linux, Windows, MacOS) and is properly isolated from +specifics of a local system, we rely on hermetic Python (see +[rules_python](https://github.com/bazelbuild/rules_python)) for all build +and test commands executed via Bazel. This means that your system Python +installation will be ignored during the build and Python interpreter itself +as well as all the Python dependencies will be managed by bazel directly. + +### Specifying Python version + +When you run `build/build.py` tool, the version of hermetic Python is set +automatically to match the version of the Python you used to +run `build/build.py` script. To choose a specific version explicitly you may +pass `--python_version` argument to the tool: + +``` +python build/build.py --python_version=3.12 +``` + +Under the hood, the hermetic Python version is controlled +by `HERMETIC_PYTHON_VERSION` environment variable, which is set automatically +when you run `build/build.py`. In case you run bazel directly you may need to +set the variable explicitly in one of the following ways: + +``` +# Either add an entry to your `.bazelrc` file +build --repo_env=HERMETIC_PYTHON_VERSION=3.12 + +# OR pass it directly to your specific build command +bazel build --repo_env=HERMETIC_PYTHON_VERSION=3.12 + +# OR set the environment variable globally in your shell: +export HERMETIC_PYTHON_VERSION=3.12 +``` + +You may run builds and tests against different versions of Python sequentially +on the same machine by simply switching the value of `--python_version` between +the runs. All the python-agnostic parts of the build cache from the previous +build will be preserved and reused for the subsequent builds. + +### Specifying Python dependencies + +During bazel build all JAX's Python dependencies are pinned to their specific +versions. This is necessary to ensure reproducibility of the build. +The pinned versions of the full transitive closure of JAX's dependencies +together with their corresponding hashes are specified in +`build/requirements_lock_.txt` files ( +e.g. `build/requirements_lock_3_12.txt` for `Python 3.12`). + +To update the lock files, make sure `build/requirements.in` contains the desired +direct dependencies list and then execute the following command (which will call +[pip-compile](https://pypi.org/project/pip-tools/) under the hood): + +``` +python build/build.py --requirements_update --python_version=3.12 +``` + +Alternatively, if you need more control, you may run the bazel command +directly (the two commands are equivalent): + +``` +bazel run //build:requirements.update --repo_env=HERMETIC_PYTHON_VERSION=3.12 +``` + +where `3.12` is the `Python` version you wish to update. + +Note, since it is still `pip` and `pip-compile` tools used under the hood, so +most of the command line arguments and features supported by those tools will be +acknowledged by the Bazel requirements updater command as well. For example, if +you wish the updater to consider pre-release versions simply pass `--pre` +argument to the bazel command: + +``` +bazel run //build:requirements.update --repo_env=HERMETIC_PYTHON_VERSION=3.12 -- --pre +``` + +### Specifying dependencies on local wheels + +If you need to depend on a local .whl file, for example on your newly built +jaxlib wheel, you may add a path to the wheel in `build/requirements.in` and +re-run the requirements updater command for a selected version of Python. For +example: + +``` +echo -e "\n$(realpath jaxlib-0.4.27.dev20240416-cp312-cp312-manylinux2014_x86_64.whl)" >> build/requirements.in +python build/build.py --requirements_update --python_version=3.12 +``` + +### Specifying dependencies on nightly wheels + +To build and test against the very latest, potentially unstable, set of Python +dependencies we provide a special version of the dependency updater command as +follows: + +``` +python build/build.py --requirements_nightly_update --python_version=3.12 +``` + +Or, if you run `bazel` directly (the two commands are equivalent): + +``` +bazel run //build:requirements_nightly.update --repo_env=HERMETIC_PYTHON_VERSION=3.12 +``` + +The difference between this and the regular updater is that by default it would +accept pre-release, dev and nightly packages, it will also +search https://pypi.anaconda.org/scientific-python-nightly-wheels/simple as an +extra index url and will not put hashes in the resultant requirements lock file. + +### Building with pre-release Python version + +We support all of the current versions of Python out of the box, but if you need +to build and test against a different version (for example the latest unstable +version which hasn't been released officially yet) please follow the +instructions below. + +1) Make sure you have installed necessary linux packages needed to build Python + interpreter itself and key packages (like `numpy` or `scipy`) from source. On + a typical Debian system you may need to install the following packages: + +``` +sudo apt-get update +sudo apt-get build-dep python3 -y +sudo apt-get install pkg-config zlib1g-dev libssl-dev -y +# to build scipy +sudo apt-get install libopenblas-dev -y +``` + +2) Check your `WORKSPACE` file and make sure it + has `custom_python_interpreter()` entry there, pointing to the version of + Python you want to build. + +3) Run `bazel build @python_dev//:python_dev` to build Python interpreter. By default it will + be built with GCC compiler. If you wish to build with clang, you need to set + corresponding env variables to do so ( + e.g. `--repo_env=CC=/usr/lib/llvm-17/bin/clang --repo_env=CXX=/usr/lib/llvm-17/bin/clang++`). + +4) Check the output of the previous command. At the very end of it you will find + a code snippet for `python_register_toolchains()` entry with your newly built + Python in it. Copy that code snippet in your `WORKSPACE` file either right + after `python_init_toolchains()` entry (to add the new version of Python) or + instead of it (to replace an existing version, like replacing 3.12 with + custom built variant of 3.12). The code snippet is generated to match your + actual setup, so it should work as is, but you can customize it if you choose + so (for example to change location of Python's `.tgz` file so it could be + downloaded remotely instead of being on local machine). + +5) Make sure there is an entry for your Python's version in `requirements` + parameter for `python_init_repositories()` in your WORKSPACE file. For + example for `Python 3.13` it should have something + like `"3.13": "//build:requirements_lock_3_13.txt"`. + +6) For unstable versions of Python, optionally (but highly recommended) + run `bazel build //build:all_py_deps --repo_env=HERMETIC_PYTHON_VERSION="3.13"`, + where `3.13` is the version of Python interpreter you built on step 3. + This will make `pip` pull and build from sources (for packages which don't + have binaries published yet, for + example `numpy`, `scipy`, `matplotlib`, `zstandard`) all of the JAX's python + dependencies. It is recommended to do this step first (i.e. independently of + actual JAX build) for all unstable versions of Python to avoid conflict + between building JAX itself and building of its Python dependencies. For + example, we normally build JAX with clang but building `matplotlib` from + sources with clang fails out of the box due to differences in LTO behavior ( + Link Time Optimization, triggered by `-flto` flag) between GCC and clang, and + matplotlib assumes GCC by default. + If you build against a stable version of Python, or in general you do not + expect any of your Python dependencies to be built from sources (i.e. binary + distributions for the corresponding Python version already exist in the + repository) this step is not needed. + +7) Congrats, you've built and configured your custom Python for JAX project! You + may now execute your built/test commands as usual, just make + sure `HERMETIC_PYTHON_VERSION` environment variable is set and points to your + new version. + +8) Note, if you were building a pre-release version of Python, updating of + `requirements_lock_.txt` files with your newly built Python + is likely to fail, because package repositories will not have matching + binary packages. When there are no binary packages available `pip-compile` + proceeds with building them from sources, which is likely to fail because it + is more restrictive than doing the same thing during `pip` installation. + The recommended way to update requirements lock file for unstable versions of + Python is to update requirements for the latest stable version (e.g. `3.12`) + without hashes (therefore special `//build:requirements_dev.update` target) + and then copy the results to the unstable Python's lock file (e.g. `3.13`): +``` +bazel run //build:requirements_dev.update --repo_env=HERMETIC_PYTHON_VERSION="3.12" +cp build/requirements_lock_3_12.txt build/requirements_lock_3_13.txt +bazel build //build:all_py_deps --repo_env=HERMETIC_PYTHON_VERSION="3.13" +# You may need to edit manually the resultant lock file, depending on how ready +# your dependencies are for the new version of Python. +``` ## Installing `jax` @@ -190,8 +399,6 @@ sets up symbolic links from site-packages into the repository. ## Running the tests -First, install the dependencies by running `pip install -r build/test-requirements.txt`. - There are two supported mechanisms for running the JAX tests, either using Bazel or using pytest. @@ -206,16 +413,33 @@ python build/build.py --configure_only You may pass additional options to `build.py` to configure the build; see the `jaxlib` build documentation for details. -By default the Bazel build runs the JAX tests using `jaxlib` built form source. +By default the Bazel build runs the JAX tests using `jaxlib` built from source. To run JAX tests, run: ``` bazel test //tests:cpu_tests //tests:backend_independent_tests ``` -`//tests:gpu_tests` and `//tests:tpu_tests` are also available, if you have the necessary hardware. +`//tests:gpu_tests` and `//tests:tpu_tests` are also available, if you have the +necessary hardware. -To use a preinstalled `jaxlib` instead of building `jaxlib` from source, run +To use a preinstalled `jaxlib` instead of building it you first need to +make it available in the hermetic Python. To install a specific version of +`jaxlib` within hermetic Python run (using `jaxlib >= 0.4.26` as an example): + +``` +echo -e "\njaxlib >= 0.4.26" >> build/requirements.in +python build/build.py --requirements_update +``` + +Alternatively, to install `jaxlib` from a local wheel (assuming Python 3.12): + +``` +echo -e "\n$(realpath jaxlib-0.4.26-cp312-cp312-manylinux2014_x86_64.whl)" >> build/requirements.in +python build/build.py --requirements_update --python_version=3.12 +``` + +Once you have `jaxlib` installed hermetically, run: ``` bazel test --//jax:build_jaxlib=false //tests:cpu_tests //tests:backend_independent_tests @@ -225,13 +449,16 @@ A number of test behaviors can be controlled using environment variables (see below). Environment variables may be passed to JAX tests using the `--test_env=FLAG=value` flag to Bazel. -Some of JAX tests are for multiple accelerators (i.e. GPUs, TPUs). When JAX is already installed, you can run GPUs tests like this: +Some of JAX tests are for multiple accelerators (i.e. GPUs, TPUs). When JAX is +already installed, you can run GPUs tests like this: ``` bazel test //tests:gpu_tests --local_test_jobs=4 --test_tag_filters=multiaccelerator --//jax:build_jaxlib=false --test_env=XLA_PYTHON_CLIENT_ALLOCATOR=platform ``` -You can speed up single accelerator tests by running them in parallel on multiple accelerators. This also triggers multiple concurrent tests per accelerator. For GPUs, you can do it like this: +You can speed up single accelerator tests by running them in parallel on +multiple accelerators. This also triggers multiple concurrent tests per +accelerator. For GPUs, you can do it like this: ``` NB_GPUS=2 @@ -241,10 +468,9 @@ MULTI_GPU="--run_under $PWD/build/parallel_accelerator_execute.sh --test_env=JAX bazel test //tests:gpu_tests //tests:backend_independent_tests --test_env=XLA_PYTHON_CLIENT_PREALLOCATE=false --test_tag_filters=-multiaccelerator $MULTI_GPU ``` -Some test targets, like a `//tests:logpcg_tests` optionally use matplotlib, so you may need to `pip -install matplotlib` to run tests via bazel. - ### Using `pytest` +First, install the dependencies by +running `pip install -r build/test-requirements.txt`. To run all the JAX tests using `pytest`, we recommend using `pytest-xdist`, which can run tests in parallel. It is installed as a part of @@ -418,7 +644,7 @@ notebooks; for example: ``` pip install jupytext==1.16.0 -jupytext --sync docs/notebooks/quickstart.ipynb +jupytext --sync docs/notebooks/thinking_in_jax.ipynb ``` The jupytext version should match that specified in diff --git a/docs/device_memory_profiling.md b/docs/device_memory_profiling.md index 57b2894a53a0..e4d871b780f3 100644 --- a/docs/device_memory_profiling.md +++ b/docs/device_memory_profiling.md @@ -1,5 +1,6 @@ # Device Memory Profiling + ```{note} May 2023 update: we recommend using [Tensorboard @@ -59,7 +60,7 @@ def func2(x): y = func1(x) return y, jnp.tile(x, 10) + 1 -x = jax.random.normal(jax.random.PRNGKey(42), (1000, 1000)) +x = jax.random.normal(jax.random.key(42), (1000, 1000)) y, z = func2(x) z.block_until_ready() @@ -107,14 +108,14 @@ import jax.numpy as jnp import jax.profiler def afunction(): - return jax.random.normal(jax.random.PRNGKey(77), (1000000,)) + return jax.random.normal(jax.random.key(77), (1000000,)) z = afunction() def anotherfunc(): arrays = [] for i in range(1, 10): - x = jax.random.normal(jax.random.PRNGKey(42), (i, 10000)) + x = jax.random.normal(jax.random.key(42), (i, 10000)) arrays.append(x) x.block_until_ready() jax.profiler.save_device_memory_profile(f"memory{i}.prof") diff --git a/docs/distributed_data_loading.md b/docs/distributed_data_loading.md new file mode 100644 index 000000000000..70cbd26baa5c --- /dev/null +++ b/docs/distributed_data_loading.md @@ -0,0 +1,466 @@ +--- +jupytext: + formats: md:myst + text_representation: + extension: .md + format_name: myst + format_version: 0.13 + jupytext_version: 1.16.1 +kernelspec: + display_name: Python 3 + language: python + name: python3 +--- + +# Distributed data loading in a multi-host/multi-process environment + + + +This high-level guide demonstrates how you can perform distributed data loading — when you run JAX in a {doc}`multi-host or multi-process environment <./multi_process>`, and the data required for the JAX computations is split across the multiple processes. This document covers the overall approach for how to think about distributed data loading, and then how to apply it to *data-parallel* (simpler) and *model-parallel* (more complicated) workloads. + +Distributed data loading is usually more efficient (the data is split across processes) but also *more complex* compared with its alternatives, such as: 1) loading the *full global data in a single process*, splitting it up and sending the needed parts to the other processes via RPC; and 2) loading the *full global data in all processes* and only using the needed parts in each process. Loading the full global data is often simpler but more expensive. For example, in machine learning the training loop can get blocked while waiting for data, and additional network bandwidth gets used per each process. + +```{note} +When using distributed data loading, it's important that each device (for example, each GPU or TPU) has access to the input data shard(s) that it needs to run the computation. This is what usually makes distributed data loading more complicated and challenging to implement correctly (compared with the alternatives described above). If the incorrect data shards end up on the wrong devices, the computation can still run without errors, since the computation has no way to know what the input data "should" be. However, the final result will often be incorrect, since the input data was different than intended. +``` + +## General approach for loading a `jax.Array` + +Consider a case of creating a single {class}`jax.Array` from raw data not produced by JAX. These concepts apply beyond loading batched data records, such as any multi-process {class}`jax.Array` that wasn't directly produced by a JAX computation. Examples include: 1) loading model weights from a checkpoint; or 2) loading a large spatially-sharded image. + +Every {class}`jax.Array` has an associated {mod}`~jax.sharding.Sharding`, which describes which shard of the global data is required by each global device. When you create a {class}`jax.Array` from scratch, you also need to create its `Sharding`. This is how JAX can understand how the data is laid out across devices. You can create whatever `Sharding` you want. In practice, you usually pick a `Sharding` based on what kind of parallelism strategy you are implementing (you will learn more about data and model parallelism in more detail later in this guide). You can also pick a `Sharding` based on how the raw data will be produced within each process. + +Once you have defined a {mod}`~jax.sharding.Sharding`, you can use {func}`~jax.sharding.Sharding.addressable_devices()` to provide a list of devices needed to load data for within the current process. (Note: The term "addressable devices" is a more general version of "local devices". The goal is to make sure that each process's data loader provides the right data to all of that process' local devices. + +### Examples + +For example, consider a `(64, 128)` {class}`jax.Array` that you need to shard across 4 processes with 2 devices each (8 devices total). This will result in 8 unique data shards, one for each device. There are many ways to shard this {class}`jax.Array`. You can perform a 1D sharding across the second dimension of the {class}`jax.Array`, giving each device a `(64, 16)` shard, as demonstrated below: + +
+ +![8 unique data shards](./_static/distributed_data_loading/1.svg) + +
+ +In the above figure, each data shard has its own color to indicate which process needs to load that shard. For example, you assume process `0`'s 2 devices contain shards `A` and `B`, corresponding to the first `(64, 32)` piece of the global data. + +You can pick a different distribution of shards to devices. For example: + +
+ +![8 unique data shards - different distribution](./_static/distributed_data_loading/2.svg) + +
+ +Here is another example — a 2D sharding: + +
+ +![2D sharding](./_static/distributed_data_loading/3.svg) + +
+ +However the {class}`jax.Array` happens to be sharded, you have to make sure that each process's data loader is provided with/loads the required shard(s) of the global data. There are several high-level methods for achieving this: 1) load the global data in each process; 2) use a per-device data pipeline; 3) use a consolidated per-process data pipeline; 4) load data in some convenient way and then reshard inside computation. + +### Option 1: Load the global data in each process + +
+ +![Loading the global data in each process](./_static/distributed_data_loading/4.svg) + +
+ +Using this option, each process: + +1) Loads the full value needed; and +2) Transfers only the needed shards to that process's local devices. + +This is not an efficient approach to distributed data loading, since each process will throw away the data not needed by its local devices, and the total data ingested can be higher than necessary. But this option works and is relatively simple to implement, while the performance overhead may be acceptable for certain workloads (for example, if the global data is small). + +### Option 2: Use a per-device data pipeline + +
+ +![Using a per-device data pipeline](./_static/distributed_data_loading/5.svg) + +
+ +In this option, each process sets up a data loader for each of its local devices (that is, each device gets its own data loader for just the data shard it requires). + +This is efficient in terms of the data loaded. It can also sometimes be simpler to consider each device independently rather than all of a process's local devices at once (refer to _Option 3: Use a consolidated per-process data pipeline_ below). However, having multiple concurrent data loaders can sometimes cause performance issues. + +### Option 3: Use a consolidated per-process data pipeline + +
+ +![Using a consolidated per-process data pipeline](./_static/distributed_data_loading/6.svg) + +
+ +If you choose this option, each process: + +1) Sets up a single data loader that loads the data required for all of its local devices; and then +2) Shards the local data before transferring to each local device. + +This is the *most efficient way to do distributed loading*. However, it's also the *most complex*, since logic is needed both to figure out which data is needed by each device, and to create a single data loading that loads only all of that data (and, ideally, no other extra data). + +### Option 4: Load data in some convenient way, reshard inside computation + +
+ +![Loading data in some convenient way, reshard inside computation](./_static/distributed_data_loading/7.svg) + +
+ +This option is more challenging to explain, but often easier to implement than the above options (from 1 to 3). + +Imagine a scenario where it's difficult or rather impossible to set up data loaders that load exactly the data you need, either for per-device or per-process loaders. However, it may still be possible to set up a data loader per process that loads `1 / num_processes` of the data, just not in the right sharding. + +Then, continuing with your 2D example sharding from before, assume it is easier for each process to load a single column of the data: + +Then, you can create a {class}`jax.Array` with a {mod}`~jax.sharding.Sharding` representing the per-column data, pass that directly into the computation, and use {func}`jax.lax.with_sharding_constraint` to immediately reshard the column-sharded input to the desired sharding. And since the data is resharded inside the computation, it will be resharded over the accelerator communication links (for example, TPU ICI or NVLink). + +This Option 4 has similar benefits to Option 3 (_Use a consolidated per-process data pipeline_): + +- Each process still has a single data loader; and +- The global data is loaded exactly once across all processes; and +- The global data has the additional benefit of offering more flexibility in how the data is loaded. + +However, this approach uses accelerator interconnect bandwidth to perform the resharding, which may slow down certain workloads. Option 4 also requires that the input data be expressed as a separate `Sharding`, in addition to the target `Sharding`. + +## Replication + +Replication describes a process where multiple devices have the same data shard. The general options mentioned above (Options 1 through 4) still work with replication. The only difference is that some processes may end up loading the same data shards. This section describes full replication and partial replication. + +### Full replication + +**Full replication** is a process where all devices have a full copy of the data (that is, the data "shard" is the entire array value). + +In the below example, since there are 8 devices in total (2 per process), you will end up with 8 copies of the full data. Each copy of the data is unsharded, that is the copy lives on a single device: + +
+ +![Full replication](./_static/distributed_data_loading/8.svg) + +
+ +### Partial replication + +**Partial replication** describes a process where there are multiple copies of the data, and each copy is sharded across multiple devices. For a given array value, there are generally many possible ways to perform partial replication (Note: There is always a single fully-replicated {mod}`~jax.sharding.Sharding` for a given array shape). + +Below are two possible examples. + +In the first example below, each copy is sharded across the two local devices of a process, for a total of 4 copies. This means that each process will need to load the full global data, since its local devices will have a full copy of the data. + +
+ +![Partial replication - example 1](./_static/distributed_data_loading/9.svg) + +
+ +In the second example below, each copy is still sharded across two devices, but each device pair is spread across two different processes. Process `0` (pink) and process `1` (yellow) both need to load just the first row of the data, and process `2` (green) and process `3` (blue) both need to load just the second row of the data: + +
+ +![Partial replication - example 2](./_static/distributed_data_loading/10.svg) + +
+ +Now that you've gone over the high-level options for creating a {class}`jax.Array`, let's apply them to data loading for ML applications. + +## Data parallelism + +In *pure data parallelism* (without model parallelism): + +- You replicate the model on each device; and +- Each model replica (that is, each device) receives a different per-replica batch of data. + +
+ +![Data parallelism - example 1](./_static/distributed_data_loading/11.svg) + +
+ +When representing the input data as a single {class}`jax.Array`, the Array contains the data across all replicas for this step (this is called *global batch*), with each shard of the {class}`jax.Array` containing a single per-replica batch. You can represent this as a 1D sharding across all devices (check the example below) — in other words, the global batch is composed of all the per-replica batches concatenated together across the batch axis. + +
+ +![Data parallelism - example 2](./_static/distributed_data_loading/12.svg) + +
+ +Applying this framework, you may conclude that process `0` should get the first quarter (2 out of 8) of the global batch, while process `1` should get the second, and so on. + +But how can you know what the first quarter is? And how do you make sure process `0` gets the first quarter? Luckily, there's a very important trick about data parallelism that means you don't have to answer these questions and makes the whole setup simpler. + +## Important trick about data parallelism + +The trick is you don't need to care which per-replica batch lands on which replica. Therefore, it doesn't matter which process loads a batch. The reason is that since each device corresponds to a model replica performing the same thing, it doesn't matter which device gets which per-replica batch within the global batch. + +What this means is that you are free to rearrange the per-replica batches within the global batch. In other words, you are free to randomize which data shard each device gets. + +For example: + +
+ +![Data parallelism - example 3](./_static/distributed_data_loading/13.svg) + +
+ +Usually, rearranging the data shards of a {class}`jax.Array`, as demonstrated above, is not a good idea – you're effectively permuting the value of the {class}`jax.Array`! However, for data parallelism, the global batch order isn't meaningful, and you are free to rearrange the per-replica batches in the global batch, as already mentioned before. + +This simplifies data loading because it means each device just needs an independent stream of per-replica batches, which can be easily implemented in most data loaders by creating an independent pipeline per process and chunking the resulting per-process batch into per-replica batches. + +
+ +![Data parallelism - example 4](./_static/distributed_data_loading/14.svg) + +
+ +This is an instance of the _Option 2: Consolidated per-process data pipeline_. You can also use other options (such as 0, 1 and 3, which are covered earlier in this document), but this one is relatively simple and efficient. + +Here's an example of how to implement this setup using tf.data: + +```{code-cell} +import jax +import tensorflow as tf +import numpy as np + +################################################################################ +# Step 1: setup the Dataset for pure data parallelism (do once) +################################################################################ +# Fake example data (replace with your Dataset) +ds = tf.data.Dataset.from_tensor_slices( + [np.ones((16, 3)) * i for i in range(100)]) + +ds = ds.shard(num_shards=jax.process_count(), index=jax.process_index()) + +################################################################################ +# Step 2: create a jax.Array of per-replica batches from the per-process batch +# produced from the Dataset (repeat every step). This can be used with batches +# produced by different data loaders as well! +################################################################################ +# Grab just the first batch from the Dataset for this example +per_process_batch = ds.as_numpy_iterator().next() + +per_process_batch_size = per_process_batch.shape[0] # adjust if your batch dim + # isn't 0 + +per_replica_batch_size = per_process_batch_size // jax.local_device_count() +assert per_process_batch_size % per_replica_batch_size == 0, \ + "This example doesn't implement padding." +per_replica_batches = np.split(per_process_batch, jax.local_device_count()) + +# Thanks to the very important trick about data parallelism, no need to care what +# order the devices appear in the sharding. +sharding = jax.sharding.PositionalSharding(jax.devices()) +# PositionalSharding must have same rank as data being sharded. +sharding = sharding.reshape((jax.device_count(),) + + (1,) * (per_process_batch.ndim - 1)) + +global_batch_size = per_replica_batch_size * jax.device_count() +global_batch_shape = ((global_batch_size,) + per_process_batch.shape[1:]) + +global_batch_array = jax.make_array_from_single_device_arrays( + global_batch_shape, sharding, + # Thanks again to the very important trick, no need to care which device gets + # which per-replica batch. + arrays=[jax.device_put(batch, device) + for batch, device + in zip(per_replica_batches, sharding.addressable_devices)]) + +assert global_batch_array.shape == global_batch_shape +assert (global_batch_array.addressable_shards[0].data.shape == + per_replica_batches[0].shape) +``` + +## Data + model parallelism + +In **model parallelism** you shard each model replica across multiple devices. If you use **pure model parallelism** (without data parallelism): + +- There's just one model replica sharded across all devices; and +- The data is (usually) fully replicated across all devices. + +This guide considers a case where you use **both data and model parallelism**: + +- You shard each of the multiple model replicas over multiple devices; and +- You partially replicate the data over each model replica — each device in the same model replica gets the same per-replica batch, and devices across model replicas get different per-replica batches. + +### Model parallelism within a process + +For the purposes of data loading, the simplest approach can be to shard each model replica within the local devices of a single process. + +For this example, let's switch to 2 processes with 4 devices each (instead of 4 processes with 2 devices each). Consider a scenario where each model replica is sharded over the 2 local devices of a single process. This results in 2 model replicas per process and 4 model replicas total, as demonstrated below: + +
+ +![Data and model parallelism - example 1](./_static/distributed_data_loading/15.svg) + +
+ +Here, once again, the input data is represented as a single {class}`jax.Array` with a 1D sharding where each shard is a per-replica batch with an exception: + +- Unlike in the pure data parallelism case, you introduce partial replication and make 2 copies of the 1D-sharded global batch. +- This is because each model replica is composed of 2 devices that each need a copy of the per-replica batch. + +
+ +![Data and model parallelism - example 2](./_static/distributed_data_loading/16.svg) + +
+ +Keeping each model replica within a single process can make things simpler because you can reuse the pure data parallelism setup described above, except you also need to replicate the per-replica batches: + +
+ +![Data and model parallelism - example 3](./_static/distributed_data_loading/17.svg) + +
+ +```{note} +_It's also very important to replicate the per-replica batches to the correct devices!_ While the very important trick about data parallelism means you don't care which batch ends up on which replica, *you do care that a single replica only gets a single batch*. +``` + +For example, this is OK: + +
+ +![Data and model parallelism - example 4](./_static/distributed_data_loading/18.svg) + +
+ +However, if you’re not careful about which local device you load each batch onto, you may accidentally create unreplicated data, even though the {mod}`~jax.sharding.Sharding` (and the parallelism strategy) says the data is replicated: + +
+ +![Data and model parallelism - example 4](./_static/distributed_data_loading/19.svg) + +
+ +JAX will raise an error if you accidentally create a {class}`jax.Array` with unreplicated data that should be replicated within a single process (this isn't always true for model parallelism across processes though; see the next section). + +Here's an example of how to implement per-process model parallelism and data parallelism using `tf.data`: + +```{code-cell} +import jax +import tensorflow as tf +import numpy as np + +################################################################################ +# Step 1: Set up the Dataset with a different data shard per-process (do once) +# (same as for pure data parallelism) +################################################################################ +# Fake example data (replace with your Dataset) +per_process_batches = [np.ones((16, 3)) * i for i in range(100)] +ds = tf.data.Dataset.from_tensor_slices(per_process_batches) + +ds = ds.shard(num_shards=jax.process_count(), index=jax.process_index()) + +################################################################################ +# Step 2: Create a jax.Array of per-replica batches from the per-process batch +# produced from the Dataset (repeat every step) +################################################################################ +# Grab just the first batch from the Dataset for this example +per_process_batch = ds.as_numpy_iterator().next() + +num_model_replicas_per_process = 2 # set according to your parallelism strategy +num_model_replicas_total = num_model_replicas_per_process * jax.process_count() + +per_process_batch_size = per_process_batch.shape[0] # adjust if your batch dim + # isn't 0 + +per_replica_batch_size = (per_process_batch_size // + num_model_replicas_per_process) +assert per_process_batch_size % per_replica_batch_size == 0, \ + "This example doesn't implement padding." +per_replica_batches = np.split(per_process_batch, + num_model_replicas_per_process) + +# Create an example `Mesh` for per-process data parallelism. Make sure all devices +# are grouped by process, and then resize so each row is a model replica. +mesh_devices = np.array([jax.local_devices(process_idx) + for process_idx in range(jax.process_count())]) +mesh_devices = mesh_devices.reshape(num_model_replicas_total, -1) +# Double check that each replica's devices are on a single process. +for replica_devices in mesh_devices: + num_processes = len(set(d.process_index for d in replica_devices)) + assert num_processes == 1 +mesh = jax.sharding.Mesh(mesh_devices, ["model_replicas", "data_parallelism"]) + +# Shard the data across model replicas. You don't shard across the +# data_parallelism mesh axis, meaning each per-replica shard will be replicated +# across that axis. +sharding = jax.sharding.NamedSharding( + mesh, jax.sharding.PartitionSpec("model_replicas")) + +global_batch_size = per_replica_batch_size * num_model_replicas_total +global_batch_shape = ((global_batch_size,) + per_process_batch.shape[1:]) + +# Create the final jax.Array using jax.make_array_from_callback. The callback +# will be called for each local device, and passed the N-D numpy-style index +# that describes what shard of the global data that device should receive. +# +# You don't need care exactly which index is passed in due to the very important data +# parallelism, but you do use the index argument to make sure you replicate each +# per-replica batch correctly -- the `index` argument will be the same for +# devices in the same model replica, and different for devices in different +# model replicas. + +index_to_batch = {} +def callback(index: tuple[slice, ...]) -> np.ndarray: + # Python `slice` objects aren't hashable, so manually create dict key. + index_key = tuple((slice_.start, slice_.stop) for slice_ in index) + if index_key not in index_to_batch: + # You don't care which per-replica batch goes to which replica, just take the + # next unused one. + index_to_batch[index_key] = per_replica_batches[len(index_to_batch)] + return index_to_batch[index_key] + +global_batch_array = jax.make_array_from_callback( + global_batch_shape, sharding, callback) + +assert global_batch_array.shape == global_batch_shape +assert (global_batch_array.addressable_shards[0].data.shape == + per_replica_batches[0].shape) +``` + +### Model parallelism across processes + +It can get more interesting when model replicas are spread across processes, either: + +- Because a single replica can't fit within a process; or +- Because the device assignment just isn't set up that way. + +For example, going back to the previous setup of 4 processes with 2 devices each, if you assign devices to replicas like so: + +
+ +![Model parallelism across processes - example 1](./_static/distributed_data_loading/20.svg) + +
+ +This is the same parallelism strategy as the previous per-process model parallelism example – 4 model replicas each sharded across 2 devices. The only difference is the device assignment – each replica's two devices are split across different processes, and each process is only responsible for one copy of each per-replica batch (but for two replicas). + +Splitting the model replicas across processes like this may seem like an arbitrary and unnecessary thing to do (and in this example it arguably is), but actual deployments may end up with this kind of device assignment to best take advantage of the communication links between devices. + +Data loading now becomes more complicated because some extra coordination is required across processes. In the pure data parallelism and per-process model parallelism cases, it was only important that each process loaded a unique data stream. Now certain processes must load the same data, and some must load different data. In the above example, processes `0` and `2` (in colors pink and green, respectively) must load the same 2 per-replica batches, and processes `1` and `3` (colors yellow and blue, respectively) must also load the same 2 per-replica batches (but different from process `0` and `2`'s batches). + +Furthermore, it's important that each process doesn't mix up its 2 per-replica batches. While you don't care which batch lands on which replica (the very important trick about data parallelism), you need to care that all the devices in a replica get the same batch. For example, this would be bad: + +
+ +![Model parallelism across processes - example 2](./_static/distributed_data_loading/21.svg) + +
+ +```{note} +As of August 2023, JAX cannot detect if {class}`jax.Array` shards across processes are supposed to be replicated but aren't, and will produce wrong results when the computation is run. So be careful not to do this! +``` + +To get the correct per-replica batch on each device, you need to represent the global input data as the following {class}`jax.Array`: + +
+ +![Model parallelism across processes - example 3](./_static/distributed_data_loading/22.svg) + +
diff --git a/docs/errors.rst b/docs/errors.rst index 4c76f5dcf5a1..23dbaf29c46f 100644 --- a/docs/errors.rst +++ b/docs/errors.rst @@ -7,6 +7,7 @@ along with representative examples of how one might fix them. .. currentmodule:: jax.errors .. autoclass:: ConcretizationTypeError +.. autoclass:: KeyReuseError .. autoclass:: NonConcreteBooleanIndexError .. autoclass:: TracerArrayConversionError .. autoclass:: TracerBoolConversionError diff --git a/docs/export/export.md b/docs/export/export.md new file mode 100644 index 000000000000..3a176a5e1bf6 --- /dev/null +++ b/docs/export/export.md @@ -0,0 +1,799 @@ + + +# Exporting and serializing staged-out computations + +The {ref}`ahead-of-time-lowering` APIs produce +objects that can be used for debugging or for compilation and +execution in the same process. +Sometimes you want to serialize a lowered JAX function for +compilation and execution in a separate process, perhaps +at a later time. This would allow you to: + + * compile and execute the function in another process or machine + without requiring access to the JAX program, + and without having to repeat the staging-out and lowering, e.g., + in an inference system. + * trace and lower a function on a machine that does not have access + to the accelerator for which you want to later compile and execute + the function. + * archive a snapshot of a JAX function, e.g., to be able to + reproduce later your results. **Note:** check out the [compatibility + guarantees](#compatibility-guarantees) for this use case. + +For more details see the {mod}`jax.export` API reference. + +Here is an example: + +```python +>>> import re +>>> import numpy as np +>>> import jax +>>> from jax import export + +>>> def f(x): return 2 * x * x + + +>>> exported: export.Exported = export.export(jax.jit(f))( +... jax.ShapeDtypeStruct((), np.float32)) + +>>> # You can inspect the Exported object +>>> exported.fun_name +'f' + +>>> exported.in_avals +(ShapedArray(float32[]),) + +>>> print(re.search(r".*@main.*", exported.mlir_module()).group(0)) + func.func public @main(%arg0: tensor {mhlo.layout_mode = "default"} loc("x")) -> (tensor {jax.result_info = "", mhlo.layout_mode = "default"}) { + +>>> # And you can serialize the Exported to a bytearray. +>>> serialized: bytearray = exported.serialize() + +>>> # The serialized function can later be rehydrated and called from +>>> # another JAX computation, possibly in another process. +>>> rehydrated_exp: export.Exported = export.deserialize(serialized) +>>> rehydrated_exp.in_avals +(ShapedArray(float32[]),) + +>>> def callee(y): +... return 3. * rehydrated_exp.call(y * 4.) + +>>> callee(1.) +Array(96., dtype=float32) + +``` + +Serialization is broken down into two stages: + 1. exporting to produce an {class}`jax.export.Exported` object that contains + the StableHLO for the lowered function along with the metadata necessary to + call it from another JAX function. We have plans to add code to generate + `Exported` objects from TensorFlow, and to use `Exported` objects from + TensorFlow and PyTorch. + 2. the actual serialization to a byte array using the flatbuffers format. + See {ref}`jax2tf` for + an alternative serialization to TensorFlow graph that can be used + for interoperation with TensorFlow. + +## Support for reverse-mode AD + +Serialization can optionally support higher-order reverse-mode AD. This is done +by serializing the {func}`jax.vjp` of the primal function along with the primal function, +up to a user-specified order (default is 0, meaning that the rehydrated +function cannot be differentiated): + +```python +>>> import jax +>>> from jax import export +>>> from typing import Callable + +>>> def f(x): return 7 * x * x * x + +>>> # Serialize 3 levels of VJP along with the primal function +>>> blob: bytearray = export.export(jax.jit(f))(1.).serialize(vjp_order=3) +>>> rehydrated_f: Callable = export.deserialize(blob).call + +>>> rehydrated_f(0.1) # 7 * 0.1^3 +Array(0.007, dtype=float32) + +>>> jax.grad(rehydrated_f)(0.1) # 7*3 * 0.1^2 +Array(0.21000001, dtype=float32) + +>>> jax.grad(jax.grad(rehydrated_f))(0.1) # 7*3*2 * 0.1 +Array(4.2, dtype=float32) + +>>> jax.grad(jax.grad(jax.grad(rehydrated_f)))(0.1) # 7*3*2 +Array(42., dtype=float32) + +>>> jax.grad(jax.grad(jax.grad(jax.grad(rehydrated_f))))(0.1) # doctest: +IGNORE_EXCEPTION_DETAIL +Traceback (most recent call last): +ValueError: No VJP is available + +``` + +Note that the VJP function is computed lazily while serializing, +when the JAX program is still available. +This means that it respects all features of JAX VJP, +e.g., {func}`jax.custom_vjp` and {func}`jax.remat`. + +Note that the rehydrated function does not support any other +transformations, e.g., forward-mode AD (jvp), or {func}`jax.vmap`. + +## Compatibility guarantees + +You should not use the raw StableHLO that is obtained from just lowering +(`jax.jit(f).lower(1.).compiler_ir()`) +for archival and for compilation in another process, for several reasons. + +First, the compilation may use a different version of the compiler, supporting a +different version of StableHLO. The {class}`jax.export` module takes +care of this by using the +[portable-artifact feature of StableHLO](https://github.com/openxla/stablehlo/blob/main/docs/compatibility.md) +to deal with the possible evolution of the StableHLO opset. + +### Compatibility guarantees for custom calls + +Second, the raw StableHLO may contain custom calls referencing C++ +functions. +JAX uses custom calls for lowering of a small number of primitives, +e.g., linear algebra primitives, sharding annotations, or Pallas kernels. +These do not fall under the compatibility guarantees for StableHLO. +The C++ implementations of these functions change rarely, but they can change. + +`jax.export` makes the following export compatibility guarantees: +A JAX exported artifact can be compiled and executed by a compiler and +JAX runtime system that are: + + * **up to 6 months newer** than the version of JAX used for exporting + (we say that JAX export offers **6 months backward compatibility**). + This is useful if we want to archive the exported artifact to be compiled and executed later. + * **up to 3 weeks older** than the version of JAX used for exporting + (we say that JAX export offers **3 weeks forward compatibility**). + This is useful if we want to compile and run an exported artifact with a + consumer that was built and deployed before the export, e.g., + an inference system that is already deployed when the exporting is done. + +(The particular compatibility window lengths are the same that JAX +[promised for jax2tf](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#usage-saved-model), +and are based on [TensorFlow Compatibility](https://www.tensorflow.org/guide/versions#graph_and_checkpoint_compatibility_when_extending_tensorflow). +The terminology “backward compatibility” is from the perspective of the consumer, +e.g., the inference system.) + +What **matters is when the exporting and consuming components were built**, +not the time when the exporting and the compilation happen. +For external JAX users, it is +[possible to run JAX and jaxlib at different versions](https://jax.readthedocs.io/en/latest/jep/9419-jax-versioning.html#how-are-jax-and-jaxlib-versioned); +what matters is when the jaxlib release was built. + +To reduce chances of incompatibility, internal JAX users should: + * **rebuild and redeploy consumer systems as frequently as possible**. + +and external users should: + * run the exporting and consumer systems with the same version of jaxlib, whenever possible, and + * export for archival **with the latest released version of jaxlib**. + +The compatibility guarantees do not apply if you bypass the `jax.export` APIs +to obtain the StableHLO code. + +In order to ensure forward compatibility, when we change the JAX lowering rules +to use a new custom call target, JAX will refrain for 3 weeks to use the new +target. To use the latest lowering rules, you can pass the +`--jax_export_ignore_forward_compatibility=1` configuration flag +or the `JAX_EXPORT_IGNORE_FORWARD_COMPATIBILITY=1` environment variable. + +Only a subset of custom calls are guaranteed stable and have +compatibility guarantees ([see list](https://github.com/search?q=repo%3Agoogle%2Fjax%20_CUSTOM_CALL_TARGETS_GUARANTEED_STABLE&type=code)). +We continuously +add more custom call targets to the allowed list along with backwards +compatibility tests. If you try to serialize +code that invokes other custom call targets you will get an error +during exporting. + +If you want to disable this safety check for a specific custom call, +e.g., with target `my_target`, you can add +`export.DisabledSafetyCheck.custom_call("my_target")` to the +`disabled_checks` parameter of the `export` method, +as in the following example: + +```python +>>> import jax +>>> from jax import export +>>> from jax import lax +>>> from jax._src import core +>>> from jax._src.interpreters import mlir +>>> # Define a new primitive backed by a custom call +>>> new_prim = core.Primitive("new_prim") +>>> _ = new_prim.def_abstract_eval(lambda x: x) +>>> _ = mlir.register_lowering(new_prim, lambda ctx, o: mlir.custom_call("my_new_prim", operands=[o], result_types=[o.type]).results) +>>> print(jax.jit(new_prim.bind).lower(1.).compiler_ir()) +module @jit_bind attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main(%arg0: tensor {mhlo.layout_mode = "default"}) -> (tensor {jax.result_info = "", mhlo.layout_mode = "default"}) { + %0 = stablehlo.custom_call @my_new_prim(%arg0) {api_version = 2 : i32} : (tensor) -> tensor + return %0 : tensor + } +} + +>>> # If we try to export, we get an error +>>> export.export(jax.jit(new_prim.bind))(1.) # doctest: +IGNORE_EXCEPTION_DETAIL +Traceback (most recent call last): +ValueError: Cannot serialize code with custom calls whose targets have no compatibility guarantees: my_new_bind + +>>> # We can avoid the error if we pass a `DisabledSafetyCheck.custom_call` +>>> exp = export.export( +... jax.jit(new_prim.bind), +... disabled_checks=[export.DisabledSafetyCheck.custom_call("my_new_prim")])(1.) + +``` + +See {ref}`export_ensuring_compat` for developer information regarding +ensuring compatibility. + +## Cross-platform and multi-platform export + +JAX lowering is platform specific for a small number of JAX primitives. +By default, the code is lowered and exported for the accelerator +present on the exporting machine: + +```python +>>> from jax import export +>>> export.default_export_platform() +'cpu' + +``` + +There is a safety check that will be raise an error when trying to compile +an `Exported` object on a machine that does not have the accelerator +for which the code was exported. + +You can specify explicitly for what platforms the code should be exported. +This allows you to specify a different accelerator than you have +available at export time, +and it even allows you to specify multi-platform lexport to +obtain an `Exported` object that can be compiled and executed +on multiple platforms. + + +```python +>>> import jax +>>> from jax import export +>>> from jax import lax + +>>> # You can specify the export platform, e.g., `tpu`, `cpu`, `cuda`, `rocm` +>>> # even if the current machine does not have that accelerator. +>>> exp = export.export(jax.jit(lax.cos), platforms=['tpu'])(1.) + +>>> # But you will get an error if you try to compile `exp` +>>> # on a machine that does not have TPUs. +>>> exp.call(1.) # doctest: +IGNORE_EXCEPTION_DETAIL +Traceback (most recent call last): +ValueError: Function 'cos' was lowered for platforms '('tpu',)' but it is used on '('cpu',)'. + +>>> # We can avoid the error if we pass a `DisabledSafetyCheck.platform` +>>> # parameter to `export`, e.g., because you have reasons to believe +>>> # that the code lowered will run adequately on the current +>>> # compilation platform (which is the case for `cos` in this +>>> # example): +>>> exp_unsafe = export.export(jax.jit(lax.cos), +... lowering_platforms=['tpu'], +... disabled_checks=[export.DisabledSafetyCheck.platform()])(1.) + +>>> exp_unsafe.call(1.) +Array(0.5403023, dtype=float32, weak_type=True) + +# and similarly with multi-platform lowering +>>> exp_multi = export.export(jax.jit(lax.cos), +... lowering_platforms=['tpu', 'cpu', 'cuda'])(1.) +>>> exp_multi.call(1.) +Array(0.5403023, dtype=float32, weak_type=True) + +``` + +For multi-platform export, the StableHLO will contain multiple +lowerings but only for those primitives that require it, so the +resulting module size should be only marginally larger than the +size of a module with default export. +As an extreme case, when serializing a module without any +primitives with platform-specific lowering, you will get +the same StableHLO as for the single-plaform export. + +```python +>>> import jax +>>> from jax import export +>>> from jax import lax +>>> # A largish function +>>> def f(x): +... for i in range(1000): +... x = jnp.cos(x) +... return x + +>>> exp_single = export.export(jax.jit(f))(1.) +>>> len(exp_single.mlir_module_serialized) # doctest: +SKIP +9220 + +>>> exp_multi = export.export(jax.jit(f), +... lowering_platforms=["cpu", "tpu", "cuda"])(1.) +>>> len(exp_multi.mlir_module_serialized) # doctest: +SKIP +9282 + +``` + +## Shape polymorphic export + +When used in JIT mode, JAX will trace and lower a function separately +for each combination of input shapes. When exporting, it is possible +in some cases to use dimension variables for some input dimensions +in order to obtain an exported artifact that can be used with multiple +combinations of input shapes. + +See the {ref}`shape_poly` documentation. + +## Device polymorphic export + +An exported artifact may contain sharding annotations for inputs, +outputs and for some intermediates, but these annotations do not refer +directly to the actual physical devices that existed at exporting time. +Instead, the sharding annotations refer to logical devices. This +means that you can compile and run the exported artifacts on different +physical devices that were used for exporting. + +```python +>>> import jax +>>> from jax import export +>>> from jax.sharding import Mesh, NamedSharding +>>> from jax.sharding import PartitionSpec as P + +>>> # Use the first 4 devices for exporting. +>>> export_devices = jax.local_devices()[:4] +>>> export_mesh = Mesh(export_devices, ("a",)) +>>> def f(x): +... return x.T + +>>> arg = jnp.arange(8 * len(export_devices)) +>>> exp = export.export(jax.jit(f, in_shardings=(NamedSharding(export_mesh, P("a")),)))(arg) + +>>> # `exp` knows for how many devices it was exported. +>>> exp.nr_devices +4 + +>>> # and it knows the shardings for the inputs. These will be applied +>>> # when the exported is called. +>>> exp.in_shardings_hlo +({devices=[4]<=[4]},) + +>>> res1 = exp.call(jax.device_put(arg, +... NamedSharding(export_mesh, P("a")))) + +>>> # Check out the first 2 shards of the result +>>> [f"device={s.device} index={s.index}" for s in res1.addressable_shards[:2]] +['device=TFRT_CPU_0 index=(slice(0, 8, None),)', + 'device=TFRT_CPU_1 index=(slice(8, 16, None),)'] + +>>> # We can call `exp` with some other 4 devices and another +>>> # mesh with a different shape, as long as the number of devices is +>>> # the same. +>>> other_mesh = Mesh(np.array(jax.local_devices()[2:6]).reshape((2, 2)), ("b", "c")) +>>> res2 = exp.call(jax.device_put(arg, +... NamedSharding(other_mesh, P("b")))) + +>>> # Check out the first 2 shards of the result. Notice that the output is +>>> # sharded similarly; this means that the input was resharded according to the +>>> # exp.in_shardings. +>>> [f"device={s.device} index={s.index}" for s in res2.addressable_shards[:2]] +['device=TFRT_CPU_2 index=(slice(0, 8, None),)', + 'device=TFRT_CPU_3 index=(slice(8, 16, None),)'] + +``` + +It is an error to try to invoke an exported artifact with a different number +of devices than it was exported for: + +```python +>>> import jax +>>> from jax import export +>>> from jax.sharding import Mesh, NamedSharding +>>> from jax.sharding import PartitionSpec as P + +>>> export_devices = jax.local_devices() +>>> export_mesh = Mesh(np.array(export_devices), ("a",)) +>>> def f(x): +... return x.T + +>>> arg = jnp.arange(4 * len(export_devices)) +>>> exp = export.export(jax.jit(f, in_shardings=(NamedSharding(export_mesh, P("a")),)))(arg) + +>>> exp.call(arg) # doctest: +IGNORE_EXCEPTION_DETAIL +Traceback (most recent call last): +ValueError: Exported module f was lowered for 8 devices and is called in a context with 1 devices. This is disallowed because: the module was lowered for more than 1 device. + +``` + +There are helper functions to shard the inputs for calling an exported +artifacts using a new mesh constructed at the call site: + +```python +>>> import jax +>>> from jax import export +>>> from jax.sharding import Mesh, NamedSharding +>>> from jax.sharding import PartitionSpec as P + +>>> export_devices = jax.local_devices() +>>> export_mesh = Mesh(np.array(export_devices), ("a",)) +>>> def f(x): +... return x.T + +>>> arg = jnp.arange(4 * len(export_devices)) +>>> exp = export.export(jax.jit(f, in_shardings=(NamedSharding(export_mesh, P("a")),)))(arg) + +>>> # Prepare the mesh for calling `exp`. +>>> calling_mesh = Mesh(np.array(export_devices[::-1]), ("b",)) + +>>> # Shard the arg according to what `exp` expects. +>>> sharded_arg = jax.device_put(arg, exp.in_shardings_jax(calling_mesh)[0]) +>>> res = exp.call(sharded_arg) + +``` + +As a special facility, if a function was exported for 1 device and if it contains no +sharding annotations, then it can be invoked on an argument of the same shape but sharded +on multiple devices, and the compiler will shard the function appropriately: + +```python +```python +>>> import jax +>>> from jax import export +>>> from jax.sharding import Mesh, NamedSharding +>>> from jax.sharding import PartitionSpec as P + +>>> def f(x): +... return jnp.cos(x) + +>>> arg = jnp.arange(4) +>>> exp = export.export(jax.jit(f))(arg) +>>> exp.in_avals +(ShapedArray(int32[4]),) + +>>> exp.nr_devices +1 + +>>> # Prepare the mesh for calling `exp`. +>>> calling_mesh = Mesh(jax.local_devices()[:4], ("b",)) + +>>> # Shard the arg according to what `exp` expects. +>>> sharded_arg = jax.device_put(arg, +... NamedSharding(calling_mesh, P("b"))) +>>> res = exp.call(sharded_arg) + +``` + +## Calling convention versions + +The JAX export support has evolved over time, e.g., to support +effects. In order to support compatibility (see [compatibility guarantees](#compatibility-guarantees)) +we maintain a calling convention version for each `Exported`. +As of June 2024, all function exported with version 9 +(the latest, see [all calling convention versions](#calling-convention-versions)): + +```python +>>> from jax import export +>>> exp: export.Exported = export.export(jnp.cos)(1.) +>>> exp.calling_convention_version +9 + +``` + +At any given time, the export APIs may support a range +of calling convention versions. You can control which calling convention +version to use using the `--jax_export_calling_convention_version` flag +or the `JAX_EXPORT_CALLING_CONVENTION_VERSION` environment variable: + +```python +>>> from jax import export +>>> (export.minimum_supported_calling_convention_version, export.maximum_supported_calling_convention_version) +(9, 9) + +>>> from jax._src import config +>>> with config.jax_export_calling_convention_version(9): +... exp = export.export(jnp.cos)(1.) +... exp.calling_convention_version +9 + +``` + +We reserve the right to remove support for +generating or consuming calling convention versions older than 6 months. + +### Module calling convention + +The `Exported.mlir_module` has a `main` function that takes an optional first +platform index argument if the module supports multiple platforms +(`len(platforms) > 1`), followed by the token arguments corresponding +to the ordered effects, followed by the kept array +arguments (corresponding to `module_kept_var_idx` and `in_avals`). +The platform index is a i32 or i64 scalar encoding the index of the current +compilation platform into the `platforms` sequence. + +Inner functions use a different calling convention: an optional +platform index argument, optional dimension variable arguments +(scalar tensors of type i32 or i64), +followed by optional token arguments (in presence of ordered effects), +followed by the regular array arguments. +The dimension arguments correspond to the dimension variables appearing in +the `args_avals`, in sorted order of their names. + +Consider the lowering of a function with one array argument of type +`f32[w, 2 * h]`, where `w` and `h` are two dimension variables. +Assume that we use multi-platform lowering, and we have +one ordered effect. The `main` function will be as follows: + +``` + func public main( + platform_index: i32 {jax.global_constant="_platform_index"}, + token_in: token, + arg: f32[?, ?]) { + arg_w = hlo.get_dimension_size(arg, 0) + dim1 = hlo.get_dimension_size(arg, 1) + arg_h = hlo.floordiv(dim1, 2) + call _check_shape_assertions(arg) # See below + token = new_token() + token_out, res = call _wrapped_jax_export_main(platform_index, + arg_h, + arg_w, + token_in, + arg) + return token_out, res + } +``` + +The actual computation is in `_wrapped_jax_export_main`, taking also +the values of `h` and `w` dimension variables. + +The signature of the `_wrapped_jax_export_main` is: + +``` + func private _wrapped_jax_export_main( + platform_index: i32 {jax.global_constant="_platform_index"}, + arg_h: i32 {jax.global_constant="h"}, + arg_w: i32 {jax.global_constant="w"}, + arg_token: stablehlo.token {jax.token=True}, + arg: f32[?, ?]) -> (stablehlo.token, ...) +``` + +Prior to calling convention version 9 the calling convention for effects was +different: the `main` function does not take or return a token. Instead +the function creates dummy tokens of type `i1[0]` and passes them to the +`_wrapped_jax_export_main`. The `_wrapped_jax_export_main` +takes dummy tokens of type `i1[0]` and will create internally real +tokens to pass to the inner functions. The inner functions use real +tokens (both before and after calling convention version 9) + +Also starting with calling convention version 9, function arguments that contain +the platform index or the dimension variable values have a +`jax.global_constant` string attribute whose value is the name of the +global constant, either `_platform_index` or a dimension variable name. +The global constant name may be empty if it is not known. +Some global constant computations use inner functions, e.g., for +`floor_divide`. The arguments of such functions have a `jax.global_constant` +attribute for all attributes, meaning that the result of the function is +also a global constant. + +Note that `main` contains a call to `_check_shape_assertions`. +JAX tracing assumes that `arg.shape[1]` is even, and that both `w` and `h` +have values >= 1. We must check these constraints when we invoke the +module. We use a special custom call `@shape_assertion` that takes +a boolean first operand, a string `error_message` attribute that may contain +format specifiers `{0}`, `{1}`, ..., and a variadic number of integer +scalar operands corresponding to the format specifiers. + +``` + func private _check_shape_assertions(arg: f32[?, ?]) { + # Check that w is >= 1 + arg_w = hlo.get_dimension_size(arg, 0) + custom_call @shape_assertion(arg_w >= 1, arg_w, + error_message="Dimension variable 'w' must have integer value >= 1. Found {0}") + # Check that dim1 is even + dim1 = hlo.get_dimension_size(arg, 1) + custom_call @shape_assertion(dim1 % 2 == 0, dim1 % 2, + error_message="Division had remainder {0} when computing the value of 'h') + # Check that h >= 1 + arg_h = hlo.floordiv(dim1, 2) + custom_call @shape_assertion(arg_h >= 1, arg_h, + error_message=""Dimension variable 'h' must have integer value >= 1. Found {0}") +``` + +(export-calling-convention-version)= + +### Calling convention versions + +We list here a history of the calling convention version numbers: + + * Version 1 used MHLO & CHLO to serialize the code, not supported anymore. + * Version 2 supports StableHLO & CHLO. Used from October 2022. Not supported + anymore. + * Version 3 supports platform checking and multiple platforms. + Used from February 2023. Not supported anymore. + * Version 4 supports StableHLO with compatibility guarantees. + This is the earliest version at the time of the JAX native serialization + launch. + Used in JAX from March 15, 2023 (cl/516885716). Starting with + March 28th, 2023 we stopped using `dim_args_spec` (cl/520033493). + The support for this version was dropped on + October 17th, 2023 (cl/573858283). + * Version 5 adds support for `call_tf_graph`. This is currently used + for some specialized use cases. Used in JAX from May 3rd, 2023 + (cl/529106145). + * Version 6 adds support for the `disabled_checks` attribute. This version + mandates a non-empty `platforms` attribute. Supported by XlaCallModule + since June 7th, 2023 and available in JAX since + June 13th, 2023 (JAX 0.4.13). + * Version 7 adds support for `stablehlo.shape_assertion` operations and + for `shape_assertions` specified in `disabled_checks`. + See [Errors in presence of shape polymorphism](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#errors-in-presence-of-shape-polymorphism). Supported by XlaCallModule + since July 12th, 2023 (cl/547482522), + available in JAX serialization since July 20th, 2023 (JAX 0.4.14), + and the default since August 12th, 2023 (JAX 0.4.15). + * Version 8 adds support for the `jax.uses_shape_polymorphism` module + attribute and enables the shape refinement pass only when the + attribute is present. Supported by XlaCallModule since July 21st, 2023 + (cl/549973693), available in JAX since July 26th, 2023 (JAX 0.4.14), + and the default since October 21st, 2023 (JAX 0.4.20). + * Version 9 adds support for effects. + See the docstring for `export.Exported` for the precise calling convention. + In this calling convention version we also tag the platform index and the + dimension variables arguments with `jax.global_constant` attributes. + Supported by XlaCallModule since October 27th, 2023, + available in JAX since October 20th, 2023 (JAX 0.4.20), + and the default since February 1st, 2024 (JAX 0.4.24). + This is the only supported version as of 27th of March, 2024. + +## Developer documentation + +(export_debugging)= +### Debugging + +You can log the exported modules, with somewhat different flags in OSS versus +in Google. In OSS you can do the following: + +```shell +# Log from python +python tests/export_test.py JaxExportTest.test_basic -v=3 +# Or, log from pytest to /tmp/mylog.txt +pytest tests/export_test.py -k test_basic --log-level=3 --log-file=/tmp/mylog.txt +``` + +You will see a log line of the form: +```shell +I0619 10:54:18.978733 8299482112 _export.py:606] Exported JAX function: fun_name=sin version=9 lowering_platforms=('cpu',) disabled_checks=() +I0619 10:54:18.978767 8299482112 _export.py:607] Define JAX_DUMP_IR_TO to dump the module. +``` + +If you set the environment variable `JAX_DUMP_IR_TO` to a directory, the exported (and the JIT compiled) HLO +modules will be saved there. + +```shell +JAX_DUMP_IR_TO=/tmp/export.dumps pytest tests/export_test.py -k test_basic --log-level=3 --log-file=/tmp/mylog.txt +INFO absl:_export.py:606 Exported JAX function: fun_name=sin version=9 lowering_platforms=('cpu',) disabled_checks=() +INFO absl:_export.py:607 The module was dumped to jax_ir0_jit_sin_export.mlir. +``` + +You will see both the exported modules (named `..._export.mlir` +and the JIT compiled modules (named `..._compile.mlir`): +```shell +$ ls -l /tmp/export.dumps/ +total 32 +-rw-rw-r--@ 1 necula wheel 2316 Jun 19 11:04 jax_ir0_jit_sin_export.mlir +-rw-rw-r--@ 1 necula wheel 2279 Jun 19 11:04 jax_ir1_jit_sin_compile.mlir +-rw-rw-r--@ 1 necula wheel 3377 Jun 19 11:04 jax_ir2_jit_call_exported_compile.mlir +-rw-rw-r--@ 1 necula wheel 2333 Jun 19 11:04 jax_ir3_jit_my_fun_export.mlir +``` + +Inside Google, you can turn on logging by using the `--vmodule` argument to +specify the logging levels for different modules, +e.g., `--vmodule=_export=3`. + + +(export_ensuring_compat)= +### Ensuring forward and backward compatibility + +This section discusses the process JAX developers +should use to ensure the [compatibility guarantees](#compatibility-guarantees). + +One complication is that external users install JAX and jaxlib +in separate packages, +and users often end up using an older jaxlib than JAX. +We observe that the custom calls live in the jaxlib, and only the jaxlib is relevant +for a consumer of an exported artifact. +To simplify the process, we are setting the expectation for external users +that the compatibility window is defined in terms of jaxlib releases, +and it is their responsibility to ensure that they export with a new jaxlib +even if JAX would function with an older version. + +Thus, we care only about jaxlib releases. +We can start a backward-compatibility deprecation clock when we make a jaxlib release, +even if we don’t force it to be the minimum allowed version. + +Let’s say that we need to add, delete, or change the semantics of a +custom call target `T` used by the JAX lowering rules. +Here is a possible chronology (for changing custom call targets +that live in jaxlib): + + 1. Day “D - 1”, before the change. Say that the active internal JAX version is `0.4.31` + (the version of the next JAX and jaxlib releases). + The JAX lowering rules use a custom call `T`. + 2. Day “D”, we add the new custom call target `T_NEW`. + We should create a new custom call target, and clean up the old + target roughly after 6 months, rather than updating `T` in place: + * See the example [PR #20997](https://github.com/google/jax/pull/20997) + implementing the steps below. + * We add the custom call target `T_NEW`. + * We change the JAX lowering rules that were previous using `T`, + to use `T_NEW`, conditionally as follows: + + ```python + from jax._src import config + from jax._src.lib import version as jaxlib_version + + def my_lowering_rule(ctx: LoweringRuleContext, ...): + lowering_parameters = ctx.module_context.lowering_parameters + forward_compat_mode = (lowering_parameters.for_export and + not lowering_parameters.export_ignore_forward_compatibility) + if forward_compat_mode or jaxlib_version < (0, 4, 31): + # this is the old lowering, using target T, while we + # are in forward compatibility mode for T, or we + # are in OSS and are using an old jaxlib. + return hlo.custom_call("T", ...) + else: + # This is the new lowering, using target T_NEW, for + # when we use a jaxlib with version `>= (0, 4, 31)` + # (or when this is internal usage), and also we are + # in JIT mode. + return hlo.custom_call("T_NEW", ...) + ``` + * Note that the forward compatibility mode is always false in JIT mode + or if the user passes `--jax_export_ignore_forward_compatibility=true` + * We add `T_NEW` to the list of + [`_CUSTOM_CALL_TARGETS_GUARANTEED_STABLE`](https://github.com/search?q=repo%3Agoogle%2Fjax++%22_CUSTOM_CALL_TARGETS_GUARANTEED_STABLE+%3D%22+path%3A_export.py&%3Btype=code&type=code) + in `_export.py`. + 3. Day “D + 21” (end of forward compatibility window; can be even later than 21 days): + We remove the `forward_compat_mode` in the lowering code, so now exporting + will start using the new custom call target `T_NEW` as long as we are using a new `jaxlib`. + * We add a backwards compatibility test for `T_NEW`. + 4. Day "RELEASE > D" (the first JAX release date after `D`, when we release version `0.4.31`): + we start the clock for the 6 months backwards compatibility. + Note that this is relevant only if `T` is among the custom call targets for which + we already guarantee stability, i.e., are listed in + [`_CUSTOM_CALL_TARGETS_GUARANTEED_STABLE`](https://github.com/search?q=repo%3Agoogle%2Fjax++%22_CUSTOM_CALL_TARGETS_GUARANTEED_STABLE+%3D%22+path%3A_export.py&%3Btype=code&type=code). + * If `RELEASE` is in the forward compatibility window `[D, D + 21]` and if + we make `RELEASE` the minimum allowed jaxlib version then we can + remove the `jaxlib_version < (0, 4, 31)` conditional in the + JIT branch. + 5. Day “RELEASE + 180” (end of backward compatibility window, + can be even later than 180 days): By now, we must have bumped + the minimum jaxlib so that the lowering conditional `jaxlib_version < (0, 4, 31)` + was already removed and JAX lowering cannot generate custom calls to `T`. + * We remove the C++ implementation of the old custom call target `T`. + * We remove also the backwards compatibility test for `T` + +## Migration guide from jax.experimental.export + +On June 18, 2024 (JAX version 0.4.30) +we deprecated the `jax.experimental.export` APIs +in favor of `jax.export` APIs. There have been some minor changes: + + * `jax.experimental.export.export`: + * The old function used to allow any Python callable, or the result of + `jax.jit`. Now only the latter is accepted. You have to manually apply + `jax.jit` to the function to export before calling `export`. + * The old `lowering_parameters` kwarg is now named `platforms` + * `jax.experimental.export.default_lowering_platform()` is now + at {func}`jax.export.default_export_platform`. + * `jax.experimental.export.call` is now a method of the {class}`jax.export.Exported` object. + Instead of `export.call(exp)` you should use `exp.call`. + * `jax.experimental.export.serialize` is now a method of the {class}`jax.export.Exported` + object. Instead of `export.serialize(exp)` you should use `exp.serialize()`. + * The configuration flag `--jax-serialization-version` is deprecated. + Use `--jax-export-calling-convention-version`. + * The value `jax.experimental.export.minimum_supported_serialization_version` + is now at `jax.export.minimum_supported_calling_convention_version`. + * The following fields of {class}`jax.export.Exported` have been renamed + * `uses_shape_polymorphism` is now `uses_global_constants` + * `mlir_module_serialization_version` is now `calling_convention_version` + * `lowering_platforms` is now `platforms`. diff --git a/docs/export/index.rst b/docs/export/index.rst new file mode 100644 index 000000000000..24cf2716cafe --- /dev/null +++ b/docs/export/index.rst @@ -0,0 +1,13 @@ +.. _export: + +Exporting and serialization +============================= + +.. toctree:: + :caption: Guides + :maxdepth: 2 + + export + shape_poly + + jax2tf diff --git a/docs/export/jax2tf.md b/docs/export/jax2tf.md new file mode 100644 index 000000000000..498a0418f232 --- /dev/null +++ b/docs/export/jax2tf.md @@ -0,0 +1,5 @@ +(jax2tf)= + +## Interoperation with TensorFlow + +See the [JAX2TF documentation](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md). diff --git a/docs/export/shape_poly.md b/docs/export/shape_poly.md new file mode 100644 index 000000000000..a88d70fb1d9a --- /dev/null +++ b/docs/export/shape_poly.md @@ -0,0 +1,663 @@ +(shape_poly)= + +# Shape polymorphism + +When JAX is used in JIT mode, a function will be traced, lowered to StableHLO, and compiled for each +combination of input types and shapes. After exporting a function and +deserializing it on another system we don't have the Python sources available anymore, +so we cannot re-trace and re-lower it. **Shape polymorphism** is a feature of JAX export +to allow some exported functions to be used for a whole family of input shapes. +These functions are traced and lowered once, during exporting, and `Exported` +object contains the information needed to be able to compile and execute the function +on many concrete input shapes. We do this by specifying shapes that contain +dimension variables (symbolic shapes) when exporting, as in the +following example: + +```python +>>> import jax +>>> from jax import export +>>> from jax import numpy as jnp +>>> def f(x): # f: f32[a, b] +... return jnp.concatenate([x, x], axis=1) + +>>> # We construct symbolic dimension variables. +>>> a, b = export.symbolic_shape("a, b") + +>>> # We can use the symbolic dimensions to construct shapes. +>>> x_shape = (a, b) +>>> x_shape +(a, b) + +>>> # Then we export with symbolic shapes: +>>> exp: export.Exported = export.export(jax.jit(f))( +... jax.ShapeDtypeStruct(x_shape, jnp.int32)) +>>> exp.in_avals +(ShapedArray(int32[a,b]),) +>>> exp.out_avals +(ShapedArray(int32[a,2*b]),) + +>>> # We can later call with concrete shapes (with a=3 and b=4), without re-tracing `f`. +>>> res = exp.call(np.ones((3, 4), dtype=np.int32)) +>>> res.shape +(3, 8) + +``` + +Note that such functions are still re-compiled on demand for +each concrete input shapes they are invoked on. Only the +tracing and the lowering are saved. + +The {func}`jax.export.symbolic_shape` is used in the above +example to parse a string representation of a symbolic shape +into dimension expressions objects (of type `_DimExpr`) that are usable in place of integer +constants to construct shapes. The dimension expression objects +overload most integer operators, so you can use them as +you'd use integer constants in most cases. +See {ref}`computing-with-dimension-variables` for more details. + +Additionally, we provide the {func}`jax.export.symbolic_args_specs` that +can be used to construct pytrees of `jax.ShapeDtypeStruct` objects based +on a polymorphic shape specification: + +```python +>>> def f1(x, y): # x: f32[a, 1], y : f32[a, 4] +... return x + y + +>>> # Assuming you have some actual args with concrete shapes +>>> x = np.ones((3, 1), dtype=np.int32) +>>> y = np.ones((3, 4), dtype=np.int32) +>>> args_specs = export.symbolic_args_specs((x, y), "a, ...") +>>> exp = export.export(jax.jit(f1))(* args_specs) +>>> exp.in_avals +(ShapedArray(int32[a,1]), ShapedArray(int32[a,4])) + +``` + +Note how the polymorphic shape specification `"a, ..."` contains +the placeholder `...` to be filled from the concrete shapes of +the concrete shapes of the arguments `(x, y)`. +The placeholder `...` stands for 0 or more dimensions, while the +placeholder `_` stands for one dimension. +The {func}`jax.export.symbolic_args_specs` supports pytrees of arguments, +which are used to fill-in the dtypes and any placeholders. +The function will construct a pytree of +argument specifications ({class}`jax.ShapeDtypeStruct`) +matching the structure of the arguments passed to it. +The polymorphic shapes specification can be a +pytree prefix in cases where one specification should apply +to multiple arguments, as in the above example. +See [how optional parameters are matched to arguments](https://jax.readthedocs.io/en/latest/pytrees.html#applying-optional-parameters-to-pytrees). + + +A few examples of shape specifications: + + * `("(b, _, _)", None)` can be used for a function with two arguments, the first + being a 3D array with a batch leading dimension that should be symbolic. + The other dimensions for the + first argument and the shape of the second argument are specialized based on the actual + arguments. Note that the same specification would work if the first + argument is a pytree of 3D arrays, all with the same leading dimension + but possibly with different trailing dimensions. + The value `None` for the second arugment means that the argument + is not symbolic. Equivalently, one can use `...`. + + * `("(batch, ...)", "(batch,)")` specifies that the two arguments + have matching leading dimensions, the first argument has rank at + least 1, and the second has rank 1. + +## Correctness of shape polymorphism + +We want to trust that the exported program produces the same results as the +original JAX program when compiled and executed for any applicable concrete shapes. +More precisely: + +For any JAX function `f` and any argument specification `arg_spec` containing a +symbolic shape, and any concrete argument `arg` whose shape matches `arg_spec`: + + * If the JAX native execution succeeds on the concrete argument: `res = f(arg)`, + * and if the exporting succeeds with symbolic shapes: `exp = export.export(f)(arg_spec)`, + * then compiling and running the export will succeed with the same result: `res == exp.call(arg)` + +It is crucial to understand that `f(arg)` has the freedom to re-invoke +the JAX tracing machinery, +and in fact it does so for each distinct concrete `arg` shape, +while the execution of `exp.call(arg)` cannot use JAX tracing anymore +(this execution may happen in an environment where the source code +of `f` is not available). + +Ensuring this form of correctness is hard, and in the hardest cases +exporting fails. The rest of this chapter describes how to handle these failures. + +(computing-with-dimension-variables)= + +## Computing with dimension variables + +JAX keeps track of the shapes of all intermediate results. When those shapes depend +on dimension variables JAX computes them as symbolic dimension expressions +involving dimension variables. +Dimension variables stand for integer values greater or equal to 1. +The symbolic expressions can represent the result +of applying arithmetic operators (add, sub, mul, floordiv, mod, +including the NumPy variants `np.sum`, `np.prod`, etc.) **on dimension +expressions and integers** (`int`, `np.int`, or anything convertible by `operator.index`). +These symbolic dimensions can then be used in shape-parameters of JAX primitives +and APIs, e.g., in `jnp.reshape`, `jnp.arange`, slicing indices, etc. + +For example, in the following code to flatten a 2D array, the computation +`x.shape[0] * x.shape[1]` computes the symbolic dimension `4 * b` as the +new shape: + +```python +>>> f = lambda x: jnp.reshape(x, (x.shape[0] * x.shape[1],)) +>>> arg_spec = jax.ShapeDtypeStruct(export.symbolic_shape("b, 4"), jnp.int32) +>>> exp = export.export(jax.jit(f))(arg_spec) +>>> exp.out_avals +(ShapedArray(int32[4*b]),) + +``` + +It is possible to convert dimension expressions explicitly +to JAX arrays, with `jnp.array(x.shape[0])` or even `jnp.array(x.shape)`. +The result of these operations can be used as regular JAX arrays, +bug cannot be used anymore as dimensions in shapes. + +```python +>>> exp = export.export(jax.jit(lambda x: jnp.array(x.shape[0]) + x))( +... jax.ShapeDtypeStruct(export.symbolic_shape("b"), np.int32)) +>>> exp.call(jnp.arange(3, dtype=np.int32)) +Array([3, 4, 5], dtype=int32) + +>>> exp = export.export(jax.jit(lambda x: x.reshape(jnp.array(x.shape[0]) + 2)))( +... jax.ShapeDtypeStruct(export.symbolic_shape("b"), np.int32)) # doctest: +IGNORE_EXCEPTION_DETAIL +Traceback (most recent call last): +TypeError: Shapes must be 1D sequences of concrete values of integer type, got [Tracedwith]. + +``` + +When a symbolic dimension is used in arithmetic operations with **non-integers**, +e.g., `float`, `np.float`, `np.ndarray`, or JAX arrays, it is automatically +converted to a JAX array using `jnp.array`. +For example, in the function below all occurrences of `x.shape[0]` +are converted implicitly to `jnp.array(x.shape[0])` because +they are involved in operations with non-integer scalars or with +JAX arrays: + +```python +>>> exp = export.export(jax.jit( +... lambda x: (5. + x.shape[0], +... x.shape[0] - np.arange(5, dtype=jnp.int32), +... x + x.shape[0] + jnp.sin(x.shape[0]))))( +... jax.ShapeDtypeStruct(export.symbolic_shape("b"), jnp.int32)) +>>> exp.out_avals +(ShapedArray(float32[], weak_type=True), + ShapedArray(int32[5]), + ShapedArray(float32[b], weak_type=True)) + +>>> exp.call(jnp.ones((3,), jnp.int32)) + (Array(8., dtype=float32, weak_type=True), + Array([ 3, 2, 1, 0, -1], dtype=int32), + Array([4.14112, 4.14112, 4.14112], dtype=float32, weak_type=True)) + +``` + +Another typical example is when computing averages +(observe how `x.shape[0]` is automatically turned into a JAX array): + +```python +>>> exp = export.export(jax.jit( +... lambda x: jnp.sum(x, axis=0) / x.shape[0]))( +... jax.ShapeDtypeStruct(export.symbolic_shape("b, c"), jnp.int32)) +>>> exp.call(jnp.arange(12, dtype=jnp.int32).reshape((3, 4))) +Array([4., 5., 6., 7.], dtype=float32) + +``` + +### Errors in presence of shape polymorphism + +Most JAX code assumes that the shapes of JAX arrays are tuples of integers, +but with shape polymorphism some dimensions may be symbolic expressions. +This can lead to a number of errors. For example, we can have the usual +JAX shape check errors: + +```python +>>> v, = export.symbolic_shape("v,") +>>> export.export(jax.jit(lambda x, y: x + y))( +... jax.ShapeDtypeStruct((v,), dtype=np.int32), +... jax.ShapeDtypeStruct((4,), dtype=np.int32)) +Traceback (most recent call last): +TypeError: add got incompatible shapes for broadcasting: (v,), (4,). + +>>> export.export(jax.jit(lambda x: jnp.matmul(x, x)))( +... jax.ShapeDtypeStruct((v, 4), dtype=np.int32)) +Traceback (most recent call last): +TypeError: dot_general requires contracting dimensions to have the same shape, got (4,) and (v,). + +``` + +We can fix the above matmul example by specifying that the +argument has shape `(v, v)`. + +### Comparison of symbolic dimensions is partially supported + +Inside JAX there are a number of equality and inequality comparisons +involving shapes, e.g., for doing shape checking or even for choosing +the implementation for some primitives. Comparisons are supported +as follows: + + * equality is supported with a caveat: if the two symbolic dimensions denote the same + value under all valuations for dimension variables, then equality evaluates to `True`, + e.g., for `b + b == 2*b`; otherwise the equality evaluates to `False`. + See [below](#caveat-for-equality-comparisons) + for a discussion of important consequences of this behavior. + * disequality is always the negation of equality. + * inequality is partially supported, in a similar way as partial equality. + However, in this + case we take into consideration that dimension variables range over strictly positive + integers. E.g., `b >= 1`, `b >= 0`, `2 * a + b >= 3` are `True`, while `b >= 2`, + `a >= b`, `a - b >= 0` are inconclusive and result in an exception. + +In cases where a comparison operation cannot be resolve to a boolean, +we raise {class}`InconclusiveDimensionOperation`. E.g., + +```python +import jax +>>> export.export(jax.jit(lambda x: 0 if x.shape[0] + 1 >= x.shape[1] else 1))( +... jax.ShapeDtypeStruct(export.symbolic_shape("a, b"), dtype=np.int32)) # doctest: +IGNORE_EXCEPTION_DETAIL +Traceback (most recent call last): +jax._src.export.shape_poly.InconclusiveDimensionOperation: Symbolic dimension comparison 'a + 1' >= 'b' is inconclusive. +This error arises for comparison operations with shapes that +are non-constant, and the result of the operation cannot be represented as +a boolean value for all values of the symbolic dimensions involved. + +``` + +If you do get a `InconclusiveDimensionOperation`, you can try +several strategies: + + * If your code uses the built-in `max` or `min`, or the + `np.max` or `np.min` then you can replace those with + `core.max_dim` and `core.min_dim`, which have the effect + of delaying the inequality comparison to the compilation + time, when shapes become known. + * Try to rewrite conditionals using `core.max_dim` and + `core.min_dim`, e.g., instead of `d if d > 0 else 0` + you can write `core.max_dim(d, 0)`. + * Try to rewrite the code to be less dependent on the fact + that dimensions should be integers, and rely on the fact + that symbolic dimensions duck-type as integers for most + arithmetic operations. E.g., instead of `int(d) + 5` write + `d + 5`. + * Specify symbolic constraints, as explained below. + +#### User-specified symbolic constraints + +By default, JAX assumes that all dimension variables range +over values greater-or-equal to 1, and it tries to derive +other simple inequalities from that, e.g.: + + * `a + 2 >= 3`, + * `a * 2 >= 1`, + * `a + b + c >= 3`, + * `a // 4 >= 0`, `a**2 >= 1`, and so on. + +You can avoid some inequality comparison failures if you +change the symbolic shape specifications to add **implicit** constraints +for dimension sizes. E.g., + + * You can use `2*b` for a dimension to constrain it to be even and greater or equal + to 2. + * You can use `b + 15` for a dimension to constrain it to + be at least 16. E.g., the following code would fail without + the `+ 15` part, because JAX will want to verify that slice sizes + are at most as large as the axis size. + +```python +>>> _ = export.export(jax.jit(lambda x: x[0:16]))( +... jax.ShapeDtypeStruct(export.symbolic_shape("b + 15"), dtype=np.int32)) + +``` + +Such implicit symbolic constraints are used for deciding comparisons and are +checked at compile time, as explained [below](#shape-assertion-errors). + +You can also specify **explicit** symbolic constraints: + +```python +>>> # Introduce dimension variable with constraints. +>>> a, b = export.symbolic_shape("a, b", +... constraints=("a >= b", "b >= 16")) +>>> _ = export.export(jax.jit(lambda x: x[:x.shape[1], :16]))( +... jax.ShapeDtypeStruct((a, b), dtype=np.int32)) + +``` + +The constraints form a conjunction together with the implicit +constraints. You can specify `>=`, `<=`, and `==` constraints. +At the moment, JAX has limited support for reasoning with +symbolic constraints: + + * You get the most from constraints of the form + of a variable being greater-or-equal or + less-or-equal to a constant. + For example, from the constraints that + `a >= 16` and `b >= 8` we can infer + that `a + 2*b >= 32`. + * You get limited power when the constraint involves + more complex expressions, e.g., from `a >= b + 8` we + can infer that `a - b >= 8` but not that `a >= 9`. + We may improve somewhat this area in the future. + * Equality constraints are treated as normalization rules. + E.g., `floordiv(a, b) = c` works by replacing all + occurences of the left-hand-side with the right-hand-side. + You can only have equality constraints where the left-hand-side + is a multiplication of factors, e.g, `a * b`, or `4 * a`, or + `floordiv(a, b)`. Thus, the left-hand-side cannot contain + addition or subtraction at the top-level. + +The symbolic constraints can also help to work around the +limitations in the JAX reasoning mechanisms. +For example, in the code below JAX will attempt to prove that +the slice size `x.shape[0] % 3`, which is the symbolic expression +`mod(b, 3)`, is less or equal to the axis size, which is `b`. +This happens to be true for all strictly positive values of +`b`, but it is not something JAX's symbolic comparison rules +can prove. Hence the following code raises an error: + +```python +from jax import lax +>>> b, = export.symbolic_shape("b") +>>> f = lambda x: lax.slice_in_dim(x, 0, x.shape[0] % 3) +>>> export.export(jax.jit(f))( +... jax.ShapeDtypeStruct((b,), dtype=np.int32)) # doctest: +IGNORE_EXCEPTION_DETAIL +Traceback (most recent call last): +jax._src.export.shape_poly.InconclusiveDimensionOperation: Symbolic dimension comparison 'b' >= 'mod(b, 3)' is inconclusive. +This error arises for comparison operations with shapes that +are non-constant, and the result of the operation cannot be represented as +a boolean value for all values of the symbolic dimensions involved. + +``` + +One option here would be to restrict the code to work only on +axis sizes that are multiple of `3` (by replacing +`b` with `3*b` in the shape). Then, JAX would be able +to simplify the modulo operation `mod(3*b, 3)` to `0`. +Another option is to add a symbolic constraint +with the exact inconclusive inequality that JAX +is attempting to prove: + +```python +>>> b, = export.symbolic_shape("b", +... constraints=["b >= mod(b, 3)"]) +>>> f = lambda x: lax.slice_in_dim(x, 0, x.shape[0] % 3) +>>> _ = export.export(jax.jit(f))( +... jax.ShapeDtypeStruct((b,), dtype=np.int32)) + +``` + +Just like the implicit constraints, the explicit +symbolic constraints are checked at compile time, +using the same mechanism as explained [below](#shape-assertion-errors). + +#### Symbolic dimension scopes + +The symbolic constraints are stored in αn +{class}`jax.export.SymbolicScope` object, which is created implicitly +for each call to {func}`jax.export.symbolic_shapes`. You must be careful +to not mix symbolic expressions that use different scopes. +For example, +the following code will fail because `a1` and `a2` +use different scopes (created by different invocations of +{func}`jax.export.symbolic_shape`): + +```python +>>> a1, = export.symbolic_shape("a,") +>>> a2, = export.symbolic_shape("a,", constraints=("a >= 8",)) + +>>> a1 + a2 # doctest: +IGNORE_EXCEPTION_DETAIL +Traceback (most recent call last): +ValueError: Invalid mixing of symbolic scopes for linear combination. +Expected scope 4776451856 created at :1:6 () +and found for 'a' (unknown) scope 4776979920 created at :1:6 () with constraints: + a >= 8 +``` + +The symbolic expressions that originate from a single call +to {func}`jax.export.symbolic_shape` share a scope and +can be mixed up in arithmetic operations. The result would +also share the same scope. + +You can re-use scopes: + +```python +>>> a, = export.symbolic_shape("a,", constraints=("a >= 8",)) +>>> b, = export.symbolic_shape("b,", scope=a.scope) # Reuse the scope of `a` + +>>> a + b # Allowed +b + a + +``` + +You can also create scopes explicitly: + +```python +>>> my_scope = export.SymbolicScope() +>>> c, = export.symbolic_shape("c", scope=my_scope) +>>> d, = export.symbolic_shape("d", scope=my_scope) +>>> c + d # Allowed +d + c + +``` + +JAX tracing uses caches keyed partially by shapes, and +symbolic shapes that are printed identically will be considered +distinct if they use different scopes. + +### Caveat for equality comparisons + +The equality comparison returns `False` for `b + 1 == b` or `b == 0` +(in which case it is certain that the dimensions are different for all values +of the dimension variables), +but also for `b == 1` and for `a == b`. This is unsound, and we +ought to raise `core.InconclusiveDimensionOperation` because under +some valuations the result should be `True` and under other +valuations it should be `False`. We choose to make equality total +thus allowing unsoundness because otherwise we may get spurious errors +in presence of hash collisions +when hashing dimension expressions or objects that include +them (shapes, `core.AbstractValue`, `core.Jaxpr`). +Besides the hashing errors, a partial semantics of equality +leads to errors for the following expressions `b == a or b == b` or `b in [a, b]` +even though the error is avoided if we change the order of the comparisons. + +Code of the form `if x.shape[0] != 1: raise NiceErrorMessage` is sound even +with this treatment of equality, but code of the form `if x.shape[0] != 1: return 1` +is unsound. + +### Dimension variables must be solvable from the input shapes + +Currently, the only way to pass the values of dimension variables +when an exported object is invoked is indirectly through the shapes +of the array arguments. E.g., the value of `b` can be inferred at the +call site from the shape of the first argument of type `f32[b]`. +This works well for most use cases, and +it mirrors the calling convention of JIT functions. + +Sometimes you may want to export a function parameterized +by an integer values that determines some shapes in the program. +For example, we may +want to export the function `my_top_k` defined below, +parameterized by the +value of `k`, which determined the shape of the result. +The following attempt will lead to an error since the dimension +variable `k` cannot be derived from the shape of the input `x: i32[4, 10]`: + +```python +>>> def my_top_k(k, x): # x: i32[4, 10], k <= 10 +... return lax.top_k(x, k)[0] # : i32[4, 3] +>>> x = np.arange(40, dtype=np.int32).reshape((4, 10)) + +>>> # Export with static `k=3`. Since `k` appears in shapes it must be in `static_argnums`. +>>> exp_static_k = export.export(jax.jit(my_top_k, static_argnums=0))(3, x) +>>> exp_static_k.in_avals[0] +ShapedArray(int32[4,10]) + +>>> exp_static_k.out_avals[0] +ShapedArray(int32[4,3]) + +>>> # When calling the exported function we pass only the non-static arguments +>>> exp_static_k.call(x) +Array([[ 9, 8, 7], + [19, 18, 17], + [29, 28, 27], + [39, 38, 37]], dtype=int32) + +>>> # Now attempt to export with symbolic `k` so that we choose `k` after export. +>>> k, = export.symbolic_shape("k", constraints=["k <= 10"]) +>>> export.export(jax.jit(my_top_k, static_argnums=0))(k, x) # doctest: +IGNORE_EXCEPTION_DETAIL +Traceback (most recent call last): +KeyError: "Encountered dimension variable 'k' that is not appearing in the shapes of the function arguments + +``` + +In the future, we may add an additional mechanism to pass the values of +dimension variables, besides implicitly through the input shapes. +Meanwhile, the workaround for the above use case is to replace the +function parameter `k` with an array of shape `(0, k)`, so that +`k` can be derived from the input shape of an array. +The first dimension is 0 to ensure that the whole array is empty +and there is no performance penalty when we call the exported function. + +```python +>>> def my_top_k_with_dimensions(dimensions, x): # dimensions: i32[0, k], x: i32[4, 10] +... return my_top_k(dimensions.shape[1], x) +>>> exp = export.export(jax.jit(my_top_k_with_dimensions))( +... jax.ShapeDtypeStruct((0, k), dtype=np.int32), +... x) +>>> exp.in_avals +(ShapedArray(int32[0,k]), ShapedArray(int32[4,10])) + +>>> exp.out_avals[0] +ShapedArray(int32[4,k]) + +>>> # When we invoke `exp` we must construct and pass an array of shape (0, k) +>>> exp.call(np.zeros((0, 3), dtype=np.int32), x) +Array([[ 9, 8, 7], + [19, 18, 17], + [29, 28, 27], + [39, 38, 37]], dtype=int32) + +``` + +Another situation when you may get an error is when some dimension +variables do appear in the input shapes, but in a non-linear +expression that JAX cannot currently solve: + +```python +>>> a, = export.symbolic_shape("a") +>>> export.export(jax.jit(lambda x: x.shape[0]))( +... jax.ShapeDtypeStruct((a * a,), dtype=np.int32)) # doctest: +IGNORE_EXCEPTION_DETAIL +Traceback (most recent call last): +ValueError: Cannot solve for values of dimension variables {'a'}. +We can only solve linear uni-variate constraints. +Using the following polymorphic shapes specifications: args[0].shape = (a^2,). +Unprocessed specifications: 'a^2' for dimension size args[0].shape[0]. + +``` + +### Shape assertion errors + +JAX assumes that dimension variables range over strictly positive integers, +and this assumption is checked when the code is compiled for concrete +input shapes. + +For example, given the symbolic input shape `(b, b, 2*d)`, +JAX will generate code to check the following assertions when +invoked with actual argument `arg`: + + * `arg.shape[0] >= 1` + * `arg.shape[1] == arg.shape[0]` + * `arg.shape[2] % 2 == 0` + * `arg.shape[2] // 2 >= 1` + +For example, here is the error we get when we call the exported +on an argument of shape `(3, 3, 5)`: + +```python +>>> def f(x): # x: f32[b, b, 2*d] +... return x +>>> exp = export.export(jax.jit(f))( +... jax.ShapeDtypeStruct(export.symbolic_shape("b, b, 2*d"), dtype=np.int32)) +>>> exp.call(np.ones((3, 3, 5), dtype=np.int32)) # doctest: +IGNORE_EXCEPTION_DETAIL +Traceback (most recent call last): +ValueError: Input shapes do not match the polymorphic shapes specification. +Division had remainder 1 when computing the value of 'd'. +Using the following polymorphic shapes specifications: + args[0].shape = (b, b, 2*d). +Obtained dimension variables: 'b' = 3 from specification 'b' for dimension args[0].shape[0] (= 3), . +Please see https://jax.readthedocs.io/en/latest/export/shape_poly.html#shape-assertion-errors for more details. + +``` + +These errors arise in a pre-processing step before the +compilation. + +### Division of symbolic dimensions is partially supported + +JAX will attempt to simplify division and modulo operations, +e.g., `(a * b + a) // (b + 1) == a` and `6*a + 4 % 3 == 1`. +In particular, JAX will handle the cases when either (a) there +is no remainder, or (b) the divisor is a constant +in which case there may be a constant remainder. + +For example, the code below results in a division error when trying to +compute the inferred dimension for a `reshape` operation: + +```python +>>> b, = export.symbolic_shape("b") +>>> export.export(jax.jit(lambda x: x.reshape((2, -1))))( +... jax.ShapeDtypeStruct((b,), dtype=np.int32)) +Traceback (most recent call last): +jax._src.core.InconclusiveDimensionOperation: Cannot divide evenly the sizes of shapes (b,) and (2, -1). +The remainder mod(b, - 2) should be 0. + +``` + +Note that the following will succeed: + +```python +>>> b, = export.symbolic_shape("b") +>>> # We specify that the first dimension is a multiple of 4 +>>> exp = export.export(jax.jit(lambda x: x.reshape((2, -1))))( +... jax.ShapeDtypeStruct((4*b,), dtype=np.int32)) +>>> exp.out_avals +(ShapedArray(int32[2,2*b]),) + +>>> # We specify that some other dimension is even +>>> exp = export.export(jax.jit(lambda x: x.reshape((2, -1))))( +... jax.ShapeDtypeStruct((b, 5, 6), dtype=np.int32)) +>>> exp.out_avals +(ShapedArray(int32[2,15*b]),) + +``` + +(shape_poly_debugging)= +## Debugging + +First, see the {ref}`export_debugging` documentation. +Additionally, you can debug the shape refinement, which is +invoked at compilation time for modules that have dimension variables or multi-platform +support. + +If there is an error during shape refinement, you can set the `JAX_DUMP_IR_TO` +environment variable to see a dump of the HLO module before +shape refinement (named `..._before_refine_polymorphic_shapes.mlir`). +This module should already have static input shapes. + +To enable the logging of all stages of shape refinement you can set the +environment variable `TF_CPP_VMODULE=refine_polymorphic_shapes=3` in OSS +(inside Google, you pass `--vmodule=refine_polymorphic_shapes=3`): + +```shell +# Log from python +JAX_DUMP_IR_TO=/tmp/export.dumps/ TF_CPP_VMODULE=refine_polymorphic_shapes=3 python tests/shape_poly_test.py ShapePolyTest.test_simple_unary -v=3 +``` diff --git a/docs/faq.rst b/docs/faq.rst index 7262bda736b1..3b63128d2c28 100644 --- a/docs/faq.rst +++ b/docs/faq.rst @@ -6,7 +6,7 @@ JAX Frequently Asked Questions (FAQ) .. _JAX - The Sharp Bits: https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html -We are collecting here answers to frequently asked questions. +We are collecting answers to frequently asked questions here. Contributions welcome! ``jit`` changes the behavior of my function @@ -337,7 +337,7 @@ referred to as being *sticky* to the device). By default, JAX arrays are placed uncommitted on the default device (``jax.devices()[0]``), which is the first GPU or TPU by default. If no GPU or TPU is present, ``jax.devices()[0]`` is the CPU. The default device can -temporarily overridden with the :func:`jax.default_device` context manager, or +be temporarily overridden with the :func:`jax.default_device` context manager, or set for the whole process by setting the environment variable ``JAX_PLATFORMS`` or the absl flag ``--jax_platforms`` to "cpu", "gpu", or "tpu" (``JAX_PLATFORMS`` can also be a list of platforms, which determines which @@ -413,7 +413,7 @@ speed of code using JAX: use 32-bit dtypes in NumPy or enable 64-bit dtypes in JAX (see `Double (64 bit) precision`_) for a fair comparison. 4. **Transferring data between CPUs and accelerators takes time.** If you only - want to measure the how long it takes to evaluate a function, you may want to + want to measure how long it takes to evaluate a function, you may want to transfer data to the device on which you want to run it first (see :ref:`faq-data-placement`). @@ -686,7 +686,7 @@ The inner ``jnp.where`` may be needed in addition to the original one, e.g.:: def my_log_or_y(x, y): """Return log(x) if x > 0 or y""" - return jnp.where(x > 0., jnp.log(jnp.where(x > 0., x, 1.), y) + return jnp.where(x > 0., jnp.log(jnp.where(x > 0., x, 1.)), y) Additional reading: @@ -814,6 +814,32 @@ computation at runtime. For example: For more information on runtime callbacks and examples of their use, see `External callbacks in JAX`_. +Why do some CUDA libraries fail to load/initialize? +--------------------------------------------------- + +When resolving dynamic libraries, JAX uses the usual `dynamic linker search pattern`_. +JAX sets :code:`RPATH` to point to the JAX-relative location of the +pip-installed NVIDIA CUDA packages, preferring them if installed. If :code:`ld.so` +cannot find your CUDA runtime libraries along its usual search path, then you +must include the paths to those libraries explicitly in :code:`LD_LIBRARY_PATH`. +The easiest way to ensure your CUDA files are discoverable is to simply install +the :code:`nvidia-*-cu12` pip packages, which are included in the standard +:code:`jax[cuda_12]` install option. + +Occasionally, even when you have ensured that your runtime libraries are discoverable, +there may still be some issues with loading or initializing them. A common cause of +such issues is simply having insufficient memory for CUDA library initialization at +runtime. This sometimes occurs because JAX will pre-allocate too large of a chunk of +currently available device memory for faster execution, occasionally resulting in +insufficient memory being left available for runtime CUDA library initialization. + +This is especially likely when running multiple JAX instances, running JAX in +tandem with TensorFlow which performs its own pre-allocation, or when running +JAX on a system where the GPU is being heavily utilized by other processes. When +in doubt, try running the program again with reduced pre-allocation, either by +reducing :code:`XLA_PYTHON_CLIENT_MEM_FRACTION` from the default of :code:`.75`, +or setting :code:`XLA_PYTHON_CLIENT_PREALLOCATE=false`. For more details, please +see the page on `JAX GPU memory allocation`_. .. _JIT mechanics: https://jax.readthedocs.io/en/latest/notebooks/thinking_in_jax.html#jit-mechanics-tracing-and-static-variables .. _External callbacks in JAX: https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html @@ -822,3 +848,5 @@ see `External callbacks in JAX`_. .. _Heaviside Step Function: https://en.wikipedia.org/wiki/Heaviside_step_function .. _Sigmoid Function: https://en.wikipedia.org/wiki/Sigmoid_function .. _algebraic_simplifier.cc: https://github.com/tensorflow/tensorflow/blob/v2.10.0/tensorflow/compiler/xla/service/algebraic_simplifier.cc#L3266 +.. _JAX GPU memory allocation: https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html +.. _dynamic linker search pattern: https://man7.org/linux/man-pages/man8/ld.so.8.html diff --git a/docs/glossary.rst b/docs/glossary.rst index 34a9d1de7268..78b7fcd246f3 100644 --- a/docs/glossary.rst +++ b/docs/glossary.rst @@ -3,6 +3,9 @@ JAX Glossary of Terms .. glossary:: + Array + JAX's analog of :class:`numpy.ndarray`. See :class:`jax.Array`. + CPU Short for *Central Processing Unit*, CPUs are the standard computational architecture available in most computers. JAX can run computations on CPUs, but often can achieve @@ -12,9 +15,6 @@ JAX Glossary of Terms The generic name used to refer to the :term:`CPU`, :term:`GPU`, or :term:`TPU` used by JAX to perform computations. - DeviceArray - JAX's analog of the :class:`numpy.ndarray`. See :class:`jaxlib.xla_extension.DeviceArray`. - forward-mode autodiff See :term:`JVP` @@ -30,7 +30,7 @@ JAX Glossary of Terms jaxpr Short for *JAX Expression*, a jaxpr is an intermediate representation of a computation that is generated by JAX, and is forwarded to :term:`XLA` for compilation and execution. - See :ref:`understanding-jaxprs` for more information. + See :ref:`understanding-jaxprs` for more discussion and examples. JIT Short for *Just In Time* compilation, JIT in JAX generally refers to the compilation of @@ -41,11 +41,21 @@ JAX Glossary of Terms differentiation. For more details, see :ref:`jacobian-vector-product`. In JAX, JVP is a :term:`transformation` that is implemented via :func:`jax.jvp`. See also :term:`VJP`. + primitive + A primitive is a fundamental unit of computation used in JAX programs. Most functions + in :mod:`jax.lax` represent individual primitives. When representing a computation in + a :term:`jaxpr`, each operation in the jaxpr is a primitive. + pure function A pure function is a function whose outputs are based only on its inputs, and which has no side-effects. JAX's :term:`transformation` model is designed to work with pure functions. See also :term:`functional programming`. + pytree + A pytree is an abstraction that lets JAX handle tuples, lists, dicts, and other more + general containers of array values in a uniform way. Refer to :ref:`working-with-pytrees` + for a more detailed discussion. + reverse-mode autodiff See :term:`VJP`. @@ -65,7 +75,7 @@ JAX Glossary of Terms fast operations on arrays (see also :term:`CPU` and :term:`GPU`). Tracer - An object used as a standin for a JAX :term:`DeviceArray` in order to determine the + An object used as a standin for a JAX :term:`Array` in order to determine the sequence of operations performed by a Python function. Internally, JAX implements this via the :class:`jax.core.Tracer` class. diff --git a/docs/gpu_ops/gpu_ops.cpp b/docs/gpu_ops/gpu_ops.cpp new file mode 100644 index 000000000000..0684f752edfd --- /dev/null +++ b/docs/gpu_ops/gpu_ops.cpp @@ -0,0 +1,45 @@ +/* Copyright 2024 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "kernels.h" +#include "pybind11_kernel_helpers.h" + +namespace { +pybind11::dict RMSNormRegistrations() { + pybind11::dict dict; + dict["rms_forward_affine_mixed_dtype"] = + gpu_ops::EncapsulateFunction(gpu_ops::rms_forward_affine_mixed_dtypes); + dict["rms_backward_affine"] = + gpu_ops::EncapsulateFunction(gpu_ops::rms_backward_affine); + return dict; +} + +PYBIND11_MODULE(gpu_ops, m) { + m.def("get_rms_norm_registrations", &RMSNormRegistrations); + m.def("create_rms_norm_descriptor", + [](int n1, int n2, double eps, gpu_ops::ElementType x_type, + gpu_ops::ElementType w_type, int part_grad_size) { + return gpu_ops::PackDescriptor(gpu_ops::RMSNormDescriptor{ + n1, n2, eps, x_type, w_type, part_grad_size}); + }); + + pybind11::enum_(m, "ElementType") + .value("BF16", gpu_ops::ElementType::BF16) + .value("F16", gpu_ops::ElementType::F16) + .value("F32", gpu_ops::ElementType::F32) + .value("F64", gpu_ops::ElementType::F64); + +} +} // namespace diff --git a/docs/gpu_ops/kernel_helpers.h b/docs/gpu_ops/kernel_helpers.h new file mode 100644 index 000000000000..0c146b38209d --- /dev/null +++ b/docs/gpu_ops/kernel_helpers.h @@ -0,0 +1,64 @@ +/* Copyright 2024 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This header is not specific to our application and you'll probably want +// something like this for any extension you're building. This includes the +// infrastructure needed to serialize descriptors that are used with the +// "opaque" parameter of the GPU custom call. In our example we'll use this +// parameter to pass the size of our problem. + +#ifndef _GPU_OPS_KERNEL_HELPERS_H_ +#define _GPU_OPS_KERNEL_HELPERS_H_ + +#include +#include +#include +#include + +#define JAX_APEX_WARP_SIZE 32 + +namespace gpu_ops { + +// https://en.cppreference.com/w/cpp/numeric/bit_cast +template +typename std::enable_if::value && + std::is_trivially_copyable::value, + To>::type +bit_cast(const From &src) noexcept { + static_assert(std::is_trivially_constructible::value, + "This implementation additionally requires destination type to " + "be trivially constructible"); + + To dst; + memcpy(&dst, &src, sizeof(To)); + return dst; +} + +template std::string PackDescriptorAsString(const T &descriptor) { + return std::string(bit_cast(&descriptor), sizeof(T)); +} + +template +const T *UnpackDescriptor(const char *opaque, std::size_t opaque_len) { + if (opaque_len != sizeof(T)) { + throw std::runtime_error("Invalid opaque object size"); + } + return bit_cast(opaque); +} + +} // namespace gpu_ops + +#endif diff --git a/docs/gpu_ops/kernels.h b/docs/gpu_ops/kernels.h new file mode 100644 index 000000000000..18207bbd5345 --- /dev/null +++ b/docs/gpu_ops/kernels.h @@ -0,0 +1,44 @@ +/* Copyright 2024 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef _GPU_OPS_KERNELS_H_ +#define _GPU_OPS_KERNELS_H_ + +#include + +#include +#include + +namespace gpu_ops { + +enum ElementType { BF16, F16, F32, F64 }; + +struct RMSNormDescriptor { + int n1; + int n2; + double eps; + ElementType x_type; + ElementType w_type; + int part_grad_size; +}; + +void rms_forward_affine_mixed_dtypes(cudaStream_t stream, void **buffers, + const char *opaque, + std::size_t opaque_len); +void rms_backward_affine(cudaStream_t stream, void **buffers, + const char *opaque, std::size_t opaque_len); +} // namespace gpu_ops + +#endif diff --git a/docs/gpu_ops/pybind11_kernel_helpers.h b/docs/gpu_ops/pybind11_kernel_helpers.h new file mode 100644 index 000000000000..248ffb145616 --- /dev/null +++ b/docs/gpu_ops/pybind11_kernel_helpers.h @@ -0,0 +1,41 @@ +/* Copyright 2024 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This header extends kernel_helpers.h with the pybind11 specific interface to +// serializing descriptors. It also adds a pybind11 function for wrapping our +// custom calls in a Python capsule. This is separate from kernel_helpers so +// that the CUDA code itself doesn't include pybind11. I don't think that this +// is strictly necessary, but they do it in jaxlib, so let's do it here too. + +#ifndef _GPU_OPS_PYBIND11_KERNEL_HELPERS_H_ +#define _GPU_OPS_PYBIND11_KERNEL_HELPERS_H_ + +#include + +#include "kernel_helpers.h" + +namespace gpu_ops { + +template pybind11::bytes PackDescriptor(const T &descriptor) { + return pybind11::bytes(PackDescriptorAsString(descriptor)); +} + +template pybind11::capsule EncapsulateFunction(T *fn) { + return pybind11::capsule(bit_cast(fn), "xla._CUSTOM_CALL_TARGET"); +} + +} // namespace gpu_ops + +#endif diff --git a/docs/gpu_ops/rms_norm_kernels.cu b/docs/gpu_ops/rms_norm_kernels.cu new file mode 100644 index 000000000000..7622ddc08772 --- /dev/null +++ b/docs/gpu_ops/rms_norm_kernels.cu @@ -0,0 +1,970 @@ +/* Copyright 2024 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "kernel_helpers.h" +#include "kernels.h" +#include "stdio.h" +#include +#include +#include + +namespace { + +#define DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, \ + NAME, ...) \ + switch (TYPEIN) { \ + case gpu_ops::ElementType::F64: { \ + using scalar_t_in = double; \ + using accscalar_t = double; \ + switch (TYPEOUT) { \ + case gpu_ops::ElementType::F64: { \ + using scalar_t_out = double; \ + __VA_ARGS__; \ + break; \ + } \ + case gpu_ops::ElementType::F32: { \ + using scalar_t_out = float; \ + __VA_ARGS__; \ + break; \ + } \ + case gpu_ops::ElementType::F16: { \ + using scalar_t_out = __half; \ + __VA_ARGS__; \ + break; \ + } \ + case gpu_ops::ElementType::BF16: { \ + using scalar_t_out = __nv_bfloat16; \ + __VA_ARGS__; \ + break; \ + } \ + default: \ + break; \ + } \ + break; \ + } \ + case gpu_ops::ElementType::F32: { \ + using scalar_t_in = float; \ + using accscalar_t = float; \ + switch (TYPEOUT) { \ + case gpu_ops::ElementType::F64: { \ + using scalar_t_out = double; \ + __VA_ARGS__; \ + break; \ + } \ + case gpu_ops::ElementType::F32: { \ + using scalar_t_out = float; \ + __VA_ARGS__; \ + break; \ + } \ + case gpu_ops::ElementType::F16: { \ + using scalar_t_out = __half; \ + __VA_ARGS__; \ + break; \ + } \ + case gpu_ops::ElementType::BF16: { \ + using scalar_t_out = __nv_bfloat16; \ + __VA_ARGS__; \ + break; \ + } \ + default: \ + break; \ + } \ + break; \ + } \ + case gpu_ops::ElementType::F16: { \ + using scalar_t_in = __half; \ + using accscalar_t = float; \ + switch (TYPEOUT) { \ + case gpu_ops::ElementType::F64: { \ + using scalar_t_out = double; \ + __VA_ARGS__; \ + break; \ + } \ + case gpu_ops::ElementType::F32: { \ + using scalar_t_out = float; \ + __VA_ARGS__; \ + break; \ + } \ + case gpu_ops::ElementType::F16: { \ + using scalar_t_out = __half; \ + __VA_ARGS__; \ + break; \ + } \ + case gpu_ops::ElementType::BF16: { \ + using scalar_t_out = __nv_bfloat16; \ + __VA_ARGS__; \ + break; \ + } \ + default: \ + break; \ + } \ + break; \ + } \ + case gpu_ops::ElementType::BF16: { \ + using scalar_t_in = __nv_bfloat16; \ + using accscalar_t = float; \ + switch (TYPEOUT) { \ + case gpu_ops::ElementType::F64: { \ + using scalar_t_out = double; \ + __VA_ARGS__; \ + break; \ + } \ + case gpu_ops::ElementType::F32: { \ + using scalar_t_out = float; \ + __VA_ARGS__; \ + break; \ + } \ + case gpu_ops::ElementType::F16: { \ + using scalar_t_out = __half; \ + __VA_ARGS__; \ + break; \ + } \ + case gpu_ops::ElementType::BF16: { \ + using scalar_t_out = __nv_bfloat16; \ + __VA_ARGS__; \ + break; \ + } \ + default: \ + break; \ + } \ + break; \ + } \ + default: \ + break; \ + } + +template +__device__ void cuWelfordOnlineSum(const U curr, U &mu, U &sigma2, U &count) { + count = count + U(1); + U delta = curr - mu; + U lmean = mu + delta / count; + mu = lmean; + U delta2 = curr - lmean; + sigma2 = sigma2 + delta * delta2; +} + +template +__device__ void cuChanOnlineSum(const U muB, const U sigma2B, const U countB, + U &mu, U &sigma2, U &count) { + U delta = muB - mu; + U nA = count; + U nB = countB; + count = count + countB; + U nX = count; + if (nX > U(0)) { + nA = nA / nX; + nB = nB / nX; + mu = nA * mu + nB * muB; + sigma2 = sigma2 + sigma2B + delta * delta * nA * nB * nX; + } else { + mu = U(0); + sigma2 = U(0); + } +} + +template __device__ void cuRMSOnlineSum(const U curr, U &sigma2) { + sigma2 = sigma2 + curr * curr; +} + +template +__device__ void cuChanRMSOnlineSum(const U sigma2B, U &sigma2) { + sigma2 = sigma2 + sigma2B; +} + +template +__device__ void cuWelfordMuSigma2(const T *__restrict__ vals, const int n1, + const int n2, const int i1, U &mu, U &sigma2, + U *buf, bool rms_only) { + // Assumptions: + // 1) blockDim.x == warpSize + // 2) Tensor is contiguous + // 3) 2*blockDim.y*sizeof(U)+blockDim.y*sizeof(int) shared memory available. + // + // compute variance and mean over n2 + U count = U(0); + mu = U(0); + sigma2 = U(0); + if (i1 < n1) { + // one warp normalizes one n1 index, + // synchronization is implicit + // initialize with standard Welford algorithm + const int numx = blockDim.x * blockDim.y; + const int thrx = threadIdx.x + threadIdx.y * blockDim.x; + const T *lvals = vals + i1 * n2; + int l = 4 * thrx; + for (; l + 3 < n2; l += 4 * numx) { + for (int k = 0; k < 4; ++k) { + U curr = static_cast(lvals[l + k]); + if (!rms_only) { + cuWelfordOnlineSum(curr, mu, sigma2, count); + } else { + cuRMSOnlineSum(curr, sigma2); + } + } + } + for (; l < n2; ++l) { + U curr = static_cast(lvals[l]); + if (!rms_only) { + cuWelfordOnlineSum(curr, mu, sigma2, count); + } else { + cuRMSOnlineSum(curr, sigma2); + } + } + // intra-warp reductions + for (int l = 0; l <= 4; ++l) { + int srcLaneB = (threadIdx.x + (1 << l)) & 31; + U sigma2B = __shfl_sync(0xffffffff, sigma2, srcLaneB, warpSize); + if (!rms_only) { + U muB = __shfl_sync(0xffffffff, mu, srcLaneB, warpSize); + U countB = __shfl_sync(0xffffffff, count, srcLaneB, warpSize); + cuChanOnlineSum(muB, sigma2B, countB, mu, sigma2, count); + } else { + cuChanRMSOnlineSum(sigma2B, sigma2); + } + } + // threadIdx.x == 0 has correct values for each warp + // inter-warp reductions + if (blockDim.y > 1) { + U *ubuf = (U *)buf; + U *ibuf = (U *)(ubuf + blockDim.y); + for (int offset = blockDim.y / 2; offset > 0; offset /= 2) { + // upper half of warps write to shared + if (threadIdx.x == 0 && threadIdx.y >= offset && + threadIdx.y < 2 * offset) { + const int wrt_y = threadIdx.y - offset; + if (!rms_only) { + ubuf[2 * wrt_y] = mu; + ibuf[wrt_y] = count; + } + ubuf[2 * wrt_y + 1] = sigma2; + } + __syncthreads(); + // lower half merges + if (threadIdx.x == 0 && threadIdx.y < offset) { + U sigma2B = ubuf[2 * threadIdx.y + 1]; + if (!rms_only) { + U muB = ubuf[2 * threadIdx.y]; + U countB = ibuf[threadIdx.y]; + cuChanOnlineSum(muB, sigma2B, countB, mu, sigma2, count); + } else { + cuChanRMSOnlineSum(sigma2B, sigma2); + } + } + __syncthreads(); + } + // threadIdx.x = 0 && threadIdx.y == 0 only thread that has correct values + if (threadIdx.x == 0 && threadIdx.y == 0) { + if (!rms_only) { + ubuf[0] = mu; + } + ubuf[1] = sigma2; + } + __syncthreads(); + if (!rms_only) { + mu = ubuf[0]; + } + sigma2 = ubuf[1] / U(n2); + // don't care about final value of count, we know count == n2 + } else { + if (!rms_only) { + mu = __shfl_sync(0xffffffff, mu, 0, warpSize); + } + sigma2 = __shfl_sync(0xffffffff, sigma2 / U(n2), 0, warpSize); + } + } +} + +template <> +__device__ void cuWelfordMuSigma2(const __half *__restrict__ vals, const int n1, + const int n2, const int i1, float &mu, + float &sigma2, float *buf, bool rms_only) { + // Assumptions: + // 1) blockDim.x == warpSize + // 2) Tensor is contiguous + // 3) 2*blockDim.y*sizeof(U)+blockDim.y*sizeof(int) shared memory available. + // + // compute variance and mean over n2 + float count = 0.0f; + mu = float(0); + sigma2 = float(0); + if (i1 < n1) { + // one warp normalizes one n1 index, + // synchronization is implicit + // initialize with standard Welford algorithm + const int numx = blockDim.x * blockDim.y; + const int thrx = threadIdx.x + threadIdx.y * blockDim.x; + const __half *lvals = vals + i1 * n2; + int l = 8 * thrx; + if ((((size_t)lvals) & 3) != 0) { + // 16 bit alignment + // first thread consumes first point + if (thrx == 0) { + float curr = static_cast(lvals[0]); + if (!rms_only) { + cuWelfordOnlineSum(curr, mu, sigma2, count); + } else { + cuRMSOnlineSum(curr, sigma2); + } + } + ++l; + } + // at this point, lvals[l] are 32 bit aligned for all threads. + for (; l + 7 < n2; l += 8 * numx) { + for (int k = 0; k < 8; k += 2) { + float2 curr = __half22float2(*((__half2 *)(lvals + l + k))); + if (!rms_only) { + cuWelfordOnlineSum(curr.x, mu, sigma2, count); + cuWelfordOnlineSum(curr.y, mu, sigma2, count); + } else { + cuRMSOnlineSum(curr.x, sigma2); + cuRMSOnlineSum(curr.y, sigma2); + } + } + } + for (; l < n2; ++l) { + float curr = static_cast(lvals[l]); + if (!rms_only) { + cuWelfordOnlineSum(curr, mu, sigma2, count); + } else { + cuRMSOnlineSum(curr, sigma2); + } + } + // intra-warp reductions + for (int l = 0; l <= 4; ++l) { + int srcLaneB = (threadIdx.x + (1 << l)) & 31; + float sigma2B = __shfl_sync(0xffffffff, sigma2, srcLaneB, warpSize); + if (!rms_only) { + float muB = __shfl_sync(0xffffffff, mu, srcLaneB, warpSize); + float countB = __shfl_sync(0xffffffff, count, srcLaneB, warpSize); + cuChanOnlineSum(muB, sigma2B, countB, mu, sigma2, count); + } else { + cuChanRMSOnlineSum(sigma2B, sigma2); + } + } + // threadIdx.x == 0 has correct values for each warp + // inter-warp reductions + if (blockDim.y > 1) { + float *ubuf = (float *)buf; + float *ibuf = (float *)(ubuf + blockDim.y); + for (int offset = blockDim.y / 2; offset > 0; offset /= 2) { + // upper half of warps write to shared + if (threadIdx.x == 0 && threadIdx.y >= offset && + threadIdx.y < 2 * offset) { + const int wrt_y = threadIdx.y - offset; + ubuf[2 * wrt_y + 1] = sigma2; + if (!rms_only) { + ubuf[2 * wrt_y] = mu; + ibuf[wrt_y] = count; + } + } + __syncthreads(); + // lower half merges + if (threadIdx.x == 0 && threadIdx.y < offset) { + float sigma2B = ubuf[2 * threadIdx.y + 1]; + if (!rms_only) { + float muB = ubuf[2 * threadIdx.y]; + float countB = ibuf[threadIdx.y]; + cuChanOnlineSum(muB, sigma2B, countB, mu, sigma2, count); + } else { + cuChanRMSOnlineSum(sigma2B, sigma2); + } + } + __syncthreads(); + } + // threadIdx.x = 0 && threadIdx.y == 0 only thread that has correct values + if (threadIdx.x == 0 && threadIdx.y == 0) { + if (!rms_only) { + ubuf[0] = mu; + } + ubuf[1] = sigma2; + } + __syncthreads(); + if (!rms_only) { + mu = ubuf[0]; + } + sigma2 = ubuf[1] / float(n2); + // don't care about final value of count, we know count == n2 + } else { + if (!rms_only) { + mu = __shfl_sync(0xffffffff, mu, 0, warpSize); + } + sigma2 = __shfl_sync(0xffffffff, sigma2 / float(n2), 0, warpSize); + } + } +} + +// This is the un-specialized struct. Note that we prevent instantiation of +// this struct by putting an undefined symbol in the function body so it won't +// compile. +// template +// struct SharedMemory +// { +// // Ensure that we won't compile any un-specialized types +// __device__ T *getPointer() +// { +// extern __device__ void error(void); +// error(); +// return NULL; +// } +// }; +// https://github.com/NVIDIA/apex/issues/246 +template struct SharedMemory; + +template <> struct SharedMemory { + __device__ float *getPointer() { + extern __shared__ float s_float[]; + return s_float; + } +}; + +template <> struct SharedMemory { + __device__ double *getPointer() { + extern __shared__ double s_double[]; + return s_double; + } +}; + +template +__device__ void cuApplyLayerNorm_(V *__restrict__ output_vals, + U *__restrict__ mean, U *__restrict__ invvar, + const T *__restrict__ vals, const int n1, + const int n2, const U epsilon, + const V *__restrict__ gamma, + const V *__restrict__ beta, bool rms_only) { + // Assumptions: + // 1) blockDim.x == warpSize + // 2) Tensors are contiguous + // + for (auto i1 = blockIdx.y; i1 < n1; i1 += gridDim.y) { + SharedMemory shared; + U *buf = shared.getPointer(); + U mu, sigma2; + cuWelfordMuSigma2(vals, n1, n2, i1, mu, sigma2, buf, rms_only); + + const T *lvals = vals + i1 * n2; + V *ovals = output_vals + i1 * n2; + U c_invvar = rsqrt(sigma2 + epsilon); + const int numx = blockDim.x * blockDim.y; + const int thrx = threadIdx.x + threadIdx.y * blockDim.x; + if (gamma != NULL && (beta != NULL || rms_only)) { + for (int i = thrx; i < n2; i += numx) { + U curr = static_cast(lvals[i]); + if (!rms_only) { + ovals[i] = + gamma[i] * static_cast(c_invvar * (curr - mu)) + beta[i]; + } else { + ovals[i] = gamma[i] * static_cast(c_invvar * curr); + } + } + } else { + for (int i = thrx; i < n2; i += numx) { + U curr = static_cast(lvals[i]); + if (!rms_only) { + ovals[i] = static_cast(c_invvar * (curr - mu)); + } else { + ovals[i] = static_cast(c_invvar * curr); + } + } + } + if (threadIdx.x == 0 && threadIdx.y == 0) { + if (!rms_only) { + mean[i1] = mu; + } + invvar[i1] = c_invvar; + } + __syncthreads(); + } +} + +template +__global__ void +cuApplyRMSNorm(V *__restrict__ output_vals, U *__restrict__ invvar, + const T *__restrict__ vals, const int n1, const int n2, + const U epsilon, const V *__restrict__ gamma) { + cuApplyLayerNorm_(output_vals, NULL, invvar, vals, n1, n2, epsilon, + gamma, NULL, true); +} + +template +void HostApplyRMSNorm(cudaStream_t stream, V *output, U *invvar, const T *input, + int n1, int n2, double epsilon, const V *gamma) { + auto getMaxGridY = []() { + int device; + int val; + cudaGetDevice(&device); + cudaDeviceGetAttribute(&val, cudaDevAttrMaxGridDimY, device); + return val; + }; + const dim3 threads(32, 4, 1); + const uint64_t maxGridY = getMaxGridY(); + const dim3 blocks(1, std::min((uint64_t)n1, maxGridY), 1); + int nshared = + threads.y > 1 ? threads.y * sizeof(U) + (threads.y / 2) * sizeof(U) : 0; + cuApplyRMSNorm<<>>( + output, invvar, input, n1, n2, U(epsilon), gamma); +} + +template +__device__ void cuLoadWriteStridedInputs( + const int i1_block, const int thr_load_row_off, const int thr_load_col_off, + const int i2_off, const int row_stride, U *warp_buf1, U *warp_buf2, + const T *input, const V *dout, const int i1_end, const int n2, + const U *__restrict__ mean, const U *__restrict__ invvar, bool rms_only) { + int i1 = i1_block + thr_load_row_off; + if (i1 < i1_end) { + U curr_mean; + if (!rms_only) { + curr_mean = mean[i1]; + } + U curr_invvar = invvar[i1]; + for (int k = 0; k < blockDim.y; ++k) { + int i2 = i2_off + k; + int load_idx = i1 * n2 + i2; + int write_idx = thr_load_row_off * row_stride + thr_load_col_off + k; + if (i2 < n2) { + U curr_input = static_cast(input[load_idx]); + U curr_dout = static_cast(dout[load_idx]); + if (!rms_only) { + warp_buf1[write_idx] = curr_dout; + warp_buf2[write_idx] = + curr_dout * (curr_input - curr_mean) * curr_invvar; + } else { + warp_buf2[write_idx] = curr_dout * (curr_input)*curr_invvar; + } + } else { + if (!rms_only) { + warp_buf1[write_idx] = U(0); + } + warp_buf2[write_idx] = U(0); + } + } + } else { + for (int k = 0; k < blockDim.y; ++k) { + int write_idx = thr_load_row_off * row_stride + thr_load_col_off + k; + if (!rms_only) { + warp_buf1[write_idx] = U(0); + } + warp_buf2[write_idx] = U(0); + } + } +} + +template +__device__ void cuLoadAddStridedInputs( + const int i1_block, const int thr_load_row_off, const int thr_load_col_off, + const int i2_off, const int row_stride, U *warp_buf1, U *warp_buf2, + const T *input, const V *dout, const int i1_end, const int n2, + const U *__restrict__ mean, const U *__restrict__ invvar, bool rms_only) { + int i1 = i1_block + thr_load_row_off; + if (i1 < i1_end) { + U curr_mean; + if (!rms_only) { + curr_mean = mean[i1]; + } + U curr_invvar = invvar[i1]; + for (int k = 0; k < blockDim.y; ++k) { + int i2 = i2_off + k; + int load_idx = i1 * n2 + i2; + int write_idx = thr_load_row_off * row_stride + thr_load_col_off + k; + if (i2 < n2) { + U curr_input = static_cast(input[load_idx]); + U curr_dout = static_cast(dout[load_idx]); + if (!rms_only) { + warp_buf1[write_idx] += curr_dout; + warp_buf2[write_idx] += + curr_dout * (curr_input - curr_mean) * curr_invvar; + } else { + warp_buf2[write_idx] += curr_dout * (curr_input)*curr_invvar; + } + } + } + } +} + +template +__global__ void cuComputePartGradGammaBeta( + const V *__restrict__ dout, const T *__restrict__ input, const int n1, + const int n2, const U *__restrict__ mean, const U *__restrict__ invvar, + U epsilon, U *part_grad_gamma, U *part_grad_beta, bool rms_only) { + const int numsegs_n1 = + (n1 + blockDim.y * blockDim.y - 1) / (blockDim.y * blockDim.y); + const int segs_per_block = (numsegs_n1 + gridDim.y - 1) / gridDim.y; + const int i1_beg = blockIdx.y * segs_per_block * blockDim.y * blockDim.y; + const int i1_beg_plus_one = + (blockIdx.y + 1) * segs_per_block * blockDim.y * blockDim.y; + const int i1_end = i1_beg_plus_one < n1 ? i1_beg_plus_one : n1; + const int row_stride = blockDim.x + 1; + const int thr_load_col_off = (threadIdx.x * blockDim.y) & (blockDim.x - 1); + const int thr_load_row_off = + (threadIdx.x * blockDim.y) / blockDim.x + threadIdx.y * blockDim.y; + const int i2_off = blockIdx.x * blockDim.x + thr_load_col_off; + SharedMemory shared; + U *buf = shared.getPointer(); // buf has at least blockDim.x * blockDim.y * + // blockDim.y + (blockDim.y - + // 1)*(blockDim.x/blockDim.y) elements + U *warp_buf1 = (U *)buf; + U *warp_buf2 = warp_buf1 + blockDim.y * blockDim.y * row_stride; + // compute partial sums from strided inputs + // do this to increase number of loads in flight + cuLoadWriteStridedInputs(i1_beg, thr_load_row_off, thr_load_col_off, i2_off, + row_stride, warp_buf1, warp_buf2, input, dout, + i1_end, n2, mean, invvar, rms_only); + for (int i1_block = i1_beg + blockDim.y * blockDim.y; i1_block < i1_end; + i1_block += blockDim.y * blockDim.y) { + cuLoadAddStridedInputs(i1_block, thr_load_row_off, thr_load_col_off, i2_off, + row_stride, warp_buf1, warp_buf2, input, dout, + i1_end, n2, mean, invvar, rms_only); + } + __syncthreads(); + // inter-warp reductions + // sum within each warp + U acc1 = U(0); + U acc2 = U(0); + for (int k = 0; k < blockDim.y; ++k) { + int row1 = threadIdx.y + k * blockDim.y; + int idx1 = row1 * row_stride + threadIdx.x; + if (!rms_only) { + acc1 += warp_buf1[idx1]; + } + acc2 += warp_buf2[idx1]; + } + if (!rms_only) { + warp_buf1[threadIdx.y * row_stride + threadIdx.x] = acc1; + } + warp_buf2[threadIdx.y * row_stride + threadIdx.x] = acc2; + __syncthreads(); + // sum all warps + for (int offset = blockDim.y / 2; offset > 1; offset /= 2) { + if (threadIdx.y < offset) { + int row1 = threadIdx.y; + int row2 = threadIdx.y + offset; + int idx1 = row1 * row_stride + threadIdx.x; + int idx2 = row2 * row_stride + threadIdx.x; + if (!rms_only) { + warp_buf1[idx1] += warp_buf1[idx2]; + } + warp_buf2[idx1] += warp_buf2[idx2]; + } + __syncthreads(); + } + int i2 = blockIdx.x * blockDim.x + threadIdx.x; + if (threadIdx.y == 0 && i2 < n2) { + int row1 = threadIdx.y; + int row2 = threadIdx.y + 1; + int idx1 = row1 * row_stride + threadIdx.x; + int idx2 = row2 * row_stride + threadIdx.x; + if (!rms_only) { + part_grad_beta[blockIdx.y * n2 + i2] = warp_buf1[idx1] + warp_buf1[idx2]; + } + part_grad_gamma[blockIdx.y * n2 + i2] = warp_buf2[idx1] + warp_buf2[idx2]; + } +} + +template +__global__ void +cuComputeGradGammaBeta(const U *part_grad_gamma, const U *part_grad_beta, + const int part_size, const int n1, const int n2, + V *grad_gamma, V *grad_beta, bool rms_only) { + // sum partial gradients for gamma and beta + SharedMemory shared; + U *buf = shared.getPointer(); + int i2 = blockIdx.x * blockDim.x + threadIdx.x; + if (i2 < n2) { + // each warp does sequential reductions until reduced part_size is num_warps + int num_warp_reductions = part_size / blockDim.y; + U sum_gamma = U(0); + U sum_beta = U(0); + const U *part_grad_gamma_ptr = + part_grad_gamma + threadIdx.y * num_warp_reductions * n2 + i2; + const U *part_grad_beta_ptr = + part_grad_beta + threadIdx.y * num_warp_reductions * n2 + i2; + for (int warp_offset = 0; warp_offset < num_warp_reductions; + ++warp_offset) { + sum_gamma += part_grad_gamma_ptr[warp_offset * n2]; + if (!rms_only) { + sum_beta += part_grad_beta_ptr[warp_offset * n2]; + } + } + // inter-warp reductions + const int nbsize3 = blockDim.x * blockDim.y / 2; + for (int offset = blockDim.y / 2; offset >= 1; offset /= 2) { + // top half write to shared memory + if (threadIdx.y >= offset && threadIdx.y < 2 * offset) { + const int write_idx = (threadIdx.y - offset) * blockDim.x + threadIdx.x; + buf[write_idx] = sum_gamma; + if (!rms_only) { + buf[write_idx + nbsize3] = sum_beta; + } + } + __syncthreads(); + // bottom half sums + if (threadIdx.y < offset) { + const int read_idx = threadIdx.y * blockDim.x + threadIdx.x; + sum_gamma += buf[read_idx]; + if (!rms_only) { + sum_beta += buf[read_idx + nbsize3]; + } + } + __syncthreads(); + } + // write out fully summed gradients + if (threadIdx.y == 0) { + grad_gamma[i2] = sum_gamma; + if (!rms_only) { + grad_beta[i2] = sum_beta; + } + } + } +} + +template +__global__ void +cuComputeGradInput(const V *__restrict__ dout, const T *__restrict__ input, + const int n1, const int n2, const U *__restrict__ mean, + const U *__restrict__ invvar, U epsilon, const V *gamma, + T *grad_input, bool rms_only) { + for (auto i1 = blockIdx.y; i1 < n1; i1 += gridDim.y) { + U sum_loss1 = U(0); + U sum_loss2 = U(0); + U c_mean; + if (!rms_only) { + c_mean = mean[i1]; + } + const U c_invvar = invvar[i1]; + const T *k_input = input + i1 * n2; + const V *k_dout = dout + i1 * n2; + const int numx = blockDim.x * blockDim.y; + const int thrx = threadIdx.x + threadIdx.y * blockDim.x; + if (gamma != NULL) { + int l = 4 * thrx; + for (; l + 3 < n2; l += 4 * numx) { + for (int k = 0; k < 4; ++k) { + const U c_h = static_cast(k_input[l + k]); + const U c_loss = static_cast(k_dout[l + k]); + if (!rms_only) { + sum_loss1 += c_loss * static_cast(gamma[l + k]); + sum_loss2 += c_loss * static_cast(gamma[l + k]) * + (c_h - c_mean) * c_invvar; + } else { + sum_loss2 += c_loss * static_cast(gamma[l + k]) * (c_h)*c_invvar; + } + } + } + for (; l < n2; ++l) { + const U c_h = static_cast(k_input[l]); + const U c_loss = static_cast(k_dout[l]); + if (!rms_only) { + sum_loss1 += c_loss * static_cast(gamma[l]); + sum_loss2 += + c_loss * static_cast(gamma[l]) * (c_h - c_mean) * c_invvar; + } else { + sum_loss2 += c_loss * static_cast(gamma[l]) * (c_h)*c_invvar; + } + } + } else { + int l = 4 * thrx; + for (; l + 3 < n2; l += 4 * numx) { + for (int k = 0; k < 4; ++k) { + const U c_h = static_cast(k_input[l + k]); + const U c_loss = static_cast(k_dout[l + k]); + if (!rms_only) { + sum_loss1 += c_loss; + sum_loss2 += c_loss * (c_h - c_mean) * c_invvar; + } else { + sum_loss2 += c_loss * (c_h)*c_invvar; + } + } + } + for (; l < n2; ++l) { + const U c_h = static_cast(k_input[l]); + const U c_loss = static_cast(k_dout[l]); + if (!rms_only) { + sum_loss1 += c_loss; + sum_loss2 += c_loss * (c_h - c_mean) * c_invvar; + } else { + sum_loss2 += c_loss * (c_h)*c_invvar; + } + } + } + // intra-warp reductions + for (int mask = blockDim.x / 2; mask > 0; mask /= 2) { + if (!rms_only) { + sum_loss1 += __shfl_xor_sync(0xffffffff, sum_loss1, mask, warpSize); + } + sum_loss2 += __shfl_xor_sync(0xffffffff, sum_loss2, mask, warpSize); + } + // inter-warp reductions + if (blockDim.y > 1) { + SharedMemory shared; + U *buf = shared.getPointer(); + for (int offset = blockDim.y / 2; offset > 0; offset /= 2) { + // upper half of warps write to shared + if (threadIdx.y >= offset && threadIdx.y < 2 * offset) { + const int wrt_i = (threadIdx.y - offset) * blockDim.x + threadIdx.x; + if (!rms_only) { + buf[2 * wrt_i] = sum_loss1; + } + buf[2 * wrt_i + 1] = sum_loss2; + } + __syncthreads(); + // lower half merges + if (threadIdx.y < offset) { + const int read_i = threadIdx.y * blockDim.x + threadIdx.x; + if (!rms_only) { + sum_loss1 += buf[2 * read_i]; + } + sum_loss2 += buf[2 * read_i + 1]; + } + __syncthreads(); + } + if (threadIdx.y == 0) { + if (!rms_only) { + buf[2 * threadIdx.x] = sum_loss1; + } + buf[2 * threadIdx.x + 1] = sum_loss2; + } + __syncthreads(); + if (threadIdx.y != 0) { + if (!rms_only) { + sum_loss1 = buf[2 * threadIdx.x]; + } + sum_loss2 = buf[2 * threadIdx.x + 1]; + } + } + // all threads now have the two sums over l + U fH = (U)n2; + U term1 = (U(1) / fH) * c_invvar; + T *k_grad_input = grad_input + i1 * n2; + if (gamma != NULL) { + for (int l = thrx; l < n2; l += numx) { + const U c_h = static_cast(k_input[l]); + const U c_loss = static_cast(k_dout[l]); + U f_grad_input = fH * c_loss * static_cast(gamma[l]); + if (!rms_only) { + f_grad_input -= sum_loss1; + f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2; + } else { + f_grad_input -= (c_h)*c_invvar * sum_loss2; + } + f_grad_input *= term1; + k_grad_input[l] = static_cast(f_grad_input); + } + } else { + for (int l = thrx; l < n2; l += numx) { + const U c_h = static_cast(k_input[l]); + const U c_loss = static_cast(k_dout[l]); + U f_grad_input = fH * c_loss; + if (!rms_only) { + f_grad_input -= sum_loss1; + f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2; + } else { + f_grad_input -= (c_h)*c_invvar * sum_loss2; + } + f_grad_input *= term1; + k_grad_input[l] = static_cast(f_grad_input); + } + } + // prevent race where buf is written again before reads are done + __syncthreads(); + } +} + +template +void HostRMSNormGradient(cudaStream_t stream, const V *dout, const U *invvar, + const T *input, int n1, int n2, const V *gamma, + double epsilon, T *grad_input, V *grad_gamma, + int part_size, U *part_grad_gamma) { + auto getMaxGridY = []() { + int device; + int val; + cudaGetDevice(&device); + cudaDeviceGetAttribute(&val, cudaDevAttrMaxGridDimY, device); + return val; + }; + const uint64_t maxGridY = getMaxGridY(); + if (gamma != NULL) { + const dim3 threads2(32, 4, 1); + const dim3 blocks2((n2 + threads2.x - 1) / threads2.x, part_size, 1); + const int nshared2_a = + 2 * sizeof(U) * threads2.y * threads2.y * (threads2.x + 1); + const int nshared2_b = threads2.x * threads2.y * sizeof(U); + const int nshared2 = nshared2_a > nshared2_b ? nshared2_a : nshared2_b; + // note (mkozuki): I can hard code part_grad_gamma's dtype as float given + // that the `cuda_layer_norm_gradient` doesn't support double. + cuComputePartGradGammaBeta<<>>( + dout, input, n1, n2, + invvar, // unused + invvar, U(epsilon), part_grad_gamma, part_grad_gamma, /* unused */ + true); + + const dim3 threads3(32, 8, 1); + const dim3 blocks3((n2 + threads2.x - 1) / threads2.x, 1, 1); + const int nshared3 = threads3.x * threads3.y * sizeof(U); + cuComputeGradGammaBeta<<>>( + part_grad_gamma, part_grad_gamma, /* unused */ + part_size, n1, n2, grad_gamma, grad_gamma, /* unused */ + true); + } + + // compute grad_input + const dim3 blocks1(1, std::min((uint64_t)n1, maxGridY), 1); + const dim3 threads1(32, 4, 1); + int nshared = threads1.y > 1 ? threads1.y * threads1.x * sizeof(U) : 0; + cuComputeGradInput<<>>( + dout, input, n1, n2, invvar, /* unused */ + invvar, U(epsilon), gamma, grad_input, true); +} + +} // namespace + +namespace gpu_ops { + +void rms_forward_affine_mixed_dtypes(cudaStream_t stream, void **buffers, + const char *opaque, + std::size_t opaque_len) { + const RMSNormDescriptor &d = + *UnpackDescriptor(opaque, opaque_len); + + DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES( + d.x_type, d.w_type, "rms_norm_cuda_kernel", + HostApplyRMSNorm( + stream, static_cast(buffers[2]), + static_cast(buffers[3]), + static_cast(buffers[0]), d.n1, d.n2, d.eps, + /*gamma=*/static_cast(buffers[1]));) +} + +void rms_backward_affine(cudaStream_t stream, void **buffers, + const char *opaque, std::size_t opaque_len) { + const RMSNormDescriptor &d = + *UnpackDescriptor(opaque, opaque_len); + + DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES( + d.x_type, d.w_type, "cuComputeGradInputRMS", + HostRMSNormGradient( + stream, + /*dout=*/static_cast(buffers[0]), + /*invvar=*/static_cast(buffers[1]), + /*input=*/static_cast(buffers[2]), d.n1, d.n2, + // TMJ pass NULL argument for gamma, beta, grad_gamma and grad_beta + // if gamma Tensor is NULL on input. + /*gamma=*/static_cast(buffers[3]), d.eps, + /*grad_input=*/static_cast(buffers[4]), + /*grad_gamma=*/static_cast(buffers[5]), + d.part_grad_size, + /*part_grad_gamma=*/static_cast(buffers[6]));) +} + +} // namespace gpu_ops diff --git a/docs/gpu_performance_tips.md b/docs/gpu_performance_tips.md index 5aa4b0ecba6a..1f5cc0727605 100644 --- a/docs/gpu_performance_tips.md +++ b/docs/gpu_performance_tips.md @@ -1,4 +1,6 @@ -# GPU peformance tips +# GPU performance tips + + This document focuses on performance tips for neural network workloads @@ -23,6 +25,10 @@ code examples: ## XLA performance flags +```{note} + JAX-Toolbox also has a page on [NVIDIA XLA performance FLAGS](https://github.com/NVIDIA/JAX-Toolbox/blob/main/rosetta/docs/GPU_performance.md). +``` + The existence and exact behavior of XLA flags may be `jaxlib`-version dependent. As of `jaxlib==0.4.18` (released [Oct 6 @@ -60,16 +66,12 @@ training on Nvidia GPUs](https://github.com/NVIDIA/JAX-Toolbox/blob/main/rosetta ### Communication flags -* **--xla_gpu_enable_async_collectives** This flag enables the collective ops - such as `AllReduce`, `AllGather`, `ReduceScatter` and `CollectivePermute` to - be asynchronous. Asynchronous communication can overlap cross-core - communication with computation. The default value is False. * **--xla_gpu_enable_latency_hiding_scheduler** This flag enables latency hiding schedulers to overlap asynchronous communication with computation efficiently. The default value is False. * **--xla_gpu_enable_pipelined_collectives** When using pipeline parallelism, this flag enables overlapping the (i+1)-th layer weight `AllGather` with the - i-th layer computation. It also enables enable overlapping (i+1)-th layer + i-th layer computation. It also enables overlapping (i+1)-th layer weight `Reduce`/`ReduceScatter` with i-th layer's computation. The default value is False. **There are some bugs when this flag is turned on.** * **--xla_gpu_collective_permute_decomposer_threshold** This flag is useful when @@ -107,3 +109,12 @@ os.environ.update({ These NCCL flags could improve single-host communication speed. These flags don't seem useful for multi-host communication yet. + +## Multi-Process + +We recommand using one process per GPU and not one per node. In some +cases, this can speed up jitted computation. The +{func}`jax.distributed.initialize` API will automatically understand +that configuration when run under SLURM. However, this only a rule of +thumb and it may be useful to test both one process per GPU and one +process per node on your use case. diff --git a/docs/index.rst b/docs/index.rst index bef957568f92..2e13c109dbbe 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -1,7 +1,8 @@ JAX: High-Performance Array Computing ===================================== -JAX is Autograd_ and XLA_, brought together for high-performance numerical computing. +JAX is a Python library for accelerator-oriented array computation and program transformation, +designed for high-performance numerical computing and large-scale machine learning. If you're looking to train neural networks, use Flax_ and start with its documentation. Some associated tools are Optax_ and Orbax_. @@ -60,8 +61,7 @@ For an end-to-end transformer library built on JAX, see MaxText_. :caption: Getting Started installation - notebooks/quickstart - notebooks/thinking_in_jax + quickstart notebooks/Common_Gotchas_in_JAX faq @@ -69,7 +69,7 @@ For an end-to-end transformer library built on JAX, see MaxText_. :hidden: :maxdepth: 1 - jax-101/index + tutorials .. toctree:: @@ -93,8 +93,6 @@ For an end-to-end transformer library built on JAX, see MaxText_. glossary -.. _Autograd: https://github.com/hips/autograd -.. _XLA: https://openxla.org/xla .. _Flax: https://flax.readthedocs.io/ .. _Orbax: https://orbax.readthedocs.io/ .. _Optax: https://optax.readthedocs.io/ diff --git a/docs/installation.md b/docs/installation.md index 379c7822acce..fa77d1fc29f6 100644 --- a/docs/installation.md +++ b/docs/installation.md @@ -1,47 +1,62 @@ +(installation)= # Installing JAX -JAX is written in pure Python, but it depends on XLA, which needs to be -installed as the `jaxlib` package. Use the following instructions to install a -binary package with `pip` or `conda`, to use a -[Docker container](#docker-containers-nvidia-gpu), or to [build JAX from -source](developer.md#building-from-source). + -## Supported platforms +Using JAX requires installing two packages: `jax`, which is pure Python and +cross-platform, and `jaxlib` which contains compiled binaries, and requires +different builds for different operating systems and accelerators. + +**TL;DR** For most users, a typical JAX installation may look something like this: + +* **CPU-only (Linux/macOS/Windows)** + ``` + pip install -U jax + ``` +* **GPU (NVIDIA, CUDA 12)** + ``` + pip install -U "jax[cuda12]" + ``` -| | Linux x86_64 | Linux aarch64 | Mac x86_64 | Mac ARM | Windows x86_64 | Windows WSL2 x86_64 | -|------------|--------------|-------------------------|--------------|----------------|----------------|---------------------| -| CPU | [yes](#cpu) | [yes](#cpu) | [yes](#cpu) | [yes](#cpu) | [yes](#cpu) | [yes](#cpu) | -| NVIDIA GPU | [yes](#nvidia-gpu) | [yes](#nvidia-gpu) | no | n/a | no | [experimental](#nvidia-gpu) | -| Google TPU | [yes](#google-tpu) | n/a | n/a | n/a | n/a | n/a | -| AMD GPU | [experimental](#amd-gpu) | no | no | n/a | no | no | -| Apple GPU | n/a | no | [experimental](#apple-gpu) | [experimental](#apple-gpu) | n/a | n/a | +* **TPU (Google Cloud TPU VM)** + ``` + pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html + ``` + +(install-supported-platforms)= +## Supported platforms +The table below shows all supported platforms and installation options. Check if your setup is supported; and if it says _"yes"_ or _"experimental"_, then click on the corresponding link to learn how to install JAX in greater detail. -We support installing or building `jaxlib` on Linux (Ubuntu 20.04 or later) and -macOS (10.12 or later) platforms. There is also *experimental* native Windows -support. +| | Linux, x86_64 | Linux, aarch64 | macOS, Intel x86_64, AMD GPU | macOS, Apple Silicon, ARM-based | Windows, x86_64 | Windows WSL2, x86_64 | +|------------------|---------------------------------------|--------------------------------|----------------------------------------|----------------------------------------|-------------------------|-----------------------------------------| +| CPU | {ref}`yes ` | {ref}`yes ` | {ref}`yes `| {ref}`yes `| {ref}`yes ` | {ref}`yes `| +| NVIDIA GPU | {ref}`yes ` | {ref}`yes ` | no | n/a | no | {ref}`experimental ` | +| Google Cloud TPU | {ref}`yes ` | n/a | n/a | n/a | n/a | n/a | +| AMD GPU | {ref}`experimental ` | no | no | n/a | no | no | +| Apple GPU | n/a | no | {ref}`experimental ` | {ref}`experimental ` | n/a | n/a | -Windows users can use JAX on CPU and GPU via the [Windows Subsystem for -Linux](https://docs.microsoft.com/en-us/windows/wsl/about), or alternatively -they can use the native Windows CPU-only support. +(install-cpu)= ## CPU ### pip installation: CPU -We currently release `jaxlib` wheels for the following +Currently, the JAX team releases `jaxlib` wheels for the following operating systems and architectures: -* Linux, x86-64 -* Mac, Intel -* Mac, ARM -* Windows, x86-64 (*experimental*) + +- Linux, x86_64 +- Linux, aarch64 +- macOS, Intel +- macOS, Apple ARM-based +- Windows, x86_64 (*experimental*) To install a CPU-only version of JAX, which might be useful for doing local -development on a laptop, you can run +development on a laptop, you can run: ```bash pip install --upgrade pip -pip install --upgrade "jax[cpu]" +pip install --upgrade jax ``` On Windows, you may also need to install the @@ -53,94 +68,100 @@ to pip install on other operating systems and architectures may lead to `jaxlib` not being installed alongside `jax`, although `jax` may successfully install (but fail at runtime). + +(install-nvidia-gpu)= ## NVIDIA GPU JAX supports NVIDIA GPUs that have SM version 5.2 (Maxwell) or newer. Note that Kepler-series GPUs are no longer supported by JAX since NVIDIA has dropped support for Kepler GPUs in its software. -You must first install the NVIDIA driver. We -recommend installing the newest driver available from NVIDIA, but the driver -must be version >= 525.60.13 for CUDA 12 and >= 450.80.02 for CUDA 11 on Linux. +You must first install the NVIDIA driver. You're +recommended to install the newest driver available from NVIDIA, but the driver +version must be >= 525.60.13 for CUDA 12 on Linux. + If you need to use a newer CUDA toolkit with an older driver, for example on a cluster where you cannot update the NVIDIA driver easily, you may be able to use the [CUDA forward compatibility packages](https://docs.nvidia.com/deploy/cuda-compatibility/) that NVIDIA provides for this purpose. -### pip installation: GPU (CUDA, installed via pip, easier) +### pip installation: NVIDIA GPU (CUDA, installed via pip, easier) + +There are two ways to install JAX with NVIDIA GPU support: -There are two ways to install JAX with NVIDIA GPU support: using CUDA and CUDNN -installed from pip wheels, and using a self-installed CUDA/CUDNN. We -strongly recommend installing CUDA and CUDNN using the pip wheels, since it is -much easier! This method is only supported on x86_64, because NVIDIA has not -released aarch64 CUDA pip packages. +- Using NVIDIA CUDA and cuDNN installed from pip wheels +- Using a self-installed CUDA/cuDNN + +The JAX team strongly recommends installing CUDA and cuDNN using the pip wheels, +since it is much easier! + +NVIDIA has released CUDA pip packages only for x86_64 and aarch64; on other +platforms you must use a local installation of CUDA. ```bash pip install --upgrade pip -# CUDA 12 installation +# NVIDIA CUDA 12 installation # Note: wheels only available on linux. -pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html - -# CUDA 11 installation -# Note: wheels only available on linux. -pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html +pip install --upgrade "jax[cuda12]" ``` -If JAX detects the wrong version of the CUDA libraries, there are several things -to check: -* make sure that `LD_LIBRARY_PATH` is not set, since `LD_LIBRARY_PATH` can - override the CUDA libraries. -* make sure that the CUDA libraries installed are those requested by JAX. +If JAX detects the wrong version of the NVIDIA CUDA libraries, there are several things +you need to check: + +* Make sure that `LD_LIBRARY_PATH` is not set, since `LD_LIBRARY_PATH` can + override the NVIDIA CUDA libraries. +* Make sure that the NVIDIA CUDA libraries installed are those requested by JAX. Rerunning the installation command above should work. -### pip installation: GPU (CUDA, installed locally, harder) +### pip installation: NVIDIA GPU (CUDA, installed locally, harder) -If you prefer to use a preinstalled copy of CUDA, you must first -install [CUDA](https://developer.nvidia.com/cuda-downloads) and -[CuDNN](https://developer.nvidia.com/CUDNN). +If you prefer to use a preinstalled copy of NVIDIA CUDA, you must first +install NVIDIA [CUDA](https://developer.nvidia.com/cuda-downloads) and +[cuDNN](https://developer.nvidia.com/CUDNN). -JAX provides pre-built CUDA-compatible wheels for **Linux x86_64 only**. Other +JAX provides pre-built CUDA-compatible wheels for **Linux x86_64 and Linux aarch64 only**. Other combinations of operating system and architecture are possible, but require -[building from source](developer.md#building-from-source). +building from source (refer to {ref}`building-from-source` to learn more}. You should use an NVIDIA driver version that is at least as new as your -[CUDA toolkit's corresponding driver version](https://docs.nvidia.com/cuda/cuda-toolkit-release-notes/index.html#cuda-major-component-versions__table-cuda-toolkit-driver-versions). +[NVIDIA CUDA toolkit's corresponding driver version](https://docs.nvidia.com/cuda/cuda-toolkit-release-notes/index.html#cuda-major-component-versions__table-cuda-toolkit-driver-versions). If you need to use a newer CUDA toolkit with an older driver, for example on a cluster where you cannot update the NVIDIA driver easily, you may be able to use the [CUDA forward compatibility packages](https://docs.nvidia.com/deploy/cuda-compatibility/) that NVIDIA provides for this purpose. -JAX currently ships two CUDA wheel variants: -* CUDA 12.3, cuDNN 8.9, NCCL 2.16 -* CUDA 11.8, cuDNN 8.6, NCCL 2.16 +JAX currently ships one CUDA wheel variant: + +| Built with | Compatible with | +|------------|--------------------| +| CUDA 12.3 | CUDA >=12.1 | +| CUDNN 9.0 | CUDNN >=9.0, <10.0 | +| NCCL 2.19 | NCCL >=2.18 | -You may use a JAX wheel provided the major version of your CUDA, cuDNN, and NCCL -installations match, and the minor versions are the same or newer. JAX checks the versions of your libraries, and will report an error if they are not sufficiently new. +Setting the `JAX_SKIP_CUDA_CONSTRAINTS_CHECK` environment variable will disable +the check, but using older versions of CUDA may lead to errors, or incorrect +results. NCCL is an optional dependency, required only if you are performing multi-GPU computations. -To install, run +To install, run: ```bash pip install --upgrade pip -# Installs the wheel compatible with CUDA 12 and cuDNN 8.9 or newer. -# Note: wheels only available on linux. -pip install --upgrade "jax[cuda12_local]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html - -# Installs the wheel compatible with CUDA 11 and cuDNN 8.6 or newer. +# Installs the wheel compatible with NVIDIA CUDA 12 and cuDNN 9.0 or newer. # Note: wheels only available on linux. -pip install --upgrade "jax[cuda11_local]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html +pip install --upgrade "jax[cuda12_local]" ``` -**These `pip` installations do not work with Windows, and may fail silently; see -[above](#installing-jax).** +**These `pip` installations do not work with Windows, and may fail silently; refer to the table +[above](#supported-platforms).** You can find your CUDA version with the command: @@ -152,48 +173,22 @@ JAX uses `LD_LIBRARY_PATH` to find CUDA libraries and `PATH` to find binaries (`ptxas`, `nvlink`). Please make sure that these paths point to the correct CUDA installation. -Please let us know on [the issue tracker](https://github.com/google/jax/issues) -if you run into any errors or problems with the prebuilt wheels. +JAX requires libdevice10.bc, which typically comes from the cuda-nvvm package. +Make sure that it is present in your CUDA installation. + +Please let the JAX team know on [the GitHub issue tracker](https://github.com/google/jax/issues) +if you run into any errors or problems with the pre-built wheels. -### Docker containers: NVIDIA GPU +(docker-containers-nvidia-gpu)= +### NVIDIA GPU Docker containers NVIDIA provides the [JAX Toolbox](https://github.com/NVIDIA/JAX-Toolbox) containers, which are bleeding edge containers containing nightly releases of jax and some models/frameworks. -## Nightly installation - -Nightly releases reflect the state of the main repository at the time they are -built, and may not pass the full test suite. - -* JAX: -```bash -pip install -U --pre jax -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -``` - -* Jaxlib CPU: -```bash -pip install -U --pre jaxlib -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html -``` - -* Jaxlib TPU: -```bash -pip install -U --pre jaxlib -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html -pip install -U --pre libtpu-nightly -f https://storage.googleapis.com/jax-releases/libtpu_releases.html -``` - -* Jaxlib GPU (Cuda 12): -```bash -pip install -U --pre jaxlib -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_cuda12_releases.html -``` - -* Jaxlib GPU (Cuda 11): -```bash -pip install -U --pre jaxlib -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_cuda_releases.html -``` - -## Google TPU +(install-google-tpu)= +## Google Cloud TPU ### pip installation: Google Cloud TPU @@ -201,50 +196,53 @@ JAX provides pre-built wheels for [Google Cloud TPU](https://cloud.google.com/tpu/docs/users-guide-tpu-vm). To install JAX along with appropriate versions of `jaxlib` and `libtpu`, you can run the following in your cloud TPU VM: + ```bash pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html ``` -For interactive notebook users: Colab TPUs no longer support JAX as of -JAX version 0.4. However, for an interactive TPU notebook in the cloud, you can -use [Kaggle TPU notebooks](https://www.kaggle.com/docs/tpu), which fully -support JAX. +For users of Colab (https://colab.research.google.com/), be sure you are +using *TPU v2* and not the older, deprecated TPU runtime. -## Apple GPU +(install-apple-gpu)= +## Apple Silicon GPU (ARM-based) -### pip installation: Apple GPUs +### pip installation: Apple ARM-based Silicon GPUs -Apple provides an experimental Metal plugin for Apple GPU hardware. For details, -see +Apple provides an experimental Metal plugin for Apple ARM-based GPU hardware. For details, +refer to [Apple's JAX on Metal documentation](https://developer.apple.com/metal/jax/). -There are several caveats with the Metal plugin: -* the Metal plugin is new and experimental and has a number of +**Note:** There are several caveats with the Metal plugin: + +* The Metal plugin is new and experimental and has a number of [known issues](https://github.com/google/jax/issues?q=is%3Aissue+is%3Aopen+label%3A%22Apple+GPU+%28Metal%29+plugin%22). Please report any issues on the JAX issue tracker. -* the Metal plugin currently requires very specific versions of `jax` and +* The Metal plugin currently requires very specific versions of `jax` and `jaxlib`. This restriction will be relaxed over time as the plugin API matures. +(install-amd-gpu)= ## AMD GPU -JAX has experimental ROCM support. There are two ways to install JAX: +JAX has experimental ROCm support. There are two ways to install JAX: -* use [AMD's docker container](https://hub.docker.com/r/rocm/jax), or -* [build from source](developer.md#additional-notes-for-building-a-rocm-jaxlib-for-amd-gpus). +* Use [AMD's Docker container](https://hub.docker.com/r/rocm/jax); or +* Build from source (refer to {ref}`building-from-source` — a section called _Additional notes for building a ROCM `jaxlib` for AMD GPUs_). -## Conda +## Conda (community-supported) ### Conda installation -There is a community-supported Conda build of `jax`. To install using `conda`, -simply run +There is a community-supported Conda build of `jax`. To install it using `conda`, +simply run: ```bash conda install jax -c conda-forge ``` -To install on a machine with an NVIDIA GPU, run +To install it on a machine with an NVIDIA GPU, run: + ```bash conda install jaxlib=*=*cuda* jax cuda-nvcc -c conda-forge -c nvidia ``` @@ -260,20 +258,57 @@ install the CUDA build on a machine without GPUs, follow the instructions in the [Tips & tricks](https://conda-forge.org/docs/user/tipsandtricks.html#installing-cuda-enabled-packages-like-tensorflow-and-pytorch) section of the `conda-forge` website. -See the `conda-forge` +Go to the `conda-forge` [jaxlib](https://github.com/conda-forge/jaxlib-feedstock#installing-jaxlib) and [jax](https://github.com/conda-forge/jax-feedstock#installing-jax) repositories for more details. -## Building JAX from source -See [Building JAX from source](developer.md#building-from-source). -## Installing older jaxlib wheels +## JAX nightly installation + +Nightly releases reflect the state of the main JAX repository at the time they are +built, and may not pass the full test suite. + +- CPU only: -Due to storage limitations on the Python package index, we periodically remove -older jaxlib wheels from the releases on http://pypi.org/project/jax. These can -still be installed directly via the URLs here; for example: +```bash +pip install -U --pre jax -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html ``` + +- Google Cloud TPU: + +```bash +pip install -U --pre jax[tpu] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/libtpu_releases.html +``` + +- NVIDIA GPU (CUDA 12): + +```bash +pip install -U --pre jax[cuda12] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html +``` + +- NVIDIA GPU (CUDA 12) legacy: + +Use the following for historical nightly releases of monolithic CUDA jaxlibs. +You most likely do not want this; no further monolithic CUDA jaxlibs will be +built and those that exist will expire by Sep 2024. Use the "CUDA 12" option above. + +```bash +pip install -U --pre jaxlib -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_cuda12_releases.html +``` + +(building-jax-from-source)= +## Building JAX from source + +Refer to {ref}`building-from-source`. + +## Installing older `jaxlib` wheels + +Due to storage limitations on the Python package index, the JAX team periodically removes +older `jaxlib` wheels from the releases on http://pypi.org/project/jax. These can +still be installed directly via the URLs here. For example: + +```bash # Install jaxlib on CPU via the wheel archive pip install jax[cpu]==0.3.25 -f https://storage.googleapis.com/jax-releases/jax_releases.html @@ -281,6 +316,6 @@ pip install jax[cpu]==0.3.25 -f https://storage.googleapis.com/jax-releases/jax_ pip install jaxlib==0.3.25 -f https://storage.googleapis.com/jax-releases/jax_releases.html ``` For specific older GPU wheels, be sure to use the `jax_cuda_releases.html` URL; for example -``` +```bash pip install jaxlib==0.3.25+cuda11.cudnn82 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html ``` diff --git a/docs/investigating_a_regression.md b/docs/investigating_a_regression.md index fb5293a82274..4affae3a65d8 100644 --- a/docs/investigating_a_regression.md +++ b/docs/investigating_a_regression.md @@ -1,6 +1,8 @@ (investigating-a-regression)= # Investigating a regression + + So you updated JAX and you hit a speed regression? You have a little bit of time and are ready to investigate this? Let's first make a JAX issue. diff --git a/docs/jax-101/01-jax-basics.ipynb b/docs/jax-101/01-jax-basics.ipynb deleted file mode 100644 index 79f64a5da61f..000000000000 --- a/docs/jax-101/01-jax-basics.ipynb +++ /dev/null @@ -1,833 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "6_117sy0CGEU" - }, - "source": [ - "# JAX As Accelerated NumPy\n", - "\n", - "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/jax-101/01-jax-basics.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/jax-101/01-jax-basics.ipynb)\n", - "\n", - "*Authors: Rosalia Schneider & Vladimir Mikulik*\n", - "\n", - "In this first section you will learn the very fundamentals of JAX." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "CXjHL4L6ku3-" - }, - "source": [ - "## Getting started with JAX numpy\n", - "\n", - "Fundamentally, JAX is a library that enables transformations of array-manipulating programs written with a NumPy-like API. \n", - "\n", - "Over the course of this series of guides, we will unpack exactly what that means. For now, you can think of JAX as *differentiable NumPy that runs on accelerators*.\n", - "\n", - "The code below shows how to import JAX and create a vector." - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": { - "id": "ZqUzvqF1B1TO" - }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[0 1 2 3 4 5 6 7 8 9]\n" - ] - } - ], - "source": [ - "import jax\n", - "import jax.numpy as jnp\n", - "\n", - "x = jnp.arange(10)\n", - "print(x)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "rPBmlAxXlBAy" - }, - "source": [ - "So far, everything is just like NumPy. A big appeal of JAX is that you don't need to learn a new API. Many common NumPy programs would run just as well in JAX if you substitute `np` for `jnp`. However, there are some important differences which we touch on at the end of this section.\n", - "\n", - "You can notice the first difference if you check the type of `x`. It is a variable of type `Array`, which is the way JAX represents arrays." - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": { - "id": "3fLtgPUAn7mi" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "Array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=int32)" - ] - }, - "execution_count": 2, - "metadata": { - "tags": [] - }, - "output_type": "execute_result" - } - ], - "source": [ - "x" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Yx8VofzzoHFH" - }, - "source": [ - "One useful feature of JAX is that the same code can be run on different backends -- CPU, GPU and TPU.\n", - "\n", - "We will now perform a dot product to demonstrate that it can be done in different devices without changing the code. We use `%timeit` to check the performance. \n", - "\n", - "(Technical detail: when a JAX function is called (including `jnp.array`\n", - "creation), the corresponding operation is dispatched to an accelerator to be\n", - "computed asynchronously when possible. The returned array is therefore not\n", - "necessarily 'filled in' as soon as the function returns. Thus, if we don't\n", - "require the result immediately, the computation won't block Python execution.\n", - "Therefore, unless we `block_until_ready` or convert the array to a regular\n", - "Python type, we will only time the dispatch, not the actual computation. See\n", - "[Asynchronous dispatch](https://jax.readthedocs.io/en/latest/async_dispatch.html#asynchronous-dispatch)\n", - "in the JAX docs.)" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": { - "id": "mRvjVxoqo-Bi" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "The slowest run took 7.39 times longer than the fastest. This could mean that an intermediate result is being cached.\n", - "100 loops, best of 5: 7.85 ms per loop\n" - ] - } - ], - "source": [ - "long_vector = jnp.arange(int(1e7))\n", - "\n", - "%timeit jnp.dot(long_vector, long_vector).block_until_ready()" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "DKBB0zs-p-RC" - }, - "source": [ - "**Tip**: Try running the code above twice, once without an accelerator, and once with a GPU runtime (while in Colab, click *Runtime* → *Change Runtime Type* and choose `GPU`). Notice how much faster it runs on a GPU." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "PkCpI-v0uQQO" - }, - "source": [ - "## JAX first transformation: `grad`\n", - "\n", - "A fundamental feature of JAX is that it allows you to transform functions.\n", - "\n", - "One of the most commonly used transformations is `jax.grad`, which takes a numerical function written in Python and returns you a new Python function that computes the gradient of the original function. \n", - "\n", - "To use it, let's first define a function that takes an array and returns the sum of squares." - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": { - "id": "LuaGUVRUvbzQ" - }, - "outputs": [], - "source": [ - "def sum_of_squares(x):\n", - " return jnp.sum(x**2)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "QAqloI1Wvtp2" - }, - "source": [ - "Applying `jax.grad` to `sum_of_squares` will return a different function, namely the gradient of `sum_of_squares` with respect to its first parameter `x`. \n", - "\n", - "Then, you can use that function on an array to return the derivatives with respect to each element of the array." - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": { - "id": "dKeorwJfvpeI" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "30.0\n", - "[2. 4. 6. 8.]\n" - ] - } - ], - "source": [ - "sum_of_squares_dx = jax.grad(sum_of_squares)\n", - "\n", - "x = jnp.asarray([1.0, 2.0, 3.0, 4.0])\n", - "\n", - "print(sum_of_squares(x))\n", - "\n", - "print(sum_of_squares_dx(x))" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "VfBt5CYbyKUX" - }, - "source": [ - "You can think of `jax.grad` by analogy to the $\\nabla$ operator from vector calculus. Given a function $f(x)$, $\\nabla f$ represents the function that computes $f$'s gradient, i.e.\n", - "\n", - "$$\n", - "(\\nabla f)(x)_i = \\frac{\\partial f}{\\partial x_i}(x).\n", - "$$\n", - "\n", - "Analogously, `jax.grad(f)` is the function that computes the gradient, so `jax.grad(f)(x)` is the gradient of `f` at `x`.\n", - "\n", - "(Like $\\nabla$, `jax.grad` will only work on functions with a scalar output -- it will raise an error otherwise.)\n", - "\n", - "This makes the JAX API quite different from other autodiff libraries like Tensorflow and PyTorch, where to compute the gradient we use the loss tensor itself (e.g. by calling `loss.backward()`). The JAX API works directly with functions, staying closer to the underlying math. Once you become accustomed to this way of doing things, it feels natural: your loss function in code really is a function of parameters and data, and you find its gradient just like you would in the math.\n", - "\n", - "This way of doing things makes it straightforward to control things like which variables to differentiate with respect to. By default, `jax.grad` will find the gradient with respect to the first argument. In the example below, the result of `sum_squared_error_dx` will be the gradient of `sum_squared_error` with respect to `x`." - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": { - "id": "f3NfaVu4yrQE" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[-0.20000005 -0.19999981 -0.19999981 -0.19999981]\n" - ] - } - ], - "source": [ - "def sum_squared_error(x, y):\n", - " return jnp.sum((x-y)**2)\n", - "\n", - "sum_squared_error_dx = jax.grad(sum_squared_error)\n", - "\n", - "y = jnp.asarray([1.1, 2.1, 3.1, 4.1])\n", - "\n", - "print(sum_squared_error_dx(x, y))" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "1tOztA5zpLWN" - }, - "source": [ - "To find the gradient with respect to a different argument (or several), you can set `argnums`:" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": { - "id": "FQSczVQkqIPY" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "(Array([-0.20000005, -0.19999981, -0.19999981, -0.19999981], dtype=float32),\n", - " Array([0.20000005, 0.19999981, 0.19999981, 0.19999981], dtype=float32))" - ] - }, - "execution_count": 7, - "metadata": { - "tags": [] - }, - "output_type": "execute_result" - } - ], - "source": [ - "jax.grad(sum_squared_error, argnums=(0, 1))(x, y) # Find gradient wrt both x & y" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "yQAMTnZSqo-t" - }, - "source": [ - "Does this mean that when doing machine learning, we need to write functions with gigantic argument lists, with an argument for each model parameter array? No. JAX comes equipped with machinery for bundling arrays together in data structures called 'pytrees', on which more in a [later guide](https://colab.research.google.com/github/google/jax/blob/main/docs/jax-101/05.1-pytrees.ipynb). So, most often, use of `jax.grad` looks like this:\n", - "\n", - "```\n", - "def loss_fn(params, data):\n", - " ...\n", - "\n", - "grads = jax.grad(loss_fn)(params, data_batch)\n", - "```" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "oBowiovisT97" - }, - "source": [ - "where `params` is, for example, a nested dict of arrays, and the returned `grads` is another nested dict of arrays with the same structure." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "LNjf9jUEsZZ8" - }, - "source": [ - "## Value and Grad\n", - "\n", - "Often, you need to find both the value and the gradient of a function, e.g. if you want to log the training loss. JAX has a handy sister transformation for efficiently doing that:" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": { - "id": "dWg4_-h3sYwl" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "(Array(0.03999995, dtype=float32),\n", - " Array([-0.20000005, -0.19999981, -0.19999981, -0.19999981], dtype=float32))" - ] - }, - "execution_count": 8, - "metadata": { - "tags": [] - }, - "output_type": "execute_result" - } - ], - "source": [ - "jax.value_and_grad(sum_squared_error)(x, y)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "QVT2EWHJsvvv" - }, - "source": [ - "which returns a tuple of, you guessed it, (value, grad). To be precise, for any `f`,\n", - "\n", - "```\n", - "jax.value_and_grad(f)(*xs) == (f(*xs), jax.grad(f)(*xs)) \n", - "```" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "QmHTVpAks3OX" - }, - "source": [ - "## Auxiliary data\n", - "\n", - "In addition to wanting to log the value, we often want to report some intermediate results obtained in computing the loss function. But if we try doing that with regular `jax.grad`, we run into trouble:" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": { - "id": "ffGCEzT4st41", - "tags": [ - "raises-exception" - ] - }, - "outputs": [ - { - "ename": "TypeError", - "evalue": "ignored", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mFilteredStackTrace\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 4\u001b[0;31m \u001b[0mjax\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgrad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msquared_error_with_aux\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", - "\u001b[0;31mFilteredStackTrace\u001b[0m: TypeError: Gradient only defined for scalar-output functions. Output was (Array(0.03999995, dtype=float32), Array([-0.10000002, -0.0999999 , -0.0999999 , -0.0999999 ], dtype=float32)).\n\nThe stack trace above excludes JAX-internal frames." - ] - } - ], - "source": [ - "def squared_error_with_aux(x, y):\n", - " return sum_squared_error(x, y), x-y\n", - "\n", - "jax.grad(squared_error_with_aux)(x, y)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "IUubno3nth4i" - }, - "source": [ - "This is because `jax.grad` is only defined on scalar functions, and our new function returns a tuple. But we need to return a tuple to return our intermediate results! This is where `has_aux` comes in:" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": { - "id": "uzUFihyatgiF" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "(Array([-0.20000005, -0.19999981, -0.19999981, -0.19999981], dtype=float32),\n", - " Array([-0.10000002, -0.0999999 , -0.0999999 , -0.0999999 ], dtype=float32))" - ] - }, - "execution_count": 10, - "metadata": { - "tags": [] - }, - "output_type": "execute_result" - } - ], - "source": [ - "jax.grad(squared_error_with_aux, has_aux=True)(x, y)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "g5s3UiFauwDk" - }, - "source": [ - "`has_aux` signifies that the function returns a pair, `(out, aux)`. It makes `jax.grad` ignore `aux`, passing it through to the user, while differentiating the function as if only `out` was returned." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "fk4FUXe7vsW4" - }, - "source": [ - "## Differences from NumPy\n", - "\n", - "The `jax.numpy` API closely follows that of NumPy. However, there are some important differences. We cover many of these in future guides, but it's worth pointing some out now.\n", - "\n", - "The most important difference, and in some sense the root of all the rest, is that JAX is designed to be _functional_, as in _functional programming_. The reason behind this is that the kinds of program transformations that JAX enables are much more feasible in functional-style programs.\n", - "\n", - "An introduction to functional programming (FP) is out of scope of this guide. If you already are familiar with FP, you will find your FP intuition helpful while learning JAX. If not, don't worry! The important feature of functional programming to grok when working with JAX is very simple: don't write code with side-effects.\n", - "\n", - "A side-effect is any effect of a function that doesn't appear in its output. One example is modifying an array in place:" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "metadata": { - "id": "o_YBuLQC1wPJ" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "array([123, 2, 3])" - ] - }, - "execution_count": 11, - "metadata": { - "tags": [] - }, - "output_type": "execute_result" - } - ], - "source": [ - "import numpy as np\n", - "\n", - "x = np.array([1, 2, 3])\n", - "\n", - "def in_place_modify(x):\n", - " x[0] = 123\n", - " return None\n", - "\n", - "in_place_modify(x)\n", - "x" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "JTtUihVZ13F6" - }, - "source": [ - "The side-effectful function modifies its argument, but returns a completely unrelated value. The modification is a side-effect. \n", - "\n", - "The code below will run in NumPy. However, JAX arrays won't allow themselves to be modified in-place:" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": { - "id": "u6grTYIVcZ3f", - "tags": [ - "raises-exception" - ] - }, - "outputs": [ - { - "ename": "TypeError", - "evalue": "ignored", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0min_place_modify\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mjnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0marray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# Raises error if we cast input to jnp.ndarray\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", - "\u001b[0;32m\u001b[0m in \u001b[0;36min_place_modify\u001b[0;34m(x)\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0min_place_modify\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 6\u001b[0;31m \u001b[0mx\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m123\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 7\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 8\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/jax/_src/numpy/lax_numpy.py\u001b[0m in \u001b[0;36m_unimplemented_setitem\u001b[0;34m(self, i, x)\u001b[0m\n\u001b[1;32m 6594\u001b[0m \u001b[0;34m\"or another .at[] method: \"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 6595\u001b[0m \"https://jax.readthedocs.io/en/latest/jax.ops.html\")\n\u001b[0;32m-> 6596\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mTypeError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmsg\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtype\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 6597\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 6598\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_operator_round\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnumber\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mndigits\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;31mTypeError\u001b[0m: '' object does not support item assignment. JAX arrays are immutable. Instead of ``x[idx] = y``, use ``x = x.at[idx].set(y)`` or another .at[] method: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html" - ] - } - ], - "source": [ - "in_place_modify(jnp.array(x)) # Raises error if we cast input to jnp.ndarray" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "RGqVfYSpc49s" - }, - "source": [ - "Helpfully, the error points us to JAX's side-effect-free way of doing the same thing via the [`jax.numpy.ndarray.at`](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html) index update operators (be careful [`jax.ops.index_*`](https://jax.readthedocs.io/en/latest/jax.ops.html#indexed-update-functions-deprecated) functions are deprecated). They are analogous to in-place modification by index, but create a new array with the corresponding modifications made:" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "metadata": { - "id": "Rmklk6BB2xF0" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "Array([123, 2, 3], dtype=int32)" - ] - }, - "execution_count": 13, - "metadata": { - "tags": [] - }, - "output_type": "execute_result" - } - ], - "source": [ - "def jax_in_place_modify(x):\n", - " return x.at[0].set(123)\n", - "\n", - "y = jnp.array([1, 2, 3])\n", - "jax_in_place_modify(y)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "91tn_25vdrNf" - }, - "source": [ - "Note that the old array was untouched, so there is no side-effect:" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "metadata": { - "id": "KQGXig4Hde6T" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "Array([1, 2, 3], dtype=int32)" - ] - }, - "execution_count": 14, - "metadata": { - "tags": [] - }, - "output_type": "execute_result" - } - ], - "source": [ - "y" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "d5TibzPO25qa" - }, - "source": [ - "Side-effect-free code is sometimes called *functionally pure*, or just *pure*.\n", - "\n", - "Isn't the pure version less efficient? Strictly, yes; we are creating a new array. However, as we will explain in the next guide, JAX computations are often compiled before being run using another program transformation, `jax.jit`. If we don't use the old array after modifying it 'in place' using indexed update operators, the compiler can recognise that it can in fact compile to an in-place modify, resulting in efficient code in the end.\n", - "\n", - "Of course, it's possible to mix side-effectful Python code and functionally pure JAX code, and we will touch on this more later. As you get more familiar with JAX, you will learn how and when this can work. As a rule of thumb, however, any functions intended to be transformed by JAX should avoid side-effects, and the JAX primitives themselves will try to help you do that.\n", - "\n", - "We will explain other places where the JAX idiosyncrasies become relevant as they come up. There is even a section that focuses entirely on getting used to the functional programming style of handling state: [Part 7: Problem of State](https://colab.research.google.com/github/google/jax/blob/main/docs/jax-101/07-state.ipynb). However, if you're impatient, you can find a [summary of JAX's sharp edges](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html) in the JAX docs." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "dFn_VBFFlGCz" - }, - "source": [ - "## Your first JAX training loop\n", - "\n", - "We still have much to learn about JAX, but you already know enough to understand how we can use JAX to build a simple training loop.\n", - "\n", - "To keep things simple, we'll start with a linear regression. \n", - "\n", - "Our data is sampled according to $y = w_{true} x + b_{true} + \\epsilon$." - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "metadata": { - "id": "WGgyEWFqrPq1" - }, - "outputs": [ - { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXIAAAD4CAYAAADxeG0DAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAVT0lEQVR4nO3df6zddX3H8dfrHk6Xc/3BqeE60wt3ZZt2EQt0XhHXbE5kFp1CbZTphplzWTMzjUxWQoUI23Ql63SaaLY0kSyLREGtVzJ1FYLOaAbz1ttSoNQRI8KpxhK96OgVbm/f++PeA7en5/f3e358v+f5SJrc8+N+z+cEePXD+/v+fD6OCAEAsmts0AMAACRDkANAxhHkAJBxBDkAZBxBDgAZd8YgPvSss86K9evXD+KjASCz9u/f/3hETNQ+P5AgX79+vWZnZwfx0QCQWbYfqfc8pRUAyDiCHAAyjiAHgIwjyAEg4whyAMi4gXStAMComZmraPe+Izo6v6B15ZJ2bNmgrZsmU7k2QQ4APTYzV9HOvYe0sLgkSarML2jn3kOSlEqYU1oBgB7bve/IMyFetbC4pN37jqRyfYIcAHrs6PxCR893KpUgt122/XnbD9k+bPtVaVwXAPJgXbnU0fOdSmtG/nFJ/xkRvyXpAkmHU7ouAGTeji0bVCoWTnmuVCxox5YNqVw/8c1O22dK+j1J75SkiHha0tNJrwsAeVG9odmrrhUnPbPT9oWS9kh6UMuz8f2S3hcRT9a8b7uk7ZI0NTX18kceqbv3CwCgAdv7I2K69vk0SitnSPptSf8SEZskPSnputo3RcSeiJiOiOmJidN2YQQAdCmNIH9M0mMRce/K489rOdgBAH2QOMgj4seSHrVdrdq/VstlFgBAH6S1svO9km61vUbS9yX9WUrXBQC0kEqQR8QBSacV4AEAvcfKTgDIOIIcADKOIAeAjGMbWwBY0cs9w3uJIAcA9X7P8F6itAIA6v2e4b3EjBzASFhdNimPFxUhPbGw+EwJpdd7hvcSQQ4g92rLJj87vvjMa9USypmlouYXFk/73bT2DO8lSisAcq9e2WS1hcUl2erpnuG9RJADyL12yiPzxxe1a9tGTZZLsqTJckm7tm0c+hudEqUVACNgXbmkSoswX1cuaeumyUwEdy1m5AByr1V5JCsllEYIcgC5MTNX0eab79a5131Zm2++WzNzFUnLfeDlUrHu7xTszJRQGiHIAeRCtTOlMr+g0LPdKNUwv+ny8+rezPzIlRdkOsQlghxATrRa0LN102Rmb2a2ws1OALnQzoKerN7MbIUZOYBcKI/Xr4FnYUFPUqkFue2C7Tnb/5HWNQGgHTNzFf3fL0+c9nyx4Ex3o7QrzRn5+yQdTvF6ANCW3fuOaPFknPb8c9ackctSSq1UauS2z5b0h5I+LOn9aVwTwGjrZG/wRvXxJ+rsnZJHac3IPybpWkknU7oegBHWqpWwVqM6+CjUx6UUgtz2GyX9JCL2t3jfdtuztmePHTuW9GMB5Fine4Pv2LIhsxtepSGNGflmSZfb/oGkz0q6xPana98UEXsiYjoipicmJlL4WAB51ene4HnuEW9H4hp5ROyUtFOSbP++pL+JiKuSXhfA6Gq0yVWzUklee8TbQR85gKEz6qWSTqW6sjMiviHpG2leE8Doqc6ss3ii/SCwRB/AUBrlUkmnKK0AQMYR5ACQcQQ5AGQcQQ4AGUeQA0DG0bUCIBXVTa4q8wsq2FqK0CRtg31BkANIrLrJVXV/lKVY3lK2utmVJMK8hyitAEis3iZXVc02u0I6CHIAiTXazKrd15EMQQ4gsTNL9c/LrBqVfcEHhRo5gI6tPr3nzFJRv3jq9PMyV2Ozq94iyAG0ZXVXiiVVT8icb3GcWrlU5EZnjxHkAFqq7Uo5/Zjj+krFgm66/LzeDQySqJEDaEOzrpRGCvZIndIzSAQ5gJY67TopFQv6yJUXEOJ9QpADaGpmrqIxu+l7imPW2vHiSJ6XOQyokQM4TaMbm6tVn2cZ/uAlDnLb50j6d0m/quV/rnsi4uNJrwtgMNq5sVmwKZ0MkTRm5CckXRMR37X9PEn7bd8ZEQ+mcG0AfbC6L3xsZcOrZk5GEOJDJHGQR8SPJP1o5edf2D4saVISQQ5kQKMNr5phpeZwSbVGbnu9pE2S7q3z2nZJ2yVpamoqzY8F0IGZuYpuuuOBZxbyjFk62W5juJY7UlipOVxS61qx/VxJX5B0dUT8vPb1iNgTEdMRMT0xMZHWxwLowMxcRTs+d/CU1ZjthHi1Z4WOlOGUyozcdlHLIX5rROxN45oA0jUzV9E1tx9sq3SyGl0pwy+NrhVL+pSkwxHx0eRDApC2G2YO6dZ7ftj20vqqyXJJ377ukp6MCelJo7SyWdI7JF1i+8DKnzekcF0AKZiZq+jTXYQ4tfDsSKNr5Vt6toQGYEisXtTTjuesKag8vkZH5xe0jnJKprCyE8ihTkspxYL14TdzEzOrCHIgJ2rbCtvFzczsI8iBDOu0fLLaeHFMD/7963swKvQbQQ5kVO2KzE4Ux6x/2HZ+D0aFQSDIgYzq5rAHiVJKHhHkQMZ0W0656uIpfWjrxh6NCoNEkAMZUl1iv9jJ5iiSNv/GCwjxHCPIgSG1emvZal/3TXc80HGIMxPPP4IcGEK1NzIr8wu6+rYDHV2jVCywwdWIIMiBIdTtjcwqbmiOFoIcGDIzc5WOb2SuHS9q7oOv69GIMOxS248cQHLVkkonigXrxjed16MRIQuYkQND4oaZQ/r0PT9s+/2W2NwKkghyYKC67Qkvl4o6cCOlFCwjyIEB6XaJfXHMuulySil4FjVyYEA66Uwpl4qylrtRdr/1AkopOAUzcmBAjrZZTqGMglZSmZHbvsz2EdsP274ujWsCWTczV9Hmm+/Wudd9WZtvvlszc5VTXl9XLrW8xphEGQUtpXH4ckHSJyX9gaTHJH3H9h0R8WDSawNZVXtCT2V+QTv3HtLsIz/V1x86pqPzCzqzVFSxYC0u1V9yXyqOade28ymjoKU0SisXSXo4Ir4vSbY/K+kKSQQ5RtLMXKXuMWsLi0untBfOLyyqOGatHS9q/vgirYToWhpBPinp0VWPH5P0yto32d4uabskTU1NpfCxwHDave9I22dlLp4Mja85g1WZSKRvXSsRsScipiNiemJiol8fC/TNzFxFm/7uax33hLd70xNoJI0ZeUXSOasen73yHDAyZuYq2vH5gw3r3c20c9MTaCaNIP+OpBfbPlfLAf42SX+cwnWBoVW7V/iTT53oKsRLxYJ2bNnQgxFilCQO8og4Yfs9kvZJKki6JSIeSDwyYEjV2yu8G2vHi7rxTedxcxOJpbIgKCK+IukraVwLGHZJ9wqnrRBpY2Un0KFOb04y80avEeRAh9aVS22XUz72RxcS4Og5Ns0COrRjywaVioWW75sslwhx9AUzcqCJeifZV8O52WHIdKOgn5iRAw3MzFX0/tsOqDK/oNByd8r7bzugmbmKtm6a1GSD/u+Czen16CuCHGhg5977dLLmuZMrz0v1SyylYkEfuZL9wtFflFaAGjfMHNJn7n1US1F/gc/C4nK8V8O6UekF6BeCHFil0wOQt26aJLgxcJRWgFU+c++jrd8EDBlm5BhZM3MV3XTHA5pfWJS0vHCnUTllteesad16CPQTQY6RNDNX0Y7PHdTiyWeD+2fHF1v+XmHM+vCbN/ZyaEDHCHKMpN37jpwS4u2Y5GYmhhRBjpHUaol9wdZShAq23v7Kc/ShrczCMbwIcoycmbmKLDU8jm2yXNK3r7ukn0MCEqFrBSOn2ZmaxYJZWo/MYUaO3Gq0T0qzbWh3v4VVmcgeghy5VO8Un517D0lqvA0tuxUiqxKVVmzvtv2Q7ftsf9F2Oa2BAd2Ymato88136+rbDpx2is/C4pJ27zvScI8USirIqqQ18jslvSwizpf0PUk7kw8J6E51Ft6sI6Uyv6Ctmya1a9tGTZZLspZn4uxWiCxLVFqJiK+teniPpLckGw7QvXbO0izYktgjBfmSZtfKuyR9tdGLtrfbnrU9e+zYsRQ/FljWzlma7SzBB7KmZZDbvsv2/XX+XLHqPddLOiHp1kbXiYg9ETEdEdMTExPpjB5YZV2Dgx5Wa3QYBJBlLUsrEXFps9dtv1PSGyW9NoLpDnqj2ZFrVTu2bDilU6UWNzSRV4lq5LYvk3StpFdHxPF0hgScql4r4dW3HdAH9t6nhcWTpwV7NfDL40VFSE8sLHLoA3ItaR/5JyT9iqQ7vXwT6Z6I+MvEowJWaXQT8/jKST2re8S5iYlRlLRr5TfTGghQzw0zzdsJq6o94oQ4RhErOzF0ag98aFc7XStAHhHkGCr1DnxoVztdK0Aesfshhko3Bz5IdKRgtDEjx1BptzxSLhVlS/PH6UgBCHIMlUY7E6521cVTnNgDrEJpBUNlx5YNKo657msWIQ7Uw4wcfdVqhWb159VdK2vHi7rxTedROgEaIMjRN80Oe6gNc0IbaB9Bjp6pnX0/+dSJuoc9XHP7QUkivIEuEeToiXqz70aWIurOzAG0h5ud6Il2DnlYrbrEHkDnCHL0RDfL5VliD3SHIEdPNFouv3a8+Mxxa+3+DoDmqJGja7WbW1XbBCXpyadOnPb+UrHwzOu1B0CwxB7oHkGOrtTb3Opnxxd1zecOakw6bb+Uer3grU78AdAeghxdabS51dLJUL1bnONrzqBXHOgRauToSqc3JrmRCfQOQY6ulIqd/avDjUygd1IJctvX2A7bZ6VxPQy3mbnKM+dltoMbmUBvJa6R2z5H0usk/TD5cJAFrRburB0vanzNGdzIBPokjZud/yzpWklfSuFayIBm9W5L7FQI9Fmi0ortKyRVIuJgG+/dbnvW9uyxY8eSfCwGrFm9+08uniLEgT5rGeS277J9f50/V0j6gKQPtvNBEbEnIqYjYnpiYiLpuDFAO7ZsUKlYOOU5Dn0ABqdlaSUiLq33vO2Nks6VdNDLS67PlvRd2xdFxI9THSWGSnXGzYIeYDh0XSOPiEOSXlh9bPsHkqYj4vEUxoUBanWKj8SCHmCYsLITp2j3FB8AwyO1II+I9WldC/1RnXlX5hdUsLUUoTFLtSvvq3uFE+TAcGJGPqJqZ95LsZzedbZPkcQSe2CYEeQ516je3ekJPiyxB4YXQZ5jzerdnc6wWWIPDC82zcqxerPuar27kxl2uVSkPg4MMYI8xxrNuo/OL9Rd1FNPqVjQTZefl/bQAKSIIM+xRrPudeWStm6a1K5tGzW58p7qOZprx4sql4qypMlySbu2bWQ2Dgw5auQ5tmPLhqZnY7KoB8gHgjwnmq3GZCk9kG8EeQ60Wo1JcAP5Ro08B5p1pwDIP4I8B5p1pwDIP4I8B5p1pwDIP4I8B+r1hHPgMTA6uNmZA3SnAKONIM8JulOA0UVpBQAyjiAHgIxLHOS232v7IdsP2P7HNAYFAGhfohq57ddIukLSBRHxlO0XtvodNNfOwccAsFrSm53vlnRzRDwlSRHxk+RDGl0cfAygG0lLKy+R9Lu277X9X7Zf0eiNtrfbnrU9e+zYsYQfm08stQfQjZYzctt3SXpRnZeuX/n9F0i6WNIrJN1u+9cj4rQjfCNij6Q9kjQ9Pd3giN/RxlJ7AN1oGeQRcWmj12y/W9LeleD+H9snJZ0liSl3F9aVS6rUCW2W2gNoJmlpZUbSayTJ9kskrZH0eNJBjSqW2gPoRtKbnbdIusX2/ZKelvSn9coqo6yTLhSW2gPohgeRu9PT0zE7O9v3z+232i4UaXmGzTmYALphe39ETNc+z14rKVs9Ax+ztVTzF+XC4pKuuf2gJFoKAaSDIE9R7Qy8NsSrliLoDweQGvZaSVG9PvBG6A8HkBZm5F2qdxOz035v+sMBpIEg70KjpfRnloqaX1hs+zr0hwNIA6WVLjRaSm+rbh/4VRdP0R8OoGcI8i40KonMH1/Urm0bNVkuyZImyyXt2rZRH9q6se7z3OgEkAZKK11otpS+0ZFrHMUGoFeYkXeBpfQAhgkz8i6wlB7AMCHIu0SpBMCwoLQCABlHkANAxhHkAJBxBDkAZBxBDgAZR5ADQMYlCnLbF9q+x/YB27O2L0prYACA9iSdkf+jpL+NiAslfXDlMQCgj5IGeUh6/srPZ0o6mvB6AIAOJV3ZebWkfbb/Sct/KfxOozfa3i5puyRNTU0l/FgAQFXLILd9l6QX1XnpekmvlfTXEfEF21dK+pSkS+tdJyL2SNojSdPT0/UPs0yg3ok9LKEHMAocDQ4IbuuX7ScklSMibFvSExHx/Fa/Nz09HbOzs11/bq3aE3uk5d0I2fMbQJ7Y3h8R07XPJ62RH5X06pWfL5H0vwmv15VGJ/ZwuDGAUZC0Rv4Xkj5u+wxJv9RKDbzfGp3Yw+HGAEZBoiCPiG9JenlKY+lasxN7ACDvcrGykxN7AIyyzBws0awrhRN7AIyyTAR5bVdKZX5BO/cekqRTwpzgBjCKMlFaoSsFABrLRJDTlQIAjWUiyBt1n9CVAgAZCXK6UgCgsUzc7KQrBQAay0SQS3SlAEAjmSitAAAaI8gBIOMIcgDIOIIcADKOIAeAjEt0QlDXH2ofk/RI3z+4d86S9PigB9FHfN/8G7XvnJXv+2sRMVH75ECCPG9sz9Y7fimv+L75N2rfOevfl9IKAGQcQQ4AGUeQp2PPoAfQZ3zf/Bu175zp70uNHAAyjhk5AGQcQQ4AGUeQp8D2btsP2b7P9hdtlwc9pl6z/VbbD9g+aTuzbVut2L7M9hHbD9u+btDj6SXbt9j+ie37Bz2WfrB9ju2v235w5d/l9w16TN0iyNNxp6SXRcT5kr4naeeAx9MP90vaJumbgx5Ir9guSPqkpNdLeqmkt9t+6WBH1VP/JumyQQ+ij05IuiYiXirpYkl/ldV/vgR5CiLiaxFxYuXhPZLOHuR4+iEiDkdE3k+/vkjSwxHx/Yh4WtJnJV0x4DH1TER8U9JPBz2OfomIH0XEd1d+/oWkw5IyeegBQZ6+d0n66qAHgVRMSnp01ePHlNH/0NGc7fWSNkm6d7Aj6U5mTggaNNt3SXpRnZeuj4gvrbznei3/79qt/Rxbr7TznYGss/1cSV+QdHVE/HzQ4+kGQd6miLi02eu23ynpjZJeGzlpzm/1nUdARdI5qx6fvfIccsJ2UcshfmtE7B30eLpFaSUFti+TdK2kyyPi+KDHg9R8R9KLbZ9re42kt0m6Y8BjQkpsW9KnJB2OiI8OejxJEOTp+ISk50m60/YB2/866AH1mu03235M0qskfdn2vkGPKW0rN7DfI2mflm+E3R4RDwx2VL1j+zOS/lvSBtuP2f7zQY+pxzZLeoekS1b+uz1g+w2DHlQ3WKIPABnHjBwAMo4gB4CMI8gBIOMIcgDIOIIcADKOIAeAjCPIASDj/h/USuotBmiqlQAAAABJRU5ErkJggg==\n", - "text/plain": [ - "
" - ] - }, - "metadata": { - "needs_background": "light", - "tags": [] - }, - "output_type": "display_data" - } - ], - "source": [ - "import numpy as np\n", - "import matplotlib.pyplot as plt\n", - "\n", - "xs = np.random.normal(size=(100,))\n", - "noise = np.random.normal(scale=0.1, size=(100,))\n", - "ys = xs * 3 - 1 + noise\n", - "\n", - "plt.scatter(xs, ys);" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "RTh22mo4rR1x" - }, - "source": [ - "Therefore, our model is $\\hat y(x; \\theta) = wx + b$.\n", - "\n", - "We will use a single array, `theta = [w, b]` to house both parameters:" - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "metadata": { - "id": "TnVrRTMamyzb" - }, - "outputs": [], - "source": [ - "def model(theta, x):\n", - " \"\"\"Computes wx + b on a batch of input x.\"\"\"\n", - " w, b = theta\n", - " return w * x + b" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "qCrLmmKrn9_h" - }, - "source": [ - "The loss function is $J(x, y; \\theta) = (\\hat y - y)^2$." - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "metadata": { - "id": "07eMcDLMn9Ww" - }, - "outputs": [], - "source": [ - "def loss_fn(theta, x, y):\n", - " prediction = model(theta, x)\n", - " return jnp.mean((prediction-y)**2)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "ejMt4dulnoYX" - }, - "source": [ - "How do we optimize a loss function? Using gradient descent. At each update step, we will find the gradient of the loss w.r.t. the parameters, and take a small step in the direction of steepest descent:\n", - "\n", - "$\\theta_{new} = \\theta - 0.1 (\\nabla_\\theta J) (x, y; \\theta)$" - ] - }, - { - "cell_type": "code", - "execution_count": 18, - "metadata": { - "id": "2I6T5Wphpaaa" - }, - "outputs": [], - "source": [ - "def update(theta, x, y, lr=0.1):\n", - " return theta - lr * jax.grad(loss_fn)(theta, x, y)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "MAUL1gT_opVn" - }, - "source": [ - "In JAX, it's common to define an `update()` function that is called every step, taking the current parameters as input and returning the new parameters. This is a natural consequence of JAX's functional nature, and is explained in more detail in [The Problem of State](https://colab.research.google.com/github/google/jax/blob/main/docs/jax-101/07-state.ipynb).\n", - "\n", - "This function can then be JIT-compiled in its entirety for maximum efficiency. The next guide will explain exactly how `jax.jit` works, but if you want to, you can try adding `@jax.jit` before the `update()` definition, and see how the training loop below runs much faster." - ] - }, - { - "cell_type": "code", - "execution_count": 19, - "metadata": { - "id": "WLZxY7nIpuVW" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "w: 3.00, b: -1.00\n" - ] - }, - { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXIAAAD4CAYAAADxeG0DAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAZRklEQVR4nO3de5BcZZnH8d8znQ7pAcIECWAGxiBiUDYm0RaiKMpFAgmEGBVkxfWeKm9L1A0kAgoLbLJmxeBqlZtV13JhJUqyIxRiSLwulAEmzJBwC6IESAclaCZgMknm8uwfc2Gmp3v6drpPn+7vp8oi5/Tpc94h8sub97zv+5i7CwAQXQ1hNwAAUBqCHAAijiAHgIgjyAEg4ghyAIi4cWE89KijjvKpU6eG8WgAiKzNmze/6O6T08+HEuRTp05VW1tbGI8GgMgys2cynWdoBQAijiAHgIgjyAEg4ghyAIg4ghwAIo4gB4CIC2X6IQDUm9b2lFau36adnV2a0pTQkjnTtGBWcyD3JsgBoMxa21O6cu0WHejpkySlOru0bN1WSQokzBlaAYAycnctXfdKiA/q6u7VyvXbAnlGIEFuZk1mdruZPWFmj5vZ24K4LwBE2dYde3TCsp9pf3dfxs93dnYF8pyghlZulvRzd3+/mY2X1BjQfQEgcvr6XB/4j99p8zO7JUkNJvVlKMY2pSkRyPNKDnIzO0LSGZI+KknuflDSwVLvCwBRdN9TL+pD371/6PgHH3urOvd1a9m6rerq7h06n4jHtGTOtECeGUSP/ARJuyT9l5nNkLRZ0uXuvjeAewNAJHT39undK3+t1MBwySlTJuqOz71DsQYbuqZcs1as1OLLZpaUtEnS6e5+v5ndLOkld78m7bpFkhZJUktLy1ueeSbjJl4AEDl3bXlen/2fh4aO133m7Xpzy6TAn2Nmm909mX4+iB75Dkk73H3w7xK3S1qafpG7r5a0WpKSyWRpf3oAQBXYd7BHM667R929/ZF21slH63sfScrMcnwzWCUHubv/ycyeM7Np7r5N0tmSHiu9aQBQvf570zO6pvWRoeMNXzhDJx1zeChtCWrWyucl3TowY+WPkj4W0H0BoKrs3ntQs67fMHR86aktWr5weogtCijI3b1D0qhxGwCoJas2PqlVG38/dHzf0rPUHNAUwlKwRB8ActjZ2aW3r/jl0PE/nn2Svvie14fYopEIcgAYw9WtW3XLpmeHjh+65j068tDxIbZoNIIcADJ46oWXdc5Nvx06vm7+KfrI26eG16AxEOQAMKC1PaWv/fwJ7dyzf+hcg0lbr52jQw+p3ris3pYBQAW1tqe05CcPq3vYpijxmGnl+2dUdYhLbGMLAOrp7dPiNR0jQlySuns9sK1my6m6/5gBgIAMr9DT1BiXu7Snq1uHHjJOfzvQk/V7QW01W04EOYCa19qeGrH74O593UOfjRXiUnBbzZYTQysAat7K9dtGbCGbyaTGuBLx2IhzQW41W04EOYCal8/wSOe+bi1fOF3NTQmZpOamhJYvnB7YVrPlxNAKgJo3pSkxtE/4WNcsmNUcieBOR48cQE370QPP5gzxqAyhZEOPHEDNGD4zJZ9euCTFzCIzhJINPXIANWFwZkqqs0sujQrxVZfMzPgy8+sXz4h0iEv0yAHUiGwzUyY1xtX+lXNHXFeOuplhIsgB1IRswyidw+aMR/VlZi4MrQCItN4+19Sld2X9PAoLekoVWI/czGKS2iSl3P2CoO4LANmMFeBS/6ZXUZ6Nkq8ge+SXS3o8wPsBQEa79x7MGeKSdOj4cTU5lJIukB65mR0naZ6kGyV9MYh7Aqhv6VMJB19MZgpwk+Sjb6E9Xd0ZztaeoHrkqyRdIakvoPsBqGOZphJeuXbLqBB/4vrztH3FvKzj4PUwPi4FEORmdoGkF9x9c47rFplZm5m17dq1q9THAqhhmaYSHugZ2U/cvmKeJgzMC18yZ1pkN7wKQhBDK6dLmm9mcyVNkDTRzG5x98uGX+TuqyWtlqRkMpnpb0EAIGnsTa6eXj5XZjbi3OA4eC3OEc9HyUHu7sskLZMkM3u3pH9KD3EAKES25fXNTYlRIT6oVueI54N55ACqytK1WzKGeD0NlRQq0JWd7v5rSb8O8p4A6kf6y8wJ8QYd6O6ru6GSQrFEH0DoTr1xo154+cCIc9tXzAupNdFDkAMIVXov/JoL3qhPvOOEkFoTTQQ5gFBkWthDL7w4BDmAiuru7dNJV9094tyaRbN12mtfFVKLoo8gB1Ax9MLLgyAHUHYvvLxfp974ixHnNi07W8ceMSGkFtUWghxAIAY3uUp1dilmpl53NWdZ2EMvPFgEOYCSDW5yNbg/Sq/378KRHuJP3nC+xo9jHWLQCHIAJctWL3M4euHlwx+NAEo21iZXUv9+4SgfghxAyY5IxMf8vF72BQ8LQysACja8ek8+e1Kz2VV5EeQA8jJ8Vkq20mqZNCXibHZVZgQ5gJzSZ6XkG+KJeEzXzj+lfA2DJMbIAeQhn1kp6WJmWr5wOr3xCiDIAeSUa1ZKukQ8pq9fPIMQrxCCHMCYbnvg2ZxDKfEG06TGuEz95djoiVcWY+QARhn+YjObwReezVTvCV3JQW5mx0v6oaRj1P/7utrdby71vgDC0dqe0hW3b9HB3r6s18TMGDqpIkH0yHskfcndHzKzwyVtNrMN7v5YAPcGUAGFzgvvcyfEq0jJQe7uz0t6fuDXL5vZ45KaJRHkQBUaHtpTmhI68+TJWrs5VdCsFFZqVpdAx8jNbKqkWZLuz/DZIkmLJKmlpSXIxwLIU/p88FRnl27Z9GxB90jEY6zUrDKBzVoxs8MkrZW02N1fSv/c3Ve7e9Ldk5MnTw7qsQAKUMx8cOmVTa+YkVKdAumRm1lc/SF+q7uvC+KeAILT2p7SdXc+qt37ugv+LrNSql8Qs1ZM0vckPe7uN5XeJABBam1PacntD6u7N9+F9a9obkrovqVnlaFVCFIQQyunS/qwpLPMrGPgf3MDuC+AAKxcv62oEGcsPDqCmLVyr9g3Hqg6+SzqGe7Q8TE1NY4fms3CcEp0sLITqEFXt27VrZtyL60fFI+ZbnwvLzGjiiAHagAvM+sbQQ5EWGt7Skt+0qHu7Kvps5rUGFf7V84NvlGoOIIciKjW9pS+uKZDRWS4EvGYvnohBR9qBUEORNTK9duKCnGGUmoPQQ5EzNWtW/Wj+59Trxc2pfCy2S26YcH0MrUKYSLIgSqVvrnVkjnT1PbMXwveG0WSTj/xSEK8hhHkQBXKtLnV4jUdBd/HJH2InnjNI8iBKlTs5laDmhJxXTv/FMbB6wRBDlShQosdx8z0h+XsjFGvKL4MVJnW9pQarLBdLy497fgytQZRQI8cqCKDY+P5zkiJmenS045nDLzOEeRAiIbPTEnEG7SvgCWa21fMK2PLECUEORCS9JkphYR4MzUzMQxj5EBIip2Zwj7hSEeQAyEpZGZKUyIuEzUzkRlDK0BIpjQl8ir60JSIq+Or7FKI7IIqvnyepJslxSR9191XBHFfIMoyLbEf7El/6ccP5xXiDZKunc8uhRibeYEb74y6gVlM0pOS3iNph6QHJV3q7o9l+04ymfS2traSngtUs0Ir9JikRLxBXd19Q99JxBu0fOGbGEbBEDPb7O7J9PNB9MhPlfSUu/9x4EG3SbpIUtYgB2pZa3sq7xCfMK5BK95HWKM0QbzsbJb03LDjHQPnRjCzRWbWZmZtu3btCuCxQHVauX5b3j3x/T19Wrl+W1nbg9pXsVkr7r7a3ZPunpw8eXKlHgtUTGt7SrP++Z68q9YPKnRfFSBdEEMrKUnDN3o4buAcUDda21NacvvD6u4t/J3TFBb3oERBBPmDkk4ysxPUH+AflPT3AdwXqFrpM1J27z1QVIizuAdBKDnI3b3HzD4nab36px9+390fLbllQJXKVPShGNTORFACmUfu7j+T9LMg7gVUu2KW1pskF+GN8mBlJ1CgQl9OTmqM66sXUq0H5UOQAwXKd2m9SfrGJTMJcJQdm2YBBVowa0rOaxLxGCGOiqFHDowhfXZKvi822aEQlUSQA1m0tqe0eE3H0HF6iB87cYL+9NL+Ud9rbkoQ4qgoghxIM9gLH6v3vX3FvFHTECXmhSMcBDkwTCErNAd73dm2qgUqhSAHhrnuzkcLWqG5YFYzwY3QMWsFGGb3vu6wmwAUjB456lZre0rX3vGoOrsKC+/TTzyyTC0CikOQoy61tqe05CcPq7uvsI2uTj/xSN36qbeVqVVAcQhy1KVr73g07xCPN5hWfmAGY+GoWoyRo+60tqdyDqc0NyVkA/8kxFHt6JGjruzZ1z1ikU8mzU0J3bf0rAq1CCgdQY6aVczy+njMWNCDyCHIUZOKKf7QYNLK9zOMguhhjBw1pbU9pdNX/FKL13RkLf6w6pKZSsRjI84l4jHddDG7FSKaSuqRm9lKSRdKOijpD5I+5u6dQTQMKFSmvU8yYWk9ak2pQysbJC0bqNv5r5KWSbqy9GYBhcunBFvMTBJL61FbShpacfd73L1n4HCTpONKbxJQnHzGwXu98Er3QLULcoz845LuzvahmS0yszYza9u1a1eAjwWkqUvvyuu65qZEmVsCVF7OoRUz2yjp2AwfXeXuPx245ipJPZJuzXYfd18tabUkJZNJukUIRL4BLrFXOGpXziB393PG+tzMPirpAklnu/P3VpRHa3tK19356NDuhIl4g7q6+0Zdt+qSmUMvMZsa43KX9nR180ITNa3UWSvnSbpC0rvcfV8wTQJGylTsIT3EE/HYUJ1Mwhr1ptQx8m9JOlzSBjPrMLPvBNAmYMjVrVu1eE1HzmIPXd29Wrl+W4VaBVSXknrk7v66oBoCpLu6datu2fRs3tfvzLPCPVBrWKKPqlNswYcpzEhBnSLIUVWKLfjAjBTUM4IcVWXl+m15hXgi3qAJ8Zg69zEjBSDIUVXyWZ152ewW3bBgegVaA0QDQY6qsOiHbbrnsT/nvI4QB0YjyFFR6cUelsyZlrNij9Rf8IG9woHM2I8cFTO4zWyqs0uu/mGU9BDfvmKeVl0yU02J+NC5SY1xQhwYAz1ylE1673vvgZ4xt5ndvmKeJLaYBQpFjxxlkan3Pda88EQ8ptb2VOUaCNQQghxlkU+Rh+FYYg8UjyBHWRSzXJ4l9kBxCHKURbbl8g32Srm1fL8DYGwEOQLX2p7KuLBnsFL91y+ekbGKPUvsgeIwawVFS9/calJjXGeePFnrHto56tpJjXF99cJTRsxGoYo9EAyCHEXJtLnV7n3dGUNckhrHjxsR1EwxBILD0AqKku/mVoN4kQmUD0GOouSzudVwvMgEyieQIDezL5mZm9lRQdwP1a3QhTu8yATKq+QgN7PjJZ0rKf+aXIi0XJtcNcYb1NyUkElqbkoMFUUGUB5BvOz8hqQrJP00gHuhip1/8//p8edfynndvyx8E8ENVFBJQW5mF0lKufvDlmWRx7BrF0laJEktLS2lPBYhmLr0rryuu2x2CyEOVFjOIDezjZKOzfDRVZK+rP5hlZzcfbWk1ZKUTCYLK8iI0GQK8FWXzNSydVtH7KVikj5E0QcgFDmD3N3PyXTezKZLOkHSYG/8OEkPmdmp7v6nQFuJUKSH+BtePVF3X/7OoWMW9ADVoeihFXffKunowWMz2y4p6e4vBtAuhChTL3xwr/BBLOgBqgfzyDGkt88zhjh7hQPVLbAl+u4+Nah7oTIGK/jkWtwzuFc4PXCgOrHXSp0arOCTb/EHltgD1Ysgr3GZqtYvmNVccAUfltgD1Ysgr2Hpve5UZ5eWrduqp1/cW/BeKSyxB6oXLztrWKZed1d3r27+xe8Luk9TIs74OFDFCPIalmtce8K43L/9iXhM184/JagmASgDgryGjTWuvX3FPK1435vUPHDNYB3NSY1xNSXibHgFRAhj5DXs6MMPGTUWnojHtHxh/zJ6FvUAtYEgrxHps1MyvcxsZik9UJMI8hqQaXbKcOnL6wHUFsbIa0C2OeGHHTKOEAfqAEFeA7LNCd97oKfCLQEQBoI8wnp6+8Ys+MBqTKA+MEYeUbkq9lDwGKgf9Mgj5i9/OzAqxO+98kytumQmBY+BOkWPPELGKvhw3KRGghuoUwR5BDz8XKcu+vZ9I849cf15mhCPhdQiANWEIK9y+ZRdA1DfSg5yM/u8pM9K6pV0l7tfUXKroDUPPqsr124dce7p5XM1UOgaAIaUFORmdqakiyTNcPcDZnZ0ru9gbK3tKS1e0zHiXDxm+v2Nc0NqEYBqV2qP/NOSVrj7AUly9xdKb1L9+vgPHtQvnxj5r3D4JlcAkEmp0w9fL+mdZna/mf3GzN6a7UIzW2RmbWbWtmvXrhIfW3umLr1rVIhLrxQ+BoBscvbIzWyjpGMzfHTVwPePlDRb0lsl/djMXuvunn6xu6+WtFqSksnkqM/r1fxv3astO/aMeQ2FjwGMJWeQu/s52T4zs09LWjcQ3A+YWZ+koyTR5c5D+oyUSY1x7d7XPeo6ltoDGEupY+Stks6U9Csze72k8ZJeLLlVNS7blML07WglltoDyK3UIP++pO+b2SOSDkr6SKZhlXo2vODDq4+YoJ179o/4/M7PvUPTjztCkoZWZg4vEEEhCAC5WBi5m0wmva2treLPrbRMPezhWNgDoBBmttndk+nnWdkZsOE98AYz9Wb4g7JB0k2XzKx84wDUJII8QOk98EwhLkl9kpat61+1ybAJgFKxjW2AspVcy4T54QCCQo+8SOlV65fMmVbwfG/mhwMIAkFehExV669cu0WFvjZmfjiAIDC0UoRMQygHevoyXpuIx3TZ7BYl0vYOZ344gKAQ5EUYa0gkU8m1GxZM1/KF0ynFBqAsGFopwpSmhFIZwry5KaEFs5ozBnS28wBQKnrkBfrVthcyhjhDJQDCQo88T+6uE5b9bMS5YydO0J9f2s9SegChIsjzkF527R2vO0q3fPK0EFsEAK8gyMfQ2+c68csje+Fbrj1XEyfEQ2oRAIxGkGdx0z3b9M1fPjV0fNnsFt2wgJJrAKoPQZ5mf3evTr7m5yPOPXnD+Ro/jvfCAKoTQT7M5be166cdO4eOvzz3ZC0648QQWwQAuRHkkv6696DefP2GEeeeXj5XZhZSiwAgf3Uf5Bf++73amnql+PE3L52l+TOmhNgiAChMSUFuZjMlfUfSBEk9kj7j7g8E0bBye+Yve/Wulb8ecY6KPQCiqNQe+dckXefud5vZ3IHjd5fcqjJ7/dV36+CwTa5uWzRbs1/7qhBbBADFKzXIXdLEgV8fIWnnGNeGruO5Ti349n0jztELBxB1pQb5Yknrzezf1L9vy9uzXWhmiyQtkqSWlpYSH1u4qUvvGnG84Qtn6KRjDq94OwAgaDmD3Mw2Sjo2w0dXSTpb0hfcfa2ZXSzpe5LOyXQfd18tabUkJZPJQmsw5JSpYs+CWc3a+Nif9ckftg1d95pXNeo3S84M+vEAEBrzLAWC8/qy2R5JTe7u1j9Xb4+7T8z1vWQy6W1tbbkuy1t6xR5JmjCuQfvTij3c/+WzdczECYE9FwAqycw2u3sy/XypyxV3SnrXwK/PkvT7Eu9XlEwVe4aH+FknH63tK+YR4gBqUqlj5J+SdLOZjZO0XwNj4JU2VsWeR66bo8MOqfvp8gBqWEkJ5+73SnpLQG0p2lgVewhxALUu8jtBHezp094DPaPOU7EHQL2ITHc106yUWIPp8z9qH7pm8mGH6MW/HaBiD4C6EokgT5+Vkurs0uI1HUOfn/OGo/Wf/5BkkysAdSkSQZ5pVsqgjV88Q687moU9AOpXJMbIs81KMYkQB1D3IhHkU5oSBZ0HgHoSiSBfMmeaEvHYiHPMSgGAfpEYIx+cfZJpLxUAqHeRCHKpP8wJbgAYLRJDKwCA7AhyAIg4ghwAIo4gB4CII8gBIOIIcgCIuJJKvRX9ULNdkp6p+IPL5yhJL4bdiAri56199fYzR+XnfY27T04/GUqQ1xoza8tUR69W8fPWvnr7maP+8zK0AgARR5ADQMQR5MFYHXYDKoyft/bV288c6Z+XMXIAiDh65AAQcQQ5AEQcQR4AM1tpZk+Y2RYz+18zawq7TeVmZh8ws0fNrM/MIjttKxczO8/MtpnZU2a2NOz2lJOZfd/MXjCzR8JuSyWY2fFm9isze2zg/8uXh92mYhHkwdgg6e/c/U2SnpS0LOT2VMIjkhZK+m3YDSkXM4tJ+rak8yW9UdKlZvbGcFtVVj+QdF7YjaigHklfcvc3Spot6bNR/f0lyAPg7ve4e8/A4SZJx4XZnkpw98fdfVvY7SizUyU95e5/dPeDkm6TdFHIbSobd/+tpL+G3Y5Kcffn3f2hgV+/LOlxSZGsXkOQB+/jku4OuxEIRLOk54Yd71BE/0PH2MxsqqRZku4PtyXFiUypt7CZ2UZJx2b46Cp3/+nANVep/69rt1aybeWSz88MRJ2ZHSZpraTF7v5S2O0pBkGeJ3c/Z6zPzeyjki6QdLbXyOT8XD9zHUhJOn7Y8XED51AjzCyu/hC/1d3Xhd2eYjG0EgAzO0/SFZLmu/u+sNuDwDwo6SQzO8HMxkv6oKQ7Qm4TAmJmJul7kh5395vCbk8pCPJgfEvS4ZI2mFmHmX0n7AaVm5m918x2SHqbpLvMbH3YbQrawAvsz0lar/4XYT9290fDbVX5mNmPJP1O0jQz22Fmnwi7TWV2uqQPSzpr4L/bDjObG3ajisESfQCIOHrkABBxBDkARBxBDgARR5ADQMQR5AAQcQQ5AEQcQQ4AEff/wA5ga+Fcz+UAAAAASUVORK5CYII=\n", - "text/plain": [ - "
" - ] - }, - "metadata": { - "needs_background": "light", - "tags": [] - }, - "output_type": "display_data" - } - ], - "source": [ - "theta = jnp.array([1., 1.])\n", - "\n", - "for _ in range(1000):\n", - " theta = update(theta, xs, ys)\n", - "\n", - "plt.scatter(xs, ys)\n", - "plt.plot(xs, model(theta, xs))\n", - "\n", - "w, b = theta\n", - "print(f\"w: {w:<.2f}, b: {b:<.2f}\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "5-q17kJ_rjLc" - }, - "source": [ - "As you will see going through these guides, this basic recipe underlies almost all training loops you'll see implemented in JAX. The main difference between this example and real training loops is the simplicity of our model: that allows us to use a single array to house all our parameters. We cover managing more parameters in the later [pytree guide](https://colab.research.google.com/github/google/jax/blob/main/docs/jax-101/05.1-pytrees.ipynb). Feel free to skip forward to that guide now to see how to manually define and train a simple MLP in JAX." - ] - } - ], - "metadata": { - "colab": { - "collapsed_sections": [], - "name": "Jax Basics.ipynb", - "provenance": [] - }, - "jupytext": { - "formats": "ipynb,md:myst" - }, - "kernelspec": { - "display_name": "Python 3", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.7.6" - } - }, - "nbformat": 4, - "nbformat_minor": 0 -} diff --git a/docs/jax-101/01-jax-basics.md b/docs/jax-101/01-jax-basics.md deleted file mode 100644 index 45959796abba..000000000000 --- a/docs/jax-101/01-jax-basics.md +++ /dev/null @@ -1,383 +0,0 @@ ---- -jupytext: - formats: ipynb,md:myst - text_representation: - extension: .md - format_name: myst - format_version: 0.13 - jupytext_version: 1.16.0 -kernelspec: - display_name: Python 3 - name: python3 ---- - -+++ {"id": "6_117sy0CGEU"} - -# JAX As Accelerated NumPy - -[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/jax-101/01-jax-basics.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/jax-101/01-jax-basics.ipynb) - -*Authors: Rosalia Schneider & Vladimir Mikulik* - -In this first section you will learn the very fundamentals of JAX. - -+++ {"id": "CXjHL4L6ku3-"} - -## Getting started with JAX numpy - -Fundamentally, JAX is a library that enables transformations of array-manipulating programs written with a NumPy-like API. - -Over the course of this series of guides, we will unpack exactly what that means. For now, you can think of JAX as *differentiable NumPy that runs on accelerators*. - -The code below shows how to import JAX and create a vector. - -```{code-cell} ipython3 -:id: ZqUzvqF1B1TO - -import jax -import jax.numpy as jnp - -x = jnp.arange(10) -print(x) -``` - -+++ {"id": "rPBmlAxXlBAy"} - -So far, everything is just like NumPy. A big appeal of JAX is that you don't need to learn a new API. Many common NumPy programs would run just as well in JAX if you substitute `np` for `jnp`. However, there are some important differences which we touch on at the end of this section. - -You can notice the first difference if you check the type of `x`. It is a variable of type `Array`, which is the way JAX represents arrays. - -```{code-cell} ipython3 -:id: 3fLtgPUAn7mi - -x -``` - -+++ {"id": "Yx8VofzzoHFH"} - -One useful feature of JAX is that the same code can be run on different backends -- CPU, GPU and TPU. - -We will now perform a dot product to demonstrate that it can be done in different devices without changing the code. We use `%timeit` to check the performance. - -(Technical detail: when a JAX function is called (including `jnp.array` -creation), the corresponding operation is dispatched to an accelerator to be -computed asynchronously when possible. The returned array is therefore not -necessarily 'filled in' as soon as the function returns. Thus, if we don't -require the result immediately, the computation won't block Python execution. -Therefore, unless we `block_until_ready` or convert the array to a regular -Python type, we will only time the dispatch, not the actual computation. See -[Asynchronous dispatch](https://jax.readthedocs.io/en/latest/async_dispatch.html#asynchronous-dispatch) -in the JAX docs.) - -```{code-cell} ipython3 -:id: mRvjVxoqo-Bi - -long_vector = jnp.arange(int(1e7)) - -%timeit jnp.dot(long_vector, long_vector).block_until_ready() -``` - -+++ {"id": "DKBB0zs-p-RC"} - -**Tip**: Try running the code above twice, once without an accelerator, and once with a GPU runtime (while in Colab, click *Runtime* → *Change Runtime Type* and choose `GPU`). Notice how much faster it runs on a GPU. - -+++ {"id": "PkCpI-v0uQQO"} - -## JAX first transformation: `grad` - -A fundamental feature of JAX is that it allows you to transform functions. - -One of the most commonly used transformations is `jax.grad`, which takes a numerical function written in Python and returns you a new Python function that computes the gradient of the original function. - -To use it, let's first define a function that takes an array and returns the sum of squares. - -```{code-cell} ipython3 -:id: LuaGUVRUvbzQ - -def sum_of_squares(x): - return jnp.sum(x**2) -``` - -+++ {"id": "QAqloI1Wvtp2"} - -Applying `jax.grad` to `sum_of_squares` will return a different function, namely the gradient of `sum_of_squares` with respect to its first parameter `x`. - -Then, you can use that function on an array to return the derivatives with respect to each element of the array. - -```{code-cell} ipython3 -:id: dKeorwJfvpeI - -sum_of_squares_dx = jax.grad(sum_of_squares) - -x = jnp.asarray([1.0, 2.0, 3.0, 4.0]) - -print(sum_of_squares(x)) - -print(sum_of_squares_dx(x)) -``` - -+++ {"id": "VfBt5CYbyKUX"} - -You can think of `jax.grad` by analogy to the $\nabla$ operator from vector calculus. Given a function $f(x)$, $\nabla f$ represents the function that computes $f$'s gradient, i.e. - -$$ -(\nabla f)(x)_i = \frac{\partial f}{\partial x_i}(x). -$$ - -Analogously, `jax.grad(f)` is the function that computes the gradient, so `jax.grad(f)(x)` is the gradient of `f` at `x`. - -(Like $\nabla$, `jax.grad` will only work on functions with a scalar output -- it will raise an error otherwise.) - -This makes the JAX API quite different from other autodiff libraries like Tensorflow and PyTorch, where to compute the gradient we use the loss tensor itself (e.g. by calling `loss.backward()`). The JAX API works directly with functions, staying closer to the underlying math. Once you become accustomed to this way of doing things, it feels natural: your loss function in code really is a function of parameters and data, and you find its gradient just like you would in the math. - -This way of doing things makes it straightforward to control things like which variables to differentiate with respect to. By default, `jax.grad` will find the gradient with respect to the first argument. In the example below, the result of `sum_squared_error_dx` will be the gradient of `sum_squared_error` with respect to `x`. - -```{code-cell} ipython3 -:id: f3NfaVu4yrQE - -def sum_squared_error(x, y): - return jnp.sum((x-y)**2) - -sum_squared_error_dx = jax.grad(sum_squared_error) - -y = jnp.asarray([1.1, 2.1, 3.1, 4.1]) - -print(sum_squared_error_dx(x, y)) -``` - -+++ {"id": "1tOztA5zpLWN"} - -To find the gradient with respect to a different argument (or several), you can set `argnums`: - -```{code-cell} ipython3 -:id: FQSczVQkqIPY - -jax.grad(sum_squared_error, argnums=(0, 1))(x, y) # Find gradient wrt both x & y -``` - -+++ {"id": "yQAMTnZSqo-t"} - -Does this mean that when doing machine learning, we need to write functions with gigantic argument lists, with an argument for each model parameter array? No. JAX comes equipped with machinery for bundling arrays together in data structures called 'pytrees', on which more in a [later guide](https://colab.research.google.com/github/google/jax/blob/main/docs/jax-101/05.1-pytrees.ipynb). So, most often, use of `jax.grad` looks like this: - -``` -def loss_fn(params, data): - ... - -grads = jax.grad(loss_fn)(params, data_batch) -``` - -+++ {"id": "oBowiovisT97"} - -where `params` is, for example, a nested dict of arrays, and the returned `grads` is another nested dict of arrays with the same structure. - -+++ {"id": "LNjf9jUEsZZ8"} - -## Value and Grad - -Often, you need to find both the value and the gradient of a function, e.g. if you want to log the training loss. JAX has a handy sister transformation for efficiently doing that: - -```{code-cell} ipython3 -:id: dWg4_-h3sYwl - -jax.value_and_grad(sum_squared_error)(x, y) -``` - -+++ {"id": "QVT2EWHJsvvv"} - -which returns a tuple of, you guessed it, (value, grad). To be precise, for any `f`, - -``` -jax.value_and_grad(f)(*xs) == (f(*xs), jax.grad(f)(*xs)) -``` - -+++ {"id": "QmHTVpAks3OX"} - -## Auxiliary data - -In addition to wanting to log the value, we often want to report some intermediate results obtained in computing the loss function. But if we try doing that with regular `jax.grad`, we run into trouble: - -```{code-cell} ipython3 -:id: ffGCEzT4st41 -:tags: [raises-exception] - -def squared_error_with_aux(x, y): - return sum_squared_error(x, y), x-y - -jax.grad(squared_error_with_aux)(x, y) -``` - -+++ {"id": "IUubno3nth4i"} - -This is because `jax.grad` is only defined on scalar functions, and our new function returns a tuple. But we need to return a tuple to return our intermediate results! This is where `has_aux` comes in: - -```{code-cell} ipython3 -:id: uzUFihyatgiF - -jax.grad(squared_error_with_aux, has_aux=True)(x, y) -``` - -+++ {"id": "g5s3UiFauwDk"} - -`has_aux` signifies that the function returns a pair, `(out, aux)`. It makes `jax.grad` ignore `aux`, passing it through to the user, while differentiating the function as if only `out` was returned. - -+++ {"id": "fk4FUXe7vsW4"} - -## Differences from NumPy - -The `jax.numpy` API closely follows that of NumPy. However, there are some important differences. We cover many of these in future guides, but it's worth pointing some out now. - -The most important difference, and in some sense the root of all the rest, is that JAX is designed to be _functional_, as in _functional programming_. The reason behind this is that the kinds of program transformations that JAX enables are much more feasible in functional-style programs. - -An introduction to functional programming (FP) is out of scope of this guide. If you already are familiar with FP, you will find your FP intuition helpful while learning JAX. If not, don't worry! The important feature of functional programming to grok when working with JAX is very simple: don't write code with side-effects. - -A side-effect is any effect of a function that doesn't appear in its output. One example is modifying an array in place: - -```{code-cell} ipython3 -:id: o_YBuLQC1wPJ - -import numpy as np - -x = np.array([1, 2, 3]) - -def in_place_modify(x): - x[0] = 123 - return None - -in_place_modify(x) -x -``` - -+++ {"id": "JTtUihVZ13F6"} - -The side-effectful function modifies its argument, but returns a completely unrelated value. The modification is a side-effect. - -The code below will run in NumPy. However, JAX arrays won't allow themselves to be modified in-place: - -```{code-cell} ipython3 -:id: u6grTYIVcZ3f -:tags: [raises-exception] - -in_place_modify(jnp.array(x)) # Raises error if we cast input to jnp.ndarray -``` - -+++ {"id": "RGqVfYSpc49s"} - -Helpfully, the error points us to JAX's side-effect-free way of doing the same thing via the [`jax.numpy.ndarray.at`](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html) index update operators (be careful [`jax.ops.index_*`](https://jax.readthedocs.io/en/latest/jax.ops.html#indexed-update-functions-deprecated) functions are deprecated). They are analogous to in-place modification by index, but create a new array with the corresponding modifications made: - -```{code-cell} ipython3 -:id: Rmklk6BB2xF0 - -def jax_in_place_modify(x): - return x.at[0].set(123) - -y = jnp.array([1, 2, 3]) -jax_in_place_modify(y) -``` - -+++ {"id": "91tn_25vdrNf"} - -Note that the old array was untouched, so there is no side-effect: - -```{code-cell} ipython3 -:id: KQGXig4Hde6T - -y -``` - -+++ {"id": "d5TibzPO25qa"} - -Side-effect-free code is sometimes called *functionally pure*, or just *pure*. - -Isn't the pure version less efficient? Strictly, yes; we are creating a new array. However, as we will explain in the next guide, JAX computations are often compiled before being run using another program transformation, `jax.jit`. If we don't use the old array after modifying it 'in place' using indexed update operators, the compiler can recognise that it can in fact compile to an in-place modify, resulting in efficient code in the end. - -Of course, it's possible to mix side-effectful Python code and functionally pure JAX code, and we will touch on this more later. As you get more familiar with JAX, you will learn how and when this can work. As a rule of thumb, however, any functions intended to be transformed by JAX should avoid side-effects, and the JAX primitives themselves will try to help you do that. - -We will explain other places where the JAX idiosyncrasies become relevant as they come up. There is even a section that focuses entirely on getting used to the functional programming style of handling state: [Part 7: Problem of State](https://colab.research.google.com/github/google/jax/blob/main/docs/jax-101/07-state.ipynb). However, if you're impatient, you can find a [summary of JAX's sharp edges](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html) in the JAX docs. - -+++ {"id": "dFn_VBFFlGCz"} - -## Your first JAX training loop - -We still have much to learn about JAX, but you already know enough to understand how we can use JAX to build a simple training loop. - -To keep things simple, we'll start with a linear regression. - -Our data is sampled according to $y = w_{true} x + b_{true} + \epsilon$. - -```{code-cell} ipython3 -:id: WGgyEWFqrPq1 - -import numpy as np -import matplotlib.pyplot as plt - -xs = np.random.normal(size=(100,)) -noise = np.random.normal(scale=0.1, size=(100,)) -ys = xs * 3 - 1 + noise - -plt.scatter(xs, ys); -``` - -+++ {"id": "RTh22mo4rR1x"} - -Therefore, our model is $\hat y(x; \theta) = wx + b$. - -We will use a single array, `theta = [w, b]` to house both parameters: - -```{code-cell} ipython3 -:id: TnVrRTMamyzb - -def model(theta, x): - """Computes wx + b on a batch of input x.""" - w, b = theta - return w * x + b -``` - -+++ {"id": "qCrLmmKrn9_h"} - -The loss function is $J(x, y; \theta) = (\hat y - y)^2$. - -```{code-cell} ipython3 -:id: 07eMcDLMn9Ww - -def loss_fn(theta, x, y): - prediction = model(theta, x) - return jnp.mean((prediction-y)**2) -``` - -+++ {"id": "ejMt4dulnoYX"} - -How do we optimize a loss function? Using gradient descent. At each update step, we will find the gradient of the loss w.r.t. the parameters, and take a small step in the direction of steepest descent: - -$\theta_{new} = \theta - 0.1 (\nabla_\theta J) (x, y; \theta)$ - -```{code-cell} ipython3 -:id: 2I6T5Wphpaaa - -def update(theta, x, y, lr=0.1): - return theta - lr * jax.grad(loss_fn)(theta, x, y) -``` - -+++ {"id": "MAUL1gT_opVn"} - -In JAX, it's common to define an `update()` function that is called every step, taking the current parameters as input and returning the new parameters. This is a natural consequence of JAX's functional nature, and is explained in more detail in [The Problem of State](https://colab.research.google.com/github/google/jax/blob/main/docs/jax-101/07-state.ipynb). - -This function can then be JIT-compiled in its entirety for maximum efficiency. The next guide will explain exactly how `jax.jit` works, but if you want to, you can try adding `@jax.jit` before the `update()` definition, and see how the training loop below runs much faster. - -```{code-cell} ipython3 -:id: WLZxY7nIpuVW - -theta = jnp.array([1., 1.]) - -for _ in range(1000): - theta = update(theta, xs, ys) - -plt.scatter(xs, ys) -plt.plot(xs, model(theta, xs)) - -w, b = theta -print(f"w: {w:<.2f}, b: {b:<.2f}") -``` - -+++ {"id": "5-q17kJ_rjLc"} - -As you will see going through these guides, this basic recipe underlies almost all training loops you'll see implemented in JAX. The main difference between this example and real training loops is the simplicity of our model: that allows us to use a single array to house all our parameters. We cover managing more parameters in the later [pytree guide](https://colab.research.google.com/github/google/jax/blob/main/docs/jax-101/05.1-pytrees.ipynb). Feel free to skip forward to that guide now to see how to manually define and train a simple MLP in JAX. diff --git a/docs/jax-101/02-jitting.ipynb b/docs/jax-101/02-jitting.ipynb deleted file mode 100644 index d72b310531e6..000000000000 --- a/docs/jax-101/02-jitting.ipynb +++ /dev/null @@ -1,673 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "O-SkdlPxvETZ" - }, - "source": [ - "# Just In Time Compilation with JAX\n", - "\n", - "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/jax-101/02-jitting.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/jax-101/02-jitting.ipynb)\n", - "\n", - "*Authors: Rosalia Schneider & Vladimir Mikulik*\n", - "\n", - "In this section, we will further explore how JAX works, and how we can make it performant.\n", - "We will discuss the `jax.jit()` transform, which will perform *Just In Time* (JIT) compilation\n", - "of a JAX Python function so it can be executed efficiently in XLA.\n", - "\n", - "## How JAX transforms work\n", - "\n", - "In the previous section, we discussed that JAX allows us to transform Python functions. This is done by first converting the Python function into a simple intermediate language called jaxpr. The transformations then work on the jaxpr representation. \n", - "\n", - "We can show a representation of the jaxpr of a function by using `jax.make_jaxpr`:" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": { - "id": "P9Xj77Wx3Z2P", - "outputId": "5a0597eb-86c9-4762-ce10-2811debbc732" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "{ lambda ; a:f32[]. let\n", - " b:f32[] = log a\n", - " c:f32[] = log 2.0\n", - " d:f32[] = div b c\n", - " in (d,) }\n" - ] - } - ], - "source": [ - "import jax\n", - "import jax.numpy as jnp\n", - "\n", - "global_list = []\n", - "\n", - "def log2(x):\n", - " global_list.append(x)\n", - " ln_x = jnp.log(x)\n", - " ln_2 = jnp.log(2.0)\n", - " return ln_x / ln_2\n", - "\n", - "print(jax.make_jaxpr(log2)(3.0))" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "jiDsT7y0RwIp" - }, - "source": [ - "The [Understanding Jaxprs](https://jax.readthedocs.io/en/latest/jaxpr.html) section of the documentation provides more information on the meaning of the above output.\n", - "\n", - "Importantly, note how the jaxpr does not capture the side-effect of the function: there is nothing in it corresponding to `global_list.append(x)`. This is a feature, not a bug: JAX is designed to understand side-effect-free (a.k.a. functionally pure) code. If *pure function* and *side-effect* are unfamiliar terms, this is explained in a little more detail in [🔪 JAX - The Sharp Bits 🔪: Pure Functions](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#pure-functions).\n", - "\n", - "Of course, impure functions can still be written and even run, but JAX gives no guarantees about their behaviour once converted to jaxpr. However, as a rule of thumb, you can expect (but shouldn't rely on) the side-effects of a JAX-transformed function to run once (during the first call), and never again. This is because of the way that JAX generates jaxpr, using a process called 'tracing'.\n", - "\n", - "When tracing, JAX wraps each argument by a *tracer* object. These tracers then record all JAX operations performed on them during the function call (which happens in regular Python). Then, JAX uses the tracer records to reconstruct the entire function. The output of that reconstruction is the jaxpr. Since the tracers do not record the Python side-effects, they do not appear in the jaxpr. However, the side-effects still happen during the trace itself.\n", - "\n", - "Note: the Python `print()` function is not pure: the text output is a side-effect of the function. Therefore, any `print()` calls will only happen during tracing, and will not appear in the jaxpr:" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": { - "id": "JxV2p7e2RawC", - "outputId": "9dfe8a56-e553-4640-a04e-5405aea7832d" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "printed x: Tracedwith\n", - "{ lambda ; a:f32[]. let\n", - " b:f32[] = log a\n", - " c:f32[] = log 2.0\n", - " d:f32[] = div b c\n", - " in (d,) }\n" - ] - } - ], - "source": [ - "def log2_with_print(x):\n", - " print(\"printed x:\", x)\n", - " ln_x = jnp.log(x)\n", - " ln_2 = jnp.log(2.0)\n", - " return ln_x / ln_2\n", - "\n", - "print(jax.make_jaxpr(log2_with_print)(3.))" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "f6W_YYwRRwGp" - }, - "source": [ - "See how the printed `x` is a `Traced` object? That's the JAX internals at work.\n", - "\n", - "The fact that the Python code runs at least once is strictly an implementation detail, and so shouldn't be relied upon. However, it's useful to understand as you can use it when debugging to print out intermediate values of a computation." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "PgVqi6NlRdWZ" - }, - "source": [ - "A key thing to understand is that jaxpr captures the function as executed on the parameters given to it. For example, if we have a conditional, jaxpr will only know about the branch we take:" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": { - "id": "hn0CuphEZKZm", - "outputId": "99dae727-d2be-4577-831c-e1e14af5890a" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "{ lambda ; a:i32[3]. let in (a,) }\n" - ] - } - ], - "source": [ - "def log2_if_rank_2(x):\n", - " if x.ndim == 2:\n", - " ln_x = jnp.log(x)\n", - " ln_2 = jnp.log(2.0)\n", - " return ln_x / ln_2\n", - " else:\n", - " return x\n", - "\n", - "print(jax.make_jaxpr(log2_if_rank_2)(jax.numpy.array([1, 2, 3])))" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Qp3WhqaqvHyD" - }, - "source": [ - "## JIT compiling a function\n", - "\n", - "As explained before, JAX enables operations to execute on CPU/GPU/TPU using the same code.\n", - "Let's look at an example of computing a *Scaled Exponential Linear Unit*\n", - "([SELU](https://proceedings.neurips.cc/paper/6698-self-normalizing-neural-networks.pdf)), an\n", - "operation commonly used in deep learning:" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": { - "id": "JAXFYtlRvD6p", - "outputId": "e94d7dc2-a9a1-4ac2-fd3f-152e3f6d141b" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "100 loops, best of 5: 2.05 ms per loop\n" - ] - } - ], - "source": [ - "import jax\n", - "import jax.numpy as jnp\n", - "\n", - "def selu(x, alpha=1.67, lambda_=1.05):\n", - " return lambda_ * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)\n", - "\n", - "x = jnp.arange(1000000)\n", - "%timeit selu(x).block_until_ready()" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "ecN5lEXe6ncy" - }, - "source": [ - "The code above is sending one operation at a time to the accelerator. This limits the ability of the XLA compiler to optimize our functions.\n", - "\n", - "Naturally, what we want to do is give the XLA compiler as much code as possible, so it can fully optimize it. For this purpose, JAX provides the `jax.jit` transformation, which will JIT compile a JAX-compatible function. The example below shows how to use JIT to speed up the previous function." - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": { - "id": "nJVEwPcH6bQX", - "outputId": "289eb2f7-a5ce-4cec-f652-5c4e5b0b86cf" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "10000 loops, best of 5: 150 µs per loop\n" - ] - } - ], - "source": [ - "selu_jit = jax.jit(selu)\n", - "\n", - "# Warm up\n", - "selu_jit(x).block_until_ready()\n", - "\n", - "%timeit selu_jit(x).block_until_ready()" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "hMNKi1mYXQg5" - }, - "source": [ - "Here's what just happened:\n", - "\n", - "1) We defined `selu_jit` as the compiled version of `selu`.\n", - "\n", - "2) We called `selu_jit` once on `x`. This is where JAX does its tracing -- it needs to have some inputs to wrap in tracers, after all. The jaxpr is then compiled using XLA into very efficient code optimized for your GPU or TPU. Finally, the compiled code is executed to satisfy the call. Subsequent calls to `selu_jit` will use the compiled code directly, skipping the python implementation entirely.\n", - "\n", - "(If we didn't include the warm-up call separately, everything would still work, but then the compilation time would be included in the benchmark. It would still be faster, because we run many loops in the benchmark, but it wouldn't be a fair comparison.)\n", - "\n", - "3) We timed the execution speed of the compiled version. (Note the use of `block_until_ready()`, which is required due to JAX's [Asynchronous execution](https://jax.readthedocs.io/en/latest/async_dispatch.html) model)." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "DRJ6R6-d9Q_U" - }, - "source": [ - "## Why can't we just JIT everything?\n", - "\n", - "After going through the example above, you might be wondering whether we should simply apply `jax.jit` to every function. To understand why this is not the case, and when we should/shouldn't apply `jit`, let's first check some cases where JIT doesn't work." - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": { - "id": "GO1Mwd_3_W6g", - "outputId": "a6fcf6d1-7bd6-4bb7-99c3-2a5a827183e2", - "tags": [ - "raises-exception" - ] - }, - "outputs": [ - { - "ename": "ConcretizationTypeError", - "evalue": "ignored", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mUnfilteredStackTrace\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 9\u001b[0m \u001b[0mf_jit\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mjax\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mjit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mf\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 10\u001b[0;31m \u001b[0mf_jit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m10\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# Should raise an error.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", - "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/jax/_src/traceback_util.py\u001b[0m in \u001b[0;36mreraise_with_filtered_traceback\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 161\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 162\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 163\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mException\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/jax/_src/api.py\u001b[0m in \u001b[0;36mcache_miss\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 418\u001b[0m \u001b[0mdevice\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbackend\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mbackend\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mflat_fun\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__name__\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 419\u001b[0;31m donated_invars=donated_invars, inline=inline)\n\u001b[0m\u001b[1;32m 420\u001b[0m \u001b[0mout_pytree_def\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mout_tree\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/jax/core.py\u001b[0m in \u001b[0;36mbind\u001b[0;34m(self, fun, *args, **params)\u001b[0m\n\u001b[1;32m 1631\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mbind\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1632\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mcall_bind\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1633\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/jax/core.py\u001b[0m in \u001b[0;36mcall_bind\u001b[0;34m(primitive, fun, *args, **params)\u001b[0m\n\u001b[1;32m 1622\u001b[0m \u001b[0mtracers\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmap\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtop_trace\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfull_raise\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1623\u001b[0;31m \u001b[0mouts\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mprimitive\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprocess\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtop_trace\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtracers\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1624\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mmap\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfull_lower\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mapply_todos\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0menv_trace_todo\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mouts\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/jax/core.py\u001b[0m in \u001b[0;36mprocess\u001b[0;34m(self, trace, fun, tracers, params)\u001b[0m\n\u001b[1;32m 1634\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mprocess\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtrace\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtracers\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1635\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mtrace\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprocess_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtracers\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1636\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/jax/core.py\u001b[0m in \u001b[0;36mprocess_call\u001b[0;34m(self, primitive, f, tracers, params)\u001b[0m\n\u001b[1;32m 626\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mprocess_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mprimitive\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtracers\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 627\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mprimitive\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mimpl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mf\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0mtracers\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 628\u001b[0m \u001b[0mprocess_map\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mprocess_call\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/jax/interpreters/xla.py\u001b[0m in \u001b[0;36m_xla_call_impl\u001b[0;34m(***failed resolving arguments***)\u001b[0m\n\u001b[1;32m 687\u001b[0m compiled_fun = _xla_callable(fun, device, backend, name, donated_invars,\n\u001b[0;32m--> 688\u001b[0;31m *unsafe_map(arg_spec, args))\n\u001b[0m\u001b[1;32m 689\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/jax/linear_util.py\u001b[0m in \u001b[0;36mmemoized_fun\u001b[0;34m(fun, *args)\u001b[0m\n\u001b[1;32m 262\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 263\u001b[0;31m \u001b[0mans\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcall\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 264\u001b[0m \u001b[0mcache\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mkey\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mans\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstores\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/jax/interpreters/xla.py\u001b[0m in \u001b[0;36m_xla_callable_uncached\u001b[0;34m(fun, device, backend, name, donated_invars, *arg_specs)\u001b[0m\n\u001b[1;32m 759\u001b[0m return lower_xla_callable(fun, device, backend, name, donated_invars,\n\u001b[0;32m--> 760\u001b[0;31m *arg_specs).compile().unsafe_call\n\u001b[0m\u001b[1;32m 761\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/jax/interpreters/xla.py\u001b[0m in \u001b[0;36mlower_xla_callable\u001b[0;34m(fun, device, backend, name, donated_invars, *arg_specs)\u001b[0m\n\u001b[1;32m 771\u001b[0m jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(\n\u001b[0;32m--> 772\u001b[0;31m fun, abstract_args, pe.debug_info_final(fun, \"jit\"))\n\u001b[0m\u001b[1;32m 773\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0many\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mc\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcore\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mTracer\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mc\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mconsts\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/jax/interpreters/partial_eval.py\u001b[0m in \u001b[0;36mtrace_to_jaxpr_final\u001b[0;34m(fun, in_avals, debug_info)\u001b[0m\n\u001b[1;32m 1541\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mcore\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnew_sublevel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1542\u001b[0;31m \u001b[0mjaxpr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout_avals\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mconsts\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtrace_to_subjaxpr_dynamic\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmain\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0min_avals\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1543\u001b[0m \u001b[0;32mdel\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmain\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/jax/interpreters/partial_eval.py\u001b[0m in \u001b[0;36mtrace_to_subjaxpr_dynamic\u001b[0;34m(fun, main, in_avals)\u001b[0m\n\u001b[1;32m 1519\u001b[0m \u001b[0min_tracers\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmap\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrace\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnew_arg\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0min_avals\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1520\u001b[0;31m \u001b[0mans\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcall_wrapped\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0min_tracers\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1521\u001b[0m \u001b[0mout_tracers\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmap\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrace\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfull_raise\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mans\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/jax/linear_util.py\u001b[0m in \u001b[0;36mcall_wrapped\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 165\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 166\u001b[0;31m \u001b[0mans\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mf\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mdict\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mparams\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 167\u001b[0m \u001b[0;32mexcept\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m\u001b[0m in \u001b[0;36mf\u001b[0;34m(x)\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 4\u001b[0;31m \u001b[0;32mif\u001b[0m \u001b[0mx\u001b[0m \u001b[0;34m>\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 5\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/jax/core.py\u001b[0m in \u001b[0;36m__bool__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 548\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m__nonzero__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0maval\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_nonzero\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 549\u001b[0;31m \u001b[0;32mdef\u001b[0m \u001b[0m__bool__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0maval\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_bool\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 550\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m__int__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0maval\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_int\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/jax/core.py\u001b[0m in \u001b[0;36merror\u001b[0;34m(self, arg)\u001b[0m\n\u001b[1;32m 999\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0merror\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0marg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1000\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mConcretizationTypeError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0marg\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfname_context\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1001\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0merror\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;31mUnfilteredStackTrace\u001b[0m: jax._src.errors.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: Tracedwith\nThe problem arose with the bool function. \nWhile tracing the function f at :3 for jit, this concrete value was not available in Python because it depends on the value of the argument 'x'.\n\nSee https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError\n\nThe stack trace below excludes JAX-internal frames.\nThe preceding is the original exception that occurred, unmodified.\n\n--------------------", - "\nThe above exception was the direct cause of the following exception:\n", - "\u001b[0;31mConcretizationTypeError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 8\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 9\u001b[0m \u001b[0mf_jit\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mjax\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mjit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mf\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 10\u001b[0;31m \u001b[0mf_jit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m10\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# Should raise an error.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", - "\u001b[0;32m\u001b[0m in \u001b[0;36mf\u001b[0;34m(x)\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 4\u001b[0;31m \u001b[0;32mif\u001b[0m \u001b[0mx\u001b[0m \u001b[0;34m>\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 5\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;31mConcretizationTypeError\u001b[0m: Abstract tracer value encountered where concrete value is expected: Tracedwith\nThe problem arose with the bool function. \nWhile tracing the function f at :3 for jit, this concrete value was not available in Python because it depends on the value of the argument 'x'.\n\nSee https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError" - ] - } - ], - "source": [ - "# Condition on value of x.\n", - "\n", - "def f(x):\n", - " if x > 0:\n", - " return x\n", - " else:\n", - " return 2 * x\n", - "\n", - "f_jit = jax.jit(f)\n", - "f_jit(10) # Should raise an error. " - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": { - "id": "LHlipkIMFUhi", - "outputId": "54935882-a180-45c0-ad03-9dfb5e3baa97", - "tags": [ - "raises-exception" - ] - }, - "outputs": [ - { - "ename": "ConcretizationTypeError", - "evalue": "ignored", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mUnfilteredStackTrace\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 9\u001b[0m \u001b[0mg_jit\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mjax\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mjit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 10\u001b[0;31m \u001b[0mg_jit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m10\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m20\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# Should raise an error.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", - "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/jax/_src/traceback_util.py\u001b[0m in \u001b[0;36mreraise_with_filtered_traceback\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 161\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 162\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 163\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mException\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/jax/_src/api.py\u001b[0m in \u001b[0;36mcache_miss\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 418\u001b[0m \u001b[0mdevice\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbackend\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mbackend\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mflat_fun\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__name__\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 419\u001b[0;31m donated_invars=donated_invars, inline=inline)\n\u001b[0m\u001b[1;32m 420\u001b[0m \u001b[0mout_pytree_def\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mout_tree\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/jax/core.py\u001b[0m in \u001b[0;36mbind\u001b[0;34m(self, fun, *args, **params)\u001b[0m\n\u001b[1;32m 1631\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mbind\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1632\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mcall_bind\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1633\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/jax/core.py\u001b[0m in \u001b[0;36mcall_bind\u001b[0;34m(primitive, fun, *args, **params)\u001b[0m\n\u001b[1;32m 1622\u001b[0m \u001b[0mtracers\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmap\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtop_trace\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfull_raise\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1623\u001b[0;31m \u001b[0mouts\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mprimitive\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprocess\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtop_trace\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtracers\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1624\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mmap\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfull_lower\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mapply_todos\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0menv_trace_todo\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mouts\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/jax/core.py\u001b[0m in \u001b[0;36mprocess\u001b[0;34m(self, trace, fun, tracers, params)\u001b[0m\n\u001b[1;32m 1634\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mprocess\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtrace\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtracers\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1635\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mtrace\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprocess_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtracers\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1636\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/jax/core.py\u001b[0m in \u001b[0;36mprocess_call\u001b[0;34m(self, primitive, f, tracers, params)\u001b[0m\n\u001b[1;32m 626\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mprocess_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mprimitive\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtracers\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 627\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mprimitive\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mimpl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mf\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0mtracers\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 628\u001b[0m \u001b[0mprocess_map\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mprocess_call\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/jax/interpreters/xla.py\u001b[0m in \u001b[0;36m_xla_call_impl\u001b[0;34m(***failed resolving arguments***)\u001b[0m\n\u001b[1;32m 687\u001b[0m compiled_fun = _xla_callable(fun, device, backend, name, donated_invars,\n\u001b[0;32m--> 688\u001b[0;31m *unsafe_map(arg_spec, args))\n\u001b[0m\u001b[1;32m 689\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/jax/linear_util.py\u001b[0m in \u001b[0;36mmemoized_fun\u001b[0;34m(fun, *args)\u001b[0m\n\u001b[1;32m 262\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 263\u001b[0;31m \u001b[0mans\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcall\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 264\u001b[0m \u001b[0mcache\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mkey\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mans\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstores\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/jax/interpreters/xla.py\u001b[0m in \u001b[0;36m_xla_callable_uncached\u001b[0;34m(fun, device, backend, name, donated_invars, *arg_specs)\u001b[0m\n\u001b[1;32m 759\u001b[0m return lower_xla_callable(fun, device, backend, name, donated_invars,\n\u001b[0;32m--> 760\u001b[0;31m *arg_specs).compile().unsafe_call\n\u001b[0m\u001b[1;32m 761\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/jax/interpreters/xla.py\u001b[0m in \u001b[0;36mlower_xla_callable\u001b[0;34m(fun, device, backend, name, donated_invars, *arg_specs)\u001b[0m\n\u001b[1;32m 771\u001b[0m jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(\n\u001b[0;32m--> 772\u001b[0;31m fun, abstract_args, pe.debug_info_final(fun, \"jit\"))\n\u001b[0m\u001b[1;32m 773\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0many\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mc\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcore\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mTracer\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mc\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mconsts\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/jax/interpreters/partial_eval.py\u001b[0m in \u001b[0;36mtrace_to_jaxpr_final\u001b[0;34m(fun, in_avals, debug_info)\u001b[0m\n\u001b[1;32m 1541\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mcore\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnew_sublevel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1542\u001b[0;31m \u001b[0mjaxpr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout_avals\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mconsts\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtrace_to_subjaxpr_dynamic\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmain\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0min_avals\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1543\u001b[0m \u001b[0;32mdel\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmain\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/jax/interpreters/partial_eval.py\u001b[0m in \u001b[0;36mtrace_to_subjaxpr_dynamic\u001b[0;34m(fun, main, in_avals)\u001b[0m\n\u001b[1;32m 1519\u001b[0m \u001b[0min_tracers\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmap\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrace\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnew_arg\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0min_avals\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1520\u001b[0;31m \u001b[0mans\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcall_wrapped\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0min_tracers\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1521\u001b[0m \u001b[0mout_tracers\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmap\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrace\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfull_raise\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mans\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/jax/linear_util.py\u001b[0m in \u001b[0;36mcall_wrapped\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 165\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 166\u001b[0;31m \u001b[0mans\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mf\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mdict\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mparams\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 167\u001b[0m \u001b[0;32mexcept\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m\u001b[0m in \u001b[0;36mg\u001b[0;34m(x, n)\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0mi\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 5\u001b[0;31m \u001b[0;32mwhile\u001b[0m \u001b[0mi\u001b[0m \u001b[0;34m<\u001b[0m \u001b[0mn\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 6\u001b[0m \u001b[0mi\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/jax/core.py\u001b[0m in \u001b[0;36m__bool__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 548\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m__nonzero__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0maval\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_nonzero\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 549\u001b[0;31m \u001b[0;32mdef\u001b[0m \u001b[0m__bool__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0maval\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_bool\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 550\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m__int__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0maval\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_int\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/jax/core.py\u001b[0m in \u001b[0;36merror\u001b[0;34m(self, arg)\u001b[0m\n\u001b[1;32m 999\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0merror\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0marg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1000\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mConcretizationTypeError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0marg\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfname_context\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1001\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0merror\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;31mUnfilteredStackTrace\u001b[0m: jax._src.errors.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: Tracedwith\nThe problem arose with the bool function. \nWhile tracing the function g at :3 for jit, this concrete value was not available in Python because it depends on the value of the argument 'n'.\n\nSee https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError\n\nThe stack trace below excludes JAX-internal frames.\nThe preceding is the original exception that occurred, unmodified.\n\n--------------------", - "\nThe above exception was the direct cause of the following exception:\n", - "\u001b[0;31mConcretizationTypeError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 8\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 9\u001b[0m \u001b[0mg_jit\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mjax\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mjit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 10\u001b[0;31m \u001b[0mg_jit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m10\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m20\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# Should raise an error.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", - "\u001b[0;32m\u001b[0m in \u001b[0;36mg\u001b[0;34m(x, n)\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mg\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0mi\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 5\u001b[0;31m \u001b[0;32mwhile\u001b[0m \u001b[0mi\u001b[0m \u001b[0;34m<\u001b[0m \u001b[0mn\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 6\u001b[0m \u001b[0mi\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 7\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mx\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mi\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;31mConcretizationTypeError\u001b[0m: Abstract tracer value encountered where concrete value is expected: Tracedwith\nThe problem arose with the bool function. \nWhile tracing the function g at :3 for jit, this concrete value was not available in Python because it depends on the value of the argument 'n'.\n\nSee https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError" - ] - } - ], - "source": [ - "# While loop conditioned on x and n.\n", - "\n", - "def g(x, n):\n", - " i = 0\n", - " while i < n:\n", - " i += 1\n", - " return x + i\n", - "\n", - "g_jit = jax.jit(g)\n", - "g_jit(10, 20) # Should raise an error. " - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "isz2U_XX_wH2" - }, - "source": [ - "The problem is that we tried to condition on the *value* of an input to the function being jitted. The reason we can't do this is related to the fact mentioned above that jaxpr depends on the actual values used to trace it. \n", - "\n", - "The more specific information about the values we use in the trace, the more we can use standard Python control flow to express ourselves. However, being too specific means we can't reuse the same traced function for other values. JAX solves this by tracing at different levels of abstraction for different purposes.\n", - "\n", - "For `jax.jit`, the default level is `ShapedArray` -- that is, each tracer has a concrete shape (which we're allowed to condition on), but no concrete value. This allows the compiled function to work on all possible inputs with the same shape -- the standard use case in machine learning. However, because the tracers have no concrete value, if we attempt to condition on one, we get the error above.\n", - "\n", - "In `jax.grad`, the constraints are more relaxed, so you can do more. If you compose several transformations, however, you must satisfy the constraints of the most strict one. So, if you `jit(grad(f))`, `f` mustn't condition on value. For more detail on the interaction between Python control flow and JAX, see [🔪 JAX - The Sharp Bits 🔪: Control Flow](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#control-flow).\n", - "\n", - "One way to deal with this problem is to rewrite the code to avoid conditionals on value. Another is to use special [control flow operators](https://jax.readthedocs.io/en/latest/jax.lax.html#control-flow-operators) like `jax.lax.cond`. However, sometimes that is impossible. In that case, you can consider jitting only part of the function. For example, if the most computationally expensive part of the function is inside the loop, we can JIT just that inner part (though make sure to check the next section on caching to avoid shooting yourself in the foot):" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": { - "id": "OeR8hF-NHAML", - "outputId": "d47fd6b2-8bbd-4939-a794-0b80183d3179" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "Array(30, dtype=int32, weak_type=True)" - ] - }, - "execution_count": 8, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# While loop conditioned on x and n with a jitted body.\n", - "\n", - "@jax.jit\n", - "def loop_body(prev_i):\n", - " return prev_i + 1\n", - "\n", - "def g_inner_jitted(x, n):\n", - " i = 0\n", - " while i < n:\n", - " i = loop_body(i)\n", - " return x + i\n", - "\n", - "g_inner_jitted(10, 20)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "5XUT2acoHBz-" - }, - "source": [ - "If we really need to JIT a function that has a condition on the value of an input, we can tell JAX to help itself to a less abstract tracer for a particular input by specifying `static_argnums` or `static_argnames`. The cost of this is that the resulting jaxpr is less flexible, so JAX will have to re-compile the function for every new value of the specified static input. It is only a good strategy if the function is guaranteed to get limited different values." - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": { - "id": "2yQmQTDNAenY", - "outputId": "c48f07b8-c3f9-4d2a-9dfd-663838a52511" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "10\n" - ] - } - ], - "source": [ - "f_jit_correct = jax.jit(f, static_argnums=0)\n", - "print(f_jit_correct(10))" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "metadata": { - "id": "R4SXUEu-M-u1", - "outputId": "9e712e14-4e81-4744-dcf2-a10f470d9121" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "30\n" - ] - } - ], - "source": [ - "g_jit_correct = jax.jit(g, static_argnames=['n'])\n", - "print(g_jit_correct(10, 20))" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "To specify such arguments when using `jit` as a decorator, a common pattern is to use python's `functools.partial`:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "2X5rR4jkIO", - "outputId": "81-4744-dc2e4-4e10f470f2-a19e71d9121" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "30\n" - ] - } - ], - "source": [ - "from functools import partial\n", - "\n", - "@partial(jax.jit, static_argnames=['n'])\n", - "def g_jit_decorated(x, n):\n", - " i = 0\n", - " while i < n:\n", - " i += 1\n", - " return x + i\n", - "\n", - "print(g_jit_decorated(10, 20))" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "LczjIBt2X2Ms" - }, - "source": [ - "## When to use JIT\n", - "\n", - "In many of the examples above, jitting is not worth it:" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "metadata": { - "id": "uMOqsNnqYApD", - "outputId": "2d6c5122-43ad-4257-e56b-e77c889131c2" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "g jitted:\n", - "The slowest run took 13.54 times longer than the fastest. This could mean that an intermediate result is being cached.\n", - "1000 loops, best of 5: 229 µs per loop\n", - "g:\n", - "The slowest run took 11.72 times longer than the fastest. This could mean that an intermediate result is being cached.\n", - "1000000 loops, best of 5: 1.2 µs per loop\n" - ] - } - ], - "source": [ - "print(\"g jitted:\")\n", - "%timeit g_jit_correct(10, 20).block_until_ready()\n", - "\n", - "print(\"g:\")\n", - "%timeit g(10, 20)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "cZmGYq80YP0j" - }, - "source": [ - "This is because `jax.jit` introduces some overhead itself. Therefore, it usually only saves time if the compiled function is complex and you will run it numerous times. Fortunately, this is common in machine learning, where we tend to compile a large, complicated model, then run it for millions of iterations.\n", - "\n", - "Generally, you want to jit the largest possible chunk of your computation; ideally, the entire update step. This gives the compiler maximum freedom to optimise." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "hJMjUlRcIzVS" - }, - "source": [ - "## Caching\n", - "\n", - "It's important to understand the caching behaviour of `jax.jit`.\n", - "\n", - "Suppose I define `f = jax.jit(g)`. When I first invoke `f`, it will get compiled, and the resulting XLA code will get cached. Subsequent calls of `f` will reuse the cached code. This is how `jax.jit` makes up for the up-front cost of compilation.\n", - "\n", - "If I specify `static_argnums`, then the cached code will be used only for the same values of arguments labelled as static. If any of them change, recompilation occurs. If there are many values, then your program might spend more time compiling than it would have executing ops one-by-one.\n", - "\n", - "Avoid calling `jax.jit` inside loops. For most cases, JAX will be able to use the compiled, cached function in subsequent calls to `jax.jit`. However, because the cache relies on the hash of the function, it becomes problematic when equivalent functions are redefined. This will cause unnecessary compilation each time in the loop:" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "metadata": { - "id": "6MDSXCfmSZVZ", - "outputId": "a035d0b7-6a4d-4a9e-c6b4-7521970829fc" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "jit called in a loop with partials:\n", - "1 loop, best of 5: 192 ms per loop\n", - "jit called in a loop with lambdas:\n", - "10 loops, best of 5: 199 ms per loop\n", - "jit called in a loop with caching:\n", - "10 loops, best of 5: 21.6 ms per loop\n" - ] - } - ], - "source": [ - "from functools import partial\n", - "\n", - "def unjitted_loop_body(prev_i):\n", - " return prev_i + 1\n", - "\n", - "def g_inner_jitted_partial(x, n):\n", - " i = 0\n", - " while i < n:\n", - " # Don't do this! each time the partial returns\n", - " # a function with different hash\n", - " i = jax.jit(partial(unjitted_loop_body))(i)\n", - " return x + i\n", - "\n", - "def g_inner_jitted_lambda(x, n):\n", - " i = 0\n", - " while i < n:\n", - " # Don't do this!, lambda will also return\n", - " # a function with a different hash\n", - " i = jax.jit(lambda x: unjitted_loop_body(x))(i)\n", - " return x + i\n", - "\n", - "def g_inner_jitted_normal(x, n):\n", - " i = 0\n", - " while i < n:\n", - " # this is OK, since JAX can find the\n", - " # cached, compiled function\n", - " i = jax.jit(unjitted_loop_body)(i)\n", - " return x + i\n", - "\n", - "print(\"jit called in a loop with partials:\")\n", - "%timeit g_inner_jitted_partial(10, 20).block_until_ready()\n", - "\n", - "print(\"jit called in a loop with lambdas:\")\n", - "%timeit g_inner_jitted_lambda(10, 20).block_until_ready()\n", - "\n", - "print(\"jit called in a loop with caching:\")\n", - "%timeit g_inner_jitted_normal(10, 20).block_until_ready()" - ] - } - ], - "metadata": { - "colab": { - "collapsed_sections": [], - "name": "Jitting functions in JAX", - "provenance": [] - }, - "jupytext": { - "formats": "ipynb,md:myst" - }, - "kernelspec": { - "display_name": "Python 3", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.7.6" - } - }, - "nbformat": 4, - "nbformat_minor": 0 -} diff --git a/docs/jax-101/02-jitting.md b/docs/jax-101/02-jitting.md deleted file mode 100644 index f1cd242b9e3d..000000000000 --- a/docs/jax-101/02-jitting.md +++ /dev/null @@ -1,338 +0,0 @@ ---- -jupytext: - formats: ipynb,md:myst - text_representation: - extension: .md - format_name: myst - format_version: 0.13 - jupytext_version: 1.16.0 -kernelspec: - display_name: Python 3 - name: python3 ---- - -+++ {"id": "O-SkdlPxvETZ"} - -# Just In Time Compilation with JAX - -[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/jax-101/02-jitting.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/jax-101/02-jitting.ipynb) - -*Authors: Rosalia Schneider & Vladimir Mikulik* - -In this section, we will further explore how JAX works, and how we can make it performant. -We will discuss the `jax.jit()` transform, which will perform *Just In Time* (JIT) compilation -of a JAX Python function so it can be executed efficiently in XLA. - -## How JAX transforms work - -In the previous section, we discussed that JAX allows us to transform Python functions. This is done by first converting the Python function into a simple intermediate language called jaxpr. The transformations then work on the jaxpr representation. - -We can show a representation of the jaxpr of a function by using `jax.make_jaxpr`: - -```{code-cell} ipython3 -:id: P9Xj77Wx3Z2P -:outputId: 5a0597eb-86c9-4762-ce10-2811debbc732 - -import jax -import jax.numpy as jnp - -global_list = [] - -def log2(x): - global_list.append(x) - ln_x = jnp.log(x) - ln_2 = jnp.log(2.0) - return ln_x / ln_2 - -print(jax.make_jaxpr(log2)(3.0)) -``` - -+++ {"id": "jiDsT7y0RwIp"} - -The [Understanding Jaxprs](https://jax.readthedocs.io/en/latest/jaxpr.html) section of the documentation provides more information on the meaning of the above output. - -Importantly, note how the jaxpr does not capture the side-effect of the function: there is nothing in it corresponding to `global_list.append(x)`. This is a feature, not a bug: JAX is designed to understand side-effect-free (a.k.a. functionally pure) code. If *pure function* and *side-effect* are unfamiliar terms, this is explained in a little more detail in [🔪 JAX - The Sharp Bits 🔪: Pure Functions](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#pure-functions). - -Of course, impure functions can still be written and even run, but JAX gives no guarantees about their behaviour once converted to jaxpr. However, as a rule of thumb, you can expect (but shouldn't rely on) the side-effects of a JAX-transformed function to run once (during the first call), and never again. This is because of the way that JAX generates jaxpr, using a process called 'tracing'. - -When tracing, JAX wraps each argument by a *tracer* object. These tracers then record all JAX operations performed on them during the function call (which happens in regular Python). Then, JAX uses the tracer records to reconstruct the entire function. The output of that reconstruction is the jaxpr. Since the tracers do not record the Python side-effects, they do not appear in the jaxpr. However, the side-effects still happen during the trace itself. - -Note: the Python `print()` function is not pure: the text output is a side-effect of the function. Therefore, any `print()` calls will only happen during tracing, and will not appear in the jaxpr: - -```{code-cell} ipython3 -:id: JxV2p7e2RawC -:outputId: 9dfe8a56-e553-4640-a04e-5405aea7832d - -def log2_with_print(x): - print("printed x:", x) - ln_x = jnp.log(x) - ln_2 = jnp.log(2.0) - return ln_x / ln_2 - -print(jax.make_jaxpr(log2_with_print)(3.)) -``` - -+++ {"id": "f6W_YYwRRwGp"} - -See how the printed `x` is a `Traced` object? That's the JAX internals at work. - -The fact that the Python code runs at least once is strictly an implementation detail, and so shouldn't be relied upon. However, it's useful to understand as you can use it when debugging to print out intermediate values of a computation. - -+++ {"id": "PgVqi6NlRdWZ"} - -A key thing to understand is that jaxpr captures the function as executed on the parameters given to it. For example, if we have a conditional, jaxpr will only know about the branch we take: - -```{code-cell} ipython3 -:id: hn0CuphEZKZm -:outputId: 99dae727-d2be-4577-831c-e1e14af5890a - -def log2_if_rank_2(x): - if x.ndim == 2: - ln_x = jnp.log(x) - ln_2 = jnp.log(2.0) - return ln_x / ln_2 - else: - return x - -print(jax.make_jaxpr(log2_if_rank_2)(jax.numpy.array([1, 2, 3]))) -``` - -+++ {"id": "Qp3WhqaqvHyD"} - -## JIT compiling a function - -As explained before, JAX enables operations to execute on CPU/GPU/TPU using the same code. -Let's look at an example of computing a *Scaled Exponential Linear Unit* -([SELU](https://proceedings.neurips.cc/paper/6698-self-normalizing-neural-networks.pdf)), an -operation commonly used in deep learning: - -```{code-cell} ipython3 -:id: JAXFYtlRvD6p -:outputId: e94d7dc2-a9a1-4ac2-fd3f-152e3f6d141b - -import jax -import jax.numpy as jnp - -def selu(x, alpha=1.67, lambda_=1.05): - return lambda_ * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha) - -x = jnp.arange(1000000) -%timeit selu(x).block_until_ready() -``` - -+++ {"id": "ecN5lEXe6ncy"} - -The code above is sending one operation at a time to the accelerator. This limits the ability of the XLA compiler to optimize our functions. - -Naturally, what we want to do is give the XLA compiler as much code as possible, so it can fully optimize it. For this purpose, JAX provides the `jax.jit` transformation, which will JIT compile a JAX-compatible function. The example below shows how to use JIT to speed up the previous function. - -```{code-cell} ipython3 -:id: nJVEwPcH6bQX -:outputId: 289eb2f7-a5ce-4cec-f652-5c4e5b0b86cf - -selu_jit = jax.jit(selu) - -# Warm up -selu_jit(x).block_until_ready() - -%timeit selu_jit(x).block_until_ready() -``` - -+++ {"id": "hMNKi1mYXQg5"} - -Here's what just happened: - -1) We defined `selu_jit` as the compiled version of `selu`. - -2) We called `selu_jit` once on `x`. This is where JAX does its tracing -- it needs to have some inputs to wrap in tracers, after all. The jaxpr is then compiled using XLA into very efficient code optimized for your GPU or TPU. Finally, the compiled code is executed to satisfy the call. Subsequent calls to `selu_jit` will use the compiled code directly, skipping the python implementation entirely. - -(If we didn't include the warm-up call separately, everything would still work, but then the compilation time would be included in the benchmark. It would still be faster, because we run many loops in the benchmark, but it wouldn't be a fair comparison.) - -3) We timed the execution speed of the compiled version. (Note the use of `block_until_ready()`, which is required due to JAX's [Asynchronous execution](https://jax.readthedocs.io/en/latest/async_dispatch.html) model). - -+++ {"id": "DRJ6R6-d9Q_U"} - -## Why can't we just JIT everything? - -After going through the example above, you might be wondering whether we should simply apply `jax.jit` to every function. To understand why this is not the case, and when we should/shouldn't apply `jit`, let's first check some cases where JIT doesn't work. - -```{code-cell} ipython3 -:id: GO1Mwd_3_W6g -:outputId: a6fcf6d1-7bd6-4bb7-99c3-2a5a827183e2 -:tags: [raises-exception] - -# Condition on value of x. - -def f(x): - if x > 0: - return x - else: - return 2 * x - -f_jit = jax.jit(f) -f_jit(10) # Should raise an error. -``` - -```{code-cell} ipython3 -:id: LHlipkIMFUhi -:outputId: 54935882-a180-45c0-ad03-9dfb5e3baa97 -:tags: [raises-exception] - -# While loop conditioned on x and n. - -def g(x, n): - i = 0 - while i < n: - i += 1 - return x + i - -g_jit = jax.jit(g) -g_jit(10, 20) # Should raise an error. -``` - -+++ {"id": "isz2U_XX_wH2"} - -The problem is that we tried to condition on the *value* of an input to the function being jitted. The reason we can't do this is related to the fact mentioned above that jaxpr depends on the actual values used to trace it. - -The more specific information about the values we use in the trace, the more we can use standard Python control flow to express ourselves. However, being too specific means we can't reuse the same traced function for other values. JAX solves this by tracing at different levels of abstraction for different purposes. - -For `jax.jit`, the default level is `ShapedArray` -- that is, each tracer has a concrete shape (which we're allowed to condition on), but no concrete value. This allows the compiled function to work on all possible inputs with the same shape -- the standard use case in machine learning. However, because the tracers have no concrete value, if we attempt to condition on one, we get the error above. - -In `jax.grad`, the constraints are more relaxed, so you can do more. If you compose several transformations, however, you must satisfy the constraints of the most strict one. So, if you `jit(grad(f))`, `f` mustn't condition on value. For more detail on the interaction between Python control flow and JAX, see [🔪 JAX - The Sharp Bits 🔪: Control Flow](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#control-flow). - -One way to deal with this problem is to rewrite the code to avoid conditionals on value. Another is to use special [control flow operators](https://jax.readthedocs.io/en/latest/jax.lax.html#control-flow-operators) like `jax.lax.cond`. However, sometimes that is impossible. In that case, you can consider jitting only part of the function. For example, if the most computationally expensive part of the function is inside the loop, we can JIT just that inner part (though make sure to check the next section on caching to avoid shooting yourself in the foot): - -```{code-cell} ipython3 -:id: OeR8hF-NHAML -:outputId: d47fd6b2-8bbd-4939-a794-0b80183d3179 - -# While loop conditioned on x and n with a jitted body. - -@jax.jit -def loop_body(prev_i): - return prev_i + 1 - -def g_inner_jitted(x, n): - i = 0 - while i < n: - i = loop_body(i) - return x + i - -g_inner_jitted(10, 20) -``` - -+++ {"id": "5XUT2acoHBz-"} - -If we really need to JIT a function that has a condition on the value of an input, we can tell JAX to help itself to a less abstract tracer for a particular input by specifying `static_argnums` or `static_argnames`. The cost of this is that the resulting jaxpr is less flexible, so JAX will have to re-compile the function for every new value of the specified static input. It is only a good strategy if the function is guaranteed to get limited different values. - -```{code-cell} ipython3 -:id: 2yQmQTDNAenY -:outputId: c48f07b8-c3f9-4d2a-9dfd-663838a52511 - -f_jit_correct = jax.jit(f, static_argnums=0) -print(f_jit_correct(10)) -``` - -```{code-cell} ipython3 -:id: R4SXUEu-M-u1 -:outputId: 9e712e14-4e81-4744-dcf2-a10f470d9121 - -g_jit_correct = jax.jit(g, static_argnames=['n']) -print(g_jit_correct(10, 20)) -``` - -To specify such arguments when using `jit` as a decorator, a common pattern is to use python's `functools.partial`: - -```{code-cell} ipython3 -:id: 2X5rR4jkIO -:outputId: 81-4744-dc2e4-4e10f470f2-a19e71d9121 - -from functools import partial - -@partial(jax.jit, static_argnames=['n']) -def g_jit_decorated(x, n): - i = 0 - while i < n: - i += 1 - return x + i - -print(g_jit_decorated(10, 20)) -``` - -+++ {"id": "LczjIBt2X2Ms"} - -## When to use JIT - -In many of the examples above, jitting is not worth it: - -```{code-cell} ipython3 -:id: uMOqsNnqYApD -:outputId: 2d6c5122-43ad-4257-e56b-e77c889131c2 - -print("g jitted:") -%timeit g_jit_correct(10, 20).block_until_ready() - -print("g:") -%timeit g(10, 20) -``` - -+++ {"id": "cZmGYq80YP0j"} - -This is because `jax.jit` introduces some overhead itself. Therefore, it usually only saves time if the compiled function is complex and you will run it numerous times. Fortunately, this is common in machine learning, where we tend to compile a large, complicated model, then run it for millions of iterations. - -Generally, you want to jit the largest possible chunk of your computation; ideally, the entire update step. This gives the compiler maximum freedom to optimise. - -+++ {"id": "hJMjUlRcIzVS"} - -## Caching - -It's important to understand the caching behaviour of `jax.jit`. - -Suppose I define `f = jax.jit(g)`. When I first invoke `f`, it will get compiled, and the resulting XLA code will get cached. Subsequent calls of `f` will reuse the cached code. This is how `jax.jit` makes up for the up-front cost of compilation. - -If I specify `static_argnums`, then the cached code will be used only for the same values of arguments labelled as static. If any of them change, recompilation occurs. If there are many values, then your program might spend more time compiling than it would have executing ops one-by-one. - -Avoid calling `jax.jit` inside loops. For most cases, JAX will be able to use the compiled, cached function in subsequent calls to `jax.jit`. However, because the cache relies on the hash of the function, it becomes problematic when equivalent functions are redefined. This will cause unnecessary compilation each time in the loop: - -```{code-cell} ipython3 -:id: 6MDSXCfmSZVZ -:outputId: a035d0b7-6a4d-4a9e-c6b4-7521970829fc - -from functools import partial - -def unjitted_loop_body(prev_i): - return prev_i + 1 - -def g_inner_jitted_partial(x, n): - i = 0 - while i < n: - # Don't do this! each time the partial returns - # a function with different hash - i = jax.jit(partial(unjitted_loop_body))(i) - return x + i - -def g_inner_jitted_lambda(x, n): - i = 0 - while i < n: - # Don't do this!, lambda will also return - # a function with a different hash - i = jax.jit(lambda x: unjitted_loop_body(x))(i) - return x + i - -def g_inner_jitted_normal(x, n): - i = 0 - while i < n: - # this is OK, since JAX can find the - # cached, compiled function - i = jax.jit(unjitted_loop_body)(i) - return x + i - -print("jit called in a loop with partials:") -%timeit g_inner_jitted_partial(10, 20).block_until_ready() - -print("jit called in a loop with lambdas:") -%timeit g_inner_jitted_lambda(10, 20).block_until_ready() - -print("jit called in a loop with caching:") -%timeit g_inner_jitted_normal(10, 20).block_until_ready() -``` diff --git a/docs/jax-101/03-vectorization.ipynb b/docs/jax-101/03-vectorization.ipynb deleted file mode 100644 index cbcf120d4812..000000000000 --- a/docs/jax-101/03-vectorization.ipynb +++ /dev/null @@ -1,369 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "zMIrmiaZxiJC" - }, - "source": [ - "# Automatic Vectorization in JAX\n", - "\n", - "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/jax-101/03-vectorization.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/jax-101/03-vectorization.ipynb)\n", - "\n", - "*Authors: Matteo Hessel*\n", - "\n", - "In the previous section we discussed JIT compilation via the `jax.jit` function. This notebook discusses another of JAX's transforms: vectorization via `jax.vmap`." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Kw-_imBrx4nN" - }, - "source": [ - "## Manual Vectorization\n", - "\n", - "Consider the following simple code that computes the convolution of two one-dimensional vectors:" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": { - "id": "5Obro91lwE_s", - "outputId": "061983c6-2faa-4a54-83a5-d2a823f61087" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "Array([11., 20., 29.], dtype=float32)" - ] - }, - "execution_count": 1, - "metadata": { - "tags": [] - }, - "output_type": "execute_result" - } - ], - "source": [ - "import jax\n", - "import jax.numpy as jnp\n", - "\n", - "x = jnp.arange(5)\n", - "w = jnp.array([2., 3., 4.])\n", - "\n", - "def convolve(x, w):\n", - " output = []\n", - " for i in range(1, len(x)-1):\n", - " output.append(jnp.dot(x[i-1:i+2], w))\n", - " return jnp.array(output)\n", - "\n", - "convolve(x, w)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "z_nPhEhLRysk" - }, - "source": [ - "Suppose we would like to apply this function to a batch of weights `w` to a batch of vectors `x`." - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": { - "id": "rHQJnnrVUbxE" - }, - "outputs": [], - "source": [ - "xs = jnp.stack([x, x])\n", - "ws = jnp.stack([w, w])" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "ghaJQW1aUfPi" - }, - "source": [ - "The most naive option would be to simply loop over the batch in Python:" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": { - "id": "yM-IycdlzGyJ", - "outputId": "07ed6ffc-0265-45ef-d585-4b5fa7d221f1" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "Array([[11., 20., 29.],\n", - " [11., 20., 29.]], dtype=float32)" - ] - }, - "execution_count": 10, - "metadata": { - "tags": [] - }, - "output_type": "execute_result" - } - ], - "source": [ - "def manually_batched_convolve(xs, ws):\n", - " output = []\n", - " for i in range(xs.shape[0]):\n", - " output.append(convolve(xs[i], ws[i]))\n", - " return jnp.stack(output)\n", - "\n", - "manually_batched_convolve(xs, ws)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "VTh0l_1SUlh4" - }, - "source": [ - "This produces the correct result, however it is not very efficient.\n", - "\n", - "In order to batch the computation efficiently, you would normally have to rewrite the function manually to ensure it is done in vectorized form. This is not particularly difficult to implement, but does involve changing how the function treats indices, axes, and other parts of the input.\n", - "\n", - "For example, we could manually rewrite `convolve()` to support vectorized computation across the batch dimension as follows:" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": { - "id": "I4Wd9nrcTRRL", - "outputId": "0b037b43-7b41-4625-f9e0-a6e0dbc4c65a" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "Array([[11., 20., 29.],\n", - " [11., 20., 29.]], dtype=float32)" - ] - }, - "execution_count": 5, - "metadata": { - "tags": [] - }, - "output_type": "execute_result" - } - ], - "source": [ - "def manually_vectorized_convolve(xs, ws):\n", - " output = []\n", - " for i in range(1, xs.shape[-1] -1):\n", - " output.append(jnp.sum(xs[:, i-1:i+2] * ws, axis=1))\n", - " return jnp.stack(output, axis=1)\n", - "\n", - "manually_vectorized_convolve(xs, ws)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "DW-RJ2Zs2QVu" - }, - "source": [ - "Such re-implementation is messy and error-prone; fortunately JAX provides another way." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "2oVLanQmUAo_" - }, - "source": [ - "## Automatic Vectorization\n", - "\n", - "In JAX, the `jax.vmap` transformation is designed to generate such a vectorized implementation of a function automatically:" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": { - "id": "Brl-BoTqSQDw", - "outputId": "af608dbb-27f2-4fbc-f225-79f3101b13ff" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "Array([[11., 20., 29.],\n", - " [11., 20., 29.]], dtype=float32)" - ] - }, - "execution_count": 6, - "metadata": { - "tags": [] - }, - "output_type": "execute_result" - } - ], - "source": [ - "auto_batch_convolve = jax.vmap(convolve)\n", - "\n", - "auto_batch_convolve(xs, ws)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "7aVAy7332lFj" - }, - "source": [ - "It does this by tracing the function similarly to `jax.jit`, and automatically adding batch axes at the beginning of each input.\n", - "\n", - "If the batch dimension is not the first, you may use the `in_axes` and `out_axes` arguments to specify the location of the batch dimension in inputs and outputs. These may be an integer if the batch axis is the same for all inputs and outputs, or lists, otherwise." - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": { - "id": "_VEEm1CGT2n0", - "outputId": "751e0fbf-bdfb-41df-9436-4da5de23123f" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "Array([[11., 11.],\n", - " [20., 20.],\n", - " [29., 29.]], dtype=float32)" - ] - }, - "execution_count": 7, - "metadata": { - "tags": [] - }, - "output_type": "execute_result" - } - ], - "source": [ - "auto_batch_convolve_v2 = jax.vmap(convolve, in_axes=1, out_axes=1)\n", - "\n", - "xst = jnp.transpose(xs)\n", - "wst = jnp.transpose(ws)\n", - "\n", - "auto_batch_convolve_v2(xst, wst)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "-gNiLuxzSX32" - }, - "source": [ - "`jax.vmap` also supports the case where only one of the arguments is batched: for example, if you would like to convolve to a single set of weights `w` with a batch of vectors `x`; in this case the `in_axes` argument can be set to `None`:" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": { - "id": "2s2YDsamSxki", - "outputId": "5c70879b-5cce-4549-e38a-f45dbe663ab2" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "Array([[11., 20., 29.],\n", - " [11., 20., 29.]], dtype=float32)" - ] - }, - "execution_count": 8, - "metadata": { - "tags": [] - }, - "output_type": "execute_result" - } - ], - "source": [ - "batch_convolve_v3 = jax.vmap(convolve, in_axes=[0, None])\n", - "\n", - "batch_convolve_v3(xs, w)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "bsxT4hA6RTCG" - }, - "source": [ - "## Combining transformations\n", - "\n", - "As with all JAX transformations, `jax.jit` and `jax.vmap` are designed to be composable, which means you can wrap a vmapped function with `jit`, or a JITted function with `vmap`, and everything will work correctly:" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": { - "id": "gsC-Myg0RVdj", - "outputId": "cbdd384e-6633-4cea-b1a0-a01ad934a768" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "Array([[11., 20., 29.],\n", - " [11., 20., 29.]], dtype=float32)" - ] - }, - "execution_count": 9, - "metadata": { - "tags": [] - }, - "output_type": "execute_result" - } - ], - "source": [ - "jitted_batch_convolve = jax.jit(auto_batch_convolve)\n", - "\n", - "jitted_batch_convolve(xs, ws)" - ] - } - ], - "metadata": { - "colab": { - "collapsed_sections": [], - "name": "Vectorization in JAX", - "provenance": [] - }, - "jupytext": { - "formats": "ipynb,md:myst" - }, - "kernelspec": { - "display_name": "Python 3", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.7.6" - } - }, - "nbformat": 4, - "nbformat_minor": 0 -} diff --git a/docs/jax-101/03-vectorization.md b/docs/jax-101/03-vectorization.md deleted file mode 100644 index cb46b87a97bf..000000000000 --- a/docs/jax-101/03-vectorization.md +++ /dev/null @@ -1,161 +0,0 @@ ---- -jupytext: - formats: ipynb,md:myst - text_representation: - extension: .md - format_name: myst - format_version: 0.13 - jupytext_version: 1.16.0 -kernelspec: - display_name: Python 3 - name: python3 ---- - -+++ {"id": "zMIrmiaZxiJC"} - -# Automatic Vectorization in JAX - -[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/jax-101/03-vectorization.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/jax-101/03-vectorization.ipynb) - -*Authors: Matteo Hessel* - -In the previous section we discussed JIT compilation via the `jax.jit` function. This notebook discusses another of JAX's transforms: vectorization via `jax.vmap`. - -+++ {"id": "Kw-_imBrx4nN"} - -## Manual Vectorization - -Consider the following simple code that computes the convolution of two one-dimensional vectors: - -```{code-cell} ipython3 -:id: 5Obro91lwE_s -:outputId: 061983c6-2faa-4a54-83a5-d2a823f61087 - -import jax -import jax.numpy as jnp - -x = jnp.arange(5) -w = jnp.array([2., 3., 4.]) - -def convolve(x, w): - output = [] - for i in range(1, len(x)-1): - output.append(jnp.dot(x[i-1:i+2], w)) - return jnp.array(output) - -convolve(x, w) -``` - -+++ {"id": "z_nPhEhLRysk"} - -Suppose we would like to apply this function to a batch of weights `w` to a batch of vectors `x`. - -```{code-cell} ipython3 -:id: rHQJnnrVUbxE - -xs = jnp.stack([x, x]) -ws = jnp.stack([w, w]) -``` - -+++ {"id": "ghaJQW1aUfPi"} - -The most naive option would be to simply loop over the batch in Python: - -```{code-cell} ipython3 -:id: yM-IycdlzGyJ -:outputId: 07ed6ffc-0265-45ef-d585-4b5fa7d221f1 - -def manually_batched_convolve(xs, ws): - output = [] - for i in range(xs.shape[0]): - output.append(convolve(xs[i], ws[i])) - return jnp.stack(output) - -manually_batched_convolve(xs, ws) -``` - -+++ {"id": "VTh0l_1SUlh4"} - -This produces the correct result, however it is not very efficient. - -In order to batch the computation efficiently, you would normally have to rewrite the function manually to ensure it is done in vectorized form. This is not particularly difficult to implement, but does involve changing how the function treats indices, axes, and other parts of the input. - -For example, we could manually rewrite `convolve()` to support vectorized computation across the batch dimension as follows: - -```{code-cell} ipython3 -:id: I4Wd9nrcTRRL -:outputId: 0b037b43-7b41-4625-f9e0-a6e0dbc4c65a - -def manually_vectorized_convolve(xs, ws): - output = [] - for i in range(1, xs.shape[-1] -1): - output.append(jnp.sum(xs[:, i-1:i+2] * ws, axis=1)) - return jnp.stack(output, axis=1) - -manually_vectorized_convolve(xs, ws) -``` - -+++ {"id": "DW-RJ2Zs2QVu"} - -Such re-implementation is messy and error-prone; fortunately JAX provides another way. - -+++ {"id": "2oVLanQmUAo_"} - -## Automatic Vectorization - -In JAX, the `jax.vmap` transformation is designed to generate such a vectorized implementation of a function automatically: - -```{code-cell} ipython3 -:id: Brl-BoTqSQDw -:outputId: af608dbb-27f2-4fbc-f225-79f3101b13ff - -auto_batch_convolve = jax.vmap(convolve) - -auto_batch_convolve(xs, ws) -``` - -+++ {"id": "7aVAy7332lFj"} - -It does this by tracing the function similarly to `jax.jit`, and automatically adding batch axes at the beginning of each input. - -If the batch dimension is not the first, you may use the `in_axes` and `out_axes` arguments to specify the location of the batch dimension in inputs and outputs. These may be an integer if the batch axis is the same for all inputs and outputs, or lists, otherwise. - -```{code-cell} ipython3 -:id: _VEEm1CGT2n0 -:outputId: 751e0fbf-bdfb-41df-9436-4da5de23123f - -auto_batch_convolve_v2 = jax.vmap(convolve, in_axes=1, out_axes=1) - -xst = jnp.transpose(xs) -wst = jnp.transpose(ws) - -auto_batch_convolve_v2(xst, wst) -``` - -+++ {"id": "-gNiLuxzSX32"} - -`jax.vmap` also supports the case where only one of the arguments is batched: for example, if you would like to convolve to a single set of weights `w` with a batch of vectors `x`; in this case the `in_axes` argument can be set to `None`: - -```{code-cell} ipython3 -:id: 2s2YDsamSxki -:outputId: 5c70879b-5cce-4549-e38a-f45dbe663ab2 - -batch_convolve_v3 = jax.vmap(convolve, in_axes=[0, None]) - -batch_convolve_v3(xs, w) -``` - -+++ {"id": "bsxT4hA6RTCG"} - -## Combining transformations - -As with all JAX transformations, `jax.jit` and `jax.vmap` are designed to be composable, which means you can wrap a vmapped function with `jit`, or a JITted function with `vmap`, and everything will work correctly: - -```{code-cell} ipython3 -:id: gsC-Myg0RVdj -:outputId: cbdd384e-6633-4cea-b1a0-a01ad934a768 - -jitted_batch_convolve = jax.jit(auto_batch_convolve) - -jitted_batch_convolve(xs, ws) -``` diff --git a/docs/jax-101/04-advanced-autodiff.ipynb b/docs/jax-101/04-advanced-autodiff.ipynb deleted file mode 100644 index de0627549e0b..000000000000 --- a/docs/jax-101/04-advanced-autodiff.ipynb +++ /dev/null @@ -1,738 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "kORMl5KmfByI" - }, - "source": [ - "# Advanced Automatic Differentiation in JAX\n", - "\n", - "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/jax-101/04-advanced-autodiff.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/jax-101/04-advanced-autodiff.ipynb)\n", - "\n", - "*Authors: Vlatimir Mikulik & Matteo Hessel*\n", - "\n", - "Computing gradients is a critical part of modern machine learning methods. This section considers a few advanced topics in the areas of automatic differentiation as it relates to modern machine learning.\n", - "\n", - "While understanding how automatic differentiation works under the hood isn't crucial for using JAX in most contexts, we encourage the reader to check out this quite accessible [video](https://www.youtube.com/watch?v=wG_nF1awSSY) to get a deeper sense of what's going on.\n", - "\n", - "[The Autodiff Cookbook](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html) is a more advanced and more detailed explanation of how these ideas are implemented in the JAX backend. It's not necessary to understand this to do most things in JAX. However, some features (like defining [custom derivatives](https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html)) depend on understanding this, so it's worth knowing this explanation exists if you ever need to use them." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "qx50CO1IorCc" - }, - "source": [ - "## Higher-order derivatives\n", - "\n", - "JAX's autodiff makes it easy to compute higher-order derivatives, because the functions that compute derivatives are themselves differentiable. Thus, higher-order derivatives are as easy as stacking transformations.\n", - "\n", - "We illustrate this in the single-variable case:\n", - "\n", - "The derivative of $f(x) = x^3 + 2x^2 - 3x + 1$ can be computed as:" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": { - "id": "Kqsbj98UTVdi" - }, - "outputs": [], - "source": [ - "import jax\n", - "\n", - "f = lambda x: x**3 + 2*x**2 - 3*x + 1\n", - "\n", - "dfdx = jax.grad(f)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "ItEt15OGiiAF" - }, - "source": [ - "The higher-order derivatives of $f$ are:\n", - "\n", - "$$\n", - "\\begin{array}{l}\n", - "f'(x) = 3x^2 + 4x -3\\\\\n", - "f''(x) = 6x + 4\\\\\n", - "f'''(x) = 6\\\\\n", - "f^{iv}(x) = 0\n", - "\\end{array}\n", - "$$\n", - "\n", - "Computing any of these in JAX is as easy as chaining the `grad` function:" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": { - "id": "5X3yQqLgimqH" - }, - "outputs": [], - "source": [ - "d2fdx = jax.grad(dfdx)\n", - "d3fdx = jax.grad(d2fdx)\n", - "d4fdx = jax.grad(d3fdx)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "fVL2P_pcj8T1" - }, - "source": [ - "Evaluating the above in $x=1$ would give us:\n", - "\n", - "$$\n", - "\\begin{array}{l}\n", - "f'(1) = 4\\\\\n", - "f''(1) = 10\\\\\n", - "f'''(1) = 6\\\\\n", - "f^{iv}(1) = 0\n", - "\\end{array}\n", - "$$\n", - "\n", - "Using JAX:" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": { - "id": "tJkIp9wFjxL3", - "outputId": "581ecf87-2d20-4c83-9443-5befc1baf51d" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "4.0\n", - "10.0\n", - "6.0\n", - "0.0\n" - ] - } - ], - "source": [ - "print(dfdx(1.))\n", - "print(d2fdx(1.))\n", - "print(d3fdx(1.))\n", - "print(d4fdx(1.))" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "3-fTelU7LHRr" - }, - "source": [ - "In the multivariable case, higher-order derivatives are more complicated. The second-order derivative of a function is represented by its [Hessian matrix](https://en.wikipedia.org/wiki/Hessian_matrix), defined according to\n", - "\n", - "$$(\\mathbf{H}f)_{i,j} = \\frac{\\partial^2 f}{\\partial_i\\partial_j}.$$\n", - "\n", - "The Hessian of a real-valued function of several variables, $f: \\mathbb R^n\\to\\mathbb R$, can be identified with the Jacobian of its gradient. JAX provides two transformations for computing the Jacobian of a function, `jax.jacfwd` and `jax.jacrev`, corresponding to forward- and reverse-mode autodiff. They give the same answer, but one can be more efficient than the other in different circumstances – see the [video about autodiff](https://www.youtube.com/watch?v=wG_nF1awSSY) linked above for an explanation." - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": { - "id": "ILhkef1rOB6_" - }, - "outputs": [], - "source": [ - "def hessian(f):\n", - " return jax.jacfwd(jax.grad(f))" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "xaENwADXOGf_" - }, - "source": [ - "Let's double check this is correct on the dot-product $f: \\mathbf{x} \\mapsto \\mathbf{x} ^\\top \\mathbf{x}$.\n", - "\n", - "if $i=j$, $\\frac{\\partial^2 f}{\\partial_i\\partial_j}(\\mathbf{x}) = 2$. Otherwise, $\\frac{\\partial^2 f}{\\partial_i\\partial_j}(\\mathbf{x}) = 0$." - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": { - "id": "Xm3A0QdWRdJl", - "outputId": "e1e8cba9-b567-439b-b8fc-34b21497e67f" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "Array([[2., 0., 0.],\n", - " [0., 2., 0.],\n", - " [0., 0., 2.]], dtype=float32)" - ] - }, - "execution_count": 6, - "metadata": { - "tags": [] - }, - "output_type": "execute_result" - } - ], - "source": [ - "import jax.numpy as jnp\n", - "\n", - "def f(x):\n", - " return jnp.dot(x, x)\n", - "\n", - "hessian(f)(jnp.array([1., 2., 3.]))" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "7_gbi34WSUsD" - }, - "source": [ - "Often, however, we aren't interested in computing the full Hessian itself, and doing so can be very inefficient. [The Autodiff Cookbook](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html) explains some tricks, like the Hessian-vector product, that allow to use it without materialising the whole matrix.\n", - "\n", - "If you plan to work with higher-order derivatives in JAX, we strongly recommend reading the Autodiff Cookbook." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "zMT2qAi-SvcK" - }, - "source": [ - "## Higher order optimization\n", - "\n", - "Some meta-learning techniques, such as Model-Agnostic Meta-Learning ([MAML](https://arxiv.org/abs/1703.03400)), require differentiating through gradient updates. In other frameworks this can be quite cumbersome, but in JAX it's much easier:\n", - "\n", - "```python\n", - "def meta_loss_fn(params, data):\n", - " \"\"\"Computes the loss after one step of SGD.\"\"\"\n", - " grads = jax.grad(loss_fn)(params, data)\n", - " return loss_fn(params - lr * grads, data)\n", - "\n", - "meta_grads = jax.grad(meta_loss_fn)(params, data)\n", - "```" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "3h9Aj3YyuL6P" - }, - "source": [ - "## Stopping gradients\n", - "\n", - "Auto-diff enables automatic computation of the gradient of a function with respect to its inputs. Sometimes, however, we might want some additional control: for instance, we might want to avoid back-propagating gradients through some subset of the computational graph.\n", - "\n", - "Consider for instance the TD(0) ([temporal difference](https://en.wikipedia.org/wiki/Temporal_difference_learning)) reinforcement learning update. This is used to learn to estimate the *value* of a state in an environment from experience of interacting with the environment. Let's assume the value estimate $v_{\\theta}(s_{t-1}$) in a state $s_{t-1}$ is parameterised by a linear function." - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": { - "id": "fjLqbCb6SiOm" - }, - "outputs": [], - "source": [ - "# Value function and initial parameters\n", - "value_fn = lambda theta, state: jnp.dot(theta, state)\n", - "theta = jnp.array([0.1, -0.1, 0.])" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "85S7HBo1tBzt" - }, - "source": [ - "Consider a transition from a state $s_{t-1}$ to a state $s_t$ during which we observed the reward $r_t$" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": { - "id": "T6cRPau6tCSE" - }, - "outputs": [], - "source": [ - "# An example transition.\n", - "s_tm1 = jnp.array([1., 2., -1.])\n", - "r_t = jnp.array(1.)\n", - "s_t = jnp.array([2., 1., 0.])" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "QO5CHA9_Sk01" - }, - "source": [ - "The TD(0) update to the network parameters is:\n", - "\n", - "$$\n", - "\\Delta \\theta = (r_t + v_{\\theta}(s_t) - v_{\\theta}(s_{t-1})) \\nabla v_{\\theta}(s_{t-1})\n", - "$$\n", - "\n", - "This update is not the gradient of any loss function.\n", - "\n", - "However, it can be **written** as the gradient of the pseudo loss function\n", - "\n", - "$$\n", - "L(\\theta) = - \\frac{1}{2} [r_t + v_{\\theta}(s_t) - v_{\\theta}(s_{t-1})]^2\n", - "$$\n", - "\n", - "if the dependency of the target $r_t + v_{\\theta}(s_t)$ on the parameter $\\theta$ is ignored.\n", - "\n", - "How can we implement this in JAX? If we write the pseudo loss naively we get:" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": { - "id": "uMcFny2xuOwz", - "outputId": "79c10af9-10b8-4e18-9753-a53918b9d72d" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "Array([ -1.2, 1.2, -1.2], dtype=float32)" - ] - }, - "execution_count": 9, - "metadata": { - "tags": [] - }, - "output_type": "execute_result" - } - ], - "source": [ - "def td_loss(theta, s_tm1, r_t, s_t):\n", - " v_tm1 = value_fn(theta, s_tm1)\n", - " target = r_t + value_fn(theta, s_t)\n", - " return -0.5 * ((target - v_tm1) ** 2)\n", - "\n", - "td_update = jax.grad(td_loss)\n", - "delta_theta = td_update(theta, s_tm1, r_t, s_t)\n", - "\n", - "delta_theta" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "CPnjm59GG4Gq" - }, - "source": [ - "But `td_update` will **not** compute a TD(0) update, because the gradient computation will include the dependency of `target` on $\\theta$.\n", - "\n", - "We can use `jax.lax.stop_gradient` to force JAX to ignore the dependency of the target on $\\theta$:" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": { - "id": "MKeq7trKPS4V", - "outputId": "0f27d754-a871-4c47-8e3a-a961418a24cc" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "Array([1.2, 2.4, -1.2], dtype=float32)" - ] - }, - "execution_count": 10, - "metadata": { - "tags": [] - }, - "output_type": "execute_result" - } - ], - "source": [ - "def td_loss(theta, s_tm1, r_t, s_t):\n", - " v_tm1 = value_fn(theta, s_tm1)\n", - " target = r_t + value_fn(theta, s_t)\n", - " return -0.5 * ((jax.lax.stop_gradient(target) - v_tm1) ** 2)\n", - "\n", - "td_update = jax.grad(td_loss)\n", - "delta_theta = td_update(theta, s_tm1, r_t, s_t)\n", - "\n", - "delta_theta" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "JOnjm59GG4Gq" - }, - "source": [ - "This will treat `target` as if it did **not** depend on the parameters $\\theta$ and compute the correct update to the parameters.\n", - "\n", - "Now, let's also calculate $\\Delta \\theta$ using the original TD(0) update expression, to cross-check our work. You may wish to try and implement this yourself using jax.grad and your knowledge so far. Here's our solution:" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "metadata": { - "id": "WCeq7trKPS4V", - "outputId": "0f19d754-a871-4c47-8e3a-a961418a24cc" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "Array([ 1.2, 2.4, -1.2], dtype=float32)" - ] - }, - "execution_count": 11, - "metadata": { - "tags": [] - }, - "output_type": "execute_result" - } - ], - "source": [ - "s_grad = jax.grad(value_fn)(theta, s_tm1)\n", - "delta_theta_original_calculation = (r_t + value_fn(theta, s_t) - value_fn(theta, s_tm1)) * s_grad\n", - "\n", - "delta_theta_original_calculation # [1.2, 2.4, -1.2], same as `delta_theta`" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "TNF0CkwOTKpD" - }, - "source": [ - "`jax.lax.stop_gradient` may also be useful in other settings, for instance if you want the gradient from some loss to only affect a subset of the parameters of the neural network (because, for instance, the other parameters are trained using a different loss)." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "UMY0IyuOTKpG" - }, - "source": [ - "## Straight-through estimator using `stop_gradient`\n", - "\n", - "The straight-through estimator is a trick for defining a 'gradient' of a function that is otherwise non-differentiable. Given a non-differentiable function $f : \\mathbb{R}^n \\to \\mathbb{R}^n$ that is used as part of a larger function that we wish to find a gradient of, we simply pretend during the backward pass that $f$ is the identity function. This can be implemented neatly using `jax.lax.stop_gradient`:" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": { - "id": "hdORJENmVHvX", - "outputId": "f0839541-46a4-45a9-fce7-ead08f20046b" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "f(x): 3.0\n", - "straight_through_f(x): 3.0\n", - "grad(f)(x): 0.0\n", - "grad(straight_through_f)(x): 1.0\n" - ] - } - ], - "source": [ - "def f(x):\n", - " return jnp.round(x) # non-differentiable\n", - "\n", - "def straight_through_f(x):\n", - " # Create an exactly-zero expression with Sterbenz lemma that has\n", - " # an exactly-one gradient.\n", - " zero = x - jax.lax.stop_gradient(x)\n", - " return zero + jax.lax.stop_gradient(f(x))\n", - "\n", - "print(\"f(x): \", f(3.2))\n", - "print(\"straight_through_f(x):\", straight_through_f(3.2))\n", - "\n", - "print(\"grad(f)(x):\", jax.grad(f)(3.2))\n", - "print(\"grad(straight_through_f)(x):\", jax.grad(straight_through_f)(3.2))" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Wx3RNE0Sw5mn" - }, - "source": [ - "## Per-example gradients\n", - "\n", - "While most ML systems compute gradients and updates from batches of data, for reasons of computational efficiency and/or variance reduction, it is sometimes necessary to have access to the gradient/update associated with each specific sample in the batch.\n", - "\n", - "For instance, this is needed to prioritise data based on gradient magnitude, or to apply clipping / normalisations on a sample by sample basis.\n", - "\n", - "In many frameworks (PyTorch, TF, Theano) it is often not trivial to compute per-example gradients, because the library directly accumulates the gradient over the batch. Naive workarounds, such as computing a separate loss per example and then aggregating the resulting gradients are typically very inefficient.\n", - "\n", - "In JAX we can define the code to compute the gradient per-sample in an easy but efficient way.\n", - "\n", - "Just combine the `jit`, `vmap` and `grad` transformations together:" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "metadata": { - "id": "tFLyd9ifw4GG", - "outputId": "bf3ad4a3-102d-47a6-ece0-f4a8c9e5d434" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "Array([[1.2, 2.4, -1.2],\n", - " [1.2, 2.4, -1.2]], dtype=float32)" - ] - }, - "execution_count": 13, - "metadata": { - "tags": [] - }, - "output_type": "execute_result" - } - ], - "source": [ - "perex_grads = jax.jit(jax.vmap(jax.grad(td_loss), in_axes=(None, 0, 0, 0)))\n", - "\n", - "# Test it:\n", - "batched_s_tm1 = jnp.stack([s_tm1, s_tm1])\n", - "batched_r_t = jnp.stack([r_t, r_t])\n", - "batched_s_t = jnp.stack([s_t, s_t])\n", - "\n", - "perex_grads(theta, batched_s_tm1, batched_r_t, batched_s_t)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "VxvYVEYQYiS_" - }, - "source": [ - "Let's walk through this one transformation at a time.\n", - "\n", - "First, we apply `jax.grad` to `td_loss` to obtain a function that computes the gradient of the loss w.r.t. the parameters on single (unbatched) inputs:" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "metadata": { - "id": "rPO67QQrY5Bk", - "outputId": "fbb45b98-2dbf-4865-e6e5-87dc3eef5560" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "Array([1.2, 2.4, -1.2], dtype=float32)" - ] - }, - "execution_count": 14, - "metadata": { - "tags": [] - }, - "output_type": "execute_result" - } - ], - "source": [ - "dtdloss_dtheta = jax.grad(td_loss)\n", - "\n", - "dtdloss_dtheta(theta, s_tm1, r_t, s_t)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "cU36nVAlcnJ0" - }, - "source": [ - "This function computes one row of the array above." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "c6DQF0b3ZA5u" - }, - "source": [ - "Then, we vectorise this function using `jax.vmap`. This adds a batch dimension to all inputs and outputs. Now, given a batch of inputs, we produce a batch of outputs -- each output in the batch corresponds to the gradient for the corresponding member of the input batch." - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "metadata": { - "id": "5agbNKavaNDM", - "outputId": "ab081012-88ab-4904-a367-68e9f81445f0" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "Array([[1.2, 2.4, -1.2],\n", - " [1.2, 2.4, -1.2]], dtype=float32)" - ] - }, - "execution_count": 15, - "metadata": { - "tags": [] - }, - "output_type": "execute_result" - } - ], - "source": [ - "almost_perex_grads = jax.vmap(dtdloss_dtheta)\n", - "\n", - "batched_theta = jnp.stack([theta, theta])\n", - "almost_perex_grads(batched_theta, batched_s_tm1, batched_r_t, batched_s_t)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "K-v34yLuan7k" - }, - "source": [ - "This isn't quite what we want, because we have to manually feed this function a batch of `theta`s, whereas we actually want to use a single `theta`. We fix this by adding `in_axes` to the `jax.vmap`, specifying theta as `None`, and the other args as `0`. This makes the resulting function add an extra axis only to the other arguments, leaving `theta` unbatched, as we want:" - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "metadata": { - "id": "S6kd5MujbGrr", - "outputId": "d3d731ef-3f7d-4a0a-ce91-7df57627ddbd" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "Array([[1.2, 2.4, -1.2],\n", - " [1.2, 2.4, -1.2]], dtype=float32)" - ] - }, - "execution_count": 16, - "metadata": { - "tags": [] - }, - "output_type": "execute_result" - } - ], - "source": [ - "inefficient_perex_grads = jax.vmap(dtdloss_dtheta, in_axes=(None, 0, 0, 0))\n", - "\n", - "inefficient_perex_grads(theta, batched_s_tm1, batched_r_t, batched_s_t)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "O0hbsm70be5T" - }, - "source": [ - "Almost there! This does what we want, but is slower than it has to be. Now, we wrap the whole thing in a `jax.jit` to get the compiled, efficient version of the same function:" - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "metadata": { - "id": "Fvr709FcbrSW", - "outputId": "627db899-5620-4bed-8d34-cd1364d3d187" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "Array([[1.2, 2.4, -1.2],\n", - " [1.2, 2.4, -1.2]], dtype=float32)" - ] - }, - "execution_count": 17, - "metadata": { - "tags": [] - }, - "output_type": "execute_result" - } - ], - "source": [ - "perex_grads = jax.jit(inefficient_perex_grads)\n", - "\n", - "perex_grads(theta, batched_s_tm1, batched_r_t, batched_s_t)" - ] - }, - { - "cell_type": "code", - "execution_count": 18, - "metadata": { - "id": "FH42yzbHcNs2", - "outputId": "c8e52f93-615a-4ce7-d8ab-fb6215995a39" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "100 loops, best of 5: 7.74 ms per loop\n", - "10000 loops, best of 5: 86.2 µs per loop\n" - ] - } - ], - "source": [ - "%timeit inefficient_perex_grads(theta, batched_s_tm1, batched_r_t, batched_s_t).block_until_ready()\n", - "%timeit perex_grads(theta, batched_s_tm1, batched_r_t, batched_s_t).block_until_ready()" - ] - } - ], - "metadata": { - "colab": { - "collapsed_sections": [], - "name": "Advanced Grads", - "provenance": [] - }, - "jupytext": { - "formats": "ipynb,md:myst" - }, - "kernelspec": { - "display_name": "Python 3", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.7.6" - } - }, - "nbformat": 4, - "nbformat_minor": 0 -} diff --git a/docs/jax-101/04-advanced-autodiff.md b/docs/jax-101/04-advanced-autodiff.md deleted file mode 100644 index 1eb99376e101..000000000000 --- a/docs/jax-101/04-advanced-autodiff.md +++ /dev/null @@ -1,374 +0,0 @@ ---- -jupytext: - formats: ipynb,md:myst - text_representation: - extension: .md - format_name: myst - format_version: 0.13 - jupytext_version: 1.16.0 -kernelspec: - display_name: Python 3 - name: python3 ---- - -+++ {"id": "kORMl5KmfByI"} - -# Advanced Automatic Differentiation in JAX - -[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/jax-101/04-advanced-autodiff.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/jax-101/04-advanced-autodiff.ipynb) - -*Authors: Vlatimir Mikulik & Matteo Hessel* - -Computing gradients is a critical part of modern machine learning methods. This section considers a few advanced topics in the areas of automatic differentiation as it relates to modern machine learning. - -While understanding how automatic differentiation works under the hood isn't crucial for using JAX in most contexts, we encourage the reader to check out this quite accessible [video](https://www.youtube.com/watch?v=wG_nF1awSSY) to get a deeper sense of what's going on. - -[The Autodiff Cookbook](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html) is a more advanced and more detailed explanation of how these ideas are implemented in the JAX backend. It's not necessary to understand this to do most things in JAX. However, some features (like defining [custom derivatives](https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html)) depend on understanding this, so it's worth knowing this explanation exists if you ever need to use them. - -+++ {"id": "qx50CO1IorCc"} - -## Higher-order derivatives - -JAX's autodiff makes it easy to compute higher-order derivatives, because the functions that compute derivatives are themselves differentiable. Thus, higher-order derivatives are as easy as stacking transformations. - -We illustrate this in the single-variable case: - -The derivative of $f(x) = x^3 + 2x^2 - 3x + 1$ can be computed as: - -```{code-cell} ipython3 -:id: Kqsbj98UTVdi - -import jax - -f = lambda x: x**3 + 2*x**2 - 3*x + 1 - -dfdx = jax.grad(f) -``` - -+++ {"id": "ItEt15OGiiAF"} - -The higher-order derivatives of $f$ are: - -$$ -\begin{array}{l} -f'(x) = 3x^2 + 4x -3\\ -f''(x) = 6x + 4\\ -f'''(x) = 6\\ -f^{iv}(x) = 0 -\end{array} -$$ - -Computing any of these in JAX is as easy as chaining the `grad` function: - -```{code-cell} ipython3 -:id: 5X3yQqLgimqH - -d2fdx = jax.grad(dfdx) -d3fdx = jax.grad(d2fdx) -d4fdx = jax.grad(d3fdx) -``` - -+++ {"id": "fVL2P_pcj8T1"} - -Evaluating the above in $x=1$ would give us: - -$$ -\begin{array}{l} -f'(1) = 4\\ -f''(1) = 10\\ -f'''(1) = 6\\ -f^{iv}(1) = 0 -\end{array} -$$ - -Using JAX: - -```{code-cell} ipython3 -:id: tJkIp9wFjxL3 -:outputId: 581ecf87-2d20-4c83-9443-5befc1baf51d - -print(dfdx(1.)) -print(d2fdx(1.)) -print(d3fdx(1.)) -print(d4fdx(1.)) -``` - -+++ {"id": "3-fTelU7LHRr"} - -In the multivariable case, higher-order derivatives are more complicated. The second-order derivative of a function is represented by its [Hessian matrix](https://en.wikipedia.org/wiki/Hessian_matrix), defined according to - -$$(\mathbf{H}f)_{i,j} = \frac{\partial^2 f}{\partial_i\partial_j}.$$ - -The Hessian of a real-valued function of several variables, $f: \mathbb R^n\to\mathbb R$, can be identified with the Jacobian of its gradient. JAX provides two transformations for computing the Jacobian of a function, `jax.jacfwd` and `jax.jacrev`, corresponding to forward- and reverse-mode autodiff. They give the same answer, but one can be more efficient than the other in different circumstances – see the [video about autodiff](https://www.youtube.com/watch?v=wG_nF1awSSY) linked above for an explanation. - -```{code-cell} ipython3 -:id: ILhkef1rOB6_ - -def hessian(f): - return jax.jacfwd(jax.grad(f)) -``` - -+++ {"id": "xaENwADXOGf_"} - -Let's double check this is correct on the dot-product $f: \mathbf{x} \mapsto \mathbf{x} ^\top \mathbf{x}$. - -if $i=j$, $\frac{\partial^2 f}{\partial_i\partial_j}(\mathbf{x}) = 2$. Otherwise, $\frac{\partial^2 f}{\partial_i\partial_j}(\mathbf{x}) = 0$. - -```{code-cell} ipython3 -:id: Xm3A0QdWRdJl -:outputId: e1e8cba9-b567-439b-b8fc-34b21497e67f - -import jax.numpy as jnp - -def f(x): - return jnp.dot(x, x) - -hessian(f)(jnp.array([1., 2., 3.])) -``` - -+++ {"id": "7_gbi34WSUsD"} - -Often, however, we aren't interested in computing the full Hessian itself, and doing so can be very inefficient. [The Autodiff Cookbook](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html) explains some tricks, like the Hessian-vector product, that allow to use it without materialising the whole matrix. - -If you plan to work with higher-order derivatives in JAX, we strongly recommend reading the Autodiff Cookbook. - -+++ {"id": "zMT2qAi-SvcK"} - -## Higher order optimization - -Some meta-learning techniques, such as Model-Agnostic Meta-Learning ([MAML](https://arxiv.org/abs/1703.03400)), require differentiating through gradient updates. In other frameworks this can be quite cumbersome, but in JAX it's much easier: - -```python -def meta_loss_fn(params, data): - """Computes the loss after one step of SGD.""" - grads = jax.grad(loss_fn)(params, data) - return loss_fn(params - lr * grads, data) - -meta_grads = jax.grad(meta_loss_fn)(params, data) -``` - -+++ {"id": "3h9Aj3YyuL6P"} - -## Stopping gradients - -Auto-diff enables automatic computation of the gradient of a function with respect to its inputs. Sometimes, however, we might want some additional control: for instance, we might want to avoid back-propagating gradients through some subset of the computational graph. - -Consider for instance the TD(0) ([temporal difference](https://en.wikipedia.org/wiki/Temporal_difference_learning)) reinforcement learning update. This is used to learn to estimate the *value* of a state in an environment from experience of interacting with the environment. Let's assume the value estimate $v_{\theta}(s_{t-1}$) in a state $s_{t-1}$ is parameterised by a linear function. - -```{code-cell} ipython3 -:id: fjLqbCb6SiOm - -# Value function and initial parameters -value_fn = lambda theta, state: jnp.dot(theta, state) -theta = jnp.array([0.1, -0.1, 0.]) -``` - -+++ {"id": "85S7HBo1tBzt"} - -Consider a transition from a state $s_{t-1}$ to a state $s_t$ during which we observed the reward $r_t$ - -```{code-cell} ipython3 -:id: T6cRPau6tCSE - -# An example transition. -s_tm1 = jnp.array([1., 2., -1.]) -r_t = jnp.array(1.) -s_t = jnp.array([2., 1., 0.]) -``` - -+++ {"id": "QO5CHA9_Sk01"} - -The TD(0) update to the network parameters is: - -$$ -\Delta \theta = (r_t + v_{\theta}(s_t) - v_{\theta}(s_{t-1})) \nabla v_{\theta}(s_{t-1}) -$$ - -This update is not the gradient of any loss function. - -However, it can be **written** as the gradient of the pseudo loss function - -$$ -L(\theta) = - \frac{1}{2} [r_t + v_{\theta}(s_t) - v_{\theta}(s_{t-1})]^2 -$$ - -if the dependency of the target $r_t + v_{\theta}(s_t)$ on the parameter $\theta$ is ignored. - -How can we implement this in JAX? If we write the pseudo loss naively we get: - -```{code-cell} ipython3 -:id: uMcFny2xuOwz -:outputId: 79c10af9-10b8-4e18-9753-a53918b9d72d - -def td_loss(theta, s_tm1, r_t, s_t): - v_tm1 = value_fn(theta, s_tm1) - target = r_t + value_fn(theta, s_t) - return -0.5 * ((target - v_tm1) ** 2) - -td_update = jax.grad(td_loss) -delta_theta = td_update(theta, s_tm1, r_t, s_t) - -delta_theta -``` - -+++ {"id": "CPnjm59GG4Gq"} - -But `td_update` will **not** compute a TD(0) update, because the gradient computation will include the dependency of `target` on $\theta$. - -We can use `jax.lax.stop_gradient` to force JAX to ignore the dependency of the target on $\theta$: - -```{code-cell} ipython3 -:id: MKeq7trKPS4V -:outputId: 0f27d754-a871-4c47-8e3a-a961418a24cc - -def td_loss(theta, s_tm1, r_t, s_t): - v_tm1 = value_fn(theta, s_tm1) - target = r_t + value_fn(theta, s_t) - return -0.5 * ((jax.lax.stop_gradient(target) - v_tm1) ** 2) - -td_update = jax.grad(td_loss) -delta_theta = td_update(theta, s_tm1, r_t, s_t) - -delta_theta -``` - -+++ {"id": "JOnjm59GG4Gq"} - -This will treat `target` as if it did **not** depend on the parameters $\theta$ and compute the correct update to the parameters. - -Now, let's also calculate $\Delta \theta$ using the original TD(0) update expression, to cross-check our work. You may wish to try and implement this yourself using jax.grad and your knowledge so far. Here's our solution: - -```{code-cell} ipython3 -:id: WCeq7trKPS4V -:outputId: 0f19d754-a871-4c47-8e3a-a961418a24cc - -s_grad = jax.grad(value_fn)(theta, s_tm1) -delta_theta_original_calculation = (r_t + value_fn(theta, s_t) - value_fn(theta, s_tm1)) * s_grad - -delta_theta_original_calculation # [1.2, 2.4, -1.2], same as `delta_theta` -``` - -+++ {"id": "TNF0CkwOTKpD"} - -`jax.lax.stop_gradient` may also be useful in other settings, for instance if you want the gradient from some loss to only affect a subset of the parameters of the neural network (because, for instance, the other parameters are trained using a different loss). - -+++ {"id": "UMY0IyuOTKpG"} - -## Straight-through estimator using `stop_gradient` - -The straight-through estimator is a trick for defining a 'gradient' of a function that is otherwise non-differentiable. Given a non-differentiable function $f : \mathbb{R}^n \to \mathbb{R}^n$ that is used as part of a larger function that we wish to find a gradient of, we simply pretend during the backward pass that $f$ is the identity function. This can be implemented neatly using `jax.lax.stop_gradient`: - -```{code-cell} ipython3 -:id: hdORJENmVHvX -:outputId: f0839541-46a4-45a9-fce7-ead08f20046b - -def f(x): - return jnp.round(x) # non-differentiable - -def straight_through_f(x): - # Create an exactly-zero expression with Sterbenz lemma that has - # an exactly-one gradient. - zero = x - jax.lax.stop_gradient(x) - return zero + jax.lax.stop_gradient(f(x)) - -print("f(x): ", f(3.2)) -print("straight_through_f(x):", straight_through_f(3.2)) - -print("grad(f)(x):", jax.grad(f)(3.2)) -print("grad(straight_through_f)(x):", jax.grad(straight_through_f)(3.2)) -``` - -+++ {"id": "Wx3RNE0Sw5mn"} - -## Per-example gradients - -While most ML systems compute gradients and updates from batches of data, for reasons of computational efficiency and/or variance reduction, it is sometimes necessary to have access to the gradient/update associated with each specific sample in the batch. - -For instance, this is needed to prioritise data based on gradient magnitude, or to apply clipping / normalisations on a sample by sample basis. - -In many frameworks (PyTorch, TF, Theano) it is often not trivial to compute per-example gradients, because the library directly accumulates the gradient over the batch. Naive workarounds, such as computing a separate loss per example and then aggregating the resulting gradients are typically very inefficient. - -In JAX we can define the code to compute the gradient per-sample in an easy but efficient way. - -Just combine the `jit`, `vmap` and `grad` transformations together: - -```{code-cell} ipython3 -:id: tFLyd9ifw4GG -:outputId: bf3ad4a3-102d-47a6-ece0-f4a8c9e5d434 - -perex_grads = jax.jit(jax.vmap(jax.grad(td_loss), in_axes=(None, 0, 0, 0))) - -# Test it: -batched_s_tm1 = jnp.stack([s_tm1, s_tm1]) -batched_r_t = jnp.stack([r_t, r_t]) -batched_s_t = jnp.stack([s_t, s_t]) - -perex_grads(theta, batched_s_tm1, batched_r_t, batched_s_t) -``` - -+++ {"id": "VxvYVEYQYiS_"} - -Let's walk through this one transformation at a time. - -First, we apply `jax.grad` to `td_loss` to obtain a function that computes the gradient of the loss w.r.t. the parameters on single (unbatched) inputs: - -```{code-cell} ipython3 -:id: rPO67QQrY5Bk -:outputId: fbb45b98-2dbf-4865-e6e5-87dc3eef5560 - -dtdloss_dtheta = jax.grad(td_loss) - -dtdloss_dtheta(theta, s_tm1, r_t, s_t) -``` - -+++ {"id": "cU36nVAlcnJ0"} - -This function computes one row of the array above. - -+++ {"id": "c6DQF0b3ZA5u"} - -Then, we vectorise this function using `jax.vmap`. This adds a batch dimension to all inputs and outputs. Now, given a batch of inputs, we produce a batch of outputs -- each output in the batch corresponds to the gradient for the corresponding member of the input batch. - -```{code-cell} ipython3 -:id: 5agbNKavaNDM -:outputId: ab081012-88ab-4904-a367-68e9f81445f0 - -almost_perex_grads = jax.vmap(dtdloss_dtheta) - -batched_theta = jnp.stack([theta, theta]) -almost_perex_grads(batched_theta, batched_s_tm1, batched_r_t, batched_s_t) -``` - -+++ {"id": "K-v34yLuan7k"} - -This isn't quite what we want, because we have to manually feed this function a batch of `theta`s, whereas we actually want to use a single `theta`. We fix this by adding `in_axes` to the `jax.vmap`, specifying theta as `None`, and the other args as `0`. This makes the resulting function add an extra axis only to the other arguments, leaving `theta` unbatched, as we want: - -```{code-cell} ipython3 -:id: S6kd5MujbGrr -:outputId: d3d731ef-3f7d-4a0a-ce91-7df57627ddbd - -inefficient_perex_grads = jax.vmap(dtdloss_dtheta, in_axes=(None, 0, 0, 0)) - -inefficient_perex_grads(theta, batched_s_tm1, batched_r_t, batched_s_t) -``` - -+++ {"id": "O0hbsm70be5T"} - -Almost there! This does what we want, but is slower than it has to be. Now, we wrap the whole thing in a `jax.jit` to get the compiled, efficient version of the same function: - -```{code-cell} ipython3 -:id: Fvr709FcbrSW -:outputId: 627db899-5620-4bed-8d34-cd1364d3d187 - -perex_grads = jax.jit(inefficient_perex_grads) - -perex_grads(theta, batched_s_tm1, batched_r_t, batched_s_t) -``` - -```{code-cell} ipython3 -:id: FH42yzbHcNs2 -:outputId: c8e52f93-615a-4ce7-d8ab-fb6215995a39 - -%timeit inefficient_perex_grads(theta, batched_s_tm1, batched_r_t, batched_s_t).block_until_ready() -%timeit perex_grads(theta, batched_s_tm1, batched_r_t, batched_s_t).block_until_ready() -``` diff --git a/docs/jax-101/05-random-numbers.ipynb b/docs/jax-101/05-random-numbers.ipynb deleted file mode 100644 index 65977674a062..000000000000 --- a/docs/jax-101/05-random-numbers.ipynb +++ /dev/null @@ -1,509 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "1Op_vnmkjw3z" - }, - "source": [ - "# Pseudo Random Numbers in JAX\n", - "\n", - "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/jax-101/05-random-numbers.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/jax-101/05-random-numbers.ipynb)\n", - "\n", - "*Authors: Matteo Hessel & Rosalia Schneider*\n", - "\n", - "In this section we focus on pseudo random number generation (PRNG); that is, the process of algorithmically generating sequences of numbers whose properties approximate the properties of sequences of random numbers sampled from an appropriate distribution. \n", - "\n", - "PRNG-generated sequences are not truly random because they are actually determined by their initial value, which is typically referred to as the `seed`, and each step of random sampling is a deterministic function of some `state` that is carried over from a sample to the next.\n", - "\n", - "Pseudo random number generation is an essential component of any machine learning or scientific computing framework. Generally, JAX strives to be compatible with NumPy, but pseudo random number generation is a notable exception.\n", - "\n", - "To better understand the difference between the approaches taken by JAX and NumPy when it comes to random number generation we will discuss both approaches in this section." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "6_117sy0CGEU" - }, - "source": [ - "## Random numbers in NumPy\n", - "\n", - "Pseudo random number generation is natively supported in NumPy by the `numpy.random` module.\n", - "\n", - "In NumPy, pseudo random number generation is based on a global `state`.\n", - "\n", - "This can be set to a deterministic initial condition using `random.seed(SEED)`." - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": { - "id": "qbmCquES5beU" - }, - "outputs": [], - "source": [ - "import numpy as np\n", - "np.random.seed(0)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "WImNZxJ-7plK" - }, - "source": [ - "You can inspect the content of the state using the following command." - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": { - "id": "qNO_vG7z7qUb", - "outputId": "47817350-83be-40cc-85c3-46419fdbfda0" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "('MT19937', array([ 0, 1, 1812433255, 1900727105, 1208447044,\n", - " 2481403966, 4042607538, 337614300, 3232553940, 1018809052,\n", - " 3202401494, 1775180719, 3192392114, 594215549, 184016991,\n", - " 829906058, 610491522, 3879932251, 3139825610, 297902587,\n", - " 4075895579, 2943625357, 3530655617, 1423771745, 2135928312,\n", - " 2891506774, 1066338622, 135451537, 933040465, 2759011858,\n", - " 2273819758, 3545703099, 2516396728, 127 ...\n" - ] - } - ], - "source": [ - "def print_truncated_random_state():\n", - " \"\"\"To avoid spamming the outputs, print only part of the state.\"\"\"\n", - " full_random_state = np.random.get_state()\n", - " print(str(full_random_state)[:460], '...')\n", - "\n", - "print_truncated_random_state()" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "nmqx0gJW9CFo" - }, - "source": [ - "The `state` is updated by each call to a random function:" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": { - "id": "ZqUzvqF1B1TO", - "outputId": "c1874391-eb8d-43d8-eb8f-c918ed0a0c1a" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "('MT19937', array([ 0, 1, 1812433255, 1900727105, 1208447044,\n", - " 2481403966, 4042607538, 337614300, 3232553940, 1018809052,\n", - " 3202401494, 1775180719, 3192392114, 594215549, 184016991,\n", - " 829906058, 610491522, 3879932251, 3139825610, 297902587,\n", - " 4075895579, 2943625357, 3530655617, 1423771745, 2135928312,\n", - " 2891506774, 1066338622, 135451537, 933040465, 2759011858,\n", - " 2273819758, 3545703099, 2516396728, 127 ...\n", - "('MT19937', array([2443250962, 1093594115, 1878467924, 2709361018, 1101979660,\n", - " 3904844661, 676747479, 2085143622, 1056793272, 3812477442,\n", - " 2168787041, 275552121, 2696932952, 3432054210, 1657102335,\n", - " 3518946594, 962584079, 1051271004, 3806145045, 1414436097,\n", - " 2032348584, 1661738718, 1116708477, 2562755208, 3176189976,\n", - " 696824676, 2399811678, 3992505346, 569184356, 2626558620,\n", - " 136797809, 4273176064, 296167901, 343 ...\n" - ] - } - ], - "source": [ - "np.random.seed(0)\n", - "\n", - "print_truncated_random_state()\n", - "\n", - "_ = np.random.uniform()\n", - "\n", - "print_truncated_random_state()" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "G1ICXejY_xR0" - }, - "source": [ - "NumPy allows you to sample both individual numbers, or entire vectors of numbers in a single function call. For instance, you may sample a vector of 3 scalars from a uniform distribution by doing:" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": { - "id": "6Xqx2e8tAW5d", - "outputId": "a428facb-cd16-4375-f5c4-8fc601e60169" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[0.5488135 0.71518937 0.60276338]\n" - ] - } - ], - "source": [ - "np.random.seed(0)\n", - "print(np.random.uniform(size=3))" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "zPfs8tXTAlr7" - }, - "source": [ - "NumPy provides a *sequential equivalent guarantee*, meaning that sampling N numbers in a row individually or sampling a vector of N numbers results in the same pseudo-random sequences:" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": { - "id": "bZiBZXHW_2wO", - "outputId": "3aff9a51-8a19-4737-a7ad-91b23bfc05f8" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "individually: [0.5488135 0.71518937 0.60276338]\n", - "all at once: [0.5488135 0.71518937 0.60276338]\n" - ] - } - ], - "source": [ - "np.random.seed(0)\n", - "print(\"individually:\", np.stack([np.random.uniform() for _ in range(3)]))\n", - "\n", - "np.random.seed(0)\n", - "print(\"all at once: \", np.random.uniform(size=3))" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "JGCZI9UTl7o4" - }, - "source": [ - "## Random numbers in JAX\n", - "\n", - "JAX's random number generation differs from NumPy's in important ways. The reason is that NumPy's PRNG design makes it hard to simultaneously guarantee a number of desirable properties for JAX, specifically that code must be:\n", - "\n", - "1. reproducible,\n", - "2. parallelizable,\n", - "3. vectorisable.\n", - "\n", - "We will discuss why in the following. First, we will focus on the implications of a PRNG design based on a global state. Consider the code:" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": { - "id": "j441y2NCmnbt", - "outputId": "77fe84d7-c86e-417a-95b9-d73663ed40fc" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "1.9791922366721637\n" - ] - } - ], - "source": [ - "import numpy as np\n", - "\n", - "np.random.seed(0)\n", - "\n", - "def bar(): return np.random.uniform()\n", - "def baz(): return np.random.uniform()\n", - "\n", - "def foo(): return bar() + 2 * baz()\n", - "\n", - "print(foo())" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "5kVpfSV5n1d7" - }, - "source": [ - "The function `foo` sums two scalars sampled from a uniform distribution.\n", - "\n", - "The output of this code can only satisfy requirement #1 if we assume a specific order of execution for `bar()` and `baz()`, as native Python does.\n", - "\n", - "This doesn't seem to be a major issue in NumPy, as it is already enforced by Python, but it becomes an issue in JAX. \n", - "\n", - "Making this code reproducible in JAX would require enforcing this specific order of execution. This would violate requirement #2, as JAX should be able to parallelize `bar` and `baz` when jitting as these functions don't actually depend on each other.\n", - "\n", - "To avoid this issue, JAX does not use a global state. Instead, random functions explicitly consume the state, which is referred to as a `key` ." - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": { - "id": "LuaGUVRUvbzQ", - "outputId": "bbf525d7-d407-49b4-8bee-2cd827846e04" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[ 0 42]\n" - ] - } - ], - "source": [ - "from jax import random\n", - "\n", - "key = random.PRNGKey(42)\n", - "\n", - "print(key)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "XhFpKnW9F2nF" - }, - "source": [ - "A key is just an array of shape `(2,)`.\n", - "\n", - "'Random key' is essentially just another word for 'random seed'. However, instead of setting it once as in NumPy, any call of a random function in JAX requires a key to be specified. Random functions consume the key, but do not modify it. Feeding the same key to a random function will always result in the same sample being generated:" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": { - "id": "Tc_Tsv06Fz3l", - "outputId": "1472ae73-edbf-4163-9992-46781d258014" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "-0.18471184\n", - "-0.18471184\n" - ] - } - ], - "source": [ - "print(random.normal(key))\n", - "print(random.normal(key))" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "foUEgtmTesOx" - }, - "source": [ - "**Note:** Feeding the same key to different random functions can result in correlated outputs, which is generally undesirable. \n", - "\n", - "**The rule of thumb is: never reuse keys (unless you want identical outputs).**" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "T4dOLP0GGJuB" - }, - "source": [ - "In order to generate different and independent samples, you must `split()` the key *yourself* whenever you want to call a random function:" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": { - "id": "qChuz1C9CSJe", - "outputId": "f6eb1dc3-d83c-45ef-d90e-5a12d36fa7e6" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "old key [ 0 42]\n", - " \\---SPLIT --> new key [2465931498 3679230171]\n", - " \\--> new subkey [255383827 267815257] --> normal 1.3694694\n" - ] - } - ], - "source": [ - "print(\"old key\", key)\n", - "new_key, subkey = random.split(key)\n", - "del key # The old key is discarded -- we must never use it again.\n", - "normal_sample = random.normal(subkey)\n", - "print(r\" \\---SPLIT --> new key \", new_key)\n", - "print(r\" \\--> new subkey\", subkey, \"--> normal\", normal_sample)\n", - "del subkey # The subkey is also discarded after use.\n", - "\n", - "# Note: you don't actually need to `del` keys -- that's just for emphasis.\n", - "# Not reusing the same values is enough.\n", - "\n", - "key = new_key # If we wanted to do this again, we would use new_key as the key." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "WKQMJQB6cGhV" - }, - "source": [ - "`split()` is a deterministic function that converts one `key` into several independent (in the pseudorandomness sense) keys. We keep one of the outputs as the `new_key`, and can safely use the unique extra key (called `subkey`) as input into a random function, and then discard it forever.\n", - "\n", - "If you wanted to get another sample from the normal distribution, you would split `key` again, and so on. The crucial point is that you never use the same PRNGKey twice. Since `split()` takes a key as its argument, we must throw away that old key when we split it.\n", - "\n", - "It doesn't matter which part of the output of `split(key)` we call `key`, and which we call `subkey`. They are all pseudorandom numbers with equal status. The reason we use the key/subkey convention is to keep track of how they're consumed down the road. Subkeys are destined for immediate consumption by random functions, while the key is retained to generate more randomness later.\n", - "\n", - "Usually, the above example would be written concisely as" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": { - "id": "Xkt5OYjHjWiP" - }, - "outputs": [], - "source": [ - "key, subkey = random.split(key)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "ULmPVyd9jWSv" - }, - "source": [ - "which discards the old key automatically." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "dlaAsObh68R1" - }, - "source": [ - "It's worth noting that `split()` can create as many keys as you need, not just 2:" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "metadata": { - "id": "hbHZP2xM7Egf" - }, - "outputs": [], - "source": [ - "key, *forty_two_subkeys = random.split(key, num=43)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Fhu7ejhLB4R_" - }, - "source": [ - "Another difference between NumPy's and JAX's random modules relates to the sequential equivalence guarantee mentioned above.\n", - "\n", - "As in NumPy, JAX's random module also allows sampling of vectors of numbers.\n", - "However, JAX does not provide a sequential equivalence guarantee, because doing so would interfere with the vectorization on SIMD hardware (requirement #3 above).\n", - "\n", - "In the example below, sampling 3 values out of a normal distribution individually using three subkeys gives a different result to using giving a single key and specifying `shape=(3,)`:" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": { - "id": "4nB_TA54D-HT", - "outputId": "2f259f63-3c45-46c8-f597-4e53dc63cb56" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "individually: [-0.04838839 0.10796146 -1.2226542 ]\n", - "all at once: [ 0.18693541 -1.2806507 -1.5593133 ]\n" - ] - } - ], - "source": [ - "key = random.PRNGKey(42)\n", - "subkeys = random.split(key, 3)\n", - "sequence = np.stack([random.normal(subkey) for subkey in subkeys])\n", - "print(\"individually:\", sequence)\n", - "\n", - "key = random.PRNGKey(42)\n", - "print(\"all at once: \", random.normal(key, shape=(3,)))" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "_vBAaU2jrWPk" - }, - "source": [ - "Note that contrary to our recommendation above, we use `key` directly as an input to `random.normal()` in the second example. This is because we won't reuse it anywhere else, so we don't violate the single-use principle." - ] - } - ], - "metadata": { - "colab": { - "name": "Random Numbers in JAX", - "provenance": [] - }, - "jupytext": { - "formats": "ipynb,md:myst" - }, - "kernelspec": { - "display_name": "Python 3", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.7.6" - } - }, - "nbformat": 4, - "nbformat_minor": 0 -} diff --git a/docs/jax-101/05-random-numbers.md b/docs/jax-101/05-random-numbers.md deleted file mode 100644 index f9f3ae178efe..000000000000 --- a/docs/jax-101/05-random-numbers.md +++ /dev/null @@ -1,254 +0,0 @@ ---- -jupytext: - formats: ipynb,md:myst - text_representation: - extension: .md - format_name: myst - format_version: 0.13 - jupytext_version: 1.16.0 -kernelspec: - display_name: Python 3 - name: python3 ---- - -+++ {"id": "1Op_vnmkjw3z"} - -# Pseudo Random Numbers in JAX - -[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/jax-101/05-random-numbers.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/jax-101/05-random-numbers.ipynb) - -*Authors: Matteo Hessel & Rosalia Schneider* - -In this section we focus on pseudo random number generation (PRNG); that is, the process of algorithmically generating sequences of numbers whose properties approximate the properties of sequences of random numbers sampled from an appropriate distribution. - -PRNG-generated sequences are not truly random because they are actually determined by their initial value, which is typically referred to as the `seed`, and each step of random sampling is a deterministic function of some `state` that is carried over from a sample to the next. - -Pseudo random number generation is an essential component of any machine learning or scientific computing framework. Generally, JAX strives to be compatible with NumPy, but pseudo random number generation is a notable exception. - -To better understand the difference between the approaches taken by JAX and NumPy when it comes to random number generation we will discuss both approaches in this section. - -+++ {"id": "6_117sy0CGEU"} - -## Random numbers in NumPy - -Pseudo random number generation is natively supported in NumPy by the `numpy.random` module. - -In NumPy, pseudo random number generation is based on a global `state`. - -This can be set to a deterministic initial condition using `random.seed(SEED)`. - -```{code-cell} ipython3 -:id: qbmCquES5beU - -import numpy as np -np.random.seed(0) -``` - -+++ {"id": "WImNZxJ-7plK"} - -You can inspect the content of the state using the following command. - -```{code-cell} ipython3 -:id: qNO_vG7z7qUb -:outputId: 47817350-83be-40cc-85c3-46419fdbfda0 - -def print_truncated_random_state(): - """To avoid spamming the outputs, print only part of the state.""" - full_random_state = np.random.get_state() - print(str(full_random_state)[:460], '...') - -print_truncated_random_state() -``` - -+++ {"id": "nmqx0gJW9CFo"} - -The `state` is updated by each call to a random function: - -```{code-cell} ipython3 -:id: ZqUzvqF1B1TO -:outputId: c1874391-eb8d-43d8-eb8f-c918ed0a0c1a - -np.random.seed(0) - -print_truncated_random_state() - -_ = np.random.uniform() - -print_truncated_random_state() -``` - -+++ {"id": "G1ICXejY_xR0"} - -NumPy allows you to sample both individual numbers, or entire vectors of numbers in a single function call. For instance, you may sample a vector of 3 scalars from a uniform distribution by doing: - -```{code-cell} ipython3 -:id: 6Xqx2e8tAW5d -:outputId: a428facb-cd16-4375-f5c4-8fc601e60169 - -np.random.seed(0) -print(np.random.uniform(size=3)) -``` - -+++ {"id": "zPfs8tXTAlr7"} - -NumPy provides a *sequential equivalent guarantee*, meaning that sampling N numbers in a row individually or sampling a vector of N numbers results in the same pseudo-random sequences: - -```{code-cell} ipython3 -:id: bZiBZXHW_2wO -:outputId: 3aff9a51-8a19-4737-a7ad-91b23bfc05f8 - -np.random.seed(0) -print("individually:", np.stack([np.random.uniform() for _ in range(3)])) - -np.random.seed(0) -print("all at once: ", np.random.uniform(size=3)) -``` - -+++ {"id": "JGCZI9UTl7o4"} - -## Random numbers in JAX - -JAX's random number generation differs from NumPy's in important ways. The reason is that NumPy's PRNG design makes it hard to simultaneously guarantee a number of desirable properties for JAX, specifically that code must be: - -1. reproducible, -2. parallelizable, -3. vectorisable. - -We will discuss why in the following. First, we will focus on the implications of a PRNG design based on a global state. Consider the code: - -```{code-cell} ipython3 -:id: j441y2NCmnbt -:outputId: 77fe84d7-c86e-417a-95b9-d73663ed40fc - -import numpy as np - -np.random.seed(0) - -def bar(): return np.random.uniform() -def baz(): return np.random.uniform() - -def foo(): return bar() + 2 * baz() - -print(foo()) -``` - -+++ {"id": "5kVpfSV5n1d7"} - -The function `foo` sums two scalars sampled from a uniform distribution. - -The output of this code can only satisfy requirement #1 if we assume a specific order of execution for `bar()` and `baz()`, as native Python does. - -This doesn't seem to be a major issue in NumPy, as it is already enforced by Python, but it becomes an issue in JAX. - -Making this code reproducible in JAX would require enforcing this specific order of execution. This would violate requirement #2, as JAX should be able to parallelize `bar` and `baz` when jitting as these functions don't actually depend on each other. - -To avoid this issue, JAX does not use a global state. Instead, random functions explicitly consume the state, which is referred to as a `key` . - -```{code-cell} ipython3 -:id: LuaGUVRUvbzQ -:outputId: bbf525d7-d407-49b4-8bee-2cd827846e04 - -from jax import random - -key = random.PRNGKey(42) - -print(key) -``` - -+++ {"id": "XhFpKnW9F2nF"} - -A key is just an array of shape `(2,)`. - -'Random key' is essentially just another word for 'random seed'. However, instead of setting it once as in NumPy, any call of a random function in JAX requires a key to be specified. Random functions consume the key, but do not modify it. Feeding the same key to a random function will always result in the same sample being generated: - -```{code-cell} ipython3 -:id: Tc_Tsv06Fz3l -:outputId: 1472ae73-edbf-4163-9992-46781d258014 - -print(random.normal(key)) -print(random.normal(key)) -``` - -+++ {"id": "foUEgtmTesOx"} - -**Note:** Feeding the same key to different random functions can result in correlated outputs, which is generally undesirable. - -**The rule of thumb is: never reuse keys (unless you want identical outputs).** - -+++ {"id": "T4dOLP0GGJuB"} - -In order to generate different and independent samples, you must `split()` the key *yourself* whenever you want to call a random function: - -```{code-cell} ipython3 -:id: qChuz1C9CSJe -:outputId: f6eb1dc3-d83c-45ef-d90e-5a12d36fa7e6 - -print("old key", key) -new_key, subkey = random.split(key) -del key # The old key is discarded -- we must never use it again. -normal_sample = random.normal(subkey) -print(r" \---SPLIT --> new key ", new_key) -print(r" \--> new subkey", subkey, "--> normal", normal_sample) -del subkey # The subkey is also discarded after use. - -# Note: you don't actually need to `del` keys -- that's just for emphasis. -# Not reusing the same values is enough. - -key = new_key # If we wanted to do this again, we would use new_key as the key. -``` - -+++ {"id": "WKQMJQB6cGhV"} - -`split()` is a deterministic function that converts one `key` into several independent (in the pseudorandomness sense) keys. We keep one of the outputs as the `new_key`, and can safely use the unique extra key (called `subkey`) as input into a random function, and then discard it forever. - -If you wanted to get another sample from the normal distribution, you would split `key` again, and so on. The crucial point is that you never use the same PRNGKey twice. Since `split()` takes a key as its argument, we must throw away that old key when we split it. - -It doesn't matter which part of the output of `split(key)` we call `key`, and which we call `subkey`. They are all pseudorandom numbers with equal status. The reason we use the key/subkey convention is to keep track of how they're consumed down the road. Subkeys are destined for immediate consumption by random functions, while the key is retained to generate more randomness later. - -Usually, the above example would be written concisely as - -```{code-cell} ipython3 -:id: Xkt5OYjHjWiP - -key, subkey = random.split(key) -``` - -+++ {"id": "ULmPVyd9jWSv"} - -which discards the old key automatically. - -+++ {"id": "dlaAsObh68R1"} - -It's worth noting that `split()` can create as many keys as you need, not just 2: - -```{code-cell} ipython3 -:id: hbHZP2xM7Egf - -key, *forty_two_subkeys = random.split(key, num=43) -``` - -+++ {"id": "Fhu7ejhLB4R_"} - -Another difference between NumPy's and JAX's random modules relates to the sequential equivalence guarantee mentioned above. - -As in NumPy, JAX's random module also allows sampling of vectors of numbers. -However, JAX does not provide a sequential equivalence guarantee, because doing so would interfere with the vectorization on SIMD hardware (requirement #3 above). - -In the example below, sampling 3 values out of a normal distribution individually using three subkeys gives a different result to using giving a single key and specifying `shape=(3,)`: - -```{code-cell} ipython3 -:id: 4nB_TA54D-HT -:outputId: 2f259f63-3c45-46c8-f597-4e53dc63cb56 - -key = random.PRNGKey(42) -subkeys = random.split(key, 3) -sequence = np.stack([random.normal(subkey) for subkey in subkeys]) -print("individually:", sequence) - -key = random.PRNGKey(42) -print("all at once: ", random.normal(key, shape=(3,))) -``` - -+++ {"id": "_vBAaU2jrWPk"} - -Note that contrary to our recommendation above, we use `key` directly as an input to `random.normal()` in the second example. This is because we won't reuse it anywhere else, so we don't violate the single-use principle. diff --git a/docs/jax-101/05.1-pytrees.ipynb b/docs/jax-101/05.1-pytrees.ipynb deleted file mode 100644 index d80ff386a27a..000000000000 --- a/docs/jax-101/05.1-pytrees.ipynb +++ /dev/null @@ -1,1019 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "-h05_PNNhZ-D" - }, - "source": [ - "# Working with Pytrees\n", - "\n", - "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/jax-101/05.1-pytrees.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/jax-101/05.1-pytrees.ipynb)\n", - "\n", - "*Author: Vladimir Mikulik*\n", - "\n", - "Often, we want to operate on objects that look like dicts of arrays, or lists of lists of dicts, or other nested structures. In JAX, we refer to these as *pytrees*, but you can sometimes see them called *nests*, or just *trees*.\n", - "\n", - "JAX has built-in support for such objects, both in its library functions as well as through the use of functions from [`jax.tree_utils`](https://jax.readthedocs.io/en/latest/jax.tree_util.html) (with the most common ones also available as `jax.tree_*`). This section will explain how to use them, give some useful snippets and point out common gotchas." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "9UjxVY9ulSCn" - }, - "source": [ - "## What is a pytree?\n", - "\n", - "As defined in the [JAX pytree docs](https://jax.readthedocs.io/en/latest/pytrees.html):\n", - "\n", - "> a pytree is a container of leaf elements and/or more pytrees. Containers include lists, tuples, and dicts. A leaf element is anything that’s not a pytree, e.g. an array. In other words, a pytree is just a possibly-nested standard or user-registered Python container. If nested, note that the container types do not need to match. A single “leaf”, i.e. a non-container object, is also considered a pytree.\n", - "\n", - "Some example pytrees:" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": { - "executionInfo": { - "elapsed": 11002, - "status": "ok", - "timestamp": 1692698031720, - "user": { - "displayName": "", - "userId": "" - }, - "user_tz": -60 - }, - "id": "Wh6BApZ9lrR1", - "outputId": "df1fa4cd-88a6-4d71-a376-b2ddf91568dd" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[1, 'a', ] has 3 leaves: [1, 'a', ]\n", - "(1, (2, 3), ()) has 3 leaves: [1, 2, 3]\n", - "[1, {'k1': 2, 'k2': (3, 4)}, 5] has 5 leaves: [1, 2, 3, 4, 5]\n", - "{'a': 2, 'b': (2, 3)} has 3 leaves: [2, 2, 3]\n", - "Array([1, 2, 3], dtype=int32) has 1 leaves: [Array([1, 2, 3], dtype=int32)]\n" - ] - } - ], - "source": [ - "import jax\n", - "import jax.numpy as jnp\n", - "\n", - "example_trees = [\n", - " [1, 'a', object()],\n", - " (1, (2, 3), ()),\n", - " [1, {'k1': 2, 'k2': (3, 4)}, 5],\n", - " {'a': 2, 'b': (2, 3)},\n", - " jnp.array([1, 2, 3]),\n", - "]\n", - "\n", - "# Let's see how many leaves they have:\n", - "for pytree in example_trees:\n", - " leaves = jax.tree_util.tree_leaves(pytree)\n", - " print(f\"{repr(pytree):<45} has {len(leaves)} leaves: {leaves}\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "_tWkkGNwW8vf" - }, - "source": [ - "We've also introduced our first `jax.tree_*` function, which allowed us to extract the flattened leaves from the trees." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "RcsmneIGlltm" - }, - "source": [ - "## Why pytrees?\n", - "\n", - "In machine learning, some places where you commonly find pytrees are:\n", - "* Model parameters\n", - "* Dataset entries\n", - "* RL agent observations\n", - "\n", - "They also often arise naturally when working in bulk with datasets (e.g., lists of lists of dicts)." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "sMrSGSIJn9MD" - }, - "source": [ - "## Common pytree functions\n", - "Perhaps the most commonly used pytree function is `jax.tree_map`. It works analogously to Python's native `map`, but on entire pytrees:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "wZRcuQu4n7o5", - "outputId": "3528bc9f-54ed-49c8-b79a-1cbea176c0f3" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "[[2, 4, 6], [2, 4], [2, 4, 6, 8]]" - ] - }, - "execution_count": 2, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "list_of_lists = [\n", - " [1, 2, 3],\n", - " [1, 2],\n", - " [1, 2, 3, 4]\n", - "]\n", - "\n", - "jax.tree_map(lambda x: x*2, list_of_lists)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "xu8X3fk4orC9" - }, - "source": [ - "`jax.tree_map` also works with multiple arguments:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "KVpB4r1OkeUK", - "outputId": "33f88a7e-aac7-48cd-d207-2c531cd37733" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "[[2, 4, 6], [2, 4], [2, 4, 6, 8]]" - ] - }, - "execution_count": 3, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "another_list_of_lists = list_of_lists\n", - "jax.tree_map(lambda x, y: x+y, list_of_lists, another_list_of_lists)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "dkRKy3LvowAb" - }, - "source": [ - "When using multiple arguments with `jax.tree_map`, the structure of the inputs must exactly match. That is, lists must have the same number of elements, dicts must have the same keys, etc." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Lla4hDW6sgMZ" - }, - "source": [ - "## Example: ML model parameters\n", - "\n", - "A simple example of training an MLP displays some ways in which pytree operations come in useful:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "j2ZUzWx8tKB2" - }, - "outputs": [], - "source": [ - "import numpy as np\n", - "\n", - "def init_mlp_params(layer_widths):\n", - " params = []\n", - " for n_in, n_out in zip(layer_widths[:-1], layer_widths[1:]):\n", - " params.append(\n", - " dict(weights=np.random.normal(size=(n_in, n_out)) * np.sqrt(2/n_in),\n", - " biases=np.ones(shape=(n_out,))\n", - " )\n", - " )\n", - " return params\n", - "\n", - "params = init_mlp_params([1, 128, 128, 1])" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "kUFwJOspuGvU" - }, - "source": [ - "We can use `jax.tree_map` to check that the shapes of our parameters are what we expect:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "ErWsXuxXse-z", - "outputId": "d3e549ab-40ef-470e-e460-1b5939d9696f" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "[{'biases': (128,), 'weights': (1, 128)},\n", - " {'biases': (128,), 'weights': (128, 128)},\n", - " {'biases': (1,), 'weights': (128, 1)}]" - ] - }, - "execution_count": 5, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "jax.tree_map(lambda x: x.shape, params)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "zQtRKaj4ua6-" - }, - "source": [ - "Now, let's train our MLP:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "iL4GvW9OuZ-X" - }, - "outputs": [], - "source": [ - "def forward(params, x):\n", - " *hidden, last = params\n", - " for layer in hidden:\n", - " x = jax.nn.relu(x @ layer['weights'] + layer['biases'])\n", - " return x @ last['weights'] + last['biases']\n", - "\n", - "def loss_fn(params, x, y):\n", - " return jnp.mean((forward(params, x) - y) ** 2)\n", - "\n", - "LEARNING_RATE = 0.0001\n", - "\n", - "@jax.jit\n", - "def update(params, x, y):\n", - "\n", - " grads = jax.grad(loss_fn)(params, x, y)\n", - " # Note that `grads` is a pytree with the same structure as `params`.\n", - " # `jax.grad` is one of the many JAX functions that has\n", - " # built-in support for pytrees.\n", - "\n", - " # This is handy, because we can apply the SGD update using tree utils:\n", - " return jax.tree_map(\n", - " lambda p, g: p - LEARNING_RATE * g, params, grads\n", - " )" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "B3HniT9-xohz", - "outputId": "d77e9811-373e-45d6-ccbe-edb6f43120d7" - }, - "outputs": [ - { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAhYAAAGdCAYAAABO2DpVAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/NK7nSAAAACXBIWXMAAA9hAAAPYQGoP6dpAABGgUlEQVR4nO3deVxU5f4H8M/MsAsziopAouKShpiliYFpaaaomd4Wu5ZrXksTzXu7JnZ/XfKaqbflWmpo3lLLzLqZa4aZuW9YZEm4i2aKoiKLC9uc8/vjOMgwM3DOcGbl8369eBnDOTNPZPCZ5/k+30cjiqIIIiIiIhVoXT0AIiIi8h4MFkRERKQaBgsiIiJSDYMFERERqYbBgoiIiFTDYEFERESqYbAgIiIi1TBYEBERkWp8nP2CgiDg/PnzCAkJgUajcfbLExERkR1EUURRUREiIyOh1dqel3B6sDh//jyioqKc/bJERESkgrNnz6Jp06Y2v+70YBESEgJAGpher3f2yxMREZEdCgsLERUVVfF73BanBwvT8oder2ewICIi8jA1lTGweJOIiIhUw2BBREREqmGwICIiItU4vcZCDqPRiLKyMlcPgwg6nQ4+Pj7cGk1EJJPbBYtr167hjz/+gCiKrh4KEQAgKCgIERER8PPzc/VQiIjcnlsFC6PRiD/++ANBQUFo3Lgx3yWSS4miiNLSUly6dAnZ2dlo06ZNtU1hiIjIzYJFWVkZRFFE48aNERgY6OrhECEwMBC+vr44c+YMSktLERAQ4OohERG5Nbd8+8WZCnInnKUgIpLPrWYsiIiIyD5GQUR6dh5yi4oRFhKAuOhQ6LTOf6POt2IeYtu2bdBoNMjPz5d9T4sWLTB37lyHjUmphx56CJMnT674XI3xudu/IxGRK6Rl5uCBOT9g6OJ9eGnlQQxdvA8PzPkBaZk5Th8Lg4UKRo0aBY1Gg3Hjxll8bcKECdBoNBg1apTzB+bmDhw4gOeff17WtUuXLkX9+vVr9RxERN4oLTMH45dnIKeg2OzxCwXFGL88w+nhgsFCJVFRUVi5ciVu3rxZ8VhxcTFWrFiBZs2auXBk6iotLVXtuRo3boygoCCXPwcRkacyCiKmr8+CtQYNpsemr8+CUXBeCwfvDBaCEcjeCRz6SvpTMDr8JTt16oSoqCh8/fXXFY99/fXXaNasGe69916za0tKSjBp0iSEhYUhICAADzzwAA4cOGB2zcaNG3HnnXciMDAQPXv2xOnTpy1ec9euXejevTsCAwMRFRWFSZMm4fr167LHPGrUKAwePBjTp09H48aNodfrMW7cOLPw8NBDDyEpKQmTJ09Go0aN0LdvXwBAZmYm+vXrh+DgYDRp0gTDhw/H5cuXK+67fv06RowYgeDgYEREROCdd96xeP2qyxj5+fl44YUX0KRJEwQEBCA2NhYbNmzAtm3bMHr0aBQUFECj0UCj0eD111+3+hy///47Bg0ahODgYOj1egwZMgQXL16s+Prrr7+Oe+65B59++ilatGgBg8GAP//5zygqKpL9fSMichfp2XkWMxWViQByCoqRnp3ntDF5X7DIWgfMjQWWPQqsGiP9OTdWetzBnnvuOSxZsqTi848//hijR4+2uO6VV17BqlWrsGzZMmRkZKB169bo27cv8vKk//Bnz57F448/joEDB+LgwYP4y1/+guTkZLPnOHnyJBITE/HEE0/g119/xRdffIFdu3YhKSlJ0Zi3bNmCw4cPY9u2bfj888/x9ddfY/r06WbXLFu2DH5+fti9ezcWLlyI/Px89OrVC/feey9+/PFHpKWl4eLFixgyZEjFPVOmTMH27duxdu1afPfdd9i2bRsyMjJsjkMQBPTr1w+7d+/G8uXLkZWVhdmzZ0On0yEhIQFz586FXq9HTk4OcnJy8Pe//93qcwwaNAh5eXnYvn07Nm/ejFOnTuHpp5+2+N6tWbMGGzZswIYNG7B9+3bMnj1b0feNiMgd5BbZDhX2XKcG79oVkrUO+HIEUHVSqDBHenzIJ0DMYw57+WHDhmHatGk4c+YMAGD37t1YuXIltm3bVnHN9evXkZqaiqVLl6Jfv34AgMWLF2Pz5s346KOPMGXKFKSmpqJVq1YV7/Lbtm2LQ4cOYc6cORXPM2vWLDz77LMVxZBt2rTB+++/jwcffBCpqamy+y34+fnh448/RlBQENq3b49//etfmDJlCmbMmFGxzbJNmzb497//XXHPG2+8gXvvvRdvvvlmxWMff/wxoqKicOzYMURGRuKjjz7C8uXL8fDDDwOQwknTpk1tjuP7779Heno6Dh8+jDvvvBMA0LJly4qvGwwGaDQahIeH23yOLVu24NChQ8jOzkZUVBQA4JNPPkH79u1x4MABdOnSBYAUQJYuXYqQkBAAwPDhw7FlyxbMnDlT1veMiMhdhIXI+1kv9zo1eE+wEIxA2lRYhArg1mMaIC0ZaDcA0OocMoTGjRtjwIABWLp0KURRxIABA9CoUSOza06ePImysjJ069at4jFfX1/ExcXh8OHDAIDDhw+ja9euZvfFx8ebff7LL7/g119/xWeffVbxmCiKEAQB2dnZuOuuu2SNuWPHjmY1CvHx8bh27RrOnj2L5s2bAwA6d+5s8dpbt25FcHCwxfOdPHkSN2/eRGlpqdm/Q2hoKNq2bWtzHAcPHkTTpk0rQoU9Dh8+jKioqIpQAQAxMTGoX78+Dh8+XBEsWrRoUREqACAiIgK5ubl2vy4RkavERYciwhCACwXFVn/7aQCEG6Stp87iPcHizB6g8Hw1F4hA4TnpuujuDhvGc889V7EcsWDBAoe9zrVr1/DCCy9g0qRJFl9Tu1i0Xr16Fq89cOBAsxkUk4iICJw4cULxaziz06qvr6/Z5xqNBoIgOO31iYjUotNqkDIwBuOXZ0AD87fWpg4WKQNjnNrPwntqLK5drPkaJdfZKTExEaWlpSgrK6sodKysVatWFfUKJmVlZThw4ABiYmIAAHfddRfS09PN7tu3b5/Z5506dUJWVhZat25t8aHksKxffvnFbCfLvn37EBwcbPauv6pOnTrht99+Q4sWLSxeu169emjVqhV8fX2xf//+inuuXr2KY8eO2XzOu+++G3/88YfNa/z8/GA0Vl+Ee9ddd+Hs2bM4e/ZsxWNZWVnIz8+v+N4SEXmbxNgIpA7rhHCD+XJHuCEAqcM6ITE2wqnj8Z5gEdxE3evspNPpcPjwYWRlZUGns1xyqVevHsaPH48pU6YgLS0NWVlZGDt2LG7cuIExY8YAAMaNG4fjx49jypQpOHr0KFasWIGlS5eaPc/UqVOxZ88eJCUl4eDBgzh+/DjWrl2ruHiztLQUY8aMQVZWFjZu3IiUlBQkJSVV28Z6woQJyMvLw9ChQ3HgwAGcPHkSmzZtwujRo2E0GhEcHIwxY8ZgypQp+OGHH5CZmYlRo0ZV+5wPPvggevTogSeeeAKbN29GdnY2vv32W6SlpQGQli+uXbuGLVu24PLly7hx44bFc/Tu3RsdOnTAs88+i4yMDKSnp2PEiBF48MEHcd999yn6vhAReZLE2AjsmtoLn4+9H+/9+R58PvZ+7Jray+mhAvCmYNE8AdBH4vbkT1UaQH+HdJ2D6fV66PV6m1+fPXs2nnjiCQwfPhydOnXCiRMnsGnTJjRo0ACAtJSxatUqrFmzBh07dsTChQvNCiUB6R3+9u3bcezYMXTv3h333nsv/vnPfyIyMlLRWB9++GG0adMGPXr0wNNPP43HHnusYiunLZGRkdi9ezeMRiP69OmDDh06YPLkyahfv35FeHjrrbfQvXt3DBw4EL1798YDDzxgUatR1apVq9ClSxcMHToUMTExeOWVVypmKRISEjBu3Dg8/fTTaNy4sVkxqYlGo8HatWvRoEED9OjRA71790bLli3xxRdfKPqeEBF5Ip1Wg/hWDTHonjsQ36qhS9p5A4BGFEXndc0AUFhYCIPBgIKCAotfvsXFxcjOzkZ0dLR9p0hW7AoBrK40OXhXiKcZNWoU8vPzsWbNGlcPxa3V+u8lEZEXqO73d2XeM2MBSKFhyCeAvsrUjz6SoYKIiMgJvGdXiEnMY9KW0jN7pELN4CbS8oeDtpgSERHRbd4XLAApRDhwS6m3qFoQSkREVFvetRRCRERELsVgQURERKpxy2Dh5I0qRNXi30ciIvncKliYGkpVPrabyNVMzbiqtgInIiJLblW86ePjg6CgIFy6dAm+vr7VdmokcjRRFHHjxg3k5uaifv36VjupEhGRObcKFhqNBhEREcjOzq44epzI1erXr1/tce1ERHSbWwULQDpsqk2bNlwOIbfg6+vLmQoiIgXcLlgAgFarZetkIiIiD8QiBiIiIlINgwURERGphsGCiIiIVMNgQURERKphsCAiIiLVMFgQERGRahgsiIiISDUMFkRERKQaBgsiIiJSDYMFERERqYbBgoiIiFTDYEFERESqYbAgIiIi1TBYEBERkWoYLIiIiEg1DBZERESkGgYLIiIiUg2DBREREamGwYKIiIhUw2BBREREqmGwICIiItUwWBAREZFqGCyIiIhINYqChdFoxGuvvYbo6GgEBgaiVatWmDFjBkRRdNT4iIiIyIP4KLl4zpw5SE1NxbJly9C+fXv8+OOPGD16NAwGAyZNmuSoMRIREZGHUBQs9uzZg0GDBmHAgAEAgBYtWuDzzz9Henq6QwZHREREnkXRUkhCQgK2bNmCY8eOAQB++eUX7Nq1C/369bN5T0lJCQoLC80+iIiIyDspmrFITk5GYWEh2rVrB51OB6PRiJkzZ+LZZ5+1ec+sWbMwffr0Wg+UiIiI3J+iGYsvv/wSn332GVasWIGMjAwsW7YMb7/9NpYtW2bznmnTpqGgoKDi4+zZs7UeNBEREbknjahgS0dUVBSSk5MxYcKEisfeeOMNLF++HEeOHJH1HIWFhTAYDCgoKIBer1c+YiIiInI6ub+/Fc1Y3LhxA1qt+S06nQ6CINg3SiIiIvIqimosBg4ciJkzZ6JZs2Zo3749fv75Z7z77rt47rnnHDU+IiIi8iCKlkKKiorw2muvYfXq1cjNzUVkZCSGDh2Kf/7zn/Dz85P1HFwKISIi8jxyf38rChZqYLAgIiLyPA6psSAiIiKqDoMFERERqYbBgoiIiFTDYEFERESqYbAgIiIi1TBYEBERkWoYLIiIiEg1DBZERESkGgYLIiIiUg2DBREREamGwYKIiIhUw2BBREREqmGwICIiItUwWBAREZFqGCyIiIhINQwWREREpBoGCyIiIlINgwURERGphsGCiIiIVMNgQURERKphsCAiIiLVMFgQERGRahgsiIiISDUMFkRERKQaBgsiIiJSDYMFERERqYbBgoiIiFTDYEFERESqYbAgIiIi1TBYEBERkWoYLIiIiEg1DBZERESkGgYLIiIiUg2DBREREamGwYKIiIhUw2BBREREqmGwICIiItUwWBAREZFqGCyIiIhINQwWREREpBoGCyIiIlINgwURERGphsGCiIiIVMNgQURERKphsCAiIiLVMFgQERGRahgsiIiISDUMFkRERKQaBgsiIiJSDYMFERERqYbBgoiIiFTDYEFERESqYbAgIiIi1TBYEBERkWoYLIiIiEg1DBZERESkGgYLIiIiUg2DBREREalGcbA4d+4chg0bhoYNGyIwMBAdOnTAjz/+6IixERERkVyCEcjeCRz6SvpTMLpkGD5KLr569Sq6deuGnj174ttvv0Xjxo1x/PhxNGjQwFHjIyIioppkrQPSpgKF528/po8EEucAMY85dSiKgsWcOXMQFRWFJUuWVDwWHR2t+qCIiIhIpqx1wJcjAIjmjxfmSI8P+cSp4ULRUsi6detw33334amnnkJYWBjuvfdeLF68uNp7SkpKUFhYaPZBREREKhCM0kxF1VAB3H4sLdmpyyKKgsWpU6eQmpqKNm3aYNOmTRg/fjwmTZqEZcuW2bxn1qxZMBgMFR9RUVG1HjQREREBOLPHfPnDgggUnpOucxKNKIrWYo5Vfn5+uO+++7Bnz+0BTpo0CQcOHMDevXut3lNSUoKSkpKKzwsLCxEVFYWCggLo9fpaDJ2IiKiOO/QVsGpMzdc98RHQ4clavVRhYSEMBkONv78VzVhEREQgJibG7LG77roLv//+u817/P39odfrzT6IiIhIBcFN1L1OBYqCRbdu3XD06FGzx44dO4bmzZurOigiIiKSoXmCtPsDGhsXaAD9HdJ1TqIoWPz1r3/Fvn378Oabb+LEiRNYsWIFPvzwQ0yYMMFR4yMiIiJbtDppSykAy3Bx6/PE2dJ1zhqSkou7dOmC1atX4/PPP0dsbCxmzJiBuXPn4tlnn3XU+IiIiKg6MY9JW0r1EeaP6yOdvtUUUFi8qQa5xR9ERESkgGCUdn9cuyjVVDRPUHWmQu7vb0UNsoiIiMhNaXVAdHdXj4KHkBEREZF6GCyIiIhINQwWREREpBoGCyIiIlINgwURERGphsGCiIiIVMNgQURERKphsCAiIiLVMFgQERGRahgsiIiISDUMFkRERKQaBgsiIiJSDYMFERERqYbBgoiIiFTDYEFERESqYbAgIiIi1TBYEBERkWoYLIiIiEg1DBZERESkGgYLIiIiUg2DBREREamGwYKIiIhU4+PqAajBKIhIz85DblExwkICEBcdCp1W4+phERER1TkeHyzSMnMwfX0WcgqKKx6LMAQgZWAMEmMjXDgyIiKiusejl0LSMnMwfnmGWagAgAsFxRi/PANpmTkuGhkREVHd5LHBwiiImL4+C6KVr5kem74+C0bB2hVERETkCB4bLNKz8yxmKioTAeQUFCM9O895gyIiIqrjPDZY5BbZDhWVXSiUdx0RERHVnscGi7CQAFnXzdjwG2stiIiInMRjg0VcdCgiDAGoaVNp3vUyFnISERE5iccGC51Wg5SBMQBQY7gAWMhJRETkDB4bLAAgMTYCqcM6oUE9v2qvYyEnERGRc3h0sACkcPHagLtkXSu34JOIiIjs4/HBAgDCDYGyrpNb8ElERET28YpgUVMhpwZSm++46FBnDouIiKjO8YpgUV0hp+nzlIExPJiMiIjIwbwiWAC3CznDDebLHeGGAKQO68QDyYiIiJzA4083rSwxNgKPxITzCHUiIiIX8apgAUjLIvGtGrp6GERERHWS1yyFEBERkesxWBAREZFqGCyIiIhINQwWREREpBoGCyIiIlINgwURERGphsGCiIiIVMNgQURERKphsCAiIiLVMFgQERGRahgsiIiISDVed1aIPYyCyIPLiIiIVFDng0VaZg6mr89CTkFxxWMRhgCkDIzhUetEREQK1emlkLTMHIxfnmEWKgAgp6AY45dnIC0zx0UjIyIi8kx1NlgYBRHT12dBtPF1EcC0rw/BKNi6goiIlDIKIvaevIK1B89h78kr/BnrhersUkh6dp7FTEVVV2+UYf4Px/FS7zudNCoiIu+VlpmD19dl4ULh7Z+94foAvP4Yl569SZ2dscgtqj5UmCzZfZqJmoioltIyczBueYZZqACAC4XFGMelZ69SZ4NFWEiArOvyb5YhPTvPwaMhIvJeRkFE8teHqr2GS8/eo84Gi7joUNQP9JV1rdzZDSIisrTv1BXk3yir9pqrN8qw79QVJ42IHKlWwWL27NnQaDSYPHmySsNxHp1Wg9HdomVdK3d2g4iILO09KS8wyL2O3JvdweLAgQNYtGgR7r77bjXH41RJvVqjfpDtWQsNpJ4WcdGhzhsUEZHXkbvEwaUQb2BXsLh27RqeffZZLF68GA0aNFB7TE6j02ow+/EOVr9m6ruZMjCGXTiJiGohvmUjVa8j92ZXsJgwYQIGDBiA3r1713htSUkJCgsLzT7cSWJsBBYO64QIg/lyR7ghAKnDOnELFBFRLd3fqmHF7LAPyvGcbiNe91mK53Qb4YNyAED9IF/c36qhK4dJKlHcx2LlypXIyMjAgQMHZF0/a9YsTJ8+XfHAnCkxNgKPxITzvBAiIgcwzQ6f/vzvGOvzDXSa20se//D5DIvLB6DF42/zZ66XUBQszp49i5deegmbN29GQIC8gsZp06bhb3/7W8XnhYWFiIqKUjZKJ9BpNYhnWiYicojE8x9A9N1g8bhWI+IF3w3QnG8JxM5wwchIbRpRFGVXy6xZswZ/+tOfoNPpKh4zGo3QaDTQarUoKSkx+5o1hYWFMBgMKCgogF6vt3/kRETkGcpLgZlNAFGwfY1GB/zjAuDj57xxkSJyf38rmrF4+OGHceiQeZOT0aNHo127dpg6dWqNocJhBCNwZg9w7SIQ3ARongBoXTQWIiIyd2Bx9aECAESjdF38BOeMiRxGUbAICQlBbGys2WP16tVDw4YNLR53mqx1QNpUoPD87cf0kUDiHCDmMdeMiYiIbrt6Wt3ryK15dufNrHXAlyPMQwUgff7lcGDbHGk2g4iIXKdBC3WvI7emqMZCDarVWAhGYG6sZaioKiQC6Pdvzl4QEbkKayy8gtzf3547Y3FmT82hAgCKcqRZjax1jh8TERFZ8vED4pOqvyZ+AkOFl/DcYHHtorLr05K5LEJE5Cp9ZgAJkwBNlV87Gp30eB9uNfUWihtkuY3gJgouFoHCc9IsR3R3hw2JiIiq0WcG0Os1affH1dNSTUWXsZyp8DKeGyyaJ0i7PwpzIPvgGqWzHEREpC4fP24p9XKeuxSi1UlbSpVQNMtBRERESnlusACknR5DPpFmLqqlAfR3SLMcRERE5DCeHSwAKVxMzgQeetXGBbcOtUmcbbsbp2AEsncCh76S/mSRJxERkV08t8aiMq0OeGgqEHaXjS6cs233sbDWuTOoETDgHaD9YIcOm4iIyNt4boMsW5ScG2Lq3Gmr+JNboIiIiAA46BAyj6DVydtSKhilmYrqdpTseR+I7AzEDlZrdERERF7N82ss7CW3c+fGl1lzQUREJFPdDRZye1rcuCyFECKiuojF7aSQ9y2FyKWkpwUbaxFRXWStuF0fKfUQ4sGOZEPdnbFoniDt/pCDjbWIqK7JXAN8OdxyybiQBztS9epusNDqpC2lNbGzsVZpuYCPdp7CP9dm4qOdp1BaXs1xwURE7uS3NcCq0Ta+eKvgnQc7kg11dykEkPpUnJsk7f6wSlN9Yy0bZm3MwuKd2RAqbTiZufEwxnaPxrT+MXYPl4jI4bLWAf8bWcNFPNiRbKu7MxYmfWYATy4DghqaP66/Q2oXrnAdcdbGLCzaYR4qAEAQgUU7sjFrY1YtB0xE5CAV2/BlYv0ZWVG3ZyxMYgcDMQPlN9ayobRcwOKd2dVes3hnNl7u0w5+Psx0RORm5G7DN2H9GVnBYGEit7FWNT7de9pipqIqQZSuG9O9Za1ei4hIdTJnIAQRyNU0ROOoeCh7+0V1Ad82q+hM3g1VryMicioFMxAppcORfqbAgYMhT8VgoaLmoUGqXkdE5FTNEwB9ZHUHHaBc1OLFsknYJMQht6jYaUMjz8FgoaLh8S2g1VR/jVYjXUdE5Ha0OvzcPhmiCItlXVGUPiaWTUSacD8AICwkwAWDJHfHYKEiPx8txnaPBgD4oBzP6TbidZ+leE63ET4oBwCM7R7Nwk0icktGQcSLGU0xvmwyLiDU7Gs5aIhxZZPxrdAVABBhCEBcdKi1p6E6jsWbKpvWPwYP/T4fcTkroNPcjvz/8PkM6RHPIL7/By4cHRGRbenZecgpKEYO4rC55D7EaY8gDPnIRX2kC+0gVHovmjIwBrqapmipTmKwUNt3ryH+wmcQq/z/ptWIiL/wGfBdQ6l3BhGRm6lcMyFAi32C9YZ+Y7q1QGJshLOGRR6Gc/JqKi8F9s4HAFTN8RWf710gXUdE5Gbk1kz0jgl38EjIkzFYqOnAYkCs4UwQ0ShdR0TkZuKiQxFhCLB4Y2SiAWsrqGYMFmq6elrRdUZBxN6TV7D24DnsPXkFxpq6axEROZBOq0HKQGn5w9asK2srqCassVBTgxayr0vLzMH09VnIKbi9phlhCEDKwBiuXRKROgSj4qMKEmMjkDqsk8XPp3D+fCKZNKIoOvVtcmFhIQwGAwoKCqDX65350o5XXgrMbFL9cohGh7Q//YzxKzItmtCY3gOkDuvE/3mJqHay1kkHilU++0MfCSTOkXW4olEQkZ6dh9yiYoSFSMsfnKmo2+T+/uZSiJp8/ID4pGovMd7/IqZ/cwIiAC0E3K/NwmPaPbhfmwUNpEAyfX0Wl0WIyH5Z64AvR1geKFaYIz2eta7Gp9BpNYhv1RCD7rkD8a0aMlSQbFwKUZtpK+ne+eYzFxodED8B6a0mI2frPvTVpiPF9xNEavIqLjkvhmJ62QhsKohDenYe4ltVOcqdiKgmFUefW3tzIgLQAGnJQLsBik9wJpKDwcIR+swAer0m7f64elqqvegyFvDxQ+7Bc+irTUeq71yL28KRh1TfuRhfNhm7T7Tm1CMRKVfj0eciUHhOuq6WJzoTWcNg4Sg+fkD8BIuHw+r5IsX3EwCwOFdEq5H686f4fooHtt6HVRl/sFiKiGpWuUjz0hF598g8Ip1IKQYLJ4vTHYGu0vJHVVoNEIkrGKVLw7KCRIxfnsFiTiKyzVqRphwKjkgnUoLFm06mu54r67p/+i7HTv9J6KtNZzEnEVlnq0izWhpAf4e09ZTIARgsnE3Bu4Rw5OED37m4u2gH0rNtz3IQUR1UbZGmLbfWXxNns3CTHIbBwtmaJ0h7yW02zb3NVIOR4vspcguvO3ZcRORZdrytfPlDHwkM+URWHwsie7HGwtm0OqlBzZcjIIWL6t9tmGouWt84BKCZM0ZIRO4uax2w7U1513afAoS1k915k6i2OGPhCjGPSe8a9PILMgOOr8dvu7+BsbzcgQMjIrdXsQQiz28B92KtMR57hRgY+SOfnIAtvV1JMAL7FwKbXpV9SwGCkRszGm2enM53HkR1UfZOYNmjNV4mAriIhkgofg/CrUDB84ioNtjS2xNodUDXcYA+EqKNmouqsc+Aa2iTNQ+ls1vKastLRF5Gbv8JEUgpHV4RKgDgQkExxi/PQFpmjoMGR8Rg4Xq3ai6kagvzcCGKgMZGjadvST5EmT3/iciLyNxZ9m75k9gkxJk9Znqfwi3s5EgMFu7gVs2FpkrNha1QYfqaKIooWTNJOlWViLyTYJSWPw59Jf0Z1bXanWUipHOHFhgH2/x6TkExt7CTw3BXiLuIeUw6FOjMHpzasQIts1fUeItWA/iXXoX4bjtoHp3LLWRE3sbW0eexTwJ75sFyZ5kUNqaXjTBbArEmt6hY9eESAZyxcC9aHRDdHTdb11yYZebGFdlHIRORh6ju6PM984CEiZY7y/SROPbgAoslEGvCQgJUHCzRbZyxcEPtuvbFxc0N0Vi8YnFQmTWm9ywaHoVM5B3kHH2euQqY9Atwdj+Eogs4XBSEE0Ed0CgkCOH6g7hYWGL1bg2AcEMA4qJDHfqvQHUXg4Ub0vn44Hx8ChrvmVRtAWdlGh6FTOQdTNvQ5Rx9fnY/0q63xvRvSpFTUAzgEACgfpCvKX5YWSgBUgbGQCfnXQuRHRgs3NS9fUfiJ0FEy33/QANck32fUHSB61tEnkrhSaW/HD6C8TsKLWYmCm6UAQAMQb7Iv/XPgDRTwT4W5GgMFm6sc79R2BjZG4e/TMHzPhsQoqm52OpwURDaO2FsRKQyU02FgkPFFv18o7rFEgT66rBgTCdcvl6CsBBp+YMzFeRofHPr5vp3bIr2Q99AT83HuCKGWDTMMhFE4LzYECeCOjh3gERUe4pPKtWgJCgCaUUtbV5h2laq1Wow6J47EN+qIUMFOQWDhQdIjI3Ae8O64tWyMRAhhYjKTJ9PLxuOMH29Ww9W2fsuGJ06ZiKSSVZNRWVSOFgXPrHGLaUAt5WS83EpxEPc37Ih/h7SAy8WAf/0/QSRuN3c5gIa4l9lw/FrSA+p0tvW3vfEOex1QeROFNZUAAD0kfi5/VRM2dpI1uXcVkrOxmDhIXRaDVIGxmD88mJsLrkPXbRHEIZ85KI+DgjtIECL1IEx0B1Zb32dtjBHenzIJwwXRO7AjpoK9H0Txi4v4MW3tgOofiaC20rJVbgU4kESYyOQOqwTwgxB2CfEYJ2QgH1CDMIMQUgd1gmJMWE17H0HkJbMZREiV7OjpgL6O4Cu45B+puDW1tLqieC2UnINzlh4mMTYCDwSE4707DzkFhWbV3pn75S39529Lohcx86aCiTOBrQ62TUTz3VrwW2l5BIMFh5Ip9UgvlVDyy/IPU555zvArneB0JbAI28AfoHqDpCILAlGYMfbwP4PgJv58u/TR0qh4tYSptyaiUdiwu0YJFHtMVh4E5nHKePUVunPkz8AB/4LtO0PDP3cceMiquuy1gHrJwE3ryq771ZNRfqZAuQePIewkAB0bt4AEYYAXCgoZstucksMFt6keQJuBobD/8YFq2eMiLe65lh86ehG4POhDBdEjpC1DvhyuMKbNIA+EmnBgzD9re1mNRURhgA81jECH+7IZstuckuKijdnzZqFLl26ICQkBGFhYRg8eDCOHj3qqLGRQkZoMb1sBADLXhemxlo2f9Qc3QiU3nTY2IjqJMEIrH9J4U3S/6U/t5+K8Z/9YlGoeaGgGB/uyMbzPaIRbjBfFgk3BEiF3KytIBdSNGOxfft2TJgwAV26dEF5eTleffVV9OnTB1lZWahXr56jxkgypWfnYeW1e3BVOxkpVXpdyDnIDJ8PBXq8DDRP4AmpRGo4vQu4mVfzdZXpI2HsOwsvrguGaGVLqald97pfcrB9Sk/8dOaqZSE3kQspChZpaWlmny9duhRhYWH46aef0KNHD1UHRsqZqsU3CXHYXHIf4m71unhO9w3u0WXX/ATZW6UPNtMiUkf2TmXX931T2lKanY+cgn02LzO16/7pzFXrhdxELlSrPhYFBQUAgNBQ20VCJSUlKCwsNPsgx6hcLS5AW9Hr4hexlbInMjXTylqn8giJ6hglkwe3+lQYocXuE5dk3cJ23eSO7A4WgiBg8uTJ6NatG2JjY21eN2vWLBgMhoqPqKgoe1+SahAXHYoIQ4DFz7KZ5cMgirB5gJklNtMislvlc3oCDPLvS5yNtKxcPDDnB8zfelLWLWzXTe7I7mAxYcIEZGZmYuXKldVeN23aNBQUFFR8nD171t6XpBqY2n4D5m+USuGH74ydAShpHlypmRYRyZP5NTCnBbDsUWDVGOC7/0PN0xZa4KllSBO6YPzyDFldNTWQdodwSym5I7u2myYlJWHDhg3YsWMHmjZtWu21/v7+8Pf3t2twpJyp7ff09VlmP6CS/ZKB0tnoo/tJ2fTsya1A9nYpkUR3B1o8wMJOIms+HyrtrrJQQ5x/cgmMdw3C9Dk/yAr+3FJK7k4jigomyEUREydOxOrVq7Ft2za0adNG8QsWFhbCYDCgoKAAer1e8f0kj1EQK9p+Nwr2x1+WHcDNMgF+KMU/fJbjHs1JdJRT0FlVYCgw8D0WdhJVtukfwN75NVxUpetESCTQTyqS3nvyCoYutl2sWVmEIQApA2O4pZScTu7vb0UzFhMmTMCKFSuwdu1ahISE4MKFCwAAg8GAwEC2hXYnldt+7zx2CTfLBADSskhK+XPQQsAu7SSEI89qMy2bbuZJzX6GfMpwQQQA5aXA3gUyLhSlXR/BTaSPW9u6jYKI3Scuy3qppJ6t8ddH7uRMBbk1RTUWqampKCgowEMPPYSIiIiKjy+++MJR4yMVrMr4w+IxoZpmWrKwsJNIcmAxZFcvBTcBOjwpLStqdUjLzLlVrHlC1u3dWjdiqCC3p2jGQsGqCbmRG6XlVh/fJMRhfJllMy1ZeEoq1XWCUfp/4MT38u+pdJ5PWmYOxi/PkF1XwfM/yFPwrJA6oEuLhvguK9fq16o202qt+QOTfNfIe2K5p6kSeZusdUDaVAVHn0Paeto8AYBUAzV9fRaLNckr1apBFnmGkQktqm3pXbmZ1h7Rdk8SC3JPUyXyJoe+luqMlIQKAHh0bkVNxdLd2bK2lQI8/4M8D2cs6gA/Hy2e7x6NRTtq3gWSLrTDebEBInC1+vNF9HdUvPsiqjM2/R+wd57y+9r2hzHmT5j//TEs2X0a+TfLZN2W1LMV/vpIW85UkEdhsKgjpvWXGmct3pldbbGmVNQ5Eqm+cwGxmsPLEmff7mdhWmu+dtGs2p3Iq3z3mh2hQgvEv4i0O5KQ/MZm5N+QFyhMurVuzFBBHkdRHws1sI+Fa5WWC/h072lsO3YJO4/b3uLWV5uOWb7/RajmmvkXqvaxsLbWHFgf6Poi0OPvDBjkHUpvArMiAVGQd32rXkDr3kCXsUg7cgXjlmcoejlTseauqb0YLMhtyP39zWBRR8lpyKOFgPfvv4aGl/YD0CC4XU/ExPeHzufWRFfWOumwMlslaGymRd4gax2wdgJQouAAxZEbgOjuMAoiOiucqTDFCNZVkLtxSIMs8h6mA8suFBTbrEwXoUXSPj2AR6QHTgL1t/6A2Y93QGJMmDRTUV1du6mZ1v0vAm37c4mEPE9N4dmaSvVH+05dUbz8Ec7OmuThuCukjrJ1YFll1n6U5t8ow7jlGUjftl5+Vfy+D6RDmebG8ih28hyCsebwbE2l+qO9J68ouvW1AXdh19ReDBXk0Rgs6jDTgWXhBvOjl+Ws6G7Yc1D5Cxael2YwGC7IE5zZo3xL6RNLqiz9yQ8lEYYAjOoWzZoK8nhcCqnjEmMj8EhMeMWBZZeLSjDjm8M13nfsRj3Az84XXTsBuDMR8LH3CYicQGkDuPiJQIfHzQ4ANAT6yr6dDbDIWzBYkNmBZWsPnpN1T7rQDjcDmiCw2I7umyWFwLvtpIZBLOwkdyW3AZxGC8QnAX1mYP0v5/Hq6kMoKr7dRr/KmaaWtwNY8AwLNcl7cCmEzISFBNR8EaR+F793TbH/hW5ckYriuCxC7qp5AqCPRLWLg/56YNp5oM8MjP3kACZ+/rNZqABqXgxZ8My96H83QwV5DwYLMhMXHYpwvX+N10UYAtD6wWek49MDG9j5aiKwfhJwajtPSiXXK70JfPMy8OmfpD/LS4HEObe+WDVcaKSPQQsAv0DM/OY3bLZxHk/lOyoL1/tj4bBO6H93pEr/AkTugX0syEJaZk6NDX0WVt5jLxiBHW8D+1OBm1fte1F9pPRDnEsj5AqfDwWObrR8vG1/oONQyyZw+juk3R8xj6G0XEDb//tWVpnmawPuQqMQf4SFSCeVsqaCPAkbZFGtpGXmIPnrQxZ78BsE+WLW4x2srwcLRuD0LumHdNl1ha946wfskE8YLsi5bIUKk7b9gaeXW7StN0KL9Ow8fHHgd6w5KG/3yHt/vgeD7rlDpYETORcbZFGtmHaL7Dt5BXtPXQYgFXje37KhxbusylXwYSGxiBucCt3/Rih8RRGABkhLBtoNYCMtco7Sm9WHCkD6enkpEN294qG0zBxMX58l+4RSE7k1TESejMGCbNJpNejWphG6tWlk8xprP2DrBwbijXZzMOCP/0Bz7YKCVxSBwnPSO8NKP8SJHOa7f8i7bvP/AQPeAQBsOHgOSSsPKn4pfYAP4qJDFd9H5GlYvEl2S8vMwfjlGRbv2vJvliHpYBQ6X5+L4zETlT/xqe3Aoa+A7J0s6iTHyVwD/LRM3rV5pwAAM7/JsitUAMDMP3VgTQXVCZyxILsYBRHT12dVW7CWd1NAn4x4fN2zDe79bbb8LoY737r9zyzqJEf47jVgz/uyLzc2aImXVmRgw685dr3cIzFhGNiRuz+obuCMBdklPTtP1vqyCODFjKYwTjoEDF+rfGtqYQ77XZC6flujKFSIAB7K6GVXqNAAGNu9BRaP6KL4XiJPxWBBdsktkl+0llNQjNc3HEZp8x7AwPdR0QNAlltzImnJXBah2hGMwIkfgNXjZN8iAviuvDPO3lD2Un1imuC1AXfh6Bv98I8B7ZXdTOThGCzILkqr2z/d9zvavfYtZp1uLW0p1SvpNFipqJPIHlnrgLdaAcv/BJTflHWLCGC7Jg4vlL+s6KUiDAFIHdYZY7q3hJ8Pf8RS3cMaC7JLXHQoIgwBirbbCSKwaEc2zuU3xXuTDkF3dq/UF+DSEWDHWzU/gdJDoYgAKVR8OVz5bb2WYtRG5Qfl8TAxqusYp8kuOq0GKQNj7Lp3w685SJizDWnXWwMdngSiH5R3Y3ATaTo7eyd3jZA8ghFYN0n5fUGNcCKks+Lb5v/5Hh4mRnUegwXZLTE2AguHdUL9IPlHQ5tcLCrB+OUZSMvMkXHYk0ZqoXz9CjA3Flj2KLBqjPTnW62AbXMYMMiSYAT2pQLF8tvMi7c+jt73OhoFByl6ubHdo/Eou2oSsaU31Z5REDH/hxNYsjsb+TfLar6hkvqBvljwbCfcX7Ibuv+NvPVo5b+St8JGwkRgzzzYPCsysIFUGMptqQRIyx9Vz/eQQRSBReWPYrbxGYTr/VFcLqDgRlm126q1GilUTOtv3wwekafgWSHkdEZBxOvrMvHpvt8V3xthCMAHnf6w7HehvwPo+yawaZq8XxJPLgNiByt+ffIiWeukLcqyjgW7rUAMQnLZWHwrdLX4msbGsz3ZqSnefLwDizSpTmCwIJcoLRfQ7rVvISj8W2VaBEl9tiMSg7PNDnvCmT3SsofcZ3pwKvDgKzxvpC4SjNJymcKZigIxCJ1LFqLcSj27v48WDYJ8caGwpOKxCEMAUgbGsJ6C6hQeQkYu4eejxdju0Vi0I1vRfbeOIMP0DUfxyNRe5lX1inaDiMD22UD6Ii6N1EVn9igKFaa3VcllY62GCgAoKRcw54m74eeju3XQHo88J6oO5+9IddP6x+CFHtFQ+nNXhNRMKz07z/wLwU2UD+LmVWmLITt21i12bEleVP6o1eWPytb8fA7xrRpi0D13IL6V5Qm/RHQbgwU5xLT+MTgyox+e7NRU8b27T1zG2oPnsPfkFRgFsdKuETuwY2fdYNqGfOmI7FsuiyEYXzYJs43P1Hjt9VL+HSKSi0sh5DB+Plq8PaQjeseEWRytXp35W09U/HPFWnbiHLsK8ngMex2gYAeIIAIFCMaLZZOwX4iBIPO9VZcWCs+4IarDOGNBDpcYG4FdU3vhszFdUT9QWc+LCwXFUr8LoYvUCjwwVPkA2LHTe2WukZa8ZIYKAEgu+wv2CrGyQ4UGwMiEaPvHSFTHMFiQU+i0GnRr0wizn+hgzxFkmL4+C8Z2A4EpJ4AHkxU8A+yr0SD399saYNVom1+uOrd1AQ0xvmwyNglxil7m+R7R3E5KpAD/byGnSoyNQOqwTgg3yD/EzKyoU6sDek4Dnloq485bHTubJ9g7XHJXWeuA/40ERMHmJabo+X75YPy59P/wQMl7ikKFBsALPdj4ikgp1liQ0yXGRuCRmHCkZ+cht6gYxy8WYf7WkzXeZ3ZUe/vBgOZTYP1LwM08K1ff+rWSONu8n4VglGouKvfJYL8LzyIYpZoKmU4ITbFPUBYOnugUiVmPd+RMBZEdGCzIJXRaDeJbNQQA7D15RVawCAsJgFEQKwJJWEg3xL18HLpd7wD7PwBu5t++WB8phYrKfSysFfkFNQT6v8tunZ5kx9uKelXkor6ip+csBVHtMFiQy5mOYL9QUGx1z4cGQLghAFevl+CBOT+Y7S6pH+iL0d0eR9LLL98+ht3aTIStNs83rgBfjQQOPw488V/OXri7rHXAtjdlXSqIUl1FutBO1vWhQb54Y3AH9L+b3TSJaoMtvcktpGXmYPzyDABWjyDD8z2i8eGObJubTYP8dHihR0sk9Wpj2bxIbptnHmTm3hS26xZEyC7WnPxwG0x82MrfHSKqIPf3NxcQyS3YKuoMNwRgwTOdsO6XnGo7WNwoNeI/3x/H3a9vwnvfH5caa5nIbfNs6tbJY9jdk4J23eWiFi+WTZIVKl7oEY3Jj9zJUEGkEi6FkNuoWtRpOpMhPTtPdnOt66VG/Of7Y1iyJxuzH+8gHRKltI/FtjeB/YuAR9+VikTJPcj87yiKwMSyiUiroU13gyBfzBwci/5329nVlYisYrAgt1K5qNPEbDeITPk3yjBueQYWDuuERLvOGrkibWfcEQuM+R7wC1T+HKQO006eXHntut8tf7Lao8+f69YCj8SE8yAxIgdhsCC3FxYiv+dFVdPXZ+GRKQ9CF9RQKtRU6mIm8GY4cGc/4JmVdo+D7JS1DmLaVGgqLYGIIqCxkgekYs1QLDAOtvpU4TzqnMgpWGNBbs+0a8QeOQXFSD9TIG0prY1j3wKLHqrdc5AyWesgfjkCopW6iqol56aSmullIyxadSf1bI3Px96PXVN7MVQQOQGDBbk9nVaDlIExSpp4m8ktKpb6VCRMqt1Acn4GvhjBwk4HMwoi9h7PRcHqlyGKosUPKWuzFbbadUcYAvDXR+7kUedETsRgQR7BtGukfpCyQ8yASkspfWbA+ORSlPkE2z+Qw2uB2c2kcypIdWmZOXhgzg94b8kyGMpyYSsLmMLF+2XVt+tOGRjDQEHkZAwW5DESYyPw0/89gr/2boMgv5obWWkgvWONi5ZORE3LzEHnr+uh7bWFeKfsSVwV69k3kNJrUmHnd6/Zdz+ZMQoidh67hCEL92Dc8gzkFBQjDPmy7j0hSu26qy5/1A/ylQp3ufRB5HQs3iSPotNq8FLvO5HUqw3m/3ACi3acxI1Sy6UJ03tU0zvWtMwcjLvVgAvQYp7xcSwwDkaSbjX+6rPK6vR6jfa8D/gEAg9NZcdOOxgFEfN/OIEPtp1ASfntw8S0ENBIky/rOaq2667np8PzPVohqVdrzlQQuQg7b5JHk345HceS3aeRf7Os4vGISjsAjIKIbrO34EJhidXnSNYtxws+G+0LFwAQUB94bB47dspk+m+2aMcpi1DYV5uOFN9PEKmxdrDcbaZ23Q+UvAcB2lut3VtY77xKRKqQ+/ubwYK8gvnhZAFmPQr2nryCoYv3VXv/NN1neN7nG/vDBQA8tYwNtapRWi7g1a9/xfpfc8xmKEz6atOR6jsXAMxqK6puLzXtAHlF93d0e3Q0wg2B7ElB5ARyf39zKYS8grXGWiZyGmzNMj6Lg2JrvO87D74ay196svxvJHAxGXjoFS6NVDFrY1a1Z734oBxv+n4EDSx3fVT9/AIa4l9lwzH4qb+whoLIDbF4k7ye3AZb3wpd0bbkE6wr72LRJ0G2HbOB2VF1/rwRoyBi9/HLeHvTETy+YBcWVRMq+mrTsc//RTTUFNU4Y/SvsmF4yn8hBj8zjqGCyE1xKYS8Xk01FtYkavfh374fQq9R3k68gk8gMOgDoMPj9j+Hh6i8FHX68g0s2Z1tVvNii2n5w9pMhTXHHpiLVr1GcdmDyAVYY0FUifmuEHm0EPAfnwV4TLe3drUXEfcAfd4Amid4zRKJURCx58RlfJ3xB05dvoaTl67hWomyJSQflONX/zEIRJn87+/IDUB0d+UDJqJaY7AgqiItMwfJXx9C/o2a30lX1l+7F/N850FX2zfJ+kggcY5H7h4xCiL2nbqCvSev4MSlIvxwOBelRvt+dGghYIJuNV70WYdAjdz/Fhrp+zf5kNeEMyJPw2BBZIVRELHv5BUs338aO45fxvUSeXUQ/bR78IHvfADypuytEaGR+msM+cSjwoW9gcyavtp0zPL9L0I11xTeqfG47xuRt+GuECIrdFoNurVphG5tGlX0U/jP98drvO9bIQGLyk/jBZ8Ndr+2BiJEaIC0ZGReAU6fycaZkhAU39EV8a2b4P6Wrj3PwlQncf7qDRz8Ix+ABjdLjfgq4w9Vnr+/di8W+M5TfqN/iFSrwlBB5BE4Y0F1XlpmDl7+8hdct9LBs6p+2v34t++HCNHcVO31C8UA/Le8P5b7DcHIbq3RolEQwkICcE9UfXy69zS+++0CCkvKcVd4CJ7sHIWE1o1kB5DKSxiAiPiWjdCpeQN8uvc0Nv2Wg6KScrQLN6BFaCC+/OkPRQWuStSqT8iw1UDrXqqPiYiU4VIIkQJGQcS8Lccxb+txGGuoQTTVCLzgswHBGvV+EV8TA7Co/FGcEcORi/pIF9pZnIEBSG2r3xnSEY/EhGPPicv4309nceRCEUL8dejbPgLPdG2OLw78jh3HLyM9Ow83y1y77TVZtwIv+GywL1T4hQDJZ1hXQeQGHBosFixYgLfeegsXLlxAx44dMW/ePMTFWZ4sWJuBEbmCURCx69glvPltFk7kXkd19YlSwFiDF3zWqxIwqnaYzBEbYIexA24iAL+LYfjE2AfllVYv/X20VjtYuhMflOOo/0hoIdoXLJ5cJh15T0Qu57Bg8cUXX2DEiBFYuHAhunbtirlz5+J///sfjh49irCwMNUGRuRqppqDTb/lYFXGORQVl1u9TgsBK33/hS7aY7XbllqFtVbW3xjj8FL5JKszGe5GCwEpPksx0ud7+54gYRLQZ4a6gyIiuzksWHTt2hVdunTB/PlShbwgCIiKisLEiRORnJys2sCI3IkpZHy48yS2HblktYvkNN1n+IvPRug0jl1dLBV1mFg2EZsEebOErjC03k94TfMRgsrzld/sr5cOdeO5K0RuxSHBorS0FEFBQfjqq68wePDgisdHjhyJ/Px8rF271uKekpISlJTcniYuLCxEVFQUgwV5rNJyAZ/uPY3TV27gzJXrOHA6DzfLpCUJH5RjhO47PKA9hM7a4zBobqj++qb/Y8eVTXa7cNGrXWPMCPoSkVmLoXjyxq8ekPAS0OPvrKkgckMO2W56+fJlGI1GNGnSxOzxJk2a4MiRI1bvmTVrFqZPn67kZYjcmp+PFmO6t6z43LRtdcHWEyg1+uBjY398bOwPLQTEaY8gDPm4BD3m+b6PRor7N1jSaKRwkeK7DIVlQWiMwmqLPR3N30eLRztEYNYTd8PvyDrgq8XKn6T7K0DPZAYKIi+gaMbi/PnzuOOOO7Bnzx7Ex8dXPP7KK69g+/bt2L9/v8U9nLGgusIoiNhz/DJW/fwHsi9fx6nL183qMp4M+BFvie8CsL/JVnVKRB+sM8ZjWvlYsyJPtTzaIRwtGgVDhIj6gb5oFOx/+8hyCED2TuCLZ4FSheEpPgnoO1P18RKRuhwyY9GoUSPodDpcvHjR7PGLFy8iPDzc6j3+/v7w9/dX8jJEHkmn1aB728bo3rYxAPODucJCAhAX3R/ZKwsQfewjh7y+v6YcT/nsxBO6ndgjxOCE2NTqbhKltBpgbPdoTOsfY/2CrHVA2lSg8LyyJ9bogPgJLNAk8jJ2FW/GxcVh3jypg54gCGjWrBmSkpJYvEkkgzFzNYQ1SfAtN39nLwLK6xKsqLqbxChqsLh8AGYbn6n2vnr+OjzfvRWa1g+o6LzZomEQhse3gJ9PpSUWwSjNTmRvB459B+RmKh9kl78AfWcBPn7K7yUil3DodtORI0di0aJFiIuLw9y5c/Hll1/iyJEjFrUXtRkYkVcz/XI+s0tKFFodxIxl0BTlVFyiZtCABjhfLxZ5909F6y6J+HT/WbPOm091biqvo2fWOmD9JODmVfsHFNQI+Psx1lMQeRiHNsiaP39+RYOse+65B++//z66du2q6sCI6hzBCJzZA1y7CGybDVyp+QwTu/gbgLb9gLxsoKQQaNIeuGcY0LKH7V/2ghHY9m9gx+zav/5Ty7iVlMgDsaU3kadLmwbs+8B5r+cXLDWlatAc+OOA9FhoSyAkAvjmb7WbpTBh0ysij8VgQeQNykuB9EXA7/ukPg+xQ4BVo6WZBk8S1Ajo/w7bcxN5MAYLIm+V+TXw1WhXj0KewAbS0keLB1hTQeTh5P7+dv8DB4jIXOzjQNv+rh6FDBpg4PtAywcZKojqEAYLIk809HOpsZS7Co4AhnwCxDzm6pEQkZNxKYTIk5WXAvtTgSMbpc+1PsCZ3YDVY9KcpP3jwBP/5SwFkZdxSOdNInIzPn5At5ekD5PyUmDDS8DBFU4ejBZISOKuD6I6jsGCyNv4+AGDU4E7+wHrXwJu5jn29SI7SXUfcS+wkyYRMVgQea2Yx4B2A4DTu6QunxoAl08AWWugylJJYCgw8D3WURCRGQYLIm+m1Um7Mlo+ePux8lJg30LgyIbbnTdDWwE/fwJUaileQaMD7n8RaPPI7YDS/AEgujvrKIjIAos3iUhiaileeM6882aXsVziICIWbxKRQlqdNAsBAB3/7NqxEJHHYh8LIiIiUg2DBREREamGwYKIiIhUw2BBREREqmGwICIiItUwWBAREZFqGCyIiIhINQwWREREpBoGCyIiIlKN0ztvmjqIFxYWOvuliYiIyE6m39s1nQTi9GBRVFQEAIiKinL2SxMREVEtFRUVwWAw2Py60w8hEwQB58+fR0hICDQajTNfWpHCwkJERUXh7NmzPCxNRfy+qo/fU/Xxe+oY/L6qz5nfU1EUUVRUhMjISGi1tispnD5jodVq0bRpU2e/rN30ej3/B3AAfl/Vx++p+vg9dQx+X9XnrO9pdTMVJizeJCIiItUwWBAREZFqGCxs8Pf3R0pKCvz9/V09FK/C76v6+D1VH7+njsHvq/rc8Xvq9OJNIiIi8l6csSAiIiLVMFgQERGRahgsiIiISDUMFkRERKQaBgsZTp8+jTFjxiA6OhqBgYFo1aoVUlJSUFpa6uqhebSZM2ciISEBQUFBqF+/vquH47EWLFiAFi1aICAgAF27dkV6erqrh+SxduzYgYEDByIyMhIajQZr1qxx9ZA83qxZs9ClSxeEhIQgLCwMgwcPxtGjR109LI+XmpqKu+++u6IxVnx8PL799ltXDwsAg4UsR44cgSAIWLRoEX777Tf85z//wcKFC/Hqq6+6emgerbS0FE899RTGjx/v6qF4rC+++AJ/+9vfkJKSgoyMDHTs2BF9+/ZFbm6uq4fmka5fv46OHTtiwYIFrh6K19i+fTsmTJiAffv2YfPmzSgrK0OfPn1w/fp1Vw/NozVt2hSzZ8/GTz/9hB9//BG9evXCoEGD8Ntvv7l6aNxuaq+33noLqampOHXqlKuH4vGWLl2KyZMnIz8/39VD8Thdu3ZFly5dMH/+fADSWTxRUVGYOHEikpOTXTw6z6bRaLB69WoMHjzY1UPxKpcuXUJYWBi2b9+OHj16uHo4XiU0NBRvvfUWxowZ49JxcMbCTgUFBQgNDXX1MKgOKy0txU8//YTevXtXPKbVatG7d2/s3bvXhSMjsq2goAAA+PNTRUajEStXrsT169cRHx/v6uE4/xAyb3DixAnMmzcPb7/9tquHQnXY5cuXYTQa0aRJE7PHmzRpgiNHjrhoVES2CYKAyZMno1u3boiNjXX1cDzeoUOHEB8fj+LiYgQHB2P16tWIiYlx9bDq9oxFcnIyNBpNtR9Vf0CfO3cOiYmJeOqppzB27FgXjdx92fM9JaK6YcKECcjMzMTKlStdPRSv0LZtWxw8eBD79+/H+PHjMXLkSGRlZbl6WHV7xuLll1/GqFGjqr2mZcuWFf98/vx59OzZEwkJCfjwww8dPDrPpPR7SvZr1KgRdDodLl68aPb4xYsXER4e7qJREVmXlJSEDRs2YMeOHWjatKmrh+MV/Pz80Lp1awBA586dceDAAbz33ntYtGiRS8dVp4NF48aN0bhxY1nXnjt3Dj179kTnzp2xZMkSaLV1erLHJiXfU6odPz8/dO7cGVu2bKkoMBQEAVu2bEFSUpJrB0d0iyiKmDhxIlavXo1t27YhOjra1UPyWoIgoKSkxNXDqNvBQq5z587hoYceQvPmzfH222/j0qVLFV/jO0P7/f7778jLy8Pvv/8Oo9GIgwcPAgBat26N4OBg1w7OQ/ztb3/DyJEjcd999yEuLg5z587F9evXMXr0aFcPzSNdu3YNJ06cqPg8OzsbBw8eRGhoKJo1a+bCkXmuCRMmYMWKFVi7di1CQkJw4cIFAIDBYEBgYKCLR+e5pk2bhn79+qFZs2YoKirCihUrsG3bNmzatMnVQwNEqtGSJUtEAFY/yH4jR460+j3dunWrq4fmUebNmyc2a9ZM9PPzE+Pi4sR9+/a5ekgea+vWrVb/To4cOdLVQ/NYtn52LlmyxNVD82jPPfec2Lx5c9HPz09s3Lix+PDDD4vfffedq4cliqIoso8FERERqYaFAkRERKQaBgsiIiJSDYMFERERqYbBgoiIiFTDYEFERESqYbAgIiIi1TBYEBERkWoYLIiIiEg1DBZERESkGgYLIiIiUg2DBREREamGwYKIiIhU8//maQ/JCima+QAAAABJRU5ErkJggg==", - "text/plain": [ - "
" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "import matplotlib.pyplot as plt\n", - "\n", - "xs = np.random.normal(size=(128, 1))\n", - "ys = xs ** 2\n", - "\n", - "for _ in range(1000):\n", - " params = update(params, xs, ys)\n", - "\n", - "plt.scatter(xs, ys)\n", - "plt.scatter(xs, forward(params, xs), label='Model prediction')\n", - "plt.legend();" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "lNAvmpzdoE9l" - }, - "source": [ - "## Key paths\n", - "\n", - "In a pytree each leaf has a _key path_. A key path for a leaf is a `list` of _keys_, where the length of the list is equal to the depth of the leaf in the pytree . Each _key_ is a [hashable object](https://docs.python.org/3/glossary.html#term-hashable) that represents an index into the corresponding pytree node type. The type of the key depends on the pytree node type; for example, the type of keys for `dict`s is different from the type of keys for `tuple`s.\n", - "\n", - "For built-in pytree node types, the set of keys for any pytree node instance is unique. For a pytree comprising nodes with this property, the key path for each leaf is unique.\n", - "\n", - "The APIs for working with key paths are:\n", - "\n", - "* [`jax.tree_util.tree_flatten_with_path`](https://jax.readthedocs.io/en/latest/_autosummary/jax.tree_util.tree_flatten_with_path.html): Works similarly with `jax.tree_util.tree_flatten`, but returns key paths.\n", - "\n", - "* [`jax.tree_util.tree_map_with_path`](https://jax.readthedocs.io/en/latest/_autosummary/jax.tree_util.tree_map_with_path.html): Works similarly with `jax.tree_util.tree_map`, but the function also takes key paths as arguments.\n", - "\n", - "* [`jax.tree_util.keystr`](https://jax.readthedocs.io/en/latest/_autosummary/jax.tree_util.keystr.html): Given a general key path, returns a reader-friendly string expression.\n", - "\n", - "One use case is to print debugging information related to a certain leaf value:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "G6E2YzhvoE9l", - "outputId": "5aec83c8-e15e-48eb-b2c3-6fa0164344b5" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Value of tree[0]: 1\n", - "Value of tree[1]['k1']: 2\n", - "Value of tree[1]['k2'][0]: 3\n", - "Value of tree[1]['k2'][1]: 4\n", - "Value of tree[2].name: foo\n" - ] - } - ], - "source": [ - "import collections\n", - "ATuple = collections.namedtuple(\"ATuple\", ('name'))\n", - "\n", - "tree = [1, {'k1': 2, 'k2': (3, 4)}, ATuple('foo')]\n", - "flattened, _ = jax.tree_util.tree_flatten_with_path(tree)\n", - "for key_path, value in flattened:\n", - " print(f'Value of tree{jax.tree_util.keystr(key_path)}: {value}')" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "zrKqmANgoE9l" - }, - "source": [ - "To express key paths, JAX provides a few default key types for the built-in pytree node types, namely:\n", - "\n", - "* `SequenceKey(idx: int)`: for lists and tuples.\n", - "* `DictKey(key: Hashable)`: for dictionaries.\n", - "* `GetAttrKey(name: str)`: for `namedtuple`s and preferably custom pytree nodes (more in the next section)\n", - "\n", - "You are free to define your own key types for your own custom nodes. They will work with `jax.tree_util.keystr` as long as their `__str__()` method is also overridden with a reader-friendly expression." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "ohDq0kGuoE9l", - "outputId": "9b8ff3ec-3461-482e-ff27-30dc2a7e68c9" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Key path of tree[0]: (SequenceKey(idx=0),)\n", - "Key path of tree[1]['k1']: (SequenceKey(idx=1), DictKey(key='k1'))\n", - "Key path of tree[1]['k2'][0]: (SequenceKey(idx=1), DictKey(key='k2'), SequenceKey(idx=0))\n", - "Key path of tree[1]['k2'][1]: (SequenceKey(idx=1), DictKey(key='k2'), SequenceKey(idx=1))\n", - "Key path of tree[2].name: (SequenceKey(idx=2), GetAttrKey(name='name'))\n" - ] - } - ], - "source": [ - "for key_path, _ in flattened:\n", - " print(f'Key path of tree{jax.tree_util.keystr(key_path)}: {repr(key_path)}')" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "sBxOB21YNEDA" - }, - "source": [ - "## Custom pytree nodes\n", - "\n", - "So far, we've only been considering pytrees of lists, tuples, and dicts; everything else is considered a leaf. Therefore, if you define your own container class, it will be considered a leaf, even if it has trees inside it:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "CK8LN2PRFnQf" - }, - "outputs": [], - "source": [ - "class MyContainer:\n", - " \"\"\"A named container.\"\"\"\n", - "\n", - " def __init__(self, name: str, a: int, b: int, c: int):\n", - " self.name = name\n", - " self.a = a\n", - " self.b = b\n", - " self.c = c" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "OPGe2R7ZOXCT", - "outputId": "40db1f41-9df8-4dea-972a-6a7bc44a49c6" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "[<__main__.MyContainer at 0x121ae9ac0>, <__main__.MyContainer at 0x1233f9910>]" - ] - }, - "execution_count": 11, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "jax.tree_util.tree_leaves([\n", - " MyContainer('Alice', 1, 2, 3),\n", - " MyContainer('Bob', 4, 5, 6)\n", - "])" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "vk4vucGXPADj" - }, - "source": [ - "Accordingly, if we try to use a tree map expecting our leaves to be the elements inside the container, we will get an error:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "vIr9_JOIOku7", - "outputId": "dadc9c15-4a10-4fac-e70d-f23e7085cf74" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "TypeError: unsupported operand type(s) for +: 'MyContainer' and 'int'\n" - ] - } - ], - "source": [ - "try:\n", - " jax.tree_map(lambda x: x + 1, [\n", - " MyContainer('Alice', 1, 2, 3),\n", - " MyContainer('Bob', 4, 5, 6)\n", - " ])\n", - "except TypeError as e:\n", - " print(f'TypeError: {e}')" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "nAZ4FR2lPN51", - "tags": [ - "raises-exception" - ] - }, - "source": [ - "To solve this, we need to register our container with JAX by telling it how to flatten and unflatten it:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "2RR5cDFvoE9m", - "outputId": "94745373-abe4-4bca-967c-4133e8027c30" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "[1, 2, 3, 4, 5, 6]" - ] - }, - "execution_count": 13, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "from typing import Iterable\n", - "\n", - "def flatten_MyContainer(container) -> tuple[Iterable[int], str]:\n", - " \"\"\"Returns an iterable over container contents, and aux data.\"\"\"\n", - " flat_contents = [container.a, container.b, container.c]\n", - "\n", - " # we don't want the name to appear as a child, so it is auxiliary data.\n", - " # auxiliary data is usually a description of the structure of a node,\n", - " # e.g., the keys of a dict -- anything that isn't a node's children.\n", - " aux_data = container.name\n", - " return flat_contents, aux_data\n", - "\n", - "def unflatten_MyContainer(\n", - " aux_data: str, flat_contents: Iterable[int]) -> MyContainer:\n", - " \"\"\"Converts aux data and the flat contents into a MyContainer.\"\"\"\n", - " return MyContainer(aux_data, *flat_contents)\n", - "\n", - "jax.tree_util.register_pytree_node(\n", - " MyContainer, flatten_MyContainer, unflatten_MyContainer)\n", - "\n", - "jax.tree_util.tree_leaves([\n", - " MyContainer('Alice', 1, 2, 3),\n", - " MyContainer('Bob', 4, 5, 6)\n", - "])" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "JXaEe76ZoE9m" - }, - "source": [ - "Alternatively, using the key path API mentioned above, you can register this container with its keys in mind by defining how the keys should look like for each flattened-out value." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "D_juQx-2OybX", - "outputId": "ee2cf4ad-ec21-4636-c9c5-2c64b81429bb" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "[1, 2, 3, 4, 5, 6]" - ] - }, - "execution_count": 14, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "class MyKeyPathContainer(MyContainer):\n", - " pass\n", - "\n", - "def flatten_with_keys_MyKeyPathContainer(container) -> tuple[Iterable[int], str]:\n", - " \"\"\"Returns an iterable over container contents, and aux data.\"\"\"\n", - "\n", - " # GetAttrKey is a common way to express an attribute key. Users are free\n", - " # to pick any other expression that fits their use cases the best.\n", - " flat_contents = [(jax.tree_util.GetAttrKey('a'), container.a),\n", - " (jax.tree_util.GetAttrKey('b'), container.b),\n", - " (jax.tree_util.GetAttrKey('c'), container.c)]\n", - "\n", - " # we don't want the name to appear as a child, so it is auxiliary data.\n", - " # auxiliary data is usually a description of the structure of a node,\n", - " # e.g., the keys of a dict -- anything that isn't a node's children.\n", - " aux_data = container.name\n", - " return flat_contents, aux_data\n", - "\n", - "def unflatten_MyKeyPathContainer(\n", - " aux_data: str, flat_contents: Iterable[int]) -> MyKeyPathContainer:\n", - " \"\"\"Converts aux data and the flat contents into a MyContainer.\"\"\"\n", - " return MyKeyPathContainer(aux_data, *flat_contents)\n", - "\n", - "jax.tree_util.register_pytree_with_keys(\n", - " MyKeyPathContainer, flatten_with_keys_MyKeyPathContainer, unflatten_MyKeyPathContainer)\n", - "\n", - "jax.tree_util.tree_leaves([\n", - " MyKeyPathContainer('Alice', 1, 2, 3),\n", - " MyKeyPathContainer('Bob', 4, 5, 6)\n", - "])" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "HPX23W4zoE9m" - }, - "source": [ - "`register_pytree_with_keys` is an extended API of `register_pytree_node`, and containers registered in either way can freely use all the `tree_util` utilities without error.\n", - "\n", - "When a container registered with `register_pytree_node` uses `.*_with_path` APIs, the keys being returned will be a series of \"flat index\" fallbacks:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "E1BwD2aZoE9m", - "outputId": "4fe12b06-aef4-426a-a732-891affa63842" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "MyContainer container[]: 1\n", - "MyContainer container[]: 2\n", - "MyContainer container[]: 3\n", - "MyKeyPathContainer container.a: 1\n", - "MyKeyPathContainer container.b: 2\n", - "MyKeyPathContainer container.c: 3\n" - ] - } - ], - "source": [ - "flattened, _ = jax.tree_util.tree_flatten_with_path(MyContainer('Alice', 1, 2, 3))\n", - "for key_path, value in flattened:\n", - " print(f'MyContainer container{jax.tree_util.keystr(key_path)}: {value}')\n", - "\n", - "flattened, _ = jax.tree_util.tree_flatten_with_path(MyKeyPathContainer('Alice', 1, 2, 3))\n", - "for key_path, value in flattened:\n", - " print(f'MyKeyPathContainer container{jax.tree_util.keystr(key_path)}: {value}')" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "JgnAp7fFShEB" - }, - "source": [ - "Modern Python comes equipped with helpful tools to make defining containers easier. Some of these will work with JAX out-of-the-box, but others require more care. For instance, a `NamedTuple` subclass doesn't need to be registered to be considered a pytree node type:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "8DNoLABtO0fr", - "outputId": "9a448508-43eb-4450-bfaf-eeeb59a9e349" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "['Alice', 1, 2, 3, 'Bob', 4, 5, 6]" - ] - }, - "execution_count": 16, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "from typing import NamedTuple, Any\n", - "\n", - "class MyOtherContainer(NamedTuple):\n", - " name: str\n", - " a: Any\n", - " b: Any\n", - " c: Any\n", - "\n", - "# NamedTuple subclasses are handled as pytree nodes, so\n", - "# this will work out-of-the-box:\n", - "jax.tree_util.tree_leaves([\n", - " MyOtherContainer('Alice', 1, 2, 3),\n", - " MyOtherContainer('Bob', 4, 5, 6)\n", - "])" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "TVdtzJDVTZb6" - }, - "source": [ - "Notice that the `name` field now appears as a leaf, as all tuple elements are children. That's the price we pay for not having to register the class the hard way." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "wDbVszv-oE9n" - }, - "source": [ - "One shortcut is to use `jax.tree_util.register_static` to register a type as being a node without children:" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": { - "executionInfo": { - "elapsed": 59, - "status": "ok", - "timestamp": 1692698060536, - "user": { - "displayName": "", - "userId": "" - }, - "user_tz": -60 - }, - "id": "Rclc079ioE9n", - "outputId": "6b6a4402-8fc1-409c-b6da-88568a612e1b" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "[1, 2, 3, 4, 5, 6]" - ] - }, - "execution_count": 3, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "from typing import NamedTuple, Any\n", - "\n", - "@jax.tree_util.register_static\n", - "class StaticStr(str):\n", - " pass\n", - "\n", - "\n", - "class YetAnotherContainer(NamedTuple):\n", - " name: StaticStr\n", - " a: Any\n", - " b: Any\n", - " c: Any\n", - "\n", - "\n", - "# NamedTuple subclasses are handled as pytree nodes, so\n", - "# this will work out-of-the-box:\n", - "jax.tree_util.tree_leaves([\n", - " YetAnotherContainer(StaticStr('Alice'), 1, 2, 3),\n", - " YetAnotherContainer(StaticStr('Bob'), 4, 5, 6)\n", - "])" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "kNsTszcEEHD0" - }, - "source": [ - "## Common pytree gotchas and patterns" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "0ki-JDENzyL7" - }, - "source": [ - "### Gotchas\n", - "#### Mistaking nodes for leaves\n", - "A common problem to look out for is accidentally introducing tree nodes instead of leaves:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "N-th4jOAGJlM", - "outputId": "23eed14d-d383-4d88-d6f9-02bac06020df" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "[(Array([1., 1.], dtype=float32), Array([1., 1., 1.], dtype=float32)),\n", - " (Array([1., 1., 1.], dtype=float32), Array([1., 1., 1., 1.], dtype=float32))]" - ] - }, - "execution_count": 17, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "a_tree = [jnp.zeros((2, 3)), jnp.zeros((3, 4))]\n", - "\n", - "# Try to make another tree with ones instead of zeros\n", - "shapes = jax.tree_map(lambda x: x.shape, a_tree)\n", - "jax.tree_map(jnp.ones, shapes)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "q8d4y-hfHTWh" - }, - "source": [ - "What happened is that the `shape` of an array is a tuple, which is a pytree node, with its elements as leaves. Thus, in the map, instead of calling `jnp.ones` on e.g. `(2, 3)`, it's called on `2` and `3`.\n", - "\n", - "The solution will depend on the specifics, but there are two broadly applicable options:\n", - "* rewrite the code to avoid the intermediate `tree_map`.\n", - "* convert the tuple into an `np.array` or `jnp.array`, which makes the entire\n", - "sequence a leaf." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "4OKlbFlEIda-" - }, - "source": [ - "#### Handling of None\n", - "`jax.tree_utils` treats `None` as a node without children, not as a leaf:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "gIwlwo2MJcEC", - "outputId": "1e59f323-a7b7-42be-8603-afa4693c00cc" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "[]" - ] - }, - "execution_count": 18, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "jax.tree_util.tree_leaves([None, None, None])" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "pwNz-rp1JvW4" - }, - "source": [ - "### Patterns\n", - "#### Transposing trees\n", - "\n", - "If you would like to transpose a pytree, i.e. turn a list of trees into a tree of lists, you can do so using `jax.tree_map`:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "UExN7-G7qU-F", - "outputId": "fd049086-ef37-44db-8e2c-9f1bd9fad950" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "{'obs': [3, 4], 't': [1, 2]}" - ] - }, - "execution_count": 19, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "def tree_transpose(list_of_trees):\n", - " \"\"\"Convert a list of trees of identical structure into a single tree of lists.\"\"\"\n", - " return jax.tree_map(lambda *xs: list(xs), *list_of_trees)\n", - "\n", - "\n", - "# Convert a dataset from row-major to column-major:\n", - "episode_steps = [dict(t=1, obs=3), dict(t=2, obs=4)]\n", - "tree_transpose(episode_steps)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Ao6R2ffm2CF4" - }, - "source": [ - "For more complicated transposes, JAX provides `jax.tree_transpose`, which is more verbose, but allows you specify the structure of the inner and outer Pytree for more flexibility:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "bZvVwxshz1D3", - "outputId": "a0314dc8-4267-41e6-a763-931d40433c26" - }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/var/folders/v7/zgx9fwms2fnd2d8pwb2c3r2400gbn1/T/ipykernel_94597/112383129.py:2: FutureWarning: jax.tree_structure is deprecated, and will be removed in a future release. Use jax.tree_util.tree_structure instead.\n", - " outer_treedef = jax.tree_structure([0 for e in episode_steps]),\n", - "/var/folders/v7/zgx9fwms2fnd2d8pwb2c3r2400gbn1/T/ipykernel_94597/112383129.py:3: FutureWarning: jax.tree_structure is deprecated, and will be removed in a future release. Use jax.tree_util.tree_structure instead.\n", - " inner_treedef = jax.tree_structure(episode_steps[0]),\n", - "/var/folders/v7/zgx9fwms2fnd2d8pwb2c3r2400gbn1/T/ipykernel_94597/112383129.py:1: FutureWarning: jax.tree_transpose is deprecated, and will be removed in a future release. Use jax.tree_util.tree_transpose instead.\n", - " jax.tree_transpose(\n" - ] - }, - { - "data": { - "text/plain": [ - "{'obs': [3, 4], 't': [1, 2]}" - ] - }, - "execution_count": 20, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "jax.tree_transpose(\n", - " outer_treedef = jax.tree_structure([0 for e in episode_steps]),\n", - " inner_treedef = jax.tree_structure(episode_steps[0]),\n", - " pytree_to_transpose = episode_steps\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "KlYA2R6N2h_8" - }, - "source": [ - "## More Information\n", - "\n", - "For more information on pytrees in JAX and the operations that are available, see the [Pytrees](https://jax.readthedocs.io/en/latest/pytrees.html) section in the JAX documentation." - ] - } - ], - "metadata": { - "colab": { - "last_runtime": { - "build_target": "//learning/deepmind/dm_python:dm_notebook3", - "kind": "private" - }, - "name": "jax101-pytrees", - "provenance": [] - }, - "jupytext": { - "formats": "ipynb,md:myst" - }, - "kernelspec": { - "display_name": "Python 3", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.9.16" - } - }, - "nbformat": 4, - "nbformat_minor": 0 -} diff --git a/docs/jax-101/05.1-pytrees.md b/docs/jax-101/05.1-pytrees.md deleted file mode 100644 index 7784afcff590..000000000000 --- a/docs/jax-101/05.1-pytrees.md +++ /dev/null @@ -1,536 +0,0 @@ ---- -jupytext: - formats: ipynb,md:myst - text_representation: - extension: .md - format_name: myst - format_version: 0.13 - jupytext_version: 1.16.0 -kernelspec: - display_name: Python 3 - name: python3 ---- - -+++ {"id": "-h05_PNNhZ-D"} - -# Working with Pytrees - -[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/jax-101/05.1-pytrees.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/jax-101/05.1-pytrees.ipynb) - -*Author: Vladimir Mikulik* - -Often, we want to operate on objects that look like dicts of arrays, or lists of lists of dicts, or other nested structures. In JAX, we refer to these as *pytrees*, but you can sometimes see them called *nests*, or just *trees*. - -JAX has built-in support for such objects, both in its library functions as well as through the use of functions from [`jax.tree_utils`](https://jax.readthedocs.io/en/latest/jax.tree_util.html) (with the most common ones also available as `jax.tree_*`). This section will explain how to use them, give some useful snippets and point out common gotchas. - -+++ {"id": "9UjxVY9ulSCn"} - -## What is a pytree? - -As defined in the [JAX pytree docs](https://jax.readthedocs.io/en/latest/pytrees.html): - -> a pytree is a container of leaf elements and/or more pytrees. Containers include lists, tuples, and dicts. A leaf element is anything that’s not a pytree, e.g. an array. In other words, a pytree is just a possibly-nested standard or user-registered Python container. If nested, note that the container types do not need to match. A single “leaf”, i.e. a non-container object, is also considered a pytree. - -Some example pytrees: - -```{code-cell} ipython3 ---- -executionInfo: - elapsed: 11002 - status: ok - timestamp: 1692698031720 - user: - displayName: '' - userId: '' - user_tz: -60 -id: Wh6BApZ9lrR1 -outputId: df1fa4cd-88a6-4d71-a376-b2ddf91568dd ---- -import jax -import jax.numpy as jnp - -example_trees = [ - [1, 'a', object()], - (1, (2, 3), ()), - [1, {'k1': 2, 'k2': (3, 4)}, 5], - {'a': 2, 'b': (2, 3)}, - jnp.array([1, 2, 3]), -] - -# Let's see how many leaves they have: -for pytree in example_trees: - leaves = jax.tree_util.tree_leaves(pytree) - print(f"{repr(pytree):<45} has {len(leaves)} leaves: {leaves}") -``` - -+++ {"id": "_tWkkGNwW8vf"} - -We've also introduced our first `jax.tree_*` function, which allowed us to extract the flattened leaves from the trees. - -+++ {"id": "RcsmneIGlltm"} - -## Why pytrees? - -In machine learning, some places where you commonly find pytrees are: -* Model parameters -* Dataset entries -* RL agent observations - -They also often arise naturally when working in bulk with datasets (e.g., lists of lists of dicts). - -+++ {"id": "sMrSGSIJn9MD"} - -## Common pytree functions -Perhaps the most commonly used pytree function is `jax.tree_map`. It works analogously to Python's native `map`, but on entire pytrees: - -```{code-cell} ipython3 -:id: wZRcuQu4n7o5 -:outputId: 3528bc9f-54ed-49c8-b79a-1cbea176c0f3 - -list_of_lists = [ - [1, 2, 3], - [1, 2], - [1, 2, 3, 4] -] - -jax.tree_map(lambda x: x*2, list_of_lists) -``` - -+++ {"id": "xu8X3fk4orC9"} - -`jax.tree_map` also works with multiple arguments: - -```{code-cell} ipython3 -:id: KVpB4r1OkeUK -:outputId: 33f88a7e-aac7-48cd-d207-2c531cd37733 - -another_list_of_lists = list_of_lists -jax.tree_map(lambda x, y: x+y, list_of_lists, another_list_of_lists) -``` - -+++ {"id": "dkRKy3LvowAb"} - -When using multiple arguments with `jax.tree_map`, the structure of the inputs must exactly match. That is, lists must have the same number of elements, dicts must have the same keys, etc. - -+++ {"id": "Lla4hDW6sgMZ"} - -## Example: ML model parameters - -A simple example of training an MLP displays some ways in which pytree operations come in useful: - -```{code-cell} ipython3 -:id: j2ZUzWx8tKB2 - -import numpy as np - -def init_mlp_params(layer_widths): - params = [] - for n_in, n_out in zip(layer_widths[:-1], layer_widths[1:]): - params.append( - dict(weights=np.random.normal(size=(n_in, n_out)) * np.sqrt(2/n_in), - biases=np.ones(shape=(n_out,)) - ) - ) - return params - -params = init_mlp_params([1, 128, 128, 1]) -``` - -+++ {"id": "kUFwJOspuGvU"} - -We can use `jax.tree_map` to check that the shapes of our parameters are what we expect: - -```{code-cell} ipython3 -:id: ErWsXuxXse-z -:outputId: d3e549ab-40ef-470e-e460-1b5939d9696f - -jax.tree_map(lambda x: x.shape, params) -``` - -+++ {"id": "zQtRKaj4ua6-"} - -Now, let's train our MLP: - -```{code-cell} ipython3 -:id: iL4GvW9OuZ-X - -def forward(params, x): - *hidden, last = params - for layer in hidden: - x = jax.nn.relu(x @ layer['weights'] + layer['biases']) - return x @ last['weights'] + last['biases'] - -def loss_fn(params, x, y): - return jnp.mean((forward(params, x) - y) ** 2) - -LEARNING_RATE = 0.0001 - -@jax.jit -def update(params, x, y): - - grads = jax.grad(loss_fn)(params, x, y) - # Note that `grads` is a pytree with the same structure as `params`. - # `jax.grad` is one of the many JAX functions that has - # built-in support for pytrees. - - # This is handy, because we can apply the SGD update using tree utils: - return jax.tree_map( - lambda p, g: p - LEARNING_RATE * g, params, grads - ) -``` - -```{code-cell} ipython3 -:id: B3HniT9-xohz -:outputId: d77e9811-373e-45d6-ccbe-edb6f43120d7 - -import matplotlib.pyplot as plt - -xs = np.random.normal(size=(128, 1)) -ys = xs ** 2 - -for _ in range(1000): - params = update(params, xs, ys) - -plt.scatter(xs, ys) -plt.scatter(xs, forward(params, xs), label='Model prediction') -plt.legend(); -``` - -+++ {"id": "lNAvmpzdoE9l"} - -## Key paths - -In a pytree each leaf has a _key path_. A key path for a leaf is a `list` of _keys_, where the length of the list is equal to the depth of the leaf in the pytree . Each _key_ is a [hashable object](https://docs.python.org/3/glossary.html#term-hashable) that represents an index into the corresponding pytree node type. The type of the key depends on the pytree node type; for example, the type of keys for `dict`s is different from the type of keys for `tuple`s. - -For built-in pytree node types, the set of keys for any pytree node instance is unique. For a pytree comprising nodes with this property, the key path for each leaf is unique. - -The APIs for working with key paths are: - -* [`jax.tree_util.tree_flatten_with_path`](https://jax.readthedocs.io/en/latest/_autosummary/jax.tree_util.tree_flatten_with_path.html): Works similarly with `jax.tree_util.tree_flatten`, but returns key paths. - -* [`jax.tree_util.tree_map_with_path`](https://jax.readthedocs.io/en/latest/_autosummary/jax.tree_util.tree_map_with_path.html): Works similarly with `jax.tree_util.tree_map`, but the function also takes key paths as arguments. - -* [`jax.tree_util.keystr`](https://jax.readthedocs.io/en/latest/_autosummary/jax.tree_util.keystr.html): Given a general key path, returns a reader-friendly string expression. - -One use case is to print debugging information related to a certain leaf value: - -```{code-cell} ipython3 -:id: G6E2YzhvoE9l -:outputId: 5aec83c8-e15e-48eb-b2c3-6fa0164344b5 - -import collections -ATuple = collections.namedtuple("ATuple", ('name')) - -tree = [1, {'k1': 2, 'k2': (3, 4)}, ATuple('foo')] -flattened, _ = jax.tree_util.tree_flatten_with_path(tree) -for key_path, value in flattened: - print(f'Value of tree{jax.tree_util.keystr(key_path)}: {value}') -``` - -+++ {"id": "zrKqmANgoE9l"} - -To express key paths, JAX provides a few default key types for the built-in pytree node types, namely: - -* `SequenceKey(idx: int)`: for lists and tuples. -* `DictKey(key: Hashable)`: for dictionaries. -* `GetAttrKey(name: str)`: for `namedtuple`s and preferably custom pytree nodes (more in the next section) - -You are free to define your own key types for your own custom nodes. They will work with `jax.tree_util.keystr` as long as their `__str__()` method is also overridden with a reader-friendly expression. - -```{code-cell} ipython3 -:id: ohDq0kGuoE9l -:outputId: 9b8ff3ec-3461-482e-ff27-30dc2a7e68c9 - -for key_path, _ in flattened: - print(f'Key path of tree{jax.tree_util.keystr(key_path)}: {repr(key_path)}') -``` - -+++ {"id": "sBxOB21YNEDA"} - -## Custom pytree nodes - -So far, we've only been considering pytrees of lists, tuples, and dicts; everything else is considered a leaf. Therefore, if you define your own container class, it will be considered a leaf, even if it has trees inside it: - -```{code-cell} ipython3 -:id: CK8LN2PRFnQf - -class MyContainer: - """A named container.""" - - def __init__(self, name: str, a: int, b: int, c: int): - self.name = name - self.a = a - self.b = b - self.c = c -``` - -```{code-cell} ipython3 -:id: OPGe2R7ZOXCT -:outputId: 40db1f41-9df8-4dea-972a-6a7bc44a49c6 - -jax.tree_util.tree_leaves([ - MyContainer('Alice', 1, 2, 3), - MyContainer('Bob', 4, 5, 6) -]) -``` - -+++ {"id": "vk4vucGXPADj"} - -Accordingly, if we try to use a tree map expecting our leaves to be the elements inside the container, we will get an error: - -```{code-cell} ipython3 -:id: vIr9_JOIOku7 -:outputId: dadc9c15-4a10-4fac-e70d-f23e7085cf74 - -try: - jax.tree_map(lambda x: x + 1, [ - MyContainer('Alice', 1, 2, 3), - MyContainer('Bob', 4, 5, 6) - ]) -except TypeError as e: - print(f'TypeError: {e}') -``` - -+++ {"id": "nAZ4FR2lPN51", "tags": ["raises-exception"]} - -To solve this, we need to register our container with JAX by telling it how to flatten and unflatten it: - -```{code-cell} ipython3 -:id: 2RR5cDFvoE9m -:outputId: 94745373-abe4-4bca-967c-4133e8027c30 - -from typing import Iterable - -def flatten_MyContainer(container) -> tuple[Iterable[int], str]: - """Returns an iterable over container contents, and aux data.""" - flat_contents = [container.a, container.b, container.c] - - # we don't want the name to appear as a child, so it is auxiliary data. - # auxiliary data is usually a description of the structure of a node, - # e.g., the keys of a dict -- anything that isn't a node's children. - aux_data = container.name - return flat_contents, aux_data - -def unflatten_MyContainer( - aux_data: str, flat_contents: Iterable[int]) -> MyContainer: - """Converts aux data and the flat contents into a MyContainer.""" - return MyContainer(aux_data, *flat_contents) - -jax.tree_util.register_pytree_node( - MyContainer, flatten_MyContainer, unflatten_MyContainer) - -jax.tree_util.tree_leaves([ - MyContainer('Alice', 1, 2, 3), - MyContainer('Bob', 4, 5, 6) -]) -``` - -+++ {"id": "JXaEe76ZoE9m"} - -Alternatively, using the key path API mentioned above, you can register this container with its keys in mind by defining how the keys should look like for each flattened-out value. - -```{code-cell} ipython3 -:id: D_juQx-2OybX -:outputId: ee2cf4ad-ec21-4636-c9c5-2c64b81429bb - -class MyKeyPathContainer(MyContainer): - pass - -def flatten_with_keys_MyKeyPathContainer(container) -> tuple[Iterable[int], str]: - """Returns an iterable over container contents, and aux data.""" - - # GetAttrKey is a common way to express an attribute key. Users are free - # to pick any other expression that fits their use cases the best. - flat_contents = [(jax.tree_util.GetAttrKey('a'), container.a), - (jax.tree_util.GetAttrKey('b'), container.b), - (jax.tree_util.GetAttrKey('c'), container.c)] - - # we don't want the name to appear as a child, so it is auxiliary data. - # auxiliary data is usually a description of the structure of a node, - # e.g., the keys of a dict -- anything that isn't a node's children. - aux_data = container.name - return flat_contents, aux_data - -def unflatten_MyKeyPathContainer( - aux_data: str, flat_contents: Iterable[int]) -> MyKeyPathContainer: - """Converts aux data and the flat contents into a MyContainer.""" - return MyKeyPathContainer(aux_data, *flat_contents) - -jax.tree_util.register_pytree_with_keys( - MyKeyPathContainer, flatten_with_keys_MyKeyPathContainer, unflatten_MyKeyPathContainer) - -jax.tree_util.tree_leaves([ - MyKeyPathContainer('Alice', 1, 2, 3), - MyKeyPathContainer('Bob', 4, 5, 6) -]) -``` - -+++ {"id": "HPX23W4zoE9m"} - -`register_pytree_with_keys` is an extended API of `register_pytree_node`, and containers registered in either way can freely use all the `tree_util` utilities without error. - -When a container registered with `register_pytree_node` uses `.*_with_path` APIs, the keys being returned will be a series of "flat index" fallbacks: - -```{code-cell} ipython3 -:id: E1BwD2aZoE9m -:outputId: 4fe12b06-aef4-426a-a732-891affa63842 - -flattened, _ = jax.tree_util.tree_flatten_with_path(MyContainer('Alice', 1, 2, 3)) -for key_path, value in flattened: - print(f'MyContainer container{jax.tree_util.keystr(key_path)}: {value}') - -flattened, _ = jax.tree_util.tree_flatten_with_path(MyKeyPathContainer('Alice', 1, 2, 3)) -for key_path, value in flattened: - print(f'MyKeyPathContainer container{jax.tree_util.keystr(key_path)}: {value}') -``` - -+++ {"id": "JgnAp7fFShEB"} - -Modern Python comes equipped with helpful tools to make defining containers easier. Some of these will work with JAX out-of-the-box, but others require more care. For instance, a `NamedTuple` subclass doesn't need to be registered to be considered a pytree node type: - -```{code-cell} ipython3 -:id: 8DNoLABtO0fr -:outputId: 9a448508-43eb-4450-bfaf-eeeb59a9e349 - -from typing import NamedTuple, Any - -class MyOtherContainer(NamedTuple): - name: str - a: Any - b: Any - c: Any - -# NamedTuple subclasses are handled as pytree nodes, so -# this will work out-of-the-box: -jax.tree_util.tree_leaves([ - MyOtherContainer('Alice', 1, 2, 3), - MyOtherContainer('Bob', 4, 5, 6) -]) -``` - -+++ {"id": "TVdtzJDVTZb6"} - -Notice that the `name` field now appears as a leaf, as all tuple elements are children. That's the price we pay for not having to register the class the hard way. - -+++ {"id": "wDbVszv-oE9n"} - -One shortcut is to use `jax.tree_util.register_static` to register a type as being a node without children: - -```{code-cell} ipython3 ---- -executionInfo: - elapsed: 59 - status: ok - timestamp: 1692698060536 - user: - displayName: '' - userId: '' - user_tz: -60 -id: Rclc079ioE9n -outputId: 6b6a4402-8fc1-409c-b6da-88568a612e1b ---- -from typing import NamedTuple, Any - -@jax.tree_util.register_static -class StaticStr(str): - pass - - -class YetAnotherContainer(NamedTuple): - name: StaticStr - a: Any - b: Any - c: Any - - -# NamedTuple subclasses are handled as pytree nodes, so -# this will work out-of-the-box: -jax.tree_util.tree_leaves([ - YetAnotherContainer(StaticStr('Alice'), 1, 2, 3), - YetAnotherContainer(StaticStr('Bob'), 4, 5, 6) -]) -``` - -+++ {"id": "kNsTszcEEHD0"} - -## Common pytree gotchas and patterns - -+++ {"id": "0ki-JDENzyL7"} - -### Gotchas -#### Mistaking nodes for leaves -A common problem to look out for is accidentally introducing tree nodes instead of leaves: - -```{code-cell} ipython3 -:id: N-th4jOAGJlM -:outputId: 23eed14d-d383-4d88-d6f9-02bac06020df - -a_tree = [jnp.zeros((2, 3)), jnp.zeros((3, 4))] - -# Try to make another tree with ones instead of zeros -shapes = jax.tree_map(lambda x: x.shape, a_tree) -jax.tree_map(jnp.ones, shapes) -``` - -+++ {"id": "q8d4y-hfHTWh"} - -What happened is that the `shape` of an array is a tuple, which is a pytree node, with its elements as leaves. Thus, in the map, instead of calling `jnp.ones` on e.g. `(2, 3)`, it's called on `2` and `3`. - -The solution will depend on the specifics, but there are two broadly applicable options: -* rewrite the code to avoid the intermediate `tree_map`. -* convert the tuple into an `np.array` or `jnp.array`, which makes the entire -sequence a leaf. - -+++ {"id": "4OKlbFlEIda-"} - -#### Handling of None -`jax.tree_utils` treats `None` as a node without children, not as a leaf: - -```{code-cell} ipython3 -:id: gIwlwo2MJcEC -:outputId: 1e59f323-a7b7-42be-8603-afa4693c00cc - -jax.tree_util.tree_leaves([None, None, None]) -``` - -+++ {"id": "pwNz-rp1JvW4"} - -### Patterns -#### Transposing trees - -If you would like to transpose a pytree, i.e. turn a list of trees into a tree of lists, you can do so using `jax.tree_map`: - -```{code-cell} ipython3 -:id: UExN7-G7qU-F -:outputId: fd049086-ef37-44db-8e2c-9f1bd9fad950 - -def tree_transpose(list_of_trees): - """Convert a list of trees of identical structure into a single tree of lists.""" - return jax.tree_map(lambda *xs: list(xs), *list_of_trees) - - -# Convert a dataset from row-major to column-major: -episode_steps = [dict(t=1, obs=3), dict(t=2, obs=4)] -tree_transpose(episode_steps) -``` - -+++ {"id": "Ao6R2ffm2CF4"} - -For more complicated transposes, JAX provides `jax.tree_transpose`, which is more verbose, but allows you specify the structure of the inner and outer Pytree for more flexibility: - -```{code-cell} ipython3 -:id: bZvVwxshz1D3 -:outputId: a0314dc8-4267-41e6-a763-931d40433c26 - -jax.tree_transpose( - outer_treedef = jax.tree_structure([0 for e in episode_steps]), - inner_treedef = jax.tree_structure(episode_steps[0]), - pytree_to_transpose = episode_steps -) -``` - -+++ {"id": "KlYA2R6N2h_8"} - -## More Information - -For more information on pytrees in JAX and the operations that are available, see the [Pytrees](https://jax.readthedocs.io/en/latest/pytrees.html) section in the JAX documentation. diff --git a/docs/jax-101/06-parallelism.ipynb b/docs/jax-101/06-parallelism.ipynb deleted file mode 100644 index 86aa7bd4260c..000000000000 --- a/docs/jax-101/06-parallelism.ipynb +++ /dev/null @@ -1,912 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "tCOWitsAS1EE" - }, - "source": [ - "# Parallel Evaluation in JAX\n", - "\n", - "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/jax-101/06-parallelism.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/jax-101/06-parallelism.ipynb)\n", - "\n", - "*Authors: Vladimir Mikulik & Roman Ring*\n", - "\n", - "In this section we will discuss the facilities built into JAX for single-program, multiple-data (SPMD) code.\n", - "\n", - "SPMD refers to a parallelism technique where the same computation (e.g., the forward pass of a neural net) is run on different input data (e.g., different inputs in a batch) in parallel on different devices (e.g., several TPUs).\n", - "\n", - "Conceptually, this is not very different from vectorisation, where the same operations occur in parallel in different parts of memory on the same device. We have already seen that vectorisation is supported in JAX as a program transformation, `jax.vmap`. JAX supports device parallelism analogously, using `jax.pmap` to transform a function written for one device into a function that runs in parallel on multiple devices. This colab will teach you all about it." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "7mCgBzix2fd3" - }, - "source": [ - "## TPU Setup\n", - "\n", - "This notebook requires multiple accelerators and we recommend running it using Kaggle TPU VMs." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "gN6VbcdRTcdE" - }, - "source": [ - "Next run the following to see the TPU devices you have available:" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": { - "id": "tqbpCcqY3Cn7", - "outputId": "1fb88cf7-35f7-4565-f370-51586213b988" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "[TpuDevice(id=0, host_id=0, coords=(0,0,0), core_on_chip=0),\n", - " TpuDevice(id=1, host_id=0, coords=(0,0,0), core_on_chip=1),\n", - " TpuDevice(id=2, host_id=0, coords=(1,0,0), core_on_chip=0),\n", - " TpuDevice(id=3, host_id=0, coords=(1,0,0), core_on_chip=1),\n", - " TpuDevice(id=4, host_id=0, coords=(0,1,0), core_on_chip=0),\n", - " TpuDevice(id=5, host_id=0, coords=(0,1,0), core_on_chip=1),\n", - " TpuDevice(id=6, host_id=0, coords=(1,1,0), core_on_chip=0),\n", - " TpuDevice(id=7, host_id=0, coords=(1,1,0), core_on_chip=1)]" - ] - }, - "execution_count": 2, - "metadata": { - "tags": [] - }, - "output_type": "execute_result" - } - ], - "source": [ - "import jax\n", - "jax.devices()" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "4_EDa0Dlgtf8" - }, - "source": [ - "## The basics\n", - "\n", - "The most basic use of `jax.pmap` is completely analogous to `jax.vmap`, so let's return to the convolution example from the [Vectorisation notebook](https://colab.research.google.com/github/google/jax/blob/main/docs/jax-101/03-vectorization.ipynb)." - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": { - "id": "IIQKBr-CgtD2", - "outputId": "6e7f8755-fdfd-4cf9-e2b5-a10c5a870dd4" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "Array([11., 20., 29.], dtype=float32)" - ] - }, - "execution_count": 5, - "metadata": { - "tags": [] - }, - "output_type": "execute_result" - } - ], - "source": [ - "import numpy as np\n", - "import jax.numpy as jnp\n", - "\n", - "x = np.arange(5)\n", - "w = np.array([2., 3., 4.])\n", - "\n", - "def convolve(x, w):\n", - " output = []\n", - " for i in range(1, len(x)-1):\n", - " output.append(jnp.dot(x[i-1:i+2], w))\n", - " return jnp.array(output)\n", - "\n", - "convolve(x, w)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "lqxz9NNJOQ9Z" - }, - "source": [ - "Now, let's convert our `convolve` function into one that runs on entire batches of data. In anticipation of spreading the batch across several devices, we'll make the batch size equal to the number of devices:" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": { - "id": "ll-hEa0jihzx", - "outputId": "788be05a-10d4-4a05-8d9d-49d0083541ab" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "array([[ 0, 1, 2, 3, 4],\n", - " [ 5, 6, 7, 8, 9],\n", - " [10, 11, 12, 13, 14],\n", - " [15, 16, 17, 18, 19],\n", - " [20, 21, 22, 23, 24],\n", - " [25, 26, 27, 28, 29],\n", - " [30, 31, 32, 33, 34],\n", - " [35, 36, 37, 38, 39]])" - ] - }, - "execution_count": 6, - "metadata": { - "tags": [] - }, - "output_type": "execute_result" - } - ], - "source": [ - "n_devices = jax.local_device_count() \n", - "xs = np.arange(5 * n_devices).reshape(-1, 5)\n", - "ws = np.stack([w] * n_devices)\n", - "\n", - "xs" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": { - "id": "mi-nysDWYbn4", - "outputId": "2d115fc3-52f5-4a68-c3a7-115111a83657" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "array([[2., 3., 4.],\n", - " [2., 3., 4.],\n", - " [2., 3., 4.],\n", - " [2., 3., 4.],\n", - " [2., 3., 4.],\n", - " [2., 3., 4.],\n", - " [2., 3., 4.],\n", - " [2., 3., 4.]])" - ] - }, - "execution_count": 7, - "metadata": { - "tags": [] - }, - "output_type": "execute_result" - } - ], - "source": [ - "ws" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "8kseIB09YWJw" - }, - "source": [ - "As before, we can vectorise using `jax.vmap`:" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": { - "id": "TNb9HsFXYVOI", - "outputId": "2e60e07a-6687-49ab-a455-60d2ec484363" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "Array([[ 11., 20., 29.],\n", - " [ 56., 65., 74.],\n", - " [101., 110., 119.],\n", - " [146., 155., 164.],\n", - " [191., 200., 209.],\n", - " [236., 245., 254.],\n", - " [281., 290., 299.],\n", - " [326., 335., 344.]], dtype=float32)" - ] - }, - "execution_count": 8, - "metadata": { - "tags": [] - }, - "output_type": "execute_result" - } - ], - "source": [ - "jax.vmap(convolve)(xs, ws)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "TDF1vzt_5GMC" - }, - "source": [ - "To spread out the computation across multiple devices, just replace `jax.vmap` with `jax.pmap`:" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": { - "id": "KWoextrails4", - "outputId": "bad1fbb7-226a-4538-e442-20ce0c1c8fad" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "Array([[ 11., 20., 29.],\n", - " [ 56., 65., 74.],\n", - " [101., 110., 119.],\n", - " [146., 155., 164.],\n", - " [191., 200., 209.],\n", - " [236., 245., 254.],\n", - " [281., 290., 299.],\n", - " [326., 335., 344.]], dtype=float32)" - ] - }, - "execution_count": 9, - "metadata": { - "tags": [] - }, - "output_type": "execute_result" - } - ], - "source": [ - "jax.pmap(convolve)(xs, ws)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "E69cVxQPksxe" - }, - "source": [ - "Note that the parallelized `convolve` returns a `jax.Array`. That is because the elements of this array are sharded across all of the devices used in the parallelism. If we were to run another parallel computation, the elements would stay on their respective devices, without incurring cross-device communication costs." - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "metadata": { - "id": "P9dUyk-ciquy", - "outputId": "99ea4c6e-cff7-4611-e9e5-bf016fa9716c" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "Array([[ 78., 138., 198.],\n", - " [ 1188., 1383., 1578.],\n", - " [ 3648., 3978., 4308.],\n", - " [ 7458., 7923., 8388.],\n", - " [12618., 13218., 13818.],\n", - " [19128., 19863., 20598.],\n", - " [26988., 27858., 28728.],\n", - " [36198., 37203., 38208.]], dtype=float32)" - ] - }, - "execution_count": 11, - "metadata": { - "tags": [] - }, - "output_type": "execute_result" - } - ], - "source": [ - "jax.pmap(convolve)(xs, jax.pmap(convolve)(xs, ws))" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "iuHqht-OYqca" - }, - "source": [ - "The outputs of the inner `jax.pmap(convolve)` never left their devices when being fed into the outer `jax.pmap(convolve)`." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "vEFAJXN2q3dV" - }, - "source": [ - "## Specifying `in_axes`\n", - "\n", - "Like with `vmap`, we can use `in_axes` to specify whether an argument to the parallelized function should be broadcast (`None`), or whether it should be split along a given axis. Note, however, that unlike `vmap`, only the leading axis (`0`) is supported by `pmap` at the time of writing this guide." - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": { - "id": "6Es5WVuRlXnB", - "outputId": "7e9612ae-d6e0-4d79-a228-f0403fcf8237" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "Array([[ 11., 20., 29.],\n", - " [ 56., 65., 74.],\n", - " [101., 110., 119.],\n", - " [146., 155., 164.],\n", - " [191., 200., 209.],\n", - " [236., 245., 254.],\n", - " [281., 290., 299.],\n", - " [326., 335., 344.]], dtype=float32)" - ] - }, - "execution_count": 12, - "metadata": { - "tags": [] - }, - "output_type": "execute_result" - } - ], - "source": [ - "jax.pmap(convolve, in_axes=(0, None))(xs, w)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "EoN6drHDOlk4" - }, - "source": [ - "Notice how we get equivalent output to what we observe above with `jax.pmap(convolve)(xs, ws)`, where we manually replicated `w` when creating `ws`. Here, it is replicated via broadcasting, by specifying it as `None` in `in_axes`." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "rRE8STSU5cjx" - }, - "source": [ - "Keep in mind that when calling the transformed function, the size of the specified axis in arguments must not exceed the number of devices available to the host." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "0lZnqImd7G6U" - }, - "source": [ - "## `pmap` and `jit`\n", - "\n", - "`jax.pmap` JIT-compiles the function given to it as part of its operation, so there is no need to additionally `jax.jit` it." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "1jZqk_2AwO4y" - }, - "source": [ - "## Communication between devices\n", - "\n", - "The above is enough to perform simple parallel operations, e.g. batching a simple MLP forward pass across several devices. However, sometimes we need to pass information between the devices. For example, perhaps we are interested in normalizing the output of each device so they sum to 1.\n", - "For that, we can use special [collective ops](https://jax.readthedocs.io/en/latest/jax.lax.html#parallel-operators) (such as the `jax.lax.p*` ops `psum`, `pmean`, `pmax`, ...). In order to use the collective ops we must specify the name of the `pmap`-ed axis through `axis_name` argument, and then refer to it when calling the op. Here's how to do that:" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "metadata": { - "id": "0nCxGwqmtd3w", - "outputId": "6f9c93b0-51ed-40c5-ca5a-eacbaf40e686" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "Array([[0.00816024, 0.01408451, 0.019437 ],\n", - " [0.04154303, 0.04577465, 0.04959785],\n", - " [0.07492582, 0.07746479, 0.07975871],\n", - " [0.10830861, 0.10915492, 0.10991956],\n", - " [0.14169139, 0.14084506, 0.14008042],\n", - " [0.17507419, 0.17253521, 0.17024128],\n", - " [0.20845698, 0.20422535, 0.20040214],\n", - " [0.24183977, 0.23591548, 0.23056298]], dtype=float32)" - ] - }, - "execution_count": 13, - "metadata": { - "tags": [] - }, - "output_type": "execute_result" - } - ], - "source": [ - "def normalized_convolution(x, w):\n", - " output = []\n", - " for i in range(1, len(x)-1):\n", - " output.append(jnp.dot(x[i-1:i+2], w))\n", - " output = jnp.array(output)\n", - " return output / jax.lax.psum(output, axis_name='p')\n", - "\n", - "jax.pmap(normalized_convolution, axis_name='p')(xs, ws)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "9ENYsJS42YVK" - }, - "source": [ - "The `axis_name` is just a string label that allows collective operations like `jax.lax.psum` to refer to the axis bound by `jax.pmap`. It can be named anything you want -- in this case, `p`. This name is essentially invisible to anything but those functions, and those functions use it to know which axis to communicate across.\n", - "\n", - "`jax.vmap` also supports `axis_name`, which allows `jax.lax.p*` operations to be used in the vectorisation context in the same way they would be used in a `jax.pmap`:" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "metadata": { - "id": "nT61xAYJUqCW", - "outputId": "e8831025-78a6-4a2b-a60a-3c77b35214ef" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "Array([[0.00816024, 0.01408451, 0.019437 ],\n", - " [0.04154303, 0.04577465, 0.04959785],\n", - " [0.07492582, 0.07746479, 0.07975871],\n", - " [0.10830861, 0.10915492, 0.10991956],\n", - " [0.14169139, 0.14084506, 0.14008042],\n", - " [0.17507419, 0.17253521, 0.17024128],\n", - " [0.20845698, 0.20422535, 0.20040214],\n", - " [0.24183977, 0.23591548, 0.23056298]], dtype=float32)" - ] - }, - "execution_count": 14, - "metadata": { - "tags": [] - }, - "output_type": "execute_result" - } - ], - "source": [ - "jax.vmap(normalized_convolution, axis_name='p')(xs, ws)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "JSK-9dbWWV2O" - }, - "source": [ - "Note that `normalized_convolution` will no longer work without being transformed by `jax.pmap` or `jax.vmap`, because `jax.lax.psum` expects there to be a named axis (`'p'`, in this case), and those two transformations are the only way to bind one.\n", - "\n", - "## Nesting `jax.pmap` and `jax.vmap`\n", - "\n", - "The reason we specify `axis_name` as a string is so we can use collective operations when nesting `jax.pmap` and `jax.vmap`. For example:\n", - "\n", - "```python\n", - "jax.vmap(jax.pmap(f, axis_name='i'), axis_name='j')\n", - "```\n", - "\n", - "A `jax.lax.psum(..., axis_name='i')` in `f` would refer only to the pmapped axis, since they share the `axis_name`. \n", - "\n", - "In general, `jax.pmap` and `jax.vmap` can be nested in any order, and with themselves (so you can have a `pmap` within another `pmap`, for instance)." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "WzQHxnHkCxej" - }, - "source": [ - "## Example\n", - "\n", - "Here's an example of a regression training loop with data parallelism, where each batch is split into sub-batches which are evaluated on separate devices.\n", - "\n", - "There are two places to pay attention to:\n", - "* the `update()` function\n", - "* the replication of parameters and splitting of data across devices.\n", - "\n", - "If this example is too confusing, you can find the same example, but without parallelism, in the next notebook, [State in JAX](https://colab.research.google.com/github/google/jax/blob/main/docs/jax-101/07-state.ipynb). Once that example makes sense, you can compare the differences to understand how parallelism changes the picture." - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "metadata": { - "id": "cI8xQqzRrc-4" - }, - "outputs": [], - "source": [ - "from typing import NamedTuple\n", - "import functools\n", - "\n", - "class Params(NamedTuple):\n", - " weight: jnp.ndarray\n", - " bias: jnp.ndarray\n", - "\n", - "\n", - "def init(rng) -> Params:\n", - " \"\"\"Returns the initial model params.\"\"\"\n", - " weights_key, bias_key = jax.random.split(rng)\n", - " weight = jax.random.normal(weights_key, ())\n", - " bias = jax.random.normal(bias_key, ())\n", - " return Params(weight, bias)\n", - "\n", - "\n", - "def loss_fn(params: Params, xs: jnp.ndarray, ys: jnp.ndarray) -> jnp.ndarray:\n", - " \"\"\"Computes the least squares error of the model's predictions on x against y.\"\"\"\n", - " pred = params.weight * xs + params.bias\n", - " return jnp.mean((pred - ys) ** 2)\n", - "\n", - "LEARNING_RATE = 0.005\n", - "\n", - "# So far, the code is identical to the single-device case. Here's what's new:\n", - "\n", - "\n", - "# Remember that the `axis_name` is just an arbitrary string label used\n", - "# to later tell `jax.lax.pmean` which axis to reduce over. Here, we call it\n", - "# 'num_devices', but could have used anything, so long as `pmean` used the same.\n", - "@functools.partial(jax.pmap, axis_name='num_devices')\n", - "def update(params: Params, xs: jnp.ndarray, ys: jnp.ndarray) -> tuple[Params, jnp.ndarray]:\n", - " \"\"\"Performs one SGD update step on params using the given data.\"\"\"\n", - "\n", - " # Compute the gradients on the given minibatch (individually on each device).\n", - " loss, grads = jax.value_and_grad(loss_fn)(params, xs, ys)\n", - "\n", - " # Combine the gradient across all devices (by taking their mean).\n", - " grads = jax.lax.pmean(grads, axis_name='num_devices')\n", - "\n", - " # Also combine the loss. Unnecessary for the update, but useful for logging.\n", - " loss = jax.lax.pmean(loss, axis_name='num_devices')\n", - "\n", - " # Each device performs its own update, but since we start with the same params\n", - " # and synchronise gradients, the params stay in sync.\n", - " new_params = jax.tree_map(\n", - " lambda param, g: param - g * LEARNING_RATE, params, grads)\n", - "\n", - " return new_params, loss" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "RWce8YZ4Pcmf" - }, - "source": [ - "Here's how `update()` works:\n", - "\n", - "Undecorated and without the `pmean`s, `update()` takes data tensors of shape `[batch, ...]`, computes the loss function on that batch and evaluates its gradients.\n", - "\n", - "We want to spread the `batch` dimension across all available devices. To do that, we add a new axis using `pmap`. The arguments to the decorated `update()` thus need to have shape `[num_devices, batch_per_device, ...]`. So, to call the new `update()`, we'll need to reshape data batches so that what used to be `batch` is reshaped to `[num_devices, batch_per_device]`. That's what `split()` does below. Additionally, we'll need to replicate our model parameters, adding the `num_devices` axis. This reshaping is how a pmapped function knows which devices to send which data.\n", - "\n", - "At some point during the update step, we need to combine the gradients computed by each device -- otherwise, the updates performed by each device would be different. That's why we use `jax.lax.pmean` to compute the mean across the `num_devices` axis, giving us the average gradient of the batch. That average gradient is what we use to compute the update.\n", - "\n", - "Aside on naming: here, we use `num_devices` for the `axis_name` for didactic clarity while introducing `jax.pmap`. However, in some sense that is tautologous: any axis introduced by a pmap will represent a number of devices. Therefore, it's common to see the axis be named something semantically meaningful, like `batch`, `data` (signifying data parallelism) or `model` (signifying model parallelism)." - ] - }, - { - "cell_type": "code", - "execution_count": 18, - "metadata": { - "id": "_CTtLrsQ-0kK" - }, - "outputs": [], - "source": [ - "# Generate true data from y = w*x + b + noise\n", - "true_w, true_b = 2, -1\n", - "xs = np.random.normal(size=(128, 1))\n", - "noise = 0.5 * np.random.normal(size=(128, 1))\n", - "ys = xs * true_w + true_b + noise\n", - "\n", - "# Initialise parameters and replicate across devices.\n", - "params = init(jax.random.PRNGKey(123))\n", - "n_devices = jax.local_device_count()\n", - "replicated_params = jax.tree_map(lambda x: jnp.array([x] * n_devices), params)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "dmCMyLP9SV99" - }, - "source": [ - "So far, we've just constructed arrays with an additional leading dimension. The params are all still on the host (CPU). `pmap` will communicate them to the devices when `update()` is first called, and each copy will stay on its own device subsequently." - ] - }, - { - "cell_type": "code", - "execution_count": 19, - "metadata": { - "id": "YSCgHguTSdGW", - "outputId": "a8bf28df-3747-4d49-e340-b7696cf0c27d" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "jax.Array" - ] - }, - "execution_count": 19, - "metadata": { - "tags": [] - }, - "output_type": "execute_result" - } - ], - "source": [ - "type(replicated_params.weight)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "90VtjPbeY-hD" - }, - "source": [ - "The params will become a jax.Array when they are returned by our pmapped `update()` (see further down)." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "eGVKxk1CV-m1" - }, - "source": [ - "We do the same to the data:" - ] - }, - { - "cell_type": "code", - "execution_count": 20, - "metadata": { - "id": "vY61QJoFWCII", - "outputId": "f436a15f-db97-44cc-df33-bbb4ff222987" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "numpy.ndarray" - ] - }, - "execution_count": 20, - "metadata": { - "tags": [] - }, - "output_type": "execute_result" - } - ], - "source": [ - "def split(arr):\n", - " \"\"\"Splits the first axis of `arr` evenly across the number of devices.\"\"\"\n", - " return arr.reshape(n_devices, arr.shape[0] // n_devices, *arr.shape[1:])\n", - "\n", - "# Reshape xs and ys for the pmapped `update()`.\n", - "x_split = split(xs)\n", - "y_split = split(ys)\n", - "\n", - "type(x_split)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "RzfJ-oK5WERq" - }, - "source": [ - "The data is just a reshaped vanilla NumPy array. Hence, it cannot be anywhere but on the host, as NumPy runs on CPU only. Since we never modify it, it will get sent to the device at each `update` call, like in a real pipeline where data is typically streamed from CPU to the device at each step." - ] - }, - { - "cell_type": "code", - "execution_count": 22, - "metadata": { - "id": "atOTi7EeSQw-", - "outputId": "c8daf141-63c4-481f-afa5-684c5f7b698d" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "after first `update()`, `replicated_params.weight` is a \n", - "after first `update()`, `loss` is a \n", - "after first `update()`, `x_split` is a \n", - "Step 0, loss: 0.228\n", - "Step 100, loss: 0.228\n", - "Step 200, loss: 0.228\n", - "Step 300, loss: 0.228\n", - "Step 400, loss: 0.228\n", - "Step 500, loss: 0.228\n", - "Step 600, loss: 0.228\n", - "Step 700, loss: 0.228\n", - "Step 800, loss: 0.228\n", - "Step 900, loss: 0.228\n" - ] - } - ], - "source": [ - "def type_after_update(name, obj):\n", - " print(f\"after first `update()`, `{name}` is a\", type(obj))\n", - "\n", - "# Actual training loop.\n", - "for i in range(1000):\n", - "\n", - " # This is where the params and data gets communicated to devices:\n", - " replicated_params, loss = update(replicated_params, x_split, y_split)\n", - "\n", - " # The returned `replicated_params` and `loss` are now both jax.Arrays,\n", - " # indicating that they're on the devices.\n", - " # `x_split`, of course, remains a NumPy array on the host.\n", - " if i == 0:\n", - " type_after_update('replicated_params.weight', replicated_params.weight)\n", - " type_after_update('loss', loss)\n", - " type_after_update('x_split', x_split)\n", - "\n", - " if i % 100 == 0:\n", - " # Note that loss is actually an array of shape [num_devices], with identical\n", - " # entries, because each device returns its copy of the loss.\n", - " # So, we take the first element to print it.\n", - " print(f\"Step {i:3d}, loss: {loss[0]:.3f}\")\n", - "\n", - "\n", - "# Plot results.\n", - "\n", - "# Like the loss, the leaves of params have an extra leading dimension,\n", - "# so we take the params from the first device.\n", - "params = jax.device_get(jax.tree_map(lambda x: x[0], replicated_params))" - ] - }, - { - "cell_type": "code", - "execution_count": 23, - "metadata": { - "id": "rvVCACv9UZcF", - "outputId": "5c472d0f-1236-401b-be55-86e3dc43875d" - }, - "outputs": [ - { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXIAAAD4CAYAAADxeG0DAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nO3de3iU1bn38e8iDpBQMCqBLUEI1i0KghwCUlArYAEVMSoWqQWF+uJhK8UqJ9sth20lioqtWq27oO6CiIqNArJBBEXYogQJKCdRRCEqIBIOEiSH9f4xmZnMZGYyM5lkMjO/z3V5XTxrnsMalTsr93OvtYy1FhERiV8NYt0BERGpGQVyEZE4p0AuIhLnFMhFROKcArmISJw7JRYPbd68uc3KyorFo0VE4taGDRu+t9Zm+LbHJJBnZWWRn58fi0eLiMQtY8xX/tqVWhERiXMK5CIicU6BXEQkzsUkR+5PSUkJe/fu5cSJE7HuikRR48aNad26NQ6HI9ZdEUlY9SaQ7927l6ZNm5KVlYUxJtbdkSiw1nLw4EH27t1Lu3btYt0dkYRVbwL5iRMnFMQTjDGGM844gwMHDsS6KyIxl7exkJnLdvBNUTGt0lMZP7A9OV0zo3LvehPIAQXxBKT/piLOID759U8oLikDoLComMmvfwIQlWCul50iIrVs5rId7iDuUlxSxsxlO6JyfwXySowx/Pa3v3Ufl5aWkpGRweDBg8O6T1ZWFt9//31E52RlZdGpUyc6d+7MgAED+O6778J6dmVTp07l0UcfBeCBBx5gxYoVAc8tKCjgrbfech+/+eab5ObmRvxsEfH4pqg4rPZwKZBX0qRJEz799FOKi53/ct9++20yM6OTwwrHqlWr2Lx5M9nZ2Tz00ENen1lrKS8vD/ue06dP5/LLLw/4uW8gHzJkCJMmTQr7OSJSVav01LDaw6VA7uPKK69kyZIlAMyfP5/hw4e7P/vhhx/Iycmhc+fO9OrVi82bNwNw8OBBBgwYQMeOHbn11lupvOvS3Llz6dmzJ126dOG2226jrMz716tgLr30Uj7//HN2795N+/btGTlyJBdccAF79uxh5syZ9OjRg86dOzNlyhT3NX/+858599xzufjii9mxw/Nr2y233MJrr70GwPr16+nduzcXXnghPXv25PDhwzzwwAMsWLCALl26sGDBAl544QXuuusuAHbv3k2/fv3o3Lkz/fv35+uvv3bfc+zYsfTu3Zuzzz7bfX8R8TZ+YHtSHSlebamOFMYPbB+V+9erl51u48ZBQUF079mlCzzxRLWn3XjjjUyfPp3BgwezefNmRo8ezfvvvw/AlClT6Nq1K3l5eaxcuZKRI0dSUFDAtGnTuPjii3nggQdYsmQJs2fPBmDbtm0sWLCAtWvX4nA4uPPOO5k3bx4jR44MqcuLFy+mU6dOAOzcuZMXX3yRXr16sXz5cnbu3MlHH32EtZYhQ4awevVqmjRpwssvv0xBQQGlpaV069aN7t27e93z5MmTDBs2jAULFtCjRw+OHDlCWloa06dPJz8/n6eeegqAF154wX3N3Xffzc0338zNN9/MnDlzGDt2LHl5eQB8++23rFmzhu3btzNkyBCGDh0a0ncTSSauF5pJUbVSH3Tu3Jndu3czf/58rrzySq/P1qxZw8KFCwHo168fBw8e5MiRI6xevZrXX38dgKuuuorTTjsNgHfeeYcNGzbQo0cPAIqLi2nRokW1fejbty8pKSl07tyZBx98kKKiItq2bUuvXr0AWL58OcuXL6dr164AHDt2jJ07d3L06FGuvfZa0tLSAGd6xNeOHTs488wz3X1q1qxZtf354IMP3N9vxIgRTJgwwf1ZTk4ODRo0oEOHDuzbt6/ae4kkq5yumVEL3L7qZyAPYeRcm4YMGcJ9993Hu+++y8GDByO+j7WWm2++mRkzZoR13apVq2jevLn7uKioiCZNmnjdd/Lkydx2221e1z0Rg39vjRo1cv9ZG3mLxIZy5H6MHj2aKVOmuNMaLpdccgnz5s0D4N1336V58+Y0a9aMSy+9lJdeegmApUuXcujQIQD69+/Pa6+9xv79+wFnjv2rr/yuQhmWgQMHMmfOHI4dOwZAYWEh+/fv59JLLyUvL4/i4mKOHj3KokWLqlzbvn17vv32W9avXw/A0aNHKS0tpWnTphw9etTv83r37s3LL78MwLx587jkkktq/B1EJHrq54g8xlq3bs3YsWOrtE+dOpXRo0fTuXNn0tLSePHFFwFn7nz48OF07NiR3r1706ZNGwA6dOjAgw8+yIABAygvL8fhcPD000/Ttm3bGvVvwIABbNu2jV/84hcA/OxnP2Pu3Ll069aNYcOGceGFF9KiRQt3+qSyhg0bsmDBAu6++26Ki4tJTU1lxYoV9O3bl9zcXLp06cLkyZO9rnnyyScZNWoUM2fOJCMjg+eff75G/ReR6DKx+HU4Ozvb+m4ssW3bNs4///w674vUPv23laRnLQwfDgsWwKZN0LlzRLcxxmyw1mb7tmtELiJSmxYtgsqFBxlVdmqrMQVyEZHa8P333kH7/POdZdUNG0b9UfXqZaeqHhKP/ptK0rEWbrnFO4hv3gxbt9ZKEId6FMgbN27MwYMH9Rc/gbjWI2/cuHGsuyJSN5YvhwYNoKIQgocecgZ2nwq4aKs3qZXWrVuzd+9erV2dYFw7BIkktEOH4PTTPcdt28L27VBHg5h6E8gdDod2kRGR+HPHHfDss57jDRugW7c67UK9Sa2IiMSVVavAGE8QnzLFmUap4yAO9WhELiISFw4fhhYt4ORJ53HLlrBrF1SscRQLGpGLiITqD3+A9HRPEF+3Dr77LqZBHBTIRUSqt3atM40ya5bzeMIEZxrlooti268KUUutGGNSgHyg0Fob3t5oIiLVqM1d6AM6dgxat3amUwCaNYO9e6Fp09p9bpiiOSL/PbAtivcTEQE8u9AXFhVj8exCn7exsPYeev/9zoDtCuLvv+/8cz0L4hClQG6MaQ1cBfwjGvcTEakskl3o8zYW0id3Je0mLaFP7srQg/5HHznTKK59BMaOdaZRLr440u7XumilVp4AJgABf1QZY8YAYwD3Mq8iIqEIdxd61wjeFfxdI3ggcDrm+HE45xz49lvnscMBBw7AqafWrPN1oMYjcmPMYGC/tXZDsPOstc9Za7OttdkZtbD6l4gkrnB3oQ97BD99OjRp4gni77zjrEyJgyAO0Umt9AGGGGN2Ay8D/Ywxc6NwXxERIPxd6EMewW/c6EyjTJniPB4zxplG6devxn2uSzVOrVhrJwOTAYwxlwH3WWt/W9P7ioi4hLsLfav0VAr9BHP3CP7ECejQAb780vPhwYPe66XEEc3sFJG4EM4u9OMHtvfKkUOlEfwjj8DEiZ6Tly6FQYOi3d06FdVAbq19F3g3mvcUEXEJtZbc3wh+WGoROd08K3F+fdX1tFn0qjO1Euc0IheRuBBuJYp7BH/yJEfO70SzXZ+5P+t29zyKTz2dGQXf1P6kojqgKfoiEhciqSXnL3+BRo3cQfzW6/6TrImL+SHt1OqvjSMakYtIXAirlnzHDjjvPPfhW+37cOc1k6qkUQLdM95oRC4icSGkWvLSUsjO9grifPMNf75lut9ceKB7xhsFchGJC9XWkj/7rHM25oaKuYmvvOKsCT/zzLDr0OONUisiEhcC1pI3O+E92h48GN5806st3Dr0eGNisWt9dna2zc/Pr/Pnikj9k7exkGmLtnDoeAkA6akOpg7pWH2QLSuDyy6DNWs8bXv2OJedTVDGmA3W2mzfdqVWRCRm8jYWMv61Te4gDlBUXML4VzcFX63w+efhlFM8QXzuXGcaJYGDeDBKrYhIzMxctoOSsqpZgZJyy8xlO6qOyr/6CrKyPMeXXw7LlkED/2PSmGxGEQMK5CISM8HK/7w+Ky+HgQNhxQpP25dfegd1HxEtZRunlFoRkZgJVv7n/uyllyAlxRPE//EPZxolSBCHCCcQxSmNyEUkZsYPbM/41zZVSa84Ghj+s2sz72qU3r1h9WpnUA9BuJtRxDONyEUkZnK6ZjJz6IWcluZwt6U3PoXVa59g0KAenhN37nTuZB9iEIfwN6OIZxqRi0hMeS1Pu3AhDB3q+fBvf4M77ojovkGXsk0wCuQiEnVhV4t89x2ceabnuFs3WLfOOVMzQok+CagyBXIRqRHfoN33vAwWbigMrVrEWrjxRud0epetW+H884M+I9SAHM5mFPFMOXIRiZirxK+wqBiLM2jPW/d1aNUiixY5679dQfzxx52B3U8Q933G5Nc/CT5hKMloRC4iEfNX4hdo0Q93tciBA9CiheeD88+HggJo2DDkZ7h+MCTDaDsUGpGLSMTCKeVrdWpjuOUW7yC+ebMzlRIgiAd7RiKWEUZKgVxEIhaolM935e/+Xxew9v7L4cUXnQ0PPeRMo3TqFPEzErGMMFIK5CISsUDrfN/Uqw2Z6amkFx9l98ODmT3/T84Ps7KguBgmT67xMxKxjDBSypGLSMSClvjdfjv8/e+ekzdscJYVRvMZAmg9cpF6LS5X71u1Cvr18xxPmQJTp8asO4kk0HrkGpGL1FNxt3rf4cPOF5knTzqPW7aEXbsgLS22/UoCypGL1FNxtXrfuHGQnu4J4uvWOWdrKojXCQVykXoqLsru1qxxrlD4l784jydOdFajXHRRbPuVZJRaEamnWqWnUugnaNeLsrtjx5zbqh0+7Dxu1gz27oWmTWPbryRV4xG5MeYsY8wqY8xWY8wWY8zvo9ExkWRXb8vu7r/fGbBdQfz9951/VhCPmWiMyEuBe621HxtjmgIbjDFvW2u3RuHeIkmr3pXdffSRd8pk7FhPSkViqsaB3Fr7LfBtxZ+PGmO2AZmAArlIDdWL1fuOH4ezz4Z9+5zHDodzvZRTT43odnFZUlnPRfVlpzEmC+gKfBjN+4pIjEybBk2aeIL4O+84K1NqEMS1kmH0RS2QG2N+BiwExllrj/j5fIwxJt8Yk3/gwIFoPVZEasPGjc5qFNdEnjFjnNUolSf6RCCuSirjSFSqVowxDpxBfJ619nV/51hrnwOeA+fMzmg8V0Si7MQJ6NABvvzS03bwIJx+elRuHxcllXEoGlUrBpgNbLPWPl7zLolITDz8MKSmeoL40qXOUXiUgjhoJcPaEo3USh9gBNDPGFNQ8c+VUbiviNSFTz5xplEmTXIejxwJ5eUwaFDUH1VvSyrjXDSqVtZQdflhEanvTp6Erl2dGzu47N8PGRm19sh6V1KZIDSzUyQZPfEE3HOP5/iNN2DIkDp5dL0oqUwwCuQicSjiWuzt2703Nx461Ln5sdEv1fFM65GLxBnf5W3Bmdu0QGagoF5SAr16wccfu5uWLv+YBzccUoojjgRaj1yrH4rEmWA71xcWFTNuQQFdpy/3TLJ59lnn5sauIP7KK+R9vJc/rN6niTkJQoFcpB7I21hIn9yVtJu0hD65K4MGVH8rIvo6dLyEZ/6xzJkyueMOZ+Pgwc5qlBtu0MScBKMcuUiMhbsTUIoxlAVJiTYoL+OVlyaRXbjN07hnj3PZ2QqamJNYNCIXibFwR8fBgvgNm99m18xr3EF83OB7nZN6KgVx0MScRKMRuUgURVJNEu7o+LQ0B4eOl3i1ZR7ez9pnR7uP32/bhZHDptPqtCZ+7zF+YPsqL0w1MSd+KZCLREmkmyWHuxNQ5QG5seX8c8F/cvFXm9xtF98+m72ntgwamDUxJ7EokItESbAUSbAA2fe8DOau+9pvuz+Hi52j8SFb3+Wvix51t4+/Yiwrel1F0fGSwGWIlWhiTuJQIBeJkkhfIK7a7n9Z50DtF5pj5OXe6D5en9mBYb+ZQXmDFFJLypk1rIsCdJJRIBeJkkg3Sw72A8Ar535qYxYuzSXv/RXuc3455jm+Oq2V+ziU3wAk8SiQi0RJ3/MymLfuayrXlATLU7uCdKAalFNTHe6c+xXb1/DMG7nuzwomP0ROeWe/16mEMPkokItEQd7GQhZuKPQKyga4vrv/PLS/afa+iopLyDh2iG1Pj3C3fdLy5/zH3X9j9R8HkJm7MqLfACTxKJCLREGgafPzP9zDvHVfV6kK8Xe+98WWp954mME71rib+t/6DF+ccRbmqPNlp0oIxUWBXCQKAqUzXJN3fEsRg6U/+n/+IbMX/pf7eHq//8ecHte4j10jbpUQiosCuUgUBHrRWVnlF5H+zj/9+GE+fvIm9/HOM87iylF/pSTF4W7zHXGrhFBAU/RFosLfFmb+uEbiXudby2OLH/MK4gNHP8Wvbn3GK4gDNHbor6xUpRG5SBT4pjkaBFjYyjct8t6Tc5n1/CT354/3vYW/9hwa8DmHjpeENFtUkosCuUiUVE5z+KtK8UqL/PADOd1ak+P6MCsLtm3j7G0HyVy2g8Ki4oCrHKpWXHwpkIuEKZSFsYK+iLz9dvj73z0nb9gA3bq5r6t8r3aTlvitM1etuFSmQC5SIZQAHc7CWFVeRK5cCabScrJTpsDUqUH7FOlsUUkuCuQihB6gAy2MNW5BAfe+sokya6suWHX4MGRkOPfNBGjZEnbtgrS0avulWnEJhV6BixD65g7BUhq+NeN5Gwth3DhIT/cE8XXr4LvvQgri4PwhMuO6TmSmp2Jwbq4847pOyo+LF43IRQh95cJQ6sUBOn65mZxuV7iPPxv1H4w69zq++df3tFq1MqyJO6oVl+ookIsQei56/MD2jH9tEyVl/pe6SjtZzLqnb6bZyeMAHG2Uxqzn32H+1iKKK+5f3YYTkewyJMlNqRUR/E/oCZiLDrBc4cR3X2DrrBvcQXzoTQ/TadwrvLj5h5D35HTl6guLirH4pGlEAtCIXITQ1y2ZuWwHJeXekbzLNzvI++e97uPnu1/NtMtvcx8H2izZXzon0l2GJLlFJZAbYwYBfwFSgH9Ya3OruUSk3nHlol2pjXsWFDBz2Q53QM/bWOiVfmlccoL3/34rGT8WAfBTyilk3z2Po428NzwONLHHXwlhpLsMSXKrcSA3xqQATwO/AvYC640xb1prt9b03iJ1KW9jIdMWbfHaod6V2sj/6gcWbvCkN8atmce4tfPdx3ff+ij977yR0tc/AZ9Sweu7Z7JwQ2FIJYSqG5dIRCNH3hP43Fq7y1p7EngZuKaaa0TqFVduunIQdykuKWP+h3soLimj474v2P3wYHcQf+nCQZz/p6X0v/PGgKWCD+Z0CrmEMKxcvUiFaKRWMoE9lY73Ahf5nmSMGQOMAWjTpk0UHisSHXkbC92TeQI5peQn3v3HHZx1eJ+77cKx8zmc2pQnKgXlQKWCoZYQao1xiUSdvey01j4HPAeQnZ0d+G+MSB36U94nVfbZ9HXHuleZ+N6L7uObb5jGe2d3B5yj63CDbHXlhaobl3BFI5AXAmdVOm5d0SZSr+VtLAwaxNsf2M2yOXd5zu/Un3FXjANjAOeenH3Pywj7maGu1SISqmgE8vXAvxtj2uEM4DcCv4nCfUVqdXJMoB3sHWUlLJ1zN+f8sNfTuH8/+Wu/w1QK/BZYuKGQ7Lanh9wnlRdKbajxy05rbSlwF7AM2Aa8Yq3dUtP7itT25Bh/JX2j17/Bzkev9QTxN94AayEjg1XbD1QJ/IEm9oTzzGDtIqGISo7cWvsW8FY07iXiUtuj18qlftl7t/DavInuzwp/NZjMZW+60yhQfRAO5bcHlRdKbdDMTqm3ggXOQEEznFRM3/MyeHXNF+x47Fqv9tw5K5k0qi/gHZyDbd8Wau471GVptd6KhMPYICVXtSU7O9vm5+fX+XMlvvTJXRlwpUGD95InwSbe+KvZzttYSNmo0Vy/abm7bVaf33Dwvkk8mNPJfY5v0PXluv/Miu3ZfGWmp7J2Ur8qzw4WpANtE6fla8UYs8Fam+3brhG51Fv+Rq8u/nLV8z/cU2XE7DcVs349OT17ep3XbsKbWNOAzO0H3G3+UjvgnHJfbq1XEL5nQYHf7+Dvt4rqygv1QlTCpUAu9VblyTGhrAFe7eJUJSXQsKHXZ5f/7m983rxN1XMJnNopt5Yvc6/yaotm7lsvRCVcWsZW6rWcrpmsndQPU/2ppBj/Z7VKT4WxY72C+D9/eSNZExd7BXH3uX7+HOgcl2hOrQ/nuSKgQC5xoroglupIYfhFZ1UJph327WLt5P7w5JOextJSms56tNrA6y84Axw/WVqlBDKaW7JpvRUJl1IrEhf85ctdLzwrb3ac3fZ0pr65hSM/nmDXTO+124bd9jeG3zaEnJSUkNY0cf156ptbKCr2LKZ16HiJ34qUaE2t13orEi5VrUidi7S0LtTr/nnZcEa897L7uPJGD5FUfwSqnvFXkSJSm1S1IvVCpGuNhBTEt26Fjh0ZUanpnPvyKE3x/G8eSfWHXj5KfadALnUqktK6aoN/eTmkeOeUrxnxGJta+c8phxuANRtT6ju97JQ6FcnodtqiLYE3L54+3TuIjxpF3sd7+axth4D3CzcAB3r52Pe8DPrkrqTdpCX0yV2pDZIlZjQilzoV7ug2b2Oh31172h76hvceHuPdeOIENGpETsWh77Zt4HxBWlhUTJ/clSHn5v29fOx7XobXLFItRyuxpJedUqfCnX7eZdpyr4oRrGX3I1d7nbN69utcOvpa/HHl1guLiv1O64+0RFAvQCUW9LJT6oVwS+sqB/ExHy7k/nefdx8vPu8S7rpmIo4vDE2mLedwcUmV+7lKAv0FXn+5+VArY/QCVOoTBXKpc8HqrSsH0lNTHQC0PryPNc/+zuu89n9YyE+ORgCUlFl3wA+U4ggl8IZTUaMXoFKfKJBLVERSG+57jW/euej4ST594tf87KQnYP5m2IP8X1aXoPf1N9IOJfCGU1ET6nK0InVBVStSY5Hs5OPvmnnrvnYHxpEbFrH7kavdQfydn/cga+LiaoO4i+8IPJRp7+GkS6I5JV+kpjQilxqLpDbc3zUWaHn0ez782y1e7R3ueZXjDcNLWfimOELJzYebLtFu91JfKJBLjQUaybrK/PwFTn/XfPj0SFoe+8F9PGroFFb9vIfXOZkBgm1lgVIc1QVepUskXim1IjUW7AVfoHRL5Wt+vWk5ux8e7A7i6866gKyJi6sEcVdQzQzyvJqkOJQukXilEbnUWLCdfCqrnG4ZP7A9j/7PatY88RuvczqNW8DRRk2qXJvpM6Kvra3QlC6ReKRALhHxrTi5vnum363WfLlSKjk39iPns8/c7bfl3M+y9r39XuM7yUbLvIp4UyCXsPmrt164obDaIA5w8+7/AzPYfbzlzH8nZ9QTlJQFvjaUSTb5X/2gwC5JS4FcwhaoSiXFmIDBPL34CAV/9U6jdBn7EkWpzXBYOC3N4XdNFaiag/f3g2Tuuq/dn2vdE0k2etkpfuVtLAy4sl+gEXKZtTgaVN03880Xx3kF8bFX30fWxMUUpTYDoKTcktbwFJ4Y1iWkLc4C7W5fmXt1RJEkoBG5VFHdVPVA9dbpqQ5+PFnqPh644//4e95DnhPOPZd21z6OvzH7N0XFIee+Q13PROueSLJQIE9y/qbWVzfBJ9D+mSVl5ZSUWZr+9COfPDHM+0HffQctW9IqwKqBrvRJKFUjgX6QBLqnSKKrUWrFGDPTGLPdGLPZGPMvY0x6tDomtS/Q1PpAQdJdcdI1k+u7Z1I5iWKBH0+W8fJLk7yC+PgrxtJu4mJo2dJ5HIUd4gPtbl+Te4rEs5rmyN8GLrDWdgY+AybXvEtSV4K9tPSn8gh31fYDXimSy75Yz+6HB9Nrz6cAfPez08mauJhXOw+ggTHuHHs0Jt34u8dve7XRRB5JWlHbWMIYcy0w1Fp7U3XnamOJ+qHdpCV+89XgHNFWDvKOFEOThqe41/x2jdrTThazddYNXtdedOcL7GvavMr9FFxFaibQxhLRrFoZDSyN4v2klgXKIbtGtK4R7mlpDrDOTR5cKRgDzH5tmlcQf+Dy2+gydRnfN8uock9VkYjUnmoDuTFmhTHmUz//XFPpnD8CpcC8IPcZY4zJN8bkHzhwIDq9lxoJlq/O6ZrJ2kn9+DL3KtIankJJuWfs3md3AV8+PJj+X6wH4EjDNLImLOLVXjlMHdKR8gC/5amKRKR2VFu1Yq29PNjnxphbgMFAfxskT2OtfQ54DpyplfC6KbUh3HK/RiU/sePx670+u27CS2w0zTgtzYG1cM+CgoDPc+34IyLRVdOqlUHABGCItfZ4dLokdSXUXX1apafy1BsPewXxhy4bRZ8Z7/D6w8OZNawLJ0rK3amXQD+lA7xDFZEaqmkd+VNAI+Bt4/xbus5ae3uNeyW1LuT9Kd9/n7WT+7sPyzGcPeFNUhuewoyK8r5QZloCAafgi0jN1CiQW2vPiVZHpG5Vu6vPTz9B48Zen//63v9h/Smnu5eUBfzuTh9IoLJGEakZrbWSpILuT/m733kH8WnTwFpeeXQEs4Y598wct6CAexYUhBzEgZBWRxSR8GmKfgIJlvP2/Szdz2qDXb7ZQd4/7/W+aVkZNGjgvkfldEy4YTnYzj4iEjkF8gQRLOcNVPnM0cDgSDGUlFlOKSvl80dzvG+4ZQt06ODVFGou3B8DmjIvUksUyBNEsJz3jz+VVvmspNySnurgT8ueYegHeZ4PJk6E3Fy/zwilDjw91cFPpeVVFtS6qVcbzeoUqSUK5Aki2E72/nTYt4u3Xhjr3VhaCimBF6M6NdVBUXHgypNURwpTh3QEtA2bSF1SII8T1dV8+8t5AzQwUGlSJg3Ky9g18xrvkzZuhC5dqn12sCDuuzmyArdI3VEgjwOh1HwHKgipHMQnvfs8t3+40H38xY2j+Pn8OWE92x8DXpsji0jdUiCPA9XWfAOHg4yWz/n+a1bMvtOrrccDS1g/7UqvtlA3mfDlb09NpVZE6o4CeRwIWvNdwd+uOcaW8+UjQ7zarhnxGJ+17cCMnE5e7YFG/dUFcd8NHEKeMSoiUaMJQXEg0HKz6WmeRah8VzL8/ZqXvIL4ouxBtJu4mO87dPG7Lni4m0yAMy9+ffdMZi7b4d6kedqiLQF/exCR2qEReRwYP7A941/bREmZdyL82IlS8jYWeu1zOX/eOyx47Gav86HSagMAAAwwSURBVNrf+y+aN2/GrCApjkCj/jJrq2wy4dokAqrWpweiJWxFao8CeZhikf/N6ZrJ1De3VKkaKSm3njy5teR0a03laT03jZzJ2jPPB6pOEPL9DoE2NM6slCv3/c59cleGPEFIGyGL1B4F8jDEMv8b6GXmN0XFMHMmTJjgblvR+TJuveK+KucWl5QxbdEWTpSUV/kO13fPZOGGwiojb1fQ9vf9Qh1layNkkdqlQB6GUKpHaou/EXPrw/tY8+zvvNq6TMyjKMh/Vn+15sUlZazafoAZ13UK67eNQKP49FQHTRqdoqoVkTqiQB6GUKpHosU3hdP3vAzPiNlats4aSlrJT54LVqygz3pDUYR9+aaoOODIO5DxA9tXqWxxze5U4BapO6paCUOgPG+087+uFE5hUbF7s+OFGwq5vnsmv9/6v+x+5GpPEL/6audsoP79q/2BkupIIT3AdmuRfIecrplemzS7Nm1WEBepWxqRhyHQCDTa+V9/KZxmP+zjwWuv8Gr7xeQ8TjROpWjSElqlpwZdC6XyZhDR/A7hjuJFJPoUyMMQ6mbFoQpUAeM7sv7oqRG0+PGQ+/iWoVN49+c9oByoyHkXFhXjSDE4GhivHe9dpYL+6saVwxZJDCbIxve1Jjs72+bn59f5c+sTf2uYuILuzGU7KCwq5teblvPI//7V/fm6Np24cfiMoPdNczTgtCaNFKRFEpAxZoO1Ntu3XSPyOlR5BN7AmCpbn7kqYP6UfTpX/Kqr12c97nuNAynee2j6c7yknIcUvEWSigJ5lAVKl/iOwAPtXzn3kRG0O/SN+/j2nPv55KL+/HFge+59ZVNI+17WRTmkiNQfCuRRFGzCUHWrCF6zZRV/WfyYpyE7G9av59lK54xbUBBSPzQdXiS5KJBHUbAJQ4GCa3rxEQr++hvvxoMH4fTTvZryNhZiCG3DY02HF0kuqiOPomAThvwF10Uv/N4riOf/+UlnTbhPEAfnD4lQgrimw4skHwXyKAo2YajyMrODdqxl98OD6bTvC+cJ7duDtWTff1fAewdLl2hCjkhyU2olioJNGMrpmonj6GGu+mVH74v27YMWLaq9d7DVCbXNmkhyUyCvId8qleu7Z7Jq+4Gqddy//CVXrV7tuXD2bPK6DmTmnE9Dqvmuq1mlIhJ/FMgj4ArehUXFXi8gXWuieKU3liyBbq3d1xZntOTyP7xE4WfF8JmnCqWwqJjxr24C/C+JG+1ZpSKSOKIys9MYcy/wKJBhrf2+uvNjNbMzkk0hgq5CGEBmeipr7+oJTZt6tf/vsnzuef9A0GvTUx0UTBkQ3hcTkaQQaGZnjV92GmPOAgYAX9f0XrXJ34qCk1//hLyNhUGvGf/qJq9r5q77utpdcR787wneQfypp8Ba/uvjw9VeG2jRKxGRQKKRWpkFTADeiMK9ak0om0L4jr5/+PEnrwWoqtNndwHzFvzJ09CsGRQVQcUGxpqoIyK1oUaB3BhzDVBord1kguy2XnHuGGAMQJs2bWry2IhUtymEv1mZoWpccoLtjw/1bty9G9q29WoKVHlS2Wlp/tcLFxEJpNrUijFmhTHmUz//XAPcDzwQyoOstc9Za7OttdkZGRk17XfYqtsUorop9IE8lZfrFcQ/Hfcn56QenyAOeNWS++NIMUy5umPAz0VE/Kl2RG6tvdxfuzGmE9AOcI3GWwMfG2N6Wmu/i2ovo6C68r1w0x499nzKqy9N8jQYA2VlXBDkNxPfypNTUx0YA0XHS1SFIiIRizi1Yq39BHDPZDHG7AayQ6laiYXqyvdCSXsAnFJWyo7HriPFlnsad+6Ec84JuR+RBOtIKm5EJDkkVR15sCAaaMR+ffdMlmz+lkPHS/jtx0t48O1n3J9vu/1ezn/m0Vrvd7BVFRXMRSRqgdxamxWte8VCsBH7gx0bw7nnus9dc/4v+H7uK+RUmuhTm0KpuBGR5JVUI/LqVBmxl5ZC797wwQeetj17uLh13QRwl+oqbkQkuSXd6od5Gwvpk7uSdpOW0Cd3ZeAJQbNng8PhCeJz5zqrUeo4iEP1FTciktySKpCHNLvzyy+dFSi33uo8/tWvoKwMbropJn0G/2WLWjBLRFySKpAHyzVTXg79+8PZZ3s+3L0bli+HBrH915TTNZMZ13XSuuMi4ldS5cgD5ZR7rn0LUvp7GubMgVGj6qhXoYm0bFFEEl9SBXLfWvEzjxzgg2cqBew+feC99yAl8OxLEZH6JqlSK65cs7HlzHl1qncQ//xzWLNGQVxE4k5SBfKcrpm82PQrvnxkCP12OddDL7h/hrMa5ec/j3HvREQikzyplW+/hVat6Ok67t4dPviALg6tNigi8S3xR+TWwtCh0KqVp23bNsjPd9aJi4jEucQO5G+84SwdXLjQeTxrljOwn3debPslIhJFiZlaOXAAWrTwHHfoABs3QsOGseuTiEgtSawRubUwYoR3EP/kE9iyRUFcRBJW4gTypUudaZS5c53HubnOwH7BBbHtl4hILYv/1MrBg9C8uee4XTvYuhUaN45dn0RE6lB8j8jHjPEO4h9/DLt2KYiLSFKJz0C+cqVzhcL//m/n8dSpzjRK164x7ZaISCzEV2rlyBE44wznhg8A//Zv8MUXkJYW236JiMRQfI3IL7zQE8TXrXPO1lQQF5EkF18j8nfegQ8/hOHDY90TEZF6I74C+dlne2/8ICIi8RPI8zYW+t3hXkQk2cVFIHfttenaps211yagYC4iSS8uXnYG3WtTRCTJxUUgD7TXZqB2EZFkEheBvFV6aljtIiLJJC4CuWuvzcpSHSmMH9g+Rj0SEak/ahzIjTF3G2O2G2O2GGMeiUanfOV0zWTGdZ3ITE/FAJnpqcy4rpNedIqIUMOqFWNMX+Aa4EJr7U/GmBbVXROpnK6ZCtwiIn7UdER+B5Brrf0JwFq7v+ZdEhGRcNQ0kJ8LXGKM+dAY854xpkegE40xY4wx+caY/AMHDtTwsSIi4lJtasUYswL4Nz8f/bHi+tOBXkAP4BVjzNnWWut7srX2OeA5gOzs7Cqfi4hIZKoN5NbaywN9Zoy5A3i9InB/ZIwpB5oDGnKLiNSRmqZW8oC+AMaYc4GGwPc17ZSIiITO+MmChH6xMQ2BOUAX4CRwn7V2ZQjXHQC+ivjB4WlOcv1wSbbvC8n3nfV9E1+g79zWWpvh21ijQB4PjDH51trsWPejriTb94Xk+876vokv3O8cFzM7RUQkMAVyEZE4lwyB/LlYd6COJdv3heT7zvq+iS+s75zwOXIRkUSXDCNyEZGEpkAuIhLnEj6QG2NmViyzu9kY8y9jTHqs+1TbjDE3VCwrXG6MSdiyLWPMIGPMDmPM58aYSbHuT20zxswxxuw3xnwa677UBWPMWcaYVcaYrRX/P/8+1n2qTcaYxsaYj4wxmyq+77RQr034QA68DVxgre0MfAZMjnF/6sKnwHXA6lh3pLYYY1KAp4ErgA7AcGNMh9j2qta9AAyKdSfqUClwr7W2A871nP4jwf8b/wT0s9ZeiHOS5SBjTK9QLkz4QG6tXW6tLa04XAe0jmV/6oK1dpu1NtF3pu4JfG6t3WWtPQm8jHNt/IRlrV0N/BDrftQVa+231tqPK/58FNgGJOymBNbpWMWho+KfkKpREj6Q+xgNLI11JyQqMoE9lY73ksB/yZOdMSYL6Ap8GNue1C5jTIoxpgDYD7xtrQ3p+9Zoh6D6IthSu9baNyrO+SPOX9Xm1WXfakso31kkERhjfgYsBMZZa4/Euj+1yVpbBnSpeJf3L2PMBdbaat+JJEQgD7bULoAx5hZgMNDf31rp8ai675wECoGzKh23rmiTBGKMceAM4vOsta/Huj91xVpbZIxZhfOdSLWBPOFTK8aYQcAEYIi19nis+yNRsx74d2NMu4pVOG8E3oxxnySKjDEGmA1ss9Y+Huv+1DZjTIarqs4Ykwr8CtgeyrUJH8iBp4CmwNvGmAJjzLOx7lBtM8Zca4zZC/wCWGKMWRbrPkVbxQvsu4BlOF+CvWKt3RLbXtUuY8x84AOgvTFmrzHmd7HuUy3rA4wA+lX83S0wxlwZ607VojOBVcaYzTgHKm9baxeHcqGm6IuIxLlkGJGLiCQ0BXIRkTinQC4iEucUyEVE4pwCuYhInFMgFxGJcwrkIiJx7v8DFIKP9D3NNnoAAAAASUVORK5CYII=\n", - "text/plain": [ - "
" - ] - }, - "metadata": { - "needs_background": "light", - "tags": [] - }, - "output_type": "display_data" - } - ], - "source": [ - "import matplotlib.pyplot as plt\n", - "plt.scatter(xs, ys)\n", - "plt.plot(xs, params.weight * xs + params.bias, c='red', label='Model Prediction')\n", - "plt.legend()\n", - "plt.show()" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "4wFJcqbhbn81" - }, - "source": [ - "## Aside: hosts and devices in JAX\n", - "\n", - "When running on TPU, the idea of a 'host' becomes important. A host is the CPU that manages several devices. A single host can only manage so many devices (usually 8), so when running very large parallel programs, multiple hosts are needed, and some finesse is required to manage them." - ] - }, - { - "cell_type": "code", - "execution_count": 24, - "metadata": { - "id": "3DO8NwW5hurX", - "outputId": "6df0bdd7-fee2-4805-9bfe-38e41bdaeb50" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "[TpuDevice(id=0, host_id=0, coords=(0,0,0), core_on_chip=0),\n", - " TpuDevice(id=1, host_id=0, coords=(0,0,0), core_on_chip=1),\n", - " TpuDevice(id=2, host_id=0, coords=(1,0,0), core_on_chip=0),\n", - " TpuDevice(id=3, host_id=0, coords=(1,0,0), core_on_chip=1),\n", - " TpuDevice(id=4, host_id=0, coords=(0,1,0), core_on_chip=0),\n", - " TpuDevice(id=5, host_id=0, coords=(0,1,0), core_on_chip=1),\n", - " TpuDevice(id=6, host_id=0, coords=(1,1,0), core_on_chip=0),\n", - " TpuDevice(id=7, host_id=0, coords=(1,1,0), core_on_chip=1)]" - ] - }, - "execution_count": 24, - "metadata": { - "tags": [] - }, - "output_type": "execute_result" - } - ], - "source": [ - "jax.devices()" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "sJwayfCoy15a" - }, - "source": [ - "When running on CPU you can always emulate an arbitrary number of devices with a nifty `--xla_force_host_platform_device_count` XLA flag, e.g. by executing the following before importing JAX:\n", - "```python\n", - "import os\n", - "os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=8'\n", - "jax.devices()\n", - "```\n", - "```\n", - "[CpuDevice(id=0),\n", - " CpuDevice(id=1),\n", - " CpuDevice(id=2),\n", - " CpuDevice(id=3),\n", - " CpuDevice(id=4),\n", - " CpuDevice(id=5),\n", - " CpuDevice(id=6),\n", - " CpuDevice(id=7)]\n", - "```\n", - "This is especially useful for debugging and testing locally or even for prototyping in Colab since a CPU runtime is faster to (re-)start." - ] - } - ], - "metadata": { - "accelerator": "TPU", - "colab": { - "name": "JAX Parallelism", - "provenance": [] - }, - "jupytext": { - "formats": "ipynb,md:myst" - }, - "kernelspec": { - "display_name": "Python 3", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.7.6" - } - }, - "nbformat": 4, - "nbformat_minor": 0 -} diff --git a/docs/jax-101/06-parallelism.md b/docs/jax-101/06-parallelism.md deleted file mode 100644 index f69301bcb74c..000000000000 --- a/docs/jax-101/06-parallelism.md +++ /dev/null @@ -1,414 +0,0 @@ ---- -jupytext: - formats: ipynb,md:myst - text_representation: - extension: .md - format_name: myst - format_version: 0.13 - jupytext_version: 1.16.0 -kernelspec: - display_name: Python 3 - name: python3 ---- - -+++ {"id": "tCOWitsAS1EE"} - -# Parallel Evaluation in JAX - -[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/jax-101/06-parallelism.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/jax-101/06-parallelism.ipynb) - -*Authors: Vladimir Mikulik & Roman Ring* - -In this section we will discuss the facilities built into JAX for single-program, multiple-data (SPMD) code. - -SPMD refers to a parallelism technique where the same computation (e.g., the forward pass of a neural net) is run on different input data (e.g., different inputs in a batch) in parallel on different devices (e.g., several TPUs). - -Conceptually, this is not very different from vectorisation, where the same operations occur in parallel in different parts of memory on the same device. We have already seen that vectorisation is supported in JAX as a program transformation, `jax.vmap`. JAX supports device parallelism analogously, using `jax.pmap` to transform a function written for one device into a function that runs in parallel on multiple devices. This colab will teach you all about it. - -+++ {"id": "7mCgBzix2fd3"} - -## TPU Setup - -This notebook requires multiple accelerators and we recommend running it using Kaggle TPU VMs. - -+++ {"id": "gN6VbcdRTcdE"} - -Next run the following to see the TPU devices you have available: - -```{code-cell} ipython3 -:id: tqbpCcqY3Cn7 -:outputId: 1fb88cf7-35f7-4565-f370-51586213b988 - -import jax -jax.devices() -``` - -+++ {"id": "4_EDa0Dlgtf8"} - -## The basics - -The most basic use of `jax.pmap` is completely analogous to `jax.vmap`, so let's return to the convolution example from the [Vectorisation notebook](https://colab.research.google.com/github/google/jax/blob/main/docs/jax-101/03-vectorization.ipynb). - -```{code-cell} ipython3 -:id: IIQKBr-CgtD2 -:outputId: 6e7f8755-fdfd-4cf9-e2b5-a10c5a870dd4 - -import numpy as np -import jax.numpy as jnp - -x = np.arange(5) -w = np.array([2., 3., 4.]) - -def convolve(x, w): - output = [] - for i in range(1, len(x)-1): - output.append(jnp.dot(x[i-1:i+2], w)) - return jnp.array(output) - -convolve(x, w) -``` - -+++ {"id": "lqxz9NNJOQ9Z"} - -Now, let's convert our `convolve` function into one that runs on entire batches of data. In anticipation of spreading the batch across several devices, we'll make the batch size equal to the number of devices: - -```{code-cell} ipython3 -:id: ll-hEa0jihzx -:outputId: 788be05a-10d4-4a05-8d9d-49d0083541ab - -n_devices = jax.local_device_count() -xs = np.arange(5 * n_devices).reshape(-1, 5) -ws = np.stack([w] * n_devices) - -xs -``` - -```{code-cell} ipython3 -:id: mi-nysDWYbn4 -:outputId: 2d115fc3-52f5-4a68-c3a7-115111a83657 - -ws -``` - -+++ {"id": "8kseIB09YWJw"} - -As before, we can vectorise using `jax.vmap`: - -```{code-cell} ipython3 -:id: TNb9HsFXYVOI -:outputId: 2e60e07a-6687-49ab-a455-60d2ec484363 - -jax.vmap(convolve)(xs, ws) -``` - -+++ {"id": "TDF1vzt_5GMC"} - -To spread out the computation across multiple devices, just replace `jax.vmap` with `jax.pmap`: - -```{code-cell} ipython3 -:id: KWoextrails4 -:outputId: bad1fbb7-226a-4538-e442-20ce0c1c8fad - -jax.pmap(convolve)(xs, ws) -``` - -+++ {"id": "E69cVxQPksxe"} - -Note that the parallelized `convolve` returns a `jax.Array`. That is because the elements of this array are sharded across all of the devices used in the parallelism. If we were to run another parallel computation, the elements would stay on their respective devices, without incurring cross-device communication costs. - -```{code-cell} ipython3 -:id: P9dUyk-ciquy -:outputId: 99ea4c6e-cff7-4611-e9e5-bf016fa9716c - -jax.pmap(convolve)(xs, jax.pmap(convolve)(xs, ws)) -``` - -+++ {"id": "iuHqht-OYqca"} - -The outputs of the inner `jax.pmap(convolve)` never left their devices when being fed into the outer `jax.pmap(convolve)`. - -+++ {"id": "vEFAJXN2q3dV"} - -## Specifying `in_axes` - -Like with `vmap`, we can use `in_axes` to specify whether an argument to the parallelized function should be broadcast (`None`), or whether it should be split along a given axis. Note, however, that unlike `vmap`, only the leading axis (`0`) is supported by `pmap` at the time of writing this guide. - -```{code-cell} ipython3 -:id: 6Es5WVuRlXnB -:outputId: 7e9612ae-d6e0-4d79-a228-f0403fcf8237 - -jax.pmap(convolve, in_axes=(0, None))(xs, w) -``` - -+++ {"id": "EoN6drHDOlk4"} - -Notice how we get equivalent output to what we observe above with `jax.pmap(convolve)(xs, ws)`, where we manually replicated `w` when creating `ws`. Here, it is replicated via broadcasting, by specifying it as `None` in `in_axes`. - -+++ {"id": "rRE8STSU5cjx"} - -Keep in mind that when calling the transformed function, the size of the specified axis in arguments must not exceed the number of devices available to the host. - -+++ {"id": "0lZnqImd7G6U"} - -## `pmap` and `jit` - -`jax.pmap` JIT-compiles the function given to it as part of its operation, so there is no need to additionally `jax.jit` it. - -+++ {"id": "1jZqk_2AwO4y"} - -## Communication between devices - -The above is enough to perform simple parallel operations, e.g. batching a simple MLP forward pass across several devices. However, sometimes we need to pass information between the devices. For example, perhaps we are interested in normalizing the output of each device so they sum to 1. -For that, we can use special [collective ops](https://jax.readthedocs.io/en/latest/jax.lax.html#parallel-operators) (such as the `jax.lax.p*` ops `psum`, `pmean`, `pmax`, ...). In order to use the collective ops we must specify the name of the `pmap`-ed axis through `axis_name` argument, and then refer to it when calling the op. Here's how to do that: - -```{code-cell} ipython3 -:id: 0nCxGwqmtd3w -:outputId: 6f9c93b0-51ed-40c5-ca5a-eacbaf40e686 - -def normalized_convolution(x, w): - output = [] - for i in range(1, len(x)-1): - output.append(jnp.dot(x[i-1:i+2], w)) - output = jnp.array(output) - return output / jax.lax.psum(output, axis_name='p') - -jax.pmap(normalized_convolution, axis_name='p')(xs, ws) -``` - -+++ {"id": "9ENYsJS42YVK"} - -The `axis_name` is just a string label that allows collective operations like `jax.lax.psum` to refer to the axis bound by `jax.pmap`. It can be named anything you want -- in this case, `p`. This name is essentially invisible to anything but those functions, and those functions use it to know which axis to communicate across. - -`jax.vmap` also supports `axis_name`, which allows `jax.lax.p*` operations to be used in the vectorisation context in the same way they would be used in a `jax.pmap`: - -```{code-cell} ipython3 -:id: nT61xAYJUqCW -:outputId: e8831025-78a6-4a2b-a60a-3c77b35214ef - -jax.vmap(normalized_convolution, axis_name='p')(xs, ws) -``` - -+++ {"id": "JSK-9dbWWV2O"} - -Note that `normalized_convolution` will no longer work without being transformed by `jax.pmap` or `jax.vmap`, because `jax.lax.psum` expects there to be a named axis (`'p'`, in this case), and those two transformations are the only way to bind one. - -## Nesting `jax.pmap` and `jax.vmap` - -The reason we specify `axis_name` as a string is so we can use collective operations when nesting `jax.pmap` and `jax.vmap`. For example: - -```python -jax.vmap(jax.pmap(f, axis_name='i'), axis_name='j') -``` - -A `jax.lax.psum(..., axis_name='i')` in `f` would refer only to the pmapped axis, since they share the `axis_name`. - -In general, `jax.pmap` and `jax.vmap` can be nested in any order, and with themselves (so you can have a `pmap` within another `pmap`, for instance). - -+++ {"id": "WzQHxnHkCxej"} - -## Example - -Here's an example of a regression training loop with data parallelism, where each batch is split into sub-batches which are evaluated on separate devices. - -There are two places to pay attention to: -* the `update()` function -* the replication of parameters and splitting of data across devices. - -If this example is too confusing, you can find the same example, but without parallelism, in the next notebook, [State in JAX](https://colab.research.google.com/github/google/jax/blob/main/docs/jax-101/07-state.ipynb). Once that example makes sense, you can compare the differences to understand how parallelism changes the picture. - -```{code-cell} ipython3 -:id: cI8xQqzRrc-4 - -from typing import NamedTuple -import functools - -class Params(NamedTuple): - weight: jnp.ndarray - bias: jnp.ndarray - - -def init(rng) -> Params: - """Returns the initial model params.""" - weights_key, bias_key = jax.random.split(rng) - weight = jax.random.normal(weights_key, ()) - bias = jax.random.normal(bias_key, ()) - return Params(weight, bias) - - -def loss_fn(params: Params, xs: jnp.ndarray, ys: jnp.ndarray) -> jnp.ndarray: - """Computes the least squares error of the model's predictions on x against y.""" - pred = params.weight * xs + params.bias - return jnp.mean((pred - ys) ** 2) - -LEARNING_RATE = 0.005 - -# So far, the code is identical to the single-device case. Here's what's new: - - -# Remember that the `axis_name` is just an arbitrary string label used -# to later tell `jax.lax.pmean` which axis to reduce over. Here, we call it -# 'num_devices', but could have used anything, so long as `pmean` used the same. -@functools.partial(jax.pmap, axis_name='num_devices') -def update(params: Params, xs: jnp.ndarray, ys: jnp.ndarray) -> tuple[Params, jnp.ndarray]: - """Performs one SGD update step on params using the given data.""" - - # Compute the gradients on the given minibatch (individually on each device). - loss, grads = jax.value_and_grad(loss_fn)(params, xs, ys) - - # Combine the gradient across all devices (by taking their mean). - grads = jax.lax.pmean(grads, axis_name='num_devices') - - # Also combine the loss. Unnecessary for the update, but useful for logging. - loss = jax.lax.pmean(loss, axis_name='num_devices') - - # Each device performs its own update, but since we start with the same params - # and synchronise gradients, the params stay in sync. - new_params = jax.tree_map( - lambda param, g: param - g * LEARNING_RATE, params, grads) - - return new_params, loss -``` - -+++ {"id": "RWce8YZ4Pcmf"} - -Here's how `update()` works: - -Undecorated and without the `pmean`s, `update()` takes data tensors of shape `[batch, ...]`, computes the loss function on that batch and evaluates its gradients. - -We want to spread the `batch` dimension across all available devices. To do that, we add a new axis using `pmap`. The arguments to the decorated `update()` thus need to have shape `[num_devices, batch_per_device, ...]`. So, to call the new `update()`, we'll need to reshape data batches so that what used to be `batch` is reshaped to `[num_devices, batch_per_device]`. That's what `split()` does below. Additionally, we'll need to replicate our model parameters, adding the `num_devices` axis. This reshaping is how a pmapped function knows which devices to send which data. - -At some point during the update step, we need to combine the gradients computed by each device -- otherwise, the updates performed by each device would be different. That's why we use `jax.lax.pmean` to compute the mean across the `num_devices` axis, giving us the average gradient of the batch. That average gradient is what we use to compute the update. - -Aside on naming: here, we use `num_devices` for the `axis_name` for didactic clarity while introducing `jax.pmap`. However, in some sense that is tautologous: any axis introduced by a pmap will represent a number of devices. Therefore, it's common to see the axis be named something semantically meaningful, like `batch`, `data` (signifying data parallelism) or `model` (signifying model parallelism). - -```{code-cell} ipython3 -:id: _CTtLrsQ-0kK - -# Generate true data from y = w*x + b + noise -true_w, true_b = 2, -1 -xs = np.random.normal(size=(128, 1)) -noise = 0.5 * np.random.normal(size=(128, 1)) -ys = xs * true_w + true_b + noise - -# Initialise parameters and replicate across devices. -params = init(jax.random.PRNGKey(123)) -n_devices = jax.local_device_count() -replicated_params = jax.tree_map(lambda x: jnp.array([x] * n_devices), params) -``` - -+++ {"id": "dmCMyLP9SV99"} - -So far, we've just constructed arrays with an additional leading dimension. The params are all still on the host (CPU). `pmap` will communicate them to the devices when `update()` is first called, and each copy will stay on its own device subsequently. - -```{code-cell} ipython3 -:id: YSCgHguTSdGW -:outputId: a8bf28df-3747-4d49-e340-b7696cf0c27d - -type(replicated_params.weight) -``` - -+++ {"id": "90VtjPbeY-hD"} - -The params will become a jax.Array when they are returned by our pmapped `update()` (see further down). - -+++ {"id": "eGVKxk1CV-m1"} - -We do the same to the data: - -```{code-cell} ipython3 -:id: vY61QJoFWCII -:outputId: f436a15f-db97-44cc-df33-bbb4ff222987 - -def split(arr): - """Splits the first axis of `arr` evenly across the number of devices.""" - return arr.reshape(n_devices, arr.shape[0] // n_devices, *arr.shape[1:]) - -# Reshape xs and ys for the pmapped `update()`. -x_split = split(xs) -y_split = split(ys) - -type(x_split) -``` - -+++ {"id": "RzfJ-oK5WERq"} - -The data is just a reshaped vanilla NumPy array. Hence, it cannot be anywhere but on the host, as NumPy runs on CPU only. Since we never modify it, it will get sent to the device at each `update` call, like in a real pipeline where data is typically streamed from CPU to the device at each step. - -```{code-cell} ipython3 -:id: atOTi7EeSQw- -:outputId: c8daf141-63c4-481f-afa5-684c5f7b698d - -def type_after_update(name, obj): - print(f"after first `update()`, `{name}` is a", type(obj)) - -# Actual training loop. -for i in range(1000): - - # This is where the params and data gets communicated to devices: - replicated_params, loss = update(replicated_params, x_split, y_split) - - # The returned `replicated_params` and `loss` are now both jax.Arrays, - # indicating that they're on the devices. - # `x_split`, of course, remains a NumPy array on the host. - if i == 0: - type_after_update('replicated_params.weight', replicated_params.weight) - type_after_update('loss', loss) - type_after_update('x_split', x_split) - - if i % 100 == 0: - # Note that loss is actually an array of shape [num_devices], with identical - # entries, because each device returns its copy of the loss. - # So, we take the first element to print it. - print(f"Step {i:3d}, loss: {loss[0]:.3f}") - - -# Plot results. - -# Like the loss, the leaves of params have an extra leading dimension, -# so we take the params from the first device. -params = jax.device_get(jax.tree_map(lambda x: x[0], replicated_params)) -``` - -```{code-cell} ipython3 -:id: rvVCACv9UZcF -:outputId: 5c472d0f-1236-401b-be55-86e3dc43875d - -import matplotlib.pyplot as plt -plt.scatter(xs, ys) -plt.plot(xs, params.weight * xs + params.bias, c='red', label='Model Prediction') -plt.legend() -plt.show() -``` - -+++ {"id": "4wFJcqbhbn81"} - -## Aside: hosts and devices in JAX - -When running on TPU, the idea of a 'host' becomes important. A host is the CPU that manages several devices. A single host can only manage so many devices (usually 8), so when running very large parallel programs, multiple hosts are needed, and some finesse is required to manage them. - -```{code-cell} ipython3 -:id: 3DO8NwW5hurX -:outputId: 6df0bdd7-fee2-4805-9bfe-38e41bdaeb50 - -jax.devices() -``` - -+++ {"id": "sJwayfCoy15a"} - -When running on CPU you can always emulate an arbitrary number of devices with a nifty `--xla_force_host_platform_device_count` XLA flag, e.g. by executing the following before importing JAX: -```python -import os -os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=8' -jax.devices() -``` -``` -[CpuDevice(id=0), - CpuDevice(id=1), - CpuDevice(id=2), - CpuDevice(id=3), - CpuDevice(id=4), - CpuDevice(id=5), - CpuDevice(id=6), - CpuDevice(id=7)] -``` -This is especially useful for debugging and testing locally or even for prototyping in Colab since a CPU runtime is faster to (re-)start. diff --git a/docs/jax-101/07-state.ipynb b/docs/jax-101/07-state.ipynb deleted file mode 100644 index c7d75abf66b3..000000000000 --- a/docs/jax-101/07-state.ipynb +++ /dev/null @@ -1,420 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "Ga0xSM8xhBIm" - }, - "source": [ - "# Stateful Computations in JAX\n", - "\n", - "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/jax-101/07-state.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/jax-101/07-state.ipynb)\n", - "\n", - "*Authors: Vladimir Mikulik*\n", - "\n", - "This section explores how JAX constrains the implementation of stateful programs." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Avjnyrjojo8z" - }, - "source": [ - "## Motivation\n", - "\n", - "In machine learning, program state most often comes in the form of:\n", - "* model parameters,\n", - "* optimizer state, and\n", - "* stateful layers, such as [BatchNorm](https://en.wikipedia.org/wiki/Batch_normalization).\n", - "\n", - "Some JAX transformations, most notably `jax.jit`, impose constraints on the functions they transform. In particular, the function transformed by `jax.jit` must have no side-effects. This is because any such side-effects will only be executed once, when the python version of the function is run during compilation. These side-effects will not be executed by the compiled function on subsequent runs.\n", - "\n", - "Changing program state is one kind of side-effect. So, if we can't have side effects, how do we update model parameters, the optimizer state, and use stateful layers in our models? This colab will explain this in detail, but the short answer is: with [functional programming](https://en.wikipedia.org/wiki/Functional_programming)." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "s_-6semKkSzp" - }, - "source": [ - "## A simple example: Counter\n", - "\n", - "Let's start by looking at a simple stateful program: a counter." - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": { - "id": "B3aoCHpjg8gm", - "outputId": "5cbcfbf5-5c42-498f-a175-050438518337" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "1\n", - "2\n", - "3\n" - ] - } - ], - "source": [ - "import jax\n", - "import jax.numpy as jnp\n", - "\n", - "class Counter:\n", - " \"\"\"A simple counter.\"\"\"\n", - "\n", - " def __init__(self):\n", - " self.n = 0\n", - "\n", - " def count(self) -> int:\n", - " \"\"\"Increments the counter and returns the new value.\"\"\"\n", - " self.n += 1\n", - " return self.n\n", - "\n", - " def reset(self):\n", - " \"\"\"Resets the counter to zero.\"\"\"\n", - " self.n = 0\n", - "\n", - "\n", - "counter = Counter()\n", - "\n", - "for _ in range(3):\n", - " print(counter.count())" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "SQ-RNLfdiw04" - }, - "source": [ - "The `n` attribute maintains the counter's _state_ between successive calls of `count`. It is modified as a side effect of calling `count`.\n", - "\n", - "Let's say we want to count fast, so we `jax.jit` the `count` method. (In this example, this wouldn't actually help speed anyway, for many reasons, but treat this as a toy model of wanting to JIT-compile the update of model parameters, where `jax.jit` makes an enormous difference)." - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": { - "id": "5jSjmJMon03W", - "outputId": "d952f16b-9b30-4753-ed94-cc914a929a36" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "1\n", - "1\n", - "1\n" - ] - } - ], - "source": [ - "counter.reset()\n", - "fast_count = jax.jit(counter.count)\n", - "\n", - "for _ in range(3):\n", - " print(fast_count())" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "weiI0V7_pKGv" - }, - "source": [ - "Oh no! Our counter isn't working. This is because the line\n", - "```\n", - "self.n += 1\n", - "```\n", - "in `count` is only called once, when JAX compiles the method call. Moreover, since the return value doesn't depend on the arguments to `count`, once it returns the first 1, subsequent calls to `fast_count` will always return 1. This won't do. So, how do we fix it?\n", - "\n", - "## The solution: explicit state\n", - "\n", - "Part of the problem with our counter was that the returned value didn't depend on the arguments, meaning a constant was \"baked into\" the compiled output. But it shouldn't be a constant -- it should depend on the state. Well, then why don't we make the state into an argument?" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": { - "id": "53pSdK4KoOEZ", - "outputId": "5ac72b9c-7029-4bf2-de8d-1d412bd74c79" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "1\n", - "2\n", - "3\n" - ] - } - ], - "source": [ - "CounterState = int\n", - "\n", - "class CounterV2:\n", - "\n", - " def count(self, n: CounterState) -> tuple[int, CounterState]:\n", - " # You could just return n+1, but here we separate its role as \n", - " # the output and as the counter state for didactic purposes.\n", - " return n+1, n+1\n", - "\n", - " def reset(self) -> CounterState:\n", - " return 0\n", - "\n", - "counter = CounterV2()\n", - "state = counter.reset()\n", - "\n", - "for _ in range(3):\n", - " value, state = counter.count(state)\n", - " print(value)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "PrBjmgZtq89b" - }, - "source": [ - "In this new version of `Counter`, we moved `n` to be an argument of `count`, and added another return value that represents the new, updated, state. To use this counter, we now need to keep track of the state explicitly. But in return, we can now safely `jax.jit` this counter:" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": { - "id": "LO4Xzcq_q8PH", - "outputId": "25c06a56-f2bf-4c54-a3c3-6e093d484362" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "1\n", - "2\n", - "3\n" - ] - } - ], - "source": [ - "state = counter.reset()\n", - "fast_count = jax.jit(counter.count)\n", - "\n", - "for _ in range(3):\n", - " value, state = fast_count(state)\n", - " print(value)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "MzMSWD2_sgnh" - }, - "source": [ - "## A general strategy\n", - "\n", - "We can apply the same process to any stateful method to convert it into a stateless one. We took a class of the form\n", - "\n", - "```python\n", - "class StatefulClass\n", - "\n", - " state: State\n", - "\n", - " def stateful_method(*args, **kwargs) -> Output:\n", - "```\n", - "\n", - "and turned it into a class of the form\n", - "\n", - "```python\n", - "class StatelessClass\n", - "\n", - " def stateless_method(state: State, *args, **kwargs) -> (Output, State):\n", - "```\n", - "\n", - "This is a common [functional programming](https://en.wikipedia.org/wiki/Functional_programming) pattern, and, essentially, is the way that state is handled in all JAX programs.\n", - "\n", - "Notice that the need for a class becomes less clear once we have rewritten it this way. We could just keep `stateless_method`, since the class is no longer doing any work. This is because, like the strategy we just applied, object-oriented programming (OOP) is a way to help programmers understand program state. \n", - "\n", - "In our case, the `CounterV2` class is nothing more than a namespace bringing all the functions that use `CounterState` into one location. Exercise for the reader: do you think it makes sense to keep it as a class?\n", - "\n", - "Incidentally, you've already seen an example of this strategy in the JAX pseudo-randomness API, `jax.random`, shown in the [Random Numbers section](https://colab.research.google.com/github/google/jax/blob/main/docs/jax-101/05-random-numbers.ipynb). Unlike Numpy, which manages random state using stateful classes, JAX requires the programmer to work directly with the random generator state -- the PRNGKey." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "I2SqRx14_z98" - }, - "source": [ - "## Simple worked example: Linear Regression\n", - "\n", - "Let's apply this strategy to a simple machine learning model: linear regression via gradient descent.\n", - "\n", - "Here, we only deal with one kind of state: the model parameters. But generally, you'll see many kinds of state being threaded in and out of JAX functions, like optimizer state, layer statistics for batchnorm, and others.\n", - "\n", - "The function to look at carefully is `update`." - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": { - "id": "wQdU7DoAseW6" - }, - "outputs": [], - "source": [ - "from typing import NamedTuple\n", - "\n", - "class Params(NamedTuple):\n", - " weight: jnp.ndarray\n", - " bias: jnp.ndarray\n", - "\n", - "\n", - "def init(rng) -> Params:\n", - " \"\"\"Returns the initial model params.\"\"\"\n", - " weights_key, bias_key = jax.random.split(rng)\n", - " weight = jax.random.normal(weights_key, ())\n", - " bias = jax.random.normal(bias_key, ())\n", - " return Params(weight, bias)\n", - "\n", - "\n", - "def loss(params: Params, x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:\n", - " \"\"\"Computes the least squares error of the model's predictions on x against y.\"\"\"\n", - " pred = params.weight * x + params.bias\n", - " return jnp.mean((pred - y) ** 2)\n", - "\n", - "\n", - "LEARNING_RATE = 0.005\n", - "\n", - "@jax.jit\n", - "def update(params: Params, x: jnp.ndarray, y: jnp.ndarray) -> Params:\n", - " \"\"\"Performs one SGD update step on params using the given data.\"\"\"\n", - " grad = jax.grad(loss)(params, x, y)\n", - "\n", - " # If we were using Adam or another stateful optimizer,\n", - " # we would also do something like\n", - " # ```\n", - " # updates, new_optimizer_state = optimizer(grad, optimizer_state)\n", - " # ```\n", - " # and then use `updates` instead of `grad` to actually update the params.\n", - " # (And we'd include `new_optimizer_state` in the output, naturally.)\n", - "\n", - " new_params = jax.tree_map(\n", - " lambda param, g: param - g * LEARNING_RATE, params, grad)\n", - "\n", - " return new_params" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "dKySWouu2-Hu" - }, - "source": [ - "Notice that we manually pipe the params in and out of the update function." - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": { - "id": "jQCYYy0yxO6K", - "outputId": "1f3b69d2-e90b-4065-cbcc-6422978d25c2" - }, - "outputs": [ - { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXIAAAD4CAYAAADxeG0DAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nO3deYDV8/7H8ednplHTOtGEJiouJVqGKWm7t3LrIhlxyZVsCdm5UZaSm1vUvdxf1pCtKJRBcVsUKaJpFUm4baO0GUqTZvn8/jhzzsw58z3LzDnTWXo9/uG7nO/3O9L7fOb9fX/eH2OtRURE4ldStB9ARETCo0AuIhLnFMhFROKcArmISJxTIBcRiXM1onHTRo0a2ebNm0fj1iIicWv58uW7rLXpvvujEsibN29Obm5uNG4tIhK3jDGbnPYrtSIiEucUyEVE4pwCuYhInItKjtxJYWEhW7du5cCBA9F+FImgWrVq0bRpU1JSUqL9KCIJK2YC+datW6lXrx7NmzfHGBPtx5EIsNaye/dutm7dSosWLaL9OCIJK2YC+YEDBxTEE4wxhqOOOoqdO3dG+1FEoi5nZR7j56znx/wCmqSlMqxPS7IzMyJy7ZgJ5ICCeALSn6mIK4iPmPklBYXFAOTlFzBi5pcAEQnmetkpIlLNxs9Z7wnibgWFxYyfsz4i11cgL8cYw8CBAz3bRUVFpKen07dv30pdp3nz5uzatatK5zRv3pw2bdrQtm1bevfuzfbt2yt17/IefPBBJkyYAMDIkSOZP3++33NXrVrF+++/79l+9913GTduXJXvLSJlfswvqNT+ylIgL6dOnTqsXbuWggLXf9x58+aRkRGZHFZlLFy4kDVr1pCVlcU///lPr2PWWkpKSip9zYceeoizzz7b73HfQN6vXz+GDx9e6fuISEVN0lIrtb+yFMh9nHvuucyePRuA119/ncsuu8xzbM+ePWRnZ9O2bVs6derEmjVrANi9eze9e/fm1FNPZfDgwZRfdWnKlCl07NiR9u3bc/3111Nc7P3rVSDdu3fnu+++Y+PGjbRs2ZJBgwZx2mmnsWXLFsaPH0+HDh1o27Yto0aN8nzm4Ycf5uSTT6Zr166sX1/2a9tVV13FW2+9BcCyZcvo3Lkz7dq1o2PHjvzyyy+MHDmS6dOn0759e6ZPn85LL73EzTffDMDGjRvp2bMnbdu2pVevXmzevNlzzVtvvZXOnTtzwgkneK4vIt6G9WlJakqy177UlGSG9WkZkevH1MtOj9tvh1WrInvN9u3h8ceDnjZgwAAeeugh+vbty5o1a7jmmmv45JNPABg1ahSZmZnk5OSwYMECBg0axKpVqxg9ejRdu3Zl5MiRzJ49mxdeeAGAdevWMX36dJYsWUJKSgpDhw5l6tSpDBo0KKRHnjVrFm3atAFgw4YNvPzyy3Tq1Im5c+eyYcMGvvjiC6y19OvXj0WLFlGnTh2mTZvGqlWrKCoq4vTTT+eMM87wuubBgwe59NJLmT59Oh06dODXX3+ldu3aPPTQQ+Tm5vLEE08A8NJLL3k+c8stt3DllVdy5ZVXMnnyZG699VZycnIA2LZtG4sXL+abb76hX79+XHzxxSH9bCKHE/cLzcOiaiUWtG3blo0bN/L6669z7rnneh1bvHgxM2bMAKBnz57s3r2bX3/9lUWLFjFz5kwAzjvvPBo2bAjAhx9+yPLly+nQoQMABQUFNG7cOOgz9OjRg+TkZNq2bcuYMWPIz8+nWbNmdOrUCYC5c+cyd+5cMjMzAdi3bx8bNmxg7969XHjhhdSuXRtwpUd8rV+/nmOPPdbzTPXr1w/6PJ999pnn57viiiu4++67Pceys7NJSkqidevW/PTTT0GvJXK4ys7MiFjg9hWbgTyEkXN16tevH3//+9/56KOP2L17d5WvY63lyiuvZOzYsZX63MKFC2nUqJFnOz8/nzp16nhdd8SIEVx//fVen3s8Cv/datas6fl3LeQtEh3KkTu45pprGDVqlCet4datWzemTp0KwEcffUSjRo2oX78+3bt357XXXgPggw8+4OeffwagV69evPXWW+zYsQNw5dg3bXLsQlkpffr0YfLkyezbtw+AvLw8duzYQffu3cnJyaGgoIC9e/fy3nvvVfhsy5Yt2bZtG8uWLQNg7969FBUVUa9ePfbu3et4v86dOzNt2jQApk6dSrdu3cL+GUQkcmJzRB5lTZs25dZbb62w/8EHH+Saa66hbdu21K5dm5dffhlw5c4vu+wyTj31VDp37szxxx8PQOvWrRkzZgy9e/empKSElJQUnnzySZo1axbW8/Xu3Zt169Zx1llnAVC3bl2mTJnC6aefzqWXXkq7du1o3LixJ31S3hFHHMH06dO55ZZbKCgoIDU1lfnz59OjRw/GjRtH+/btGTFihNdnJk6cyNVXX8348eNJT0/nxRdfDOv5RSSyTDR+Hc7KyrK+C0usW7eOU0455ZA/i1Q//dmKRIYxZrm1Nst3v1IrIiJxToFcRCTOxVQgV9VD4tGfqUj1i5lAXqtWLXbv3q2/+AnE3Y+8Vq1a0X4UkYQWsaoVY0wykAvkWWsr12UKV6XI1q1b1bs6wbhXCBKR6hPJ8sPbgHVA8KmCDlJSUrSKjIgkHPeCEufOnUq/75eyacoM+nZrFdF7RCSQG2OaAucBDwN3RuKaIiLxLmdlHlOemMGSybd59l3/9iqK6taL6HT9SI3IHwfuBur5O8EYMwQYAngmzIiIJKy9e+nZ5RSyC8pmTHe46VV2HlGP8XPWRzSQh/2y0xjTF9hhrV0e6Dxr7SRrbZa1Nis9PT3c24qIxCZr4YYboH596pcG8SsueYjm98xiZ11XQ71ILSjhFokReRegnzHmXKAWUN8YM8VaOzDI50REEst770G5rqPTulzE8K5XVzgtUgtKuIU9IrfWjrDWNrXWNgcGAAsUxEXksJKXB8aUBfGmTWHfPmpN/E+1LijhFjN15CIicae4GHr2dAVut9WrYcsWqFOH7MwMxvZvQ0ZaKgbISEtlbP82Ee9LHjNNs0RE4soTT8Att5RtP/kkDB1arbf01zRLbWxFRAJw14G7l2j7R4tieg7oXXZCr14wZw4kJ/u/SDVTIBcR8SNnZR4jZn5JQWExqQcPMPPhQRy9b0/ZCVu3Qkb1LN9WGcqRi4j4MX7OegoKixk5fxLrHrvYE8SHDRrjKjOMgSAOGpGLiPj1hxWLWfLmKM/2K5nnMbL3jRhgfPQeqwIFchERX9u3w7HH8nLp5u7U+nS//nl+q1kbiHwdeLgUyEVE3EpKoG9f+OADz66Lrp3I8kZlDf2qow48XMqRi4gAPPecq/LEHcT//W+wlituurDa68DDpRG5iBzevv4aTj21bLtLF/joI6jhCo/ZmRkxF7h9KZCLyOGpoMAVwP/3v7J9GzdCs2ZRe6SqUmpFRA4/w4dD7dplQXzGDFc5YRwGcdCIXETikO9sy2F9WoaW/liwwDUT0+2aa+D5510Nr+KYArmIxJXysy0B8vILGDHzSwD/wXznTmjcuGy7bl3XrMwGDar7cQ8JpVZEJK64Z1uWV1BYzPg56yuebC307+8dxJcuhb17Kx3Ec1bm0WXcAloMn02XcQvIWZlXlcevFgrkIhJX/K2uU2H/K69AUhK8/bZre+xYV2A/88xK39P9W0BefgGWst8CYiWYK7UiInGlSVoqeQ7B3DPb8ttvoWW5CTunnw6ffQZHHFHlewb6LSAWShM1IheRuDKsT0vHVXfu6dEcWrf2DuLffw/Ll4cVxKESvwVEiQK5iMQVp1V33t4xl36dToR161wnvf66K41ywgkRuae/3iqx0nNFqRURiTue2ZaLF0O3bmUH/vY3mDIl4uWEw/q09KqUgdjquaJALiLxZ88eSE93NbkCV4+UHTvgyCOr5XbuPHiVatcPAQVyEYkf1sLll7tSJ26ffAJdu1b7rWO554oCuYjEjIAzNqdPhwEDyk4eNQoefDAqzxlrFMhFJCb4m7FZe+smevfrUnZi69awYgXUrBmlJ409CuQiEhN8a7VTigt586XbOG3M92UnrV8PJ58chaeLbSo/FJGYUL4m+8alb7JhwoWc9lNpEH/5ZVd+XEHckUbkIhITmqSlkv71KnJevcuz74OTOzPmqodYMqhXgE+KArlIAqlye9do++UXPh7dlxoHykblp98ylYIGRzL2L62i+GDxQYFcJEFUqb1rkOtV+5eCtXDttfDii55gdMvgCcw6qhVN0lIZGS9fRFEWdiA3xhwHvAIcDVhgkrX2P+FeV0QqJ5KNnSL9peDo7bddLWbdhg+HsWOZCEyMzB0OG5EYkRcBd1lrVxhj6gHLjTHzrLVfR+DaIhKiSDZ2qtZuf5s3ey+pdsIJsHYtpMZG35J4FHbVirV2m7V2Rem/7wXWAfpdSOQQi2Rjp2rp9ldUBJ07ewfxr75ydShUEA9LRMsPjTHNgUzgc4djQ4wxucaY3J07d0bytiKC//auVWnsFPFuf489Bikprr7gAJMmufLjrVtX7XriJWKB3BhTF5gB3G6t/dX3uLV2krU2y1qblZ6eHqnbikgpp/auY/u3qVIqJNQvhWDLny187b+uToR33gnA9q49obgYrruu0s8k/kWkasUYk4IriE+11s6MxDVFpPIi1dgplG5/AV+I/qE+B487nh6/5HvO73DTq+xr2Iixq7epEiXCIlG1YoAXgHXW2n+H/0giEgucvhTKlyQmGUOxtV7HCwqLKblxKHz+Lu41eQb9dTSLTjjDtRFDy6MlkkiMyLsAVwBfGmNWle6711r7fgSuLSLVqDK14r4jcN8g3uP7Zbz41mjP9gtZF/CPXhVTKE7rbUp4wg7k1trFQGSX4xCRalfZWnGnkkSAo/fu4vOnrvJsb697JMvmLGXyoi3gELRN6b01Ko8cNc0SSTDBXkC6BaoVd+JbephUUszUafd6BfFzrv4/Ot30CuMWbWFYn5aOIzxbem+JHAVykQTiHmXn5RdgKRtlOwXzytaKly89HLhiNj+Mv4Aum9YA8MCfb6D5PbNY1/gEzzWyMzOwjleKndXnE4UCuUgCqcwou7K14sP6tKTdns1sfKQvY+Y9DcCnx7flhGHv8OrpfR2vkRHjq88nCgVykQRSmVF2KLXiOSvzaD96LqfcOYNO3dvyznNDPccuGD6N9/8zhZo1j/B7jUhOUhL/1P1QJIE0SUt1rApxGgEHqxXPWZnHsDdXc8+8SQzOfcfzuSF/HcW5I67jndLzspod6fcasb76fKIw1vrLYlWfrKwsm5ube8jvK5LofCtRwDUCrsoMzzuueZTHXrzHs/1q5rk80Ns1Is9IS2XJ8J6ReWgJmTFmubU2y3e/RuQiCaB8PXha7RRq1kjil4JCmqSl0qNVOuPnrOeO6atCGxFv3w7HHstjpZs/16pH1xte4LeatT2nqBY8tiiQi8Q531H4z/sLSU1J5rFL2wM41ornbtrDwm92eqc72h0L558P75fN5et75eOsPeYPFe6pWvDYokAuEueCVao4HZu6dLOnNDAvv4DcByaQPfvxspP+9S9yelzK+jdXQ0nF9Ku7FtwpkMftcnNxTIFcJIZUJQhWpXe4OzSfuGsLH75wY9mBs86CRYugRg2yS3fdPn2V78f9Xv+QrCwkFSiQi8SIqgZBf5UqScZQP7UGP+8vrHCsZuHvzJ18E83yt3v2dblhMjRrxo/3z/H6Ehk/Z33IlTDVurKQ+KU6cpEYUdkp825Otdrgamq170ARKcneE+WHffwy6/99kSeI35A9gub3zOLHBo29ZoTeMX0VzYfPZv/BIlKSvK/hrxa8WlYWkqA0IheJEVUNgu6R7l1vrK7QkbCwxJKWmkKdmjU4bs3nTHv9Xs+xN9qczd3n3AbGYKDCdHr39s/7C0lJNqSlpngqYfylfCpTxy6Ro0AuEiPCCYLZmRnc4SeXbXbvYsnEy8t2pKYy+7+5/OfT7ZjSXHywcsLCYkudmjVYNap3wPOG9WnpWMeumZzVS4FcJEZUNgj6vhhtkJpCfkG5fLi1PJ0zlnO+/dSza8iNE5lXvwVNPt3uNaruMm5B0GDu+5tBoBezqlo5tBTIRWJEZYKg04vR8rnwC9cu4LHZZQt2Tex5FU91vtRvPXlefoFjeqW88r8ZBHsxq8B9aCmQixwCoZYVhhoEnV6MFhZbmu/J46PnrvfsW3v0iVx4xQQKk1MgSD25BU8w9w3qvr8ZqDoltiiQi1SzqpYVBgr+vmmOI4oKee/l22i5a7NnX/chz7G54bEBn83pBWdG6b0CffGoOiW2KJCLVLOqjF6DBf/yLyhvXfI6dy6e6vnsXdn3MKNlN892akoytVKSHOvJnbgXhahK7bqqU6JDdeQi1awqo9dgNeXD+rSk8/Zv2PhIX08Qn3Xqn8hZvoVuI28hIy0Vg2t0PbZ/G0adf2qFWnN/C+2GEozVZzy2aEQuUs2qMnr1F+Tz8guY/fFXZJ/dnuyiIgBKMJx3/wyu79+xQh9wX+XTJT1apTNjeV6VSgVVnRJbFMhFqllVaqsdg7+1PD5rAuc98nHZvkWLSOrWjQ9CeA6ndEmgRSGqcj2JDgVykQjx93KyKqNX3+B/3rpPePLdRzzHX+w5kKs/fDXsZ1YwTgwK5CIREEpddWUCZnZmBrmb9vDRf79g8bODPfs3HHUc5131fxTWSOHqyP4IEscUyEXClLMyjzumr6pQyhdOXfU7X2xkwNCLGLNtg2dfz8HP8MNRTQH/q9PL4UmBXCQMOSvzuNMhiLtVqa76kUe4YPhwz+Zd597BjDa9PNuqDhFfCuQiYXjw3a8oCXA81LrqnJV5zHr+HZ5/6ibPvv+efBY3Zo/AGu8q4aospCyJLSKB3BjzF+A/QDLwvLV2XCSuKxLrvJpUOQhl5Dxr0Tr+/OczyD5YNnrPunkKu+qkVTjXnVLpMm6Byv7EI+xAboxJBp4E/gxsBZYZY9611n4d7rVF4l3AAGstXHcdfV94wbPrb5eO4dPmrkWTnfqd9GiVHrGl1LS2ZuKIxMzOjsB31tofrLUHgWnABRG4rkjMa1g7xe8xY1zB0lFODiQlQWkQf+bMi2h+zyxPEIeyviflZ2gu/GZnlVYRqnD70iqb8isCjZj5pf/nlZgWidRKBrCl3PZW4Ezfk4wxQ4AhAMcff3wEbisSfaPOP5W73lxNsdNK85aKo+XNm6FZs7KTWrSgx9VP8r/fKmbaM9JSWTK8p9c+f4tHVPalqroXJpZD1mvFWjvJWptlrc1KT08/VLcV8ZKzMo8u4xbQYvhsuoxbUKkRqNNnszMz+Ndf2/kdmXtGy0VF0LWrdxBfuxZ++IHb+rYNuW+Jv5enlW1Wpe6FiSUSI/I84Lhy201L94lEXfk8cFrtFPYdKKKwdPRcmfxyKK1ob/czWu4zbxqMKCsf5NlnYcgQz2ZlZn5Gaik1dS9MLJEI5MuAk4wxLXAF8AHA3yJwXZGw+AZfpzauoaYTgqUinHLUp/70PbNfuq1sxznnwKxZrty4j1BnfkaqWZXW1kwsYQdya22RMeZmYA6u8sPJ1tqvwn4ykTA5BV8noaQTgqUiyh+vfbCAT565lqMKfi134o9wbOBFHkIVif4o6l6YWCJSR26tfR94PxLXEomUUPO9oaQT/KUi0kpz4+7jo+c9zZUrZnuO33T5GJ6ccl+IT3xoqWFW4tDMTklY/oKvr7z8AtqPnsuD/U4FnEepPVqlM2Xp5gqf/aWgkJyVeUyos5WzRlzp2f/iGefz6DlDGdu/TeR+IBE/FMglYTnlgVOSDTWSDAWF3uV++QWF3PnGKpKNcXwZuvCbnY73aPTrbrJPb+rZ3lXvKLoPfpaGjRsyVqkKOUQUyCVh+csDj5+z3nGkXmKhxHrXg7tfaPqmaZJKinnljZF03bTas2/h63PoMaA3mtIsh5oCuSQ0pzywv0k1/ri/BNzBf+DK9xkz9ynP8VFnX8/LZ5xPxsYaLAn/kUUqTYFcDjuh5s7Lnz+sT0uee/IdZr9Q1p1w6XGn8bcBD1OS5JrMo8k0Ei0K5BL3Ktv8aViflgx7azWFxd5plCSDV44cXDn14r376PjH9mTv3eXZ3+nGl9hev5HX5zWZRqJFgVziWigzLn25949+7yvPJKG01JQKVStptVO4afYzDP7ibc9nh14yiiMvu4hfludBFSfTqOugRJqx1t/aJtUnKyvL5ubmHvL7SuLpMm6BY5rEqeFUpcybB717ezantv8L9/W+CYwho9xL08oGY98vHnB9CWixCAmFMWa5tTbLd79G5BLXIt786aef4JhjPJu/1KxDlxtfZF/N2l7XrspkmpyVedz1xmqK/VTGKJBLVR2y7oci1cFfXjrJmJA7HOaszKPrP+fz4R86egXxa296ina3T/cK4oHuGeweI2Z+WSGIu+lFqYRDI3KJqnDzxf5mXLoDZrCcec7KPJaNnMDiWY979j3a61pOHj+a84FPI9RYKljfF70olXAokEvU+HtRmbtpDwu/2RlScPc347I839SF+8uj1vff8uHzN5Jdet7KY1vy18sfoSi5Bhlz1nty7JF4MRloxK2ugxIuBXKJGn+tYacu3exZqzLYiDrUlERefgEths+mVkoSB38vZODK9xk9/1nP8a43vMDWBkdX+rqh8le7nmyMXnRK2BTIJWr8BUvfLHKgl4GVmdxjgcwNKxj14SRa7trMiiYtea7DhXzQqqvjde/P+bJSXypO3KP/vPwCx8WUFcQlEhTIJWoqE4T9BX2nxlhOmv7yE/cteIFzvv2ULQ2OZsiF9zH3pE6uFZJ9uFerLx/E3covdBws5eKbOrLgCeYZqh+XCFIgl6hxCsK+o1Y3fy8DnRpj9WiV7smx1yw8wI1LZ3D9FzMoMYYJ3QbyXIcL+T2lptd1ko2hxFqvxlr+Zli4R+bBJiE5pY7cQTysGncRHwrkEjX+gvCM5XmVqhRxrOm2Ft58k+1DbuGYX3bw7indGfunq9lW33nh78vOPI4x2WW9wwM11ko2JqQV6LXAsRwqCuQSVU5BOKvZkeFViqxZA7feCh9/TOrJrbmi39/5pEnrgB+ZtXqbV6VMWu0UxzU+DYRcC64FjuVQUSCXmFPlJch274aRI+GZZ6BhQ3j6aRpcdx0XrdnOD6VfDP7SJfkFheQXuAK3v7y9AS7vdDwLv9kZUoDWAsdyqCiQS/wrLoZJk+D++yE/H4YOhdGj4cgjAe8vBn+9Wfxxejnpr1+Kb4DWAsdyqCiQS1SF3Qnw449daZQ1a1hxQjvuzX6IvRmnMGxTAdlHVryPUxlgIE4vJysToLXAsRwKCuQSNVVpQeuxZQsMGwbTp7P/mAxGXHQv75x4lquc0Oc6vvXgvmWA+w8WOebD3ZxeTipASyxRIJeo8TezM2AnwIICmDABxo51VaaMGkXfpI78sN+5oyDgWA9ucfUgXzK8p2OqpDy9nJRYp0AuUeOvDC8vv4Au4xZ4py3aN4G334a77oKNG+Hii10BvVkzfhg+2+/1A9WD5xcUkrMyz/Ol8eC7X3leeLrp5aTEA7WxlWqRszKPLuMWBGwl62+ka3AFc1v6zxeencWOTt3hoougbl1YsADefBOaNSNnZR4V52aWXT9YzbZ71J6dmcGqUb15/NL2ZKSlYnClXTSFXuKBRuQScaHmvp2mwZd/EVn/wD7uWDyVK1bMZn+tOjBxItxwA9Qo+9/W34jbgGeGZqAqFd9jyn1LPNKIXCIuUO7bLWdlHjOW5zkG8aSSYgas+i8LJw3hyuWzmNauD38a/AzcfLNXEIfAjbeyMzMY1qclqSnJfp812aHXiki8CWtEbowZD5wPHAS+B6621uZH4sEkfoUyNd1fH5Iztn7N6PnPctpP3/NF09aMPvt6vjr6RDL8pGH8zZ50n+8eXd/uZ8q9v1maIvEk3BH5POA0a21b4FtgRPiPJPHOX+67/H7fYH/03l089t4EZky9m6N+y+fW84dxyd8e4aujTwz4wtFpxO17fnZmht8vAn/7ReJJWIHcWjvXWltUurkUaBr+I0m8cwquBldO3M0d1GsWHWToZ2+w4LkbOHf9EiaedSk9r3uWd1v/0dNitmYN//+bZmdmMLZ/m6AvKEMJ+CLxytgI/WppjHkPmG6tneLn+BBgCMDxxx9/xqZNmyJyX4lNlz/3GUu+3+O1r/xCCjkrtjL3kee5e+4kmudvY85JnRjTczBb0o5xvF4kFmEIexapSJQZY5Zba7Mq7A8WyI0x8wGnv133WWvfKT3nPiAL6G9D+GbIysqyubm5IT24xCZ/QTFnZZ5jPbZbRloqS7KbwO23w5w5/K/x8Tzwp+tY3CIz6D3Vx1sOd/4CedCXndbas4Nc+CqgL9ArlCAu8al84G6QmsJvB4soLPZeqT53054KvcTLq/f7b1w983l4YBbUrg2PPUaLm27i4rU7WB7CKj/q4y3iLNyqlb8AdwN/tNbuj8wjSazx7VXiNNouKCzm9c+3OFaBGFvCxV9+yN0fv8xRBb/AtdfCww9D48aAdxOqQDXfmiov4iysHLkx5jugJrC7dNdSa+0NwT6n1Er8yFmZxx3TV4XcLdBX+x/X8+D8Z2i/bQPLm7Ri76P/4k+Xn+v3fH9tZsv3AleOWw5XVU6tBGKt/UM4n5fYF6hXia9kYzwj8vR9e7jn45e5eO2H/FT3SO44707qXjOIf/RvF/AagSb4lE/bVGVFe5FEpZmdElCoeenUlGQuO/M46ieVMOTzGSx87nr6ff0xT595MT0HP8Ostr04o0WjoNfxlz4JtE6myOFOvVYEqFiF4l6JPtBovGHtFPL3F5alObat5p7Xbqfeph+Yf2IHxvQczMYjS0fLJTZwe9pS/pZH8/ciVC9ARRTIBecmV1OWbvZ7vjtf7Vl1fsMGuON6mD2beiefzFUXP8hHJ1ZI43kFXX/li/5W3/H3IlQvQEUUyAXnvif+pKWm8GC/U10Bd+9eGDMGHnsMatWC8ePh1lvZ8O/F4BB0k4yhxfDZpNVOYd+BIgpLvMsXoaz7oNPIXQsZizhTjlwqlZ6oU7MG2e2OhVdfhZNPhkcfhcsvh2+/hb//HY44wm/HwWJrsaNYFowAAA1rSURBVMDP+ws9QdwtWL471Kn4IocjjcjFbwdBJ0etWwNd74fPPoMOHSAnB8480+sc3/RIUrlqlkCCfaGoV7iIMwVycXzB6Ouo3/IZtugVLvlyHqSnw4svwqBBkOT8S135oNvCz1JsvpTvFqkaBXIJOLOyRnERV66YxW2LXyO16He+HziEkyY+Ag0ahHz9UEb8qSnJ9GiVXnGtTo3ARYJSjlwAVzBfMrwnj1/a3pPf7v7Dcj548RYeWPA8KzNa0Xfwk3x1xwOVCuLg3EI2JdmQlpriyXdfdEYGM5bnea3VOWLml45rfYqIt4i1sa0MTdGPbUuemUaXGy8DYGPasfyj12A+PLEjGFfwrVOzRqVHzcFayPqbmq+OhyJlqmWKvsQnv0F1xw44+mi6lJ63s3Yava99ioM1UjyfzS8o9DTNqsw0+WAvKkNZHk5EnCm1cphxT/4pn8K4d8Zqtnf/Mxx9tOe8a4c+SYdbpngFcScFhcXc9cbqsFMgoSwPJyLOFMgPM76Tf/66Zh5fP3wex3wy37Xj0UfBWs4ffEHA1efLK7Y27Hy2lmITqTqlVg4z7lTFibu38OHzN3r2rzy2JZmbvoSUFE/qpaCw2NPRMCMtlf0Hi/h5v/PKP+4JPVWtMvE3NV9VKyLBKZAfZprXSeb5xwdz4p6y0XPXG17gt2ObsrI0iJevKS+21mtkHKjePNx8tib8iFSNAvnh5IEHWDhmjGdz6AXDeb9VVwCS9hd6jcTLc4+23dUjd7yxCqdip7TagfPpIlI9lCM/DHzy/FtgjKvBFTDjtJ40v/s9TxAHKKEsreHEvT87M4MGtZwDtlZsFYkOjcgT2e7d0KgR3Uo3f09OocPNr/JrrbqOp7tz08Haxf7isGZnoP0iUr00Ik9E1sIll0CjshV5+l8+npZ/f9tvEAc8LxiDVY+oVFAktiiQJ5rXXnM1snrzTQD+1W0gze+ZxYqmpwT8WEqS8VSJBGsXq1JBkdii1Eqi+O47OOmksu127eCLL5jpd5EHcLcE91osguDVIyoVFIkt6rUSA4L1IQl4/OBBV1/wNWs8519y1yssq3GkZ+3N8qvPg2v0rEUZROKPeq3EKKf1Msv3Lwl4/P2X4P77PdfKHfN/XHHgJK9zZyzP46IzMlj4zU6NnkUSlAJ5lAWq287OzHA8fsrGtWSffk7ZjksugWnTuO2RhRQUFlS41sJvdqqDoEgCUyCPsmB12+WP1z+wj2VPDKRmcVHZibt2wVFHhXQtEUlMqlqJsmClfE3SUsFa/jXrX6z5zwBPEL/pun+7ygxLg3igazVITaHLuAW0GD6bLuMWaLEGkQSjQB5FOSvz2H+wqML+8qV8jyVvYOOj53PRVwsBmHjWpZxy/wf8+cZLKnzOcSWeJMNvB4u08o5IAotIasUYcxcwAUi31u6KxDUTWc7KPEa/95VjJ0FPKWDDQjCGjqX7N6UfR+9B/6FRowaM9fOy0qks0KljYbidCkUktoQdyI0xxwG9gc3hP07i861C8VW/BmTf0B+++KJs57p1NGvVivUhXN+3BtzfCvbKm4skjkikVh4D7gbUMikETlUobtd9PpNF9/cpC+KTJ7vy4K1aVfl+mk4vkvjCCuTGmAuAPGvt6hDOHWKMyTXG5O7cuTOc28Y1p5Fw223fsvGRvtz30WTXjvPPh+JiuPrqsO+n6fQiiS9oasUYMx84xuHQfcC9uNIqQVlrJwGTwDWzsxLPGNd8Z2U2SE3xLF5c9/f9fPrUVdQ/uN9z/tgXFzJrewk/3vtBRCbvaDq9SOKr8hR9Y0wb4EPAHYWaAj8CHa212wN99nCZou+UD09JNhQWlfDPOU/wt9VzPPsHXvIPFrfIJDUlWdPpRcRRxKfoW2u/BBqXu8FGIEtVK2Wc8uF//GYpz8/8h2f72Y79GdvjGgCSjQk4y1NExIlmdlaj8vnwY37dxdKnryo7Vr8xPQc/xYGUWgAVRuL+riMi4itiE4Kstc01GvfWJC2V5JJipr823CuIX3Hbc3zx0QqOSm/o1fM7QxUmIlIFGpFXQrB2s76e/nkJbceP9Gzf13soMzue78l5O33WN6euChMRCUaBPETB2s16WbUKMjNpW7r52Ymnc3n/UZQkJZNWw/8vQaowEZGqUCAPUbB2swDs2wd/+AP89JPnnP/OyeWOT3ZSUvrZ/IJC/18ABF+dR0TEl5pmhShoi9jbboN69cqC+OzZYC3/WPGL3y8AEZFI0Ig8RE3SUslzCOYX/vQlmL5lO26+GSZO9GyqR7iIVDcF8hAN69PSK0eevm8Py54cVHZCejr88APUrev1OX9fAKpEEZFIictAXtnqkUhwX3/CB+t4+Pl7+OP/VpQdXLECMjMdP+f7BQCqRBGRyIq7QF6p6pEIy/5iFtn33VC24/HHXbnxQJ9RJYqIVLO4C+QhVY9E2tq10KZN2Xa3brBgAdQI7T+fKlFEpDrFXSAP5eVhxFIv+/dD69awaVPZvk2b4PjjK38tEZFqEnflh8EWSnCnXsJeo/Luu6FOnbIg/vbbrkUeFMRFJMbEXSAPtlBCoNRLSD78EIyB8eNd29ddByUlkJ3tOSVnZZ5WpReRmBF3qZVgLw+rXLe9YwccfXTZdv36sHkzNGjgdVo0X7aKiDiJu0AOgV8eVrpuu6QE+veHd94p2/f559Cxo+PpUXnZKiISQNylVoKp1BqVL70EycllQfyRR1x5cD9BHDRTU0RiT1yOyAMJqW57/Xrvlek7dIAlSyAlJej1NVNTRGJNwgVyCJB6OXAA2rd3BXK3H36AFi1CvrZmaopIrEm41IpfDzwAqallQXz6dFcapRJBHFxfEu7VfMqv7qP8uIhES0KOyL0sWgR//GPZ9sCB8MorrhLDKtJMTRGJJYkbyHfvhkaNyraPOAK2b4eGDaP3TCIi1SDxUivWwoAB3kF88WL4/XcFcRFJSIkVyF97DZKSXPlvgNGjXYG9S5foPpeISDVKjNTK99+71sp0a9MGli2DmjWj90wiIodIfI/IDx50lROWD+Lffgtr1iiIi8hhI34D+cMPu4L16tWu7VdfdaVRTjopus8lInKIxV9q5bPPoHPnsu2LL4Y33iBn1Y+MH7cgYA/yaCwRJyJS3eIrkM+Y4Qrcbjt3QqNGIXUkVNdCEUlUYadWjDG3GGO+McZ8ZYx5NBIP5VfDhlCvHixc6EqjlJYYhtKDPOw+5SIiMSqsEbkxpgdwAdDOWvu7MaZxZB7Lj5494ddfK+wOpSOhuhaKSKIKN7VyIzDOWvs7gLV2R/iP5CxQfjutdgo/7y+s8JnyHQnVtVBEElW4qZWTgW7GmM+NMR8bYzr4O9EYM8QYk2uMyd25c2elbhJoHc6clXnsO1BU4TMpycarI2Gl+pSLiMSRoCNyY8x84BiHQ/eVfv5IoBPQAXjDGHOCtdb6nmytnQRMAsjKyqpwPJBg+e3CkoqXq3NEDa+XmCH1KRcRiUNBA7m19mx/x4wxNwIzSwP3F8aYEqARULkhdxBVyW//UlAx1aKuhSKSiMJNreQAPQCMMScDRwC7wn0oX/7y2E3SUgMeExE5HIQbyCcDJxhj1gLTgCud0irhCpTfVu5bRA53YVWtWGsPAgMj9Cx+hZLfVu5bRA5XphoG0EFlZWXZ3NzcQ35fEZF4ZoxZbq3N8t0fv02zREQEUCAXEYl7CuQiInFOgVxEJM4pkIuIxLmoVK0YY3YCmw75jaumEdUwySnGHW4/8+H284J+5njVzFqb7rszKoE8nhhjcp3KfRLZ4fYzH24/L+hnTjRKrYiIxDkFchGROKdAHtykaD9AFBxuP/Ph9vOCfuaEohy5iEic04hcRCTOKZCLiMQ5BfIQGGPGG2O+McasMca8bYxJi/YzVSdjzF+NMV8ZY0qMMQlZruVmjPmLMWa9MeY7Y8zwaD9PdTPGTDbG7ChdQyDhGWOOM8YsNMZ8Xfr/9G3RfqbqoEAemnnAadbatsC3wIgoP091Wwv0BxZF+0GqkzEmGXgSOAdoDVxmjGkd3aeqdi8Bf4n2QxxCRcBd1trWuNYWvikR/4wVyENgrZ1rrS0q3VwKNI3m81Q3a+06a+36aD/HIdAR+M5a+0PpIinTgAui/EzVylq7CNgT7ec4VKy126y1K0r/fS+wDki4VWcUyCvvGuCDaD+EREQGsKXc9lYS8C+5uBhjmgOZwOfRfZLIC2upt0RijJkPHONw6D5r7Tul59yH61e1qYfy2apDKD+vSKIwxtQFZgC3W2t/jfbzRJoCeSlr7dmBjhtjrgL6Ar2qY4HpQy3Yz3uYyAOOK7fdtHSfJBBjTAquID7VWjsz2s9THZRaCYEx5i/A3UA/a+3+aD+PRMwy4CRjTAtjzBHAAODdKD+TRJAxxgAvAOustf+O9vNUFwXy0DwB1APmGWNWGWOeifYDVSdjzIXGmK3AWcBsY8ycaD9TdSh9gX0zMAfXS7A3rLVfRfepqpcx5nXgM6ClMWarMebaaD9TNesCXAH0LP27u8oYc260HyrSNEVfRCTOaUQuIhLnFMhFROKcArmISJxTIBcRiXMK5CIicU6BXEQkzimQi4jEuf8HT7WLcq1/wiUAAAAASUVORK5CYII=\n", - "text/plain": [ - "
" - ] - }, - "metadata": { - "needs_background": "light", - "tags": [] - }, - "output_type": "display_data" - } - ], - "source": [ - "import matplotlib.pyplot as plt\n", - "\n", - "rng = jax.random.PRNGKey(42)\n", - "\n", - "# Generate true data from y = w*x + b + noise\n", - "true_w, true_b = 2, -1\n", - "x_rng, noise_rng = jax.random.split(rng)\n", - "xs = jax.random.normal(x_rng, (128, 1))\n", - "noise = jax.random.normal(noise_rng, (128, 1)) * 0.5\n", - "ys = xs * true_w + true_b + noise\n", - "\n", - "# Fit regression\n", - "params = init(rng)\n", - "for _ in range(1000):\n", - " params = update(params, xs, ys)\n", - "\n", - "plt.scatter(xs, ys)\n", - "plt.plot(xs, params.weight * xs + params.bias, c='red', label='Model Prediction')\n", - "plt.legend();" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "1wq3L6Xg1UHP" - }, - "source": [ - "## Taking it further\n", - "\n", - "The strategy described above is how any (jitted) JAX program must handle state. \n", - "\n", - "Handling parameters manually seems fine if you're dealing with two parameters, but what if it's a neural net with dozens of layers? You might already be getting worried about two things:\n", - "\n", - "1) Are we supposed to initialize them all manually, essentially repeating what we already write in the forward pass definition?\n", - "\n", - "2) Are we supposed to pipe all these things around manually?\n", - "\n", - "The details can be tricky to handle, but there are examples of libraries that take care of this for you. See [JAX Neural Network Libraries](https://github.com/google/jax#neural-network-libraries) for some examples." - ] - } - ], - "metadata": { - "colab": { - "name": "The Problem of State", - "provenance": [] - }, - "jupytext": { - "formats": "ipynb,md:myst" - }, - "kernelspec": { - "display_name": "Python 3", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.7.6" - } - }, - "nbformat": 4, - "nbformat_minor": 0 -} diff --git a/docs/jax-101/08-pjit.rst b/docs/jax-101/08-pjit.rst deleted file mode 100644 index 2c6f35eac797..000000000000 --- a/docs/jax-101/08-pjit.rst +++ /dev/null @@ -1,9 +0,0 @@ -:orphan: - -Introduction to `pjit` -====================== - -This content is no longer relevant, because :func:`~jax.pjit` and :func:`~jax.jit` -have been merged into a single unified interface. -For an updated guide to compiling and executing JAX functions in multi-host or multi-core environments, -see :doc:`../notebooks/Distributed_arrays_and_automatic_parallelization`. diff --git a/docs/jax-101/index.rst b/docs/jax-101/index.rst deleted file mode 100644 index 8192d10a2b16..000000000000 --- a/docs/jax-101/index.rst +++ /dev/null @@ -1,24 +0,0 @@ -:orphan: - -.. _Jax-101: - -Tutorial: JAX 101 -================= - -This is a tutorial developed by engineers and researchers at DeepMind_. - -.. toctree:: - :maxdepth: 1 - :caption: Tutorials - - 01-jax-basics - 02-jitting - 03-vectorization - 04-advanced-autodiff - 05-random-numbers - 05.1-pytrees - 06-parallelism - 07-state - - -.. _Deepmind: http://deepmind.com diff --git a/docs/jax.experimental.key_reuse.rst b/docs/jax.experimental.key_reuse.rst index 5c7caf80f0ce..7255afabfc10 100644 --- a/docs/jax.experimental.key_reuse.rst +++ b/docs/jax.experimental.key_reuse.rst @@ -2,12 +2,3 @@ ===================================== .. automodule:: jax.experimental.key_reuse - -API ---- - -.. autosummary:: - :toctree: _autosummary - - reuse_key - KeyReuseError diff --git a/docs/jax.experimental.pallas.rst b/docs/jax.experimental.pallas.rst new file mode 100644 index 000000000000..f10bb252460d --- /dev/null +++ b/docs/jax.experimental.pallas.rst @@ -0,0 +1,23 @@ +``jax.experimental.pallas`` module +================================== + +.. automodule:: jax.experimental.pallas + +Classes +------- + +.. autosummary:: + :toctree: _autosummary + + BlockSpec + +Functions +--------- + +.. autosummary:: + :toctree: _autosummary + + pallas_call + program_id + num_programs + diff --git a/docs/jax.experimental.rst b/docs/jax.experimental.rst index 37cab679d891..3052e391e356 100644 --- a/docs/jax.experimental.rst +++ b/docs/jax.experimental.rst @@ -26,6 +26,9 @@ Experimental Modules jax.experimental.compilation_cache jax.experimental.key_reuse jax.experimental.mesh_utils + jax.experimental.serialize_executable + jax.experimental.shard_map + jax.experimental.pallas Experimental APIs ----------------- diff --git a/docs/jax.experimental.serialize_executable.rst b/docs/jax.experimental.serialize_executable.rst new file mode 100644 index 000000000000..67fa76cf8d81 --- /dev/null +++ b/docs/jax.experimental.serialize_executable.rst @@ -0,0 +1,13 @@ +``jax.experimental.serialize_executable`` module +================================================ + +.. automodule:: jax.experimental.serialize_executable + +API +--- + +.. autosummary:: + :toctree: _autosummary + + serialize + deserialize_and_load diff --git a/docs/jax.experimental.shard_map.rst b/docs/jax.experimental.shard_map.rst new file mode 100644 index 000000000000..65be7f21ba1e --- /dev/null +++ b/docs/jax.experimental.shard_map.rst @@ -0,0 +1,12 @@ +``jax.experimental.shard_map`` module +===================================== + +.. automodule:: jax.experimental.shard_map + +API +--- + +.. autosummary:: + :toctree: _autosummary + + shard_map diff --git a/docs/jax.export.rst b/docs/jax.export.rst new file mode 100644 index 000000000000..d458b6c64e8e --- /dev/null +++ b/docs/jax.export.rst @@ -0,0 +1,52 @@ +``jax.export`` module +===================== + +.. automodule:: jax.export + +:mod:`jax.export` is a library for exporting and serializing JAX functions +for persistent archival. + +See the :ref:`export` documentation. + +Classes +------- + +.. autosummary:: + :toctree: _autosummary + + Exported + DisabledSafetyCheck + +Functions +--------- + +.. autosummary:: + :toctree: _autosummary + + export + deserialize + minimum_supported_calling_convention_version + maximum_supported_calling_convention_version + default_export_platform + +Functions related to shape polymorphism +--------------------------------------- + +.. autosummary:: + :toctree: _autosummary + + symbolic_shape + symbolic_args_specs + is_symbolic_dim + SymbolicScope + +Constants +--------- + +.. data:: jax.export.minimum_supported_serialization_version + + The minimum supported serialization version; see :ref:`export-calling-convention-version`. + +.. data:: jax.export.maximum_supported_serialization_version + + The maximum supported serialization version; see :ref:`export-calling-convention-version`. diff --git a/docs/jax.extend.ffi.rst b/docs/jax.extend.ffi.rst new file mode 100644 index 000000000000..5928189eb647 --- /dev/null +++ b/docs/jax.extend.ffi.rst @@ -0,0 +1,11 @@ +``jax.extend.ffi`` module +========================= + +.. automodule:: jax.extend.ffi + +.. autosummary:: + :toctree: _autosummary + + ffi_call + ffi_lowering + pycapsule diff --git a/docs/jax.extend.rst b/docs/jax.extend.rst index 3b4ec41ea680..9cbee08e8e50 100644 --- a/docs/jax.extend.rst +++ b/docs/jax.extend.rst @@ -11,6 +11,7 @@ Modules .. toctree:: :maxdepth: 1 + jax.extend.ffi jax.extend.linear_util jax.extend.mlir jax.extend.random diff --git a/docs/jax.lax.rst b/docs/jax.lax.rst index b916458a33b3..32db1ba77dea 100644 --- a/docs/jax.lax.rst +++ b/docs/jax.lax.rst @@ -158,7 +158,6 @@ Operators sub tan tanh - tie_in top_k transpose zeros_like_array diff --git a/docs/jax.lib.rst b/docs/jax.lib.rst index 2513ccf1d6b9..5da79f0a80ae 100644 --- a/docs/jax.lib.rst +++ b/docs/jax.lib.rst @@ -21,4 +21,6 @@ jax.lib.xla_client .. currentmodule:: jaxlib.xla_client .. autosummary:: - :toctree: _autosummary + :toctree: _autosummary + + register_custom_call_target diff --git a/docs/jax.nn.initializers.rst b/docs/jax.nn.initializers.rst index d96ba43f0d49..246e0cdbe9a1 100644 --- a/docs/jax.nn.initializers.rst +++ b/docs/jax.nn.initializers.rst @@ -14,8 +14,8 @@ consistent with definitions used in Keras and Sonnet. An initializer is a function that takes three arguments: ``(key, shape, dtype)`` and returns an array with dimensions ``shape`` and -data type ``dtype``. Argument ``key`` is a :class:`jax.random.PRNGKey` random -key used when generating random numbers to initialize the array. +data type ``dtype``. Argument ``key`` is a PRNG key (e.g. from +:func:`jax.random.key`), used to generate random numbers to initialize the array. .. autosummary:: :toctree: _autosummary diff --git a/docs/jax.nn.rst b/docs/jax.nn.rst index ae2c6f24a10a..33223ee755e5 100644 --- a/docs/jax.nn.rst +++ b/docs/jax.nn.rst @@ -22,6 +22,8 @@ Activation functions relu6 sigmoid softplus + sparse_plus + sparse_sigmoid soft_sign silu swish @@ -37,6 +39,7 @@ Activation functions gelu glu squareplus + mish Other functions --------------- diff --git a/docs/jax.numpy.rst b/docs/jax.numpy.rst index cdc557477c4e..b96dfcdfb208 100644 --- a/docs/jax.numpy.rst +++ b/docs/jax.numpy.rst @@ -138,6 +138,7 @@ namespace; they are listed below. csingle cumprod cumsum + cumulative_sum deg2rad degrees delete @@ -391,6 +392,7 @@ namespace; they are listed below. tensordot tile trace + trapezoid transpose tri tril @@ -416,6 +418,7 @@ namespace; they are listed below. unique_values unpackbits unravel_index + unstack unsignedinteger unwrap vander @@ -492,6 +495,7 @@ jax.numpy.linalg tensordot tensorinv tensorsolve + trace vector_norm vecdot diff --git a/docs/jax.random.rst b/docs/jax.random.rst index ea7845f8ce0e..9d6369d2d2b1 100644 --- a/docs/jax.random.rst +++ b/docs/jax.random.rst @@ -18,6 +18,7 @@ Key Creation & Manipulation wrap_key_data fold_in split + clone Random Samplers ~~~~~~~~~~~~~~~ diff --git a/docs/jax.rst b/docs/jax.rst index 2c4e4e5ea6a2..b112490a0912 100644 --- a/docs/jax.rst +++ b/docs/jax.rst @@ -27,6 +27,7 @@ Subpackages jax.tree jax.tree_util jax.typing + jax.export jax.extend jax.example_libraries jax.experimental @@ -113,6 +114,7 @@ jax.Array (:code:`jax.Array`) Array make_array_from_callback make_array_from_single_device_arrays + make_array_from_process_local_data Vectorization (:code:`vmap`) ---------------------------- diff --git a/docs/jax.scipy.rst b/docs/jax.scipy.rst index 7ed9729f9b5a..f6d8a151440b 100644 --- a/docs/jax.scipy.rst +++ b/docs/jax.scipy.rst @@ -1,6 +1,16 @@ ``jax.scipy`` module ==================== +jax.scipy.cluster +----------------- + +.. automodule:: jax.scipy.cluster.vq + +.. autosummary:: + :toctree: _autosummary + + vq + jax.scipy.fft ------------- @@ -24,6 +34,17 @@ jax.scipy.integrate trapezoid +jax.scipy.interpolate +--------------------- + +.. automodule:: jax.scipy.interpolate + +.. autosummary:: + :toctree: _autosummary + + RegularGridInterpolator + + jax.scipy.linalg ---------------- @@ -43,6 +64,7 @@ jax.scipy.linalg expm_frechet funm hessenberg + hilbert inv lu lu_factor @@ -92,6 +114,7 @@ jax.scipy.signal correlate correlate2d csd + detrend istft stft welch @@ -145,12 +168,15 @@ jax.scipy.special gammainc gammaincc gammaln + gammasgn hyp1f1 i0 i0e i1 i1e + kl_div log_ndtr + log_softmax logit logsumexp lpmn @@ -160,13 +186,13 @@ jax.scipy.special ndtri poch polygamma + rel_entr + softmax spence sph_harm xlog1py xlogy zeta - kl_div - rel_entr jax.scipy.stats @@ -393,6 +419,7 @@ jax.scipy.stats.poisson logpmf pmf + cdf jax.scipy.stats.t ~~~~~~~~~~~~~~~~~ diff --git a/docs/jax.sharding.rst b/docs/jax.sharding.rst index 7b1393d8e2c4..954f62b8a52d 100644 --- a/docs/jax.sharding.rst +++ b/docs/jax.sharding.rst @@ -10,9 +10,6 @@ Classes .. autoclass:: Sharding :members: -.. autoclass:: XLACompatibleSharding - :members: - :show-inheritance: .. autoclass:: SingleDeviceSharding :members: :show-inheritance: diff --git a/docs/jax.stages.rst b/docs/jax.stages.rst index f8adce32b7c6..804019ee1cc6 100644 --- a/docs/jax.stages.rst +++ b/docs/jax.stages.rst @@ -9,7 +9,7 @@ Classes .. currentmodule:: jax.stages .. autoclass:: Wrapped - :members: lower + :members: trace, lower :special-members: __call__ .. autoclass:: Lowered diff --git a/docs/jax.tree_util.rst b/docs/jax.tree_util.rst index da52d87868e0..35bce340d4de 100644 --- a/docs/jax.tree_util.rst +++ b/docs/jax.tree_util.rst @@ -14,23 +14,33 @@ List of Functions Partial all_leaves build_tree + register_dataclass register_pytree_node register_pytree_node_class register_pytree_with_keys register_pytree_with_keys_class + register_static + tree_flatten_with_path + tree_leaves_with_path + tree_map_with_path + treedef_children + treedef_is_leaf + treedef_tuple + keystr + +Legacy APIs +----------- +These APIs are now accessed via :mod:`jax.tree`. + +.. autosummary:: + :toctree: _autosummary + tree_all tree_flatten - tree_flatten_with_path tree_leaves - tree_leaves_with_path tree_map - tree_map_with_path tree_reduce tree_structure tree_transpose tree_unflatten - treedef_children - treedef_is_leaf - treedef_tuple - keystr diff --git a/docs/jax_array_migration.md b/docs/jax_array_migration.md index 860197becf6e..95d4a632a295 100644 --- a/docs/jax_array_migration.md +++ b/docs/jax_array_migration.md @@ -1,6 +1,8 @@ (jax-array-migration)= # jax.Array migration + + **yashkatariya@** ## TL;DR diff --git a/docs/jaxpr.rst b/docs/jaxpr.rst index b5eb8c2c1390..56be62162a9e 100644 --- a/docs/jaxpr.rst +++ b/docs/jaxpr.rst @@ -34,7 +34,7 @@ There are two related representations in the code for jaxprs, :py:class:`jax.core.Jaxpr`, and is what you obtain when you use :py:func:`jax.make_jaxpr` to inspect jaxprs. It has the following fields: - * ``jaxpr``: is a :py:class:`jax.core.Jaxpr` representing the actual + * ``jaxpr`` is a :py:class:`jax.core.Jaxpr` representing the actual computation content of the function (described below). * ``consts`` is a list of constants. @@ -42,9 +42,9 @@ The most interesting part of the ClosedJaxpr is the actual execution content, represented as a :py:class:`jax.core.Jaxpr` as printed using the following grammar:: - jaxpr ::= { lambda Var* ; Var+. - let Eqn* - in [Expr+] } + Jaxpr ::= { lambda Var* ; Var+. let + Eqn* + in [Expr+] } where: * The parameters of the jaxpr are shown as two lists of variables separated by @@ -62,7 +62,7 @@ where: Equations are printed as follows:: - Eqn ::= let Var+ = Primitive [ Param* ] Expr+ + Eqn ::= Var+ = Primitive [ Param* ] Expr+ where: * ``Var+`` are one or more intermediate variables to be defined as the output @@ -76,7 +76,7 @@ where: square brackets. Each parameter is shown as ``Name = Value``. -Most jaxpr primitives are first-order (they take just one or more Expr as arguments):: +Most jaxpr primitives are first-order (they take just one or more ``Expr`` as arguments):: Primitive := add | sub | sin | mul | ... @@ -103,8 +103,8 @@ Here there are no constvars, ``a`` and ``b`` are the input variables and they correspond respectively to ``first`` and ``second`` function parameters. The scalar literal ``3.0`` is kept inline. -The ``reduce_sum`` primitive has named parameters ``axes`` and ``input_shape``, in -addition to the operand ``e``. +The ``reduce_sum`` primitive has named parameter ``axes``, in addition to the +operand ``e``. Note that even though execution of a program that calls into JAX builds a jaxpr, Python-level control-flow and Python-level functions execute normally. @@ -218,18 +218,12 @@ For example: { lambda ; h:f32[]. let i:f32[] = sub h 2.0 in (i,) } { lambda ; j:f32[]. let k:f32[] = add j 3.0 in (k,) } ) - linear=(False,) ] d b in (e,) } -The cond primitive has a number of parameters: - - * `branches` are jaxprs that correspond to the branch - functionals. In this example, those functionals each take one - input variable, corresponding to ``x``. - * `linear` is a tuple of booleans that is used internally by the - auto-differentiation machinery to encode which of the input - parameters are used linearly in the conditional. +The `branches` parameter to the cond primitive corresponds to the branch +functionals. In this example, those functionals each take one input variable, +corresponding to ``x``. The above instance of the cond primitive takes two operands. The first one (``d``) is the branch index, then ``b`` is the operand (``arg``) to @@ -255,7 +249,6 @@ Another example, using :py:func:`lax.cond`: { lambda ; e:f32[]. let f:f32[] = sub e 3.0 in (f,) } { lambda ; g:f32[]. let h:f32[] = add g 3.0 in (h,) } ) - linear=(False,) ] c a in (d,) } @@ -287,7 +280,6 @@ contains a constant ``jnp.ones(1)`` that is hoisted as a `constvar` in (l,) } { lambda ; m_:i32[1] n:f32[1] o:f32[]. let in (n,) } ) - linear=(False, False, False) ] f a c d in (g,) } @@ -306,7 +298,7 @@ and :py:func:`jax.lax.fori_loop` lax.fori_loop(start: int, end: int, body: (int -> C -> C), init: C) -> C -In the above signature, “C” stands for the type of a the loop “carry” value. +In the above signature, “C” stands for the type of the loop “carry” value. For example, here is an example fori loop >>> import numpy as np @@ -371,6 +363,7 @@ For the example consider the function ``func11`` below { lambda ; a:f32[16] b:f32[]. let c:f32[16] = broadcast_in_dim[broadcast_dimensions=() shape=(16,)] 1.0 d:f32[] e:f32[16] = scan[ + _split_transpose=False jaxpr={ lambda ; f:f32[] g:f32[] h:f32[] i:f32[]. let j:f32[] = mul h i k:f32[] = convert_element_type[new_dtype=float32 weak_type=False] g diff --git a/docs/jep/10657-sequencing-effects.md b/docs/jep/10657-sequencing-effects.md index f6fe615c200e..5f7eb0da4c04 100644 --- a/docs/jep/10657-sequencing-effects.md +++ b/docs/jep/10657-sequencing-effects.md @@ -132,7 +132,7 @@ def f(token, x): return token, x ``` If we rewrite `jax.print` to take in and return a token, we have now sequenced -the two prints since the input to the second print depends is the output of the first print. +the two prints since the input to the second print depends on the output of the first print. The actual value of `token` can be anything really, but we'll see in practice that the tokens are invisible to users. @@ -207,7 +207,7 @@ computations will take the token as input and return it as an output. The implementation of this token threading involves upgrading the JAX lowering machinery to do this bookkeeping automatically. The main challenges involve dealing with higher-order primitives like call primitives -and control-flow primitives. We won't go into details in how to handle those in this design note. +and control-flow primitives. We won't go into details on how to handle those in this design note. ## Blocking on output tokens diff --git a/docs/jep/18137-numpy-scipy-scope.md b/docs/jep/18137-numpy-scipy-scope.md index 1b96f5762e3e..2371e11ee07e 100644 --- a/docs/jep/18137-numpy-scipy-scope.md +++ b/docs/jep/18137-numpy-scipy-scope.md @@ -140,7 +140,7 @@ incompatible with JAX’s computation model. We instead focus on {mod}`jax.rando which offers similar functionality using a counter-based PRNG. #### ❌ `numpy.ma` & `numpy.polynomial` -The {mod}`numpy.ma` andd {mod}`numpy.polynomial` submodules are mostly concerned with +The {mod}`numpy.ma` and {mod}`numpy.polynomial` submodules are mostly concerned with providing object-oriented interfaces to computations that can be expressed via other functional means (Axis 5); for this reason, we deem them out-of-scope for JAX. @@ -187,7 +187,7 @@ evaluations. {func}`jax.experimental.ode.odeint` is related, but rather limited under any active development. JAX does currently include {func}`jax.scipy.integrate.trapezoid`, but this is only because -{func}`numpy.trapz` was recently deprecated in favor of this. For any particular inputs, +{func}`numpy.trapz` was recently deprecated in favor of this. For any particular input, its implementation could be replaced with one line of {mod}`jax.numpy` expressions, so it’s not a particularly useful API to provide. diff --git a/docs/jep/263-prng.md b/docs/jep/263-prng.md index 66e2ae24b567..7ef10ae0e9c4 100644 --- a/docs/jep/263-prng.md +++ b/docs/jep/263-prng.md @@ -1,3 +1,4 @@ +(prng-design-jep)= # JAX PRNG Design We want a PRNG design that 1. is **expressive** in that it is convenient to use and it doesn’t constrain the user’s ability to write numerical programs with exactly the behavior that they want, diff --git a/docs/jep/9263-typed-keys.md b/docs/jep/9263-typed-keys.md index 3f203aaef2df..828b95e8ce00 100644 --- a/docs/jep/9263-typed-keys.md +++ b/docs/jep/9263-typed-keys.md @@ -21,6 +21,7 @@ Array([0, 0], dtype=uint32) (2,) >>> key.dtype dtype('uint32') + ``` Starting now, new-style RNG keys can be created with {func}`jax.random.key`: @@ -33,6 +34,7 @@ Array((), dtype=key) overlaying: () >>> key.dtype key + ``` This (scalar-shaped) array behaves the same as any other JAX array, except that its element type is a key (and associated metadata). We can make @@ -48,6 +50,7 @@ Array((4,), dtype=key) overlaying: [0 3]] >>> key_arr.shape (4,) + ``` Aside from switching to a new constructor, most PRNG-related code should continue to work as expected. You can continue to use keys in @@ -62,14 +65,17 @@ data = jax.random.uniform(key, shape=(5,)) However, not all numerical operations work on key arrays. They now intentionally raise errors: ```python ->>> key = key + 1 -ValueError: dtype=key is not a valid dtype for JAX type promotion. +>>> key = key + 1 # doctest: +SKIP +Traceback (most recent call last): +TypeError: add does not accept dtypes key, int32. + ``` If for some reason you need to recover the underlying buffer (the old-style key), you can do so with {func}`jax.random.key_data`: ```python >>> jax.random.key_data(key) Array([0, 0], dtype=uint32) + ``` For old-style keys, {func}`~jax.random.key_data` is an identity operation. @@ -108,6 +114,7 @@ True >>> raw_key = jax.random.PRNGKey(0) >>> jax.dtypes.issubdtype(raw_key.dtype, jax.dtypes.prng_key) False + ``` ### Type annotations for PRNG Keys @@ -173,6 +180,7 @@ Array((), dtype=key) overlaying: [0 0 0 0] >>> jax.random.uniform(key, shape=(3,)) Array([0.39904642, 0.8805201 , 0.73571277], dtype=float32) + ``` ### Safe PRNG key use @@ -237,15 +245,15 @@ time boils down to this: :class: red-background # Incorrect keys = random.split(random.PRNGKey(0)) -data = jax.vmap(random.uniform, axis=1)(keys) +data = jax.vmap(random.uniform, in_axes=1)(keys) ``` ```{code-block} python :class: green-background # Correct keys = random.split(random.PRNGKey(0)) -data = jax.vmap(random.uniform, axis=0)(keys) +data = jax.vmap(random.uniform, in_axes=0)(keys) ``` -The bug here is subtle. By mapping over `axis=1`, this code makes new keys by +The bug here is subtle. By mapping over `in_axes=1`, this code makes new keys by combining a single element from each key buffer in the batch. The resulting keys are different from one another, but are effectively "derived" in a non-standard way. Again, the PRNG is not designed or tested to produce @@ -322,6 +330,7 @@ which has the following property: ```python >>> jax.dtypes.issubdtype(jax.dtypes.prng_key, jax.dtypes.extended) True + ``` PRNG key arrays then have a dtype with the following properties: ```python @@ -330,6 +339,7 @@ PRNG key arrays then have a dtype with the following properties: True >>> jax.dtypes.issubdtype(key.dtype, jax.dtypes.prng_key) True + ``` And in addition to `key.dtype._rules` as outlined for extended dtypes in general, PRNG dtypes define `key.dtype._impl`, which contains the metadata diff --git a/docs/jep/9407-type-promotion.ipynb b/docs/jep/9407-type-promotion.ipynb index ebaec82356c9..2aef1768112f 100644 --- a/docs/jep/9407-type-promotion.ipynb +++ b/docs/jep/9407-type-promotion.ipynb @@ -711,7 +711,7 @@ "source": [ "This behavior gives motivation to our `*` notation for scalar values: the `*` is reminiscent of a wildcard that can take on any desired value.\n", "\n", - "The benefit of these semantics are that you can readily express sequences of operations with clean Python code, without having to explicitly cast scalars to the appropriate type. Imagine if rather than writing this:\n", + "The benefit of these semantics is that you can readily express sequences of operations with clean Python code, without having to explicitly cast scalars to the appropriate type. Imagine if rather than writing this:\n", "```python\n", "3 * (x + 1) ** 2\n", "```\n", @@ -801,7 +801,7 @@ "\n", "5. Wherever possible, binary promotion should avoid resulting in types that are wider than the inputs. This is to ensure that JAX's implicit promotions remain friendly to accelerator-based workflows, in which users often want to restrict types to 32-bit (or in some cases 16-bit) values.\n", "\n", - "Each new connection on the lattice introduces some level of convenience to the user (a new set of types that can interact without explicit casting), but the convenience may becomes too costly if any of the above criteria are violated. Developing a full promotion lattice involves striking a balance between this convenience and this cost." + "Each new connection on the lattice introduces some level of convenience to the user (a new set of types that can interact without explicit casting), but the convenience may become too costly if any of the above criteria are violated. Developing a full promotion lattice involves striking a balance between this convenience and this cost." ] }, { @@ -946,7 +946,7 @@ "source": [ "### How to handle `uint64`?\n", "\n", - "The approached to mixed signed/unsigned integer promotion leaves out one type: `uint64`. Following the pattern above, the output of a mixed-integer operation involving `uint64` should result in `int128`, but this is not a standard available dtype.\n", + "The approach to mixed signed/unsigned integer promotion leaves out one type: `uint64`. Following the pattern above, the output of a mixed-integer operation involving `uint64` should result in `int128`, but this is not a standard available dtype.\n", "\n", "Numpy's choice here is to promote to `float64`:" ] @@ -1335,7 +1335,7 @@ "However, these advantages comes with a few tradeoffs:\n", "\n", "- mixed float/integer promotion is very prone to precision loss: for example, `int64` (with a maximum value of $9.2 \\times 10^{18}$) can be promoted to `float16` (with a maximum value of $6.5 \\times 10^4$), meaning most representable values will become `inf`.\n", - "- as mentioned above, `f*` can no longer be thought of as a \"scalar type\", but whether as a different flavor of float64. In JAX's parlance, this is referred to as a [*weak type*](https://jax.readthedocs.io/en/latest/type_promotion.html#weakly-typed-values-in-jax), in that it is represented as 64-bit, but only weakly holds to this bit width in promotion with other values.\n", + "- as mentioned above, `f*` can no longer be thought of as a \"scalar type\", but as a different flavor of float64. In JAX's parlance, this is referred to as a [*weak type*](https://jax.readthedocs.io/en/latest/type_promotion.html#weakly-typed-values-in-jax), in that it is represented as 64-bit, but only weakly holds to this bit width in promotion with other values.\n", "\n", "Note that also, this approach still leaves the `uint64` promotion question unanswered, although it is perhaps reasonable to close the lattice by connecting `u64` to `f*`." ] diff --git a/docs/jep/9407-type-promotion.md b/docs/jep/9407-type-promotion.md index fc78946ceac9..107bcd8c968b 100644 --- a/docs/jep/9407-type-promotion.md +++ b/docs/jep/9407-type-promotion.md @@ -5,7 +5,7 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.16.0 + jupytext_version: 1.16.1 kernelspec: display_name: Python 3 language: python @@ -311,7 +311,7 @@ for dtype in [np.int8, np.int16, np.int32, np.int64]: This behavior gives motivation to our `*` notation for scalar values: the `*` is reminiscent of a wildcard that can take on any desired value. -The benefit of these semantics are that you can readily express sequences of operations with clean Python code, without having to explicitly cast scalars to the appropriate type. Imagine if rather than writing this: +The benefit of these semantics is that you can readily express sequences of operations with clean Python code, without having to explicitly cast scalars to the appropriate type. Imagine if rather than writing this: ```python 3 * (x + 1) ** 2 ``` @@ -373,7 +373,7 @@ Broadly speaking, we want any additional connections to satisfy a few properties 5. Wherever possible, binary promotion should avoid resulting in types that are wider than the inputs. This is to ensure that JAX's implicit promotions remain friendly to accelerator-based workflows, in which users often want to restrict types to 32-bit (or in some cases 16-bit) values. -Each new connection on the lattice introduces some level of convenience to the user (a new set of types that can interact without explicit casting), but the convenience may becomes too costly if any of the above criteria are violated. Developing a full promotion lattice involves striking a balance between this convenience and this cost. +Each new connection on the lattice introduces some level of convenience to the user (a new set of types that can interact without explicit casting), but the convenience may become too costly if any of the above criteria are violated. Developing a full promotion lattice involves striking a balance between this convenience and this cost. +++ {"id": "GSqwTTS8nYdn"} @@ -457,7 +457,7 @@ Again, the connections added here are precisely the promotion semantics implemen ### How to handle `uint64`? -The approached to mixed signed/unsigned integer promotion leaves out one type: `uint64`. Following the pattern above, the output of a mixed-integer operation involving `uint64` should result in `int128`, but this is not a standard available dtype. +The approach to mixed signed/unsigned integer promotion leaves out one type: `uint64`. Following the pattern above, the output of a mixed-integer operation involving `uint64` should result in `int128`, but this is not a standard available dtype. Numpy's choice here is to promote to `float64`: @@ -680,7 +680,7 @@ This is important because `f16` and `bf16` are not comparable because they utili However, these advantages comes with a few tradeoffs: - mixed float/integer promotion is very prone to precision loss: for example, `int64` (with a maximum value of $9.2 \times 10^{18}$) can be promoted to `float16` (with a maximum value of $6.5 \times 10^4$), meaning most representable values will become `inf`. -- as mentioned above, `f*` can no longer be thought of as a "scalar type", but whether as a different flavor of float64. In JAX's parlance, this is referred to as a [*weak type*](https://jax.readthedocs.io/en/latest/type_promotion.html#weakly-typed-values-in-jax), in that it is represented as 64-bit, but only weakly holds to this bit width in promotion with other values. +- as mentioned above, `f*` can no longer be thought of as a "scalar type", but as a different flavor of float64. In JAX's parlance, this is referred to as a [*weak type*](https://jax.readthedocs.io/en/latest/type_promotion.html#weakly-typed-values-in-jax), in that it is represented as 64-bit, but only weakly holds to this bit width in promotion with other values. Note that also, this approach still leaves the `uint64` promotion question unanswered, although it is perhaps reasonable to close the lattice by connecting `u64` to `f*`. diff --git a/docs/tutorials/jit-compilation.md b/docs/jit-compilation.md similarity index 65% rename from docs/tutorials/jit-compilation.md rename to docs/jit-compilation.md index 68c69b7bf54a..2d442c8411aa 100644 --- a/docs/tutorials/jit-compilation.md +++ b/docs/jit-compilation.md @@ -5,16 +5,25 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.16.0 + jupytext_version: 1.16.1 kernelspec: display_name: Python 3 language: python name: python3 --- +```{code-cell} +:tags: [remove-cell] + +# This ensures that code cell tracebacks appearing below will be concise. +%xmode minimal +``` + (jit-compilation)= # Just-in-time compilation + + In this section, we will further explore how JAX works, and how we can make it performant. We will discuss the {func}`jax.jit` transformation, which will perform *Just In Time* (JIT) compilation of a JAX Python function so it can be executed efficiently in XLA. @@ -22,10 +31,10 @@ compilation of a JAX Python function so it can be executed efficiently in XLA. ## How JAX transformations work In the previous section, we discussed that JAX allows us to transform Python functions. -This is done by first converting the Python function into a simple intermediate language called jaxpr. -The transformations then work on the jaxpr representation. +JAX accomplishes this by reducing each function into a sequence of {term}`primitive` operations, each +representing one fundamental unit of computation. -We can show a representation of the jaxpr of a function by using {func}`jax.make_jaxpr`: +One way to see the sequence of primitives behind a function is using {func}`jax.make_jaxpr`: ```{code-cell} import jax @@ -44,9 +53,14 @@ print(jax.make_jaxpr(log2)(3.0)) The {ref}`understanding-jaxprs` section of the documentation provides more information on the meaning of the above output. -Importantly, note how the jaxpr does not capture the side-effect of the function: there is nothing in it corresponding to `global_list.append(x)`. This is a feature, not a bug: JAX is designed to understand side-effect-free (a.k.a. functionally pure) code. If *pure function* and *side-effect* are unfamiliar terms, this is explained in a little more detail in [🔪 JAX - The Sharp Bits 🔪: Pure Functions](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#pure-functions). +Importantly, notice that the jaxpr does not capture the side-effect present in the function: there is nothing in it corresponding to `global_list.append(x)`. +This is a feature, not a bug: JAX transformations are designed to understand side-effect-free (a.k.a. functionally pure) code. +If *pure function* and *side-effect* are unfamiliar terms, this is explained in a little more detail in [🔪 JAX - The Sharp Bits 🔪: Pure Functions](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#pure-functions). -Of course, impure functions can still be written and even run, but JAX gives no guarantees about their behaviour once converted to jaxpr. However, as a rule of thumb, you can expect (but shouldn't rely on) the side-effects of a JAX-transformed function to run once (during the first call), and never again. This is because of the way that JAX generates jaxpr, using a process called 'tracing'. +Impure functions are dangerous because under JAX transformations they are likely not to behave as intended; they might fail silently, or produce surprising downstream errors like leaked Tracers. +Moreover, JAX often can't detect when side effects are present. +(If you want debug printing, use {func}`jax.debug.print`. To express general side-effects at the cost of performance, see {func}`jax.experimental.io_callback`. +To check for tracer leaks at the cost of performance, use with {func}`jax.check_tracer_leaks`). When tracing, JAX wraps each argument by a *tracer* object. These tracers then record all JAX operations performed on them during the function call (which happens in regular Python). Then, JAX uses the tracer records to reconstruct the entire function. The output of that reconstruction is the jaxpr. Since the tracers do not record the Python side-effects, they do not appear in the jaxpr. However, the side-effects still happen during the trace itself. @@ -66,7 +80,8 @@ See how the printed `x` is a `Traced` object? That's the JAX internals at work. The fact that the Python code runs at least once is strictly an implementation detail, and so shouldn't be relied upon. However, it's useful to understand as you can use it when debugging to print out intermediate values of a computation. -A key thing to understand is that jaxpr captures the function as executed on the parameters given to it. For example, if we have a conditional, jaxpr will only know about the branch we take: +A key thing to understand is that a jaxpr captures the function as executed on the parameters given to it. +For example, if we have a Python conditional, the jaxpr will only know about the branch we take: ```{code-cell} def log2_if_rank_2(x): @@ -116,7 +131,6 @@ Here's what just happened: 1) We defined `selu_jit` as the compiled version of `selu`. 2) We called `selu_jit` once on `x`. This is where JAX does its tracing -- it needs to have some inputs to wrap in tracers, after all. The jaxpr is then compiled using XLA into very efficient code optimized for your GPU or TPU. Finally, the compiled code is executed to satisfy the call. Subsequent calls to `selu_jit` will use the compiled code directly, skipping the python implementation entirely. - (If we didn't include the warm-up call separately, everything would still work, but then the compilation time would be included in the benchmark. It would still be faster, because we run many loops in the benchmark, but it wouldn't be a fair comparison.) 3) We timed the execution speed of the compiled version. (Note the use of {func}`~jax.block_until_ready`, which is required due to JAX's {ref}`async-dispatch`). @@ -136,8 +150,7 @@ def f(x): else: return 2 * x -f_jit = jax.jit(f) -f_jit(10) # Should raise an error. +jax.jit(f)(10) # Raises an error ``` ```{code-cell} @@ -151,19 +164,17 @@ def g(x, n): i += 1 return x + i -g_jit = jax.jit(g) -g_jit(10, 20) # Should raise an error. +jax.jit(g)(10, 20) # Raises an error ``` -The problem is that we tried to condition on the *value* of an input to the function being jitted. The reason we can't do this is related to the fact mentioned above that jaxpr depends on the actual values used to trace it. - -The more specific information about the values we use in the trace, the more we can use standard Python control flow to express ourselves. However, being too specific means we can't reuse the same traced function for other values. JAX solves this by tracing at different levels of abstraction for different purposes. - -For {func}`jax.jit`, the default level is {class}`~jax.core.ShapedArray` -- that is, each tracer has a concrete shape (which we're allowed to condition on), but no concrete value. This allows the compiled function to work on all possible inputs with the same shape -- the standard use case in machine learning. However, because the tracers have no concrete value, if we attempt to condition on one, we get the error above. - -In {func}`jax.grad`, the constraints are more relaxed, so you can do more. If you compose several transformations, however, you must satisfy the constraints of the most strict one. So, if you `jit(grad(f))`, `f` mustn't condition on value. For more detail on the interaction between Python control flow and JAX, see [🔪 JAX - The Sharp Bits 🔪: Control Flow](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#control-flow). +The problem in both cases is that we tried to condition the trace-time flow of the program using runtime values. +Traced values within JIT, like `x` and `n` here, can only affect control flow via their static attributes: such as +`shape` or `dtype`, and not via their values. +For more detail on the interaction between Python control flow and JAX, see [🔪 JAX - The Sharp Bits 🔪: Control Flow](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#control-flow). -One way to deal with this problem is to rewrite the code to avoid conditionals on value. Another is to use special {ref}`lax-control-flow` like {func}`jax.lax.cond`. However, sometimes that is impossible. In that case, you can consider jitting only part of the function. For example, if the most computationally expensive part of the function is inside the loop, we can JIT just that inner part (though make sure to check the next section on caching to avoid shooting yourself in the foot): +One way to deal with this problem is to rewrite the code to avoid conditionals on value. Another is to use special {ref}`lax-control-flow` like {func}`jax.lax.cond`. However, sometimes that is not possible or practical. +In that case, you can consider JIT-compiling only part of the function. +For example, if the most computationally expensive part of the function is inside the loop, we can JIT-compile just that inner part (though make sure to check the next section on caching to avoid shooting yourself in the foot): ```{code-cell} # While loop conditioned on x and n with a jitted body. @@ -181,7 +192,11 @@ def g_inner_jitted(x, n): g_inner_jitted(10, 20) ``` -If we really need to JIT a function that has a condition on the value of an input, we can tell JAX to help itself to a less abstract tracer for a particular input by specifying `static_argnums` or `static_argnames`. The cost of this is that the resulting jaxpr is less flexible, so JAX will have to re-compile the function for every new value of the specified static input. It is only a good strategy if the function is guaranteed to get limited different values. +## Marking arguments as static + +If we really need to JIT-compile a function that has a condition on the value of an input, we can tell JAX to help itself to a less abstract tracer for a particular input by specifying `static_argnums` or `static_argnames`. +The cost of this is that the resulting jaxpr and compiled artifact depends on the particular value passed, and so JAX will have to re-compile the function for every new value of the specified static input. +It is only a good strategy if the function is guaranteed to see a limited set of static values. ```{code-cell} f_jit_correct = jax.jit(f, static_argnums=0) @@ -208,22 +223,6 @@ def g_jit_decorated(x, n): print(g_jit_decorated(10, 20)) ``` -## When to use JIT - -In many of the examples above, using `jit` is not worth it: - -```{code-cell} -print("g jitted:") -%timeit g_jit_correct(10, 20).block_until_ready() - -print("g:") -%timeit g(10, 20) -``` - -This is because {func}`jax.jit` introduces some overhead itself. Therefore, it usually only saves time if the compiled function is complex and you will run it numerous times. Fortunately, this is common in machine learning, where we tend to compile a large, complicated model, then run it for millions of iterations. - -Generally, you want to jit the largest possible chunk of your computation; ideally, the entire update step. This gives the compiler maximum freedom to optimise. - ## JIT and caching With the compilation overhead of the first JIT call, understanding how and when {func}`jax.jit` caches previous compilations is key to using it effectively. diff --git a/docs/key-concepts.md b/docs/key-concepts.md new file mode 100644 index 000000000000..4b114c857460 --- /dev/null +++ b/docs/key-concepts.md @@ -0,0 +1,191 @@ +--- +jupytext: + formats: md:myst + text_representation: + extension: .md + format_name: myst + format_version: 0.13 + jupytext_version: 1.16.1 +kernelspec: + display_name: Python 3 + language: python + name: python3 +--- + +(key-concepts)= +# Key Concepts + + + +This section briefly introduces some key concepts of the JAX package. + +(key-concepts-jax-arrays)= +## JAX arrays ({class}`jax.Array`) + +The default array implementation in JAX is {class}`jax.Array`. In many ways it is similar to +the {class}`numpy.ndarray` type that you may be familar with from the NumPy package, but it +has some important differences. + +### Array creation + +We typically don't call the {class}`jax.Array` constructor directly, but rather create arrays via JAX API functions. +For example, {mod}`jax.numpy` provides familar NumPy-style array construction functionality +such as {func}`jax.numpy.zeros`, {func}`jax.numpy.linspace`, {func}`jax.numpy.arange`, etc. + +```{code-cell} +import jax +import jax.numpy as jnp + +x = jnp.arange(5) +isinstance(x, jax.Array) +``` + +If you use Python type annotations in your code, {class}`jax.Array` is the appropriate +annotation for jax array objects (see {mod}`jax.typing` for more discussion). + +### Array devices and sharding + +JAX Array objects have a `devices` method that lets you inspect where the contents of the array are stored. In the simplest cases, this will be a single CPU device: + +```{code-cell} +x.devices() +``` + +In general, an array may be *sharded* across multiple devices, in a manner that can be inspected via the `sharding` attribute: + +```{code-cell} +x.sharding +``` + +Here the array is on a single device, but in general a JAX array can be +sharded across multiple devices, or even multiple hosts. +To read more about sharded arrays and parallel computation, refer to {ref}`sharded-computation` + +(key-concepts-transformations)= +## Transformations +Along with functions to operate on arrays, JAX includes a number of +{term}`transformations ` which operate on JAX functions. These include + +- {func}`jax.jit`: Just-in-time (JIT) compilation; see {ref}`jit-compilation` +- {func}`jax.vmap`: Vectorizing transform; see {ref}`automatic-vectorization` +- {func}`jax.grad`: Gradient transform; see {ref}`automatic-differentiation` + +as well as several others. Transformations accept a function as an argument, and return a +new transformed function. For example, here's how you might JIT-compile a simple SELU function: + +```{code-cell} +def selu(x, alpha=1.67, lambda_=1.05): + return lambda_ * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha) + +selu_jit = jax.jit(selu) +print(selu_jit(1.0)) +``` + +Often you'll see transformations applied using Python's decorator syntax for convenience: + +```{code-cell} +@jax.jit +def selu(x, alpha=1.67, lambda_=1.05): + return lambda_ * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha) +``` + +Transformations like {func}`~jax.jit`, {func}`~jax.vmap`, {func}`~jax.grad`, and others are +key to using JAX effectively, and we'll cover them in detail in later sections. + +(key-concepts-tracing)= +## Tracing + +The magic behind transformations is the notion of a {term}`Tracer`. +Tracers are abstract stand-ins for array objects, and are passed to JAX functions in order +to extract the sequence of operations that the function encodes. + +You can see this by printing any array value within transformed JAX code; for example: + +```{code-cell} +@jax.jit +def f(x): + print(x) + return x + 1 + +x = jnp.arange(5) +result = f(x) +``` + +The value printed is not the array `x`, but a {class}`~jax.core.Tracer` instance that +represents essential attributes of `x`, such as its `shape` and `dtype`. By executing +the function with traced values, JAX can determine the sequence of operations encoded +by the function before those operations are actually executed: transformations like +{func}`~jax.jit`, {func}`~jax.vmap`, and {func}`~jax.grad` can then map this sequence +of input operations to a transformed sequence of operations. + +(key-concepts-jaxprs)= +## Jaxprs + +JAX has its own intermediate representation for sequences of operations, known as a {term}`jaxpr`. +A jaxpr (short for *JAX exPRession*) is a simple representation of a functional program, comprising a sequence of {term}`primitive` operations. + +For example, consider the `selu` function we defined above: + +```{code-cell} +def selu(x, alpha=1.67, lambda_=1.05): + return lambda_ * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha) +``` + +We can use the {func}`jax.make_jaxpr` utility to convert this function into a jaxpr +given a particular input: + +```{code-cell} +x = jnp.arange(5.0) +jax.make_jaxpr(selu)(x) +``` + +Comparing this to the Python function definition, we see that it encodes the precise +sequence of operations that the function represents. We'll go into more depth about +jaxprs later in {ref}`jax-internals-jaxpr`. + +(key-concepts-pytrees)= +## Pytrees + +JAX functions and transformations fundamentally operate on arrays, but in practice it is +convenient to write code that work with collections of arrays: for example, a neural +network might organize its parameters in a dictionary of arrays with meaningful keys. +Rather than handle such structures on a case-by-case basis, JAX relies on the {term}`pytree` +abstraction to treat such collections in a uniform matter. + +Here are some examples of objects that can be treated as pytrees: + +```{code-cell} +# (nested) list of parameters +params = [1, 2, (jnp.arange(3), jnp.ones(2))] + +print(jax.tree.structure(params)) +print(jax.tree.leaves(params)) +``` + +```{code-cell} +# Dictionary of parameters +params = {'n': 5, 'W': jnp.ones((2, 2)), 'b': jnp.zeros(2)} + +print(jax.tree.structure(params)) +print(jax.tree.leaves(params)) +``` + +```{code-cell} +# Named tuple of parameters +from typing import NamedTuple + +class Params(NamedTuple): + a: int + b: float + +params = Params(1, 5.0) +print(jax.tree.structure(params)) +print(jax.tree.leaves(params)) +``` + +JAX has a number of general-purpose utilities for working with PyTrees; for example +the functions {func}`jax.tree.map` can be used to map a function to every leaf in a +tree, and {func}`jax.tree.reduce` can be used to apply a reduction across the leaves +in a tree. + +You can learn more in the {ref}`working-with-pytrees` tutorial. diff --git a/docs/multi_process.md b/docs/multi_process.md index a779eb4f12e5..7d7083bde10f 100644 --- a/docs/multi_process.md +++ b/docs/multi_process.md @@ -1,5 +1,7 @@ # Using JAX in multi-host and multi-process environments + + ## Introduction This guide explains how to use JAX in environments such as @@ -12,7 +14,7 @@ operations (e.g. {func}`jax.lax.psum` ) in multi-process settings, although other communication methods may be useful too depending on your use case (e.g. RPC, [mpi4jax](https://github.com/mpi4jax/mpi4jax)). If you’re not already familiar with JAX’s collective operations, we recommend starting with the -{doc}`/jax-101/06-parallelism` notebook. An important requirement of +{doc}`/sharded-computation` section. An important requirement of multi-process environments in JAX is direct communication links between accelerators, e.g. the high-speed interconnects for Cloud TPUs or [NCCL](https://developer.nvidia.com/nccl) for GPUs. These links allow @@ -28,12 +30,15 @@ Key concepts: * Each process has a distinct set of *local* devices it can address. The *global* devices are the set of all devices across all processes. - * Use standard JAX parallelism APIs like {func}`~jax.pmap` and - {func}`~jax.experimental.maps.xmap` . Each process “sees” *local* input and - output to parallelized functions, but communication inside the computations - is *global*. - * Make sure all processes run the same parallel computations in the same + * Use standard JAX parallelism APIs like {func}`~jax.jit` (see + {doc}`/sharded-computation` tutorial) and + {func}`~jax.experimental.shard_map.shard_map`. jax.jit only accepts + globally shaped arrays. shard_map allows you to drop to per-device + shape. + * Make sure all processes run the same parallel computations in the same order. + * Make sure all processes has the same number of local devices. + * Make sure all devices are the same (e.g., all V100, or all H100). ### Launching JAX processes @@ -62,6 +67,9 @@ The API {func}`jax.distributed.initialize` takes several arguments, namely: with a port available on that process. Process 0 will start a JAX service exposed via that IP address and port, to which the other processes in the cluster will connect. + * `coordinator_bind_address`: the IP address and port to which the JAX service + on process 0 in your cluster will bind. By default, it will bind to all + available interfaces using the same port as `coordinator_address`. * `num_processes`: the number of processes in the cluster * `process_id`: the ID number of this process, in the range `[0 .. num_processes)`. @@ -106,13 +114,13 @@ only launch computations on the 8 TPU cores attached directly to that host (see the [Cloud TPU System Architecture](https://cloud.google.com/tpu/docs/system-architecture) documentation for more details). You can see a process’s local devices via -{func}`jax.local_devices()` . +{func}`jax.local_devices()`. **The *global* devices are the devices across all processes.** A computation can span devices across processes and perform collective operations via the direct communication links between devices, as long as each process launches the computation on its local devices. You can see all available global devices via -{func}`jax.devices()` . A process’s local devices are always a subset of the +{func}`jax.devices()`. A process’s local devices are always a subset of the global devices. ### Running multi-process computations @@ -120,17 +128,13 @@ global devices. So how do you actually run a computation involving cross-process communication? **Use the same parallel evaluation APIs that you would in a single process!** -For example, {func}`~jax.pmap` can be used to run a parallel computation across -multiple processes. (If you’re not already familiar with how to use -{func}`~jax.pmap` to run across multiple devices within a single process, check -out the {doc}`/jax-101/06-parallelism` notebook.) Each process should call the -same pmapped function and pass in arguments to be mapped across its *local* -devices (i.e., the pmapped axis size is equal to the number of local devices). -Similarly, the function will return outputs sharded across *local* devices only. -Inside the function, however, collective communication operations are run across -all *global* devices, across all processes. Conceptually, this can be thought of -as running a pmap over a single array sharded across hosts, where each host -“sees” only its local shard of the input and output. +For example, {func}`~jax.experimental.shard_map.shard_map` can be used +to run a parallel computation across multiple processes. (If you’re +not already familiar with how to use `shard_map` to run across +multiple devices within a single process, check out the +{doc}`/sharded-computation` tutorial.) Conceptually, this can be +thought of as running a pmap over a single array sharded across hosts, +where each host “sees” only its local shard of the input and output. Here’s an example of multi-process pmap in action: @@ -148,12 +152,6 @@ Here’s an example of multi-process pmap in action: ShardedDeviceArray([32., 32., 32., 32., 32., 32., 32., 32.], dtype=float32) ``` -{func}`~jax.experimental.maps.xmap` works similarly when using a physical -hardware mesh (see the {doc}`xmap tutorial` if you’re -not familiar with the single-process version). Like {func}`~jax.pmap` , the -inputs and outputs are local and any parallel communication inside the xmapped -function is global. The mesh is also global. - **It’s very important that all processes run the same cross-process computations in the same order.** Running the same JAX Python program in each process is usually sufficient. Some common pitfalls to look out for that may cause diff --git a/docs/notebooks/Common_Gotchas_in_JAX.ipynb b/docs/notebooks/Common_Gotchas_in_JAX.ipynb index cd89e425956d..2665e25fdd43 100644 --- a/docs/notebooks/Common_Gotchas_in_JAX.ipynb +++ b/docs/notebooks/Common_Gotchas_in_JAX.ipynb @@ -8,6 +8,8 @@ "source": [ "# 🔪 JAX - The Sharp Bits 🔪\n", "\n", + "\n", + "\n", "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/Common_Gotchas_in_JAX.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/Common_Gotchas_in_JAX.ipynb)" ] }, @@ -920,7 +922,7 @@ "id": "ORMVVGZJgSVi" }, "source": [ - "Underneath the hood, numpy uses the [Mersenne Twister](https://en.wikipedia.org/wiki/Mersenne_Twister) PRNG to power its pseudorandom functions. The PRNG has a period of $2^{19937}-1$ and at any point can be described by __624 32bit unsigned ints__ and a __position__ indicating how much of this \"entropy\" has been used up." + "Underneath the hood, numpy uses the [Mersenne Twister](https://en.wikipedia.org/wiki/Mersenne_Twister) PRNG to power its pseudorandom functions. The PRNG has a period of $2^{19937}-1$ and at any point can be described by __624 32-bit unsigned ints__ and a __position__ indicating how much of this \"entropy\" has been used up." ] }, { @@ -1006,7 +1008,7 @@ "source": [ "JAX instead implements an _explicit_ PRNG where entropy production and consumption are handled by explicitly passing and iterating PRNG state. JAX uses a modern [Threefry counter-based PRNG](https://github.com/google/jax/blob/main/docs/jep/263-prng.md) that's __splittable__. That is, its design allows us to __fork__ the PRNG state into new PRNGs for use with parallel stochastic generation.\n", "\n", - "The random state is described by two unsigned-int32s that we call a __key__:" + "The random state is described by a special array element that we call a __key__:" ] }, { @@ -1030,7 +1032,7 @@ ], "source": [ "from jax import random\n", - "key = random.PRNGKey(0)\n", + "key = random.key(0)\n", "key" ] }, @@ -1946,9 +1948,9 @@ "\n", "* setting the `JAX_DEBUG_NANS=True` environment variable;\n", "\n", - "* adding `from jax import config` and `config.update(\"jax_debug_nans\", True)` near the top of your main file;\n", + "* adding `jax.config.update(\"jax_debug_nans\", True)` near the top of your main file;\n", "\n", - "* adding `from jax import config` and `config.parse_flags_with_absl()` to your main file, then set the option using a command-line flag like `--jax_debug_nans=True`;\n", + "* adding `jax.config.parse_flags_with_absl()` to your main file, then set the option using a command-line flag like `--jax_debug_nans=True`;\n", "\n", "This will cause computations to error-out immediately on production of a NaN. Switching this option on adds a nan check to every floating point type value produced by XLA. That means values are pulled back to the host and checked as ndarrays for every primitive operation not under an `@jit`. For code under an `@jit`, the output of every `@jit` function is checked and if a nan is present it will re-run the function in de-optimized op-by-op mode, effectively removing one level of `@jit` at a time.\n", "\n", @@ -2121,7 +2123,7 @@ } ], "source": [ - "x = random.uniform(random.PRNGKey(0), (1000,), dtype=jnp.float64)\n", + "x = random.uniform(random.key(0), (1000,), dtype=jnp.float64)\n", "x.dtype" ] }, @@ -2135,30 +2137,30 @@ "\n", "There are a few ways to do this:\n", "\n", - "1. You can enable 64bit mode by setting the environment variable `JAX_ENABLE_X64=True`.\n", + "1. You can enable 64-bit mode by setting the environment variable `JAX_ENABLE_X64=True`.\n", "\n", "2. You can manually set the `jax_enable_x64` configuration flag at startup:\n", "\n", " ```python\n", " # again, this only works on startup!\n", - " from jax import config\n", - " config.update(\"jax_enable_x64\", True)\n", + " import jax\n", + " jax.config.update(\"jax_enable_x64\", True)\n", " ```\n", "\n", "3. You can parse command-line flags with `absl.app.run(main)`\n", "\n", " ```python\n", - " from jax import config\n", - " config.config_with_absl()\n", + " import jax\n", + " jax.config.config_with_absl()\n", " ```\n", "\n", "4. If you want JAX to run absl parsing for you, i.e. you don't want to do `absl.app.run(main)`, you can instead use\n", "\n", " ```python\n", - " from jax import config\n", + " import jax\n", " if __name__ == '__main__':\n", - " # calls config.config_with_absl() *and* runs absl parsing\n", - " config.parse_flags_with_absl()\n", + " # calls jax.config.config_with_absl() *and* runs absl parsing\n", + " jax.config.parse_flags_with_absl()\n", " ```\n", "\n", "Note that #2-#4 work for _any_ of JAX's configuration options.\n", @@ -2188,7 +2190,7 @@ "source": [ "import jax.numpy as jnp\n", "from jax import random\n", - "x = random.uniform(random.PRNGKey(0), (1000,), dtype=jnp.float64)\n", + "x = random.uniform(random.key(0), (1000,), dtype=jnp.float64)\n", "x.dtype # --> dtype('float64')" ] }, @@ -2223,6 +2225,7 @@ "\n", " >>> jnp.arange(254.0, 258.0).astype('uint8')\n", " Array([254, 255, 255, 255], dtype=uint8)\n", + "\n", " ```\n", " This sort of mismatch would typically arise when casting extreme values from floating to integer types or vice versa.\n", "\n", diff --git a/docs/notebooks/Common_Gotchas_in_JAX.md b/docs/notebooks/Common_Gotchas_in_JAX.md index 93695cb65c75..58fcb4310bc7 100644 --- a/docs/notebooks/Common_Gotchas_in_JAX.md +++ b/docs/notebooks/Common_Gotchas_in_JAX.md @@ -5,7 +5,7 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.16.0 + jupytext_version: 1.16.1 kernelspec: display_name: Python 3 language: python @@ -16,6 +16,8 @@ kernelspec: # 🔪 JAX - The Sharp Bits 🔪 + + [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/Common_Gotchas_in_JAX.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/Common_Gotchas_in_JAX.ipynb) +++ {"id": "4k5PVzEo2uJO"} @@ -407,7 +409,7 @@ print(np.random.random()) +++ {"id": "ORMVVGZJgSVi"} -Underneath the hood, numpy uses the [Mersenne Twister](https://en.wikipedia.org/wiki/Mersenne_Twister) PRNG to power its pseudorandom functions. The PRNG has a period of $2^{19937}-1$ and at any point can be described by __624 32bit unsigned ints__ and a __position__ indicating how much of this "entropy" has been used up. +Underneath the hood, numpy uses the [Mersenne Twister](https://en.wikipedia.org/wiki/Mersenne_Twister) PRNG to power its pseudorandom functions. The PRNG has a period of $2^{19937}-1$ and at any point can be described by __624 32-bit unsigned ints__ and a __position__ indicating how much of this "entropy" has been used up. ```{code-cell} ipython3 :id: 7Pyp2ajzfPO2 @@ -463,14 +465,14 @@ The Mersenne Twister PRNG is also known to have a [number](https://cs.stackexcha JAX instead implements an _explicit_ PRNG where entropy production and consumption are handled by explicitly passing and iterating PRNG state. JAX uses a modern [Threefry counter-based PRNG](https://github.com/google/jax/blob/main/docs/jep/263-prng.md) that's __splittable__. That is, its design allows us to __fork__ the PRNG state into new PRNGs for use with parallel stochastic generation. -The random state is described by two unsigned-int32s that we call a __key__: +The random state is described by a special array element that we call a __key__: ```{code-cell} ipython3 :id: yPHE7KTWgAWs :outputId: ae8af0ee-f19e-474e-81b6-45e894eb2fc3 from jax import random -key = random.PRNGKey(0) +key = random.key(0) key ``` @@ -938,9 +940,9 @@ If you want to trace where NaNs are occurring in your functions or gradients, yo * setting the `JAX_DEBUG_NANS=True` environment variable; -* adding `from jax import config` and `config.update("jax_debug_nans", True)` near the top of your main file; +* adding `jax.config.update("jax_debug_nans", True)` near the top of your main file; -* adding `from jax import config` and `config.parse_flags_with_absl()` to your main file, then set the option using a command-line flag like `--jax_debug_nans=True`; +* adding `jax.config.parse_flags_with_absl()` to your main file, then set the option using a command-line flag like `--jax_debug_nans=True`; This will cause computations to error-out immediately on production of a NaN. Switching this option on adds a nan check to every floating point type value produced by XLA. That means values are pulled back to the host and checked as ndarrays for every primitive operation not under an `@jit`. For code under an `@jit`, the output of every `@jit` function is checked and if a nan is present it will re-run the function in de-optimized op-by-op mode, effectively removing one level of `@jit` at a time. @@ -1071,7 +1073,7 @@ At the moment, JAX by default enforces single-precision numbers to mitigate the :id: CNNGtzM3NDkO :outputId: b422bb23-a784-44dc-f8c9-57f3b6c861b8 -x = random.uniform(random.PRNGKey(0), (1000,), dtype=jnp.float64) +x = random.uniform(random.key(0), (1000,), dtype=jnp.float64) x.dtype ``` @@ -1081,30 +1083,30 @@ To use double-precision numbers, you need to set the `jax_enable_x64` configurat There are a few ways to do this: -1. You can enable 64bit mode by setting the environment variable `JAX_ENABLE_X64=True`. +1. You can enable 64-bit mode by setting the environment variable `JAX_ENABLE_X64=True`. 2. You can manually set the `jax_enable_x64` configuration flag at startup: ```python # again, this only works on startup! - from jax import config - config.update("jax_enable_x64", True) + import jax + jax.config.update("jax_enable_x64", True) ``` 3. You can parse command-line flags with `absl.app.run(main)` ```python - from jax import config - config.config_with_absl() + import jax + jax.config.config_with_absl() ``` 4. If you want JAX to run absl parsing for you, i.e. you don't want to do `absl.app.run(main)`, you can instead use ```python - from jax import config + import jax if __name__ == '__main__': - # calls config.config_with_absl() *and* runs absl parsing - config.parse_flags_with_absl() + # calls jax.config.config_with_absl() *and* runs absl parsing + jax.config.parse_flags_with_absl() ``` Note that #2-#4 work for _any_ of JAX's configuration options. @@ -1117,7 +1119,7 @@ We can then confirm that `x64` mode is enabled: import jax.numpy as jnp from jax import random -x = random.uniform(random.PRNGKey(0), (1000,), dtype=jnp.float64) +x = random.uniform(random.key(0), (1000,), dtype=jnp.float64) x.dtype # --> dtype('float64') ``` @@ -1143,6 +1145,7 @@ Many such cases are discussed in detail in the sections above; here we list seve >>> jnp.arange(254.0, 258.0).astype('uint8') Array([254, 255, 255, 255], dtype=uint8) + ``` This sort of mismatch would typically arise when casting extreme values from floating to integer types or vice versa. diff --git a/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb b/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb index 70f499a89a5d..3abb6d9cbaec 100644 --- a/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb +++ b/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb @@ -8,6 +8,8 @@ "source": [ "# Custom derivative rules for JAX-transformable Python functions\n", "\n", + "\n", + "\n", "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb)\n", "\n", "*mattjj@ Mar 19 2020, last updated Oct 14 2020*\n", diff --git a/docs/notebooks/Custom_derivative_rules_for_Python_code.md b/docs/notebooks/Custom_derivative_rules_for_Python_code.md index 4ed598b2988e..ad577d55cd0d 100644 --- a/docs/notebooks/Custom_derivative_rules_for_Python_code.md +++ b/docs/notebooks/Custom_derivative_rules_for_Python_code.md @@ -5,7 +5,7 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.16.0 + jupytext_version: 1.16.1 kernelspec: display_name: Python 3 name: python3 @@ -15,6 +15,8 @@ kernelspec: # Custom derivative rules for JAX-transformable Python functions + + [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb) *mattjj@ Mar 19 2020, last updated Oct 14 2020* diff --git a/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb b/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb index b6ea0ffe14e2..2face1d4a0b2 100644 --- a/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb +++ b/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb @@ -6,7 +6,9 @@ "id": "PxHrg4Cjuapm" }, "source": [ - "# Distributed arrays and automatic parallelization" + "# Distributed arrays and automatic parallelization\n", + "\n", + "" ] }, { @@ -17,11 +19,7 @@ "source": [ "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb)\n", "\n", - "This tutorial discusses parallelism via `jax.Array`, the unified array object model available in JAX v0.4.1 and newer.\n", - "\n", - "Refer to the [`jax.Array migration`](https://jax.readthedocs.io/en/latest/jax_array_migration.html#jax-array-migration) guide to learn how to migrate the existing JAX pre-v0.4.1 codebases to `jax.Array`.\n", - "\n", - "**Note:** The features required by `jax.Array` are not supported by the Colab TPU runtime at this time, but are available on Google Cloud TPU and Kaggle TPU VMs." + "This tutorial discusses parallelism via `jax.Array`, the unified array object model available in JAX v0.4.1 and newer." ] }, { @@ -131,7 +129,7 @@ ], "source": [ "# Create an array of random values:\n", - "x = jax.random.normal(jax.random.PRNGKey(0), (8192, 8192))\n", + "x = jax.random.normal(jax.random.key(0), (8192, 8192))\n", "# and use jax.device_put to distribute it across devices:\n", "y = jax.device_put(x, sharding.reshape(4, 2))\n", "jax.debug.visualize_array_sharding(y)" @@ -272,7 +270,7 @@ "outputs": [], "source": [ "import jax\n", - "x = jax.random.normal(jax.random.PRNGKey(0), (8192, 8192))" + "x = jax.random.normal(jax.random.key(0), (8192, 8192))" ] }, { @@ -416,6 +414,8 @@ "id": "uRLpOcmNj_Vt" }, "source": [ + "The device numbers here are not in numerical order, because the mesh reflects the underlying toroidal topology of the device.\n", + "\n", "By writing `PositionalSharding(ndarray_of_devices)`, we fix the device order and the initial shape. Then we can reshape it:" ] }, @@ -1513,7 +1513,7 @@ }, "outputs": [], "source": [ - "x = jax.random.normal(jax.random.PRNGKey(0), (8192, 8192))\n", + "x = jax.random.normal(jax.random.key(0), (8192, 8192))\n", "x = jax.device_put(x, sharding.reshape(4, 2))" ] }, @@ -1738,7 +1738,7 @@ "layer_sizes = [784, 8192, 8192, 8192, 10]\n", "batch_size = 8192\n", "\n", - "params, batch = init_model(jax.random.PRNGKey(0), layer_sizes, batch_size)" + "params, batch = init_model(jax.random.key(0), layer_sizes, batch_size)" ] }, { @@ -2184,7 +2184,7 @@ " numbers = jax.random.uniform(key, x.shape)\n", " return x + numbers\n", "\n", - "key = jax.random.PRNGKey(42)\n", + "key = jax.random.key(42)\n", "x_sharding = jax.sharding.PositionalSharding(jax.devices())\n", "x = jax.device_put(jnp.arange(24), x_sharding)" ] @@ -2369,6 +2369,7 @@ } ], "metadata": { + "accelerator": "TPU", "colab": { "provenance": [], "toc_visible": true diff --git a/docs/notebooks/Distributed_arrays_and_automatic_parallelization.md b/docs/notebooks/Distributed_arrays_and_automatic_parallelization.md index 256e2622410f..b9ec9dc694d2 100644 --- a/docs/notebooks/Distributed_arrays_and_automatic_parallelization.md +++ b/docs/notebooks/Distributed_arrays_and_automatic_parallelization.md @@ -5,7 +5,7 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.16.0 + jupytext_version: 1.16.1 kernelspec: display_name: Python 3 name: python3 @@ -15,16 +15,14 @@ kernelspec: # Distributed arrays and automatic parallelization + + +++ {"id": "pFtQjv4SzHRj"} [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb) This tutorial discusses parallelism via `jax.Array`, the unified array object model available in JAX v0.4.1 and newer. -Refer to the [`jax.Array migration`](https://jax.readthedocs.io/en/latest/jax_array_migration.html#jax-array-migration) guide to learn how to migrate the existing JAX pre-v0.4.1 codebases to `jax.Array`. - -**Note:** The features required by `jax.Array` are not supported by the Colab TPU runtime at this time, but are available on Google Cloud TPU and Kaggle TPU VMs. - ```{code-cell} :id: FNxScTfq3vGF @@ -81,7 +79,7 @@ sharding = PositionalSharding(mesh_utils.create_device_mesh((8,))) :outputId: 3b518df8-5c29-4848-acc3-e41df939f30b # Create an array of random values: -x = jax.random.normal(jax.random.PRNGKey(0), (8192, 8192)) +x = jax.random.normal(jax.random.key(0), (8192, 8192)) # and use jax.device_put to distribute it across devices: y = jax.device_put(x, sharding.reshape(4, 2)) jax.debug.visualize_array_sharding(y) @@ -144,7 +142,7 @@ For example, here's a value with a single-device `Sharding`: :id: VmoX4SUp3vGJ import jax -x = jax.random.normal(jax.random.PRNGKey(0), (8192, 8192)) +x = jax.random.normal(jax.random.key(0), (8192, 8192)) ``` ```{code-cell} @@ -196,6 +194,8 @@ sharding +++ {"id": "uRLpOcmNj_Vt"} +The device numbers here are not in numerical order, because the mesh reflects the underlying toroidal topology of the device. + By writing `PositionalSharding(ndarray_of_devices)`, we fix the device order and the initial shape. Then we can reshape it: ```{code-cell} @@ -609,7 +609,7 @@ sharding = PositionalSharding(mesh_utils.create_device_mesh((8,))) ```{code-cell} :id: Q1wuDp-L3vGT -x = jax.random.normal(jax.random.PRNGKey(0), (8192, 8192)) +x = jax.random.normal(jax.random.key(0), (8192, 8192)) x = jax.device_put(x, sharding.reshape(4, 2)) ``` @@ -720,7 +720,7 @@ def init_model(key, layer_sizes, batch_size): layer_sizes = [784, 8192, 8192, 8192, 10] batch_size = 8192 -params, batch = init_model(jax.random.PRNGKey(0), layer_sizes, batch_size) +params, batch = init_model(jax.random.key(0), layer_sizes, batch_size) ``` +++ {"id": "sJv_h0AS2drh"} @@ -902,7 +902,7 @@ def f(key, x): numbers = jax.random.uniform(key, x.shape) return x + numbers -key = jax.random.PRNGKey(42) +key = jax.random.key(42) x_sharding = jax.sharding.PositionalSharding(jax.devices()) x = jax.device_put(jnp.arange(24), x_sharding) ``` diff --git a/docs/notebooks/How_JAX_primitives_work.ipynb b/docs/notebooks/How_JAX_primitives_work.ipynb index 89e774d84321..f42e3f74b4e3 100644 --- a/docs/notebooks/How_JAX_primitives_work.ipynb +++ b/docs/notebooks/How_JAX_primitives_work.ipynb @@ -8,6 +8,8 @@ "source": [ "# How JAX primitives work\n", "\n", + "\n", + "\n", "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/How_JAX_primitives_work.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/How_JAX_primitives_work.ipynb)\n", "\n", "*necula@google.com*, October 2019.\n", diff --git a/docs/notebooks/How_JAX_primitives_work.md b/docs/notebooks/How_JAX_primitives_work.md index a24854009da0..0ebf202f2258 100644 --- a/docs/notebooks/How_JAX_primitives_work.md +++ b/docs/notebooks/How_JAX_primitives_work.md @@ -5,7 +5,7 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.16.0 + jupytext_version: 1.16.1 kernelspec: display_name: Python 3 name: python3 @@ -15,6 +15,8 @@ kernelspec: # How JAX primitives work + + [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/How_JAX_primitives_work.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/How_JAX_primitives_work.ipynb) *necula@google.com*, October 2019. diff --git a/docs/notebooks/Neural_Network_and_Data_Loading.ipynb b/docs/notebooks/Neural_Network_and_Data_Loading.ipynb index fb0ac165be16..f0c157655790 100644 --- a/docs/notebooks/Neural_Network_and_Data_Loading.ipynb +++ b/docs/notebooks/Neural_Network_and_Data_Loading.ipynb @@ -8,6 +8,8 @@ "source": [ "# Training a Simple Neural Network, with PyTorch Data Loading\n", "\n", + "\n", + "\n", "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/Neural_Network_and_Data_Loading.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/Neural_Network_and_Data_Loading.ipynb)\n", "\n", "**Copyright 2018 The JAX Authors.**\n", @@ -32,7 +34,7 @@ "source": [ "![JAX](https://raw.githubusercontent.com/google/jax/main/images/jax_logo_250px.png)\n", "\n", - "Let's combine everything we showed in the [quickstart notebook](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/quickstart.ipynb) to train a simple neural network. We will first specify and train a simple MLP on MNIST using JAX for the computation. We will use PyTorch's data loading API to load images and labels (because it's pretty great, and the world doesn't need yet another data loading library).\n", + "Let's combine everything we showed in the [quickstart](https://colab.research.google.com/github/google/jax/blob/main/docs/quickstart.html) to train a simple neural network. We will first specify and train a simple MLP on MNIST using JAX for the computation. We will use PyTorch's data loading API to load images and labels (because it's pretty great, and the world doesn't need yet another data loading library).\n", "\n", "Of course, you can use JAX with any API that is compatible with NumPy to make specifying the model a bit more plug-and-play. Here, just for explanatory purposes, we won't use any neural network libraries or special APIs for building our model." ] @@ -84,7 +86,7 @@ "num_epochs = 8\n", "batch_size = 128\n", "n_targets = 10\n", - "params = init_network_params(layer_sizes, random.PRNGKey(0))" + "params = init_network_params(layer_sizes, random.key(0))" ] }, { @@ -150,7 +152,7 @@ ], "source": [ "# This works on single examples\n", - "random_flattened_image = random.normal(random.PRNGKey(1), (28 * 28,))\n", + "random_flattened_image = random.normal(random.key(1), (28 * 28,))\n", "preds = predict(params, random_flattened_image)\n", "print(preds.shape)" ] @@ -173,7 +175,7 @@ ], "source": [ "# Doesn't work with a batch\n", - "random_flattened_images = random.normal(random.PRNGKey(1), (10, 28 * 28))\n", + "random_flattened_images = random.normal(random.key(1), (10, 28 * 28))\n", "try:\n", " preds = predict(params, random_flattened_images)\n", "except TypeError:\n", diff --git a/docs/notebooks/Neural_Network_and_Data_Loading.md b/docs/notebooks/Neural_Network_and_Data_Loading.md index ebbb6da3d107..2c53bb1e4ab5 100644 --- a/docs/notebooks/Neural_Network_and_Data_Loading.md +++ b/docs/notebooks/Neural_Network_and_Data_Loading.md @@ -5,7 +5,7 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.16.0 + jupytext_version: 1.16.1 kernelspec: display_name: Python 3 language: python @@ -16,6 +16,8 @@ kernelspec: # Training a Simple Neural Network, with PyTorch Data Loading + + [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/Neural_Network_and_Data_Loading.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/Neural_Network_and_Data_Loading.ipynb) **Copyright 2018 The JAX Authors.** @@ -35,7 +37,7 @@ limitations under the License. ![JAX](https://raw.githubusercontent.com/google/jax/main/images/jax_logo_250px.png) -Let's combine everything we showed in the [quickstart notebook](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/quickstart.ipynb) to train a simple neural network. We will first specify and train a simple MLP on MNIST using JAX for the computation. We will use PyTorch's data loading API to load images and labels (because it's pretty great, and the world doesn't need yet another data loading library). +Let's combine everything we showed in the [quickstart](https://colab.research.google.com/github/google/jax/blob/main/docs/quickstart.html) to train a simple neural network. We will first specify and train a simple MLP on MNIST using JAX for the computation. We will use PyTorch's data loading API to load images and labels (because it's pretty great, and the world doesn't need yet another data loading library). Of course, you can use JAX with any API that is compatible with NumPy to make specifying the model a bit more plug-and-play. Here, just for explanatory purposes, we won't use any neural network libraries or special APIs for building our model. @@ -71,7 +73,7 @@ step_size = 0.01 num_epochs = 8 batch_size = 128 n_targets = 10 -params = init_network_params(layer_sizes, random.PRNGKey(0)) +params = init_network_params(layer_sizes, random.key(0)) ``` +++ {"id": "BtoNk_yxWtIw"} @@ -109,7 +111,7 @@ Let's check that our prediction function only works on single images. :outputId: 9d3b29e8-fab3-4ecb-9f63-bc8c092f9006 # This works on single examples -random_flattened_image = random.normal(random.PRNGKey(1), (28 * 28,)) +random_flattened_image = random.normal(random.key(1), (28 * 28,)) preds = predict(params, random_flattened_image) print(preds.shape) ``` @@ -119,7 +121,7 @@ print(preds.shape) :outputId: d5d20211-b6da-44e9-f71e-946f2a9d0fc4 # Doesn't work with a batch -random_flattened_images = random.normal(random.PRNGKey(1), (10, 28 * 28)) +random_flattened_images = random.normal(random.key(1), (10, 28 * 28)) try: preds = predict(params, random_flattened_images) except TypeError: diff --git a/docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb b/docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb index aac86fd8b710..7e65aefe359c 100644 --- a/docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb +++ b/docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb @@ -8,6 +8,8 @@ "source": [ "# Writing custom Jaxpr interpreters in JAX\n", "\n", + "\n", + "\n", "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb)" ] }, @@ -66,7 +68,7 @@ }, "outputs": [], "source": [ - "x = random.normal(random.PRNGKey(0), (5000, 5000))\n", + "x = random.normal(random.key(0), (5000, 5000))\n", "def f(w, b, x):\n", " return jnp.tanh(jnp.dot(x, w) + b)\n", "fast_f = jit(f)" diff --git a/docs/notebooks/Writing_custom_interpreters_in_Jax.md b/docs/notebooks/Writing_custom_interpreters_in_Jax.md index af4379b03802..e52c6a5f8742 100644 --- a/docs/notebooks/Writing_custom_interpreters_in_Jax.md +++ b/docs/notebooks/Writing_custom_interpreters_in_Jax.md @@ -5,7 +5,7 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.16.0 + jupytext_version: 1.16.1 kernelspec: display_name: Python 3 language: python @@ -16,6 +16,8 @@ kernelspec: # Writing custom Jaxpr interpreters in JAX + + [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb) +++ {"id": "r-3vMiKRYXPJ"} @@ -48,7 +50,7 @@ JAX provides a NumPy-like API for numerical computing which can be used as is, b ```{code-cell} ipython3 :id: HmlMcICOcSXR -x = random.normal(random.PRNGKey(0), (5000, 5000)) +x = random.normal(random.key(0), (5000, 5000)) def f(w, b, x): return jnp.tanh(jnp.dot(x, w) + b) fast_f = jit(f) diff --git a/docs/notebooks/autodiff_cookbook.ipynb b/docs/notebooks/autodiff_cookbook.ipynb index 39aad749de47..edfd0d4535f8 100644 --- a/docs/notebooks/autodiff_cookbook.ipynb +++ b/docs/notebooks/autodiff_cookbook.ipynb @@ -8,6 +8,8 @@ "source": [ "# The Autodiff Cookbook\n", "\n", + "\n", + "\n", "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/autodiff_cookbook.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/autodiff_cookbook.ipynb)\n", "\n", "*alexbw@, mattjj@* \n", @@ -27,7 +29,7 @@ "from jax import grad, jit, vmap\n", "from jax import random\n", "\n", - "key = random.PRNGKey(0)" + "key = random.key(0)" ] }, { @@ -1055,7 +1057,7 @@ " outs, = vmap(vjp_fun)(M)\n", " return outs\n", "\n", - "key = random.PRNGKey(0)\n", + "key = random.key(0)\n", "num_covecs = 128\n", "U = random.normal(key, (num_covecs,) + y.shape)\n", "\n", @@ -1306,7 +1308,7 @@ "outputs": [], "source": [ "def check(seed):\n", - " key = random.PRNGKey(seed)\n", + " key = random.key(seed)\n", "\n", " # random coeffs for u and v\n", " key, subkey = random.split(key)\n", @@ -1399,7 +1401,7 @@ "outputs": [], "source": [ "def check(seed):\n", - " key = random.PRNGKey(seed)\n", + " key = random.key(seed)\n", "\n", " # random coeffs for u and v\n", " key, subkey = random.split(key)\n", diff --git a/docs/notebooks/autodiff_cookbook.md b/docs/notebooks/autodiff_cookbook.md index fa50db4f2194..c24d05c0e7c9 100644 --- a/docs/notebooks/autodiff_cookbook.md +++ b/docs/notebooks/autodiff_cookbook.md @@ -5,7 +5,7 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.16.0 + jupytext_version: 1.16.1 kernelspec: display_name: Python 3 language: python @@ -16,6 +16,8 @@ kernelspec: # The Autodiff Cookbook + + [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/autodiff_cookbook.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/autodiff_cookbook.ipynb) *alexbw@, mattjj@* @@ -29,7 +31,7 @@ import jax.numpy as jnp from jax import grad, jit, vmap from jax import random -key = random.PRNGKey(0) +key = random.key(0) ``` +++ {"id": "YxnjtAGN6vu2"} @@ -614,7 +616,7 @@ def vmap_mjp(f, x, M): outs, = vmap(vjp_fun)(M) return outs -key = random.PRNGKey(0) +key = random.key(0) num_covecs = 128 U = random.normal(key, (num_covecs,) + y.shape) @@ -770,7 +772,7 @@ Here's a check: :id: BGZV__zupIMS def check(seed): - key = random.PRNGKey(seed) + key = random.key(seed) # random coeffs for u and v key, subkey = random.split(key) @@ -833,7 +835,7 @@ Here's a check of the VJP rules: :id: 4J7edvIBttcU def check(seed): - key = random.PRNGKey(seed) + key = random.key(seed) # random coeffs for u and v key, subkey = random.split(key) diff --git a/docs/notebooks/autodiff_remat.ipynb b/docs/notebooks/autodiff_remat.ipynb index 9aec8b1a23df..f0552e52688f 100644 --- a/docs/notebooks/autodiff_remat.ipynb +++ b/docs/notebooks/autodiff_remat.ipynb @@ -6,7 +6,9 @@ "id": "29WqUVkCXjDD" }, "source": [ - "## Control autodiff's saved values with `jax.checkpoint` (aka `jax.remat`)" + "## Control autodiff's saved values with `jax.checkpoint` (aka `jax.remat`)\n", + "\n", + "" ] }, { diff --git a/docs/notebooks/autodiff_remat.md b/docs/notebooks/autodiff_remat.md index 6e12fb5c3155..b31e093b6f91 100644 --- a/docs/notebooks/autodiff_remat.md +++ b/docs/notebooks/autodiff_remat.md @@ -5,7 +5,7 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.16.0 + jupytext_version: 1.16.1 kernelspec: display_name: Python 3 name: python3 @@ -15,6 +15,8 @@ kernelspec: ## Control autodiff's saved values with `jax.checkpoint` (aka `jax.remat`) + + ```{code-cell} import jax import jax.numpy as jnp diff --git a/docs/notebooks/convolutions.ipynb b/docs/notebooks/convolutions.ipynb index f8dcaa3685ad..0a823353068b 100644 --- a/docs/notebooks/convolutions.ipynb +++ b/docs/notebooks/convolutions.ipynb @@ -8,6 +8,8 @@ "source": [ "# Generalized Convolutions in JAX\n", "\n", + "\n", + "\n", "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/convolutions.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/convolutions.ipynb)\n", "\n", "JAX provides a number of interfaces to compute convolutions across data, including:\n", @@ -60,7 +62,7 @@ "import jax.numpy as jnp\n", "import numpy as np\n", "\n", - "key = random.PRNGKey(1701)\n", + "key = random.key(1701)\n", "\n", "x = jnp.linspace(0, 10, 500)\n", "y = jnp.sin(x) + 0.2 * random.normal(key, shape=(500,))\n", @@ -130,7 +132,7 @@ "ax[0].set_title('original')\n", "\n", "# Create a noisy version by adding random Gaussian noise\n", - "key = random.PRNGKey(1701)\n", + "key = random.key(1701)\n", "noisy_image = image + 50 * random.normal(key, image.shape)\n", "ax[1].imshow(noisy_image, cmap='binary_r')\n", "ax[1].set_title('noisy')\n", diff --git a/docs/notebooks/convolutions.md b/docs/notebooks/convolutions.md index 5d34ef950021..3de8f261aa5b 100644 --- a/docs/notebooks/convolutions.md +++ b/docs/notebooks/convolutions.md @@ -5,7 +5,7 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.16.0 + jupytext_version: 1.16.1 kernelspec: display_name: Python 3 language: python @@ -16,6 +16,8 @@ kernelspec: # Generalized Convolutions in JAX + + [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/convolutions.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/convolutions.ipynb) JAX provides a number of interfaces to compute convolutions across data, including: @@ -43,7 +45,7 @@ from jax import random import jax.numpy as jnp import numpy as np -key = random.PRNGKey(1701) +key = random.key(1701) x = jnp.linspace(0, 10, 500) y = jnp.sin(x) + 0.2 * random.normal(key, shape=(500,)) @@ -84,7 +86,7 @@ ax[0].imshow(image, cmap='binary_r') ax[0].set_title('original') # Create a noisy version by adding random Gaussian noise -key = random.PRNGKey(1701) +key = random.key(1701) noisy_image = image + 50 * random.normal(key, image.shape) ax[1].imshow(noisy_image, cmap='binary_r') ax[1].set_title('noisy') diff --git a/docs/notebooks/external_callbacks.ipynb b/docs/notebooks/external_callbacks.ipynb index 5cda80620961..bdf71004c01b 100644 --- a/docs/notebooks/external_callbacks.ipynb +++ b/docs/notebooks/external_callbacks.ipynb @@ -6,7 +6,9 @@ "id": "7XNMxdTwURqI" }, "source": [ - "# External Callbacks in JAX" + "# External Callbacks in JAX\n", + "\n", + "" ] }, { diff --git a/docs/notebooks/external_callbacks.md b/docs/notebooks/external_callbacks.md index 8387a5160a24..857eef42e2b3 100644 --- a/docs/notebooks/external_callbacks.md +++ b/docs/notebooks/external_callbacks.md @@ -5,7 +5,7 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.16.0 + jupytext_version: 1.16.1 kernelspec: display_name: Python 3 name: python3 @@ -15,6 +15,8 @@ kernelspec: # External Callbacks in JAX + + +++ {"id": "h6lXo6bSUYGq"} This guide outlines the uses of various callback functions, which allow JAX runtimes to execute Python code on the host, even while running under `jit`, `vmap`, `grad`, or another transformation. diff --git a/docs/notebooks/neural_network_with_tfds_data.ipynb b/docs/notebooks/neural_network_with_tfds_data.ipynb index 8a14bb7bbec9..95c00bf1e689 100644 --- a/docs/notebooks/neural_network_with_tfds_data.ipynb +++ b/docs/notebooks/neural_network_with_tfds_data.ipynb @@ -38,13 +38,15 @@ "source": [ "# Training a Simple Neural Network, with tensorflow/datasets Data Loading\n", "\n", + "\n", + "\n", "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/neural_network_with_tfds_data.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/neural_network_with_tfds_data.ipynb)\n", "\n", "_Forked from_ `neural_network_and_data_loading.ipynb`\n", "\n", "![JAX](https://raw.githubusercontent.com/google/jax/main/images/jax_logo_250px.png)\n", "\n", - "Let's combine everything we showed in the [quickstart notebook](https://jax.readthedocs.io/en/latest/notebooks/quickstart.html) to train a simple neural network. We will first specify and train a simple MLP on MNIST using JAX for the computation. We will use `tensorflow/datasets` data loading API to load images and labels (because it's pretty great, and the world doesn't need yet another data loading library :P).\n", + "Let's combine everything we showed in the [quickstart](https://jax.readthedocs.io/en/latest/quickstart.html) to train a simple neural network. We will first specify and train a simple MLP on MNIST using JAX for the computation. We will use `tensorflow/datasets` data loading API to load images and labels (because it's pretty great, and the world doesn't need yet another data loading library :P).\n", "\n", "Of course, you can use JAX with any API that is compatible with NumPy to make specifying the model a bit more plug-and-play. Here, just for explanatory purposes, we won't use any neural network libraries or special APIs for building our model." ] @@ -97,7 +99,7 @@ "num_epochs = 10\n", "batch_size = 128\n", "n_targets = 10\n", - "params = init_network_params(layer_sizes, random.PRNGKey(0))" + "params = init_network_params(layer_sizes, random.key(0))" ] }, { @@ -163,7 +165,7 @@ ], "source": [ "# This works on single examples\n", - "random_flattened_image = random.normal(random.PRNGKey(1), (28 * 28,))\n", + "random_flattened_image = random.normal(random.key(1), (28 * 28,))\n", "preds = predict(params, random_flattened_image)\n", "print(preds.shape)" ] @@ -186,7 +188,7 @@ ], "source": [ "# Doesn't work with a batch\n", - "random_flattened_images = random.normal(random.PRNGKey(1), (10, 28 * 28))\n", + "random_flattened_images = random.normal(random.key(1), (10, 28 * 28))\n", "try:\n", " preds = predict(params, random_flattened_images)\n", "except TypeError:\n", diff --git a/docs/notebooks/neural_network_with_tfds_data.md b/docs/notebooks/neural_network_with_tfds_data.md index b3c2be06bffb..8f795484d5b9 100644 --- a/docs/notebooks/neural_network_with_tfds_data.md +++ b/docs/notebooks/neural_network_with_tfds_data.md @@ -5,7 +5,7 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.16.0 + jupytext_version: 1.16.1 kernelspec: display_name: Python 3 language: python @@ -36,13 +36,15 @@ limitations under the License. # Training a Simple Neural Network, with tensorflow/datasets Data Loading + + [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/neural_network_with_tfds_data.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/neural_network_with_tfds_data.ipynb) _Forked from_ `neural_network_and_data_loading.ipynb` ![JAX](https://raw.githubusercontent.com/google/jax/main/images/jax_logo_250px.png) -Let's combine everything we showed in the [quickstart notebook](https://jax.readthedocs.io/en/latest/notebooks/quickstart.html) to train a simple neural network. We will first specify and train a simple MLP on MNIST using JAX for the computation. We will use `tensorflow/datasets` data loading API to load images and labels (because it's pretty great, and the world doesn't need yet another data loading library :P). +Let's combine everything we showed in the [quickstart](https://jax.readthedocs.io/en/latest/quickstart.html) to train a simple neural network. We will first specify and train a simple MLP on MNIST using JAX for the computation. We will use `tensorflow/datasets` data loading API to load images and labels (because it's pretty great, and the world doesn't need yet another data loading library :P). Of course, you can use JAX with any API that is compatible with NumPy to make specifying the model a bit more plug-and-play. Here, just for explanatory purposes, we won't use any neural network libraries or special APIs for building our model. @@ -79,7 +81,7 @@ step_size = 0.01 num_epochs = 10 batch_size = 128 n_targets = 10 -params = init_network_params(layer_sizes, random.PRNGKey(0)) +params = init_network_params(layer_sizes, random.key(0)) ``` +++ {"id": "BtoNk_yxWtIw"} @@ -117,7 +119,7 @@ Let's check that our prediction function only works on single images. :outputId: ce9d86ed-a830-4832-e04d-10d1abb1fb8a # This works on single examples -random_flattened_image = random.normal(random.PRNGKey(1), (28 * 28,)) +random_flattened_image = random.normal(random.key(1), (28 * 28,)) preds = predict(params, random_flattened_image) print(preds.shape) ``` @@ -127,7 +129,7 @@ print(preds.shape) :outputId: f43bbc9d-bc8f-4168-ee7b-79ee9d33f245 # Doesn't work with a batch -random_flattened_images = random.normal(random.PRNGKey(1), (10, 28 * 28)) +random_flattened_images = random.normal(random.key(1), (10, 28 * 28)) try: preds = predict(params, random_flattened_images) except TypeError: diff --git a/docs/notebooks/quickstart.ipynb b/docs/notebooks/quickstart.ipynb deleted file mode 100644 index 722047adadbd..000000000000 --- a/docs/notebooks/quickstart.ipynb +++ /dev/null @@ -1,609 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "xtWX4x9DCF5_" - }, - "source": [ - "# JAX Quickstart\n", - "\n", - "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/quickstart.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/quickstart.ipynb)\n", - "\n", - "**JAX is NumPy on the CPU, GPU, and TPU, with great automatic differentiation for high-performance machine learning research.**\n", - "\n", - "With its updated version of [Autograd](https://github.com/hips/autograd), JAX\n", - "can automatically differentiate native Python and NumPy code. It can\n", - "differentiate through a large subset of Python’s features, including loops, ifs,\n", - "recursion, and closures, and it can even take derivatives of derivatives of\n", - "derivatives. It supports reverse-mode as well as forward-mode differentiation, and the two can be composed arbitrarily\n", - "to any order.\n", - "\n", - "What’s new is that JAX uses\n", - "[XLA](https://www.tensorflow.org/xla)\n", - "to compile and run your NumPy code on accelerators, like GPUs and TPUs.\n", - "Compilation happens under the hood by default, with library calls getting\n", - "just-in-time compiled and executed. But JAX even lets you just-in-time compile\n", - "your own Python functions into XLA-optimized kernels using a one-function API.\n", - "Compilation and automatic differentiation can be composed arbitrarily, so you\n", - "can express sophisticated algorithms and get maximal performance without having\n", - "to leave Python." - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": { - "id": "SY8mDvEvCGqk" - }, - "outputs": [], - "source": [ - "import jax.numpy as jnp\n", - "from jax import grad, jit, vmap\n", - "from jax import random" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "FQ89jHCYfhpg" - }, - "source": [ - "## Multiplying Matrices" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Xpy1dSgNqCP4" - }, - "source": [ - "We'll be generating random data in the following examples. One big difference between NumPy and JAX is how you generate random numbers. For more details, see [Common Gotchas in JAX].\n", - "\n", - "[Common Gotchas in JAX]: https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#%F0%9F%94%AA-Random-Numbers" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": { - "id": "u0nseKZNqOoH", - "outputId": "03e20e21-376c-41bb-a6bb-57431823691b" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[-0.3721109 0.26423115 -0.18252768 -0.7368197 -0.44030377 -0.1521442\n", - " -0.67135346 -0.5908641 0.73168886 0.5673026 ]\n" - ] - } - ], - "source": [ - "key = random.PRNGKey(0)\n", - "x = random.normal(key, (10,))\n", - "print(x)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "hDJF0UPKnuqB" - }, - "source": [ - "Let's dive right in and multiply two big matrices." - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": { - "id": "eXn8GUl6CG5N", - "outputId": "ffce6bdc-86e6-4af0-ab5d-65d235022db9" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "13.5 ms ± 1.89 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" - ] - } - ], - "source": [ - "size = 3000\n", - "x = random.normal(key, (size, size), dtype=jnp.float32)\n", - "%timeit jnp.dot(x, x.T).block_until_ready() # runs on the GPU" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "0AlN7EbonyaR" - }, - "source": [ - "We added that `block_until_ready` because JAX uses asynchronous execution by default (see {ref}`async-dispatch`).\n", - "\n", - "JAX NumPy functions work on regular NumPy arrays." - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": { - "id": "ZPl0MuwYrM7t", - "outputId": "71219657-b559-474e-a877-5441ee39f18f" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "80 ms ± 30.2 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)\n" - ] - } - ], - "source": [ - "import numpy as np\n", - "x = np.random.normal(size=(size, size)).astype(np.float32)\n", - "%timeit jnp.dot(x, x.T).block_until_ready()" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "_SrcB2IurUuE" - }, - "source": [ - "That's slower because it has to transfer data to the GPU every time. You can ensure that an NDArray is backed by device memory using {func}`~jax.device_put`." - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": { - "id": "Jj7M7zyRskF0", - "outputId": "a649a6d3-cf28-445e-c3fc-bcfe3069482c" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "15.8 ms ± 113 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" - ] - } - ], - "source": [ - "from jax import device_put\n", - "\n", - "x = np.random.normal(size=(size, size)).astype(np.float32)\n", - "x = device_put(x)\n", - "%timeit jnp.dot(x, x.T).block_until_ready()" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "clO9djnen8qi" - }, - "source": [ - "The output of {func}`~jax.device_put` still acts like an NDArray, but it only copies values back to the CPU when they're needed for printing, plotting, saving to disk, branching, etc. The behavior of {func}`~jax.device_put` is equivalent to the function `jit(lambda x: x)`, but it's faster." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "ghkfKNQttDpg" - }, - "source": [ - "If you have a GPU (or TPU!) these calls run on the accelerator and have the potential to be much faster than on CPU.\n", - "See {ref}`faq-jax-vs-numpy` for more comparison of performance characteristics of NumPy and JAX" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "iOzp0P_GoJhb" - }, - "source": [ - "JAX is much more than just a GPU-backed NumPy. It also comes with a few program transformations that are useful when writing numerical code. For now, there are three main ones:\n", - "\n", - " - {func}`~jax.jit`, for speeding up your code\n", - " - {func}`~jax.grad`, for taking derivatives\n", - " - {func}`~jax.vmap`, for automatic vectorization or batching.\n", - "\n", - "Let's go over these, one-by-one. We'll also end up composing these in interesting ways." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "bTTrTbWvgLUK" - }, - "source": [ - "## Using {func}`~jax.jit` to speed up functions" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "YrqE32mvE3b7" - }, - "source": [ - "JAX runs transparently on the GPU or TPU (falling back to CPU if you don't have one). However, in the above example, JAX is dispatching kernels to the GPU one operation at a time. If we have a sequence of operations, we can use the `@jit` decorator to compile multiple operations together using [XLA](https://www.tensorflow.org/xla). Let's try that." - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": { - "id": "qLGdCtFKFLOR", - "outputId": "870253fa-ba1b-47ec-c5a4-1c6f706be996" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "1.07 ms ± 261 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" - ] - } - ], - "source": [ - "def selu(x, alpha=1.67, lmbda=1.05):\n", - " return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)\n", - "\n", - "x = random.normal(key, (1000000,))\n", - "%timeit selu(x).block_until_ready()" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "a_V8SruVHrD_" - }, - "source": [ - "We can speed it up with `@jit`, which will jit-compile the first time `selu` is called and will be cached thereafter." - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": { - "id": "fh4w_3NpFYTp", - "outputId": "4d56b4f2-5d58-4689-ecc2-ac361c0245cd" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "127 µs ± 1.43 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)\n" - ] - } - ], - "source": [ - "selu_jit = jit(selu)\n", - "%timeit selu_jit(x).block_until_ready()" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "HxpBc4WmfsEU" - }, - "source": [ - "## Taking derivatives with {func}`~jax.grad`\n", - "\n", - "In addition to evaluating numerical functions, we also want to transform them. One transformation is [automatic differentiation](https://en.wikipedia.org/wiki/Automatic_differentiation). In JAX, just like in [Autograd](https://github.com/HIPS/autograd), you can compute gradients with the {func}`~jax.grad` function." - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": { - "id": "IMAgNJaMJwPD", - "outputId": "6646cc65-b52f-4825-ff7f-e50b67083493" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[0.25 0.19661194 0.10499357]\n" - ] - } - ], - "source": [ - "def sum_logistic(x):\n", - " return jnp.sum(1.0 / (1.0 + jnp.exp(-x)))\n", - "\n", - "x_small = jnp.arange(3.)\n", - "derivative_fn = grad(sum_logistic)\n", - "print(derivative_fn(x_small))" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "PtNs881Ohioc" - }, - "source": [ - "Let's verify with finite differences that our result is correct." - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": { - "id": "JXI7_OZuKZVO", - "outputId": "18c1f913-d5d6-4895-f71e-e62180c3ad1b" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[0.24998187 0.1965761 0.10502338]\n" - ] - } - ], - "source": [ - "def first_finite_differences(f, x):\n", - " eps = 1e-3\n", - " return jnp.array([(f(x + eps * v) - f(x - eps * v)) / (2 * eps)\n", - " for v in jnp.eye(len(x))])\n", - "\n", - "\n", - "print(first_finite_differences(sum_logistic, x_small))" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Q2CUZjOWNZ-3" - }, - "source": [ - "Taking derivatives is as easy as calling {func}`~jax.grad`. {func}`~jax.grad` and {func}`~jax.jit` compose and can be mixed arbitrarily. In the above example we jitted `sum_logistic` and then took its derivative. We can go further:" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": { - "id": "TO4g8ny-OEi4", - "outputId": "1a0421e6-60e9-42e3-dc9c-e558a69bbf17" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "-0.0353256\n" - ] - } - ], - "source": [ - "print(grad(jit(grad(jit(grad(sum_logistic)))))(1.0))" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "yCJ5feKvhnBJ" - }, - "source": [ - "For more advanced autodiff, you can use {func}`jax.vjp` for reverse-mode vector-Jacobian products and {func}`jax.jvp` for forward-mode Jacobian-vector products. The two can be composed arbitrarily with one another, and with other JAX transformations. Here's one way to compose them to make a function that efficiently computes full Hessian matrices:" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "metadata": { - "id": "Z-JxbiNyhxEW" - }, - "outputs": [], - "source": [ - "from jax import jacfwd, jacrev\n", - "def hessian(fun):\n", - " return jit(jacfwd(jacrev(fun)))" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "TI4nPsGafxbL" - }, - "source": [ - "## Auto-vectorization with {func}`~jax.vmap`" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "PcxkONy5aius" - }, - "source": [ - "JAX has one more transformation in its API that you might find useful: {func}`~jax.vmap`, the vectorizing map. It has the familiar semantics of mapping a function along array axes, but instead of keeping the loop on the outside, it pushes the loop down into a function’s primitive operations for better performance. When composed with {func}`~jax.jit`, it can be just as fast as adding the batch dimensions by hand." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "TPiX4y-bWLFS" - }, - "source": [ - "We're going to work with a simple example, and promote matrix-vector products into matrix-matrix products using {func}`~jax.vmap`. Although this is easy to do by hand in this specific case, the same technique can apply to more complicated functions." - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": { - "id": "8w0Gpsn8WYYj" - }, - "outputs": [], - "source": [ - "mat = random.normal(key, (150, 100))\n", - "batched_x = random.normal(key, (10, 100))\n", - "\n", - "def apply_matrix(v):\n", - " return jnp.dot(mat, v)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "0zWsc0RisQWx" - }, - "source": [ - "Given a function such as `apply_matrix`, we can loop over a batch dimension in Python, but usually the performance of doing so is poor." - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "metadata": { - "id": "KWVc9BsZv0Ki", - "outputId": "bea78b6d-cd17-45e6-c361-1c55234e77c0" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Naively batched\n", - "3.12 ms ± 176 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" - ] - } - ], - "source": [ - "def naively_batched_apply_matrix(v_batched):\n", - " return jnp.stack([apply_matrix(v) for v in v_batched])\n", - "\n", - "print('Naively batched')\n", - "%timeit naively_batched_apply_matrix(batched_x).block_until_ready()" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "qHfKaLE9stbA" - }, - "source": [ - "We know how to batch this operation manually. In this case, `jnp.dot` handles extra batch dimensions transparently." - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "metadata": { - "id": "ipei6l8nvrzH", - "outputId": "335cdc4c-c603-497b-fc88-3fa37c5630c2" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Manually batched\n", - "45.6 µs ± 5.03 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)\n" - ] - } - ], - "source": [ - "@jit\n", - "def batched_apply_matrix(v_batched):\n", - " return jnp.dot(v_batched, mat.T)\n", - "\n", - "print('Manually batched')\n", - "%timeit batched_apply_matrix(batched_x).block_until_ready()" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "1eF8Nhb-szAb" - }, - "source": [ - "However, suppose we had a more complicated function without batching support. We can use {func}`~jax.vmap` to add batching support automatically." - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "metadata": { - "id": "67Oeknf5vuCl", - "outputId": "9c680e74-ebb5-4563-ebfc-869fd82de091" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Auto-vectorized with vmap\n", - "48.3 µs ± 1.06 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)\n" - ] - } - ], - "source": [ - "@jit\n", - "def vmap_batched_apply_matrix(v_batched):\n", - " return vmap(apply_matrix)(v_batched)\n", - "\n", - "print('Auto-vectorized with vmap')\n", - "%timeit vmap_batched_apply_matrix(batched_x).block_until_ready()" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "pYVl3Z2nbZhO" - }, - "source": [ - "Of course, {func}`~jax.vmap` can be arbitrarily composed with {func}`~jax.jit`, {func}`~jax.grad`, and any other JAX transformation." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "WwNnjaI4th_8" - }, - "source": [ - "This is just a taste of what JAX can do. We're really excited to see what you do with it!" - ] - } - ], - "metadata": { - "accelerator": "GPU", - "colab": { - "name": "JAX Quickstart.ipynb", - "provenance": [], - "toc_visible": true - }, - "jupytext": { - "formats": "ipynb,md:myst" - }, - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.7.6" - } - }, - "nbformat": 4, - "nbformat_minor": 0 -} diff --git a/docs/notebooks/quickstart.md b/docs/notebooks/quickstart.md deleted file mode 100644 index 46c37f208a8f..000000000000 --- a/docs/notebooks/quickstart.md +++ /dev/null @@ -1,293 +0,0 @@ ---- -jupytext: - formats: ipynb,md:myst - text_representation: - extension: .md - format_name: myst - format_version: 0.13 - jupytext_version: 1.16.0 -kernelspec: - display_name: Python 3 - language: python - name: python3 ---- - -+++ {"id": "xtWX4x9DCF5_"} - -# JAX Quickstart - -[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/quickstart.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/quickstart.ipynb) - -**JAX is NumPy on the CPU, GPU, and TPU, with great automatic differentiation for high-performance machine learning research.** - -With its updated version of [Autograd](https://github.com/hips/autograd), JAX -can automatically differentiate native Python and NumPy code. It can -differentiate through a large subset of Python’s features, including loops, ifs, -recursion, and closures, and it can even take derivatives of derivatives of -derivatives. It supports reverse-mode as well as forward-mode differentiation, and the two can be composed arbitrarily -to any order. - -What’s new is that JAX uses -[XLA](https://www.tensorflow.org/xla) -to compile and run your NumPy code on accelerators, like GPUs and TPUs. -Compilation happens under the hood by default, with library calls getting -just-in-time compiled and executed. But JAX even lets you just-in-time compile -your own Python functions into XLA-optimized kernels using a one-function API. -Compilation and automatic differentiation can be composed arbitrarily, so you -can express sophisticated algorithms and get maximal performance without having -to leave Python. - -```{code-cell} ipython3 -:id: SY8mDvEvCGqk - -import jax.numpy as jnp -from jax import grad, jit, vmap -from jax import random -``` - -+++ {"id": "FQ89jHCYfhpg"} - -## Multiplying Matrices - -+++ {"id": "Xpy1dSgNqCP4"} - -We'll be generating random data in the following examples. One big difference between NumPy and JAX is how you generate random numbers. For more details, see [Common Gotchas in JAX]. - -[Common Gotchas in JAX]: https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#%F0%9F%94%AA-Random-Numbers - -```{code-cell} ipython3 -:id: u0nseKZNqOoH -:outputId: 03e20e21-376c-41bb-a6bb-57431823691b - -key = random.PRNGKey(0) -x = random.normal(key, (10,)) -print(x) -``` - -+++ {"id": "hDJF0UPKnuqB"} - -Let's dive right in and multiply two big matrices. - -```{code-cell} ipython3 -:id: eXn8GUl6CG5N -:outputId: ffce6bdc-86e6-4af0-ab5d-65d235022db9 - -size = 3000 -x = random.normal(key, (size, size), dtype=jnp.float32) -%timeit jnp.dot(x, x.T).block_until_ready() # runs on the GPU -``` - -+++ {"id": "0AlN7EbonyaR"} - -We added that `block_until_ready` because JAX uses asynchronous execution by default (see {ref}`async-dispatch`). - -JAX NumPy functions work on regular NumPy arrays. - -```{code-cell} ipython3 -:id: ZPl0MuwYrM7t -:outputId: 71219657-b559-474e-a877-5441ee39f18f - -import numpy as np -x = np.random.normal(size=(size, size)).astype(np.float32) -%timeit jnp.dot(x, x.T).block_until_ready() -``` - -+++ {"id": "_SrcB2IurUuE"} - -That's slower because it has to transfer data to the GPU every time. You can ensure that an NDArray is backed by device memory using {func}`~jax.device_put`. - -```{code-cell} ipython3 -:id: Jj7M7zyRskF0 -:outputId: a649a6d3-cf28-445e-c3fc-bcfe3069482c - -from jax import device_put - -x = np.random.normal(size=(size, size)).astype(np.float32) -x = device_put(x) -%timeit jnp.dot(x, x.T).block_until_ready() -``` - -+++ {"id": "clO9djnen8qi"} - -The output of {func}`~jax.device_put` still acts like an NDArray, but it only copies values back to the CPU when they're needed for printing, plotting, saving to disk, branching, etc. The behavior of {func}`~jax.device_put` is equivalent to the function `jit(lambda x: x)`, but it's faster. - -+++ {"id": "ghkfKNQttDpg"} - -If you have a GPU (or TPU!) these calls run on the accelerator and have the potential to be much faster than on CPU. -See {ref}`faq-jax-vs-numpy` for more comparison of performance characteristics of NumPy and JAX - -+++ {"id": "iOzp0P_GoJhb"} - -JAX is much more than just a GPU-backed NumPy. It also comes with a few program transformations that are useful when writing numerical code. For now, there are three main ones: - - - {func}`~jax.jit`, for speeding up your code - - {func}`~jax.grad`, for taking derivatives - - {func}`~jax.vmap`, for automatic vectorization or batching. - -Let's go over these, one-by-one. We'll also end up composing these in interesting ways. - -+++ {"id": "bTTrTbWvgLUK"} - -## Using {func}`~jax.jit` to speed up functions - -+++ {"id": "YrqE32mvE3b7"} - -JAX runs transparently on the GPU or TPU (falling back to CPU if you don't have one). However, in the above example, JAX is dispatching kernels to the GPU one operation at a time. If we have a sequence of operations, we can use the `@jit` decorator to compile multiple operations together using [XLA](https://www.tensorflow.org/xla). Let's try that. - -```{code-cell} ipython3 -:id: qLGdCtFKFLOR -:outputId: 870253fa-ba1b-47ec-c5a4-1c6f706be996 - -def selu(x, alpha=1.67, lmbda=1.05): - return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha) - -x = random.normal(key, (1000000,)) -%timeit selu(x).block_until_ready() -``` - -+++ {"id": "a_V8SruVHrD_"} - -We can speed it up with `@jit`, which will jit-compile the first time `selu` is called and will be cached thereafter. - -```{code-cell} ipython3 -:id: fh4w_3NpFYTp -:outputId: 4d56b4f2-5d58-4689-ecc2-ac361c0245cd - -selu_jit = jit(selu) -%timeit selu_jit(x).block_until_ready() -``` - -+++ {"id": "HxpBc4WmfsEU"} - -## Taking derivatives with {func}`~jax.grad` - -In addition to evaluating numerical functions, we also want to transform them. One transformation is [automatic differentiation](https://en.wikipedia.org/wiki/Automatic_differentiation). In JAX, just like in [Autograd](https://github.com/HIPS/autograd), you can compute gradients with the {func}`~jax.grad` function. - -```{code-cell} ipython3 -:id: IMAgNJaMJwPD -:outputId: 6646cc65-b52f-4825-ff7f-e50b67083493 - -def sum_logistic(x): - return jnp.sum(1.0 / (1.0 + jnp.exp(-x))) - -x_small = jnp.arange(3.) -derivative_fn = grad(sum_logistic) -print(derivative_fn(x_small)) -``` - -+++ {"id": "PtNs881Ohioc"} - -Let's verify with finite differences that our result is correct. - -```{code-cell} ipython3 -:id: JXI7_OZuKZVO -:outputId: 18c1f913-d5d6-4895-f71e-e62180c3ad1b - -def first_finite_differences(f, x): - eps = 1e-3 - return jnp.array([(f(x + eps * v) - f(x - eps * v)) / (2 * eps) - for v in jnp.eye(len(x))]) - - -print(first_finite_differences(sum_logistic, x_small)) -``` - -+++ {"id": "Q2CUZjOWNZ-3"} - -Taking derivatives is as easy as calling {func}`~jax.grad`. {func}`~jax.grad` and {func}`~jax.jit` compose and can be mixed arbitrarily. In the above example we jitted `sum_logistic` and then took its derivative. We can go further: - -```{code-cell} ipython3 -:id: TO4g8ny-OEi4 -:outputId: 1a0421e6-60e9-42e3-dc9c-e558a69bbf17 - -print(grad(jit(grad(jit(grad(sum_logistic)))))(1.0)) -``` - -+++ {"id": "yCJ5feKvhnBJ"} - -For more advanced autodiff, you can use {func}`jax.vjp` for reverse-mode vector-Jacobian products and {func}`jax.jvp` for forward-mode Jacobian-vector products. The two can be composed arbitrarily with one another, and with other JAX transformations. Here's one way to compose them to make a function that efficiently computes full Hessian matrices: - -```{code-cell} ipython3 -:id: Z-JxbiNyhxEW - -from jax import jacfwd, jacrev -def hessian(fun): - return jit(jacfwd(jacrev(fun))) -``` - -+++ {"id": "TI4nPsGafxbL"} - -## Auto-vectorization with {func}`~jax.vmap` - -+++ {"id": "PcxkONy5aius"} - -JAX has one more transformation in its API that you might find useful: {func}`~jax.vmap`, the vectorizing map. It has the familiar semantics of mapping a function along array axes, but instead of keeping the loop on the outside, it pushes the loop down into a function’s primitive operations for better performance. When composed with {func}`~jax.jit`, it can be just as fast as adding the batch dimensions by hand. - -+++ {"id": "TPiX4y-bWLFS"} - -We're going to work with a simple example, and promote matrix-vector products into matrix-matrix products using {func}`~jax.vmap`. Although this is easy to do by hand in this specific case, the same technique can apply to more complicated functions. - -```{code-cell} ipython3 -:id: 8w0Gpsn8WYYj - -mat = random.normal(key, (150, 100)) -batched_x = random.normal(key, (10, 100)) - -def apply_matrix(v): - return jnp.dot(mat, v) -``` - -+++ {"id": "0zWsc0RisQWx"} - -Given a function such as `apply_matrix`, we can loop over a batch dimension in Python, but usually the performance of doing so is poor. - -```{code-cell} ipython3 -:id: KWVc9BsZv0Ki -:outputId: bea78b6d-cd17-45e6-c361-1c55234e77c0 - -def naively_batched_apply_matrix(v_batched): - return jnp.stack([apply_matrix(v) for v in v_batched]) - -print('Naively batched') -%timeit naively_batched_apply_matrix(batched_x).block_until_ready() -``` - -+++ {"id": "qHfKaLE9stbA"} - -We know how to batch this operation manually. In this case, `jnp.dot` handles extra batch dimensions transparently. - -```{code-cell} ipython3 -:id: ipei6l8nvrzH -:outputId: 335cdc4c-c603-497b-fc88-3fa37c5630c2 - -@jit -def batched_apply_matrix(v_batched): - return jnp.dot(v_batched, mat.T) - -print('Manually batched') -%timeit batched_apply_matrix(batched_x).block_until_ready() -``` - -+++ {"id": "1eF8Nhb-szAb"} - -However, suppose we had a more complicated function without batching support. We can use {func}`~jax.vmap` to add batching support automatically. - -```{code-cell} ipython3 -:id: 67Oeknf5vuCl -:outputId: 9c680e74-ebb5-4563-ebfc-869fd82de091 - -@jit -def vmap_batched_apply_matrix(v_batched): - return vmap(apply_matrix)(v_batched) - -print('Auto-vectorized with vmap') -%timeit vmap_batched_apply_matrix(batched_x).block_until_ready() -``` - -+++ {"id": "pYVl3Z2nbZhO"} - -Of course, {func}`~jax.vmap` can be arbitrarily composed with {func}`~jax.jit`, {func}`~jax.grad`, and any other JAX transformation. - -+++ {"id": "WwNnjaI4th_8"} - -This is just a taste of what JAX can do. We're really excited to see what you do with it! diff --git a/docs/notebooks/shard_map.ipynb b/docs/notebooks/shard_map.ipynb index 671dc7fbfa79..ed0a13d8702a 100644 --- a/docs/notebooks/shard_map.ipynb +++ b/docs/notebooks/shard_map.ipynb @@ -7,6 +7,8 @@ "source": [ "# SPMD multi-device parallelism with `shard_map`\n", "\n", + "\n", + "\n", "`shard_map` is a single-program multiple-data (SPMD) multi-device parallelism API to map a function over shards of data. Mapped function applications, or _instances_, communicate with each other via explicit collective communication operations.\n", "\n", "`shard_map` is complementary to, and composable with, the automatic compiler-based parallelization built into `jit`. With `jit` you write code as if for a single device, and [the compiler can automatically partition computation over multiple devices](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html), generating per-device code and communication collectives behind the scenes. With `shard_map` you take control, writing your own partitioned code and explicit collectives. Or you can do a bit of both: take manual control across groups of devices while leaving within-group device partitioning up to the compiler. The two approaches can be mixed, matched, and composed as needed.\n", @@ -15,7 +17,7 @@ "\n", "By reading this tutorial, you'll learn how to use `shard_map` to get full control over your multi-device code. You'll see in detail how it composes with `jax.jit`'s automatic parallelization and `jax.grad`'s automatic differentiation. We'll also give some basic examples of neural network parallelization strategies.\n", "\n", - "We'll assume this tutorial is being run in an environment with eight devices\"" + "We'll assume this tutorial is being run in an environment with eight devices:" ] }, { @@ -259,7 +261,7 @@ "id": "985ff202", "metadata": {}, "source": [ - "Recall that jnp.split slices its input into equally-sized blocks with the same\n", + "Recall that `jnp.split` slices its input into equally-sized blocks with the same\n", "rank, so that if in the above example `y` had shape `f32[8,5]` then each\n", "`y_blk` would have shape `f32[2,5]`, and if each `f(y_blk)` had shape\n", "`f32[3,7]` then the final concatenated result `shard_map(f, ...)(y)` would have\n", @@ -1184,7 +1186,7 @@ "metadata": {}, "source": [ "That's great, but we're not getting any compute/communication overlap\n", - "here: before we can start the matmul, we need the all_gather to complete.\n", + "here: before we can start the matmul, we need the `all_gather` to complete.\n", "Here's a profile using the same code, but on larger example shapes (`(8192,\n", "8192)` for `lhs` and `(8192, 1024)` for `rhs`):\n", "\n", diff --git a/docs/notebooks/shard_map.md b/docs/notebooks/shard_map.md index 90ca7ca2b4c3..67494cfd4a02 100644 --- a/docs/notebooks/shard_map.md +++ b/docs/notebooks/shard_map.md @@ -7,7 +7,7 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.16.0 + jupytext_version: 1.16.1 kernelspec: display_name: Python 3 language: python @@ -16,6 +16,8 @@ kernelspec: # SPMD multi-device parallelism with `shard_map` + + `shard_map` is a single-program multiple-data (SPMD) multi-device parallelism API to map a function over shards of data. Mapped function applications, or _instances_, communicate with each other via explicit collective communication operations. `shard_map` is complementary to, and composable with, the automatic compiler-based parallelization built into `jit`. With `jit` you write code as if for a single device, and [the compiler can automatically partition computation over multiple devices](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html), generating per-device code and communication collectives behind the scenes. With `shard_map` you take control, writing your own partitioned code and explicit collectives. Or you can do a bit of both: take manual control across groups of devices while leaving within-group device partitioning up to the compiler. The two approaches can be mixed, matched, and composed as needed. @@ -24,7 +26,7 @@ If you're familiar with `pmap`, think of `shard_map` as an evolution. It's more By reading this tutorial, you'll learn how to use `shard_map` to get full control over your multi-device code. You'll see in detail how it composes with `jax.jit`'s automatic parallelization and `jax.grad`'s automatic differentiation. We'll also give some basic examples of neural network parallelization strategies. -We'll assume this tutorial is being run in an environment with eight devices" +We'll assume this tutorial is being run in an environment with eight devices: ```{code-cell} import os @@ -166,7 +168,7 @@ def check_shmap(f, y): check_shmap(lambda x: x.T @ x, jnp.arange(32).reshape(8, 4)) ``` -Recall that jnp.split slices its input into equally-sized blocks with the same +Recall that `jnp.split` slices its input into equally-sized blocks with the same rank, so that if in the above example `y` had shape `f32[8,5]` then each `y_blk` would have shape `f32[2,5]`, and if each `f(y_blk)` had shape `f32[3,7]` then the final concatenated result `shard_map(f, ...)(y)` would have @@ -845,7 +847,7 @@ print(jnp.allclose(out, lhs @ rhs, atol=1e-3, rtol=1e-3)) ``` That's great, but we're not getting any compute/communication overlap -here: before we can start the matmul, we need the all_gather to complete. +here: before we can start the matmul, we need the `all_gather` to complete. Here's a profile using the same code, but on larger example shapes (`(8192, 8192)` for `lhs` and `(8192, 1024)` for `rhs`): diff --git a/docs/notebooks/thinking_in_jax.ipynb b/docs/notebooks/thinking_in_jax.ipynb index bcaa1f42b9a7..1c1c9729b654 100644 --- a/docs/notebooks/thinking_in_jax.ipynb +++ b/docs/notebooks/thinking_in_jax.ipynb @@ -8,6 +8,8 @@ "source": [ "# How to Think in JAX\n", "\n", + "\n", + "\n", "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/thinking_in_jax.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/thinking_in_jax.ipynb)\n", "\n", "JAX provides a simple and powerful API for writing accelerated numerical code, but working effectively in JAX sometimes requires extra consideration. This document is meant to help build a ground-up understanding of how JAX operates, so that you can use it more effectively." diff --git a/docs/notebooks/thinking_in_jax.md b/docs/notebooks/thinking_in_jax.md index ce7721bf9d50..14089fa36e32 100644 --- a/docs/notebooks/thinking_in_jax.md +++ b/docs/notebooks/thinking_in_jax.md @@ -5,7 +5,7 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.16.0 + jupytext_version: 1.16.1 kernelspec: display_name: Python 3 name: python3 @@ -15,6 +15,8 @@ kernelspec: # How to Think in JAX + + [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/thinking_in_jax.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/thinking_in_jax.ipynb) JAX provides a simple and powerful API for writing accelerated numerical code, but working effectively in JAX sometimes requires extra consideration. This document is meant to help build a ground-up understanding of how JAX operates, so that you can use it more effectively. diff --git a/docs/notebooks/vmapped_log_probs.ipynb b/docs/notebooks/vmapped_log_probs.ipynb index 4ee2e4924d53..96b334296667 100644 --- a/docs/notebooks/vmapped_log_probs.ipynb +++ b/docs/notebooks/vmapped_log_probs.ipynb @@ -8,6 +8,8 @@ "source": [ "# Autobatching for Bayesian Inference\n", "\n", + "\n", + "\n", "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/vmapped_log_probs.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/vmapped_log_probs.ipynb)\n", "\n", "This notebook demonstrates a simple Bayesian inference example where autobatching makes user code easier to write, easier to read, and less likely to include bugs.\n", @@ -483,7 +485,7 @@ "\n", "normal_sample = jax.jit(normal_sample, static_argnums=(1,))\n", "\n", - "key = random.PRNGKey(10003)\n", + "key = random.key(10003)\n", "\n", "beta_loc = jnp.zeros(num_features, jnp.float32)\n", "beta_log_scale = jnp.zeros(num_features, jnp.float32)\n", diff --git a/docs/notebooks/vmapped_log_probs.md b/docs/notebooks/vmapped_log_probs.md index 8b6d7ceeb61a..ea8b4fce2f70 100644 --- a/docs/notebooks/vmapped_log_probs.md +++ b/docs/notebooks/vmapped_log_probs.md @@ -5,7 +5,7 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.16.0 + jupytext_version: 1.16.1 kernelspec: display_name: Python 3 language: python @@ -16,6 +16,8 @@ kernelspec: # Autobatching for Bayesian Inference + + [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/vmapped_log_probs.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/vmapped_log_probs.ipynb) This notebook demonstrates a simple Bayesian inference example where autobatching makes user code easier to write, easier to read, and less likely to include bugs. @@ -210,7 +212,7 @@ def normal_sample(key, shape): normal_sample = jax.jit(normal_sample, static_argnums=(1,)) -key = random.PRNGKey(10003) +key = random.key(10003) beta_loc = jnp.zeros(num_features, jnp.float32) beta_log_scale = jnp.zeros(num_features, jnp.float32) diff --git a/docs/notebooks/xmap_tutorial.ipynb b/docs/notebooks/xmap_tutorial.ipynb index bfde9da799f8..a8eb76c353ed 100644 --- a/docs/notebooks/xmap_tutorial.ipynb +++ b/docs/notebooks/xmap_tutorial.ipynb @@ -8,7 +8,9 @@ "source": [ "# Named axes and easy-to-revise parallelism with `xmap`\n", "\n", - "**_UPDATE:_** The recommended ways to do multi-device programming in JAX are using: 1) [`jit` (automatic partitioning of computation and `jax.Array` sharding)](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html); and/or 2) [`shard_map` (manual data sharding)](https://jax.readthedocs.io/en/latest/jep/14273-shard-map.html). Learn more in [Why don’t `pmap` or `xmap` already solve this?](https://jax.readthedocs.io/en/latest/jep/14273-shard-map.html#why-don-t-pmap-or-xmap-already-solve-this) in the [`shard_map` JEP document](https://jax.readthedocs.io/en/latest/jep/14273-shard-map.html).\n", + "\n", + "\n", + "**_UPDATE:_** `xmap` is deprecated and will be removed in a future release. The recommended ways to do multi-device programming in JAX are using: 1) [`jit` (automatic partitioning of computation and `jax.Array` sharding)](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html); and/or 2) [`shard_map` (manual data sharding)](https://jax.readthedocs.io/en/latest/jep/14273-shard-map.html). Learn more in [Why don’t `pmap` or `xmap` already solve this?](https://jax.readthedocs.io/en/latest/jep/14273-shard-map.html#why-don-t-pmap-or-xmap-already-solve-this) in the [`shard_map` JEP document](https://jax.readthedocs.io/en/latest/jep/14273-shard-map.html).\n", "\n", "This tutorial introduces JAX `xmap` (`jax.experimental.maps.xmap`) and the named-axis programming model that comes with it. By reading this, you'll learn how to write error-avoiding, self-documenting functions using named axes, then control how they're executed on hardware at any scale, from your laptop CPU to the largest TPU supercomputer.\n", "\n", diff --git a/docs/notebooks/xmap_tutorial.md b/docs/notebooks/xmap_tutorial.md index 69addc5cc3bd..c4b511dbe711 100644 --- a/docs/notebooks/xmap_tutorial.md +++ b/docs/notebooks/xmap_tutorial.md @@ -5,7 +5,7 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.16.0 + jupytext_version: 1.16.1 kernelspec: display_name: Python 3 name: python3 @@ -15,7 +15,9 @@ kernelspec: # Named axes and easy-to-revise parallelism with `xmap` -**_UPDATE:_** The recommended ways to do multi-device programming in JAX are using: 1) [`jit` (automatic partitioning of computation and `jax.Array` sharding)](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html); and/or 2) [`shard_map` (manual data sharding)](https://jax.readthedocs.io/en/latest/jep/14273-shard-map.html). Learn more in [Why don’t `pmap` or `xmap` already solve this?](https://jax.readthedocs.io/en/latest/jep/14273-shard-map.html#why-don-t-pmap-or-xmap-already-solve-this) in the [`shard_map` JEP document](https://jax.readthedocs.io/en/latest/jep/14273-shard-map.html). + + +**_UPDATE:_** `xmap` is deprecated and will be removed in a future release. The recommended ways to do multi-device programming in JAX are using: 1) [`jit` (automatic partitioning of computation and `jax.Array` sharding)](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html); and/or 2) [`shard_map` (manual data sharding)](https://jax.readthedocs.io/en/latest/jep/14273-shard-map.html). Learn more in [Why don’t `pmap` or `xmap` already solve this?](https://jax.readthedocs.io/en/latest/jep/14273-shard-map.html#why-don-t-pmap-or-xmap-already-solve-this) in the [`shard_map` JEP document](https://jax.readthedocs.io/en/latest/jep/14273-shard-map.html). This tutorial introduces JAX `xmap` (`jax.experimental.maps.xmap`) and the named-axis programming model that comes with it. By reading this, you'll learn how to write error-avoiding, self-documenting functions using named axes, then control how they're executed on hardware at any scale, from your laptop CPU to the largest TPU supercomputer. diff --git a/docs/pallas/design.md b/docs/pallas/design.md index b7876af0c788..3ba32f25a376 100644 --- a/docs/pallas/design.md +++ b/docs/pallas/design.md @@ -1,23 +1,96 @@ # Pallas Design -In this document, we explain the initial Pallas design. This is a snapshot of some of the earlier design decisions made and Pallas's specific APIs might have changed since. + -## Introduction - -JAX is being used for a diverse set of workloads, from large scale machine learning to scientific computing. JAX’s success story is as much a success story for XLA, the primary compiler that JAX targets – XLA compiles JAX programs for accelerators and has enabled JAX to scale to the largest ML models. JAX describes logical computations in XLA’s representation, HLO. HLO describes how computations happen logically but not physically. Given a logical HLO computation, XLA decides how that computation is to be executed physically. For a wide variety of ML applications, XLA does a good job of compiling user programs but inevitably some users hit XLA's limitations. In these cases, we need to provide an “escape hatch” to allow experts to write hand-tuned kernels that outperform XLA at that point in time. Furthermore, advances in ML systems research take some time to be incorporated into XLA and users often want to run ahead with them. Over time, the compiler can incorporate the optimizations that were proven out experimentally through hand-tuned kernels. - -XLA does offer the `CustomCall` mechanism as an escape hatch, but it requires users to write C++ and on GPU it requires users to learn the CUDA programming model. The CUDA programming model is arguably too low-level for many machine learning GPU kernels, like matrix multiplication, and even expert users will have trouble using CUDA to implement efficient matrix multiplication or multi-headed attention. Not only this, JAX users are usually familiar with Python and NumPy-style array programming which doesn’t involve writing any C++ or thinking about GPU parallelism. All popular machine learning frameworks share this idea: manipulating (usually) arrays with high level operations like `matmul` or `convolution`. Unfortunately, this means implementing a custom operation via `CustomCall` is a big investment, involving potentially learning C++ and/or GPU programming. - -[Triton](https://triton-lang.org/main/index.html), a GPU compiler built and maintained by OpenAI, has taken the ML compiler world by storm. Triton offers the best of both worlds: an array-based programming model for GPU kernels. Triton is the primary code generation route for `torch.compile` in PyTorch 2.0, via the Torch Inductor library. Triton actively hides some aspects of GPU programming in the name of a more accessible programming model that can be used from Python and to generate optimized code from a higher-level representation. While GPUs are more flexible than what Triton offers, in the ML domain, Triton seems to be expressive enough for many applications. +In this document, we explain the initial Pallas design. +This is a snapshot of some of the earlier design decisions made +and Pallas's specific APIs might have changed since. -In this document, we describe Pallas, an extension to JAX that enables kernel programming for both GPUs and TPUs using a Triton-like model. A JAX-based kernel language offers several advantages: -* Although Triton exposes a TPU-like programming model to users, i.e. writing programs for tiles of arrays in L1-cache, it is specialized enough to GPU that we cannot directly compile Triton for TPU. For example, Triton offers atomic operations specifically meant to handle parallel writes that don’t necessarily make sense on TPU. A higher level front end can abstract away details of the platform while surfacing just that tile-based programming model. The kernels will thus be portable across different hardware platforms. -* JAX as a tracing-based frontend for numerical computing is both mature and well-used. By embedding the kernel programming language in JAX itself, we can re-use JAX’s tracing infrastructure and provide a NumPy-like frontend that’s already familiar to users. -* JAX transformations are key to its success, allowing users to express simple programs but transform them to achieve complex functionality. We can leverage the same transformations (vmap, jvp, etc.) to transform user-written kernels. - -The open question is: is JAX a good fit for a kernel language at all? We think so. Triton demonstrates that an array programming language can be practical for writing GPU kernels and JAX is just that. JAX has also proven to be a flexible front-end for compilers and for program transformations. +## Introduction -We describe Pallas as follows: we first describe the ways in which we extend JAX to support writing custom kernels. We then show how we can lower Pallas to both Triton and Mosaic. We conclude by describing existing and potential ways to transform Pallas kernels via JAX transformations. +JAX is being used for a diverse set of workloads, from large scale machine +learning to scientific computing. +JAX’s success story is as much a success story for XLA, +the primary compiler that JAX targets – XLA compiles JAX +programs for accelerators and has enabled JAX to scale to the largest ML +models. +JAX describes logical computations in XLA’s representation, HLO. +HLO describes how computations happen logically but not physically. +Given a logical HLO computation, XLA decides how that computation is to be +executed physically. +For a wide variety of ML applications, XLA does a good +job of compiling user programs but inevitably some users hit XLA's +limitations. +In these cases, we need to provide an “escape hatch” to allow +experts to write hand-tuned kernels that outperform XLA at that +point in time. +Furthermore, advances in ML systems research take some time to be +incorporated into XLA and users often want to run ahead with them. +Over time, the compiler can incorporate the optimizations that were proven +out experimentally through hand-tuned kernels. + +XLA does offer the `CustomCall` mechanism as an escape hatch, but it +requires users to write C++ and on GPU it requires users to learn the +CUDA programming model. +The CUDA programming model is arguably too low-level for many machine +learning GPU kernels, like matrix multiplication, +and even expert users will have trouble using CUDA to implement efficient +matrix multiplication or multi-headed attention. +Not only this, JAX users are usually familiar with Python and NumPy-style +array programming which doesn’t involve writing any C++ or thinking about +GPU parallelism. +All popular machine learning frameworks share this +idea: manipulating (usually) arrays with high level operations +like `matmul` or `convolution`. +Unfortunately, this means implementing a custom operation via `CustomCall` +is a big investment, involving potentially learning C++ and/or GPU +programming. + +[Triton](https://triton-lang.org/main/index.html), a GPU compiler built +and maintained by OpenAI, has taken the ML compiler world by storm. +Triton offers the best of both worlds: an array-based programming model +for GPU kernels. Triton is the primary code generation route +for `torch.compile` in PyTorch 2.0, via the Torch Inductor library. +Triton actively hides some aspects of GPU programming in the name of a +more accessible programming model that can be used from Python and to +generate optimized code from a higher-level representation. +While GPUs are more flexible than what Triton offers, in the ML domain, +Triton seems to be expressive enough for many applications. + +In this document, we describe Pallas, an extension to JAX that enables +kernel programming for both GPUs and TPUs using a Triton-like model. +A JAX-based kernel language offers several advantages: +* Although Triton exposes a TPU-like programming model to users, + i.e. writing programs for tiles of arrays in L1-cache, it is specialized + enough to GPU that we cannot directly compile Triton for TPU. + For example, Triton offers atomic operations specifically meant to + handle parallel writes that don’t necessarily make sense on TPU. + A higher level front end can abstract away details of the platform + while surfacing just that tile-based programming model. + The kernels will thus be portable across different hardware platforms. +* JAX as a tracing-based frontend for numerical computing is both + mature and well-used. + By embedding the kernel programming language in JAX itself, + we can re-use JAX’s tracing infrastructure and provide a + NumPy-like frontend that’s already familiar to users. +* JAX transformations are key to its success, allowing users to + express simple programs but transform them to achieve complex + functionality. + We can leverage the same transformations (vmap, jvp, etc.) to + transform user-written kernels. + +The open question is: is JAX a good fit for a kernel language at all? +We think so. +Triton demonstrates that an array programming language can be +practical for writing GPU kernels and JAX is just that. +JAX has also proven to be a flexible front-end for compilers and +for program transformations. + +We describe Pallas as follows: we first describe the ways in which +we extend JAX to support writing custom kernels. +We then show how we can lower Pallas to both Triton and Mosaic. +We conclude by describing existing and potential ways to transform +Pallas kernels via JAX transformations.
@@ -28,10 +101,17 @@ Visualization of Pallas lowering paths ## Pallas: Extending JAX for kernels -The key point we’d like to make is that Pallas is just JAX, with some extensions: -1. Users now use reference types called `Ref`s in their JAX code. This gives users more precise control over memory access and layout in JAX will more closely resemble physical layout. -2. Users write their JAX programs using a subset of JAX primitives, along with a set of Pallas-specific primitives. -3. Users embed their Pallas kernels in an outer JAX program via a special `pallas_call` higher-order function, that executes the kernel in a map. It is analogous to `pmap` or `shard_map`, except with references to shared memory. +The key point we’d like to make is that Pallas is just JAX, with some +extensions: +1. Users now use reference types called `Ref`s in their JAX code. + This gives users more precise control over memory access and + layout in JAX will more closely resemble physical layout. +2. Users write their JAX programs using a subset of JAX primitives, + along with a set of Pallas-specific primitives. +3. Users embed their Pallas kernels in an outer JAX program via a + special `pallas_call` higher-order function, that executes the + kernel in a map. It is analogous to `pmap` or `shard_map`, + except with references to shared memory. We’ll go over these three extensions one at a time, by example. @@ -56,13 +136,28 @@ add = pl.pallas_call(add_kernel, out_shape=jax.ShapeDtypeStruct((8,), jnp.int32) add(x, y) ``` -Unlike a regular JAX program, `add_kernel` does not receive immutable array arguments. Instead, it’s provided with references that can be read from and updated in-place using NumPy-like syntax. `Ref`s are not a Pallas-specific concept – they were introduced to JAX to represent stateful computations. However, we can leverage them when writing kernels that operate on mutable memory too. - -Pallas kernels not only receive `Ref`s corresponding to the inputs to the kernel, but also receive `Ref`s for the outputs as well (specified in `pallas_call` via `out_shape`). `Ref`s are special types that cannot be passed into the usual set of JAX primitives without being read from first. When you read from a `Ref` you get a JAX `Array` type out, and you must write an `Array` into a `Ref`. +Unlike a regular JAX program, `add_kernel` does not receive immutable +array arguments. +Instead, it’s provided with references that can be read from and +updated in-place using NumPy-like syntax. +`Ref`s are not a Pallas-specific concept – they were introduced to +JAX to represent stateful computations. +However, we can leverage them when writing kernels that operate on +mutable memory too. + +Pallas kernels not only receive `Ref`s corresponding to the inputs +to the kernel, but also receive `Ref`s for the outputs as well +(specified in `pallas_call` via `out_shape`). +`Ref`s are special types that cannot be passed into the usual set of +JAX primitives without being read from first. +When you read from a `Ref` you get a JAX `Array` type out, and you +must write an `Array` into a `Ref`. #### Reading from/writing into Refs -Reading from a `Ref` corresponds to loading an array into the lowest level of the memory hierarchy (L1-cache on GPU and vector registers on TPU). Writing into a `Ref` is analogous. +Reading from a `Ref` corresponds to loading an array into the +lowest level of the memory hierarchy (L1-cache on GPU and vector +registers on TPU). Writing into a `Ref` is analogous. ```python def f(x_ref, o_ref): @@ -77,18 +172,37 @@ def f(x_ref): x = x_ref[jnp.arange(2)[..., None], jnp.arange(3)[None, ...]] ``` -Writing to `Ref`s can be done via analogous `__setitem__` style indexing. +Writing to `Ref`s can be done via analogous `__setitem__` style +indexing. -Other forms of indexing (for example, dynamic slicing) can be done via `pallas.load` and `pallas.store`, new JAX primitives designed to make loading from/storing into memory easier. We’ll discuss these new primitives later. +Other forms of indexing (for example, dynamic slicing) can be done +via `pallas.load` and `pallas.store`, new JAX primitives designed to +make loading from/storing into memory easier. +We’ll discuss these new primitives later. ### Extending JAX with new Pallas primitives -Because JAX was designed with HLO in mind, the set of JAX primitives closely mirrors the set of HLO operations. Targeting a new compiler (e.g. Triton or Mosaic) means we might need to supplement JAX’s primitives with new ones specific to the new compiler. At the same time, we may not be able to lower all JAX primitives, so we need to restrict it to a subset. +Because JAX was designed with HLO in mind, the set of JAX primitives +closely mirrors the set of HLO operations. +Targeting a new compiler (e.g. Triton or Mosaic) means we might need +to supplement JAX’s primitives with new ones specific to the new +compiler. +At the same time, we may not be able to lower all JAX primitives, +so we need to restrict it to a subset. -Because Pallas was initially designed with Triton in mind, we offer a set of new primitives targeting the Triton programming model. As we’ll show later, we can lower these primitives to Mosaic as well. +Because Pallas was initially designed with Triton in mind, +we offer a set of new primitives targeting the Triton programming model. +As we’ll show later, we can lower these primitives to Mosaic as well. #### `pallas.load` and `pallas.store` -`pallas.load` and `pallas.store` are primitives that allow loading from memory and storing into memory. Unlike `__getitem__` and `__setitem__` they are more flexible at the cost of being more verbose. Specifically, you can use the `pallas.dynamic_slice` (`pallas.ds` for short) construct (which should maybe be upstreamed into JAX to be used with Ref `__getitem__` and `__setitem__`). + +`pallas.load` and `pallas.store` are primitives that allow loading +from memory and storing into memory. +Unlike `__getitem__` and `__setitem__` they are more flexible at the +cost of being more verbose. +Specifically, you can use the `pallas.dynamic_slice` (`pallas.ds` for +short) construct (which should maybe be upstreamed into JAX to be +used with Ref `__getitem__` and `__setitem__`). ```python def f(x_ref, o_ref): @@ -101,7 +215,8 @@ def f(x_ref, o_ref): ``` -`pallas.load` and `pallas.store` also support masking via the mask argument. +`pallas.load` and `pallas.store` also support masking via the mask +argument. ```python def f(x_ref, o_ref): @@ -111,12 +226,25 @@ def f(x_ref, o_ref): x = pl.load(x_ref, (idx,), mask=mask, other=float('-inf')) ``` -Masking is important when doing out-of-bounds loads/stores. The operational semantics of masking can be compiler-determined (if we understand the documentation properly, Triton avoids the read from/write to memory if it’s masked). +Masking is important when doing out-of-bounds loads/stores. +The operational semantics of masking can be compiler-determined +(if we understand the documentation properly, Triton avoids the read +from/write to memory if it’s masked). #### `pallas.program_id` and `pallas.num_programs` -As we’ll soon see, we’ll be executing the same Pallas kernels many times (either in parallel or in a pipeline depending on the backend). These new primitives tell us “where” we are in the execution of the kernel. -`pallas.program_id` takes in an axis argument, which tells us which index in an axis of a multidimensional grid this kernel is currently executing in (analogous to `threadId` from CUDA programming or `lax.axis_index` in `jax.pmap`). Note that we are currently borrowing the “program” terminology from Triton and in the future we might want to change it to something more familiar to JAX users. +As we’ll soon see, we’ll be executing the same Pallas kernels many +times (either in parallel or in a pipeline depending on the backend). +These new primitives tell us “where” we are in the execution of the +kernel. + +`pallas.program_id` takes in an axis argument, which tells us which +index in an axis of a multidimensional grid this kernel is currently +executing in (analogous to `threadId` from CUDA programming or +`lax.axis_index` in `jax.pmap`). +Note that we are currently borrowing the “program” terminology from +Triton and in the future we might want to change it to something more +familiar to JAX users. ```python def f(x_ref, o_ref): @@ -124,42 +252,61 @@ def f(x_ref, o_ref): o_ref[i] = jnp.exp(x_ref[i]) ``` -`pallas.num_programs` also takes in an axis and returns the grid size for that axis. +`pallas.num_programs` also takes in an axis and returns the grid size +for that axis. -Note that while `program_id` and `num_programs` are Triton-specific terminology they are easily generalized to make sense on TPU as well. +Note that while `program_id` and `num_programs` are Triton-specific +terminology they are easily generalized to make sense on TPU as well. #### Using a subset of JAX primitives in Pallas -Because we’re writing kernels, not high-level HLO programs, some JAX primitives may not be able to be represented in our underlying substrate efficiently. However, we know we can support most elementwise operations, simple dot products, and JAX control flow. +Because we’re writing kernels, not high-level HLO programs, some JAX +primitives may not be able to be represented in our underlying +substrate efficiently. +However, we know we can support most elementwise operations, +simple dot products, and JAX control flow. -While we haven’t yet mapped out exactly all the JAX primitives that we can support in Pallas kernels, we can certainly identify some that are not easy to lower or are unlikely to be useful: -* `conv_general` - convolution usually isn’t offered as primitive in the underlying hardware. -* `gather/scatter` - the underlying compiler may not support noncontiguous memory reads and writes +While we haven’t yet mapped out exactly all the JAX primitives that +we can support in Pallas kernels, we can certainly identify some that +are not easy to lower or are unlikely to be useful: +* `conv_general` - convolution usually isn’t offered as primitive in + the underlying hardware. +* `gather/scatter` - the underlying compiler may not support + noncontiguous memory reads and writes ### Executing Pallas kernels with `pallas_call` -Now that we’ve written our Pallas kernels (a.k.a. JAX with `Ref`s and the extra Pallas primitives), how do we execute them on a GPU or TPU? We use `pallas_call`, a higher order function (akin to `jax.jit` and `jax.pmap`) that executes the kernel. +Now that we’ve written our Pallas kernels (a.k.a. JAX with `Ref`s and +the extra Pallas primitives), how do we execute them on a GPU or TPU? +We use `pallas_call`, a higher order function (akin to `jax.jit` and +`jax.pmap`) that executes the kernel. The signature of `pallas_call` is as follows: ```python def pallas_call( kernel: Callable, + out_shape: Sequence[jax.ShapeDtypeStruct], + *, in_specs: Sequence[Spec], out_specs: Sequence[Spec], - out_shapes: Sequence[jax.ShapeDtypeStruct], grid: Optional[Tuple[int, ...]] = None) -> Callable: ... ``` -When we provide a kernel to `pallas_call` we provide additional information. The first is `out_shape` which tells the kernel what the outputs look like (`pallas_call` will pass a `Ref` corresponding to these into the kernel to be written to). The rest of the information (`in_specs`, `out_specs`, and `grid`) are information about how the kernel will be scheduled on the accelerator. +When we provide a kernel to `pallas_call` we provide additional +information. The first is `out_shape` which tells the kernel what the +outputs look like (`pallas_call` will pass a `Ref` corresponding to +these into the kernel to be written to). +The rest of the information (`in_specs`, `out_specs`, and `grid`) are +information about how the kernel will be scheduled on the accelerator. The (rough) semantics for `pallas_call` are as follows: ```python -def pallas_call(kernel, in_specs, out_specs, out_shapes, grid): +def pallas_call(kernel, out_shape, *, in_specs, out_specs, grid): def execute(*args): - outputs = map(empty_ref, out_shapes) + outputs = map(empty_ref, out_shape) grid_indices = map(range, grid) for indices in itertools.product(*grid_indices): # Could run in parallel! local_inputs = [in_spec.transform(arg, indices) for arg, in_spec in @@ -170,13 +317,37 @@ def pallas_call(kernel, in_specs, out_specs, out_shapes, grid): return execute ``` -Specifically, `pallas_call` will “loop” over grid iteration space, applying a transformation to the inputs and outputs specified via the `in_specs` and `out_specs`. In each iteration, the kernel will be called on the transformed inputs and outputs. Note that the “loop” over the iteration space could be executed in parallel (e.g. on GPU). `pallas_call` also provides no guarantees on the order of loop iterations over the iteration space, just that every member of the iteration space will be looped over. Compilers like Triton and Mosaic will have more specific operational semantics associated with the grid. +Specifically, `pallas_call` will “loop” over grid iteration space, +applying a transformation to the inputs and outputs specified via +the `in_specs` and `out_specs`. +In each iteration, the kernel will be called on the transformed +inputs and outputs. Note that the “loop” over the iteration space +could be executed in parallel (e.g. on GPU). +`pallas_call` also provides no guarantees on the order of loop +iterations over the iteration space, just that every member of the +iteration space will be looped over. +Compilers like Triton and Mosaic will have more specific operational +semantics associated with the grid. #### Transformation functions -The `in_specs` and `out_specs` arguments to `pallas_call` allow inputs and outputs to be transformed in some way. The two options that Pallas offers right now are an identity transformation (where inputs and outputs are left unchanged), and `BlockSpec`s, take fixed-size slices of `Ref`s determined by the loop index. - -A `BlockSpec` takes an `index_map` function and a `block_shape`. Logically, it takes an array and slices it along each axis into `block_shape` sizes blocks. The `index_map` function takes loop indices (from the grid index set) and maps them to block indices. The transform function converts `Ref`s into logical views of the `Ref` at the corresponding block. When we specify `None` in an entry in block_shape, that corresponds to “mapping” over that dimension, removing it from the block within the kernel. +The `in_specs` and `out_specs` arguments to `pallas_call` allow +inputs and outputs to be transformed in some way. +The two options that Pallas offers right now are an identity +transformation (where inputs and outputs are left unchanged), +and `BlockSpec`s, take fixed-size slices of `Ref`s determined by the +loop index. + +A `BlockSpec` takes an `index_map` function and a `block_shape`. +Logically, it takes an array and slices it along each axis into +`block_shape` sizes blocks. +The `index_map` function takes loop indices (from the grid index set) +and maps them to block indices. +The transform function converts `Ref`s into logical views of the +`Ref` at the corresponding block. +When we specify `None` in an entry in block_shape, +that corresponds to “mapping” over that dimension, +removing it from the block within the kernel. ```python class BlockSpec: @@ -189,16 +360,28 @@ class BlockSpec: ... ``` -We could also imagine other `Spec`s that are used with `pallas_call`, for example a `Spec` that corresponds to overlapping windows to, say, implement convolutions. +We could also imagine other `Spec`s that are used with `pallas_call`, +for example a `Spec` that corresponds to overlapping windows to, say, +implement convolutions. ### Immediate benefits of Pallas as a front-end -By offering a JAX front-end for kernel writing, we can immediately reap some benefits. +By offering a JAX front-end for kernel writing, we can immediately +reap some benefits. #### More flexible front end -The first is that JAX users are already accustomed to the benefits (and limitations) of programming with JAX and its tracing-based transformations. This means users can use closures and other familiar Python constructs when writing Pallas kernels. This is unlike the existing AST-parsing-based Triton front end or the MLIR builders for Mosaic. For example, this makes Pallas far more amenable to templating than Triton. +The first is that JAX users are already accustomed to the benefits +(and limitations) of programming with JAX and its tracing-based +transformations. +This means users can use closures and other familiar Python constructs +when writing Pallas kernels. +This is unlike the existing AST-parsing-based Triton front end or the +MLIR builders for Mosaic. +For example, this makes Pallas far more amenable to templating than +Triton. -See this example of how we can use higher-order functions in Python to template a kernel. +See this example of how we can use higher-order functions in Python +to template a kernel. ```python def make_kernel(eltwise_kernel): @@ -217,13 +400,25 @@ pl.pallas_call(kernel2, out_shape=x, grid=1)(1., 1.) #### Emulation mode -By representing kernels as programs with JAX primitives and some new Pallas primitives, we can also lower Pallas programs to StableHLO directly and compile/execute them with XLA. Specifically, a `pallas_call` can be implemented as a `lax.scan` over the grid. This enables us to develop GPU or TPU kernels on any XLA-supported platform (even CPU!) and debug them using JAX/XLA debugging tools (like `jax.debug.print`). We can also use the more reliable and better tested XLA numerics to verify the correctness of the Triton and Mosaic compilers. One could also imagine perturbing the `scan` ordering to simulate the parallel reads and writes that happen on GPU. +By representing kernels as programs with JAX primitives and some new +Pallas primitives, we can also lower Pallas programs to StableHLO +directly and compile/execute them with XLA. +Specifically, a `pallas_call` can be implemented as a `lax.scan` over +the grid. +This enables us to develop GPU or TPU kernels on any XLA-supported +platform (even CPU!) and debug them using JAX/XLA debugging tools +(like `jax.debug.print`). +We can also use the more reliable and better tested XLA numerics to +verify the correctness of the Triton and Mosaic compilers. +One could also imagine perturbing the `scan` ordering to simulate the +parallel reads and writes that happen on GPU. ### Examples #### `add` -We modify our `add_kernel` example to operate over (2,)-sized blocks using `BlockSpec`s. +We modify our `add_kernel` example to operate over (2,)-sized blocks +using `BlockSpec`s. ```python def add_kernel(x_ref, y_ref, o_ref): @@ -236,28 +431,32 @@ add = pl.pallas_call( add_kernel, out_shape=jax.ShapeDtypeStruct((8,), jnp.int32), in_specs=[ - pl.BlockSpec(lambda i:, i, (2,)), - pl.BlockSpec(lambda i:, i, (2,)) + pl.BlockSpec((2,), lambda i: i), + pl.BlockSpec((2,), lambda i: i) ], - out_specs=pl.BlockSpec(lambda i: i, (2,)) + out_specs=pl.BlockSpec((2,), lambda i: i), + grid=(4,)) add(x, y) ``` #### Templated matmul -In this example, we compute tiles of the output by doing an unrolled accumulation over blocks of rows and columns from our input arrays. We inline an activation function into the body of the kernel using a higher order function so we can emit a fused kernel. +In this example, we compute tiles of the output by doing an unrolled +accumulation over blocks of rows and columns from our input arrays. +We inline an activation function into the body of the kernel using a +higher order function so we can emit a fused kernel. ```python def matmul_kernel(x_ref, y_ref, o_ref, *, activation, block_k): - acc = jnp.zeros((x_ref.shape[0], x_ref.shape[1]), jnp.float32) - for k in range(x_ref.shape[1] // block_k) + acc = jnp.zeros((x_ref.shape[0], y_ref.shape[1]), jnp.float32) + for k in range(x_ref.shape[1] // block_k): x = x_ref[:, k*block_k:(k+1)*block_k] y = y_ref[k*block_k:(k+1)*block_k, :] acc += x @ y o_ref[:, :] = activation(acc).astype(o_ref.dtype) x, y = jnp.ones((512, 256)), jnp.ones((256, 1024)) -block_shape = 256, 256, 128 +block_shape = 128, 256, 128 @partial(jax.jit, static_argnames=["block_shape", "activation"]) def matmul(x, y, *, block_shape, activation): @@ -266,52 +465,124 @@ def matmul(x, y, *, block_shape, activation): partial(matmul_kernel, block_k=block_k, activation=activation), out_shape=jax.ShapeDtypeStruct((x.shape[0], y.shape[1],), jnp.float32), in_specs=[ - pl.BlockSpec(lambda i, j:, (i, 0), (block_m, x.shape[1])), - pl.BlockSpec(lambda i, j:, (0, j), (y.shape[0], block_n)) + pl.BlockSpec((block_m, x.shape[1]), lambda i, j: (i, 0)), + pl.BlockSpec((y.shape[0], block_n), lambda i, j: (0, j)) ], - out_specs=pl.BlockSpec(lambda i, j: (i, j), (block_m, block_n)) + out_specs=pl.BlockSpec((block_m, block_n), lambda i, j: (i, j)), + grid=(4, 4), + ) return fused_matmul(x, y) z = matmul(x, y, block_shape=block_shape, activation=jax.nn.gelu) ``` ### Lowering Pallas -After users express their Pallas kernels, we lower them to different representations depending on the target backend. On GPUs, we lower Pallas to Triton IR, and on TPU we lower Pallas to Mosaic. +After users express their Pallas kernels, we lower them to different +representations depending on the target backend. +On GPUs, we lower Pallas to Triton IR, and on TPU we lower Pallas to +Mosaic. #### Lowering Pallas to Triton for GPU -Lowering Pallas to Triton is easy because Pallas was designed with Triton as a target language in mind. The main differences between Pallas and Triton is that Triton doesn’t have a notion of `BlockSpec`s and also uses pointers when doing memory loads and stores as opposed to indices. - -Triton supports pointers as an array element type in its language and in Triton you can load from and store to arrays of pointers. In Pallas, when given a `(4, 5)`-shaped `Ref`, `x_ref`, and then do like `x_ref[3, 2]`, we need to lower this to computing a Triton pointer to the appropriate row-major position in `x_ref` (that is, doing 5 * 3 + 2 * 1). Similarly, when we lower slices to Triton, e.g. `x_ref[4, :]` we need to produce an array of pointers `5 * 4 + jnp.arange(3)`. - -Other than that, lowering to Triton is fairly straightforward. JAX dot products can be lowered to Triton dot products and JAX unary primitives are lowered to their Triton equivalents. Triton’s atomic operations are lowered via new Pallas atomic primitives. +Lowering Pallas to Triton is easy because Pallas was designed with +Triton as a target language in mind. +The main differences between Pallas and Triton is that Triton doesn’t +have a notion of `BlockSpec`s and also uses pointers when doing +memory loads and stores as opposed to indices. + +Triton supports pointers as an array element type in its language +and in Triton you can load from and store to arrays of pointers. +In Pallas, when given a `(4, 5)`-shaped `Ref`, `x_ref`, and then do +like `x_ref[3, 2]`, we need to lower this to computing a Triton +pointer to the appropriate row-major position in `x_ref` (that is, +doing 5 * 3 + 2 * 1). +Similarly, when we lower slices to Triton, e.g. `x_ref[4, :]` we need +to produce an array of pointers `5 * 4 + jnp.arange(3)`. + +Other than that, lowering to Triton is fairly straightforward. +JAX dot products can be lowered to Triton dot products and JAX unary +primitives are lowered to their Triton equivalents. +Triton’s atomic operations are lowered via new Pallas atomic +primitives. #### Lowering Pallas to Mosaic for TPU -Mosaic consumes (mostly) standard dialect MLIR and emits LLO to be compiled for TPU. Pallas can be lowered to Mosaic via translating JAX primitives to MLIR (mostly the `vector` and `arith` dialects). The `BlockSpec`s can be converted into pipeline schedules (i.e. the `transform_func`s in Mosaic). +Mosaic consumes (mostly) standard dialect MLIR and emits LLO to be +compiled for TPU. +Pallas can be lowered to Mosaic via translating JAX primitives to +MLIR (mostly the `vector` and `arith` dialects). +The `BlockSpec`s can be converted into pipeline schedules +(i.e. the `transform_func`s in Mosaic). ### Transforming Pallas -A natural question is how do JAX transformations interact with Pallas kernels? There are two main ways: transformations inside Pallas kernels and transformations outside Pallas kernels. +A natural question is how do JAX transformations interact with Pallas +kernels? +There are two main ways: transformations inside Pallas kernels and +transformations outside Pallas kernels. -Transformation inside Pallas kernels should actually “just work”, so long as we are able to lower the transformed code. For example, we could use `jax.grad(jnp.sin)(...)` inside of a JAX kernel because we can lower a `cos` to both Triton and Mosaic. However, we might not be able to lower a `jax.vmap(lax.dynamic_slice)` because it could turn into a gather that we cannot lower. +Transformation inside Pallas kernels should actually “just work”, +so long as we are able to lower the transformed code. +For example, we could use `jax.grad(jnp.sin)(...)` inside of a JAX +kernel because we can lower a `cos` to both Triton and Mosaic. +However, we might not be able to lower a `jax.vmap(lax.dynamic_slice)` +because it could turn into a gather that we cannot lower. -Transformations of Pallas kernels from the outer JAX programs is perhaps the more interesting case. How do we handle things like `vmap(pallas_call)` and `grad(pallas_call)`? +Transformations of Pallas kernels from the outer JAX programs is +perhaps the more interesting case. How do we handle things like +`vmap(pallas_call)` and `grad(pallas_call)`? #### `vmap-of-pallas_call` -vmap automatically vectorizes JAX programs. While kernel writers might want precise control over how a batched kernel will behave differently from its unbatched variant, we can offer a reasonable default `vmap` rule for `pallas_call` while offering the `jax.custom_vmap` customization mechanism. When `pallas_call` is `vmap`-ed, we augment the `pallas_call` to have an extra grid dimension corresponding to the new batch dimension and transform the `BlockSpec`s to handle indexing along that dimension. +vmap automatically vectorizes JAX programs. While kernel writers might +want precise control over how a batched kernel will behave differently +from its unbatched variant, we can offer a reasonable default `vmap` +rule for `pallas_call` while offering the `jax.custom_vmap` +customization mechanism. When `pallas_call` is `vmap`-ed, we augment +the `pallas_call` to have an extra grid dimension corresponding to the +new batch dimension and transform the `BlockSpec`s to handle indexing +along that dimension. #### `grad-of-pallas_call` -`grad` of `pallas_call` enables automatic differentiation of kernels. `jax.grad` breaks down into applications of three distinct transforms: `jvp`, `partial_eval` and `transpose`. In principle, we can re-use most of JAX’s infrastructure when implementing these rules for `pallas_call` (since it behaves much like existing JAX higher order primitives). - -However, automatic differentiation of kernels can result in a performance hit due to how memory access is transposed. If we write a GPU kernel with overlapping-and-parallel reads and disjoint-but-parallel writes, we automatically transpose it into a kernel that has overlapping-but-parallel writes (which are slow when done atomically) and disjoint-and-parallel reads. To emit a kernel that better uses parallelism with shared memory, we would need to reorder loops and change how the kernel is vectorized. Unfortunately, we do not have a program representation amenable to that in Pallas. A potential direction to automatically differentiating kernels efficiently is to explore a different representation, perhaps one like that in Dex. We could also look at how Enzyme approaches this problem. However, AD of Pallas kernels may still be useful for a class of kernels that does transpose efficiently (for example elementwise kernels). - -In general, though, `jax.custom_vjp` is a viable escape hatch to express Pallas kernels that work with `jax.grad`. +`grad` of `pallas_call` enables automatic differentiation of kernels. +`jax.grad` breaks down into applications of three distinct transforms: +`jvp`, `partial_eval` and `transpose`. +In principle, we can re-use most of JAX’s infrastructure when +implementing these rules for `pallas_call` (since it behaves much like +existing JAX higher order primitives). + +However, automatic differentiation of kernels can result in a +performance hit due to how memory access is transposed. +If we write a GPU kernel with overlapping-and-parallel reads and +disjoint-but-parallel writes, we automatically transpose it into a +kernel that has overlapping-but-parallel writes (which are slow when +done atomically) and disjoint-and-parallel reads. +To emit a kernel that better uses parallelism with shared memory, +we would need to reorder loops and change how the kernel is vectorized. +Unfortunately, we do not have a program representation amenable to +that in Pallas. +A potential direction to automatically differentiating kernels +efficiently is to explore a different representation, perhaps one +like that in Dex. +We could also look at how Enzyme approaches this problem. +However, AD of Pallas kernels may still be useful for a class of +kernels that does transpose efficiently (for example elementwise +kernels). + +In general, though, `jax.custom_vjp` is a viable escape hatch to +express Pallas kernels that work with `jax.grad`. #### Other transformations -We could imagine other JAX transformations applying to Pallas kernels that we haven’t explicitly explored yet. For example, `checkify` is a JAX transformation that does functional error handling. We could imagine using `checkify` with pallas_call to allow plumbing out error codes from GPU kernels that indicate if OOB access or NaNs were produced. - -Another potential transformation to integrate with is custom_partitioning to enable automatically partitionable kernels to be used with pjit. +We could imagine other JAX transformations applying to Pallas kernels +that we haven’t explicitly explored yet. +For example, `checkify` is a JAX transformation that does functional +error handling. +We could imagine using `checkify` with pallas_call to allow plumbing +out error codes from GPU kernels that indicate if OOB access or NaNs +were produced. + +Another potential transformation to integrate with is +custom_partitioning to enable automatically partitionable kernels to +be used with pjit. diff --git a/docs/pallas/grid_blockspec.md b/docs/pallas/grid_blockspec.md new file mode 100644 index 000000000000..c89c536a70f9 --- /dev/null +++ b/docs/pallas/grid_blockspec.md @@ -0,0 +1,215 @@ +(pallas_grids_and_blockspecs)= + +# Grids and BlockSpecs + + + +(pallas_grid)= +### `grid`, a.k.a. kernels in a loop + +When using {func}`jax.experimental.pallas.pallas_call` the kernel function +is executed multiple times on different inputs, as specified via the `grid` argument +to `pallas_call`. Conceptually: +```python +pl.pallas_call(some_kernel, grid=(n,))(...) +``` +maps to +```python +for i in range(n): + some_kernel(...) +``` +Grids can be generalized to be multi-dimensional, corresponding to nested +loops. For example, + +```python +pl.pallas_call(some_kernel, grid=(n, m))(...) +``` +is equivalent to +```python +for i in range(n): + for j in range(m): + some_kernel(...) +``` +This generalizes to any tuple of integers (a length `d` grid will correspond +to `d` nested loops). +The kernel is executed as many times +as `prod(grid)`. Each of these invocations is referred to as a "program". +To access which program (i.e. which element of the grid) the kernel is currently +executing, we use {func}`jax.experimental.pallas.program_id`. +For example, for invocation `(1, 2)`, `program_id(axis=0)` returns `1` and +`program_id(axis=1)` returns `2`. +You can also use {func}`jax.experimental.pallas.num_programs` to get the +grid size for a given axis. + +Here's an example kernel that uses a `grid` and `program_id`. + +```python +>>> import jax +>>> from jax.experimental import pallas as pl + +>>> def iota_kernel(o_ref): +... i = pl.program_id(0) +... o_ref[i] = i + +``` + +We now execute it using `pallas_call` with an additional `grid` argument. + +```python +>>> def iota(size: int): +... return pl.pallas_call(iota_kernel, +... out_shape=jax.ShapeDtypeStruct((size,), jnp.int32), +... grid=(size,), interpret=True)() +>>> iota(8) +Array([0, 1, 2, 3, 4, 5, 6, 7], dtype=int32) + +``` + +On GPUs, each program is executed in parallel on separate thread blocks. +Thus, we need to think about race conditions on writes to HBM. +A reasonable approach is to write our kernels in such a way that different +programs write to disjoint places in HBM to avoid these parallel writes. + +On TPUs, programs are executed in a combination of parallel and sequential +(depending on the architecture) so there are slightly different considerations. +See [the Pallas TPU documentation](https://jax.readthedocs.io/en/latest/pallas/tpu/details.html#noteworthy-properties-and-restrictions). + +(pallas_blockspec)= + +### `BlockSpec`, a.k.a. how to chunk up inputs + +```{note} +The documentation here applies to the ``indexing_mode == Blocked``, which +is the default. +The documentation for the ``indexing_mode == Unblocked`` is coming. +``` + +In conjunction with the `grid` argument, we need to provide Pallas +the information on how to slice up the input for each invocation. +Specifically, we need to provide a mapping between *the iteration of the loop* +to *which block of our inputs and outputs to be operated on*. +This is provided via {class}`jax.experimental.pallas.BlockSpec` objects. + +Before we get into the details of `BlockSpec`s, you may want +to revisit the +[Pallas Quickstart BlockSpecs example](https://jax.readthedocs.io/en/latest/pallas/quickstart.html#block-specs-by-example). + +`BlockSpec`s are provided to `pallas_call` via the +`in_specs` and `out_specs`, one for each input and output respectively. + +Informally, the `index_map` of the `BlockSpec` takes as arguments +the invocation indices (as many as the length of the `grid` tuple), +and returns **block indices** (one block index for each axis of +the overall array). Each block index is then multiplied by the +corresponding axis size from `block_shape` +to get the actual element index on the corresponding array axis. + +```{note} +This documentation applies to the case when the block shape divides +the array shape. +The documentation for the other cases is pending. +``` + +More precisely, the slices for each axis of the input `x` of +shape `x_shape` are computed as in the function `slice_for_invocation` +below: + +```python +>>> def slices_for_invocation(x_shape: tuple[int, ...], +... x_spec: pl.BlockSpec, +... grid: tuple[int, ...], +... invocation_indices: tuple[int, ...]) -> tuple[slice, ...]: +... assert len(invocation_indices) == len(grid) +... assert all(0 <= i < grid_size for i, grid_size in zip(invocation_indices, grid)) +... block_indices = x_spec.index_map(*invocation_indices) +... assert len(x_shape) == len(x_spec.block_shape) == len(block_indices) +... elem_indices = [] +... for x_size, block_size, block_idx in zip(x_shape, x_spec.block_shape, block_indices): +... assert block_size <= x_size # Blocks must be smaller than the array +... start_idx = block_idx * block_size +... # For now, we document only the case when the entire iteration is in bounds +... assert start_idx + block_size <= x_size +... elem_indices.append(slice(start_idx, start_idx + block_size)) +... return elem_indices + +``` + +For example: +```python +>>> slices_for_invocation(x_shape=(100, 100), +... x_spec = pl.BlockSpec((10, 20), lambda i, j: (i, j)), +... grid = (10, 5), +... invocation_indices = (2, 3)) +[slice(20, 30, None), slice(60, 80, None)] + +>>> # Same shape of the array and blocks, but we iterate over each block 4 times +>>> slices_for_invocation(x_shape=(100, 100), +... x_spec = pl.BlockSpec((10, 20), lambda i, j, k: (i, j)), +... grid = (10, 5, 4), +... invocation_indices = (2, 3, 0)) +[slice(20, 30, None), slice(60, 80, None)] + +``` + +The function `show_invocations` defined below uses Pallas to show the +invocation indices. The `iota_2D_kernel` will fill each output block +with a decimal number where the first digit represents the invocation +index over the first axis, and the second the invocation index +over the second axis: + +```python +>>> def show_invocations(x_shape, block_shape, grid, out_index_map=lambda i, j: (i, j)): +... def iota_2D_kernel(o_ref): +... axes = 0 +... for axis in range(len(grid)): +... axes += pl.program_id(axis) * 10**(len(grid) - 1 - axis) +... o_ref[...] = jnp.full(o_ref.shape, axes) +... res = pl.pallas_call(iota_2D_kernel, +... out_shape=jax.ShapeDtypeStruct(x_shape, dtype=np.int32), +... grid=grid, +... in_specs=[], +... out_specs=pl.BlockSpec(block_shape, out_index_map), +... interpret=True)() +... print(res) + +``` + +For example: +```python +>>> show_invocations(x_shape=(8, 6), block_shape=(2, 3), grid=(4, 2)) +[[ 0 0 0 1 1 1] + [ 0 0 0 1 1 1] + [10 10 10 11 11 11] + [10 10 10 11 11 11] + [20 20 20 21 21 21] + [20 20 20 21 21 21] + [30 30 30 31 31 31] + [30 30 30 31 31 31]] + +``` + +When multiple invocations write to the same elements of the output +array the result is platform dependent. + +In the example below, we have a 3D grid with the last grid dimension +not used in the block selection (`out_index_map=lambda i, j, k: (i, j)`). +Hence, we iterate over the same output block 10 times. +The output shown below was generated on CPU using `interpret=True` +mode, which at the moment executes the invocation sequentially. +On TPUs, programs are executed in a combination of parallel and sequential, +and this function generates the output shown. +See [the Pallas TPU documentation](https://jax.readthedocs.io/en/latest/pallas/tpu/details.html#noteworthy-properties-and-restrictions). + +```python +>>> show_invocations(x_shape=(8, 6), block_shape=(2, 3), grid=(4, 2, 10), +... out_index_map=lambda i, j, k: (i, j)) +[[ 9 9 9 19 19 19] + [ 9 9 9 19 19 19] + [109 109 109 119 119 119] + [109 109 109 119 119 119] + [209 209 209 219 219 219] + [209 209 209 219 219 219] + [309 309 309 319 319 319] + [309 309 309 319 319 319]] + +``` diff --git a/docs/pallas/index.rst b/docs/pallas/index.rst index bd086bd47303..9fbb560d1b8b 100644 --- a/docs/pallas/index.rst +++ b/docs/pallas/index.rst @@ -4,13 +4,15 @@ Pallas: a JAX kernel language ============================= Pallas is an extension to JAX that enables writing custom kernels for GPU and TPU. This section contains tutorials, guides and examples for using Pallas. +See also the :class:`jax.experimental.pallas` module API documentation. .. toctree:: :caption: Guides :maxdepth: 2 - design quickstart + design + grid_blockspec .. toctree:: :caption: Platform Features diff --git a/docs/pallas/quickstart.ipynb b/docs/pallas/quickstart.ipynb index ee5fd44ed6f6..5a8608f494c3 100644 --- a/docs/pallas/quickstart.ipynb +++ b/docs/pallas/quickstart.ipynb @@ -7,9 +7,15 @@ "source": [ "# Pallas Quickstart\n", "\n", - "Pallas is an extension to JAX that enables writing custom kernels for GPU and TPU. Pallas allows you to use the same JAX functions and APIs but operates at a *lower* level of abstraction.\n", + "\n", "\n", - "Specifically, Pallas requires users to think about memory access and how to divide up computations across multiple compute units in a hardware accelerator. On GPUs, Pallas lowers to Triton and on TPUs, Pallas lowers to Mosaic.\n", + "Pallas is an extension to JAX that enables writing custom kernels for GPU and TPU.\n", + "Pallas allows you to use the same JAX functions and APIs but operates at a\n", + "*lower* level of abstraction.\n", + "\n", + "Specifically, Pallas requires users to think about memory access and how to\n", + "divide up computations across multiple compute units in a hardware accelerator.\n", + "On GPUs, Pallas lowers to Triton and on TPUs, Pallas lowers to Mosaic.\n", "\n", "Let's dive into some examples.\n", "\n", @@ -64,15 +70,24 @@ "source": [ "**`Ref` types**\n", "\n", - "Let's dissect this function a bit. Unlike most JAX functions you've probably written, it does not take in `jax.Array`s as inputs and doesn't return any values. Instead it takes in *`Ref`* objects as inputs. Note that we also don't have any outputs but we are given an `o_ref`, which corresponds to the desired output.\n", + "Let's dissect this function a bit. Unlike most JAX functions you've probably written,\n", + "it does not take in `jax.Array`s as inputs and doesn't return any values.\n", + "Instead, it takes in *`Ref`* objects as inputs. Note that we also don't have any outputs\n", + "but we are given an `o_ref`, which corresponds to the desired output.\n", "\n", "**Reading from `Ref`s**\n", "\n", - "In the body, we are first reading from `x_ref` and `y_ref`, indicated by the `[...]` (the ellipsis means we are reading the whole `Ref`; alternatively we also could have used `x_ref[:]`). Reading from a `Ref` like this returns a `jax.Array`.\n", + "In the body, we are first reading from `x_ref` and `y_ref`, indicated by the `[...]`\n", + "(the ellipsis means we are reading the whole `Ref`;\n", + "alternatively we also could have used `x_ref[:]`).\n", + "Reading from a `Ref` like this returns a `jax.Array`.\n", "\n", "**Writing to `Ref`s**\n", "\n", - "We then write `x + y` to `o_ref`. Mutation has not historically been supported in JAX -- `jax.Array`s are immutable! `Ref`s are new (experimental) types that allow mutation under certain circumstances. We can interpret writing to a `Ref` as mutating its underlying buffer." + "We then write `x + y` to `o_ref`.\n", + "Mutation has not historically been supported in JAX -- `jax.Array`s are immutable!\n", + "`Ref`s are new (experimental) types that allow mutation under certain circumstances.\n", + "We can interpret writing to a `Ref` as mutating its underlying buffer." ] }, { @@ -80,7 +95,11 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "So we've written what we call a \"kernel\", which we define as a program that will run as an atomic unit of execution on an accelerator, without any interaction with the host. How do we invoke it from a JAX computation? We use the `pallas_call` higher-order function." + "So we've written what we call a \"kernel\", which we define as a program that will\n", + "run as an atomic unit of execution on an accelerator,\n", + "without any interaction with the host.\n", + "How do we invoke it from a JAX computation?\n", + "We use the `pallas_call` higher-order function." ] }, { @@ -102,9 +121,10 @@ "source": [ "@jax.jit\n", "def add_vectors(x: jax.Array, y: jax.Array) -> jax.Array:\n", - " return pl.pallas_call(add_vectors_kernel,\n", - " out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype)\n", - " )(x, y)\n", + " return pl.pallas_call(\n", + " add_vectors_kernel,\n", + " out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype)\n", + " )(x, y)\n", "add_vectors(jnp.arange(8), jnp.arange(8))" ] }, @@ -113,7 +133,10 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "`pallas_call` lifts the Pallas kernel function into an operation that can be called as part of a larger JAX program. But, to do so, it needs a few more details. Here we specify `out_shape`, an object that has a `.shape` and `.dtype` (or a list thereof).\n", + "`pallas_call` lifts the Pallas kernel function into an operation that can be called\n", + "as part of a larger JAX program. But, to do so, it needs a few more details.\n", + "Here we specify `out_shape`, an object that has a `.shape` and `.dtype` (or a list\n", + "thereof).\n", "`out_shape` determines the shape/dtype of `o_ref` in our `add_vector_kernel`.\n", "\n", "`pallas_call` returns a function that takes in and returns `jax.Array`s." @@ -126,11 +149,20 @@ "source": [ "**What's actually happening here?**\n", "\n", - "Thus far we've described how to think about Pallas kernels but what we've actually accomplished is we're writing a function that's executed very close to the compute units.\n", + "Thus far we've described how to think about Pallas kernels but what we've actually\n", + "accomplished is we're writing a function that's executed very close to the compute units.\n", "\n", - "On GPU, `x_ref` corresponds to a value in high-bandwidth memory (HBM) and when we do `x_ref[...]` we are copying the value from HBM into static RAM (SRAM) (this is a costly operation generally speaking!). We then use GPU vector compute to execute the addition, then copy the resulting value in SRAM back to HBM.\n", + "On GPU, `x_ref` corresponds to a value in high-bandwidth memory (HBM) and when\n", + "we do `x_ref[...]` we are copying the value from HBM into static RAM (SRAM)\n", + "(this is a costly operation generally speaking!).\n", + "We then use GPU vector compute to execute the addition, then copy the resulting value\n", + "in SRAM back to HBM.\n", "\n", - "On TPU, we do something slightly different. Before the kernel is ever executed, we fetch the value from HBM into SRAM. `x_ref` therefore corresponds to a value in SRAM and when we do `x_ref[...]` we are copying the value from SRAM into a register. We then use TPU vector compute to execute the addition, then copy the resulting value back into SRAM. After the kernel is executed, the SRAM value is copied back into HBM.\n", + "On TPU, we do something slightly different. Before the kernel is ever executed,\n", + "we fetch the value from HBM into SRAM. `x_ref` therefore corresponds to a value in\n", + "SRAM and when we do `x_ref[...]` we are copying the value from SRAM into a register.\n", + "We then use TPU vector compute to execute the addition, then copy the resulting\n", + "value back into SRAM. After the kernel is executed, the SRAM value is copied back into HBM.\n", "\n", "We are in the process of writing backend-specific Pallas guides. Coming soon!" ] @@ -148,7 +180,10 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "In our \"hello world\" example, we wrote a very simple kernel. It takes advantage of the fact that our 8-sized arrays can comfortably fit inside the SRAM of hardware accelerators. In most real-world applications, this will not be the case!" + "In our \"hello world\" example, we wrote a very simple kernel.\n", + "It takes advantage of the fact that our 8-sized arrays can comfortably fit inside\n", + "the SRAM of hardware accelerators.\n", + "In most real-world applications, this will not be the case!" ] }, { @@ -156,15 +191,21 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Part of writing Pallas kernels is thinking about how to take big arrays that live in high-bandwidth memory (HBM, also known as DRAM) and expressing computations that operate on \"blocks\" of those arrays that can fit in SRAM.\n", + "Part of writing Pallas kernels is thinking about how to take big arrays that\n", + "live in high-bandwidth memory (HBM, also known as DRAM) and expressing computations\n", + "that operate on \"blocks\" of those arrays that can fit in SRAM.\n", "\n", - "### Grids\n", + "### Grids by example\n", "\n", - "To automatically \"carve\" up the inputs and outputs, you provide a `grid` and `BlockSpec`s to `pallas_call`.\n", + "To automatically \"carve\" up the inputs and outputs, you provide a `grid` and\n", + "`BlockSpec`s to `pallas_call`.\n", "\n", - "A `grid` is a tuple of integers (e.g. `()`, `(2, 3, 4)`, or `(8,)`) that specifies an iteration space.\n", - "For example, a grid `(4, 5)` would have 20 elements: `(0, 0), (0, 1), ..., (0, 4), (1, 0), ..., (3, 4)`.\n", - "We run the kernel function once for each element, a style of single-program multiple-data (SPMD) programming.\n", + "A `grid` is a tuple of integers (e.g. `()`, `(2, 3, 4)`, or `(8,)`) that specifies\n", + "an iteration space.\n", + "For example, a grid `(4, 5)` would have 20 elements:\n", + "`(0, 0), (0, 1), ..., (0, 4), (1, 0), ..., (3, 4)`.\n", + "We run the kernel function once for each element, a style of single-program\n", + "multiple-data (SPMD) programming.\n", "\n", "
\n", "\n", @@ -173,7 +214,12 @@ "A 2D grid\n", "
\n", "\n", - "When we provide a `grid` to `pallas_call`, the kernel is executed as many times as `prod(grid)`. Each of these invocations is referred to as a \"program\", To access which program (i.e. which element of the grid) the kernel is currently executing, we use `program_id(axis=...)`. For example, for invocation `(1, 2)`, `program_id(axis=0)` returns `1` and `program_id(axis=1)` returns `2`.\n", + "When we provide a `grid` to `pallas_call`, the kernel is executed as many times\n", + "as `prod(grid)`. Each of these invocations is referred to as a \"program\".\n", + "To access which program (i.e. which element of the grid) the kernel is currently\n", + "executing, we use `program_id(axis=...)`.\n", + "For example, for invocation `(1, 2)`, `program_id(axis=0)` returns `1` and\n", + "`program_id(axis=1)` returns `2`.\n", "\n", "Here's an example kernel that uses a `grid` and `program_id`." ] @@ -214,10 +260,10 @@ } ], "source": [ - "def iota(len: int):\n", + "def iota(size: int):\n", " return pl.pallas_call(iota_kernel,\n", - " out_shape=jax.ShapeDtypeStruct((len,), jnp.int32),\n", - " grid=(len,))()\n", + " out_shape=jax.ShapeDtypeStruct((size,), jnp.int32),\n", + " grid=(size,))()\n", "iota(8)" ] }, @@ -226,9 +272,17 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "On GPUs, each program is executed in parallel on separate threads. Thus, we need to think about race conditions on writes to HBM. A reasonable approach is to write our kernels in such a way that different programs write to disjoint places in HBM to avoid these parallel writes. On the other hand, parallelizing the computation is how we can execute operations like matrix multiplications really quickly.\n", + "On GPUs, each program is executed in parallel on separate threads.\n", + "Thus, we need to think about race conditions on writes to HBM.\n", + "A reasonable approach is to write our kernels in such a way that different\n", + "programs write to disjoint places in HBM to avoid these parallel writes.\n", + "On the other hand, parallelizing the computation is how we can execute\n", + "operations like matrix multiplications really quickly.\n", "\n", - "On TPUs, programs are executed in a combination of parallel and sequential (depending on the architecture) so there are slightly different considerations." + "On TPUs, programs are executed in a combination of parallel and sequential\n", + "(depending on the architecture) so there are slightly different considerations.\n", + "\n", + "You can read more details at {ref}`pallas_grid`." ] }, { @@ -236,7 +290,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### Block specs" + "### Block specs by example" ] }, { @@ -244,12 +298,19 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "With `grid` and `program_id` in mind, Pallas provides an abstraction that takes care of some common indexing patterns seen in a lot of kernels.\n", - "To build intution, let's try to implement a matrix multiplication.\n", + "With `grid` and `program_id` in mind, Pallas provides an abstraction that\n", + "takes care of some common indexing patterns seen in a lot of kernels.\n", + "To build intuition, let's try to implement a matrix multiplication.\n", "\n", - "A simple strategy for implementing a matrix multiplication in Pallas is to implement it recursively. We know our underlying hardware has support for small matrix multiplications (using GPU and TPU tensorcores), so we just express a big matrix multiplication in terms of smaller ones.\n", + "A simple strategy for implementing a matrix multiplication in Pallas is to\n", + "implement it recursively.\n", + "We know our underlying hardware has support for small matrix multiplications\n", + "(using GPU and TPU tensorcores), so we just express a big matrix multiplication\n", + "in terms of smaller ones.\n", "\n", - "Suppose we have input matrices $X$ and $Y$ and are computing $Z = XY$. We first express $X$ and $Y$ as block matrices. $X$ will have \"row\" blocks and $Y$ will have \"column\" blocks.\n", + "Suppose we have input matrices $X$ and $Y$ and are computing $Z = XY$.\n", + "We first express $X$ and $Y$ as block matrices. $X$ will have \"row\" blocks\n", + "and $Y$ will have \"column\" blocks.\n", "\n", "$$\n", "\\begin{align*}\n", @@ -289,7 +350,10 @@ "\\end{align*}\n", "$$\n", "\n", - "Our strategy is that because $Z$ is also a block matrix, we can assign each of the programs in our Pallas kernel one of the output blocks. Computing each output block corresponds to doing a smaller matrix multiply between a \"row\" block of $X$ and a \"column\" block of $Y$." + "Our strategy is that because $Z$ is also a block matrix, we can assign each of\n", + "the programs in our Pallas kernel one of the output blocks.\n", + "Computing each output block corresponds to doing a smaller matrix multiply\n", + "between a \"row\" block of $X$ and a \"column\" block of $Y$." ] }, { @@ -297,7 +361,9 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "To express this pattern, we use `BlockSpec`s. A `BlockSpec` specifies a block shape for each input and output, and an \"index map\" function, that maps a set of program indices to a block index.\n", + "To express this pattern, we use `BlockSpec`s. A `BlockSpec` specifies a block\n", + "shape for each input and output, and an \"index map\" function, that maps a\n", + "set of program indices to a block index.\n", "\n", "
\n", "\n", @@ -307,13 +373,25 @@ "\n", "
\n", "\n", - "For a concrete example, let's say we'd like to multiply two `(1024, 1024)` matrices `x` and `y` together to produce `z`, and would like to parallelize the computation 4 ways. We split up `z` into 4 `(512, 512)` blocks where each block is computed with a `(512, 1024) x (1024, 512)` matrix multiplication. To express this, we'd first use a `(2, 2)` grid (one block for each program).\n", + "For a concrete example, let's say we'd like to multiply two `(1024, 1024)`\n", + "matrices `x` and `y` together to produce `z`, and would like to parallelize\n", + "the computation 4 ways. We split up `z` into 4 `(512, 512)` blocks where\n", + "each block is computed with a `(512, 1024) x (1024, 512)` matrix multiplication.\n", + "To express this, we'd first use a `(2, 2)` grid (one block for each program).\n", "\n", - "For `x`, we use `BlockSpec(lambda i, j: (i, 0), (512, 1024))` -- this carves `x` up into \"row\" blocks. To see this see how both program instances `(1, 0)` and `(1, 1)` pick the `(1, 0)` block in `x`. For `y`, we use a transposed version `BlockSpec(lambda i, j: (0, j), (1024, 512))`. Finally, for `z` we use `BlockSpec(lambda i, j: (i, j), (512, 512))`.\n", + "For `x`, we use `BlockSpec((512, 1024), lambda i, j: (i, 0))` -- this\n", + "carves `x` up into \"row\" blocks.\n", + "To see this see how both program instances\n", + "`(1, 0)` and `(1, 1)` pick the `(1, 0)` block in `x`.\n", + "For `y`, we use a transposed version `BlockSpec((1024, 512), lambda i, j: (0, j))`.\n", + "Finally, for `z` we use `BlockSpec((512, 512), lambda i, j: (i, j))`.\n", "\n", "These `BlockSpec`s are passed into `pallas_call` via `in_specs` and `out_specs`.\n", "\n", - "Underneath the hood, `pallas_call` will automatically carve up your inputs and outputs into `Ref`s for each block that will be passed into the kernel." + "For more detail on `BlockSpec`s see {ref}`pallas_blockspec`.\n", + "\n", + "Underneath the hood, `pallas_call` will automatically carve up your inputs and\n", + "outputs into `Ref`s for each block that will be passed into the kernel." ] }, { @@ -331,14 +409,14 @@ " out_shape=jax.ShapeDtypeStruct((x.shape[0], y.shape[1]), x.dtype),\n", " grid=(2, 2),\n", " in_specs=[\n", - " pl.BlockSpec(lambda i, j: (i, 0), (x.shape[0] // 2, x.shape[1])),\n", - " pl.BlockSpec(lambda i, j: (0, j), (y.shape[0], y.shape[1] // 2))\n", + " pl.BlockSpec((x.shape[0] // 2, x.shape[1]), lambda i, j: (i, 0)),\n", + " pl.BlockSpec((y.shape[0], y.shape[1] // 2), lambda i, j: (0, j))\n", " ],\n", " out_specs=pl.BlockSpec(\n", - " lambda i, j: (i, j), (x.shape[0] // 2, y.shape[1] // 2)\n", + " (x.shape[0] // 2, y.shape[1] // 2), lambda i, j: (i, j),\n", " )\n", " )(x, y)\n", - "k1, k2 = jax.random.split(jax.random.PRNGKey(0))\n", + "k1, k2 = jax.random.split(jax.random.key(0))\n", "x = jax.random.normal(k1, (1024, 1024))\n", "y = jax.random.normal(k2, (1024, 1024))\n", "z = matmul(x, y)\n", @@ -350,8 +428,10 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Note that this is a very naive implementation of a matrix multiplication but consider it a starting point for various types of optimizations.\n", - "Let's add an additional feature to our matrix multiply: fused activation. It's actually really easy! Just pass a higher-order activation function into the kernel." + "Note that this is a very naive implementation of a matrix multiplication but\n", + "consider it a starting point for various types of optimizations.\n", + "Let's add an additional feature to our matrix multiply: fused activation.\n", + "It's actually really easy! Just pass a higher-order activation function into the kernel." ] }, { @@ -369,14 +449,14 @@ " out_shape=jax.ShapeDtypeStruct((x.shape[0], y.shape[1]), x.dtype),\n", " grid=(2, 2),\n", " in_specs=[\n", - " pl.BlockSpec(lambda i, j: (i, 0), (x.shape[0] // 2, x.shape[1])),\n", - " pl.BlockSpec(lambda i, j: (0, j), (y.shape[0], y.shape[1] // 2))\n", + " pl.BlockSpec((x.shape[0] // 2, x.shape[1]), lambda i, j: (i, 0)),\n", + " pl.BlockSpec((y.shape[0], y.shape[1] // 2), lambda i, j: (0, j))\n", " ],\n", " out_specs=pl.BlockSpec(\n", - " lambda i, j: (i, j), (x.shape[0] // 2, y.shape[1] // 2)\n", + " (x.shape[0] // 2, y.shape[1] // 2), lambda i, j: (i, j)\n", " ),\n", " )(x, y)\n", - "k1, k2 = jax.random.split(jax.random.PRNGKey(0))\n", + "k1, k2 = jax.random.split(jax.random.key(0))\n", "x = jax.random.normal(k1, (1024, 1024))\n", "y = jax.random.normal(k2, (1024, 1024))\n", "z = matmul(x, y, activation=jax.nn.relu)\n", @@ -388,7 +468,8 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "To conclude, let's highlight a cool feature of Pallas: it composes with `jax.vmap`! To turn this matrix multiplication into a batched version, we just need to `vmap` it." + "To conclude, let's highlight a cool feature of Pallas: it composes with `jax.vmap`!\n", + "To turn this matrix multiplication into a batched version, we just need to `vmap` it." ] }, { @@ -397,7 +478,7 @@ "metadata": {}, "outputs": [], "source": [ - "k1, k2 = jax.random.split(jax.random.PRNGKey(0))\n", + "k1, k2 = jax.random.split(jax.random.key(0))\n", "x = jax.random.normal(k1, (4, 1024, 1024))\n", "y = jax.random.normal(k2, (4, 1024, 1024))\n", "z = jax.vmap(partial(matmul, activation=jax.nn.relu))(x, y)\n", diff --git a/docs/pallas/quickstart.md b/docs/pallas/quickstart.md index 61e68ef1ea9a..36cc14bf5c34 100644 --- a/docs/pallas/quickstart.md +++ b/docs/pallas/quickstart.md @@ -5,7 +5,7 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.16.0 + jupytext_version: 1.16.1 kernelspec: display_name: Python 3 (ipykernel) language: python @@ -14,9 +14,15 @@ kernelspec: # Pallas Quickstart -Pallas is an extension to JAX that enables writing custom kernels for GPU and TPU. Pallas allows you to use the same JAX functions and APIs but operates at a *lower* level of abstraction. + -Specifically, Pallas requires users to think about memory access and how to divide up computations across multiple compute units in a hardware accelerator. On GPUs, Pallas lowers to Triton and on TPUs, Pallas lowers to Mosaic. +Pallas is an extension to JAX that enables writing custom kernels for GPU and TPU. +Pallas allows you to use the same JAX functions and APIs but operates at a +*lower* level of abstraction. + +Specifically, Pallas requires users to think about memory access and how to +divide up computations across multiple compute units in a hardware accelerator. +On GPUs, Pallas lowers to Triton and on TPUs, Pallas lowers to Mosaic. Let's dive into some examples. @@ -45,30 +51,47 @@ def add_vectors_kernel(x_ref, y_ref, o_ref): **`Ref` types** -Let's dissect this function a bit. Unlike most JAX functions you've probably written, it does not take in `jax.Array`s as inputs and doesn't return any values. Instead it takes in *`Ref`* objects as inputs. Note that we also don't have any outputs but we are given an `o_ref`, which corresponds to the desired output. +Let's dissect this function a bit. Unlike most JAX functions you've probably written, +it does not take in `jax.Array`s as inputs and doesn't return any values. +Instead, it takes in *`Ref`* objects as inputs. Note that we also don't have any outputs +but we are given an `o_ref`, which corresponds to the desired output. **Reading from `Ref`s** -In the body, we are first reading from `x_ref` and `y_ref`, indicated by the `[...]` (the ellipsis means we are reading the whole `Ref`; alternatively we also could have used `x_ref[:]`). Reading from a `Ref` like this returns a `jax.Array`. +In the body, we are first reading from `x_ref` and `y_ref`, indicated by the `[...]` +(the ellipsis means we are reading the whole `Ref`; +alternatively we also could have used `x_ref[:]`). +Reading from a `Ref` like this returns a `jax.Array`. **Writing to `Ref`s** -We then write `x + y` to `o_ref`. Mutation has not historically been supported in JAX -- `jax.Array`s are immutable! `Ref`s are new (experimental) types that allow mutation under certain circumstances. We can interpret writing to a `Ref` as mutating its underlying buffer. +We then write `x + y` to `o_ref`. +Mutation has not historically been supported in JAX -- `jax.Array`s are immutable! +`Ref`s are new (experimental) types that allow mutation under certain circumstances. +We can interpret writing to a `Ref` as mutating its underlying buffer. +++ -So we've written what we call a "kernel", which we define as a program that will run as an atomic unit of execution on an accelerator, without any interaction with the host. How do we invoke it from a JAX computation? We use the `pallas_call` higher-order function. +So we've written what we call a "kernel", which we define as a program that will +run as an atomic unit of execution on an accelerator, +without any interaction with the host. +How do we invoke it from a JAX computation? +We use the `pallas_call` higher-order function. ```{code-cell} ipython3 @jax.jit def add_vectors(x: jax.Array, y: jax.Array) -> jax.Array: - return pl.pallas_call(add_vectors_kernel, - out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype) - )(x, y) + return pl.pallas_call( + add_vectors_kernel, + out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype) + )(x, y) add_vectors(jnp.arange(8), jnp.arange(8)) ``` -`pallas_call` lifts the Pallas kernel function into an operation that can be called as part of a larger JAX program. But, to do so, it needs a few more details. Here we specify `out_shape`, an object that has a `.shape` and `.dtype` (or a list thereof). +`pallas_call` lifts the Pallas kernel function into an operation that can be called +as part of a larger JAX program. But, to do so, it needs a few more details. +Here we specify `out_shape`, an object that has a `.shape` and `.dtype` (or a list +thereof). `out_shape` determines the shape/dtype of `o_ref` in our `add_vector_kernel`. `pallas_call` returns a function that takes in and returns `jax.Array`s. @@ -77,11 +100,20 @@ add_vectors(jnp.arange(8), jnp.arange(8)) **What's actually happening here?** -Thus far we've described how to think about Pallas kernels but what we've actually accomplished is we're writing a function that's executed very close to the compute units. +Thus far we've described how to think about Pallas kernels but what we've actually +accomplished is we're writing a function that's executed very close to the compute units. -On GPU, `x_ref` corresponds to a value in high-bandwidth memory (HBM) and when we do `x_ref[...]` we are copying the value from HBM into static RAM (SRAM) (this is a costly operation generally speaking!). We then use GPU vector compute to execute the addition, then copy the resulting value in SRAM back to HBM. +On GPU, `x_ref` corresponds to a value in high-bandwidth memory (HBM) and when +we do `x_ref[...]` we are copying the value from HBM into static RAM (SRAM) +(this is a costly operation generally speaking!). +We then use GPU vector compute to execute the addition, then copy the resulting value +in SRAM back to HBM. -On TPU, we do something slightly different. Before the kernel is ever executed, we fetch the value from HBM into SRAM. `x_ref` therefore corresponds to a value in SRAM and when we do `x_ref[...]` we are copying the value from SRAM into a register. We then use TPU vector compute to execute the addition, then copy the resulting value back into SRAM. After the kernel is executed, the SRAM value is copied back into HBM. +On TPU, we do something slightly different. Before the kernel is ever executed, +we fetch the value from HBM into SRAM. `x_ref` therefore corresponds to a value in +SRAM and when we do `x_ref[...]` we are copying the value from SRAM into a register. +We then use TPU vector compute to execute the addition, then copy the resulting +value back into SRAM. After the kernel is executed, the SRAM value is copied back into HBM. We are in the process of writing backend-specific Pallas guides. Coming soon! @@ -91,19 +123,28 @@ We are in the process of writing backend-specific Pallas guides. Coming soon! +++ -In our "hello world" example, we wrote a very simple kernel. It takes advantage of the fact that our 8-sized arrays can comfortably fit inside the SRAM of hardware accelerators. In most real-world applications, this will not be the case! +In our "hello world" example, we wrote a very simple kernel. +It takes advantage of the fact that our 8-sized arrays can comfortably fit inside +the SRAM of hardware accelerators. +In most real-world applications, this will not be the case! +++ -Part of writing Pallas kernels is thinking about how to take big arrays that live in high-bandwidth memory (HBM, also known as DRAM) and expressing computations that operate on "blocks" of those arrays that can fit in SRAM. +Part of writing Pallas kernels is thinking about how to take big arrays that +live in high-bandwidth memory (HBM, also known as DRAM) and expressing computations +that operate on "blocks" of those arrays that can fit in SRAM. -### Grids +### Grids by example -To automatically "carve" up the inputs and outputs, you provide a `grid` and `BlockSpec`s to `pallas_call`. +To automatically "carve" up the inputs and outputs, you provide a `grid` and +`BlockSpec`s to `pallas_call`. -A `grid` is a tuple of integers (e.g. `()`, `(2, 3, 4)`, or `(8,)`) that specifies an iteration space. -For example, a grid `(4, 5)` would have 20 elements: `(0, 0), (0, 1), ..., (0, 4), (1, 0), ..., (3, 4)`. -We run the kernel function once for each element, a style of single-program multiple-data (SPMD) programming. +A `grid` is a tuple of integers (e.g. `()`, `(2, 3, 4)`, or `(8,)`) that specifies +an iteration space. +For example, a grid `(4, 5)` would have 20 elements: +`(0, 0), (0, 1), ..., (0, 4), (1, 0), ..., (3, 4)`. +We run the kernel function once for each element, a style of single-program +multiple-data (SPMD) programming.
@@ -112,7 +153,12 @@ We run the kernel function once for each element, a style of single-program mult A 2D grid
-When we provide a `grid` to `pallas_call`, the kernel is executed as many times as `prod(grid)`. Each of these invocations is referred to as a "program", To access which program (i.e. which element of the grid) the kernel is currently executing, we use `program_id(axis=...)`. For example, for invocation `(1, 2)`, `program_id(axis=0)` returns `1` and `program_id(axis=1)` returns `2`. +When we provide a `grid` to `pallas_call`, the kernel is executed as many times +as `prod(grid)`. Each of these invocations is referred to as a "program". +To access which program (i.e. which element of the grid) the kernel is currently +executing, we use `program_id(axis=...)`. +For example, for invocation `(1, 2)`, `program_id(axis=0)` returns `1` and +`program_id(axis=1)` returns `2`. Here's an example kernel that uses a `grid` and `program_id`. @@ -125,29 +171,44 @@ def iota_kernel(o_ref): We now execute it using `pallas_call` with an additional `grid` argument. ```{code-cell} ipython3 -def iota(len: int): +def iota(size: int): return pl.pallas_call(iota_kernel, - out_shape=jax.ShapeDtypeStruct((len,), jnp.int32), - grid=(len,))() + out_shape=jax.ShapeDtypeStruct((size,), jnp.int32), + grid=(size,))() iota(8) ``` -On GPUs, each program is executed in parallel on separate threads. Thus, we need to think about race conditions on writes to HBM. A reasonable approach is to write our kernels in such a way that different programs write to disjoint places in HBM to avoid these parallel writes. On the other hand, parallelizing the computation is how we can execute operations like matrix multiplications really quickly. +On GPUs, each program is executed in parallel on separate threads. +Thus, we need to think about race conditions on writes to HBM. +A reasonable approach is to write our kernels in such a way that different +programs write to disjoint places in HBM to avoid these parallel writes. +On the other hand, parallelizing the computation is how we can execute +operations like matrix multiplications really quickly. -On TPUs, programs are executed in a combination of parallel and sequential (depending on the architecture) so there are slightly different considerations. +On TPUs, programs are executed in a combination of parallel and sequential +(depending on the architecture) so there are slightly different considerations. + +You can read more details at {ref}`pallas_grid`. +++ -### Block specs +### Block specs by example +++ -With `grid` and `program_id` in mind, Pallas provides an abstraction that takes care of some common indexing patterns seen in a lot of kernels. -To build intution, let's try to implement a matrix multiplication. +With `grid` and `program_id` in mind, Pallas provides an abstraction that +takes care of some common indexing patterns seen in a lot of kernels. +To build intuition, let's try to implement a matrix multiplication. -A simple strategy for implementing a matrix multiplication in Pallas is to implement it recursively. We know our underlying hardware has support for small matrix multiplications (using GPU and TPU tensorcores), so we just express a big matrix multiplication in terms of smaller ones. +A simple strategy for implementing a matrix multiplication in Pallas is to +implement it recursively. +We know our underlying hardware has support for small matrix multiplications +(using GPU and TPU tensorcores), so we just express a big matrix multiplication +in terms of smaller ones. -Suppose we have input matrices $X$ and $Y$ and are computing $Z = XY$. We first express $X$ and $Y$ as block matrices. $X$ will have "row" blocks and $Y$ will have "column" blocks. +Suppose we have input matrices $X$ and $Y$ and are computing $Z = XY$. +We first express $X$ and $Y$ as block matrices. $X$ will have "row" blocks +and $Y$ will have "column" blocks. $$ \begin{align*} @@ -187,11 +248,16 @@ X_1 Y_0 & X_1 Y_1 \end{align*} $$ -Our strategy is that because $Z$ is also a block matrix, we can assign each of the programs in our Pallas kernel one of the output blocks. Computing each output block corresponds to doing a smaller matrix multiply between a "row" block of $X$ and a "column" block of $Y$. +Our strategy is that because $Z$ is also a block matrix, we can assign each of +the programs in our Pallas kernel one of the output blocks. +Computing each output block corresponds to doing a smaller matrix multiply +between a "row" block of $X$ and a "column" block of $Y$. +++ -To express this pattern, we use `BlockSpec`s. A `BlockSpec` specifies a block shape for each input and output, and an "index map" function, that maps a set of program indices to a block index. +To express this pattern, we use `BlockSpec`s. A `BlockSpec` specifies a block +shape for each input and output, and an "index map" function, that maps a +set of program indices to a block index.
@@ -201,13 +267,25 @@ A visualization of a `BlockSpec`
-For a concrete example, let's say we'd like to multiply two `(1024, 1024)` matrices `x` and `y` together to produce `z`, and would like to parallelize the computation 4 ways. We split up `z` into 4 `(512, 512)` blocks where each block is computed with a `(512, 1024) x (1024, 512)` matrix multiplication. To express this, we'd first use a `(2, 2)` grid (one block for each program). +For a concrete example, let's say we'd like to multiply two `(1024, 1024)` +matrices `x` and `y` together to produce `z`, and would like to parallelize +the computation 4 ways. We split up `z` into 4 `(512, 512)` blocks where +each block is computed with a `(512, 1024) x (1024, 512)` matrix multiplication. +To express this, we'd first use a `(2, 2)` grid (one block for each program). -For `x`, we use `BlockSpec(lambda i, j: (i, 0), (512, 1024))` -- this carves `x` up into "row" blocks. To see this see how both program instances `(1, 0)` and `(1, 1)` pick the `(1, 0)` block in `x`. For `y`, we use a transposed version `BlockSpec(lambda i, j: (0, j), (1024, 512))`. Finally, for `z` we use `BlockSpec(lambda i, j: (i, j), (512, 512))`. +For `x`, we use `BlockSpec((512, 1024), lambda i, j: (i, 0))` -- this +carves `x` up into "row" blocks. +To see this see how both program instances +`(1, 0)` and `(1, 1)` pick the `(1, 0)` block in `x`. +For `y`, we use a transposed version `BlockSpec((1024, 512), lambda i, j: (0, j))`. +Finally, for `z` we use `BlockSpec((512, 512), lambda i, j: (i, j))`. These `BlockSpec`s are passed into `pallas_call` via `in_specs` and `out_specs`. -Underneath the hood, `pallas_call` will automatically carve up your inputs and outputs into `Ref`s for each block that will be passed into the kernel. +For more detail on `BlockSpec`s see {ref}`pallas_blockspec`. + +Underneath the hood, `pallas_call` will automatically carve up your inputs and +outputs into `Ref`s for each block that will be passed into the kernel. ```{code-cell} ipython3 def matmul_kernel(x_ref, y_ref, z_ref): @@ -219,22 +297,24 @@ def matmul(x: jax.Array, y: jax.Array): out_shape=jax.ShapeDtypeStruct((x.shape[0], y.shape[1]), x.dtype), grid=(2, 2), in_specs=[ - pl.BlockSpec(lambda i, j: (i, 0), (x.shape[0] // 2, x.shape[1])), - pl.BlockSpec(lambda i, j: (0, j), (y.shape[0], y.shape[1] // 2)) + pl.BlockSpec((x.shape[0] // 2, x.shape[1]), lambda i, j: (i, 0)), + pl.BlockSpec((y.shape[0], y.shape[1] // 2), lambda i, j: (0, j)) ], out_specs=pl.BlockSpec( - lambda i, j: (i, j), (x.shape[0] // 2, y.shape[1] // 2) + (x.shape[0] // 2, y.shape[1] // 2), lambda i, j: (i, j), ) )(x, y) -k1, k2 = jax.random.split(jax.random.PRNGKey(0)) +k1, k2 = jax.random.split(jax.random.key(0)) x = jax.random.normal(k1, (1024, 1024)) y = jax.random.normal(k2, (1024, 1024)) z = matmul(x, y) np.testing.assert_allclose(z, x @ y) ``` -Note that this is a very naive implementation of a matrix multiplication but consider it a starting point for various types of optimizations. -Let's add an additional feature to our matrix multiply: fused activation. It's actually really easy! Just pass a higher-order activation function into the kernel. +Note that this is a very naive implementation of a matrix multiplication but +consider it a starting point for various types of optimizations. +Let's add an additional feature to our matrix multiply: fused activation. +It's actually really easy! Just pass a higher-order activation function into the kernel. ```{code-cell} ipython3 def matmul_kernel(x_ref, y_ref, z_ref, *, activation): @@ -246,24 +326,25 @@ def matmul(x: jax.Array, y: jax.Array, *, activation): out_shape=jax.ShapeDtypeStruct((x.shape[0], y.shape[1]), x.dtype), grid=(2, 2), in_specs=[ - pl.BlockSpec(lambda i, j: (i, 0), (x.shape[0] // 2, x.shape[1])), - pl.BlockSpec(lambda i, j: (0, j), (y.shape[0], y.shape[1] // 2)) + pl.BlockSpec((x.shape[0] // 2, x.shape[1]), lambda i, j: (i, 0)), + pl.BlockSpec((y.shape[0], y.shape[1] // 2), lambda i, j: (0, j)) ], out_specs=pl.BlockSpec( - lambda i, j: (i, j), (x.shape[0] // 2, y.shape[1] // 2) + (x.shape[0] // 2, y.shape[1] // 2), lambda i, j: (i, j) ), )(x, y) -k1, k2 = jax.random.split(jax.random.PRNGKey(0)) +k1, k2 = jax.random.split(jax.random.key(0)) x = jax.random.normal(k1, (1024, 1024)) y = jax.random.normal(k2, (1024, 1024)) z = matmul(x, y, activation=jax.nn.relu) np.testing.assert_allclose(z, jax.nn.relu(x @ y)) ``` -To conclude, let's highlight a cool feature of Pallas: it composes with `jax.vmap`! To turn this matrix multiplication into a batched version, we just need to `vmap` it. +To conclude, let's highlight a cool feature of Pallas: it composes with `jax.vmap`! +To turn this matrix multiplication into a batched version, we just need to `vmap` it. ```{code-cell} ipython3 -k1, k2 = jax.random.split(jax.random.PRNGKey(0)) +k1, k2 = jax.random.split(jax.random.key(0)) x = jax.random.normal(k1, (4, 1024, 1024)) y = jax.random.normal(k2, (4, 1024, 1024)) z = jax.vmap(partial(matmul, activation=jax.nn.relu))(x, y) diff --git a/docs/pallas/tpu/details.rst b/docs/pallas/tpu/details.rst index 73a8446fdaed..718ade0c7046 100644 --- a/docs/pallas/tpu/details.rst +++ b/docs/pallas/tpu/details.rst @@ -65,7 +65,8 @@ Noteworthy properties and restrictions ``BlockSpec``\s and grid iteration ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -``BlockSpec``\s generally behave as expected in Pallas --- every invocation of +``BlockSpec``\s (see :ref:`pallas_blockspec`) generally behave as expected +in Pallas --- every invocation of the kernel body gets access to slices of the inputs and is meant to initialize a slice of the output. @@ -147,8 +148,10 @@ grid axes over cores. This is an opt-in procedure. To allow that, .. pallas_call( ..., - mosaic_params=dict( - dimension_semantics=["parallel", "parallel", "arbitrary"] + compiler_params=dict( + mosaic=dict( + dimension_semantics=["parallel", "parallel", "arbitrary"] + ) ), ) @@ -266,7 +269,7 @@ Elementwise operations ^^^^^^^^^^^^^^^^^^^^^^ Many elementwise operations are supported. It is worth noting that the hardware -generally only supports elementwise compute using 32-bit types. When loading +generally only supports elementwise computation using 32-bit types. When loading operands that use lower-precision types, they should generally be upcast to a 32-bit type before applying elementwise ops. @@ -297,7 +300,7 @@ expensive (🔴). Many JAX functions are implemented in terms of other JAX primitives, so this list might not be comprehensive. For example, ``jax.nn.relu`` is implemented -in terms of comparisons and ``jnp.where`` and will work in Pallas kernels too. +in terms of comparisons and ``jnp.where`` will work in Pallas kernels too. Array constructors ^^^^^^^^^^^^^^^^^^ @@ -344,5 +347,5 @@ However, loop primitives get fully unrolled during the compilation at the moment, so try to keep the loop trip count reasonably small. Overusing control flow can lead to significant regressions in low-level code -generation, and it is recommended to try to squeeze as many computationaly +generation, and it is recommended to try to squeeze as many computationally expensive operations into a single basic block as possible. diff --git a/docs/pallas/tpu/pipelining.ipynb b/docs/pallas/tpu/pipelining.ipynb index fe9271e92d94..275a72f3837b 100644 --- a/docs/pallas/tpu/pipelining.ipynb +++ b/docs/pallas/tpu/pipelining.ipynb @@ -6,7 +6,9 @@ "id": "teoJ_fUwlu0l" }, "source": [ - "# Pipelining and `BlockSpec`s" + "# Pipelining\n", + "\n", + "" ] }, { @@ -15,7 +17,8 @@ "id": "gAJDZh1gBh-h" }, "source": [ - "In this guide we'll cover how memory spaces in TPU work and how to write pipelines in Pallas that overlap memory I/O with compute." + "In this guide we'll cover how memory spaces in TPU work and how to write\n", + "pipelines in Pallas that overlap memory I/O with compute." ] }, { @@ -42,17 +45,33 @@ "source": [ "## TPU and its memory spaces\n", "\n", - "A TPU and its TensorCore consist of memory spaces (where arrays can reside), registers (which temporarily store scalar and array values) and compute units (that do computation with values in registers). Below is a diagram of a TPU in which `x` and `y` are arrays that live in high-bandwidth memory (HBM):\n", + "A TPU and its TensorCore consist of memory spaces (where arrays can reside),\n", + "registers (which temporarily store scalar and array values) and compute units\n", + "(that do computation with values in registers).\n", + "Below is a diagram of a TPU in which `x` and `y` are arrays that live in\n", + "high-bandwidth memory (HBM):\n", "\n", "![TPU Memory Space Cartoon.png]()\n", "\n", "Let's talk about the components of this diagram in more detail:\n", "\n", - "* **Memory spaces**: A TPU has high-bandwidth memory (HBM) which is what we often think of as \"device memory\". There is also vector memory (VMEM), a cache meant for storing vector and array values, and scalar memory (SMEM), a cache designed to store scalar values.\n", - "* **Registers**: A TensorCore has two main types of registers: vector registers (VREGs) store array values, and scalar registers (SREGs) store scalar values. Values can be loaded into memory from their respective caches (VMEM for VREGs and SMEM for SREGs).\n", - "* **Compute units**: A TensorCore has a scalar unit, vector unit (VPU) and matrix unit (MXU) that can do numerical computation. Compute units operate on values that live in SREGs and VREGs and output values into those registers as well.\n", - "\n", - "In order to do a vectorized computation on our values `x` and `y` that live in HBM, we need to:\n", + "* **Memory spaces**: A TPU has high-bandwidth memory (HBM) which is what we\n", + " often think of as \"device memory\".\n", + " There is also vector memory (VMEM),\n", + " a cache meant for storing vector and array values, and scalar memory (SMEM),\n", + " a cache designed to store scalar values.\n", + "* **Registers**: A TensorCore has two main types of registers: vector\n", + " registers (VREGs) store array values, and scalar registers (SREGs) store\n", + " scalar values.\n", + " Values can be loaded into memory from their respective caches (VMEM for\n", + " VREGs and SMEM for SREGs).\n", + "* **Compute units**: A TensorCore has a scalar unit, vector unit (VPU) and\n", + " matrix unit (MXU) that can do numerical computation.\n", + " Compute units operate on values that live in SREGs and VREGs and output\n", + " values into those registers as well.\n", + "\n", + "In order to do a vectorized computation on our values `x` and `y` that live\n", + "in HBM, we need to:\n", "\n", "1. Copy the values `x` and `y` into VMEM.\n", "2. Load the values from VMEM into VREGs.\n", @@ -128,9 +147,20 @@ "source": [ "We've written two functions: `add_matrices_kernel` and `add_matrices`.\n", "\n", - "`add_matrices_kernel` operates using `Ref`s that live in VMEM. Loading from a VMEM `Ref` produces a value that lives in VREGs. Values in VREGs behave like `jax.Array`s in that we can use `jnp` and `jax.lax` operations on then to produce new values that live in VREGs. When we produce the values we'd like to return, we store them in the output VMEM `Ref`.\n", + "`add_matrices_kernel` operates using `Ref`s that live in VMEM.\n", + "Loading from a VMEM `Ref` produces a value that lives in VREGs.\n", + "Values in VREGs behave like `jax.Array`s in that we can use `jnp` and\n", + "`jax.lax` operations on them to produce new values that live in VREGs.\n", + "When we produce the values we'd like to return, we store them in the output\n", + "VMEM `Ref`.\n", "\n", - "The `add_matrices` function acts on `jax.Array`s and returns a `jax.Array`. Inside it, we pass `x` and `y` into `pallas_call`. `pallas_call` is responsible for copying `x` and `y` into VMEM and for allocating the VMEM buffers that the kernel operates on (including allocating `z_vmem_ref`, the output VMEM buffer). After the kernel function is finished running, `pallas_call` will also copy the value in `z_vmem_ref` to HBM, resulting in an output `jax.Array`." + "The `add_matrices` function acts on `jax.Array`s and returns a `jax.Array`.\n", + "Inside it, we pass `x` and `y` into `pallas_call`.\n", + "`pallas_call` is responsible for copying `x` and `y` into VMEM and for\n", + "allocating the VMEM buffers that the kernel operates on (including allocating\n", + "`z_vmem_ref`, the output VMEM buffer).\n", + "After the kernel function is finished running, `pallas_call` will also copy\n", + "the value in `z_vmem_ref` to HBM, resulting in an output `jax.Array`." ] }, { @@ -141,13 +171,22 @@ "source": [ "## Constraints of using VMEM/SMEM\n", "\n", - "Pallas exposes access to lower level memory spaces like VMEM and SMEM but writing kernels utilizing them adds some considerations.\n", + "Pallas exposes access to lower level memory spaces like VMEM and SMEM but\n", + "writing kernels utilizing them adds some considerations.\n", "\n", - "1. Memory capacity. VMEM and SMEM are *small*! VMEM on v4 TPUs is only 16MiB and SMEM ranges in the tens to hundreds of KiB. If our arrays are too big, we won't even be able to fit them into VMEM at all. For reference, a `f32[2048, 2048]` array is 16MiB, so our above kernel won't scale beyond moderately sized arrays.\n", + "1. Memory capacity. VMEM and SMEM are *small*! VMEM on v4 TPUs is only 16MiB\n", + " and SMEM ranges in the tens to hundreds of KiB.\n", + " If our arrays are too big, we won't even be able to fit them into VMEM at all.\n", + " For reference, a `f32[2048, 2048]` array is 16MiB, so our above kernel won't\n", + " scale beyond moderately sized arrays.\n", "\n", - "2. Memory bandwidth. Copying to/from HBM and VMEM takes a long time, at least compared to most compute instructions. The `add_matrices` function above will likely spend more time copying between HBM and VMEM than actually performing the addition itself.\n", + "2. Memory bandwidth. Copying to/from HBM and VMEM takes a long time, at least\n", + " compared to most compute instructions.\n", + " The `add_matrices` function above will likely spend more time copying\n", + " between HBM and VMEM than actually performing the addition itself.\n", "\n", - "With these two constraints in mind, we'll have to rethink our strategy for getting performance out of our TPUs." + "With these two constraints in mind, we'll have to rethink our strategy for\n", + "getting performance out of our TPUs." ] }, { @@ -158,13 +197,26 @@ "source": [ "## Primer: Pipelining\n", "\n", - "Pipelining our computation offers a way of dealing with both the memory capacity and bandwidth constraints in one fell swoop. What do we mean by pipelining?\n", + "Pipelining our computation offers a way of dealing with both the memory\n", + "capacity and bandwidth constraints in one fell swoop.\n", + "What do we mean by pipelining?\n", "\n", - "The goal is: *in parallel* copy to/from HBM and VMEM *while* utilizing our compute units. Naively this is difficult because in our program above we copy *all* of `x` and `y` before we start doing any compute with them, creating a dependence between the copy and the compute.\n", + "The goal is: *in parallel* copy to/from HBM and VMEM *while* utilizing our\n", + "compute units.\n", + "Naively this is difficult because in our program above we copy *all* of `x`\n", + "and `y` before we start doing any compute with them, creating a dependence\n", + "between the copy and the compute.\n", "\n", - "However, if we can chunk up our computation into several subcomputations (e.g. when we add two matrices, we can express that as addition of \"blocks\" of the original matrices together), we can now overlap the copies of one of those subcomputations with the compute of the other. Let's walk through a simple example:\n", + "However, if we can chunk up our computation into several subcomputations\n", + "(e.g. when we add two matrices, we can express that as addition of \"blocks\"\n", + "of the original matrices together), we can now overlap the copies of one of\n", + "those subcomputations with the compute of the other. Let's walk through a\n", + "simple example:\n", "\n", - "Let's say we split our arrays `x` and `y` into `x1, x2` and `y1, y2` (for example, split along the leading axis, resulting in two `(256, 512)` arrays for each input. We can now execute the following pipelined computation.\n", + "Let's say we split our arrays `x` and `y` into `x1, x2` and `y1, y2` (for\n", + "example, split along the leading axis, resulting in two `(256, 512)` arrays\n", + "for each input.\n", + "We can now execute the following pipelined computation.\n", "\n", "1. Copy `x1` and `y1` into VMEM.\n", "1. Start copying `x2` and `y2` into VMEM\n", @@ -180,9 +232,15 @@ "10. Start copying `z2` from VMEM back into HBM.\n", "10. Wait until `z2` is copied into HBM.\n", "\n", - "Any time we are doing compute here, we are asynchronously copying something. This means that some of the time spent copying is not wasted.\n", + "Any time we are doing compute here, we are asynchronously copying something.\n", + "This means that some of the time spent copying is not wasted.\n", "\n", - "The two most important numbers for determining how efficient a pipelined computation are a) how many floating point operations (FLOPs) we need to execute and b) how many bytes we need to copy to execute that computation. The ratio of these two (FLOPs/memory usage) is called the *arithmetic intensity* of an operation and determines if our pipeline will be compute bound or memory bound." + "The two most important numbers for determining how efficient a pipelined\n", + "computation are a) how many floating point operations (FLOPs) we need to\n", + "execute and b) how many bytes we need to copy to execute that computation.\n", + "The ratio of these two (FLOPs/memory usage) is called the\n", + "*arithmetic intensity* of an operation and determines if our pipeline will\n", + "be compute bound or memory bound." ] }, { @@ -200,66 +258,29 @@ "id": "U-dPTjlBverB" }, "source": [ - "How do we implement a pipeline like the one above in Pallas? It seems like a complex sequence of asynchronous data operations and executing kernels that would be a pain to implement manually. Fear not! Pallas offers an API for expressing pipelines without too much boilerplate, namely through `grid`s and `BlockSpec`s." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "x-LQKu8HwED7" - }, - "source": [ - "### `grid`, a.k.a. kernels in a loop\n", + "How do we implement a pipeline like the one above in Pallas?\n", + "It seems like a complex sequence of asynchronous data operations and\n", + "executing kernels that would be a pain to implement manually.\n", + "Fear not! Pallas offers an API for expressing pipelines without too much\n", + "boilerplate, namely through `grid`s and `BlockSpec`s.\n", "\n", - "See how in the above pipelined example, we are executing the same logic multiple times: steps 3-5 and 8-10 both execute the same operations, only on different inputs. The generalized version of this is a loop in which the same kernel is executed multiple times. `pallas_call` provides an option to do exactly that.\n", + "See how in the above pipelined example, we are executing the same logic\n", + "multiple times: steps 3-5 and 8-10 both execute the same operations,\n", + "only on different inputs.\n", + "The {func}`jax.experimental.pallas.pallas_call` provides a way to\n", + "execute a kernel multiple times, by using the `grid` argument.\n", + "See {ref}`pallas_grid`.\n", "\n", - "The number of iterations in the loop is specified via the `grid` argument to `pallas_call`. Conceptually:\n", - "```python\n", - "pl.pallas_call(some_kernel, grid=n)(...)\n", - "```\n", - "maps to\n", - "```python\n", - "for i in range(n):\n", - " # do HBM -> VMEM copies\n", - " some_kernel(...)\n", - " # do VMEM -> HBM copies\n", - "```\n", - "Grids can be generalized to be multi-dimensional, corresponding to nested loops. For example,\n", + "We also use {class}`jax.experimental.pallas.BlockSpec` to specify\n", + "how to construct the input of each kernel invocation.\n", + "See {ref}`pallas_blockspec`.\n", "\n", - "```python\n", - "pl.pallas_call(some_kernel, grid=(n, m))(...)\n", - "```\n", - "is equivalent to\n", - "```python\n", - "for i in range(n):\n", - " for j in range(m):\n", - " # do HBM -> VMEM copies\n", - " some_kernel(...)\n", - " # do VMEM -> HBM copies\n", - "```\n", - "This generalizes to any tuple of integers (a length `d` grid will correspond to `d` nested loops)." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "hRLr5JeyyEwM" - }, - "source": [ - "### `BlockSpec`, a.k.a. how to chunk up inputs" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "miWgPkytyIIa" - }, - "source": [ - "The next piece of information we need to provide Pallas in order to automatically pipeline our computation is information on how to chunk it up. Specifically, we need to provide a mapping between *the iteration of the loop* to *which block of our inputs and outputs to be operated on*. A `BlockSpec` is exactly these two pieces of information.\n", - "\n", - " First we pick a `block_shape` for our inputs. In the pipelining example above, we had `(512, 512)`-shaped arrays and split them along the leading dimension into two `(256, 512)`-shaped arrays. In this pipeline, our `block_shape` would be `(256, 512)`.\n", - "\n", - "We then provide an `index_map` function that maps the iteration space to the blocks. Specifically, in the aforementioned pipeline, on the 1st iteration we'd like to select `x1` and on the second iteration we'd like to use `x2`. This can be expressed with the following `index_map`:\n", + "In the pipelining example above, we had `(512, 512)`-shaped arrays and\n", + "split them along the leading dimension into two `(256, 512)`-shaped arrays.\n", + "In this pipeline, our `BlockSpec.block_shape` would be `(256, 512)`.\n", + "On the 1st iteration we'd\n", + "like to select `x1` and on the second iteration we'd like to use `x2`.\n", + "This can be expressed with the following `index_map`:\n", "\n", "```python\n", "def x_index_map(i):\n", @@ -268,7 +289,7 @@ "\n", "We'd then construct the `BlockSpec`:\n", "```python\n", - "block_spec = pl.BlockSpec(x_index_map, (256, 512))\n", + "block_spec = pl.BlockSpec((256, 512), x_index_map)\n", "```\n", "\n", "The `BlockSpec`s for `y` and `z` will be the same as the one for `x`." @@ -282,7 +303,9 @@ "source": [ "### Putting it together\n", "\n", - "We provide these arguments to `pallas_call` via `grid`, `in_specs` and `out_specs` (`in_specs` corresponds to the tuple of positional arguments, and `out_specs` corresponds to the output)." + "We provide these arguments to `pallas_call` via `grid`, `in_specs` and\n", + "`out_specs` (`in_specs` corresponds to the tuple of positional arguments,\n", + "and `out_specs` corresponds to the output)." ] }, { @@ -312,13 +335,14 @@ ], "source": [ "def add_matrices_pipelined(x: jax.Array, y: jax.Array) -> jax.Array:\n", - " block_spec = pl.BlockSpec(lambda i: (i, 0), (256, 512))\n", + " block_spec = pl.BlockSpec((256, 512), lambda i: (i, 0))\n", " return pl.pallas_call(\n", " add_matrices_kernel,\n", " out_shape=x,\n", " in_specs=[block_spec, block_spec],\n", " out_specs=block_spec,\n", - " grid=(2,))(x, y)\n", + " grid=(2,)\n", + " )(x, y)\n", "\n", "add_matrices_pipelined(x, y)" ] @@ -329,9 +353,17 @@ "id": "rkytgIZYzz4t" }, "source": [ - "We've only added a little bit of code to our original function to add automatic pipelining but the `BlockSpec`s and `grid` do a lot of heavy lifting!\n", + "We've only added a little bit of code to our original function to add\n", + "automatic pipelining but the `BlockSpec`s and `grid` do a lot of heavy\n", + "lifting!\n", "\n", - "How does it work? Well, the `BlockSpec`s provide enough information to start *prefetching* blocks of our input from HBM into VMEM. For example, if we are starting iteration `i` of our `grid`, we can pass `i + 1` into the `index_map` functions to obtain the blocks needed for the next iteration. We can then start an asynchronous copy for those blocks. Similarly for outputs, we can wait for the outputs of the previous iteration to be copied before starting the copy for the current iteration's outputs." + "How does it work? Well, the `BlockSpec`s provide enough information to start\n", + "*prefetching* blocks of our input from HBM into VMEM.\n", + "For example, if we are starting iteration `i` of our `grid`, we can pass\n", + "`i + 1` into the `index_map` functions to obtain the blocks needed for the\n", + "next iteration. We can then start an asynchronous copy for those blocks.\n", + "Similarly for outputs, we can wait for the outputs of the previous iteration\n", + "to be copied before starting the copy for the current iteration's outputs." ] }, { @@ -349,9 +381,15 @@ "id": "esY4GcIB0bqQ" }, "source": [ - "It's common to parameterize the block shapes in our kernel. Block sizes are perhaps the most important parameter to tune when optimizing the performance of Pallas kernels! They give us control over the pipeline (for example, picking smaller blocks adds more iterations to our pipelined loop where each iteration has less work to do).\n", + "It's common to parameterize the block shapes in our kernel. Block sizes are\n", + "perhaps the most important parameter to tune when optimizing the performance\n", + "of Pallas kernels! They give us control over the pipeline (for example,\n", + "picking smaller blocks adds more iterations to our pipelined loop where each\n", + "iteration has less work to do).\n", "\n", - "Furthermore, we could also carve up the inputs and outputs along the 2nd dimension (we are only splitting along the first right now). Let's write a more general kernel that handles both of these features." + "Furthermore, we could also carve up the inputs and outputs along the 2nd\n", + "dimension (we are only splitting along the first right now). Let's write a\n", + "more general kernel that handles both of these features." ] }, { @@ -366,8 +404,7 @@ " x: jax.Array, y: jax.Array, *, bm: int = 256, bn: int = 256\n", ") -> jax.Array:\n", " m, n = x.shape\n", - " block_spec = pl.BlockSpec(lambda i, j: (i, j), (bm, bn))\n", - "\n", + " block_spec = pl.BlockSpec((bm, bn), lambda i, j: (i, j))\n", " return pl.pallas_call(\n", " add_matrices_kernel,\n", " out_shape=x,\n", @@ -376,7 +413,6 @@ " grid=(m // bm, n // bn),\n", " )(x, y)\n", "\n", - "\n", "np.testing.assert_array_equal(\n", " add_matrices_pipelined_2d(x, y, bm=256, bn=256), x + y\n", ")\n", @@ -403,9 +439,11 @@ "id": "P3SqEKDe3Mar" }, "source": [ - "How would you implement something like `jnp.sum` using `pallas_call`? Specifically, we'd like to pipeline across the reduction dimension.\n", + "How would you implement something like `jnp.sum` using `pallas_call`?\n", + "Specifically, we'd like to pipeline across the reduction dimension.\n", "\n", - "Take the example of reducing a `(8, 512, 512)`-shaped array to a `(512, 512)`-shaped one." + "Take the example of reducing a `(8, 512, 512)`-shaped array to a\n", + "`(512, 512)`-shaped one." ] }, { @@ -444,7 +482,10 @@ "id": "5O3ByvuT3iyC" }, "source": [ - "To do this using `pallas_call`, we could use a grid of size `(8,)` and in each iteration `i` load `x[i]` into VMEM. Then we could add `x[i]` to an output VMEM buffer. Let's implement this naively first." + "To do this using `pallas_call`, we could use a grid of size `(8,)` and in\n", + "each iteration `i` load `x[i]` into VMEM.\n", + "Then we could add `x[i]` to an output VMEM buffer. Let's implement this\n", + "naively first." ] }, { @@ -484,10 +525,10 @@ " naive_sum_kernel,\n", " grid=grid,\n", " # None in `block_shape` means we pick a size of 1 and squeeze it away\n", - " in_specs=[pl.BlockSpec(lambda i: (i, 0, 0), (None, *out_shape))],\n", - " out_specs=pl.BlockSpec(lambda i: (0, 0), out_shape),\n", - " out_shape=jax.ShapeDtypeStruct(out_shape, x.dtype)\n", - " )(x)\n", + " in_specs=[pl.BlockSpec((None, *out_shape), lambda i: (i, 0, 0))],\n", + " out_specs=pl.BlockSpec(out_shape, lambda i: (0, 0)),\n", + " out_shape=jax.ShapeDtypeStruct(out_shape, x.dtype),\n", + " )(x)\n", "naive_sum(x)" ] }, @@ -497,11 +538,29 @@ "id": "Kv9qJYJY4jbK" }, "source": [ - "Notice how we've set up the `BlockSpec`s: we're loading the entirety of the `(512, 512)` dimension into VMEM (no pipelining there) but selecting the `i`-th dimension of `x` each iteration in the `index_map`. We are using a `None` for that dimension in the block shape, which indicates that we are selecting a singleton dimension from `x` that we would like squeeze away in the kernel. Therefore, `x_ref` is `(512, 512)`-shaped in VMEM as well.\n", + "Notice how we've set up the `BlockSpec`s: we're loading the entirety of\n", + "the `(512, 512)` dimension into VMEM (no pipelining there) but selecting\n", + "the `i`-th dimension of `x` each iteration in the `index_map`.\n", + "We are using a `None` for that dimension in the block shape, which indicates\n", + "that we are selecting a singleton dimension from `x` that we would like\n", + "to squeeze away in the kernel.\n", + "Therefore, `x_ref` is `(512, 512)`-shaped in VMEM as well.\n", "\n", - "`out_spec` uses `lambda i: (0, 0)` as its `index_map`, indicating that `o_ref` is unchanged over the course of the pipeline. This means that we can update its value each iteration by reading from and writing to it. Or can it? Actually there is one catch: *`o_ref` is initially garbage*, meaning we'll be accumulating into garbage. This will result in the overall function outputting the incorrect value!\n", + "`out_spec` uses `lambda i: (0, 0)` as its `index_map`, indicating that\n", + "`o_ref` is unchanged over the course of the pipeline.\n", + "This means that we can update its value each iteration by reading from and\n", + "writing to it. Or can it?\n", + "Actually there is one catch: *`o_ref` is initially garbage*, meaning we'll\n", + "be accumulating into garbage.\n", + "This will result in the overall function outputting the incorrect value!\n", "\n", - "Therefore, **whenever we do a reduction in a kernel, we need to make sure to initialize the `Ref` that is storing the reduced value**. We can accomplish this by conditionally writing a value to `out_ref` when we're on iteration 0. We can do this with the helper function `pl.when`, a convenience wrapper around `jax.lax.cond`, and `pl.program_id`, which queries which iteration in a grid axis we are in." + "Therefore, **whenever we do a reduction in a kernel, we need to make sure\n", + "to initialize the `Ref` that is storing the reduced value**.\n", + "We can accomplish this by conditionally writing a value to `out_ref`\n", + "when we're on iteration 0.\n", + "We can do this with the helper function `pl.when`, a convenience wrapper\n", + "around `jax.lax.cond`, and `pl.program_id`,\n", + "which queries which iteration in a grid axis we are in." ] }, { @@ -543,10 +602,11 @@ " sum_kernel,\n", " grid=grid,\n", " # None in `block_shape` means we pick a size of 1 and squeeze it away\n", - " in_specs=[pl.BlockSpec(lambda i: (i, 0, 0), (None, *out_shape))],\n", - " out_specs=pl.BlockSpec(lambda i: (0, 0), out_shape),\n", + " in_specs=[pl.BlockSpec((None, *out_shape), lambda i: (i, 0, 0))],\n", + " out_specs=pl.BlockSpec(out_shape, lambda i: (0, 0)),\n", " out_shape=jax.ShapeDtypeStruct(out_shape, x.dtype)\n", - " )(x)\n", + " )(x)\n", + "\n", "sum(x)" ] }, @@ -558,7 +618,16 @@ "source": [ "This `sum` function now outputs the correct values!\n", "\n", - "One last thing to note about reductions in Pallas are that **they must be done in the minormost (rightmost) dimensions of our grid** (our grid is 1-dimensional in the above example so we are reducing over its minormost dimension). This is because the pipeline that Pallas generates using the `BlockSpec`s, `grid` and kernel function *does not read outputs back from HBM*. Once you've written an output value back to HBM you cannot revisit it. Therefore, you cannot do a reduction across a grid dimension that has any revisiting and therefore all reductions need to happen in the rightmost dimensions." + "One last thing to note about reductions in Pallas are that **they must be\n", + "done in the minormost (rightmost) dimensions of our grid** (our grid is\n", + "1-dimensional in the above example so we are reducing over its minormost\n", + "dimension). This is because the pipeline that Pallas generates using\n", + "the `BlockSpec`s, `grid` and kernel function *does not read outputs back\n", + "from HBM*.\n", + "Once you've written an output value back to HBM you cannot revisit it.\n", + "Therefore, you cannot do a reduction across a grid dimension that has any\n", + "revisiting and therefore all reductions need to happen in the rightmost\n", + "dimensions." ] }, { @@ -576,13 +645,21 @@ "id": "0f4HAVzQ8n71" }, "source": [ - "Some TPU chips have two TensorCores but appear as one device to JAX users. This is called \"megacore\". The separate TensorCores have their own separate VMEM, VREGs, SMEM, SREGs and compute units but *share HBM*.\n", + "Some TPU chips have two TensorCores but appear as one device to JAX users.\n", + "This is called \"megacore\".\n", + "The separate TensorCores have their own separate VMEM, VREGs, SMEM, SREGs\n", + "and compute units but *share HBM*.\n", "\n", "![TPU Memory Space Cartoon (Megacore).png]()\n", "\n", - "Conceptually, TPUs in Megacore behave like very simple GPUs, i.e. they have only two threads. How do we modify our kernels to utilize both TensorCores simultaneously?\n", + "Conceptually, TPUs in Megacore behave like very simple GPUs, i.e. they have\n", + "only two threads.\n", + "How do we modify our kernels to utilize both TensorCores simultaneously?\n", "\n", - "The basic idea is that if we have embarassingly parallel dimensions in our computation, we can split up those dimensions across the TensorCores. We can indicate which dimensions are parallelizable by providing an annotation to `pallas_call` called `dimension_semantics`." + "The basic idea is that if we have embarrassingly parallel dimensions in our\n", + "computation, we can split up those dimensions across the TensorCores.\n", + "We can indicate which dimensions are parallelizable by providing an\n", + "annotation to `pallas_call` called `dimension_semantics`." ] }, { @@ -612,15 +689,15 @@ ], "source": [ "def add_matrices_pipelined_megacore(x: jax.Array, y: jax.Array) -> jax.Array:\n", - " block_spec = pl.BlockSpec(lambda i: (i, 0), (256, 512))\n", + " block_spec = pl.BlockSpec((256, 512), lambda i: (i, 0))\n", " return pl.pallas_call(\n", " add_matrices_kernel,\n", " out_shape=x,\n", " in_specs=[block_spec, block_spec],\n", " out_specs=block_spec,\n", " grid=(2,),\n", - " compiler_params=dict(mosaic=dict(dimension_semantics=(\"parallel\",))))(\n", - " x, y)\n", + " compiler_params=dict(mosaic=dict(dimension_semantics=(\"parallel\",)))\n", + " )(x, y)\n", "\n", "x, y = jnp.ones((512, 512)), jnp.ones((512, 512))\n", "add_matrices_pipelined_megacore(x, y)" @@ -632,9 +709,12 @@ "id": "xG51AiUC-8cl" }, "source": [ - "`dimension_semantics` should be a tuple of same length as `grid` where each entry is either `\"parallel\"` or `\"arbitrary\"`. `\"parallel\"` indicates to Pallas that the iterations of the for loop corresponding to that dimension can be executed independently without affecting the correctness of the program. `\"arbitrary\"` indicates to Pallas that there can be no assumptions made about this grid dimension and it therefore cannot be parallelized.\n", + "`dimension_semantics` should be a tuple of same length as `grid` where each\n", + "entry is either `\"parallel\"` or `\"arbitrary\"`. `\"parallel\"` indicates to Pallas that the iterations of the for loop corresponding to that dimension can be executed independently without affecting the correctness of the program. `\"arbitrary\"` indicates to Pallas that there can be no assumptions made about this grid dimension and it therefore cannot be parallelized.\n", "\n", - "By specifying `dimension_semantics`, we now execute the kernel simultaneously on each TensorCore. Pallas will handle splitting up the grid automatically.\n", + "By specifying `dimension_semantics`, we now execute the kernel\n", + "simultaneously on each TensorCore. Pallas will handle splitting up the grid\n", + "automatically.\n", "\n", "> Note that Megacore is only currently available on TPU `v4` and TPU `v5p`. Supplying `dimension_semantics` annotations is a no-op on other platforms, but *not* specifying it will result in only one TensorCore being used (even if there are more than one available)." ] @@ -647,7 +727,11 @@ "source": [ "## Conclusion\n", "\n", - "In this guide we covered how to express TPU pipelines using `pallas_call`, `grid` and `BlockSpec`s. We covered how to express nested loops via a multi-dimensional grid and how to handle reductions by initialize our accumulators at the beginning of the reduction. We also learned how to handle Megacore by adding annotations to the kernel.\n", + "In this guide we covered how to express TPU pipelines using `pallas_call`,\n", + "`grid` and `BlockSpec`s. We covered how to express nested loops via a\n", + "multi-dimensional grid and how to handle reductions by initialize our\n", + "accumulators at the beginning of the reduction.\n", + "We also learned how to handle Megacore by adding annotations to the kernel.\n", "\n", "Exercises left to the reader:\n", "* Try implementing a `sum` kernel that pipelines the other dimensions as well\n", diff --git a/docs/pallas/tpu/pipelining.md b/docs/pallas/tpu/pipelining.md index 8e42364c2307..d753b404db1a 100644 --- a/docs/pallas/tpu/pipelining.md +++ b/docs/pallas/tpu/pipelining.md @@ -5,7 +5,7 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.16.0 + jupytext_version: 1.16.1 kernelspec: display_name: Python 3 name: python3 @@ -13,11 +13,14 @@ kernelspec: +++ {"id": "teoJ_fUwlu0l"} -# Pipelining and `BlockSpec`s +# Pipelining + + +++ {"id": "gAJDZh1gBh-h"} -In this guide we'll cover how memory spaces in TPU work and how to write pipelines in Pallas that overlap memory I/O with compute. +In this guide we'll cover how memory spaces in TPU work and how to write +pipelines in Pallas that overlap memory I/O with compute. ```{code-cell} :id: ejAVO6ikUUuF @@ -34,17 +37,33 @@ import numpy as np ## TPU and its memory spaces -A TPU and its TensorCore consist of memory spaces (where arrays can reside), registers (which temporarily store scalar and array values) and compute units (that do computation with values in registers). Below is a diagram of a TPU in which `x` and `y` are arrays that live in high-bandwidth memory (HBM): +A TPU and its TensorCore consist of memory spaces (where arrays can reside), +registers (which temporarily store scalar and array values) and compute units +(that do computation with values in registers). +Below is a diagram of a TPU in which `x` and `y` are arrays that live in +high-bandwidth memory (HBM): ![TPU Memory Space Cartoon.png]() Let's talk about the components of this diagram in more detail: -* **Memory spaces**: A TPU has high-bandwidth memory (HBM) which is what we often think of as "device memory". There is also vector memory (VMEM), a cache meant for storing vector and array values, and scalar memory (SMEM), a cache designed to store scalar values. -* **Registers**: A TensorCore has two main types of registers: vector registers (VREGs) store array values, and scalar registers (SREGs) store scalar values. Values can be loaded into memory from their respective caches (VMEM for VREGs and SMEM for SREGs). -* **Compute units**: A TensorCore has a scalar unit, vector unit (VPU) and matrix unit (MXU) that can do numerical computation. Compute units operate on values that live in SREGs and VREGs and output values into those registers as well. - -In order to do a vectorized computation on our values `x` and `y` that live in HBM, we need to: +* **Memory spaces**: A TPU has high-bandwidth memory (HBM) which is what we + often think of as "device memory". + There is also vector memory (VMEM), + a cache meant for storing vector and array values, and scalar memory (SMEM), + a cache designed to store scalar values. +* **Registers**: A TensorCore has two main types of registers: vector + registers (VREGs) store array values, and scalar registers (SREGs) store + scalar values. + Values can be loaded into memory from their respective caches (VMEM for + VREGs and SMEM for SREGs). +* **Compute units**: A TensorCore has a scalar unit, vector unit (VPU) and + matrix unit (MXU) that can do numerical computation. + Compute units operate on values that live in SREGs and VREGs and output + values into those registers as well. + +In order to do a vectorized computation on our values `x` and `y` that live +in HBM, we need to: 1. Copy the values `x` and `y` into VMEM. 2. Load the values from VMEM into VREGs. @@ -88,33 +107,66 @@ add_matrices(x, y) We've written two functions: `add_matrices_kernel` and `add_matrices`. -`add_matrices_kernel` operates using `Ref`s that live in VMEM. Loading from a VMEM `Ref` produces a value that lives in VREGs. Values in VREGs behave like `jax.Array`s in that we can use `jnp` and `jax.lax` operations on then to produce new values that live in VREGs. When we produce the values we'd like to return, we store them in the output VMEM `Ref`. - -The `add_matrices` function acts on `jax.Array`s and returns a `jax.Array`. Inside it, we pass `x` and `y` into `pallas_call`. `pallas_call` is responsible for copying `x` and `y` into VMEM and for allocating the VMEM buffers that the kernel operates on (including allocating `z_vmem_ref`, the output VMEM buffer). After the kernel function is finished running, `pallas_call` will also copy the value in `z_vmem_ref` to HBM, resulting in an output `jax.Array`. +`add_matrices_kernel` operates using `Ref`s that live in VMEM. +Loading from a VMEM `Ref` produces a value that lives in VREGs. +Values in VREGs behave like `jax.Array`s in that we can use `jnp` and +`jax.lax` operations on them to produce new values that live in VREGs. +When we produce the values we'd like to return, we store them in the output +VMEM `Ref`. + +The `add_matrices` function acts on `jax.Array`s and returns a `jax.Array`. +Inside it, we pass `x` and `y` into `pallas_call`. +`pallas_call` is responsible for copying `x` and `y` into VMEM and for +allocating the VMEM buffers that the kernel operates on (including allocating +`z_vmem_ref`, the output VMEM buffer). +After the kernel function is finished running, `pallas_call` will also copy +the value in `z_vmem_ref` to HBM, resulting in an output `jax.Array`. +++ {"id": "5kWr-1tKpYro"} ## Constraints of using VMEM/SMEM -Pallas exposes access to lower level memory spaces like VMEM and SMEM but writing kernels utilizing them adds some considerations. +Pallas exposes access to lower level memory spaces like VMEM and SMEM but +writing kernels utilizing them adds some considerations. -1. Memory capacity. VMEM and SMEM are *small*! VMEM on v4 TPUs is only 16MiB and SMEM ranges in the tens to hundreds of KiB. If our arrays are too big, we won't even be able to fit them into VMEM at all. For reference, a `f32[2048, 2048]` array is 16MiB, so our above kernel won't scale beyond moderately sized arrays. +1. Memory capacity. VMEM and SMEM are *small*! VMEM on v4 TPUs is only 16MiB + and SMEM ranges in the tens to hundreds of KiB. + If our arrays are too big, we won't even be able to fit them into VMEM at all. + For reference, a `f32[2048, 2048]` array is 16MiB, so our above kernel won't + scale beyond moderately sized arrays. -2. Memory bandwidth. Copying to/from HBM and VMEM takes a long time, at least compared to most compute instructions. The `add_matrices` function above will likely spend more time copying between HBM and VMEM than actually performing the addition itself. +2. Memory bandwidth. Copying to/from HBM and VMEM takes a long time, at least + compared to most compute instructions. + The `add_matrices` function above will likely spend more time copying + between HBM and VMEM than actually performing the addition itself. -With these two constraints in mind, we'll have to rethink our strategy for getting performance out of our TPUs. +With these two constraints in mind, we'll have to rethink our strategy for +getting performance out of our TPUs. +++ {"id": "_NTqvlbetB3P"} ## Primer: Pipelining -Pipelining our computation offers a way of dealing with both the memory capacity and bandwidth constraints in one fell swoop. What do we mean by pipelining? +Pipelining our computation offers a way of dealing with both the memory +capacity and bandwidth constraints in one fell swoop. +What do we mean by pipelining? -The goal is: *in parallel* copy to/from HBM and VMEM *while* utilizing our compute units. Naively this is difficult because in our program above we copy *all* of `x` and `y` before we start doing any compute with them, creating a dependence between the copy and the compute. +The goal is: *in parallel* copy to/from HBM and VMEM *while* utilizing our +compute units. +Naively this is difficult because in our program above we copy *all* of `x` +and `y` before we start doing any compute with them, creating a dependence +between the copy and the compute. -However, if we can chunk up our computation into several subcomputations (e.g. when we add two matrices, we can express that as addition of "blocks" of the original matrices together), we can now overlap the copies of one of those subcomputations with the compute of the other. Let's walk through a simple example: +However, if we can chunk up our computation into several subcomputations +(e.g. when we add two matrices, we can express that as addition of "blocks" +of the original matrices together), we can now overlap the copies of one of +those subcomputations with the compute of the other. Let's walk through a +simple example: -Let's say we split our arrays `x` and `y` into `x1, x2` and `y1, y2` (for example, split along the leading axis, resulting in two `(256, 512)` arrays for each input. We can now execute the following pipelined computation. +Let's say we split our arrays `x` and `y` into `x1, x2` and `y1, y2` (for +example, split along the leading axis, resulting in two `(256, 512)` arrays +for each input. +We can now execute the following pipelined computation. 1. Copy `x1` and `y1` into VMEM. 1. Start copying `x2` and `y2` into VMEM @@ -130,9 +182,15 @@ Let's say we split our arrays `x` and `y` into `x1, x2` and `y1, y2` (for exampl 10. Start copying `z2` from VMEM back into HBM. 10. Wait until `z2` is copied into HBM. -Any time we are doing compute here, we are asynchronously copying something. This means that some of the time spent copying is not wasted. +Any time we are doing compute here, we are asynchronously copying something. +This means that some of the time spent copying is not wasted. -The two most important numbers for determining how efficient a pipelined computation are a) how many floating point operations (FLOPs) we need to execute and b) how many bytes we need to copy to execute that computation. The ratio of these two (FLOPs/memory usage) is called the *arithmetic intensity* of an operation and determines if our pipeline will be compute bound or memory bound. +The two most important numbers for determining how efficient a pipelined +computation are a) how many floating point operations (FLOPs) we need to +execute and b) how many bytes we need to copy to execute that computation. +The ratio of these two (FLOPs/memory usage) is called the +*arithmetic intensity* of an operation and determines if our pipeline will +be compute bound or memory bound. +++ {"id": "gutx7y8uvZKH"} @@ -140,51 +198,29 @@ The two most important numbers for determining how efficient a pipelined computa +++ {"id": "U-dPTjlBverB"} -How do we implement a pipeline like the one above in Pallas? It seems like a complex sequence of asynchronous data operations and executing kernels that would be a pain to implement manually. Fear not! Pallas offers an API for expressing pipelines without too much boilerplate, namely through `grid`s and `BlockSpec`s. - -+++ {"id": "x-LQKu8HwED7"} - -### `grid`, a.k.a. kernels in a loop - -See how in the above pipelined example, we are executing the same logic multiple times: steps 3-5 and 8-10 both execute the same operations, only on different inputs. The generalized version of this is a loop in which the same kernel is executed multiple times. `pallas_call` provides an option to do exactly that. - -The number of iterations in the loop is specified via the `grid` argument to `pallas_call`. Conceptually: -```python -pl.pallas_call(some_kernel, grid=n)(...) -``` -maps to -```python -for i in range(n): - # do HBM -> VMEM copies - some_kernel(...) - # do VMEM -> HBM copies -``` -Grids can be generalized to be multi-dimensional, corresponding to nested loops. For example, - -```python -pl.pallas_call(some_kernel, grid=(n, m))(...) -``` -is equivalent to -```python -for i in range(n): - for j in range(m): - # do HBM -> VMEM copies - some_kernel(...) - # do VMEM -> HBM copies -``` -This generalizes to any tuple of integers (a length `d` grid will correspond to `d` nested loops). - -+++ {"id": "hRLr5JeyyEwM"} - -### `BlockSpec`, a.k.a. how to chunk up inputs - -+++ {"id": "miWgPkytyIIa"} - -The next piece of information we need to provide Pallas in order to automatically pipeline our computation is information on how to chunk it up. Specifically, we need to provide a mapping between *the iteration of the loop* to *which block of our inputs and outputs to be operated on*. A `BlockSpec` is exactly these two pieces of information. - - First we pick a `block_shape` for our inputs. In the pipelining example above, we had `(512, 512)`-shaped arrays and split them along the leading dimension into two `(256, 512)`-shaped arrays. In this pipeline, our `block_shape` would be `(256, 512)`. - -We then provide an `index_map` function that maps the iteration space to the blocks. Specifically, in the aforementioned pipeline, on the 1st iteration we'd like to select `x1` and on the second iteration we'd like to use `x2`. This can be expressed with the following `index_map`: +How do we implement a pipeline like the one above in Pallas? +It seems like a complex sequence of asynchronous data operations and +executing kernels that would be a pain to implement manually. +Fear not! Pallas offers an API for expressing pipelines without too much +boilerplate, namely through `grid`s and `BlockSpec`s. + +See how in the above pipelined example, we are executing the same logic +multiple times: steps 3-5 and 8-10 both execute the same operations, +only on different inputs. +The {func}`jax.experimental.pallas.pallas_call` provides a way to +execute a kernel multiple times, by using the `grid` argument. +See {ref}`pallas_grid`. + +We also use {class}`jax.experimental.pallas.BlockSpec` to specify +how to construct the input of each kernel invocation. +See {ref}`pallas_blockspec`. + +In the pipelining example above, we had `(512, 512)`-shaped arrays and +split them along the leading dimension into two `(256, 512)`-shaped arrays. +In this pipeline, our `BlockSpec.block_shape` would be `(256, 512)`. +On the 1st iteration we'd +like to select `x1` and on the second iteration we'd like to use `x2`. +This can be expressed with the following `index_map`: ```python def x_index_map(i): @@ -193,7 +229,7 @@ def x_index_map(i): We'd then construct the `BlockSpec`: ```python -block_spec = pl.BlockSpec(x_index_map, (256, 512)) +block_spec = pl.BlockSpec((256, 512), x_index_map) ``` The `BlockSpec`s for `y` and `z` will be the same as the one for `x`. @@ -202,29 +238,40 @@ The `BlockSpec`s for `y` and `z` will be the same as the one for `x`. ### Putting it together -We provide these arguments to `pallas_call` via `grid`, `in_specs` and `out_specs` (`in_specs` corresponds to the tuple of positional arguments, and `out_specs` corresponds to the output). +We provide these arguments to `pallas_call` via `grid`, `in_specs` and +`out_specs` (`in_specs` corresponds to the tuple of positional arguments, +and `out_specs` corresponds to the output). ```{code-cell} :id: ehKAYAwIojfv :outputId: 504bab29-83f3-4e1f-8664-1860ad15b6de def add_matrices_pipelined(x: jax.Array, y: jax.Array) -> jax.Array: - block_spec = pl.BlockSpec(lambda i: (i, 0), (256, 512)) + block_spec = pl.BlockSpec((256, 512), lambda i: (i, 0)) return pl.pallas_call( add_matrices_kernel, out_shape=x, in_specs=[block_spec, block_spec], out_specs=block_spec, - grid=(2,))(x, y) + grid=(2,) + )(x, y) add_matrices_pipelined(x, y) ``` +++ {"id": "rkytgIZYzz4t"} -We've only added a little bit of code to our original function to add automatic pipelining but the `BlockSpec`s and `grid` do a lot of heavy lifting! +We've only added a little bit of code to our original function to add +automatic pipelining but the `BlockSpec`s and `grid` do a lot of heavy +lifting! -How does it work? Well, the `BlockSpec`s provide enough information to start *prefetching* blocks of our input from HBM into VMEM. For example, if we are starting iteration `i` of our `grid`, we can pass `i + 1` into the `index_map` functions to obtain the blocks needed for the next iteration. We can then start an asynchronous copy for those blocks. Similarly for outputs, we can wait for the outputs of the previous iteration to be copied before starting the copy for the current iteration's outputs. +How does it work? Well, the `BlockSpec`s provide enough information to start +*prefetching* blocks of our input from HBM into VMEM. +For example, if we are starting iteration `i` of our `grid`, we can pass +`i + 1` into the `index_map` functions to obtain the blocks needed for the +next iteration. We can then start an asynchronous copy for those blocks. +Similarly for outputs, we can wait for the outputs of the previous iteration +to be copied before starting the copy for the current iteration's outputs. +++ {"id": "7Xtz9oMs0ZRL"} @@ -232,9 +279,15 @@ How does it work? Well, the `BlockSpec`s provide enough information to start *pr +++ {"id": "esY4GcIB0bqQ"} -It's common to parameterize the block shapes in our kernel. Block sizes are perhaps the most important parameter to tune when optimizing the performance of Pallas kernels! They give us control over the pipeline (for example, picking smaller blocks adds more iterations to our pipelined loop where each iteration has less work to do). +It's common to parameterize the block shapes in our kernel. Block sizes are +perhaps the most important parameter to tune when optimizing the performance +of Pallas kernels! They give us control over the pipeline (for example, +picking smaller blocks adds more iterations to our pipelined loop where each +iteration has less work to do). -Furthermore, we could also carve up the inputs and outputs along the 2nd dimension (we are only splitting along the first right now). Let's write a more general kernel that handles both of these features. +Furthermore, we could also carve up the inputs and outputs along the 2nd +dimension (we are only splitting along the first right now). Let's write a +more general kernel that handles both of these features. ```{code-cell} :id: VartelFd0YfY @@ -243,8 +296,7 @@ def add_matrices_pipelined_2d( x: jax.Array, y: jax.Array, *, bm: int = 256, bn: int = 256 ) -> jax.Array: m, n = x.shape - block_spec = pl.BlockSpec(lambda i, j: (i, j), (bm, bn)) - + block_spec = pl.BlockSpec((bm, bn), lambda i, j: (i, j)) return pl.pallas_call( add_matrices_kernel, out_shape=x, @@ -253,7 +305,6 @@ def add_matrices_pipelined_2d( grid=(m // bm, n // bn), )(x, y) - np.testing.assert_array_equal( add_matrices_pipelined_2d(x, y, bm=256, bn=256), x + y ) @@ -271,9 +322,11 @@ np.testing.assert_array_equal( +++ {"id": "P3SqEKDe3Mar"} -How would you implement something like `jnp.sum` using `pallas_call`? Specifically, we'd like to pipeline across the reduction dimension. +How would you implement something like `jnp.sum` using `pallas_call`? +Specifically, we'd like to pipeline across the reduction dimension. -Take the example of reducing a `(8, 512, 512)`-shaped array to a `(512, 512)`-shaped one. +Take the example of reducing a `(8, 512, 512)`-shaped array to a +`(512, 512)`-shaped one. ```{code-cell} :id: JoT-ZKEk1R7l @@ -285,7 +338,10 @@ jnp.sum(x, axis=0) +++ {"id": "5O3ByvuT3iyC"} -To do this using `pallas_call`, we could use a grid of size `(8,)` and in each iteration `i` load `x[i]` into VMEM. Then we could add `x[i]` to an output VMEM buffer. Let's implement this naively first. +To do this using `pallas_call`, we could use a grid of size `(8,)` and in +each iteration `i` load `x[i]` into VMEM. +Then we could add `x[i]` to an output VMEM buffer. Let's implement this +naively first. ```{code-cell} :id: hqvv_WRQ3bvP @@ -302,20 +358,38 @@ def naive_sum(x: jax.Array) -> jax.Array: naive_sum_kernel, grid=grid, # None in `block_shape` means we pick a size of 1 and squeeze it away - in_specs=[pl.BlockSpec(lambda i: (i, 0, 0), (None, *out_shape))], - out_specs=pl.BlockSpec(lambda i: (0, 0), out_shape), - out_shape=jax.ShapeDtypeStruct(out_shape, x.dtype) - )(x) + in_specs=[pl.BlockSpec((None, *out_shape), lambda i: (i, 0, 0))], + out_specs=pl.BlockSpec(out_shape, lambda i: (0, 0)), + out_shape=jax.ShapeDtypeStruct(out_shape, x.dtype), + )(x) naive_sum(x) ``` +++ {"id": "Kv9qJYJY4jbK"} -Notice how we've set up the `BlockSpec`s: we're loading the entirety of the `(512, 512)` dimension into VMEM (no pipelining there) but selecting the `i`-th dimension of `x` each iteration in the `index_map`. We are using a `None` for that dimension in the block shape, which indicates that we are selecting a singleton dimension from `x` that we would like squeeze away in the kernel. Therefore, `x_ref` is `(512, 512)`-shaped in VMEM as well. - -`out_spec` uses `lambda i: (0, 0)` as its `index_map`, indicating that `o_ref` is unchanged over the course of the pipeline. This means that we can update its value each iteration by reading from and writing to it. Or can it? Actually there is one catch: *`o_ref` is initially garbage*, meaning we'll be accumulating into garbage. This will result in the overall function outputting the incorrect value! - -Therefore, **whenever we do a reduction in a kernel, we need to make sure to initialize the `Ref` that is storing the reduced value**. We can accomplish this by conditionally writing a value to `out_ref` when we're on iteration 0. We can do this with the helper function `pl.when`, a convenience wrapper around `jax.lax.cond`, and `pl.program_id`, which queries which iteration in a grid axis we are in. +Notice how we've set up the `BlockSpec`s: we're loading the entirety of +the `(512, 512)` dimension into VMEM (no pipelining there) but selecting +the `i`-th dimension of `x` each iteration in the `index_map`. +We are using a `None` for that dimension in the block shape, which indicates +that we are selecting a singleton dimension from `x` that we would like +to squeeze away in the kernel. +Therefore, `x_ref` is `(512, 512)`-shaped in VMEM as well. + +`out_spec` uses `lambda i: (0, 0)` as its `index_map`, indicating that +`o_ref` is unchanged over the course of the pipeline. +This means that we can update its value each iteration by reading from and +writing to it. Or can it? +Actually there is one catch: *`o_ref` is initially garbage*, meaning we'll +be accumulating into garbage. +This will result in the overall function outputting the incorrect value! + +Therefore, **whenever we do a reduction in a kernel, we need to make sure +to initialize the `Ref` that is storing the reduced value**. +We can accomplish this by conditionally writing a value to `out_ref` +when we're on iteration 0. +We can do this with the helper function `pl.when`, a convenience wrapper +around `jax.lax.cond`, and `pl.program_id`, +which queries which iteration in a grid axis we are in. ```{code-cell} :id: JXN2RthX5cSw @@ -334,10 +408,11 @@ def sum(x: jax.Array) -> jax.Array: sum_kernel, grid=grid, # None in `block_shape` means we pick a size of 1 and squeeze it away - in_specs=[pl.BlockSpec(lambda i: (i, 0, 0), (None, *out_shape))], - out_specs=pl.BlockSpec(lambda i: (0, 0), out_shape), + in_specs=[pl.BlockSpec((None, *out_shape), lambda i: (i, 0, 0))], + out_specs=pl.BlockSpec(out_shape, lambda i: (0, 0)), out_shape=jax.ShapeDtypeStruct(out_shape, x.dtype) - )(x) + )(x) + sum(x) ``` @@ -345,7 +420,16 @@ sum(x) This `sum` function now outputs the correct values! -One last thing to note about reductions in Pallas are that **they must be done in the minormost (rightmost) dimensions of our grid** (our grid is 1-dimensional in the above example so we are reducing over its minormost dimension). This is because the pipeline that Pallas generates using the `BlockSpec`s, `grid` and kernel function *does not read outputs back from HBM*. Once you've written an output value back to HBM you cannot revisit it. Therefore, you cannot do a reduction across a grid dimension that has any revisiting and therefore all reductions need to happen in the rightmost dimensions. +One last thing to note about reductions in Pallas are that **they must be +done in the minormost (rightmost) dimensions of our grid** (our grid is +1-dimensional in the above example so we are reducing over its minormost +dimension). This is because the pipeline that Pallas generates using +the `BlockSpec`s, `grid` and kernel function *does not read outputs back +from HBM*. +Once you've written an output value back to HBM you cannot revisit it. +Therefore, you cannot do a reduction across a grid dimension that has any +revisiting and therefore all reductions need to happen in the rightmost +dimensions. +++ {"id": "KvPFez9N8cKJ"} @@ -353,28 +437,36 @@ One last thing to note about reductions in Pallas are that **they must be done i +++ {"id": "0f4HAVzQ8n71"} -Some TPU chips have two TensorCores but appear as one device to JAX users. This is called "megacore". The separate TensorCores have their own separate VMEM, VREGs, SMEM, SREGs and compute units but *share HBM*. +Some TPU chips have two TensorCores but appear as one device to JAX users. +This is called "megacore". +The separate TensorCores have their own separate VMEM, VREGs, SMEM, SREGs +and compute units but *share HBM*. ![TPU Memory Space Cartoon (Megacore).png]() -Conceptually, TPUs in Megacore behave like very simple GPUs, i.e. they have only two threads. How do we modify our kernels to utilize both TensorCores simultaneously? +Conceptually, TPUs in Megacore behave like very simple GPUs, i.e. they have +only two threads. +How do we modify our kernels to utilize both TensorCores simultaneously? -The basic idea is that if we have embarassingly parallel dimensions in our computation, we can split up those dimensions across the TensorCores. We can indicate which dimensions are parallelizable by providing an annotation to `pallas_call` called `dimension_semantics`. +The basic idea is that if we have embarrassingly parallel dimensions in our +computation, we can split up those dimensions across the TensorCores. +We can indicate which dimensions are parallelizable by providing an +annotation to `pallas_call` called `dimension_semantics`. ```{code-cell} :id: nQNa8RaQ-TR1 :outputId: 385ed87c-d95c-466c-af77-df3845c979f2 def add_matrices_pipelined_megacore(x: jax.Array, y: jax.Array) -> jax.Array: - block_spec = pl.BlockSpec(lambda i: (i, 0), (256, 512)) + block_spec = pl.BlockSpec((256, 512), lambda i: (i, 0)) return pl.pallas_call( add_matrices_kernel, out_shape=x, in_specs=[block_spec, block_spec], out_specs=block_spec, grid=(2,), - compiler_params=dict(mosaic=dict(dimension_semantics=("parallel",))))( - x, y) + compiler_params=dict(mosaic=dict(dimension_semantics=("parallel",))) + )(x, y) x, y = jnp.ones((512, 512)), jnp.ones((512, 512)) add_matrices_pipelined_megacore(x, y) @@ -382,9 +474,12 @@ add_matrices_pipelined_megacore(x, y) +++ {"id": "xG51AiUC-8cl"} -`dimension_semantics` should be a tuple of same length as `grid` where each entry is either `"parallel"` or `"arbitrary"`. `"parallel"` indicates to Pallas that the iterations of the for loop corresponding to that dimension can be executed independently without affecting the correctness of the program. `"arbitrary"` indicates to Pallas that there can be no assumptions made about this grid dimension and it therefore cannot be parallelized. +`dimension_semantics` should be a tuple of same length as `grid` where each +entry is either `"parallel"` or `"arbitrary"`. `"parallel"` indicates to Pallas that the iterations of the for loop corresponding to that dimension can be executed independently without affecting the correctness of the program. `"arbitrary"` indicates to Pallas that there can be no assumptions made about this grid dimension and it therefore cannot be parallelized. -By specifying `dimension_semantics`, we now execute the kernel simultaneously on each TensorCore. Pallas will handle splitting up the grid automatically. +By specifying `dimension_semantics`, we now execute the kernel +simultaneously on each TensorCore. Pallas will handle splitting up the grid +automatically. > Note that Megacore is only currently available on TPU `v4` and TPU `v5p`. Supplying `dimension_semantics` annotations is a no-op on other platforms, but *not* specifying it will result in only one TensorCore being used (even if there are more than one available). @@ -392,7 +487,11 @@ By specifying `dimension_semantics`, we now execute the kernel simultaneously on ## Conclusion -In this guide we covered how to express TPU pipelines using `pallas_call`, `grid` and `BlockSpec`s. We covered how to express nested loops via a multi-dimensional grid and how to handle reductions by initialize our accumulators at the beginning of the reduction. We also learned how to handle Megacore by adding annotations to the kernel. +In this guide we covered how to express TPU pipelines using `pallas_call`, +`grid` and `BlockSpec`s. We covered how to express nested loops via a +multi-dimensional grid and how to handle reductions by initialize our +accumulators at the beginning of the reduction. +We also learned how to handle Megacore by adding annotations to the kernel. Exercises left to the reader: * Try implementing a `sum` kernel that pipelines the other dimensions as well diff --git a/docs/persistent_compilation_cache.md b/docs/persistent_compilation_cache.md new file mode 100644 index 000000000000..2f748825af1f --- /dev/null +++ b/docs/persistent_compilation_cache.md @@ -0,0 +1,76 @@ +# Persistent Compilation Cache + + + +JAX has an optional disk cache for compiled programs. If enabled, JAX will +store copies of compiled programs on disk, which can save recompilation time +when running the same or similar tasks repeatedly. + +## Usage + +The compilation cache is enabled when the +[cache-location](https://github.com/google/jax/blob/jax-v0.4.26/jax/_src/config.py#L1206) +is set. This should be done prior to the first compilation. Set the location as +follows: + +``` +import jax + +# Make sure this is called before jax runs any operations! +jax.config.update("jax_compilation_cache_dir", "cache-location") +``` + +See the sections below for more detail on `cache-location`. + +[`set_cache_dir()`](https://github.com/google/jax/blob/jax-v0.4.26/jax/experimental/compilation_cache/compilation_cache.py#L18) +is an alternate way of setting `cache-location`. + +### Local filesystem + +`cache-location` can be a directory on the local filesystem. For example: + +``` +import jax + +jax.config.update("jax_compilation_cache_dir", "/tmp/jax-cache") +``` + +Note: the cache does not have an eviction mechanism implemented. If the +cache-location is a directory in the local filesystem, its size will continue +to grow unless files are manually deleted. + +### Google Cloud + +When running on Google Cloud, the compilation cache can be placed on a Google +Cloud Storage (GCS) bucket. We recommend the following configuration: + +* Create the bucket in the same region as where the workload will run. + +* Create the bucket in the same project as the workload’s VM(s). Ensure that + permissions are set so that the VM(s) can write to the bucket. + +* There is no need for replication for smaller workloads. Larger workloads + could benefit from replication. + +* Use “Standard” for the default storage class for the bucket. + +* Set the soft delete policy to its shortest: 7 days. + +* Set the object lifecycle to the expected duration of the workload run. + For example, if the workload is expected to run for 10 days, set the object + lifecycle to 10 days. That should cover restarts that occur during the entire + run. Use `age` for the lifecycle condition and `Delete` for the action. See + [Object Lifecycle Management](https://cloud.google.com/storage/docs/lifecycle) + for details. If the object lifecycle is not set, the cache will continue to + grow since there is no eviction mechanism implemented. + +* All encryption policies are supported. + +Assuming that `gs://jax-cache` is the GCS bucket, set `cache-location` as +follows: + +``` +import jax + +jax.config.update("jax_compilation_cache_dir", "gs://jax-cache") +``` diff --git a/docs/profiling.md b/docs/profiling.md index 86b539c0a7fb..6eceec8f54b8 100644 --- a/docs/profiling.md +++ b/docs/profiling.md @@ -1,5 +1,7 @@ # Profiling JAX programs + + ## Viewing program traces with Perfetto We can use the JAX profiler to generate traces of a JAX program that can be @@ -11,7 +13,7 @@ check out the Tensorboard profiler below. ```python with jax.profiler.trace("/tmp/jax-trace", create_perfetto_link=True): # Run the operations to be profiled - key = jax.random.PRNGKey(0) + key = jax.random.key(0) x = jax.random.normal(key, (5000, 5000)) y = x @ x y.block_until_ready() @@ -50,7 +52,7 @@ active for a portion of your script, you can shut it down by calling `jax.profiler.stop_server()`. Once the script is running and after the profiler server has started, we can -manually capture an trace by running: +manually capture and trace by running: ```bash $ python -m jax.collect_profile ``` @@ -107,7 +109,7 @@ import jax jax.profiler.start_trace("/tmp/tensorboard") # Run the operations to be profiled -key = jax.random.PRNGKey(0) +key = jax.random.key(0) x = jax.random.normal(key, (5000, 5000)) y = x @ x y.block_until_ready() @@ -126,7 +128,7 @@ alternative to `start_trace` and `stop_trace`: import jax with jax.profiler.trace("/tmp/tensorboard"): - key = jax.random.PRNGKey(0) + key = jax.random.key(0) x = jax.random.normal(key, (5000, 5000)) y = x @ x y.block_until_ready() @@ -216,16 +218,6 @@ from a running program. You can also use the `memory_viewer`, `op_profile`, and `graph_viewer` tools.

-### Concurrent kernel tracing on GPU - -By default, traces captured on GPU in a mode that prevents CUDA kernels from -running concurrently. This allows for more accurate kernel timings, but removes -any concurrency between streams (for example, between compute and -communication). To enable concurrent kernel tracing, set the environment -variable `TF_GPU_CUPTI_FORCE_CONCURRENT_KERNEL=1` when launching the JAX -program. - - ### Adding custom trace events By default, the events in the trace viewer are mostly low-level internal JAX diff --git a/docs/pytrees.md b/docs/pytrees.md index f7c491a10c95..a39c36db5de6 100644 --- a/docs/pytrees.md +++ b/docs/pytrees.md @@ -16,6 +16,8 @@ language_info: # Pytrees + + ## What is a pytree? In JAX, we use the term *pytree* to refer to a tree-like structure built out of @@ -280,7 +282,7 @@ class RegisteredSpecial2(Special): show_example(RegisteredSpecial2(1., 2.)) ``` -When defining an unflattening functions, in general `children` should contain all the +When defining unflattening functions, in general `children` should contain all the dynamic elements of the data structure (arrays, dynamic scalars, and pytrees), while `aux_data` should contain all the static elements that will be rolled into the `treedef` structure. JAX sometimes needs to compare `treedef` for equality, or compute its hash diff --git a/docs/tutorials/quickstart.md b/docs/quickstart.md similarity index 59% rename from docs/tutorials/quickstart.md rename to docs/quickstart.md index 13f923969842..91ac5a63be20 100644 --- a/docs/tutorials/quickstart.md +++ b/docs/quickstart.md @@ -5,7 +5,7 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.16.0 + jupytext_version: 1.16.1 kernelspec: display_name: Python 3 language: python @@ -14,7 +14,9 @@ kernelspec: # Quickstart -**JAX a library for array-oriented numerical computation (*a la* [NumPy](https://numpy.org/)), with automatic differentiation and JIT compilation to enable high-performance machine learning research**. + + +**JAX a library for array-oriented numerical computation (*à la* [NumPy](https://numpy.org/)), with automatic differentiation and JIT compilation to enable high-performance machine learning research**. This document provides a quick overview of essential JAX features, so you can get started with JAX quickly: @@ -27,9 +29,13 @@ This document provides a quick overview of essential JAX features, so you can ge JAX can be installed for CPU on Linux, Windows, and macOS directly from the [Python Package Index](https://pypi.org/project/jax/): ``` -pip install "jax[cpu]" +pip install jax +``` +or, for NVIDIA GPU: ``` -For more detailed installation information, including installation with GPU support, check out {ref}`installation`. +pip install -U "jax[cuda12]" +``` +For more detailed platform-specific installation information, check out {ref}`installation`. ## JAX as NumPy @@ -54,8 +60,8 @@ print(selu(x)) You'll find a few differences between JAX arrays and NumPy arrays once you begin digging-in; these are explored in [🔪 JAX - The Sharp Bits 🔪](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html). -## Using {func}`~jax.jit` to speed up functions -JAX runs transparently on the GPU or TPU (falling back to CPU if you don't have one). However, in the above example, JAX is dispatching kernels to the chip one operation at a time. If we have a sequence of operations, we can use the `@jit` decorator to compile multiple operations together using XLA. +## Just-in-time compilation with {func}`jax.jit` +JAX runs transparently on the GPU or TPU (falling back to CPU if you don't have one). However, in the above example, JAX is dispatching kernels to the chip one operation at a time. If we have a sequence of operations, we can use the {func}`jax.jit` function to compile this sequence of operations together using XLA. We can use IPython's `%timeit` to quickly benchmark our `selu` function, using `block_until_ready()` to account for JAX's dynamic dispatch (See {ref}`async-dispatch`): @@ -71,8 +77,8 @@ x = random.normal(key, (1_000_000,)) (notice we've used {mod}`jax.random` to generate some random numbers; for details on how to generate random numbers in JAX, check out {ref}`pseudorandom-numbers`). -We can speed the execution of this function with `@jit`, which will jit-compile the -first time `selu` is called and will be cached thereafter. +We can speed the execution of this function with the {func}`jax.jit` transformation, +which will jit-compile the first time `selu` is called and will be cached thereafter. ```{code-cell} from jax import jit @@ -82,16 +88,16 @@ _ = selu_jit(x) # compiles on first call %timeit selu_jit(x).block_until_ready() ``` -The above timing represent execution on CPU, but the same code can be run on GPU or TPU for -an even greater speedup. +The above timing represent execution on CPU, but the same code can be run on GPU or TPU, +typically for an even greater speedup. For more on JIT compilation in JAX, check out {ref}`jit-compilation`. -## Taking derivatives with {func}`~jax.grad` +## Taking derivatives with {func}`jax.grad` -In addition to evaluating numerical functions, we can also to transform them. -One transformation is [automatic differentiation (autodiff)](https://en.wikipedia.org/wiki/Automatic_differentiation). -In JAX, you can compute gradients with the {func}`~jax.grad` function. +In addition to transforming functions via JIT compilation, JAX also provides other +transformations. One such transformation is {func}`jax.grad`, which performs +[automatic differentiation (autodiff)](https://en.wikipedia.org/wiki/Automatic_differentiation): ```{code-cell} from jax import grad @@ -121,8 +127,18 @@ In the above example we jitted `sum_logistic` and then took its derivative. We c print(grad(jit(grad(jit(grad(sum_logistic)))))(1.0)) ``` -For more advanced autodiff, you can use {func}`jax.vjp` for reverse-mode vector-Jacobian products and {func}`jax.jvp` for forward-mode Jacobian-vector products. +Beyond scalar-valued functions, the {func}`jax.jacobian` transformation can be +used to compute the full Jacobian matrix for vector-valued functions: + +```{code-cell} +from jax import jacobian +print(jacobian(jnp.exp)(x_small)) +``` + +For more advanced autodiff operations, you can use {func}`jax.vjp` for reverse-mode vector-Jacobian products, +and {func}`jax.jvp` and {func}`jax.linearize` for forward-mode Jacobian-vector products. The two can be composed arbitrarily with one another, and with other JAX transformations. +For example, {func}`jax.jvp` and {func}`jax.vjp` are used to define the forward-mode {func}`jax.jacfwd` and reverse-mode {func}`jax.jacrev` for computing Jacobians in forward- and reverse-mode, respectively. Here's one way to compose them to make a function that efficiently computes full Hessian matrices: ```{code-cell} @@ -136,24 +152,28 @@ This kind of composition produces efficient code in practice; this is more-or-le For more on automatic differentiation in JAX, check out {ref}`automatic-differentiation`. -## Auto-vectorization with {func}`~jax.vmap` +## Auto-vectorization with {func}`jax.vmap` Another useful transformation is {func}`~jax.vmap`, the vectorizing map. -It has the familiar semantics of mapping a function along array axes, but instead of keeping the loop on the outside, it pushes the loop down into a function’s primitive operations for better performance. -When composed with {func}`~jax.jit`, it can be just as fast as adding the batch dimensions manually. +It has the familiar semantics of mapping a function along array axes, but instead of explicitly looping +over function calls, it transforms the function into a natively vectorized version for better performance. +When composed with {func}`~jax.jit`, it can be just as performant as manually rewriting your function +operate over an extra batch dimension. We're going to work with a simple example, and promote matrix-vector products into matrix-matrix products using {func}`~jax.vmap`. Although this is easy to do by hand in this specific case, the same technique can apply to more complicated functions. ```{code-cell} -mat = random.normal(key, (150, 100)) -batched_x = random.normal(key, (10, 100)) +key1, key2 = random.split(key) +mat = random.normal(key1, (150, 100)) +batched_x = random.normal(key2, (10, 100)) -def apply_matrix(v): - return jnp.dot(mat, v) +def apply_matrix(x): + return jnp.dot(mat, x) ``` -Given a function such as `apply_matrix`, we can loop over a batch dimension in Python, but usually the performance of doing so is poor. +The `apply_matrix` function maps a vector to a vector, but we may want to apply it row-wise across a matrix. +We could do this by looping over the batch dimension in Python, but this usually results in poor performance. ```{code-cell} def naively_batched_apply_matrix(v_batched): @@ -163,32 +183,40 @@ print('Naively batched') %timeit naively_batched_apply_matrix(batched_x).block_until_ready() ``` -We know how to batch this operation manually. -In this case, `jnp.dot` handles extra batch dimensions transparently. +A programmer familiar with the the `jnp.dot` function might recognize that `apply_matrix` can +be rewritten to avoid explicit looping, using the built-in batching semantics of `jnp.dot`: ```{code-cell} +import numpy as np + @jit -def batched_apply_matrix(v_batched): - return jnp.dot(v_batched, mat.T) +def batched_apply_matrix(batched_x): + return jnp.dot(batched_x, mat.T) +np.testing.assert_allclose(naively_batched_apply_matrix(batched_x), + batched_apply_matrix(batched_x), atol=1E-4, rtol=1E-4) print('Manually batched') %timeit batched_apply_matrix(batched_x).block_until_ready() ``` -However, suppose we had a more complicated function without batching support. We can use {func}`~jax.vmap` to add batching support automatically. +However, as functions become more complicated, this kind of manual batching becomes more difficult and error-prone. +The {func}`~jax.vmap` transformation is designed to automatically transform a function into a batch-aware version: ```{code-cell} from jax import vmap @jit -def vmap_batched_apply_matrix(v_batched): - return vmap(apply_matrix)(v_batched) +def vmap_batched_apply_matrix(batched_x): + return vmap(apply_matrix)(batched_x) +np.testing.assert_allclose(naively_batched_apply_matrix(batched_x), + vmap_batched_apply_matrix(batched_x), atol=1E-4, rtol=1E-4) print('Auto-vectorized with vmap') %timeit vmap_batched_apply_matrix(batched_x).block_until_ready() ``` -Of course, {func}`~jax.vmap` can be arbitrarily composed with {func}`~jax.jit`, {func}`~jax.grad`, and any other JAX transformation. +As you would expect, {func}`~jax.vmap` can be arbitrarily composed with {func}`~jax.jit`, +{func}`~jax.grad`, and any other JAX transformation. For more on automatic vectorization in JAX, check out {ref}`automatic-vectorization`. diff --git a/docs/tutorials/random-numbers.md b/docs/random-numbers.md similarity index 92% rename from docs/tutorials/random-numbers.md rename to docs/random-numbers.md index f41c1caf5c26..85bb5ce01974 100644 --- a/docs/tutorials/random-numbers.md +++ b/docs/random-numbers.md @@ -5,7 +5,7 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.16.0 + jupytext_version: 1.16.1 kernelspec: display_name: Python 3 language: python @@ -15,6 +15,8 @@ kernelspec: (pseudorandom-numbers)= # Pseudorandom numbers + + In this section we focus on {mod}`jax.random` and pseudo random number generation (PRNG); that is, the process of algorithmically generating sequences of numbers whose properties approximate the properties of sequences of random numbers sampled from an appropriate distribution. PRNG-generated sequences are not truly random because they are actually determined by their initial value, which is typically referred to as the `seed`, and each step of random sampling is a deterministic function of some `state` that is carried over from a sample to the next. @@ -26,7 +28,7 @@ To better understand the difference between the approaches taken by JAX and NumP ## Random numbers in NumPy Pseudo random number generation is natively supported in NumPy by the {mod}`numpy.random` module. -In NumPy, pseudo random number generation is based on a global `state`, which can be set to a deterministic initial condition using {func}`np.random.seed`. +In NumPy, pseudo random number generation is based on a global `state`, which can be set to a deterministic initial condition using {func}`numpy.random.seed`. ```{code-cell} import numpy as np @@ -192,4 +194,17 @@ key = random.key(42) print("all at once: ", random.normal(key, shape=(3,))) ``` -Note that contrary to our recommendation above, we use `key` directly as an input to {func}`random.normal` in the second example. This is because we won't reuse it anywhere else, so we don't violate the single-use principle. +The lack of sequential equivalence gives us freedom to write code more efficiently; for example, +instead of generating `sequence` above via a sequential loop, we can use {func}`jax.vmap` to +compute the same result in a vectorized manner: + +```{code-cell} +import jax +print("vectorized:", jax.vmap(random.normal)(subkeys)) +``` + +## Next Steps + +For more information on JAX random numbers, refer to the documentation of the {mod}`jax.random` +module. If you're interested in the details of the design of JAX's random number generator, +see {ref}`prng-design-jep`. diff --git a/docs/rank_promotion_warning.rst b/docs/rank_promotion_warning.rst index e81509e2a941..5e4e7ec65cbc 100644 --- a/docs/rank_promotion_warning.rst +++ b/docs/rank_promotion_warning.rst @@ -40,8 +40,8 @@ One is by using :code:`jax.config` in your code: .. code-block:: python - from jax import config - config.update("jax_numpy_rank_promotion", "warn") + import jax + jax.config.update("jax_numpy_rank_promotion", "warn") You can also set the option using the environment variable :code:`JAX_NUMPY_RANK_PROMOTION`, for example as diff --git a/docs/requirements.txt b/docs/requirements.txt index 104a7e06d9d3..643d4086d8be 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,11 +1,11 @@ absl-py ipython>=8.8.0 # 8.7.0 has ipython3 lexer error -sphinx>=6.0.0 -sphinx-autodoc-typehints +sphinx>=7.3.2 # 7.3.0 breaks sphinx-book-theme sphinx-book-theme>=1.0.1 # Older versions fail to pin pydata-sphinx-theme sphinx-copybutton>=0.5.0 sphinx-remove-toctrees sphinx-design +sphinxext-rediraffe myst-nb>=1.0.0 # Packages used for CI tests. diff --git a/docs/sharded-computation.ipynb b/docs/sharded-computation.ipynb new file mode 100644 index 000000000000..8fa2107795fd --- /dev/null +++ b/docs/sharded-computation.ipynb @@ -0,0 +1,779 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "(sharded-computation)=\n", + "# Introduction to sharded computation\n", + "\n", + "\n", + "\n", + "This tutorial serves as an introduction to device parallelism for Single-Program Multi-Data (SPMD) code in JAX. SPMD is a parallelism technique where the same computation, such as the forward pass of a neural network, can be run on different input data (for example, different inputs in a batch) in parallel on different devices, such as several GPUs or Google TPUs.\n", + "\n", + "The tutorial covers three modes of parallel computation:\n", + "\n", + "- _Automatic parallelism via {func}`jax.jit`_: The compiler chooses the optimal computation strategy (a.k.a. \"the compiler takes the wheel\").\n", + "- _Semi-automated parallelism_ using {func}`jax.jit` and {func}`jax.lax.with_sharding_constraint`\n", + "- _Fully manual parallelism with manual control using {func}`jax.experimental.shard_map.shard_map`_: `shard_map` enables per-device code and explicit communication collectives\n", + "\n", + "Using these schools of thought for SPMD, you can transform a function written for one device into a function that can run in parallel on multiple devices.\n", + "\n", + "If you are running these examples in a Google Colab notebook, make sure that your hardware accelerator is the latest Google TPU by checking your notebook settings: **Runtime** > **Change runtime type** > **Hardware accelerator** > **TPU v2** (which provides eight devices to work with)." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "outputId": "18905ae4-7b5e-4bb9-acb4-d8ab914cb456" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),\n", + " TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),\n", + " TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),\n", + " TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),\n", + " TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),\n", + " TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),\n", + " TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),\n", + " TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]" + ] + }, + "execution_count": 1, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import jax\n", + "jax.devices()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Key concept: Data sharding\n", + "\n", + "Key to all of the distributed computation approaches below is the concept of *data sharding*, which describes how data is laid out on the available devices.\n", + "\n", + "How can JAX can understand how the data is laid out across devices? JAX's datatype, the {class}`jax.Array` immutable array data structure, represents arrays with physical storage spanning one or multiple devices, and helps make parallelism a core feature of JAX. The {class}`jax.Array` object is designed with distributed data and computation in mind. Every `jax.Array` has an associated {mod}`jax.sharding.Sharding` object, which describes which shard of the global data is required by each global device. When you create a {class}`jax.Array` from scratch, you also need to create its `Sharding`.\n", + "\n", + "In the simplest cases, arrays are sharded on a single device, as demonstrated below:" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "outputId": "39fdbb79-d5c0-4ea6-8b20-88b2c502a27a" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "{TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0)}" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import jax.numpy as jnp\n", + "arr = jnp.arange(32.0).reshape(4, 8)\n", + "arr.devices()" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "outputId": "536f773a-7ef4-4526-c58b-ab4d486bf5a1" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "SingleDeviceSharding(device=TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0))" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "arr.sharding" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "For a more visual representation of the storage layout, the {mod}`jax.debug` module provides some helpers to visualize the sharding of an array. For example, {func}`jax.debug.visualize_array_sharding` displays how the array is stored in memory of a single device:" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "outputId": "74a793e9-b13b-4d07-d8ec-7e25c547036d" + }, + "outputs": [ + { + "data": { + "text/html": [ + "
                                                  \n",
+       "                                                  \n",
+       "                                                  \n",
+       "                                                  \n",
+       "                                                  \n",
+       "                      TPU 0                       \n",
+       "                                                  \n",
+       "                                                  \n",
+       "                                                  \n",
+       "                                                  \n",
+       "                                                  \n",
+       "
\n" + ], + "text/plain": [ + "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\n", + "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\n", + "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\n", + "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\n", + "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\n", + "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;57;59;121mTPU 0\u001b[0m\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\n", + "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\n", + "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\n", + "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\n", + "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\n", + "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "jax.debug.visualize_array_sharding(arr)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "To create an array with a non-trivial sharding, you can define a {mod}`jax.sharding` specification for the array and pass this to {func}`jax.device_put`.\n", + "\n", + "Here, define a {class}`~jax.sharding.NamedSharding`, which specifies an N-dimensional grid of devices with named axes, where {class}`jax.sharding.Mesh` allows for precise device placement:" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "outputId": "0b397dba-3ddc-4aca-f002-2beab7e6b8a5" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "NamedSharding(mesh=Mesh('x': 2, 'y': 4), spec=PartitionSpec('x', 'y'))\n" + ] + } + ], + "source": [ + "# Pardon the boilerplate; constructing a sharding will become easier in future!\n", + "from jax.sharding import Mesh\n", + "from jax.sharding import PartitionSpec\n", + "from jax.sharding import NamedSharding\n", + "from jax.experimental import mesh_utils\n", + "\n", + "P = jax.sharding.PartitionSpec\n", + "devices = mesh_utils.create_device_mesh((2, 4))\n", + "mesh = jax.sharding.Mesh(devices, ('x', 'y'))\n", + "sharding = jax.sharding.NamedSharding(mesh, P('x', 'y'))\n", + "print(sharding)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Passing this `Sharding` object to {func}`jax.device_put`, you can obtain a sharded array:" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "outputId": "c8ceedba-05ca-4156-e6e4-1e98bb664a66" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[[ 0. 1. 2. 3. 4. 5. 6. 7.]\n", + " [ 8. 9. 10. 11. 12. 13. 14. 15.]\n", + " [16. 17. 18. 19. 20. 21. 22. 23.]\n", + " [24. 25. 26. 27. 28. 29. 30. 31.]]\n" + ] + }, + { + "data": { + "text/html": [ + "
                                                \n",
+       "                                                \n",
+       "   TPU 0       TPU 1       TPU 2       TPU 3    \n",
+       "                                                \n",
+       "                                                \n",
+       "                                                \n",
+       "                                                \n",
+       "                                                \n",
+       "   TPU 6       TPU 7       TPU 4       TPU 5    \n",
+       "                                                \n",
+       "                                                \n",
+       "                                                \n",
+       "
\n" + ], + "text/plain": [ + "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;214;97;107m \u001b[0m\u001b[38;2;255;255;255;48;2;140;162;82m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\n", + "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;214;97;107m \u001b[0m\u001b[38;2;255;255;255;48;2;140;162;82m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\n", + "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;57;59;121mTPU 0\u001b[0m\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;214;97;107m \u001b[0m\u001b[38;2;255;255;255;48;2;214;97;107mTPU 1\u001b[0m\u001b[38;2;255;255;255;48;2;214;97;107m \u001b[0m\u001b[38;2;255;255;255;48;2;140;162;82m \u001b[0m\u001b[38;2;255;255;255;48;2;140;162;82mTPU 2\u001b[0m\u001b[38;2;255;255;255;48;2;140;162;82m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214mTPU 3\u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\n", + "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;214;97;107m \u001b[0m\u001b[38;2;255;255;255;48;2;140;162;82m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\n", + "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;214;97;107m \u001b[0m\u001b[38;2;255;255;255;48;2;140;162;82m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\n", + "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;214;97;107m \u001b[0m\u001b[38;2;255;255;255;48;2;140;162;82m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\n", + "\u001b[38;2;0;0;0;48;2;231;203;148m \u001b[0m\u001b[38;2;255;255;255;48;2;107;110;207m \u001b[0m\u001b[38;2;255;255;255;48;2;165;81;148m \u001b[0m\u001b[38;2;255;255;255;48;2;140;109;49m \u001b[0m\n", + "\u001b[38;2;0;0;0;48;2;231;203;148m \u001b[0m\u001b[38;2;255;255;255;48;2;107;110;207m \u001b[0m\u001b[38;2;255;255;255;48;2;165;81;148m \u001b[0m\u001b[38;2;255;255;255;48;2;140;109;49m \u001b[0m\n", + "\u001b[38;2;0;0;0;48;2;231;203;148m \u001b[0m\u001b[38;2;0;0;0;48;2;231;203;148mTPU 6\u001b[0m\u001b[38;2;0;0;0;48;2;231;203;148m \u001b[0m\u001b[38;2;255;255;255;48;2;107;110;207m \u001b[0m\u001b[38;2;255;255;255;48;2;107;110;207mTPU 7\u001b[0m\u001b[38;2;255;255;255;48;2;107;110;207m \u001b[0m\u001b[38;2;255;255;255;48;2;165;81;148m \u001b[0m\u001b[38;2;255;255;255;48;2;165;81;148mTPU 4\u001b[0m\u001b[38;2;255;255;255;48;2;165;81;148m \u001b[0m\u001b[38;2;255;255;255;48;2;140;109;49m \u001b[0m\u001b[38;2;255;255;255;48;2;140;109;49mTPU 5\u001b[0m\u001b[38;2;255;255;255;48;2;140;109;49m \u001b[0m\n", + "\u001b[38;2;0;0;0;48;2;231;203;148m \u001b[0m\u001b[38;2;255;255;255;48;2;107;110;207m \u001b[0m\u001b[38;2;255;255;255;48;2;165;81;148m \u001b[0m\u001b[38;2;255;255;255;48;2;140;109;49m \u001b[0m\n", + "\u001b[38;2;0;0;0;48;2;231;203;148m \u001b[0m\u001b[38;2;255;255;255;48;2;107;110;207m \u001b[0m\u001b[38;2;255;255;255;48;2;165;81;148m \u001b[0m\u001b[38;2;255;255;255;48;2;140;109;49m \u001b[0m\n", + "\u001b[38;2;0;0;0;48;2;231;203;148m \u001b[0m\u001b[38;2;255;255;255;48;2;107;110;207m \u001b[0m\u001b[38;2;255;255;255;48;2;165;81;148m \u001b[0m\u001b[38;2;255;255;255;48;2;140;109;49m \u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "arr_sharded = jax.device_put(arr, sharding)\n", + "\n", + "print(arr_sharded)\n", + "jax.debug.visualize_array_sharding(arr_sharded)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The device numbers here are not in numerical order, because the mesh reflects the underlying toroidal topology of the device.\n", + "\n", + "## 1. Automatic parallelism via `jit`\n", + "\n", + "Once you have sharded data, the easiest way to do parallel computation is to simply pass the data to a {func}`jax.jit`-compiled function! In JAX, you need to only specify how you want the input and output of your code to be partitioned, and the compiler will figure out how to: 1) partition everything inside; and 2) compile inter-device communications.\n", + "\n", + "The XLA compiler behind `jit` includes heuristics for optimizing computations across multiple devices.\n", + "In the simplest of cases, those heuristics boil down to *computation follows data*.\n", + "\n", + "To demonstrate how auto-parallelization works in JAX, below is an example that uses a {func}`jax.jit`-decorated staged-out function: it's a simple element-wise function, where the computation for each shard will be performed on the device associated with that shard, and the output is sharded in the same way:" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "outputId": "de46f86a-6907-49c8-f36c-ed835e78bc3d" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "shardings match: True\n" + ] + } + ], + "source": [ + "@jax.jit\n", + "def f_elementwise(x):\n", + " return 2 * jnp.sin(x) + 1\n", + "\n", + "result = f_elementwise(arr_sharded)\n", + "\n", + "print(\"shardings match:\", result.sharding == arr_sharded.sharding)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "As computations get more complex, the compiler makes decisions about how to best propagate the sharding of the data.\n", + "\n", + "Here, you sum along the leading axis of `x`, and visualize how the result values are stored across multiple devices (with {func}`jax.debug.visualize_array_sharding`):" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "outputId": "90c3b997-3653-4a7b-c8ff-12a270f11d02" + }, + "outputs": [ + { + "data": { + "text/html": [ + "
 TPU 0,6  TPU 1,7  TPU 2,4  TPU 3,5 \n",
+       "                                    \n",
+       "
\n" + ], + "text/plain": [ + "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;57;59;121mTPU 0,6\u001b[0m\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214mTPU 1,7\u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\u001b[38;2;255;255;255;48;2;173;73;74mTPU 2,4\u001b[0m\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107mTPU 3,5\u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\n", + "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[48. 52. 56. 60. 64. 68. 72. 76.]\n" + ] + } + ], + "source": [ + "@jax.jit\n", + "def f_contract(x):\n", + " return x.sum(axis=0)\n", + "\n", + "result = f_contract(arr_sharded)\n", + "jax.debug.visualize_array_sharding(result)\n", + "print(result)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The result is partially replicated: that is, the first two elements of the array are replicated on devices `0` and `6`, the second on `1` and `7`, and so on.\n", + "\n", + "## 2. Semi-automated sharding with constraints\n", + "\n", + "If you'd like to have some control over the sharding used within a particular computation, JAX offers the {func}`~jax.lax.with_sharding_constraint` function. You can use {func}`jax.lax.with_sharding_constraint` (in place of (func}`jax.device_put()`) together with {func}`jax.jit` for more control over how the compiler constraints how the intermediate values and outputs are distributed.\n", + "\n", + "For example, suppose that within `f_contract` above, you'd prefer the output not to be partially-replicated, but rather to be fully sharded across the eight devices:" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": { + "outputId": "8468f5c6-76ca-4367-c9f2-93c723687cfd" + }, + "outputs": [ + { + "data": { + "text/html": [ + "
  TPU 0    TPU 1    TPU 2    TPU 3    TPU 6    TPU 7    TPU 4    TPU 5  \n",
+       "                                                                        \n",
+       "
\n" + ], + "text/plain": [ + "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;57;59;121mTPU 0\u001b[0m\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;214;97;107m \u001b[0m\u001b[38;2;255;255;255;48;2;214;97;107mTPU 1\u001b[0m\u001b[38;2;255;255;255;48;2;214;97;107m \u001b[0m\u001b[38;2;255;255;255;48;2;140;162;82m \u001b[0m\u001b[38;2;255;255;255;48;2;140;162;82mTPU 2\u001b[0m\u001b[38;2;255;255;255;48;2;140;162;82m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214mTPU 3\u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;0;0;0;48;2;231;203;148m \u001b[0m\u001b[38;2;0;0;0;48;2;231;203;148mTPU 6\u001b[0m\u001b[38;2;0;0;0;48;2;231;203;148m \u001b[0m\u001b[38;2;255;255;255;48;2;107;110;207m \u001b[0m\u001b[38;2;255;255;255;48;2;107;110;207mTPU 7\u001b[0m\u001b[38;2;255;255;255;48;2;107;110;207m \u001b[0m\u001b[38;2;255;255;255;48;2;165;81;148m \u001b[0m\u001b[38;2;255;255;255;48;2;165;81;148mTPU 4\u001b[0m\u001b[38;2;255;255;255;48;2;165;81;148m \u001b[0m\u001b[38;2;255;255;255;48;2;140;109;49m \u001b[0m\u001b[38;2;255;255;255;48;2;140;109;49mTPU 5\u001b[0m\u001b[38;2;255;255;255;48;2;140;109;49m \u001b[0m\n", + "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;214;97;107m \u001b[0m\u001b[38;2;255;255;255;48;2;140;162;82m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;0;0;0;48;2;231;203;148m \u001b[0m\u001b[38;2;255;255;255;48;2;107;110;207m \u001b[0m\u001b[38;2;255;255;255;48;2;165;81;148m \u001b[0m\u001b[38;2;255;255;255;48;2;140;109;49m \u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[48. 52. 56. 60. 64. 68. 72. 76.]\n" + ] + } + ], + "source": [ + "@jax.jit\n", + "def f_contract_2(x):\n", + " out = x.sum(axis=0)\n", + " # mesh = jax.create_mesh((8,), 'x')\n", + " devices = mesh_utils.create_device_mesh(8)\n", + " mesh = jax.sharding.Mesh(devices, 'x')\n", + " sharding = jax.sharding.NamedSharding(mesh, P('x'))\n", + " return jax.lax.with_sharding_constraint(out, sharding)\n", + "\n", + "result = f_contract_2(arr_sharded)\n", + "jax.debug.visualize_array_sharding(result)\n", + "print(result)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This gives you a function with the particular output sharding you'd like.\n", + "\n", + "## 3. Manual parallelism with `shard_map`\n", + "\n", + "In the automatic parallelism methods explored above, you can write a function as if you're operating on the full dataset, and `jit` will split that computation across multiple devices. By contrast, with {func}`jax.experimental.shard_map.shard_map` you write the function that will handle a single shard of data, and `shard_map` will construct the full function.\n", + "\n", + "`shard_map` works by mapping a function across a particular *mesh* of devices (`shard_map` maps over shards). In the example below:\n", + "\n", + "- As before, {class}`jax.sharding.Mesh` allows for precise device placement, with the axis names parameter for logical and physical axis names.\n", + "- The `in_specs` argument determines the shard sizes. The `out_specs` argument identifies how the blocks are assembled back together.\n", + "\n", + "**Note:** {func}`jax.experimental.shard_map.shard_map` code can work inside {func}`jax.jit` if you need it." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": { + "outputId": "435c32f3-557a-4676-c11b-17e6bab8c1e2" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "Array([ 1. , 2.682942 , 2.818595 , 1.28224 , -0.513605 ,\n", + " -0.9178486 , 0.44116896, 2.3139732 , 2.9787164 , 1.824237 ,\n", + " -0.08804226, -0.99998045, -0.07314599, 1.8403342 , 2.9812148 ,\n", + " 2.3005757 , 0.42419332, -0.92279506, -0.50197446, 1.2997544 ,\n", + " 2.8258905 , 2.6733112 , 0.98229736, -0.69244075, -0.81115675,\n", + " 0.7352965 , 2.525117 , 2.912752 , 1.5418116 , -0.32726777,\n", + " -0.97606325, 0.19192469], dtype=float32)" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from jax.experimental.shard_map import shard_map\n", + "P = jax.sharding.PartitionSpec\n", + "mesh = jax.sharding.Mesh(jax.devices(), 'x')\n", + "\n", + "f_elementwise_sharded = shard_map(\n", + " f_elementwise,\n", + " mesh=mesh,\n", + " in_specs=P('x'),\n", + " out_specs=P('x'))\n", + "\n", + "arr = jnp.arange(32)\n", + "f_elementwise_sharded(arr)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The function you write only \"sees\" a single batch of the data, which you can check by printing the device local shape:" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": { + "outputId": "99a3dc6e-154a-4ef6-8eaa-3dd0b68fb1da" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "global shape: x.shape=(32,)\n", + "device local shape: x.shape=(4,)\n" + ] + } + ], + "source": [ + "x = jnp.arange(32)\n", + "print(f\"global shape: {x.shape=}\")\n", + "\n", + "def f(x):\n", + " print(f\"device local shape: {x.shape=}\")\n", + " return x * 2\n", + "\n", + "y = shard_map(f, mesh=mesh, in_specs=P('x'), out_specs=P('x'))(x)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Because each of your functions only \"sees\" the device-local part of the data, it means that aggregation-like functions require some extra thought.\n", + "\n", + "For example, here's what a `shard_map` of a {func}`jax.numpy.sum` looks like:" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": { + "outputId": "1e9a45f5-5418-4246-c75b-f9bc6dcbbe72" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "Array([ 6, 22, 38, 54, 70, 86, 102, 118], dtype=int32)" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "def f(x):\n", + " return jnp.sum(x, keepdims=True)\n", + "\n", + "shard_map(f, mesh=mesh, in_specs=P('x'), out_specs=P('x'))(x)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Your function `f` operates separately on each shard, and the resulting summation reflects this.\n", + "\n", + "If you want to sum across shards, you need to explicitly request it using collective operations like {func}`jax.lax.psum`:" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": { + "outputId": "4fd29e80-4fee-42b7-ff80-29f9887ab38d" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "Array(496, dtype=int32)" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "def f(x):\n", + " sum_in_shard = x.sum()\n", + " return jax.lax.psum(sum_in_shard, 'x')\n", + "\n", + "shard_map(f, mesh=mesh, in_specs=P('x'), out_specs=P())(x)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Because the output no longer has a sharded dimension, set `out_specs=P()` (recall that the `out_specs` argument identifies how the blocks are assembled back together in `shard_map`).\n", + "\n", + "## Comparing the three approaches\n", + "\n", + "With these concepts fresh in our mind, let's compare the three approaches for a simple neural network layer.\n", + "\n", + "Start by defining your canonical function like this:" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": { + "id": "1TdhfTsoiqS1" + }, + "outputs": [], + "source": [ + "@jax.jit\n", + "def layer(x, weights, bias):\n", + " return jax.nn.sigmoid(x @ weights + bias)" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": { + "outputId": "f3007fe4-f6f3-454e-e7c5-3638de484c0a" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "Array([0.02138912, 0.893112 , 0.59892005, 0.97742504], dtype=float32)" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import numpy as np\n", + "rng = np.random.default_rng(0)\n", + "\n", + "x = rng.normal(size=(32,))\n", + "weights = rng.normal(size=(32, 4))\n", + "bias = rng.normal(size=(4,))\n", + "\n", + "layer(x, weights, bias)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "You can automatically run this in a distributed manner using {func}`jax.jit` and passing appropriately sharded data.\n", + "\n", + "If you shard the leading axis of both `x` and `weights` in the same way, then the matrix multiplication will automatically happen in parallel:" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": { + "outputId": "80be899e-8dbc-4bfc-acd2-0f3d554a0aa5" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "Array([0.02138912, 0.893112 , 0.59892005, 0.97742504], dtype=float32)" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "P = jax.sharding.PartitionSpec\n", + "mesh = jax.sharding.Mesh(jax.devices(), 'x')\n", + "sharding = jax.sharding.NamedSharding(mesh, P('x'))\n", + "\n", + "x_sharded = jax.device_put(x, sharding)\n", + "weights_sharded = jax.device_put(weights, sharding)\n", + "\n", + "layer(x_sharded, weights_sharded, bias)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Alternatively, you can use {func}`jax.lax.with_sharding_constraint` in the function to automatically distribute unsharded inputs:" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": { + "outputId": "bb63e8da-ff4f-4e95-f083-10584882daf4" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "Array([0.02138914, 0.89311206, 0.5989201 , 0.97742516], dtype=float32)" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "@jax.jit\n", + "def layer_auto(x, weights, bias):\n", + " x = jax.lax.with_sharding_constraint(x, sharding)\n", + " weights = jax.lax.with_sharding_constraint(weights, sharding)\n", + " return layer(x, weights, bias)\n", + "\n", + "layer_auto(x, weights, bias) # pass in unsharded inputs" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Finally, you can do the same thing with `shard_map`, using {func}`jax.lax.psum` to indicate the cross-shard collective required for the matrix product:" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": { + "outputId": "568d1c85-39a7-4dba-f09a-0e4f7c2ea918" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "Array([0.02138914, 0.89311206, 0.5989201 , 0.97742516], dtype=float32)" + ] + }, + "execution_count": 18, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from functools import partial\n", + "\n", + "@jax.jit\n", + "@partial(shard_map, mesh=mesh,\n", + " in_specs=(P('x'), P('x', None), P(None)),\n", + " out_specs=P(None))\n", + "def layer_sharded(x, weights, bias):\n", + " return jax.nn.sigmoid(jax.lax.psum(x @ weights, 'x') + bias)\n", + "\n", + "layer_sharded(x, weights, bias)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Next steps\n", + "\n", + "This tutorial serves as a brief introduction of sharded and parallel computation in JAX.\n", + "\n", + "To learn about each SPMD method in-depth, check out these docs:\n", + "- {doc}`../notebooks/Distributed_arrays_and_automatic_parallelization`\n", + "- {doc}`../notebooks/shard_map`" + ] + } + ], + "metadata": { + "accelerator": "TPU", + "colab": { + "gpuType": "V28", + "provenance": [], + "toc_visible": true + }, + "jupytext": { + "formats": "ipynb,md:myst" + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/docs/sharded-computation.md b/docs/sharded-computation.md new file mode 100644 index 000000000000..345ca7987b41 --- /dev/null +++ b/docs/sharded-computation.md @@ -0,0 +1,318 @@ +--- +jupytext: + formats: ipynb,md:myst + text_representation: + extension: .md + format_name: myst + format_version: 0.13 + jupytext_version: 1.16.1 +kernelspec: + display_name: Python 3 + name: python3 +--- + +(sharded-computation)= +# Introduction to sharded computation + + + +This tutorial serves as an introduction to device parallelism for Single-Program Multi-Data (SPMD) code in JAX. SPMD is a parallelism technique where the same computation, such as the forward pass of a neural network, can be run on different input data (for example, different inputs in a batch) in parallel on different devices, such as several GPUs or Google TPUs. + +The tutorial covers three modes of parallel computation: + +- _Automatic parallelism via {func}`jax.jit`_: The compiler chooses the optimal computation strategy (a.k.a. "the compiler takes the wheel"). +- _Semi-automated parallelism_ using {func}`jax.jit` and {func}`jax.lax.with_sharding_constraint` +- _Fully manual parallelism with manual control using {func}`jax.experimental.shard_map.shard_map`_: `shard_map` enables per-device code and explicit communication collectives + +Using these schools of thought for SPMD, you can transform a function written for one device into a function that can run in parallel on multiple devices. + +If you are running these examples in a Google Colab notebook, make sure that your hardware accelerator is the latest Google TPU by checking your notebook settings: **Runtime** > **Change runtime type** > **Hardware accelerator** > **TPU v2** (which provides eight devices to work with). + +```{code-cell} +:outputId: 18905ae4-7b5e-4bb9-acb4-d8ab914cb456 + +import jax +jax.devices() +``` + +## Key concept: Data sharding + +Key to all of the distributed computation approaches below is the concept of *data sharding*, which describes how data is laid out on the available devices. + +How can JAX can understand how the data is laid out across devices? JAX's datatype, the {class}`jax.Array` immutable array data structure, represents arrays with physical storage spanning one or multiple devices, and helps make parallelism a core feature of JAX. The {class}`jax.Array` object is designed with distributed data and computation in mind. Every `jax.Array` has an associated {mod}`jax.sharding.Sharding` object, which describes which shard of the global data is required by each global device. When you create a {class}`jax.Array` from scratch, you also need to create its `Sharding`. + +In the simplest cases, arrays are sharded on a single device, as demonstrated below: + +```{code-cell} +:outputId: 39fdbb79-d5c0-4ea6-8b20-88b2c502a27a + +import jax.numpy as jnp +arr = jnp.arange(32.0).reshape(4, 8) +arr.devices() +``` + +```{code-cell} +:outputId: 536f773a-7ef4-4526-c58b-ab4d486bf5a1 + +arr.sharding +``` + +For a more visual representation of the storage layout, the {mod}`jax.debug` module provides some helpers to visualize the sharding of an array. For example, {func}`jax.debug.visualize_array_sharding` displays how the array is stored in memory of a single device: + +```{code-cell} +:outputId: 74a793e9-b13b-4d07-d8ec-7e25c547036d + +jax.debug.visualize_array_sharding(arr) +``` + +To create an array with a non-trivial sharding, you can define a {mod}`jax.sharding` specification for the array and pass this to {func}`jax.device_put`. + +Here, define a {class}`~jax.sharding.NamedSharding`, which specifies an N-dimensional grid of devices with named axes, where {class}`jax.sharding.Mesh` allows for precise device placement: + +```{code-cell} +:outputId: 0b397dba-3ddc-4aca-f002-2beab7e6b8a5 + +# Pardon the boilerplate; constructing a sharding will become easier in future! +from jax.sharding import Mesh +from jax.sharding import PartitionSpec +from jax.sharding import NamedSharding +from jax.experimental import mesh_utils + +P = jax.sharding.PartitionSpec +devices = mesh_utils.create_device_mesh((2, 4)) +mesh = jax.sharding.Mesh(devices, ('x', 'y')) +sharding = jax.sharding.NamedSharding(mesh, P('x', 'y')) +print(sharding) +``` + +Passing this `Sharding` object to {func}`jax.device_put`, you can obtain a sharded array: + +```{code-cell} +:outputId: c8ceedba-05ca-4156-e6e4-1e98bb664a66 + +arr_sharded = jax.device_put(arr, sharding) + +print(arr_sharded) +jax.debug.visualize_array_sharding(arr_sharded) +``` + +The device numbers here are not in numerical order, because the mesh reflects the underlying toroidal topology of the device. + +## 1. Automatic parallelism via `jit` + +Once you have sharded data, the easiest way to do parallel computation is to simply pass the data to a {func}`jax.jit`-compiled function! In JAX, you need to only specify how you want the input and output of your code to be partitioned, and the compiler will figure out how to: 1) partition everything inside; and 2) compile inter-device communications. + +The XLA compiler behind `jit` includes heuristics for optimizing computations across multiple devices. +In the simplest of cases, those heuristics boil down to *computation follows data*. + +To demonstrate how auto-parallelization works in JAX, below is an example that uses a {func}`jax.jit`-decorated staged-out function: it's a simple element-wise function, where the computation for each shard will be performed on the device associated with that shard, and the output is sharded in the same way: + +```{code-cell} +:outputId: de46f86a-6907-49c8-f36c-ed835e78bc3d + +@jax.jit +def f_elementwise(x): + return 2 * jnp.sin(x) + 1 + +result = f_elementwise(arr_sharded) + +print("shardings match:", result.sharding == arr_sharded.sharding) +``` + +As computations get more complex, the compiler makes decisions about how to best propagate the sharding of the data. + +Here, you sum along the leading axis of `x`, and visualize how the result values are stored across multiple devices (with {func}`jax.debug.visualize_array_sharding`): + +```{code-cell} +:outputId: 90c3b997-3653-4a7b-c8ff-12a270f11d02 + +@jax.jit +def f_contract(x): + return x.sum(axis=0) + +result = f_contract(arr_sharded) +jax.debug.visualize_array_sharding(result) +print(result) +``` + +The result is partially replicated: that is, the first two elements of the array are replicated on devices `0` and `6`, the second on `1` and `7`, and so on. + +## 2. Semi-automated sharding with constraints + +If you'd like to have some control over the sharding used within a particular computation, JAX offers the {func}`~jax.lax.with_sharding_constraint` function. You can use {func}`jax.lax.with_sharding_constraint` (in place of (func}`jax.device_put()`) together with {func}`jax.jit` for more control over how the compiler constraints how the intermediate values and outputs are distributed. + +For example, suppose that within `f_contract` above, you'd prefer the output not to be partially-replicated, but rather to be fully sharded across the eight devices: + +```{code-cell} +:outputId: 8468f5c6-76ca-4367-c9f2-93c723687cfd + +@jax.jit +def f_contract_2(x): + out = x.sum(axis=0) + # mesh = jax.create_mesh((8,), 'x') + devices = mesh_utils.create_device_mesh(8) + mesh = jax.sharding.Mesh(devices, 'x') + sharding = jax.sharding.NamedSharding(mesh, P('x')) + return jax.lax.with_sharding_constraint(out, sharding) + +result = f_contract_2(arr_sharded) +jax.debug.visualize_array_sharding(result) +print(result) +``` + +This gives you a function with the particular output sharding you'd like. + +## 3. Manual parallelism with `shard_map` + +In the automatic parallelism methods explored above, you can write a function as if you're operating on the full dataset, and `jit` will split that computation across multiple devices. By contrast, with {func}`jax.experimental.shard_map.shard_map` you write the function that will handle a single shard of data, and `shard_map` will construct the full function. + +`shard_map` works by mapping a function across a particular *mesh* of devices (`shard_map` maps over shards). In the example below: + +- As before, {class}`jax.sharding.Mesh` allows for precise device placement, with the axis names parameter for logical and physical axis names. +- The `in_specs` argument determines the shard sizes. The `out_specs` argument identifies how the blocks are assembled back together. + +**Note:** {func}`jax.experimental.shard_map.shard_map` code can work inside {func}`jax.jit` if you need it. + +```{code-cell} +:outputId: 435c32f3-557a-4676-c11b-17e6bab8c1e2 + +from jax.experimental.shard_map import shard_map +P = jax.sharding.PartitionSpec +mesh = jax.sharding.Mesh(jax.devices(), 'x') + +f_elementwise_sharded = shard_map( + f_elementwise, + mesh=mesh, + in_specs=P('x'), + out_specs=P('x')) + +arr = jnp.arange(32) +f_elementwise_sharded(arr) +``` + +The function you write only "sees" a single batch of the data, which you can check by printing the device local shape: + +```{code-cell} +:outputId: 99a3dc6e-154a-4ef6-8eaa-3dd0b68fb1da + +x = jnp.arange(32) +print(f"global shape: {x.shape=}") + +def f(x): + print(f"device local shape: {x.shape=}") + return x * 2 + +y = shard_map(f, mesh=mesh, in_specs=P('x'), out_specs=P('x'))(x) +``` + +Because each of your functions only "sees" the device-local part of the data, it means that aggregation-like functions require some extra thought. + +For example, here's what a `shard_map` of a {func}`jax.numpy.sum` looks like: + +```{code-cell} +:outputId: 1e9a45f5-5418-4246-c75b-f9bc6dcbbe72 + +def f(x): + return jnp.sum(x, keepdims=True) + +shard_map(f, mesh=mesh, in_specs=P('x'), out_specs=P('x'))(x) +``` + +Your function `f` operates separately on each shard, and the resulting summation reflects this. + +If you want to sum across shards, you need to explicitly request it using collective operations like {func}`jax.lax.psum`: + +```{code-cell} +:outputId: 4fd29e80-4fee-42b7-ff80-29f9887ab38d + +def f(x): + sum_in_shard = x.sum() + return jax.lax.psum(sum_in_shard, 'x') + +shard_map(f, mesh=mesh, in_specs=P('x'), out_specs=P())(x) +``` + +Because the output no longer has a sharded dimension, set `out_specs=P()` (recall that the `out_specs` argument identifies how the blocks are assembled back together in `shard_map`). + +## Comparing the three approaches + +With these concepts fresh in our mind, let's compare the three approaches for a simple neural network layer. + +Start by defining your canonical function like this: + +```{code-cell} +:id: 1TdhfTsoiqS1 + +@jax.jit +def layer(x, weights, bias): + return jax.nn.sigmoid(x @ weights + bias) +``` + +```{code-cell} +:outputId: f3007fe4-f6f3-454e-e7c5-3638de484c0a + +import numpy as np +rng = np.random.default_rng(0) + +x = rng.normal(size=(32,)) +weights = rng.normal(size=(32, 4)) +bias = rng.normal(size=(4,)) + +layer(x, weights, bias) +``` + +You can automatically run this in a distributed manner using {func}`jax.jit` and passing appropriately sharded data. + +If you shard the leading axis of both `x` and `weights` in the same way, then the matrix multiplication will automatically happen in parallel: + +```{code-cell} +:outputId: 80be899e-8dbc-4bfc-acd2-0f3d554a0aa5 + +P = jax.sharding.PartitionSpec +mesh = jax.sharding.Mesh(jax.devices(), 'x') +sharding = jax.sharding.NamedSharding(mesh, P('x')) + +x_sharded = jax.device_put(x, sharding) +weights_sharded = jax.device_put(weights, sharding) + +layer(x_sharded, weights_sharded, bias) +``` + +Alternatively, you can use {func}`jax.lax.with_sharding_constraint` in the function to automatically distribute unsharded inputs: + +```{code-cell} +:outputId: bb63e8da-ff4f-4e95-f083-10584882daf4 + +@jax.jit +def layer_auto(x, weights, bias): + x = jax.lax.with_sharding_constraint(x, sharding) + weights = jax.lax.with_sharding_constraint(weights, sharding) + return layer(x, weights, bias) + +layer_auto(x, weights, bias) # pass in unsharded inputs +``` + +Finally, you can do the same thing with `shard_map`, using {func}`jax.lax.psum` to indicate the cross-shard collective required for the matrix product: + +```{code-cell} +:outputId: 568d1c85-39a7-4dba-f09a-0e4f7c2ea918 + +from functools import partial + +@jax.jit +@partial(shard_map, mesh=mesh, + in_specs=(P('x'), P('x', None), P(None)), + out_specs=P(None)) +def layer_sharded(x, weights, bias): + return jax.nn.sigmoid(jax.lax.psum(x @ weights, 'x') + bias) + +layer_sharded(x, weights, bias) +``` + +## Next steps + +This tutorial serves as a brief introduction of sharded and parallel computation in JAX. + +To learn about each SPMD method in-depth, check out these docs: +- {doc}`../notebooks/Distributed_arrays_and_automatic_parallelization` +- {doc}`../notebooks/shard_map` diff --git a/docs/jax-101/07-state.md b/docs/stateful-computations.md similarity index 60% rename from docs/jax-101/07-state.md rename to docs/stateful-computations.md index bd2d2aa390d9..5a8af2b74142 100644 --- a/docs/jax-101/07-state.md +++ b/docs/stateful-computations.md @@ -1,49 +1,40 @@ --- jupytext: - formats: ipynb,md:myst + formats: md:myst text_representation: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.16.0 + jupytext_version: 1.16.1 kernelspec: display_name: Python 3 + language: python name: python3 --- -+++ {"id": "Ga0xSM8xhBIm"} +# Stateful Computations -# Stateful Computations in JAX + -[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/jax-101/07-state.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/jax-101/07-state.ipynb) +JAX transformations like {func}`~jax.jit`, {func}`~jax.vmap`, {func}`~jax.grad`, require the functions +they wrap to be pure: that is, functions whose outputs depend *solely* on the inputs, and which have +no side effects such as updating of global state. +You can find a discussion of this in [JAX sharp bits: Pure functions](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#pure-functions). -*Authors: Vladimir Mikulik* +This constraint can pose some challenges in the context of machine learning, where state may exist in +many forms. For example: -This section explores how JAX constrains the implementation of stateful programs. - -+++ {"id": "Avjnyrjojo8z"} - -## Motivation - -In machine learning, program state most often comes in the form of: * model parameters, * optimizer state, and * stateful layers, such as [BatchNorm](https://en.wikipedia.org/wiki/Batch_normalization). -Some JAX transformations, most notably `jax.jit`, impose constraints on the functions they transform. In particular, the function transformed by `jax.jit` must have no side-effects. This is because any such side-effects will only be executed once, when the python version of the function is run during compilation. These side-effects will not be executed by the compiled function on subsequent runs. - -Changing program state is one kind of side-effect. So, if we can't have side effects, how do we update model parameters, the optimizer state, and use stateful layers in our models? This colab will explain this in detail, but the short answer is: with [functional programming](https://en.wikipedia.org/wiki/Functional_programming). - -+++ {"id": "s_-6semKkSzp"} +This section offers some advice of how to properly handle state in a JAX program. ## A simple example: Counter Let's start by looking at a simple stateful program: a counter. -```{code-cell} ipython3 -:id: B3aoCHpjg8gm -:outputId: 5cbcfbf5-5c42-498f-a175-050438518337 - +```{code-cell} import jax import jax.numpy as jnp @@ -69,16 +60,12 @@ for _ in range(3): print(counter.count()) ``` -+++ {"id": "SQ-RNLfdiw04"} - -The `n` attribute maintains the counter's _state_ between successive calls of `count`. It is modified as a side effect of calling `count`. +The counter's `n` attribute maintains the counter's _state_ between successive calls of `count`. It is modified as a side effect of calling `count`. -Let's say we want to count fast, so we `jax.jit` the `count` method. (In this example, this wouldn't actually help speed anyway, for many reasons, but treat this as a toy model of wanting to JIT-compile the update of model parameters, where `jax.jit` makes an enormous difference). - -```{code-cell} ipython3 -:id: 5jSjmJMon03W -:outputId: d952f16b-9b30-4753-ed94-cc914a929a36 +Let's say we want to count fast, so we JIT-compile the `count` method. +(In this example, this wouldn't actually help speed anyway, for many reasons, but treat this as a toy model of JIT-compiling the update of model parameters, where {func}`~jax.jit` makes an enormous difference). +```{code-cell} counter.reset() fast_count = jax.jit(counter.count) @@ -86,22 +73,19 @@ for _ in range(3): print(fast_count()) ``` -+++ {"id": "weiI0V7_pKGv"} - Oh no! Our counter isn't working. This is because the line ``` self.n += 1 ``` -in `count` is only called once, when JAX compiles the method call. Moreover, since the return value doesn't depend on the arguments to `count`, once it returns the first 1, subsequent calls to `fast_count` will always return 1. This won't do. So, how do we fix it? +in `count` involves a side effect: it modifies the input counter in-place, and so this function is not supported by `jit`. +Such side effects are executed only once when the function is first traced, and subsequent calls will not repeat the side effect. +So, how do we fix it? ## The solution: explicit state Part of the problem with our counter was that the returned value didn't depend on the arguments, meaning a constant was "baked into" the compiled output. But it shouldn't be a constant -- it should depend on the state. Well, then why don't we make the state into an argument? -```{code-cell} ipython3 -:id: 53pSdK4KoOEZ -:outputId: 5ac72b9c-7029-4bf2-de8d-1d412bd74c79 - +```{code-cell} CounterState = int class CounterV2: @@ -122,14 +106,9 @@ for _ in range(3): print(value) ``` -+++ {"id": "PrBjmgZtq89b"} - In this new version of `Counter`, we moved `n` to be an argument of `count`, and added another return value that represents the new, updated, state. To use this counter, we now need to keep track of the state explicitly. But in return, we can now safely `jax.jit` this counter: -```{code-cell} ipython3 -:id: LO4Xzcq_q8PH -:outputId: 25c06a56-f2bf-4c54-a3c3-6e093d484362 - +```{code-cell} state = counter.reset() fast_count = jax.jit(counter.count) @@ -138,8 +117,6 @@ for _ in range(3): print(value) ``` -+++ {"id": "MzMSWD2_sgnh"} - ## A general strategy We can apply the same process to any stateful method to convert it into a stateless one. We took a class of the form @@ -162,25 +139,25 @@ class StatelessClass This is a common [functional programming](https://en.wikipedia.org/wiki/Functional_programming) pattern, and, essentially, is the way that state is handled in all JAX programs. -Notice that the need for a class becomes less clear once we have rewritten it this way. We could just keep `stateless_method`, since the class is no longer doing any work. This is because, like the strategy we just applied, object-oriented programming (OOP) is a way to help programmers understand program state. +Notice that the need for a class becomes less clear once we have rewritten it this way. We could just keep `stateless_method`, since the class is no longer doing any work. +This is because, like the strategy we just applied, object-oriented programming (OOP) is a way to help programmers understand program state. In our case, the `CounterV2` class is nothing more than a namespace bringing all the functions that use `CounterState` into one location. Exercise for the reader: do you think it makes sense to keep it as a class? -Incidentally, you've already seen an example of this strategy in the JAX pseudo-randomness API, `jax.random`, shown in the [Random Numbers section](https://colab.research.google.com/github/google/jax/blob/main/docs/jax-101/05-random-numbers.ipynb). Unlike Numpy, which manages random state using stateful classes, JAX requires the programmer to work directly with the random generator state -- the PRNGKey. +Incidentally, you've already seen an example of this strategy in the JAX pseudo-randomness API, {mod}`jax.random`, shown in the :ref:`pseudorandom-numbers` section. +Unlike Numpy, which manages random state using implicitly updated stateful classes, JAX requires the programmer to work directly with the random generator state -- the PRNG key. -+++ {"id": "I2SqRx14_z98"} ## Simple worked example: Linear Regression Let's apply this strategy to a simple machine learning model: linear regression via gradient descent. -Here, we only deal with one kind of state: the model parameters. But generally, you'll see many kinds of state being threaded in and out of JAX functions, like optimizer state, layer statistics for batchnorm, and others. +Here, we only deal with one kind of state: the model parameters. +But generally, you'll see many kinds of state being threaded in and out of JAX functions, like optimizer state, layer statistics for batchnorm, and others. The function to look at carefully is `update`. -```{code-cell} ipython3 -:id: wQdU7DoAseW6 - +```{code-cell} from typing import NamedTuple class Params(NamedTuple): @@ -211,9 +188,9 @@ def update(params: Params, x: jnp.ndarray, y: jnp.ndarray) -> Params: # If we were using Adam or another stateful optimizer, # we would also do something like - # ``` - # updates, new_optimizer_state = optimizer(grad, optimizer_state) - # ``` + # + # updates, new_optimizer_state = optimizer(grad, optimizer_state) + # # and then use `updates` instead of `grad` to actually update the params. # (And we'd include `new_optimizer_state` in the output, naturally.) @@ -223,17 +200,12 @@ def update(params: Params, x: jnp.ndarray, y: jnp.ndarray) -> Params: return new_params ``` -+++ {"id": "dKySWouu2-Hu"} - Notice that we manually pipe the params in and out of the update function. -```{code-cell} ipython3 -:id: jQCYYy0yxO6K -:outputId: 1f3b69d2-e90b-4065-cbcc-6422978d25c2 - +```{code-cell} import matplotlib.pyplot as plt -rng = jax.random.PRNGKey(42) +rng = jax.random.key(42) # Generate true data from y = w*x + b + noise true_w, true_b = 2, -1 @@ -252,11 +224,9 @@ plt.plot(xs, params.weight * xs + params.bias, c='red', label='Model Prediction' plt.legend(); ``` -+++ {"id": "1wq3L6Xg1UHP"} - ## Taking it further -The strategy described above is how any (jitted) JAX program must handle state. +The strategy described above is how any JAX program must handle state when using transformations like `jit`, `vmap`, `grad`, etc. Handling parameters manually seems fine if you're dealing with two parameters, but what if it's a neural net with dozens of layers? You might already be getting worried about two things: diff --git a/docs/tutorials.rst b/docs/tutorials.rst new file mode 100644 index 000000000000..2f90e4226e50 --- /dev/null +++ b/docs/tutorials.rst @@ -0,0 +1,18 @@ +.. _jax-tutorials: + +JAX tutorials +============= + +.. toctree:: + :maxdepth: 1 + + quickstart + key-concepts + jit-compilation + automatic-vectorization + automatic-differentiation + debugging + random-numbers + working-with-pytrees + sharded-computation + stateful-computations diff --git a/docs/tutorials/advanced-debugging.md b/docs/tutorials/advanced-debugging.md deleted file mode 100644 index 64579654ea83..000000000000 --- a/docs/tutorials/advanced-debugging.md +++ /dev/null @@ -1,16 +0,0 @@ ---- -jupytext: - formats: md:myst - text_representation: - extension: .md - format_name: myst - format_version: 0.13 - jupytext_version: 1.16.0 -kernelspec: - display_name: Python 3 - language: python - name: python3 ---- - -(advanced-debugging)= -# Advanced debugging diff --git a/docs/tutorials/debugging.md b/docs/tutorials/debugging.md deleted file mode 100644 index cf4150aa57e5..000000000000 --- a/docs/tutorials/debugging.md +++ /dev/null @@ -1,162 +0,0 @@ ---- -jupytext: - formats: md:myst - text_representation: - extension: .md - format_name: myst - format_version: 0.13 - jupytext_version: 1.16.0 -kernelspec: - display_name: Python 3 - language: python - name: python3 ---- - -(debugging)= -# Debugging 101 - -This tutorial introduces you to a set of built-in JAX debugging methods — {func}`jax.debug.print`, {func}`jax.debug.breakpoint`, and {func}`jax.debug.callback` — that you can use with various JAX transformations. - -Let's begin with {func}`jax.debug.print`. - -## JAX `debug.print` for high-level debugging - -**TL;DR** Here is a rule of thumb: - -- Use {func}`jax.debug.print` for traced (dynamic) array values with {func}`jax.jit`, {func}`jax.vmap` and others. -- Use Python `print` for static values, such as dtypes and array shapes. - -With some JAX transformations, such as {func}`jax.grad` and {func}`jax.vmap`, you can use Python’s built-in `print` function to print out numerical values. However, with {func}`jax.jit` for example, you need to use {func}`jax.debug.print`, because those transformations delay numerical evaluation. - -Below is a basic example with {func}`jax.jit`: - -```{code-cell} -import jax -import jax.numpy as jnp - -@jax.jit -def f(x): - jax.debug.print("This is `jax.debug.print` of x {x}", x=x) - y = jnp.sin(x) - jax.debug.print("This is `jax.debug.print` of y {y} 🤯", y=y) - return y - -f(2.) -``` - -{func}`jax.debug.print` can reveal the information about how computations are evaluated. - -Here's an example with {func}`jax.vmap`: - -```{code-cell} -def f(x): - jax.debug.print("This is `jax.debug.print` of x: {}", x) - y = jnp.sin(x) - jax.debug.print("This is `jax.debug.print` of y: {}", y) - return y - -xs = jnp.arange(3.) - -jax.vmap(f)(xs) -``` - -Here's an example with {func}`jax.lax.map`: - -```{code-cell} -jax.lax.map(f, xs) -``` - -Notice the order is different, as {func}`jax.vmap` and {func}`jax.lax.map` compute the same results in different ways. When debugging, the evaluation order details are exactly what you may need to inspect. - -Below is an example with {func}`jax.grad`, where {func}`jax.debug.print` only prints the forward pass. In this case, the behavior is similar to Python's `print`, but it's consistent if you apply {func}`jax.jit` during the call. - -```{code-cell} -def f(x): - jax.debug.print("This is `jax.debug.print` of x: {}", x) - return x ** 2 - -jax.grad(f)(1.) -``` - -Sometimes, when the arguments don't depend on one another, calls to {func}`jax.debug.print` may print them in a different order when staged out with a JAX transformation. If you need the original order, such as `x: ...` first and then `y: ...` second, add the `ordered=True` parameter. - -For example: - -```{code-cell} -@jax.jit -def f(x, y): - jax.debug.print("This is `jax.debug.print of x: {}", x, ordered=True) - jax.debug.print("This is `jax.debug.print of y: {}", y, ordered=True) - return x + y -``` - -To learn more about {func}`jax.debug.print` and its Sharp Bits, refer to {ref}`advanced-debugging`. - - -## JAX `debug.breakpoint` for `pdb`-like debugging - -**TL;DR** Use {func}`jax.debug.breakpoint` to pause the execution of your JAX program to inspect values. - -To pause your compiled JAX program during certain points during debugging, you can use {func}`jax.debug.breakpoint`. The prompt is similar to Python `pdb`, and it allows you to inspect the values in the call stack. In fact, {func}`jax.debug.breakpoint` is an application of {func}`jax.debug.callback` that captures information about the call stack. - -To print all available commands during a `breakpoint` debugging session, use the `help` command. (Full debugger commands, the Sharp Bits, its strengths and limitations are covered in {ref}`advanced-debugging`.) - -Example: - -```{code-cell} -:tags: [raises-exception] - -def breakpoint_if_nonfinite(x): - is_finite = jnp.isfinite(x).all() - def true_fn(x): - pass - def false_fn(x): - jax.debug.breakpoint() - jax.lax.cond(is_finite, true_fn, false_fn, x) - -@jax.jit -def f(x, y): - z = x / y - breakpoint_if_nonfinite(z) - return z -f(2., 0.) # ==> Pauses during execution -``` - -![JAX debugger](../_static/debugger.gif) - -## JAX `debug.callback` for more control during debugging - -As mentioned in the beginning, {func}`jax.debug.print` is a small wrapper around {func}`jax.debug.callback`. The {func}`jax.debug.callback` method allows you to have greater control over string formatting and the debugging output, like printing or plotting. It is compatible with {func}`jax.jit`, {func}`jax.vmap`, {func}`jax.grad` and other transformations (refer to the {ref}`external-callbacks-flavors-of-callback` table in {ref]`external-callbacks` for more information). - -For example: - -```{code-cell} -def log_value(x): - print("log:", x) - -@jax.jit -def f(x): - jax.debug.callback(log_value, x) - return x - -f(1.0); -``` - -This callback is compatible with {func}`jax.vmap` and {func}`jax.grad`: - -```{code-cell} -x = jnp.arange(5.0) -jax.vmap(f)(x); -``` - -```{code-cell} -jax.grad(f)(1.0); -``` - -This can make {func}`jax.debug.callback` useful for general-purpose debugging. - -You can learn more about different flavors of JAX callbacks in {ref}`external-callbacks-flavors-of-callback` and {ref}`external-callbacks-exploring-debug-callback`. - -## Next steps - -Check out the {ref}`advanced-debugging` to learn more about debugging in JAX. diff --git a/docs/tutorials/index.rst b/docs/tutorials/index.rst deleted file mode 100644 index a51445292f07..000000000000 --- a/docs/tutorials/index.rst +++ /dev/null @@ -1,57 +0,0 @@ -:orphan: - -.. _jax-tutorials: - -JAX tutorials -============= - -.. note:: - - The tutorials below are a work in progress; for the time being, please refer - to the older tutorial content at :ref:`Jax-101`, :ref:`beginner-guide` and - :ref:`user-guides`. - -JAX 101 -------- - -.. toctree:: - :maxdepth: 1 - - installation - quickstart - jax-as-accelerated-numpy - thinking-in-jax - jit-compilation - automatic-vectorization - automatic-differentiation - debugging - random-numbers - working-with-pytrees - single-host-sharding - stateful-computations - simple-neural-network - - -JAX 201 -------- - -.. toctree:: - :maxdepth: 1 - - parallelism - advanced-autodiff - gradient-checkpointing - advanced-debugging - external-callbacks - profiling-and-performance - - -JAX 301 -------- - -.. toctree:: - :maxdepth: 1 - - jax-primitives - jaxpr - advanced-compilation diff --git a/docs/tutorials/installation.md b/docs/tutorials/installation.md deleted file mode 100644 index 34508c4fd2a3..000000000000 --- a/docs/tutorials/installation.md +++ /dev/null @@ -1,318 +0,0 @@ -(installation)= -# How to install JAX - -This guide provides instructions for: - -- Installing JAX binary packages for supported platforms using `pip` or `conda` -- Using Docker containers (for example {ref}`docker-containers-nvidia-gpu`) -- {ref}`building-jax-from-source` - -**TL;DR** For most users, a typical JAX installation may look something like this: - -| Hardware | Installation | -|------------------------------------|--------------------------------------------| -| CPU-only, Linux/macOS/Windows | `pip install -U "jax[cpu]"` | -| NVIDIA, CUDA 12, x86_64 | `pip install -U "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html`| - - -(install-supported-platforms)= -## Supported platforms - -The table below shows all supported platforms and installation options. Check if your setup is supported; and if it says _"yes"_ or _"experimental"_, then click on the corresponding link to learn how to install JAX in greater detail. - -| | Linux, x86_64 | Linux, aarch64 | macOS, Intel x86_64, AMD GPU | macOS, Apple Silicon, ARM-based | Windows, x86_64 | Windows WSL2, x86_64 | -|------------------|---------------------------------------|--------------------------------|----------------------------------------|----------------------------------------|-------------------------|-----------------------------------------| -| CPU | {ref}`yes ` | {ref}`yes ` | {ref}`yes `| {ref}`yes `| {ref}`yes ` | {ref}`yes `| -| NVIDIA GPU | {ref}`yes ` | {ref}`yes ` | no | n/a | no | {ref}`experimental ` | -| Google Cloud TPU | {ref}`yes ` | n/a | n/a | n/a | n/a | n/a | -| AMD GPU | {ref}`experimental ` | no | no | n/a | no | no | -| Apple GPU | n/a | no | {ref}`experimental ` | {ref}`experimental ` | n/a | n/a | - - -(install-cpu)= -## CPU - -### pip installation: CPU - -Currently, the JAX team releases `jaxlib` wheels for the following -operating systems and architectures: - -- Linux, x86_64 -- macOS, Intel -- macOS, Apple ARM-based -- Windows, x86_64 (*experimental*) - -To install a CPU-only version of JAX, which might be useful for doing local -development on a laptop, you can run: - -```bash -pip install --upgrade pip -pip install --upgrade "jax[cpu]" -``` - -On Windows, you may also need to install the -[Microsoft Visual Studio 2019 Redistributable](https://learn.microsoft.com/en-US/cpp/windows/latest-supported-vc-redist?view=msvc-170#visual-studio-2015-2017-2019-and-2022) -if it is not already installed on your machine. - -Other operating systems and architectures require building from source. Trying -to pip install on other operating systems and architectures may lead to `jaxlib` -not being installed alongside `jax`, although `jax` may successfully install -(but fail at runtime). - - -(install-nvidia-gpu)= -## NVIDIA GPU - -JAX supports NVIDIA GPUs that have SM version 5.2 (Maxwell) or newer. -Note that Kepler-series GPUs are no longer supported by JAX since -NVIDIA has dropped support for Kepler GPUs in its software. - -You must first install the NVIDIA driver. You're -recommended to install the newest driver available from NVIDIA, but the driver -version must be >= 525.60.13 for CUDA 12 and >= 450.80.02 for CUDA 11 on Linux. - -If you need to use a newer CUDA toolkit with an older driver, for example -on a cluster where you cannot update the NVIDIA driver easily, you may be -able to use the -[CUDA forward compatibility packages](https://docs.nvidia.com/deploy/cuda-compatibility/) -that NVIDIA provides for this purpose. - -### pip installation: NVIDIA GPU (CUDA, installed via pip, easier) - -There are two ways to install JAX with NVIDIA GPU support: - -- Using NVIDIA CUDA and cuDNN installed from pip wheels -- Using a self-installed CUDA/cuDNN - -The JAX team strongly recommends installing CUDA and cuDNN using the pip wheels, -since it is much easier! - -This method is only supported on x86_64, because NVIDIA has not released aarch64 -CUDA pip packages. - -```bash -pip install --upgrade pip - -# NVIDIA CUDA 12 installation -# Note: wheels only available on linux. -pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html - -# NVIDIA CUDA 11 installation -# Note: wheels only available on linux. -pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html -``` - -If JAX detects the wrong version of the NVIDIA CUDA libraries, there are several things -you need to check: - -* Make sure that `LD_LIBRARY_PATH` is not set, since `LD_LIBRARY_PATH` can - override the NVIDIA CUDA libraries. -* Make sure that the NVIDIA CUDA libraries installed are those requested by JAX. - Rerunning the installation command above should work. - -### pip installation: NVIDIA GPU (CUDA, installed locally, harder) - -If you prefer to use a preinstalled copy of NVIDIA CUDA, you must first -install NVIDIA [CUDA](https://developer.nvidia.com/cuda-downloads) and -[cuDNN](https://developer.nvidia.com/CUDNN). - -JAX provides pre-built CUDA-compatible wheels for **Linux x86_64 only**. Other -combinations of operating system and architecture are possible, but require -building from source (refer to {ref}`building-from-source` to learn more}. - -You should use an NVIDIA driver version that is at least as new as your -[NVIDIA CUDA toolkit's corresponding driver version](https://docs.nvidia.com/cuda/cuda-toolkit-release-notes/index.html#cuda-major-component-versions__table-cuda-toolkit-driver-versions). -If you need to use a newer CUDA toolkit with an older driver, for example -on a cluster where you cannot update the NVIDIA driver easily, you may be -able to use the -[CUDA forward compatibility packages](https://docs.nvidia.com/deploy/cuda-compatibility/) -that NVIDIA provides for this purpose. - -JAX currently ships two NVIDIA CUDA wheel variants: - -- CUDA 12.2, cuDNN 8.9, NCCL 2.16 -- CUDA 11.8, cuDNN 8.6, NCCL 2.16 - -You may use a JAX wheel provided the major version of your CUDA, cuDNN, and NCCL -installations match, and the minor versions are the same or newer. -JAX checks the versions of your libraries, and will report an error if they are -not sufficiently new. - -NCCL is an optional dependency, required only if you are performing multi-GPU -computations. - -To install, run: - -```bash -pip install --upgrade pip - -# Installs the wheel compatible with NVIDIA CUDA 12 and cuDNN 8.9 or newer. -# Note: wheels only available on linux. -pip install --upgrade "jax[cuda12_local]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html - -# Installs the wheel compatible with NVIDIA CUDA 11 and cuDNN 8.6 or newer. -# Note: wheels only available on linux. -pip install --upgrade "jax[cuda11_local]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html -``` - -**These `pip` installations do not work with Windows, and may fail silently; refer to the table -[above](#supported-platforms).** - -You can find your CUDA version with the command: - -```bash -nvcc --version -``` - -JAX uses `LD_LIBRARY_PATH` to find CUDA libraries and `PATH` to find binaries -(`ptxas`, `nvlink`). Please make sure that these paths point to the correct CUDA -installation. - -Please let the JAX team know on [the GitHub issue tracker](https://github.com/google/jax/issues) -if you run into any errors or problems with the pre-built wheels. - -(docker-containers-nvidia-gpu)= -### NVIDIA GPU Docker containers - -NVIDIA provides the [JAX -Toolbox](https://github.com/NVIDIA/JAX-Toolbox) containers, which are -bleeding edge containers containing nightly releases of jax and some -models/frameworks. - -## JAX nightly installation - -Nightly releases reflect the state of the main JAX repository at the time they are -built, and may not pass the full test suite. - -- `jax`: - -```bash -pip install -U --pre jax -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -``` - -- `jaxlib` CPU: - -```bash -pip install -U --pre jaxlib -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html -``` - -- `jaxlib` Google Cloud TPU: - -```bash -pip install -U --pre jaxlib -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html -pip install -U libtpu-nightly -f https://storage.googleapis.com/jax-releases/libtpu_releases.html -``` - -- `jaxlib` NVIDIA GPU (CUDA 12): - -```bash -pip install -U --pre jaxlib -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_cuda12_releases.html -``` - -- `jaxlib` NVIDIA GPU (CUDA 11): - -```bash -pip install -U --pre jaxlib -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_cuda_releases.html -``` - -(install-google-tpu)= -## Google Cloud TPU - -### pip installation: Google Cloud TPU - -JAX provides pre-built wheels for -[Google Cloud TPU](https://cloud.google.com/tpu/docs/users-guide-tpu-vm). -To install JAX along with appropriate versions of `jaxlib` and `libtpu`, you can run -the following in your cloud TPU VM: - -```bash -pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html -``` - -For interactive notebook users: Colab TPUs no longer support JAX as of -JAX version 0.4. However, for an interactive TPU notebook in the cloud, you can -use [Kaggle TPU notebooks](https://www.kaggle.com/docs/tpu), which fully -support JAX. - -(install-apple-gpu)= -## Apple Silicon GPU (ARM-based) - -### pip installation: Apple ARM-based Silicon GPUs - -Apple provides an experimental Metal plugin for Apple ARM-based GPU hardware. For details, -refer to -[Apple's JAX on Metal documentation](https://developer.apple.com/metal/jax/). - -**Note:** There are several caveats with the Metal plugin: - -* The Metal plugin is new and experimental and has a number of - [known issues](https://github.com/google/jax/issues?q=is%3Aissue+is%3Aopen+label%3A%22Apple+GPU+%28Metal%29+plugin%22). - Please report any issues on the JAX issue tracker. -* The Metal plugin currently requires very specific versions of `jax` and - `jaxlib`. This restriction will be relaxed over time as the plugin API - matures. - -(install-amd-gpu)= -## AMD GPU - -JAX has experimental ROCm support. There are two ways to install JAX: - -* Use [AMD's Docker container](https://hub.docker.com/r/rocm/jax); or -* Build from source (refer to {ref}`building-from-source` — a section called _Additional notes for building a ROCM `jaxlib` for AMD GPUs_). - -## Conda (community-supported) - -### Conda installation - -There is a community-supported Conda build of `jax`. To install it using `conda`, -simply run: - -```bash -conda install jax -c conda-forge -``` - -To install it on a machine with an NVIDIA GPU, run: - -```bash -conda install jaxlib=*=*cuda* jax cuda-nvcc -c conda-forge -c nvidia -``` - -Note the `cudatoolkit` distributed by `conda-forge` is missing `ptxas`, which -JAX requires. You must therefore either install the `cuda-nvcc` package from -the `nvidia` channel, or install CUDA on your machine separately so that `ptxas` -is in your path. The channel order above is important (`conda-forge` before -`nvidia`). - -If you would like to override which release of CUDA is used by JAX, or to -install the CUDA build on a machine without GPUs, follow the instructions in the -[Tips & tricks](https://conda-forge.org/docs/user/tipsandtricks.html#installing-cuda-enabled-packages-like-tensorflow-and-pytorch) -section of the `conda-forge` website. - -Go to the `conda-forge` -[jaxlib](https://github.com/conda-forge/jaxlib-feedstock#installing-jaxlib) and -[jax](https://github.com/conda-forge/jax-feedstock#installing-jax) repositories -for more details. - - -(building-jax-from-source)= -## Building JAX from source - -Refer to {ref}`building-from-source`. - -## Installing older `jaxlib` wheels - -Due to storage limitations on the Python package index, the JAX team periodically removes -older `jaxlib` wheels from the releases on http://pypi.org/project/jax. These can -still be installed directly via the URLs here. For example: - -```bash -# Install jaxlib on CPU via the wheel archive -pip install jax[cpu]==0.3.25 -f https://storage.googleapis.com/jax-releases/jax_releases.html - -# Install the jaxlib 0.3.25 CPU wheel directly -pip install jaxlib==0.3.25 -f https://storage.googleapis.com/jax-releases/jax_releases.html -``` -For specific older GPU wheels, be sure to use the `jax_cuda_releases.html` URL; for example -```bash -pip install jaxlib==0.3.25+cuda11.cudnn82 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html -``` \ No newline at end of file diff --git a/docs/tutorials/jax-as-accelerated-numpy.md b/docs/tutorials/jax-as-accelerated-numpy.md deleted file mode 100644 index 95eddbbc1e3e..000000000000 --- a/docs/tutorials/jax-as-accelerated-numpy.md +++ /dev/null @@ -1,8 +0,0 @@ -# JAX as accelerated NumPy - -```{note} -This is a placeholder for a section in the new {ref}`jax-tutorials`. - -For the time being, you may find some related content in the old documentation: -- {doc}`../jax-101/01-jax-basics` -``` diff --git a/docs/tutorials/single-host-sharding.md b/docs/tutorials/single-host-sharding.md deleted file mode 100644 index 9d918907ad59..000000000000 --- a/docs/tutorials/single-host-sharding.md +++ /dev/null @@ -1,5 +0,0 @@ -# Sharded data on a single host - -```{note} -This is a placeholder for a section in the new {ref}`jax-tutorials`. -``` diff --git a/docs/tutorials/stateful-computations.md b/docs/tutorials/stateful-computations.md deleted file mode 100644 index 318d4c4c6c05..000000000000 --- a/docs/tutorials/stateful-computations.md +++ /dev/null @@ -1,8 +0,0 @@ -# Stateful computations - -```{note} -This is a placeholder for a section in the new {ref}`jax-tutorials`. - -For the time being, you may find some related content in the old documentation: -- {doc}`../jax-101/07-state` -``` diff --git a/docs/tutorials/thinking-in-jax.md b/docs/tutorials/thinking-in-jax.md deleted file mode 100644 index df601ddaef31..000000000000 --- a/docs/tutorials/thinking-in-jax.md +++ /dev/null @@ -1,417 +0,0 @@ ---- -jupytext: - formats: md:myst - text_representation: - extension: .md - format_name: myst - format_version: 0.13 - jupytext_version: 1.16.0 -kernelspec: - display_name: Python 3 - language: python - name: python3 ---- - -(thinking-in-jax)= -# How to think in JAX - -In this tutorial you will learn about how JAX operates, so that you can use it more effectively. JAX provides a simple and powerful API for writing accelerated numerical code, and working effectively in JAX sometimes requires extra consideration. This document will help you build a ground-up understanding of the JAX API. - - -## JAX versus NumPy - -**Key concepts:** - -- JAX provides a NumPy-inspired interface for convenience. -- Through [duck typing](https://en.wikipedia.org/wiki/Duck_typing), JAX arrays (`jax.Array`s) can often be used as drop-in replacements of NumPy arrays ([`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html)s). -- Unlike NumPy arrays, JAX arrays are always immutable. - -NumPy provides a well-known, powerful API for working with numerical data. For convenience, JAX provides `jax.numpy` which closely mirrors the numpy API and provides easy entry into JAX. Almost anything that can be done with `numpy` can be done with `jax.numpy`: - -```{code-cell} -import matplotlib.pyplot as plt -import numpy as np # Ordinary NumPy - -x_np = np.linspace(0, 10, 1000) -y_np = 2 * np.sin(x_np) * np.cos(x_np) -plt.plot(x_np, y_np); -``` - -```{code-cell} -import jax.numpy as jnp # JAX NumPy - -x_jnp = jnp.linspace(0, 10, 1000) -y_jnp = 2 * jnp.sin(x_jnp) * jnp.cos(x_jnp) -plt.plot(x_jnp, y_jnp); -``` - -The code blocks are identical aside from replacing NumPy (`np`) with JAX NumPy (`jnp`), and the results are the same. JAX arrays can often be used directly in place of NumPy arrays for things like plotting. - -The arrays themselves are implemented as different Python types: - -```{code-cell} -type(x_np) -``` - -```{code-cell} -type(x_jnp) -``` - -Python's [duck typing](https://en.wikipedia.org/wiki/Duck_typing) allows JAX arrays `jax.Array`s and NumPy arrays `numpy.ndarray`s to be used interchangeably in many places. - -However, there is one important difference between JAX and NumPy arrays: JAX arrays are immutable, meaning that once created their contents cannot be changed. - -Here is an example of mutating an array in NumPy: - -```{code-cell} -# NumPy: mutable arrays -x = np.arange(10) -x[0] = 10 -print(x) -``` - -The equivalent in JAX results in an error, as JAX arrays are immutable: - -```{code-cell} -%xmode minimal -``` - -```python -:tags: [raises-exception] - -# JAX: immutable arrays -x = jnp.arange(10) -x[0] = 10 -``` - -For updating individual elements, JAX provides an [indexed update syntax](https://jax.readthedocs.io/en/latest/jax.ops.html#indexed-update-operators) that returns an updated copy: - -```python -y = x.at[0].set(10) -print(x) -print(y) -``` - -(thinking-in-jax-jax-arrays)= -## JAX arrays (`jax.Array`s) - -**Key concepts:** - -- `jax.Array` is the default array implementation in JAX. -- The JAX array is a unified distributed datatype for representing arrays, even with physical storage spanning multiple devices -- Automatic parallelization: You can operate over sharded `jax.Array`s without copying data onto a device using the {func}`jax.jit` transformation. You can also replicate a `jax.Array` to every device on a mesh. - -Consider this simple example: - -```{code-cell} -import jax -from jax import Array -import jax.numpy as jnp - -x = jnp.arange(5) -isinstance(x, jax.Array) # Returns True both inside and outside traced functions. - -def f(x: Array) -> Array: # Type annotations are valid for traced and non-traced types. - return x -``` - -The `jax.Array` type also helps make parallelism a core feature of JAX. - -(thinking-in-jax-pytrees)= -# Pytrees - -**Key concepts:** - -- JAX supports a special data structure called a pytree when you need to operate on dictionaries of lists, for example. -- Use cases: machine learning model parameters, dataset entries, lists of lists of dictionaries. - -JAX has built-in support for objects that look like dictionaries (dicts) of arrays, or lists of lists of dicts, or other nested structures — they are called JAX pytrees (also known as nests, or just trees). In the context of machine learning, a pytree can contain model parameters, dataset entries, and reinforcement learning agent observations. - -Below is an example of a simple pytree. In JAX, you can use {func}`jax.tree_util.tree_leaves`, to extract the flattened leaves from the trees, as demonstrated here: - -```{code-cell} -example_trees = [ - [1, 'a', object()], - (1, (2, 3), ()), - [1, {'k1': 2, 'k2': (3, 4)}, 5], - {'a': 2, 'b': (2, 3)}, - jnp.array([1, 2, 3]), -] - -# Let's see how many leaves they have: -for pytree in example_trees: - leaves = jax.tree_util.tree_leaves(pytree) - print(f"{repr(pytree):<45} has {len(leaves)} leaves: {leaves}") -``` - -{func}`jax.tree_map` is the most commonly used pytree function in JAX. It works analogously to Python's native map, but on entire pytrees. - -You can learn more in the {ref}`working-with-pytrees` tutorial. - -(thinking-in-jax-jax-api-layering)= -## NumPy, `jax.lax` and XLA: JAX API layering - -**Key concepts:** - -- {mod}`jax.numpy` is a high-level wrapper that provides a familiar interface. -- {mod}`jax.lax` is a lower-level API that is stricter and often more powerful. -- All JAX operations are implemented in terms of operations in [XLA](https://www.tensorflow.org/xla/) — the Accelerated Linear Algebra compiler. - -If you look at the source of {mod}`jax.numpy`, you'll see that all the operations are eventually expressed in terms of functions defined in {mod}`jax.lax`. You can think of {mod}`jax.lax` as a stricter, but often more powerful, API for working with multi-dimensional arrays. - -For example, while {mod}`jax.numpy` will implicitly promote arguments to allow operations between mixed data types, {mod}`jax.lax` will not: - -```{code-cell} -import jax.numpy as jnp -jnp.add(1, 1.0) # `jax.numpy` API implicitly promotes mixed types. -``` - -```{code-cell} -:tags: [raises-exception] - -from jax import lax -lax.add(1, 1.0) # `jax.lax` API requires explicit type promotion. -``` - -If using {mod}`jax.lax` directly, you'll have to do type promotion explicitly in such cases: - -```{code-cell} -lax.add(jnp.float32(1), 1.0) -``` - -Along with this strictness, {mod}`jax.lax` also provides efficient APIs for some more general operations than are supported by NumPy. - -For example, consider a 1D convolution, which can be expressed in NumPy this way: - -```{code-cell} -x = jnp.array([1, 2, 1]) -y = jnp.ones(10) -jnp.convolve(x, y) -``` - -Under the hood, this NumPy operation is translated to a much more general convolution implemented by {func}`jax.lax.conv_general_dilated`: - -```{code-cell} -from jax import lax -result = lax.conv_general_dilated( - x.reshape(1, 1, 3).astype(float), # note: explicit promotion - y.reshape(1, 1, 10), - window_strides=(1,), - padding=[(len(y) - 1, len(y) - 1)]) # equivalent of `padding='full'`` in NumPy -result[0, 0] -``` - -This is a batched convolution operation designed to be efficient for the types of convolutions often used in deep neural nets. It requires much more boilerplate, but is far more flexible and scalable than the convolution provided by NumPy (Refer to [Convolutions in JAX](https://jax.readthedocs.io/en/latest/notebooks/convolutions.html) for more details on JAX convolutions). - -At their heart, all {mod}`jax.lax` operations are Python wrappers for operations in XLA. Here, for example, the convolution implementation is provided by [XLA:ConvWithGeneralPadding](https://www.tensorflow.org/xla/operation_semantics#convwithgeneralpadding_convolution). -Every JAX operation is eventually expressed in terms of these fundamental XLA operations, which is what enables just-in-time (JIT) compilation. - - -(thinking-in-jax-to-jit-or-not-to-jit)= -## To JIT or not to JIT (`jax.jit`) - -**Key concepts:** - -- By default JAX executes operations one at a time, in sequence. -- Using a just-in-time (JIT) compilation decorator — {func}`jax.jit` — sequences of operations can be optimized together and run at once. -- Not all JAX code can be JIT compiled, as it requires array shapes to be static & known at compile time. - -The fact that all JAX operations are expressed in terms of XLA allows JAX to use the XLA compiler to execute blocks of code very efficiently. - -For example, consider this function that normalizes the rows of a 2D matrix, expressed in terms of {mod}`jax.numpy` operations: - -```{code-cell} -import jax.numpy as jnp - -def norm(X): - X = X - X.mean(0) - return X / X.std(0) -``` - -A JIT-compiled version of the function can be created using the {func}`jax.jit` transform: - -```{code-cell} -from jax import jit -norm_compiled = jit(norm) -``` - -This function returns the same results as the original, up to standard floating-point accuracy: - -```{code-cell} -np.random.seed(1701) -X = jnp.array(np.random.rand(10000, 10)) -np.allclose(norm(X), norm_compiled(X), atol=1E-6) -``` - -But due to the compilation (which includes fusing of operations, avoidance of allocating temporary arrays, and a host of other tricks), execution times can be orders of magnitude faster in the JIT-compiled case (note the use of {func}`jax.block_until_ready` to account for JAX's [asynchronous dispatch](https://jax.readthedocs.io/en/latest/async_dispatch.html)): - -```{code-cell} -%timeit norm(X).block_until_ready() -%timeit norm_compiled(X).block_until_ready() -``` - -That said, {func}`jax.jit` does have limitations: in particular, it requires all arrays to have static shapes. This means some JAX operations are incompatible with JIT compilation. - -For example, this operation can be executed in op-by-op mode: - -```{code-cell} -def get_negatives(x): - return x[x < 0] - -x = jnp.array(np.random.randn(10)) -get_negatives(x) -``` - -But it returns an error if you attempt to execute it in {func}`jax.jit` mode: - -```{code-cell} -:tags: [raises-exception] - -jit(get_negatives)(x) -``` - -This is because the function generates an array whose shape is not known at compile time: the size of the output depends on the values of the input array, and so it is not compatible with JIT. - -(thinking-in-jax-jit-mechanics)= -## JIT mechanics: tracing and static variables - -**Key concepts:** - -- JIT and other JAX transforms work by *tracing* a function to determine its effect on inputs of a specific shape and type. - -- Variables that you don't want to be traced can be marked as *static* - -To use {func}`jax.jit` effectively, it is useful to understand how it works. Let's put a few `print()` statements within a JIT-compiled function and then call the function: - -```{code-cell} -@jit -def f(x, y): - print("Running f():") - print(f" x = {x}") - print(f" y = {y}") - result = jnp.dot(x + 1, y + 1) - print(f" result = {result}") - return result - -x = np.random.randn(3, 4) -y = np.random.randn(4) -f(x, y) -``` - -Notice that the print statements execute, but rather than printing the data you passed to the function, though, it prints *tracer* objects that stand-in for them. - -These tracer objects are what {func}`jax.jit` uses to extract the sequence of operations specified by the function. Basic tracers are stand-ins that encode the **shape** and **dtype** of the arrays, but are agnostic to the values. This recorded sequence of computations can then be efficiently applied within XLA to new inputs with the same shape and dtype, without having to re-execute the Python code. - -When you call the compiled function again on matching inputs, no recompilation is required and nothing is printed because the result is computed in compiled XLA rather than in Python: - -```{code-cell} -x2 = np.random.randn(3, 4) -y2 = np.random.randn(4) -f(x2, y2) -``` - -The extracted sequence of operations is encoded in a JAX expression, or *jaxpr* for short. You can view the jaxpr using the {func}`jax.make_jaxpr` transformation: - -```python -from jax import make_jaxpr - -def f(x, y): - return jnp.dot(x + 1, y + 1) - -make_jaxpr(f)(x, y) -``` - -Note one consequence of this: because JIT compilation is done *without* information on the content of the array, control flow statements in the function cannot depend on traced values. For example, this fails: - -```{code-cell} -:tags: [raises-exception] - -@jit -def f(x, neg): - return -x if neg else x - -f(1, True) -``` - -If there are variables that you would not like to be traced, they can be marked as static for the purposes of JIT compilation: - -```{code-cell} -from functools import partial - -@partial(jit, static_argnums=(1,)) -def f(x, neg): - return -x if neg else x - -f(1, True) -``` - -Note that calling a JIT-compiled function with a different static argument results in recompilation, so the function still works as expected: - -```{code-cell} -f(1, False) -``` - -Understanding which values and operations will be static and which will be traced is a key part of using {func}`jax.jit` effectively. - - -(thinking-in-jax-static-versus-traced-operations)= -## Static versus traced operations - -**Key concepts:** - -- Just as values can be either static or traced, operations can be static or traced. -- Static operations are evaluated at compile-time in Python. Traced operations are compiled & evaluated at run-time in XLA. -- Use NumPy (`numpy`) for operations that you want to be static. Use JAX NumPy {mod}`jax.numpy` for operations that you want to be traced. - -This distinction between static and traced values makes it important to think about how to keep a static value static. Consider this function: - -```{code-cell} -:tags: [raises-exception] - -import jax.numpy as jnp -from jax import jit - -@jit -def f(x): - return x.reshape(jnp.array(x.shape).prod()) - -x = jnp.ones((2, 3)) -f(x) -``` - -This fails with an error specifying that a tracer was found instead of a 1D sequence of concrete values of integer type. Let's add some print statements to the function to understand why this is happening: - -```{code-cell} -@jit -def f(x): - print(f"x = {x}") - print(f"x.shape = {x.shape}") - print(f"jnp.array(x.shape).prod() = {jnp.array(x.shape).prod()}") - # comment this out to avoid the error: - # return x.reshape(jnp.array(x.shape).prod()) - -f(x) -``` - -Notice that although `x` is traced, `x.shape` is a static value. However, when you use {func}`jnp.array` and {func}`jnp.prod` on this static value, it becomes a traced value, at which point it cannot be used in a function like `reshape()` that requires a static input (recall: array shapes must be static). - -A useful pattern is to: - -- Use NumPy (`numpy`) for operations that should be static (i.e., done at compile-time); and -- Use JAX NumPy (`jax.numpy`) for operations that should be traced (i.e. compiled and executed at run-time). - -For this function, it might look like this: - -```{code-cell} -from jax import jit -import jax.numpy as jnp -import numpy as np - -@jit -def f(x): - return x.reshape((np.prod(x.shape),)) - -f(x) -``` - -For this reason, a standard convention in JAX programs is to `import numpy as np` and `import jax.numpy as jnp` so that both interfaces are available for finer control over whether operations are performed in a static matter (with `numpy`, once at compile-time) or a traced manner (with {mod}`jax.numpy`, optimized at run-time). diff --git a/docs/type_promotion.rst b/docs/type_promotion.rst index 3fbf6e88b839..103a8331df2b 100644 --- a/docs/type_promotion.rst +++ b/docs/type_promotion.rst @@ -209,4 +209,45 @@ strongly-typed array value: .. code-block:: python >>> jnp.asarray(2, dtype='int32') - Array(2, dtype=int32) \ No newline at end of file + Array(2, dtype=int32) + + +.. _strict-dtype-promotion: + +Strict dtype promotion +---------------------- +In some contexts it can be useful to disable implicit type promotion behavior, and +instead require all promotions to be explicit. This can be done in JAX by setting the +``jax_numpy_dtype_promtion`` flag to ``'strict'``. Locally, it can be done with a\ +context manager: + +.. code-block:: python + + >>> x = jnp.float32(1) + >>> y = jnp.int32(1) + >>> with jax.numpy_dtype_promotion('strict'): + ... z = x + y # doctest: +SKIP + ... + Traceback (most recent call last): + TypePromotionError: Input dtypes ('float32', 'int32') have no available implicit + dtype promotion path when jax_numpy_dtype_promotion=strict. Try explicitly casting + inputs to the desired output type, or set jax_numpy_dtype_promotion=standard. + +For convenience, strict promotion mode will still allow safe weakly-typed promotions, +so you can still write code code that mixes JAX arrays and Python scalars: + +.. code-block:: python + + >>> with jax.numpy_dtype_promotion('strict'): + ... z = x + 1 + >>> print(z) + 2.0 + +If you would prefer to set the configuration globally, you can do so using the standard +configuration update:: + + jax.config.update('jax_numpy_dtype_promotion', 'strict') + +To restore the default standard type promotion, set this configuration to ``'standard'``:: + + jax.config.update('jax_numpy_dtype_promotion', 'standard') diff --git a/docs/user_guides.rst b/docs/user_guides.rst index aaa9f883d62e..f46d6b027471 100644 --- a/docs/user_guides.rst +++ b/docs/user_guides.rst @@ -11,11 +11,12 @@ or deployed codebases. :maxdepth: 1 :caption: Debugging and Performance + notebooks/thinking_in_jax profiling device_memory_profiling debugging/index gpu_performance_tips - + persistent_compilation_cache .. toctree:: :maxdepth: 1 @@ -31,6 +32,7 @@ or deployed codebases. :caption: Run Time aot + export/index errors transfer_guard diff --git a/docs/tutorials/working-with-pytrees.md b/docs/working-with-pytrees.md similarity index 77% rename from docs/tutorials/working-with-pytrees.md rename to docs/working-with-pytrees.md index f082462401cf..2bd1cc08ecdf 100644 --- a/docs/tutorials/working-with-pytrees.md +++ b/docs/working-with-pytrees.md @@ -5,17 +5,27 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.16.0 + jupytext_version: 1.16.1 kernelspec: display_name: Python 3 language: python name: python3 --- +```{code-cell} +:tags: [remove-cell] + +# This ensures that code cell tracebacks appearing below will be concise. +%xmode minimal +``` + (working-with-pytrees)= # Working with pytrees -JAX has built-in support for objects that look like dictionaries (dicts) of arrays, or lists of lists of dicts, or other nested structures — they are called JAX pytrees (also known as nests, or just trees). Often, in JAX, you want to operate over these nested pytrees. This tutorial will explain how to use them, provide useful code example, and point out common "gotchas" and patterns. + + +JAX has built-in support for objects that look like dictionaries (dicts) of arrays, or lists of lists of dicts, or other nested structures — in JAX these are called pytrees. +This section will explain how to use them, provide useful code examples, and point out common "gotchas" and patterns. (pytrees-what-is-a-pytree)= @@ -25,13 +35,13 @@ A pytree is a container-like structure built out of container-like Python object In the context of machine learning (ML), a pytree can contain: -- Model parameters ({ref}`pytrees-example-jax-tree-map-ml`) +- Model parameters - Dataset entries - Reinforcement learning agent observations When working with datasets, you can often come across pytrees (such as lists of lists of dicts). -Below is an example of a simple pytree. In JAX, you can use {func}`jax.tree_util.tree_leaves`, to extract the flattened leaves from the trees, as demonstrated here: +Below is an example of a simple pytree. In JAX, you can use {func}`jax.tree.leaves`, to extract the flattened leaves from the trees, as demonstrated here: ```{code-cell} import jax @@ -47,28 +57,26 @@ example_trees = [ # Print how many leaves the pytrees have. for pytree in example_trees: - # This `jax.tree_util.tree_leaves()` method extracts the flattened leaves from the pytrees. - leaves = jax.tree_util.tree_leaves(pytree) + # This `jax.tree.leaves()` method extracts the flattened leaves from the pytrees. + leaves = jax.tree.leaves(pytree) print(f"{repr(pytree):<45} has {len(leaves)} leaves: {leaves}") ``` -Any tree-like structure built out of container-like Python objects can be referred to as pytrees in JAX. Classes are considered container-like if they are in the pytree registry, which by default includes lists, tuples, and dicts. Any object whose type is *not* in the pytree container registry will be treated as a leaf node in the tree. - -The pytree registry can be extended to include user-defined container classes by registering a pair of functions that specify: +Any tree-like structure built out of container-like Python objects can be treated as a pytree in JAX. +Classes are considered container-like if they are in the pytree registry, which by default includes lists, tuples, and dicts. Any object whose type is *not* in the pytree container registry will be treated as a leaf node in the tree. -1) How to convert an instance of the container type to a `(children, metadata)` pair; and -2) How to convert this pair back to an instance of the container type. - -JAX will use these functions to canonicalize any tree of registered container objects into a flat tuple, and then reassemble the tree-like container before returning the processed data to the user. +The pytree registry can be extended to include user-defined container classes by registering the class +with functions that specify how to flatten the tree; see {ref}`pytrees-custom-pytree-nodes` below. (pytrees-common-pytree-functions)= ## Common pytree functions -JAX provides a number of utilities to operate over pytrees. These can be found in the {mod}`jax.tree_util` subpackage. +JAX provides a number of utilities to operate over pytrees. These can be found in the {mod}`jax.tree_util` subpackage; +for convenience many of these have aliases in the {mod}`jax.tree` module. -### Common function: `jax.tree_map` +### Common function: `jax.tree.map` -The most commonly used pytree function is {func}`jax.tree_map`. It works analogously to Python's native `map`, but transparently operates over entire pytrees. +The most commonly used pytree function is {func}`jax.tree.map`. It works analogously to Python's native `map`, but transparently operates over entire pytrees. Here's an example: @@ -79,20 +87,20 @@ list_of_lists = [ [1, 2, 3, 4] ] -jax.tree_map(lambda x: x*2, list_of_lists) +jax.tree.map(lambda x: x*2, list_of_lists) ``` -{func}`jax.tree_map` also allows to map a [N-ary](https://en.wikipedia.org/wiki/N-ary) function over multiple arguments. For example: +{func}`jax.tree.map` also allows mapping a [N-ary](https://en.wikipedia.org/wiki/N-ary) function over multiple arguments. For example: ```{code-cell} another_list_of_lists = list_of_lists -jax.tree_map(lambda x, y: x+y, list_of_lists, another_list_of_lists) +jax.tree.map(lambda x, y: x+y, list_of_lists, another_list_of_lists) ``` -When using multiple arguments with {func}`jax.tree_map`, the structure of the inputs must exactly match. That is, lists must have the same number of elements, dicts must have the same keys, etc. +When using multiple arguments with {func}`jax.tree.map`, the structure of the inputs must exactly match. That is, lists must have the same number of elements, dicts must have the same keys, etc. (pytrees-example-jax-tree-map-ml)= -### Example of `jax.tree_map` with ML model parameters +### Example of `jax.tree.map` with ML model parameters This example demonstrates how pytree operations can be useful when training a simple [multi-layer perceptron (MLP)](https://en.wikipedia.org/wiki/Multilayer_perceptron). @@ -114,10 +122,10 @@ def init_mlp_params(layer_widths): params = init_mlp_params([1, 128, 128, 1]) ``` -Use {func}`jax.tree_map` to check the shapes of the initial parameters: +Use {func}`jax.tree.map` to check the shapes of the initial parameters: ```{code-cell} -jax.tree_map(lambda x: x.shape, params) +jax.tree.map(lambda x: x.shape, params) ``` Next, define the functions for training the MLP model: @@ -147,7 +155,7 @@ def update(params, x, y): # `jax.grad` is one of many JAX functions that has # built-in support for pytrees. # This is useful - you can apply the SGD update using JAX pytree utilities. - return jax.tree_map( + return jax.tree.map( lambda p, g: p - LEARNING_RATE * g, params, grads ) ``` @@ -155,7 +163,7 @@ def update(params, x, y): (pytrees-custom-pytree-nodes)= ## Custom pytree nodes -This section explains how in JAX you can extend the set of Python types that will be considered _internal nodes_ in pytrees (pytree nodes) by using {func}`jax.tree_util.register_pytree_node` with {func}`jax.tree_map`. +This section explains how in JAX you can extend the set of Python types that will be considered _internal nodes_ in pytrees (pytree nodes) by using {func}`jax.tree_util.register_pytree_node` with {func}`jax.tree.map`. Why would you need this? In the previous examples, pytrees were shown as lists, tuples, and dicts, with everything else as pytree leaves. This is because if you define your own container class, it will be considered to be a pytree leaf unless you _register_ it with JAX. This is also the case even if your container class has trees inside it. For example: @@ -165,23 +173,22 @@ class Special(object): self.x = x self.y = y -jax.tree_util.tree_leaves([ +jax.tree.leaves([ Special(0, 1), Special(2, 4), ]) ``` -Accordingly, if you try to use a {func}`jax.tree_map` expecting the leaves to be elements inside the container, you will get an error: +Accordingly, if you try to use a {func}`jax.tree.map` expecting the leaves to be elements inside the container, you will get an error: ```{code-cell} -try: - jax.tree_map(lambda x: x + 1, - [ - Special(0, 1), - Special(2, 4), - ]) -except TypeError as e: - print(f'TypeError: {e}') +:tags: [raises-exception] + +jax.tree.map(lambda x: x + 1, + [ + Special(0, 1), + Special(2, 4) + ]) ``` As a solution, JAX allows to extend the set of types to be considered internal pytree nodes through a global registry of types. Additionally, the values of registered types are traversed recursively. @@ -235,11 +242,11 @@ register_pytree_node( Now you can traverse the special container structure: ```{code-cell} -jax.tree_map(lambda x: x + 1, -[ - RegisteredSpecial(0, 1), - RegisteredSpecial(2, 4), -]) +jax.tree.map(lambda x: x + 1, + [ + RegisteredSpecial(0, 1), + RegisteredSpecial(2, 4), + ]) ``` Modern Python comes equipped with helpful tools to make defining containers easier. Some will work with JAX out-of-the-box, but others require more care. @@ -257,7 +264,7 @@ class MyOtherContainer(NamedTuple): # NamedTuple subclasses are handled as pytree nodes, so # this will work out-of-the-box. -jax.tree_util.tree_leaves([ +jax.tree.leaves([ MyOtherContainer('Alice', 1, 2, 3), MyOtherContainer('Bob', 4, 5, 6) ]) @@ -267,7 +274,7 @@ Notice that the `name` field now appears as a leaf, because all tuple elements a (pytree-and-jax-transformations)= -## Pytree and JAX's transformations +## Pytrees and JAX transformations Many JAX functions, like {func}`jax.lax.scan`, operate over pytrees of arrays. In addition, all JAX function transformations can be applied to functions that accept as input and produce as output pytrees of arrays. @@ -275,28 +282,28 @@ Some JAX function transformations take optional parameters that specify how cert For example, if you pass the following input to {func}`jax.vmap` (note that the input arguments to a function are considered a tuple): -``` -(a1, {"k1": a2, "k2": a3}) +```python +vmap(f, in_axes=(a1, {"k1": a2, "k2": a3})) ``` then you can use the following `in_axes` pytree to specify that only the `k2` argument is mapped (`axis=0`), and the rest aren’t mapped over (`axis=None`): -``` -(None, {"k1": None, "k2": 0}) +```python +vmap(f, in_axes=(None, {"k1": None, "k2": 0})) ``` The optional parameter pytree structure must match that of the main input pytree. However, the optional parameters can optionally be specified as a “prefix” pytree, meaning that a single leaf value can be applied to an entire sub-pytree. For example, if you have the same {func}`jax.vmap` input as above, but wish to only map over the dictionary argument, you can use: -``` -(None, 0) # equivalent to (None, {"k1": 0, "k2": 0}) +```python +vmap(f, in_axes=(None, 0)) # equivalent to (None, {"k1": 0, "k2": 0}) ``` Alternatively, if you want every argument to be mapped, you can write a single leaf value that is applied over the entire argument tuple pytree: -``` -0 +```python +vmap(f, in_axes=0) # equivalent to (0, {"k1": 0, "k2": 0}) ``` This happens to be the default `in_axes` value for {func}`jax.vmap`. @@ -312,8 +319,8 @@ For built-in pytree node types, the set of keys for any pytree node instance is JAX has the following `jax.tree_util.*` methods for working with key paths: -- {func}`jax.tree_util.tree_flatten_with_path`: Works similarly to {func}`jax.tree_util.tree_flatten`, but returns key paths. -- {func}`jax.tree_util.tree_map_with_path``: Works similarly to {func}`jax.tree_util.tree_map`, but the function also takes key paths as arguments. +- {func}`jax.tree_util.tree_flatten_with_path`: Works similarly to {func}`jax.tree.flatten`, but returns key paths. +- {func}`jax.tree_util.tree_map_with_path`: Works similarly to {func}`jax.tree.map`, but the function also takes key paths as arguments. - {func}`jax.tree_util.keystr`: Given a general key path, returns a reader-friendly string expression. For example, one use case is to print debugging information related to a certain leaf value: @@ -327,7 +334,7 @@ tree = [1, {'k1': 2, 'k2': (3, 4)}, ATuple('foo')] flattened, _ = jax.tree_util.tree_flatten_with_path(tree) for key_path, value in flattened: - print(f'Value of tree{jax.tree_util.keystr(key_path)}: {value}') + print(f'Value of tree{jax.tree_util.keystr(key_path)}: {value}') ``` To express key paths, JAX provides a few default key types for the built-in pytree node types, namely: @@ -340,7 +347,7 @@ You are free to define your own key types for your custom nodes. They will work ```{code-cell} for key_path, _ in flattened: - print(f'Key path of tree{jax.tree_util.keystr(key_path)}: {repr(key_path)}') + print(f'Key path of tree{jax.tree_util.keystr(key_path)}: {repr(key_path)}') ``` (pytrees-common-pytree-gotchas)= @@ -356,26 +363,30 @@ A common gotcha to look out for is accidentally introducing _tree nodes_ instead a_tree = [jnp.zeros((2, 3)), jnp.zeros((3, 4))] # Try to make another pytree with ones instead of zeros. -shapes = jax.tree_map(lambda x: x.shape, a_tree) -jax.tree_map(jnp.ones, shapes) +shapes = jax.tree.map(lambda x: x.shape, a_tree) +jax.tree.map(jnp.ones, shapes) ``` What happened here is that the `shape` of an array is a tuple, which is a pytree node, with its elements as leaves. Thus, in the map, instead of calling `jnp.ones` on e.g. `(2, 3)`, it's called on `2` and `3`. The solution will depend on the specifics, but there are two broadly applicable options: -- Rewrite the code to avoid the intermediate {func}`jax.tree_map`. +- Rewrite the code to avoid the intermediate {func}`jax.tree.map`. - Convert the tuple into a NumPy array (`np.array`) or a JAX NumPy array (`jnp.array`), which makes the entire sequence a leaf. -### Handling of `None` by `jax.tree_utils` +### Handling of `None` by `jax.tree_util` -`jax.tree_utils` treats `None` as the absence of a pytree node, not as a leaf: +`jax.tree_util` functions treat `None` as the absence of a pytree node, not as a leaf: ```{code-cell} -jax.tree_util.tree_leaves([None, None, None]) +jax.tree.leaves([None, None, None]) ``` -Note that this is different from how the (now deprecated) [`tree` (`dm_tree`)](https://github.com/google-deepmind/tree) library used to treat `None`. +To treat `None` as a leaf, you can use the `is_leaf` argument: + +```{code-cell} +jax.tree.leaves([None, None, None], is_leaf=lambda x: x is None) +``` ### Custom pytrees and initialization with unexpected values @@ -394,6 +405,11 @@ register_pytree_node(MyTree, lambda tree: ((tree.a,), None), tree = MyTree(jnp.arange(5.0)) jax.vmap(lambda x: x)(tree) # Error because object() is passed to `MyTree`. +``` + +```{code-cell} +:tags: [raises-exception] + jax.jacobian(lambda x: x)(tree) # Error because MyTree(...) is passed to `MyTree`. ``` @@ -429,30 +445,30 @@ def tree_unflatten(aux_data, children): This section covers some of the most common patterns with JAX pytrees. -### Transposing pytrees with `jax.tree_map` and `jax.tree_util.tree_transpose` +### Transposing pytrees with `jax.tree.map` and `jax.tree.transpose` -To transpose a pytree (turn a list of trees into a tree of lists), JAX has two functions: {func} `jax.tree_map` (more basic) and {func}`jax.tree_util.tree_transpose` (more flexible, complex and verbose). +To transpose a pytree (turn a list of trees into a tree of lists), JAX has two functions: {func} `jax.tree.map` (more basic) and {func}`jax.tree.transpose` (more flexible, complex and verbose). -**Option 1:** Use {func}`jax.tree_map`. Here's an example: +**Option 1:** Use {func}`jax.tree.map`. Here's an example: ```{code-cell} def tree_transpose(list_of_trees): """ Converts a list of trees of identical structure into a single tree of lists. """ - return jax.tree_map(lambda *xs: list(xs), *list_of_trees) + return jax.tree.map(lambda *xs: list(xs), *list_of_trees) # Convert a dataset from row-major to column-major. episode_steps = [dict(t=1, obs=3), dict(t=2, obs=4)] tree_transpose(episode_steps) ``` -**Option 2:** For more complex transposes, use {func}`jax.tree_util.tree_transpose`, which is more verbose, but allows you specify the structure of the inner and outer pytree for more flexibility. For example: +**Option 2:** For more complex transposes, use {func}`jax.tree.transpose`, which is more verbose, but allows you specify the structure of the inner and outer pytree for more flexibility. For example: ```{code-cell} -jax.tree_util.tree_transpose( - outer_treedef = jax.tree_util.tree_structure([0 for e in episode_steps]), - inner_treedef = jax.tree_util.tree_structure(episode_steps[0]), +jax.tree.transpose( + outer_treedef = jax.tree.structure([0 for e in episode_steps]), + inner_treedef = jax.tree.structure(episode_steps[0]), pytree_to_transpose = episode_steps ) ``` diff --git a/examples/advi.py b/examples/advi.py index 35ee94a58d2f..68092b2cf74a 100644 --- a/examples/advi.py +++ b/examples/advi.py @@ -78,7 +78,7 @@ def funnel_log_density(params): @jit def objective(params, t): - rng = random.PRNGKey(t) + rng = random.key(t) return -batch_elbo(funnel_log_density, rng, params, num_samples) # Set up figure. @@ -107,7 +107,7 @@ def callback(params, t): # Plot random samples from variational distribution. # Here we clone the rng used in computing the objective # so that we can show exactly the same samples. - rngs = random.split(random.PRNGKey(t), num_samples) + rngs = random.split(random.key(t), num_samples) samples = vmap(diag_gaussian_sample, in_axes=(0, None, None))(rngs, *params) ax.plot(samples[:, 0], samples[:, 1], 'b.') diff --git a/examples/differentially_private_sgd.py b/examples/differentially_private_sgd.py index 4777554b1127..ca368098d243 100644 --- a/examples/differentially_private_sgd.py +++ b/examples/differentially_private_sgd.py @@ -182,7 +182,7 @@ def main(_): num_train = train_images.shape[0] num_complete_batches, leftover = divmod(num_train, _BATCH_SIZE.value) num_batches = num_complete_batches + bool(leftover) - key = random.PRNGKey(_SEED.value) + key = random.key(_SEED.value) def data_stream(): rng = npr.RandomState(_SEED.value) diff --git a/examples/examples_test.py b/examples/examples_test.py index e2ca51d78155..007e8e65824d 100644 --- a/examples/examples_test.py +++ b/examples/examples_test.py @@ -22,20 +22,22 @@ import numpy as np -from jax import lax +import jax from jax import random import jax.numpy as jnp +from jax._src import test_util as jtu + +del jtu # Needed for flags sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from examples import kernel_lsq sys.path.pop() -from jax import config -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() def _CheckShapeAgreement(test_case, init_fun, apply_fun, input_shape): - jax_rng = random.PRNGKey(0) + jax_rng = random.key(0) result_shape, params = init_fun(jax_rng, input_shape) result = apply_fun(params, test_case.rng.normal(size=input_shape).astype("float32")) test_case.assertEqual(result.shape, result_shape) @@ -52,12 +54,13 @@ def testKernelRegressionGram(self): kernel = lambda x, y: jnp.dot(x, y) np.testing.assert_allclose(kernel_lsq.gram(kernel, xs), jnp.dot(xs, xs.T), atol=1E-5) + @jax.default_matmul_precision("float32") def testKernelRegressionTrainAndPredict(self): n, d = 100, 20 truth = self.rng.normal(size=d) xs = self.rng.normal(size=(n, d)) ys = jnp.dot(xs, truth) - kernel = lambda x, y: jnp.dot(x, y, precision=lax.Precision.HIGH) + kernel = lambda x, y: jnp.dot(x, y) predict = kernel_lsq.train(kernel, xs, ys) np.testing.assert_allclose(predict(xs), ys, atol=1e-3, rtol=1e-3) diff --git a/examples/gaussian_process_regression.py b/examples/gaussian_process_regression.py index 070943b72413..75f7398d12ca 100644 --- a/examples/gaussian_process_regression.py +++ b/examples/gaussian_process_regression.py @@ -17,10 +17,11 @@ from absl import app from functools import partial + +import jax from jax import grad from jax import jit from jax import vmap -from jax import config import jax.numpy as jnp import jax.random as random import jax.scipy as scipy @@ -30,7 +31,7 @@ def main(unused_argv): numpts = 7 - key = random.PRNGKey(0) + key = random.key(0) eye = jnp.eye(numpts) def cov_map(cov_func, xs, xs2=None): @@ -125,5 +126,5 @@ def train_step(params, momentums, scales, x, y): mu.flatten() - std * 2, mu.flatten() + std * 2) if __name__ == "__main__": - config.config_with_absl() + jax.config.config_with_absl() app.run(main) diff --git a/examples/jax_cpp/BUILD b/examples/jax_cpp/BUILD index a36b6c1949c1..fccf0cc37048 100644 --- a/examples/jax_cpp/BUILD +++ b/examples/jax_cpp/BUILD @@ -21,14 +21,11 @@ cc_binary( srcs = ["main.cc"], tags = ["manual"], deps = [ + "//third_party/absl/status:statusor", "@xla//xla:literal", "@xla//xla:literal_util", - "@xla//xla:shape_util", - "@xla//xla:status", - "@xla//xla:statusor", "@xla//xla/pjrt:pjrt_client", - "@xla//xla/pjrt:tfrt_cpu_pjrt_client", - "@xla//xla/service:hlo_proto_cc", + "@xla//xla/pjrt/cpu:cpu_client", "@xla//xla/tools:hlo_module_loader", "@tsl//tsl/platform:logging", "@tsl//tsl/platform:platform_port", diff --git a/examples/jax_cpp/main.cc b/examples/jax_cpp/main.cc index a55a6c3e1ea4..2a8f8d4debba 100644 --- a/examples/jax_cpp/main.cc +++ b/examples/jax_cpp/main.cc @@ -27,10 +27,10 @@ limitations under the License. // // To load and run the HloModule, // -// $ bazel build examples/jax_cpp:main --experimental_repo_remote_exec --check_visibility=false -// $ bazel-bin/examples/jax_cpp/main -// 2021-01-12 15:35:28.316880: I examples/jax_cpp/main.cc:65] result = ( -// f32[2,2] { +// $ bazel build examples/jax_cpp:main --experimental_repo_remote_exec \ +// --check_visibility=false +// $ bazel-bin/examples/jax_cpp/main 2021-01-12 +// 15:35:28.316880: I examples/jax_cpp/main.cc:65] result = ( f32[2,2] { // { 1.5, 1.5 }, // { 3.5, 3.5 } // } @@ -40,12 +40,11 @@ limitations under the License. #include #include +#include "third_party/absl/status/statusor.h" #include "xla/literal.h" #include "xla/literal_util.h" +#include "xla/pjrt/cpu/cpu_client.h" #include "xla/pjrt/pjrt_client.h" -#include "xla/pjrt/tfrt_cpu_pjrt_client.h" -#include "xla/status.h" -#include "xla/statusor.h" #include "xla/tools/hlo_module_loader.h" #include "tsl/platform/init_main.h" #include "tsl/platform/logging.h" diff --git a/examples/mnist_classifier.py b/examples/mnist_classifier.py index a0fe8b996c98..a7730ab2b6aa 100644 --- a/examples/mnist_classifier.py +++ b/examples/mnist_classifier.py @@ -50,7 +50,7 @@ def accuracy(params, batch): Dense(10), LogSoftmax) if __name__ == "__main__": - rng = random.PRNGKey(0) + rng = random.key(0) step_size = 0.001 num_epochs = 10 diff --git a/examples/mnist_vae.py b/examples/mnist_vae.py index 141be978f635..df207afd8749 100644 --- a/examples/mnist_vae.py +++ b/examples/mnist_vae.py @@ -87,14 +87,14 @@ def image_grid(nrow, ncol, imagevecs, imshape): batch_size = 32 nrow, ncol = 10, 10 # sampled image grid size - test_rng = random.PRNGKey(1) # fixed prng key for evaluation + test_rng = random.key(1) # fixed prng key for evaluation imfile = os.path.join(os.getenv("TMPDIR", "/tmp/"), "mnist_vae_{:03d}.png") train_images, _, test_images, _ = datasets.mnist(permute_train=True) num_complete_batches, leftover = divmod(train_images.shape[0], batch_size) num_batches = num_complete_batches + bool(leftover) - enc_init_rng, dec_init_rng = random.split(random.PRNGKey(2)) + enc_init_rng, dec_init_rng = random.split(random.key(2)) _, init_encoder_params = encoder_init(enc_init_rng, (batch_size, 28 * 28)) _, init_decoder_params = decoder_init(dec_init_rng, (batch_size, 10)) init_params = init_encoder_params, init_decoder_params @@ -131,7 +131,7 @@ def evaluate(opt_state, images): opt_state = opt_init(init_params) for epoch in range(num_epochs): tic = time.time() - opt_state = run_epoch(random.PRNGKey(epoch), opt_state, train_images) + opt_state = run_epoch(random.key(epoch), opt_state, train_images) test_elbo, sampled_images = evaluate(opt_state, test_images) print(f"{epoch: 3d} {test_elbo} ({time.time() - tic:.3f} sec)") plt.imsave(imfile.format(epoch), sampled_images, cmap=plt.cm.gray) diff --git a/jax/BUILD b/jax/BUILD index 8386660f3c19..a4de4550c092 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -15,6 +15,7 @@ # JAX is Autograd and XLA load("@bazel_skylib//rules:common_settings.bzl", "bool_flag") +load("@rules_python//python:defs.bzl", "py_library") load( "//jaxlib:jax.bzl", "if_building_jaxlib", @@ -25,6 +26,7 @@ load( "jax_internal_test_harnesses_visibility", "jax_test_util_visibility", "jax_visibility", + "mosaic_gpu_internal_users", "mosaic_internal_users", "pallas_gpu_internal_users", "pallas_tpu_internal_users", @@ -33,7 +35,6 @@ load( "pytype_library", "pytype_strict_library", ) -load("@rules_python//python:defs.bzl", "py_library") package( default_applicable_licenses = [], @@ -104,6 +105,7 @@ package_group( name = "pallas_gpu_users", packages = [ "//...", + "//learning/brain/research/jax", ] + pallas_gpu_internal_users, ) @@ -111,9 +113,18 @@ package_group( name = "pallas_tpu_users", packages = [ "//...", + "//learning/brain/research/jax", ] + pallas_tpu_internal_users, ) +package_group( + name = "mosaic_gpu_users", + packages = [ + "//...", + "//learning/brain/research/jax", + ] + mosaic_gpu_internal_users, +) + # JAX-private test utilities. py_library( # This build target is required in order to use private test utilities in jax._src.test_util, @@ -170,14 +181,16 @@ py_library( ] + jax_internal_export_back_compat_test_util_visibility, deps = [ ":jax", - "//jax/experimental/export", ] + py_deps("numpy"), ) py_library( name = "internal_export_back_compat_test_data", testonly = 1, - srcs = glob(["_src/internal_test_util/export_back_compat_test_data/*.py"]), + srcs = glob([ + "_src/internal_test_util/export_back_compat_test_data/*.py", + "_src/internal_test_util/export_back_compat_test_data/pallas/*.py", + ]), visibility = [ ":internal", ], @@ -191,6 +204,7 @@ py_library_providing_imports_info( "_src/ad_checkpoint.py", "_src/api.py", "_src/array.py", + "_src/blocked_sampler.py", "_src/callback.py", "_src/checkify.py", "_src/custom_batching.py", @@ -199,6 +213,7 @@ py_library_providing_imports_info( "_src/debugging.py", "_src/dispatch.py", "_src/dlpack.py", + "_src/earray.py", "_src/flatten_util.py", "_src/interpreters/__init__.py", "_src/interpreters/ad.py", @@ -210,6 +225,7 @@ py_library_providing_imports_info( "_src/public_test_util.py", "_src/random.py", "_src/shard_alike.py", + "_src/sourcemap.py", "_src/stages.py", "_src/tree.py", ] + glob( @@ -218,6 +234,7 @@ py_library_providing_imports_info( "_src/debugger/**/*.py", "_src/extend/**/*.py", "_src/image/**/*.py", + "_src/export/**/*.py", "_src/lax/**/*.py", "_src/nn/**/*.py", "_src/numpy/**/*.py", @@ -271,6 +288,7 @@ py_library_providing_imports_info( ":cloud_tpu_init", ":compilation_cache_internal", ":compiler", + ":compute_on", ":config", ":core", ":custom_api_util", @@ -303,7 +321,7 @@ py_library_providing_imports_info( ":xla", ":xla_bridge", "//jax/_src/lib", - ] + py_deps("numpy") + py_deps("scipy") + py_deps("opt_einsum") + jax_extra_deps, + ] + py_deps("numpy") + py_deps("scipy") + py_deps("opt_einsum") + py_deps("flatbuffers") + jax_extra_deps, ) pytype_strict_library( @@ -357,7 +375,9 @@ pytype_strict_library( name = "cloud_tpu_init", srcs = ["_src/cloud_tpu_init.py"], deps = [ + ":config", ":hardware_utils", + ":version", ], ) @@ -370,6 +390,7 @@ pytype_strict_library( ":compilation_cache_interface", ":config", ":gfile_cache", + ":lru_cache", ":monitoring", ":path", "//jax/_src/lib", @@ -395,6 +416,15 @@ pytype_strict_library( ], ) +pytype_strict_library( + name = "lru_cache", + srcs = ["_src/lru_cache.py"], + deps = [ + ":compilation_cache_interface", + ":config", + ] + py_deps("filelock"), +) + pytype_strict_library( name = "config", srcs = ["_src/config.py"], @@ -432,7 +462,9 @@ pytype_strict_library( "_src/linear_util.py", ], deps = [ + ":compute_on", ":config", + ":deprecations", ":dtypes", ":effects", ":pretty_printer", @@ -548,6 +580,7 @@ pytype_strict_library( ":partial_eval", ":path", ":pickle_util", + ":sharding", ":sharding_impls", ":source_info_util", ":state_types", @@ -580,6 +613,7 @@ pytype_strict_library( exclude = [ "experimental/pallas/gpu.py", "experimental/pallas/tpu.py", + "experimental/pallas/ops/gpu/**/*.py", "experimental/pallas/ops/tpu/**/*.py", ], ), @@ -600,9 +634,14 @@ pytype_strict_library( ":pallas_tpu_users", ], deps = [ - ":pallas", # buildcleaner: keep + ":pallas", # build_cleaner: keep ":tpu_custom_call", - "//jax/_src/pallas/mosaic", + "//jax/_src/pallas/mosaic:core", + "//jax/_src/pallas/mosaic:lowering", + "//jax/_src/pallas/mosaic:pallas_call_registration", # build_cleaner: keep + "//jax/_src/pallas/mosaic:pipeline", + "//jax/_src/pallas/mosaic:primitives", + "//jax/_src/pallas/mosaic:random", ], ) @@ -640,16 +679,48 @@ pytype_strict_library( ], deps = [ ":pallas", - "//jax/_src/pallas/triton", + "//jax/_src/pallas/mosaic_gpu:pallas_call_registration", # build_cleaner: keep + "//jax/_src/pallas/triton:pallas_call_registration", # build_cleaner: keep + "//jax/_src/pallas/triton:primitives", ], ) +# This target only supports sm_90 GPUs. +py_library( + name = "mosaic_gpu", + srcs = glob(["experimental/mosaic/gpu/*.py"]), + visibility = [ + ":mosaic_gpu_users", + ], + deps = [ + ":config", + ":core", + ":jax", + ":mlir", + "//jax/_src/lib", + "//jaxlib/mlir:arithmetic_dialect", + "//jaxlib/mlir:builtin_dialect", + "//jaxlib/mlir:func_dialect", + "//jaxlib/mlir:gpu_dialect", + "//jaxlib/mlir:ir", + "//jaxlib/mlir:llvm_dialect", + "//jaxlib/mlir:math_dialect", + "//jaxlib/mlir:memref_dialect", + "//jaxlib/mlir:nvgpu_dialect", + "//jaxlib/mlir:nvvm_dialect", + "//jaxlib/mlir:pass_manager", + "//jaxlib/mlir:scf_dialect", + "//jaxlib/mlir:vector_dialect", + ] + py_deps("absl/flags") + py_deps("numpy"), +) + pytype_strict_library( name = "partial_eval", srcs = ["_src/interpreters/partial_eval.py"], deps = [ ":ad_util", ":api_util", + ":compute_on", ":config", ":core", ":dtypes", @@ -711,20 +782,28 @@ pytype_strict_library( name = "sharding", srcs = ["_src/sharding.py"], deps = [ + ":op_shardings", ":util", ":xla_bridge", "//jax/_src/lib", ], ) +pytype_strict_library( + name = "compute_on", + srcs = ["_src/compute_on.py"], + deps = [], +) + pytype_strict_library( name = "layout", srcs = ["_src/layout.py"], deps = [ - ":util", - ":xla_bridge", + ":dtypes", + ":sharding", + ":sharding_impls", "//jax/_src/lib", - ], + ] + py_deps("numpy"), ) pytype_strict_library( @@ -732,6 +811,7 @@ pytype_strict_library( srcs = ["_src/sharding_impls.py"], deps = [ ":config", + ":core", ":mesh", ":op_shardings", ":partition_spec", @@ -826,6 +906,8 @@ pytype_strict_library( "//jax/_src/lib", ] + if_building_jaxlib([ "//jaxlib/mlir:ir", + "//jaxlib/mlir:mhlo_dialect", + "//jaxlib/mlir:pass_manager", "//jaxlib/mlir:stablehlo_dialect", ]) + py_deps("numpy") + py_deps("absl/flags"), ) @@ -870,6 +952,7 @@ pytype_strict_library( "_src/clusters/cluster.py", "_src/clusters/ompi_cluster.py", "_src/clusters/slurm_cluster.py", + "_src/clusters/mpi4py_cluster.py", "_src/distributed.py", "_src/xla_bridge.py", ], @@ -993,7 +1076,11 @@ pytype_library( pytype_library( name = "experimental_host_callback", - srcs = ["experimental/host_callback.py"], + srcs = [ + "experimental/__init__.py", # To support JAX_HOST_CALLBACK_LEGACY=False + "experimental/host_callback.py", + "experimental/x64_context.py", # To support JAX_HOST_CALLBACK_LEGACY=False + ], visibility = ["//visibility:public"], deps = [ ":jax", @@ -1024,6 +1111,7 @@ pytype_strict_library( visibility = [":jax_extend_users"], deps = [ "//jax/extend", + "//jax/extend:backend", "//jax/extend:core", "//jax/extend:linear_util", "//jax/extend:random", diff --git a/jax/__init__.py b/jax/__init__.py index 7086b9e9c66a..cedbec4b8d75 100644 --- a/jax/__init__.py +++ b/jax/__init__.py @@ -47,7 +47,7 @@ from jax._src.config import ( config as config, enable_checks as enable_checks, - enable_key_reuse_checks as enable_key_reuse_checks, + debug_key_reuse as debug_key_reuse, check_tracer_leaks as check_tracer_leaks, checking_leaks as checking_leaks, enable_custom_prng as enable_custom_prng, @@ -81,7 +81,7 @@ from jax._src.api import block_until_ready as block_until_ready from jax._src.ad_checkpoint import checkpoint_wrapper as checkpoint from jax._src.ad_checkpoint import checkpoint_policies as checkpoint_policies -from jax._src.api import clear_backends as clear_backends +from jax._src.api import clear_backends as _deprecated_clear_backends from jax._src.api import clear_caches as clear_caches from jax._src.custom_derivatives import closure_convert as closure_convert from jax._src.custom_derivatives import custom_gradient as custom_gradient @@ -118,13 +118,14 @@ from jax._src.api import pmap as pmap from jax._src.xla_bridge import process_count as process_count from jax._src.xla_bridge import process_index as process_index -from jax._src.callback import pure_callback_api as pure_callback +from jax._src.callback import pure_callback as pure_callback from jax._src.ad_checkpoint import checkpoint_wrapper as remat from jax._src.api import ShapeDtypeStruct as ShapeDtypeStruct from jax._src.api import value_and_grad as value_and_grad from jax._src.api import vjp as vjp from jax._src.api import vmap as vmap -from jax._src.api import xla_computation as xla_computation +from jax._src.api import xla_computation as _deprecated_xla_computation +from jax._src.sharding_impls import NamedSharding as NamedSharding # Force import, allowing jax.interpreters.* to be used after import jax. from jax.interpreters import ad, batching, mlir, partial_eval, pxla, xla @@ -133,6 +134,7 @@ from jax._src.array import ( make_array_from_single_device_arrays as make_array_from_single_device_arrays, make_array_from_callback as make_array_from_callback, + make_array_from_process_local_data as make_array_from_process_local_data, ) from jax._src.tree_util import ( @@ -176,9 +178,10 @@ import jax.experimental.compilation_cache.compilation_cache as _ccache del _ccache -# TODO(jakevdp): remove this when jax/config.py is removed. from jax._src.deprecations import register as _register_deprecation -_register_deprecation("jax.config", "config-module") +_register_deprecation("jax-experimental-maps-module") +_register_deprecation('jax-scipy-beta-args') +_register_deprecation('tracer-hash') del _register_deprecation _deprecations = { @@ -218,10 +221,25 @@ "or jax.tree_util.tree_map (any JAX version).", _deprecated_tree_map ), + # Added Mar 18, 2024 + "clear_backends": ( + "jax.clear_backends is deprecated.", + _deprecated_clear_backends + ), + # Added Jun 16, 2024 + "xla_computation": ( + "jax.xla_computation is deprecated. Please use the AOT APIs; see " + "https://jax.readthedocs.io/en/latest/aot.html. For example, replace " + "xla_computation(f)(*xs) with jit(f).lower(*xs).compiler_ir('hlo'). See " + "CHANGELOG.md for 0.4.30 for more examples.", + _deprecated_xla_computation + ), } import typing as _typing if _typing.TYPE_CHECKING: + from jax._src.api import clear_backends as clear_backends + from jax._src.api import xla_computation as xla_computation from jax._src.tree_util import treedef_is_leaf as treedef_is_leaf from jax._src.tree_util import tree_flatten as tree_flatten from jax._src.tree_util import tree_leaves as tree_leaves diff --git a/jax/_src/abstract_arrays.py b/jax/_src/abstract_arrays.py index 1a1484e27eef..423dbaf9ce58 100644 --- a/jax/_src/abstract_arrays.py +++ b/jax/_src/abstract_arrays.py @@ -70,4 +70,4 @@ def _make_concrete_python_scalar(t, x): for t in dtypes.python_scalar_dtypes: core.pytype_aval_mappings[t] = partial(_make_concrete_python_scalar, t) -core.literalable_types.update(dtypes.python_scalar_dtypes.keys()) # type: ignore +core.literalable_types.update(dtypes.python_scalar_dtypes.keys()) diff --git a/jax/_src/ad_checkpoint.py b/jax/_src/ad_checkpoint.py index f3112c837eab..6f752e0e2e4a 100644 --- a/jax/_src/ad_checkpoint.py +++ b/jax/_src/ad_checkpoint.py @@ -14,11 +14,11 @@ from __future__ import annotations -from collections.abc import Sequence +from collections.abc import Callable, Sequence import functools from functools import partial import logging -from typing import Any, Callable +from typing import Any import types import numpy as np @@ -40,6 +40,7 @@ from jax._src.interpreters import partial_eval as pe from jax._src.lax import lax as lax_internal from jax._src.lax import convolution as lax_convolution +from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import hlo from jax._src.traceback_util import api_boundary from jax._src.tree_util import tree_flatten, tree_unflatten, tree_structure, keystr @@ -465,11 +466,14 @@ def _saved_residuals(jaxpr, arg_info) -> list[tuple[core.AbstractValue, str]]: src = 'from the argument at flattened index {i}' results.append((v.aval, src)) + named_vars = {v: e for e in jaxpr.eqns if e.primitive is name_p + for v in e.invars} + for eqn in jaxpr.eqns: src = source_info_util.summarize(eqn.source_info) for v in eqn.outvars: if v in res_vars: - if eqn.primitive is name_p: + if eqn.primitive is name_p or v in named_vars and (eqn := named_vars[v]): results.append((v.aval, f"named '{eqn.params['name']}' from {src}")) elif str(eqn.primitive) == 'xla_call': results.append((v.aval, @@ -655,7 +659,7 @@ def transposed(*args_flat): transposed_jaxpr_, _, consts, () = pe.trace_to_jaxpr_dynamic(transposed, in_avals) transposed_jaxpr = core.ClosedJaxpr(transposed_jaxpr_, consts) - return transposed_jaxpr, cell.in_cts_zero # type: ignore + return transposed_jaxpr, cell.in_cts_zero # pytype: disable=attribute-error def remat_vmap(spmd_axis_name, axis_size, axis_name, main_type, args, dims, *, jaxpr, **params): @@ -684,7 +688,7 @@ def remat_dce(used_outputs: list[bool], eqn: core.JaxprEqn new_eqn = pe.new_jaxpr_eqn( [v for v, used in zip(eqn.invars, used_inputs) if used], [v for v, used in zip(eqn.outvars, used_outputs) if used], - eqn.primitive, new_params, new_jaxpr.effects, eqn.source_info) + eqn.primitive, new_params, new_jaxpr.effects, eqn.source_info, eqn.ctx) return used_inputs, new_eqn pe.dce_rules[remat_p] = remat_dce @@ -692,9 +696,9 @@ def _has_effects(effects) -> bool: return bool({e for e in effects if not isinstance(e, core.NamedAxisEffect)}) -def remat_lowering(*args, jaxpr: core.Jaxpr, prevent_cse: bool, - differentiated: bool, is_gpu_platform: bool = False, - **_): +def remat_expansion(*args, jaxpr: core.Jaxpr, prevent_cse: bool, + differentiated: bool, is_gpu_platform: bool = False, + **_): assert not jaxpr.constvars if differentiated and prevent_cse: @@ -763,13 +767,33 @@ def dummy_comp(*args): unif = lax_internal.rng_uniform(np.float32(0), np.float32(1), shape=()) return lax_control_flow.cond(unif < np.float32(2), remat_comp, dummy_comp, *args) -mlir.register_lowering( - remat_p, mlir.lower_fun(remat_lowering, multiple_results=True)) -mlir.register_lowering( - remat_p, - mlir.lower_fun(partial(remat_lowering, is_gpu_platform=True), - multiple_results=True), - platform="gpu") +def _remat_lowering(ctx, *args, jaxpr: core.Jaxpr, prevent_cse: bool, + differentiated: bool, policy, is_gpu_platform=False): + jaxpr_args: Sequence[Sequence[ir.Value]] + if differentiated and prevent_cse: + # If we're using the loop or cond lowerings, use the slower lower_fun + # based path. + if not config.remat_opt_barrier.value: + return mlir.lower_fun(remat_expansion, multiple_results=True)( + ctx, *args, jaxpr=jaxpr, prevent_cse=prevent_cse, + differentiated=differentiated, policy=policy, + is_gpu_platform=is_gpu_platform) + + arg_types = map(mlir.aval_to_ir_types, ctx.avals_in) + flat_args = mlir.flatten_lowering_ir_args(args) + barrier_op = hlo.OptimizationBarrierOp(flat_args) + jaxpr_args = util.unflatten(barrier_op.results, map(len, arg_types)) + else: + jaxpr_args = map(mlir.wrap_singleton_ir_values, args) + outs, tokens_out = mlir.jaxpr_subcomp( + ctx.module_context, jaxpr, ctx.name_stack.extend('checkpoint'), + ctx.tokens_in, (), *jaxpr_args, dim_var_values=ctx.dim_var_values) + ctx.set_tokens_out(tokens_out) + return outs + +mlir.register_lowering(remat_p, _remat_lowering) +mlir.register_lowering(remat_p, partial(_remat_lowering, is_gpu_platform=True), + platform="gpu") def _optimization_barrier_abstract_eval(*args): return args diff --git a/jax/_src/ad_util.py b/jax/_src/ad_util.py index 257fcc7c527c..90ae6c1413ec 100644 --- a/jax/_src/ad_util.py +++ b/jax/_src/ad_util.py @@ -13,8 +13,9 @@ # limitations under the License. from __future__ import annotations +from collections.abc import Callable import types -from typing import Any, Callable, TypeVar +from typing import Any, TypeVar from jax._src import core from jax._src import traceback_util diff --git a/jax/_src/api.py b/jax/_src/api.py index 005e0ceca3a1..4a42693c2e8f 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -23,12 +23,12 @@ from __future__ import annotations import collections -from collections.abc import Generator, Hashable, Iterable, Sequence -from functools import partial +from collections.abc import Callable, Generator, Hashable, Iterable, Sequence +from functools import partial, lru_cache import inspect import math import typing -from typing import (Any, Callable, Literal, NamedTuple, TypeVar, overload, +from typing import (Any, Literal, NamedTuple, TypeVar, overload, cast) import weakref @@ -65,14 +65,14 @@ from jax._src.lax import lax as lax_internal from jax._src.lib import jax_jit from jax._src.lib import xla_client as xc -from jax._src.lib import xla_extension_version from jax._src.lib import pmap_lib from jax._src.sharding import Sharding -from jax._src.sharding_impls import (PmapSharding, TransferToMemoryKind, - XLACompatibleSharding) +from jax._src.sharding_impls import PmapSharding, TransferToMemoryKind +from jax._src.layout import Layout, AutoLayout from jax._src.traceback_util import api_boundary from jax._src import tree_util -from jax._src.util import unzip2, safe_map, safe_zip, wrap_name, wraps +from jax._src.util import (unzip2, safe_map, safe_zip, wrap_name, wraps, + split_list) from jax._src import util from jax._src.interpreters import ad @@ -156,60 +156,37 @@ def jit( """Sets up ``fun`` for just-in-time compilation with XLA. Args: - fun: Function to be jitted. ``fun`` should be a pure function, as - side-effects may only be executed once. - - The arguments and return value of ``fun`` should be arrays, - scalars, or (nested) standard Python containers (tuple/list/dict) thereof. - Positional arguments indicated by ``static_argnums`` can be anything at - all, provided they are hashable and have an equality operation defined. - Static arguments are included as part of a compilation cache key, which is - why hash and equality operators must be defined. - - JAX keeps a weak reference to ``fun`` for use as a compilation cache key, - so the object ``fun`` must be weakly-referenceable. Most :class:`Callable` - objects will already satisfy this requirement. - in_shardings: Pytree of structure matching that of arguments to ``fun``, - with all actual arguments replaced by resource assignment specifications. - It is also valid to specify a pytree prefix (e.g. one value in place of a - whole subtree), in which case the leaves get broadcast to all values in - that subtree. - - The ``in_shardings`` argument is optional. JAX will infer the shardings - from the input :py:class:`jax.Array`'s and defaults to replicating the input - if the sharding cannot be inferred. - - The valid resource assignment specifications are: - - :py:class:`XLACompatibleSharding`, which will decide how the value - will be partitioned. With this, using a mesh context manager is not - required. - - :py:obj:`None`, will give JAX the freedom to choose whatever sharding - it wants. - For in_shardings, JAX will mark is as replicated but this behavior - can change in the future. - For out_shardings, we will rely on the XLA GSPMD partitioner to - determine the output shardings. - - The size of every dimension has to be a multiple of the total number of - resources assigned to it. This is similar to pjit's in_shardings. - out_shardings: Like ``in_shardings``, but specifies resource - assignment for function outputs. This is similar to pjit's - out_shardings. - - The ``out_shardings`` argument is optional. If not specified, :py:func:`jax.jit` - will use GSPMD's sharding propagation to figure out what the sharding of the - output(s) should be. - static_argnums: An optional int or collection of ints that specify which - positional arguments to treat as static (compile-time constant). - Operations that only depend on static arguments will be constant-folded in - Python (during tracing), and so the corresponding argument values can be - any Python object. + fun: Function to be jitted. ``fun`` should be a pure function. + + The arguments and return value of ``fun`` should be arrays, scalar, or + (nested) standard Python containers (tuple/list/dict) thereof. Positional + arguments indicated by ``static_argnums`` can be any hashable type. Static + arguments are included as part of a compilation cache key, which is why + hash and equality operators must be defined. JAX keeps a weak reference to + ``fun`` for use as a compilation cache key, so the object ``fun`` must be + weakly-referenceable. + in_shardings: optional, a :py:class:`Sharding` or pytree with + :py:class:`Sharding` leaves and structure that is a tree prefix of the + positional arguments tuple to ``fun``. If provided, the positional + arguments passed to ``fun`` must have shardings that are compatible with + ``in_shardings`` or an error is raised, and the compiled computation has + input shardings corresponding to ``in_shardings``. If not provided, the + compiled computation's input shardings are inferred from argument + shardings. + out_shardings: optional, a :py:class:`Sharding` or pytree with + :py:class:`Sharding` leaves and structure that is a tree prefix of the + output of ``fun``. If provided, it has the same effect as applying + corresponding :py:func:`jax.lax.with_sharding_constraint`s to the output + of ``fun``. + static_argnums: optional, an int or collection of ints that specify which + positional arguments to treat as static (trace- and compile-time + constant). Static arguments should be hashable, meaning both ``__hash__`` and - ``__eq__`` are implemented, and immutable. Calling the jitted function - with different values for these constants will trigger recompilation. - Arguments that are not arrays or containers thereof must be marked as - static. + ``__eq__`` are implemented, and immutable. Otherwise they can be arbitrary + Python objects. Calling the jitted function with different values for + these constants will trigger recompilation. Arguments that are not + array-like or containers thereof must be marked as static. If neither ``static_argnums`` nor ``static_argnames`` is provided, no arguments are treated as static. If ``static_argnums`` is not provided but @@ -220,17 +197,18 @@ def jit( provided, ``inspect.signature`` is not used, and only actual parameters listed in either ``static_argnums`` or ``static_argnames`` will be treated as static. - static_argnames: An optional string or collection of strings specifying + static_argnames: optional, a string or collection of strings specifying which named arguments to treat as static (compile-time constant). See the comment on ``static_argnums`` for details. If not provided but ``static_argnums`` is set, the default is based on calling ``inspect.signature(fun)`` to find corresponding named arguments. - donate_argnums: Specify which positional argument buffers are "donated" to - the computation. It is safe to donate argument buffers if you no longer - need them once the computation has finished. In some cases XLA can make - use of donated buffers to reduce the amount of memory needed to perform a + donate_argnums: optional, collection of integers to specify which positional + argument buffers can be overwritten by the computation and marked deleted + in the caller. It is safe to donate argument buffers if you no longer need + them once the computation has started. In some cases XLA can make use of + donated buffers to reduce the amount of memory needed to perform a computation, for example recycling one of your input buffers to store a - result. You should not reuse buffers that you donate to a computation, JAX + result. You should not reuse buffers that you donate to a computation; JAX will raise an error if you try to. By default, no argument buffers are donated. @@ -246,15 +224,16 @@ def jit( For more details on buffer donation see the `FAQ `_. - donate_argnames: An optional string or collection of strings specifying + donate_argnames: optional, a string or collection of strings specifying which named arguments are donated to the computation. See the comment on ``donate_argnums`` for details. If not provided but ``donate_argnums`` is set, the default is based on calling ``inspect.signature(fun)`` to find corresponding named arguments. - keep_unused: If `False` (the default), arguments that JAX determines to be - unused by `fun` *may* be dropped from resulting compiled XLA executables. - Such arguments will not be transferred to the device nor provided to the - underlying executable. If `True`, unused arguments will not be pruned. + keep_unused: optional boolean. If `False` (the default), arguments that JAX + determines to be unused by `fun` *may* be dropped from resulting compiled + XLA executables. Such arguments will not be transferred to the device nor + provided to the underlying executable. If `True`, unused arguments will + not be pruned. device: This is an experimental feature and the API is likely to change. Optional, the Device the jitted function will run on. (Available devices can be retrieved via :py:func:`jax.devices`.) The default is inherited @@ -263,9 +242,8 @@ def jit( backend: This is an experimental feature and the API is likely to change. Optional, a string representing the XLA backend: ``'cpu'``, ``'gpu'``, or ``'tpu'``. - inline: Specify whether this function should be inlined into enclosing - jaxprs (rather than being represented as an application of the xla_call - primitive with its own subjaxpr). Default False. + inline: Optional boolean. Specify whether this function should be inlined + into enclosing jaxprs. Default False. Returns: A wrapped version of ``fun``, set up for just-in-time compilation. @@ -280,14 +258,14 @@ def jit( ... def selu(x, alpha=1.67, lmbda=1.05): ... return lmbda * jax.numpy.where(x > 0, x, alpha * jax.numpy.exp(x) - alpha) >>> - >>> key = jax.random.PRNGKey(0) + >>> key = jax.random.key(0) >>> x = jax.random.normal(key, (10,)) >>> print(selu(x)) # doctest: +SKIP [-0.54485 0.27744 -0.29255 -0.91421 -0.62452 -0.24748 -0.85743 -0.78232 0.76827 0.59566 ] - To pass arguments such as ``static_argnames`` when decorating a function, a common - pattern is to use :func:`functools.partial`: + To pass arguments such as ``static_argnames`` when decorating a function, a + common pattern is to use :func:`functools.partial`: >>> from functools import partial >>> @@ -300,34 +278,10 @@ def jit( >>> g(jnp.arange(4), 3) Array([ 0, 1, 256, 6561], dtype=int32) """ - (in_shardings, out_shardings, donate_argnums, donate_argnames, static_argnums, - static_argnames) = pjit.pre_infer_params( + return pjit.make_jit( fun, in_shardings, out_shardings, donate_argnums, donate_argnames, - static_argnums, static_argnames, device, backend, abstracted_axes) - - fun_sourceinfo = api_util.fun_sourceinfo(fun) - fun_signature = api_util.fun_signature(fun) - - def infer_params(*args, **kwargs): - # TODO(yashkatariya): Remove this when it's added on jit. - in_layouts = kwargs.pop('_in_layouts', None) - out_layouts = kwargs.pop('_out_layouts', None) - pjit_info_args = pjit.PjitInfo( - fun=fun, fun_sourceinfo=fun_sourceinfo, fun_signature=fun_signature, - in_shardings=in_shardings, - out_shardings=out_shardings, static_argnums=static_argnums, - static_argnames=static_argnames, donate_argnums=donate_argnums, - donate_argnames=donate_argnames, device=device, backend=backend, - keep_unused=keep_unused, inline=inline, resource_env=None, - abstracted_axes=abstracted_axes, in_layouts=in_layouts, - out_layouts=out_layouts) - return pjit.common_infer_params(pjit_info_args, *args, **kwargs) - - has_explicit_sharding = pjit._pjit_explicit_sharding( - in_shardings, out_shardings, device, backend) - return pjit.post_infer_params(fun, infer_params, static_argnums, - static_argnames, donate_argnums, - abstracted_axes, has_explicit_sharding) + static_argnums, static_argnames, device, backend, abstracted_axes, + keep_unused, inline, use_resource_env=False) @contextmanager @@ -558,7 +512,7 @@ def computation_maker(*args, **kwargs): f, dyn_args = argnums_partial_except(f, static_argnums, args, allow_invalid=False) args_flat, in_tree = tree_flatten((dyn_args, kwargs)) if donate_argnums: - donated_invars = donation_vector(donate_argnums, (), dyn_args, kwargs) + donated_invars = donation_vector(donate_argnums, (), in_tree) else: donated_invars = (False,) * len(args_flat) @@ -589,10 +543,10 @@ def computation_maker(*args, **kwargs): arg_shardings=None, result_shardings=None, lowering_parameters=mlir.LoweringParameters()) + + m = mlir.module_to_bytecode(lowering_result.module) built = xc._xla.mlir.mlir_module_to_xla_computation( - mlir.module_to_string(lowering_result.module), - use_tuple_args=tuple_args, - return_tuple=True) + m, use_tuple_args=tuple_args, return_tuple=True) out_shapes_flat = [ ShapeDtypeStruct(a.shape, a.dtype, a.named_shape) for a in out_avals] out_shape = tree_unflatten(out_tree(), out_shapes_flat) @@ -1676,7 +1630,7 @@ def _prepare_pmap(fun, in_axes, out_axes, static_broadcasted_tuple, args, in_tree = tree_flatten((dyn_args, kwargs)) if donate_tuple and not config.debug_nans.value: - donated_invars = donation_vector(donate_tuple, (), dyn_args, kwargs) + donated_invars = donation_vector(donate_tuple, (), in_tree) else: donated_invars = (False,) * len(args) try: @@ -1853,61 +1807,55 @@ def cache_miss(*args, **kwargs): cpp_mapped_f = pmap_lib.pmap( fun, cache_miss, static_broadcasted_tuple, - pxla.shard_arg if xla_extension_version >= 229 else pxla.temp_shard_arg, # type: ignore + lambda x, s: pxla.shard_args([s], [x])[0], pytree_registry=tree_util.default_registry) _pmap_cache_clears.add(cpp_mapped_f) pmap_f = wraps(fun)(cpp_mapped_f) - pmap_f.lower = _pmap_lower( - fun, axis_name, in_axes, out_axes, static_broadcasted_tuple, devices, - backend, axis_size, donate_tuple) - - return pmap_f - -_pmap_cache_clears = weakref.WeakSet() # type: ignore - + @api_boundary + def lower(*args, **kwargs): + return trace(*args, **kwargs).lower() -def _pmap_lower(fun, axis_name, in_axes, out_axes, static_broadcasted_tuple, - devices, backend, axis_size, donate_tuple): # noqa: F811 - """Make a ``lower`` method for pmapped functions.""" - # If the function we returned from ``pmap`` were a class instance, - # this might naturally be a method, with ``fun`` as a ``self`` and - # all the other arguments stored as attributes. @api_boundary - def lower(*args, **kwargs) -> stages.Lowered: - """Lower a parallel-mapped form of this function for the given arguments. - - A parallel-mapped and lowered function is staged out of Python and - translated to a compiler's input language, possibly in a - backend-dependent manner. It is ready for compilation but is not yet - compiled. It represents a function intended for SPMD execution on - multiple devices. - - Returns: - A ``Lowered`` instance representing the post-map lowering. - """ - lowering_parameters = kwargs.pop( - '_experimental_lowering_parameters', mlir.LoweringParameters()) + def trace(*args, **kwargs): p = _prepare_pmap( fun, in_axes, out_axes, static_broadcasted_tuple, donate_tuple, devices, backend, axis_size, args, kwargs) abstract_args = list(map(shaped_abstractify, p.flat_args)) - computation = pxla.lower_parallel_callable( + closed_jaxpr, xc_backend, replicas, shards, pci = pxla.get_pmap_jaxpr( p.flat_fun, backend, axis_name, axis_size=p.local_axis_size, global_axis_size=p.global_axis_size, devices=p.devices, name=p.flat_fun.__name__, in_axes=p.in_axes_flat, out_axes_thunk=p.out_axes_thunk, + avals=abstract_args) + lower_callable = partial( + pxla.lower_parallel_callable, p.flat_fun, axis_name, + axis_size=p.local_axis_size, global_axis_size=p.global_axis_size, + devices=p.devices, + name=p.flat_fun.__name__, + in_axes=p.in_axes_flat, donated_invars=p.donated_invars, is_explicit_global_axis_size=p.is_explicit_global_axis_size, avals=abstract_args, - lowering_parameters=lowering_parameters) - return stages.Lowered.from_flat_info( - computation, p.in_tree, abstract_args, donate_tuple, p.out_tree()) + closed_jaxpr=closed_jaxpr, + backend=xc_backend, + replicas=replicas, + shards=shards, + pci=pci) + args_info = stages.make_args_info(p.in_tree, abstract_args, donate_tuple) + return stages.Traced(closed_jaxpr, args_info, p.flat_fun.__name__, + p.out_tree(), lower_callable) + + pmap_f.lower = lower + pmap_f.trace = trace + + return pmap_f + +_pmap_cache_clears = weakref.WeakSet() # type: ignore - return lower def jvp( fun: Callable, primals, tangents, has_aux: bool = False @@ -1965,7 +1913,7 @@ def _jvp(fun: lu.WrappedFun, primals, tangents, has_aux=False): raise TypeError("primal and tangent arguments to jax.jvp must have the same tree " f"structure; primals have tree structure {tree_def} whereas tangents have " f"tree structure {tree_def_2}.") - for p, t in safe_zip(ps_flat, ts_flat): + for p, t in zip(ps_flat, ts_flat): if core.primal_dtype_to_tangent_dtype(_dtype(p)) != _dtype(t): raise TypeError("primal and tangent arguments to jax.jvp do not match; " "dtypes must be equal, or in case of int/bool primal dtype " @@ -2111,8 +2059,7 @@ def fun(*tangents): return apply_flat_fun_nokwargs(fun, io_tree, py_args) -def _vjp_pullback_wrapper(name, cotangent_dtypes, cotangent_shapes, io_tree, - fun, *py_args_): +def _vjp_pullback_wrapper(name, out_primal_avals, io_tree, fun, *py_args_): if len(py_args_) != 1: msg = (f"The function returned by `jax.vjp` applied to {name} was called " f"with {len(py_args_)} arguments, but functions returned by " @@ -2138,23 +2085,27 @@ def _vjp_pullback_wrapper(name, cotangent_dtypes, cotangent_shapes, io_tree, in_tree_expected, out_tree = io_tree args, in_tree = tree_flatten(py_args) if in_tree != in_tree_expected: - raise TypeError(f"Tree structure of cotangent input {in_tree}, does not match structure of " - f"primal output {in_tree_expected}.") - for arg, ct_dtype, ct_shape in safe_zip(args, cotangent_dtypes, cotangent_shapes): - expected_tangent_dtype = core.primal_dtype_to_tangent_dtype(_dtype(arg)) - if expected_tangent_dtype != ct_dtype: - raise TypeError( - f"Type of cotangent input to vjp pullback function ({ct_dtype}) is not " - f"the expected tangent type ({expected_tangent_dtype}) of corresponding primal output " - f"with dtype {_dtype(arg)}.") - if np.shape(arg) != ct_shape: + raise ValueError(f"unexpected tree structure of argument to vjp function: " + f"got {in_tree}, but expected to match {in_tree_expected}") + for arg, aval in zip(args, out_primal_avals): + ct_aval = shaped_abstractify(arg) + ct_aval_expected = aval.at_least_vspace() + if (not core.typecompat(ct_aval, ct_aval_expected) and + not _temporary_dtype_exception(ct_aval, ct_aval_expected)): raise ValueError( - f"Shape of cotangent input to vjp pullback function {np.shape(arg)} " - "must be the same as the shape of corresponding primal input " - f"{ct_shape}.") + "unexpected JAX type (e.g. shape/dtype) for argument to vjp function: " + f"got {ct_aval.str_short()}, but expected {ct_aval_expected.str_short()} " + f"because the corresponding output of the function {name} had JAX type " + f"{aval.str_short()}") ans = fun(*args) return tree_unflatten(out_tree, ans) +# TODO(mattjj): see similar function in custom_derivatives.py +def _temporary_dtype_exception(a, a_) -> bool: + if isinstance(a, core.ShapedArray) and isinstance(a_, core.ShapedArray): + return a.shape == a_.shape and a_.dtype == float0 + return False + @overload def vjp(fun: Callable[..., T], *primals: Any, @@ -2167,7 +2118,7 @@ def vjp(fun: Callable[..., tuple[T, U]], *primals: Any, has_aux: Literal[True], reduce_axes: Sequence[AxisName] = ()) -> tuple[T, Callable, U]: ... -def vjp( # type: ignore +def vjp( fun: Callable, *primals, has_aux: bool = False, reduce_axes=() ) -> tuple[Any, Callable] | tuple[Any, Callable, Any]: """Compute a (reverse-mode) vector-Jacobian product of ``fun``. @@ -2222,21 +2173,16 @@ def _vjp(fun: lu.WrappedFun, *primals, has_aux=False): for arg in primals_flat: dispatch.check_arg(arg) if not has_aux: flat_fun, out_tree = flatten_fun_nokwargs(fun, in_tree) - out_primal, out_vjp = ad.vjp(flat_fun, primals_flat) + out_primals, vjp = ad.vjp(flat_fun, primals_flat) out_tree = out_tree() else: flat_fun, out_aux_trees = flatten_fun_nokwargs2(fun, in_tree) - out_primal, out_vjp, aux = ad.vjp( - flat_fun, primals_flat, has_aux=True) + out_primals, vjp, aux = ad.vjp(flat_fun, primals_flat, has_aux=True) out_tree, aux_tree = out_aux_trees() - out_primal_py = tree_unflatten(out_tree, out_primal) - ct_dtypes = [core.primal_dtype_to_tangent_dtype(_dtype(x)) for x in out_primal] - ct_shapes = [np.shape(x) for x in out_primal] - # Ensure that vjp_py is a PyTree so that we can pass it from the forward to the - # backward pass in a custom VJP. + out_primal_avals = map(shaped_abstractify, out_primals) + out_primal_py = tree_unflatten(out_tree, out_primals) vjp_py = Partial(partial(_vjp_pullback_wrapper, fun.__name__, - ct_dtypes, ct_shapes, (out_tree, in_tree)), - out_vjp) + out_primal_avals, (out_tree, in_tree)), vjp) if not has_aux: return out_primal_py, vjp_py else: @@ -2340,7 +2286,7 @@ def make_jaxpr(fun: Callable, # type: ignore ... @overload -def make_jaxpr(fun: Callable, # type: ignore +def make_jaxpr(fun: Callable, static_argnums: int | Iterable[int] = (), axis_env: Sequence[tuple[AxisName, int]] | None = None, return_shape: Literal[True] = ..., @@ -2369,11 +2315,11 @@ def make_jaxpr(fun: Callable, specifies the axis name/size environment that would be set up by applications of :py:func:`jax.pmap`. return_shape: Optional boolean, defaults to ``False``. If ``True``, the - wrapped function returns a pair where the first element is the XLA - computation and the second element is a pytree with the same structure as - the output of ``fun`` and where the leaves are objects with ``shape``, - ``dtype``, and ``named_shape`` attributes representing the corresponding - types of the output leaves. + wrapped function returns a pair where the first element is the + ``ClosedJaxpr`` representation of ``fun`` and the second element is a + pytree with the same structure as the output of ``fun`` and where the + leaves are objects with ``shape``, ``dtype``, and ``named_shape`` + attributes representing the corresponding types of the output leaves. Returns: A wrapped version of ``fun`` that when applied to example arguments returns @@ -2411,43 +2357,34 @@ def make_jaxpr(fun: Callable, g:f32[] = mul f c in (g,) } """ - check_callable(fun) - static_argnums = _ensure_index_tuple(static_argnums) - - def abstractify(args, kwargs): - flat_args, in_tree = tree_flatten((args, kwargs)) - if abstracted_axes is None: - return map(shaped_abstractify, flat_args), in_tree, [True] * len(flat_args) - else: - axes_specs = _flat_axes_specs(abstracted_axes, *args, **kwargs) - in_type = pe.infer_lambda_input_type(axes_specs, flat_args) - in_avals, keep_inputs = unzip2(in_type) - return in_avals, in_tree, keep_inputs + try: + hash(fun) + weakref.ref(fun) + except TypeError: + fun = partial(fun) @wraps(fun) @api_boundary def make_jaxpr_f(*args, **kwargs): - f = lu.wrap_init(fun) - if static_argnums: - dyn_argnums = [i for i in range(len(args)) if i not in static_argnums] - f, args = argnums_partial(f, dyn_argnums, args) - in_avals, in_tree, keep_inputs = abstractify(args, kwargs) - in_type = tuple(zip(in_avals, keep_inputs)) - f, out_tree = flatten_fun(f, in_tree) - f = lu.annotate(f, in_type) - debug_info = pe.debug_info(fun, in_tree, out_tree, True, 'make_jaxpr') with ExitStack() as stack: for axis_name, size in axis_env or []: stack.enter_context(core.extend_axis_env(axis_name, size, None)) - jaxpr, out_type, consts = pe.trace_to_jaxpr_dynamic2( - f, debug_info=debug_info) - closed_jaxpr = core.ClosedJaxpr(jaxpr, consts) + traced = jit(fun, static_argnums=static_argnums, + abstracted_axes=abstracted_axes).trace(*args, **kwargs) + # `jit` converts tracers in consts to args but that breaks the semantics of + # `make_jaxpr`. Hence convert the tracers in args back to consts in jaxpr. + if traced._num_consts: + consts, _ = split_list(traced._args_flat, [traced._num_consts]) + jaxpr_ = pe.convert_invars_to_constvars(traced.jaxpr.jaxpr, + traced._num_consts) + jaxpr = core.ClosedJaxpr(jaxpr_, consts) + else: + jaxpr = traced.jaxpr if return_shape: - out_avals, _ = unzip2(out_type) - out_shapes_flat = [ - ShapeDtypeStruct(a.shape, a.dtype, a.named_shape) for a in out_avals] - return closed_jaxpr, tree_unflatten(out_tree(), out_shapes_flat) - return closed_jaxpr + out = [ShapeDtypeStruct(o.shape, o.dtype, getattr(o, 'named_shape', None)) + for o in jaxpr.out_avals] + return jaxpr, tree_unflatten(tree_structure(traced.out_info), out) + return jaxpr make_jaxpr_f.__module__ = "jax" if hasattr(fun, "__qualname__"): @@ -2458,7 +2395,8 @@ def make_jaxpr_f(*args, **kwargs): def _infer_src_sharding(src, x) -> Sharding | None: if src is not None: - return src + # TODO(slebedev): This looks like an error and needs investigation. + return src # pytype: disable=bad-return-type if isinstance(x, array.ArrayImpl): return x.sharding elif isinstance(x, core.Tracer): @@ -2468,12 +2406,20 @@ def _infer_src_sharding(src, x) -> Sharding | None: return None -# TODO(yashkatariya): Generalize is_compatible_aval (maybe renamed) and use that -# to check if shardings are compatible with the input. -def _check_sharding(x, s): +# TODO(yashkatariya): Generalize check_compatible_aval (maybe renamed) and use +# that to check if shardings are compatible with the input. +@lru_cache(maxsize=2048) +def _check_sharding(aval, s): + if (s is not None and + not isinstance(s, (xc.Device, Sharding, Layout, TransferToMemoryKind))): + raise ValueError( + "`jax.device_put` only accepts `None`, `jax.sharding.Sharding`," + " `jax.Device`, `Layout` or a pytree of these values. Received" + f" invalid value: {s}") if isinstance(s, Sharding): - aval = shaped_abstractify(x) - if isinstance(s, XLACompatibleSharding) and not isinstance(s, PmapSharding): + if isinstance(aval, core.AbstractToken): + aval = core.token_shaped_array + if not isinstance(s, PmapSharding): pjit.pjit_check_aval_sharding( (s,), (aval,), None, "device_put args", allow_uneven_sharding=False) s.shard_shape(aval.shape) # should raise an Error if incompatible @@ -2481,16 +2427,16 @@ def _check_sharding(x, s): def device_put( x, - device: None | xc.Device | Sharding | Any | TransferToMemoryKind = None, - *, src: None | xc.Device | Sharding | Any | TransferToMemoryKind = None): + device: None | xc.Device | Sharding | Layout | Any | TransferToMemoryKind = None, + *, src: None | xc.Device | Sharding | Layout | Any | TransferToMemoryKind = None): """Transfers ``x`` to ``device``. Args: x: An array, scalar, or (nested) standard Python container thereof. - device: The (optional) :py:class:`Device`, `Sharding`, or a (nested) - `Sharding` in standard Python container (must be a tree prefix of ``x``), - representing the device(s) to which ``x`` should be transferred. If - given, then the result is committed to the device(s). + device: The (optional) :py:class:`Device`, :py:class:`Sharding`, or a + (nested) :py:class:`Sharding` in standard Python container (must be a tree + prefix of ``x``), representing the device(s) to which ``x`` should be + transferred. If given, then the result is committed to the device(s). Returns: A copy of ``x`` that resides on ``device``. @@ -2506,25 +2452,25 @@ def device_put( blocking the calling Python thread until any transfers are completed. """ with config.explicit_device_put_scope(): - if ((device is None or - isinstance(device, (xc.Device, Sharding, TransferToMemoryKind))) and - (src is None or - isinstance(src, (xc.Device, Sharding, TransferToMemoryKind)))): - for leaf in tree_leaves(x): - _check_sharding(leaf, s=device) - return tree_map( - lambda y: dispatch.device_put_p.bind( - y, device=device, src=_infer_src_sharding(src, y)), x) - x_flat, treedef = tree_flatten(x) - device_flat = flatten_axes("device_put device", treedef, device) - src_flat = flatten_axes("device_put source", treedef, src) - for x_leaf, device_leaf in zip(x_flat, device_flat): - _check_sharding(x_leaf, device_leaf) - out_flat = [ - dispatch.device_put_p.bind(xf, device=d, src=_infer_src_sharding(s, xf)) - for xf, d, s in zip(x_flat, device_flat, src_flat) - ] + if (device is None or + isinstance(device, (xc.Device, Sharding, TransferToMemoryKind))): + device_flat = [device] * len(x_flat) + else: + device_flat = flatten_axes("device_put device", treedef, device) + + if (src is None or + isinstance(src, (xc.Device, Sharding, TransferToMemoryKind))): + src_flat = [_infer_src_sharding(src, xf) for xf in x_flat] + else: + src_flat = flatten_axes("device_put source", treedef, src) + src_flat = list(map(_infer_src_sharding, src_flat, x_flat)) + + for xf, d in zip(x_flat, device_flat): + _check_sharding(shaped_abstractify(xf), d) + out_flat = dispatch.device_put_p.bind( + *x_flat, devices=device_flat, srcs=src_flat + ) return tree_unflatten(treedef, out_flat) @@ -2578,7 +2524,7 @@ def device_put_sharded(shards: Sequence[Any], devices: Sequence[xc.Device]): # # TODO(jakevdp): provide a default for devices that considers both local # devices and pods if not isinstance(shards, Sequence): - raise ValueError("device_put_sharded `shards` input must be a sequence; " + raise TypeError("device_put_sharded `shards` input must be a sequence; " f"got {type(shards)}") if len(shards) != len(devices): raise ValueError(f"len(shards) = {len(shards)} must equal " @@ -2729,22 +2675,34 @@ class ShapeDtypeStruct: named_shape: (optional) a dictionary representing a named shape sharding: (optional) a :class:`jax.Sharding` object """ - __slots__ = ["shape", "dtype", "named_shape", "sharding"] + __slots__ = ["shape", "dtype", "named_shape", "sharding", "_dll"] + def __init__(self, shape, dtype, named_shape=None, sharding=None): self.shape = tuple(shape) if dtype is None: raise ValueError("ShapeDtypeStruct: dtype must be specified.") self.dtype = dtype if dtypes.issubdtype(dtype, dtypes.extended) else np.dtype(dtype) - if sharding is not None and not isinstance(sharding, Sharding): + if sharding is not None and not isinstance(sharding, (Sharding, Layout)): raise ValueError( - "sharding should be an instance of `jax.sharding.Sharding`. " - f"Got {sharding} of type {type(sharding)}.") - self.sharding = sharding + "sharding should be an instance of `jax.sharding.Sharding` or" + f" `jax.experimental.layout.Layout`. Got {sharding} of type" + f" {type(sharding)}.") + if (isinstance(sharding, Layout) and + isinstance(sharding.device_local_layout, AutoLayout)): + raise TypeError( + "`DeviceLocalLayout.AUTO` cannot be used in place of a device-local" + f" layout in a `ShapeDtypeStruct`. Got {sharding}") + self.sharding = sharding.sharding if isinstance(sharding, Layout) else sharding + self._dll = sharding.device_local_layout if isinstance(sharding, Layout) else None self.named_shape = {} if named_shape is None else dict(named_shape) size = property(lambda self: math.prod(self.shape)) ndim = property(lambda self: len(self.shape)) + @property + def layout(self): + return Layout(self._dll, self.sharding) + def __len__(self): try: return self.shape[0] @@ -2754,8 +2712,9 @@ def __len__(self): def __repr__(self): ns = f", named_shape={self.named_shape}" if self.named_shape else "" sh = f", sharding={self.sharding}" if self.sharding is not None else "" + l = f", layout={self.layout}" if self._dll is not None else "" return (f"{type(self).__name__}(shape={self.shape}, " - f"dtype={self.dtype.name}{ns}{sh})") + f"dtype={self.dtype.name}{ns}{sh}{l})") __str__ = __repr__ @@ -2763,19 +2722,21 @@ def __eq__(self, other): if not isinstance(other, ShapeDtypeStruct): return False else: - return ((other.shape, other.dtype, other.named_shape, other.sharding) == - (self.shape, self.dtype, self.named_shape, self.sharding)) + return ((other.shape, other.dtype, other.named_shape, other.sharding, other.layout) == + (self.shape, self.dtype, self.named_shape, self.sharding, self.layout)) def __hash__(self): # TODO(frostig): avoid the conversion from dict by addressing # https://github.com/google/jax/issues/8182 named = frozenset(self.named_shape.items()) - return hash((self.shape, self.dtype, named, self.sharding)) + return hash((self.shape, self.dtype, named, self.sharding, self.layout)) + core.pytype_aval_mappings[ShapeDtypeStruct] = ( lambda x: ShapedArray(x.shape, dtypes.canonicalize_dtype(x.dtype, allow_extended_dtype=True), weak_type=False, named_shape=x.named_shape)) + @api_boundary def eval_shape(fun: Callable, *args, **kwargs): """Compute the shape/dtype of ``fun`` without any FLOPs. @@ -2930,7 +2891,7 @@ def named_scope( ... return jax.nn.relu(logits) """ if not isinstance(name, str): - raise ValueError("named_scope name argument must be a string.") + raise TypeError("named_scope name argument must be a string.") with source_info_util.extend_name_stack(name): yield @@ -2954,7 +2915,26 @@ def try_to_block(x): return x.block_until_ready() except AttributeError: return x - return tree_map(try_to_block, x) + + arrays = [] + for leaf in tree_leaves(x): + if isinstance(leaf, array.ArrayImpl): + arrays.append(leaf) + else: + try_to_block(leaf) + + if not arrays: + # `arrays` will be empty if tree_leaves(x) is empty or all leaves are not + # jax.Array. + pass + elif len(arrays) == 1: + # Fast path for single array. + try_to_block(arrays[0]) + else: + # Optimized for multiple arrays. + xc.batched_block_until_ready(arrays) + + return x def clear_backends(): @@ -2965,6 +2945,7 @@ def clear_backends(): xb.local_devices.cache_clear() xb.process_count.cache_clear() dispatch.xla_primitive_callable.cache_clear() + pjit._infer_params_cached.cache_clear() pjit._pjit_lower_cached.cache_clear() pjit._create_pjit_jaxpr.cache_clear() # pytype: disable=attribute-error pjit._cpp_pjit_cache.clear() @@ -2978,14 +2959,19 @@ def live_arrays(platform=None): return xb.get_backend(platform).live_arrays() def clear_caches(): - """Clear all compilation and staging caches.""" - # Clear all lu.cache and util.weakref_lru_cache instances (used for staging - # and Python-dispatch compiled executable caches). - lu.clear_all_caches() + """Clear all compilation and staging caches. + + This doesn't clear the persistent cache; to disable it (e.g. for benchmarks), + set the jax_enable_compilation_cache config option to False. + """ + # Clear all lu.cache, util.cache and util.weakref_lru_cache instances + # (used for staging and Python-dispatch compiled executable caches). + util.clear_all_caches() util.clear_all_weakref_lru_caches() # Clear all C++ compiled executable caches for pjit pjit._cpp_pjit_cache.clear() + pjit._infer_params_cached.cache_clear() xc._xla.PjitFunctionCache.clear_all() # Clear all C++ compiled executable caches for pmap diff --git a/jax/_src/api_util.py b/jax/_src/api_util.py index 1d4225840bce..16a29e699bbc 100644 --- a/jax/_src/api_util.py +++ b/jax/_src/api_util.py @@ -14,12 +14,11 @@ from __future__ import annotations -from collections.abc import Iterable, Sequence +from collections.abc import Callable, Iterable, Sequence import inspect import operator -from functools import partial -from typing import Any, Callable, Type -import warnings +from functools import partial, lru_cache +from typing import Any import numpy as np @@ -28,14 +27,14 @@ from jax._src.abstract_arrays import numpy_scalar_types from jax._src.core import ShapedArray from jax._src.tree_util import ( - PyTreeDef, tree_flatten, tree_unflatten, tree_map, tree_structure, + PyTreeDef, tree_flatten, tree_unflatten, tree_map, treedef_children, generate_key_paths, keystr, broadcast_prefix, prefix_errors) from jax._src.tree_util import _replace_nones from jax._src import linear_util as lu from jax._src.linear_util import TracingDebugInfo from jax._src.util import (safe_map, WrapKwArgs, Hashable, HashableFunction, - Unhashable) + Unhashable, safe_zip) from jax._src import traceback_util traceback_util.register_exclusion(__file__) @@ -151,7 +150,7 @@ def __eq__(self, other): inspect.Parameter.POSITIONAL_OR_KEYWORD ) -def validate_argnums(sig: inspect.Signature, argnums: tuple[int, ...], argnums_name: str) -> None: +def _validate_argnums(sig: inspect.Signature, argnums: tuple[int, ...], argnums_name: str) -> None: """ Validate that the argnums are sensible for a given function. @@ -168,24 +167,22 @@ def validate_argnums(sig: inspect.Signature, argnums: tuple[int, ...], argnums_n return if argnums and (-min(argnums) > n_pos_args or max(argnums) >= n_pos_args): - # raise ValueError(f"Jitted function has {argnums_name}={argnums}, " - # f"but only accepts {n_pos_args} positional arguments.") - # TODO: 2022-08-20 or later: replace with error - warnings.warn(f"Jitted function has {argnums_name}={argnums}, " - f"but only accepts {n_pos_args} positional arguments. " - "This warning will be replaced by an error after 2022-08-20 " - "at the earliest.", SyntaxWarning) + raise ValueError(f"Jitted function has {argnums_name}={argnums}, " + f"but only accepts {n_pos_args} positional arguments.") _INVALID_KEYWORD_ARGUMENTS = ( inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.VAR_POSITIONAL ) + _KEYWORD_ARGUMENTS = ( inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.KEYWORD_ONLY, ) -def validate_argnames(sig: inspect.Signature, argnames: tuple[str, ...], argnames_name: str) -> None: +def _validate_argnames( + sig: inspect.Signature, argnames: tuple[str, ...], argnames_name: str +) -> None: """ Validate that the argnames are sensible for a given function. @@ -206,33 +203,19 @@ def validate_argnames(sig: inspect.Signature, argnames: tuple[str, ...], argname elif param.kind in _INVALID_KEYWORD_ARGUMENTS: invalid_kwargs.add(param_name) - # Check whether any kwargs are invalid due to position only - invalid_argnames = invalid_kwargs & set(argnames) - if invalid_argnames: - # raise ValueError(f"Jitted function has invalid argnames {invalid_argnames} " - # f"in {argnames_name}. These are positional-only") - # TODO: 2022-08-20 or later: replace with error - warnings.warn(f"Jitted function has invalid argnames {invalid_argnames} " - f"in {argnames_name}. These are positional-only. " - "This warning will be replaced by an error after 2022-08-20 " - "at the earliest.", SyntaxWarning) + if invalid_argnames := (invalid_kwargs & set(argnames)): + raise ValueError(f"Jitted function has invalid argnames {invalid_argnames} " + f"in {argnames_name}. These are positional-only") # Takes any kwargs if var_kwargs: return # Check that all argnames exist on function - invalid_argnames = set(argnames) - valid_kwargs - if invalid_argnames: - # TODO: 2022-08-20 or later: replace with error - # raise ValueError(f"Jitted function has invalid argnames {invalid_argnames} " - # f"in {argnames_name}. Function does not take these args.") - warnings.warn(f"Jitted function has invalid argnames {invalid_argnames} " - f"in {argnames_name}. Function does not take these args." - "This warning will be replaced by an error after 2022-08-20 " - "at the earliest.", SyntaxWarning) - + if invalid_argnames := (set(argnames) - valid_kwargs): + raise ValueError(f"Jitted function has invalid argnames {invalid_argnames} " + f"in {argnames_name}. Function does not take these args.") def argnums_partial(f, dyn_argnums, args, require_static_args_hashable=True): @@ -290,7 +273,7 @@ def argnums_partial_except(f: lu.WrappedFun, static_argnums: tuple[int, ...], f"to unexpected cache-misses. Static argument (index {i}) of type " f"{type(static_arg)} for function {f.__name__} is non-hashable.") else: - fixed_args.append(_HashableWithStrictTypeEquality(static_arg)) # type: ignore + fixed_args.append(_HashableWithStrictTypeEquality(static_arg)) return _argnums_partial(f, dyn_argnums, tuple(fixed_args)), dyn_args @@ -324,7 +307,7 @@ def argnames_partial_except(f: lu.WrappedFun, static_argnames: tuple[str, ...], f"to unexpected cache-misses. Static argument (name {k}) of type " f"{type(arg)} for function {f.__name__} is non-hashable.") else: - fixed_kwargs[k] = Hashable(arg) # type: ignore + fixed_kwargs[k] = Hashable(arg) return _argnames_partial(f, WrapKwArgs(fixed_kwargs)), dyn_kwargs @@ -335,7 +318,9 @@ def _argnames_partial(fixed_kwargs: WrapKwArgs, *args, **dyn_kwargs): yield ans -def donation_vector(donate_argnums, donate_argnames, args, kwargs) -> tuple[bool, ...]: +@lru_cache(maxsize=4096) +def donation_vector(donate_argnums, donate_argnames, in_tree, + kws: bool = True) -> tuple[bool, ...]: """Returns a tuple with a boolean value for each leaf in args and kwargs. What if a user specifies donate_argnums but calls the function with kwargs @@ -349,12 +334,17 @@ def donation_vector(donate_argnums, donate_argnames, args, kwargs) -> tuple[bool kwargs specified are donated. """ res: list[bool] = [] - for i, arg in enumerate(args): + if kws: + args_tree, kwargs_tree = treedef_children(in_tree) + else: + args_tree, kwargs_tree = in_tree, None + for i, arg in enumerate(args_tree.children()): donate = bool(i in donate_argnums) - res.extend((donate,) * tree_structure(arg).num_leaves) - for key, val in kwargs.items(): - donate = key in donate_argnames - res.extend((donate,) * tree_structure(val).num_leaves) + res.extend((donate,) * arg.num_leaves) + if kwargs_tree is not None: + for key, val in safe_zip(kwargs_tree.node_data()[1], kwargs_tree.children()): # type: ignore + donate = key in donate_argnames + res.extend((donate,) * val.num_leaves) return tuple(res) def rebase_donate_argnums(donate_argnums, static_argnums) -> tuple[int, ...]: @@ -506,11 +496,23 @@ def infer_argnums_and_argnames( def resolve_argnums( - fun, donate_argnums, donate_argnames, static_argnums, static_argnames + fun: Callable, + signature: inspect.Signature | None, + donate_argnums: int | Sequence[int] | None, + donate_argnames: str | Iterable[str] | None, + static_argnums: int | Sequence[int] | None, + static_argnames: str | Iterable[str] | None, ) -> tuple[tuple[int, ...], tuple[str, ...], tuple[int, ...], tuple[str, ...]]: - try: - sig = inspect.signature(fun) - except ValueError as e: + """Validates and completes the argnum/argname specification for a jit. + + * fills in any missing pieces (e.g., names given numbers, or vice versa), + * validates the argument names/numbers against the function signature, + * validates that donated and static arguments don't intersect. + * rebases the donated arguments so they index into the dynamic arguments, + (after static arguments have been removed), in the order that parameters + are passed into the compiled function. + """ + if signature is None: # Some built-in functions don't support signature. # See: https://github.com/python/cpython/issues/73485 # In this case no validation is done @@ -522,7 +524,7 @@ def resolve_argnums( donate_argnums) if donate_argnames is not None: raise ValueError(f"Getting the signature of function {fun} failed. " - "Pass donate_argnums instead of donate_argnames.") from e + "Pass donate_argnums instead of donate_argnames.") assert donate_argnames is None donate_argnames = () else: @@ -530,23 +532,23 @@ def resolve_argnums( # If nums is None and names is not None, then nums are inferred from the # names and vice-versa. static_argnums, static_argnames = infer_argnums_and_argnames( - sig, static_argnums, static_argnames) + signature, static_argnums, static_argnames) donate_argnums, donate_argnames = infer_argnums_and_argnames( - sig, donate_argnums, donate_argnames) + signature, donate_argnums, donate_argnames) # Validation - validate_argnums(sig, static_argnums, "static_argnums") - validate_argnames(sig, static_argnames, "static_argnames") - validate_argnums(sig, donate_argnums, "donate_argnums") - validate_argnames(sig, donate_argnames, "donate_argnames") + _validate_argnums(signature, static_argnums, "static_argnums") + _validate_argnames(signature, static_argnames, "static_argnames") + _validate_argnums(signature, donate_argnums, "donate_argnums") + _validate_argnames(signature, donate_argnames, "donate_argnames") # Compensate for static argnums absorbing args - assert_no_intersection(static_argnames, donate_argnames) + _assert_no_intersection(static_argnames, donate_argnames) donate_argnums = rebase_donate_argnums(donate_argnums, static_argnums) return donate_argnums, donate_argnames, static_argnums, static_argnames -def assert_no_intersection(static_argnames, donate_argnames): +def _assert_no_intersection(static_argnames, donate_argnames): out = set(static_argnames).intersection(set(donate_argnames)) if out: raise ValueError( @@ -580,10 +582,9 @@ def _shaped_abstractify_slow(x): # TODO(mattjj,yashkatariya): replace xla.abstractify with this, same behavior def shaped_abstractify(x): - try: - return _shaped_abstractify_handlers[type(x)](x) - except KeyError: - return _shaped_abstractify_slow(x) + handler = _shaped_abstractify_handlers.get(type(x), None) + return handler(x) if handler is not None else _shaped_abstractify_slow(x) + _shaped_abstractify_handlers: dict[Any, Callable[[Any], core.ShapedArray]] = {} @@ -606,6 +607,13 @@ def _np_scalar_abstractify(x: np.generic) -> ShapedArray: _shaped_abstractify_handlers.update((t, _np_scalar_abstractify) for t in numpy_scalar_types) +def _python_scalar_abstractify(x: int | float | complex | bool) -> ShapedArray: + typ = type(x) + dtype = dtypes._scalar_type_to_dtype(typ, x) + return ShapedArray((), dtype, weak_type=typ in dtypes._weak_types) +_shaped_abstractify_handlers.update((t, _python_scalar_abstractify) + for t in dtypes.python_scalar_dtypes) + # This decorator exists to make it easier to monkey-patch APIs in JAX. # By default it does nothing, but it can be monkey-patched to do other things. def api_hook(fun, tag: str): @@ -664,7 +672,7 @@ def result_paths(*args, **kwargs): yield ans, [keystr(path) for path, _ in generate_key_paths(ans)] def jaxpr_debug_info(jaxpr: core.Jaxpr, trace_debug: TracingDebugInfo | None, - result_paths: tuple[str | None, ...] | None = None, + result_paths: tuple[str, ...] | None = None, ) -> core.Jaxpr: """Add debug info to jaxpr, given trace-time debug info and result paths.""" if trace_debug is None: @@ -705,6 +713,6 @@ def __hash__(self): def __eq__(self, other): return self.val is other.val -def register_class_with_attrs(t: Type) -> None: +def register_class_with_attrs(t: type) -> None: _class_with_attrs.add(t) -_class_with_attrs: set[Type] = set() +_class_with_attrs: set[type] = set() diff --git a/jax/_src/array.py b/jax/_src/array.py index ff13ab7acab9..26e3e2e6efca 100644 --- a/jax/_src/array.py +++ b/jax/_src/array.py @@ -15,14 +15,12 @@ from __future__ import annotations from collections import defaultdict +from collections.abc import Callable, Sequence import enum +import functools import math import operator as op -import numpy as np -import functools -from typing import Any, Callable, cast, TYPE_CHECKING -import warnings -from collections.abc import Sequence +from typing import Any, TYPE_CHECKING, cast from jax._src import abstract_arrays from jax._src import api @@ -30,25 +28,26 @@ from jax._src import basearray from jax._src import config from jax._src import core -from jax._src import deprecations from jax._src import dispatch from jax._src import dtypes from jax._src import errors from jax._src import profiler from jax._src import tree_util from jax._src import xla_bridge -from jax._src.lib import xla_client as xc from jax._src.interpreters import mlir from jax._src.interpreters import pxla from jax._src.interpreters import xla +from jax._src.layout import AutoLayout, DeviceLocalLayout, Layout +from jax._src.lib import xla_client as xc +from jax._src.lib import xla_extension as xe from jax._src.sharding import Sharding from jax._src.sharding_impls import ( - SingleDeviceSharding, XLACompatibleSharding, PmapSharding, - device_replica_id_map, hashed_index) -from jax._src.typing import ArrayLike -from jax._src.util import safe_zip, unzip3, use_cpp_class, use_cpp_method + PmapSharding, SingleDeviceSharding, + device_replica_id_map, hashed_index, num_addressable_indices, local_to_global_shape) # pyformat: disable +from jax._src.typing import ArrayLike, DLDeviceType +from jax._src.util import safe_zip, unzip3, use_cpp_class, use_cpp_method, cache +import numpy as np -deprecations.register(__name__, "device-method") Shape = tuple[int, ...] Device = xc.Device @@ -56,8 +55,9 @@ PRNGKeyArray = Any # TODO(jakevdp): fix cycles and import this. def _get_device(a: ArrayImpl) -> Device: - assert len(a.devices()) == 1 - return next(iter(a.devices())) + devices = a.sharding._internal_device_list # pytype: disable=attribute-error + assert len(devices) == 1 + return devices[0] class Shard: @@ -120,30 +120,20 @@ def _reconstruct_array(fun, args, arr_state, aval_state): return jnp_value -@functools.lru_cache(maxsize=4096) +@cache(max_size=4096, trace_context_in_key=False) def _cached_index_calc(s, shape): map_ = s.addressable_devices_indices_map(shape) seen_h_indices = set() - m = {} - for d, index in map_.items(): + l = [] + for array_index, index in enumerate(map_.values()): h_index = hashed_index(index) if h_index not in seen_h_indices: seen_h_indices.add(h_index) - m[d] = index - return m - + l.append((array_index, index)) + return l -def _create_copy_plan(arrays, s: Sharding, shape: Shape): - di_map = _cached_index_calc(s, shape) - copy_plan = [] - for a in arrays: - ind = di_map.get(_get_device(a), None) - if ind is not None: - copy_plan.append((ind, a)) - return copy_plan - -@functools.lru_cache(maxsize=4096) +@cache(max_size=4096, trace_context_in_key=False) def _process_has_full_value_in_mcjax(s, shape): # Return False for single host as a fast path. if xla_bridge.process_count() == 1: @@ -156,6 +146,28 @@ def _process_has_full_value_in_mcjax(s, shape): return num_unique_indices == num_addressable_unique_indices +def _validate_shape_and_dtype_for_per_device_arrays( + arrays: Sequence[ArrayImpl | np.ndarray], + sharding: Sharding, + aval: core.ShapedArray, + expected_shape: Shape, +): + """Validates that per-device arrays are valid and consistent.""" + expected_dtype = aval.dtype + for db in arrays: + if db.dtype != expected_dtype: + raise ValueError( + "Input buffers to `Array` must have matching dtypes. " + f"Got {db.dtype}, expected {expected_dtype} for buffer: {db}" + ) + if db.shape != expected_shape: + raise ValueError( + f"Expected shard shape {expected_shape} doesn't match the single " + f"device array shape {db.shape}. Shape of Array is " + f"{aval.str_short()} with sharding {sharding}" + ) + + class ArrayImpl(basearray.Array): # TODO(yashkatariya): Add __slots__ here. @@ -186,12 +198,6 @@ def __init__(self, aval: core.ShapedArray, sharding: Sharding, self._check_and_rearrange() def _check_and_rearrange(self): - for db in self._arrays: - if db.dtype != self.dtype: - raise ValueError( - "Input buffers to `Array` must have matching dtypes. " - f"Got {db.dtype}, expected {self.dtype} for buffer: {db}") - device_id_to_buffer = {_get_device(db).id: db for db in self._arrays} addressable_dev = self.sharding.addressable_devices @@ -219,18 +225,15 @@ def _check_and_rearrange(self): "that are not present in the sharding.") raise ValueError(err_msg) - ss = self.sharding.shard_shape(self.shape) - for db in self._arrays: - if db.shape != ss: - raise ValueError( - f"Expected shard shape {ss} doesn't match the single device array " - f"shape {db.shape}. Shape of Array is " - f"{self.aval.str_short()} with sharding {self.sharding}") - + _validate_shape_and_dtype_for_per_device_arrays( + self._arrays, + sharding=self.sharding, + aval=self.aval, + expected_shape=self.sharding.shard_shape(self.shape), + ) # Rearrange arrays based on the device assignment. - if isinstance(self.sharding, XLACompatibleSharding): - addressable_da = self.sharding._addressable_device_assignment - self._arrays = [device_id_to_buffer[device.id] for device in addressable_da] + addressable_da = self.sharding._addressable_device_assignment + self._arrays = [device_id_to_buffer[device.id] for device in addressable_da] @property def shape(self) -> Shape: @@ -283,11 +286,11 @@ def __complex__(self): def __hex__(self): core.check_integer_conversion(self) - return hex(self._value) # type: ignore + return hex(self._value) def __oct__(self): core.check_integer_conversion(self) - return oct(self._value) # type: ignore + return oct(self._value) def __index__(self): core.check_integer_conversion(self) @@ -351,9 +354,9 @@ def __iter__(self): else: assert self.is_fully_replicated or self.is_fully_addressable if dispatch.is_single_device_sharding(self.sharding) or self.is_fully_replicated: - return (sl for chunk in self._chunk_iter(100) for sl in chunk._unstack()) # type: ignore + return (sl for chunk in self._chunk_iter(100) for sl in chunk._unstack()) elif isinstance(self.sharding, PmapSharding): - return (self[i] for i in range(self.shape[0])) # type: ignore + return (self[i] for i in range(self.shape[0])) else: # TODO(yashkatariya): Don't bounce to host and use `_chunk_iter` path # here after uneven partitioning support is added. @@ -404,15 +407,29 @@ def __array__(self, dtype=None, context=None, copy=None): kwds = {} if copy is None else {'copy': copy} return np.asarray(self._value, dtype=dtype, **kwds) - def __dlpack__(self, *, stream: int | Any | None = None): - if len(self._arrays) != 1: - raise ValueError("__dlpack__ only supported for unsharded arrays.") + def __dlpack__(self, *, stream: int | Any | None = None, + max_version: tuple[int, int] | None = None, + dl_device: tuple[DLDeviceType, int] | None = None, + copy: bool | None = None): from jax._src.dlpack import to_dlpack # pylint: disable=g-import-not-at-top - return to_dlpack(self, stream=stream) + + device_set = self.sharding.device_set + if len(device_set) > 1: + raise BufferError( + "to_dlpack can only pack a dlpack tensor from an array on a singular " + f"device, but an array with a Sharding over {len(device_set)} devices " + "was provided." + ) + device, = device_set + return to_dlpack(self, stream=stream, + max_version=max_version, + src_device=device, + dl_device=dl_device, + copy=copy) def __dlpack_device__(self) -> tuple[enum.Enum, int]: if len(self._arrays) != 1: - raise ValueError("__dlpack__ only supported for unsharded arrays.") + raise BufferError("__dlpack__ only supported for unsharded arrays.") from jax._src.dlpack import DLDeviceType # pylint: disable=g-import-not-at-top @@ -426,23 +443,23 @@ def __dlpack_device__(self) -> tuple[enum.Enum, int]: elif "rocm" in platform_version: dl_device_type = DLDeviceType.kDLROCM else: - raise ValueError("Unknown GPU platform for __dlpack__: " + raise BufferError("Unknown GPU platform for __dlpack__: " f"{platform_version}") local_hardware_id = _get_device(self).local_hardware_id if local_hardware_id is None: - raise ValueError("Couldn't get local_hardware_id for __dlpack__") + raise BufferError("Couldn't get local_hardware_id for __dlpack__") return dl_device_type, local_hardware_id else: - raise ValueError( + raise BufferError( "__dlpack__ device only supported for CPU and GPU, got platform: " f"{self.platform()}" ) def __reduce__(self): - fun, args, arr_state = self._value.__reduce__() # type: ignore + fun, args, arr_state = self._value.__reduce__() aval_state = {'weak_type': self.aval.weak_type, 'named_shape': self.aval.named_shape} return (_reconstruct_array, (fun, args, arr_state, aval_state)) @@ -466,52 +483,22 @@ def __cuda_array_interface__(self): def on_device_size_in_bytes(self): """Returns the total global on-device size of the array in bytes.""" arr = self._arrays[0] - per_shard_size = arr.on_device_size_in_bytes() # type: ignore + per_shard_size = arr.on_device_size_in_bytes() return per_shard_size * len(self.sharding.device_set) - # TODO(yashkatariya): Remove this method when everyone is using devices(). - def device(self) -> Device: - if deprecations.is_accelerated(__name__, "device-method"): - raise NotImplementedError("arr.device() is deprecated. Use arr.devices() instead.") - else: - warnings.warn("arr.device() is deprecated. Use arr.devices() instead.", - DeprecationWarning, stacklevel=2) - self._check_if_deleted() - device_set = self.sharding.device_set - if len(device_set) == 1: - single_device, = device_set - return single_device - raise ValueError('Length of devices is greater than 1. ' - 'Please use `.devices()`.') - def devices(self) -> set[Device]: self._check_if_deleted() return self.sharding.device_set - # TODO(https://github.com/google/jax/issues/12380): Remove this when DA is - # deleted. @property - def device_buffer(self) -> ArrayImpl: - # Added 2023 Dec 6 - warnings.warn( - "arr.device_buffer is deprecated. Use arr.addressable_data(0)", - DeprecationWarning, stacklevel=2) - self._check_if_deleted() - if len(self._arrays) == 1: - return self._arrays[0] - raise ValueError('Length of buffers is greater than 1. Please use ' - '`.device_buffers` instead.') + def device_buffer(self): + raise AttributeError( + "arr.device_buffer has been deprecated. Use arr.addressable_data(0)") - # TODO(https://github.com/google/jax/issues/12380): Remove this when SDA is - # deleted. @property - def device_buffers(self) -> Sequence[ArrayImpl]: - # Added 2023 Dec 6 - warnings.warn( - "arr.device_buffers is deprecated. Use [x.data for x in arr.addressable_shards]", - DeprecationWarning, stacklevel=2) - self._check_if_deleted() - return cast(Sequence[ArrayImpl], self._arrays) + def device_buffers(self): + raise AttributeError( + "arr.device_buffers has been deprecated. Use [x.data for x in arr.addressable_shards]") def addressable_data(self, index: int) -> ArrayImpl: self._check_if_deleted() @@ -527,6 +514,21 @@ def addressable_shards(self) -> Sequence[Shard]: out.append(Shard(_get_device(a), self.sharding, self.shape, a)) return out + @property + def layout(self): + # TODO(yashkatariya): Remove the deleted check from here. + if self.is_deleted(): + return Layout(None, self.sharding) + try: + return Layout(DeviceLocalLayout.from_pjrt_layout(self._pjrt_layout), + self.sharding) + except xe.XlaRuntimeError as e: + msg, *_ = e.args + if type(msg) is str and msg.startswith("UNIMPLEMENTED"): + return Layout(None, self.sharding) + else: + raise + @property def global_shards(self) -> Sequence[Shard]: """Returns list of all `Shard`s of the Array across all devices. @@ -593,9 +595,8 @@ def copy_to_host_async(self): if self.is_fully_replicated: self._copy_single_device_array_to_host_async() return - copy_plan = _create_copy_plan(self._arrays, self.sharding, self.shape) - for _, arr in copy_plan: - arr._copy_single_device_array_to_host_async() + for i, _ in _cached_index_calc(self.sharding, self.shape): + self._arrays[i]._copy_single_device_array_to_host_async() @property @functools.partial(profiler.annotate_function, name="np.asarray(jax.Array)") @@ -604,7 +605,7 @@ def _value(self) -> np.ndarray: if self._npy_value is None: if self.is_fully_replicated: - self._npy_value = self._single_device_array_to_np_array() # type: ignore + self._npy_value = self._single_device_array_to_np_array() self._npy_value.flags.writeable = False return cast(np.ndarray, self._npy_value) @@ -612,19 +613,21 @@ def _value(self) -> np.ndarray: # is_fully_addressable. if (not self.is_fully_addressable and not _process_has_full_value_in_mcjax(self.sharding, self.shape)): - raise RuntimeError("Fetching value for `jax.Array` that spans " - "non-addressable devices is not possible. You can use " - "`jax.experimental.multihost_utils.process_allgather` " - "for this use case.") + raise RuntimeError( + "Fetching value for `jax.Array` that spans non-addressable" + " (non process local) devices is not possible. You can use" + " `jax.experimental.multihost_utils.process_allgather` to print the" + " global array or use `.addressable_shards` method of jax.Array to" + " inspect the addressable (process local) shards." + ) - copy_plan = _create_copy_plan(self._arrays, self.sharding, self.shape) - for _, arr in copy_plan: - arr._copy_single_device_array_to_host_async() + for i, _ in _cached_index_calc(self.sharding, self.shape): + self._arrays[i]._copy_single_device_array_to_host_async() npy_value = np.empty(self.shape, self.dtype) - for ind, arr in copy_plan: - npy_value[ind] = arr._single_device_array_to_np_array() - self._npy_value = npy_value # type: ignore + for i, ind in _cached_index_calc(self.sharding, self.shape): + npy_value[ind] = self._arrays[i]._single_device_array_to_np_array() + self._npy_value = npy_value self._npy_value.flags.writeable = False # https://docs.python.org/3/library/typing.html#typing.cast return cast(np.ndarray, self._npy_value) @@ -637,13 +640,24 @@ def _value(self) -> np.ndarray: ArrayImpl = use_cpp_class(xc.ArrayImpl)(ArrayImpl) -# explicitly set to be unhashable. Same as what device_array.py does. +def _get_shape_from_index(slc: Index, shape: Shape) -> Shape: + return tuple( + (s.stop or dim) - (s.start or 0) + for s, dim in safe_zip(slc, shape) + if isinstance(s, slice) # If element is int, this dimension is reduced + ) + + +# explicitly set to be unhashable. setattr(ArrayImpl, "__hash__", None) setattr(ArrayImpl, "__array_priority__", 100) +# TODO(yashkatariya): Remove None from callback input type. + def make_array_from_callback( - shape: Shape, sharding: Sharding, + shape: Shape, sharding: Sharding | Layout, data_callback: Callable[[Index | None], ArrayLike]) -> ArrayImpl: + # pyformat: disable """Returns a ``jax.Array`` via data fetched from ``data_callback``. ``data_callback`` is used to fetch the data for each addressable shard of the @@ -662,7 +676,7 @@ def make_array_from_callback( Returns: A ``jax.Array`` via data fetched from ``data_callback``. - Example: + Examples: >>> import math >>> from jax.sharding import Mesh @@ -681,47 +695,256 @@ def make_array_from_callback( >>> arr.addressable_data(0).shape (4, 2) """ - has_device_assignment = False + # pyformat: enable + dll = sharding.device_local_layout if isinstance(sharding, Layout) else None + if isinstance(dll, AutoLayout): + raise TypeError( + "`DeviceLocalLayout.AUTO` cannot be used in place of a device-local" + f" layout when calling `jax.make_array_from_callback`. Got {sharding}") + sharding = sharding.sharding if isinstance(sharding, Layout) else sharding # type: ignore + if not isinstance(sharding, Sharding): + raise TypeError( + f"sharding should be an instance of `jax.sharding`. Got {sharding} of" + f" type {type(sharding)}") + + def get_data(index: Index | None) -> ArrayImpl | np.ndarray: + # Perhaps cache on index here, then we can unify fully_replicated + # and non-fully_replicated cases below and become faster for + # partially replicated cases. + assert index is not None + r = data_callback(index) + if isinstance(r, core.Tracer): + raise errors.UnexpectedTracerError( + "jax.make_array_from_callback cannot be called within a traced" + " context." + ) + # Value can be python scalar, resolve it into something with dtype. + return xla.canonicalize_dtype(r) + if sharding.is_fully_replicated: - if isinstance(sharding, XLACompatibleSharding): - devices = list(sharding._addressable_device_assignment) - has_device_assignment = True - else: - devices = list(sharding.addressable_devices) - per_device_values = [data_callback((slice(None),) * len(shape))] * len(devices) + devices = list(sharding._internal_device_list.addressable_device_list) # type: ignore + # Only compute data once. + per_device_values = [get_data((slice(None),) * len(shape))] * len(devices) else: device_to_index_map = sharding.addressable_devices_indices_map(shape) devices = list(device_to_index_map.keys()) - per_device_values = [data_callback(device_to_index_map[device]) - for device in devices] - - if isinstance(per_device_values[0], core.Tracer): - raise errors.UnexpectedTracerError( - "jax.make_array_from_callback cannot be called within a traced context.") - - first_value = xla.canonicalize_dtype(per_device_values[0]) - aval = core.ShapedArray(shape, first_value.dtype, weak_type=False) - - # TODO(yashkatariya): Look into taking this path for non-fully replicated - # shardings too. - if (sharding.is_fully_replicated and has_device_assignment and - not dtypes.issubdtype(aval.dtype, dtypes.extended)): - # Do this check outside because `batched_device_put` won't do these checks - # like ArrayImpl. This is a fast path for fully replicated arrays with - # xla compatible sharding. - if shape != first_value.shape: - raise ValueError( - f"Expected shard shape {shape} doesn't match the single device " - f"array shape {first_value.shape}. Shape of Array is " - f"{aval.str_short()} with sharding {sharding}") - return pxla.batched_device_put( - aval, sharding, per_device_values, devices, committed=True) + per_device_values = [ + get_data(device_to_index_map[device]) for device in devices + ] + + first_value = per_device_values[0] + expected_dtype = first_value.dtype + expected_shape = sharding.shard_shape(shape) + aval = core.ShapedArray(shape, expected_dtype) + _validate_shape_and_dtype_for_per_device_arrays( + per_device_values, + expected_shape=expected_shape, + aval=aval, + sharding=sharding, + ) + if (isinstance(first_value, ArrayImpl) + and first_value._committed + and sharding.is_fully_replicated + and first_value.is_fully_replicated + and first_value.sharding._device_assignment == tuple(devices) + and (first_value.layout.device_local_layout == + pxla._maybe_get_default_layout(Layout(dll, sharding), None, sharding, aval))): + return first_value - arrays = api.device_put(per_device_values, devices) if dtypes.issubdtype(aval.dtype, dtypes.extended): - return aval.dtype._rules.make_sharded_array(aval, sharding, arrays, - committed=True) - return ArrayImpl(aval, sharding, arrays, committed=True) + # TODO(yashkatariya): Can this also use batched_device_put? + arrays = api.device_put(per_device_values, devices) + return aval.dtype._rules.make_sharded_array( + aval, sharding, arrays, committed=True + ) + + if dll is not None: + devices = [Layout(dll, SingleDeviceSharding(d)) for d in devices] + # pxla.batched_device_put doesn't support Layout... Take the slow route + arrays = api.device_put(per_device_values, devices) + return ArrayImpl(aval, sharding, arrays, committed=True) + + if isinstance(first_value, ArrayImpl) and len(first_value.devices()) > 1: + # The output of the callback is already a sharded array, move it to + # to target device. + per_device_values = api.device_put(per_device_values, devices) + + return pxla.batched_device_put(aval, sharding, per_device_values, devices) + + +def make_array_from_process_local_data( + sharding: Sharding, + local_data: np.ndarray, + global_shape: Shape | None = None, +) -> ArrayImpl: + # pyformat: disable + """Creates distributed tensor using the data available in process. + + This function is a common special case of `make_array_from_callback`. It + assumes that the data is available in the process and takes care of the + index wrangling. + + The most common case is when the sharding is sharded across the batch + dimension and each host just loads its corresponding sub-batch. This function + supports more general cases as well, such as mixed multi-host and multi-axis + replication and sharding but you would need to compute the size and the + contents of process-local data correctly to satisfy the sharding constraints. + + In particular, if any two hosts are replicas, host_local_data should be + identical as well. + + The global_shape is optional. If not provided it will be be inferred from + the local_data and sharding, under the assumption that + each host represents only their own data for uniform sharding. If sharding + is non-uniform, (see note below) an exception will be raised. + + Setting global_shape explicitly allows for finer grain control and works with + non-uniform shardings. Each dimension of global_shape must either match + host_local_data, or match the inferred global shape of the sharding (in which + case it is equivalent to setting it to None, but is more explicit). + + For example if dimension `i` is fully sharded then this size would be + `per_device_shape[i] * jax.local_device_count()`. Each device will be mapped + into local slice of `local_data` array. For example, if given process + addresses slices (8, 12) and (24, 28), then these slices will be mapped + into (0, 4) and (4, 8) of the `local_data`. + + For each dimension where global_shapes matches local_shape, each device + will lookup the slice in the local_data. For example if + global_shape == local_data.shape, the local data is assumed to be the + actual target array that will be sharded into device. + + If global_shape is the same as local_data.shape, then the data must + be the same across all hosts. + + Examples: + >>> from jax.sharding import PartitionSpec as P + >>> mesh_rows = 2 + >>> mesh_cols = jax.device_count() // 2 + ... + >>> mesh = jax.sharding.Mesh(np.array(jax.devices()).reshape(mesh_rows, mesh_cols), ('x', 'y')) + + >>> sharding = jax.sharding.NamedSharding(mesh, P(('x', 'y'),)) + >>> rows_per_device = 2 + >>> feature_length = 32 + >>> per_device_shape = (rows_per_device, feature_length) + >>> per_host_shape = (rows_per_device * len(mesh.local_devices), feature_length) + >>> per_host_generator = lambda : np.arange(np.prod(per_host_shape)).reshape(per_host_shape) + >>> per_host_data = per_host_generator() # replace with your own per-host data pipeline that outputs numpy arrays + >>> global_shape = (rows_per_device * len(sharding.device_set), ) + per_device_shape[1:] + >>> output_global_array = jax.make_array_from_process_local_data(sharding, per_host_data, global_shape) + ... + >>> assert output_global_array.addressable_data(0).shape == per_device_shape + >>> assert output_global_array.shape == global_shape + + NB: While most shardings are uniform, It is possible to design am exotic + sharding mesh where each process's devices will be arranged in a non-grid + like pattern in some dimensions, or for indices to overlap non-trivially. + Such sharding is called "non-uniform" in those dimensions. In that case, + the global shape along those directions must match local shape as there is + no meaningful way to represent all needed + per-process data in non-overlapping fashion. For example for global_shape 4x4 + if sharding looks like this: + + 0123 + 2103 + 4675 + 4567 + + with 4 processes, containing devices (0,1), (2, 3), (4, 5), (6, 7) respectively. + Then the data for each host look like + + xx.. ..xx .... .... + .xx. x..x .... .... + .... .... x..x .xx. + .... .... xx.. ..xx + + the sharding is uniform on rows (each host requires either rows 1-2, or rows 3-4) + and non-uniform on columns (hosts require overlapping but not matching + set of columns). Thus local data must have the shape 2x4 or 4x4 + for all hosts, even though each host can potentially fit into 2x2 shape. + In this case user must provide global_shape explicitly and for + local_shape=(2, 4), potentially valid global shapes are (2, 4) and (4, 4). + + On the other hand for sharding: + + 0213 x.x. .x.x. .... .... + 0213 x.x. .x.x. .... .... + 4657 .... .... .x.x x.x. + 4657 .... .... .x.x x.x. + + for local_shape=(2, 2) this function can accept a choice of 2x2, 2x4, 4x2 + and 4x4 global shapes. Setting global_shape to None, is equivalent to + setting it to (4, 4) in this case. + + Args: + sharding: sharding of the global tensor. + host_local_data: data on the host to be placed on local devices. Each + dimension should either match global_shape, or match + num_addressable_indices(dim). + global_shape: the target shape of the global tensor. If None, + will infer from host_local_data and sharding. + + Returns: + Tensor that will have sharding=sharding and of shape global_shape. + """ + # pyformat: enable + # TODO(sandler): consider supporting partially specified global_shape or + # making local_to_global_shape available in the api. + local_shape = local_data.shape + if global_shape is None: + global_shape = local_to_global_shape(sharding, local_shape) # type: ignore[assignment] + assert global_shape is not None + if None in global_shape: + raise ValueError( + "Unable to compute global_shape due to non-uniform sharding." + f" Specify global shape directly. Partially computed {global_shape=}." + ) + elif None in global_shape: + raise ValueError(f"{global_shape=} has Nones. This is not supported.") + full_dim = [] + for i, (data_dim, global_dim) in enumerate( + zip(local_data.shape, global_shape) + ): + full_dim.append(data_dim == global_dim) + if data_dim != global_dim: + process_slice = num_addressable_indices(sharding, i, global_shape) + if process_slice != data_dim: + raise ValueError( + "Invalid host data, each dimension should match either global or " + f"process shape. In dimension {i=}, the process data has {data_dim}" + f"elements. Process addresses {process_slice} elements and " + f"{global_shape=}." + ) + addressable_shards = sharding.addressable_devices_indices_map(global_shape) + shard = next(iter(addressable_shards.values())) + assert shard is not None + shard_shape = _get_shape_from_index(shard, global_shape) + slices_for_each_dim: list[list[int]] = [[] for _ in global_shape] + for shard_index in addressable_shards.values(): + assert shard_index is not None + for i, slc in enumerate(shard_index): + slices_for_each_dim[i].append(slc.start or 0) + for i in range(len(global_shape)): + slices_for_each_dim[i] = sorted(set(slices_for_each_dim[i])) + + @functools.lru_cache(maxsize=4096) + def local_slice(i, start): + # Looks up the index of this slice in the list of slices for this dimension. + # This will determine the slice in host_local_data + start = slices_for_each_dim[i].index(start or 0) * shard_shape[i] + end = start + shard_shape[i] + return slice(start, end) + + def cb(index: Index | None) -> ArrayLike: + assert index is not None + data_slice = ( + slc if full_dim[i] else local_slice(i, slc.start) + for i, slc in enumerate(index) + ) + return local_data[tuple(data_slice)] + + return make_array_from_callback(global_shape, sharding, cb) def make_array_from_single_device_arrays( @@ -806,11 +1029,13 @@ def make_array_from_single_device_arrays( # All input arrays should be committed. Checking it is expensive on # single-controller systems. if any(isinstance(arr, core.Tracer) for arr in arrays): - raise ValueError("jax.make_array_from_single_device_arrays requires a list of concrete arrays as input. " - f"got types {set(map(type, arrays))}") + raise ValueError( + "jax.make_array_from_single_device_arrays requires a list of concrete" + f" arrays as input. got types {set(map(type, arrays))}") aval = core.ShapedArray(shape, arrays[0].dtype, weak_type=False) if dtypes.issubdtype(aval.dtype, dtypes.extended): - return aval.dtype._rules.make_sharded_array(aval, sharding, arrays, committed=True) + return aval.dtype._rules.make_sharded_array(aval, sharding, arrays, + committed=True) # TODO(phawkins): ideally the cast() could be checked. return ArrayImpl(aval, sharding, cast(Sequence[ArrayImpl], arrays), committed=True) @@ -825,7 +1050,18 @@ def make_array_from_single_device_arrays( def _array_mlir_constant_handler(val): - return mlir.ir_constants(val._value) + try: + return mlir.ir_constants(val._value) + except RuntimeError as e: + # TODO(yashkatariya): Ideally we would catch a custom exception from + # `_value` function in ArrayImpl instead of checking the error string. + if 'Fetching value for `jax.Array` that spans non-addressable' in str(e): + raise RuntimeError( + "Closing over jax.Array that spans non-addressable (non process" + " local) devices is not allowed. Please pass such arrays as arguments" + f" to the function. Got jax.Array: {val.aval.str_short()}") from e + raise + mlir.register_constant_handler(ArrayImpl, _array_mlir_constant_handler) @@ -855,7 +1091,7 @@ def as_slice_indices(arr: Any, idx: Index) -> tuple[ start_indices[dim] = sub_idx.start limit_indices[dim] = sub_idx.stop - return tuple(start_indices), tuple(limit_indices), tuple(removed_dims) # type: ignore + return tuple(start_indices), tuple(limit_indices), tuple(removed_dims) def shard_device_array(x, devices, indices, sharding): @@ -872,15 +1108,11 @@ def _hashable_index(idx): return tree_util.tree_map( lambda x: (x.start, x.stop) if type(x) == slice else x, idx) -# The fast path is handled directly in shard_args(). + def shard_sharded_device_array_slow_path(x, devices, indices, sharding): candidates = defaultdict(list) - if isinstance(x, ArrayImpl): - bufs = [buf.data for buf in x.addressable_shards] - arr_indices = tuple(x.sharding.devices_indices_map(x.shape).values()) - else: - bufs = x.device_buffers - arr_indices = x.indices + bufs = [buf.data for buf in x.addressable_shards] + arr_indices = tuple(x.sharding.devices_indices_map(x.shape).values()) for buf, idx in safe_zip(bufs, arr_indices): candidates[_hashable_index(idx)].append(buf) @@ -891,7 +1123,7 @@ def shard_sharded_device_array_slow_path(x, devices, indices, sharding): if not candidates_list: # This array isn't sharded correctly. Reshard it via host roundtrip. # TODO(skye): more efficient reshard? - return pxla.shard_arg(x._value, sharding, canonicalize=False) + return pxla.shard_args([sharding], [x._value], canonicalize=False)[0] # Try to find a candidate buffer already on the correct device, # otherwise copy one of them. for buf in candidates_list: @@ -904,33 +1136,58 @@ def shard_sharded_device_array_slow_path(x, devices, indices, sharding): return pxla.batched_device_put(x.aval, sharding, bufs, devices) -def _array_shard_arg(x, sharding): - x._check_if_deleted() +@cache(max_size=4096, trace_context_in_key=False) +def _sharding_indices_and_eq(src_sharding, shape, dst_sharding): + src_indices = src_sharding.addressable_devices_indices_map(shape).values() + dst_indices = dst_sharding.addressable_devices_indices_map(shape).values() + return dst_indices, tuple(src_indices) == tuple(dst_indices) - x_indices = x.sharding.addressable_devices_indices_map(x.shape).values() - indices = sharding.addressable_devices_indices_map(x.shape).values() - if not x.is_fully_addressable: - if tuple(x_indices) == tuple(indices): - return x - else: - raise NotImplementedError( - "Cannot reshard an input that is not fully addressable") - else: - devices = pxla.get_addressable_devices_for_shard_arg(sharding) - if tuple(x_indices) == tuple(indices): - return xc.copy_array_to_devices_with_sharding(x, list(devices), sharding) - # Resharding starts here: - if dispatch.is_single_device_sharding(x.sharding): - return shard_device_array(x, devices, indices, sharding) + +def _array_shard_arg(xs, shardings): + results = [] + batch_xs, batch_devs, batch_shardings, batch_indices = [], [], [], [] + for i, (x, sharding) in enumerate(safe_zip(xs, shardings)): + x._check_if_deleted() + + indices, same_indices = _sharding_indices_and_eq( + x.sharding, x.shape, sharding) + if not x.is_fully_addressable: + if same_indices: + results.append(x) + else: + raise NotImplementedError( + "Cannot reshard an input that is not fully addressable") else: - return shard_sharded_device_array_slow_path(x, devices, indices, sharding) + devices = sharding._addressable_device_assignment + if same_indices: + # Add a placeholder result that will be filled in later. + results.append(None) + # Accumulate arguments to `batched_copy_array_to_devices_with_sharding`. + batch_xs.append(x) + batch_devs.append(list(devices)) + batch_shardings.append(sharding) + batch_indices.append(i) + # Resharding starts here: + elif dispatch.is_single_device_sharding(x.sharding): + results.append(shard_device_array(x, devices, indices, sharding)) + else: + results.append( + shard_sharded_device_array_slow_path(x, devices, indices, sharding)) + + copy_outs = xc.batched_copy_array_to_devices_with_sharding( + batch_xs, batch_devs, batch_shardings) + for i, copy_out in safe_zip(batch_indices, copy_outs): + assert results[i] is None + results[i] = copy_out + return results + pxla.shard_arg_handlers[ArrayImpl] = _array_shard_arg def _array_global_result_handler(global_aval, out_sharding, committed): if global_aval.dtype == dtypes.float0: - return lambda _: np.zeros(global_aval.shape, dtypes.float0) # type: ignore + return lambda _: np.zeros(global_aval.shape, dtypes.float0) if dtypes.issubdtype(global_aval.dtype, dtypes.extended): return global_aval.dtype._rules.global_sharded_result_handler( global_aval, out_sharding, committed) @@ -939,13 +1196,11 @@ def _array_global_result_handler(global_aval, out_sharding, committed): ) pxla.global_result_handlers[core.ShapedArray] = _array_global_result_handler pxla.global_result_handlers[core.ConcreteArray] = _array_global_result_handler -pxla.global_result_handlers[core.AbstractToken] = lambda *_: lambda *_: core.token - # Only used for Arrays that come out of pmap. def _array_local_result_handler(aval, sharding, indices): if aval.dtype == dtypes.float0: - return lambda _: np.zeros(aval.shape, dtypes.float0) # type: ignore + return lambda _: np.zeros(aval.shape, dtypes.float0) if dtypes.issubdtype(aval.dtype, dtypes.extended): return aval.dtype._rules.local_sharded_result_handler( aval, sharding, indices) @@ -954,3 +1209,21 @@ def _array_local_result_handler(aval, sharding, indices): ) pxla.local_result_handlers[core.ShapedArray] = _array_local_result_handler pxla.local_result_handlers[core.ConcreteArray] = _array_local_result_handler + + +# Token handlers + +def _token_shard_arg(xs, shardings): + return _array_shard_arg([x._buf for x in xs], shardings) +pxla.shard_arg_handlers[core.Token] = _token_shard_arg + + +def _token_global_result_handler(global_aval, out_sharding, committed): + array_handler = _array_global_result_handler( + core.token_shaped_array, out_sharding, committed) + + def wrapper(*args, **kwargs): + out_buf = array_handler(*args, **kwargs) + return core.Token(out_buf) + return wrapper +pxla.global_result_handlers[core.AbstractToken] = _token_global_result_handler diff --git a/jax/_src/basearray.py b/jax/_src/basearray.py index cc33d753b899..5809b9649f26 100644 --- a/jax/_src/basearray.py +++ b/jax/_src/basearray.py @@ -115,6 +115,14 @@ def sharding(self) -> Sharding: Array.__module__ = "jax" +# StaticScalar is the Union of all scalar types that can be converted to +# JAX arrays, and are possible to mark as static arguments. +StaticScalar = Union[ + np.bool_, np.number, # NumPy scalar types + bool, int, float, complex, # Python scalar types +] +StaticScalar.__doc__ = "Type annotation for JAX-compatible static scalars." + # ArrayLike is a Union of all objects that can be implicitly converted to a # standard JAX array (i.e. not including future non-standard array types like @@ -123,7 +131,6 @@ def sharding(self) -> Sharding: ArrayLike = Union[ Array, # JAX array type np.ndarray, # NumPy array type - np.bool_, np.number, # NumPy scalar types - bool, int, float, complex, # Python scalar types + StaticScalar, # valid scalars ] ArrayLike.__doc__ = "Type annotation for JAX array-like objects." diff --git a/jax/_src/basearray.pyi b/jax/_src/basearray.pyi index 304c1f0707c0..fbdd4894843e 100644 --- a/jax/_src/basearray.pyi +++ b/jax/_src/basearray.pyi @@ -12,7 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. import abc -from typing import Any, Callable, Optional, Sequence, Union +from collections.abc import Callable, Sequence +from typing import Any, Union import numpy as np from jax._src.sharding import Sharding @@ -57,12 +58,12 @@ class Array(abc.ABC): # Comparisons # these return bool for object, so ignore override errors. - def __lt__(self, other) -> Array: ... # type: ignore[override] - def __le__(self, other) -> Array: ... # type: ignore[override] + def __lt__(self, other) -> Array: ... + def __le__(self, other) -> Array: ... def __eq__(self, other) -> Array: ... # type: ignore[override] def __ne__(self, other) -> Array: ... # type: ignore[override] - def __gt__(self, other) -> Array: ... # type: ignore[override] - def __ge__(self, other) -> Array: ... # type: ignore[override] + def __gt__(self, other) -> Array: ... + def __ge__(self, other) -> Array: ... # Unary arithmetic @@ -109,25 +110,28 @@ class Array(abc.ABC): def __float__(self) -> float: ... def __index__(self) -> int: ... + def __buffer__(self, flags: int) -> memoryview: ... + def __release_buffer__(self, view: memoryview) -> None: ... + # np.ndarray methods: - def all(self, axis: Optional[Union[int, Sequence[int]]] = None, out=None, - keepdims=None, *, where: Optional[ArrayLike] = ...) -> Array: ... - def any(self, axis: Optional[Union[int, Sequence[int]]] = None, out=None, - keepdims=None, *, where: Optional[ArrayLike] = ...) -> Array: ... - def argmax(self, axis: Optional[int] = None, out=None, keepdims=None) -> Array: ... - def argmin(self, axis: Optional[int] = None, out=None, keepdims=None) -> Array: ... + def all(self, axis: int | Sequence[int] | None = None, out=None, + keepdims=None, *, where: ArrayLike | None = ...) -> Array: ... + def any(self, axis: int | Sequence[int] | None = None, out=None, + keepdims=None, *, where: ArrayLike | None = ...) -> Array: ... + def argmax(self, axis: int | None = None, out=None, keepdims=None) -> Array: ... + def argmin(self, axis: int | None = None, out=None, keepdims=None) -> Array: ... def argpartition(self, kth, axis=-1, kind='introselect', order=None) -> Array: ... - def argsort(self, axis: Optional[int] = -1, kind='quicksort', order=None) -> Array: ... + def argsort(self, axis: int | None = -1, kind='quicksort', order=None) -> Array: ... def astype(self, dtype) -> Array: ... def choose(self, choices, out=None, mode='raise') -> Array: ... def clip(self, min=None, max=None, out=None) -> Array: ... - def compress(self, condition, axis: Optional[int] = None, out=None) -> Array: ... + def compress(self, condition, axis: int | None = None, out=None) -> Array: ... def conj(self) -> Array: ... def conjugate(self) -> Array: ... def copy(self) -> Array: ... - def cumprod(self, axis: Optional[Union[int, Sequence[int]]] = None, + def cumprod(self, axis: int | Sequence[int] | None = None, dtype=None, out=None) -> Array: ... - def cumsum(self, axis: Optional[Union[int, Sequence[int]]] = None, + def cumsum(self, axis: int | Sequence[int] | None = None, dtype=None, out=None) -> Array: ... def diagonal(self, offset=0, axis1: int = 0, axis2: int = 1) -> Array: ... def dot(self, b, *, precision=None) -> Array: ... @@ -135,35 +139,35 @@ class Array(abc.ABC): @property def imag(self) -> Array: ... def item(self, *args) -> Any: ... - def max(self, axis: Optional[Union[int, Sequence[int]]] = None, out=None, + def max(self, axis: int | Sequence[int] | None = None, out=None, keepdims=None, initial=None, where=None) -> Array: ... - def mean(self, axis: Optional[Union[int, Sequence[int]]] = None, dtype=None, + def mean(self, axis: int | Sequence[int] | None = None, dtype=None, out=None, keepdims=False, *, where=None,) -> Array: ... - def min(self, axis: Optional[Union[int, Sequence[int]]] = None, out=None, + def min(self, axis: int | Sequence[int] | None = None, out=None, keepdims=None, initial=None, where=None) -> Array: ... @property def nbytes(self) -> int: ... def nonzero(self, *, size=None, fill_value=None) -> Array: ... - def prod(self, axis: Optional[Union[int, Sequence[int]]] = None, dtype=None, + def prod(self, axis: int | Sequence[int] | None = None, dtype=None, out=None, keepdims=None, initial=None, where=None) -> Array: ... - def ptp(self, axis: Optional[Union[int, Sequence[int]]] = None, out=None, + def ptp(self, axis: int | Sequence[int] | None = None, out=None, keepdims=False,) -> Array: ... def ravel(self, order='C') -> Array: ... @property def real(self) -> Array: ... - def repeat(self, repeats, axis: Optional[int] = None, *, + def repeat(self, repeats, axis: int | None = None, *, total_repeat_length=None) -> Array: ... def reshape(self, *args, order='C') -> Array: ... def round(self, decimals=0, out=None) -> Array: ... def searchsorted(self, v, side='left', sorter=None) -> Array: ... - def sort(self, axis: Optional[int] = -1, kind='quicksort', order=None) -> Array: ... - def squeeze(self, axis: Optional[Union[int, Sequence[int]]] = None) -> Array: ... - def std(self, axis: Optional[Union[int, Sequence[int]]] = None, + def sort(self, axis: int | None = -1, kind='quicksort', order=None) -> Array: ... + def squeeze(self, axis: int | Sequence[int] | None = None) -> Array: ... + def std(self, axis: int | Sequence[int] | None = None, dtype=None, out=None, ddof=0, keepdims=False, *, where=None) -> Array: ... - def sum(self, axis: Optional[Union[int, Sequence[int]]] = None, dtype=None, + def sum(self, axis: int | Sequence[int] | None = None, dtype=None, out=None, keepdims=None, initial=None, where=None) -> Array: ... def swapaxes(self, axis1: int, axis2: int) -> Array: ... - def take(self, indices, axis: Optional[int] = None, out=None, + def take(self, indices, axis: int | None = None, out=None, mode=None) -> Array: ... def tobytes(self, order='C') -> bytes: ... def tolist(self) -> list[Any]: ... @@ -174,15 +178,15 @@ class Array(abc.ABC): def T(self) -> Array: ... @property def mT(self) -> Array: ... - def var(self, axis: Optional[Union[int, Sequence[int]]] = None, + def var(self, axis: int | Sequence[int] | None = None, dtype=None, out=None, ddof=0, keepdims=False, *, where=None) -> Array: ... def view(self, dtype=None, type=None) -> Array: ... # Even though we don't always support the NumPy array protocol, e.g., for # tracer types, for type checking purposes we must declare support so we # implement the NumPy ArrayLike protocol. - def __array__(self, dtype: Optional[np.dtype] = ..., - copy: Optional[bool] = ...) -> np.ndarray: ... + def __array__(self, dtype: np.dtype | None = ..., + copy: bool | None = ...) -> np.ndarray: ... def __dlpack__(self) -> Any: ... # JAX extensions @@ -196,7 +200,6 @@ class Array(abc.ABC): def block_until_ready(self) -> Array: ... def copy_to_host_async(self) -> None: ... def delete(self) -> None: ... - def device(self) -> Device: ... def devices(self) -> set[Device]: ... @property def sharding(self) -> Sharding: ... @@ -213,15 +216,17 @@ class Array(abc.ABC): @property def traceback(self) -> Traceback: ... def unsafe_buffer_pointer(self) -> int: ... - @property - def device_buffers(self) -> Any: ... +StaticScalar = Union[ + np.bool_, np.number, # NumPy scalar types + bool, int, float, complex, # Python scalar types +] + ArrayLike = Union[ Array, # JAX array type np.ndarray, # NumPy array type - np.bool_, np.number, # NumPy scalar types - bool, int, float, complex, # Python scalar types + StaticScalar, # valid scalars ] @@ -233,23 +238,23 @@ class _IndexUpdateHelper: class _IndexUpdateRef: def get(self, indices_are_sorted: bool = False, unique_indices: bool = False, - mode: Optional[str] = None, fill_value: Optional[ArrayLike] = None) -> Array: ... + mode: str | None = None, fill_value: StaticScalar | None = None) -> Array: ... def set(self, values: Any, indices_are_sorted: bool = False, unique_indices: bool = False, - mode: Optional[str] = None, fill_value: Optional[ArrayLike] = None) -> Array: ... + mode: str | None = None, fill_value: StaticScalar | None = None) -> Array: ... def add(self, values: Any, indices_are_sorted: bool = False, - unique_indices: bool = False, mode: Optional[str] = None) -> Array: ... + unique_indices: bool = False, mode: str | None = None) -> Array: ... def mul(self, values: Any, indices_are_sorted: bool = False, - unique_indices: bool = False, mode: Optional[str] = None) -> Array: ... + unique_indices: bool = False, mode: str | None = None) -> Array: ... def multiply(self, values: Any, indices_are_sorted: bool = False, - unique_indices: bool = False, mode: Optional[str] = None) -> Array: ... + unique_indices: bool = False, mode: str | None = None) -> Array: ... def divide(self, values: Any, indices_are_sorted: bool = False, - unique_indices: bool = False, mode: Optional[str] = None) -> Array: ... + unique_indices: bool = False, mode: str | None = None) -> Array: ... def power(self, values: Any, indices_are_sorted: bool = False, - unique_indices: bool = False, mode: Optional[str] = None) -> Array: ... + unique_indices: bool = False, mode: str | None = None) -> Array: ... def min(self, values: Any, indices_are_sorted: bool = False, - unique_indices: bool = False, mode: Optional[str] = None) -> Array: ... + unique_indices: bool = False, mode: str | None = None) -> Array: ... def max(self, values: Any, indices_are_sorted: bool = False, - unique_indices: bool = False, mode: Optional[str] = None) -> Array: ... + unique_indices: bool = False, mode: str | None = None) -> Array: ... def apply(self, func: Callable[[ArrayLike], ArrayLike], indices_are_sorted: bool = False, - unique_indices: bool = False, mode: Optional[str] = None) -> Array: ... + unique_indices: bool = False, mode: str | None = None) -> Array: ... diff --git a/jax/_src/blocked_sampler.py b/jax/_src/blocked_sampler.py new file mode 100644 index 000000000000..16da61d75b3f --- /dev/null +++ b/jax/_src/blocked_sampler.py @@ -0,0 +1,165 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections.abc import Sequence +from typing import Any, Protocol +import jax +from jax._src import random +from jax._src.typing import Array, ArrayLike +from jax import numpy as jnp + +NdKeyList = Any +Shape = random.Shape + +class SampleFn(Protocol): + def __call__(self, key: random.KeyArrayLike, *args, shape: Shape, + **kwargs) -> Array: + ... + + +def _compute_scalar_index(iteration_index: Sequence[int], + total_size: Shape, + block_size: Shape, + block_index: Sequence[int]) -> int: + ndims = len(iteration_index) + dim_size = 1 + total_idx = 0 + for i in range(ndims-1, -1, -1): + dim_idx = block_index[i] + iteration_index[i] * block_size[i] + total_idx += dim_idx * dim_size + dim_size *= total_size[i] + return total_idx + + +def blocked_fold_in( + global_key: random.KeyArrayLike, + total_size: Shape, + block_size: Shape, + tile_size: Shape, + block_index: Sequence[ArrayLike], + ) -> NdKeyList: + """Computes a grid of keys for block-invariant sampling. + + Suppose we wished to construct a 16x512 array of random numbers, using + block sizes of 16x128 and 16x256. We could select an tile size of 8x128 + (which divides both 16x128 and 16x256) and divide the total array in tiles as: + --------------------------------- + | 8x128 | 8x128 | 8x128 | 8x128 | + --------------------------------- + | 8x128 | 8x128 | 8x128 | 8x128 | + --------------------------------- + + We generate a key for each tile as: + tile_key = fold_in(global_key, tile_idx) + + Where the tile_idx is the row-major raveled index of each element: + ----------------- + | 0 | 1 | 2 | 3 | + ----------------- + | 4 | 5 | 6 | 7 | + ----------------- + + We then compute and return the keys required to sample the tiles that make + up the current block (specified via `block_index`). + With a 16x256 block size, each block requires 4 (2x2) tile keys: + --------------- + | 0, 1 | 2, 3 | + | 4, 5 | 6, 7 | + --------------- + Therefore, we return a grid of 2x2 keys for each block (2 blocks total). + + With a 16x128 block size, each block requires 2 (2x1) tile keys: + ----------------- + | 0 | 1 | 2 | 3 | + | 4 | 5 | 6 | 7 | + ----------------- + Therefore, we return a grid of 2x1 keys for each block (4 blocks total). + + Args: + global_key: The global key shared between all blocks. + total_size: The shape of the array being generated. + block_size: The shape of an individual block. + tile_size: The shape of a `tile`, which is the smallest unit at + which samples are generated. This should be selected to be a divisor + of all block sizes one needs to be invariant to. + block_index: The index denoting which block to generate keys for. + + Returns: + An N-dimensional nested list of keys required to sample the tiles + corresponding to the block specified by `block_index`. + """ + size_in_blocks = tuple( + _shape // _element for _shape, _element in zip(block_size, tile_size)) + + def _keygen_loop(axis, prefix): + if axis == len(size_in_blocks): + subtile_key = jax.random.fold_in( + global_key, _compute_scalar_index( + block_index, total_size, size_in_blocks, prefix)) + return subtile_key + else: + keys = [] + for i in range(size_in_blocks[axis]): + keys.append(_keygen_loop(axis+1, prefix+(i,))) + return keys + return _keygen_loop(0, tuple()) + + +def sample_block( + sampler_fn: SampleFn, + keys: NdKeyList, + block_size: Shape, + tile_size: Shape, + *args, + **kwargs + ) -> jax.Array: + """Draws random samples for a single block. + + This function is intended to be used in conjunction with `blocked_fold_in`: + ``` + key_list = blocked_fold_in(global_key, total_size, block_size, tile_size, + block_index) + samples = sample_block(jax.random.uniform, key_list, block_size, tile_size) + ``` + + Args: + sampler_fn: A random sampling function, e.g. jax.random.uniform. + keys: A grid of keys generated by `blocked_fold_in`. + block_size: The shape of an individual block. + tile_size: The shape of a `tile`, which is the smallest unit at + which samples are generated. This should be selected to be a divisor + of all block sizes one needs to be invariant to. + args: varargs for sampler_fn. + kwargs: kwargs for sampler_fn. + + Returns: + An array of random samples drawn using sampler_fn. + """ + size_in_tiles = tuple( + _shape // _element for _shape, _element in zip(block_size, tile_size)) + def _nested_index(arr: jax.Array, idx: Sequence[int]) -> jax.Array: + if len(idx) == 1: + return arr[idx[0]] + return _nested_index(arr[idx[0]], idx[1:]) + + def _sample_loop(axis: int, prefix: tuple[int, ...]) -> jax.Array: + if axis == len(size_in_tiles): + return sampler_fn(_nested_index(keys, prefix), *args, + shape=tile_size, **kwargs) + else: + samples = [] + for i in range(size_in_tiles[axis]): + samples.append(_sample_loop(axis+1, prefix+(i,))) + return jnp.concatenate(samples, axis=axis) + return _sample_loop(0, tuple()) diff --git a/jax/_src/cache_key.py b/jax/_src/cache_key.py index 13f3e2933ae9..6fdf0c600b7d 100644 --- a/jax/_src/cache_key.py +++ b/jax/_src/cache_key.py @@ -18,6 +18,7 @@ import logging import os import sys +from typing import cast as type_cast from jax._src import config from jax._src.lib import version_str as jaxlib_version_str @@ -90,7 +91,10 @@ def get(module: ir.Module, lambda hash_obj: _hash_xla_flags(hash_obj, get_flag_prefixes())), ("compile_options", lambda hash_obj: _hash_serialized_compile_options( - hash_obj, compile_options)), + hash_obj, compile_options, + # In case of GPU multi-process tasks we need to strip device + # assignment to use cache key as invariant between processes. + strip_device_assignment=(backend.platform == "gpu"))), ("accelerator_config", lambda hash_obj: _hash_accelerator_config(hash_obj, devices, backend)), ("compression", @@ -133,7 +137,7 @@ def _serialize_ir(m: ir.Module) -> bytes: def _canonicalize_ir(m_original: ir.Module) -> bytes: with m_original.context: - m = m_original.operation.clone() + m = type_cast(ir.Module, m_original.operation.clone()) passes = pm.PassManager.parse( "builtin.module(strip-debuginfo)" ) @@ -172,7 +176,8 @@ def _hash_accelerator_config(hash_obj, accelerators: np.ndarray, backend): _hash_platform(hash_obj, backend) -def _hash_serialized_compile_options(hash_obj, compile_options_obj): +def _hash_serialized_compile_options(hash_obj, compile_options_obj, + strip_device_assignment=False): # Do not mess with the original CompileOptions object since it is passed to # the compiler. Create a deep copy for the purpose of cache key generation. compile_options_copy = copy.deepcopy(compile_options_obj) @@ -211,6 +216,13 @@ def _hash_serialized_compile_options(hash_obj, compile_options_obj): debug_options.xla_gpu_cuda_data_dir = "" # LINT.ThenChange(:xla_flags) + if strip_device_assignment and compile_options_copy.device_assignment: + replica_count = compile_options_copy.device_assignment.replica_count() + computation_count = compile_options_copy.device_assignment.computation_count() + compile_options_copy.device_assignment = xla_client.DeviceAssignment.create( + np.arange(replica_count * computation_count).reshape( + [replica_count, computation_count]) + ) return hash_obj.update(compile_options_copy.SerializeAsString()) diff --git a/jax/_src/callback.py b/jax/_src/callback.py index 506545446682..2820d8bf2eae 100644 --- a/jax/_src/callback.py +++ b/jax/_src/callback.py @@ -14,12 +14,13 @@ """Module for JAX callbacks.""" from __future__ import annotations -from collections.abc import Sequence +from collections.abc import Callable, Sequence +import dataclasses import functools -from typing import Any, Callable - -import numpy as np +import logging +from typing import Any +import jax from jax._src import core from jax._src import dispatch from jax._src import dtypes @@ -30,9 +31,13 @@ from jax._src.interpreters import ad from jax._src.interpreters import batching from jax._src.interpreters import mlir -from jax._src.lib import xla_client as xc from jax._src.lax.control_flow.loops import map as lax_map +from jax._src.lib import xla_client as xc from jax._src.sharding_impls import SingleDeviceSharding +import numpy as np + +logger = logging.getLogger(__name__) + # `pure_callback_p` is the main primitive for staging out Python pure callbacks. pure_callback_p = core.Primitive("pure_callback") @@ -42,15 +47,46 @@ map, unsafe_map = util.safe_map, map +@dataclasses.dataclass(frozen=True) +class _FlatCallback: + """A Python function callable with flat arguments and results. + + An instance of this class is used as a parameter for the callback primitives. + We prefer it to an anonymous flattened function because it produces + equal objects when we call the same Python function with the same argument + structure. + """ + callback_func: Callable[..., Any] + in_tree: tree_util.PyTreeDef # (args, kwargs) pytree for `callback_func`. + + def __call__(self, *flat_args: jax.Array) -> Sequence[jax.Array]: + args, kwargs = tree_util.tree_unflatten(self.in_tree, flat_args) + return tree_util.tree_leaves(self.callback_func(*args, **kwargs)) + + def pure_callback_impl( *args, result_avals, - callback: Callable[..., Any], + callback: _FlatCallback, sharding: SingleDeviceSharding | None, vectorized: bool, ): del sharding, vectorized, result_avals - return callback(*args) + try: + cpu_device, *_ = jax.local_devices(backend="cpu") + except RuntimeError as e: + raise RuntimeError( + "jax.pure_callback failed to find a local CPU device to place the" + " inputs on. Make sure \"cpu\" is listed in --jax_platforms or the" + " JAX_PLATFORMS environment variable." + ) from e + args = jax.device_put(args, cpu_device) + with jax.default_device(cpu_device): + try: + return tree_util.tree_map(np.asarray, callback(*args)) + except BaseException: + logger.exception("jax.pure_callback failed") + raise pure_callback_p.def_impl(functools.partial(dispatch.apply_primitive, @@ -60,7 +96,7 @@ def pure_callback_impl( @pure_callback_p.def_abstract_eval def pure_callback_abstract_eval( *avals, - callback: Callable[..., Any], + callback: _FlatCallback, result_avals, sharding: SingleDeviceSharding | None, vectorized: bool, @@ -88,14 +124,14 @@ def pure_callback_transpose_rule(*args, **kwargs): ad.primitive_transposes[pure_callback_p] = pure_callback_transpose_rule -def pure_callback_batching_rule( +def callback_batching_rule( + prim, args, dims, *, - callback, - sharding: SingleDeviceSharding | None, vectorized: bool, result_avals: Sequence[core.ShapedArray], + **kwargs: Any, ): axis_size = next(a.shape[d] for a, d in zip(args, dims) if d is not batching.not_mapped) @@ -105,30 +141,30 @@ def pure_callback_batching_rule( result_avals = tuple( core.unmapped_aval(axis_size, core.no_axis_name, 0, aval) # type: ignore for aval in result_avals) - outvals = pure_callback_p.bind( + outvals = prim.bind( *new_args, - callback=callback, - sharding=sharding, vectorized=vectorized, result_avals=result_avals, + **kwargs, ) else: is_batched = [d is not batching.not_mapped for d in dims] unbatched_args, batched_args = util.partition_list(is_batched, new_args) def _batch_fun(batched_args): merged_args = util.merge_lists(is_batched, unbatched_args, batched_args) - return pure_callback_p.bind( + return prim.bind( *merged_args, - callback=callback, - sharding=sharding, result_avals=result_avals, vectorized=vectorized, + **kwargs, ) outvals = lax_map(_batch_fun, batched_args) return tuple(outvals), (0,) * len(outvals) -batching.primitive_batchers[pure_callback_p] = pure_callback_batching_rule +batching.primitive_batchers[pure_callback_p] = functools.partial( + callback_batching_rule, pure_callback_p +) def _callback_op_sharding(axis_context, sharding: SingleDeviceSharding | None): @@ -185,7 +221,7 @@ def _callback_op_sharding(axis_context, sharding: SingleDeviceSharding | None): def pure_callback_lowering( - ctx, *args, callback, sharding: SingleDeviceSharding | None, **params + ctx, *args, callback: _FlatCallback, sharding: SingleDeviceSharding | None, **params ): def _callback(*flat_args): return tuple( @@ -205,7 +241,7 @@ def _callback(*flat_args): list(args), ctx.avals_in, ctx.avals_out, - False, + has_side_effect=False, sharding=op_sharding, ) return result @@ -217,7 +253,7 @@ def _check_shape_dtype(shape_dtype): dt = np.dtype(shape_dtype.dtype) if dtypes.canonicalize_dtype(dt) != dt: raise ValueError( - "Cannot return 64-bit values when `jax_enable_x64` is disabled") + "result_shape_dtypes cannot specify 64-bit types when `jax_enable_x64` is disabled") def pure_callback( @@ -228,14 +264,39 @@ def pure_callback( vectorized: bool = False, **kwargs: Any, ): - """Calls a pure Python callback. + """Calls a pure Python callback. Works under :func:`jit`/:func:`~vmap`/etc. For more explanation, see `External Callbacks`_. + ``pure_callback`` enables calling a Python function in JIT-ed JAX functions. + The input ``callback`` will be passed JAX arrays placed on a local CPU, and + it should also return JAX arrays on CPU. + + The callback is treated as functionally pure, meaning it has no side-effects + and its output value depends only on its argument values. As a consequence, it + is safe to be called multiple times (e.g. when transformed by :func:`~vmap` or + :func:`~pmap`), or not to be called at all when e.g. the output of a + `jit`-decorated function has no data dependence on its value. Pure callbacks + may also be reordered if data-dependence allows. + + When `vmap`-ed the behavior will depend on the value of the + ``vectorized`` keyword argument. When ``vectorized`` is ``True``, the callback + is assumed to obey + ``jax.vmap(callback)(xs) == callback(xs) == jnp.stack([callback(x) for x in xs])``. + Therefore, the callback will be called directly on batched inputs (where the + batch axes are the leading dimensions). Additionally, the callbacks should + return outputs that have corresponding leading batch axes. If not vectorized + ``callback`` will be mapped sequentially across the batched axis. + For example, if ``callback = lambda x, y: np.matmul(x, y)``, then we are free + to set ``vectorized=True`` because the ``np.matmul`` function handles + arbitrary leading batch dimensions. + Args: callback: function to execute on the host. The callback is assumed to be a pure function (i.e. one without side-effects): if an impure function is passed, it - may behave in unexpected ways, particularly under transformation. + may behave in unexpected ways, particularly under transformation. The callable + will be passed PyTrees of arrays as arguments, and should return a PyTree of + arrays that matches ``result_shape_dtypes``. result_shape_dtypes: pytree whose leaves have ``shape`` and ``dtype`` attributes, whose structure matches the expected output of the callback function at runtime. :class:`jax.ShapeDtypeStruct` is often used to define leaf values. @@ -257,10 +318,6 @@ def pure_callback( .. _External Callbacks: https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html """ - def _flat_callback(*flat_args): - args, kwargs = tree_util.tree_unflatten(in_tree, flat_args) - return tree_util.tree_leaves(callback(*args, **kwargs)) - flat_args, in_tree = tree_util.tree_flatten((args, kwargs)) tree_util.tree_map(_check_shape_dtype, result_shape_dtypes) result_avals = tree_util.tree_map( @@ -268,7 +325,7 @@ def _flat_callback(*flat_args): flat_result_avals, out_tree = tree_util.tree_flatten(result_avals) out_flat = pure_callback_p.bind( *flat_args, - callback=_flat_callback, + callback=_FlatCallback(callback, in_tree), result_avals=tuple(flat_result_avals), sharding=sharding, vectorized=vectorized, @@ -276,75 +333,6 @@ def _flat_callback(*flat_args): return tree_util.tree_unflatten(out_tree, out_flat) -def pure_callback_api( - callback: Callable[..., Any], - result_shape_dtypes: Any, - *args: Any, - sharding: SingleDeviceSharding | None = None, - vectorized: bool = False, - **kwargs: Any, -): - """Applies a functionally pure Python callable. Works under :func:`jit`/:func:`~pmap`/etc. - - ``pure_callback`` enables calling a Python function in JIT-ed JAX functions. - The input ``callback`` will be passed NumPy arrays in place of JAX arrays and - should also return NumPy arrays. Execution takes place on CPU, like any - Python+NumPy function. - - The callback is treated as functionally pure, meaning it has no side-effects - and its output value depends only on its argument values. As a consequence, it - is safe to be called multiple times (e.g. when transformed by :func:`~vmap` or - :func:`~pmap`), or not to be called at all when e.g. the output of a - `jit`-decorated function has no data dependence on its value. Pure callbacks - may also be reordered if data-dependence allows. - - When :func:`~pmap`-ed, the pure callback will be called several times (one on each - axis of the map). When `vmap`-ed the behavior will depend on the value of the - ``vectorized`` keyword argument. When ``vectorized`` is ``True``, the callback - is assumed to obey - ``jax.vmap(callback)(xs) == callback(xs) == jnp.stack([callback(x) for x in xs])``. - Therefore, the callback will be called directly on batched inputs (where the - batch axes are the leading dimensions). Additionally, the callbacks should - return outputs that have corresponding leading batch axes. If not vectorized - ``callback`` will be mapped sequentially across the batched axis. - For example, if ``callback = lambda x, y: np.matmul(x, y)``, then we are free - to set ``vectorized=True`` because the ``np.matmul`` function handles - arbitrary leading batch dimensions. - - Args: - callback: A Python callable. The callable will be passed PyTrees of NumPy - arrays as arguments, and should return a PyTree of NumPy arrays that - matches ``result_shape_dtypes``. - result_shape_dtypes: A PyTree with leaves that are objects with ``shape`` - and ``dtype`` attributes which represent to the shapes and dtypes of the - value of ``callback`` applied to ``args`` and ``kwargs``. - *args: The positional arguments to the callback. Must be PyTrees of JAX - types. - sharding: optional sharding that specifies the device from which the - callback should be invoked. - vectorized: A boolean that indicates whether or not ``callback`` is - vectorized, meaning it can handle arrays with additional leading - dimensions. If ``vectorized`` is `True`, when the callback is mapped - via `jax.vmap`, it will be called directly on inputs with leading batch - dimensions instead of executing ``callback`` on each mapped input - individually. The callback should also return outputs batched across the - leading axis. By default, ``vectorized`` is ``False``. - **kwargs: The keyword arguments to the callback. Must be PyTrees of JAX - types. - - Returns: - The value of ``callback(*args, **kwargs)``. - """ - return pure_callback( - callback, - result_shape_dtypes, - *args, - sharding=sharding, - vectorized=vectorized, - **kwargs, - ) - - # IO Callback io_callback_p = core.Primitive("io_callback") @@ -370,12 +358,26 @@ class OrderedIOEffect(effects.Effect): def io_callback_impl( *args, result_avals, - callback: Callable[..., Any], + callback: _FlatCallback, sharding: SingleDeviceSharding | None, ordered: bool, ): del result_avals, sharding, ordered - return callback(*args) + try: + cpu_device, *_ = jax.local_devices(backend="cpu") + except RuntimeError as e: + raise RuntimeError( + "jax.io_callback failed to find a local CPU device to place the" + " inputs on. Make sure \"cpu\" is listed in --jax_platforms or the" + " JAX_PLATFORMS environment variable." + ) from e + args = jax.device_put(args, cpu_device) + with jax.default_device(cpu_device): + try: + return tree_util.tree_map(np.asarray, callback(*args)) + except BaseException: + logger.exception("jax.io_callback failed") + raise io_callback_p.def_impl(functools.partial(dispatch.apply_primitive, @@ -385,7 +387,7 @@ def io_callback_impl( @io_callback_p.def_effectful_abstract_eval def io_callback_abstract_eval( *avals, - callback: Callable[..., Any], + callback: _FlatCallback, result_avals, sharding: SingleDeviceSharding | None, ordered: bool, @@ -412,16 +414,16 @@ def io_callback_batching_rule( ): if ordered: raise ValueError("Cannot `vmap` ordered IO callback.") - return pure_callback_batching_rule( - args, - dims, - callback=callback, - sharding=sharding, - vectorized=False, - result_avals=result_avals, - ) - - + is_batched = [d is not batching.not_mapped for d in dims] + new_args = [arg if dim is batching.not_mapped else + batching.moveaxis(arg, dim, 0) for arg, dim in zip(args, dims)] + unbatched_args, batched_args = util.partition_list(is_batched, new_args) + def _batch_fun(batched_args): + merged = util.merge_lists(is_batched, unbatched_args, batched_args) + return io_callback_p.bind(*merged, callback=callback, sharding=sharding, + result_avals=result_avals, ordered=False) + out_vals = lax_map(_batch_fun, batched_args) + return out_vals, (0,) * len(out_vals) batching.primitive_batchers[io_callback_p] = io_callback_batching_rule @@ -447,7 +449,7 @@ def _callback(*flat_args): list(args), ctx.avals_in, ctx.avals_out, - True, + has_side_effect=True, sharding=op_sharding, ) ctx.set_tokens_out(mlir.TokenSet({_OrderedIOEffect: (token,)})) @@ -459,7 +461,7 @@ def _callback(*flat_args): list(args), ctx.avals_in, ctx.avals_out, - True, + has_side_effect=True, sharding=op_sharding, ) return result @@ -481,7 +483,7 @@ def io_callback( For more explanation, see `External Callbacks`_. Args: - callback: function to execute on the host. It is assumet to be an impure function. + callback: function to execute on the host. It is assumed to be an impure function. If ``callback`` is pure, using :func:`jax.pure_callback` instead may lead to more efficient execution. result_shape_dtypes: pytree whose leaves have ``shape`` and ``dtype`` attributes, @@ -504,10 +506,6 @@ def io_callback( .. _External Callbacks: https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html """ - def _flat_callback(*flat_args): - args, kwargs = tree_util.tree_unflatten(in_tree, flat_args) - return tree_util.tree_leaves(callback(*args, **kwargs)) - flat_args, in_tree = tree_util.tree_flatten((args, kwargs)) tree_util.tree_map(_check_shape_dtype, result_shape_dtypes) flat_shape_dtypes, out_tree = tree_util.tree_flatten(result_shape_dtypes) @@ -516,7 +514,7 @@ def _flat_callback(*flat_args): flat_args = map(core.raise_as_much_as_possible, flat_args) out_flat = io_callback_p.bind( *flat_args, - callback=_flat_callback, + callback=_FlatCallback(callback, in_tree), result_avals=tuple(flat_result_avals), sharding=sharding, ordered=ordered, diff --git a/jax/_src/checkify.py b/jax/_src/checkify.py index 303d1d9af0d4..1167914e51c9 100644 --- a/jax/_src/checkify.py +++ b/jax/_src/checkify.py @@ -13,11 +13,11 @@ # limitations under the License. from __future__ import annotations -from collections.abc import Sequence +from collections.abc import Callable, Sequence import dataclasses import functools import itertools as it -from typing import Callable, TypeVar, Any, Union +from typing import TypeVar, Any, Union import numpy as np @@ -271,9 +271,9 @@ def _get_batched_exception(self) -> BatchedError | None: return None def _update(self, effect_type: ErrorEffect, pred, code, metadata, payload): - new_errs = {**self._pred, **{effect_type: pred}} # type: ignore - new_codes = {**self._code, **{effect_type: code}} # type: ignore - new_payload = {**self._payload, **{effect_type: payload}} # type: ignore + new_errs = {**self._pred, **{effect_type: pred}} + new_codes = {**self._code, **{effect_type: code}} + new_payload = {**self._payload, **{effect_type: payload}} new_metadata = {**self._metadata, **metadata} return Error(new_errs, new_codes, new_metadata, new_payload) @@ -751,7 +751,7 @@ def jaxpr_to_checkify_jaxpr( out_tree, error_effects = metadata() return checked_jaxpr, out_tree, error_effects -def cond_error_check(error: Error, enabled_errors, index, *ops, branches, linear): +def cond_error_check(error: Error, enabled_errors, index, *ops, branches): # Get the error-effects out of all branches so the cond can be called with # a merged error with all these effects. err_vals, err_tree = jtu.tree_flatten(error) @@ -763,7 +763,6 @@ def get_error_effects_from_jaxpr(jxpr): effects = [get_error_effects_from_jaxpr(jxpr) for jxpr in branches] merged_error = error._add_placeholder_effects(set().union(*effects)) err_vals, err_tree = jtu.tree_flatten(merged_error) - new_linear = (*[False] * len(err_vals), *linear) # Update branch jaxprs to be checkified jaxprs. in_avals = map(get_shaped_aval, [*err_vals, *ops]) @@ -773,7 +772,7 @@ def get_error_effects_from_jaxpr(jxpr): err_and_outs = lax.cond_p.bind( index, *err_vals, *ops, - branches=tuple(new_branches), linear=new_linear) + branches=tuple(new_branches)) # we need to merge metadata across out_trees (a tuple) err0, out = tree_unflatten(out_trees[0], err_and_outs) @@ -785,7 +784,7 @@ def get_error_effects_from_jaxpr(jxpr): error_checks[lax.cond_p] = cond_error_check def scan_error_check(error, enabled_errors, *in_flat, reverse, length, jaxpr, - num_consts, num_carry, linear, unroll): + num_consts, num_carry, linear, unroll, _split_transpose): consts, carry, xs = split_list(in_flat, [num_consts, num_carry]) xs_mapped = [core.mapped_aval(length, 0, get_shaped_aval(val)) for val in xs] @@ -812,7 +811,7 @@ def scan_error_check(error, enabled_errors, *in_flat, reverse, length, jaxpr, err_and_out = lax.scan_p.bind( *new_in_flat, reverse=reverse, length=length, jaxpr=checked_jaxpr, num_consts=len(consts), num_carry=len(carry)+len(err_vals), - linear=new_linear, unroll=unroll) + linear=new_linear, unroll=unroll, _split_transpose=_split_transpose) err, out = tree_unflatten(out_tree, err_and_out) return err, out @@ -895,9 +894,9 @@ def while_loop_error_check(error, enabled_errors, *in_flat, cond_nconsts, error_checks[lax.while_p] = while_loop_error_check def pjit_error_check(error, enabled_errors, *vals_in, jaxpr, - in_shardings, out_shardings, resource_env, - donated_invars, name, - inline, keep_unused): + in_shardings, out_shardings, + in_layouts, out_layouts, + resource_env, donated_invars, name, inline, keep_unused): # jaxpr to checked_jaxpr err_vals, err_tree = jtu.tree_flatten(error) new_vals_in = [*err_vals, *vals_in] @@ -908,10 +907,12 @@ def pjit_error_check(error, enabled_errors, *vals_in, jaxpr, # Update pjit params to account for extra error values. num_error_vals = len(err_vals) num_out_error_vals = out_tree.num_leaves - len(out_shardings) - sharding = sharding_impls.UNSPECIFIED + sharding = sharding_impls.UNSPECIFIED new_in_shardings = (*[sharding] * num_error_vals, *in_shardings) new_out_shardings = (*[sharding] * num_out_error_vals, *out_shardings) + new_in_layouts = (*[None] * num_error_vals, *in_layouts) + new_out_layouts = (*[None] * num_out_error_vals, *out_layouts) new_donated_invars = (*[False] * num_error_vals, *donated_invars) err_and_out = pjit.pjit_p.bind( @@ -919,6 +920,8 @@ def pjit_error_check(error, enabled_errors, *vals_in, jaxpr, jaxpr=checked_jaxpr, in_shardings=new_in_shardings, out_shardings=new_out_shardings, + in_layouts=new_in_layouts, + out_layouts=new_out_layouts, resource_env=resource_env, donated_invars=new_donated_invars, name=name, @@ -1296,6 +1299,6 @@ def check_error(error: Error) -> None: >>> error, _ = checkify.checkify(with_inner_jit)(-1) """ if not isinstance(error, Error): - raise ValueError('check_error takes an Error as argument, ' + raise TypeError('check_error takes an Error as argument, ' f'got type {type(error)} instead.') _check_error(error, debug=False) diff --git a/jax/_src/cloud_tpu_init.py b/jax/_src/cloud_tpu_init.py index 4b89cf12d91a..73e61de68008 100644 --- a/jax/_src/cloud_tpu_init.py +++ b/jax/_src/cloud_tpu_init.py @@ -13,6 +13,8 @@ # limitations under the License. import os +from jax import version +from jax._src import config from jax._src import hardware_utils running_in_cloud_tpu_vm: bool = False @@ -32,6 +34,18 @@ def maybe_import_libtpu(): return libtpu +def get_tpu_library_path() -> str | None: + path_from_env = os.getenv("TPU_LIBRARY_PATH") + if path_from_env is not None and os.path.isfile(path_from_env): + return path_from_env + + libtpu_module = maybe_import_libtpu() + if libtpu_module is not None: + return libtpu_module.get_library_path() + + return None + + def jax_force_tpu_init() -> bool: return 'JAX_FORCE_TPU_INIT' in os.environ @@ -54,11 +68,10 @@ def cloud_tpu_init() -> None: """ global running_in_cloud_tpu_vm - # We assume we are in a correctly-configured Cloud TPU environment - # if the following hold: a) libtpu is installed b) JAX_FORCE_TPU_INIT is set - # Exit early if we're not running on Cloud TPU. - libtpu_module = maybe_import_libtpu() - if libtpu_module is None and not jax_force_tpu_init(): + # Exit early if we're not running on a Cloud TPU VM or libtpu isn't installed. + libtpu_path = get_tpu_library_path() + num_tpu_chips = hardware_utils.num_available_tpu_chips_and_device_id()[0] + if (libtpu_path is None or num_tpu_chips == 0) and not jax_force_tpu_init(): return running_in_cloud_tpu_vm = True @@ -66,5 +79,17 @@ def cloud_tpu_init() -> None: os.environ.setdefault('GRPC_VERBOSITY', 'ERROR') os.environ.setdefault('JAX_PLATFORMS', 'tpu,cpu') os.environ['TPU_ML_PLATFORM'] = 'JAX' + os.environ['TPU_ML_PLATFORM_VERSION'] = version.__version__ + os.environ['ENABLE_RUNTIME_UPTIME_TELEMETRY'] = '1' if hardware_utils.tpu_enhanced_barrier_supported(): os.environ["LIBTPU_INIT_ARGS"] = os.environ.get("LIBTPU_INIT_ARGS","") + " --xla_tpu_use_enhanced_launch_barrier=true" + + # this makes tensorstore serialization work better on TPU + os.environ.setdefault('TENSORSTORE_CURL_LOW_SPEED_TIME_SECONDS', '60') + os.environ.setdefault('TENSORSTORE_CURL_LOW_SPEED_LIMIT_BYTES', '256') + + if config.jax_pjrt_client_create_options.value is None: + config.update( + 'jax_pjrt_client_create_options', + f'ml_framework_name:JAX;ml_framework_version:{version.__version__}' + ) diff --git a/jax/_src/clusters/__init__.py b/jax/_src/clusters/__init__.py index d933af613810..73e4ac9412f7 100644 --- a/jax/_src/clusters/__init__.py +++ b/jax/_src/clusters/__init__.py @@ -22,6 +22,6 @@ # available one from the list will be picked. from .ompi_cluster import OmpiCluster from .slurm_cluster import SlurmCluster +from .mpi4py_cluster import Mpi4pyCluster from .cloud_tpu_cluster import GkeTpuCluster -from .cloud_tpu_cluster import MultisliceGceTpuCluster -from .cloud_tpu_cluster import SingleSliceGceTpuCluster +from .cloud_tpu_cluster import GceTpuCluster diff --git a/jax/_src/clusters/cloud_tpu_cluster.py b/jax/_src/clusters/cloud_tpu_cluster.py index a978cf4beff7..c85abb2f83de 100644 --- a/jax/_src/clusters/cloud_tpu_cluster.py +++ b/jax/_src/clusters/cloud_tpu_cluster.py @@ -14,6 +14,7 @@ from __future__ import annotations +import logging import os import re import socket @@ -21,10 +22,14 @@ from jax._src import clusters from jax._src.cloud_tpu_init import running_in_cloud_tpu_vm +logger = logging.getLogger(__name__) + # We use an arbitrarily chosen port for the coordinator since we cannot # rely on communication to choose one in real time. coordinator_port = '8476' +metadata_response_code_success = 200 + def get_metadata(key): import requests # pytype: disable=import-error import time # pytype: disable=import-error @@ -47,11 +52,11 @@ def get_metadata(key): if api_resp is None: raise RuntimeError(f"Getting metadata['{key}'] failed for 6 tries") - return api_resp.text + return api_resp.text, api_resp.status_code def get_tpu_env_value(key): def get_tpu_env_value_from_metadata(key): - tpu_env_data = get_metadata('tpu-env') + tpu_env_data = get_metadata('tpu-env')[0] key_value_pairs = tpu_env_data.split('\n') for key_value_pair in key_value_pairs: # Typical line is MEGASCALE_NUM_SLICES: '2' @@ -65,112 +70,155 @@ def get_tpu_env_value_from_metadata(key): value = os.environ.get(key, None) return value if value is not None else get_tpu_env_value_from_metadata(key) -def is_gce_env(): - worker_number_string = get_metadata('agent-worker-number') - try: - worker_number = int(worker_number_string) - return True - except: - return False - -def is_multislice_gce_env(): - return is_gce_env() and get_tpu_env_value('MEGASCALE_COORDINATOR_ADDRESS') is not None - -def is_gke_env(): - return os.environ.get("TPU_WORKER_HOSTNAMES", None) is not None +def has_megascale_address(): + return get_tpu_env_value('MEGASCALE_COORDINATOR_ADDRESS') is not None -def get_gce_worker_endpoints() -> str: - return get_metadata('worker-network-endpoints').split(',') - -class SingleSliceGceTpuCluster(clusters.ClusterEnv): - @classmethod - def is_env_present(cls) -> bool: - return running_in_cloud_tpu_vm and is_gce_env() and not is_multislice_gce_env() - - @classmethod - def get_coordinator_address(cls) -> str: - return f"{get_gce_worker_endpoints()[0].split(':')[2]}:{coordinator_port}" +class BaseTpuCluster(clusters.ClusterEnv): - @classmethod - def get_process_count(cls) -> int: - return len(get_gce_worker_endpoints()) + name: str = "tpu" - @classmethod - def get_process_id(cls) -> int: - return int(get_metadata('agent-worker-number')) + """Abstract cluster supports both single and multislice TPU environments. - @classmethod - def get_local_process_id(cls) -> int | None: - return None + If MEGASCALE_COORDINATOR_ADDRESS is not set, we assume single slice topology. + Concrete extensions of this class must implement methods for generating a list + of within-slice workers and a within-slice process ID. + `get_coordinator_address` must return the address of the host with + process ID 0 (as returned by `get_process_id`), since the coordinator service + is started on the host with process ID = 0. + """ -class MultisliceGceTpuCluster(clusters.ClusterEnv): @classmethod def is_env_present(cls) -> bool: - return running_in_cloud_tpu_vm and is_multislice_gce_env() + """Override this method to return True if the environment is present.""" + return False @classmethod - def get_coordinator_address(cls) -> str: - coordinator_address = get_tpu_env_value('MEGASCALE_COORDINATOR_ADDRESS') + def get_coordinator_address(cls, timeout_secs: int | None) -> str: + if has_megascale_address(): + # For both GCE via QueuedResources and GKE via JobSet, the + # Megascale coordinator address is set as the host with process id = 0, + # so can be used as the jax distributed system coordinator. + coordinator_address = get_tpu_env_value('MEGASCALE_COORDINATOR_ADDRESS') + else: + # For both GCE (QueuedResources and TPUVM create) and GKE via Job API, + # the workers lists are sorted by process ID so the first one can + # be used as the jax distributed system coordinator. + coordinator_address = cls._get_worker_list_in_slice()[0] coordinator_address = coordinator_address.split(':')[0] + logger.debug("TPU Cluster using coordinator address: %s", coordinator_address) + cls.wait_for_coordinator(coordinator_address, timeout_secs) + return f'{coordinator_address}:{coordinator_port}' + @classmethod + def wait_for_coordinator(cls, coordinator_address, timeout_secs): # The coordinator may not be up before the other hosts try to # communicate with it. We check for its existence with retries. coordinator_found = False - lookup_attempt = 1 - max_coordinator_lookups = 50 - while not coordinator_found and lookup_attempt <= max_coordinator_lookups: + max_time = time.time() + timeout_secs + coordinator_retry_secs = 5 + while not coordinator_found and time.time() < max_time: try: ip_address = socket.gethostbyname(coordinator_address) coordinator_found = True + logger.debug("Found coordinator with address %s", coordinator_address) except socket.gaierror: - print(f"Failed to recognize coordinator address {coordinator_address} on attempt {lookup_attempt}, retrying...") - lookup_attempt += 1 - time.sleep(5) - + logger.debug( + "Failed to recognize coordinator address %s" + " retrying...", coordinator_address + ) + time.sleep(coordinator_retry_secs) if not coordinator_found: raise RuntimeError(f"Failed to recognize coordinator address {coordinator_address}") - # Use a different port for the jax coordinator than the MXLA coordinator, - # which is set to 8080 in multislice GCE. - return f'{coordinator_address}:{coordinator_port}' - @classmethod def get_process_count(cls) -> int: - processes_per_slice = cls._get_process_count_per_slice() - num_slices = int(get_tpu_env_value('MEGASCALE_NUM_SLICES')) - return processes_per_slice * num_slices + processes_per_slice = len(cls._get_worker_list_in_slice()) + num_slices = cls._get_num_slices() + total_process_count = processes_per_slice * num_slices + logger.debug("Total process count of %s = %s processes per slice and %s slices", total_process_count, processes_per_slice, num_slices) + return total_process_count @classmethod def get_process_id(cls) -> int: process_id_in_slice = cls._get_process_id_in_slice() - slice_id = int(get_tpu_env_value('MEGASCALE_SLICE_ID')) - processes_per_slice = cls._get_process_count_per_slice() - return process_id_in_slice + slice_id * processes_per_slice + slice_id = cls._get_slice_id() + processes_per_slice = len(cls._get_worker_list_in_slice()) + process_id = process_id_in_slice + slice_id * processes_per_slice + logger.debug("Process ID of %s generated by within-slice id %s and slice id %s", process_id, process_id_in_slice, slice_id) + return process_id - @classmethod - def get_local_process_id(cls) -> int | None: - return None + @staticmethod + def _get_num_slices() -> int: + if has_megascale_address(): + return int(get_tpu_env_value('MEGASCALE_NUM_SLICES')) + else: + return 1 @staticmethod - def _get_process_count_per_slice() -> int: - return len(get_gce_worker_endpoints()) + def _get_slice_id() -> int: + if has_megascale_address(): + return int(get_tpu_env_value('MEGASCALE_SLICE_ID')) + else: + return 0 @staticmethod def _get_process_id_in_slice() -> int: - return int(get_metadata('agent-worker-number')) + """Returns a process ID that is unique within slice.""" + raise NotImplementedError() + + @staticmethod + def _get_worker_list_in_slice() -> list[str]: + """Returns a list of worker endpoints/hostnames within slice.""" + raise NotImplementedError() + +class GceTpuCluster(BaseTpuCluster): + + name: str = "gcetpu" -class GkeTpuCluster(MultisliceGceTpuCluster): - # This class handles both single and multislice GKE as the environment - # variables are set the same in both cases. @classmethod def is_env_present(cls) -> bool: - return running_in_cloud_tpu_vm and is_gke_env() + if not running_in_cloud_tpu_vm: + logger.debug("Did not detect cloud TPU VM") + return False + metadata_response, metadata_code = get_metadata('agent-worker-number') + if metadata_code == metadata_response_code_success: + logger.debug("Gce Tpu Cluster detected for Jax Distributed System") + return True + else: + logger.debug("Did not detect Gce Tpu Cluster since agent-worker-number is not set in metadata") + logger.debug("Metadata code: %s", metadata_code) + logger.debug("Metadata response: %s", metadata_response) + return False @staticmethod - def _get_process_count_per_slice() -> int: - tpu_worker_hostnames = str(os.environ.get('TPU_WORKER_HOSTNAMES', None)) - return len(tpu_worker_hostnames.split(',')) + def _get_process_id_in_slice() -> int: + return int(get_metadata('agent-worker-number')[0]) + + @staticmethod + def _get_worker_list_in_slice() -> list[str]: + workers = get_metadata('worker-network-endpoints')[0].split(',') + return [worker.split(':')[2] for worker in workers] + +class GkeTpuCluster(BaseTpuCluster): + + name: str = "gketpu" + + @classmethod + def is_env_present(cls) -> bool: + if running_in_cloud_tpu_vm and os.environ.get("TPU_WORKER_HOSTNAMES") is not None: + logger.debug("Gke Tpu Cluster detected for Jax Distributed System") + return True + else: + if not running_in_cloud_tpu_vm: + logger.debug("Did not detect cloud TPU VM") + else: + logger.debug("Did not detect TPU GKE cluster since TPU_WORKER_HOSTNAMES is not set") + return False @staticmethod def _get_process_id_in_slice() -> int: return int(str(os.environ.get('TPU_WORKER_ID'))) + + @staticmethod + def _get_worker_list_in_slice() -> list[str]: + return str(os.environ.get('TPU_WORKER_HOSTNAMES', None)).split(',') diff --git a/jax/_src/clusters/cluster.py b/jax/_src/clusters/cluster.py index fb566276a862..4c7df0617403 100644 --- a/jax/_src/clusters/cluster.py +++ b/jax/_src/clusters/cluster.py @@ -31,29 +31,51 @@ class ClusterEnv: """ _cluster_types: list[type[ClusterEnv]] = [] + opt_in_only_method: bool = False # Override this in derived classes if necessary def __init_subclass__(cls, **kwargs): super().__init_subclass__(**kwargs) cls._cluster_types.append(cls) + @classmethod # pytype: disable=bad-return-type def auto_detect_unset_distributed_params(cls, coordinator_address: str | None, num_processes: int | None, process_id: int | None, - local_device_ids: Sequence[int] | None + local_device_ids: Sequence[int] | None, + cluster_detection_method: str | None, + initialization_timeout: int | None, ) -> tuple[str | None, int | None, int | None, Sequence[int] | None]: + if all(p is not None for p in (coordinator_address, num_processes, process_id, local_device_ids)): return (coordinator_address, num_processes, process_id, local_device_ids) - env = next((env for env in cls._cluster_types if env.is_env_present()), None) + + # First, we check the spec detection method because it will ignore submitted values + # If if succeeds. + if cluster_detection_method is not None: + env = next( (env for env in cls._cluster_types if env.name == cluster_detection_method), None ) + if env is None: + logger.error(f"Automatic Distributed initialization can not proceed:" + f" {cluster_detection_method} is not supported.") + elif not env.is_env_present(): + logger.error(f"Automatic Distributed initialization can not proceed:" + f" {cluster_detection_method} is supported but not functional in this environment.") + else: + env = next((env for env in cls._cluster_types if env.opt_in_only_method == False and env.is_env_present()), None) + + # Above: I have wrapped the env selection in a conditional to go through + # opt-in methods first (currently only mpi4py) but to check all possible options + # otherwise. Passing no cluster_detection_method results in the default, original behavior. + if env: logger.debug('Initializing distributed JAX environment via %s', env.__name__) if coordinator_address is None: - coordinator_address = env.get_coordinator_address() + coordinator_address = env.get_coordinator_address(timeout_secs=initialization_timeout) if num_processes is None: num_processes = env.get_process_count() if process_id is None: @@ -79,7 +101,7 @@ def is_env_present(cls) -> bool: raise NotImplementedError("ClusterEnv subclasses must implement is_env_present") @classmethod - def get_coordinator_address(cls) -> str: + def get_coordinator_address(cls, timeout_secs: int | None) -> str: """Returns address and port used by JAX to bootstrap. Process id 0 will open a tcp socket at "hostname:port" where diff --git a/jax/_src/clusters/mpi4py_cluster.py b/jax/_src/clusters/mpi4py_cluster.py new file mode 100644 index 000000000000..10793778f745 --- /dev/null +++ b/jax/_src/clusters/mpi4py_cluster.py @@ -0,0 +1,93 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from jax._src import clusters +import socket + +from importlib.util import find_spec + + +class Mpi4pyCluster(clusters.ClusterEnv): + + + name: str = "mpi4py" + opt_in_only_method: bool = True + + @classmethod + def is_env_present(cls) -> bool: + + # Relies on mpi4py: + return find_spec("mpi4py") is not None + + @classmethod + def get_coordinator_address(cls, timeout_secs: int | None) -> str: + + # Using mpi4py, figure out rank 0 and it's hostname. + # Then broadcast the hostname and port. + + + from mpi4py import MPI #type: ignore + # Get the global communicator: + COMM_WORLD = MPI.COMM_WORLD + + # On rank 0, get the hostname: + + if COMM_WORLD.Get_rank() == 0: + # Order all the hostnames, and find unique ones + hostname = socket.gethostname() + + # Apparently, we want to pick a port in an ephemeral range... + port_id = hash(hostname) % 2**12 + (65535 - 2**12 + 1) + + hostname = f'{hostname}:{port_id}' + + else: + hostname = "None" + + + + # Broadcast the host_ip to all ranks: + hostname = COMM_WORLD.bcast(hostname, root=0) + + + return hostname + + + @classmethod + def get_process_count(cls) -> int: + from mpi4py import MPI + return int(MPI.COMM_WORLD.Get_size()) + + @classmethod + def get_process_id(cls) -> int: + from mpi4py import MPI + return int(MPI.COMM_WORLD.Get_rank()) + + @classmethod + def get_local_process_id(cls) -> int | None: + + # Using mpi4py, split the global communicator into sub communicators + # based on hostname. mpi will assign them ranks and that will allow + # a selection of the local process ID. + from mpi4py import MPI + COMM_WORLD = MPI.COMM_WORLD + + # This is the alternative method that is simpler: + new_comm = COMM_WORLD.Split_type(MPI.COMM_TYPE_SHARED) + + + # The rank in the new communicator - which is host-local only - IS the local rank: + return int(new_comm.Get_rank()) diff --git a/jax/_src/clusters/ompi_cluster.py b/jax/_src/clusters/ompi_cluster.py index 908af28a027b..151968c1c2bc 100644 --- a/jax/_src/clusters/ompi_cluster.py +++ b/jax/_src/clusters/ompi_cluster.py @@ -25,12 +25,15 @@ _LOCAL_PROCESS_ID = 'OMPI_COMM_WORLD_LOCAL_RANK' class OmpiCluster(clusters.ClusterEnv): + + name: str = "ompi" + @classmethod def is_env_present(cls) -> bool: return _ORTE_URI in os.environ @classmethod - def get_coordinator_address(cls) -> str: + def get_coordinator_address(cls, timeout_secs: int | None) -> str: # Examples of orte_uri: # 1531576320.0;tcp://10.96.0.1,10.148.0.1,10.108.0.1:34911 # 1314521088.0;tcp6://[fe80::b9b:ac5d:9cf0:b858,2620:10d:c083:150e::3000:2]:43370 diff --git a/jax/_src/clusters/slurm_cluster.py b/jax/_src/clusters/slurm_cluster.py index 5edacb4f5d7a..8cec07601094 100644 --- a/jax/_src/clusters/slurm_cluster.py +++ b/jax/_src/clusters/slurm_cluster.py @@ -25,12 +25,15 @@ _NUM_NODES = 'SLURM_STEP_NUM_NODES' class SlurmCluster(clusters.ClusterEnv): + + name: str = "slurm" + @classmethod def is_env_present(cls) -> bool: return _JOBID_PARAM in os.environ @classmethod - def get_coordinator_address(cls) -> str: + def get_coordinator_address(cls, timeout_secs: int | None) -> str: # Pick port in ephemeral range [(65535 - 2^12 + 1), 65535] port = int(os.environ[_JOBID_PARAM]) % 2**12 + (65535 - 2**12 + 1) diff --git a/jax/_src/compilation_cache.py b/jax/_src/compilation_cache.py index e60a29e274e9..9c276151741d 100644 --- a/jax/_src/compilation_cache.py +++ b/jax/_src/compilation_cache.py @@ -34,6 +34,7 @@ from jax._src.gfile_cache import GFileCache from jax._src.lib import xla_client from jax._src.lib.mlir import ir +from jax._src.lru_cache import LRUCache logger = logging.getLogger(__name__) @@ -47,6 +48,7 @@ # Mutex to protect _cache_initialized and _cache_used. _cache_initialized_mutex = threading.Lock() +_UNSUPPORTED_RUNTIMES: set[str] = set() def set_once_cache_used(f) -> None: """One-time setting of _cache_used. @@ -65,7 +67,20 @@ def set_once_cache_used(f) -> None: def get_file_cache(path: str) -> tuple[CacheInterface, str] | None: """Returns the file cache and the path to the cache.""" - return GFileCache(path), path + + def is_local_filesystem(path: str) -> bool: + return path.startswith("file://") or "://" not in path + + # `LRUCache` currently only supports local filesystem. Therefore, if `path` + # is not on a local filesystem, instead of using `LRUCache`, we + # fallback to the old `GFileCache`, which does not support LRU eviction. + # TODO(ayx): Add cloud storage support for `LRUCache`, so that all these code + # can be removed. + if not is_local_filesystem(path): + return GFileCache(path), path + + max_size = config.compilation_cache_max_size.value + return LRUCache(path, max_size=max_size), path def set_cache_dir(path) -> None: @@ -134,10 +149,13 @@ def _initialize_cache() -> None: logger.debug("Initialized persistent compilation cache at %s", path) -def _get_cache() -> CacheInterface | None: +def _get_cache(backend) -> CacheInterface | None: # TODO(b/289098047): consider making this an API and changing the callers of # get_executable_and_time() and put_executable_and_time() to call get_cache() # and passing the result to them. + if backend.runtime_type in _UNSUPPORTED_RUNTIMES: + logger.debug("_get_cache: Unsupported runtime: %s", backend.runtime_type) + return None if _cache is None: _initialize_cache() # initialization is done at most once; see above return _cache @@ -157,13 +175,25 @@ def decompress_executable(executable): else: return zlib.decompress(executable) + +def is_executable_in_cache(backend, cache_key: str) -> bool: + """Checks if the executable is in the cache.""" + cache = _get_cache(backend) + if cache is None: + return False + + # TODO(patrios): add check cache key method to cache interface. + executable_and_time = cache.get(cache_key) + return executable_and_time is not None + + def get_executable_and_time( cache_key: str, compile_options, backend ) -> tuple[xla_client.LoadedExecutable | None, int | None]: """Returns the cached executable and its compilation time if present, or None otherwise. """ - cache = _get_cache() + cache = _get_cache(backend) if cache is None: logger.debug("get_executable_and_time: cache is disabled/not initialized") return None, None @@ -189,7 +219,7 @@ def put_executable_and_time( """Adds the 'executable' and its compilation time to the cache, possibly evicting older entries. """ - cache = _get_cache() + cache = _get_cache(backend) if cache is None: logger.debug("put_executable_and_time: cache is disabled/not initialized") return diff --git a/jax/_src/compiler.py b/jax/_src/compiler.py index c54957556819..438f1f9e5183 100644 --- a/jax/_src/compiler.py +++ b/jax/_src/compiler.py @@ -34,17 +34,16 @@ from jax._src.interpreters import mlir from jax._src.lib import xla_client as xc from jax._src.lib.mlir import ir -from jax._src.xla_bridge import process_count import numpy as np -_DISABLE_MOST_OPTIMIZATIONS = config.DEFINE_bool( +_DISABLE_MOST_OPTIMIZATIONS = config.bool_flag( 'jax_disable_most_optimizations', config.bool_env('JAX_DISABLE_MOST_OPTIMIZATIONS', False), 'Try not to do much optimization work. This can be useful if the cost of ' 'optimization is greater than that of running a less-optimized program.') -_COMPILER_DETAILED_LOGGING_MIN_OPS = config.DEFINE_integer( +_COMPILER_DETAILED_LOGGING_MIN_OPS = config.int_flag( "jax_compiler_detailed_logging_min_ops", config.int_env("JAX_COMPILER_DETAILED_LOGGING_MIN_OPS", 10), help=( @@ -218,9 +217,11 @@ def backend_compile( options: xc.CompileOptions, host_callbacks: Sequence[Any], ) -> xc.LoadedExecutable: - # Convert ir.Module to a string representation, unless the - # back-end expliclity flags the ability to handle a module directly - # (avoiding the overhead of back and forth conversions) + # Convert ir.Module to a string representation, unless the backend + # explicitly flags the ability to handle a module directly (avoiding the + # overhead of back and forth conversions). + # TODO(slebedev): Change the backend.compile() to accept ir.Module. + built_c: Any if getattr(backend, "needs_str_ir", True): built_c = mlir.module_to_bytecode(module) else: @@ -242,6 +243,7 @@ def compile_or_get_cached( devices: np.ndarray, compile_options: xc.CompileOptions, host_callbacks: Sequence[Any], + pgle_profiler: profiler.PGLEProfiler | None = None, ) -> xc.LoadedExecutable: sym_name = computation.operation.attributes['sym_name'] module_name = ir.StringAttr(sym_name).value @@ -252,9 +254,7 @@ def compile_or_get_cached( # Persistent compilation cache only implemented on TPU and GPU and the backend # that supports serialization of executables. # TODO(skye): add warning when initializing cache on unsupported default platform - supported_platforms = ["tpu", "gpu"] - # TODO(b/323256224): Add back support for CPU together with extra fields in a - # cache key with underlying hardware features (xla_extension_version >= 230). + supported_platforms = ["tpu", "gpu", "cpu"] use_compilation_cache = ( config.enable_compilation_cache.value and getattr(backend, "supports_executable_serialization", True) @@ -279,6 +279,50 @@ def compile_or_get_cached( return backend_compile(backend, computation, compile_options, host_callbacks) + is_multi_process = ( + len({device.process_index for device in devices.flatten()}) > 1) + min_device_process_id = ( + min(devices.flatten(), key=lambda device: device.id).process_index) + + # When PGLE is enabled there might be 3 types of situations: + # 1. PGLE profiled module (the one which was recompiled with FDO profile) is + # in the persistent cache. In this case the module should be returned from + # cache and PGLE should be disabled for this module. Is module is stored in + # the persistent cache under the "pgle_profiled_module_key" which calculated + # with replacing FDO profile with flag which identify that module were PGLE + # profiled. + # 2. PGLE profiled module is not in the persistent cache and the module is + # getting built with an FDO profile. In this case we need to share FDO profile + # with other processes and store the result under the + # "pgle_profiled_module_key" so later in case 1 we will be able to find the + # module. + # 3. PGLE profiled module is not in the persistent cache and the module is + # getting compiled to be PGLEd (FDO profile is empty). In this case we need to + # simply return the non-PGLE profiled module from the persistent cache. + if (config.enable_pgle.value + and config.pgle_profiling_runs.value > 0): + fdo_profile = compile_options.executable_build_options.fdo_profile + compile_options.executable_build_options.fdo_profile = b"pgle profiled" + + pgle_profiled_module_key = compilation_cache.get_cache_key( + computation, devices, compile_options, backend) + compile_options.executable_build_options.fdo_profile = fdo_profile + + if _is_executable_in_cache(backend, pgle_profiled_module_key): + # Load PGLE profiled module from the persistent cache. + cache_key = pgle_profiled_module_key + if pgle_profiler is not None: + pgle_profiler.disable() + elif fdo_profile is not None and len(fdo_profile) > 0: + # Store module under PGLE profiled module cache key. + cache_key = pgle_profiled_module_key + if is_multi_process and distributed.global_state.client is not None: + compile_options.executable_build_options.fdo_profile = _share_fdo_profiles( + computation, devices, compile_options, backend, + distributed.global_state.client, + min_device_process_id + ) + cache_retrieval_start = time.monotonic() retrieved_executable, retrieved_compile_time = _cache_read( module_name, cache_key, compile_options, backend) @@ -298,8 +342,8 @@ def compile_or_get_cached( return retrieved_executable elif ( - process_count() > 1 - and config.share_binary_between_hosts.value + config.share_binary_between_hosts.value + and is_multi_process and distributed.global_state.client is not None # Host callbacks are currently baked into the HLO module so we cant share # them. @@ -313,10 +357,11 @@ def compile_or_get_cached( distributed.global_state.client, module_name, cache_key, + min_device_process_id ) elif ( - process_count() > 1 - and config.share_autotune_config_between_hosts.value + config.share_autotune_config_between_hosts.value + and is_multi_process and distributed.global_state.client is not None ): return _compile_and_write_autotune_config( @@ -327,6 +372,7 @@ def compile_or_get_cached( distributed.global_state.client, module_name, cache_key, + min_device_process_id ) else: return _compile_and_write_cache( @@ -338,9 +384,61 @@ def compile_or_get_cached( cache_key, ) +# The process that has the lowest device ID should share FDO profile before +# compilation with other processes. +def _share_fdo_profiles( + computation: ir.Module, + devices: np.ndarray, + compile_options: xc.CompileOptions, + backend: xc.Client, + global_client: lib.xla_extension.DistributedRuntimeClient, + min_process_id +) -> bytes | None: + sym_name = computation.operation.attributes['sym_name'] + module_name = ir.StringAttr(sym_name).value + fdo_profile = compile_options.executable_build_options.fdo_profile + if fdo_profile is None or len(fdo_profile) == 0: + return fdo_profile + + compile_options.executable_build_options.fdo_profile = b"" + profile_key = ( + compilation_cache.get_cache_key( + computation, devices, compile_options, backend + ) + + "_fdo_sync" + ) + if profile_key in _share_fdo_profiles.modules_profiles: + return _share_fdo_profiles.modules_profiles[profile_key] -# The process with id 0 should compile the module and write an autotune config -# to the K-V storage. + share_timeout = config.share_binary_between_hosts_timeout_ms.value + if distributed.global_state.process_id == min_process_id: + logger.debug( + "Sharing FDO profile: %s. For module %s. Process %d.", + fdo_profile, + module_name, + min_process_id, + ) + global_client.key_value_set_bytes(profile_key, fdo_profile) + else: + logger.debug( + "Waiting for FDO profile: %s. For module %s. Should be set by process %d.", + fdo_profile, + module_name, + min_process_id, + ) + fdo_profile = global_client.blocking_key_value_get_bytes( + profile_key, share_timeout + ) + + _share_fdo_profiles.modules_profiles[profile_key] = fdo_profile + return fdo_profile + + +_share_fdo_profiles.modules_profiles = {} + + +# The process with the first_process_id should compile the module and write an +# autotune config to the K-V storage. def _compile_and_write_autotune_config( backend: xc.Client, computation: ir.Module, @@ -349,14 +447,24 @@ def _compile_and_write_autotune_config( global_client: lib.xla_extension.DistributedRuntimeClient, module_name: str, cache_key: str, + first_process_id: int ) -> xc.LoadedExecutable: share_timeout = config.share_binary_between_hosts_timeout_ms.value debug_options = compile_options.executable_build_options.debug_options + + if _compile_and_write_autotune_config.autotune_configs_dir is None: + _compile_and_write_autotune_config.autotune_configs_dir = tempfile.mkdtemp() + autotune_tmp_file = os.path.join( _compile_and_write_autotune_config.autotune_configs_dir, cache_key ) if os.path.exists(autotune_tmp_file): + logger.debug( + "Compiling module: %s. Use existing autotune config file: %s", + module_name, + autotune_tmp_file, + ) debug_options.xla_gpu_load_autotune_results_from = autotune_tmp_file return _compile_and_write_cache( backend, @@ -367,8 +475,10 @@ def _compile_and_write_autotune_config( cache_key, ) - if distributed.global_state.process_id == 0: + if distributed.global_state.process_id == first_process_id: debug_options.xla_gpu_dump_autotune_results_to = autotune_tmp_file + logger.debug("Process %d compiling and dumping autotune for module: %s", + first_process_id, module_name) executable = _compile_and_write_cache( backend, computation, @@ -377,20 +487,49 @@ def _compile_and_write_autotune_config( module_name, cache_key, ) + + logger.debug( + "Writing autotune config for module %s to %s", + module_name, + autotune_tmp_file, + ) with open(autotune_tmp_file, "rb") as f: autotune_config = f.read() autotune_config = compilation_cache.compress_executable(autotune_config) global_client.key_value_set_bytes(cache_key, autotune_config) + logger.debug( + "Autotune config for module %s with size %d shared by cache_key %s", + module_name, + len(autotune_config), + cache_key, + ) else: + logger.debug( + "Compiling module %s, waiting for config to be shared by cache_key %s" + "from process %d", + module_name, + cache_key, + first_process_id + ) autotune_config = global_client.blocking_key_value_get_bytes( cache_key, share_timeout ) + logger.debug( + "Received autotune config for module %s of size %d", + module_name, + len(autotune_config), + ) autotune_config = compilation_cache.decompress_executable(autotune_config) with open(autotune_tmp_file, "wb") as f: f.write(autotune_config) + logger.debug( + "Compiling module %s, using autotune config from %s", + module_name, + autotune_tmp_file, + ) debug_options.xla_gpu_load_autotune_results_from = autotune_tmp_file executable = _compile_and_write_cache( backend, @@ -402,12 +541,10 @@ def _compile_and_write_autotune_config( ) return executable -_compile_and_write_autotune_config.autotune_configs_dir = tempfile.mkdtemp() +_compile_and_write_autotune_config.autotune_configs_dir = None -# The process with id 0 should compile the module and write it to the K-V -# storage. -# TODO: In case when the process with id 0 is not participating in computation -# we need to choose another process to compile the module. +# The process with the first_process_id should compile the module and write it +# to the K-V storage. def _compile_and_share_module( backend: xc.Client, computation: ir.Module, @@ -416,6 +553,7 @@ def _compile_and_share_module( global_client: lib.xla_extension.DistributedRuntimeClient, module_name: str, cache_key: str, + first_process_id: int ) -> xc.LoadedExecutable: share_timeout = config.share_binary_between_hosts_timeout_ms.value @@ -424,7 +562,9 @@ def _compile_and_share_module( if cache_key in _compile_and_share_module.modules_cache: return _compile_and_share_module.modules_cache[cache_key] - if distributed.global_state.process_id == 0: + if distributed.global_state.process_id == first_process_id: + logger.debug("Process %d compiling and sharing module: %s", + first_process_id, module_name) executable = _compile_and_write_cache( backend, computation, @@ -439,6 +579,8 @@ def _compile_and_share_module( ) global_client.key_value_set_bytes(cache_key, serialized_executable) else: + logger.debug("Waiting for module: %s from process %d", module_name, + first_process_id) serialized_executable = global_client.blocking_key_value_get_bytes( cache_key, share_timeout ) @@ -472,6 +614,20 @@ def _compile_and_write_cache( ) return executable +def _is_executable_in_cache(backend, cache_key) -> bool: + """Checks if executable is presented in cache on a given key + """ + try: + return compilation_cache.is_executable_in_cache(backend, cache_key) + except Exception as ex: + if config.raise_persistent_cache_errors.value: + raise + warnings.warn( + f"Error reading persistent compilation cache entry for " + f"'{cache_key}': {type(ex).__name__}: {ex}") + return False + + def _cache_read( module_name: str, cache_key: str, compile_options: xc.CompileOptions, backend: xc.Client diff --git a/jax/_src/compute_on.py b/jax/_src/compute_on.py new file mode 100644 index 000000000000..25b2be78d287 --- /dev/null +++ b/jax/_src/compute_on.py @@ -0,0 +1,55 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations +import threading +from contextlib import contextmanager + + +class ComputeOnContext(threading.local): + + def __init__(self): + self.stack = [] + +compute_on_context = ComputeOnContext() + + +@contextmanager +def extend_compute_type(c_type: str): + compute_on_context.stack.append(c_type) + try: + if len(set(filter(lambda x: x is not None, set(compute_on_context.stack)))) > 1: + raise NotImplementedError( + 'Nesting `compute_on` with different compute types is not supported' + f' yet. Current stack: {compute_on_context.stack}') + yield compute_on_context.stack[-1] + finally: + compute_on_context.stack.pop() + +def current_compute_type() -> str | None: + return compute_on_context.stack[-1] if compute_on_context.stack else None + +def _check_valid(c_type: str): + if c_type not in {'device_host', 'device'}: + raise ValueError('Invalid compute type received. Current supported values ' + f'are `device_host` and `device`. Got {c_type}') + +@contextmanager +def compute_on(compute_type: str): + if not isinstance(compute_type, str): + raise TypeError("`compute_on`'s compute_type argument must be a string.") + _check_valid(compute_type) + + with extend_compute_type(compute_type): + yield diff --git a/jax/_src/config.py b/jax/_src/config.py index c4d5117ece13..829ed8185807 100644 --- a/jax/_src/config.py +++ b/jax/_src/config.py @@ -14,7 +14,7 @@ from __future__ import annotations -from collections.abc import Hashable, Iterator +from collections.abc import Callable, Hashable, Iterator, Sequence import contextlib import functools import itertools @@ -22,7 +22,7 @@ import os import sys import threading -from typing import Any, Callable, Generic, NamedTuple, NoReturn, TypeVar, cast +from typing import Any, Generic, NamedTuple, NoReturn, Protocol, TypeVar, cast from jax._src import lib from jax._src.lib import jax_jit @@ -60,24 +60,24 @@ def int_env(varname: str, default: int) -> int: return int(os.getenv(varname, str(default))) -UPGRADE_BOOL_HELP = ( - " This will be enabled by default in future versions of JAX, at which " - "point all uses of the flag will be considered deprecated (following " - "the `API compatibility policy " - "`_).") +class ValueHolder(Protocol[_T]): + """A holder for a configuration value. -UPGRADE_BOOL_EXTRA_DESC = " (transient)" + There are two kinds of value holders: ``Flag``, which is assigned exactly + once and never modified after; and ``State``, which can be changed locally + within a thread via a context manager. + """ + + value: _T + + def _set(self, value: _T) -> None: ... class Config: _HAS_DYNAMIC_ATTRIBUTES = True def __init__(self): - # There are two kinds of value holders: FlagHolders, which hold global - # flags, and StateContextManagers, which hold state that can be changed - # locally within a thread. A value holder needs a `.value` property and a - # `._set()` method. - self._value_holders = {} + self._value_holders: dict[str, ValueHolder] = {} self.meta = {} self.use_absl = False self._contextmanager_flags = set() @@ -113,11 +113,13 @@ def add_option(self, name, holder, opt_type, meta_args, meta_kwargs): def config_with_absl(self): """Registers absl flags for the JAX configs. - E.g., for each JAX config defined using define_bool_state(), this method + E.g., for each JAX config defined using bool_state(), this method registers an absl boolean flag, with the same name. This is the recommended method to call if you use `app.run(main)` and you - need JAX flags. Example: + need JAX flags. + + Examples: ```python from absl import app @@ -210,13 +212,16 @@ def trace_context(): dynamic_shapes.value, numpy_dtype_promotion.value, default_device.value, random_seed_offset.value, threefry_partitionable.value, + threefry_gpu_kernel_lowering.value, softmax_custom_jvp.value, - new_select_transpose.value, enable_memories.value, disable_jit.value, + debug_key_reuse.value, jax_xla_profile_version.value, # Technically this affects jaxpr->stablehlo lowering, not tracing. - hlo_source_file_canonicalization_regex.value) + hlo_source_file_canonicalization_regex.value, + pgle_profiling_runs.value, + enable_pgle.value) config = Config() @@ -234,7 +239,8 @@ class _Unset: pass _thread_local_state = threading.local() -class _StateContextManager(Generic[_T]): +class State(Generic[_T]): + __slots__ = ( '_name', '_value', '_update_thread_local_hook', '_update_global_hook', '_validator', '_default_context_manager_value', '__doc__', '__name__', @@ -268,6 +274,8 @@ def __bool__(self) -> NoReturn: type(self).__name__)) def _set(self, value: _T) -> None: + if self._validator: + self._validator(value) self._value = value if self._update_global_hook: self._update_global_hook(value) @@ -315,7 +323,16 @@ def _add_hooks(self, update_global_hook, update_thread_local_hook): update_global_hook(self._value) -def define_bool_state( +UPGRADE_BOOL_HELP = ( + " This will be enabled by default in future versions of JAX, at which " + "point all uses of the flag will be considered deprecated (following " + "the `API compatibility policy " + "`_).") + +UPGRADE_BOOL_EXTRA_DESC = " (transient)" + + +def bool_state( name: str, default: bool, help: str, @@ -324,7 +341,7 @@ def define_bool_state( update_thread_local_hook: Callable[[bool | None], None] | None = None, upgrade: bool = False, extra_description: str = '', -) -> _StateContextManager[bool]: +) -> State[bool]: """Set up thread-local state and return a contextmanager for managing it. This function is a convenience wrapper. It defines a flag, environment @@ -355,9 +372,9 @@ def define_bool_state( Returns: A contextmanager to control the thread-local state value. - Example: + Examples: - enable_foo = config.define_bool_state( + ENABLE_FOO = config.bool_state( name='jax_enable_foo', default=False, help='Enable foo.') @@ -376,7 +393,8 @@ def define_bool_state( an error. """ if not isinstance(default, bool): - raise TypeError(f"Default value must be of type bool, got {default}") + raise TypeError(f"Default value must be of type bool, got {default} " + f"of type {getattr(type(default), '__name__', type(default))}") default = bool_env(name.upper(), default) name = name.lower() if upgrade: @@ -384,7 +402,7 @@ def define_bool_state( extra_description += UPGRADE_BOOL_EXTRA_DESC config._contextmanager_flags.add(name) - s = _StateContextManager[bool]( + s = State[bool]( name, default, help, update_global_hook=update_global_hook, update_thread_local_hook=update_thread_local_hook, extra_description=extra_description, default_context_manager_value=True) @@ -393,18 +411,18 @@ def define_bool_state( return s -def define_enum_state( +def enum_state( name: str, - enum_values: list[str], + enum_values: Sequence[str], default: str, help: str, *, update_global_hook: Callable[[str], None] | None = None, update_thread_local_hook: Callable[[str | None], None] | None = None, -) -> _StateContextManager[str]: +) -> State[str]: """Set up thread-local state and return a contextmanager for managing it. - See docstring for ``define_bool_state``. + See docstring for ``bool_state``. Args: name: string, converted to lowercase to define the name of the config @@ -420,7 +438,8 @@ def define_enum_state( A contextmanager to control the thread-local state value. """ if not isinstance(default, str): - raise TypeError(f"Default value must be of type str, got {default}") + raise TypeError(f"Default value must be of type str, got {default} " + f"of type {getattr(type(default), '__name__', type(default))}") name = name.lower() default = os.getenv(name.upper(), default) if default not in enum_values: @@ -432,7 +451,7 @@ def validator(new_val): raise ValueError(f"new enum value must be in {enum_values}, " f"got {new_val} of type {type(new_val)}.") - s = _StateContextManager[str]( + s = State[str]( name, default, help, @@ -449,18 +468,18 @@ def validator(new_val): return s -def define_optional_enum_state( +def optional_enum_state( name: str, - enum_values: list[str], + enum_values: Sequence[str], default: str | None, help: str, *, update_global_hook: Callable[[str | None], None] | None = None, update_thread_local_hook: Callable[[str | None], None] | None = None, -) -> _StateContextManager[str | None]: +) -> State[str | None]: """Set up thread-local state and return a contextmanager for managing it. - See docstring for ``define_bool_state``. + See docstring for ``bool_state``. Args: name: string, converted to lowercase to define the name of the config @@ -476,7 +495,8 @@ def define_optional_enum_state( A contextmanager to control the thread-local state value. """ if default is not None and not isinstance(default, str): - raise TypeError(f"Default value must be of type str or None, got {default}") + raise TypeError(f"Default value must be of type str or None, got {default} " + f"of type {getattr(type(default), '__name__', type(default))}") name = name.lower() default = os.getenv(name.upper(), default) if default is not None and default not in enum_values: @@ -489,7 +509,7 @@ def validate(new_val): raise ValueError(f"new enum value must be None or in {enum_values}, " f"got {new_val} of type {type(new_val)}.") - s = _StateContextManager['str | None']( + s = State['str | None']( name, default, help, update_global_hook, update_thread_local_hook, validate ) @@ -502,17 +522,17 @@ def validate(new_val): return s -def define_int_state( +def int_state( name: str, default: int, help: str, *, update_global_hook: Callable[[int], None] | None = None, update_thread_local_hook: Callable[[int | None], None] | None = None, -) -> _StateContextManager[int]: +) -> State[int]: """Set up thread-local state and return a contextmanager for managing it. - See docstring for ``define_bool_state``. + See docstring for ``bool_state``. Args: name: string, converted to lowercase to define the name of the config @@ -526,7 +546,8 @@ def define_int_state( A contextmanager to control the thread-local state value. """ if not isinstance(default, int): - raise TypeError(f"Default value must be of type int, got {default}") + raise TypeError(f"Default value must be of type int, got {default} " + f"of type {getattr(type(default), '__name__', type(default))}") name = name.lower() default_env = os.getenv(name.upper()) if default_env is not None: @@ -541,24 +562,24 @@ def validate(new_val): raise ValueError(f'new int config value must be None or of type int, ' f'got {new_val} of type {type(new_val)}') - s = _StateContextManager[int](name, default, help, update_global_hook, - update_thread_local_hook, validate) + s = State[int](name, default, help, update_global_hook, + update_thread_local_hook, validate) config.add_option(name, s, int, meta_args=[], meta_kwargs={"help": help}) setattr(Config, name, property(lambda _: s.value)) return s -def define_float_state( +def float_state( name: str, default: float, help: str, *, update_global_hook: Callable[[float], None] | None = None, update_thread_local_hook: Callable[[float | None], None] | None = None, -) -> _StateContextManager[float]: +) -> State[float]: """Set up thread-local state and return a contextmanager for managing it. - See docstring for ``define_bool_state``. + See docstring for ``bool_state``. Args: name: string, converted to lowercase to define the name of the config @@ -572,7 +593,8 @@ def define_float_state( A contextmanager to control the thread-local state value. """ if not isinstance(default, float): - raise TypeError(f"Default value must be of type float, got {default}") + raise TypeError(f"Default value must be of type float, got {default} " + f"of type {getattr(type(default), '__name__', type(default))}") name = name.lower() default_env = os.getenv(name.upper()) if default_env is not None: @@ -588,24 +610,24 @@ def validate(new_val): f'new float config value must be None or of type float, ' f'got {new_val} of type {type(new_val)}') - s = _StateContextManager[float](name, default, help, update_global_hook, - update_thread_local_hook, validate) + s = State[float](name, default, help, update_global_hook, + update_thread_local_hook, validate) config.add_option(name, s, float, meta_args=[], meta_kwargs={"help": help}) setattr(Config, name, property(lambda _: s.value)) return s -def define_string_state( +def string_state( name: str, default: str, help: str, *, update_global_hook: Callable[[str], None] | None = None, update_thread_local_hook: Callable[[str | None], None] | None = None, -) -> _StateContextManager[str]: +) -> State[str]: """Set up thread-local state and return a contextmanager for managing it. - See docstring for ``define_bool_state``. + See docstring for ``bool_state``. Args: name: string, converted to lowercase to define the name of the config @@ -624,31 +646,32 @@ def define_string_state( A contextmanager to control the thread-local state value. """ if not isinstance(default, str): - raise TypeError(f"Default value must be of type str, got {default}") + raise TypeError(f"Default value must be of type str, got {default} " + f"of type {getattr(type(default), '__name__', type(default))}") def validator(new_val): if not isinstance(new_val, str): - raise ValueError('new string config value must be of type str,' + raise TypeError('new string config value must be of type str,' f' got {new_val} of type {type(new_val)}.') - return define_string_or_object_state( + return string_or_object_state( name, default, help, update_global_hook=update_global_hook, update_thread_local_hook=update_thread_local_hook, validator=validator) -def define_optional_string_state( +def optional_string_state( name: str, default: str | None, help: str, *, update_global_hook: Callable[[str], None] | None = None, update_thread_local_hook: Callable[[str | None], None] | None = None, -) -> _StateContextManager[str | None]: +) -> State[str | None]: """Set up thread-local state and return a contextmanager for managing it. - See docstring for ``define_bool_state``. + See docstring for ``bool_state``. Args: name: string, converted to lowercase to define the name of the config @@ -667,20 +690,21 @@ def define_optional_string_state( A contextmanager to control the thread-local state value. """ if default is not None and not isinstance(default, str): - raise TypeError(f"Default value must be of type str or None, got {default}") + raise TypeError(f"Default value must be of type str or None, got {default} " + f"of type {getattr(type(default), '__name__', type(default))}") def validator(new_val): if new_val is not None and not isinstance(new_val, str): raise ValueError('new string config value must be None or of type str,' f' got {new_val} of type {type(new_val)}.') - return define_string_or_object_state( + return string_or_object_state( name, default, help, update_global_hook=update_global_hook, update_thread_local_hook=update_thread_local_hook, validator=validator) -def define_string_or_object_state( +def string_or_object_state( name: str, default: Any, help: str, @@ -688,11 +712,11 @@ def define_string_or_object_state( update_global_hook: Callable[[Any], None] | None = None, update_thread_local_hook: Callable[[Any], None] | None = None, validator: Callable[[Any], None] | None = None, -) -> _StateContextManager[Any]: +) -> State[Any]: """Set up thread-local state and return a contextmanager for managing it. - Similar to ``define_string_state``, except the context manager will accept - any object, not just a string. Any value passed via commandline flag or + Similar to ``string_state``, except the context manager will accept + any object, not just a string. Any value passed via command line flag or environment variable will be treated as a string. Args: @@ -718,7 +742,7 @@ def define_string_or_object_state( default = os.getenv(name.upper(), default) config._contextmanager_flags.add(name) - s = _StateContextManager[Any]( + s = State[Any]( name, default, help, update_global_hook, update_thread_local_hook, validator) setattr(Config, name, property(lambda _: s.value)) @@ -726,7 +750,8 @@ def define_string_or_object_state( return s -class FlagHolder(Generic[_T]): +class Flag(Generic[_T]): + __slots__ = ("_name", "value", "_update_hook") _name: str @@ -751,42 +776,37 @@ def _set(self, value: _T) -> None: self._update_hook(value) -def check_exists(name): - if name not in config._value_holders: - raise AttributeError(f"Unrecognized config option: {name}") - - -def DEFINE_bool(name, default, *args, **kwargs) -> FlagHolder[bool]: +def bool_flag(name, default, *args, **kwargs) -> Flag[bool]: update_hook = kwargs.pop("update_hook", None) - holder = FlagHolder(name, default, update_hook) + holder = Flag(name, default, update_hook) config.add_option(name, holder, bool, args, kwargs) return holder -def DEFINE_integer(name, default, *args, **kwargs) -> FlagHolder[int]: +def int_flag(name, default, *args, **kwargs) -> Flag[int]: update_hook = kwargs.pop("update_hook", None) - holder = FlagHolder(name, default, update_hook) + holder = Flag(name, default, update_hook) config.add_option(name, holder, int, args, kwargs) return holder -def DEFINE_float(name, default, *args, **kwargs) -> FlagHolder[float]: +def float_flag(name, default, *args, **kwargs) -> Flag[float]: update_hook = kwargs.pop("update_hook", None) - holder = FlagHolder(name, default, update_hook) + holder = Flag(name, default, update_hook) config.add_option(name, holder, float, args, kwargs) return holder -def DEFINE_string(name, default, *args, **kwargs) -> FlagHolder[str]: +def string_flag(name, default, *args, **kwargs) -> Flag[str]: update_hook = kwargs.pop("update_hook", None) - holder = FlagHolder(name, default, update_hook) + holder = Flag(name, default, update_hook) config.add_option(name, holder, str, args, kwargs) return holder -def DEFINE_enum(name, default, *args, **kwargs) -> FlagHolder[str]: +def enum_flag(name, default, *args, **kwargs) -> Flag[str]: update_hook = kwargs.pop("update_hook", None) - holder = FlagHolder(name, default, update_hook) + holder = Flag(name, default, update_hook) config.add_option(name, holder, 'enum', args, kwargs) return holder @@ -804,9 +824,11 @@ class _GlobalExtraJitContext(NamedTuple): dynamic_shapes: bool = False random_seed_offset: int = 0 threefry_partitionable: bool = False + threefry_gpu_kernel_lowering: bool = False softmax_custom_jvp: bool = False - new_select_transpose: bool = False xla_profile_version: int = 0 + pgle_profiling_runs: int = 0 + enable_pgle: bool = False def _update_global_jit_state(**kw): @@ -839,9 +861,11 @@ class _ThreadLocalExtraJitContext(NamedTuple): dynamic_shapes: bool | None = None random_seed_offset: int | None = None threefry_partitionable: bool | None = None + threefry_gpu_kernel_lowering: bool | None = None softmax_custom_jvp: bool | None = None - new_select_transpose: bool | None = None xla_profile_version: int | None = None + pgle_profiling_runs: int | None = None + enable_pgle: bool | None = None class _ThreadLocalStateCache(threading.local): @@ -849,7 +873,7 @@ class _ThreadLocalStateCache(threading.local): The extra_jit_context in jax_jit.thread_local_state() may get updated and thus incurring dispatch overhead for comparing this python object during jit calls. - We want to duduplicate the objects that have the same hash/equality to also + We want to deduplicate the objects that have the same hash/equality to also have the same object ID, since the equality check is much faster if the object IDs match. """ @@ -871,7 +895,7 @@ def update_thread_local_jit_state(**kw): # TODO(b/214340779): remove flag when XLA:CPU is improved. -jax2tf_associative_scan_reductions = define_bool_state( +jax2tf_associative_scan_reductions = bool_state( name='jax2tf_associative_scan_reductions', default=False, help=( @@ -886,7 +910,7 @@ def update_thread_local_jit_state(**kw): ) ) -jax2tf_default_native_serialization = define_bool_state( +jax2tf_default_native_serialization = bool_state( name='jax2tf_default_native_serialization', default=bool_env('JAX2TF_DEFAULT_NATIVE_SERIALIZATION', True), help=( @@ -896,22 +920,39 @@ def update_thread_local_jit_state(**kw): ) ) -jax_serialization_version = define_int_state( +jax_serialization_version = int_state( name='jax_serialization_version', - # Note: bump the default serialization version at least one month after + default=int_env('JAX_SERIALIZATION_VERSION', 0), # We use 0 to detect default. + help=( + 'DEPRECATED: use jax_export_calling_convention_version.' + ) +) + +jax_export_calling_convention_version = int_state( + name='jax_export_calling_convention_version', + # Note: bump the default calling convention version at least one month after # we update XlaCallModule to support the new version, so that serialized # modules are forward compatible with deployed versions of XlaCallModule. # Version 9 of XlaCallModule is supported since October 27th, 2023. - default=int_env('JAX_SERIALIZATION_VERSION', 9), + default=int_env('JAX_EXPORT_CALLING_CONVENTION_VERSION', 9), help=( - 'The version number to use for native serialization. This must be ' + 'The calling convention version number to use for exporting. This must be ' 'within the range of versions supported by the tf.XlaCallModule ' 'used in your deployment environment. ' - 'See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#native-serialization-versions.' + 'See https://jax.readthedocs.io/en/latest/export/shape_poly.html#calling-convention-versions.' + ) +) + +export_ignore_forward_compatibility = bool_state( + name='jax_export_ignore_forward_compatibility', + default=bool_env('JAX_EXPORT_IGNORE_FORWARD_COMPATIBILIY', False), + help=( + 'Whether to ignore the forward compatibility lowering rules. ' + 'See https://jax.readthedocs.io/en/latest/export/export.html#compatibility-guarantees-for-custom-calls.' ) ) -jax_platforms = define_optional_string_state( +jax_platforms = optional_string_state( name='jax_platforms', default=None, help=( @@ -927,13 +968,19 @@ def update_thread_local_jit_state(**kw): 'otherwise.' )) -enable_checks = define_bool_state( +jax_pjrt_client_create_options = optional_string_state( + name='jax_pjrt_client_create_options', + default=None, + help=('A set of key-value pairs in the format of "k1:v1;k2:v2" strings ' + 'provided to a device platform pjrt client as extra arguments.')) + +enable_checks = bool_state( name='jax_enable_checks', default=False, help='Turn on invariant checking for JAX internals. Makes things slower.') -enable_key_reuse_checks = define_bool_state( - name='jax_enable_key_reuse_checks', +debug_key_reuse = bool_state( + name='jax_debug_key_reuse', default=False, help=('Turn on experimental key reuse checking. With this configuration enabled,' ' typed PRNG keys (i.e. keys created with jax.random.key()) will have their' @@ -941,7 +988,7 @@ def update_thread_local_jit_state(**kw): ' an error. Currently enabling this leads to a small Python overhead on' ' every call to a JIT-compiled function with keys as inputs or outputs.')) -check_tracer_leaks = define_bool_state( +check_tracer_leaks = bool_state( name='jax_check_tracer_leaks', default=False, help=('Turn on checking for leaked tracers as soon as a trace completes. ' @@ -951,7 +998,7 @@ def update_thread_local_jit_state(**kw): 'to disable any debuggers while leak checking is enabled.')) checking_leaks = functools.partial(check_tracer_leaks, True) -debug_nans = define_bool_state( +debug_nans = bool_state( name='jax_debug_nans', default=False, help=('Add nan checks to every operation. When a nan is detected on the ' @@ -959,7 +1006,7 @@ def update_thread_local_jit_state(**kw): 'version in an attempt to more precisely identify the operation ' 'which produced the nan.')) -debug_infs = define_bool_state( +debug_infs = bool_state( name='jax_debug_infs', default=False, help=('Add inf checks to every operation. When an inf is detected on the ' @@ -967,15 +1014,15 @@ def update_thread_local_jit_state(**kw): 'version in an attempt to more precisely identify the operation ' 'which produced the inf.')) -log_compiles = define_bool_state( +log_compiles = bool_state( name='jax_log_compiles', default=False, - help=('Log a message each time every time `jit` or `pmap` compiles an XLA ' + help=('Log a message each time `jit` or `pmap` compiles an XLA ' 'computation. Logging is performed with `logging`. When this ' 'option is set, the log level is WARNING; otherwise the level is ' 'DEBUG.')) -explain_cache_misses = define_bool_state( +explain_cache_misses = bool_state( name='jax_explain_cache_misses', default=False, help=('Each time there is a miss on one of the main caches (e.g. the ' @@ -983,19 +1030,14 @@ def update_thread_local_jit_state(**kw): '`logging`. When this option is set, the log level is WARNING; ' 'otherwise the level is DEBUG.')) -log_checkpoint_residuals = define_bool_state( +log_checkpoint_residuals = bool_state( name='jax_log_checkpoint_residuals', default=False, help=('Log a message every time jax.checkpoint (aka jax.remat) is ' 'partially evaluated (e.g. for autodiff), printing what residuals ' 'are saved.')) -parallel_functions_output_gda = define_bool_state( - name='jax_parallel_functions_output_gda', - default=False, - help='If True, pjit will output GDAs.') - -pmap_shmap_merge = define_bool_state( +pmap_shmap_merge = bool_state( name='jax_pmap_shmap_merge', default=False, upgrade=True, @@ -1007,7 +1049,7 @@ def _update_jax_memories_global(val): def _update_jax_memories_thread_local(val): lib.jax_jit.thread_local_state().enable_memories = val -enable_memories = define_bool_state( +enable_memories = bool_state( 'jax_enable_memories', default=False, upgrade=True, @@ -1016,7 +1058,7 @@ def _update_jax_memories_thread_local(val): help=("If True, will allow fetching memory kinds available on executable " "and annotate Shardings with it.")) -spmd_mode = define_enum_state( +spmd_mode = enum_state( name='jax_spmd_mode', enum_values=['allow_all', 'allow_jit'], default='allow_jit', @@ -1029,14 +1071,14 @@ def _update_jax_memories_thread_local(val): " execute on non-fully addressable `jax.Array`s.")) -distributed_debug = define_bool_state( +distributed_debug = bool_state( name='jax_distributed_debug', default=False, help=('Enable logging useful for debugging multi-process distributed ' 'computations. Logging is performed with `logging` at WARNING ' 'level.')) -random_seed_offset = define_int_state( +random_seed_offset = int_state( name='jax_random_seed_offset', default=0, help=('Offset to all random seeds (e.g. argument to jax.random.key()).'), @@ -1046,7 +1088,7 @@ def _update_jax_memories_thread_local(val): random_seed_offset=val) ) -legacy_prng_key = define_enum_state( +legacy_prng_key = enum_state( name='jax_legacy_prng_key', enum_values=['allow', 'warn', 'error'], default='allow', @@ -1054,21 +1096,21 @@ def _update_jax_memories_thread_local(val): 'jax.random APIs.') ) -enable_custom_prng = define_bool_state( +enable_custom_prng = bool_state( name='jax_enable_custom_prng', default=False, upgrade=True, help=('Enables an internal upgrade that allows one to define custom ' 'pseudo-random number generator implementations.')) -default_prng_impl = define_enum_state( +default_prng_impl = enum_state( name='jax_default_prng_impl', enum_values=['threefry2x32', 'rbg', 'unsafe_rbg'], default='threefry2x32', help=('Select the default PRNG implementation, used when one is not ' 'explicitly provided at seeding time.')) -threefry_partitionable = define_bool_state( +threefry_partitionable = bool_state( name='jax_threefry_partitionable', default=False, upgrade=True, @@ -1083,8 +1125,19 @@ def _update_jax_memories_thread_local(val): update_thread_local_hook=lambda val: update_thread_local_jit_state( threefry_partitionable=val)) -# TODO(mattjj): set default True then remove this flag (or die trying) -softmax_custom_jvp = define_bool_state( +threefry_gpu_kernel_lowering = bool_state( + name='jax_threefry_gpu_kernel_lowering', + default=False, + help=('On GPU, lower threefry PRNG operations to a kernel implementation. ' + 'This makes compile times faster at a potential runtime memory ' + 'cost.'), + update_global_hook=lambda val: _update_global_jit_state( + threefry_gpu_kernel_lowering=val), + update_thread_local_hook=lambda val: update_thread_local_jit_state( + threefry_gpu_kernel_lowering=val)) + + +softmax_custom_jvp = bool_state( name='jax_softmax_custom_jvp', default=False, upgrade=True, @@ -1097,26 +1150,14 @@ def _update_jax_memories_thread_local(val): softmax_custom_jvp=val)) -# TODO(mattjj): remove this flag -new_select_transpose = define_bool_state( - name='new_select_transpose', - default=True, - upgrade=True, - help=('Change select_n_p transpose rule to specialize on bools'), - update_global_hook=lambda val: _update_global_jit_state( - new_select_transpose=val), - update_thread_local_hook=lambda val: update_thread_local_jit_state( - new_select_transpose=val)) - - -enable_custom_vjp_by_custom_transpose = define_bool_state( +enable_custom_vjp_by_custom_transpose = bool_state( name='jax_enable_custom_vjp_by_custom_transpose', default=False, upgrade=True, help=('Enables an internal upgrade that implements `jax.custom_vjp` by ' 'reduction to `jax.custom_jvp` and `jax.custom_transpose`.')) -raise_persistent_cache_errors = define_bool_state( +raise_persistent_cache_errors = bool_state( name='jax_raise_persistent_cache_errors', default=False, help=('If true, exceptions raised when reading or writing to the ' @@ -1126,14 +1167,14 @@ def _update_jax_memories_thread_local(val): 'continue. Defaults to false so cache bugs or intermittent issues ' 'are non-fatal.')) -persistent_cache_min_compile_time_secs = define_float_state( +persistent_cache_min_compile_time_secs = float_state( name='jax_persistent_cache_min_compile_time_secs', default=1., help=('The minimum compile time of a computation to be written to the ' 'persistent compilation cache. This threshold can be raised to ' 'decrease the number of entries written to the cache.')) -persistent_cache_min_entry_size_bytes = define_int_state( +persistent_cache_min_entry_size_bytes = int_state( name='jax_persistent_cache_min_entry_size_bytes', default=0, help=('The minimum size (in bytes) of an entry that will be cached in the ' @@ -1144,7 +1185,7 @@ def _update_jax_memories_thread_local(val): ' filesystem being used for the cache. ' '* > 0: the actual minimum size desired; no overrides.')) -compilation_cache_include_metadata_in_key = define_bool_state( +compilation_cache_include_metadata_in_key = bool_state( name='jax_compilation_cache_include_metadata_in_key', default=False, help=( @@ -1156,7 +1197,7 @@ def _update_jax_memories_thread_local(val): ), ) -hlo_source_file_canonicalization_regex = define_optional_string_state( +hlo_source_file_canonicalization_regex = optional_string_state( name='jax_hlo_source_file_canonicalization_regex', default=None, help=('Used to canonicalize the source_path metadata of HLO instructions ' @@ -1166,7 +1207,7 @@ def _update_jax_memories_thread_local(val): 'persistent compilation cache, which includes HLO metadata in the ' 'cache key.')) -include_full_tracebacks_in_locations = define_bool_state( +include_full_tracebacks_in_locations = bool_state( name='jax_include_full_tracebacks_in_locations', default=True, help=( @@ -1174,7 +1215,7 @@ def _update_jax_memories_thread_local(val): ), ) -traceback_in_locations_limit = define_int_state( +traceback_in_locations_limit = int_state( name='jax_traceback_in_locations_limit', default=10, help=( @@ -1184,7 +1225,7 @@ def _update_jax_memories_thread_local(val): ), ) -share_autotune_config_between_hosts = define_bool_state( +share_autotune_config_between_hosts = bool_state( name='jax_share_autotune_config_between_hosts', default=False, help=( @@ -1198,7 +1239,7 @@ def _update_jax_memories_thread_local(val): ), ) -share_binary_between_hosts = define_bool_state( +share_binary_between_hosts = bool_state( name='jax_share_binary_between_hosts', default=False, help=( @@ -1207,13 +1248,49 @@ def _update_jax_memories_thread_local(val): ), ) -share_binary_between_hosts_timeout_ms = define_int_state( +share_binary_between_hosts_timeout_ms = int_state( name='jax_share_binary_between_hosts_timeout_ms', default=20 * 60 * 1000, help='Timeout for the compiled module share.', ) -enable_compilation_cache = define_bool_state( +enable_pgle = bool_state( + name='jax_enable_pgle', + default=False, + help=( + 'If set to True and the property jax_pgle_profiling_runs is set to ' + 'greater than 0, the modules will be recompiled after running specified ' + 'number times with collected data provided to the profile guided latency ' + 'estimator.' + ), + update_global_hook=lambda val: _update_global_jit_state(enable_pgle=val), + update_thread_local_hook=lambda val: update_thread_local_jit_state( + enable_pgle=val), +) + +pgle_profiling_runs = int_state( + name='jax_pgle_profiling_runs', + default=3, + help=( + 'Amount of times module should be profiled before recompilation when ' + 'PGLE is used.' + ), + update_global_hook=lambda val: _update_global_jit_state( + pgle_profiling_runs=val + ), + update_thread_local_hook=lambda val: update_thread_local_jit_state( + pgle_profiling_runs=val + ), +) + +pgle_aggregation_percentile = int_state( + name='jax_pgle_aggregation_percentile', + default=90, + help='Percentile used to aggregate performance data between devices when ' + 'PGLE is used.', +) + +enable_compilation_cache = bool_state( name='jax_enable_compilation_cache', default=True, help=('If set to False, the compilation cache will be disabled regardless ' @@ -1222,7 +1299,7 @@ def _update_jax_memories_thread_local(val): 'set_cache_dir().'), ) -compilation_cache_dir = define_optional_string_state( +compilation_cache_dir = optional_string_state( name='jax_compilation_cache_dir', default=None, help=('Path for the cache. ' @@ -1231,7 +1308,19 @@ def _update_jax_memories_thread_local(val): '2. The value of this flag set in the command line or by default.'), ) -default_dtype_bits = define_enum_state( +compilation_cache_max_size = int_state( + name='jax_compilation_cache_max_size', + default=-1, + help=('The maximum size (in bytes) allowed for the persistent compilation ' + 'cache. When set, the least recently accessed cache entry(s) ' + 'will be deleted once the total cache directory size ' + 'exceeds the specified limit. ' + 'Caching will be disabled if this value is set to 0. A ' + 'special value of -1 indicates no limit, allowing the cache ' + 'size to grow indefinitely.'), +) + +default_dtype_bits = enum_state( name='jax_default_dtype_bits', enum_values=['32', '64'], default='64', @@ -1239,7 +1328,7 @@ def _update_jax_memories_thread_local(val): 'This is a temporary flag that will be used during the process ' 'of deprecating the ``jax_enable_x64`` flag.')) -numpy_dtype_promotion = define_enum_state( +numpy_dtype_promotion = enum_state( name='jax_numpy_dtype_promotion', enum_values=['standard', 'strict'], default='standard', @@ -1258,7 +1347,7 @@ def _update_x64_global(val): def _update_x64_thread_local(val): lib.jax_jit.thread_local_state().enable_x64 = val -enable_x64 = define_bool_state( +enable_x64 = bool_state( name='jax_enable_x64', default=False, help='Enable 64-bit types to be used', @@ -1293,7 +1382,7 @@ def _validate_default_device(val): # TODO(skye): default_device only accepts devices for now. Make it work with # platform names as well (e.g. "cpu" to mean the same as jax.devices("cpu")[0]). -default_device = define_string_or_object_state( +default_device = string_or_object_state( name='jax_default_device', default=None, help=( @@ -1313,7 +1402,7 @@ def _update_disable_jit_global(val): def _update_disable_jit_thread_local(val): lib.jax_jit.thread_local_state().disable_jit = val -disable_jit = define_bool_state( +disable_jit = bool_state( name='jax_disable_jit', default=False, help=('Disable JIT compilation and just call original Python.'), @@ -1321,7 +1410,7 @@ def _update_disable_jit_thread_local(val): update_thread_local_hook=_update_disable_jit_thread_local) -numpy_rank_promotion = define_enum_state( +numpy_rank_promotion = enum_state( name='jax_numpy_rank_promotion', enum_values=['allow', 'warn', 'raise'], default='allow', @@ -1332,9 +1421,9 @@ def _update_disable_jit_thread_local(val): update_thread_local_hook=lambda val: \ update_thread_local_jit_state(numpy_rank_promotion=val)) -default_matmul_precision = define_optional_enum_state( +default_matmul_precision = optional_enum_state( name='jax_default_matmul_precision', - enum_values=['bfloat16', 'tensorfloat32', 'float32'], + enum_values=['default', 'high', 'highest', 'bfloat16', 'tensorfloat32', 'float32'], default=None, help=('Control the default matmul and conv precision for 32bit inputs.\n\n' @@ -1357,7 +1446,7 @@ def _update_disable_jit_thread_local(val): update_thread_local_hook=lambda val: \ update_thread_local_jit_state(default_matmul_precision=val)) -traceback_filtering = define_enum_state( +traceback_filtering = enum_state( name = 'jax_traceback_filtering', enum_values=["off", "tracebackhide", "remove_frames", "quiet_remove_frames", "auto"], @@ -1378,14 +1467,14 @@ def _update_disable_jit_thread_local(val): # This flag is for internal use. # TODO(tianjianlu): Removes once we always enable cusparse lowering. # TODO(b/262050896): Set to true after bug is fixed -bcoo_cusparse_lowering = define_bool_state( +bcoo_cusparse_lowering = bool_state( name='jax_bcoo_cusparse_lowering', default=False, help=('Enables lowering BCOO ops to cuSparse.')) # TODO(mattjj): remove this flag when we ensure we only succeed at trace-staging # if the intended backend can handle lowering the result -dynamic_shapes = define_bool_state( +dynamic_shapes = bool_state( name='jax_dynamic_shapes', default=bool(os.getenv('JAX_DYNAMIC_SHAPES', '')), help=('Enables experimental features for staging out computations with ' @@ -1397,19 +1486,26 @@ def _update_disable_jit_thread_local(val): # This flag is temporary during rollout of the remat barrier. # TODO(parkers): Remove if there are no complaints. -remat_opt_barrier = define_bool_state( +remat_opt_barrier = bool_state( name='jax_remat_opt_barrier', default=True, help=('Enables using optimization-barrier op for lowering remat.')) # TODO(sharadmv,mattjj): set default to True, then remove -eager_pmap = define_bool_state( +eager_pmap = bool_state( name='jax_eager_pmap', default=True, upgrade=True, help='Enable eager-mode pmap when jax_disable_jit is activated.') -xla_runtime_errors = define_bool_state( +# TODO(mattjj): remove once we land mutable array plumbing, or face great shame +custom_vjp_disable_shape_check = bool_state( + name='jax_custom_vjp_disable_shape_check', + default=False, + upgrade=True, + help='Disable the check from #19009 to enable some custom_vjp hacks.') + +xla_runtime_errors = bool_state( name='jax_experimental_unsafe_xla_runtime_errors', default=False, help=('Enable XLA runtime errors for jax.experimental.checkify.checks ' @@ -1419,7 +1515,7 @@ def _update_disable_jit_thread_local(val): 'work under pmap/pjit.') ) -jax_xla_profile_version = define_int_state( +jax_xla_profile_version = int_state( name='jax_xla_profile_version', default=0, help=( @@ -1471,7 +1567,7 @@ def _update_transfer_guard(state, key, val): else: assert False, f'Invalid transfer guard level {val}' -transfer_guard_host_to_device = define_optional_enum_state( +transfer_guard_host_to_device = optional_enum_state( name='jax_transfer_guard_host_to_device', enum_values=[ 'allow', 'log', 'disallow', 'log_explicit', 'disallow_explicit' @@ -1486,7 +1582,7 @@ def _update_transfer_guard(state, key, val): update_thread_local_hook=lambda val: _update_transfer_guard( transfer_guard_lib.thread_local_state(), 'host_to_device', val)) -transfer_guard_device_to_device = define_optional_enum_state( +transfer_guard_device_to_device = optional_enum_state( name='jax_transfer_guard_device_to_device', enum_values=[ 'allow', 'log', 'disallow', 'log_explicit', 'disallow_explicit' @@ -1501,7 +1597,7 @@ def _update_transfer_guard(state, key, val): update_thread_local_hook=lambda val: _update_transfer_guard( transfer_guard_lib.thread_local_state(), 'device_to_device', val)) -transfer_guard_device_to_host = define_optional_enum_state( +transfer_guard_device_to_host = optional_enum_state( name='jax_transfer_guard_device_to_host', enum_values=[ 'allow', 'log', 'disallow', 'log_explicit', 'disallow_explicit' @@ -1522,7 +1618,7 @@ def _update_all_transfer_guard_global(val): 'jax_transfer_guard_device_to_host'): config.update(name, val) -_transfer_guard = define_optional_enum_state( +_transfer_guard = optional_enum_state( name='jax_transfer_guard', enum_values=[ 'allow', 'log', 'disallow', 'log_explicit', 'disallow_explicit' @@ -1566,7 +1662,7 @@ def _update_debug_log_modules(module_names_str: str | None): logging_config.enable_debug_logging(module_name) # Don't define a context manager since this isn't threadsafe. -define_string_state( +string_state( name='jax_debug_log_modules', default='', help=('Comma-separated list of module names (e.g. "jax" or ' @@ -1574,7 +1670,7 @@ def _update_debug_log_modules(module_names_str: str | None): 'for.'), update_global_hook=_update_debug_log_modules) -pmap_no_rank_reduction = define_bool_state( +pmap_no_rank_reduction = bool_state( name='jax_pmap_no_rank_reduction', default=False, help=( diff --git a/jax/_src/core.py b/jax/_src/core.py index c351e6980bf7..c99288747c2d 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -13,12 +13,11 @@ # limitations under the License. from __future__ import annotations -import collections # noqa: F401 from collections import Counter, defaultdict, deque, namedtuple -from collections.abc import (Collection, Generator, Hashable, Iterable, - Iterator, Set, Sequence, MutableSet, +from collections.abc import (Callable, Collection, Generator, Hashable, + Iterable, Iterator, Set, Sequence, MutableSet, MutableMapping) -from contextlib import contextmanager +from contextlib import contextmanager, ExitStack from dataclasses import dataclass import functools from functools import partial, partialmethod, total_ordering @@ -29,16 +28,18 @@ import operator import threading import types -from typing import (Any, Callable, ClassVar, Generic, NamedTuple, TypeVar, +from typing import (Any, ClassVar, Generic, NamedTuple, TypeVar, cast, overload, Union) import warnings from weakref import ref import numpy as np +from jax._src import deprecations from jax._src import dtypes from jax._src import config from jax._src import effects +from jax._src import compute_on from jax._src.errors import ( ConcretizationTypeError, TracerArrayConversionError, TracerBoolConversionError, TracerIntegerConversionError, UnexpectedTracerError) @@ -54,13 +55,14 @@ from jax._src import traceback_util from jax._src.typing import Array, DimSize, Shape from jax._src import typing + traceback_util.register_exclusion(__file__) zip, unsafe_zip = safe_zip, zip map, unsafe_map = safe_map, map -_TRACER_ERROR_NUM_TRACEBACK_FRAMES = config.DEFINE_integer( +_TRACER_ERROR_NUM_TRACEBACK_FRAMES = config.int_flag( 'jax_tracer_error_num_traceback_frames', config.int_env('JAX_TRACER_ERROR_NUM_TRACEBACK_FRAMES', 5), help='Set the number of stack frames in JAX tracer error messages.' @@ -78,7 +80,7 @@ class JaxprDebugInfo(NamedTuple): traced_for: str # e.g. 'jit', 'scan', etc func_src_info: str # e.g. f'{fun.__name__} at {filename}:{lineno}' arg_names: tuple[str | None, ...] # e.g. ('args[0]', ... ) - result_paths: tuple[str | None, ...] # e.g. ('[0]', '[1]', ...) + result_paths: tuple[str, ...] # e.g. ('[0]', '[1]', ...) class Jaxpr: __slots__ = ['__weakref__', '_constvars', '_invars', '_outvars', '_eqns', @@ -148,65 +150,28 @@ def __str__(self): def pretty_print(self, *, source_info=False, print_shapes=True, custom_pp_eqn_rules=True, name_stack=False, print_effects: bool = False, **kwargs): - context = JaxprPpContext() - settings = JaxprPpSettings( - source_info=source_info, - print_shapes=print_shapes, - custom_pp_eqn_rules=custom_pp_eqn_rules, - name_stack=name_stack, - print_effects=print_effects) - - # Compute how many times each jaxpr is used. - names = defaultdict[Jaxpr, str](lambda: "jaxpr") - jaxpr_counts = Counter[Jaxpr]() - s = deque([self]) - while s: - jaxpr = s.popleft() - jaxpr_counts[jaxpr] += 1 - for eqn in jaxpr.eqns: - # TODO(slebedev): Come up with a more elaborate heuristic for name=. - name = eqn.params.get("name") - if name is None: - s.extend(jaxprs_in_params(eqn.params)) - continue - name = name.strip("<>") # -> lambda - for subjaxpr in jaxprs_in_params(eqn.params): - s.append(subjaxpr) - names.setdefault(subjaxpr, name) - - # Pull jaxprs occurring more than once to the top-level, making sure - # that their names are unique. - docs = [] - name_counts = Counter[str]() - for jaxpr, c in jaxpr_counts.items(): - if c == 1: - continue - name = names[jaxpr] - if (count := name_counts[name]) > 0: - name_counts[name] += 1 - name += str(count) - name_counts[name] += 1 - else: - name_counts[name] += 1 - docs.append(pp_top_level_jaxpr(name, jaxpr, context, settings)) - context.used_names.add(name) - context.top_level_jaxprs[jaxpr] = name - docs.append(pp_jaxpr(self, context, settings)) - return pp.concat(docs).format(**kwargs) + doc = pp_toplevel_jaxpr( + self, source_info=source_info, print_shapes=print_shapes, + custom_pp_eqn_rules=custom_pp_eqn_rules, name_stack=name_stack, + print_effects=print_effects) + return doc.format(**kwargs) def _repr_pretty_(self, p, cycle): return p.text(self.pretty_print(use_color=True)) - def replace(self, *, constvars=None, invars=None, outvars=None, eqns=None, - effects=None, debug_info=None): - constvars = self.constvars if constvars is None else constvars - invars = self.invars if invars is None else invars - outvars = self.outvars if outvars is None else outvars - eqns = self.eqns if eqns is None else eqns - effects = self.effects if effects is None else effects - debug_info = self.debug_info if debug_info is None else debug_info - return Jaxpr(constvars=constvars, invars=invars, outvars=outvars, eqns=eqns, - effects=effects, debug_info=debug_info) + def replace(self, **kwargs): + jaxpr = Jaxpr( + constvars=kwargs.pop("constvars", self.constvars), + invars=kwargs.pop("invars", self.invars), + outvars=kwargs.pop("outvars", self.outvars), + eqns=kwargs.pop("eqns", self.eqns), + effects=kwargs.pop("effects", self.effects), + debug_info=kwargs.pop("debug_info", self.debug_info), + ) + if kwargs: + raise ValueError(f"Unknown keyword arguments: {kwargs}") + return jaxpr + def join_effects(*effects: Effects) -> Effects: return set().union(*effects) if effects else no_effects @@ -294,13 +259,51 @@ def jaxpr_as_fun(closed_jaxpr: ClosedJaxpr, *args): return eval_jaxpr(closed_jaxpr.jaxpr, closed_jaxpr.consts, *args) -class JaxprEqn(NamedTuple): +class JaxprEqnContext: + + def __init__(self, compute_type: str | None, threefry_partitionable: bool): + self.compute_type = compute_type + self.threefry_partitionable = threefry_partitionable + self._managers = [ + (compute_on.extend_compute_type, self.compute_type), + (config.threefry_partitionable.__call__, self.threefry_partitionable), + ] + + @property + @contextmanager + def manager(self): + with ExitStack() as stack: + for manager, val in self._managers: + stack.enter_context(manager(val)) + yield + + def __repr__(self): + return (f"JaxprEqnContext(compute_type={self.compute_type}," + f"threefry_partitionable={self.threefry_partitionable})") + + +class JaxprEqn: invars: list[Atom] outvars: list[Var] primitive: Primitive params: dict[str, Any] effects: Effects source_info: source_info_util.SourceInfo + ctx: JaxprEqnContext + + # It's slightly faster to use a class with __slots__ than a NamedTuple. + __slots__ = ['invars', 'outvars', 'primitive', 'params', 'effects', + 'source_info', 'ctx'] + + def __init__(self, invars, outvars, primitive, params, effects, source_info, + ctx): + self.invars = invars + self.outvars = outvars + self.primitive = primitive + self.params = params + self.effects = effects + self.source_info = source_info + self.ctx = ctx def __repr__(self): return str(pp_eqn(self, JaxprPpContext(), JaxprPpSettings())).rstrip() @@ -313,8 +316,8 @@ def replace( params: dict[str, Any] | None = None, effects: Effects | None = None, source_info: source_info_util.SourceInfo | None = None, + ctx: JaxprEqnContext | None = None ): - # It is slightly faster to rebuild the tuple directly than to call _replace. return JaxprEqn( self.invars if invars is None else invars, self.outvars if outvars is None else outvars, @@ -322,16 +325,20 @@ def replace( self.params if params is None else params, self.effects if effects is None else effects, self.source_info if source_info is None else source_info, + self.ctx if ctx is None else ctx, ) # TODO(mattjj): call typecheck rules here, so we don't form bad eqns -def new_jaxpr_eqn(invars, outvars, primitive, params, effects, source_info=None): +def new_jaxpr_eqn(invars, outvars, primitive, params, effects, source_info=None, + ctx=None): source_info = source_info or source_info_util.new_source_info() + ctx = ctx or JaxprEqnContext(compute_on.current_compute_type(), + config.threefry_partitionable.value) if config.enable_checks.value: assert all(isinstance(x, (Var, Literal)) for x in invars) assert all(isinstance(v, Var) for v in outvars) - return JaxprEqn(invars, outvars, primitive, params, effects, source_info) + return JaxprEqn(invars, outvars, primitive, params, effects, source_info, ctx) _var_counter = it.count() @@ -422,7 +429,8 @@ def bind(self, *args, **params): return self.bind_with_trace(find_top_trace(args), args, params) def bind_with_trace(self, trace, args, params): - out = trace.process_primitive(self, map(trace.full_raise, args), params) + with pop_level(trace.level): + out = trace.process_primitive(self, map(trace.full_raise, args), params) return map(full_lower, out) if self.multiple_results else full_lower(out) def def_impl(self, impl): @@ -430,11 +438,11 @@ def def_impl(self, impl): return impl def def_abstract_eval(self, abstract_eval): - self.abstract_eval = _effect_free_abstract_eval(abstract_eval) # type: ignore[assignment] + self.abstract_eval = _effect_free_abstract_eval(abstract_eval) return abstract_eval def def_effectful_abstract_eval(self, effectful_abstract_eval): - self.abstract_eval = effectful_abstract_eval # type: ignore[assignment] + self.abstract_eval = effectful_abstract_eval return effectful_abstract_eval def def_custom_bind(self, bind): @@ -486,7 +494,8 @@ def write(v: Var, val: Any) -> None: subfuns, bind_params = eqn.primitive.get_bind_params(eqn.params) name_stack = source_info_util.current_name_stack() + eqn.source_info.name_stack traceback = eqn.source_info.traceback if propagate_source_info else None - with source_info_util.user_context(traceback, name_stack=name_stack): + with source_info_util.user_context( + traceback, name_stack=name_stack), eqn.ctx.manager: ans = eqn.primitive.bind(*subfuns, *map(read, eqn.invars), **bind_params) if eqn.primitive.multiple_results: map(write, eqn.outvars, ans) @@ -672,13 +681,25 @@ class Tracer(typing.Array, metaclass=StrictABCMeta): size = _aval_property('size') shape = _aval_property('shape') + def __hash__(self): + # TODO(jakevdp) finalize this deprecation and set __hash__ = None + # Warning added 2024-06-13 + if deprecations.is_accelerated('tracer-hash'): + raise TypeError(f"unhashable type: {type(self)}") + # Use FutureWarning rather than DeprecationWarning because hash is likely + # not called directly by the user, so we want to warn at all stacklevels. + warnings.warn( + f"unhashable type: {type(self)}. Attempting to hash a tracer will lead to an" + " error in a future JAX release.", category=FutureWarning) + return super().__hash__() + def __init__(self, trace: Trace): self._trace = trace def _error_repr(self): if self.aval is None: return f"traced array with aval {self.aval}" - return f"traced array with shape {raise_to_shaped(self.aval).str_short()}." + return f"traced array with shape {raise_to_shaped(self.aval).str_short()}" def __array__(self, *args, **kw): raise TracerArrayConversionError(self) @@ -713,9 +734,13 @@ def sharding(self): # This attribute is part of the jax.Array API, but only defined on concrete arrays. # Raising a ConcretizationTypeError would make sense, but for backward compatibility # we raise an AttributeError so that hasattr() and getattr() work as expected. + try: + orig_msg = self._origin_msg() + except: + orig_msg = '' raise AttributeError(self, f"The 'sharding' attribute is not available on {self._error_repr()}." - f"{self._origin_msg()}") + f"{orig_msg}") @property def addressable_shards(self): @@ -828,14 +853,14 @@ def addressable_data(self, index): @property def block_until_ready(self): - # Raise AttribureError for backward compatibility with hasattr() and getattr() checks. + # Raise AttributeError for backward compatibility with hasattr() and getattr() checks. raise AttributeError(self, f"The 'block_until_ready' method is not available on {self._error_repr()}." f"{self._origin_msg()}") @property def copy_to_host_async(self): - # Raise AttribureError for backward compatibility with hasattr() and getattr() checks. + # Raise AttributeError for backward compatibility with hasattr() and getattr() checks. raise AttributeError(self, f"The 'copy_to_host_async' method is not available on {self._error_repr()}." f"{self._origin_msg()}") @@ -845,11 +870,6 @@ def delete(self): f"The delete() method was called on {self._error_repr()}." f"{self._origin_msg()}") - def device(self): - raise ConcretizationTypeError(self, - f"The device() method was called on {self._error_repr()}." - f"{self._origin_msg()}") - def devices(self): raise ConcretizationTypeError(self, f"The devices() method was called on {self._error_repr()}." @@ -906,10 +926,20 @@ def pure(self, x): return x lift = sublift = pure def process_primitive(self, primitive, tracers, params): - return primitive.impl(*tracers, **params) + if config.debug_key_reuse.value: + # Import here to avoid circular imports + from jax.experimental.key_reuse._core import call_impl_with_key_reuse_checks # pytype: disable=import-error + return call_impl_with_key_reuse_checks(primitive, primitive.impl, *tracers, **params) + else: + return primitive.impl(*tracers, **params) def process_call(self, primitive, f, tracers, params): - return primitive.impl(f, *tracers, **params) + if config.debug_key_reuse.value: + # Import here to avoid circular imports + from jax.experimental.key_reuse._core import call_impl_with_key_reuse_checks # pytype: disable=import-error + return call_impl_with_key_reuse_checks(primitive, primitive.impl, f, *tracers, **params) + else: + return primitive.impl(f, *tracers, **params) process_map = process_call def process_custom_transpose(self, primitive, call, tracers, **_): @@ -1023,13 +1053,8 @@ def copy(self): def _update_thread_local_jit_state(dynamic): - # Copies the MainTrace instance, removing any .debug_info or .jaxpr_stack - # fields that should not be kept alive as part of a cache key. - # TODO(mattjj): split debug_info and jaxpr_stack out of MainTrace. - # TODO(mattjj): add a test that verifies that JIT-ted functions are not kept - # alive by the JIT cache, particularly for nested JIT-ted functions. - copy = MainTrace(dynamic.level, dynamic.trace_type, **dynamic.payload) - config.update_thread_local_jit_state(dynamic_trace_state=copy) + state = (dynamic.level, dynamic.trace_type) + config.update_thread_local_jit_state(dynamic_trace_state=state) # The global state of the tracer is accessed by a thread-local object. @@ -1054,8 +1079,8 @@ def _initialize_jax_jit_thread_local_state(): tls = jax_jit.thread_local_state() if tls.extra_jit_context is None: dynamic = thread_local_state.trace_state.trace_stack.dynamic - copy = MainTrace(dynamic.level, dynamic.trace_type, **dynamic.payload) - config.update_thread_local_jit_state(dynamic_trace_state=copy) + state = (dynamic.level, dynamic.trace_type) + config.update_thread_local_jit_state(dynamic_trace_state=state) jax_jit.set_thread_local_state_initialization_callback( @@ -1071,7 +1096,7 @@ def trace_state_clean() -> bool: def reset_trace_state() -> bool: """Resets the global trace state and returns True if it was already clean.""" if not trace_state_clean(): - thread_local_state.trace_state.__init__() # type: ignore + thread_local_state.trace_state.__init__() return False else: return True @@ -1235,6 +1260,18 @@ def new_base_main(trace_type: type[Trace], leaked_tracers = maybe_find_leaked_tracers(t()) if leaked_tracers: raise leaked_tracer_error("trace", t(), leaked_tracers) +@contextmanager +def pop_level(level: int): + if level == 0: + return (yield) + prev, thread_local_state.trace_state.trace_stack.stack = \ + thread_local_state.trace_state.trace_stack.stack, \ + thread_local_state.trace_state.trace_stack.stack[:level] + try: + yield + finally: + thread_local_state.trace_state.trace_stack.stack = prev + @contextmanager def ensure_compile_time_eval(): """Context manager to ensure evaluation at trace/compile time (or error). @@ -1277,7 +1314,7 @@ def f(x): @jax.jit def jax_fn(x): with jax.ensure_compile_time_eval(): - y = random.randint(random.PRNGKey(0), (1000,1000), 0, 100) + y = random.randint(random.key(0), (1000,1000), 0, 100) y2 = y @ y x2 = jnp.sum(y2) * x return x2 @@ -1285,7 +1322,7 @@ def jax_fn(x): A similar behavior can often be achieved simply by 'hoisting' the constant expression out of the corresponding staging API:: - y = random.randint(random.PRNGKey(0), (1000,1000), 0, 100) + y = random.randint(random.key(0), (1000,1000), 0, 100) @jax.jit def jax_fn(x): @@ -1333,11 +1370,11 @@ def find_top_trace(xs) -> Trace: top_tracer._assert_live() top_main = top_tracer._trace.main else: - top_main = None # type: ignore + top_main = None dynamic = thread_local_state.trace_state.trace_stack.dynamic top_main = (dynamic if top_main is None or dynamic.level > top_main.level else top_main) - return top_main.with_cur_sublevel() # type: ignore + return top_main.with_cur_sublevel() def get_referent(x: Any) -> Any: return x.get_referent() if isinstance(x, Tracer) else x @@ -1625,6 +1662,70 @@ def shape(self): "UnshapedArray instances to ever be produced.") raise TypeError(msg) +def _canonicalize_dimension(dim: DimSize) -> DimSize: + # Dimensions are most commonly integral (by far), so we check that first. + try: + return operator.index(dim) + except TypeError as e: + type_error = e + if isinstance(dim, Tracer) and config.dynamic_shapes.value: + if not (dim.ndim == 0 and (dtypes.issubdtype(dim.dtype, np.integer) + or isinstance(dim.dtype, bint))): + raise TypeError(f"Dimensions must be integer scalars; got {dim.ndim=} {dim.dtype=}") + return dim + elif (config.dynamic_shapes.value and isinstance(dim, DArray) and + type(dim._aval.dtype) is bint and not dim._aval.shape): + return dim + elif is_dim(dim): + return dim + else: + raise type_error + +def canonicalize_shape(shape: Shape, context: str="") -> tuple[Any, ...]: + """Canonicalizes and checks for errors in a user-provided shape value. + + Args: + shape: a Python value that represents a shape. + + Returns: + A tuple of canonical dimension values. + """ + try: + return tuple(unsafe_map(_canonicalize_dimension, shape)) + except TypeError: + pass + raise _invalid_shape_error(shape, context) + +def canonicalize_dim(d: DimSize, context: str="") -> DimSize: + """Canonicalizes and checks for errors in a user-provided shape dimension value. + + Args: + f: a Python value that represents a dimension. + + Returns: + A canonical dimension value. + """ + return canonicalize_shape((d,), context)[0] + +def _invalid_shape_error(shape: Shape, context: str=""): + if config.dynamic_shapes.value: + msg = ("Shapes must be 1D sequences of integer scalars, " + f"got {shape}") + else: + msg = ("Shapes must be 1D sequences of concrete values of integer type, " + f"got {shape}.") + if context: + msg += f" {context}." + if not config.dynamic_shapes.value and any( + isinstance(x, Tracer) and isinstance(get_aval(x), ShapedArray) + and not isinstance(get_aval(x), ConcreteArray) for x in shape): + msg += ("\nIf using `jit`, try using `static_argnums` or applying `jit` to " + "smaller subfunctions.") + for x in shape: + if isinstance(x, Tracer) and hasattr(x, "_origin_msg"): + msg += x._origin_msg() + + return TypeError(msg) class ShapedArray(UnshapedArray): __slots__ = ['shape', 'named_shape'] @@ -1686,6 +1787,7 @@ def join(self, other): def str_short(self, short_dtypes=False): dt_str = _short_dtype_name(self.dtype) if short_dtypes else self.dtype.name + dt_str = dt_str.replace('void', 'float0') shapestr = ','.join(map(str, self.shape)) if self.named_shape: named_shapestr = ','.join(f'{k}:{v}' for k, v in self.named_shape.items()) @@ -1766,7 +1868,7 @@ def str_short(self, short_dtypes=False) -> str: def primal_dtype_to_tangent_dtype(primal_dtype): if isinstance(primal_dtype, dtypes.ExtendedDType): - return primal_dtype._rules.tangent_dtype(primal_dtype) # type: ignore + return primal_dtype._rules.tangent_dtype(primal_dtype) elif not dtypes.issubdtype(primal_dtype, np.inexact): return dtypes.float0 else: @@ -1918,15 +2020,26 @@ def __init__(self, aval, buf): dtype = property(lambda self: self._aval.dtype) def __getitem__(self, idx): return get_aval(self)._getitem(self, idx) def __setitem__(self, idx, x): return get_aval(self)._setitem(self, idx, x) + def __repr__(self) -> str: return 'Mutable' + repr(self[...]) pytype_aval_mappings[MutableArray] = lambda x: x._aval def mutable_array(init_val): return mutable_array_p.bind(init_val) mutable_array_p = Primitive('mutable_array') +class InternalMutableArrayEffect(effects.Effect): + pass +internal_mutable_array_effect = InternalMutableArrayEffect() +effects.control_flow_allowed_effects.add_type(InternalMutableArrayEffect) + +@mutable_array_p.def_effectful_abstract_eval +def mutable_array_abstract_eval(init_aval): + from jax._src.state.types import AbstractRef # pytype: disable=import-error + return AbstractRef(init_aval), {internal_mutable_array_effect} + @mutable_array_p.def_impl def _mutable_array_impl(init_val): - from jax._src.state.types import AbstractRef # type: ignore[import] + from jax._src.state.types import AbstractRef # pytype: disable=import-error aval = raise_to_shaped(get_aval(init_val)) return MutableArray(AbstractRef(aval), init_val) @@ -1941,9 +2054,18 @@ def str_short(self, short_dtypes=False): return 'Tok' def at_least_vspace(self): return self abstract_token: AbstractToken = AbstractToken() +# Singleton shaped array used by all abstract tokens when shape/dtype is needed. +token_shaped_array: ShapedArray = ShapedArray((0,), np.dtype(np.bool_)) + # Concrete token object -class Token: pass -token: Token = Token() +class Token: + # The underlying data wrapped by the token, could be used to threaded in and + # out of computations to build up data dependency. + _buf: Array + def __init__(self, buf): + self._buf = buf + def block_until_ready(self): + self._buf.block_until_ready() pytype_aval_mappings[Token] = lambda _: abstract_token @@ -2102,71 +2224,6 @@ def dimension_as_value(d: DimSize): if hasattr(d, "dimension_as_value"): return d.dimension_as_value() return operator.index(d) -def _canonicalize_dimension(dim: DimSize) -> DimSize: - # Dimensions are most commonly integral (by far), so we check that first. - try: - return operator.index(dim) - except TypeError as e: - type_error = e - if isinstance(dim, Tracer) and config.dynamic_shapes.value: - if not (dim.ndim == 0 and (dtypes.issubdtype(dim.dtype, np.integer) - or isinstance(dim.dtype, bint))): - raise TypeError(f"Dimensions must be integer scalars; got {dim.ndim=} {dim.dtype=}") - return dim - elif (config.dynamic_shapes.value and isinstance(dim, DArray) and - type(dim._aval.dtype) is bint and not dim._aval.shape): - return dim - elif is_dim(dim): - return dim - else: - raise type_error - -def canonicalize_shape(shape: Shape, context: str="") -> tuple[Any, ...]: - """Canonicalizes and checks for errors in a user-provided shape value. - - Args: - shape: a Python value that represents a shape. - - Returns: - A tuple of canonical dimension values. - """ - try: - return tuple(unsafe_map(_canonicalize_dimension, shape)) - except TypeError: - pass - raise _invalid_shape_error(shape, context) - -def canonicalize_dim(d: DimSize, context: str="") -> DimSize: - """Canonicalizes and checks for errors in a user-provided shape dimension value. - - Args: - f: a Python value that represents a dimension. - - Returns: - A canonical dimension value. - """ - return canonicalize_shape((d,), context)[0] - -def _invalid_shape_error(shape: Shape, context: str=""): - if config.dynamic_shapes.value: - msg = ("Shapes must be 1D sequences of integer scalars, " - f"got {shape}") - else: - msg = ("Shapes must be 1D sequences of concrete values of integer type, " - f"got {shape}.") - if context: - msg += f" {context}." - if not config.dynamic_shapes.value and any( - isinstance(x, Tracer) and isinstance(get_aval(x), ShapedArray) - and not isinstance(get_aval(x), ConcreteArray) for x in shape): - msg += ("\nIf using `jit`, try using `static_argnums` or applying `jit` to " - "smaller subfunctions.") - for x in shape: - if isinstance(x, Tracer) and hasattr(x, "_origin_msg"): - msg += x._origin_msg() - - return TypeError(msg) - class SomeTracer: __slots__ = () def __repr__(self): return "[dynamic]" @@ -2378,6 +2435,8 @@ def get_bind_params(self, params): closed_call_p: ClosedCallPrimitive = ClosedCallPrimitive('closed_call') closed_call_p.def_impl(call_impl) +closed_call_p.def_effectful_abstract_eval( + lambda *_, call_jaxpr: (call_jaxpr.out_avals, call_jaxpr.effects)) outfeed_primitives: set[Primitive] = set() @@ -2742,7 +2801,7 @@ def do_subst_axis_names_jaxpr(jaxpr: Jaxpr | ClosedJaxpr, subst: AxisSubst): var_map: dict[Var, Var] = {} invars = [subst_axis_names_var(v, subst, var_map) for v in jaxpr.invars] # type: ignore[union-attr] constvars = [subst_axis_names_var(v, subst, var_map) for v in jaxpr.constvars] # type: ignore[union-attr] - eqns = [subst_axis_names_eqn(eqn, subst, var_map) for eqn in jaxpr.eqns] # type: ignore[union-attr] + eqns = [subst_axis_names_eqn(eqn, subst, var_map) for eqn in jaxpr.eqns] outvars: list[Atom] = [v if isinstance(v, Literal) else var_map[v] for v in jaxpr.outvars] # type: ignore[union-attr] effects = subst_axis_names_effects(jaxpr.effects, subst) new_jaxpr = Jaxpr(constvars, invars, outvars, eqns, effects) @@ -2821,7 +2880,7 @@ class JaxprTypeError(TypeError): pass def _check_closed_call(_, *in_atoms, call_jaxpr): in_avals = [x.aval for x in in_atoms] - if list(in_avals) != list(call_jaxpr.in_avals): + if not all(map(typecompat, call_jaxpr.in_avals, in_avals)): raise JaxprTypeError("Closed call in_avals mismatch") return call_jaxpr.out_avals, call_jaxpr.effects custom_typechecks[closed_call_p] = _check_closed_call @@ -2860,7 +2919,7 @@ def ctx_factory(): raise JaxprTypeError(msg) from None # Run key reuse checker after validating jaxpr: - if config.enable_key_reuse_checks.value: + if config.debug_key_reuse.value: # Import here to avoid circular imports from jax.experimental.key_reuse._core import check_key_reuse_jaxpr # pytype: disable=import-error check_key_reuse_jaxpr(jaxpr) @@ -2917,6 +2976,8 @@ def write(v: Var, a: AbstractValue) -> None: write(v, v.aval) # Check each eqn. + sentinel = object() + in_idx = {v: i for i, v in enumerate(it.chain(jaxpr.constvars, jaxpr.invars))} for eqn_idx, eqn in enumerate(jaxpr.eqns): prim = eqn.primitive try: @@ -2938,6 +2999,9 @@ def write(v: Var, a: AbstractValue) -> None: # Check the computed effect type matches the eqn's annotation, and is # included in the jaxpr's annotation. + if prim is mutable_array_p: + outvar, = eqn.outvars + in_idx[outvar] = None # type: ignore if eqn.effects != eqn_effects: raise JaxprTypeError("Inferred effects do not match equation effects. " f"Equation effects: {eqn.effects}. " @@ -2945,11 +3009,9 @@ def write(v: Var, a: AbstractValue) -> None: for eff in eqn.effects: if isinstance(eff, effects.JaxprInputEffect): eqn_invar = eqn.invars[eff.input_index] - all_vars = [*jaxpr.constvars, *jaxpr.invars] - if eqn_invar not in all_vars: + if (jaxpr_index := in_idx.get(eqn_invar, sentinel)) is sentinel: raise JaxprTypeError( "Invalid `JaxprInputEffect`: must correspond to a jaxpr invar") - jaxpr_index = all_vars.index(eqn_invar) jaxpr_effect = eff.replace(input_index=jaxpr_index) if jaxpr_effect not in jaxpr.effects: raise JaxprTypeError( @@ -3022,8 +3084,8 @@ def substitute_vars_in_output_ty( result = [] for aval in out_type: if type(aval) is DShapedArray: - shape = [in_atoms[d.val] if type(d) is InDBIdx else # type: ignore - out_binders[d.val] if type(d) is OutDBIdx else # type: ignore + shape = [in_atoms[d.val] if type(d) is InDBIdx else + out_binders[d.val] if type(d) is OutDBIdx else d for d in aval.shape] aval = aval.update(shape=tuple(shape)) result.append(aval) @@ -3116,6 +3178,55 @@ def _check_map(ctx_factory, prim, in_avals, params): # ------------------- Jaxpr printed representation ------------------- +def pp_toplevel_jaxpr(jaxpr_to_print, *, source_info=False, print_shapes=True, + custom_pp_eqn_rules=True, name_stack=False, + print_effects: bool = False) -> pp.Doc: + context = JaxprPpContext() + settings = JaxprPpSettings( + source_info=source_info, + print_shapes=print_shapes, + custom_pp_eqn_rules=custom_pp_eqn_rules, + name_stack=name_stack, + print_effects=print_effects) + + # Compute how many times each jaxpr is used. + names = defaultdict[Jaxpr, str](lambda: "jaxpr") + jaxpr_counts = Counter[Jaxpr]() + s = deque([jaxpr_to_print]) + while s: + jaxpr = s.popleft() + jaxpr_counts[jaxpr] += 1 + for eqn in jaxpr.eqns: + # TODO(slebedev): Come up with a more elaborate heuristic for name=. + name = eqn.params.get("name") + if name is None: + s.extend(jaxprs_in_params(eqn.params)) + continue + name = name.strip("<>") # -> lambda + for subjaxpr in jaxprs_in_params(eqn.params): + s.append(subjaxpr) + names.setdefault(subjaxpr, name) + + # Pull jaxprs occurring more than once to the top-level, making sure + # that their names are unique. + docs = [] + name_counts = Counter[str]() + for jaxpr, c in jaxpr_counts.items(): + if c == 1: + continue + name = names[jaxpr] + if (count := name_counts[name]) > 0: + name_counts[name] += 1 + name += str(count) + name_counts[name] += 1 + else: + name_counts[name] += 1 + docs.append(pp_top_level_jaxpr(name, jaxpr, context, settings)) + context.used_names.add(name) + context.top_level_jaxprs[jaxpr] = name + docs.append(pp_jaxpr(jaxpr_to_print, context, settings)) + return pp.concat(docs) + class JaxprPpSettings(NamedTuple): print_shapes: bool = True @@ -3205,7 +3316,9 @@ def pp_eqn(eqn: JaxprEqn, context: JaxprPpContext, settings: JaxprPpSettings ) -> pp.Doc: rule = (_pp_eqn if not settings.custom_pp_eqn_rules else pp_eqn_rules.get(eqn.primitive, _pp_eqn)) - return rule(eqn, context, settings) # type: ignore[operator] + doc = rule(eqn, context, settings) # type: ignore[operator] + user_frame = source_info_util.user_frame(eqn.source_info) + return doc if user_frame is None else pp.source_map(doc, user_frame) def _pp_eqn(eqn, context, settings, params=None) -> pp.Doc: annotation = (source_info_util.summarize(eqn.source_info) diff --git a/jax/_src/cudnn/fused_attention_stablehlo.py b/jax/_src/cudnn/fused_attention_stablehlo.py index 9b7bb975dd6f..e5b1f0084d00 100644 --- a/jax/_src/cudnn/fused_attention_stablehlo.py +++ b/jax/_src/cudnn/fused_attention_stablehlo.py @@ -12,9 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +from enum import Enum from functools import partial, reduce import operator -from typing import Optional import json import jax @@ -36,6 +36,49 @@ DType = jnp.dtype PRNGKey = jnp.ndarray +class AttentionLayout(Enum): + BTNH = 0 + BNTH = 1 + +class MaskType(Enum): + NO_MASK = 0 + PADDING = 1 + CAUSAL = 2 + PADDING_CAUSAL = 3 + ALIBI = 4 + +def convert_mask_type_to_string(mask_type: MaskType) -> str: + if mask_type == MaskType.NO_MASK: + return "NO_MASK" + elif mask_type == MaskType.PADDING: + return "PADDING" + elif mask_type == MaskType.CAUSAL: + return "CAUSAL" + elif mask_type == MaskType.PADDING_CAUSAL: + return "PADDING_CAUSAL" + elif mask_type == MaskType.ALIBI: + return "ALIBI" + else: + raise ValueError(f"Unexpected mask type: {mask_type}") + +def has_padding(mask_type: MaskType) -> bool: + return mask_type == MaskType.PADDING or mask_type == MaskType.PADDING_CAUSAL + +def should_export_dbias(bias_shape, query_shape, layout) -> bool: + b_B, b_N, _, _ = bias_shape + if layout == AttentionLayout.BNTH.value: + _, q_N, _, _ = query_shape + else: + _, _, q_N, _ = query_shape + return b_B == 1 and b_N == q_N + +def _normalize_layout(layout: str) -> AttentionLayout: + layout_upper = layout.upper() + if layout_upper in ["BSNH", "BNSH", "BTNH", "BNTH"]: + return AttentionLayout[layout_upper.replace("S", "T")] + else: + raise ValueError(f"Unsupported qkv_layout: {layout}") + def element_type_to_backend_config_type_mapping(dtype): _element_type_to_backend_config_type_mapping = { ir.BF16Type.get(): "BF16", @@ -54,20 +97,19 @@ def create_dot_product_attention_backend_config(batch, fmha_scale, seed, dropout_rate, - is_flash_attention, - is_causal_mask, + mask_type, + layout, is_bwd): - # b q_seq num_heads head_dim -> Q - # b kv_seq num_heads head_dim -> K - # b kv_seq num_heads head_dim -> V - # b num_heads q_seq kv_seq -> P - # b q_seq num_heads head_dim -> O - # bmm1: Q @ K -> P - # bmm2: P @ V -> O - # bmm2Grad1: P @ dO -> dV - # bmm2Grad2: dO @ V -> dP - # bmm1Grad1: dP @ Q -> dK - # bmm1Grad2: dP @ K -> dQ + # Q, K, V: query, key, value in shape of BT(S)NH or BNT(S)H + # P: BMM1 output in shape of BNTS + # O: BMM2 output in the same shape with Q + # BMM1: Q @ K -> P + # BMM2: P @ V -> O + # BMM1Grad1: dP @ Q -> dK + # BMM1Grad2: dP @ K -> dQ + # BMM2Grad1: P @ dO -> dV + # BMM2Grad2: dO @ V -> dP + cudnn_fmha_backend_config = { "algorithm": { "algo_id": "0", @@ -97,49 +139,50 @@ def create_dot_product_attention_backend_config(batch, "is_dynamic_dimension": [False, False, False, False], }, "seed": seed, - "is_flash_attention": is_flash_attention, - "is_causal_mask": is_causal_mask, - } - fwd_dot_number = { - "bmm1_dot_dimension_numbers": { - "lhs_contracting_dimensions": ["3"], - "rhs_contracting_dimensions": ["3"], - "lhs_batch_dimensions": ["0", "2"], - "rhs_batch_dimensions": ["0", "2"], - }, - "bmm2_dot_dimension_numbers": { - "lhs_contracting_dimensions": ["3"], - "rhs_contracting_dimensions": ["1"], - "lhs_batch_dimensions": ["0", "1"], - "rhs_batch_dimensions": ["0", "2"], - }, - } - bwd_dot_number = { - "bmm1_grad_gemm1_dot_dimension_numbers": { - "lhs_contracting_dimensions": ["2"], - "rhs_contracting_dimensions": ["1"], - "lhs_batch_dimensions": ["0", "1"], - "rhs_batch_dimensions": ["0", "2"], - }, - "bmm1_grad_gemm2_dot_dimension_numbers": { - "lhs_contracting_dimensions": ["3"], - "rhs_contracting_dimensions": ["1"], - "lhs_batch_dimensions": ["0", "1"], - "rhs_batch_dimensions": ["0", "2"], - }, - "bmm2_grad_gemm1_dot_dimension_numbers": { - "lhs_contracting_dimensions": ["2"], - "rhs_contracting_dimensions": ["1"], - "lhs_batch_dimensions": ["0", "1"], - "rhs_batch_dimensions": ["0", "2"], - }, - "bmm2_grad_gemm2_dot_dimension_numbers": { - "lhs_contracting_dimensions": ["3"], - "rhs_contracting_dimensions": ["3"], - "lhs_batch_dimensions": ["0", "2"], - "rhs_batch_dimensions": ["0", "2"], - }, + "is_flash_attention": True, + "mask_type": convert_mask_type_to_string(mask_type), } + + # We define the contracting and batch dims in the format of + # ((lhs_contracting_dims, rhs_contracting_dims), (lhs_batch_dims, + # rhs_batch_dims)). + if layout == AttentionLayout.BNTH.value: + dims = [ + ((3, 3), ((0, 1), (0, 1))), # BMM1: BNTH,BNSH->BNTS + ((3, 2), ((0, 1), (0, 1))), # BMM2: BNTS,BNSH->BNTH + ((2, 2), ((0, 1), (0, 1))), # BMM1_grad_1: BNTS,BNTH->BNSH + ((3, 2), ((0, 1), (0, 1))), # BMM1_grad_2: BNTS,BNSH->BNTH + ((2, 2), ((0, 1), (0, 1))), # BMM2_grad_1: BNTS,BNTH->BNSH + ((3, 3), ((0, 1), (0, 1))), # BMM2_grad_2: BNTH,BNSH->BNTS + ] + else: + dims = [ + ((3, 3), ((0, 2), (0, 2))), # BMM1: BTNH,BSNH->BNTS + ((3, 1), ((0, 1), (0, 2))), # BMM2: BNTS,BSNH->BTNH + ((2, 1), ((0, 1), (0, 2))), # BMM1_grad_1: BNTS,BTNH->BSNH + ((3, 1), ((0, 1), (0, 2))), # BMM1_grad_2: BNTS,BSNH->BTNH + ((2, 1), ((0, 1), (0, 2))), # BMM2_grad_1: BNTS,BTNH->BSNH + ((3, 3), ((0, 2), (0, 2))), # BMM2_grad_2: BTNH,BSNH->BNTS + ] + keys = [ + "bmm1_dot_dimension_numbers", + "bmm2_dot_dimension_numbers", + "bmm1_grad_gemm1_dot_dimension_numbers", + "bmm1_grad_gemm2_dot_dimension_numbers", + "bmm2_grad_gemm1_dot_dimension_numbers", + "bmm2_grad_gemm2_dot_dimension_numbers", + ] + fwd_dot_number = {} + bwd_dot_number = {} + for idx, (key, ((lc, rc), (lb, rb))) in enumerate(zip(keys, dims)): + dims_to_write = fwd_dot_number if idx < 2 else bwd_dot_number + dims_to_write[key] = { + "lhs_contracting_dimensions": [str(lc)], + "rhs_contracting_dimensions": [str(rc)], + "lhs_batch_dimensions": [str(i) for i in lb], + "rhs_batch_dimensions": [str(i) for i in rb], + } + if is_bwd: cudnn_fmha_backend_config = {**cudnn_fmha_backend_config, **bwd_dot_number} else: @@ -156,371 +199,494 @@ def create_dot_product_attention_backend_config(batch, # mapping from (is_bwd, has_dropout, has_mask, has_bias) to custom call name _custom_name_maps = { # fMHA forward call targets. - (False, False, False, False): "__cudnn$fmhaSoftmax", - (False, False, False, True): "__cudnn$fmhaScaleBiasSoftmax", - (False, False, True, False): "__cudnn$fmhaScaleMaskSoftmax", - (False, False, True, True): "__cudnn$fmhaScaleBiasMaskSoftmax", - (False, True, False, False): "__cudnn$fmhaSoftmaxDropout", - (False, True, False, True): "__cudnn$fmhaScaleBiasSoftmaxDropout", - (False, True, True, False): "__cudnn$fmhaScaleMaskSoftmaxDropout", - (False, True, True, True): "__cudnn$fmhaScaleBiasMaskSoftmaxDropout", - # fMHA backward call targets. - (True, False, False, False): "__cudnn$fmhaSoftmaxBackward", - (True, False, False, True): "__cudnn$fmhaScaleBiasSoftmaxBackward", - (True, False, True, False): "__cudnn$fmhaScaleMaskSoftmaxBackward", - (True, False, True, True): "__cudnn$fmhaScaleBiasMaskSoftmaxBackward", - (True, True, False, False): "__cudnn$fmhaSoftmaxDropoutBackward", - (True, True, False, True): "__cudnn$fmhaScaleBiasSoftmaxDropoutBackward", - (True, True, True, False): "__cudnn$fmhaScaleMaskSoftmaxDropoutBackward", - (True, True, True, True): "__cudnn$fmhaScaleBiasMaskSoftmaxDropoutBackward" + (False, False, False): "__cudnn$fmhaSoftmax", + (False, False, True): "__cudnn$fmhaScaleBiasSoftmax", + (False, True, False): "__cudnn$fmhaSoftmaxDropout", + (False, True, True): "__cudnn$fmhaScaleBiasSoftmaxDropout", + # fMHA backward call targets. + (True, False, False): "__cudnn$fmhaSoftmaxBackward", + (True, False, True): "__cudnn$fmhaScaleBiasSoftmaxBackward", + (True, True, False): "__cudnn$fmhaSoftmaxDropoutBackward", + (True, True, True): "__cudnn$fmhaScaleBiasSoftmaxDropoutBackward", } -def get_custom_call_name(has_bias, has_mask, has_dropout, is_bwd): - return _custom_name_maps[(is_bwd, has_dropout, has_mask, has_bias)] - -def check_qkv_layout(query, key, value): - assert len(query.shape) == len(key.shape) == len(value.shape) == 4, \ - "query, key and value should have rank 4." - - # Only support fp16 and bf16 here - query_dtype = query.dtype - key_dtype = key.dtype - value_dtype = value.dtype - assert query_dtype == key_dtype == value_dtype and query_dtype in [jnp.float16, jnp.bfloat16], \ - "query, key and value should have same dtype and should be float16 or bfloat16" - - q_batch, q_seq_len, q_num_heads, q_head_dim = query.shape - k_batch, k_seq_len, k_num_heads, k_head_dim = key.shape - v_batch, v_seq_len, v_num_heads, v_head_dim = value.shape - if not((q_batch == k_batch == v_batch) - and (k_seq_len == v_seq_len) - and (q_num_heads == k_num_heads == v_num_heads) - and (q_head_dim == k_head_dim == v_head_dim)): - raise ValueError( - "query should have layout [batch, q_seq, num_heads, head_dim], " \ - "key and value should have layout [batch, kv_seq, num_heads, head_dim].") - -def check_is_flash_attention(query, key): - batch, q_seq_len, num_heads, head_dim = query.shape - _, kv_sqe_len, _, _ = key.shape - is_cross_attention = q_seq_len != kv_sqe_len - # check if attention pattern is supported by flash attention or fused attention - if q_seq_len > 512 and kv_sqe_len > 512 and head_dim in [64, 128]: - # check if flash attention is supported - is_flash_attention = True - elif q_seq_len <= 512 and kv_sqe_len <= 512 and head_dim == 64: - # check if regular fused attention is supported - is_flash_attention = False +def get_custom_call_name(has_bias, has_dropout, is_bwd): + return _custom_name_maps[(is_bwd, has_dropout, has_bias)] + +def check_layout(query, key, value, bias, q_seqlen, kv_seqlen, layout): + def check_eq(a, b, c, msg): + if not (a == b == c): + raise ValueError(f"{msg} must be same, got {a}, {b}, {b}") + + q_rank, k_rank, v_rank = len(query.shape), len(key.shape), len(value.shape) + if q_rank != 4: + raise ValueError(f"Q must have a rank of 4, got {q_rank}") + check_eq(q_rank, k_rank, v_rank, "QKV rank") + + q_dtype, k_dtype, v_dtype = query.dtype, key.dtype, value.dtype + assert q_dtype in [jnp.float16, jnp.bfloat16], "Q must be fp16 or bf16" + check_eq(q_dtype, k_dtype, v_dtype, "QKV dtype") + + if layout == AttentionLayout.BNTH: + qB, qN, qT, qH = query.shape + kB, kN, kS, kH = key.shape + vB, vN, vS, vH = value.shape else: + assert layout == AttentionLayout.BTNH + qB, qT, qN, qH = query.shape + kB, kS, kN, kH = key.shape + vB, vS, vN, vH = value.shape + + check_eq(qB, kB, vB, "QKV batch") + check_eq(qH, kH, vH, "QKV dim_per_head") + if kN != vN: + raise ValueError(f"KV must have same number of heads, got {kN} vs {vN}") + if kS != vS: + raise ValueError(f"KV must have same seq length, got {kS} vs {vS}") + + # check bias/q_seqlen/kv_seqlen + if bias is not None: + _, _, bT, bS = bias.shape + if bT != qT or bS != vS: + raise ValueError( + f"Bias must have same seq length as QKV, got {bT} and {bS}") + if q_seqlen is not None: + q_seq_dtype = q_seqlen.dtype + q_seq_rank = len(q_seqlen.shape) + if q_seq_dtype != jnp.int32: + raise ValueError(f"q_seqlen must have int32 datatype, got {q_seq_dtype}") + if q_seq_rank != 1: + raise ValueError(f"q_seqlen must have a rank of 1, got {q_seq_rank}") + q_seq_b = q_seqlen.shape[0] + if q_seq_b != qB: + raise ValueError(f"q_seqlen must have same batch as Q, got {q_seq_b}") + if kv_seqlen is not None: + kv_seq_dtype = kv_seqlen.dtype + kv_seq_rank = len(kv_seqlen.shape) + if kv_seq_dtype != jnp.int32: + raise ValueError( + f"kv_seqlen must have int32 datatype, got {kv_seq_dtype}") + if kv_seq_rank != 1: + raise ValueError(f"kv_seq_rank must have a rank of 1, got {kv_seq_rank}") + kv_seq_b = kv_seqlen.shape[0] + if kv_seq_b != qB: + raise ValueError(f"kv_seqlen must have same batch as Q, got {kv_seq_b}") + +def check_is_flash_attention( + query, key, layout, cudnn_version, has_bias, is_training): + if layout == AttentionLayout.BNTH: + _, _, T, H = query.shape + _, _, S, _ = key.shape + else: + _, T, _, H = query.shape + _, S, _, _ = key.shape + + if not ((H <= 128 and H % 8 == 0) and + (not is_training or not has_bias or T % 2 == 0 and S % 2 == 0)): + # check if flash attention is supported + # for training, for patterns with bias, seqlen should be divisible by 2 raise NotImplementedError( - f"Unsupported sequence length Q {q_seq_len}, KV {kv_sqe_len} and head dim {head_dim}.") - return is_flash_attention, is_cross_attention + f"Unsupported sequence length Q {T}, KV {S} and head dim {H}.") + # check if minimum cudnn version requirement is satisfied + if cudnn_version < 8904: + raise RuntimeError( + "JAX requires cuDNN >= 8.9.4 to use flash cross attention.") + +def check_cudnn_version(): + # check if cuDNN is installed + if cuda_versions is None: + raise RuntimeError("cuDNN is not detected.") + return cuda_versions.cudnn_get_version() -def check_cudnn_version(is_flash_attention, is_cross_attention): - # check if cuDNN is installed and if cuDNN version contraint is satisfied +def check_compute_capability(cc): if cuda_versions is None: raise RuntimeError("cuDNN is not detected.") - elif is_flash_attention: - if not is_cross_attention and cuda_versions.cudnn_get_version() < 8903: - raise RuntimeError("JAX requires cuDNN >= 8.9.3 to use flash attention.") - if is_cross_attention and cuda_versions.cudnn_get_version() < 8904: - raise RuntimeError("JAX requires cuDNN >= 8.9.4 to use flash cross attention.") - elif not is_flash_attention and cuda_versions.cudnn_get_version() < 8901: - raise RuntimeError("JAX requires cuDNN >= 8.9.1 to use fused attention.") - -def _dot_product_attention_fwd(query, key, value, bias, mask, - scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask): - output, _ = _dot_product_attention_fwd_p_wrapper.bind( - query, key, value, bias, mask, scale=scale, seed=seed, dropout_rate=dropout_rate, - variadic_args=variadic_args, is_flash_attention=is_flash_attention, - is_causal_mask=is_causal_mask) + for i in range(jax.device_count()): + compute_cap = cuda_versions.cuda_compute_capability(i) + if compute_cap not in cc: + raise RuntimeError("Require compute capability in " + str(cc)) + +def _dot_product_attention_fwd( + query, key, value, bias, mask, q_seqlen, kv_seqlen, scale, seed, + dropout_rate, variadic_args, mask_type, layout, cudnn_version): + # check if flash attention is supported for this attention pattern + check_is_flash_attention( + query, key, layout, cudnn_version, bias is not None, False) + outputs = _dot_product_attention_fwd_p_wrapper.bind( + query, key, value, bias, mask, q_seqlen, kv_seqlen, scale=scale, + seed=seed, dropout_rate=dropout_rate, variadic_args=variadic_args, + mask_type=mask_type, layout=layout, is_training=False) + output = outputs[0] return output -def _dot_product_attention_fwd_rule(query, key, value, bias, mask, - scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask): - output, activation = _dot_product_attention_fwd_p_wrapper.bind( - query, key, value, bias, mask, scale=scale, seed=seed, dropout_rate=dropout_rate, - variadic_args=variadic_args, is_flash_attention=is_flash_attention, - is_causal_mask=is_causal_mask) - res = (query, key, value, bias, mask, activation, output) - return output, res - -def _dot_product_attention_bwd_rule(scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask, res, grad_output): - query, key, value, bias, mask, activation, fwd_output = res - grad_query, grad_key, grad_value = _dot_product_attention_bwd_p_wrapper.bind( - query, key, value, bias, mask, activation, fwd_output, grad_output, - scale=scale, seed=seed, dropout_rate=dropout_rate, - variadic_args=variadic_args, is_flash_attention=is_flash_attention, - is_causal_mask=is_causal_mask) - grads = (grad_query, grad_key, grad_value, None, None) +def _dot_product_attention_fwd_rule( + query, key, value, bias, mask, q_seqlen, kv_seqlen, scale, seed, + dropout_rate, variadic_args, mask_type, layout, cudnn_version): + # check if flash attention is supported for this attention pattern + check_is_flash_attention( + query, key, layout, cudnn_version, bias is not None, True) + outputs = _dot_product_attention_fwd_p_wrapper.bind( + query, key, value, bias, mask, q_seqlen, kv_seqlen, scale=scale, + seed=seed, dropout_rate=dropout_rate, variadic_args=variadic_args, + mask_type=mask_type, layout=layout, is_training=True) + res = (query, key, value, bias, mask, q_seqlen, kv_seqlen, + outputs[1], outputs[0]) + return outputs[0], res + +def _dot_product_attention_bwd_rule( + scale, seed, dropout_rate, variadic_args, mask_type, layout, is_training, + res, grad_output): + (query, key, value, bias, mask, q_seqlen, kv_seqlen, activation, + fwd_output) = res + grads = _dot_product_attention_bwd_p_wrapper.bind( + query, key, value, bias, mask, q_seqlen, kv_seqlen, activation, + fwd_output, grad_output, scale=scale, seed=seed, + dropout_rate=dropout_rate, variadic_args=variadic_args, + mask_type=mask_type, layout=layout, + ) + grads = (*grads,) + (None,) * (7 - len(grads)) return grads -def _dot_product_attention_fwd_impl(query, key, value, bias, mask, - scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask): +def _dot_product_attention_fwd_impl( + query, key, value, bias, mask, q_seqlen, kv_seqlen, scale, seed, + dropout_rate, variadic_args, mask_type, layout, is_training): # args: {Q, K, V, mask*, bias*} - output, activation = _dot_product_attention_fwd_p.bind( - query, key, value, bias, mask, scale=scale, seed=seed, dropout_rate=dropout_rate, - variadic_args=variadic_args, is_flash_attention=is_flash_attention, - is_causal_mask=is_causal_mask) - return output, activation - -def _dot_product_attention_bwd_impl(query, key, value, bias, mask, activation, fwd_output, grad_output, - scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask): - grad_query, grad_key, grad_value = _dot_product_attention_bwd_p.bind( - query, key, value, bias, mask, activation, fwd_output, grad_output, - scale=scale, seed=seed, dropout_rate=dropout_rate, - variadic_args=variadic_args, is_flash_attention=is_flash_attention, - is_causal_mask=is_causal_mask) - grads = (grad_query, grad_key, grad_value) + outputs = _dot_product_attention_fwd_p.bind( + query, key, value, bias, mask, q_seqlen, kv_seqlen, scale=scale, + seed=seed, dropout_rate=dropout_rate, variadic_args=variadic_args, + mask_type=mask_type, layout=layout, is_training=is_training) + return outputs + +def _dot_product_attention_bwd_impl( + query, key, value, bias, mask, q_seqlen, kv_seqlen, activation, fwd_output, + grad_output, scale, seed, dropout_rate, variadic_args, mask_type, layout): + grads = _dot_product_attention_bwd_p.bind( + query, key, value, bias, mask, q_seqlen, kv_seqlen, activation, + fwd_output, grad_output, scale=scale, seed=seed, + dropout_rate=dropout_rate, variadic_args=variadic_args, + mask_type=mask_type, layout=layout) return grads -def _dot_product_attention_fwd_abstract(query, key, value, bias, mask, - *, scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask): +def _dot_product_attention_fwd_abstract( + query, key, value, bias, mask, q_seqlen, kv_seqlen, *, scale, seed, + dropout_rate, variadic_args, mask_type, layout, is_training): query_dtype = dtypes.canonicalize_dtype(query.dtype) - batch, q_seq_len, num_heads, head_dim = query.shape - _, kv_seq_len, _, _ = key.shape - output_shape = (batch, q_seq_len, num_heads, head_dim) - activation_shape = (batch, num_heads, q_seq_len, kv_seq_len) - softmax_stat_shape = (batch, num_heads, q_seq_len) - if q_seq_len > 512: - # is flash attention + if layout == AttentionLayout.BNTH.value: + B, N, T, _ = query.shape + _, _, S, _ = key.shape + else: + B, T, N, _ = query.shape + _, S, _, _ = key.shape + output_shape = query.shape + softmax_stat_shape = (B, N, T) + + if is_training: return ( core.ShapedArray(output_shape, query_dtype), # output core.ShapedArray(softmax_stat_shape, jnp.float32), # softmax_stat ) - return ( - core.ShapedArray(output_shape, query_dtype), # output - core.ShapedArray(activation_shape, query_dtype), # activation - ) + else: + return ( + core.ShapedArray(output_shape, query_dtype), # output + ) -def _dot_product_attention_bwd_abstract(query, key, value, bias, mask, activation, fwd_output, grad_output, - *, scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask): +def _dot_product_attention_bwd_abstract( + query, key, value, bias, mask, q_seqlen, kv_seqlen, activation, fwd_output, + grad_output, *, scale, seed, dropout_rate, variadic_args, mask_type, + layout): query_dtype = dtypes.canonicalize_dtype(query.dtype) key_dtype = dtypes.canonicalize_dtype(key.dtype) value_dtype = dtypes.canonicalize_dtype(value.dtype) - return ( - core.ShapedArray( - query.shape, query_dtype - ), # grad query - core.ShapedArray( - key.shape, key_dtype - ), # grad key - core.ShapedArray( - value.shape, value_dtype - ), # part value - ) + _, _, has_dbias = variadic_args + if has_dbias: + # cuDNN supports bias for this case + bias_dtype = dtypes.canonicalize_dtype(bias.dtype) + return ( + core.ShapedArray( + query.shape, query_dtype + ), # grad query + core.ShapedArray( + key.shape, key_dtype + ), # grad key + core.ShapedArray( + value.shape, value_dtype + ), # grad value + core.ShapedArray( + bias.shape, bias_dtype + ), # grad bias + ) + else: + return ( + core.ShapedArray( + query.shape, query_dtype + ), # grad query + core.ShapedArray( + key.shape, key_dtype + ), # grad key + core.ShapedArray( + value.shape, value_dtype + ), # grad value + ) -def _dot_product_attention_fwd_cuda_lowering(ctx, query, key, value, bias, mask, - scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask): +def _dot_product_attention_fwd_cuda_lowering( + ctx, query, key, value, bias, mask, q_seqlen, kv_seqlen, scale, seed, + dropout_rate, variadic_args, mask_type, layout, is_training): query_type = ir.RankedTensorType(query.type) query_shape = query_type.shape key_type = ir.RankedTensorType(key.type) key_shape = key_type.shape - value_type = ir.RankedTensorType(value.type) - value_shape = value_type.shape - - batch, q_seq_len, num_heads, head_dim = query_shape - _, kv_seq_len, _, _ = key_shape - - output_shape = (batch, num_heads, q_seq_len, head_dim) - output_layout = (3, 1, 2, 0) - output_transpose_perm = mlir.dense_int_array((0, 2, 1, 3)) - activation_shape = (batch, num_heads, q_seq_len, kv_seq_len) - softmax_stat_shape = (batch, num_heads, q_seq_len) - scratch_shape = (0,) - scratch_type = ir.IntegerType.get_unsigned(8) - # get backend config - backend_config = create_dot_product_attention_backend_config(batch, num_heads, q_seq_len, kv_seq_len, query_type.element_type, scale, seed, dropout_rate, is_flash_attention, is_causal_mask, False) - # {Q, K, V, mask*, bias*} - # {output, scratch, activation*} + + if layout == AttentionLayout.BNTH.value: + B, N, T, H = query_shape + _, _, S, _ = key_shape + output_layout = (3, 2, 1, 0) + output_transpose_perm = mlir.dense_int_array((0, 1, 2, 3)) + else: + B, T, N, H = query_shape + _, S, _, _ = key_shape + output_layout = (3, 1, 2, 0) + output_transpose_perm = mlir.dense_int_array((0, 2, 1, 3)) + + output_shape = (B, N, T, H) + softmax_stat_shape = (B, N, T) + workspace_shape = (0,) + workspace_type = ir.IntegerType.get_unsigned(8) + backend_config = create_dot_product_attention_backend_config( + B, N, T, S, query_type.element_type, scale, seed, dropout_rate, + mask_type, layout, is_bwd=False, + ) + # {Q, K, V, mask*, bias*, q_seqlen*, kv_seqlen*} + # {output, activation*, workspace} has_dropout = dropout_rate > 0 - has_bias, has_mask = variadic_args + has_bias, has_mask, _ = variadic_args operands = [query, key, value] if has_mask: operands.append(mask) if has_bias: operands.append(bias) - custom_call_name = get_custom_call_name(has_bias, has_mask, has_dropout, False) + if has_padding(mask_type): + operands.append(q_seqlen) + operands.append(kv_seqlen) + custom_call_name = get_custom_call_name(has_bias, has_dropout, False) # create output types and layouts - if is_flash_attention: + if is_training: result_types = [ ir.RankedTensorType.get(output_shape, query_type.element_type), - ir.RankedTensorType.get(scratch_shape, scratch_type), ir.RankedTensorType.get(softmax_stat_shape, ir.F32Type.get()), + ir.RankedTensorType.get(workspace_shape, workspace_type), ] - result_layouts = [output_layout] + default_layouts(scratch_shape, softmax_stat_shape) + result_layouts = [output_layout] + default_layouts(softmax_stat_shape, workspace_shape) else: result_types = [ ir.RankedTensorType.get(output_shape, query_type.element_type), - ir.RankedTensorType.get(scratch_shape, scratch_type), - ir.RankedTensorType.get(activation_shape, query_type.element_type), + ir.RankedTensorType.get(workspace_shape, workspace_type) ] - result_layouts = [output_layout] + default_layouts(scratch_shape, activation_shape) + result_layouts = [output_layout] + default_layouts(workspace_shape) # create custom call here out = mlir.custom_call( custom_call_name, result_types=result_types, operands=operands, backend_config=backend_config, - operand_layouts=default_layouts(*[ir.RankedTensorType(operand.type).shape for operand in operands]), + operand_layouts=default_layouts( + *[ir.RankedTensorType(operand.type).shape for operand in operands]), result_layouts=result_layouts, ) - # drop scratch memory - # output should be (batch, q_seq_len, num_heads, head_dim) instead of (batch, num_heads, q_seq_len, head_dim) - return [hlo.transpose(out.results[0], output_transpose_perm), out.results[2]] + # drop workspace memory + # output should be (B, T, N, H) instead of (B, N, T, H) + if is_training: + return [hlo.transpose(out.results[0], output_transpose_perm), out.results[1]] + else: + return [hlo.transpose(out.results[0], output_transpose_perm)] -def _dot_product_attention_bwd_cuda_lowering(ctx, query, key, value, bias, mask, activation, fwd_output, grad_output, - scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask): +def _dot_product_attention_bwd_cuda_lowering( + ctx, query, key, value, bias, mask, q_seqlen, kv_seqlen, activation, + fwd_output, grad_output, scale, seed, dropout_rate, variadic_args, + mask_type, layout): query_type = ir.RankedTensorType(query.type) query_shape = query_type.shape key_type = ir.RankedTensorType(key.type) key_shape = key_type.shape value_type = ir.RankedTensorType(value.type) - value_shape = value_type.shape - activation_type = ir.RankedTensorType(activation.type) - activation_shape = activation_type.shape - grad_output_type = ir.RankedTensorType(grad_output.type) - grad_output_shape = grad_output_type.shape - - batch, q_seq_len, num_heads, head_dim = query_shape - _, kv_seq_len, _, _ = key_shape - scratch_shape = (0,) - scratch_type = ir.IntegerType.get_unsigned(8) - - grad_query_shape = (batch, num_heads, q_seq_len, head_dim) - grad_key_shape = (batch, num_heads, kv_seq_len, head_dim) - grad_value_shape = (batch, num_heads, kv_seq_len, head_dim) - softmax_sum_shape = (batch, num_heads, q_seq_len) - grad_layout = (3, 1, 2, 0) - grad_transpose_perm = mlir.dense_int_array((0, 2, 1, 3)) - backend_config = create_dot_product_attention_backend_config(batch, num_heads, q_seq_len, kv_seq_len, query_type.element_type, scale, seed, dropout_rate, is_flash_attention, is_causal_mask, True) - # {Q, K, V, activation, dO, mask*, bias*, O*} - # {dQ, dK, dV, d_S*, softmax_sum*, d_Q_accum*, scratch, dbias*} + + if layout == AttentionLayout.BNTH.value: + B, q_N, T, H = query_shape + _, k_N, S, _ = key_shape + grad_layout = (3, 2, 1, 0) + grad_transpose_perm = mlir.dense_int_array((0, 1, 2, 3)) + else: + B, T, q_N, H = query_shape + _, S, k_N, _ = key_shape + grad_layout = (3, 1, 2, 0) + grad_transpose_perm = mlir.dense_int_array((0, 2, 1, 3)) + + workspace_shape = (0,) + workspace_type = ir.IntegerType.get_unsigned(8) + + grad_query_shape = (B, q_N, T, H) + grad_key_shape = (B, k_N, S, H) + grad_value_shape = (B, k_N, S, H) + backend_config = create_dot_product_attention_backend_config( + B, q_N, T, S, query_type.element_type, scale, seed, dropout_rate, + mask_type, layout, is_bwd=True, + ) + # {Q, K, V, activation, dO, mask*, bias*, O, q_seqlen*, kv_seqlen*} + # {dQ, dK, dV, dbias*, workspace} has_dropout = dropout_rate > 0 - has_bias, has_mask = variadic_args + has_bias, has_mask, has_dbias = variadic_args # create operands operands = [query, key, value, activation, grad_output] if has_mask: operands.append(mask) - if has_bias and is_flash_attention: + if has_bias: # flash attention requires bias in the bwd for remat operands.append(bias) - if is_flash_attention: - operands.append(fwd_output) + operands.append(fwd_output) + if has_padding(mask_type): + operands.append(q_seqlen) + operands.append(kv_seqlen) # get custom call name - custom_call_name = get_custom_call_name(has_bias, has_mask, has_dropout, True) + custom_call_name = get_custom_call_name(has_bias, has_dropout, True) # create output types and layouts - if is_flash_attention: - result_types = [ - ir.RankedTensorType.get(grad_query_shape, query_type.element_type), # grad query - ir.RankedTensorType.get(grad_key_shape, key_type.element_type), # grad key - ir.RankedTensorType.get(grad_value_shape, value_type.element_type), # grad value - ir.RankedTensorType.get(softmax_sum_shape, ir.F32Type.get()), # softmax_sum - ir.RankedTensorType.get(grad_query_shape, ir.F32Type.get()), # d_Q_accum - ir.RankedTensorType.get(scratch_shape, scratch_type), # scratch - ] - result_layouts = [grad_layout, grad_layout, grad_layout] + default_layouts(softmax_sum_shape, grad_query_shape, scratch_shape) - else: - result_types = [ - ir.RankedTensorType.get(grad_query_shape, query_type.element_type), # grad query - ir.RankedTensorType.get(grad_key_shape, key_type.element_type), # grad key - ir.RankedTensorType.get(grad_value_shape, value_type.element_type), # grad value - ir.RankedTensorType.get(activation_shape, activation_type.element_type), # dS - ir.RankedTensorType.get(scratch_shape, scratch_type), # scratch - ] - result_layouts = [grad_layout, grad_layout, grad_layout] + default_layouts(activation_shape, scratch_shape) + # grad_query, grad_key, grad_value + result_types = [ + ir.RankedTensorType.get(grad_query_shape, query_type.element_type), + ir.RankedTensorType.get(grad_key_shape, key_type.element_type), + ir.RankedTensorType.get(grad_value_shape, value_type.element_type), + ] + result_layouts = [grad_layout, grad_layout, grad_layout] + bias_type = ir.RankedTensorType(bias.type) + bias_shape = bias_type.shape + if has_dbias: + # cuDNN supports bias for this case + result_types.append( + ir.RankedTensorType.get(bias_shape, bias_type.element_type)) + result_layouts = result_layouts + default_layouts(bias_shape) + # workspace + result_types.append(ir.RankedTensorType.get(workspace_shape, workspace_type)) + result_layouts = result_layouts + default_layouts(workspace_shape) out = mlir.custom_call( custom_call_name, result_types=result_types, operands=operands, backend_config=backend_config, - operand_layouts=default_layouts(*[ir.RankedTensorType(operand.type).shape for operand in operands]), + operand_layouts=default_layouts( + *[ir.RankedTensorType(operand.type).shape for operand in operands]), result_layouts=result_layouts, ) - # Only keep dQ, dK and dV here - return [hlo.transpose(out.results[0], grad_transpose_perm), + dqkv = (hlo.transpose(out.results[0], grad_transpose_perm), hlo.transpose(out.results[1], grad_transpose_perm), - hlo.transpose(out.results[2], grad_transpose_perm)] + hlo.transpose(out.results[2], grad_transpose_perm)) + # Only keep dQ, dK, dV and dBias here + if has_dbias: + return dqkv + (out.results[3],) + else: + return dqkv # batcher def _check_valid_batch_dims(bdims): for dim in bdims: if dim not in [0, None]: - raise NotImplementedError("Currently only support batch_dim in [0, None], " \ - f"but got {dim=}") + raise NotImplementedError( + f"Currently only support batch_dim in [0, None], but got {dim=}") -def _dot_product_attention_fwd_batcher(batched_args, batch_dims, *, scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask): +def _dot_product_attention_fwd_batcher( + batched_args, batch_dims, *, scale, seed, dropout_rate, variadic_args, + mask_type, layout, is_training): _check_valid_batch_dims(batch_dims) - query, key, value, bias, mask = batched_args + query, key, value, bias, mask, q_seqlen, kv_seqlen = batched_args query_bdim = batch_dims[0] - out_bdims = query_bdim, query_bdim + if is_training: + out_bdims = query_bdim, query_bdim + else: + out_bdims = (query_bdim,) - *batch_tuple, q_seq_len, num_heads, head_dim = query.shape - *_, kv_seq_len, _, _ = key.shape - new_batch = reduce(operator.mul, batch_tuple) - has_bias, has_mask = variadic_args + if layout == AttentionLayout.BNTH.value: + *Bs, N, T, _ = query.shape + *_, _, S, _ = key.shape + else: + *Bs, T, N, _ = query.shape + *_, S, _, _ = key.shape + B = reduce(operator.mul, Bs) + has_bias, has_mask, _ = variadic_args # reshape to 4D shape - query = jnp.reshape(query, (new_batch, q_seq_len, num_heads, head_dim)) - key = jnp.reshape(key, (new_batch, kv_seq_len, num_heads, head_dim)) - value = jnp.reshape(value, (new_batch, kv_seq_len, num_heads, head_dim)) + query = jnp.reshape(query, (B,) + query.shape[-3:]) + key = jnp.reshape(key, (B,) + key.shape[-3:]) + value = jnp.reshape(value, (B,) + key.shape[-3:]) if has_bias: - bias = jnp.reshape(bias, (new_batch, num_heads, q_seq_len, kv_seq_len)) + bias = jnp.reshape(bias, (B, N, T, S)) if has_mask: - mask = jnp.reshape(mask, (new_batch, num_heads, q_seq_len, kv_seq_len)) + mask = jnp.reshape(mask, (B, N, T, S)) + if has_padding(mask_type): + q_seqlen = jnp.reshape(q_seqlen, (B, )) + kv_seqlen = jnp.reshape(kv_seqlen, (B, )) - output, activation = _dot_product_attention_fwd_p_wrapper.bind( - query, key, value, bias, mask, - scale=scale, seed=seed, dropout_rate=dropout_rate, - variadic_args=variadic_args, is_flash_attention=is_flash_attention, - is_causal_mask=is_causal_mask) + outputs = _dot_product_attention_fwd_p_wrapper.bind( + query, key, value, bias, mask, q_seqlen, kv_seqlen, scale=scale, + seed=seed, dropout_rate=dropout_rate, variadic_args=variadic_args, + mask_type=mask_type, layout=layout, is_training=is_training) # reshape to original shape - output = jnp.reshape(output, (*batch_tuple, q_seq_len, num_heads, head_dim)) - if is_flash_attention: - activation = jnp.reshape(activation, (*batch_tuple, num_heads, q_seq_len)) + output = outputs[0] + output = jnp.reshape(output, query.shape) + if is_training: + activation = outputs[1] + activation = jnp.reshape(activation, (*Bs, N, T)) + return (output, activation), out_bdims else: - activation = jnp.reshape(activation, (*batch_tuple, num_heads, q_seq_len, kv_seq_len)) - return (output, activation), out_bdims + return (output,), out_bdims -def _dot_product_attention_bwd_batcher(batched_args, batch_dims, *, scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask): +def _dot_product_attention_bwd_batcher( + batched_args, batch_dims, *, scale, seed, dropout_rate, variadic_args, + mask_type, layout): _check_valid_batch_dims(batch_dims) - query, key, value, bias, mask, activation, fwd_output, grad_output = batched_args + query, key, value, bias, mask, q_seqlen, \ + kv_seqlen, activation, fwd_output, grad_output = batched_args query_bdim = batch_dims[0] out_bdims = query_bdim, query_bdim, query_bdim - *batch_tuple, q_seq_len, num_heads, head_dim = query.shape - *_, kv_seq_len, _, _ = key.shape - new_batch = reduce(operator.mul, batch_tuple) - has_bias, has_mask = variadic_args + if layout == AttentionLayout.BNTH.value: + *Bs, N, T, _ = query.shape + *_, _, S, _ = key.shape + else: + *Bs, T, N, _ = query.shape + *_, S, _, _ = key.shape + B = reduce(operator.mul, Bs) + has_bias, has_mask, has_dbias = variadic_args # reshape to 4D shape - query = jnp.reshape(query, (new_batch, q_seq_len, num_heads, head_dim)) - key = jnp.reshape(key, (new_batch, kv_seq_len, num_heads, head_dim)) - value = jnp.reshape(value, (new_batch, kv_seq_len, num_heads, head_dim)) + query = jnp.reshape(query, (B,) + query.shape[-3:]) + key = jnp.reshape(key, (B,) + key.shape[-3:]) + value = jnp.reshape(value, (B,) + key.shape[-3:]) if has_bias: - bias = jnp.reshape(bias, (new_batch, num_heads, q_seq_len, kv_seq_len)) + bias = jnp.reshape(bias, (B, N, T, S)) if has_mask: - mask = jnp.reshape(mask, (new_batch, num_heads, q_seq_len, kv_seq_len)) - if is_flash_attention: - activation = jnp.reshape(activation, (new_batch, num_heads, q_seq_len)) - else: - activation = jnp.reshape(activation, (new_batch, num_heads, q_seq_len, kv_seq_len)) - fwd_output = jnp.reshape(fwd_output, (new_batch, q_seq_len, num_heads, head_dim)) - grad_output = jnp.reshape(grad_output, (new_batch, q_seq_len, num_heads, head_dim)) - - grad_query, grad_key, grad_value = _dot_product_attention_bwd_p_wrapper.bind( - query, key, value, bias, - mask, activation, fwd_output, grad_output, - scale=scale, seed=seed, dropout_rate=dropout_rate, - variadic_args=variadic_args, is_flash_attention=is_flash_attention, - is_causal_mask=is_causal_mask) + mask = jnp.reshape(mask, (B, N, T, S)) + if has_padding(mask_type): + q_seqlen = jnp.reshape(q_seqlen, (B, )) + kv_seqlen = jnp.reshape(kv_seqlen, (B, )) + + activation = jnp.reshape(activation, (B, N, T)) + fwd_output = jnp.reshape(fwd_output, (B,) + query.shape[-3:]) + grad_output = jnp.reshape(grad_output, (B,) + query.shape[-3:]) + + grads = _dot_product_attention_bwd_p_wrapper.bind( + query, key, value, bias, mask, q_seqlen, kv_seqlen, activation, + fwd_output, grad_output, scale=scale, seed=seed, + dropout_rate=dropout_rate, variadic_args=variadic_args, + mask_type=mask_type, layout=layout, + ) + grad_query, grad_key, grad_value = grads[:3] # reshape to original shape - grad_query = jnp.reshape(grad_query, (*batch_tuple, q_seq_len, num_heads, head_dim)) - grad_key = jnp.reshape(grad_key, (*batch_tuple, kv_seq_len, num_heads, head_dim)) - grad_value = jnp.reshape(grad_value, (*batch_tuple, kv_seq_len, num_heads, head_dim)) - grads = (grad_query, grad_key, grad_value) + grad_query = jnp.reshape(grad_query, query.shape) + grad_key = jnp.reshape(grad_key, key.shape) + grad_value = jnp.reshape(grad_value, value.shape) + if has_dbias: + grad_bias = grads[3] + grad_bias = jnp.reshape(grad_bias, bias.shape) + return grads + (grad_bias,), out_bdims + (query_bdim,) return grads, out_bdims # custom partitioning @@ -532,58 +698,75 @@ def _get_padded_spec(arg_info): assert len(spec) <= ndim return spec + (None,) * (ndim - len(spec)) -def _check_qkv_bias_mask_spec(query_spec, key_spec, value_spec, bias_spec, mask_spec): +def _check_qkv_bias_mask_spec( + query_spec, key_spec, value_spec, bias_spec, mask_spec): # check qkv spec if not query_spec == key_spec == value_spec: raise ValueError("Query, key and value should have same sharding.") *batch_spec, q_seq_spec, num_head_spec, head_spec = query_spec - if q_seq_spec != None: + if q_seq_spec is not None: raise ValueError("Sharding on sequence dim is not allowed.") - if head_spec != None: + if head_spec is not None: raise ValueError("Sharding on head dim is not allowed.") # check bias and mask spec if bias_spec: *bias_batch_spec, bias_num_head_spec, bias_q_seq_spec, bias_kv_seq_spec = bias_spec if bias_batch_spec != batch_spec or bias_num_head_spec != num_head_spec: - raise ValueError("Query and bias should have same sharding on batch and num_head dim.") - if bias_q_seq_spec != None or bias_kv_seq_spec != None: + raise ValueError( + "Query and bias should have same sharding on batch and num_head dim.") + if bias_q_seq_spec is not None or bias_kv_seq_spec is not None: raise ValueError("Sharding on bias sequence dim is not allowed.") if mask_spec: *mask_batch_spec, mask_num_head_spec, mask_q_seq_spec, mask_kv_seq_spec = mask_spec if mask_batch_spec != batch_spec or mask_num_head_spec != num_head_spec: - raise ValueError("Query and mask should have same sharding on batch and num_head dim.") - if mask_q_seq_spec != None or mask_kv_seq_spec != None: + raise ValueError( + "Query and mask should have same sharding on batch and num_head dim.") + if mask_q_seq_spec is not None or mask_kv_seq_spec is not None: raise ValueError("Sharding on mask sequence dim is not allowed.") # fwd custom partition -def _infer_fwd_output_sharding(mesh, arg_shapes, variadic_args): +def _infer_fwd_output_sharding(mesh, arg_shapes, variadic_args, is_training): # only sharding on batch and num_head dim is allowed # (*batch, q_seq, num_head, head) query_spec = _get_padded_spec(arg_shapes[0]) # (*batch, kv_seq, num_head, head) key_spec = _get_padded_spec(arg_shapes[1]) value_spec = _get_padded_spec(arg_shapes[2]) - has_bias, has_mask = variadic_args + has_bias, has_mask, _ = variadic_args bias_spec = _get_padded_spec(arg_shapes[3]) if has_bias else None mask_spec = _get_padded_spec(arg_shapes[4]) if has_mask else None - _check_qkv_bias_mask_spec(query_spec, key_spec, value_spec, bias_spec, mask_spec) + _check_qkv_bias_mask_spec( + query_spec, key_spec, value_spec, bias_spec, mask_spec) # keep out sharding same as query sharding since they have same shape out_sharding = NamedSharding(mesh, PartitionSpec(*query_spec)) - # activation sharding - *batch_spec, q_seq_spec, num_head_spec, head_spec = query_spec - activation_sharding = NamedSharding(mesh, PartitionSpec(*batch_spec, num_head_spec, q_seq_spec, None)) - return (out_sharding, activation_sharding) - -_dot_product_attention_fwd_lower = custom_partitioning(_dot_product_attention_fwd_impl, static_argnums=(5,6,7,8,9,10)) -def _dot_product_attention_fwd_infer_sharding_from_operands(scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask, mesh, arg_shapes, result_shape): - return _infer_fwd_output_sharding(mesh, arg_shapes, variadic_args) - -def _dot_product_attention_fwd_partition(scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask, mesh, arg_shapes, result_shape): + if is_training: + # activation sharding + *batch_spec, q_seq_spec, num_head_spec, _ = query_spec + activation_sharding = NamedSharding( + mesh, PartitionSpec(*batch_spec, num_head_spec, q_seq_spec, None)) + return [out_sharding, activation_sharding] + return [out_sharding] + +_dot_product_attention_fwd_lower = custom_partitioning( + _dot_product_attention_fwd_impl, static_argnums=(7, 8, 9, 10, 11, 12, 13)) + +def _dot_product_attention_fwd_infer_sharding_from_operands( + scale, seed, dropout_rate, variadic_args, mask_type, layout, is_training, + mesh, arg_shapes, result_shape): + return _infer_fwd_output_sharding(mesh, arg_shapes, variadic_args, is_training) + +def _dot_product_attention_fwd_partition( + scale, seed, dropout_rate, variadic_args, mask_type, layout, is_training, + mesh, arg_shapes, result_shape): # args sharding arg_shardings = tuple([arg_i.sharding for arg_i in arg_shapes]) - out_shardings = _infer_fwd_output_sharding(mesh, arg_shapes, variadic_args) - impl = partial(_dot_product_attention_fwd_impl, scale=scale, seed=seed, dropout_rate=dropout_rate, - variadic_args=variadic_args, is_flash_attention=is_flash_attention, is_causal_mask=is_causal_mask) + out_shardings = _infer_fwd_output_sharding( + mesh, arg_shapes, variadic_args, is_training) + impl = partial( + _dot_product_attention_fwd_impl, scale=scale, seed=seed, + dropout_rate=dropout_rate, variadic_args=variadic_args, + mask_type=mask_type, + layout=layout, is_training=is_training) return mesh, impl, out_shardings, arg_shardings # bwd custom partition @@ -593,34 +776,53 @@ def _infer_bwd_output_sharding(mesh, arg_shapes, variadic_args): # (*batch, kv_seq, num_head, head) key_spec = _get_padded_spec(arg_shapes[1]) value_spec = _get_padded_spec(arg_shapes[2]) - has_bias, has_mask = variadic_args + has_bias, has_mask, has_dbias = variadic_args bias_spec = _get_padded_spec(arg_shapes[3]) if has_bias else None mask_spec = _get_padded_spec(arg_shapes[4]) if has_mask else None - _check_qkv_bias_mask_spec(query_spec, key_spec, value_spec, bias_spec, mask_spec) + _check_qkv_bias_mask_spec( + query_spec, key_spec, value_spec, bias_spec, mask_spec) # keep grad query sharding same as query sharding grad_query_sharding = NamedSharding(mesh, PartitionSpec(*query_spec)) grad_key_sharding = NamedSharding(mesh, PartitionSpec(*key_spec)) grad_value_sharding = NamedSharding(mesh, PartitionSpec(*key_spec)) - out_shardings = (grad_query_sharding, grad_key_sharding, grad_value_sharding) + out_shardings = [grad_query_sharding, grad_key_sharding, grad_value_sharding] + if has_dbias: + grad_bias_sharding = NamedSharding(mesh, PartitionSpec(*bias_spec)) + out_shardings = out_shardings + [grad_bias_sharding] return out_shardings -_dot_product_attention_bwd_lower = custom_partitioning(_dot_product_attention_bwd_impl, static_argnums=(8,9,10,11,12,13)) -def _dot_product_attention_bwd_infer_sharding_from_operands(scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask, mesh, arg_shapes, result_shape): +_dot_product_attention_bwd_lower = custom_partitioning( + _dot_product_attention_bwd_impl, static_argnums=(10, 11, 12, 13, 14, 15) +) + +def _dot_product_attention_bwd_infer_sharding_from_operands( + scale, seed, dropout_rate, variadic_args, mask_type, layout, mesh, + arg_shapes, result_shape): return _infer_bwd_output_sharding(mesh, arg_shapes, variadic_args) -def _dot_product_attention_bwd_partition(scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask, mesh, arg_shapes, result_shape): +def _dot_product_attention_bwd_partition( + scale, seed, dropout_rate, variadic_args, mask_type, layout, mesh, + arg_shapes, result_shape): out_shardings = _infer_bwd_output_sharding(mesh, arg_shapes, variadic_args) # args sharding arg_shardings = tuple([arg_i.sharding for arg_i in arg_shapes]) - impl = partial(_dot_product_attention_bwd_impl, scale=scale, seed=seed, dropout_rate=dropout_rate, - variadic_args=variadic_args, is_flash_attention=is_flash_attention, is_causal_mask=is_causal_mask) + impl = partial( + _dot_product_attention_bwd_impl, scale=scale, seed=seed, + dropout_rate=dropout_rate, variadic_args=variadic_args, + mask_type=mask_type, + layout=layout, + ) return mesh, impl, out_shardings, arg_shardings # Create dot_product_attention_fwd_p for forward operation. _dot_product_attention_fwd_p = core.Primitive("dot_product_attention_fwd") _dot_product_attention_fwd_p.multiple_results = True -_dot_product_attention_fwd_p.def_impl(partial(xla.apply_primitive, _dot_product_attention_fwd_p)) -_dot_product_attention_fwd_p.def_abstract_eval(_dot_product_attention_fwd_abstract) +_dot_product_attention_fwd_p.def_impl( + partial(xla.apply_primitive, _dot_product_attention_fwd_p) +) +_dot_product_attention_fwd_p.def_abstract_eval( + _dot_product_attention_fwd_abstract +) mlir.register_lowering( _dot_product_attention_fwd_p, @@ -628,16 +830,24 @@ def _dot_product_attention_bwd_partition(scale, seed, dropout_rate, variadic_arg platform="cuda", ) -_dot_product_attention_fwd_p_wrapper = core.Primitive("dot_product_attention_fwd_wrapper") +_dot_product_attention_fwd_p_wrapper = core.Primitive( + "dot_product_attention_fwd_wrapper" +) _dot_product_attention_fwd_p_wrapper.multiple_results = True _dot_product_attention_fwd_p_wrapper.def_impl(_dot_product_attention_fwd_impl) -_dot_product_attention_fwd_p_wrapper.def_abstract_eval(_dot_product_attention_fwd_abstract) +_dot_product_attention_fwd_p_wrapper.def_abstract_eval( + _dot_product_attention_fwd_abstract +) # Create dot_product_attention_bwd_p for backward operation. _dot_product_attention_bwd_p = core.Primitive("dot_product_attention_bwd") _dot_product_attention_bwd_p.multiple_results = True -_dot_product_attention_bwd_p.def_impl(partial(xla.apply_primitive, _dot_product_attention_bwd_p)) -_dot_product_attention_bwd_p.def_abstract_eval(_dot_product_attention_bwd_abstract) +_dot_product_attention_bwd_p.def_impl( + partial(xla.apply_primitive, _dot_product_attention_bwd_p) +) +_dot_product_attention_bwd_p.def_abstract_eval( + _dot_product_attention_bwd_abstract +) mlir.register_lowering( _dot_product_attention_bwd_p, @@ -645,14 +855,21 @@ def _dot_product_attention_bwd_partition(scale, seed, dropout_rate, variadic_arg platform="cuda", ) -_dot_product_attention_bwd_p_wrapper = core.Primitive("dot_product_attention_bwd_wrapper") +_dot_product_attention_bwd_p_wrapper = core.Primitive( + "dot_product_attention_bwd_wrapper" +) _dot_product_attention_bwd_p_wrapper.multiple_results = True _dot_product_attention_bwd_p_wrapper.def_impl(_dot_product_attention_bwd_impl) -_dot_product_attention_bwd_p_wrapper.def_abstract_eval(_dot_product_attention_bwd_abstract) - +_dot_product_attention_bwd_p_wrapper.def_abstract_eval( + _dot_product_attention_bwd_abstract +) -batching.primitive_batchers[_dot_product_attention_fwd_p_wrapper] = _dot_product_attention_fwd_batcher -batching.primitive_batchers[_dot_product_attention_bwd_p_wrapper] = _dot_product_attention_bwd_batcher +batching.primitive_batchers[ + _dot_product_attention_fwd_p_wrapper +] = _dot_product_attention_fwd_batcher +batching.primitive_batchers[ + _dot_product_attention_bwd_p_wrapper +] = _dot_product_attention_bwd_batcher _dot_product_attention_fwd_lower.def_partition( infer_sharding_from_operands=_dot_product_attention_fwd_infer_sharding_from_operands, @@ -668,27 +885,38 @@ def _dot_product_attention_bwd_partition(scale, seed, dropout_rate, variadic_arg mlir.register_lowering(_dot_product_attention_bwd_p_wrapper, mlir.lower_fun(_dot_product_attention_bwd_lower, multiple_results=True)) -dispatch.prim_requires_devices_during_lowering.add(_dot_product_attention_fwd_p) -dispatch.prim_requires_devices_during_lowering.add(_dot_product_attention_fwd_p_wrapper) -dispatch.prim_requires_devices_during_lowering.add(_dot_product_attention_bwd_p) -dispatch.prim_requires_devices_during_lowering.add(_dot_product_attention_bwd_p_wrapper) +dispatch.prim_requires_devices_during_lowering.add( + _dot_product_attention_fwd_p +) +dispatch.prim_requires_devices_during_lowering.add( + _dot_product_attention_fwd_p_wrapper +) +dispatch.prim_requires_devices_during_lowering.add( + _dot_product_attention_bwd_p +) +dispatch.prim_requires_devices_during_lowering.add( + _dot_product_attention_bwd_p_wrapper +) -@partial(jax.custom_vjp, nondiff_argnums=(5, 6, 7, 8, 9, 10)) +@partial(jax.custom_vjp, nondiff_argnums=(7, 8, 9, 10, 11, 12, 13)) def _dot_product_attention(query: Array, - key: Array, - value: Array, - bias: Array, - mask: Array, - scale: float, - seed: int, - dropout_rate: float, - variadic_args: tuple[bool, ...], - is_flash_attention: bool, - is_causal_mask: bool): + key: Array, + value: Array, + bias: Array, + mask: Array, + q_seqlen: Array, + kv_seqlen: Array, + scale: float, + seed: int, + dropout_rate: float, + variadic_args: tuple[bool, ...], + mask_type: bool, + layout: int, + cudnn_version: int): output = _dot_product_attention_fwd( - query, key, value, bias, mask, - scale=scale, seed=seed, dropout_rate=dropout_rate, variadic_args=variadic_args, - is_flash_attention=is_flash_attention, is_causal_mask=is_causal_mask) + query, key, value, bias, mask, q_seqlen, kv_seqlen, scale=scale, + seed=seed, dropout_rate=dropout_rate, variadic_args=variadic_args, + mask_type=mask_type, layout=layout, cudnn_version=cudnn_version) return output # _dot_product_attention_fwd must have the same func signature as _dot_product_attention @@ -698,54 +926,78 @@ def _dot_product_attention(query: Array, def dot_product_attention(query: Array, key: Array, value: Array, - bias: Optional[Array] = None, - mask: Optional[Array] = None, + bias: Array | None = None, + mask: Array | None = None, + q_seqlen: Array | None = None, + kv_seqlen: Array | None = None, *, scale: float = 1.0, - is_causal_mask: bool = False, + mask_type: MaskType = MaskType.NO_MASK, seed: int = 42, - dropout_rate: float = 0.): - """Computes dot-product attention given query, key, and value. - This is the core function for applying attention based on - https://arxiv.org/abs/1706.03762. It calculates the attention weights given - query and key and combines the values using the attention weights. - batch seq num_heads, head_dim // but all assume Q, K and V will have same - b q_seq num_heads head_dim -> Q - b kv_seq num_heads head_dim -> K - b kv_seq num_heads head_dim -> V + dropout_rate: float = 0., + qkv_layout: str = "BTNH"): + """Computes dot-product attention given query (Q), key (K), and value (V). + + This function serves as the core operation for applying attention + mechanisms as described in the paper [https://arxiv.org/abs/1706.03762]. + Initially, it determines the attention weights by processing Q and K, + subsequently combining the outcomes using K. Throughout this function, we + utilize the following uppercase letters to represent specific parameters of + array: + + B = batch size + S = length of the key/value (source) + T = length of the query (target) + N = number of attention heads + H = dimensions of each attention head. + + The supported layouts for Q, K, V are either BT(S)NH or BNT(S)H, and they must + adhere to the same layout. The output layout remains consistent with Q, + defaulting to BT(S)NH. + Args: - query: queries for calculating attention with shape of `[batch, q_length, - num_heads, qk_depth_per_head]`. - key: keys for calculating attention with shape of `[batch, kv_length, - num_heads, qk_depth_per_head]`. - value: values to be used in attention with shape of `[batch, kv_length, - num_heads, v_depth_per_head]`. - bias: bias to be added to logits with shape of `[batch, num_heads, - q_length, kv_length]`. - mask: mask used mask out logits with shape of `[batch, num_heads, - q_length, kv_length]`. - scale: scale for the query. - dropout_rate: dropout rate. + query: Queries for attention calculation with a shape of BTNH or BNTH. + key: Keys for attention calculation with a shape of BSNH or BNSH. + value: Values to be used in attention with a shape of BSNH or BNSH. + bias: Bias to be added to logits with a shape of BNTS. + mask: Mask used to filter out logits with a shape of BNTS. + q_seqlen: Non padded sequence length of Queries with a shape of B. + kv_seqlen: Non padded sequence length of Keys and Values with a shape of B. + scale: Scale for the query. + dropout_rate: Dropout rate. + qkv_layout: Layout string, with supported formats being BTNH, BNTH, BSNH, + BNSH. + Returns: - Output of shape `[batch, q_length, num_heads, v_depth_per_head]`. + Output of the same shape as the query. """ - # check if query, key and value layout meets cuDNN layout requirement - check_qkv_layout(query, key, value) - # check if flash attention is supported for this attention pattern - is_flash_attention, is_cross_attention = check_is_flash_attention(query, key) - # check if cuDNN is installed and if cuDNN version is sufficient - check_cudnn_version(is_flash_attention, is_cross_attention) - if mask is not None and is_causal_mask: - raise ValueError("can not apply a mask and generate a causal_mask at the same time.") - if not is_flash_attention and is_causal_mask: - raise ValueError("can only generate a causal_mask with flash attention.") - variadic_args = (bias is not None, mask is not None) + # check if cuDNN is installed + cudnn_version = check_cudnn_version() + # only support Ampere and Hopper for now + check_compute_capability((80, 90)) + layout = _normalize_layout(qkv_layout) + if bias is not None: + # reshape bias to have 4D shape + bias = bias.reshape((1,) * (4 - len(bias.shape)) + bias.shape) + # check if input shape and data type is compatiable + check_layout(query, key, value, bias, q_seqlen, kv_seqlen, layout) + if has_padding(mask_type) and (q_seqlen is None or kv_seqlen is None): + raise ValueError("Require q_seqlen and kv_seqlen to generate padding mask") + has_bias = bias is not None + has_mask = mask is not None + has_dbias = has_bias and \ + should_export_dbias(bias.shape, query.shape, layout) # type: ignore[union-attr] + variadic_args = (has_bias, has_mask, has_dbias) if bias is None: bias = jnp.zeros(0, dtype=query.dtype) if mask is None: mask = jnp.zeros(0, dtype=query.dtype) + if q_seqlen is None: + q_seqlen = jnp.zeros(0, dtype=query.dtype) + if kv_seqlen is None: + kv_seqlen = jnp.zeros(0, dtype=query.dtype) output = _dot_product_attention( - query, key, value, bias, mask, - scale, seed, dropout_rate, variadic_args, - is_flash_attention, is_causal_mask) + query, key, value, bias, mask, q_seqlen, kv_seqlen, scale, seed, + dropout_rate, variadic_args, mask_type, layout.value, cudnn_version + ) return output diff --git a/jax/_src/custom_batching.py b/jax/_src/custom_batching.py index fa74549c01cf..4d41849b75d3 100644 --- a/jax/_src/custom_batching.py +++ b/jax/_src/custom_batching.py @@ -14,9 +14,9 @@ from __future__ import annotations +from collections.abc import Callable import functools import operator -from typing import Callable from jax import lax from jax._src import api @@ -53,7 +53,7 @@ class custom_vmap: def __init__(self, fun: Callable): functools.update_wrapper(self, fun) - self.fun = fun # type: ignore[assignment] + self.fun = fun self.vmap_rule = None __getattr__ = custom_api_util.forward_attr diff --git a/jax/_src/custom_derivatives.py b/jax/_src/custom_derivatives.py index 11740b878131..46d9fab00455 100644 --- a/jax/_src/custom_derivatives.py +++ b/jax/_src/custom_derivatives.py @@ -14,11 +14,11 @@ from __future__ import annotations -from collections.abc import Sequence +from collections.abc import Callable, Sequence import dataclasses from functools import update_wrapper, reduce, partial import inspect -from typing import Any, Callable, Generic, TypeVar +from typing import Any, Generic, TypeVar from jax._src import config from jax._src import core @@ -40,10 +40,12 @@ from jax._src.interpreters import xla from jax._src.interpreters.batching import not_mapped from jax._src.lax import lax -from jax._src.tree_util import (tree_flatten, tree_unflatten, tree_map, - treedef_is_leaf, treedef_tuple, - register_pytree_node_class, tree_leaves) -from jax._src.util import cache, safe_zip, safe_map, split_list, Unhashable +from jax._src.tree_util import ( + tree_flatten, tree_unflatten, tree_map, treedef_is_leaf, treedef_tuple, + register_pytree_node_class, tree_leaves, tree_flatten_with_path, keystr, + treedef_children) +from jax._src.util import (cache, safe_zip, safe_map, split_list, Unhashable, + unzip2) traceback_util.register_exclusion(__file__) @@ -178,7 +180,7 @@ def defjvp(self, Returns: None. - Example:: + Examples: @jax.custom_jvp def f(x, y): @@ -210,7 +212,7 @@ def defjvps(self, *jvps: Callable[..., ReturnValue] | None): Returns: None. - Example:: + Examples: @jax.custom_jvp def f(x, y): @@ -250,7 +252,7 @@ def __call__(self, *args: Any, **kwargs: Any) -> ReturnValue: # pytype: disable static_args = [args[i] for i in self.nondiff_argnums] jvp = _add_args(lu.wrap_init(self.jvp), static_args) else: - f_, dyn_args = lu.wrap_init(self.fun), args # type: ignore + f_, dyn_args = lu.wrap_init(self.fun), args jvp = lu.wrap_init(self.jvp) args_flat, in_tree = tree_flatten(dyn_args) flat_fun, out_type1 = _flatten_fun_nokwargs(f_, in_tree) @@ -357,9 +359,9 @@ def bind(self, fun, jvp, *args, symbolic_zeros): fun, self, top_trace and top_trace.level, False) jvp, env_trace_todo2 = process_env_traces( jvp, self, top_trace and top_trace.level, True) - tracers = map(top_trace.full_raise, args) # type: ignore - outs = top_trace.process_custom_jvp_call(self, fun, jvp, tracers, # type: ignore - symbolic_zeros=symbolic_zeros) # type: ignore + tracers = map(top_trace.full_raise, args) + outs = top_trace.process_custom_jvp_call(self, fun, jvp, tracers, + symbolic_zeros=symbolic_zeros) _, env_trace_todo = lu.merge_linear_aux(env_trace_todo1, env_trace_todo2) return core.apply_todos(env_trace_todo, map(core.full_lower, outs)) @@ -565,7 +567,7 @@ def defvjp(self, Returns: None. - Example:: + Examples: @jax.custom_vjp def f(x, y): @@ -727,12 +729,15 @@ def _flatten_bwd(in_tree, in_avals, out_trees, *args): py_res = tree_unflatten(res_tree, res) py_cts_out = tree_unflatten(out_tree, cts_out) py_cts_in = yield (py_res, py_cts_out), {} + if isinstance(py_cts_in, list) and len(py_cts_in) == len(treedef_children(in_tree)): + py_cts_in = tuple(py_cts_in) # For each None in py_cts_in, indicating an argument for which the rule # produces no cotangent, we replace it with a pytree with the structure of the # corresponding subtree of in_tree and with leaves of a non-pytree sentinel # object, to be replaced with Nones in the final returned result. zero = object() # non-pytree sentinel to replace Nones in py_cts_in dummy = tree_unflatten(in_tree, [object()] * in_tree.num_leaves) + keypaths, _ = unzip2(tree_flatten_with_path(dummy)[0]) cts_in_flat = [] def append(x, d): num_leaves = len(tree_flatten(d)[0]) @@ -747,18 +752,51 @@ def append(x, d): tree_map(append, py_cts_in, dummy, is_leaf=lambda x: x is None) except ValueError: _, in_tree2 = tree_flatten(py_cts_in) - msg = ("Custom VJP rule must produce an output with the same container " + msg = ("Custom VJP bwd rule must produce an output with the same container " "(pytree) structure as the args tuple of the primal function, " "and in particular must produce a tuple of length equal to the " - "number of arguments to the primal function, but got VJP output " + "number of arguments to the primal function, but got bwd output " "structure {} for primal input structure {}.") raise TypeError(msg.format(in_tree2, in_tree)) from None - # Ignore any None cotangents, and any corresponding to inputs for which the - # type doesn't equal the tangent type (i.e. float0s) - # TODO(mattjj): change this to check if tangent type represents 0dim vspace - yield [Zero(a.at_least_vspace()) if ct is zero or a != a.at_least_vspace() - else ct for a, ct in zip(in_avals, cts_in_flat)] - + results = [] + for kp, a, ct in zip(keypaths, in_avals, cts_in_flat): + if ct is zero or a != a.at_least_vspace(): + results.append(Zero(a.at_least_vspace())) + elif type(ct) is SymbolicZero: + if not core.typecompat(a.at_least_vspace(), a_ := ct.aval): + msg = ("Custom VJP bwd rule produced a SymbolicZero with a shape/dtype " + "that does not match the corresponding input tangent shape/dtype: " + f"at output{keystr(kp)} the SymbolicZero had shape/dtype " + f"{a_.str_short()} while the " + f"corresponding input had shape/dtype {a.str_short()}. " + "Consider just returning a None here instead of a SymbolicZero " + "object.") + raise ValueError(msg) + results.append(Zero(ct.aval)) + else: + if (not core.typecompat(a.at_least_vspace(), a_ := core.get_aval(ct)) + and not (_temporary_dtype_exception(a, a_) or + _temporary_shape_exception(a, a_))): + msg = ("Custom VJP bwd rule must produce an output with the same " + "shape/dtypes as the args tuple of the primal function, but at " + f"output{keystr(kp)} the bwd rule produced an output of " + f"shape/dtype {raise_to_shaped(a_).str_short()} corresponding " + f"to an input of shape/dtype {a.str_short()}.") + raise ValueError(msg) + results.append(ct) + yield results + +# TODO(mattjj): remove both these exceptions to cotangent compatibility check +def _temporary_dtype_exception(a, a_) -> bool: + if isinstance(a, core.ShapedArray) and isinstance(a_, core.ShapedArray): + return (a.shape == a_.shape and + (dtypes.issubdtype(a_.dtype, dtypes.extended) or + dtypes.issubdtype(a.dtype, dtypes.np.inexact))) + return False + +# TODO(mattjj): remove both these exceptions to cotangent compatibility check +def _temporary_shape_exception(a, a_) -> bool: + return config.custom_vjp_disable_shape_check.value class CustomVJPCallPrimitive(core.CallPrimitive): initial_style: core.Primitive @@ -770,7 +808,7 @@ def bind(self, fun, fwd, bwd, *args, out_trees, symbolic_zeros): fun, self, top_trace and top_trace.level, False) fwd, env_trace_todo2 = process_env_traces_fwd( fwd, top_trace and top_trace.level, out_trees) - tracers = map(top_trace.full_raise, args) # type: ignore + tracers = map(top_trace.full_raise, args) bwd_ = lambda *args: bwd(*args) outs = top_trace.process_custom_vjp_call(self, fun, fwd, bwd_, tracers, out_trees=out_trees, diff --git a/jax/_src/custom_transpose.py b/jax/_src/custom_transpose.py index 210d40632c37..a4de1b8cc46c 100644 --- a/jax/_src/custom_transpose.py +++ b/jax/_src/custom_transpose.py @@ -14,8 +14,9 @@ from __future__ import annotations +from collections.abc import Callable import functools -from typing import Any, Callable +from typing import Any from jax._src import ad_util from jax._src import api_util @@ -59,7 +60,7 @@ def transformation_with_aux( return fun.wrap(gen, gen_static_args, out_store), out_thunk flatten_fun_nokwargs = transformation_with_aux( - api_util.flatten_fun_nokwargs.args[0]) # type: ignore[has-type] + api_util.flatten_fun_nokwargs.args[0]) ### api @@ -71,7 +72,7 @@ class custom_transpose: def __init__(self, fun: Callable): functools.update_wrapper(self, fun) - self.fun = fun # type: ignore[assignment] + self.fun = fun __getattr__ = custom_api_util.forward_attr diff --git a/jax/_src/debugger/core.py b/jax/_src/debugger/core.py index 0b254d0d1639..f6b0a81baf92 100644 --- a/jax/_src/debugger/core.py +++ b/jax/_src/debugger/core.py @@ -161,7 +161,7 @@ def breakpoint(*, backend: str | None = None, filter_frames: bool = True, debugger and in the absence of other registered debuggers, falls back to the CLI debugger. filter_frames: Whether or not to filter out JAX-internal stack frames from - the traceback. Since some libraries, like Flax, also make user of JAX's + the traceback. Since some libraries, like Flax, also make use of JAX's stack frame filtering system, this option can also affect whether stack frames from libraries are filtered. num_frames: The number of frames above the current stack frame to make diff --git a/jax/_src/debugger/web_debugger.py b/jax/_src/debugger/web_debugger.py index d5c97d1903ad..443bfa676715 100644 --- a/jax/_src/debugger/web_debugger.py +++ b/jax/_src/debugger/web_debugger.py @@ -13,6 +13,9 @@ # limitations under the License. from __future__ import annotations +import atexit +import functools +import importlib.util import os from typing import Any import weakref @@ -20,16 +23,22 @@ from jax._src.debugger import cli_debugger from jax._src.debugger import core as debugger_core -web_pdb_version: tuple[int, ...] | None = None -try: + +@functools.cache +def _web_pdb_version() -> tuple[int, ...]: import web_pdb # pytype: disable=import-error - web_pdb_version = tuple(map(int, web_pdb.__version__.split("."))) - WEB_PDB_ENABLED = True -except: - WEB_PDB_ENABLED = False + return tuple(map(int, web_pdb.__version__.split("."))) + + +_web_consoles: dict[tuple[str, int], Any] = {} -_web_consoles: dict[tuple[str, int], web_pdb.WebConsole] = {} +@atexit.register +def _close_debuggers(): + for console in _web_consoles.values(): + console.close() + _web_consoles.clear() + class WebDebugger(cli_debugger.CliDebugger): """A web-based debugger.""" @@ -39,6 +48,7 @@ class WebDebugger(cli_debugger.CliDebugger): def __init__(self, frames: list[debugger_core.DebuggerFrame], thread_id, completekey: str = "tab", host: str = "", port: int = 5555): if (host, port) not in _web_consoles: + import web_pdb # pytype: disable=import-error _web_consoles[host, port] = web_pdb.WebConsole(host, port, self) # Clobber the debugger in the web console _web_console = _web_consoles[host, port] @@ -54,7 +64,7 @@ def get_current_frame_data(self): current_line = None if current_frame.offset is not None: current_line = current_frame.offset + 1 - if web_pdb_version and web_pdb_version < (1, 4, 4): + if _web_pdb_version() < (1, 4, 4): return { 'filename': filename, 'listing': '\n'.join(lines), @@ -74,15 +84,15 @@ def get_current_frame_data(self): def get_globals(self): current_frame = self.current_frame() - globals = "\n".join([f"{key} = {value}" for key, value in - sorted(current_frame.globals.items())]) - return globals + return "\n".join( + f"{key} = {value}" + for key, value in sorted(current_frame.globals.items())) def get_locals(self): current_frame = self.current_frame() - locals = "\n".join([f"{key} = {value}" for key, value in - sorted(current_frame.locals.items())]) - return locals + return "\n".join( + f"{key} = {value}" + for key, value in sorted(current_frame.locals.items())) def run(self): return self.cmdloop() @@ -91,5 +101,6 @@ def run_debugger(frames: list[debugger_core.DebuggerFrame], thread_id: int | None, **kwargs: Any): WebDebugger(frames, thread_id, **kwargs).run() -if WEB_PDB_ENABLED: + +if importlib.util.find_spec("web_pdb") is not None: debugger_core.register_debugger("web", run_debugger, -2) diff --git a/jax/_src/debugging.py b/jax/_src/debugging.py index 84fc677c9f69..7d8b3a914b6d 100644 --- a/jax/_src/debugging.py +++ b/jax/_src/debugging.py @@ -15,15 +15,18 @@ from __future__ import annotations -from collections.abc import Sequence +import importlib.util +from collections.abc import Callable, Sequence import functools +import logging import string import sys -from typing import Any, Callable, Union +from typing import Any, Union import weakref import numpy as np +import jax import jax.numpy as jnp from jax import lax @@ -44,19 +47,7 @@ from jax._src.sharding import Sharding from jax._src.sharding_impls import NamedSharding, parse_flatten_op_sharding -# pytype: disable=import-error -try: - import rich - import rich.align - import rich.box - import rich.console - import rich.padding - import rich.style - import rich.table - RICH_ENABLED = True -except: - RICH_ENABLED = False -# pytype: enable=import-error +logger = logging.getLogger(__name__) class DebugEffect(effects.Effect): __str__ = lambda self: "Debug" @@ -86,7 +77,22 @@ class OrderedDebugEffect(effects.Effect): def debug_callback_impl(*args, callback: Callable[..., Any], effect: DebugEffect): del effect - return callback(*args) + try: + cpu_device, *_ = jax.local_devices(backend="cpu") + except RuntimeError as e: + raise RuntimeError( + "jax.debug.callback failed to find a local CPU device to place the" + " inputs on. Make sure \"cpu\" is listed in --jax_platforms or the" + " JAX_PLATFORMS environment variable." + ) from e + args = jax.device_put(args, cpu_device) + with jax.default_device(cpu_device): + try: + callback(*args) + except BaseException: + logger.exception("jax.debug.callback failed") + raise + return () @debug_callback_p.def_effectful_abstract_eval def debug_callback_abstract_eval(*flat_avals, callback: Callable[..., Any], @@ -149,18 +155,19 @@ def debug_callback_lowering(ctx, *args, effect, callback, **params): sharding = None def _callback(*flat_args): - return tuple( - debug_callback_p.impl( - *flat_args, effect=effect, callback=callback, **params)) + debug_callback_p.impl( + *flat_args, effect=effect, callback=callback, **params) + return () if effects.ordered_effects.contains(effect): - token = ctx.tokens_in.get(effect)[0] + [token] = ctx.tokens_in.get(effect) result, token, _ = mlir.emit_python_callback( - ctx, _callback, token, list(args), ctx.avals_in, ctx.avals_out, True) + ctx, _callback, token, list(args), ctx.avals_in, ctx.avals_out, + has_side_effect=True) ctx.set_tokens_out(mlir.TokenSet({effect: (token,)})) else: - result, token, _ = mlir.emit_python_callback( - ctx, _callback, None, list(args), ctx.avals_in, ctx.avals_out, True, - sharding=sharding) + result, _, _ = mlir.emit_python_callback( + ctx, _callback, None, list(args), ctx.avals_in, ctx.avals_out, + has_side_effect=True, sharding=sharding) return result mlir.register_lowering(debug_callback_p, debug_callback_lowering, platform="cpu") @@ -200,7 +207,7 @@ def _debug_callback_partial_eval_custom(saveable, unks_in, inst_in, eqn): pe.partial_eval_jaxpr_custom_rules[debug_callback_p] = ( _debug_callback_partial_eval_custom) -def debug_callback(callback: Callable[..., Any], *args: Any, +def debug_callback(callback: Callable[..., None], *args: Any, ordered: bool = False, **kwargs: Any) -> None: """Calls a stageable Python callback. @@ -219,7 +226,7 @@ def debug_callback(callback: Callable[..., Any], *args: Any, of the computation are duplicated or dropped. Args: - callback: A Python callable. Its return value will be ignored. + callback: A Python callable returning None. *args: The positional arguments to the callback. ordered: A keyword only argument used to indicate whether or not the staged out computation will enforce ordering of this callback w.r.t. @@ -244,7 +251,7 @@ def debug_callback(callback: Callable[..., Any], *args: Any, def _flat_callback(*flat_args): args, kwargs = tree_util.tree_unflatten(in_tree, flat_args) callback(*args, **kwargs) - return [] + return () debug_callback_p.bind(*flat_args, callback=_flat_callback, effect=effect) class _DebugPrintFormatChecker(string.Formatter): @@ -384,7 +391,7 @@ def _hlo_sharding_callback(hlo_sharding: xc.HloSharding): has_side_effect=ir.BoolAttr.get(True), api_version=mlir.i32_attr(1), called_computations=ir.ArrayAttr.get([]), - backend_config=ir.StringAttr.get(key), + backend_config=ir.StringAttr.get(key), # type: ignore[arg-type] operand_layouts=None, result_layouts=None) return [] @@ -440,8 +447,19 @@ def visualize_sharding(shape: Sequence[int], sharding: Sharding, *, min_width: int = 9, max_width: int = 80, color_map: ColorMap | None = None): """Visualizes a ``Sharding`` using ``rich``.""" - if not RICH_ENABLED: + if not importlib.util.find_spec("rich"): raise ValueError("`visualize_sharding` requires `rich` to be installed.") + + # These imports are local so that they don't affect JAX import times. + # pytype: disable=import-error + import rich.align + import rich.console + import rich.box + import rich.padding + import rich.style + import rich.table + # pytype: enable=import-error + if len(shape) > 2 or len(shape) < 1: raise ValueError( "`visualize_sharding` only works for shapes with 1 and 2 dimensions.") @@ -493,11 +511,11 @@ def visualize_sharding(shape: Sequence[int], sharding: Sharding, *, heights[chunk_idxs] = None widths[chunk_idxs] = horiz_size / shape[0] slices.setdefault(chunk_idxs, set()).add(dev.id) - num_rows = max([a[0] for a in slices.keys()]) + 1 + num_rows = max(a[0] for a in slices.keys()) + 1 if len(list(slices.keys())[0]) == 1: num_cols = 1 else: - num_cols = max([a[1] for a in slices.keys()]) + 1 + num_cols = max(a[1] for a in slices.keys()) + 1 color_iter = make_color_iter(color_map, num_rows, num_cols) table = rich.table.Table(show_header=False, show_lines=not use_color, diff --git a/jax/_src/deprecations.py b/jax/_src/deprecations.py index 782ac0ca1ca4..5513b169cca9 100644 --- a/jax/_src/deprecations.py +++ b/jax/_src/deprecations.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from dataclasses import dataclass from types import ModuleType import warnings @@ -47,7 +48,7 @@ def deprecation_getattr(module, deprecations): def getattr(name): if name in deprecations: message, fn = deprecations[name] - if fn is None: + if fn is None: # Is the deprecation accelerated? raise AttributeError(message) warnings.warn(message, DeprecationWarning, stacklevel=2) return fn @@ -56,27 +57,56 @@ def getattr(name): return getattr -def accelerate_module_deprecation(module: ModuleType, name: str) -> None: - """Accelerate the deprecation of a module-level attribute""" +def accelerate_getattr_deprecation(module: ModuleType, name: str) -> None: + """Accelerate the deprecation of a module-level attribute. + + Raises an AttributeError instead of a DeprecationWarning upon attribute access. + Used in Google-internal code to implement faster deprecation. + """ message, _ = module._deprecations[name] module._deprecations[name] = (message, None) +# The following mechanism is a separate one, for registering and +# accelerating deprecations that are not imports (for example, deprecations +# of a function argument). +# Maps a globally unique string ID to a DeprecationState, which tracks whether +# the deprecation is accelerated. +# The intent is that non-accelerated deprecations will warn, and accelerated +# deprecations will error. + +@dataclass +class DeprecationState: + accelerated: bool = False + +_registered_deprecations: dict[str, DeprecationState] = {} + -_registered_deprecations: dict[tuple[str, str], bool] = {} +def register(deprecation_id: str) -> None: + _registered_deprecations[deprecation_id] = DeprecationState() -def register(module: str, key: str) -> None: - _registered_deprecations[module, key] = False +def unregister(deprecation_id: str) -> None: + if deprecation_id not in _registered_deprecations: + raise ValueError(f"{deprecation_id=!r} not registered.") + _registered_deprecations.pop(deprecation_id) -def unregister(module: str, key: str) -> None: - _registered_deprecations.pop((module, key)) +def accelerate(deprecation_id: str) -> None: + if deprecation_id not in _registered_deprecations: + raise ValueError(f"{deprecation_id=!r} not registered.") + _registered_deprecations[deprecation_id].accelerated = True -def accelerate(module: str, key: str) -> None: - assert (module, key) in _registered_deprecations - _registered_deprecations[module, key] = True +def is_accelerated(deprecation_id: str) -> bool: + if deprecation_id not in _registered_deprecations: + raise ValueError(f"{deprecation_id=!r} not registered.") + return _registered_deprecations[deprecation_id].accelerated -def is_accelerated(module: str, key: str) -> bool: - return _registered_deprecations[module, key] +def warn(deprecation_id: str, message: str, stacklevel: int) -> None: + """Warns about a deprecation, or errors if the deprecation is accelerated.""" + if is_accelerated(deprecation_id): + raise ValueError(message) + else: + warnings.warn(message, category=DeprecationWarning, + stacklevel=stacklevel + 1) diff --git a/jax/_src/dispatch.py b/jax/_src/dispatch.py index 8ea1a4ba45a6..ec7bb81aff3b 100644 --- a/jax/_src/dispatch.py +++ b/jax/_src/dispatch.py @@ -16,12 +16,13 @@ from __future__ import annotations import atexit -from collections.abc import Iterator, Sequence +from collections.abc import Callable, Iterator, Sequence import contextlib +import dataclasses from functools import partial import itertools import time -from typing import Any, Callable, NamedTuple +from typing import Any, NamedTuple import logging import threading @@ -39,18 +40,19 @@ from jax._src import xla_bridge as xb from jax._src.interpreters import ad from jax._src.interpreters import batching +from jax._src.abstract_arrays import array_types from jax._src.interpreters import mlir from jax._src.interpreters import xla from jax._src.interpreters import pxla from jax._src import lib from jax._src.lib import xla_client as xc -from jax._src.lib import xla_extension_version from jax._src.monitoring import record_event_duration_secs from jax._src.partition_spec import PartitionSpec from jax._src.sharding import Sharding from jax._src.sharding_impls import ( - PmapSharding, SingleDeviceSharding, NamedSharding, XLACompatibleSharding, - GSPMDSharding, TransferToMemoryKind) + SingleDeviceSharding, NamedSharding, + GSPMDSharding, TransferToMemoryKind, is_single_device_sharding) +from jax._src.layout import Layout, DeviceLocalLayout JAXPR_TRACE_EVENT = "/jax/core/compile/jaxpr_trace_duration" @@ -81,15 +83,11 @@ def apply_primitive(prim, *args, **params): fun = xla_primitive_callable(prim, **params) # TODO(yashkatariya): Investigate adding is_primitive to jit and never # triggering the disable jit path instead of messing around with it here. - if xla_extension_version >= 218: - prev = lib.jax_jit.swap_thread_local_state_disable_jit(False) - try: - outs = fun(*args) - finally: - lib.jax_jit.swap_thread_local_state_disable_jit(prev) - else: - with config.disable_jit(False): - outs = fun(*args) + prev = lib.jax_jit.swap_thread_local_state_disable_jit(False) + try: + outs = fun(*args) + finally: + lib.jax_jit.swap_thread_local_state_disable_jit(prev) return outs @util.cache() @@ -107,11 +105,11 @@ def simple_impl(prim): RuntimeToken = Any class RuntimeTokenSet(threading.local): - """See docstring for effect.py module for the calling convention for tokens.""" + """See docstring for effects.py module for the calling convention for tokens.""" # For each ordered effect, the token returned by the last dispatched # computation, sharded over the devices in that computation. - current_tokens: dict[core.Effect, jax.Array] + current_tokens: dict[core.Effect, core.Token] # For each device, the runtime token returned by the last dispatched # computation on that device. @@ -121,15 +119,26 @@ def __init__(self): self.current_tokens = {} self.output_runtime_tokens = {} - def get_token_input(self, eff: core.Effect, - devices: list[Device]) -> jax.Array: + def get_token_input( + self, eff: core.Effect, devices: list[Device] + ) -> core.Token: tok = self.current_tokens.get(eff, np.zeros(0, np.bool_)) + + if isinstance(tok, core.Token): + # The order of devices may change, so we need to reshard if necessary. + # TODO(yueshengys): This might still be buggy in a multi-process SPMD + # scenario. Revise the logic later. A distributed shutdown barrier inside + # the XLA program may be needed. + return jax.device_put(tok, jax.sharding.PositionalSharding(devices)) + + # We only use replicated sharding for the first time when the token for the + # order effect hasn't been created. s = jax.sharding.GSPMDSharding.get_replicated(devices) - sharded_tok = pxla.shard_args([s], [tok])[0] + sharded_tok = core.Token(pxla.shard_args([s], [tok])[0]) self.current_tokens[eff] = sharded_tok return sharded_tok - def set_token_result(self, eff: core.Effect, token: jax.Array): + def set_token_result(self, eff: core.Effect, token: core.Token): self.current_tokens[eff] = token def set_output_runtime_token(self, device: Device, token: RuntimeToken): @@ -156,12 +165,6 @@ def wait_for_tokens(): runtime_tokens.block_until_ready() -def is_single_device_sharding(sharding: Sharding) -> bool: - # Special case PmapSharding here because PmapSharding maps away an axis - # and needs to be handled separately.test_pjit_single_device_sharding_add - return len(sharding.device_set) == 1 and not isinstance(sharding, PmapSharding) - - @contextlib.contextmanager def log_elapsed_time(fmt: str, fun_name: str, event: str | None = None): if _on_exit: @@ -219,7 +222,7 @@ class SourceInfo(NamedTuple): def jaxpr_shardings( jaxpr: core.Jaxpr, -) -> Iterator[tuple[XLACompatibleSharding, SourceInfo]]: +) -> Iterator[tuple[Sharding, SourceInfo]]: from jax._src import pjit from jax.experimental import shard_map @@ -239,10 +242,9 @@ def _names_to_pspec(names): yield from ((NamedSharding(eqn.params['mesh'], _names_to_pspec(names)), source_info) for names in [*eqn.params['in_names'], *eqn.params['out_names']]) elif eqn.primitive is device_put_p: - s = eqn.params['device'] - if isinstance(s, XLACompatibleSharding) and s.memory_kind is not None: - source_info = SourceInfo(eqn.source_info, eqn.primitive.name) - yield (s, source_info) + source_info = SourceInfo(eqn.source_info, eqn.primitive.name) + yield from ((s, source_info) for s in eqn.params['devices'] + if isinstance(s, Sharding) and s.memory_kind is not None) for subjaxpr in core.subjaxprs(jaxpr): yield from jaxpr_shardings(subjaxpr) @@ -321,10 +323,6 @@ def _check_special(name: str, dtype: np.dtype, buf: basearray.Array) -> None: raise FloatingPointError(f"invalid value (inf) encountered in {name}") -def _put_x(x, s: Sharding, aval: core.AbstractValue, committed: bool): - result_handler = pxla.global_aval_to_result_handler(aval, s, committed) - return result_handler(pxla.shard_arg(x, s)) - def _override_get_device_assignment(sharding, *args, **kwargs): da = sharding._device_assignment return xb.get_device_backend(da[0]), da @@ -332,7 +330,7 @@ def _override_get_device_assignment(sharding, *args, **kwargs): def _identity_fn(x): return x -def _mcjax_reshard(x, target_sharding): +def _different_device_order_reshard(x, target_sharding): from jax._src import api, array inp_sharding = x.sharding @@ -367,7 +365,8 @@ def _mcjax_reshard(x, target_sharding): new_x = array.make_array_from_single_device_arrays( x.shape, - GSPMDSharding(target_sharding._device_assignment, new_hlo_sharding), + GSPMDSharding(target_sharding._device_assignment, new_hlo_sharding, + memory_kind=target_sharding.memory_kind), x._arrays, ) @@ -380,44 +379,62 @@ def _mcjax_reshard(x, target_sharding): pxla._get_and_check_device_assignment.fn = _orig_get_and_check_device_assignment -def _device_put_impl( - x, - device: Device | Sharding | None = None, - src: Device | Sharding | None = None): - from jax._src import array +@dataclasses.dataclass(frozen=True) +class _DeferredShardArg: + """Deferred call to `pxla.shard_args`. - if (isinstance(device, TransferToMemoryKind) or - isinstance(src, TransferToMemoryKind)): - raise ValueError( - "TransferToMemoryKind argument to jax.device_put can only be used" - " inside jax.jit. If you are using device_put outside jax.jit, then" - " please provide a concrete Sharding with memory_kind.") + Per-array impls return this object instead of a result array to indicate a + deferred `shard_args` call. `_batched_device_put_impl` then batches all + `_DeferredShardArg` objects into a single `shard_args` call. + """ - try: - aval = xla.abstractify(x) - except TypeError as err: - raise TypeError( - f"Argument '{x}' of type {type(x)} is not a valid JAX type") from err + x: Any + s: Sharding + aval: core.AbstractValue + committed: bool + + @property + def result_handler(self): + return pxla.global_aval_to_result_handler(self.aval, self.s, self.committed) + + +def _device_put_sharding_impl(x, aval, device): + from jax._src import array + from jax.experimental import multihost_utils if isinstance(device, Sharding): s = device if getattr(x, 'sharding', None) == s and getattr(x, '_committed', False): return x - if (not s.is_fully_addressable and # type: ignore + + if (not s.is_fully_addressable and isinstance(x, array.ArrayImpl) and not x.is_fully_addressable): - # This has to be XLACompatible because _mcjax_reshard will run a - # XLA computation. - assert isinstance(s, XLACompatibleSharding) - return _mcjax_reshard(x, s) - if not s.is_fully_addressable: # type: ignore + assert isinstance(s, Sharding) + return _different_device_order_reshard(x, s) + + if (s.is_fully_addressable and isinstance(x, array.ArrayImpl) and + x.is_fully_addressable and len(s.device_set) > 1 and + s._internal_device_list != x.sharding._internal_device_list and # pytype: disable=attribute-error + s.device_set == x.sharding.device_set): + assert isinstance(s, Sharding) + return _different_device_order_reshard(x, s) + + if not s.is_fully_addressable: + if ((isinstance(x, array.ArrayImpl) and not x._committed) or + type(x) in array_types): + # TODO(yashkatariya): Move this check to `jit`. + multihost_utils.assert_equal( + x, fail_message=( + f"{type(x)} passed to device_put is not the same on each" + " process. Make sure you are passing the same value of" + f" {type(x)} on each process.")) + return api.jit(_identity_fn, out_shardings=s)(x) # TODO(yashkatariya,mattjj): Link to a doc about McJAX and jax.Array. raise ValueError( "device_put's second argument must be a Device or a Sharding which" - f" represents addressable devices, but got {s}. You are probably" - " trying to use device_put in multi-controller JAX which is not" - " supported. Please use jax.make_array_from_single_device_arrays API" - " or pass device or Sharding which represents addressable devices.") - return _put_x(x, s, aval, True) + f" represents addressable devices, but got {s}. Please pass device or" + " Sharding which represents addressable devices.") + return _DeferredShardArg(x, s, aval, True) # Only `Device` exists below. `Sharding` instance is handled above. if isinstance(x, array.ArrayImpl): @@ -433,37 +450,135 @@ def _device_put_impl( sh = SingleDeviceSharding(pxla._get_default_device() if device is None else device) - return _put_x(x, sh, aval, device is not None) + return _DeferredShardArg(x, sh, aval, device is not None) + + +def _device_put_impl( + x, + *, + device: Device | Sharding | Layout | None, + src: Device | Sharding | Layout | None, +): + if (isinstance(device, TransferToMemoryKind) or + isinstance(src, TransferToMemoryKind)): + raise ValueError( + "TransferToMemoryKind argument to jax.device_put can only be used" + " inside jax.jit. If you are using device_put outside jax.jit, then" + " please provide a concrete Sharding with memory_kind.") + + try: + aval = xla.abstractify(x) + except TypeError as err: + raise TypeError( + f"Argument '{x}' of type {type(x)} is not a valid JAX type") from err + + if isinstance(device, Layout): + l = device + dll = l.device_local_layout + x_dll = x.layout.device_local_layout if hasattr(x, 'layout') else None + if dll is None and l.sharding is None: + return _device_put_sharding_impl(x, aval, l.sharding) + if (not isinstance(l.sharding, Sharding) or + not isinstance(dll, (DeviceLocalLayout, type(None)))): + raise ValueError( + "sharding and device_local_layout in `Layout` instance should be" + f" concrete. Got layout: {l} for input {aval.str_short()}") + if getattr(x, 'layout', None) == l and getattr(x, '_committed', False): + return x + if x_dll is None and dll is None: + return _device_put_sharding_impl(x, aval, l.sharding) + return api.jit(_identity_fn, out_shardings=l)(x) + + return _device_put_sharding_impl(x, aval, device) + + +def _batched_device_put_impl( + *xs, + devices: Sequence[Device | Sharding | Layout | None], + srcs: Sequence[Device | Sharding | Layout | None], +): + ys = [] + shard_arg_indices, shard_arg_xs, shard_arg_shardings = [], [], [] + for i, (x, device, src) in enumerate(zip(xs, devices, srcs)): + y = _device_put_impl(x, device=device, src=src) + if isinstance(y, _DeferredShardArg): + shard_arg_indices.append(i) + shard_arg_xs.append(y.x) + shard_arg_shardings.append(y.s) + ys.append(y) + + if shard_arg_xs: + # Batch shard_arg calls. Helps improve efficiency for backends that support + # efficient batch transfer. + shard_arg_results = pxla.shard_args(shard_arg_shardings, shard_arg_xs) + for i, shard_arg_result in zip(shard_arg_indices, shard_arg_results): + assert isinstance(ys[i], _DeferredShardArg) + ys[i] = ys[i].result_handler(shard_arg_result) + + return ys device_put_p = core.Primitive('device_put') -device_put_p.def_impl(_device_put_impl) -device_put_p.def_abstract_eval(lambda x, device=None, src=None: x) - -def device_put_transpose_rule(ct, _, device, src): - return [device_put_p.bind(ct, device=src, src=device)] -ad.deflinear2(device_put_p, device_put_transpose_rule) -batching.defvectorized(device_put_p) - -def _tpu_device_put_lowering(ctx, x, *, device, src): - if (isinstance(device, (XLACompatibleSharding, TransferToMemoryKind)) and - device.memory_kind is not None): - aval, = ctx.avals_in - out_aval, = ctx.avals_out - x = mlir.wrap_with_memory_kind(x, device.memory_kind, out_aval) - if isinstance(device, XLACompatibleSharding): - x = mlir.wrap_with_sharding_op( - ctx, x, out_aval, device._to_xla_hlo_sharding(aval.ndim).to_proto()) - return [x] - return [x] -mlir.register_lowering(device_put_p, _tpu_device_put_lowering, platform='tpu') - - -def _common_device_put_lowering(ctx, x, *, device, src): - if (isinstance(device, (XLACompatibleSharding, TransferToMemoryKind)) and - device.memory_kind is not None): - raise NotImplementedError( - "Passing memory_kind to device_put via Shardings is not supported on" - f" platforms {ctx.module_context.platforms}") - return [x] +device_put_p.multiple_results = True +device_put_p.def_impl(_batched_device_put_impl) +device_put_p.def_abstract_eval(lambda *xs, devices, srcs: xs) + +def _device_put_transpose(cts, *_, devices, srcs): + results = [None] * len(cts) + dp_args = [] + for i, (ct, device, src) in enumerate(zip(cts, devices, srcs)): + if type(ct) is not ad.Zero: + dp_args.append((i, ct, device, src)) + if dp_args: + indices, args, devices, srcs = list(zip(*dp_args)) + ys = device_put_p.bind(*args, devices=srcs, srcs=devices) + for i, y in zip(indices, ys): + results[i] = y + return results +ad.primitive_jvps[device_put_p] = partial(ad.linear_jvp, device_put_p) +ad.primitive_transposes[device_put_p] = _device_put_transpose + +def _device_put_batcher(batched_args, batch_dims, **params): + mapped_batch_dims = [bd for bd in batch_dims if bd is not batching.not_mapped] + assert not mapped_batch_dims or all( + mapped_batch_dims[0] == bd for bd in mapped_batch_dims[1:] + ), batch_dims + return device_put_p.bind(*batched_args, **params), batch_dims +batching.primitive_batchers[device_put_p] = _device_put_batcher + +def _tpu_gpu_device_put_lowering(ctx, *xs, devices, srcs): + def lower(x, device, src, aval, out_aval): + if (isinstance(device, (Sharding, TransferToMemoryKind)) and + device.memory_kind is not None): + if isinstance(device, Sharding): + x = mlir.wrap_with_sharding_op( + ctx, x, out_aval, device._to_xla_hlo_sharding(aval.ndim).to_proto()) + x = mlir.wrap_with_memory_kind(x, device.memory_kind, out_aval) + return x + return x + return list(map(lower, xs, devices, srcs, ctx.avals_in, ctx.avals_out)) +mlir.register_lowering( + device_put_p, _tpu_gpu_device_put_lowering, platform='tpu') +mlir.register_lowering( + device_put_p, _tpu_gpu_device_put_lowering, platform='gpu') + + +def _common_device_put_lowering(ctx, *xs, devices, srcs): + for device in devices: + if (isinstance(device, (Sharding, TransferToMemoryKind)) and + device.memory_kind is not None): + raise NotImplementedError( + "Passing memory_kind to device_put via Shardings is not supported on" + f" platforms {ctx.module_context.platforms}") + return xs mlir.register_lowering(device_put_p, _common_device_put_lowering) + +def _propagate_mem_kind_dp(*xm, devices=None, srcs=None): + memory_kinds = [] + for device in devices: + if isinstance(device, (Sharding, TransferToMemoryKind)): + memory_kinds.append(device.memory_kind) + else: + memory_kinds.append(None) + return memory_kinds +pxla.memory_kind_propagate_rule[device_put_p] = _propagate_mem_kind_dp diff --git a/jax/_src/distributed.py b/jax/_src/distributed.py index af9a80856532..5e8e956cf98b 100644 --- a/jax/_src/distributed.py +++ b/jax/_src/distributed.py @@ -41,15 +41,23 @@ def initialize(self, num_processes: int | None = None, process_id: int | None = None, local_device_ids: int | Sequence[int] | None = None, - initialization_timeout: int = 300): + cluster_detection_method: str | None = None, + initialization_timeout: int = 300, + coordinator_bind_address: str | None = None): coordinator_address = (coordinator_address or os.environ.get('JAX_COORDINATOR_ADDRESS', None)) if isinstance(local_device_ids, int): local_device_ids = [local_device_ids] + (coordinator_address, num_processes, process_id, local_device_ids) = ( clusters.ClusterEnv.auto_detect_unset_distributed_params( - coordinator_address, num_processes, process_id, local_device_ids + coordinator_address, + num_processes, + process_id, + local_device_ids, + cluster_detection_method, + initialization_timeout, ) ) @@ -62,20 +70,43 @@ def initialize(self, self.coordinator_address = coordinator_address + # The default value of [::]:port tells the coordinator to bind to all + # available addresses on the same port as coordinator_address. + default_coordinator_bind_address = '[::]:' + coordinator_address.rsplit(':', 1)[1] + coordinator_bind_address = (coordinator_bind_address or + os.environ.get('JAX_COORDINATOR_BIND_ADDRESS', + default_coordinator_bind_address)) + if coordinator_bind_address is None: + raise ValueError('coordinator_bind_address should be defined.') + if local_device_ids: - visible_devices = ','.join(str(x) for x in local_device_ids) # type: ignore[union-attr] + visible_devices = ','.join(str(x) for x in local_device_ids) logger.info('JAX distributed initialized with visible devices: %s', visible_devices) config.update("jax_cuda_visible_devices", visible_devices) config.update("jax_rocm_visible_devices", visible_devices) self.process_id = process_id + # Emit a warning about PROXY variables if they are in the user's env: + proxy_vars = [ key for key in os.environ.keys() if '_proxy' in key.lower()] + + if len(proxy_vars) > 0: + vars = " ".join(proxy_vars) + ". " + warning = ( + f'JAX detected proxy variable(s) in the environment as distributed setup: {vars}' + 'On some systems, this may cause a hang of distributed.initialize and ' + 'you may need to unset these ENV variable(s)' + ) + logger.warning(warning) + if process_id == 0: if self.service is not None: raise RuntimeError('distributed.initialize should only be called once.') - logger.info('Starting JAX distributed service on %s', coordinator_address) + logger.info( + 'Starting JAX distributed service on %s', coordinator_bind_address + ) self.service = xla_extension.get_distributed_runtime_service( - coordinator_address, num_processes) + coordinator_bind_address, num_processes) self.num_processes = num_processes @@ -114,7 +145,9 @@ def initialize(coordinator_address: str | None = None, num_processes: int | None = None, process_id: int | None = None, local_device_ids: int | Sequence[int] | None = None, - initialization_timeout: int = 300): + cluster_detection_method: str | None = None, + initialization_timeout: int = 300, + coordinator_bind_address: str | None = None): """Initializes the JAX distributed system. Calling :func:`~jax.distributed.initialize` prepares JAX for execution on @@ -123,16 +156,27 @@ def initialize(coordinator_address: str | None = None, The JAX distributed system serves a number of roles: - * it allows JAX processes to discover each other and share topology information, - * it performs health checking, ensuring that all processes shut down if any process dies, and - * it is used for distributed checkpointing. + * It allows JAX processes to discover each other and share topology information, + * It performs health checking, ensuring that all processes shut down if any process dies, and + * It is used for distributed checkpointing. If you are using TPU, Slurm, or Open MPI, all arguments are optional: if omitted, they will be chosen automatically. + The ``cluster_detection_method`` may be used to choose a specific method for detecting those + distributed arguments. You may pass any of the automatic ``spec_detect_methods`` to this + argument though it is not necessary in the TPU, Slurm, or Open MPI cases. For other MPI + installations, if you have a functional ``mpi4py`` installed, you may pass + ``cluster_detection_method="mpi4py"`` to bootstrap the required arguments. + Otherwise, you must provide the ``coordinator_address``, ``num_processes``, and ``process_id`` arguments to :func:`~jax.distributed.initialize`. + Please note: on some systems, particularly HPC clusters that only access external networks + through proxy variables such as HTTP_PROXY, HTTPS_PROXY, etc., the call to + :func:`~jax.distributed.initialize` may timeout. You may need to unset these variables + prior to application launch. + Args: coordinator_address: the IP address of process `0` and a port on which that process should launch a coordinator service. The choice of @@ -149,14 +193,24 @@ def initialize(coordinator_address: str | None = None, local_device_ids: Restricts the visible devices of the current process to ``local_device_ids``. If ``None``, defaults to all local devices being visible to the process except when processes are launched via Slurm and Open MPI on GPUs. In that case, it will default to a single device per process. + cluster_detection_method: An optional string to attempt to autodetect the configuration of the distributed + run. Note that "mpi4py" method requires you to have a working ``mpi4py`` install in your environment, + and launch the applicatoin with an MPI-compatible job launcher such as ``mpiexec`` or ``mpirun``. + Legacy auto-detect options (OMPI, Slurm) remain enabled. initialization_timeout: Time period (in seconds) for which connection will be retried. If the initialization takes more than the timeout specified, the initialization will error. Defaults to 300 secs i.e. 5 mins. + coordinator_bind_address: the address and port to which the coordinator service + on process `0` should bind. If this is not specified, the default is to bind to + all available addresses on the same port as ``coordinator_address``. On systems + that have multiple network interfaces per node it may be insufficient to only + have the coordinator service listen on one address/interface. Raises: - RuntimeError: If :func:`~jax.distributed.initialize` is called more than once. + RuntimeError: If :func:`~jax.distributed.initialize` is called more than once + or if called after the backend is already initialized. - Example: + Examples: Suppose there are two GPU processes, and process 0 is the designated coordinator with address ``10.0.0.1:1234``. To initialize the GPU cluster, run the @@ -174,7 +228,8 @@ def initialize(coordinator_address: str | None = None, raise RuntimeError("jax.distributed.initialize() must be called before " "any JAX computations are executed.") global_state.initialize(coordinator_address, num_processes, process_id, - local_device_ids, initialization_timeout) + local_device_ids, cluster_detection_method, + initialization_timeout, coordinator_bind_address) atexit.register(shutdown) diff --git a/jax/_src/dlpack.py b/jax/_src/dlpack.py index 83dc893a9515..386123ae61f0 100644 --- a/jax/_src/dlpack.py +++ b/jax/_src/dlpack.py @@ -14,17 +14,19 @@ from __future__ import annotations -import enum from typing import Any -import warnings +from jax._src.api import device_put from jax import numpy as jnp from jax._src import array from jax._src import xla_bridge +from jax._src.lax.lax import _array_copy from jax._src.lib import xla_client -from jax._src.lib import xla_extension_version -from jax._src.typing import Array +from jax._src.typing import Array, DLDeviceType +from jax._src.sharding import Sharding +DLPACK_VERSION = (0, 8) +MIN_DLPACK_VERSION = (0, 5) # A set of dtypes that dlpack supports. # Note: Make sure to use a "type", not a dtype instance, when looking up this set @@ -36,62 +38,221 @@ SUPPORTED_DTYPES = frozenset({ jnp.int8, jnp.int16, jnp.int32, jnp.int64, jnp.uint8, jnp.uint16, jnp.uint32, jnp.uint64, jnp.float16, jnp.bfloat16, jnp.float32, - jnp.float64, jnp.complex64, jnp.complex128}) + jnp.float64, jnp.complex64, jnp.complex128, jnp.bool_}) -if xla_extension_version >= 231: - SUPPORTED_DTYPES = SUPPORTED_DTYPES | frozenset({jnp.bool_}) +def _to_dlpack(x: Array, stream: int | Any | None, + src_device: xla_client.Device | None = None, + device: xla_client.Device | None = None, + copy: bool | None = None): -# Mirror of dlpack.h enum -class DLDeviceType(enum.IntEnum): - kDLCPU = 1 - kDLCUDA = 2 - kDLROCM = 10 - + if src_device is None: + src_device, = x.devices() + if device and (src_device is None or device != src_device): + if copy is not None and not copy: + raise ValueError( + f"Specified {device=} which requires a copy since the source device " + f"is {repr(src_device)}, however copy=False. Set copy=True or " + "copy=None to perform the requested operation." + ) + else: + arr = device_put(x, device) + else: + arr = _array_copy(x) if copy else x + return xla_client._xla.buffer_to_dlpack_managed_tensor( + arr.addressable_data(0), stream=stream + ) -def to_dlpack(x: Array, take_ownership: bool = False, - stream: int | Any | None = None): +def to_dlpack(x: Array, stream: int | Any | None = None, + src_device: xla_client.Device | None = None, + dl_device: tuple[DLDeviceType, int] | None = None, + max_version: tuple[int, int] | None = None, + copy : bool | None = None): """Returns a DLPack tensor that encapsulates a :class:`~jax.Array` ``x``. Args: x: a :class:`~jax.Array`, on either CPU or GPU. - take_ownership: Deprecated. It is a no-op to set take_ownership. Will be - deleted in 01/2024. stream: optional platform-dependent stream to wait on until the buffer is ready. This corresponds to the `stream` argument to ``__dlpack__`` documented in https://dmlc.github.io/dlpack/latest/python_spec.html. + src_device: either a CPU or GPU :class:`~jax.Device`. + dl_device: a tuple of ``(dl_device_type, local_hardware_id)`` in DLPack + format e.g. as produced by ``__dlpack_device__``. + max_version: the maximum DLPack version that the consumer (i.e. caller of + ``__dlpack__``) supports in the form of a 2-tuple of ``(major, minor)``. + This function is not guaranteed to return a capsule of version + ``max_version``. + copy: a boolean indicating whether or not to copy the input. If + ``copy=True`` then the function must always copy. When + ``copy=False`` then the function must never copy, and must raise an error + when a copy is deemed necessary. If ``copy=None`` then the function must + avoid a copy if possible but may copy if needed. Returns: - A dlpack PyCapsule object. + A DLPack PyCapsule object. Note: - While JAX arrays are always immutable, dlpack buffers cannot be marked as - immutable, and it is possible for processes external to JAX to mutate them - in-place. If a dlpack buffer derived from a JAX array is mutated, it may - lead to undefined behavior when using the associated JAX array. + While JAX arrays are always immutable, ``DLPackManagedTensor`` buffers + cannot be marked as immutable, and it is possible for processes external + to JAX to mutate them in-place. If a DLPack buffer derived from a JAX array + is mutated, it may lead to undefined behavior when using the associated JAX + array. When JAX eventually supports ``DLManagedTensorVersioned`` + (DLPack 1.0), it will be possible to specify that a buffer is read-only. """ if not isinstance(x, array.ArrayImpl): raise TypeError("Argument to to_dlpack must be a jax.Array, " f"got {type(x)}") - assert len(x.devices()) == 1 - if take_ownership: - warnings.warn( - "take_ownership in to_dlpack is deprecated and it is a no-op." + + device = None + dl_device_type, local_hardware_id = dl_device if dl_device else (None, None) + if dl_device_type: + try: + dl_device_platform = { + DLDeviceType.kDLCPU: "cpu", + DLDeviceType.kDLCUDA: "cuda", + DLDeviceType.kDLROCM: "rocm", + }[dl_device_type] + backend = xla_bridge.get_backend(dl_device_platform) + device = backend.device_from_local_hardware_id(local_hardware_id) + except TypeError: + # https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.__dlpack__.html + # recommends using BufferError. + raise BufferError( + "The device specification passed to to_dlpack contains an unsupported " + f"device type (DLDeviceType: {dl_device_type})") + + # As new versions are adopted over time, we can maintain some legacy paths + # for compatability mediated through the max_version parameter. + # TODO(micky774): Deprecate default usage of DLPackManagedTensor when XLA + # supports DLManagedTensorVersioned (DLPack version 1.0) and repurpose the + # current _to_dlpack as a legacy path for (0,5) <= max_version < (1,0). + if max_version is None or max_version >= DLPACK_VERSION: + # Latest + return _to_dlpack( + x, stream=stream, + src_device=src_device, + device=device, + copy=copy ) - return xla_client._xla.buffer_to_dlpack_managed_tensor( - x.addressable_data(0), stream=stream - ) # type: ignore + elif max_version >= MIN_DLPACK_VERSION: + # Oldest supported + return _to_dlpack( + x, stream=stream, + src_device=src_device, + device=device, + copy=copy + ) + else: + raise BufferError( + f"JAX does not support any version below {MIN_DLPACK_VERSION} but " + f"version ({max_version}) was requested." + ) + +def _place_array(_arr, device, dlpack_device, copy): + if device and dlpack_device != device: + if copy is not None and not copy: + raise ValueError( + f"Specified {device=} which requires a copy since the source device " + f"is {repr(dlpack_device)}, however copy=False. Set copy=True or " + "copy=None to perform the requested operation." + ) + else: + return device_put(_arr, device) + if copy: + return jnp.array(_arr, copy=True) + return _arr + +def _legacy_from_dlpack(dlpack, device: xla_client.Device | None = None, + copy: bool | None = None): + preferred_platform = getattr(device, "platform", None) + if device and preferred_platform == "gpu": + preferred_platform = "cuda" if "cuda" in device.client.platform_version else "rocm" + cpu_backend = xla_bridge.get_backend("cpu") + gpu_backend = None + + if preferred_platform in {"cuda", "rocm"}: + try: + gpu_backend = xla_bridge.get_backend(preferred_platform) + except RuntimeError: + raise TypeError( + f"A {str.upper(preferred_platform)} device was specified, however no " + f"{str.upper(preferred_platform)} backend was found." + ) -def from_dlpack(external_array): + if preferred_platform is None: + try: + gpu_backend = xla_bridge.get_backend("cuda") + except RuntimeError: + pass + # Try ROCm if CUDA backend not found + if gpu_backend is None: + try: + gpu_backend = xla_bridge.get_backend("rocm") + except RuntimeError: + pass + + _arr = jnp.asarray(xla_client._xla.dlpack_managed_tensor_to_buffer( + dlpack, cpu_backend, gpu_backend)) + dlpack_device, = _arr.devices() + return _place_array(_arr, device, dlpack_device, copy) + +def _from_dlpack(external_array, device: xla_client.Device | None = None, + copy: bool | None = None): + dl_device_type, device_id = external_array.__dlpack_device__() + try: + dl_device_platform = { + DLDeviceType.kDLCPU: "cpu", + DLDeviceType.kDLCUDA: "cuda", + DLDeviceType.kDLROCM: "rocm", + }[dl_device_type] + except TypeError: + # https://dmlc.github.io/dlpack/latest/python_spec.html recommends using + # TypeError. + raise TypeError( + "Array passed to from_dlpack is on unsupported device type " + f"(DLDeviceType: {dl_device_type}, array: {external_array}") + + backend = xla_bridge.get_backend(dl_device_platform) + dlpack_device = backend.device_from_local_hardware_id(device_id) + try: + stream = dlpack_device.get_stream_for_external_ready_events() + except xla_client.XlaRuntimeError as err: + if "UNIMPLEMENTED" in str(err): + stream = None + else: + raise + dlpack = external_array.__dlpack__(stream=stream) + + _arr = jnp.asarray(xla_client._xla.dlpack_managed_tensor_to_buffer( + dlpack, dlpack_device, stream)) + return _place_array(_arr, device, dlpack_device, copy) + +def from_dlpack(external_array, + device: xla_client.Device | Sharding | None = None, + copy: bool | None = None): """Returns a :class:`~jax.Array` representation of a DLPack tensor. - The returned :class:`~jax.Array` shares memory with ``external_array``. + The returned :class:`~jax.Array` shares memory with ``external_array`` if no + device transfer or copy was requested. Args: - external_array: an array object that has __dlpack__ and __dlpack_device__ + external_array: An array object that has __dlpack__ and __dlpack_device__ methods, or a DLPack tensor on either CPU or GPU (legacy API). + device: The (optional) :py:class:`Device`, representing the device on which + the returned array should be placed. If given, then the result is committed + to the device. If unspecified, the resulting array will be unpacked onto the + same device it originated from. Setting ``device`` to a device different from + the source of ``external_array`` will require a copy, meaning ``copy`` must be + set to either ``True`` or ``None``. + + copy: An (optional) boolean, controlling whether or not a copy is performed. + If ``copy=True`` then a copy is always performed, even if unpacked onto the + same device. If ``copy=False`` then the copy is never performed and will raise + an error if necessary. When ``copy=None`` then a copy may be performed if + needed for a device transfer. + Returns: A jax.Array @@ -102,49 +263,16 @@ def from_dlpack(external_array): is later modified in-place, it may lead to undefined behavior when using the associated JAX array. """ + if isinstance(device, Sharding): + device_set = device.device_set + if len(device_set) > 1: + raise ValueError( + "from_dlpack can only unpack a dlpack tensor onto a singular device, but " + f"a Sharding with {len(device_set)} devices was provided." + ) + device, = device_set if hasattr(external_array, "__dlpack__"): - dl_device_type, device_id = external_array.__dlpack_device__() - try: - device_platform = { - DLDeviceType.kDLCPU: "cpu", - DLDeviceType.kDLCUDA: "cuda", - DLDeviceType.kDLROCM: "rocm", - }[dl_device_type] - except TypeError: - # https://dmlc.github.io/dlpack/latest/python_spec.html recommends using - # TypeError. - raise TypeError( - "Array passed to from_dlpack is on unsupported device type " - f"(DLDeviceType: {dl_device_type}, array: {external_array}") - - backend = xla_bridge.get_backend(device_platform) - device = backend.device_from_local_hardware_id(device_id) - try: - stream = device.get_stream_for_external_ready_events() - except xla_client.XlaRuntimeError as err: # type: ignore - if "UNIMPLEMENTED" in str(err): - stream = None - else: - raise - dlpack = external_array.__dlpack__(stream=stream) - - return jnp.asarray(xla_client._xla.dlpack_managed_tensor_to_buffer( - dlpack, device, stream)) - else: - # Legacy path - dlpack = external_array - cpu_backend = xla_bridge.get_backend("cpu") - try: - gpu_backend = xla_bridge.get_backend("cuda") - except RuntimeError: - gpu_backend = None - - # Try ROCm if CUDA backend not found - if gpu_backend is None: - try: - gpu_backend = xla_bridge.get_backend("rocm") - except RuntimeError: - gpu_backend = None + return _from_dlpack(external_array, device, copy) - return jnp.asarray(xla_client._xla.dlpack_managed_tensor_to_buffer( - dlpack, cpu_backend, gpu_backend)) + # Legacy path + return _legacy_from_dlpack(external_array, device, copy) diff --git a/jax/_src/dtypes.py b/jax/_src/dtypes.py index a79ae5e2b9c9..1d50c5be74b6 100644 --- a/jax/_src/dtypes.py +++ b/jax/_src/dtypes.py @@ -31,7 +31,7 @@ import numpy as np from jax._src import config -from jax._src.typing import DType, DTypeLike +from jax._src.typing import Array, DType, DTypeLike from jax._src.util import set_module, StrictABC from jax._src import traceback_util @@ -340,22 +340,36 @@ def issubdtype(a: DTypeLike | None, b: DTypeLike | None) -> bool: # don't conform to the standard numpy type hierarchy (e.g. the bfloat16 scalar # type is not a subclass of np.floating) so we must also handle these specially. - # First handle extended dtypes. This is important for performance because - # isinstance(x, extended) is called frequently within JAX internals. - if _issubclass(b, extended): + # We cannot use the cached version directly for all inputs, because some may be + # unhashable (e.g. custom objects with a dtype attribute). The following check is + # fast and covers the majority of calls to this function within JAX library code. + return _issubdtype_cached( + a if isinstance(a, (type, np.dtype, ExtendedDType)) else np.dtype(a), # type: ignore[arg-type] + b if isinstance(b, (type, np.dtype, ExtendedDType)) else np.dtype(b), # type: ignore[arg-type] + ) + + +@functools.lru_cache(512) # don't use util.memoize because there is no X64 dependence. +def _issubdtype_cached(a: type | np.dtype | ExtendedDType, + b: type | np.dtype | ExtendedDType) -> bool: + # First handle extended dtypes, which require their own logic. + a_is_type = isinstance(a, type) + b_is_type = isinstance(b, type) + if b_is_type and _issubclass(b, extended): if isinstance(a, ExtendedDType): return _issubclass(a.type, b) - if _issubclass(a, np.generic): + if a_is_type and _issubclass(a, np.generic): return _issubclass(a, b) return _issubclass(np.dtype(a).type, b) if isinstance(b, ExtendedDType): return isinstance(a, ExtendedDType) and a == b if isinstance(a, ExtendedDType): a = a.type + a_is_type = isinstance(a, type) # For all others, normalize inputs to scalar types. - a_sctype = a if _issubclass(a, np.generic) else np.dtype(a).type - b_sctype = b if _issubclass(b, np.generic) else np.dtype(b).type + a_sctype = a if a_is_type and _issubclass(a, np.generic) else np.dtype(a).type + b_sctype = b if b_is_type and _issubclass(b, np.generic) else np.dtype(b).type # Now do special handling of custom float and int types, as they don't conform # to the normal scalar type hierarchy. @@ -421,7 +435,7 @@ def issubdtype(a: DTypeLike | None, b: DTypeLike | None) -> bool: } -def isdtype(dtype: DTypeLike, kind: str | DType | tuple[str | DType, ...]) -> bool: +def isdtype(dtype: DTypeLike, kind: str | DTypeLike | tuple[str | DTypeLike, ...]) -> bool: """Returns a boolean indicating whether a provided dtype is of a specified kind. Args: @@ -444,18 +458,25 @@ def isdtype(dtype: DTypeLike, kind: str | DType | tuple[str | DType, ...]) -> bo True or False """ the_dtype = np.dtype(dtype) - kind_tuple: tuple[DType | str, ...] = kind if isinstance(kind, tuple) else (kind,) + kind_tuple: tuple[str | DTypeLike, ...] = ( + kind if isinstance(kind, tuple) else (kind,) + ) options: set[DType] = set() for kind in kind_tuple: - if isinstance(kind, str): - if kind not in _dtype_kinds: - raise ValueError(f"Unrecognized {kind=} expected one of {list(_dtype_kinds.keys())}") + if isinstance(kind, str) and kind in _dtype_kinds: options.update(_dtype_kinds[kind]) - elif isinstance(kind, np.dtype): - options.add(kind) - else: - # TODO(jakevdp): should we handle scalar types or ScalarMeta here? - raise TypeError(f"Expected kind to be a dtype, string, or tuple; got {kind=}") + continue + try: + _dtype = np.dtype(kind) + except TypeError as e: + if isinstance(kind, str): + raise ValueError( + f"Unrecognized {kind=} expected one of {list(_dtype_kinds.keys())}, " + "or a compatible input for jnp.dtype()") + raise TypeError( + f"Expected kind to be a dtype, string, or tuple; got {kind=}" + ) from e + options.add(_dtype) return the_dtype in options @@ -620,10 +641,12 @@ def promote_types(a: DTypeLike, b: DTypeLike) -> DType: return np.dtype(_least_upper_bound(config.numpy_dtype_promotion.value, a_tp, b_tp)) def is_weakly_typed(x: Any) -> bool: + if type(x) in _weak_types: + return True try: return x.aval.weak_type except AttributeError: - return type(x) in _weak_types + return False def is_python_scalar(x: Any) -> bool: try: @@ -640,11 +663,12 @@ def dtype(x: Any, *, canonicalize: bool = False) -> DType: """Return the dtype object for a value or type, optionally canonicalized based on X64 mode.""" if x is None: raise ValueError(f"Invalid argument to dtype: {x}.") - elif isinstance(x, type) and x in python_scalar_dtypes: + is_type = isinstance(x, type) + if is_type and x in python_scalar_dtypes: dt = python_scalar_dtypes[x] elif type(x) in python_scalar_dtypes: dt = python_scalar_dtypes[type(x)] - elif _issubclass(x, np.generic): + elif is_type and _issubclass(x, np.generic): return np.dtype(x) elif issubdtype(getattr(x, 'dtype', None), extended): dt = x.dtype @@ -717,6 +741,12 @@ def result_type(*args: Any, return_weak_type_flag: bool = False) -> DType | tupl return (dtype, weak_type) if return_weak_type_flag else dtype # type: ignore[return-value] def check_user_dtype_supported(dtype, fun_name=None): + if isinstance(dtype, Array): + # Deprecation warning added 2024 June 13. + warnings.warn("Passing an array as a dtype argument is deprecated; " + "instead of dtype=arr use dtype=arr.dtype.", + category=DeprecationWarning, stacklevel=3) + return # no further check needed, as array dtypes have already been validated. if issubdtype(dtype, extended): return # Avoid using `dtype in [...]` because of numpy dtype equality overloading. @@ -728,14 +758,14 @@ def check_user_dtype_supported(dtype, fun_name=None): msg = f"JAX only supports number and bool dtypes, got dtype {dtype}" msg += f" in {fun_name}" if fun_name else "" raise TypeError(msg) - if dtype is not None and np_dtype != canonicalize_dtype(dtype): + if dtype is not None and np_dtype != canonicalize_dtype(np_dtype): msg = ("Explicitly requested dtype {} {} is not available, " "and will be truncated to dtype {}. To enable more dtypes, set the " "jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell " "environment variable. " "See https://github.com/google/jax#current-gotchas for more.") fun_name = f"requested in {fun_name}" if fun_name else "" - truncated_dtype = canonicalize_dtype(dtype).name + truncated_dtype = canonicalize_dtype(np_dtype).name warnings.warn(msg.format(dtype, fun_name, truncated_dtype), stacklevel=3) def safe_to_cast(input_dtype_or_value: Any, diff --git a/jax/_src/earray.py b/jax/_src/earray.py new file mode 100644 index 000000000000..f4b5e232bc33 --- /dev/null +++ b/jax/_src/earray.py @@ -0,0 +1,112 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import math + +from jax._src import api_util +from jax._src import basearray +from jax._src import core +from jax._src import tree_util +from jax._src import sharding_impls +from jax._src.interpreters import pxla +from jax._src.interpreters import xla +from jax._src.util import safe_zip, safe_map + +map, unsafe_map = safe_map, map +zip, unsafe_zip = safe_zip, zip + +# EArray is an Array that can contain extended dtypes. +class EArray(basearray.Array): + __slots__ = ['aval', '_data'] + __hash__ = None # type: ignore[assignment] + __array_priority__ = 100 + + def __init__(self, aval, data): + self.aval = aval + self._data = data + + def block_until_ready(self): + _ = self._data.block_until_ready() + return self + + def copy_to_host_async(self): + self._data.copy_to_host_async() + + def copy(self): + return EArray(self.aval, self._data.copy()) + + def __repr__(self): + return 'E' + repr(self._data) + + def __iter__(self): + if self.ndim == 0: raise TypeError('iteration over a 0-d array') + raise NotImplementedError + + # forward to aval + shape = property(lambda self: self.aval.shape) # type: ignore[assignment] + dtype = property(lambda self: self.aval.dtype) # type: ignore[assignment] + + # computed from shape and dtype + ndim = property(lambda self: len(self.aval.shape)) # type: ignore[assignment] + size = property(lambda self: math.prod(self.aval.shape)) # type: ignore[assignment] + itemsize = property(lambda self: self.aval.dtype.itemsize) # type: ignore[assignment] + def __len__(self): + if self.ndim == 0: raise TypeError('len() of unsized object') + return self.shape[0] + + # forward to self._data + devices = property(lambda self: self._data.devices) # type: ignore[assignment] + _committed = property(lambda self: self._data._committed) + is_fully_addressable = property(lambda self: self._data.is_fully_addressable) # type: ignore[assignment] + is_fully_replicated = property(lambda self: self._data.is_fully_replicated) # type: ignore[assignment] + delete = property(lambda self: self._data.delete) # type: ignore[assignment] + is_deleted = property(lambda self: self._data.is_deleted) # type: ignore[assignment] + on_device_size_in_bytes = property(lambda self: self._data.on_device_size_in_bytes) # type: ignore[assignment] + unsafe_buffer_pointer = property(lambda self: self._data.unsafe_buffer_pointer) # type: ignore[assignment] + + # defer to extended dtype rules + @property + def sharding(self): + phys_sharding = self._data.sharding + return sharding_impls.logical_sharding(self.aval, phys_sharding) + + # TODO(mattjj): not implemented below here, need more methods from ArrayImpl + + def addressable_data(self, index: int) -> EArray: + raise NotImplementedError + + @property + def addressable_shards(self): + raise NotImplementedError + + @property + def global_shards(self): + raise NotImplementedError + +# TODO(mattjj): _set_array_base_attributes + +def _earray_shard_arg_handler(xs, shardings): + arrs = [x._data for x in xs] + phys_shardings = [sharding_impls.physical_sharding(x.aval, sharding) + for x, sharding in zip(xs, shardings)] + return pxla.shard_args(phys_shardings, arrs) +pxla.shard_arg_handlers[EArray] = _earray_shard_arg_handler + +api_util._shaped_abstractify_handlers[EArray] = lambda self: self.aval +core.pytype_aval_mappings[EArray] = lambda x: x.aval +xla.canonicalize_dtype_handlers[EArray] = lambda x: x +tree_util.dispatch_registry.register_node( + EArray, lambda x: ((x._data,), x.aval), lambda a, xs: EArray(a, xs[0])) diff --git a/jax/_src/errors.py b/jax/_src/errors.py index 5594b261abcd..590f68ac0b3b 100644 --- a/jax/_src/errors.py +++ b/jax/_src/errors.py @@ -581,7 +581,7 @@ class UnexpectedTracerError(JAXTypeError): code by including information about each stage. Respectively: 1. The name of the transformed function (``side_effecting``) and which - transform kicked of the trace :func:`~jax.jit`). + transform kicked off the trace :func:`~jax.jit`). 2. A reconstructed stack trace of where the leaked Tracer was created, which includes where the transformed function was called. (``When the Tracer was created, the final 5 stack frames were...``). @@ -589,7 +589,7 @@ class UnexpectedTracerError(JAXTypeError): the leaked Tracer. 4. The leak location is not included in the error message because it is difficult to pin down! JAX can only tell you what the leaked value - looks like (what shape is has and where it was created) and what + looks like (what shape it has and where it was created) and what boundary it was leaked over (the name of the transformation and the name of the transformed function). 5. The current error's stack trace points to where the value is used. @@ -655,3 +655,29 @@ class UnexpectedTracerError(JAXTypeError): def __init__(self, msg: str): super().__init__(msg) + + +@export +class KeyReuseError(JAXTypeError): + """ + This error occurs when a PRNG key is reused in an unsafe manner. + Key reuse is checked only when `jax_debug_key_reuse` is + set to `True`. + + Here is a simple example of code that would lead to such an error:: + + >>> with jax.debug_key_reuse(True): # doctest: +SKIP + ... key = jax.random.key(0) + ... value = jax.random.uniform(key) + ... new_value = jax.random.uniform(key) + ... + --------------------------------------------------------------------------- + KeyReuseError Traceback (most recent call last) + ... + KeyReuseError: Previously-consumed key passed to jit-compiled function at index 0 + + This sort of key reuse is problematic because the JAX PRNG is stateless, and keys + must be manually split; For more information on this see `Sharp Bits: Random Numbers + `_. + """ + pass diff --git a/jaxlib/cpu/_ducc_fft.pyi b/jax/_src/export/__init__.py similarity index 74% rename from jaxlib/cpu/_ducc_fft.pyi rename to jax/_src/export/__init__.py index 7d5c3071adea..862a661e24b9 100644 --- a/jaxlib/cpu/_ducc_fft.pyi +++ b/jax/_src/export/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2023 The JAX Authors. +# Copyright 2024 The JAX Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,6 +11,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -def dynamic_ducc_fft_descriptor(ndims: int, is_double: bool, fft_type: int, axes: list[int], forward: bool) -> bytes: ... -def registrations() -> dict: ... diff --git a/jax/experimental/export/_export.py b/jax/_src/export/_export.py similarity index 56% rename from jax/experimental/export/_export.py rename to jax/_src/export/_export.py index 3c48768c0db4..8581608e3be4 100644 --- a/jax/experimental/export/_export.py +++ b/jax/_src/export/_export.py @@ -17,13 +17,13 @@ from __future__ import annotations -from collections.abc import Sequence +from collections.abc import Callable, Sequence import copy import dataclasses import functools import itertools import re -from typing import Any, Callable, Union +from typing import Any, Union import warnings from absl import logging @@ -47,11 +47,12 @@ from jax._src import pjit from jax._src import sharding_impls from jax._src import source_info_util +from jax._src import stages from jax._src import tree_util from jax._src import util from jax._src import xla_bridge as xb -from jax.experimental.export import _shape_poly +from jax._src.export import shape_poly map = util.safe_map zip = util.safe_zip @@ -59,36 +60,31 @@ DType = Any Shape = jax._src.core.Shape # The values of input and output sharding from the lowering. -LoweringSharding = Union[sharding.XLACompatibleSharding, pxla.UnspecifiedValue] +LoweringSharding = Union[sharding.Sharding, pxla.UnspecifiedValue] +HloSharding = xla_client.HloSharding -# None means unspecified sharding -Sharding = Union[xla_client.HloSharding, None] - -# See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#native-serialization-versions -# for a description of the different versions. -minimum_supported_serialization_version = 6 -maximum_supported_serialization_version = 9 - -_VERSION_START_SUPPORT_SHAPE_ASSERTIONS = 7 -_VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS = 9 +# The minimum and maximum supported calling convention version. +# See https://jax.readthedocs.io/en/latest/export.html#module-calling-convention#calling-conventions-versions +minimum_supported_calling_convention_version = 9 +maximum_supported_calling_convention_version = 9 class DisabledSafetyCheck: - """A safety check should be skipped on (de)serialization. + """A safety check that should be skipped on (de)serialization. Most of these checks are performed on serialization, but some are deferred to deserialization. The list of disabled checks is attached to the serialization, - e.g., as a sequence of string attributes to `jax_export.Exported` or of + e.g., as a sequence of string attributes to `jax.export.Exported` or of `tf.XlaCallModuleOp`. - You can disable more deserialization safety checks by passing - `TF_XLA_FLAGS=--tf_xla_call_module_disabled_checks=platform`. + When using jax2tf, you can disable more deserialization safety checks + by passing `TF_XLA_FLAGS=--tf_xla_call_module_disabled_checks=platform`. """ _impl: str @classmethod def platform(cls) -> DisabledSafetyCheck: - """Allows the execution platform to differ from the serialization platform. + """Allows the compilation platform to differ from the export platform. Has effect only on deserialization. """ @@ -106,11 +102,16 @@ def custom_call(cls, target_name: str) -> DisabledSafetyCheck: @classmethod def shape_assertions(cls) -> DisabledSafetyCheck: - """Allows invocations with shapes that do not meet the constraints. + """DEPRECATED: A noop. - Has effect on serialization (to suppress the generation of the assertions) - and also on deserialization (to suppress the checking of the assertions). + Was used previously to allow invocations with shapes that do not meet the + constraints. Has no effect anymore, shape assertions cannot be disabled. """ + # TODO(necula): remove this after compatibility period. Was deprecated in + # May 2024. + warnings.warn( + "DisabledSafetyCheck.shape_assertions is deprecated, has no effect anymore", + DeprecationWarning, stacklevel=2) return DisabledSafetyCheck("shape_assertions") def is_custom_call(self) -> str | None: @@ -148,26 +149,37 @@ class Exported: out_tree: a PyTreeDef describing the result of the lowered JAX function. out_avals: the flat tuple of output abstract values. May contain dimension expressions in the shapes, with dimension variables among those in - `in_avals. - in_shardings: the flattened input shardings, as long as `in_avals`. - out_shardings: the flattened output shardings, as long as `out_avals`. + `in_avals`. + in_shardings_hlo: the flattened input shardings, a sequence as long + as `in_avals`. `None` means unspecified sharding. + Note that these do not include the mesh or the actual devices used in + the mesh. See `in_shardings_jax` for a way to turn these + into sharding specification that can be used with JAX APIs. + out_shardings_hlo: the flattened output shardings, a sequence as long + as `out_avals`. `None` means unspecified sharding. + Note that these do not include the mesh or the actual devices used in + the mesh. See `out_shardings_jax` for a way to turn these + into sharding specification that can be used with JAX APIs. nr_devices: the number of devices that the module has been lowered for. - lowering_platforms: a tuple containing at least one of 'tpu', 'cpu', - 'cuda', 'rocm'. See below for the calling convention for when - there are multiple lowering platforms. + platforms: a tuple containing the platforms for which the function should + be exported. The set of platforms in JAX is open-ended; users can + add platforms. JAX built-in platforms are: 'tpu', 'cpu', 'cuda', 'rocm'. + See https://jax.readthedocs.io/en/latest/export/export.html#cross-platform-and-multi-platform-export. ordered_effects: the ordered effects present in the serialized module. - This is present from serialization version 9. See below for the - calling convention in presence of ordered effects. + This is present from serialization version 9. See https://jax.readthedocs.io/en/latest/export.html#module-calling-convention + for the calling convention in presence of ordered effects. unordered_effects: the unordered effects present in the serialized module. This is present from serialization version 9. mlir_module_serialized: the serialized lowered VHLO module. - mlir_module_serialization_version: a version number for the serialized module. - See more versioning details at https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#native-serialization-versions. + calling_convention_version: a version number for the calling + convention of the exported module. + See more versioning details at https://jax.readthedocs.io/en/latest/export.html#calling-convention-versions. module_kept_var_idx: the sorted indices of the arguments among `in_avals` that must be passed to the module. The other arguments have been dropped - because they are not used. Same length as `in_shardings`. - uses_shape_polymorphism: whether the `mlir_module_serialized` uses shape - polymorphism. This may be because `in_avals` contains dimension + because they are not used. + uses_global_constants: whether the `mlir_module_serialized` uses shape + polymorphism or multi-platform export. + This may be because `in_avals` contains dimension variables, or due to inner calls of Exported modules that have dimension variables or platform index arguments. Such modules need shape refinement before XLA compilation. @@ -180,123 +192,30 @@ class Exported: for each primal output. It returns a tuple with the cotangents corresponding to the flattened primal inputs. - Calling convention for the exported module (for latest supported version): - - The `mlir_module` has a `main` function that takes an optional first - platform index argument if the module supports multiple platforms - (`len(lowering_platforms) > 1`), followed by the token arguments corresponding - to the ordered effects, followed by the kept array - arguments (corresponding to `module_kept_var_idx` and `in_avals`). - The platform index is a i32 or i64 scalar encoding the index of the current - compilation platform into the `lowering_platforms` sequence. - - Inner functions use a different calling convention: an optional - platform index argument, optional dimension variable arguments - (scalar tensors of type i32 or i64), - followed by optional token arguments (in presence of ordered effects), - followed by the regular array arguments. - The dimension arguments correspond to the dimension variables appearing in - the `args_avals`, in sorted order of their names. - - Consider the lowering of a function with one array argument of type "f32[w, - 2 * h]", where "w" and "h" are two dimension variables. - Assume that we use multi-platform lowering, and we have - one ordered effect. The `main` function will be as follows: - - func public main( - platform_index: i32 {jax.global_constant="_platform_index"}, - token_in: token, - arg: f32[?, ?]) { - arg_w = hlo.get_dimension_size(arg, 0) - dim1 = hlo.get_dimension_size(arg, 1) - arg_h = hlo.floordiv(dim1, 2) - call _check_shape_assertions(arg) # See below - token = new_token() - token_out, res = call _wrapped_jax_export_main(platform_index, - arg_h, - arg_w, - token_in, - arg) - return token_out, res - } - - The actual computation is in `_wrapped_jax_export_main`, taking also - the values of `h` and `w` dimension variables. - - The signature of the `_wrapped_jax_export_main` is: - - func private _wrapped_jax_export_main( - platform_index: i32 {jax.global_constant="_platform_index"}, - arg_h: i32 {jax.global_constant="h"}, - arg_w: i32 {jax.global_constant="w"}, - arg_token: stablehlo.token {jax.token=True}, - arg: f32[?, ?]) -> (stablehlo.token, ...) - - Prior to serialization version 9 the calling convention for effects is - different: the `main` function does not take or return a token. Instead - the function creates dummy tokens of type `i1[0]` and passes them to the - `_wrapped_jax_export_main`. The `_wrapped_jax_export_main` - takes dummy tokens of type `i1[0]` and will create internally real - tokens to pass to the inner functions. The inner functions use real - tokens (both before and after serialization version 9) - - Also starting with serialization version 9, function arguments that contain - the platform index or the dimension variable values have a - `jax.global_constant` string attribute whose value is the name of the - global constant, either `_platform_index` or a dimension variable name. - The global constant name may be empty if it is not known. - Some global constant computations use inner functions, e.g., for - `floor_divide`. The arguments of such functions have a `jax.global_constant` - attribute for all attributes, meaning that the result of the function is - also a global constant. - - Note that `main` contains a call to `_check_shape_assertions. - JAX tracing assumes that `arg.shape[1]` is even, and that both `w` and `h` - have values >= 1. We must check these constraints when we invoke the - module. We use a special custom call `@shape_assertion` that takes - a boolean first operand, a string `error_message` attribute that may contain - format specifiers `{0}`, `{1}`, ..., and a variadic number of integer - scalar operands corresponding to the format specifiers. - - func private _check_shape_assertions(arg: f32[?, ?]) { - # Check that w is >= 1 - arg_w = hlo.get_dimension_size(arg, 0) - custom_call @shape_assertion(arg_w >= 1, arg_w, - error_message="Dimension variable 'w' must have integer value >= 1. Found {0}") - # Check that dim1 is even - dim1 = hlo.get_dimension_size(arg, 1) - custom_call @shape_assertion(dim1 % 2 == 0, dim1, - error_message="Dimension variable 'h' must have integer value >= 1. Found non-zero remainder {0}") - # Check that h >= 1 - arg_h = hlo.floordiv(dim1, 2) - custom_call @shape_assertion(arg_h >= 1, arg_h, - error_message=""Dimension variable 'h' must have integer value >= 1. Found {0}") - - If we `call_exported` with this module we perform these checks - statically (in `call_exported_abstract_eval`). + See a [description of the calling convention for the `mlir_module`](https://jax.readthedocs.io/en/latest/export.html#module-calling-convention). """ fun_name: str in_tree: tree_util.PyTreeDef - in_avals: tuple[core.AbstractValue, ...] + in_avals: tuple[core.ShapedArray, ...] out_tree: tree_util.PyTreeDef - out_avals: tuple[core.AbstractValue, ...] + out_avals: tuple[core.ShapedArray, ...] - in_shardings: tuple[Sharding, ...] - out_shardings: tuple[Sharding, ...] + in_shardings_hlo: tuple[HloSharding | None, ...] + out_shardings_hlo: tuple[HloSharding | None, ...] nr_devices: int - lowering_platforms: tuple[str, ...] + platforms: tuple[str, ...] ordered_effects: tuple[effects.Effect, ...] unordered_effects: tuple[effects.Effect, ...] disabled_safety_checks: Sequence[DisabledSafetyCheck] mlir_module_serialized: bytes - mlir_module_serialization_version: int + calling_convention_version: int module_kept_var_idx: tuple[int, ...] - uses_shape_polymorphism: bool + uses_global_constants: bool _get_vjp: Callable[[Exported], Exported] | None - def mlir_module(self) -> ir.Module: + def mlir_module(self) -> str: return xla_client._xla.mlir.deserialize_portable_artifact(self.mlir_module_serialized) def __str__(self): @@ -304,23 +223,146 @@ def __str__(self): # do not want the entire serialized module to end up in locations. return f"Exported(fun_name={self.fun_name}, ...)" + # For backwards compatibility + # TODO(necula): remove after September 2024. + @property + def in_shardings(self): + return self.in_shardings_hlo + @property + def out_shardings(self): + return self.out_shardings_hlo + + def in_shardings_jax( + self, + mesh: sharding.Mesh) -> Sequence[sharding.Sharding | None]: + """Creates Shardings corresponding to self.in_shardings_hlo. + + The Exported object stores `in_shardings_hlo` as HloShardings, which are + independent of a mesh or set of devices. This method constructs + Sharding that can be used in JAX APIs such as `jax.jit` or + `jax.device_put`. + + Example usage: + >>> from jax import export + >>> exp_mesh = sharding.Mesh(jax.devices(), ("a",)) + >>> exp = export.export(jax.jit(lambda x: jax.numpy.add(x, x), + ... in_shardings=sharding.NamedSharding(exp_mesh, sharding.PartitionSpec("a"))) + ... )(np.arange(jax.device_count())) + >>> exp.in_shardings_hlo + ({devices=[8]<=[8]},) + + # Create a mesh for running the exported object + >>> run_mesh = sharding.Mesh(jax.devices()[::-1], ("b",)) + >>> + # Put the args and kwargs on the appropriate devices + >>> run_arg = jax.device_put(np.arange(jax.device_count()), + ... exp.in_shardings_jax(run_mesh)[0]) + >>> res = exp.call(run_arg) + >>> res.addressable_shards + [Shard(device=CpuDevice(id=7), index=(slice(0, 1, None),), replica_id=0, data=[0]), + Shard(device=CpuDevice(id=6), index=(slice(1, 2, None),), replica_id=0, data=[2]), + Shard(device=CpuDevice(id=5), index=(slice(2, 3, None),), replica_id=0, data=[4]), + Shard(device=CpuDevice(id=4), index=(slice(3, 4, None),), replica_id=0, data=[6]), + Shard(device=CpuDevice(id=3), index=(slice(4, 5, None),), replica_id=0, data=[8]), + Shard(device=CpuDevice(id=2), index=(slice(5, 6, None),), replica_id=0, data=[10]), + Shard(device=CpuDevice(id=1), index=(slice(6, 7, None),), replica_id=0, data=[12]), + Shard(device=CpuDevice(id=0), index=(slice(7, 8, None),), replica_id=0, data=[14])] + """ + return tuple(_hlo_sharding_to_xla_compatible_sharding(s, mesh) + for s in self.in_shardings_hlo) + + def out_shardings_jax( + self, + mesh: sharding.Mesh) -> Sequence[sharding.Sharding | None]: + """Creates Shardings corresponding to self.out_shardings_hlo. + + See documentation for in_shardings_jax. + """ + return tuple(_hlo_sharding_to_xla_compatible_sharding(s, mesh) + for s in self.out_shardings_hlo) + + # For backwards compatibility + # TODO(necula): remove after September 2024. + @property + def lowering_platforms(self): + """DEPRECATED.""" + warnings.warn("lowering_platform is deprecated. Use .platforms instead.", + DeprecationWarning, stacklevel=2) + return self.platforms + + # For backwards compatibility + # TODO(necula): remove after September 2024. + @property + def mlir_module_serialization_version(self): + """DEPRECATED.""" + warnings.warn("mlir_module_serialization_version is deprecated. Use .calling_convention_version instead.", + DeprecationWarning, stacklevel=2) + return self.calling_convention_version + + # For backwards compatibility + # TODO(necula): remove after September 2024. + @property + def uses_shape_polymorphism(self): + """DEPRECATED.""" + warnings.warn("uses_shape_polymorphism is deprecated. Use .uses_global_constants instead.", + DeprecationWarning, stacklevel=2) + return self.uses_global_constants + def has_vjp(self) -> bool: + """Returns if this Exported supports VJP.""" return self._get_vjp is not None def vjp(self) -> Exported: """Gets the exported VJP. Returns None if not available, which can happen if the Exported has been - loaded from an external format, without a VJP.""" + loaded from an external format without a VJP. + """ if self._get_vjp is None: raise ValueError("No VJP is available") return self._get_vjp(self) + def serialize(self, + vjp_order: int = 0) -> bytearray: + """Serializes an Exported. + + Args: + vjp_order: The maximum vjp order to include. E.g., the value 2 means that we + serialize the primal functions and two orders of the `vjp` function. This + should allow 2nd order reverse mode differentiation of the deserialized + function. i.e., `jax.grad(jax.grad(f)).` + """ + # Lazy load the serialization module, since flatbuffers is an optional + # dependency. + from jax._src.export.serialization import serialize + return serialize(self, vjp_order=vjp_order) + + def call(self, *args, **kwargs): + return call_exported(self)(*args, **kwargs) + + +def deserialize(blob: bytearray) -> Exported: + """Deserializes an Exported. + + Args: + blob: a bytearray obtained from `Exported.serialize`. + """ + # Lazy load the serialization module, since flatbuffers is an optional + # dependency. + from jax._src.export.serialization import deserialize + return deserialize(blob) + -def default_lowering_platform() -> str: +def default_export_platform() -> str: + """Retrieves the default export platform. + + One of: `tpu`, `cpu`, `cuda`, `rocm`. + """ # Canonicalize to turn 'gpu' into 'cuda' or 'rocm' return xb.canonicalize_platform(jax.default_backend()) +default_lowering_platform = default_export_platform + def shape_and_dtype_jax_array(a) -> tuple[Sequence[int | None], DType]: """Returns the shape and dtype of a jax.Array or a j""" if isinstance(a, jax.ShapeDtypeStruct): @@ -341,20 +383,24 @@ def args_specs( # This was needed in some older jax2tf implementations args = tree_util.tree_map(lambda a: jax.ShapeDtypeStruct(* get_shape_and_dtype(a)), args) - return _shape_poly.symbolic_args_specs(args, polymorphic_shapes) - - + return shape_poly.symbolic_args_specs(args, polymorphic_shapes) -def _keep_main_tokens(serialization_version: int) -> bool: - return serialization_version >= _VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS -def export(fun_jax: Callable, - *, - lowering_platforms: Sequence[str] | None = None, - disabled_checks: Sequence[DisabledSafetyCheck] = (), - ) -> Callable[..., Exported]: +# TODO(necula): remove this once we remove jax.experimental.export. +def export_back_compat( + fun_jax: Callable, + *, + lowering_platforms: Sequence[str] | None = None, + disabled_checks: Sequence[DisabledSafetyCheck] = (), + _device_assignment_for_internal_jax2tf_use_only = None, + ) -> Callable[..., Exported]: """Exports native serialization for a JAX function. + Note: this function exists only for internal usage by jax2tf and for + backwards compatibility with jax.experimental.export. Use + `jax.export` instead. + See https://jax.readthedocs.io/en/latest/export.html + Args: fun_jax: the function to lower and serialize. lowering_platforms: @@ -362,8 +408,8 @@ def export(fun_jax: Callable, 'cuda', 'rocm'. If more than one platform is specified, then the lowered code takes an argument specifying the platform. If None, then use the default JAX backend. - The calling convention for multiple platforms is explained in the - `jax_export.Exported` docstring. + The calling convention for multiple platforms is explained + at https://jax.readthedocs.io/en/latest/export.html#module-calling-convention. disabled_checks: the safety checks to disable. See docstring of `DisabledSafetyCheck`. @@ -376,179 +422,278 @@ def export(fun_jax: Callable, def f_jax(*args, **kwargs): ... exported = jax_export.export(f_jax)(*args, **kwargs) """ - fun_name = getattr(fun_jax, "__name__", "unknown") - version = config.jax_serialization_version.value - if (version < minimum_supported_serialization_version or - version > maximum_supported_serialization_version): - raise ValueError( - f"The requested jax_serialization version {version} is outside the " - f"range of supported versions [{minimum_supported_serialization_version}" - f"..{maximum_supported_serialization_version}]") def do_export(*args_specs, **kwargs_specs) -> Exported: - if not hasattr(fun_jax, "lower"): + if hasattr(fun_jax, "trace"): + # If we have a pjit or pmap already we do not wrap with another, and we + # allow shardings. + wrapped_fun_jax = fun_jax + else: # We support convert(pjit(f_jax)) and convert(jit(f_jax)) but also # convert(f_jax), in which case a "jit" is implied. In that case we raise # an error if the lowered function contains non-replicated sharding annotations. wrapped_fun_jax = jax.jit(fun_jax) - allow_non_replicated_sharding = False - else: - # If we have a pjit or pmap already we do not wrap with another, and we - # allow shardings. - wrapped_fun_jax = fun_jax # type: ignore - allow_non_replicated_sharding = True if lowering_platforms is not None: actual_lowering_platforms = tuple(lowering_platforms) else: - actual_lowering_platforms = (default_lowering_platform(),) + actual_lowering_platforms = (default_export_platform(),) + + # TODO: move to `lower` + symbolic_scope: tuple[shape_poly.SymbolicScope, tree_util.KeyPath] | None = None # type: ignore[invalid-annotation,unused-ignore] + for k_path, aval in tree_util.tree_flatten_with_path((args_specs, kwargs_specs))[0]: + # Static args may have no `shape` attribute. + if not hasattr(aval, "shape"): + continue + for d in aval.shape: + if shape_poly.is_symbolic_dim(d): + if symbolic_scope is None: + symbolic_scope = (d.scope, k_path) + continue + symbolic_scope[0]._check_same_scope( + d, when=f"when exporting {util.fun_name(wrapped_fun_jax)}", + self_descr=f"current (from {shape_poly.args_kwargs_path_to_str(symbolic_scope[1])}) ", + other_descr=shape_poly.args_kwargs_path_to_str(k_path)) + + traced = wrapped_fun_jax.trace(*args_specs, **kwargs_specs) + lowered = traced.lower( + lowering_platforms=actual_lowering_platforms, + _private_parameters=mlir.LoweringParameters( + for_export=True, + export_ignore_forward_compatibility=config.export_ignore_forward_compatibility.value)) + return _export_lowered( + lowered, traced.jaxpr, traced.fun_name, + disabled_checks=disabled_checks, + _device_assignment_for_internal_jax2tf_use_only=_device_assignment_for_internal_jax2tf_use_only) + return do_export - # Do not include shape assertions if the version is < 7. - enable_shape_assertions = ( - DisabledSafetyCheck.shape_assertions() not in disabled_checks and - version >= _VERSION_START_SUPPORT_SHAPE_ASSERTIONS) # type: ignore - try: - prev_enable_shape_assertions = _shape_poly.thread_local_state.enable_shape_assertions - _shape_poly.thread_local_state.enable_shape_assertions = enable_shape_assertions - replace_tokens_with_dummy = not _keep_main_tokens(version) - - symbolic_scope: tuple[_shape_poly.SymbolicScope, tree_util.KeyPath] | None = None - for k_path, aval in tree_util.tree_flatten_with_path((args_specs, kwargs_specs))[0]: - for d in aval.shape: - if _shape_poly.is_symbolic_dim(d): - if symbolic_scope is None: - symbolic_scope = (d.scope, k_path) - continue - symbolic_scope[0]._check_same_scope( - d, when=f"when exporting {fun_name}", - self_descr=f"current (from {_shape_poly.args_kwargs_path_to_str(symbolic_scope[1])}) ", - other_descr=_shape_poly.args_kwargs_path_to_str(k_path)) - - lowered = wrapped_fun_jax.lower( - *args_specs, **kwargs_specs, - _experimental_lowering_parameters=mlir.LoweringParameters( - platforms=actual_lowering_platforms, - replace_tokens_with_dummy=replace_tokens_with_dummy, - )) +def export( + fun_jit: stages.Wrapped, + *, + platforms: Sequence[str] | None = None, + lowering_platforms: Sequence[str] | None = None, + disabled_checks: Sequence[DisabledSafetyCheck] = (), + ) -> Callable[..., Exported]: + """Exports a JAX function for persistent serialization. - lowering = lowered._lowering # type: ignore - _check_lowering(lowering) - mlir_module = lowering.stablehlo() + Args: + fun_jit: the function to export. Should be the result of `jax.jit`. + platforms: + Optional sequence containing a subset of 'tpu', 'cpu', + 'cuda', 'rocm'. If more than one platform is specified, then + the exported code takes an argument specifying the platform. + If None, then use the default JAX backend. + The calling convention for multiple platforms is explained at + https://jax.readthedocs.io/en/latest/export.html#module-calling-convention. + lowering_platforms: DEPRECATED, use `platforms`. + disabled_checks: the safety checks to disable. See documentation for + of `jax.export.DisabledSafetyCheck`. - args_avals_flat, _ = tree_util.tree_flatten(lowered.in_avals) - if "out_mut" in lowering.compile_args: - if lowering.compile_args["out_mut"]: raise NotImplementedError - if "kept_var_idx" in lowering.compile_args: - module_kept_var_idx = tuple(sorted(lowering.compile_args["kept_var_idx"])) - else: - # For pmap - module_kept_var_idx = tuple(range(len(args_avals_flat))) - shape_poly_state = lowering.compile_args["shape_poly_state"] - if (not all(core.is_constant_shape(a.shape) for a in args_avals_flat) - or lowering.compile_args.get("ordered_effects", [])): - mlir_module = _wrap_main_func( - mlir_module, args_avals_flat, args_kwargs_tree=lowered.in_tree, - has_platform_index_argument=shape_poly_state.has_platform_index_argument, - module_kept_var_idx=module_kept_var_idx, - serialization_version=version) - finally: - _shape_poly.thread_local_state.enable_shape_assertions = prev_enable_shape_assertions - - with mlir_module.context: - mlir_module_attrs = mlir_module.operation.attributes - mlir_module_attrs["jax.uses_shape_polymorphism"] = ( - mlir.ir.BoolAttr.get(shape_poly_state.uses_dim_vars)) - - mlir_module_serialized = _module_to_bytecode(mlir_module) - - # Figure out the result types and shapes - if "global_out_avals" in lowering.compile_args: - # This is currently the case for pjit - out_avals_flat = lowering.compile_args["global_out_avals"] - elif "shards" in lowering.compile_args: # for PmapComputation - out_avals_flat = lowering.compile_args["shards"].out_sharded_avals - else: - out_avals_flat = lowered.compile_args["out_avals"] - - # Log and then check the module. - if logging.vlog_is_on(3): - logmsg = (f"version={version} " - f"lowering_platforms={actual_lowering_platforms} " - f"disabled_checks={disabled_checks}") - logging.info("Lowered JAX module: %s\n", logmsg) - if dumped_to := mlir.dump_module_to_file(mlir_module, "export"): - logging.info("Dumped the exported MLIR module to %s", dumped_to) - - _check_module(mlir_module, - allow_non_replicated_sharding=allow_non_replicated_sharding, - disabled_checks=disabled_checks) - - ordered_effects = tuple(lowering.compile_args["ordered_effects"]) - unordered_effects = tuple(lowering.compile_args["unordered_effects"]) - if version < _VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS: - ordered_effects = unordered_effects = () - - nr_devices = len(lowering.compile_args["device_assignment"]) - def export_sharding(s: LoweringSharding, - aval: core.ShapedArray) -> Sharding: - if sharding_impls.is_unspecified(s): - return None - return s._to_xla_hlo_sharding(aval.ndim) # type: ignore[union-attr] - - all_in_shardings = expand_in_shardings(lowering.compile_args["in_shardings"], - module_kept_var_idx, - len(args_avals_flat)) - in_shardings = tuple( - export_sharding(s, aval) - for s, aval in zip(all_in_shardings, args_avals_flat)) - out_shardings = tuple( - export_sharding(s, aval) - for s, aval in zip(lowering.compile_args["out_shardings"], out_avals_flat)) - return Exported( - fun_name=fun_name, - in_tree=lowered.in_tree, - out_tree=lowered.out_tree, - in_avals=tuple(args_avals_flat), - out_avals=tuple(out_avals_flat), - in_shardings=in_shardings, - out_shardings=out_shardings, - nr_devices=nr_devices, + Returns: a function that takes args and kwargs pytrees of {class}`jax.ShapeDtypeStruct`, + or values with `.shape` and `.dtype` attributes, and returns an + `Exported`. + + Usage: + + >>> from jax import export + >>> exported: export.Exported = export.export(jnp.sin)( + ... np.arange(4, dtype=np.float32)) + >>> + >>> # You can inspect the Exported object + >>> exported.in_avals + (ShapedArray(float32[4]),) + >>> blob: bytearray = exported.serialize() + >>> + >>> # The serialized bytes are safe to use in a separate process + >>> rehydrated: export.Exported = export.deserialize(blob) + >>> rehydrated.fun_name + 'sin' + >>> rehydrated.call(np.array([.1, .2, .3, .4], dtype=np.float32)) + Array([0.09983342, 0.19866933, 0.29552022, 0.38941833], dtype=float32) + """ + if not isinstance(fun_jit, stages.Wrapped): + raise ValueError( + f"Function to be exported must be the result of `jit` but is: {fun_jit}") + if platforms is not None and lowering_platforms is not None: + raise ValueError("Cannot use both `platforms` and `lowering_platforms`") + if platforms is None and lowering_platforms is not None: + platforms = lowering_platforms + if platforms is not None: + actual_lowering_platforms = tuple(platforms) + else: + actual_lowering_platforms = (default_export_platform(),) + + def do_export(*args_specs, **kwargs_specs) -> Exported: + # TODO: move to `lower` + symbolic_scope: tuple[shape_poly.SymbolicScope, tree_util.KeyPath] | None = None # type: ignore[invalid-annotation,unused-ignore] + for k_path, aval in tree_util.tree_flatten_with_path((args_specs, kwargs_specs))[0]: + # Static args may have no `shape` attribute. + if not hasattr(aval, "shape"): + continue + for d in aval.shape: + if shape_poly.is_symbolic_dim(d): + if symbolic_scope is None: + symbolic_scope = (d.scope, k_path) + continue + symbolic_scope[0]._check_same_scope( + d, when=f"when exporting {util.fun_name(fun_jit)}", + self_descr=f"current (from {shape_poly.args_kwargs_path_to_str(symbolic_scope[1])}) ", + other_descr=shape_poly.args_kwargs_path_to_str(k_path)) + + traced = fun_jit.trace(*args_specs, **kwargs_specs) + lowered = traced.lower( lowering_platforms=actual_lowering_platforms, - ordered_effects=ordered_effects, - unordered_effects=unordered_effects, - disabled_safety_checks=tuple(disabled_checks), - mlir_module_serialized=mlir_module_serialized, + _private_parameters=mlir.LoweringParameters( + for_export=True, + export_ignore_forward_compatibility=config.export_ignore_forward_compatibility.value)) + return _export_lowered( + lowered, traced.jaxpr, traced.fun_name, + disabled_checks=disabled_checks) + return do_export + +def _export_lowered( + lowered: stages.Lowered, + jaxpr: core.ClosedJaxpr, fun_name: str, + disabled_checks: Sequence[DisabledSafetyCheck] = (), + _device_assignment_for_internal_jax2tf_use_only = None, + ) -> Exported: + version = config.jax_export_calling_convention_version.value + if (version < minimum_supported_calling_convention_version or + version > maximum_supported_calling_convention_version): + raise ValueError( + f"The requested export calling convention version {version} is outside the " + f"range of supported versions [{minimum_supported_calling_convention_version}" + f"..{maximum_supported_calling_convention_version}]") + + lowering = lowered._lowering + _check_lowering(lowering) + mlir_module = lowering.stablehlo() + + args_avals_flat, _ = tree_util.tree_flatten(lowered.in_avals) + if "mut" in lowering.compile_args: + if lowering.compile_args["mut"]: raise NotImplementedError + if "kept_var_idx" in lowering.compile_args: + module_kept_var_idx = tuple(sorted(lowering.compile_args["kept_var_idx"])) + else: + # For pmap + module_kept_var_idx = tuple(range(len(args_avals_flat))) + shape_poly_state = lowering.compile_args["shape_poly_state"] + if (not all(core.is_constant_shape(a.shape) for a in args_avals_flat) + or lowering.compile_args.get("ordered_effects", [])): + mlir_module = _wrap_main_func( + mlir_module, args_avals_flat, args_kwargs_tree=lowered.in_tree, + has_platform_index_argument=shape_poly_state.has_platform_index_argument, module_kept_var_idx=module_kept_var_idx, - uses_shape_polymorphism=shape_poly_state.uses_dim_vars, - mlir_module_serialization_version=version, # type: ignore - _get_vjp=lambda exported: _export_native_vjp(fun_jax, exported)) + serialization_version=version) - return do_export + with mlir_module.context: + mlir_module_attrs = mlir_module.operation.attributes + mlir_module_attrs["jax.uses_shape_polymorphism"] = ( + mlir.ir.BoolAttr.get(shape_poly_state.uses_dim_vars)) + mlir_module_serialized = _module_to_bytecode(mlir_module) + + # Figure out the result types and shapes + if "global_out_avals" in lowering.compile_args: + # This is currently the case for pjit + out_avals_flat = lowering.compile_args["global_out_avals"] + elif "shards" in lowering.compile_args: # for PmapComputation + out_avals_flat = lowering.compile_args["shards"].out_sharded_avals + else: + out_avals_flat = lowered.compile_args["out_avals"] # type: ignore + + # Log and then check the module. + if logging.vlog_is_on(3): + logmsg = (f"fun_name={fun_name} version={version} " + f"lowering_platforms={lowering._platforms} " # type: ignore[unused-ignore,attribute-error] + f"disabled_checks={disabled_checks}") + logging.info("Exported JAX function: %s\n", logmsg) + logging.info(mlir.dump_module_message(mlir_module, "export")) + + _check_module(mlir_module, + disabled_checks=disabled_checks) + + ordered_effects = tuple(lowering.compile_args["ordered_effects"]) + unordered_effects = tuple(lowering.compile_args["unordered_effects"]) + + nr_devices = len(lowering.compile_args["device_assignment"]) + def export_sharding(s: LoweringSharding, + aval: core.ShapedArray) -> HloSharding | None: + if sharding_impls.is_unspecified(s): + return None + return s._to_xla_hlo_sharding(aval.ndim) # type: ignore[union-attr] + + all_in_shardings = expand_in_shardings(lowering.compile_args["in_shardings"], + module_kept_var_idx, + len(args_avals_flat)) + in_shardings = tuple( + export_sharding(s, aval) + for s, aval in zip(all_in_shardings, args_avals_flat)) + out_shardings = tuple( + export_sharding(s, aval) + for s, aval in zip(lowering.compile_args["out_shardings"], out_avals_flat)) + + device_assignment = lowering.compile_args["device_assignment"] + if _device_assignment_for_internal_jax2tf_use_only is not None: + _device_assignment_for_internal_jax2tf_use_only[0] = device_assignment + def _get_exported_vjp(exp_primal: Exported) -> Exported: + # Turn the primal jaxpr into a function, in preparation for exporting + # the VJP. Note that jaxpr_as_fun produces a function with flat arguments + assert(jaxpr is not None) # None only when the lowered was created outside JAX + fun_jax = core.jaxpr_as_fun(jaxpr) + + fun_vjp_jax, vjp_in_avals = _get_vjp_fun(fun_jax, + in_tree=exp_primal.in_tree, + in_avals=exp_primal.in_avals, + in_shardings_hlo=exp_primal.in_shardings_hlo, + out_avals=exp_primal.out_avals, + out_shardings_hlo=exp_primal.out_shardings_hlo, + device_assignment=device_assignment, + apply_jit=True, + flat_primal_fun=True) + return export(fun_vjp_jax, # type: ignore[arg-type] + platforms=exp_primal.platforms, + disabled_checks=exp_primal.disabled_safety_checks)(*vjp_in_avals) + + return Exported( + fun_name=fun_name, + in_tree=lowered.in_tree, + out_tree=lowered.out_tree, + in_avals=tuple(args_avals_flat), + out_avals=tuple(out_avals_flat), + in_shardings_hlo=in_shardings, + out_shardings_hlo=out_shardings, + nr_devices=nr_devices, + platforms=lowering._platforms, # type: ignore + ordered_effects=ordered_effects, + unordered_effects=unordered_effects, + disabled_safety_checks=tuple(disabled_checks), + mlir_module_serialized=mlir_module_serialized, + module_kept_var_idx=module_kept_var_idx, + uses_global_constants=shape_poly_state.uses_dim_vars, + calling_convention_version=version, + _get_vjp=_get_exported_vjp) def _module_to_bytecode(module: ir.Module) -> bytes: mlir_str = mlir.module_to_bytecode(module) - if hlo.get_api_version() < 4: - target_version = hlo.get_earliest_forward_compatible_version() - else: - # `target_version` is used to manage situations when a StableHLO producer - # (in this case, jax2tf) and a StableHLO consumer were built using - # different versions of StableHLO. - # - # Each StableHLO version `producer_version` has a compatibility window, - # i.e. range of versions [`consumer_version_min`, `consumer_version_max`], - # where StableHLO portable artifacts serialized by `producer_version` - # can be deserialized by `consumer_version` within the window. - # See https://github.com/openxla/stablehlo/blob/main/docs/compatibility.md - # for the exact extent of these compatibility guarantees. - # - # `hlo.get_minimum_version()` returns `consumer_version_min` - # for the current version of StableHLO. We are using it here to maximize - # forward compatibility, i.e. to maximize how far into the past we can go - # and still have the payloads produced by `serialize_portable_artifact` - # compatible with potential consumers from the past. - target_version = hlo.get_minimum_version() - module_serialized = xla_client._xla.mlir.serialize_portable_artifact( + # `target_version` is used to manage situations when a StableHLO producer + # (in this case, jax2tf) and a StableHLO consumer were built using + # different versions of StableHLO. + # + # Each StableHLO version `producer_version` has a compatibility window, + # i.e. range of versions [`consumer_version_min`, `consumer_version_max`], + # where StableHLO portable artifacts serialized by `producer_version` + # can be deserialized by `consumer_version` within the window. + # See https://github.com/openxla/stablehlo/blob/main/docs/compatibility.md + # for the exact extent of these compatibility guarantees. + # + # `hlo.get_minimum_version()` returns `consumer_version_min` + # for the current version of StableHLO. We are using it here to maximize + # forward compatibility, i.e. to maximize how far into the past we can go + # and still have the payloads produced by `serialize_portable_artifact` + # compatible with potential consumers from the past. + target_version = hlo.get_minimum_version() + module_serialized = xla_client._xla.mlir.serialize_portable_artifact( # type: ignore mlir_str, target_version) return module_serialized @@ -564,11 +709,10 @@ def _wrap_main_func( ) -> ir.Module: """Wraps the lowered module with a new "main" handling dimension arguments. - See calling convention documentation for `jax_export.Exported`. + See calling convention documentation https://jax.readthedocs.io/en/latest/export.html#module-calling-convention. Args: - module: the HLO module as obtained from lowering. See the calling convention - for inner functions in `jax_export.Exported`. + module: the HLO module as obtained from lowering. args_avals_flat: the avals for all the arguments of the lowered function, which correspond to the array arguments of the `module`. args_kwargs_tree: the PyTreeDef corresponding to `(args, kwargs)`, for error @@ -581,11 +725,11 @@ def _wrap_main_func( Returns the wrapped module, without dimension and token arguments. """ - dim_vars = _shape_poly.all_dim_vars(args_avals_flat) + dim_vars = shape_poly.all_dim_vars(args_avals_flat) context = mlir.make_ir_context() with context, ir.Location.unknown(context): # Make a copy, do not mutate because it may be cached - wrapped_module = ir.Module.parse(mlir.module_to_bytecode(module)) + wrapped_module = ir.Module.parse(mlir.module_to_bytecode(module)) # type: ignore symbol_table = ir.SymbolTable(wrapped_module.operation) orig_main = symbol_table["main"] orig_main.attributes["sym_visibility"] = ir.StringAttr.get("private") @@ -593,16 +737,10 @@ def _wrap_main_func( orig_main_name = ir.StringAttr(symbol_table.insert(orig_main)).value def is_token(typ, attrs): - if typ == mlir.token_type()[0]: - return True - # TODO(b/302258959): in older versions we cannot use the token type - try: - return ir.BoolAttr(ir.DictAttr(attrs)["jax.token"]).value - except KeyError: - return False + return (typ == mlir.token_type()[0]) - orig_input_types = orig_main.type.inputs - arg_attrs = list(ir.ArrayAttr(orig_main.arg_attrs)) + orig_input_types = orig_main.type.inputs # type: ignore + arg_attrs = list(ir.ArrayAttr(orig_main.arg_attrs)) # type: ignore # The order of args: platform_index_arg, dim args, token args, array args. nr_platform_index_args = 1 if has_platform_index_argument else 0 nr_dim_args = len(dim_vars) @@ -624,8 +762,8 @@ def is_token(typ, attrs): orig_input_types, [nr_platform_index_args, nr_dim_args, nr_token_args]) # The order of results: tokens, array results - orig_output_types = orig_main.type.results - result_attrs = list(ir.ArrayAttr(orig_main.result_attrs)) + orig_output_types = orig_main.type.results # type: ignore + result_attrs = list(ir.ArrayAttr(orig_main.result_attrs)) # type: ignore token_result_idxs = [i for i, (typ, attrs) in enumerate(zip(orig_output_types, result_attrs)) if is_token(typ, attrs)] @@ -633,17 +771,11 @@ def is_token(typ, attrs): assert token_result_idxs == list(range(0, nr_token_results)) nr_array_results = len(orig_output_types) - nr_token_results assert nr_array_results >= 0 - if _keep_main_tokens(serialization_version): - new_main_arg_indices = (tuple(range(0, nr_platform_index_args)) + - tuple(range(nr_platform_index_args + nr_dim_args, - len(orig_input_types)))) - new_main_result_indices = tuple(range(0, len(orig_output_types))) - else: - new_main_arg_indices = ( - tuple(range(0, nr_platform_index_args)) + - tuple(range(nr_platform_index_args + nr_dim_args + nr_token_args, - len(orig_input_types)))) - new_main_result_indices = tuple(range(nr_token_results, len(orig_output_types))) + new_main_arg_indices = ( + *range(nr_platform_index_args), + *range(nr_platform_index_args + nr_dim_args, len(orig_input_types))) + new_main_result_indices = tuple(range(0, len(orig_output_types))) + new_main_input_types = [orig_input_types[idx] for idx in new_main_arg_indices] new_main_output_types = [orig_output_types[idx] for idx in new_main_result_indices] new_main_ftype = ir.FunctionType.get(new_main_input_types, new_main_output_types) @@ -681,7 +813,9 @@ def is_token(typ, attrs): keepalives=[], channel_iterator=itertools.count(1), host_callbacks=[], module=wrapped_module, context=context, lowering_parameters=mlir.LoweringParameters( - global_constant_computation=True + global_constant_computation=True, + for_export=True, + export_ignore_forward_compatibility=config.export_ignore_forward_compatibility.value, )) ctx = mlir.LoweringRuleContext( module_context=module_context, @@ -690,12 +824,12 @@ def is_token(typ, attrs): tokens_in=mlir.TokenSet(), tokens_out=None) # We compute dim_values from the array arguments. new_main_op_array_args = new_main_op.arguments[-nr_array_args:] - if _shape_poly.all_dim_vars(args_avals_flat): + if shape_poly.all_dim_vars(args_avals_flat): # TODO(necula): handle module_kept_var_idx in presence of shape # polymorphism. For now we ensured upstream that we keep all variables. assert len(set(module_kept_var_idx)) == len(args_avals_flat) dim_values = mlir.lower_fun( - functools.partial(_shape_poly.compute_dim_vars_from_arg_shapes, + functools.partial(shape_poly.compute_dim_vars_from_arg_shapes, args_avals_flat, args_kwargs_tree=args_kwargs_tree), multiple_results=True)(ctx, *new_main_op_array_args) else: @@ -711,11 +845,8 @@ def is_token(typ, attrs): else: orig_main_args.append(arg) # Then the token arguments - if _keep_main_tokens(serialization_version): - orig_main_args.extend( - new_main_op.arguments[nr_platform_index_args: nr_platform_index_args + nr_token_args]) - else: - orig_main_args.extend(list(mlir.dummy_token()) * nr_token_args) + orig_main_args.extend( + new_main_op.arguments[nr_platform_index_args: nr_platform_index_args + nr_token_args]) # Then the array arguments. We insert a ConvertOp as the only use of # an input argument. This helps the downstream shape refinement because # it will set the type of input arguments to static shapes, and this @@ -737,7 +868,7 @@ def is_token(typ, attrs): def _check_lowering(lowering) -> None: if not isinstance(lowering, pxla.MeshComputation): - raise NotImplementedError(f"serialization is supported only for pjit. {lowering}") + raise NotImplementedError(f"serialization is supported only for jit. {lowering}") if lowering.compile_args["host_callbacks"] or lowering.compile_args["keepalive"]: raise NotImplementedError("serialization of host_callbacks is not yet implemented") @@ -745,13 +876,14 @@ def _check_lowering(lowering) -> None: # safe to add it to the allowed_compile_args if it does not change the semantics # or the calling convention of the lowered module. allowed_compile_args = [ - "backend", "mesh", "global_in_avals", + "backend", "platforms", "mesh", "global_in_avals", "global_out_avals", "in_shardings", "out_shardings", "kept_var_idx", - "out_mut", "spmd_lowering", "auto_spmd_lowering", + "mut", "spmd_lowering", "auto_spmd_lowering", "tuple_args", "ordered_effects", "unordered_effects", "keepalive", "host_callbacks", "pmap_nreps", "committed", "device_assignment", "jaxpr_debug_info", "shape_poly_state", - "all_default_mem_kind", "in_layouts", "out_layouts", "all_args_info"] + "all_default_mem_kind", "in_layouts", "out_layouts", "all_args_info", + "pgle_profiler"] for compile_arg in lowering.compile_args.keys(): if compile_arg not in allowed_compile_args: raise NotImplementedError(f"Unrecognized lowered.compile_args[{compile_arg}]") @@ -790,7 +922,8 @@ def _check_lowering(lowering) -> None: # Their backwards compatibility is tested by back_compat_test.py. _CUSTOM_CALL_TARGETS_GUARANTEED_STABLE = { "Sharding", "SPMDFullToShardShape", "SPMDShardToFullShape", - "ducc_fft", "dynamic_ducc_fft", "cu_threefry2x32", + "cu_threefry2x32", "cu_threefry2x32_ffi", + "__gpu$xla.gpu.triton", # Pallas call on GPU # cholesky on CPU "lapack_spotrf", "lapack_dpotrf", "lapack_cpotrf", "lapack_zpotrf", # eigh on CPU @@ -809,7 +942,7 @@ def _check_lowering(lowering) -> None: "lapack_sgesdd", "lapack_dgesdd", "lapack_cgesdd", "lapack_zgesdd", # qr on GPU "cusolver_geqrf", "cublas_geqrf_batched", - "cusolver_geqrf", "cusolver_orgqr", + "cusolver_orgqr", # qr and svd on TPU "Qr", "ProductOfElementaryHouseholderReflectors", # triangular_solve on CPU @@ -825,7 +958,7 @@ def _check_lowering(lowering) -> None: # lu on TPU "LuDecomposition", # ApproxTopK on TPU - "ApproxTopK", + "ApproxTopK", "stablehlo.dynamic_approx_top_k", "tf.call_tf_function", # From jax2tf.call_tf(func, call_tf_graph=True) "tpu_custom_call", # Pallas/TPU kernels # TODO(burmako): maintain backwards compatibility for these, until they @@ -840,17 +973,15 @@ def _check_lowering(lowering) -> None: check_sharding_pattern = re.compile(r"^({replicated}|{unknown shard_as.*}|"")$") def _check_module(mod: ir.Module, *, - allow_non_replicated_sharding: bool, - disabled_checks: Sequence[DisabledSafetyCheck]) -> None: + disabled_checks: Sequence[DisabledSafetyCheck]) -> bool: """Run a number of checks on the module. Args: - allow_non_replicated_sharding: whether the module is allowed to contain - non_replicated sharding annotations. disabled_checks: the safety checks that are disabled. + + Returns True if the module uses non-replicated shardings. """ sharding_attr = ir.StringAttr.get("Sharding", mod.context) - shape_assertion_attr = ir.StringAttr.get("shape_assertion", mod.context) allowed_custom_call_targets: set[str] = copy.copy(_CUSTOM_CALL_TARGETS_GUARANTEED_STABLE) for dc in disabled_checks: target = dc.is_custom_call() @@ -861,20 +992,24 @@ def _check_module(mod: ir.Module, *, ir.StringAttr.get(target, mod.context) for target in allowed_custom_call_targets} disallowed_custom_call_ops: list[str] = [] + module_uses_non_replicated_sharding = False def check_sharding(op: ir.Operation, loc: ir.Location): - if not allow_non_replicated_sharding: + try: + sharding = op.attributes["mhlo.sharding"] + except KeyError: + pass + else: + nonlocal module_uses_non_replicated_sharding try: - sharding = op.attributes["mhlo.sharding"] - except KeyError: - pass + sharding_value = ir.StringAttr(sharding).value + except UnicodeDecodeError: + # The mhlo.sharding attribute may be in pretty-printed format, or + # as an encoding of an HloSharding protobuf in some rare situations. + # We handle the latter by conservatively assuming it is non-replicated. + module_uses_non_replicated_sharding = True else: - if not re.match(check_sharding_pattern, ir.StringAttr(sharding).value): - raise ValueError( - "Lowered function does not have a top-level pjit but it has" - f" non-replicated sharding annotations, e.g., {op} at {loc}.\nSee" - " https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#support-for-partitioning" - " for a discussion." - ) + if not re.match(check_sharding_pattern, sharding_value): + module_uses_non_replicated_sharding = True def check_op(op: ir.Operation): op_name = op.operation.name @@ -887,8 +1022,6 @@ def check_op(op: ir.Operation): disallowed_custom_call_ops.append(f"{op} at {op.location}") if call_target_name_attr == sharding_attr: check_sharding(op, op.location) - elif call_target_name_attr == shape_assertion_attr: - assert (DisabledSafetyCheck.shape_assertions() not in disabled_checks) def walk_operations(op): check_op(op) @@ -903,8 +1036,9 @@ def walk_operations(op): msg = ("Cannot serialize code with custom calls whose targets have no " "compatibility guarantees. Examples are:\n" f"{disallowed_custom_call_ops_str}.\n" - "See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#native-lowering-supports-only-select-custom-calls") + "See https://jax.readthedocs.io/en/latest/export/export.html#compatibility-guarantees-for-custom-calls") raise ValueError(msg) + return module_uses_non_replicated_sharding def expand_in_shardings(in_shardings: Sequence[LoweringSharding], module_kept_var_idx: Sequence[int], @@ -920,47 +1054,38 @@ def expand_in_shardings(in_shardings: Sequence[LoweringSharding], all_in_shardings[idx] = in_s return tuple(all_in_shardings) -# TODO(yashkatariya, necula): remove this function once we relax the checks -# in the jit front-end. -def canonical_shardings( - device_assignment: Sequence[jax.Device], - in_shardings: Sequence[Sharding], - out_shardings: Sequence[Sharding] - ) -> tuple[(pxla.UnspecifiedValue | - Sequence[sharding.XLACompatibleSharding]), - (pxla.UnspecifiedValue | - Sequence[sharding.XLACompatibleSharding])]: - """Prepares canonical in_ and out_shardings for a pjit invocation. - - The pjit front-end is picky about what in- and out-shardings it accepts, - e.g., if all are unspecified then the whole sharding should be the - sharding_impls.UNSPECIFIED object, otherwise the unspecified shardings are - replaced with the replicated sharding. - - Returns: a pair with the canonicalized input and output shardings. - """ - replicated_s = sharding.GSPMDSharding.get_replicated(device_assignment) - def canonicalize( - ss: Sequence[Sharding]) -> (pxla.UnspecifiedValue | - Sequence[sharding.XLACompatibleSharding]): - if all(s is None for s in ss): - return sharding_impls.UNSPECIFIED - return tuple( - sharding.GSPMDSharding(device_assignment, s) if s is not None else replicated_s - for s in ss) - return (canonicalize(in_shardings), canonicalize(out_shardings)) +def _hlo_sharding_to_xla_compatible_sharding( + hlo_sharding: HloSharding | None, + mesh: sharding.Mesh) -> sharding.Sharding | None: + if hlo_sharding is None: + return None + return sharding_impls._gspmd_to_named_sharding_via_mesh( + _hlo_sharding_to_gspmd_sharding(hlo_sharding, tuple(mesh.devices.flat)), # type: ignore[arg-type] + mesh) + +def _hlo_sharding_to_gspmd_sharding( + hlo_sharding: HloSharding | None, + device_assignment: Sequence[jax.Device]) -> sharding.GSPMDSharding | None: + if hlo_sharding is None: + return None + return sharding.GSPMDSharding(device_assignment, hlo_sharding) def _get_vjp_fun(primal_fun: Callable, *, in_tree: tree_util.PyTreeDef, in_avals: Sequence[core.AbstractValue], out_avals: Sequence[core.AbstractValue], - in_shardings: tuple[Sharding, ...], - out_shardings: tuple[Sharding, ...], - nr_devices: int, - apply_jit: bool + in_shardings_hlo: tuple[HloSharding | None, ...], + out_shardings_hlo: tuple[HloSharding | None, ...], + device_assignment: Sequence[sharding_impls.Device] | None, + apply_jit: bool, + flat_primal_fun: bool = False, ) -> tuple[Callable, Sequence[core.AbstractValue]]: # Since jax.vjp does not handle kwargs, it is easier to do all the work # here with flattened functions. + # apply_jit=False is only used for backwards compatibility with the graph + # graph serialization. When apply_jit=True, we must pass a device assignment. + # flat_primal_fun=False is used only from jax2tf, and it means that the + # `primal_fun` takes PyTree `*args` and `**kwargs`. def fun_vjp_jax(*args_and_out_cts_flat_jax): # Takes a flat list of primals and output cotangents def flattened_primal_fun_jax(*args_flat): @@ -971,7 +1096,8 @@ def flattened_primal_fun_jax(*args_flat): args_flat_jax, out_cts_flat_jax = util.split_list(args_and_out_cts_flat_jax, [len(in_avals)]) - _, pullback_jax = jax.vjp(flattened_primal_fun_jax, *args_flat_jax) + _, pullback_jax = jax.vjp(primal_fun if flat_primal_fun else flattened_primal_fun_jax, + *args_flat_jax) return pullback_jax(out_cts_flat_jax) vjp_in_avals = list( @@ -979,34 +1105,19 @@ def flattened_primal_fun_jax(*args_flat): map(lambda a: a.at_least_vspace(), out_avals))) if apply_jit: - # Prepare a device assignment. For exporting purposes, all it matters - # is the number of devices. - device_assignment = jax.devices(jax.default_backend())[:nr_devices] - assert len(device_assignment) == nr_devices - vjp_in_shardings, vjp_out_shardings = canonical_shardings( - device_assignment, - tuple(itertools.chain(in_shardings, out_shardings)), - in_shardings) + assert device_assignment is not None + vjp_in_shardings = tuple( + _hlo_sharding_to_gspmd_sharding(s, device_assignment) + for s in itertools.chain(in_shardings_hlo, out_shardings_hlo)) + vjp_out_shardings = tuple( + _hlo_sharding_to_gspmd_sharding(s, device_assignment) + for s in in_shardings_hlo) return pjit.pjit(fun_vjp_jax, in_shardings=vjp_in_shardings, out_shardings=vjp_out_shardings), vjp_in_avals else: return fun_vjp_jax, vjp_in_avals -def _export_native_vjp(primal_fun, primal: Exported) -> Exported: - # Export the VJP of `primal_fun_jax`. See documentation for Exported.vjp - fun_vjp_jax, vjp_in_avals = _get_vjp_fun(primal_fun, - in_tree=primal.in_tree, - in_avals=primal.in_avals, - in_shardings=primal.in_shardings, - out_avals=primal.out_avals, - out_shardings=primal.out_shardings, - nr_devices=primal.nr_devices, - apply_jit=True) - return export(fun_vjp_jax, - lowering_platforms=primal.lowering_platforms, - disabled_checks=primal.disabled_safety_checks)(*vjp_in_avals) - ### Calling the exported function def call(exported: Exported) -> Callable[..., jax.Array]: @@ -1054,7 +1165,7 @@ def f_imported(*args, **kwargs): f"as when the function '{exported.fun_name}' was exported, but they " "have the following structural differences:\n" + ("\n".join( - f" - {_shape_poly.args_kwargs_path_to_str(path)} is a {thing1} in the invocation and a " + f" - {shape_poly.args_kwargs_path_to_str(path)} is a {thing1} in the invocation and a " f"{thing2} when exported, so {explanation}.\n" for path, thing1, thing2, explanation in tree_util.equality_errors(in_args, exp_in_args)))) @@ -1075,12 +1186,14 @@ def _call_exported_abstract_eval( *in_avals: core.AbstractValue, exported: Exported ) -> tuple[tuple[core.AbstractValue, ...], set[effects.Effect]]: - exported_dim_vars = _shape_poly.all_dim_vars(exported.in_avals) + exported_dim_vars = shape_poly.all_dim_vars(exported.in_avals) assert len(in_avals) == len(exported.in_avals) # since the pytrees have the same structure # Check that the expected shapes match the actual ones for arg_idx, (exp_aval, actual_aval) in enumerate(zip(exported.in_avals, in_avals)): + if not isinstance(actual_aval, core.ShapedArray): + raise ValueError(f"Expected ShapedArray but got: {actual_aval}") def pp_arg_dim(dim_idx: int | None) -> str: - return _shape_poly.pretty_print_dimension_descriptor(exported.in_tree, + return shape_poly.pretty_print_dimension_descriptor(exported.in_tree, arg_idx, dim_idx) if len(exp_aval.shape) != len(actual_aval.shape): raise ValueError( @@ -1102,11 +1215,11 @@ def pp_arg_dim(dim_idx: int | None) -> str: f"expected {exp_aval.shape} and called with {actual_aval.shape}") # Must express the exported_dim_vars in terms of the shapes in in_avals. - solution, shape_constraints, synth_dim_vars = _shape_poly.solve_dim_vars( + solution, shape_constraints, synth_dim_vars = shape_poly.solve_dim_vars( exported.in_avals, args_kwargs_tree=exported.in_tree) synthetic_env = {vname: in_avals[arg_idx].shape[dim_idx] for (vname, arg_idx, dim_idx) in synth_dim_vars} - synthetic_eval = _shape_poly.CachingShapeEvaluator(**synthetic_env) + synthetic_eval = shape_poly.CachingShapeEvaluator(**synthetic_env) # We discharge all the constraints statically. This results in much simpler # composability (because we do not have to worry about the constraints of the # Exported called recursively; we only need to worry about entry-point @@ -1139,28 +1252,40 @@ def _call_exported_impl(*args, exported: Exported): def _call_exported_lowering(ctx: mlir.LoweringRuleContext, *args, exported: Exported): - if exported.uses_shape_polymorphism: + if exported.uses_global_constants: ctx.module_context.shape_poly_state.uses_dim_vars = True + submodule = ir.Module.parse(exported.mlir_module()) axis_context = ctx.module_context.axis_context if isinstance(axis_context, sharding_impls.ShardingContext): num_devices = axis_context.num_devices elif isinstance(axis_context, sharding_impls.SPMDAxisContext): num_devices = axis_context.mesh.size + elif isinstance(axis_context, sharding_impls.ReplicaAxisContext): + num_devices = axis_context.axis_env.nreps else: raise NotImplementedError(type(axis_context)) if num_devices != exported.nr_devices: - raise NotImplementedError( - f"Exported module {exported.fun_name} was lowered for " - f"{exported.nr_devices} devices and is called in a context with " - f"{num_devices} devices" - ) + # In some special cases we allow running with a different number of devices + # than the function was exported for. + err_msg = "" + if exported.nr_devices != 1: + err_msg = "the function was exported for more than 1 device." + elif (_check_module(submodule, disabled_checks=()) or + any(s is not None and not s.is_replicated() + for s in exported.in_shardings_hlo + exported.out_shardings_hlo)): + err_msg = "the function contains non-replicated sharding annotations." + if err_msg: + raise ValueError( + f"Function {exported.fun_name} was exported for " + f"{exported.nr_devices} devices and is called in a context with " + f"{num_devices} devices. This is disallowed because: {err_msg}" + ) # Apply in_shardings args = tuple( wrap_with_sharding(ctx, x, x_aval, x_sharding) - for x, x_aval, x_sharding in zip(args, ctx.avals_in, exported.in_shardings)) - submodule = ir.Module.parse(exported.mlir_module()) + for x, x_aval, x_sharding in zip(args, ctx.avals_in, exported.in_shardings_hlo)) symtab = ir.SymbolTable(submodule.operation) # The called function may have been exported with polymorphic shapes and called # now with more refined shapes. We insert hlo.ConvertOp to ensure the module @@ -1176,27 +1301,28 @@ def convert_shape(x: ir.Value, x_aval: core.AbstractValue, new_aval: core.Abstra # TODO: maybe cache multiple calls fn = mlir.merge_mlir_modules(ctx.module_context.module, f"call_exported_{exported.fun_name}", - submodule) + submodule, + dst_symtab=ctx.module_context.symbol_table) - submodule_args = [] + submodule_args: list[ir.Value] = [] # All the platforms for the current lowering must be among the platforms # for which the callee was lowered. lowering_platforms = ctx.module_context.platforms callee_lowering_platform_index: list[int] = [] for platform in lowering_platforms: - if platform in exported.lowering_platforms: + if platform in exported.platforms: callee_lowering_platform_index.append( - exported.lowering_platforms.index(platform)) + exported.platforms.index(platform)) elif DisabledSafetyCheck.platform() in exported.disabled_safety_checks: callee_lowering_platform_index.append(0) else: raise ValueError( - f"The exported function '{exported.fun_name}' was lowered for " - f"platforms '{exported.lowering_platforms}' but it is used " + f"Function '{exported.fun_name}' was exported for " + f"platforms '{exported.platforms}' but it is used " f"on '{lowering_platforms}'.") - if len(exported.lowering_platforms) > 1: + if len(exported.platforms) > 1: # The exported module takes a platform index argument if len(lowering_platforms) > 1: current_platform_idx = ctx.dim_var_values[0] @@ -1222,10 +1348,7 @@ def convert_shape(x: ir.Value, x_aval: core.AbstractValue, new_aval: core.Abstra else: assert len(lowering_platforms) == 1 - if _keep_main_tokens(exported.mlir_module_serialization_version): - ordered_effects = exported.ordered_effects - else: - ordered_effects = () + ordered_effects = exported.ordered_effects for eff in ordered_effects: token_in = ctx.tokens_in.get(eff)[0] submodule_args.append(token_in) @@ -1251,7 +1374,7 @@ def convert_shape(x: ir.Value, x_aval: core.AbstractValue, new_aval: core.Abstra # Apply out_shardings results = tuple( wrap_with_sharding(ctx, x, x_aval, x_sharding) - for x, x_aval, x_sharding in zip(results, ctx.avals_out, exported.out_shardings) + for x, x_aval, x_sharding in zip(results, ctx.avals_out, exported.out_shardings_hlo) ) return results @@ -1260,35 +1383,8 @@ def convert_shape(x: ir.Value, x_aval: core.AbstractValue, new_aval: core.Abstra def wrap_with_sharding(ctx: mlir.LoweringRuleContext, x: ir.Value, x_aval: core.AbstractValue, - x_sharding: Sharding) -> ir.Value: + x_sharding: HloSharding | None) -> ir.Value: if x_sharding is None: return x return mlir.wrap_with_sharding_op( ctx, x, x_aval, x_sharding.to_proto()) - -# TODO(necula): Previously, we had `from jax.experimental.export import export` -# Now we want to simplify the usage, and export the public APIs directly -# from `jax.experimental.export` and now `jax.experimental.export.export` -# refers to the `export` function. Since there may still be users of the -# old API in other packages, we add the old public API as attributes of the -# exported function. We will clean this up after a deprecation period. -def wrap_with_deprecation_warning(f): - msg = (f"You are using function `{f.__name__}` from " - "`jax.experimental.export.export`. You should instead use it directly " - "from `jax.experimental.export`. Instead of " - "`from jax.experimental.export import export` you should use " - "`from jax.experimental import export`.") - def wrapped_f(*args, **kwargs): - warnings.warn(msg, DeprecationWarning) - return f(*args, **kwargs) - return wrapped_f - -export.export = wrap_with_deprecation_warning(export) -export.Exported = Exported -export.call_exported = wrap_with_deprecation_warning(call_exported) -export.DisabledSafetyCheck = DisabledSafetyCheck -export.default_lowering_platform = wrap_with_deprecation_warning(default_lowering_platform) -export.symbolic_shape = wrap_with_deprecation_warning(_shape_poly.symbolic_shape) -export.args_specs = wrap_with_deprecation_warning(args_specs) -export.minimum_supported_serialization_version = minimum_supported_serialization_version -export.maximum_supported_serialization_version = maximum_supported_serialization_version diff --git a/jax/experimental/export/serialization.fbs b/jax/_src/export/serialization.fbs similarity index 93% rename from jax/experimental/export/serialization.fbs rename to jax/_src/export/serialization.fbs index e7904954a111..758950adaa8e 100644 --- a/jax/experimental/export/serialization.fbs +++ b/jax/_src/export/serialization.fbs @@ -20,7 +20,7 @@ // 3. Add back the licence comment at the start // -namespace jax.experimental.export.serialization; +namespace jax.export.serialization; enum PyTreeDefKind: byte { leaf = 0, @@ -38,7 +38,7 @@ table PyTreeDef { enum AbstractValueKind: byte { shapedArray = 0, - abstractToken = 1, + abstractToken = 1, // unused } enum DType: byte { @@ -119,16 +119,16 @@ table Exported { in_shardings: [Sharding]; out_shardings: [Sharding]; - lowering_platforms: [string]; + platforms: [string]; ordered_effects: [Effect]; unordered_effects: [Effect]; disabled_checks: [DisabledSafetyCheck]; mlir_module_serialized: [byte]; - mlir_module_serialization_version: uint16; + calling_convention_version: uint16; module_kept_var_idx: [uint16]; - uses_shape_polymorphism: bool; + uses_global_constants: bool; vjp: Exported; } diff --git a/jax/experimental/export/_serialization.py b/jax/_src/export/serialization.py similarity index 81% rename from jax/experimental/export/_serialization.py rename to jax/_src/export/serialization.py index b809eee89076..a47b095e4450 100644 --- a/jax/experimental/export/_serialization.py +++ b/jax/_src/export/serialization.py @@ -12,11 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -# Serialization and deserialization of export.Exported +# Serialization and deserialization of _export.Exported -from typing import Callable, TypeVar -from collections.abc import Sequence +from __future__ import annotations + +from collections.abc import Callable, Sequence from functools import partial +from typing import TypeVar try: import flatbuffers @@ -29,10 +31,10 @@ from jax._src import dtypes from jax._src import effects from jax._src import tree_util +from jax._src.export import serialization_generated as ser_flatbuf +from jax._src.export import _export +from jax._src.export import shape_poly from jax._src.lib import xla_client -from jax.experimental.export import serialization_generated as ser_flatbuf -from jax.experimental.export import _export -from jax.experimental import export import numpy as np @@ -45,8 +47,8 @@ # Version 2, Dec 16th, 2023, adds the f0 dtype. _SERIALIZATION_VERSION = 2 -def serialize(exp: export.Exported, vjp_order: int = 0) -> bytearray: - """Serialize an Exported. +def serialize(exp: _export.Exported, vjp_order: int = 0) -> bytearray: + """Serializes an Exported. Args: exp: the Exported to serialize. @@ -61,14 +63,14 @@ def serialize(exp: export.Exported, vjp_order: int = 0) -> bytearray: return builder.Output() -def deserialize(ser: bytearray) -> export.Exported: - """Deserialize an Exported.""" +def deserialize(ser: bytearray) -> _export.Exported: + """Deserializes an Exported.""" exp = ser_flatbuf.Exported.GetRootAsExported(ser) return _deserialize_exported(exp) def _serialize_exported( - builder: flatbuffers.Builder, exp: export.Exported, vjp_order: int + builder: flatbuffers.Builder, exp: _export.Exported, vjp_order: int ) -> int: # Serialize bottom-up fun_name = builder.CreateString(exp.fun_name) @@ -77,10 +79,10 @@ def _serialize_exported( out_tree = _serialize_pytreedef(builder, exp.out_tree) out_avals = _serialize_array(builder, _serialize_aval, exp.out_avals) in_shardings = _serialize_array( - builder, _serialize_sharding, exp.in_shardings + builder, _serialize_sharding, exp.in_shardings_hlo ) out_shardings = _serialize_array( - builder, _serialize_sharding, exp.out_shardings + builder, _serialize_sharding, exp.out_shardings_hlo ) ordered_effects = _serialize_array( builder, _serialize_effect, exp.ordered_effects @@ -91,8 +93,8 @@ def _serialize_exported( disabled_safety_checks = _serialize_array( builder, _serialize_disabled_safety_check, exp.disabled_safety_checks ) - lowering_platforms = _serialize_array( - builder, lambda b, p: b.CreateString(p), exp.lowering_platforms + platforms = _serialize_array( + builder, lambda b, p: b.CreateString(p), exp.platforms ) mlir_module_serialized = builder.CreateByteVector(exp.mlir_module_serialized) module_kept_var_idx = builder.CreateNumpyVector( @@ -119,17 +121,17 @@ def _serialize_exported( ser_flatbuf.ExportedAddNrDevices(builder, exp.nr_devices) ser_flatbuf.ExportedAddInShardings(builder, in_shardings) ser_flatbuf.ExportedAddOutShardings(builder, out_shardings) - ser_flatbuf.ExportedAddLoweringPlatforms(builder, lowering_platforms) + ser_flatbuf.ExportedAddPlatforms(builder, platforms) ser_flatbuf.ExportedAddOrderedEffects(builder, ordered_effects) ser_flatbuf.ExportedAddUnorderedEffects(builder, unordered_effects) ser_flatbuf.ExportedAddDisabledChecks(builder, disabled_safety_checks) ser_flatbuf.ExportedAddMlirModuleSerialized(builder, mlir_module_serialized) - ser_flatbuf.ExportedAddMlirModuleSerializationVersion( - builder, exp.mlir_module_serialization_version + ser_flatbuf.ExportedAddCallingConventionVersion( + builder, exp.calling_convention_version ) ser_flatbuf.ExportedAddModuleKeptVarIdx(builder, module_kept_var_idx) - ser_flatbuf.ExportedAddUsesShapePolymorphism( - builder, exp.uses_shape_polymorphism + ser_flatbuf.ExportedAddUsesGlobalConstants( + builder, exp.uses_global_constants ) if vjp is not None: ser_flatbuf.ExportedAddVjp(builder, vjp) @@ -148,7 +150,7 @@ def _serialize_array( return builder.EndVector() -def _deserialize_exported(exp: ser_flatbuf.Exported) -> export.Exported: +def _deserialize_exported(exp: ser_flatbuf.Exported) -> _export.Exported: serialization_version = exp.SerializationVersion() if serialization_version != _SERIALIZATION_VERSION: raise NotImplementedError( @@ -159,7 +161,7 @@ def _deserialize_exported(exp: ser_flatbuf.Exported) -> export.Exported: _, in_tree = tree_util.tree_flatten( _deserialize_pytreedef_to_pytree(exp.InTree()) ) - scope = export.SymbolicScope(()) # TODO: serialize the constraints + scope = shape_poly.SymbolicScope(()) # TODO: serialize the constraints deser_aval = partial(_deserialize_aval, scope=scope) in_avals = _deserialize_tuple( exp.InAvalsLength, exp.InAvals, deser_aval @@ -177,9 +179,9 @@ def _deserialize_exported(exp: ser_flatbuf.Exported) -> export.Exported: out_shardings = _deserialize_tuple( exp.OutShardingsLength, exp.OutShardings, _deserialize_sharding ) - lowering_platforms = _deserialize_tuple( - exp.LoweringPlatformsLength, - exp.LoweringPlatforms, + platforms = _deserialize_tuple( + exp.PlatformsLength, + exp.Platforms, lambda v: v.decode("utf-8"), ) ordered_effects = _deserialize_tuple( @@ -195,31 +197,31 @@ def _deserialize_exported(exp: ser_flatbuf.Exported) -> export.Exported: ) mlir_module_serialized = exp.MlirModuleSerializedAsNumpy().tobytes() - mlir_module_serialization_version = exp.MlirModuleSerializationVersion() + calling_convention_version = exp.CallingConventionVersion() module_kept_var_idx = tuple(exp.ModuleKeptVarIdxAsNumpy().tolist()) - uses_shape_polymorphism = exp.UsesShapePolymorphism() + uses_global_constants = exp.UsesGlobalConstants() _get_vjp = None if vjp := exp.Vjp(): _get_vjp = lambda _: _deserialize_exported(vjp) - return export.Exported( + return _export.Exported( fun_name=fun_name, in_tree=in_tree, in_avals=in_avals, out_tree=out_tree, out_avals=out_avals, nr_devices=nr_devices, - in_shardings=in_shardings, - out_shardings=out_shardings, - lowering_platforms=lowering_platforms, + in_shardings_hlo=in_shardings, + out_shardings_hlo=out_shardings, + platforms=platforms, ordered_effects=ordered_effects, unordered_effects=unordered_effects, disabled_safety_checks=disabled_safety_checks, mlir_module_serialized=mlir_module_serialized, - mlir_module_serialization_version=mlir_module_serialization_version, + calling_convention_version=calling_convention_version, module_kept_var_idx=module_kept_var_idx, - uses_shape_polymorphism=uses_shape_polymorphism, + uses_global_constants=uses_global_constants, _get_vjp=_get_vjp, ) @@ -329,32 +331,28 @@ def _deserialize_pytreedef_to_pytree(p: ser_flatbuf.PyTreeDef): def _serialize_aval( - builder: flatbuffers.Builder, aval: core.AbstractValue + builder: flatbuffers.Builder, aval: core.ShapedArray ) -> int: - aval_type = type(aval) - if aval_type is core.ShapedArray: - aval_kind = ser_flatbuf.AbstractValueKind.shapedArray - shape_offsets = [builder.CreateString(str(d)) for d in aval.shape] - ser_flatbuf.AbstractValueStartShapeVector(builder, len(aval.shape)) - for d in reversed(shape_offsets): - builder.PrependUOffsetTRelative(d) - shape_vector_offset = builder.EndVector() - - ser_flatbuf.AbstractValueStart(builder) - ser_flatbuf.AbstractValueAddKind(builder, aval_kind) - ser_flatbuf.AbstractValueAddShape(builder, shape_vector_offset) - ser_flatbuf.AbstractValueAddDtype(builder, _dtype_to_dtype_kind[aval.dtype]) - return ser_flatbuf.AbstractValueEnd(builder) - else: - raise NotImplementedError(f"serializing AbstractValue: {aval}") + aval_kind = ser_flatbuf.AbstractValueKind.shapedArray + shape_offsets = [builder.CreateString(str(d)) for d in aval.shape] + ser_flatbuf.AbstractValueStartShapeVector(builder, len(aval.shape)) + for d in reversed(shape_offsets): + builder.PrependUOffsetTRelative(d) + shape_vector_offset = builder.EndVector() + + ser_flatbuf.AbstractValueStart(builder) + ser_flatbuf.AbstractValueAddKind(builder, aval_kind) + ser_flatbuf.AbstractValueAddShape(builder, shape_vector_offset) + ser_flatbuf.AbstractValueAddDtype(builder, _dtype_to_dtype_kind[aval.dtype]) + return ser_flatbuf.AbstractValueEnd(builder) def _deserialize_aval(aval: ser_flatbuf.AbstractValue, - scope) -> core.AbstractValue: + scope) -> core.ShapedArray: aval_kind = aval.Kind() if aval_kind == ser_flatbuf.AbstractValueKind.shapedArray: dtype = _dtype_kind_to_dtype[aval.Dtype()] - shape = export.symbolic_shape( + shape = shape_poly.symbolic_shape( ",".join( aval.Shape(i).decode("utf-8") for i in range(aval.ShapeLength()) ), @@ -366,14 +364,14 @@ def _deserialize_aval(aval: ser_flatbuf.AbstractValue, def _serialize_sharding( - builder: flatbuffers.Builder, s: _export.Sharding + builder: flatbuffers.Builder, s: _export.HloSharding | None ) -> int: proto = None if s is None: kind = ser_flatbuf.ShardingKind.unspecified else: kind = ser_flatbuf.ShardingKind.hlo_sharding - proto_bytes = s.to_proto().SerializeToString() # type: ignore[union-attr] + proto_bytes = s.to_proto().SerializeToString() proto = builder.CreateByteVector(proto_bytes) ser_flatbuf.ShardingStart(builder) @@ -383,7 +381,7 @@ def _serialize_sharding( return ser_flatbuf.ShardingEnd(builder) -def _deserialize_sharding(s: ser_flatbuf.Sharding) -> _export.Sharding: +def _deserialize_sharding(s: ser_flatbuf.Sharding) -> _export.HloSharding | None: kind = s.Kind() if kind == ser_flatbuf.ShardingKind.unspecified: return None @@ -443,16 +441,16 @@ def _deserialize_effect(eff: ser_flatbuf.Effect) -> core.Effect: def _serialize_disabled_safety_check( - builder: flatbuffers.Builder, check: export.DisabledSafetyCheck + builder: flatbuffers.Builder, check: _export.DisabledSafetyCheck ) -> int: custom_call_target_str = check.is_custom_call() custom_call_target = None if custom_call_target_str is not None: kind = ser_flatbuf.DisabledSafetyCheckKind.custom_call custom_call_target = builder.CreateString(custom_call_target_str) - elif check == export.DisabledSafetyCheck.platform(): + elif check == _export.DisabledSafetyCheck.platform(): kind = ser_flatbuf.DisabledSafetyCheckKind.platform - elif check == export.DisabledSafetyCheck.shape_assertions(): + elif check == _export.DisabledSafetyCheck.shape_assertions(): kind = ser_flatbuf.DisabledSafetyCheckKind.shape_assertions else: raise NotImplementedError(f"serializing DisabledSafetyCheck: {check}") @@ -468,14 +466,14 @@ def _serialize_disabled_safety_check( def _deserialize_disabled_safety_check( sc: ser_flatbuf.DisabledSafetyCheck, -) -> export.DisabledSafetyCheck: +) -> _export.DisabledSafetyCheck: kind = sc.Kind() if kind == ser_flatbuf.DisabledSafetyCheckKind.custom_call: - return export.DisabledSafetyCheck.custom_call( + return _export.DisabledSafetyCheck.custom_call( sc.CustomCallTarget().decode("utf-8") ) if kind == ser_flatbuf.DisabledSafetyCheckKind.platform: - return export.DisabledSafetyCheck.platform() + return _export.DisabledSafetyCheck.platform() if kind == ser_flatbuf.DisabledSafetyCheckKind.shape_assertions: - return export.DisabledSafetyCheck.shape_assertions() + return _export.DisabledSafetyCheck.shape_assertions() assert False, kind diff --git a/jax/experimental/export/serialization_generated.py b/jax/_src/export/serialization_generated.py similarity index 96% rename from jax/experimental/export/serialization_generated.py rename to jax/_src/export/serialization_generated.py index 941513667dae..a872d03a9fdd 100644 --- a/jax/experimental/export/serialization_generated.py +++ b/jax/_src/export/serialization_generated.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +# pytype: skip-file # automatically generated by the FlatBuffers compiler, do not modify # namespace: serialization @@ -20,7 +21,7 @@ from flatbuffers.compat import import_numpy np = import_numpy() -class PyTreeDefKind(object): +class PyTreeDefKind: leaf = 0 none = 1 tuple = 2 @@ -28,12 +29,12 @@ class PyTreeDefKind(object): dict = 4 -class AbstractValueKind(object): +class AbstractValueKind: shapedArray = 0 abstractToken = 1 -class DType(object): +class DType: bool = 0 i8 = 1 i16 = 2 @@ -59,18 +60,18 @@ class DType(object): f0 = 22 -class ShardingKind(object): +class ShardingKind: unspecified = 0 hlo_sharding = 1 -class DisabledSafetyCheckKind(object): +class DisabledSafetyCheckKind: platform = 0 custom_call = 1 shape_assertions = 2 -class PyTreeDef(object): +class PyTreeDef: __slots__ = ['_tab'] @classmethod @@ -162,7 +163,7 @@ def PyTreeDefEnd(builder): -class AbstractValue(object): +class AbstractValue: __slots__ = ['_tab'] @classmethod @@ -234,7 +235,7 @@ def AbstractValueEnd(builder): -class Sharding(object): +class Sharding: __slots__ = ['_tab'] @classmethod @@ -303,7 +304,7 @@ def ShardingEnd(builder): -class Effect(object): +class Effect: __slots__ = ['_tab'] @classmethod @@ -339,7 +340,7 @@ def EffectEnd(builder): -class DisabledSafetyCheck(object): +class DisabledSafetyCheck: __slots__ = ['_tab'] @classmethod @@ -385,7 +386,7 @@ def DisabledSafetyCheckEnd(builder): -class Exported(object): +class Exported: __slots__ = ['_tab'] @classmethod @@ -546,7 +547,7 @@ def OutShardingsIsNone(self): return o == 0 # Exported - def LoweringPlatforms(self, j): + def Platforms(self, j): o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(22)) if o != 0: a = self._tab.Vector(o) @@ -554,14 +555,14 @@ def LoweringPlatforms(self, j): return "" # Exported - def LoweringPlatformsLength(self): + def PlatformsLength(self): o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(22)) if o != 0: return self._tab.VectorLen(o) return 0 # Exported - def LoweringPlatformsIsNone(self): + def PlatformsIsNone(self): o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(22)) return o == 0 @@ -665,7 +666,7 @@ def MlirModuleSerializedIsNone(self): return o == 0 # Exported - def MlirModuleSerializationVersion(self): + def CallingConventionVersion(self): o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(32)) if o != 0: return self._tab.Get(flatbuffers.number_types.Uint16Flags, o + self._tab.Pos) @@ -699,7 +700,7 @@ def ModuleKeptVarIdxIsNone(self): return o == 0 # Exported - def UsesShapePolymorphism(self): + def UsesGlobalConstants(self): o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(36)) if o != 0: return bool(self._tab.Get(flatbuffers.number_types.BoolFlags, o + self._tab.Pos)) @@ -757,10 +758,10 @@ def ExportedAddOutShardings(builder, outShardings): def ExportedStartOutShardingsVector(builder, numElems): return builder.StartVector(4, numElems, 4) -def ExportedAddLoweringPlatforms(builder, loweringPlatforms): - builder.PrependUOffsetTRelativeSlot(9, flatbuffers.number_types.UOffsetTFlags.py_type(loweringPlatforms), 0) +def ExportedAddPlatforms(builder, platforms): + builder.PrependUOffsetTRelativeSlot(9, flatbuffers.number_types.UOffsetTFlags.py_type(platforms), 0) -def ExportedStartLoweringPlatformsVector(builder, numElems): +def ExportedStartPlatformsVector(builder, numElems): return builder.StartVector(4, numElems, 4) def ExportedAddOrderedEffects(builder, orderedEffects): @@ -787,8 +788,8 @@ def ExportedAddMlirModuleSerialized(builder, mlirModuleSerialized): def ExportedStartMlirModuleSerializedVector(builder, numElems): return builder.StartVector(1, numElems, 1) -def ExportedAddMlirModuleSerializationVersion(builder, mlirModuleSerializationVersion): - builder.PrependUint16Slot(14, mlirModuleSerializationVersion, 0) +def ExportedAddCallingConventionVersion(builder, callingConventionVersion): + builder.PrependUint16Slot(14, callingConventionVersion, 0) def ExportedAddModuleKeptVarIdx(builder, moduleKeptVarIdx): builder.PrependUOffsetTRelativeSlot(15, flatbuffers.number_types.UOffsetTFlags.py_type(moduleKeptVarIdx), 0) @@ -796,8 +797,8 @@ def ExportedAddModuleKeptVarIdx(builder, moduleKeptVarIdx): def ExportedStartModuleKeptVarIdxVector(builder, numElems): return builder.StartVector(2, numElems, 2) -def ExportedAddUsesShapePolymorphism(builder, usesShapePolymorphism): - builder.PrependBoolSlot(16, usesShapePolymorphism, 0) +def ExportedAddUsesGlobalConstants(builder, usesGlobalConstants): + builder.PrependBoolSlot(16, usesGlobalConstants, 0) def ExportedAddVjp(builder, vjp): builder.PrependUOffsetTRelativeSlot(17, flatbuffers.number_types.UOffsetTFlags.py_type(vjp), 0) diff --git a/jax/experimental/export/_shape_poly.py b/jax/_src/export/shape_poly.py similarity index 92% rename from jax/experimental/export/_shape_poly.py rename to jax/_src/export/shape_poly.py index 8afc8bf65eb0..d380bc5a2476 100644 --- a/jax/experimental/export/_shape_poly.py +++ b/jax/_src/export/shape_poly.py @@ -13,28 +13,13 @@ # limitations under the License. """Shape polymorphism support. -We introduce a set of dimension variables at the top-level of a `jit` function. -They are introduced implicitly by way of specifying for each dimension of each -argument a symbolic dimension expression in terms of some dimension variables. -All dimension variables are assumed to range over integers greater or equal to 1. - -Symbolic dimensions overload some integer operations, such as -add, multiply, divide, equality, etc. The JAX NumPy layer and the LAX layers have been -touched up to be sensitive to handling shapes that contain symbolic dimensions. -This enables many JAX programs to be traced with symbolic dimensions -in some dimensions. A priority has been to enable the batch -dimension in neural network examples to be polymorphic. - -This was built initially for jax2tf, but it is now -independent of TF. The best documentation at the moment is in the -jax2tf.convert docstring, and the -[README](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md). +See documentation at https://jax.readthedocs.io/en/latest/export/shape_poly.html. """ from __future__ import annotations import enum -from collections.abc import Iterable, Sequence +from collections.abc import Callable, Sequence import dataclasses from enum import Enum import functools @@ -42,9 +27,8 @@ import io import copy import operator as op -import threading import tokenize -from typing import Any, Callable, Union, overload +from typing import Any, Union, overload import warnings import numpy as np @@ -87,23 +71,14 @@ class InconclusiveDimensionOperation(core.InconclusiveDimensionOperation): are non-constant, and the result of the operation cannot be represented as a boolean value for all values of the symbolic dimensions involved. -Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#comparison_of_symbolic_dimensions_is_partially_supported +Please see https://jax.readthedocs.io/en/latest/export/shape_poly.html#comparison-of-symbolic-dimensions-is-partially-supported for more details. """ def __init__(self, message: str): error_msg = f"{message}{InconclusiveDimensionOperation._help_msg}" # https://github.com/python/mypy/issues/5887 - super().__init__(error_msg) # type: ignore - -class _ShapePolyThreadLocalState(threading.local): - - def __init__(self): - # TODO(necula): this does not play well with some lowering caches, because - # this state is not part of the cache key. - self.enable_shape_assertions = True - -thread_local_state = _ShapePolyThreadLocalState() + super().__init__(error_msg) class Comparator(Enum): @@ -209,7 +184,7 @@ def _syntactic_cmp(self, other: _DimFactor) -> int: if c := cmp_comparable(self._size, other._size): return c if self.var is not None: return cmp_comparable(self.var, other.var) - if c := cmp_comparable(self.operation, other.operation): return c # type: ignore + if c := cmp_comparable(self.operation, other.operation): return c return cmp_sequence(self.operands, other.operands, lambda s_o, o_o: s_o._syntactic_cmp(o_o)) @@ -240,8 +215,8 @@ def evaluate(self, env: DimVarEnv): return env[self.var] except KeyError: err_msg = ( - f"Encountered dimension variable '{self.var}' that is not appearing in the shapes of the used function arguments.\n" - "Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#dimension-variables-must-be-solvable-from-the-input-shapes for more details.") + f"Encountered dimension variable '{self.var}' that is not appearing in the shapes of the function arguments.\n" + "Please see https://jax.readthedocs.io/en/latest/export/shape_poly.html#dimension-variables-must-be-solvable-from-the-input-shapes for more details.") raise KeyError(err_msg) else: operand_values = [opnd._evaluate(env) for opnd in self.operands] @@ -379,8 +354,8 @@ def mul(self, other: _DimTerm) -> _DimTerm: """ Returns the product with another term. Example: (n^2*m) * n == n^3 * m. """ - return _DimTerm(_DimExpr._linear_combination_sorted_pairs(self._factors, 0, 1, # type: ignore[arg-type] - other._factors, 0, 1)) # type: ignore[arg-type] + return _DimTerm(_DimExpr._linear_combination_sorted_pairs(self._factors, 0, 1, + other._factors, 0, 1)) def divide(self, divisor: _DimTerm) -> _DimTerm: """ @@ -388,12 +363,12 @@ def divide(self, divisor: _DimTerm) -> _DimTerm: if the result is not a term. For example, (n^3 * m) // n == n^2*m, but n // m fails. """ - new_factors = _DimExpr._linear_combination_sorted_pairs(self._factors, 0, 1, # type: ignore[arg-type] - divisor._factors, 0, -1) # type: ignore[arg-type] + new_factors = _DimExpr._linear_combination_sorted_pairs(self._factors, 0, 1, + divisor._factors, 0, -1) for _, f_exp in new_factors: if f_exp <= 0: raise InconclusiveDimensionOperation(f"Cannot divide {self} by {divisor}.") - return _DimTerm(new_factors) # type: ignore + return _DimTerm(new_factors) def evaluate(self, env: DimVarEnv): prod = lambda xs: functools.reduce(_evaluate_multiply, xs) if xs else core.dim_constant(1) @@ -601,15 +576,13 @@ def _get_vars(self) -> set[str]: @staticmethod def _linear_combination_sorted_pairs( e1: SortedTerms, i1: int, f1: int, - e2: SortedTerms, i2: int, f2: int) -> SortedTerms: - ... + e2: SortedTerms, i2: int, f2: int) -> SortedTerms: ... # type: ignore[bad-return-type,unused-ignore] @overload @staticmethod def _linear_combination_sorted_pairs( e1: SortedFactors, i1: int, f1: int, - e2: SortedFactors, i2: int, f2: int) -> SortedFactors: - ... + e2: SortedFactors, i2: int, f2: int) -> SortedFactors: ... # type: ignore[bad-return-type,unused-ignore] @staticmethod def _linear_combination_sorted_pairs( @@ -674,7 +647,7 @@ def _eq(self, other: _DimExpr) -> bool: # Here we really ought to raise InconclusiveDimensionOperation, but __eq__ # cannot raise exceptions, because it is used indirectly when hashing. # So, we say that the expressions are disequal, which is really unsound. - # See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#comparison-of-symbolic-dimensions-is-partially-supported + # See https://jax.readthedocs.io/en/latest/export/shape_poly.html#comparison-of-symbolic-dimensions-is-partially-supported return False return diff == 0 @@ -851,7 +824,7 @@ def __eq__(self, other: Any) -> bool: # Here we really ought to raise InconclusiveDimensionOperation, but __eq__ # cannot raise exceptions, because it is used indirectly when hashing. # So, we say that the expressions are disequal, which is really unsound. - # See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#comparison-of-symbolic-dimensions-is-partially-supported + # See https://jax.readthedocs.io/en/latest/export/shape_poly.html#comparison-of-symbolic-dimensions-is-partially-supported return False return diff == 0 @@ -885,7 +858,7 @@ def _divmod(self, divisor: DimSize) -> tuple[DimSize, int]: # invariant: self = dividend + divisor * quotient # quotient and dividend are changed in the loop; the leading term of # dividend decreases at each iteration. - while is_symbolic_dim(dividend) and not dividend._is_constant: + while is_symbolic_dim(dividend) and not dividend._is_constant: # type: ignore[attribute-error,unused-ignore] mon, count = dividend._leading_term if isinstance(divisor, _DimExpr): dterm, dcount = divisor._leading_term @@ -898,7 +871,7 @@ def _divmod(self, divisor: DimSize) -> tuple[DimSize, int]: q = _DimExpr._from_term(qterm, qcount, self.scope) quotient += q - dividend -= q * divisor # type: ignore[assignment] + dividend -= q * divisor dividend = int(dividend) # type: ignore[assignment] if isinstance(divisor, _DimExpr): @@ -906,7 +879,7 @@ def _divmod(self, divisor: DimSize) -> tuple[DimSize, int]: raise InconclusiveDimensionOperation("") remainder = 0 else: - q, r = divmod(dividend, int(divisor)) # type: ignore + q, r = divmod(dividend, int(divisor)) quotient += q remainder = r @@ -994,7 +967,7 @@ class SymbolicScope: Holds the constraints on symbolic expressions. - See [the README](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#user-specified-symbolic-constraints) + See [the README](https://jax.readthedocs.io/en/latest/export/shape_poly.html#user-specified-symbolic-constraints) for more details. Args: @@ -1058,9 +1031,9 @@ def _parse_and_process_explicit_constraint(self, c_str: str): if cmp_pos < 0: raise ValueError("Constraint parsing error: must contain one of '==' or '>=' or '<='") e1_str = c_str[:cmp_pos] - e1, = _Parser(e1_str, None, repr(e1_str), self).parse() + e1, = _Parser(e1_str, None, repr(e1_str), self).parse() # type: ignore[name-error,unused-ignore] e2_str = c_str[cmp_pos + 2:] - e2, = _Parser(e2_str, None, repr(e2_str), self).parse() + e2, = _Parser(e2_str, None, repr(e2_str), self).parse() # type: ignore[name-error,unused-ignore] if cmp == Comparator.GEQ and not is_geq: e1, e2 = e2, e1 @@ -1082,7 +1055,7 @@ def _parse_and_process_explicit_constraint(self, c_str: str): raise ValueError("Invalid equality constraint: {e1} == {e2}. " "The left-hand-side must be of the form `term * coefficient`.") - after = _ensure_poly(e2, "parse_constraint", e1.scope) + after = _ensure_poly(e2, "parse_constraint", e1.scope) # type: ignore[name-error,unused-ignore] if before in self._normalization_rules: raise NotImplementedError( f"Found multiple equality constraints with the same left-hand-side: {before}") @@ -1097,7 +1070,7 @@ def _check_same_scope(self, other: _DimExpr, f"Invalid mixing of symbolic scopes {when}.\n" f"Expected {self_descr}scope {self}\n" f"and found for '{other}' ({other_descr}) scope {other.scope}\n" - f"See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#user-specified-symbolic-constraints.") + f"See https://jax.readthedocs.io/en/latest/export/shape_poly.html#user-specified-symbolic-constraints.") def _clear_caches(self): self._bounds_cache.clear() @@ -1206,6 +1179,8 @@ def _convertible_to_poly(p: DimSize) -> bool: return isinstance(p, _DimExpr) or _convertible_to_int(p) def is_symbolic_dim(p: DimSize) -> bool: + """Checks if a dimension is symbolic. + """ return isinstance(p, _DimExpr) def is_poly_dim(p: DimSize) -> bool: @@ -1311,9 +1286,8 @@ def shape_assertion(assert_what: jax.Array, The format specifiers are sometimes processed with Python's `string::format` method, and sometimes with `llvm::formatv`. """ - if thread_local_state.enable_shape_assertions: - shape_assertion_p.bind(assert_what, *error_message_inputs, - error_message=error_message) + shape_assertion_p.bind(assert_what, *error_message_inputs, + error_message=error_message) # A JAX primitive with no array arguments but with a dimension parameter # that is a DimExpr. The value of the primitive is the value of the dimension, @@ -1324,7 +1298,7 @@ def shape_assertion(assert_what: jax.Array, def dim_as_value_impl(dim: DimSize): raise NotImplementedError( "Evaluation rule for 'dim_as_value' is not implemented. " - "It seems that you are using shape polymorphism outside jax2tf.") + "It seems that you are using shape polymorphism outside jax.export.") dim_as_value_p.def_impl(dim_as_value_impl) def _dim_as_value(dim: DimSize): @@ -1373,26 +1347,36 @@ def symbolic_shape(shape_spec: str | None, scope: SymbolicScope | None = None, like: Sequence[int | None] | None = None ) -> Sequence[DimSize]: - """Constructs a jax.ShapeDtypeStruct with polymorphic shapes. + """Constructs a symbolic shape from a string representation. + + See https://jax.readthedocs.io/en/latest/export/shape_poly.html for examples. Args: shape_spec: a symbolic shape specification. None stands for "...". + A shape specification is the string representation of a tuple (the + parentheses are optional) with comma-separated dimension expressions. + A dimension expression can be either: an integer constant, + a dimension variable (alphanumeric + starting with a letter), e1 + e2, e1 - e2, e1 * e2, floordiv(e1, e2), + mod(e1, e2), max(e1, e2), or min(e1, e2). + constraints: a sequence of constraints on symbolic dimension expressions, of + the form `e1 >= e2` or `e1 <= e2`, or `e1 == e2`. + See [the documentation](https://jax.readthedocs.io/en/latest/export/shape_poly.html#user-specified-symbolic-constraints) + for usage. + scope: optionally, you can specify that the parsed symbolic expressions + be created in the given scope. If this is missing, then a new + `SymbolicScope` is created with the given `constraints`. + You cannot specify both a `scope` and `constraints`. + See [the documentation](https://jax.readthedocs.io/en/latest/export/shape_poly.html#user-specified-symbolic-constraints) + for usage. like: when `shape_spec` contains placeholders ("_", "..."), use this shape to fill in the placeholders. The dimensions of `like` that are used for filling - must be known (not `None`). If a dimension in `like` is known and + must be not `None`. If a dimension in `like` is not `None` and the corresponding dimension in `shape_spec` is a constant then they must be equal. - scope: optionally, you can specify that the parsed symbolic expressions - be created in a given scope. You cannot specify `constraints` in this case. - constraints: a sequence of constraints on symbolic dimension expressions, of - the form `e1 >= e2` or `e1 <= e2`. This is used to create a new SymbolicScope - shared by all symbolic expressions created. - See [the README](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#user-specified-symbolic-constraints) - for more details. - Returns: a jax.ShapeDTypeStruct with shapes that may contain symbolic - expressions involving dimension variables. + Returns: a tuple with integers or symbolic expressions involving dimension variables. """ shape_spec_repr = repr(shape_spec) if shape_spec is None: @@ -1411,43 +1395,51 @@ def symbolic_shape(shape_spec: str | None, def symbolic_args_specs( args, # pytree of arguments - polymorphic_shapes, # prefix pytree of strings - symbolic_scope: SymbolicScope | None = None, - symbolic_constraints: Sequence[str] = (), + shapes_specs, # prefix pytree of strings + constraints: Sequence[str] = (), + scope: SymbolicScope | None = None, + symbolic_constraints: Sequence[str] = (), # DEPRECATED on 6/14/24 + symbolic_scope: SymbolicScope | None = None, # DEPRECATED on 6/14/24 ): """Constructs a pytree of jax.ShapeDtypeSpec arguments specs for `export`. - Note that this function does not ensure that the provided `args` shapes - are compatible with `polymorphic_shapes`. The `.shape` of the `args` are - used only to fill-in placeholders from `polymorphic_shapes`. - - See docstring of `symbolic_shape` and - [the README](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#shape-polymorphic-conversion) - for more details. + See the documentation of :func:`jax.export.symbolic_shape` and + the [shape polymorphism documentation](https://jax.readthedocs.io/en/latest/export/shape_poly.html) for details. Args: args: a pytree of arguments. These can be jax.Array, or jax.ShapeDTypeSpec. - This is used to learn the pytree structure of the arguments, their dtypes, - and to fill-in the actual shapes where the `polymorphic_shapes` contains + They are used to learn the pytree structure of the arguments, their dtypes, + and to fill-in the actual shapes where the `shapes_specs` contains placeholders. Note that only the shape dimensions for which - `polymorphic_shapes` is a placeholder are used from `args`. - The unused dimensions can be `None`, which jax2tf uses when the TF - shapes are not known. - polymorphic_shapes: should be `None` (all arguments have static shapes), - a single string (applies to all arguments), or a pytree matching a prefix + `shapes_specs` is a placeholder are used from `args`. + shapes_specs: should be `None` (all arguments have static shapes), + a single string (see `shape_spec` for :func:`jax.export.symbolic_shape`; + applies to all arguments), or a pytree matching a prefix of the `args`. See [how optional parameters are matched to arguments](https://jax.readthedocs.io/en/latest/pytrees.html#applying-optional-parameters-to-pytrees). - symbolic_scope: optionally, you can specify that the parsed symbolic expressions - be created in a given scope. You cannot specify `symbolic_constraints` in this case. - symbolic_constraints: a sequence of constraints on symbolic dimension expressions, of - the form `e1 >= e2` or `e1 <= e2`. This is used to create a new SymbolicScope - shared by all symbolic expressions created. - See more details at https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#user-specified-symbolic-constraints. + constraints: as for :func:`jax.export.symbolic_shape`. + scope: as for :func:`jax.export.symbolic_shape`. + symbolic_constraints: DEPRECATED, use `constraints`. + symbolic_scope: DEPRECATED, use `scope`. Returns: a pytree of jax.ShapeDTypeStruct matching the `args` with the shapes - replaced with symbolic dimensions as specified by `polymorphic_shapes`. + replaced with symbolic dimensions as specified by `shapes_specs`. """ + if symbolic_constraints: + warnings.warn("symbolic_constraints is deprecated, use constraints", + DeprecationWarning, stacklevel=2) + if constraints: + raise ValueError("Cannot use both symbolic_constraints and constraints") + constraints = symbolic_constraints + if symbolic_scope is not None: + warnings.warn("symbolic_scope is deprecated, use scope", + DeprecationWarning, stacklevel=2) + if scope is not None: + raise ValueError("Cannot use both symbolic_scope and scope") + scope = symbolic_scope + + polymorphic_shapes = shapes_specs args_flat, args_tree = tree_util.tree_flatten(args) shapes_and_dtypes = tuple(map(shape_and_dtype_jax_array, args_flat)) @@ -1467,15 +1459,15 @@ def symbolic_args_specs( e, *_ = tree_util.prefix_errors( polymorphic_shapes_, args, is_leaf=lambda x: x is None) - raise e("jax_export polymorphic_shapes") from None + raise e("export.symbolic_args_specs shapes_specs") from None # Now add in the polymorphic shapes - if symbolic_scope is None: - symbolic_scope = SymbolicScope(symbolic_constraints) - elif symbolic_constraints: - raise ValueError("Cannot have both `symbolic_scope` and `symbolic_constraints`") + if scope is None: + scope = SymbolicScope(constraints) + elif constraints: + raise ValueError("Cannot use both `scope` and `constraints`") args_specs_flat = ( - jax.ShapeDtypeStruct(symbolic_shape(spec, like=s, scope=symbolic_scope), t) + jax.ShapeDtypeStruct(symbolic_shape(spec, like=s, scope=scope), t) for s, t, spec in zip(shapes, dtypes, polymorphic_shapes_flat)) return args_tree.unflatten(args_specs_flat) @@ -1516,7 +1508,7 @@ def add_dim(self, expr: DimSize | None, tok: tokenize.TokenInfo): if core.is_constant_dim(expr) and self.like_shape is not None: like_shape_dim = self.like_shape[len(self.dimensions)] - if expr != like_shape_dim: # type: ignore[operator] + if expr != like_shape_dim: raise self.parse_err(tok, (f"different size {expr} for known dimension; " f"like={self.like_shape}")) @@ -1533,7 +1525,7 @@ def parse_err(self, tok: tokenize.TokenInfo | None, detail: str) -> Exception: def next_tok(self) -> tokenize.TokenInfo: while True: try: - t = next(self.tokstream) + t = next(self.tokstream) # type: ignore[attribute-error,unused-ignore] except StopIteration: raise self.parse_err(None, "unexpected end of string") if t.exact_type not in [tokenize.NEWLINE, tokenize.INDENT, tokenize.DEDENT]: @@ -1619,7 +1611,7 @@ def expr(self, tok: tokenize.TokenInfo) -> tuple[DimSize, tokenize.TokenInfo]: while True: t, tok = self.term(tok) t_sign = - t if next_t_negated else t - acc = acc + t_sign if acc is not None else t_sign # type:ignore [operator] + acc = acc + t_sign if acc is not None else t_sign # type: ignore[operator] if tok.exact_type in self.FOLLOW_EXPR: return acc, tok next_t_negated = (tok.exact_type == tokenize.MINUS) @@ -1640,7 +1632,7 @@ def term(self, tok: tokenize.TokenInfo) -> tuple[DimSize, tokenize.TokenInfo]: acc = acc * f if acc is not None else f # type: ignore[operator] if tok.exact_type in self.FOLLOW_TERM: - return acc, tok + return acc, tok # type: ignore[bad-return-type,unused-ignore] tok = self.consume_token(tok, tokenize.STAR) def factor(self, tok: tokenize.TokenInfo) -> tuple[DimSize, tokenize.TokenInfo]: @@ -1666,7 +1658,7 @@ def factor_unary_op(self, op: str, tok: tokenize.TokenInfo) -> tuple[DimSize, to e1, tok = self.expr(tok) tok = self.consume_token(tok, tokenize.RPAR) return _DimExpr._from_operation(op, e1, - scope=self.scope), tok # type: ignore + scope=self.scope), tok def factor_binary_op(self, op: str, tok) -> tuple[DimSize, tokenize.TokenInfo]: tok = self.consume_token(tok, tokenize.LPAR) @@ -1679,7 +1671,7 @@ def factor_binary_op(self, op: str, tok) -> tuple[DimSize, tokenize.TokenInfo]: if op == _DimFactor.MIN: return core.min_dim(e1, e2), tok return _DimExpr._from_operation(op, e1, e2, - scope=self.scope), tok # type: ignore + scope=self.scope), tok def _evaluate_add(v1, v2): @@ -1730,7 +1722,7 @@ def _dimension_size_lowering_rule(ctx, arg, *, dimension): mlir.register_lowering(dimension_size_p, _dimension_size_lowering_rule) -def all_dim_vars(args_avals: Sequence[core.AbstractValue]) -> Sequence[str]: +def all_dim_vars(args_avals: Sequence[core.ShapedArray]) -> Sequence[str]: dim_vars: set[str] = set() for a in args_avals: for d in a.shape: @@ -1869,7 +1861,7 @@ def check_statically(self, eval: CachingShapeEvaluator) -> None: def shape_assertions(self, eval: CachingShapeEvaluator) -> None: """Computes the shape assertions for the set of constraints. - See jax_export._wrap_main_func docstring. + See jax_export.Exported docstring. """ # We want to report the errors in the same order as `check_statically`. # So, we process them in order, in case some fail statically, and we @@ -1924,7 +1916,7 @@ def pretty_print_dimension_descriptor( @util.cache() def solve_dim_vars( - args_avals: Sequence[core.AbstractValue], + args_avals: Sequence[core.ShapedArray], args_kwargs_tree: tree_util.PyTreeDef, ) -> tuple[DimVarEnv, ShapeConstraints, Sequence[tuple[str, int, int]]]: """Solves dimension variables in a called function's avals in terms of actual argument shapes. @@ -1989,7 +1981,7 @@ def solve_dim_vars( def compute_dim_vars_from_arg_shapes( - args_avals: Sequence[core.AbstractValue], + args_avals: Sequence[core.ShapedArray], *actual_args: jax.Array, args_kwargs_tree: tree_util.PyTreeDef) -> Sequence[jax.Array]: """Computes values of dimension variables to unify args_avals with actual arguments. @@ -2030,7 +2022,7 @@ def _solve_dim_equations( " Using the following polymorphic shapes specifications: " + ",".join(f"{arg_name}.shape = {arg_spec}" for arg_name, arg_spec in polymorphic_shape_specs)) + "." - solution_err_msg_trailer_errors = ". Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#shape-assertion-errors for more details." + solution_err_msg_trailer_errors = ". Please see https://jax.readthedocs.io/en/latest/export/shape_poly.html#shape-assertion-errors for more details." shape_constraints = ShapeConstraints() # accumulate shape constraints scope: SymbolicScope | None = None @@ -2084,9 +2076,9 @@ def process_one_eqn(eqn: _DimEquation) -> bool: solution_err_msg_trailer_errors])) if not isinstance(var_value, _DimExpr): - assert var_value.dtype == core.dim_value_dtype() + assert var_value.dtype == core.dim_value_dtype() # type: ignore[attribute-error,unused-ignore] shape_env[var] = var_value # type: ignore - solution_error_message_pieces.extend([ + solution_error_message_pieces.extend([ # type: ignore[container-type-mismatch,unused-ignore] f"'{var}' = ", var_value, f" from specification '{eqn.aval_dim_expr}' " f"for dimension {eqn.dim_name} (= ", @@ -2159,6 +2151,6 @@ def add_explicit_symbolic_constraints(shape_env: DimVarEnv): " Unprocessed specifications: " + ", ".join(f"'{eqn.aval_dim_expr}' for dimension size {eqn.dim_name}" for eqn in eqns) + - ". Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#dimension-variables-must-be-solvable-from-the-input-shapes for more details." + ". Please see https://jax.readthedocs.io/en/latest/export/shape_poly.html#dimension-variables-must-be-solvable-from-the-input-shapes for more details." ) raise ValueError(err_msg) diff --git a/jax/experimental/export/_shape_poly_decision.py b/jax/_src/export/shape_poly_decision.py similarity index 98% rename from jax/experimental/export/_shape_poly_decision.py rename to jax/_src/export/shape_poly_decision.py index 9e35d82972ee..e325722b0c26 100644 --- a/jax/experimental/export/_shape_poly_decision.py +++ b/jax/_src/export/shape_poly_decision.py @@ -23,8 +23,8 @@ import numpy as np -from jax.experimental.export import _shape_poly -from jax.experimental.export._shape_poly import ( +from jax._src.export import shape_poly +from jax._src.export.shape_poly import ( _DimExpr, _DimTerm, _DimFactor, SymbolicScope, DimSize, @@ -43,7 +43,7 @@ def bounds_decision(e: DimSize, decision = _DecisionByElimination.build(e.scope) return decision.bounds(e, prec, add_implicit_constraints=True) -_shape_poly._bounds_decision = bounds_decision +shape_poly._bounds_decision = bounds_decision class _DecisionByElimination: @@ -183,7 +183,7 @@ def add_to_state(self, lead_t_constraints.add((cmp, lead_t_k, e)) def combine_term_with_existing(self, t: _DimTerm, t_k: int, *, - scope: _shape_poly.SymbolicScope, + scope: shape_poly.SymbolicScope, only_smaller_than_t=True, ) -> Sequence[tuple[Comparator, _DimExpr, @@ -292,7 +292,7 @@ def _bounds_for_sorted_terms(self, prec: BoundsPrecision) -> tuple[float, float]: """The lower and upper bounds of e[i:]. - See comments about soundness and `cmp_with` in the `_shape_poly.bounds_decision`` method. + See comments about soundness and `cmp_with` in the `shape_poly.bounds_decision`` method. Returns (lower-bound, upper-bound) """ if i >= len(e): return (0, 0) diff --git a/jax/_src/extend/ffi.py b/jax/_src/extend/ffi.py new file mode 100644 index 000000000000..428187d57fdb --- /dev/null +++ b/jax/_src/extend/ffi.py @@ -0,0 +1,257 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from collections.abc import Iterable, Mapping, Sequence +import ctypes +import functools +import os +from typing import Any + +from jax._src import core +from jax._src import dispatch +from jax._src import dtypes +from jax._src import util +from jax._src.callback import _check_shape_dtype, callback_batching_rule +from jax._src.interpreters import ad +from jax._src.interpreters import batching +from jax._src.interpreters import mlir +from jax._src.lib import jaxlib +from jax._src.lib.mlir import ir +from jax._src.typing import Array, ArrayLike, DimSize, DuckTypedArray +import numpy as np + +map, unsafe_map = util.safe_map, map + + +def pycapsule(funcptr): + """Wrap a ctypes function pointer in a PyCapsule. + + The primary use of this function, and the reason why it lives with in the + ``jax.extend.ffi`` submodule, is to wrap function calls from external + compiled libraries to be registered as XLA custom calls. + + Example usage:: + + import ctypes + import jax + from jax.lib import xla_client + + libfoo = ctypes.cdll.LoadLibrary('./foo.so') + xla_client.register_custom_call_target( + name="bar", + fn=jax.extend.ffi.pycapsule(libfoo.bar), + platform=PLATFORM, + api_version=API_VERSION + ) + + Args: + funcptr: A function pointer loaded from a dynamic library using ``ctypes``. + + Returns: + An opaque ``PyCapsule`` object wrapping ``funcptr``. + """ + destructor = ctypes.CFUNCTYPE(None, ctypes.py_object) + builder = ctypes.pythonapi.PyCapsule_New + builder.restype = ctypes.py_object + builder.argtypes = (ctypes.c_void_p, ctypes.c_char_p, destructor) + return builder(funcptr, None, destructor(0)) + + +def include_dir() -> str: + """Get the path to the directory containing header files bundled with jaxlib""" + jaxlib_dir = os.path.dirname(os.path.abspath(jaxlib.__file__)) + return os.path.join(jaxlib_dir, "include") + + +def ffi_lowering( + call_target_name: str, + *, + operand_layouts: Sequence[Sequence[DimSize]] | None = None, + result_layouts: Sequence[Sequence[DimSize]] | None = None, + backend_config: Mapping[str, ir.Attribute] | None = None, + **lowering_args: Any +) -> mlir.LoweringRule: + """Build a lowering rule for an foreign function interface (FFI) target. + + By default, this lowering rule can use the input and output abstract values to + compute the input and output types and shapes for the custom call, assuming + row-major layouts. + + If keyword arguments are passed to the lowering rule, these are treated as + attributes, and added to `backend_config`. + + Args: + call_target_name: The name of the custom call target. + operand_layouts: A sequence of layouts (dimension orders) for each operand. + By default, the operands are assumed to be row-major. + result_layouts: A sequence of layouts (dimension orders) for each result. + By default, the results are assumed to be row-major. + backend_config: Configuration data for the custom call. Any keyword + arguments passed to the lowering rule will added to this dictionary. + lowering_args: Any other arguments to :func:`mlir.custom_call` will also be + passed through if provided as extra arguments to this function. + """ + + def _lowering( + ctx: mlir.LoweringRuleContext, *operands: ir.Value, **params: Any + ) -> Sequence[ir.Value | Sequence[ir.Value]]: + kwargs = dict(lowering_args) + kwargs.setdefault("api_version", 4) + kwargs["backend_config"] = dict( + backend_config or {}, **{k: _ir_attribute(v) for k, v in params.items()}) + if "result_types" not in kwargs: + kwargs["result_types"] = [mlir.aval_to_ir_type(aval) for aval in ctx.avals_out] + if operand_layouts is None: + kwargs["operand_layouts"] = _default_layouts(aval.shape for aval in ctx.avals_in) # pytype: disable=attribute-error + if result_layouts is None: + kwargs["result_layouts"] = _default_layouts(aval.shape for aval in ctx.avals_out) + + return mlir.custom_call(call_target_name, operands=operands, **kwargs).results # type: ignore + + return _lowering + + +def _default_layouts(shapes: Iterable[Sequence[DimSize]]) -> list[list[DimSize]]: + return [list(reversed(range(len(shape)))) for shape in shapes] + + +def _ir_attribute(obj: Any) -> ir.Attribute: + # TODO(dfm): Similar functions exist in Pallas and Mosaic GPU. Perhaps these + # could be consolidated into mlir or similar. + if isinstance(obj, str): + return ir.StringAttr.get(obj) + elif isinstance(obj, bool): + return ir.BoolAttr.get(obj) + elif isinstance(obj, int): + return mlir.i64_attr(obj) + elif isinstance(obj, float): + return ir.FloatAttr.get_f64(obj) + elif hasattr(obj, "dtype"): + if not (dtypes.is_python_scalar(obj) or np.isscalar(obj)): + raise TypeError("Only scalar attributes are supported") + mlir_type = mlir.dtype_to_ir_type(obj.dtype) + if isinstance(mlir_type, ir.IntegerType): + return ir.IntegerAttr.get(mlir_type, obj) + elif isinstance(mlir_type, ir.FloatType): + return ir.FloatAttr.get(mlir_type, obj) + raise TypeError(f"Unsupported attribute type: {type(obj)}") + + +def ffi_call( + target_name: str, + result_shape_dtypes: DuckTypedArray | Sequence[DuckTypedArray], + *args: ArrayLike, + vectorized: bool = False, + **kwargs: Any, +) -> Array | list[Array]: + """Call a foreign function interface (FFI) target. + + Like :func:`~jax.pure_callback`, the behavior of ``ffi_call`` under + :func:`~jax.vmap` depends on the value of ``vectorized``. When ``vectorized`` + is ``True``, the FFI target is assumed to satisfy: ``ffi_call(xs) == + jnp.stack([ffi_call(x) for x in xs])``. In other words, calling the FFI target + with an extra leading dimension should return the same result as calling it + within a loop and stacking along the zeroth axis. Therefore, the FFI target + will be called directly on batched inputs (where the batch axes are the + leading dimensions). Additionally, the callbacks should return outputs that + have corresponding leading batch axes. If ``vectorized`` is ``False`` (the + default behavior), transforming this ``ffi_call`` under :func:`~jax.vmap` will + result in a :func:`~jax.lax.scan` with the ``ffi_call`` in the body. + + Args: + target_name: the name of the XLA FFI custom call target that was registered + using :func:`~jaxlib.xla_client.register_custom_call_target`. + result_shape_dtypes: an object, or sequence of objects, with ``shape`` and + ``dtype`` attributes which are expected to match the shape and dtype of + the custom call output or outputs. :class:`~jax.ShapeDtypeStruct` is often + used to define the elements of ``result_shape_dtypes``. + *args: the arguments passed to the custom call. + vectorized: boolean specifying whether the callback function can operate in + a vectorized manner, as described above. + **kwargs: keyword arguments that are passed as named attributes to the + custom call using XLA's FFI interface. + + Returns: + One or more :class:`~jax.Array` objects whose shapes and dtypes match + ``result_shape_dtypes``. + """ + if isinstance(result_shape_dtypes, Sequence): + multiple_results = True + result_types = result_shape_dtypes + else: + multiple_results = False + result_types = (result_shape_dtypes,) + map(_check_shape_dtype, result_types) + result_avals = tuple(core.ShapedArray(x.shape, x.dtype) for x in result_types) + results = ffi_call_p.bind( + *args, + result_avals=result_avals, + vectorized=vectorized, + target_name=target_name, + **kwargs, + ) + if multiple_results: + return results + else: + return results[0] + + +def ffi_call_abstract_eval( + *avals_in, + result_avals: tuple[core.ShapedArray, ...], + target_name: str, + vectorized: bool, + **kwargs: Any, +): + del avals_in, target_name, vectorized, kwargs + return result_avals + + +def ffi_call_jvp(*args, target_name, **kwargs): + del args, kwargs + raise ValueError( + f"The FFI call to `{target_name}` cannot be differentiated. " + "You can use `jax.custom_jvp` or `jax.custom_jvp` to add support.") + + +def ffi_call_transpose(*args, target_name, **kwargs): + del args, kwargs + raise ValueError( + f"The FFI call to `{target_name}` cannot be differentiated. " + "You can use `jax.custom_jvp` or `jax.custom_jvp` to add support.") + + +def ffi_call_lowering( + ctx: mlir.LoweringRuleContext, + *operands: ir.Value, + result_avals: tuple[core.ShapedArray, ...], + target_name: str, + vectorized: bool, + **kwargs: Any, +) -> Sequence[ir.Value]: + del result_avals, vectorized + return ffi_lowering(target_name)(ctx, *operands, **kwargs) + + +ffi_call_p = core.Primitive("ffi_call") +ffi_call_p.multiple_results = True +ffi_call_p.def_impl(functools.partial(dispatch.apply_primitive, ffi_call_p)) +ffi_call_p.def_abstract_eval(ffi_call_abstract_eval) +ad.primitive_jvps[ffi_call_p] = ffi_call_jvp +ad.primitive_transposes[ffi_call_p] = ffi_call_transpose +batching.primitive_batchers[ffi_call_p] = functools.partial( + callback_batching_rule, ffi_call_p) +mlir.register_lowering(ffi_call_p, ffi_call_lowering) diff --git a/jax/_src/extend/random.py b/jax/_src/extend/random.py index ffc7f63072b9..df927486dd2f 100644 --- a/jax/_src/extend/random.py +++ b/jax/_src/extend/random.py @@ -12,8 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable -from collections.abc import Hashable +from collections.abc import Callable, Hashable from jax import Array diff --git a/jax/_src/gfile_cache.py b/jax/_src/gfile_cache.py index 301f61cc6bdb..989844b10ddb 100644 --- a/jax/_src/gfile_cache.py +++ b/jax/_src/gfile_cache.py @@ -17,6 +17,10 @@ from jax._src import path as pathlib from jax._src.compilation_cache_interface import CacheInterface + +# TODO (ayx): This class will be ultimately removed after `lru_cache.py` is +# finished. It exists because the current `lru_cache.py` does not support +# `gs://`. class GFileCache(CacheInterface): def __init__(self, path: str): diff --git a/jax/_src/image/scale.py b/jax/_src/image/scale.py index 1ef1db916f73..aa9910555130 100644 --- a/jax/_src/image/scale.py +++ b/jax/_src/image/scale.py @@ -14,10 +14,9 @@ from __future__ import annotations -from collections.abc import Sequence +from collections.abc import Callable, Sequence from functools import partial import enum -from typing import Callable import numpy as np diff --git a/jax/_src/internal_test_util/export_back_compat_test_data/cpu_ducc_fft.py b/jax/_src/internal_test_util/export_back_compat_test_data/cpu_ducc_fft.py deleted file mode 100644 index d92a90b36ce6..000000000000 --- a/jax/_src/internal_test_util/export_back_compat_test_data/cpu_ducc_fft.py +++ /dev/null @@ -1,114 +0,0 @@ -# Copyright 2023 The JAX Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import datetime -from numpy import array, float32, complex64 - -# Pasted from the test output (see back_compat_test.py module docstring) -data_2023_03_17 = dict( - testdata_version=1, - platform='cpu', - custom_call_targets=['ducc_fft'], - serialized_date=datetime.date(2023, 3, 17), - inputs=(array([[ 0., 1., 2., 3.], - [ 4., 5., 6., 7.], - [ 8., 9., 10., 11.]], dtype=float32),), - expected_outputs=(array([[ 6.+0.j, -2.+2.j, -2.+0.j, -2.-2.j], - [22.+0.j, -2.+2.j, -2.+0.j, -2.-2.j], - [38.+0.j, -2.+2.j, -2.+0.j, -2.-2.j]], dtype=complex64),), - mlir_module_text=r""" -module @jit_func { - func.func public @main(%arg0: tensor<3x4xf32> {jax.arg_info = "x", mhlo.sharding = "{replicated}"}) -> (tensor<3x4xcomplex> {jax.result_info = ""}) { - %0 = call @fft(%arg0) : (tensor<3x4xf32>) -> tensor<3x4xcomplex> - return %0 : tensor<3x4xcomplex> - } - func.func private @fft(%arg0: tensor<3x4xf32>) -> tensor<3x4xcomplex> { - %0 = stablehlo.convert %arg0 : (tensor<3x4xf32>) -> tensor<3x4xcomplex> - %1 = stablehlo.constant dense<"0x18000000140024000000000008000C001000140007001800140000000000000154000000380000001C00000010000000000000000000F03F0000000001000000010000000200000004000000000000000100000000000000000000000200000004000000000000000100000000000000000000000200000003000000000000000400000000000000"> : tensor<136xui8> - %2 = stablehlo.custom_call @ducc_fft(%1, %0) {api_version = 2 : i32, operand_layouts = [dense<0> : tensor<1xindex>, dense<[1, 0]> : tensor<2xindex>], result_layouts = [dense<[1, 0]> : tensor<2xindex>]} : (tensor<136xui8>, tensor<3x4xcomplex>) -> tensor<3x4xcomplex> - return %2 : tensor<3x4xcomplex> - } -} -""", - mlir_module_serialized=b'ML\xefR\x03MLIRxxx-trunk\x00\x01\x1d\x05\x01\x05\x01\x03\x05\x03\r\x07\t\x0b\r\x0f\x11\x03\x99s\x15\x01?\x07\x0b\x0f\x17\x0b\x0b\x0b\x0b\x0f\x13\x0b33\x0b\x0b\x0f\x0b\x13\x0b\x0bK\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x0b\x035\x0b\x0b\x0f\x0b\x0bO\x0f\x1b\x0b\x0b\x0b\x0b\x0f\x13\x0b\x0b\x0b\x0b\x0bf\x04\x0b\x0b\x0b\x13/\x0f\x03\x15\x17\x17\x07\x17\x07\x17\x0b\x07\x13\x13\x02\xca\x05\x1f\x05\x13\x1d\x1b\x07\x17\x1d^\x03\x01\x05\x15\x05\x17\x05\x19\x05\x1b\x1d\'\x07\x03\x03\x03\x15\x05\x1d\x03\x0b\tK\x0b?\rW\x03]\x0f_\x03\x0b\tC\x0b?\rC\x03E\x0fc\x05\x1f\x05!\x1d!\x07\x05#\x03\x03%e\x05%\x05\'\x03\x11+g-A/i1G3k5m7G9q\x05)\x05+\x05-\x05/\x051\x053\x055\x057\x03\x03=E\x059#\x0b\x1d;\x03\x03a\x1d=\x03\x01\x1f\x13!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03\x03M\r\x05OQSU\x1d?\x1dA\x1dC\x1dE\x03\x03Y\r\x03[A\x1dG\x1dI\x1dK\r\x01\x1dM\x1f\x07"\x02\x18\x00\x00\x00\x14\x00$\x00\x00\x00\x00\x00\x08\x00\x0c\x00\x10\x00\x14\x00\x07\x00\x18\x00\x14\x00\x00\x00\x00\x00\x00\x01T\x00\x00\x008\x00\x00\x00\x1c\x00\x00\x00\x10\x00\x00\x00\x00\x00\x00\x00\x00\x00\xf0?\x00\x00\x00\x00\x01\x00\x00\x00\x01\x00\x00\x00\x02\x00\x00\x00\x04\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x04\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x03\x00\x00\x00\x00\x00\x00\x00\x04\x00\x00\x00\x00\x00\x00\x00\x0b\x05\x1dO\x05\x01\x03\x05oI\x1f\x11\x11\x00\x00\x00\x00\x00\x00\x00\x00\x03\x03I)\x05\r\x11\r)\x05\r\x11\x05\t)\x03B\x04\x0f\x13\x11\x03\x03\x03\x01\x03\x05!)\x03\x05\t)\x03\t\t\x04\x8f\x05\x01\x11\x01\x13\x07\x03\x01\t\x03\x11\x01\x17\x05\x03\x05\x0b\x03\x03\x01\r\x07\x05;\x03\x01\x03\x01\x05\x04\x01\x03\x03\x03\x11\x05\x19\x05\x03\t\x13\x03\x03\x01\x07\x06\x1f\x03\x01\x03\x01\t\x03\x11#\x03\x07\x0b\x07\x11)\x03\x01\x05\x05\x03\x05\x04\x05\x03\x07\x06\x03\x01\x05\x01\x00\xc2\x0eQ\x13\x11\x0f\x0b!\x1b\x1d\x05\x1b\t\x03\x0f\x1f/!!)#\x1f\x19\x91\r\xaf\x83\x82\x04\x13\x1f\x15\x1d\x15\x13\x11\x1f\x19\x17\x15\x11\x0f\x0b\x11builtin\x00vhlo\x00module\x00func_v1\x00return_v1\x00convert_v1\x00constant_v1\x00custom_call_v1\x00call_v1\x00sym_name\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00jit_func\x00jit(func)/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) resource_env=None donated_invars=(False,) name=fft in_positional_semantics=(<_PositionalSemantics.GLOBAL: 1>,) out_positional_semantics=_PositionalSemantics.GLOBAL keep_unused=False inline=False]\x00third_party/py/jax/experimental/jax2tf/tests/back_compat_test.py\x00jit(func)/jit(main)/jit(fft)/convert_element_type[new_dtype=complex64 weak_type=False]\x00value\x00jit(func)/jit(main)/jit(fft)/fft[fft_type=FftType.FFT fft_lengths=(4,)]\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00callee\x00\x00fft\x00jax.arg_info\x00x\x00mhlo.sharding\x00{replicated}\x00jax.result_info\x00main\x00public\x00private\x00ducc_fft\x00', - xla_call_module_version=4, -) # End paste - -# Pasted from the test output (see back_compat_test.py module docstring) -data_2023_06_14 = dict( - testdata_version=1, - platform='cpu', - custom_call_targets=['dynamic_ducc_fft'], - serialized_date=datetime.date(2023, 6, 14), - inputs=(array([[ 0., 1., 2., 3.], - [ 4., 5., 6., 7.], - [ 8., 9., 10., 11.]], dtype=float32),), - expected_outputs=(array([[ 6.+0.j, -2.+2.j, -2.+0.j, -2.-2.j], - [22.+0.j, -2.+2.j, -2.+0.j, -2.-2.j], - [38.+0.j, -2.+2.j, -2.+0.j, -2.-2.j]], dtype=complex64),), - mlir_module_text=r""" -#loc = loc(unknown) -module @jit_func attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main(%arg0: tensor<3x4xf32> {jax.arg_info = "x", mhlo.sharding = "{replicated}"} loc(unknown)) -> (tensor<3x4xcomplex> {jax.result_info = ""}) { - %0 = call @fft(%arg0) : (tensor<3x4xf32>) -> tensor<3x4xcomplex> loc(#loc3) - return %0 : tensor<3x4xcomplex> loc(#loc) - } loc(#loc) - func.func private @fft(%arg0: tensor<3x4xf32> loc(unknown)) -> tensor<3x4xcomplex> { - %0 = stablehlo.convert %arg0 : (tensor<3x4xf32>) -> tensor<3x4xcomplex> loc(#loc4) - %1 = stablehlo.constant dense<4> : tensor loc(#loc5) - %2 = stablehlo.constant dense<1> : tensor loc(#loc5) - %3 = stablehlo.convert %1 : (tensor) -> tensor loc(#loc5) - %4 = stablehlo.reshape %3 : (tensor) -> tensor<1xi32> loc(#loc5) - %5 = stablehlo.convert %2 : (tensor) -> tensor loc(#loc5) - %6 = stablehlo.reshape %5 : (tensor) -> tensor<1xi32> loc(#loc5) - %7 = stablehlo.concatenate %4, %6, dim = 0 : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> loc(#loc5) - %8 = stablehlo.constant dense<4> : tensor loc(#loc5) - %9 = stablehlo.constant dense<1> : tensor loc(#loc5) - %10 = stablehlo.convert %8 : (tensor) -> tensor loc(#loc5) - %11 = stablehlo.reshape %10 : (tensor) -> tensor<1xi32> loc(#loc5) - %12 = stablehlo.convert %9 : (tensor) -> tensor loc(#loc5) - %13 = stablehlo.reshape %12 : (tensor) -> tensor<1xi32> loc(#loc5) - %14 = stablehlo.concatenate %11, %13, dim = 0 : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> loc(#loc5) - %15 = stablehlo.constant dense<4> : tensor loc(#loc5) - %16 = stablehlo.convert %15 : (tensor) -> tensor loc(#loc5) - %17 = stablehlo.constant dense<1.000000e+00> : tensor loc(#loc5) - %18 = stablehlo.reshape %17 : (tensor) -> tensor<1xf64> loc(#loc5) - %19 = stablehlo.constant dense<3> : tensor loc(#loc5) - %20 = stablehlo.constant dense<4> : tensor loc(#loc5) - %21 = stablehlo.convert %19 : (tensor) -> tensor loc(#loc5) - %22 = stablehlo.reshape %21 : (tensor) -> tensor<1xi32> loc(#loc5) - %23 = stablehlo.convert %20 : (tensor) -> tensor loc(#loc5) - %24 = stablehlo.reshape %23 : (tensor) -> tensor<1xi32> loc(#loc5) - %25 = stablehlo.concatenate %22, %24, dim = 0 : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> loc(#loc5) - %26 = stablehlo.constant dense<3> : tensor loc(#loc5) - %27 = stablehlo.constant dense<4> : tensor loc(#loc5) - %28 = stablehlo.convert %26 : (tensor) -> tensor loc(#loc5) - %29 = stablehlo.reshape %28 : (tensor) -> tensor<1xi32> loc(#loc5) - %30 = stablehlo.convert %27 : (tensor) -> tensor loc(#loc5) - %31 = stablehlo.reshape %30 : (tensor) -> tensor<1xi32> loc(#loc5) - %32 = stablehlo.concatenate %29, %31, dim = 0 : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> loc(#loc5) - %33 = stablehlo.constant dense<[20, 0, 0, 0, 0, 0, 14, 0, 16, 0, 8, 0, 0, 0, 0, 0, 12, 0, 7, 0, 14, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0, 0, 4, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0]> : tensor<44xui8> loc(#loc5) - %34 = stablehlo.custom_call @dynamic_ducc_fft(%33, %0, %25, %7, %14, %18, %32) {api_version = 2 : i32, indices_of_shape_operands = dense<6> : tensor<1xi64>, operand_layouts = [dense<0> : tensor<1xindex>, dense<[1, 0]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>], result_layouts = [dense<[1, 0]> : tensor<2xindex>]} : (tensor<44xui8>, tensor<3x4xcomplex>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>, tensor<1xf64>, tensor<2xi32>) -> tensor<3x4xcomplex> loc(#loc5) - return %34 : tensor<3x4xcomplex> loc(#loc3) - } loc(#loc3) -} loc(#loc) -#loc1 = loc("third_party/py/jax/experimental/jax2tf/tests/back_compat_test.py":437:0) -#loc2 = loc("third_party/py/jax/experimental/jax2tf/tests/back_compat_test.py":428:0) -#loc3 = loc("jit(func)/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) resource_env=None donated_invars=(False,) name=fft keep_unused=False inline=False]"(#loc1)) -#loc4 = loc("jit(func)/jit(main)/jit(fft)/convert_element_type[new_dtype=complex64 weak_type=False]"(#loc2)) -#loc5 = loc("jit(func)/jit(main)/jit(fft)/fft[fft_type=FftType.FFT fft_lengths=(4,)]"(#loc2)) -""", - mlir_module_serialized=b"ML\xefR\x01StableHLO_v0.9.0\x00\x01!\x05\x01\x03\x01\x03\x05\x03\x11\x07\t\x0b\r\x0f\x11\x13\x15\x03\xd3\x95+\x01U\x0f\x07\x13\x0b\x13\x0b\x0f\x0f\x0b\x0b\x0b\x0b\x0b\x17\x13\x13#\x0b\x0b\x0b33\x0b\x17\x0f\x0b\x0b\x0b\x13\x13S\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x0b\x03A/\x0b\x0b\x0f\x0b\x0bO\x0f\x1b\x0b\x0b\x0b\x0b\x0f\x13\x0b\x0b\x0b\x0b\x0b//\x0f//\xbf\x0b\x0b\x0b/'\x0f\x01\x03\x0f\x03)\x0f\x0f\x13\x17\x13\x17\x07\x07\x0f\x07\x07\x13\x07\x17\x0b\x13\x07\x13\x13\x13\x02r\x06\x1d5\x1b\x1f\x03\x03\x07}\x05\x17\x03\x037\x81\x05\x19\x1d-/\x11\x01\x05\x05\x1b\x05\x1d\x05\x1f\x05!\x05#\x17\x19\xb2\x06\x01\x03\x03\x07\x7f\x03\x03\x07\x85\x03\x07#\x0f%\x0f\x0b'\x05%\x05'\x05)\x03\x0b\x11c\x13W\x15o\x0bu\x17w\x03\x0b\x11[\x13W\x15[\x0b]\x17{\x05+\x17\x19\xd6\x06\x01\x1d3\x1b\x05-\x05/\x051\x03\x03\x07\x83\x03\x03\x07\x87\x03\x13?\x89AYC\x8bE_G\x8dI\x8fK\x91M_O\x93\x053\x055\x057\x059\x05;\x05=\x05?\x05A\x05C\x03\x03S]\x05E\x1f'\x11\x00\x00\x00\x00\x00\x00\x00\x00#\x1d\x1dG\x03\x03y\x1dI\x03\x01\x1f)!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03\x03e\r\x05gikm\x1dK\x1dM\x1dO\x1dQ\x03\x03q\r\x03sY\x1dS\x1dU\x1dW\r\x01\x1dY\x1f\x03\x11\x04\x00\x00\x00\x00\x00\x00\x00\x1f\x03\x11\x01\x00\x00\x00\x00\x00\x00\x00\x13\x0f\x01\x1f\x13\x11\x00\x00\x00\x00\x00\x00\xf0?\x1f\x03\x11\x03\x00\x00\x00\x00\x00\x00\x00\x1f\x19Y\x14\x00\x00\x00\x00\x00\x0e\x00\x10\x00\x08\x00\x00\x00\x00\x00\x0c\x00\x07\x00\x0e\x00\x00\x00\x00\x00\x00\x01\x02\x00\x00\x00\x04\x00\x00\x00\x01\x00\x00\x00\x01\x00\x00\x00\x0b\x05\x1d[\x05\x01\x1f%\x11\x06\x00\x00\x00\x00\x00\x00\x00\x03\x0fUaUUUUU\x03\x03a\x01\x02\x02)\x01\x0f)\x01\x11)\x03\x05\x11)\x05\r\x11\x1f)\x03\t\x11)\x05\r\x11\x15\x1d\x1b)\x01\x17\t\x0b)\x03\xb1#\x13\x11\x03\r\x03\t\x03\x15)\x03\x05\x17!)\x03\x05\x0f)\x03\x05\x1b)\x03\t\x1b\x04\xaa\x04\x05\x01\x11\x03!\x07\x03\x01\t\x0b\x11\x03)\x05\x03\x05\x0b\x03\r\x03\x11\x07\rQ\x03\t\x03\x01\r\x04\x03\x03\x03\x0b\x11\r+\x05\x03I\x93\x03\r\x03\x05\x061\x03\t\x03\x01\x03\x03\x01\x05\x03\x03\x03\x03\x01\x1d\x03\x03\x05\x06\x01\x03\x05\x03\x05\x07\x06\x01\x03\x07\x03\t\x05\x06\x01\x03\x05\x03\x07\x07\x06\x01\x03\x07\x03\r\t\x07\x01\t\x03\x0b\x05\x0b\x0f\x03\x03\x01\x05\x03\x03\x03\x03\x01\x1d\x03\x03\x05\x06\x01\x03\x05\x03\x13\x07\x06\x01\x03\x07\x03\x17\x05\x06\x01\x03\x05\x03\x15\x07\x06\x01\x03\x07\x03\x1b\t\x07\x01\t\x03\x0b\x05\x19\x1d\x03\x03\x01\x05\x03\x03\x05\x06\x01\x03\x13\x03!\x03\x03\x019\x03\x13\x07\x06\x01\x03!\x03%\x03\x03\x01\x1f\x03\x03\x03\x03\x01\x05\x03\x03\x05\x06\x01\x03\x05\x03)\x07\x06\x01\x03\x07\x03-\x05\x06\x01\x03\x05\x03+\x07\x06\x01\x03\x07\x031\t\x07\x01\t\x03\x0b\x05/3\x03\x03\x01\x1f\x03\x03\x03\x03\x01\x05\x03\x03\x05\x06\x01\x03\x05\x037\x07\x06\x01\x03\x07\x03;\x05\x06\x01\x03\x05\x039\x07\x06\x01\x03\x07\x03?\t\x07\x01\t\x03\x0b\x05=A\x03\x03\x01;\x03\x19\x0f\x07\x01=\x03\t\x0fE\x035\x11\x1f'C\r\x04\r\x03G\x06\x03\x01\x05\x01\x00\xc6\x0e]#\x11\x0f\x0b!\x1b\x1d\x05\x1b\t\x03\x0f\x1f/!5!)#\x1f\x19\x15\x91\xaf\xbe\x02\x13%)\x83\x1f\x15\x1d\x15\x13\r\x11\x1f\x15\x11\x1f\x17\x17\x19\x0f\x0b\x11builtin\x00vhlo\x00module\x00constant_v1\x00convert_v1\x00reshape_v1\x00concatenate_v1\x00func_v1\x00return_v1\x00custom_call_v1\x00call_v1\x00value\x00sym_name\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00third_party/py/jax/experimental/jax2tf/tests/back_compat_test.py\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00jit(func)/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) resource_env=None donated_invars=(False,) name=fft keep_unused=False inline=False]\x00jit(func)/jit(main)/jit(fft)/convert_element_type[new_dtype=complex64 weak_type=False]\x00jit(func)/jit(main)/jit(fft)/fft[fft_type=FftType.FFT fft_lengths=(4,)]\x00dimension\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00indices_of_shape_operands\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00callee\x00\x00fft\x00jax.arg_info\x00x\x00mhlo.sharding\x00{replicated}\x00jax.result_info\x00main\x00public\x00private\x00dynamic_ducc_fft\x00", - xla_call_module_version=6, -) # End paste diff --git a/jax/_src/internal_test_util/export_back_compat_test_data/pallas/cuda_add_one.py b/jax/_src/internal_test_util/export_back_compat_test_data/pallas/cuda_add_one.py new file mode 100644 index 000000000000..4a1bbe63b9be --- /dev/null +++ b/jax/_src/internal_test_util/export_back_compat_test_data/pallas/cuda_add_one.py @@ -0,0 +1,47 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import datetime +from numpy import array, float32 + + +# Pasted from the test output (see export_back_compat_test_util.py module docstring) +data_2024_05_02 = dict( + testdata_version=1, + platform='cuda', + custom_call_targets=['__gpu$xla.gpu.triton'], + serialized_date=datetime.date(2024, 5, 2), + inputs=(array([0., 1., 2., 3., 4., 5., 6., 7.], dtype=float32),), + expected_outputs=(array([1., 2., 3., 4., 5., 6., 7., 8.], dtype=float32),), + mlir_module_text=r""" +#loc1 = loc("x") +#loc2 = loc("third_party/py/jax/tests/pallas/export_back_compat_pallas_test.py":43:13) +#loc3 = loc("jit(func)/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) in_layouts=(None,) out_layouts=(None,) resource_env=None donated_invars=(False,) name=wrapped keep_unused=False inline=False]"(#loc2)) +module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main(%arg0: tensor<8xf32> {mhlo.layout_mode = "default"} loc("x")) -> (tensor<8xf32> {jax.result_info = "", mhlo.layout_mode = "default"}) { + %0 = call @wrapped(%arg0) : (tensor<8xf32>) -> tensor<8xf32> loc(#loc3) + return %0 : tensor<8xf32> loc(#loc) + } loc(#loc) + func.func private @wrapped(%arg0: tensor<8xf32> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) in_layouts=(None,) out_layouts=(None,) resource_env=None donated_invars=(False,) name=wrapped keep_unused=False inline=False]"(#loc2))) -> (tensor<8xf32> {mhlo.layout_mode = "default"}) { + %0 = stablehlo.custom_call @__gpu$xla.gpu.triton(%arg0) {mhlo.backend_config = {debug = false, grid_x = 8 : i32, grid_y = 1 : i32, grid_z = 1 : i32, ir = "ML\EFR\0DMLIRgooglex-trunk\00\01-\07\01\05\09\17\01\03\0F\03\0D\13\17\1B\1F#'\05\09+/37\03O1\0B\01-\07\0F\0F\0F\0F\13\13\13\0B\0F\0F\13\0B\0F\0B\0B\0B\0B\1F\0B\0B\13\05\05YY\01\09\0F\07\17\0B\03\035\02\16\02\1F\11\01\05\1D)+\1D#\0F\11\01\01#\01\01\01\03\03\19\1B\17\11U'\05\1D\11\07\00\1D'\0F\01\05\0D\0D\05\1F\11\01\81\0D\05\05!\05#\05%\13\03\10\00\00\E0\0F\05'\05)\17\11U\11#arith.overflow\00#arith.fastmath\00\01\02\02\0B\05\05\09\09\01\01\09!tt.ptr\00\04\D2\02\05\01P\01\01\07\04\AE\02\03\01\05\07P\01\03\07\04\82\02\03+W\05\11\11\00\09B\01\05\03\01\0FB\01\07\03\01\11F\01\09\03\01\05\05\07\0FB\01\07\03\01\11F\01\09\03\01\05\05\0B\0FB\07\05\03\01\13F\07\09\03\01\05\0F\09\0FB\07\07\03\01\11F\07\09\03\01\05\11\13\03\06\07\03\09\05\01\15\05F\07\0B\03\03\03\17\0FB\15\0D\03\03\15F\15\0F\03\03\05\19\1B\0FB\05\05\03\01\13F\05\09\03\01\05\1F\0D\0FB\05\07\03\01\11F\05\09\03\01\05!#\03\06\05\03\09\05\03%\05F\05\0B\03\03\03'\0BD\05\11\05'\1D\0D\00\01\06\03\01\05\01\00\F2\05+\A5\0B\A3\0F\11!\85\0B\0B\0B\13\0F\0D\1F\0B\0B\0F\0F\0D\07\11builtin\00tt\00arith\00module\00addptr\00load\00func\00get_program_id\00store\00return\00constant\00muli\00addi\00addf\00third_party/py/jax/tests/pallas/export_back_compat_pallas_test.py\00tt.divisibility\00add_one\00public\00/get[tree=PyTreeDef((CustomNode(NDIndexer[(PyTreeDef((*,)), (1,), ())], [*]),))]\00/add\00/swap[tree=PyTreeDef((CustomNode(NDIndexer[(PyTreeDef((*,)), (1,), ())], [*]),))]\00\08C\13\05\01\01\0B/\1D\01\1FC\03\09\03\03\03[\11\17\07\07'\01\07\01\03\03%\03_\07\17\07\07", name = "add_one", num_stages = 3 : i32, num_warps = 4 : i32}, operand_layouts = [dense<0> : tensor<1xindex>], result_layouts = [dense<0> : tensor<1xindex>]} : (tensor<8xf32>) -> tensor<8xf32> loc(#loc4) + return %0 : tensor<8xf32> loc(#loc3) + } loc(#loc3) +} loc(#loc) +#loc = loc(unknown) +#loc4 = loc("jit(func)/jit(main)/jit(wrapped)/pallas_call[name=add_one which_linear=(False,) in_shapes=(ShapeDtypeStruct(shape=(8,), dtype=float32),) out_shapes=(ShapeDtypeStruct(shape=(8,), dtype=float32),) debug=False interpret=False grid_mapping=GridMapping(grid=(8,), block_mappings=(BlockMapping(block_shape=(1,), index_map_jaxpr={ lambda ; a:i32[]. let in (a,) }, indexing_mode=), BlockMapping(block_shape=(1,), index_map_jaxpr={ lambda ; a:i32[]. let in (a,) }, indexing_mode=)), mapped_dims=(), num_index_operands=0, num_scratch_operands=0) input_output_aliases=() compiler_params={}]"(#loc2)) +""", + mlir_module_serialized=b"ML\xefR\x01StableHLO_v0.9.0\x00\x01\x19\x05\x01\x03\x01\x03\x05\x03\t\x07\t\x0b\r\x03\xaf\x8b\x11\x01G\x07\x0f\x0b\x0f\x0b\x0b\x0b\x0b\x13+\x0b\x0f\x0b\x0b\x0b33\x0b\x0bS\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0f\x0b\x0f\x0b\x13\x0b\x03E\x0f\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0f\x0f\x13\x0f\x1b\x0b\x0b\x0b\x0b\x0b\x0bK\x0b\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0f\x0b\x0f/\x01\x05\x0b\x0f\x03\r\x13\x07\x17\x07\x13\x07\x02\xee\x03\x1f\x1d#\x11\x05\x0f\x11\x03\x05\x05\x11\x05\x13\x05\x15\x05\x17\x17%W\x1b\x03\t\x15\x17\x19\x07\x1b\x07\x05\x1d\x05\x19\x11\x01\x00\x05\x1b\x05\x1d\x05\x1f\x03\x0b\tG\x0bM\r]\x05c\x0fe\x03\x0b\tG\x0bM\rG\x05Q\x0fg\x05!\x05#\x03\x13)i+O-k/S1U3m5Y7S9Y\x05%\x05'\x05)\x05+\x05-\x05/\x051\x053\x055\x1d=\x11\x057\x1dA\x01\x059\x03\x03EQ\x05;\x03\x03[\x1d=\x1d?#\t\x1dA\x1dC\x03\x01\x05\x01\x13\x07\x05\x03\x03\x89\r\x03IK\x03\x03_\r\x05aOIK\x1dE\x1dG\x1dI\x1dK\x0b\x03\x1dM\r\x11oUqsuWwWy{}\x7f\x81\x83\x85\x87\x1dO\x1dQ\x13\x07!\x1dS\x1dU\x1dW\x1dY\x1d[\x1d]\x1d_\x13\x07\r\x1da\x13\x07\x11\x1f\r\x11\x00\x00\x00\x00\x00\x00\x00\x00\x01\t\x01\x02\x02)\x03!\x0b\x1b\x11\x03\x05\x03\x05\t)\x03\x05\x0f\x13\x04s\x05\x01\x11\x01\x13\x07\x03\x01\t\x03\x11\x01\x1f\x07\x03\x05\x0b\x03\x05?\t\x07\x03C\x03\x05\x03\x01\x05\x04\x01\x03\x03\x03\x11\x03!\x07\x03\x05\x0b\x03\x05\x03\x07\x07;'\x03\x05\x03\x01\x05\x04\x03\x03\x03\x06\x03\x01\x05\x01\x00\x12%c\x15\x17\x11\x0b\xfe\x0c\x07\x0f\x0f\x0f\r+\x11\x0f\x0b!\x11\x03\x11#\x0f\x05\xd2\n\x1f/!)!)#\x1f\x19\x85j\x03\x13%)9\x1f\x15\x1d\x15\x13\x11\x1f\x15\x11\x0f\x0b\x11builtin\x00vhlo\x00module\x00func_v1\x00return_v1\x00custom_call_v1\x00call_v1\x00sym_name\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00jit(func)/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) in_layouts=(None,) out_layouts=(None,) resource_env=None donated_invars=(False,) name=wrapped keep_unused=False inline=False]\x00third_party/py/jax/tests/pallas/export_back_compat_pallas_test.py\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00mhlo.backend_config\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00jit(func)/jit(main)/jit(wrapped)/pallas_call[name=add_one which_linear=(False,) in_shapes=(ShapeDtypeStruct(shape=(8,), dtype=float32),) out_shapes=(ShapeDtypeStruct(shape=(8,), dtype=float32),) debug=False interpret=False grid_mapping=GridMapping(grid=(8,), block_mappings=(BlockMapping(block_shape=(1,), index_map_jaxpr={ lambda ; a:i32[]. let in (a,) }, indexing_mode=), BlockMapping(block_shape=(1,), index_map_jaxpr={ lambda ; a:i32[]. let in (a,) }, indexing_mode=)), mapped_dims=(), num_index_operands=0, num_scratch_operands=0) input_output_aliases=() compiler_params={}]\x00x\x00callee\x00mhlo.layout_mode\x00default\x00\x00wrapped\x00jax.result_info\x00main\x00public\x00private\x00__gpu$xla.gpu.triton\x00debug\x00grid_x\x00grid_y\x00grid_z\x00ir\x00ML\xefR\rMLIRgooglex-trunk\x00\x01-\x07\x01\x05\t\x17\x01\x03\x0f\x03\r\x13\x17\x1b\x1f#'\x05\t+/37\x03O1\x0b\x01-\x07\x0f\x0f\x0f\x0f\x13\x13\x13\x0b\x0f\x0f\x13\x0b\x0f\x0b\x0b\x0b\x0b\x1f\x0b\x0b\x13\x05\x05YY\x01\t\x0f\x07\x17\x0b\x03\x035\x02\x16\x02\x1f\x11\x01\x05\x1d)+\x1d#\x0f\x11\x01\x01#\x01\x01\x01\x03\x03\x19\x1b\x17\x11U'\x05\x1d\x11\x07\x00\x1d'\x0f\x01\x05\r\r\x05\x1f\x11\x01\x81\r\x05\x05!\x05#\x05%\x13\x03\x10\x00\x00\xe0\x0f\x05'\x05)\x17\x11U\x11#arith.overflow\x00#arith.fastmath\x00\x01\x02\x02\x0b\x05\x05\t\t\x01\x01\t!tt.ptr\x00\x04\xd2\x02\x05\x01P\x01\x01\x07\x04\xae\x02\x03\x01\x05\x07P\x01\x03\x07\x04\x82\x02\x03+W\x05\x11\x11\x00\tB\x01\x05\x03\x01\x0fB\x01\x07\x03\x01\x11F\x01\t\x03\x01\x05\x05\x07\x0fB\x01\x07\x03\x01\x11F\x01\t\x03\x01\x05\x05\x0b\x0fB\x07\x05\x03\x01\x13F\x07\t\x03\x01\x05\x0f\t\x0fB\x07\x07\x03\x01\x11F\x07\t\x03\x01\x05\x11\x13\x03\x06\x07\x03\t\x05\x01\x15\x05F\x07\x0b\x03\x03\x03\x17\x0fB\x15\r\x03\x03\x15F\x15\x0f\x03\x03\x05\x19\x1b\x0fB\x05\x05\x03\x01\x13F\x05\t\x03\x01\x05\x1f\r\x0fB\x05\x07\x03\x01\x11F\x05\t\x03\x01\x05!#\x03\x06\x05\x03\t\x05\x03%\x05F\x05\x0b\x03\x03\x03'\x0bD\x05\x11\x05'\x1d\r\x00\x01\x06\x03\x01\x05\x01\x00\xf2\x05+\xa5\x0b\xa3\x0f\x11!\x85\x0b\x0b\x0b\x13\x0f\r\x1f\x0b\x0b\x0f\x0f\r\x07\x11builtin\x00tt\x00arith\x00module\x00addptr\x00load\x00func\x00get_program_id\x00store\x00return\x00constant\x00muli\x00addi\x00addf\x00third_party/py/jax/tests/pallas/export_back_compat_pallas_test.py\x00tt.divisibility\x00add_one\x00public\x00/get[tree=PyTreeDef((CustomNode(NDIndexer[(PyTreeDef((*,)), (1,), ())], [*]),))]\x00/add\x00/swap[tree=PyTreeDef((CustomNode(NDIndexer[(PyTreeDef((*,)), (1,), ())], [*]),))]\x00\x08C\x13\x05\x01\x01\x0b/\x1d\x01\x1fC\x03\t\x03\x03\x03[\x11\x17\x07\x07'\x01\x07\x01\x03\x03%\x03_\x07\x17\x07\x07\x00name\x00add_one\x00num_stages\x00num_warps\x00", + xla_call_module_version=9, + nr_devices=1, +) # End paste diff --git a/jax/_src/internal_test_util/export_back_compat_test_data/stablehlo_dynamic_approx_top_k.py b/jax/_src/internal_test_util/export_back_compat_test_data/stablehlo_dynamic_approx_top_k.py new file mode 100644 index 000000000000..b676cc8011d3 --- /dev/null +++ b/jax/_src/internal_test_util/export_back_compat_test_data/stablehlo_dynamic_approx_top_k.py @@ -0,0 +1,84 @@ +# Copyright 2023 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ruff: noqa + +import datetime +from numpy import array, float32, int32 + +# Pasted from the test output (see export_back_compat_test_util.py module docstring) +data_2024_05_30 = dict( + testdata_version=1, + platform='cpu', + custom_call_targets=['shape_assertion', 'stablehlo.dynamic_approx_top_k'], + serialized_date=datetime.date(2024, 5, 30), + inputs=(array([ 0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., + 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23.], + dtype=float32),), + expected_outputs=(array([23., 22., 21., 20., 19., 18., 17., 16., 15., 14., 13., 12., 11., + 10., 9., 8., 7., 6., 5., 4.], dtype=float32), array([23, 22, 21, 20, 19, 18, 17, 16, 15, 14, 13, 12, 11, 10, 9, 8, 7, + 6, 5, 4], dtype=int32)), + mlir_module_text=r""" +#loc = loc(unknown) +#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":718:13) +#loc3 = loc("a") +#loc9 = loc("jit(func)/jit(main)/approx_top_k[k=b reduction_dimension=-1 recall_target=0.95 is_max_k=True reduction_input_size_override=-1 aggregate_to_topk=True]"(#loc2)) +module @jit_func attributes {jax.uses_shape_polymorphism = true, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main(%arg0: tensor {mhlo.layout_mode = "default"} loc(unknown)) -> (tensor {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor {jax.result_info = "[1]", mhlo.layout_mode = "default"}) { + %0 = stablehlo.get_dimension_size %arg0, dim = 0 : (tensor) -> tensor loc(#loc4) + %1 = stablehlo.convert %0 : (tensor) -> tensor loc(#loc4) + %2 = stablehlo.convert %1 : tensor loc(#loc5) + %c = stablehlo.constant dense<-4> : tensor loc(#loc) + %3 = stablehlo.add %2, %c : tensor loc(#loc6) + %c_0 = stablehlo.constant dense<1> : tensor loc(#loc) + %4 = stablehlo.compare GE, %3, %c_0, SIGNED : (tensor, tensor) -> tensor loc(#loc7) + stablehlo.custom_call @shape_assertion(%4, %3, %1) {api_version = 2 : i32, error_message = "Input shapes do not match the polymorphic shapes specification. Expected value >= 1 for dimension variable 'b'. Using the following polymorphic shapes specifications: args[0].shape = (b + 4,). Obtained dimension variables: 'b' = {0} from specification 'b + 4' for dimension args[0].shape[0] (= {1}), . Please see https://github.com/googlexjax/blob/main/jax/experimental/jax2tf/README.md#shape-assertion-errors for more details.", has_side_effect = true} : (tensor, tensor, tensor) -> () loc(#loc8) + %5:2 = call @_wrapped_jax_export_main(%3, %arg0) : (tensor, tensor) -> (tensor, tensor) loc(#loc) + return %5#0, %5#1 : tensor, tensor loc(#loc) + } loc(#loc) + func.func @top_k_gt_f32_comparator(%arg0: tensor loc("jit(func)/jit(main)/approx_top_k[k=b reduction_dimension=-1 recall_target=0.95 is_max_k=True reduction_input_size_override=-1 aggregate_to_topk=True]"(#loc2)), %arg1: tensor loc("jit(func)/jit(main)/approx_top_k[k=b reduction_dimension=-1 recall_target=0.95 is_max_k=True reduction_input_size_override=-1 aggregate_to_topk=True]"(#loc2)), %arg2: tensor loc("jit(func)/jit(main)/approx_top_k[k=b reduction_dimension=-1 recall_target=0.95 is_max_k=True reduction_input_size_override=-1 aggregate_to_topk=True]"(#loc2)), %arg3: tensor loc("jit(func)/jit(main)/approx_top_k[k=b reduction_dimension=-1 recall_target=0.95 is_max_k=True reduction_input_size_override=-1 aggregate_to_topk=True]"(#loc2))) -> tensor { + %0 = stablehlo.compare GT, %arg0, %arg1 : (tensor, tensor) -> tensor loc(#loc9) + return %0 : tensor loc(#loc9) + } loc(#loc9) + func.func private @_wrapped_jax_export_main(%arg0: tensor {jax.global_constant = "b", mhlo.layout_mode = "default"} loc(unknown), %arg1: tensor {mhlo.layout_mode = "default"} loc("a")) -> (tensor {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor {jax.result_info = "[1]", mhlo.layout_mode = "default"}) { + %0 = stablehlo.convert %arg0 : tensor loc(#loc10) + %c = stablehlo.constant dense<4> : tensor loc(#loc9) + %1 = stablehlo.add %0, %c : tensor loc(#loc11) + %2 = stablehlo.convert %1 : (tensor) -> tensor loc(#loc9) + %3 = stablehlo.reshape %2 : (tensor) -> tensor<1xi32> loc(#loc9) + %4 = stablehlo.dynamic_iota %3, dim = 0 : (tensor<1xi32>) -> tensor loc(#loc9) + %c_0 = stablehlo.constant dense<-1> : tensor loc(#loc9) + %cst = stablehlo.constant dense<0xFF800000> : tensor loc(#loc9) + %5 = stablehlo.convert %arg0 : (tensor) -> tensor loc(#loc9) + %6 = stablehlo.reshape %5 : (tensor) -> tensor<1xi32> loc(#loc9) + %7 = stablehlo.convert %arg0 : (tensor) -> tensor loc(#loc9) + %8 = stablehlo.reshape %7 : (tensor) -> tensor<1xi32> loc(#loc9) + %9 = stablehlo.convert %arg0 : (tensor) -> tensor loc(#loc9) + %10:2 = stablehlo.custom_call @stablehlo.dynamic_approx_top_k(%arg1, %4, %cst, %c_0, %9, %6, %8) {called_computations = [@top_k_gt_f32_comparator], indices_of_shape_operands = dense<[5, 6]> : tensor<2xi64>, mhlo.backend_config = {aggregate_to_topk = true, is_fallback = true, recall_target = 0.949999988 : f32, reduction_dim = 0 : i64, reduction_input_size_override = -1 : i64}} : (tensor, tensor, tensor, tensor, tensor, tensor<1xi32>, tensor<1xi32>) -> (tensor, tensor) loc(#loc9) + return %10#0, %10#1 : tensor, tensor loc(#loc) + } loc(#loc) +} loc(#loc) +#loc1 = loc("third_party/py/jax/tests/export_back_compat_test.py":738:4) +#loc4 = loc("/dimension_size[dimension=0]"(#loc1)) +#loc5 = loc("/convert_element_type[new_dtype=int64 weak_type=False]"(#loc1)) +#loc6 = loc("/add"(#loc1)) +#loc7 = loc("/ge"(#loc1)) +#loc8 = loc("/shape_assertion[error_message=Input shapes do not match the polymorphic shapes specification. Expected value >= 1 for dimension variable 'b'. Using the following polymorphic shapes specifications: args[0].shape = (b + 4,). Obtained dimension variables: 'b' = {0} from specification 'b + 4' for dimension args[0].shape[0] (= {1}), . Please see https://github.com/googlexjax/blob/main/jax/experimental/jax2tf/README.md#shape-assertion-errors for more details.]"(#loc1)) +#loc10 = loc("jit(func)/jit(main)/convert_element_type[new_dtype=int64 weak_type=False]"(#loc2)) +#loc11 = loc("jit(func)/jit(main)/add"(#loc2)) +""", + mlir_module_serialized=b'ML\xefR\x01StableHLO_v0.9.0\x00\x01\'\x05\x01\x03\x01\x03\x05\x03\x17\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x1b\x03B\x02\xeb#\x01\x85\x0f\x07\x0b\x17\x0b\x0b\x0b\x0b\x0b\x17\x0f\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0f+\x0b\x0f\x0b\x0b\x0b33\x0b3\x0f\x0b\x0f\x0b\x13\x0f\x0b\x13\x0b\x13\x13[\x0b\x0b\x1b\x13\x0b\x0b\x0f\x0b\x13\x0f\x0b\x13\x1b\x0f\x0bS\x0b\x0f\x0b\x13\x0b\x03g\x0b\x0b\x0b\x0b\x0f\x0b\x13\x13\x0b\x0b\x0b\x0f\x0b\x1b\x0b\x1b\x0b\x0b\x0b\x0b\x13\x1b\x0b\x0b\x0b\x0b/\x1f\x1f\x0b\x0b\x0f\x0bO3\x0b\x0b\x0b\x1f\x0b\x0b\x0f\x0b\x0b//\x0b\x0b\x0b\x0b\x0b\x01\x05\x0b\x0f\x03\x1f\x0f\x0f3\x0f3\x07\x07\x07\x0f\x13\x1b#\x07\x1f\x13\x02\x0e\x08\x1d?\x13\x1f\x05\x1d\x17\x17\x8a\x0b\t\x05\x1f\x05!\x05#\x05%\x05\'\x17\x17:\x0b\x1b\x11\x03\x05\x05)\x05+\x05-\x05/\x051\x053\x055\x057\x059\x05;\x05=\x1de\x07\x03\t135\x157\x15\t9\x05?\x11\x01\x01\x05A\x05C\x05E\x03\x0b\x0b\x9b\r\x9d\x0f\x93\t\xa7\x11\xa9\x03\x0b\x0b\x85\r\xab\x0f\x85\t\x97\x11\x8b\x05G\x03\x0b\x0b\xad\r\xb5\x0f\x93\t\x99\x11\xb7\x1dE\x03\x05I\x1dI\x13\x05K\x03\x03\x05\xb9\x1dO\x13\x05M\x03\x03S\x8d\x05O\x03\x03\x05\xbb\x03\x03\x05\xbd\x03\x15\x19\xbf\x1b\x8b\x1d\xc1\x1f\xc3!\xc5[\xc7]\xc9#\x85%\x85\'\x85\x05Q\x05S\x03\x05)\xd9+\xdb\x03\x03c\x8d\x05U\x05W\x1di\x07\x05Y\x03\x03\x05\xdd\x1do\x07\x05[\x03\x03\x05\xdf\x03\x05)\xe1+\xe3\x1dw\x07\x05]\x03\x13\x19\xe5\x1b\x8b\x1d\xe7\x1f\x85{\xe9!\x8f#\x85%\x85\'\x85\x05_\x1d\x7f\x07\x05a\x03\x03\x83\x99\x05c\x03\x01\x1de\x1dg\x1di\x13\x0f\x01\x05\x03\r\x03\x87\x89\x03\x05\x9f\xa3\x1dk\x1dm\x1do\x03\x03\x91#\x19\r\x05\x95\xa1\x87\x89\x1dq\r\x05\x95\xa5\x87\x89\x1ds\x1du\x1dw#\x1b\x03\x05\xaf\x91\r\x05\xb1\xb3\x87\x89\x1dy\x1d{#\x1f\x1d}\x1f\x05\x11\x04\x00\x00\x00\x00\x00\x00\x00\x1f\x07\t\xff\xff\xff\xff\x1f\x0b\t\x00\x00\x80\xff\x0b\x03\x1d\x7f\x03\x03\x97\x05\x01\x1f!!\x05\x00\x00\x00\x00\x00\x00\x00\x06\x00\x00\x00\x00\x00\x00\x00\r\x0b\xcb\x8f\xcd\x8f\xcf\xd1\xd3\x8d\xd5\xd7\x1d\x81\x1d\x83\x1d\x85\x11\x11\xd0\xcc\xcc\xdc\x0f\x1d\x87\x1d\x89\x13\x0f\x03\t\x01\x07\x07\x1f\x05\x11\xfc\xff\xff\xff\xff\xff\xff\xff\x1f\x05\x11\x01\x00\x00\x00\x00\x00\x00\x00\t\x07\x07\x05\x0b\x05\x1d\x8b\x1d\x8d\x01\t\x01\x02\x02)\x01\x0f)\x01\x13)\x03\x00\xff\xff\xff\xff\xff\xff\xff\xff\x11)\x01\x11)\x03\x00\xff\xff\xff\xff\xff\xff\xff\xff\x13\x1d\t\x1b)\x01\x1d)\x03\x05\x13\x11\x03\t\x05\t\r\x11\t\x0b\x0b\x07\x07\x03\x15\x01\x11\x05\x05\t\x05\t\r)\x03\t\x0f\x04\xea\x03\x05\x01\x11\x03/\x07\x03\x01\r\x07\x11\x03;\x07\x03\x15+\x03\t\x03\x15\x07-a\x03\x07\x03\x01\x03\x06-\x03\x05\x03\x03\x03\x06g\x03\x05\x03\x05\x05\x03\x03k\x03\x05\r\x06m\x03\x05\x05\x07\t\x05\x03\x03q\x03\x05\x11\x07us\x03\x15\x05\x0b\r\x0f\x05}y\x07\x0f\x0b\x05\x17\x07\x03\x81\x05\t\r\x05\x0b\x01\x0b\x04\x03\x05\x11\x13\x07\x11\x01=\x07\x03\x0b\x0b\t\x0b\x01\x0b\x01\x07\x01\x07\x01\x11\x07\x01_\x03\x15\x05\x01\x03\x0b\x04\x01\x03\t\x07\x11\x03A\x07\x03#?\x05\x05\x03\tC\x03\x06G\x03\x05\x03\x01\x05\x03\x01K\x03\x05\r\x06M\x03\x05\x05\x05\x07\x03\x06\x01\x03\x07\x03\t\t\x06\x01\x03\x17\x03\x0b\x13\x07\x01Q\x03\r\x03\r\x05\x03\x01U\x03\x07\x05\x03\x01W\x03\x0b\x03\x06\x01\x03\x07\x03\x01\t\x06\x01\x03\x17\x03\x15\x03\x06\x01\x03\x07\x03\x01\t\x06\x01\x03\x17\x03\x19\x03\x06\x01\x03\x07\x03\x01\x0f\x07\x01Y\x05\t\r\x0f\x03\x0f\x13\x11\x1d\x17\x1b\x0b\x04\x03\x05\x1f!\x06\x03\x01\x05\x01\x00""\x8f\xb2\x06!=\x1d\x1d\x19%?\x11\x05)\x0f\x0b\t\t31!\x03\x11#\x0f2\x07\x1d\t\x0bo;\x15)5\x1f1\x95\x05Z\x02\x13%)9+\x1b\x1f/!!)#\x1f\x19i\x1f\x15\x1d\x15\x13\r\x11-!\x17\x1f\x0f\x15\x17\x11\x19\x17\x0f\x0b\x11builtin\x00vhlo\x00module\x00convert_v1\x00constant_v1\x00func_v1\x00reshape_v1\x00return_v1\x00add_v1\x00custom_call_v1\x00compare_v1\x00dynamic_iota_v1\x00get_dimension_size_v1\x00call_v1\x00value\x00sym_name\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00third_party/py/jax/tests/export_back_compat_test.py\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00compare_type\x00comparison_direction\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00jit(func)/jit(main)/approx_top_k[k=b reduction_dimension=-1 recall_target=0.95 is_max_k=True reduction_input_size_override=-1 aggregate_to_topk=True]\x00a\x00jit(func)/jit(main)/convert_element_type[new_dtype=int64 weak_type=False]\x00jit(func)/jit(main)/add\x00iota_dimension\x00indices_of_shape_operands\x00mhlo.backend_config\x00dimension\x00/dimension_size[dimension=0]\x00/convert_element_type[new_dtype=int64 weak_type=False]\x00/add\x00/ge\x00error_message\x00/shape_assertion[error_message=Input shapes do not match the polymorphic shapes specification. Expected value >= 1 for dimension variable \'b\'. Using the following polymorphic shapes specifications: args[0].shape = (b + 4,). Obtained dimension variables: \'b\' = {0} from specification \'b + 4\' for dimension args[0].shape[0] (= {1}), . Please see https://github.com/googlexjax/blob/main/jax/experimental/jax2tf/README.md#shape-assertion-errors for more details.]\x00callee\x00mhlo.layout_mode\x00default\x00\x00jax.result_info\x00top_k_gt_f32_comparator\x00_wrapped_jax_export_main\x00[0]\x00[1]\x00main\x00public\x00jax.global_constant\x00b\x00private\x00stablehlo.dynamic_approx_top_k\x00aggregate_to_topk\x00is_fallback\x00recall_target\x00reduction_dim\x00reduction_input_size_override\x00shape_assertion\x00Input shapes do not match the polymorphic shapes specification. Expected value >= 1 for dimension variable \'b\'. Using the following polymorphic shapes specifications: args[0].shape = (b + 4,). Obtained dimension variables: \'b\' = {0} from specification \'b + 4\' for dimension args[0].shape[0] (= {1}), . Please see https://github.com/googlexjax/blob/main/jax/experimental/jax2tf/README.md#shape-assertion-errors for more details.\x00', + xla_call_module_version=9, + nr_devices=1, +) # End paste diff --git a/jax/_src/internal_test_util/export_back_compat_test_util.py b/jax/_src/internal_test_util/export_back_compat_test_util.py index dc1e0409ca26..5a975e3c5a61 100644 --- a/jax/_src/internal_test_util/export_back_compat_test_util.py +++ b/jax/_src/internal_test_util/export_back_compat_test_util.py @@ -17,7 +17,9 @@ JAX serialized format, we need to guarantee that custom calls continue to work as before. We test this here. -The tests in this file refer to the test data in ./back_compat_testdata. +The tests in this file refer to the test data in +jax/_src/internal_test_util/export_back_compat_test_data. + There is one test for each version of a custom call target, e.g., `test_ducc_fft` tests the FFT custom calls on CPU. Only custom call targets tested here should be listed in @@ -32,11 +34,12 @@ Write the JAX function `func` that exercises the custom call `foo_call` you want, then pick some inputs, and then add this to the new test to get started. +Add the following code to your test file, e.g., `export_back_compat_test.py`. import dataclasses from jax._src.internal_test_util import export_back_compat_test_util as bctu - class BackCompatTest(bctu.CompatTestBase) + class CompatTest(bctu.CompatTestBase) ... def test_foo_call(self): @@ -48,13 +51,13 @@ def func(...): ... The test will fail, but will save to a file the test data you will need. The file name will be printed in the logs. Create a new -file ./back_compat_testdata/foo_call.py and paste the test data that -you will see printed in the logs. +file jax/_src/internal_test_util/export_back_compat_test_data/foo_call.py +and paste the test data that you will see printed in the logs. -Name the literal `data_YYYYY_MM_DD` to include the date of serializaton +Name the literal `data_YYYYY_MM_DD` to include the date of serialization (for readability only). Then add to this file: - from jax.experimental.jax2tf.tests.back_compat_testdata import foo_call + from jax._src.internal_test_util.export_back_compat_test_data import foo_call then update `test_custom_call_coverage`, and then update your `test_foo_call`: @@ -67,13 +70,13 @@ def func(...): ... from __future__ import annotations -from collections.abc import Iterable, Sequence +from collections.abc import Callable, Iterable, Sequence import dataclasses import datetime import os import re import sys -from typing import Any, Callable +from typing import Any from absl import logging @@ -83,7 +86,7 @@ def func(...): ... import jax from jax import tree_util -from jax.experimental import export +from jax import export from jax.experimental import pjit @@ -300,7 +303,7 @@ def serialize(self, module_str = str(exported.mlir_module()) serialized = exported.mlir_module_serialized - module_version = exported.mlir_module_serialization_version + module_version = exported.calling_convention_version nr_devices = exported.nr_devices return serialized, module_str, module_version, nr_devices @@ -327,19 +330,19 @@ def _get_vjp(_): in_avals=tuple(in_avals), out_tree=out_tree, out_avals=tuple(out_avals), - in_shardings=(None,) * len(in_avals), - out_shardings=(None,) * len(out_avals), - lowering_platforms=(data.platform,), + in_shardings_hlo=(None,) * len(in_avals), + out_shardings_hlo=(None,) * len(out_avals), + platforms=(data.platform,), ordered_effects=(), unordered_effects=(), disabled_safety_checks=(), mlir_module_serialized=data.mlir_module_serialized, - mlir_module_serialization_version=data.xla_call_module_version, + calling_convention_version=data.xla_call_module_version, nr_devices=data.nr_devices, module_kept_var_idx=tuple(range(len(in_avals))), - uses_shape_polymorphism=any(not core.is_constant_shape(a.shape) + uses_global_constants=any(not core.is_constant_shape(a.shape) for a in in_avals), _get_vjp=_get_vjp) # We use pjit in case there are shardings in the exported module. - return pjit.pjit(export.call_exported(exported))(*data.inputs) + return pjit.pjit(exported.call)(*data.inputs) diff --git a/jax/_src/internal_test_util/lax_test_util.py b/jax/_src/internal_test_util/lax_test_util.py index 5cc08eb05231..b57b7d0852a9 100644 --- a/jax/_src/internal_test_util/lax_test_util.py +++ b/jax/_src/internal_test_util/lax_test_util.py @@ -23,6 +23,7 @@ import itertools from typing import Union, cast +import jax from jax import lax from jax._src import dtypes from jax._src import test_util @@ -30,8 +31,7 @@ import numpy as np -from jax import config -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() map, unsafe_map = safe_map, map zip, unsafe_zip = safe_zip, zip diff --git a/jax/_src/internal_test_util/test_harnesses.py b/jax/_src/internal_test_util/test_harnesses.py index bf099344b47e..2b22944c17b8 100644 --- a/jax/_src/internal_test_util/test_harnesses.py +++ b/jax/_src/internal_test_util/test_harnesses.py @@ -38,11 +38,11 @@ from __future__ import annotations -from collections.abc import Iterable, Sequence +from collections.abc import Callable, Iterable, Sequence import operator import os from functools import partial -from typing import Any, Callable, NamedTuple, Union +from typing import Any, NamedTuple, Union from absl import testing import numpy as np @@ -62,6 +62,9 @@ from jax._src.lib import xla_client from jax._src import random as jax_random +# mypy generates a lot of false positive due to re-assigned variables. +# mypy: disable-error-code="assignment, no-redef" + # The code in this file relies on the values of some flags that are defined by # jtu. Note that the following can not always be moved to a test file since # then the test file has to import jtu first (to define the flags) which is not @@ -172,9 +175,9 @@ def __init__(self, self.group_name = jtu.sanitize_test_name(group_name) self.name = jtu.sanitize_test_name(name) self.fullname = self.name if self.group_name is None else f"{self.group_name}_{self.name}" - self.fun = fun # type: ignore[assignment] + self.fun = fun self.arg_descriptors = arg_descriptors - self.rng_factory = rng_factory # type: ignore[assignment] + self.rng_factory = rng_factory self.jax_unimplemented = jax_unimplemented self.dtype = dtype self.params = params @@ -651,7 +654,7 @@ def _make_device_put_harness(name, define( "device_put", f"{name}_shape={jtu.format_shape_dtype_string(shape, dtype)}_{device=}", - lambda x: dispatch.device_put_p.bind(x, device=_device_fn(), src=None), + lambda x: dispatch.device_put_p.bind(x, devices=[_device_fn()], srcs=[None])[0], [RandArg(shape, dtype)], shape=shape, dtype=dtype, @@ -1819,7 +1822,6 @@ def _fft_rng_factory(dtype): jax_unimplemented=[ Limitation( "unimplemented", - devices=("cpu", "gpu"), dtypes=[np.float16, dtypes.bfloat16], ), ], @@ -2061,18 +2063,17 @@ def _make_slice_harness(name, define( lax.slice_p, f"{name}_a={jtu.format_shape_dtype_string(shape, dtype)}_{start_indices=}_{limit_indices=}_{strides=}", - # type: ignore lax.slice, [ - RandArg(shape, dtype), # type: ignore - StaticArg(start_indices), # type: ignore - StaticArg(limit_indices), # type: ignore + RandArg(shape, dtype), + StaticArg(start_indices), + StaticArg(limit_indices), StaticArg(strides) - ], # type: ignore + ], dtype=dtype, - shape=shape, # type: ignore - start_indices=start_indices, # type: ignore - limit_indices=limit_indices) # type: ignore + shape=shape, + start_indices=start_indices, + limit_indices=limit_indices) # Test first all dtypes @@ -2162,17 +2163,16 @@ def _make_dynamic_slice_harness(name, define( lax.dynamic_slice_p, f"{name}_a={jtu.format_shape_dtype_string(shape, dtype)}_{start_indices=}_{limit_indices=}_enablexla={enable_xla}", - # type: ignore lax.dynamic_slice, [ - RandArg(shape, dtype), # type: ignore + RandArg(shape, dtype), np.array(list(start_indices)), StaticArg(tuple(map(operator.sub, limit_indices, start_indices))) - ], # type: ignore + ], dtype=dtype, - shape=shape, # type: ignore - start_indices=start_indices, # type: ignore - limit_indices=limit_indices, # type: ignore + shape=shape, + start_indices=start_indices, + limit_indices=limit_indices, enable_xla=enable_xla) @@ -2219,19 +2219,19 @@ def _make_dynamic_update_slice_harness(name, define( lax.dynamic_update_slice_p, ( - f"{name}_operand={jtu.format_shape_dtype_string(shape, dtype)}" # type: ignore + f"{name}_operand={jtu.format_shape_dtype_string(shape, dtype)}" f"_update={jtu.format_shape_dtype_string(update_shape, dtype)}" f"_{start_indices=}_{enable_xla=}"), lax.dynamic_update_slice, [ - RandArg(shape, dtype), # type: ignore - RandArg(update_shape, dtype), # type: ignore + RandArg(shape, dtype), + RandArg(update_shape, dtype), np.array(start_indices) - ], # type: ignore + ], dtype=dtype, - shape=shape, # type: ignore - start_indices=start_indices, # type: ignore - update_shape=update_shape, # type: ignore + shape=shape, + start_indices=start_indices, + update_shape=update_shape, enable_xla=enable_xla) @@ -2262,12 +2262,12 @@ def _make_squeeze_harness(name, dtype=np.float32): define( lax.squeeze_p, - f"{name}_inshape={jtu.format_shape_dtype_string(shape, dtype)}_{dimensions=}", # type: ignore + f"{name}_inshape={jtu.format_shape_dtype_string(shape, dtype)}_{dimensions=}", lax.squeeze, - [RandArg(shape, dtype), StaticArg(dimensions)], # type: ignore[has-type] + [RandArg(shape, dtype), StaticArg(dimensions)], dtype=dtype, arg_shape=shape, - dimensions=dimensions) # type: ignore[has-type] + dimensions=dimensions) # Test first all dtypes @@ -3313,6 +3313,7 @@ def _make_conv_harness(name, lhs_dilation=lhs_dilation, rhs_dilation=rhs_dilation) +key_types: list[tuple[tuple[int, ...], jax.typing.DTypeLike]] key_types = [((4,), np.uint32)] if config.enable_x64.value: key_types.append(((2,), np.uint64)) diff --git a/jax/_src/interpreters/ad.py b/jax/_src/interpreters/ad.py index bddbf9c92dc6..a527acb8db90 100644 --- a/jax/_src/interpreters/ad.py +++ b/jax/_src/interpreters/ad.py @@ -14,12 +14,12 @@ from __future__ import annotations -from collections.abc import Sequence +from collections.abc import Callable, Sequence import contextlib import functools import itertools as it from functools import partial -from typing import Any, Callable +from typing import Any import jax from jax._src import config @@ -238,7 +238,8 @@ def write_primal(v, val): else: cts_in, = map(read_cotangent, eqn.outvars) name_stack = source_info_util.current_name_stack() + eqn.source_info.name_stack - with source_info_util.user_context(eqn.source_info.traceback, name_stack=name_stack): + with source_info_util.user_context( + eqn.source_info.traceback, name_stack=name_stack), eqn.ctx.manager: if eqn.primitive.call_primitive or eqn.primitive.map_primitive: cts_in_avals = [v.aval for v in eqn.outvars] params = dict(eqn.params) diff --git a/jax/_src/interpreters/batching.py b/jax/_src/interpreters/batching.py index c58c42d291d2..3a87fffa5116 100644 --- a/jax/_src/interpreters/batching.py +++ b/jax/_src/interpreters/batching.py @@ -14,10 +14,10 @@ from __future__ import annotations import collections -from collections.abc import Iterable, Sequence +from collections.abc import Callable, Iterable, Sequence import dataclasses from functools import partial -from typing import Any, Callable, Union +from typing import Any, Union import numpy as np @@ -200,7 +200,7 @@ def _update_annotation( class Name: def __init__(self, a): self.a = a names = [Name(a) for a, _ in orig_type] - avals = [a.update(shape=tuple(names[d.val] if type(d) is pe.DBIdx else d # type: ignore + avals = [a.update(shape=tuple(names[d.val] if type(d) is pe.DBIdx else d for d in a.shape)) if type(a) is core.DShapedArray else a for a, e in orig_type if e] @@ -245,7 +245,7 @@ def to_elt(trace: Trace, get_idx: GetIdx, x: Vmappable, spec: MapSpec) -> Elt: for i, sz in enumerate(x.aval.elt_ty.shape) if type(sz) is IndexedAxisSize) batch_axis = make_batch_axis(x.data.ndim, 0, [(d+1, ias.lengths)]) - return BatchTracer(trace, x.data, batch_axis) # type: ignore + return BatchTracer(trace, x.data, batch_axis) elif isinstance(spec, int) or spec is None: spec = spec and canonicalize_axis(spec, len(np.shape(x))) return (BatchTracer(trace, x, spec, source_info_util.current()) @@ -322,7 +322,7 @@ def __init__(self, trace, val, batch_dim: NotMapped | int | RaggedAxis, assert type(batch_dim) in (NotMapped, int, RaggedAxis) if type(batch_dim) is int: aval = raise_to_shaped(core.get_aval(val)) - assert 0 <= batch_dim < len(aval.shape) # type: ignore + assert 0 <= batch_dim < len(aval.shape) self._trace = trace self.val = val self.batch_dim = batch_dim @@ -338,7 +338,7 @@ def aval(self): elif type(self.batch_dim) is RaggedAxis: new_aval = core.mapped_aval( aval.shape[self.batch_dim.stacked_axis], self.batch_dim.stacked_axis, aval) - shape = list(new_aval.shape) # type: ignore + shape = list(new_aval.shape) # pytype: disable=attribute-error for ragged_axis, segment_lengths in self.batch_dim.ragged_axes: size_tracer = BatchTracer(self._trace, segment_lengths, 0) if self.batch_dim.stacked_axis < ragged_axis: diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index 95734493c648..346fe464f28f 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -16,7 +16,8 @@ from __future__ import annotations import collections -from collections.abc import Iterator, Sequence +import contextlib +from collections.abc import Callable, Iterator, Sequence import dataclasses import functools from functools import partial @@ -27,7 +28,7 @@ import re import types import typing -from typing import Any, Callable, NamedTuple, Protocol, Union +from typing import Any, NamedTuple, Protocol, Union, cast as type_cast import warnings import numpy as np @@ -46,16 +47,15 @@ from jax._src import xla_bridge as xb from jax._src.interpreters import partial_eval as pe from jax._src.interpreters import xla -from jax._src.layout import XLACompatibleLayout, LayoutRequest +from jax._src.layout import AutoLayout, DeviceLocalLayout +from jax._src.sharding import Sharding as JSharding from jax._src.lib import xla_client as xc from jax._src.lib import xla_extension -from jax._src.lib import xla_extension_version from jax._src.lib.mlir import dialects from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import func as func_dialect from jax._src.lib.mlir.dialects import hlo from jax._src.lib.mlir import register_jax_dialects -from jax._src.sharding_impls import XLACompatibleSharding from jax._src.state.types import AbstractRef map, unsafe_map = util.safe_map, map @@ -68,32 +68,30 @@ # mypy implicitly sets this variable to true when type checking. MYPY = False -_JAX_DUMP_IR_TO = config.DEFINE_string( +_JAX_DUMP_IR_TO = config.string_flag( 'jax_dump_ir_to', os.getenv('JAX_DUMP_IR_TO', ''), help="Path to which the IR that is emitted by JAX should be dumped as " "text files. If omitted, JAX will not dump IR. " "Supports the special value 'sponge' to pick the path from the " "environment variable TEST_UNDECLARED_OUTPUTS_DIR.") +_JAX_INCLUDE_DEBUG_INFO_IN_DUMPS = config.string_flag( + 'jax_include_debug_info_in_dumps', + os.getenv('JAX_INCLUDE_DEBUG_INFO_IN_DUMPS', "True"), + help="Determine whether or not to keep debug symbols and location information " + "when dumping IR code. By default, debug information will be preserved in " + "the IR dump. To avoid exposing source code and potentially sensitive " + "information, set to false") lowerable_effects: effects_lib.EffectTypeSet = effects_lib.lowerable_effects # IR Helpers def dense_int_elements(xs) -> ir.DenseIntElementsAttr: - return ir.DenseIntElementsAttr.get(np.asarray(xs, np.int64)) + return type_cast(ir.DenseIntElementsAttr, + ir.DenseIntElementsAttr.get(np.asarray(xs, np.int64))) -def dense_int_array(xs) -> ir.DenseIntElementsAttr | ir.DenseI64ArrayAttr: - # TODO: b/321794305 - remove this check when jaxlib is on StableHLO API v5 or higher - if hlo.get_api_version() < 5: - return dense_int_elements(xs) - return ir.DenseI64ArrayAttr.get(np.asarray(xs, np.int64)) - -# TODO: b/321794305 - delete this when jaxlib is on StableHLO API v6 or higher -def dense_int_array_v6(xs) -> ir.DenseIntElementsAttr | ir.DenseI64ArrayAttr: - if hlo.get_api_version() < 6 or xc.mlir_api_version < 55: - return dense_int_elements(xs) - return ir.DenseI64ArrayAttr.get(np.asarray(xs, np.int64)) +dense_int_array = ir.DenseI64ArrayAttr.get def dense_bool_elements(xs: Sequence[bool]) -> ir.DenseElementsAttr: a = np.packbits(np.array(xs, np.bool_), bitorder='little') @@ -104,11 +102,8 @@ def dense_bool_elements(xs: Sequence[bool]) -> ir.DenseElementsAttr: return ir.DenseElementsAttr.get( a, type=ir.IntegerType.get_signless(1), shape=[len(xs)]) -def dense_bool_array(xs: Sequence[bool]) -> ir.DenseElementsAttr | ir.DenseBoolArrayAttr: - # TODO: b/321794305 - remove this check when jaxlib is on StableHLO API v6 or higher - if hlo.get_api_version() < 6 or xc.mlir_api_version < 55: - return dense_bool_elements(xs) - return ir.DenseBoolArrayAttr.get(xs) +def dense_bool_array(xs: Sequence[bool]) -> ir.DenseBoolArrayAttr: + return ir.DenseBoolArrayAttr.get(xs) # type: ignore def i32_attr(i): return ir.IntegerAttr.get(ir.IntegerType.get_signless(32), i) def i64_attr(i): return ir.IntegerAttr.get(ir.IntegerType.get_signless(64), i) @@ -126,7 +121,7 @@ def lower_dim(d): return hlo.reshape(int1d, d) ds = map(lower_dim, sizes) if not ds: - return ir_constant(np.array([], np.int32)) + return type_cast(ir.RankedTensorType, ir_constant(np.array([], np.int32))) elif len(ds) == 1: return ds[0] else: @@ -189,7 +184,7 @@ def _array_ir_types(aval: core.ShapedArray | core.DShapedArray aval = core.physical_aval(aval) # type: ignore if not core.is_constant_shape(aval.shape): return _dynamic_array_ir_types(aval) # type: ignore - return (ir.RankedTensorType.get(aval.shape, dtype_to_ir_type(aval.dtype)),) + return (ir.RankedTensorType.get(aval.shape, dtype_to_ir_type(aval.dtype)),) # type: ignore def _dynamic_array_ir_types(aval: core.ShapedArray) -> Sequence[ir.Type]: dyn_size = ir.ShapedType.get_dynamic_size() @@ -276,7 +271,7 @@ def _numpy_array_constant(x: np.ndarray | np.generic) -> Sequence[ir.Value]: if x.dtype == np.bool_: x = np.packbits(x, bitorder='little') # type: ignore x = np.ascontiguousarray(x) - attr = ir.DenseElementsAttr.get(x, type=element_type, shape=shape) + attr = ir.DenseElementsAttr.get(x, type=element_type, shape=shape) # type: ignore return (hlo.constant(attr),) @@ -305,16 +300,16 @@ def _ndarray_constant_handler(val: np.ndarray | np.generic) -> Sequence[ir.Value """ if val.dtype == dtypes.float0: return _numpy_array_constant(np.zeros(val.shape, dtype=np.bool_)) - elif np.any(np.equal(0, val.strides)) and val.size > 0: + elif 0 in val.strides and val.size > 0: zero_stride_axes, = np.where(np.equal(0, val.strides)) other_axes, = np.where(np.not_equal(0, val.strides)) - collapsed_val = val[tuple(0 if ax in zero_stride_axes else slice(None) # type: ignore - for ax in range(val.ndim))] # type: ignore + collapsed_val = val[tuple(0 if ax in zero_stride_axes else slice(None) # type: ignore + for ax in range(val.ndim))] out = hlo.broadcast_in_dim( ir.RankedTensorType.get( - val.shape, dtype_to_ir_type(collapsed_val.dtype)), + val.shape, dtype_to_ir_type(collapsed_val.dtype)), # type: ignore _numpy_array_constant(collapsed_val)[0], - dense_int_array_v6(other_axes)) + dense_int_array(other_axes)) # type: ignore return (out,) else: return _numpy_array_constant(val) @@ -474,9 +469,12 @@ def dump_module_message(module: ir.Module, stage_name: str) -> str: def _make_string_safe_for_filename(s: str) -> str: return re.sub(r'[^\w.)( -]', '', s) -def module_to_string(module: ir.Module) -> str: +def module_to_string(module: ir.Module, enable_debug_info=None) -> str: output = io.StringIO() - module.operation.print(file=output, enable_debug_info=True) + if enable_debug_info is None: + enable_debug_flag = str.lower(_JAX_INCLUDE_DEBUG_INFO_IN_DUMPS.value) + enable_debug_info = enable_debug_flag not in ('false', '0') + module.operation.print(file=output, enable_debug_info=enable_debug_info) return output.getvalue() def module_to_bytecode(module: ir.Module) -> bytes: @@ -548,13 +546,6 @@ class LoweringParameters: # existing Jax rules. override_lowering_rules: tuple[tuple[core.Primitive, LoweringRule]] | None = None - # The current lowering platforms, a non-empty tuple containing some of - # 'cpu', 'cuda', 'rocm', 'tpu'. If the tuple has multiple entries we are - # doing multi-platform lowering, otherwise it can specify cross-platform - # lowering. The value None specifies the default lowering platform. - # This is used only in export and jax2tf. - platforms: tuple[str, ...] | None = None - # Signals that the entire computation being lowered operates on global # constants. This will result in adding jax.global_constant attributes # to the arguments of all functions that are created, e.g., floor_divide. @@ -562,17 +553,13 @@ class LoweringParameters: # or multi-platform lowering. global_constant_computation: bool = False - # TODO(b/302258959): in JAX native execution we cannot lower the tokens - # to stablehlo.token for the top-level function, due to runtime limitations. - # Instead, we use dummy bool[0] arrays. This is controlled by setting - # replace_tokens_with_dummy to True (default). However, when exporting StableHLO - # we can use real tokens, because the resulting StableHLO will not be - # executed directly, but will be embedded as an inner function in a larger - # JAX or TensorFlow program. In these cases, replace_tokens_with_dummy must - # be set to False (for serialization versions >= 9). - # Once the PJRT is extended to use tokens, we can use tokens even in the - # native execution (and we can remove this parameter). - replace_tokens_with_dummy: bool = True + # Signals that we are lowering for exporting. + + for_export: bool = False + # See usage in https://jax.readthedocs.io/en/latest/export.html#ensuring-forward-and-backward-compatibility + # We have this here to ensure it is reflected in the cache keys + export_ignore_forward_compatibility: bool = False + @dataclasses.dataclass class TracebackCaches: @@ -595,6 +582,8 @@ class ModuleContext: ip: ir.InsertionPoint symbol_table: ir.SymbolTable backend_or_name: str | xb.XlaBackend | None + # The lowering platforms for the module. Can be more than one only when + # exporting. platforms: Sequence[str] axis_context: AxisContext keepalives: list[Any] @@ -629,8 +618,7 @@ def __init__( module: ir.Module | None = None, ip: ir.InsertionPoint | None = None, symbol_table: ir.SymbolTable | None = None, - cached_primitive_lowerings: None | (dict[Any, - func_dialect.FuncOp]) = None, + cached_primitive_lowerings: None | (dict[Any, func_dialect.FuncOp]) = None, traceback_caches: None | TracebackCaches = None, shape_poly_state = None): @@ -695,8 +683,13 @@ class LoweringRuleContext: tokens_in: TokenSet tokens_out: TokenSet | None # Mutable store for output containers axis_size_env: dict[core.Var, ir.Value] | None = None # Dynamic axis sizes - dim_var_values: Sequence[ir.Value] = () # The values for the dimension variables - # in same order as module_context.shape_poly_state.dim_vars + # The values for the dimension variables in same order as + # module_context.shape_poly_state.dim_vars + dim_var_values: Sequence[ir.Value] = () + jaxpr_eqn_ctx: core.JaxprEqnContext | None = None + # Override module_context.platforms if not None. Used during multi-platform + # lowering, when in a scope with a subset of the module_context.platforms. + platforms: Sequence[str] | None = None def set_tokens_out(self, tokens_out: TokenSet): assert self.tokens_out is None, 'Should only set `tokens_out` once.' @@ -740,13 +733,13 @@ def wrap_singleton_ir_values(x: ir.Value | Sequence[ir.Value] def flatten_lowering_ir_args( xs: Sequence[ir.Value | Sequence[ir.Value]] -) -> Sequence[Sequence[ir.Value]]: +) -> Sequence[ir.Value]: return util.flatten(map(wrap_singleton_ir_values, xs)) _module_name_regex = re.compile(r"[^\w.-]") def sharded_aval(aval: core.AbstractValue, - sharding: XLACompatibleSharding | None) -> core.AbstractValue: + sharding: JSharding | None) -> core.AbstractValue: """Returns the new aval sharded based on sharding proto.""" if sharding is None: return aval @@ -771,7 +764,7 @@ def eval_dynamic_shape(ctx: LoweringRuleContext, partial(core.evaluate_shape, shape, ctx.module_context.shape_poly_state.dim_vars), multiple_results=True)(ctx, *ctx.dim_var_values) return tuple(operator.index(d) if core.is_constant_dim(d) else d_ir - for d, d_ir in zip(shape, util.flatten(res))) # type: ignore + for d, d_ir in zip(shape, util.flatten(res))) # TODO: replace usage of eval_dynamic_shape_as_vals with eval_dynamic_shape_as_ivals def eval_dynamic_shape_as_vals(ctx: LoweringRuleContext, @@ -819,30 +812,36 @@ class LoweringResult(NamedTuple): _platforms_with_donation = ["cpu", "cuda", "rocm", "tpu"] -def _to_logical_op_sharding( - aval: core.AbstractValue, sharding: XLACompatibleSharding | None, -) -> xc.HloSharding | None: +def _to_physical_op_sharding( + aval: core.AbstractValue, sharding: JSharding | None, +) -> xc.OpSharding | None: if sharding is None: return None - assert isinstance(sharding, sharding_impls.XLACompatibleSharding) + assert isinstance(sharding, JSharding) if isinstance(aval, AbstractRef): - return _to_logical_op_sharding(aval.inner_aval, sharding) + return _to_physical_op_sharding(aval.inner_aval, sharding) assert isinstance(aval, (core.ShapedArray, core.DShapedArray)) - return sharding._to_xla_hlo_sharding(aval.ndim) + if dtypes.issubdtype(aval.dtype, dtypes.extended): + sharding = sharding_impls.physical_sharding(aval, sharding) + aval = core.physical_aval(aval) + return sharding._to_xla_hlo_sharding(aval.ndim).to_proto() # type: ignore -def _to_xla_layout(layout: XLACompatibleLayout | None | LayoutRequest) -> str | None: +def _to_xla_layout(layout: DeviceLocalLayout | None | AutoLayout, + aval: core.AbstractValue) -> str | None: if layout is None: return "default" - if isinstance(layout, LayoutRequest): + if isinstance(layout, AutoLayout): return "auto" - return layout._to_xla_layout() + if aval is core.abstract_token: + return "default" + return layout._to_xla_layout(aval.dtype) # type: ignore -def _get_mem_kind(s: XLACompatibleSharding | None) -> str | None: +def _get_mem_kind(s: JSharding | None) -> str | None: if s is None: return None - assert isinstance(s, sharding_impls.XLACompatibleSharding) + assert isinstance(s, JSharding) return s.memory_kind @@ -857,16 +856,17 @@ def lower_jaxpr_to_module( name_stack: source_info_util.NameStack, donated_args: Sequence[bool], replicated_args: Sequence[bool] | None = None, - arg_shardings: Sequence[XLACompatibleSharding | None] | None = None, - result_shardings: Sequence[XLACompatibleSharding | None] | None = None, - in_layouts: Sequence[XLACompatibleLayout | None | LayoutRequest] | None = None, - out_layouts: Sequence[XLACompatibleLayout | None | LayoutRequest] | None = None, + arg_shardings: Sequence[JSharding | None] | None = None, + result_shardings: Sequence[JSharding | None] | None = None, + in_layouts: Sequence[DeviceLocalLayout | None | AutoLayout] | None = None, + out_layouts: Sequence[DeviceLocalLayout | None | AutoLayout] | None = None, arg_names: Sequence[str | None] | None = None, - result_names: Sequence[str | None] | None = None, + result_names: Sequence[str] | None = None, num_replicas: int = 1, num_partitions: int = 1, all_default_mem_kind: bool = True, input_output_aliases: None | tuple[int | None, ...] = None, + propagated_out_mem_kinds: tuple[None | str, ...] | None = None, lowering_parameters: LoweringParameters, ) -> LoweringResult: """Lowers a top-level jaxpr to an MLIR module. @@ -893,12 +893,13 @@ def lower_jaxpr_to_module( platforms_with_donation = [p for p in platforms if p in _platforms_with_donation] if platforms_with_donation: - if len(platforms_with_donation) != len(platforms): + if len(platforms_with_donation) != len(platforms) and ( + xla_donated_args or any(donated_args)): raise NotImplementedError( "In multi-platform lowering either all or no lowering platforms " f"should support donation. Lowering for {platforms} of which " f"only {platforms_with_donation} support donation") - if num_partitions > 1 and xla_extension_version >= 220 and ( + if num_partitions > 1 and ( result_shardings is None or all(s is None for s in result_shardings)): xla_donated_args = donated_args if xla_donated_args is None: @@ -941,26 +942,13 @@ def lower_jaxpr_to_module( else: dim_vars = () - arg_op_shardings = ( - map(_to_logical_op_sharding, jaxpr.in_avals, arg_shardings) - if arg_shardings is not None else arg_shardings) - result_op_shardings = ( - map(_to_logical_op_sharding, jaxpr.out_avals, result_shardings) - if result_shardings is not None else result_shardings) - - arg_layouts = (map(_to_xla_layout, in_layouts) if in_layouts is not None - else in_layouts) - result_layouts = (map(_to_xla_layout, out_layouts) if out_layouts is not None - else out_layouts) - ctx = ModuleContext(backend_or_name=backend_or_name, platforms=platforms, axis_context=axis_context, keepalives=keepalives, channel_iterator=channel_iter, host_callbacks=host_callbacks, lowering_parameters=lowering_parameters, - shape_poly_state=ShapePolyLoweringState( - dim_vars, lowering_parameters.platforms)) + shape_poly_state=ShapePolyLoweringState(dim_vars, platforms)) with ctx.context, ir.Location.unknown(ctx.context): # Remove module name characters that XLA would alter. This ensures that # XLA computation preserves the module name. @@ -969,30 +957,27 @@ def lower_jaxpr_to_module( attrs["sym_name"] = ir.StringAttr.get(module_name) attrs["mhlo.num_replicas"] = i32_attr(num_replicas) attrs["mhlo.num_partitions"] = i32_attr(num_partitions) - replace_tokens_with_dummy = lowering_parameters.replace_tokens_with_dummy lower_jaxpr_to_fun( ctx, "main", jaxpr, ordered_effects, name_stack=name_stack, public=True, - create_tokens=replace_tokens_with_dummy, - replace_tokens_with_dummy=replace_tokens_with_dummy, - num_output_tokens=0, replicated_args=replicated_args, - arg_shardings=arg_op_shardings, - result_shardings=result_op_shardings, + arg_shardings=arg_shardings, + result_shardings=result_shardings, input_output_aliases=input_output_aliases, xla_donated_args=xla_donated_args, arg_names=arg_names, result_names=result_names, arg_memory_kinds=arg_memory_kinds, result_memory_kinds=result_memory_kinds, - arg_layouts=arg_layouts, - result_layouts=result_layouts) + arg_layouts=in_layouts, + result_layouts=out_layouts, + propagated_out_mem_kinds=propagated_out_mem_kinds) try: if not ctx.module.operation.verify(): raise ValueError( - "Cannot lower jaxpr with verifier errors." + + "Cannot lower jaxpr with verifier errors. " + dump_module_message(ctx.module, "verification")) except ir.MLIRError as e: msg_lines = ["Cannot lower jaxpr with verifier errors:"] @@ -1003,7 +988,7 @@ def emit_diagnostic_info(d): emit_diagnostic_info(n) for d in e.error_diagnostics: emit_diagnostic_info(d) - raise ValueError("\n".join(msg_lines) + + raise ValueError("\n".join(msg_lines) + "\n" + dump_module_message(ctx.module, "verification")) from e return LoweringResult(ctx.module, ctx.keepalives, ctx.host_callbacks, @@ -1103,15 +1088,6 @@ def update_tokens(self, tokens: TokenSet) -> TokenSet: new_tokens.append((eff, self._tokens[eff])) return TokenSet(new_tokens) -def dummy_token_type() -> Sequence[ir.Type]: - # TODO(b/302258959): For now HLO does not allow hlo.TokenType among - # arguments and results, so we use bool[0] to pass tokens to the - # top-level function only. - return aval_to_ir_types(core.ShapedArray((0,), np.bool_)) - -def dummy_token() -> Sequence[ir.Value]: - return ir_constants(np.zeros(0, np.bool_)) - def lower_jaxpr_to_fun( ctx: ModuleContext, name: str, @@ -1119,23 +1095,21 @@ def lower_jaxpr_to_fun( effects: Sequence[core.Effect], name_stack: source_info_util.NameStack, *, - create_tokens: bool = False, public: bool = False, - replace_tokens_with_dummy: bool = False, replicated_args: Sequence[bool] | None = None, - arg_shardings: Sequence[xc.HloSharding | None] | None = None, - result_shardings: Sequence[xc.HloSharding | None] | None = None, + arg_shardings: Sequence[JSharding | None] | None = None, + result_shardings: Sequence[JSharding | None] | None = None, use_sharding_annotations: bool = True, input_output_aliases: Sequence[int | None] | None = None, xla_donated_args: Sequence[bool] | None = None, - num_output_tokens: int = 0, api_name: str = "jit", arg_names: Sequence[str | None] | None = None, - result_names: Sequence[str | None] | None = None, + result_names: Sequence[str] | None = None, arg_memory_kinds: Sequence[str | None] | None = None, result_memory_kinds: Sequence[str | None] | None = None, - arg_layouts: Sequence[str | None] | None = None, - result_layouts: Sequence[str | None] | None = None, + arg_layouts: Sequence[DeviceLocalLayout | None | AutoLayout] | None = None, + result_layouts: Sequence[DeviceLocalLayout | None | AutoLayout] | None = None, + propagated_out_mem_kinds: tuple[None | str, ...] | None = None, ) -> func_dialect.FuncOp: """Lowers jaxpr and its callees to an IR function. @@ -1148,11 +1122,7 @@ def lower_jaxpr_to_fun( jaxpr: the jaxpr to lower. effects: a sequence of `core.Effect`s corresponding to an ordering of tokens that will be created in or used by the lowered function. - create_tokens: if true, the HLO will create tokens and ignore dummy input - tokens. See b/302258959. public: if true, the function's visibility is set to "public". - replace_tokens_with_dummy: if true, token arguments/return values are - replaced with bool arrays of size [0]. See b/302258959. replicated_args: if present, annotates arguments as replicated. arg_shardings: sharding annotations for each argument (optional). result_shardings: sharding annotations for each result (optional). @@ -1169,50 +1139,38 @@ def lower_jaxpr_to_fun( Returns: MLIR func op """ - def aval_to_types(aval): - if replace_tokens_with_dummy and aval is core.abstract_token: - aval = core.ShapedArray((), np.dtype(np.bool_)) - return aval_to_ir_types(aval) # The first dimension variable may be the platform index num_dim_vars = len(ctx.shape_poly_state.dim_vars) dim_var_avals = [core.ShapedArray((), dtypes.canonicalize_dtype(np.int64))] * num_dim_vars - dim_var_types = map(aval_to_types, dim_var_avals) + dim_var_types = map(aval_to_ir_types, dim_var_avals) # Function inputs: *dim_var_values, *tokens, *actual_inputs - input_types = map(aval_to_types, jaxpr.in_avals) - output_types = map(aval_to_types, jaxpr.out_avals) + input_types = map(aval_to_ir_types, jaxpr.in_avals) + output_types = map(aval_to_ir_types, jaxpr.out_avals) num_tokens = len(effects) - if create_tokens: - # TODO(b/302258959): Use actual tokens - token_types = [dummy_token_type() for _ in effects] - output_token_types = [dummy_token_type() for _ in range(num_output_tokens)] - else: - # If we aren't creating tokens they will be the initial inputs to the - # MLIR function. - output_token_types = [] - token_types = [token_type() for _ in effects] + token_types = [token_type() for _ in effects] token_avals = [core.abstract_token] * num_tokens # Order of arguments: dim vars, tokens, array inputs input_avals = dim_var_avals + token_avals + jaxpr.in_avals input_types = [*dim_var_types, *token_types, *input_types] - output_avals = [core.abstract_token] * (len(output_token_types) + num_tokens) + jaxpr.out_avals - output_types = [*output_token_types, *token_types, *output_types] + output_avals = [core.abstract_token] * num_tokens + jaxpr.out_avals + output_types = [*token_types, *output_types] if input_output_aliases is not None: token_input_output_aliases = [None] * (num_dim_vars + num_tokens) input_output_aliases = [*token_input_output_aliases, *input_output_aliases] # Update the existing aliases to account for the new output values input_output_aliases = [None if a is None - else a + num_output_tokens + num_tokens - for a in input_output_aliases] # type: ignore + else a + num_tokens + for a in input_output_aliases] if arg_shardings is not None: token_shardings = [None] * (num_dim_vars + num_tokens) arg_shardings = [*token_shardings, *arg_shardings] if result_shardings is not None: - token_shardings = [None] * (num_tokens + num_output_tokens) + token_shardings = [None] * num_tokens result_shardings = [*token_shardings, *result_shardings] if replicated_args is not None: token_replicated_args = [False] * (num_dim_vars + num_tokens) @@ -1221,13 +1179,13 @@ def aval_to_types(aval): token_memory_kinds = [None] * (num_dim_vars + num_tokens) arg_memory_kinds = [*token_memory_kinds, *arg_memory_kinds] if result_memory_kinds is not None: - token_memory_kinds = [None] * (num_tokens + num_output_tokens) + token_memory_kinds = [None] * num_tokens result_memory_kinds = [*token_memory_kinds, *result_memory_kinds] if arg_layouts is not None: token_layouts = [None] * (num_dim_vars + num_tokens) arg_layouts = [*token_layouts, *arg_layouts] if result_layouts is not None: - token_layouts = [None] * (num_tokens + num_output_tokens) + token_layouts = [None] * num_tokens result_layouts = [*token_layouts, *result_layouts] if xla_donated_args is not None: xla_donated_args = [*([False] * (num_dim_vars + num_tokens)), *xla_donated_args] @@ -1242,11 +1200,9 @@ def aval_to_types(aval): ir_arg_shardings = None if arg_shardings is not None: - in_avals = [None] * (num_dim_vars + num_tokens) + list(jaxpr.in_avals) ir_arg_shardings = util.flatten( [[_to_physical_op_sharding(a, s)] * len(types) - for a, s, types in zip(in_avals, arg_shardings, input_types)]) - del in_avals + for a, s, types in zip(input_avals, arg_shardings, input_types)]) ir_arg_memory_kinds = None if arg_memory_kinds is not None: @@ -1256,7 +1212,8 @@ def aval_to_types(aval): ir_arg_layouts = None if arg_layouts is not None: ir_arg_layouts = util.flatten( - [[l] * len(types) for l, types in zip(arg_layouts, input_types)]) + [[_to_xla_layout(l, a)] * len(types) + for l, a, types in zip(arg_layouts, input_avals, input_types)]) ir_donated_args = None if xla_donated_args is not None: @@ -1265,25 +1222,39 @@ def aval_to_types(aval): ir_result_shardings = None if result_shardings is not None: - out_avals = [None] * (num_tokens + num_output_tokens) + list(jaxpr.out_avals) ir_result_shardings = util.flatten( [[_to_physical_op_sharding(a, s)] * len(types) - for a, s, types in zip(out_avals, result_shardings, output_types)]) - del out_avals + for a, s, types in zip(output_avals, result_shardings, output_types)]) ir_result_memory_kinds = None + custom_call_ir_result_memory_kinds = None if result_memory_kinds is not None: - ir_result_memory_kinds = util.flatten( - [[mk] * len(types) for mk, types in zip(result_memory_kinds, output_types)]) + if propagated_out_mem_kinds is None: + propagated_out_mem_kinds = (None,) * len(result_memory_kinds) + res, custom_call_res = [], [] + for pom, mk, types in zip(propagated_out_mem_kinds, result_memory_kinds, + output_types): + if pom is not None and mk is None: + res.append([pom] * len(types)) + else: + res.append([mk] * len(types)) # type: ignore + # To add the custom call on the output to signal a transfer, only do it + # if memory kind comes from out_shardings on `jit` and result_memory_kinds + # comes from out_shardings on `jit`. + custom_call_res.append([mk] * len(types)) + ir_result_memory_kinds = util.flatten(res) + custom_call_ir_result_memory_kinds = util.flatten(custom_call_res) ir_result_layouts = None if result_layouts is not None: ir_result_layouts = util.flatten( - [[l] * len(types) for l, types in zip(result_layouts, output_types)]) + [[_to_xla_layout(l, a)] * len(types) + for l, a, types in zip(result_layouts, output_avals, output_types)]) if ( replicated_args is not None or ir_arg_shardings is not None + or ir_arg_memory_kinds is not None or ir_arg_layouts is not None or input_output_aliases is not None or ir_donated_args is not None @@ -1306,6 +1277,11 @@ def aval_to_types(aval): if sharding is not None: attrs["mhlo.sharding"] = get_sharding_attr(sharding) + if ir_arg_memory_kinds is not None: + for attrs, memory_kind in zip(arg_attrs, ir_arg_memory_kinds): + if memory_kind is not None: + attrs["mhlo.memory_kind"] = ir.StringAttr.get(memory_kind) + if ir_arg_layouts is not None: for attrs, layout in zip(arg_attrs, ir_arg_layouts): if layout is not None: @@ -1365,6 +1341,11 @@ def aval_to_types(aval): if sharding is not None: attrs['mhlo.sharding'] = get_sharding_attr(sharding) + if ir_result_memory_kinds is not None: + for attrs, mem_kind in zip(result_attrs, ir_result_memory_kinds): + if mem_kind is not None: + attrs['mhlo.memory_kind'] = ir.StringAttr.get(mem_kind) + if ir_result_layouts is not None: for attrs, layout in zip(result_attrs, ir_result_layouts): if layout is not None: @@ -1400,49 +1381,29 @@ def aval_to_types(aval): if ir_arg_shardings is not None and name == "main": flat_args = [ - a.dtype._rules.replicate_trailing_dims(entry_lowering_ctx, o, a) # type: ignore + replicate_trailing_dims(entry_lowering_ctx, o, a) if (a is not core.abstract_token and - dtypes.issubdtype(a.dtype, dtypes.extended) and s is None) else o # type: ignore + dtypes.issubdtype(a.dtype, dtypes.extended) and s is None) else o # pytype: disable=attribute-error for o, s, a in zip(flat_args, ir_arg_shardings, input_avals) ] - if ir_arg_memory_kinds is not None: - flat_args = [ - a if mk is None else wrap_with_memory_kind(a, mk, a_aval) - for a, mk, a_aval in zip(flat_args, ir_arg_memory_kinds, input_avals)] - _, token_args, unflattened_args = util.split_list( util.unflatten(flat_args, map(len, input_types)), [num_dim_vars, num_tokens]) - if create_tokens: - tokens_in = TokenSet.create(effects) + tokens_in = TokenSet(zip(effects, token_args)) + args: list[list[ir.Value]] = unflattened_args + if name is not None: + callee_name_stack = name_stack.extend(util.wrap_name(name, api_name)) else: - tokens_in = TokenSet(zip(effects, token_args)) - args: list[list[ir.Value]] = [] - for aval, arg in zip(jaxpr.in_avals, unflattened_args): - if replace_tokens_with_dummy and aval is core.abstract_token: - args.append([hlo.create_token()]) - else: - args.append(arg) - callee_name_stack = name_stack.extend(util.wrap_name(name, api_name)) + callee_name_stack = name_stack consts = [ir_constants(xla.canonicalize_dtype(x)) for x in jaxpr.consts] out_vals, tokens_out = jaxpr_subcomp( ctx, jaxpr.jaxpr, callee_name_stack, tokens_in, consts, *args, dim_var_values=dim_var_values) outs = [] - if create_tokens: - for _ in range(num_output_tokens): - outs.append(dummy_token()) - for _ in effects: - outs.append(dummy_token()) - else: - for eff in effects: - outs.append(tokens_out.get(eff)) - for aval, out in zip(jaxpr.out_avals, out_vals): - if replace_tokens_with_dummy and aval is core.abstract_token: - outs.append(ir_constants(np.zeros((), np.bool_))) - else: - outs.append(out) + for eff in effects: + outs.append(wrap_singleton_ir_values(tokens_out.get(eff))) + outs.extend(out_vals) flat_outputs = util.flatten(outs) @@ -1451,16 +1412,19 @@ def aval_to_types(aval): o if s is None else wrap_with_sharding_op(entry_lowering_ctx, o, o_aval, s) for o, s, o_aval in zip(flat_outputs, ir_result_shardings, output_avals)] - if ir_result_memory_kinds is not None: + # Insert a custom call if output is on host because XLA needs that to do the + # transfer. + if custom_call_ir_result_memory_kinds is not None and name == "main": flat_outputs = [ o if mk is None else wrap_with_memory_kind(o, mk, o_aval) - for o, mk, o_aval in zip(flat_outputs, ir_result_memory_kinds, output_avals)] + for o, mk, o_aval in zip( + flat_outputs, custom_call_ir_result_memory_kinds, output_avals)] if ir_result_shardings is not None and name == "main": flat_outputs = [ - a.dtype._rules.replicate_trailing_dims(entry_lowering_ctx, o, a) # type: ignore + replicate_trailing_dims(entry_lowering_ctx, o, a) if (a is not core.abstract_token and - dtypes.issubdtype(a.dtype, dtypes.extended) and s is None) else o # type: ignore + dtypes.issubdtype(a.dtype, dtypes.extended) and s is None) else o # pytype: disable=attribute-error for o, s, a in zip(flat_outputs, ir_result_shardings, output_avals) ] @@ -1482,13 +1446,17 @@ def wrap_with_memory_kind( return op.result -def _to_physical_op_sharding( - aval: core.AbstractValue | None, sharding: xc.HloSharding | None -) -> xc.OpSharding | None: - if (isinstance(aval, core.ShapedArray) and dtypes.issubdtype(aval.dtype, dtypes.extended) - and sharding is not None): - return aval.dtype._rules.physical_hlo_sharding(aval, sharding).to_proto() - return None if sharding is None else sharding.to_proto() # type: ignore +def replicate_trailing_dims(ctx, val: ir.Value, aval) -> ir.Value: + # Set the sharding of extended dtypes to be UNCONSTRAINED + # (i.e. XLA will choose) on aval.shape. + # For the trailing dims i.e. the dimension of key_shape on the base_array, + # the sharding is set to be REPLICATED always. + # For example: if the key.shape is (8, 2) and key_data(key).shape is (8, 2, 2), + # then the sharding will be P(P.UNCONSTRAINED, P.UNCONSTRAINED, None). + # The below custom call achieves the sharding like above example. + return wrap_with_sharding_op( + ctx, val, aval, xc.HloSharding.replicate().to_proto(), + unspecified_dims=set(range(aval.ndim))) def _emit_lowering_rule_as_fun(lowering_rule, @@ -1590,7 +1558,7 @@ def get_override_lowering_rule(primitive: core.Primitive) -> LoweringRule | None default_rule = override_rule else: # First the platform-specific rules - for p in ctx.platforms: + for p in _platforms_for_eqn_ctx(eqn.ctx) or ctx.platforms: if eqn.primitive in _platform_specific_lowerings[p]: platform_rules[p] = _platform_specific_lowerings[p][eqn.primitive] elif eqn.primitive in xla._backend_specific_translations[p]: @@ -1609,7 +1577,7 @@ def get_override_lowering_rule(primitive: core.Primitive) -> LoweringRule | None name_stack=source_info.name_stack, avals_in=avals_in, avals_out=map(aval, eqn.outvars), tokens_in=tokens_in, - tokens_out=None, dim_var_values=dim_var_values) + tokens_out=None, jaxpr_eqn_ctx=eqn.ctx, dim_var_values=dim_var_values) if config.dynamic_shapes.value: axis_size_env = {d: read(d)[0] for a in avals_in if type(a) is core.DShapedArray @@ -1652,13 +1620,23 @@ def get_override_lowering_rule(primitive: core.Primitive) -> LoweringRule | None return map(read, jaxpr.outvars), tokens +def _platforms_for_eqn_ctx(eqn_ctx: core.JaxprEqnContext | None + ) -> tuple[str, ...]: + """Returns platforms to override based on compute type of jaxpr equation.""" + if eqn_ctx is None: + return () + if eqn_ctx.compute_type == 'device_host': + return ('cpu',) + return () + + def lower_per_platform(ctx: LoweringRuleContext, description: str, platform_rules: dict[str, LoweringRule], default_rule: LoweringRule | None, effects: effects_lib.Effects, *rule_args: ir.Value, - **rule_kwargs) -> ir.Value: + **rule_kwargs) -> Sequence[ir.Value]: """Emits code for a primitive for the current lowering platform(s). For example, given @@ -1693,7 +1671,8 @@ def lower_per_platform(ctx: LoweringRuleContext, rule_args: the args of the lowering rules. rule_kwargs: the kwargs of the lowering rules. """ - platforms: Sequence[str] = ctx.module_context.platforms + platforms: Sequence[str] = (_platforms_for_eqn_ctx(ctx.jaxpr_eqn_ctx) or + ctx.platforms or ctx.module_context.platforms) # Special case the common case (single-platform lowering) if len(platforms) == 1: rule = platform_rules.get(platforms[0], default_rule) @@ -1725,7 +1704,11 @@ def lower_per_platform(ctx: LoweringRuleContext, assert kept_rules # If there is a single rule left just apply the rule, without conditionals. if len(kept_rules) == 1: - return kept_rules[0](ctx, *rule_args, **rule_kwargs) + output = kept_rules[0](ctx, *rule_args, **rule_kwargs) + wrapped_out = map(wrap_singleton_ir_values, output) + map(lambda o: wrap_compute_type_in_place(ctx, o.owner), + util.flatten(wrapped_out)) + return output assert len(platforms) > 1 and len(kept_rules) >= 2, (platforms, kept_rules) assert len(ctx.dim_var_values) >= 1, "Must have a platform_index variable" @@ -1750,7 +1733,10 @@ def lower_per_platform(ctx: LoweringRuleContext, index=rule_idx_op, num_branches=len(kept_rules)) for i, rule in enumerate(kept_rules): - inner_ctx = ctx.replace() + platforms_for_this_rule = [p + for p, rule_idx in platform_to_kept_rules_idx.items() + if rule_idx == i] + inner_ctx = ctx.replace(platforms=platforms_for_this_rule) branch = case_op.regions[i].blocks.append() with ir.InsertionPoint(branch): output = rule(inner_ctx, *rule_args, **rule_kwargs) @@ -1759,6 +1745,8 @@ def lower_per_platform(ctx: LoweringRuleContext, except TypeError as e: raise ValueError("Output of translation rule must be iterable: " f"{description}, got output {output}") from e + map(lambda o: wrap_compute_type_in_place(ctx, o.owner), + util.flatten(out_nodes)) if inner_ctx.tokens_out is not None: assert len(ordered_effects) == len(inner_ctx.tokens_out) out_nodes = [inner_ctx.tokens_out.get(eff) @@ -1789,35 +1777,43 @@ def lower_fun(fun: Callable, multiple_results: bool = True) -> Callable: The returned function does not use `avals_out`, so callers may pass any value as `avals_out`.""" - def f_lowered(ctx, *args, **params): + def f_lowered(ctx: LoweringRuleContext, *args, **params): f = fun if multiple_results else lambda *args, **kw: (fun(*args, **kw),) wrapped_fun = lu.wrap_init(f, params) + manager = (contextlib.nullcontext() if ctx.jaxpr_eqn_ctx is None else + ctx.jaxpr_eqn_ctx.manager) - if config.dynamic_shapes.value: - # We might be applying this function to arguments with dynamic shapes, - # i.e. there might be Vars in the shape tuples of ctx.avals_in. In that - # case, we need to form a jaxpr with leading binders for those axis size - # arguments (by computing an InputType and using trace_to_jaxpr_dynamic2), - # and we need to call jaxpr_subcomp with these arguments made explicit. - args = (*ctx.axis_size_env.values(), *args) - idx = {d: core.DBIdx(i) for i, d in enumerate(ctx.axis_size_env)} - i32_aval = core.ShapedArray((), np.dtype('int32')) - implicit_args = [(i32_aval, False)] * len(ctx.axis_size_env) - explicit_args = [(a.update(shape=tuple(idx.get(d, d) for d in a.shape)) - if type(a) is core.DShapedArray else a, True) - for a in ctx.avals_in] - wrapped_fun = lu.annotate(wrapped_fun, (*implicit_args, *explicit_args)) - jaxpr, _, consts = pe.trace_to_jaxpr_dynamic2(wrapped_fun) - else: - jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_fun, ctx.avals_in) - # TODO(frostig,mattjj): check ctx.avals_out against jaxpr avals out? - - out, tokens = jaxpr_subcomp( - ctx.module_context, jaxpr, ctx.name_stack, ctx.tokens_in, - _ir_consts(consts), *map(wrap_singleton_ir_values, args), - dim_var_values=ctx.dim_var_values) - ctx.set_tokens_out(tokens) - return out + with manager: + if config.dynamic_shapes.value: + # We might be applying this function to arguments with dynamic shapes, + # i.e. there might be Vars in the shape tuples of ctx.avals_in. In that + # case, we need to form a jaxpr with leading binders for those axis size + # arguments (by computing an InputType and using trace_to_jaxpr_dynamic2), + # and we need to call jaxpr_subcomp with these arguments made explicit. + assert ctx.axis_size_env is not None + args = (*ctx.axis_size_env.values(), *args) + idx = {d: core.DBIdx(i) for i, d in enumerate(ctx.axis_size_env)} + i32_aval = core.ShapedArray((), np.dtype('int32')) + implicit_args = [(i32_aval, False)] * len(ctx.axis_size_env) + explicit_args = [(a.update(shape=tuple(idx.get(d, d) for d in a.shape)) # type: ignore + if type(a) is core.DShapedArray else a, True) + for a in ctx.avals_in] + wrapped_fun = lu.annotate(wrapped_fun, (*implicit_args, *explicit_args)) + jaxpr, _, consts = pe.trace_to_jaxpr_dynamic2(wrapped_fun) + else: + jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_fun, ctx.avals_in) + # TODO(frostig,mattjj): check ctx.avals_out against jaxpr avals out? + + if ctx.platforms is not None: + sub_context = ctx.module_context.replace(platforms=ctx.platforms) + else: + sub_context = ctx.module_context + out, tokens = jaxpr_subcomp( + sub_context, jaxpr, ctx.name_stack, ctx.tokens_in, + _ir_consts(consts), *map(wrap_singleton_ir_values, args), + dim_var_values=ctx.dim_var_values) + ctx.set_tokens_out(tokens) + return out return f_lowered @@ -1895,7 +1891,22 @@ def core_call_lowering(ctx: LoweringRuleContext, register_lowering(core.call_p, partial(core_call_lowering, name="core_call")) register_lowering(core.closed_call_p, - partial(core_call_lowering, name="core_closed_call")) + partial(core_call_lowering, name=None)) + +def map_compute_type(c_type): + if c_type == 'device_host': + return 'host' + elif c_type == 'device': + return 'dense' + raise ValueError('Invalid compute type received. Current supported values ' + 'are `device_host` and `device`') + +def wrap_compute_type_in_place(ctx, op): + if ctx.jaxpr_eqn_ctx is not None and ctx.jaxpr_eqn_ctx.compute_type is not None: + dict_attr = {"_xla_compute_type": ir.StringAttr.get( + map_compute_type(ctx.jaxpr_eqn_ctx.compute_type))} + op.operation.attributes["mhlo.frontend_attributes"] = ir.DictAttr.get(dict_attr) + def broadcast_in_dim(ctx: LoweringRuleContext, op, aval_out: core.AbstractValue, *, broadcast_dimensions) -> ir.Value: @@ -1913,17 +1924,19 @@ def broadcast_in_dim(ctx: LoweringRuleContext, op, aval_out: core.AbstractValue, else: if not core.is_constant_shape(aval_out.shape): # type: ignore shape = eval_dynamic_shape_as_tensor(ctx, aval_out.shape) # type: ignore - return hlo.dynamic_broadcast_in_dim( + out = hlo.dynamic_broadcast_in_dim( aval_to_ir_type(aval_out), op, shape, - dense_int_array_v6(broadcast_dimensions), + dense_int_array(broadcast_dimensions), ) else: assert all(d != ir.ShapedType.get_dynamic_size() for d in aval_out.shape), aval_out # type: ignore - return hlo.broadcast_in_dim( + out = hlo.broadcast_in_dim( aval_to_ir_type(aval_out), op, - dense_int_array_v6(broadcast_dimensions)) + dense_int_array(broadcast_dimensions)) + wrap_compute_type_in_place(ctx, out.owner) + return out def multi_broadcast_in_dim(ctx: LoweringRuleContext, ops: Sequence[ir.Value], @@ -1933,7 +1946,7 @@ def multi_broadcast_in_dim(ctx: LoweringRuleContext, out = [] for op, op_aval in zip(ops, ops_avals): op_aval_shape = op_aval.shape # type: ignore - if core.definitely_equal_shape(op_aval_shape, out_shape): # type: ignore + if core.definitely_equal_shape(op_aval_shape, out_shape): out.append(op) else: assert len(op_aval_shape) <= len(out_shape), (op_aval_shape, out_shape) @@ -2068,7 +2081,7 @@ def full_like_aval(ctx: LoweringRuleContext, value, aval: core.ShapedArray) -> i def add_jaxvals_lowering(ctx, x, y): if (isinstance(a := ctx.avals_in[0], core.ShapedArray) and dtypes.issubdtype(a.dtype, dtypes.extended)): - return lower_fun(lambda x, y: [a.dtype._rules.add(a.dtype, x, y)])(ctx, x, y) # type: ignore + return lower_fun(lambda x, y: [a.dtype._rules.add(a.dtype, x, y)])(ctx, x, y) return [hlo.add(x, y)] register_lowering(ad_util.add_jaxvals_p, add_jaxvals_lowering) @@ -2079,9 +2092,8 @@ def compare_hlo(x, y, direction: str, comparison_type: str | None = None): """Creates CompareOp.""" if comparison_type is None: elem_type = ir.RankedTensorType(x.type).element_type - if ir.IntegerType.isinstance(elem_type): - comparison_type = ("UNSIGNED" if ir.IntegerType.is_unsigned(elem_type) - else "SIGNED") + if isinstance(elem_type, ir.IntegerType): + comparison_type = "UNSIGNED" if elem_type.is_unsigned else "SIGNED" else: comparison_type = "FLOAT" @@ -2169,11 +2181,34 @@ def get_sharding_attr(sharding_proto: xc.OpSharding): # The MHLO to HLO conversion supports both, and the proto representation is # more compact. if len(sharding_proto.tile_assignment_devices) > 100: - return ir.StringAttr.get(sharding_proto.SerializeToString()) + return ir.StringAttr.get(sharding_proto.SerializeToString()) # type: ignore else: return ir.StringAttr.get(repr(xc.HloSharding.from_proto(sharding_proto))) +def wrap_with_layout_op(ctx: LoweringRuleContext, + x: ir.Value, + aval_out: core.AbstractValue, + layout: DeviceLocalLayout, + aval_in: core.AbstractValue): + result_type = aval_to_ir_type(aval_out) + out_shape = core.physical_aval(aval_out).shape # type: ignore + if core.is_constant_shape(out_shape): + result_shapes = None + else: + result_shapes = [eval_dynamic_shape_as_tensor(ctx, out_shape)] + + op = custom_call('LayoutConstraint', result_types=[result_type], operands=[x], + api_version=1, + result_shapes=result_shapes, + # Set operand layouts to anything. XLA will ignore it. + operand_layouts=[list(range(aval_in.ndim))], # type: ignore + # TODO(yashkatariya): Figure out how to pass tiling to the + # custom call. + result_layouts=[layout.major_to_minor[::-1]]) + return op.result + + # MLIR lowerings for lax primitives def cache_lowering(f): @@ -2219,7 +2254,8 @@ def xla_computation_to_mlir_module(xla_computation: xc.XlaComputation def merge_mlir_modules(dst_module: ir.Module, sym_name: str, - src_module: ir.Module) -> str: + src_module: ir.Module, + dst_symtab: ir.SymbolTable | None = None) -> str: """ Args: dst_module: the module into which the contents of src_module should be @@ -2230,6 +2266,7 @@ def merge_mlir_modules(dst_module: ir.Module, src_module: the module whose contents are to be alpha-renamed, set to private visibility, and merged into dst_module. src_module must contain exactly one symbol named "main". + dst_symtab: the symbol table of `dst_module` Functions in src_module will be renamed such that they do not collide with functions in dst_module. @@ -2243,7 +2280,7 @@ def merge_mlir_modules(dst_module: ir.Module, assert dst_module.context == src_module.context src_symtab = ir.SymbolTable(src_module.operation) - dst_symtab = ir.SymbolTable(dst_module.operation) + dst_symtab = dst_symtab or ir.SymbolTable(dst_module.operation) used_names = set() # Rename all symbols in src_module that clash with names in dst_module, or @@ -2281,6 +2318,7 @@ def merge_mlir_modules(dst_module: ir.Module, for op in src_module.body.operations: dst_module.body.append(op) + dst_symtab.insert(op) return renamings["main"] @@ -2308,7 +2346,8 @@ def fallback(ctx: LoweringRuleContext, *args, **params): ctx.avals_out, **params) xla_module = xla_computation_to_mlir_module(xla_computation) callee_name = merge_mlir_modules( - module_ctx.module, f"xla_fallback_{prim.name}", xla_module) + module_ctx.module, f"xla_fallback_{prim.name}", xla_module, + dst_symtab=module_ctx.symbol_table) output_types = map(aval_to_ir_types, ctx.avals_out) flat_output_types = util.flatten(output_types) output_type = (ir.TupleType.get_tuple(flat_output_types) @@ -2351,7 +2390,8 @@ def send_to_host(channel: int, token: hlo.TokenType, operand: Any, def receive_from_host(channel: int, token: hlo.TokenType, out_aval: core.ShapedArray, name: str, *, - sharding: xc.OpSharding | None = None) -> ir.Value: + sharding: xc.OpSharding | None = None, +) -> tuple[ir.Value, ir.Value]: channel_handle = hlo.ChannelHandle.get(channel, RECV_FROM_HOST_TYPE) recv_op = hlo.RecvOp([aval_to_ir_type(out_aval), hlo.TokenType.get()], token, channel_handle, @@ -2419,7 +2459,7 @@ def _wrapped_callback(*args): # pylint: disable=function-redefined recv_channels.append(channel) ifrt_callback = backend.make_python_callback_from_host_send_and_recv( _wrapped_callback, operand_shapes, result_shapes, send_channels, - recv_channels, pickle_util.dumps) # type: ignore # pylint: disable=missing-parameter + recv_channels, pickle_util.dumps) ctx.module_context.add_host_callback(ifrt_callback) return outputs, token @@ -2436,14 +2476,20 @@ def _aval_to_default_layouts(aval): # Row major order is default for `NumPy`. return [list(range(aval.ndim - 1, -1, -1)) for aval in avals] + def emit_python_callback( - ctx: LoweringRuleContext, callback, token: Any | None, - operands: Sequence[ir.Value], operand_avals: Sequence[core.ShapedArray], + ctx: LoweringRuleContext, + callback, + token: Any | None, + operands: Sequence[ir.Value], + operand_avals: Sequence[core.ShapedArray], result_avals: Sequence[core.ShapedArray], - has_side_effect: bool, *, sharding: xc.OpSharding | None = None, + *, + has_side_effect: bool, + sharding: xc.OpSharding | None = None, operand_layouts: Sequence[Sequence[int] | None] | None = None, result_layouts: Sequence[Sequence[int] | None] | None = None, - ) -> tuple[Sequence[ir.Value], Any, Any]: +) -> tuple[Sequence[ir.Value], Any, Any]: """Emits MLIR that calls back to a provided Python function.""" if len(ctx.module_context.platforms) > 1: raise NotImplementedError("multi-platform lowering for python_callback") @@ -2473,15 +2519,18 @@ def _wrapped_callback(*args): raise RuntimeError( "Mismatched number of outputs from callback. " "Expected: {}, Actual: {}".format(len(result_avals), len(out_vals))) + # Handle Python literals, and custom arrays, e.g., tf.Tensor. + out_vals = tuple(xla.canonicalize_dtype(np.asarray(a)) for a in out_vals) for i, (out_val, out_aval) in enumerate(zip(out_vals, result_avals)): if out_val.shape != out_aval.shape: raise RuntimeError( - f"Incorrect output shape for return value {i}: " - "Expected: {}, Actual: {}".format(out_aval.shape, out_val.shape)) + f"Incorrect output shape for return value #{i}: " + f"Expected: {out_aval.shape}, Actual: {out_val.shape}") if out_val.dtype != out_aval.dtype: raise RuntimeError( - f"Incorrect output dtype for return value {i}: " - "Expected: {}, Actual: {}".format(out_aval.dtype, out_val.dtype)) + f"Incorrect output dtype for return value #{i}: " + f"Expected: {out_aval.dtype}, Actual: {out_val.dtype}") + if platform == "tpu": # On TPU we cannot receive empty arrays. So, we return from the wrapped # callback only the non-empty results, and we will create empty constants @@ -2619,7 +2668,7 @@ def custom_call( if backend_config is None: backend_config_attr = ir.StringAttr.get("") elif isinstance(backend_config, (str, bytes)): - backend_config_attr = ir.StringAttr.get(backend_config) + backend_config_attr = ir.StringAttr.get(backend_config) # type: ignore elif isinstance(backend_config, dict): # TODO(necula): it seems that the CustomCallOp constructor requires that # backend_config_attr be a string attribute, even though in some cases we @@ -2688,8 +2737,8 @@ def custom_call( op = hlo.CustomCallOp.build_generic(results=result_types, operands=operands, attributes=attributes) if isinstance(backend_config, dict): - backend_config_attr = ir.DictAttr.get(backend_config) - op.operation.attributes["mhlo.backend_config"] = backend_config_attr + op.operation.attributes["mhlo.backend_config"] = ir.DictAttr.get( + backend_config) return op @@ -2743,12 +2792,12 @@ def prep_one_pad(pad_lo_hi: tuple[core.DimSize, core.DimSize]): rw = hlo.ReduceWindowOp( list(map(aval_to_ir_type, out_avals)), operands, init_values, - dense_int_array_v6(window_dimensions), - window_strides=dense_int_array_v6(window_strides), - base_dilations=dense_int_array_v6(base_dilation), - window_dilations=dense_int_array_v6(window_dilation), + dense_int_array(window_dimensions), + window_strides=dense_int_array(window_strides), + base_dilations=dense_int_array(base_dilation), + window_dilations=dense_int_array(window_dilation), padding=ir.DenseIntElementsAttr.get(np.asarray(padding, np.int64), - shape=(len(padding), 2))) + shape=[len(padding), 2])) reducer = rw.regions[0].blocks.append(*(scalar_types + scalar_types)) with ir.InsertionPoint(reducer): hlo.return_(reducer_body(reducer)) diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py index 44738b3df16f..344fa78de46c 100644 --- a/jax/_src/interpreters/partial_eval.py +++ b/jax/_src/interpreters/partial_eval.py @@ -14,13 +14,13 @@ from __future__ import annotations from collections import namedtuple -from collections.abc import Sequence, Hashable +from collections.abc import Callable, Sequence, Hashable from contextlib import contextmanager, AbstractContextManager from functools import partial import inspect import itertools as it import operator as op -from typing import Any, Callable, NamedTuple, Union +from typing import Any, NamedTuple, Union from weakref import ref import numpy as np @@ -34,6 +34,7 @@ from jax._src import linear_util as lu from jax._src import profiler from jax._src import source_info_util +from jax._src import compute_on from jax._src.api_util import (flattened_fun_in_tree, flatten_fun_nokwargs, fun_sourceinfo) from jax._src.core import (Trace, Tracer, Jaxpr, Literal, get_aval, @@ -41,10 +42,11 @@ ConcreteArray, Var, DropVar, raise_to_shaped, Atom, JaxprEqn, Primitive, ShapedArray, DShapedArray, mapped_aval, unmapped_aval, DBIdx, InDBIdx, OutDBIdx, - InputType, OutputType, get_referent) + InputType, OutputType, get_referent, JaxprEqnContext) from jax._src.state.types import AbstractRef from jax._src.tree_util import (PyTreeDef, treedef_tuple, tree_unflatten, - KeyPath, generate_key_paths, keystr) + tree_flatten, tree_structure, KeyPath, generate_key_paths, + keystr) from jax._src.util import (unzip2, safe_zip, safe_map, toposort, split_list, merge_lists, partition_list, OrderedSet, as_hashable_function, weakref_lru_cache, subs_list) @@ -77,7 +79,7 @@ def _update_annotation_known( class Name: def __init__(self, a): self.a = a names = [Name(a) for a, _ in orig_type] - avals = [a.update(shape=tuple(names[d.val] if type(d) is DBIdx else d # type: ignore + avals = [a.update(shape=tuple(names[d.val] if type(d) is DBIdx else d for d in a.shape)) if type(a) is DShapedArray else a for a, e in orig_type if e] avals = [a for a, known in zip(avals, in_knowns) if known] @@ -228,7 +230,8 @@ def default_process_primitive(self, primitive, tracers, params): if primitive.multiple_results: out_tracers = [JaxprTracer(self, PartialVal.unknown(aval), None) for aval in out_aval] - eqn = new_eqn_recipe(tracers, out_tracers, primitive, params, effects, source) + eqn = new_eqn_recipe(tracers, out_tracers, primitive, params, effects, + source) for t in out_tracers: t.recipe = eqn return out_tracers else: @@ -382,7 +385,7 @@ def const_out_axes_thunk(): for a in out_avals] effs = core.filter_named_axis_effects(jaxpr.effects, {params['axis_name']}) src_info = source_info_util.current() - eqn = new_eqn_recipe((*const_tracers, *env_tracers, *unknown_arg_tracers), # type: ignore[arg-type] + eqn = new_eqn_recipe((*const_tracers, *env_tracers, *unknown_arg_tracers), out_tracers, primitive, staged_params, effs, src_info) for t in out_tracers: t.recipe = eqn @@ -409,7 +412,7 @@ def todo(out): name_stack = self._current_truncated_name_stack() source = source_info_util.current().replace(name_stack=name_stack) eqn = new_eqn_recipe(in_tracers, out_tracers, primitive, new_params, - jaxpr.effects, source) + jaxpr.effects, source) for t in out_tracers: t.recipe = eqn return merge_lists(out_knowns, out_tracers, out_consts) @@ -590,7 +593,7 @@ def trace_to_subjaxpr_nounits_dyn( JaxprTracer(trace, PartialVal.unknown(in_avals[d.val]), ConstVar(constval.shape[i])) assert core.same_referent(constval.shape[i], in_consts_full[d.val]) - shape = [in_consts_full[d.val] if type(d) is DBIdx else d # type: ignore + shape = [in_consts_full[d.val] if type(d) is DBIdx else d for d in aval.shape] aval = aval.update(shape=tuple(shape)) in_consts_full[idx] = JaxprTracer(trace, PartialVal.unknown(aval), @@ -599,7 +602,7 @@ def trace_to_subjaxpr_nounits_dyn( for idx, (aval, explicit) in enumerate(in_type): if not explicit: assert in_consts_full[idx] is not None if isinstance(aval, DShapedArray): - assert all(type(d) is not DBIdx or in_consts_full[d.val] is not None # type: ignore + assert all(type(d) is not DBIdx or in_consts_full[d.val] is not None for d in aval.shape) # Next, build tracers for all unknown inputs, using the in_consts_full list @@ -609,7 +612,7 @@ def trace_to_subjaxpr_nounits_dyn( for aval, explicit in in_type: if explicit and not next(in_knowns_iter): if isinstance(aval, DShapedArray): - shape = [in_consts_full[d.val] if type(d) is DBIdx else d # type: ignore + shape = [in_consts_full[d.val] if type(d) is DBIdx else d for d in aval.shape] aval = aval.update(shape=tuple(shape)) tracer = JaxprTracer(trace, PartialVal.unknown(aval), LambdaBinding()) @@ -847,7 +850,7 @@ def trace_to_subjaxpr_nounits_fwd2( out_pvals = [t.pval for t in out_tracers] # Which consts (aka residuals) are just forwarded inputs? Check obj id. - in_consts = [pval.get_known() for pval in in_pvals if pval.is_known()] + in_consts = [pval.get_known() for pval in in_pvals if pval.is_known()] id_map = {id(c): i for i, c in enumerate(in_consts)} input_fwds: list[int | None] = [id_map.get(id(c)) for c in consts] @@ -875,14 +878,15 @@ class JaxprEqnRecipe(NamedTuple): params: dict[str, Any] effects: core.Effects source_info: source_info_util.SourceInfo + ctx: JaxprEqnContext def new_eqn_recipe(in_tracers: Sequence[JaxprTracer], out_tracers: Sequence[JaxprTracer], primitive: Primitive, params: dict[str, Any], effects: core.Effects, - source_info: source_info_util.SourceInfo - ) -> JaxprEqnRecipe: + source_info: source_info_util.SourceInfo, + ctx: JaxprEqnContext | None = None) -> JaxprEqnRecipe: # TODO(necula): move these checks to core.check_jaxpr, and call in more places if primitive.call_primitive or primitive.map_primitive: assert "call_jaxpr" in params @@ -894,18 +898,22 @@ def new_eqn_recipe(in_tracers: Sequence[JaxprTracer], assert ("donated_invars" in params and len(params["donated_invars"]) == len(params["call_jaxpr"].invars)) out_avals = [core.raise_to_shaped(t.aval) for t in out_tracers] + ctx = ctx or JaxprEqnContext(compute_on.current_compute_type(), + config.threefry_partitionable.value) return JaxprEqnRecipe(object(), tuple(in_tracers), map(ref, out_tracers), - out_avals, primitive, params, effects, source_info) + out_avals, primitive, params, effects, source_info, + ctx) def recipe_to_eqn(getvar: Callable[[JaxprTracer], Atom], recipe: JaxprEqnRecipe) -> core.JaxprEqn: - (_, in_tracers, out_tracer_refs, out_avals, prim, params, eff, src) = recipe + (_, in_tracers, out_tracer_refs, out_avals, prim, params, eff, src, + ctx) = recipe invars = [getvar(t) for t in in_tracers] out_tracers = [t_ref() for t_ref in out_tracer_refs] - outvars = [DropVar(a) if t is None else getvar(t) # type: ignore + outvars = [DropVar(a) if t is None else getvar(t) for a, t in zip(out_avals, out_tracers)] - return new_jaxpr_eqn(invars, outvars, prim, params, eff, src) + return new_jaxpr_eqn(invars, outvars, prim, params, eff, src, ctx) def tracers_to_jaxpr( in_tracers: Sequence[JaxprTracer], @@ -958,7 +966,7 @@ def type_substitute(aval: AbstractValue) -> AbstractValue: outvars = [DropVar(type_substitute(a)) if rf() is None else newvar(rf()) for a, rf in zip(r.out_avals, r.out_tracer_refs)] eqns.append(new_jaxpr_eqn(in_atoms, outvars, r.primitive, r.params, - r.effects, r.source_info)) + r.effects, r.source_info, r.ctx)) processed_eqn_ids.add(r.eqn_id) elif isinstance(r, LambdaBinding): if not any(t is in_tracer for in_tracer in in_tracers): @@ -971,7 +979,7 @@ def type_substitute(aval: AbstractValue) -> AbstractValue: consts[var] = r.val t_to_var[id(t)] = var elif isinstance(r, FreeVar): - env[newvar(t)] = r.val # type: ignore + env[newvar(t)] = r.val elif isinstance(r, Literal): pass elif r is None: @@ -984,7 +992,7 @@ def type_substitute(aval: AbstractValue) -> AbstractValue: const_vars, const_vals = unzip2(consts.items()) outvars = map(get_atom, out_tracers) # type: ignore[arg-type] jaxpr_effects = make_jaxpr_effects(const_vars, invars, outvars, eqns) - jaxpr = Jaxpr(const_vars, invars, # type: ignore[list-item,arg-type] + jaxpr = Jaxpr(const_vars, invars, # type: ignore[arg-type] outvars, eqns, jaxpr_effects) config.enable_checks.value and core.check_jaxpr(jaxpr) # del getvar # needed to avoid cyclic-reference closure, apparently! @@ -1005,11 +1013,9 @@ def convert_constvars_jaxpr(jaxpr: Jaxpr) -> Jaxpr: @weakref_lru_cache def convert_invars_to_constvars(jaxpr: Jaxpr, n: int) -> Jaxpr: - """Move n invars to constvars. Like an inverse of convert_constvars_Jaxpr.""" + """Move n invars to constvars. Like an inverse of convert_constvars_jaxpr.""" if n == 0: return jaxpr.replace() # 'return jaxpr' would create cache reference cycle - if any(isinstance(eff, effects.JaxprInputEffect) for eff in jaxpr.effects): - raise NotImplementedError config.enable_checks.value and core.check_jaxpr(jaxpr) constvars, invars = split_list(jaxpr.invars, [n]) dbg = jaxpr.debug_info and jaxpr.debug_info._replace( @@ -1248,21 +1254,24 @@ def has_effects(effects) -> bool: if has_effects(eqn.effects) or isinstance(policy, SaveableType): map(partial(write, False, False), eqn.outvars) elif isinstance(policy, Offloadable): - from jax._src.dispatch import device_put_p, TransferToMemoryKind # type: ignore + # TODO(slebedev): This is a legit error which requires a BUILD fix. + from jax._src.dispatch import device_put_p, TransferToMemoryKind # pytype: disable=import-error resvars = [newvar(v.aval) for v in eqn.outvars] outvars_copy = list[Atom](eqn.outvars) offload_eqn = core.JaxprEqn( outvars_copy, resvars, device_put_p, - dict(device=TransferToMemoryKind(policy.dst), src=None), - set(), source_info_util.new_source_info()) + dict(devices=[TransferToMemoryKind(policy.dst)], srcs=[None]), + set(), source_info_util.new_source_info(), + JaxprEqnContext(None, False)) known_eqns.append(offload_eqn) # resvars are known and available in the backward jaxpr. map(partial(write, False, True), resvars) residuals.update(resvars) reload_eqn = core.JaxprEqn( - resvars, eqn.outvars, device_put_p, # type: ignore - dict(device=TransferToMemoryKind(policy.src), src=None), - set(), source_info_util.new_source_info()) + resvars, eqn.outvars, device_put_p, + dict(devices=[TransferToMemoryKind(policy.src)], srcs=[None]), + set(), source_info_util.new_source_info(), + JaxprEqnContext(None, False)) staged_eqns.append(reload_eqn) # outvars are known and available in the backward jaxpr. map(partial(write, False, True), eqn.outvars) @@ -1384,10 +1393,11 @@ def call_partial_eval_custom_rule( residuals = [newvar(res_aval(params_known, var.aval)) for var in jaxpr_staged.invars[:num_res]] eqn_known = new_jaxpr_eqn(ins_known, [*out_binders_known, *residuals], - eqn.primitive, params_known, jaxpr_known.effects, eqn.source_info) + eqn.primitive, params_known, jaxpr_known.effects, + eqn.source_info, eqn.ctx) eqn_staged = new_jaxpr_eqn([*residuals, *ins_staged], out_binders_staged, eqn.primitive, params_staged, - jaxpr_staged.effects, eqn.source_info) + jaxpr_staged.effects, eqn.source_info, eqn.ctx) assert len(eqn_staged.invars) == len(jaxpr_staged.invars) new_inst = [x for x, inst in zip(eqn.invars, inst_in) if type(x) is Var and not inst] @@ -1427,11 +1437,11 @@ def closed_call_partial_eval_custom_rule( eqn_known = new_jaxpr_eqn([*ins_known, *res_ref_binders], [*out_binders_known, *res_val_binders], eqn.primitive, params_known, jaxpr_known.effects, - eqn.source_info) + eqn.source_info, eqn.ctx) eqn_staged = new_jaxpr_eqn([*res_val_vars, *res_ref_binders, *ins_staged], out_binders_staged, eqn.primitive, params_staged, jaxpr_staged.effects, - eqn.source_info) + eqn.source_info, eqn.ctx) assert len(eqn_staged.invars) == len(jaxpr_staged.in_avals) assert len(ins_known) + len(res_ref_binders) == len(jaxpr_known.jaxpr.invars) assert len(ins_staged) + len(res_ref_binders) + len(res_val_vars) == len(jaxpr_staged.jaxpr.invars) @@ -1511,7 +1521,7 @@ def prune_closed_jaxpr_outputs( ) -> ClosedJaxpr: return _prune_closed_jaxpr_outputs(jaxpr, tuple(used_outputs)) -@weakref_lru_cache +@partial(weakref_lru_cache, trace_context_in_key=False) def _prune_closed_jaxpr_outputs( jaxpr: ClosedJaxpr, used_outputs: tuple[bool, ...] ) -> ClosedJaxpr: @@ -1608,7 +1618,7 @@ def dce_jaxpr_call_rule(used_outputs: list[bool], eqn: JaxprEqn new_eqn = new_jaxpr_eqn( [v for v, used in zip(eqn.invars, used_inputs) if used], [v for v, used in zip(eqn.outvars, used_outputs) if used], - eqn.primitive, new_params, new_jaxpr.effects, eqn.source_info) + eqn.primitive, new_params, new_jaxpr.effects, eqn.source_info, eqn.ctx) return used_inputs, new_eqn dce_rules[core.call_p] = dce_jaxpr_call_rule @@ -1629,7 +1639,7 @@ def dce_jaxpr_closed_call_rule(used_outputs: list[bool], eqn: JaxprEqn new_eqn = new_jaxpr_eqn( [v for v, used in zip(eqn.invars, used_inputs) if used], [v for v, used in zip(eqn.outvars, used_outputs) if used], - eqn.primitive, new_params, closed_jaxpr.effects, eqn.source_info) + eqn.primitive, new_params, closed_jaxpr.effects, eqn.source_info, eqn.ctx) return used_inputs, new_eqn dce_rules[core.closed_call_p] = dce_jaxpr_closed_call_rule @@ -1673,13 +1683,17 @@ def __init__(self, trace, aval, line_info=None): self.aval = aval def full_lower(self): - return self + var = self._trace.frame.tracer_to_var.get(id(self)) + if var is None: return self + val = self._trace.frame.constvar_to_val.get(var) + if val is None: return self + return core.full_lower(val) def _contents(self): return () def _origin_msg(self): - if not self._trace.main.jaxpr_stack: # type: ignore + if not self._trace.main.jaxpr_stack: # If this Tracer has been leaked the jaxpr stack may no longer be # available. So we can't print as much origin information. return ("\nThis DynamicJaxprTracer was created on line " @@ -1725,12 +1739,19 @@ def get_referent(self): frame = self._trace.frame val = frame.constvar_to_val.get(frame.tracer_to_var.get(id(self))) return self if val is None else get_referent(val) -api_util._shaped_abstractify_handlers[DynamicJaxprTracer] = op.attrgetter("aval") + +def _dynamic_jaxpr_tracer_shaped_abstractify(x): + return core.raise_to_shaped(x.aval) +api_util._shaped_abstractify_handlers[DynamicJaxprTracer] = _dynamic_jaxpr_tracer_shaped_abstractify def make_jaxpr_effects(constvars, invars, outvars, eqns) -> effects.Effects: + sentinel = object() jaxpr_effects = set() - all_vars = [*constvars, *invars] + all_vars = {v: i for i, v in enumerate(it.chain(constvars, invars))} for eqn in eqns: + if eqn.primitive is core.mutable_array_p: + outvar, = eqn.outvars + all_vars[outvar] = None # type: ignore for eff in eqn.effects: if isinstance(eff, effects.JaxprInputEffect): if eff.input_index >= len(eqn.invars): @@ -1740,14 +1761,14 @@ def make_jaxpr_effects(constvars, invars, outvars, eqns) -> effects.Effects: "\n Jaxpr: " f"{core.Jaxpr(constvars, invars, outvars, eqns, set())}") invar = eqn.invars[eff.input_index] - if invar not in all_vars: + if (input_index := all_vars.get(invar, sentinel)) is sentinel: raise ValueError( f"`JaxprInputEffect` {eff} does not have " f"corresponding input: {invar}." f"\n Equation: {eqn}\n" "\n Jaxpr: " f"{core.Jaxpr(constvars, invars, outvars, eqns, set())}") - eff = eff.replace(input_index=all_vars.index(invar)) + eff = eff.replace(input_index=input_index) jaxpr_effects.add(eff) return jaxpr_effects @@ -1783,12 +1804,15 @@ def __init__(self): def add_eqn(self, eqn: core.JaxprEqn): self.eqns.append(eqn) - def to_jaxpr(self, out_tracers: Sequence[Tracer] - ) -> tuple[Jaxpr, list[Any], list[tuple[Any, str]]]: + def to_jaxpr(self, trace: DynamicJaxprTrace, out_tracers: Sequence[Tracer] + ) -> tuple[Jaxpr, list[Any], list[tuple[PyTreeDef, PyTreeDef, tuple[Any, str]]]]: # It's not necessary, but we keep the tracer-to-var mapping injective: assert len(self.tracer_to_var) == len(set(self.tracer_to_var.values())) invars = self.attrs_vars + self.invars - state_outvars = [self.tracer_to_var[id(t)] for t in get_states(self.attrs_tracked)] + state_ans, end_trees = unzip2( + tree_flatten(t) for t in get_states(self.attrs_tracked)) + state_outvars = [self.tracer_to_var[id(trace.full_raise(x))] + for xs in state_ans for x in xs] explicit_outvars = [self.tracer_to_var[id(t)] for t in out_tracers] outvars = state_outvars + explicit_outvars constvars, constvals = unzip2(self.constvar_to_val.items()) @@ -1796,8 +1820,9 @@ def to_jaxpr(self, out_tracers: Sequence[Tracer] jaxpr = Jaxpr(constvars, invars, outvars, self.eqns, jaxpr_effects) jaxpr, constvals = _const_folding_and_forwarding(jaxpr, constvals) jaxpr, constvals = _inline_literals(jaxpr, constvals) # type: ignore + init_trees = [tree_structure(init_val) for init_val in self.attrs_inits] set_states(self.attrs_tracked, self.attrs_inits) - return jaxpr, list(constvals), self.attrs_tracked + return jaxpr, list(constvals), zip(init_trees, end_trees, self.attrs_tracked) def to_jaxpr2(self, out_tracers): # It's not necessary, but we keep the tracer-to-var mapping injective: @@ -1832,10 +1857,11 @@ def find_progenitors(self, tracer): produced = set(eqn.outvars) & active_vars if produced: active_vars.difference_update(produced) - active_vars.update(eqn.invars) + active_vars.update({v for v in eqn.invars if type(v) is Var}) invar_positions = [i for i, v in enumerate(self.invars) if v in active_vars] constvars = active_vars & set(self.constvar_to_val) - const_eqns = [eqn for eqn in self.eqns if set(eqn.invars) & constvars] + const_eqns = [eqn for eqn in self.eqns + if {v for v in eqn.invars if type(v) is Var} & constvars] return invar_positions, const_eqns def _const_folding_and_forwarding( @@ -1851,8 +1877,9 @@ def apply_var_sub(a: Atom) -> Atom: # if any inputs are constants and we have a constant-folding rule, apply it has_input_effect = any(isinstance(eff, effects.JaxprInputEffect) for eff in eqn.effects) - if (eqn.primitive in const_fold_rules and any(v in consts for v in eqn.invars) - and not has_input_effect): + if (eqn.primitive in const_fold_rules and + any(v in consts for v in eqn.invars if isinstance(v, Var)) and + not has_input_effect): consts_in = [consts.get(v) if isinstance(v, Var) else None for v in eqn.invars] consts_out, new_eqn = const_fold_rules[eqn.primitive](consts_in, eqn) @@ -1864,7 +1891,6 @@ def apply_var_sub(a: Atom) -> Atom: # if the application trivially maps some inputs to outputs, simplify if eqn.primitive in forwarding_rules and not has_input_effect: fwd_vars, new_eqn = forwarding_rules[eqn.primitive](eqn) - assert (new_eqn is None) == all(v is not None for v in fwd_vars) for v_orig, v_new in zip(eqn.outvars, fwd_vars): if v_new is not None: var_subs[v_orig] = v_new if new_eqn is None: continue @@ -1904,7 +1930,8 @@ def _inline_literals( has_input_effect) if type(c) in core.literalable_types and not np.shape(c) and not e} def lit(a: Atom) -> Literal | None: - return lits.get(a) if isinstance(a, Var) else None + return (a if isinstance(a, Literal) else lits.get(a) if isinstance(a, Var) + else None) newname: Callable[[AbstractValue], Var] = core.gensym() newvars: dict[Var, Var] = {} newvar = lambda aval: newname(_substitute_vars_in_type(lits, newvars, aval)) @@ -1916,8 +1943,9 @@ def vars_in_shape(aval: AbstractValue) -> Sequence[Var]: return [d for d in aval.shape if isinstance(d, Var)] return [] - used = {v for eqn in jaxpr.eqns for invar in eqn.invars - for v in it.chain([invar], vars_in_shape(invar.aval))} + used = {v for eqn in jaxpr.eqns for atom in eqn.invars + for v in it.chain([atom], vars_in_shape(atom.aval)) + if isinstance(atom, Var)} used |= {v for outvar in jaxpr.outvars for v in it.chain([outvar], vars_in_shape(outvar.aval))} new_constvars = [var(v) for v in jaxpr.constvars if v in used and not lit(v)] @@ -1926,7 +1954,7 @@ def vars_in_shape(aval: AbstractValue) -> Sequence[Var]: new_invars = [var(v) for v in jaxpr.invars] new_eqns = [] for eqn in jaxpr.eqns: - invars = [lit(v) or var(v) for v in eqn.invars] + invars = [lit(x) or var(x) for x in eqn.invars] outvars = [var(v) if v in used else dropvar(v.aval) for v in eqn.outvars] new_eqns.append(eqn.replace(invars=invars, outvars=outvars)) new_outvars = [lit(v) or var(v) for v in jaxpr.outvars] @@ -1938,7 +1966,7 @@ def vars_in_shape(aval: AbstractValue) -> Sequence[Var]: class DynamicJaxprTrace(core.Trace): - __slots__ = [] # type: ignore + __slots__ = [] @property def frame(self): @@ -2027,7 +2055,8 @@ def default_process_primitive(self, primitive, tracers, params): out_tracers = [DynamicJaxprTracer(self, a, source_info) for a in out_avals] invars = map(self.getvar, tracers) outvars = map(self.makevar, out_tracers) - eqn = new_jaxpr_eqn(invars, outvars, primitive, params, effects, source_info) + eqn = new_jaxpr_eqn(invars, outvars, primitive, params, effects, + source_info) self.frame.add_eqn(eqn) return out_tracers if primitive.multiple_results else out_tracers.pop() @@ -2078,7 +2107,7 @@ def process_map(self, map_primitive, f, tracers, params): reduced_in_avals = [core.mapped_aval(axis_size, in_axis, a) if in_axis is not None else a for a, in_axis in zip(in_avals, params['in_axes'])] - with core.extend_axis_env(axis_name, params["global_axis_size"], None): # type: ignore + with core.extend_axis_env(axis_name, params["global_axis_size"], None): with core.new_sublevel(): jaxpr, reduced_out_avals, consts, () = trace_to_subjaxpr_dynamic( f, self.main, reduced_in_avals, @@ -2158,7 +2187,9 @@ def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees, def fwd_jaxpr_from_zeros(*zeros): for store in fwd.stores: store and store.reset() fwd_ = _interleave_fun(fwd, zeros) - return trace_to_subjaxpr_dynamic(fwd_, main_(), in_avals)[::2] + jaxpr, _, consts, atr = trace_to_subjaxpr_dynamic(fwd_, main_(), in_avals) + if atr: raise NotImplementedError + return jaxpr, consts out_tracers = [DynamicJaxprTracer(self, a) for a in out_avals] invars = map(self.getvar, tracers) @@ -2316,8 +2347,9 @@ def trace_to_jaxpr_dynamic( debug_info: DebugInfo | None = None, *, keep_inputs: list[bool] | None = None, -) -> tuple[Jaxpr, list[AbstractValue], list[Any], list[tuple[Any, str]]]: - with core.new_main(DynamicJaxprTrace, dynamic=True) as main: # type: ignore +) -> tuple[Jaxpr, list[AbstractValue], list[Any], + list[tuple[PyTreeDef, PyTreeDef, tuple[Any, str]]]]: + with core.new_main(DynamicJaxprTrace, dynamic=True) as main: main.jaxpr_stack = () # type: ignore jaxpr, out_avals, consts, attrs_tracked = trace_to_subjaxpr_dynamic( fun, main, in_avals, keep_inputs=keep_inputs, debug_info=debug_info) @@ -2332,7 +2364,8 @@ def trace_to_subjaxpr_dynamic( *, keep_inputs: Sequence[bool] | None = None, debug_info: DebugInfo | None = None, -) -> tuple[Jaxpr, list[AbstractValue], list[Any], list[tuple[Any, str]]]: +) -> tuple[Jaxpr, list[AbstractValue], list[Any], + list[tuple[PyTreeDef, PyTreeDef, tuple[Any, str]]]]: keep_inputs = [True] * len(in_avals) if keep_inputs is None else keep_inputs frame = JaxprStackFrame() @@ -2343,7 +2376,7 @@ def trace_to_subjaxpr_dynamic( in_tracers_ = [t for t, keep in zip(in_tracers, keep_inputs) if keep] ans = fun.call_wrapped(*in_tracers_) out_tracers = map(trace.full_raise, ans) - jaxpr, consts, attrs_tracked = frame.to_jaxpr(out_tracers) + jaxpr, consts, attrs_tracked = frame.to_jaxpr(trace, out_tracers) del fun, main, trace, frame, in_tracers, out_tracers, ans config.enable_checks.value and core.check_jaxpr(jaxpr) return jaxpr, [v.aval for v in jaxpr.outvars], consts, attrs_tracked @@ -2353,7 +2386,7 @@ def trace_to_subjaxpr_dynamic( def trace_to_jaxpr_dynamic2( fun: lu.WrappedFun, debug_info: DebugInfo | None = None ) -> tuple[Jaxpr, OutputType, list[Any]]: - with core.new_main(DynamicJaxprTrace, dynamic=True) as main: # type: ignore + with core.new_main(DynamicJaxprTrace, dynamic=True) as main: main.jaxpr_stack = () # type: ignore jaxpr, out_type, consts = trace_to_subjaxpr_dynamic2(fun, main, debug_info) del main, fun @@ -2394,7 +2427,7 @@ def trace_to_jaxpr_final( debug_info: DebugInfo | None = None, keep_inputs: Sequence[bool] | None = None, ) -> tuple[Jaxpr, list[AbstractValue], list[Any]]: - with core.new_base_main(DynamicJaxprTrace) as main: # type: ignore + with core.new_base_main(DynamicJaxprTrace) as main: main.jaxpr_stack = () # type: ignore with core.new_sublevel(): jaxpr, out_avals, consts, () = trace_to_subjaxpr_dynamic( @@ -2407,7 +2440,7 @@ def trace_to_jaxpr_final( def trace_to_jaxpr_final2( fun: lu.WrappedFun, debug_info: DebugInfo | None = None ) -> tuple[Jaxpr, OutputType, list[Any]]: - with core.new_base_main(DynamicJaxprTrace) as main: # type: ignore + with core.new_base_main(DynamicJaxprTrace) as main: main.jaxpr_stack = () # type: ignore with core.new_sublevel(): jaxpr, out_type, consts = trace_to_subjaxpr_dynamic2(fun, main, debug_info) @@ -2475,7 +2508,8 @@ def _complete_specs( for x, spec in zip(args, partial_specs): for i, name in spec.items(): d = sizes.setdefault(name, x.shape[i]) - if d is not x.shape[i] and d != x.shape[i]: raise TypeError + if d is not x.shape[i] and d != x.shape[i]: + raise TypeError(f"Provided size {d} for {name} does not match prior associated name for {name} : {x.shape[i]}") # Introduce new names as needed for Tracers in shapes. named_tracers: dict[TracerId, AbstractedAxisName] = { @@ -2614,7 +2648,7 @@ def _input_type_to_tracers( def _substitute_tracers_in_aval(a: AbstractValue) -> AbstractValue: if isinstance(a, DShapedArray) and any(type(d) is DBIdx for d in a.shape): - shape = [in_tracers[d.val] if type(d) is DBIdx else d for d in a.shape] # type: ignore + shape = [in_tracers[d.val] if type(d) is DBIdx else d for d in a.shape] return a.update(shape=tuple(shape)) return a @@ -2735,13 +2769,6 @@ def call_padding_rule(prim, in_avals, out_avals, *args, call_jaxpr, **params): return prim.bind(*subfuns, *args, **bind_params) -def _error_staging_mutable_array_p(trace, x): - raise Exception( - "mutable_array constructor can't be staged out, and in particular can't " - "be used under a jax.jit or jax.lax.scan") -custom_staging_rules[core.mutable_array_p] = _error_staging_mutable_array_p - - # TODO(mattjj): the following are deprecated; update callers to _nounits version # See https://github.com/google/jax/pull/9498 @lu.transformation @@ -2771,38 +2798,32 @@ def instantiate_const_at(trace: JaxprTrace, instantiate: bool, tracer): else: return tracer -def inline_jaxpr_into_trace(trace: DynamicJaxprTrace, jaxpr: Jaxpr, consts, - *args) -> list[Any]: +def inline_jaxpr_into_trace( + trace: DynamicJaxprTrace, jaxpr: Jaxpr, consts: Sequence[Any], + *arg_tracers: DynamicJaxprTracer) -> list[Any]: # This function is conceptually the same thing as just calling eval_jaxpr, - # but doesn't redo abstract evaluation: we know the shapes from the jaxpr. - def read(v: Atom) -> Any: - return v.val if isinstance(v, Literal) else env[v] - - def write(v: Var, val: Any) -> None: - if config.enable_checks.value and not config.dynamic_shapes.value: - assert core.typecheck(v.aval, val), (v.aval, val) - env[v] = val + const_tracers = map(trace.new_const, consts) + constvars = map(trace.getvar, const_tracers) + argvars = map(trace.getvar, arg_tracers) + env: dict[Var, Var] = dict(zip([*jaxpr.constvars, *jaxpr.invars], + [*constvars, *argvars])) - env: dict[Var, Any] = {} - map(write, jaxpr.constvars, consts) - map(write, jaxpr.invars, args) - lu = core.last_used(jaxpr) - source_info = source_info_util.current() + src = source_info_util.current() for eqn in jaxpr.eqns: - ins = map(read, eqn.invars) - out_tracers = [DynamicJaxprTracer(trace, a.aval, source_info) - for a in eqn.outvars] - invars = [trace.getvar(trace.full_raise(x)) for x in ins] - outvars = map(trace.makevar, out_tracers) - if eqn.source_info.name_stack: - eqn_source_info = source_info.replace( - name_stack=source_info.name_stack + eqn.source_info.name_stack) - else: - eqn_source_info = source_info - - new_eqn = core.new_jaxpr_eqn(invars, outvars, eqn.primitive, eqn.params, - eqn.effects, eqn_source_info) - trace.frame.add_eqn(new_eqn) - map(write, eqn.outvars, out_tracers) - core.clean_up_dead_vars(eqn, env, lu) - return map(read, jaxpr.outvars) + invars = [x if isinstance(x, Literal) else env[x] for x in eqn.invars] + outvars = [Var('', v.aval) for v in eqn.outvars] + src_ = (src if not eqn.source_info.name_stack else + src.replace(name_stack=src.name_stack + eqn.source_info.name_stack)) + trace.frame.add_eqn(core.new_jaxpr_eqn(invars, outvars, eqn.primitive, + eqn.params, eqn.effects, src_)) + map(env.setdefault, eqn.outvars, outvars) + + tracer_env: dict[Var, Any] = dict(zip([*jaxpr.constvars, *jaxpr.invars], + [*consts, *arg_tracers])) + def new_tracer(atom): + tracer = DynamicJaxprTracer(trace, atom.aval, src) + trace.frame.tracers.append(tracer) + trace.frame.tracer_to_var[id(tracer)] = env[atom] + return tracer + return [x.val if isinstance(x, Literal) else tracer_env[x] if x in tracer_env + else new_tracer(x) for x in jaxpr.outvars] diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 9fc30fb6530c..378aa2c5f85f 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -17,16 +17,16 @@ import enum from contextlib import contextmanager +import collections from collections import namedtuple -from collections.abc import Sequence, Iterable +from collections.abc import Callable, Sequence, Iterable, Iterator import dataclasses from functools import partial, lru_cache, cached_property import itertools as it import logging import math import threading -from typing import Any, Callable, NamedTuple, TypeVar, Union, cast -from collections.abc import Iterator +from typing import Any, NamedTuple, TypeVar, Union, cast import warnings import numpy as np @@ -60,21 +60,19 @@ from jax._src.interpreters import partial_eval as pe from jax._src.interpreters import mlir from jax._src.interpreters import xla -from jax._src.layout import XLACompatibleLayout, SpecifiedLayout, LayoutRequest +from jax._src.layout import DeviceLocalLayout, AutoLayout, Layout from jax._src.lib import xla_client as xc -from jax._src.lib import xla_extension_version from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import hlo from jax._src.partition_spec import PartitionSpec +from jax._src.sharding import Sharding as JSharding from jax._src.sharding_impls import ( - ArrayMapping, ArrayMappingOrAutoOrUnspecified, - AUTO, UnspecifiedValue, get_array_mapping as _get_array_mapping, is_auto, + ArrayMapping, ArrayMappingOrAutoOrUnspecified, AUTO, UNSPECIFIED, + UnspecifiedValue, get_array_mapping as _get_array_mapping, is_auto, is_unspecified, is_unspecified_or_auto, array_mapping_to_axis_resources, - SingleDeviceSharding, GSPMDSharding -) -from jax._src.util import (safe_map, safe_zip, partition_list, - wrap_name, tuple_update, tuple_delete, - distributed_debug_log, + SingleDeviceSharding, GSPMDSharding) +from jax._src.util import (safe_map, safe_zip, partition_list, wrap_name, + tuple_update, tuple_delete, distributed_debug_log, unzip2, HashableFunction, weakref_lru_cache) from jax._src.state.types import AbstractRef, RefEffect @@ -109,63 +107,79 @@ class WeakRefList(list): def identity(x): return x -def shard_arg(arg, sharding, canonicalize=True): - if canonicalize: - arg = xla.canonicalize_dtype(arg) - return shard_arg_handlers[type(arg)](arg, sharding) - - @profiler.annotate_function -def shard_args( - shardings: Sequence[sharding_impls.XLACompatibleSharding], args, -) -> Sequence[jax.Array]: - return [shard_arg(arg, shardings[i]) for i, arg in enumerate(args)] - -shard_arg_handlers: dict[Any, Callable[[Any, Any], Any]] = {} +def shard_args(shardings: Sequence[JSharding], args, canonicalize=True) -> Sequence[xc.ArrayImpl]: + # Fast path for one argument. + if len(args) == 1: + arg = args[0] + if canonicalize: + arg = xla.canonicalize_dtype(arg) + return shard_arg_handlers[type(arg)]([arg], shardings) + + # type(arg) -> (indices, args, shardings) + batches = collections.defaultdict(lambda: ([], [], [])) # type: ignore + for i, (arg, sharding) in enumerate(safe_zip(args, shardings)): + if canonicalize: + arg = xla.canonicalize_dtype(arg) + batch = batches[type(arg)] + batch[0].append(i) + batch[1].append(arg) + batch[2].append(sharding) + + # Call `shard_arg_handlers` per batch and build a flat list of arrays returned + # from each call in the same order as `args`. Since `batches` is grouped by + # types, we cannot simply flatten the results and we have to use the original + # indices to put each array back to its original position. + results: list[jax.Array | None] = [None] * len(args) + for t, (indices, a, s) in batches.items(): + outs = shard_arg_handlers[t](a, s) + for i, out in safe_zip(indices, outs): + results[i] = out + + assert all(result is not None for result in results) + return results + + +shard_arg_handlers: dict[Any, Callable[[Sequence[Any], Sequence[Any]], Sequence[Any]]] = {} -@lru_cache(maxsize=1024) -def get_addressable_devices_for_shard_arg( - s: sharding_impls.XLACompatibleSharding) -> tuple[xc.Device, ...]: - return s._addressable_device_assignment - @lru_cache(maxsize=1024) def _get_replicated_slices(num_addressable_devices: int): return ((slice(None),),) * num_addressable_devices -def _shard_token(x, sharding): - devices = get_addressable_devices_for_shard_arg(sharding) - indices = _get_replicated_slices(len(devices)) - zeros = np.zeros((), dtype=np.dtype(np.bool_)) - aval = api_util.shaped_abstractify(zeros) - return batched_device_put(aval, sharding, [zeros for _ in indices], devices) -shard_arg_handlers[core.Token] = _shard_token -def _masked_array_error(x, sharding): +def _masked_array_error(xs, shardings): raise ValueError("numpy masked arrays are not supported as direct inputs to JAX functions. " "Use arr.filled() to convert the value to a standard numpy array.") shard_arg_handlers[np.ma.MaskedArray] = _masked_array_error -def _shard_array(x, sharding): - indices = tuple(sharding.addressable_devices_indices_map(x.shape).values()) - devices = get_addressable_devices_for_shard_arg(sharding) - if x.dtype == dtypes.float0: - x = np.zeros(x.shape, dtype=np.dtype(bool)) - aval = api_util.shaped_abstractify(x) - return batched_device_put(aval, sharding, [x[i] for i in indices], devices) +def _shard_array(xs, shardings): + results = [] + for x, sharding in safe_zip(xs, shardings): + devices = sharding._addressable_device_assignment + if x.dtype == dtypes.float0: + x = np.zeros(x.shape, dtype=np.dtype(bool)) + aval = api_util.shaped_abstractify(x) + if sharding.is_fully_replicated: + shards = [x] * len(devices) + else: + indices = tuple(sharding.addressable_devices_indices_map(x.shape).values()) + shards = [x[i] for i in indices] + results.append(batched_device_put(aval, sharding, shards, devices)) + return results for _t in array_types: shard_arg_handlers[_t] = _shard_array -def _shard_darray(x, sharding): - return shard_arg(x._data, sharding) +def _shard_darray(xs, shardings): + return shard_args(shardings, [x._data for x in xs]) shard_arg_handlers[core.DArray] = _shard_darray -def _shard_mutable_array(x, sharding): - return shard_arg(x._buf, sharding) +def _shard_mutable_array(xs, shardings): + return shard_args(shardings, [x._buf for x in xs]) shard_arg_handlers[core.MutableArray] = _shard_mutable_array def batched_device_put(aval: core.ShapedArray, - sharding: jax.sharding.Sharding, xs: Sequence[Any], + sharding: JSharding, xs: Sequence[Any], devices: Sequence[jax.Device], committed: bool = True): from jax._src import array @@ -176,7 +190,7 @@ def batched_device_put(aval: core.ShapedArray, if len(bufs) == len(xs): return array.ArrayImpl( aval, sharding, bufs, committed=committed, _skip_checks=True) - return xc.batched_device_put(aval, sharding, xs, list(devices), committed) # type: ignore + return xc.batched_device_put(aval, sharding, xs, list(devices), committed) def _shard_aval(size, axis: int, aval): try: @@ -201,7 +215,7 @@ def _shard_abstract_array(size, axis: int, x): def local_aval_to_result_handler( aval: core.AbstractValue, - sharding: sharding_impls.XLACompatibleSharding, + sharding: JSharding, indices: tuple[Index, ...] | None, ) -> Callable[[list[xc.ArrayImpl]], Any]: """Returns a function for handling the raw buffers of a single output aval. @@ -567,11 +581,18 @@ def parallel_callable(fun: lu.WrappedFun, donated_invars: Sequence[bool], is_explicit_global_axis_size: bool, *avals): + closed_jaxpr, xc_backend, replicas, shards, pci = get_pmap_jaxpr( + fun, backend_name, axis_name, + axis_size=axis_size, global_axis_size=global_axis_size, + devices=devices, name=fun.__name__, in_axes=in_axes, + out_axes_thunk=out_axes_thunk, avals=avals) pmap_computation = lower_parallel_callable( - fun, backend_name, axis_name, axis_size, global_axis_size, devices, name, - in_axes, out_axes_thunk, donated_invars, + fun, axis_name, axis_size, global_axis_size, devices, name, + in_axes, donated_invars, is_explicit_global_axis_size, avals, - lowering_parameters=mlir.LoweringParameters()) + lowering_platforms=None, lowering_parameters=mlir.LoweringParameters(), + closed_jaxpr=closed_jaxpr, backend=xc_backend, replicas=replicas, + shards=shards, pci=pci) pmap_executable = pmap_computation.compile() return WeakRefList([pmap_executable.unsafe_call, pmap_executable.fingerprint]) @@ -595,7 +616,7 @@ def local_devices(self): if d.process_index == xb.process_index(self.backend)] assert len(out) > 0 else: - out = None # type: ignore + out = None return out @cached_property @@ -652,7 +673,7 @@ def stage_parallel_callable( fun = _change_argument_ranks(fun, pci.in_axes, pci.out_axes_thunk) else: fun = orig_fun - with core.extend_axis_env(pci.axis_name, pci.global_axis_size, None): # type: ignore + with core.extend_axis_env(pci.axis_name, pci.global_axis_size, None): with dispatch.log_elapsed_time( "Finished tracing + transforming {fun_name} for pmap in {elapsed_time} sec", fun_name=fun.__name__, event=dispatch.JAXPR_TRACE_EVENT): @@ -675,8 +696,7 @@ def stage_parallel_callable( return jaxpr, consts, replicas, shards -@profiler.annotate_function -def lower_parallel_callable( +def get_pmap_jaxpr( fun: lu.WrappedFun, backend_name: str | None, axis_name: core.AxisName, @@ -686,11 +706,41 @@ def lower_parallel_callable( name: str, in_axes: Iterable[int | None], out_axes_thunk: Callable[[], Sequence[int | None]], + avals: Sequence[core.AbstractValue]): + if devices is not None and backend_name is None: + backend = xb.get_device_backend(devices[0]) + else: + backend = xb.get_backend(backend_name) + + pci = ParallelCallableInfo( + name, backend, axis_name, axis_size, global_axis_size, devices, + in_axes, out_axes_thunk, avals) + jaxpr, consts, replicas, shards = stage_parallel_callable(pci, fun) + jaxpr = core.remove_named_axis_effects(jaxpr, {axis_name}) + closed_jaxpr = core.ClosedJaxpr(jaxpr, consts) + return closed_jaxpr, backend, replicas, shards, pci + + +@profiler.annotate_function +def lower_parallel_callable( + fun: lu.WrappedFun, + axis_name: core.AxisName, + axis_size: int, + global_axis_size: int, + devices: Sequence[xc.Device] | None, + name: str, + in_axes: Iterable[int | None], donated_invars: Sequence[bool], is_explicit_global_axis_size: bool, avals: Sequence[core.AbstractValue], *, - lowering_parameters: mlir.LoweringParameters) -> PmapComputation: + lowering_platforms: tuple[str, ...] | None, + lowering_parameters: mlir.LoweringParameters, + closed_jaxpr: core.ClosedJaxpr, + backend: xc.Client, + replicas: ReplicaInfo, + shards: ShardInfo, + pci: ParallelCallableInfo) -> PmapComputation: # Determine global_axis_size for use in AxisEnv. # TODO(mattjj,skyewm): revive this check (inner_pmap always False now) # if xb.process_count() > 1 and global_axis_size is None and inner_pmap: @@ -701,10 +751,7 @@ def lower_parallel_callable( f"Specified axis_size {global_axis_size} doesn't match received " f"axis_size {axis_size}.") - if devices is not None and backend_name is None: - backend = xb.get_device_backend(devices[0]) - else: - backend = xb.get_backend(backend_name) + jaxpr = closed_jaxpr.jaxpr no_nested_sharding = False must_run_on_all_devices = False @@ -721,10 +768,6 @@ def lower_parallel_callable( # devices). Nested sharding is ok in this case. must_run_on_all_devices = True - pci = ParallelCallableInfo( - name, backend, axis_name, axis_size, global_axis_size, devices, - in_axes, out_axes_thunk, avals) - jaxpr, consts, replicas, shards = stage_parallel_callable(pci, fun) if logger.isEnabledFor(logging.DEBUG): logger.debug("sharded_avals: %s", shards.sharded_avals) logger.debug("global_sharded_avals: %s", shards.global_sharded_avals) @@ -766,13 +809,12 @@ def lower_parallel_callable( axis_env = sharding_impls.AxisEnv( replicas.num_global_replicas, (axis_name,), (global_axis_size,)) name_stack = source_info_util.new_name_stack(wrap_name(name, 'pmap')) - jaxpr = core.remove_named_axis_effects(jaxpr, {axis_name}) - closed_jaxpr = core.ClosedJaxpr(jaxpr, consts) replicated_args = [axis is None for axis in in_axes] tuple_args = dispatch.should_tuple_args(len(shards.global_sharded_avals), backend.platform) module_name = f"pmap_{fun.__name__}" - with maybe_extend_axis_env(axis_name, global_axis_size, None): # type: ignore + platforms = lowering_platforms or (backend.platform,) + with maybe_extend_axis_env(axis_name, global_axis_size, None): ordered_effects = list( effects.ordered_effects.filter_in(closed_jaxpr.effects)) if ordered_effects: @@ -787,7 +829,7 @@ def lower_parallel_callable( closed_jaxpr, ordered_effects=ordered_effects, backend_or_name=backend, - platforms=lowering_parameters.platforms or (backend.platform,), + platforms=platforms, axis_context=sharding_impls.ReplicaAxisContext(axis_env), name_stack=name_stack, donated_args=donated_invars, @@ -798,13 +840,16 @@ def lower_parallel_callable( result_names=jaxpr.debug_info and jaxpr.debug_info.result_paths, num_replicas=replicas.num_global_replicas, lowering_parameters=lowering_parameters) - return PmapComputation(lowering_result.module, pci=pci, replicas=replicas, + return PmapComputation(lowering_result.module, + platforms=platforms, + pci=pci, replicas=replicas, shards=shards, tuple_args=tuple_args, unordered_effects=unordered_effects, ordered_effects=ordered_effects, keepalive=lowering_result.keepalive, host_callbacks=lowering_result.host_callbacks, - jaxpr_debug_info=closed_jaxpr.jaxpr.debug_info) + jaxpr_debug_info=closed_jaxpr.jaxpr.debug_info, + shape_poly_state=lowering_result.shape_poly_state) def _pmap_unmap_shaped_array( @@ -870,9 +915,9 @@ class UnloadedPmapExecutable: compiled: Any backend: xb.XlaBackend local_input_avals: Sequence[core.AbstractValue] - input_shardings: Sequence[sharding_impls.XLACompatibleSharding] + input_shardings: Sequence[JSharding] local_output_avals: Sequence[ShapedArray] - output_shardings: Sequence[sharding_impls.XLACompatibleSharding] + output_shardings: Sequence[JSharding] unordered_effects: list[core.Effect] ordered_effects: list[core.Effect] keepalive: Sequence[Any] @@ -917,7 +962,13 @@ def from_hlo(hlo: ir.Module, host_callbacks: list[Any], keepalive: Any, jaxpr_debug_info: core.JaxprDebugInfo, + platforms: Sequence[str], + shape_poly_state: mlir.ShapePolyLoweringState | None = None, compiler_options=None): + del platforms + if shape_poly_state is not None and shape_poly_state.uses_dim_vars: + hlo = mlir.refine_polymorphic_shapes(hlo) + devices = pci.devices if devices is None: if shards.num_global_shards > xb.device_count(pci.backend): @@ -1002,19 +1053,6 @@ def from_hlo(hlo: ir.Module, shards.out_sharded_avals, pci.out_axes)] out_shardings = _get_pmap_sharding(local_device_assignment, out_specs) - if hasattr(pci.backend, "compile_replicated"): - input_indices = [ - sharding_specs.spec_to_indices(aval.shape, spec) - if spec is not None else None - for aval, spec in safe_zip(pci.avals, input_sharding_specs) - ] - handle_outs = local_avals_to_results_handler(local_unmapped_avals, - out_shardings) - return _compile_replicated_pmap_executable_from_hlo( - hlo, pci, input_indices, in_shardings, handle_outs, - compile_options, host_callbacks, bool(unordered_effects), - ordered_effects, jaxpr_debug_info) - with dispatch.log_elapsed_time( "Finished XLA compilation of {fun_name} in {elapsed_time} sec", fun_name=pci.name, event=dispatch.BACKEND_COMPILE_EVENT): @@ -1036,23 +1074,6 @@ def from_hlo(hlo: ir.Module, jaxpr_debug_info=jaxpr_debug_info).load() -def _compile_replicated_pmap_executable_from_hlo( - hlo: ir.Module, pci, input_indices, in_shardings, handle_outs, - compile_options, host_callbacks, has_unordered_effects, ordered_effects, - jaxpr_debug_info): - # Use the standard out_handler. - execute_fun = pci.backend.compile_replicated( - is_trivial=False, name=pci.name, computation=hlo, - compile_options=compile_options, host_callbacks=host_callbacks, - has_unordered_effects=has_unordered_effects, - ordered_effects=ordered_effects, in_avals=pci.avals, - in_indices=input_indices, in_shardings=in_shardings, - kept_var_idx=set(range(len(pci.avals))), out_handler=handle_outs) - # TODO(frostig): need `compile_replicated` to give us the XLA executable - return PmapExecutable(None, lambda: execute_fun, None, pci.avals, - jaxpr_debug_info, None) - - class PmapExecutable(stages.XlaExecutable): __slots__ = ["xla_executable", "_unsafe_call", "build_unsafe_call", "fingerprint", "in_avals", "_jaxpr_debug_info", @@ -1111,9 +1132,8 @@ def __str__(self): class ResultsHandler: - # `out_avals` is the `Array` global avals when using pjit or xmap - # with `config.parallel_functions_output_gda=True`. It is the local one - # otherwise, and also when using `pmap`. + # `out_avals` is the `Array` global avals when using pjit or xmap. It is the + # local one when using `pmap`. __slots__ = ("handlers", "out_shardings", "out_avals") def __init__(self, handlers, out_shardings, out_avals): @@ -1127,7 +1147,7 @@ def __call__(self, out_bufs): def local_avals_to_results_handler( unmapped_local_out_avals: Sequence[ShapedArray], - local_shardings: Sequence[sharding_impls.XLACompatibleSharding]) -> ResultsHandler: + local_shardings: Sequence[JSharding]) -> ResultsHandler: out_indices = [tuple(s.devices_indices_map(aval.shape).values()) for s, aval in safe_zip(local_shardings, unmapped_local_out_avals)] handlers = [ @@ -1139,7 +1159,7 @@ def local_avals_to_results_handler( def global_avals_to_results_handler( global_out_avals: Sequence[ShapedArray], - shardings: Sequence[sharding_impls.XLACompatibleSharding], + shardings: Sequence[JSharding], committed: bool) -> ResultsHandler: handlers = [ global_aval_to_result_handler(global_aval, s, committed) @@ -1153,14 +1173,15 @@ class ExecuteReplicated: __slots__ = ['xla_executable', 'name', 'backend', 'in_handler', 'out_handler', 'has_unordered_effects', 'ordered_effects', 'keepalive', 'has_host_callbacks', '_local_devices', 'kept_var_idx', - 'out_mut', '__weakref__'] + 'mut', 'pgle_profiler', '__weakref__'] def __init__(self, xla_executable, name, backend, in_handler: InputsHandler, out_handler: ResultsHandler, unordered_effects: list[core.Effect], ordered_effects: list[core.Effect], keepalive: Any, has_host_callbacks: bool, kept_var_idx: set[int], - out_mut: Sequence[int | None] | None): + mut: MutationData | None, + pgle_profiler: profiler.PGLEProfiler | None = None): self.xla_executable = xla_executable self.name = name self.backend = backend @@ -1172,54 +1193,80 @@ def __init__(self, xla_executable, name, backend, in_handler: InputsHandler, self.keepalive = keepalive self.has_host_callbacks = has_host_callbacks self.kept_var_idx = kept_var_idx - self.out_mut = out_mut + self.mut = mut + self.pgle_profiler = pgle_profiler def _add_tokens_to_inputs(self, input_bufs): if self.ordered_effects: tokens = [ - dispatch.runtime_tokens.get_token_input(eff, self._local_devices) - for eff in self.ordered_effects] + dispatch.runtime_tokens.get_token_input(eff, self._local_devices)._buf + for eff in self.ordered_effects + ] input_bufs = [*tokens, *input_bufs] return input_bufs def _handle_token_bufs(self, token_bufs, sharded_token): # token_bufs: Sequence[Sequence[tokenArray]], for each effect the returned - # token buffer (as a singleton list). + # token buffers. # sharded_token: ShardedToken, containing the RuntimeTokens for each device for i, device in enumerate(self._local_devices): dispatch.runtime_tokens.set_output_runtime_token( device, sharded_token.get_token(i)) for eff, token_buf in zip(self.ordered_effects, token_bufs): - dispatch.runtime_tokens.set_token_result(eff, token_buf[0]) + assert len(token_buf) > 0 + if len(token_buf) == 1: + dispatch.runtime_tokens.set_token_result(eff, core.Token(token_buf[0])) + else: + token_devices = [] + for token in token_buf: + assert isinstance(token.sharding, sharding_impls.SingleDeviceSharding) + token_devices.append(token.sharding._device_assignment[0]) + s = sharding_impls.PositionalSharding(token_devices) + global_token_array = jax.make_array_from_single_device_arrays( + (0,), s, token_buf + ) + dispatch.runtime_tokens.set_token_result( + eff, core.Token(global_token_array) + ) @profiler.annotate_function def __call__(self, *args): args = [x for i, x in enumerate(args) if i in self.kept_var_idx] + if self.mut: + args = [*args, *self.mut.in_mut] input_bufs = self.in_handler(args) - if (self.ordered_effects or self.has_unordered_effects - or self.has_host_callbacks): - input_bufs = self._add_tokens_to_inputs(input_bufs) - results = self.xla_executable.execute_sharded( - input_bufs, with_tokens=True - ) - result_token_bufs = results.disassemble_prefix_into_single_device_arrays( - len(self.ordered_effects)) - sharded_runtime_token = results.consume_token() - self._handle_token_bufs(result_token_bufs, sharded_runtime_token) - else: - results = self.xla_executable.execute_sharded(input_bufs) - if dispatch.needs_check_special(): - out_arrays = results.disassemble_into_single_device_arrays() - for arrays in out_arrays: - dispatch.check_special(self.name, arrays) - out = self.out_handler(out_arrays) - else: - out = results.consume_with_handlers(self.out_handler.handlers) - if self.out_mut is None: + with profiler.PGLEProfiler.trace(self.pgle_profiler): + if (self.ordered_effects or self.has_unordered_effects + or self.has_host_callbacks): + input_bufs = self._add_tokens_to_inputs(input_bufs) + results = self.xla_executable.execute_sharded( + input_bufs, with_tokens=True + ) + + result_token_bufs = results.disassemble_prefix_into_single_device_arrays( + len(self.ordered_effects)) + sharded_runtime_token = results.consume_token() + self._handle_token_bufs(result_token_bufs, sharded_runtime_token) + else: + results = self.xla_executable.execute_sharded(input_bufs) + + if dispatch.needs_check_special(): + out_arrays = results.disassemble_into_single_device_arrays() + for arrays in out_arrays: + dispatch.check_special(self.name, arrays) + out = self.out_handler(out_arrays) + else: + out = results.consume_with_handlers(self.out_handler.handlers) + + if (self.pgle_profiler is not None and self.pgle_profiler.is_running() + and len(out) > 0): + out[0].block_until_ready() + + if self.mut is None: return out else: out_ = [] - for i, o in zip(self.out_mut, out): + for i, o in zip(self.mut.out_mut, out): if i is not None: args[i]._buf = o else: @@ -1344,6 +1391,8 @@ def _hlo_shard(aval, axis_env, xs, in_axis): if aval is core.abstract_token: return xs elif isinstance(aval, core.ShapedArray): + if dtypes.issubdtype(aval.dtype, dtypes.extended): + aval = aval.dtype._rules.physical_element_aval(aval.dtype) x, = xs dims = list(aval.shape) zero = mlir.ir_constant(np.zeros((), dtype=np.uint32)) @@ -1444,7 +1493,7 @@ def _pmap_lowering(ctx, *in_nodes, axis_name, if in_axis is not None else mlir.wrap_singleton_ir_values(in_node) for aval, in_node, in_axis in zip(in_avals, in_nodes, in_axes)) - with maybe_extend_axis_env(axis_name, global_axis_size, None): # type: ignore + with maybe_extend_axis_env(axis_name, global_axis_size, None): sub_ctx = ctx.module_context.replace( axis_context=sharding_impls.ReplicaAxisContext(new_env)) sharded_outs, _ = mlir.jaxpr_subcomp( @@ -1547,11 +1596,11 @@ def manual_proto( tad_shape.append(math.prod([named_mesh_shape[a] for a in replicated_axes])) tad_shape.append(math.prod([named_mesh_shape[a] for a in manual_axes])) - raw_mesh = np.arange(math.prod(mesh_shape)).reshape(mesh_shape) proto = xc.OpSharding() proto.type = xc.OpSharding.Type.OTHER proto.tile_assignment_dimensions = tad_shape - proto.tile_assignment_devices = list(raw_mesh.transpose(tad_perm).reshape(tad_shape).flat) + proto.iota_reshape_dims = mesh_shape + proto.iota_transpose_perm = tad_perm proto.last_tile_dims = [xc.OpSharding.Type.REPLICATED, xc.OpSharding.Type.MANUAL] return proto @@ -1619,8 +1668,7 @@ class TileManual: def check_if_any_auto( - shardings: Iterable[(sharding_impls.XLACompatibleSharding | - AUTO | UnspecifiedValue)]) -> bool: + shardings: Iterable[(JSharding | AUTO | UnspecifiedValue)]) -> bool: for s in shardings: if is_auto(s): return True @@ -1685,7 +1733,7 @@ class DeviceAssignmentMismatchError(Exception): ShardingInfo = tuple[ - Union[sharding_impls.XLACompatibleSharding, UnspecifiedValue, AUTO], + Union[JSharding, UnspecifiedValue, AUTO], MismatchType, Union[Any, None], # Any is dispatch.SourceInfo to avoid circular imports ] @@ -1742,7 +1790,7 @@ def _get_and_check_device_assignment( final_device_assignment = first_sharding_info[0] return xb.get_device_backend(final_device_assignment[0]), final_device_assignment -MaybeSharding = Union[sharding_impls.XLACompatibleSharding, UnspecifiedValue] +MaybeSharding = Union[JSharding, UnspecifiedValue] def prune_unused_inputs( @@ -1756,55 +1804,85 @@ def prune_unused_inputs( @weakref_lru_cache -def _dce_jaxpr(closed_jaxpr, global_in_avals, api_name, fun_name, +def _dce_jaxpr(closed_jaxpr, api_name, fun_name, keep_unused, donated_invars, auto_spmd_lowering): name_stack = source_info_util.new_name_stack(wrap_name(fun_name, api_name)) assert isinstance(closed_jaxpr, core.ClosedJaxpr) jaxpr = closed_jaxpr.jaxpr - global_out_avals = closed_jaxpr.out_avals consts = closed_jaxpr.consts + in_avals = closed_jaxpr.in_avals if (keep_unused or auto_spmd_lowering or any(hasattr(a, "shape") and not core.is_constant_shape(a.shape) - for a in global_in_avals)): - kept_var_idx = set(range(len(global_in_avals))) + for a in in_avals)): + kept_var_idx = set(range(len(in_avals))) else: jaxpr, kept_const_idx, kept_var_idx = prune_unused_inputs(jaxpr) consts = [c for i, c in enumerate(consts) if i in kept_const_idx] - global_in_avals = tuple(a for i, a in enumerate(global_in_avals) if i in kept_var_idx) donated_invars = tuple(x for i, x in enumerate(donated_invars) if i in kept_var_idx) del kept_const_idx jaxpr = dispatch.apply_outfeed_rewriter(jaxpr) closed_jaxpr = core.ClosedJaxpr(jaxpr, consts) - return (closed_jaxpr, global_in_avals, tuple(global_out_avals), donated_invars, - kept_var_idx, name_stack) + return closed_jaxpr, donated_invars, kept_var_idx, name_stack + +class MutationData(NamedTuple): + in_mut: list[core.MutableArray] + out_mut: list[int | None] @weakref_lru_cache def _discharge_refs( jaxpr: core.ClosedJaxpr -) -> tuple[core.ClosedJaxpr, Sequence[int | None], Sequence[int | None]]: +) -> tuple[core.ClosedJaxpr, Sequence[int | None], MutationData]: from jax._src.state.discharge import discharge_state + jaxpr, in_mut = _move_mutable_consts(jaxpr) new_jaxpr = core.ClosedJaxpr(*discharge_state(jaxpr.jaxpr, jaxpr.consts)) count = it.count(len(jaxpr.out_avals)) # new outputs are appended to the end inout_map = {i: next(count) for i, a in enumerate(jaxpr.in_avals) if isinstance(a, AbstractRef)} outin_map = {j: i for i, j in inout_map.items()} inout_aliases = tuple(map(inout_map.get, range(len(new_jaxpr.in_avals)))) - out_mut = tuple(map(outin_map.get, range(len(new_jaxpr.out_avals)))) - return new_jaxpr, inout_aliases, out_mut + out_mut = list(map(outin_map.get, range(len(new_jaxpr.out_avals)))) + return new_jaxpr, inout_aliases, MutationData(in_mut, out_mut) + +@weakref_lru_cache +def _move_mutable_consts( + closed_jaxpr: core.ClosedJaxpr, +) -> tuple[core.ClosedJaxpr, list[core.MutableArray]]: + jaxpr = closed_jaxpr.jaxpr + hoist = [isinstance(c, core.MutableArray) for c in closed_jaxpr.consts] + consts, in_mut = partition_list(hoist, closed_jaxpr.consts) + constvars, mutvars = partition_list(hoist, jaxpr.constvars) + invars = (*jaxpr.invars, *mutvars) + effects = pe.make_jaxpr_effects(constvars, invars, jaxpr.outvars, jaxpr.eqns) + jaxpr = core.Jaxpr(constvars, invars, jaxpr.outvars, jaxpr.eqns, + effects, None) + return core.ClosedJaxpr(jaxpr, consts), in_mut + +@weakref_lru_cache +def _discharge_internal_refs(jaxpr: core.ClosedJaxpr) -> core.ClosedJaxpr: + from jax._src.state.discharge import discharge_state + jaxpr_, consts = discharge_state(jaxpr.jaxpr, jaxpr.consts) + jaxpr_._debug_info = jaxpr.jaxpr.debug_info + return core.ClosedJaxpr(jaxpr_, consts) -@dataclasses.dataclass(frozen=True) class SemanticallyEqualShardings: - shardings: tuple[sharding_impls.GSPMDSharding | UnspecifiedValue, ...] + + def __init__(self, shardings: tuple[GSPMDSharding | UnspecifiedValue, ...], + avals: tuple[core.AbstractValue]): + gspmd_shardings = [ + s if is_unspecified_or_auto(s) else to_gspmd_sharding(s, a.ndim) # type: ignore + for s, a in zip(shardings, avals)] + self._gspmd_shardings = gspmd_shardings + self.shardings = shardings + self.avals = avals def __hash__(self): return hash(tuple( - (s._hlo_sharding_hash, s.memory_kind) # type: ignore - if isinstance(s, sharding_impls.GSPMDSharding) else s - for s in self.shardings)) + (s._hlo_sharding_hash, s.memory_kind) + if isinstance(s, GSPMDSharding) else s for s in self._gspmd_shardings)) def __eq__(self, other): if not isinstance(other, SemanticallyEqualShardings): @@ -1812,10 +1890,9 @@ def __eq__(self, other): return all( (op_shardings.are_op_shardings_equal(s._hlo_sharding, o._hlo_sharding) and s.memory_kind == o.memory_kind) - if (isinstance(s, sharding_impls.GSPMDSharding) and - isinstance(o, sharding_impls.GSPMDSharding)) + if (isinstance(s, GSPMDSharding) and isinstance(o, GSPMDSharding)) else s == o - for s, o in zip(self.shardings, other.shardings) + for s, o in zip(self._gspmd_shardings, other._gspmd_shardings) ) @@ -1843,23 +1920,54 @@ def _raise_warnings_or_errors_for_jit_of_pmap( "extra data movement anyway, so maybe you don't want it after all).") +@lru_cache(maxsize=2048) +def _maybe_get_default_layout(arg_layout, jit_in_layout, sharding, aval + ) -> DeviceLocalLayout | None: + if is_unspecified_or_auto(sharding): + return None + # TODO(yashkatariya): Figure out how layouts work with extended dtypes. + if dtypes.issubdtype(aval.dtype, dtypes.extended): + return None + if not core.is_constant_shape(aval.shape): + return None + shard_shape = sharding.shard_shape(aval.shape) + d = sharding._device_assignment[0] + # If a backend doesn't implement `get_default_layout` return `None` to avoid + # cache misses. This can happen when you have `jit(f, in_shardings=s)`. On + # first call you pass it a sharded array with layout and on second call you + # pass a numpy array. The layouts should be the same to get cache hits. + try: + al = DeviceLocalLayout.from_pjrt_layout( + d.client.get_default_layout(aval.dtype, shard_shape, d)) + except: + return None + # argument does not have `.layout` property. ShapedArray, numpy array, etc + # are some examples. + if arg_layout is None: + return al if jit_in_layout is None else arg_layout # arg_layout is None + # If arg has a `.layout` property, then return device_local_layout as is. + return arg_layout.device_local_layout + + @weakref_lru_cache def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend, semantic_in_shardings, semantic_out_shardings, in_layouts, out_layouts, num_devices, device_assignment, donated_invars, name_stack, all_default_mem_kind, inout_aliases: None | tuple[None | int, ...], + propagated_out_mem_kinds: tuple[None | str, ...], + platforms: tuple[str, ...], lowering_parameters: mlir.LoweringParameters): jaxpr = closed_jaxpr.jaxpr - in_shardings = semantic_in_shardings.shardings - out_shardings = semantic_out_shardings.shardings + in_shardings = semantic_in_shardings._gspmd_shardings + out_shardings = semantic_out_shardings._gspmd_shardings global_in_avals = closed_jaxpr.in_avals global_out_avals = closed_jaxpr.out_avals log_priority = logging.WARNING if config.log_compiles.value else logging.DEBUG if logger.isEnabledFor(log_priority): logger.log(log_priority, - "Compiling %s for with global shapes and types %s. " + "Compiling %s with global shapes and types %s. " "Argument mapping: %s.", fun_name, global_in_avals, in_shardings) @@ -1871,8 +1979,8 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend, nreps = dispatch.jaxpr_replicas(jaxpr) _raise_warnings_or_errors_for_jit_of_pmap(nreps, backend, fun_name, jaxpr) - in_mlir_shardings: list[sharding_impls.XLACompatibleSharding | None] | None - out_mlir_shardings: list[sharding_impls.XLACompatibleSharding | None] | None + in_mlir_shardings: list[JSharding | None] | None + out_mlir_shardings: list[JSharding | None] | None axis_ctx: mlir.AxisContext if nreps == 1: @@ -1901,7 +2009,6 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend, "The following ordered effects are not supported for " f"more than 1 device: {unsupported_effects}") ordered_effects = list(effects.ordered_effects.filter_in(closed_jaxpr.effects)) - with dispatch.log_elapsed_time( "Finished jaxpr to MLIR module conversion {fun_name} in {elapsed_time} sec", fun_name=str(name_stack), event=dispatch.JAXPR_TO_MLIR_MODULE_EVENT): @@ -1910,8 +2017,7 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend, closed_jaxpr, ordered_effects=ordered_effects, backend_or_name=backend, - # Optionally, override the lowering platform - platforms=lowering_parameters.platforms or (backend.platform,), + platforms=platforms, axis_context=axis_ctx, name_stack=name_stack, donated_args=donated_invars, @@ -1926,6 +2032,7 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend, num_partitions=num_partitions, all_default_mem_kind=all_default_mem_kind, input_output_aliases=inout_aliases, + propagated_out_mem_kinds=propagated_out_mem_kinds, lowering_parameters=lowering_parameters) tuple_args = dispatch.should_tuple_args(len(global_in_avals), backend.platform) unordered_effects = list( @@ -1937,16 +2044,16 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend, @lru_cache(maxsize=2048) def _create_da_object( # pytype: disable=invalid-annotation - device_assignment: tuple[xc.Device, ...]) -> xc.DeviceList: # type: ignore + device_assignment: tuple[xc.Device, ...]) -> xc.DeviceList: return xc.DeviceList(device_assignment) def jaxpr_transfer_mem_kinds( jaxpr: core.Jaxpr) -> Iterator[sharding_impls.TransferToMemoryKind]: for eqn in jaxpr.eqns: - if (eqn.primitive is dispatch.device_put_p and - isinstance(eqn.params['device'], sharding_impls.TransferToMemoryKind)): - yield eqn.params['device'] + if eqn.primitive is dispatch.device_put_p: + yield from (d for d in eqn.params['devices'] + if isinstance(d, sharding_impls.TransferToMemoryKind)) for subjaxpr in core.subjaxprs(jaxpr): yield from jaxpr_transfer_mem_kinds(subjaxpr) @@ -1963,16 +2070,89 @@ def are_all_shardings_default_mem_kind(da_object, shardings): return False return True -MaybeLayout = Sequence[Union[XLACompatibleLayout, LayoutRequest, None]] +memory_kind_propagate_rule: dict[Any, Any] = {} + +@weakref_lru_cache +def get_out_memory_kinds_via_propagation(closed_jaxpr: core.ClosedJaxpr, + in_shardings=None) -> tuple[None | str]: + env = {} # type: ignore + jaxpr = closed_jaxpr.jaxpr + + def read(var): + if type(var) is core.Literal: + return None + return env[var] + + def write(var, val): + env[var] = val + + def _default_rule(prim, num_outvars, *_, **__): + return [None] * num_outvars if prim.multiple_results else None + + if in_shardings is None: + invar_mem_kind = [None] * len(jaxpr.invars) + else: + invar_mem_kind = [None if is_unspecified_or_auto(s) else s.memory_kind + for s in in_shardings] + safe_map(write, jaxpr.invars, invar_mem_kind) + safe_map(write, jaxpr.constvars, [None] * len(jaxpr.constvars)) + + for eqn in jaxpr.eqns: + in_mem_kinds = safe_map(read, eqn.invars) + rule = memory_kind_propagate_rule.get( + eqn.primitive, partial(_default_rule, eqn.primitive, len(eqn.outvars))) + out_mem_kinds = rule(*in_mem_kinds, **eqn.params) + if not eqn.primitive.multiple_results: + out_mem_kinds = [out_mem_kinds] + safe_map(write, eqn.outvars, out_mem_kinds) + return tuple(safe_map(read, jaxpr.outvars)) + + +MaybeLayout = Sequence[Union[DeviceLocalLayout, AutoLayout, None]] class AllArgsInfo(NamedTuple): - """Avals, shardings, layouts and debug_info for all arguments prior to DCE.""" + """Avals and debug_info for all arguments prior to DCE.""" in_avals: Sequence[core.ShapedArray] - in_shardings: Any debug_info: core.JaxprDebugInfo | None +@lru_cache(maxsize=2048) +def to_gspmd_sharding(s: JSharding, ndim: int) -> GSPMDSharding: + if isinstance(s, GSPMDSharding): + return s + return GSPMDSharding(s._device_assignment, s._to_xla_hlo_sharding(ndim), + memory_kind=s.memory_kind, + _device_list=getattr(s, '_internal_device_list', None)) + + +# Dummy function which is a no-op in OSS since enhanced barrier is switched on +# in OSS. +def spmd_mode_check(da_object, inline): + return + + +def _discharge_refs_jaxpr(closed_jaxpr, in_shardings, in_layouts, + donated_invars, out_shardings, out_layouts): + if any(isinstance(e, RefEffect) for e in closed_jaxpr.effects): + closed_jaxpr, inout_aliases, mut = _discharge_refs(closed_jaxpr) + in_shardings = (*in_shardings,) + (UNSPECIFIED,) * len(mut.in_mut) + in_layouts = (*in_layouts,) + (None,) * len(mut.in_mut) + donated_invars = (*donated_invars,) + (False,) * len(mut.in_mut) + out_layouts_ = iter(zip(out_shardings, out_layouts)) + out_shardings, out_layouts = unzip2( + next(out_layouts_) if i is None else (in_shardings[i], in_layouts[i]) + for i in mut.out_mut) + assert next(out_layouts_, None) is None + else: + inout_aliases = mut = None + if any(isinstance(e, core.InternalMutableArrayEffect) for e in closed_jaxpr.effects): + closed_jaxpr = _discharge_internal_refs(closed_jaxpr) + + return (closed_jaxpr, inout_aliases, mut, in_shardings, in_layouts, + donated_invars, out_shardings, out_layouts) + + @profiler.annotate_function def lower_sharding_computation( closed_jaxpr: core.ClosedJaxpr, @@ -1980,15 +2160,16 @@ def lower_sharding_computation( fun_name: str, in_shardings: Sequence[MaybeSharding], out_shardings: Sequence[MaybeSharding], + in_layouts: MaybeLayout, + out_layouts: MaybeLayout, donated_invars: Sequence[bool], - global_in_avals: Sequence[core.ShapedArray], *, keep_unused: bool, inline: bool, - devices_from_context: Sequence[xc.Device] | None = None, + devices_from_context: Sequence[xc.Device] | None, + lowering_platforms: tuple[str, ...] | None, lowering_parameters: mlir.LoweringParameters, - in_layouts: MaybeLayout, - out_layouts: MaybeLayout, + pgle_profiler: profiler.PGLEProfiler | None, ) -> MeshComputation: """Lowers a computation to XLA. It can take arbitrary shardings as input. @@ -1999,31 +2180,25 @@ def lower_sharding_computation( """ # 1. Trace to jaxpr and preprocess/verify it auto_spmd_lowering = check_if_any_auto( - it.chain.from_iterable([in_shardings, out_shardings])) # type: ignore + it.chain.from_iterable([in_shardings, out_shardings])) - all_args_info = AllArgsInfo(global_in_avals, in_shardings, - closed_jaxpr.jaxpr.debug_info) + all_args_info = AllArgsInfo(closed_jaxpr.in_avals, closed_jaxpr.jaxpr.debug_info) - (closed_jaxpr, global_in_avals, global_out_avals, donated_invars, - kept_var_idx, name_stack) = _dce_jaxpr( - closed_jaxpr, global_in_avals, api_name, fun_name, keep_unused, - donated_invars, auto_spmd_lowering) + closed_jaxpr, donated_invars, kept_var_idx, name_stack = _dce_jaxpr( + closed_jaxpr, api_name, fun_name, keep_unused, donated_invars, + auto_spmd_lowering) in_shardings = tuple(s for i, s in enumerate(in_shardings) if i in kept_var_idx) in_layouts = tuple(l for i, l in enumerate(in_layouts) if i in kept_var_idx) - if any(isinstance(e, RefEffect) for e in closed_jaxpr.effects): - closed_jaxpr, inout_aliases, out_mut = _discharge_refs(closed_jaxpr) - if out_mut: - out_layouts_ = iter(zip(out_shardings, out_layouts)) - out_shardings, out_layouts = unzip2( - next(out_layouts_) if i is None else (in_shardings[i], in_layouts[i]) - for i in out_mut) - assert next(out_layouts_, None) is None - global_out_avals = closed_jaxpr.out_avals - else: - inout_aliases = out_mut = None + (closed_jaxpr, inout_aliases, mut, in_shardings, in_layouts, + donated_invars, out_shardings, out_layouts) = _discharge_refs_jaxpr( + closed_jaxpr, in_shardings, in_layouts, donated_invars, out_shardings, + out_layouts) jaxpr = closed_jaxpr.jaxpr + global_in_avals = closed_jaxpr.in_avals + global_out_avals = closed_jaxpr.out_avals + assert len(out_shardings) == len(out_layouts) == len(global_out_avals), ( len(out_shardings), len(out_layouts), len(global_out_avals)) @@ -2038,6 +2213,7 @@ def lower_sharding_computation( for js, source_info in util.stable_unique(jaxpr_sharding))), devices_from_context) + platforms = lowering_platforms or (backend.platform,) # TODO(yashkatariya): Enable this when offload APIs are stable. # transfer_mem_kind_in_jaxpr = list(jaxpr_transfer_mem_kinds(jaxpr)) @@ -2048,33 +2224,27 @@ def lower_sharding_computation( any(not is_unspecified(js) for js, _ in jaxpr_sharding) or any(not is_unspecified(o) for o in out_shardings)) - gs = GSPMDSharding.get_replicated(device_assignment) - if xla_extension_version < 241 or hasattr(backend, "compile_replicated"): - in_shardings = tuple(gs if is_unspecified(i) else i for i in in_shardings) - da_object = _create_da_object(tuple(device_assignment)) all_default_mem_kind = are_all_shardings_default_mem_kind( da_object, - it.chain(in_shardings, out_shardings, [js for js, _ in jaxpr_sharding])) # type: ignore - - if not da_object.is_fully_addressable: # type: ignore - if inline and config.spmd_mode.value != 'allow_all': - raise RuntimeError( - "Running operations on `Array`s that are not fully addressable by this " - "process (i.e. `Array`s with data sharded across multiple devices and " - "processes.) is dangerous. It’s very important that all processes run " - "the same cross-process computations in the same order otherwise it " - "can lead to hangs. " - "If you’re not already familiar with JAX’s multi-process " - "programming model, please read " - "https://jax.readthedocs.io/en/latest/multi_process.html. " - "To fix this error, run your `jitted` computation inside " - "`with jax.spmd_mode('allow_all'):` context manager.") + it.chain(in_shardings, out_shardings, [js for js, _ in jaxpr_sharding])) + + # TODO(yashkatariya): Remove this when XLA can propagate memory kinds or when + # JAX puts memory kinds in the types of jaxpr. + if not all_default_mem_kind: + propagated_out_mem_kinds = get_out_memory_kinds_via_propagation( + closed_jaxpr, in_shardings) + else: + propagated_out_mem_kinds = (None,) * len(global_out_avals) + + spmd_mode_check(da_object, inline) # 2. Build up the HLO - semantic_in_shardings = SemanticallyEqualShardings(in_shardings) # type: ignore - semantic_out_shardings = SemanticallyEqualShardings(out_shardings) # type: ignore + semantic_in_shardings = SemanticallyEqualShardings( + in_shardings, global_in_avals) # type: ignore + semantic_out_shardings = SemanticallyEqualShardings( + out_shardings, global_out_avals) # type: ignore prim_requires_devices = dispatch.jaxpr_has_prim_requiring_devices(jaxpr) (module, keepalive, host_callbacks, unordered_effects, ordered_effects, @@ -2083,6 +2253,7 @@ def lower_sharding_computation( semantic_out_shardings, in_layouts, out_layouts, len(da_object), tuple(da_object) if prim_requires_devices else None, donated_invars, name_stack, all_default_mem_kind, inout_aliases, + propagated_out_mem_kinds, platforms, lowering_parameters=lowering_parameters) # backend and device_assignment is passed through to MeshExecutable because @@ -2094,6 +2265,7 @@ def lower_sharding_computation( str(name_stack), module, donated_invars, + platforms, global_in_avals=global_in_avals, global_out_avals=global_out_avals, in_shardings=in_shardings, @@ -2106,7 +2278,7 @@ def lower_sharding_computation( host_callbacks=host_callbacks, keepalive=keepalive, kept_var_idx=kept_var_idx, - out_mut=out_mut, + mut=mut, backend=backend, device_assignment=da_object, committed=committed, @@ -2115,16 +2287,17 @@ def lower_sharding_computation( pmap_nreps=nreps, shape_poly_state=shape_poly_state, all_default_mem_kind=all_default_mem_kind, - all_args_info=all_args_info) + all_args_info=all_args_info, + pgle_profiler=pgle_profiler) def _to_logical_sharding( aval: core.AbstractValue, sharding: MaybeSharding | AUTO -) -> sharding_impls.XLACompatibleSharding | None: +) -> JSharding | None: if is_unspecified(sharding) or is_auto(sharding): return None elif isinstance(aval, (ShapedArray, DShapedArray, AbstractRef)): - assert isinstance(sharding, sharding_impls.XLACompatibleSharding) + assert isinstance(sharding, JSharding) return sharding elif isinstance(aval, core.AbstractToken): return None @@ -2145,9 +2318,11 @@ def lower_mesh_computation( spmd_lowering: bool, global_in_avals: Sequence[core.ShapedArray], tiling_method: TilingMethod | None, + lowering_platforms: tuple[str, ...] | None, lowering_parameters: mlir.LoweringParameters) -> MeshComputation: assert not mesh.empty backend = xb.get_device_backend(mesh.devices.flat[0]) + platforms = lowering_platforms or (backend.platform,) name_stack = source_info_util.new_name_stack(wrap_name(fun_name, api_name)) global_axis_sizes = mesh.shape @@ -2168,7 +2343,7 @@ def lower_mesh_computation( if isinstance(tiling_method, TileVectorize): tiling_transform = vtile_by_mesh elif isinstance(tiling_method, TileManual): - tiling_transform = lambda f, *args: vtile_manual(f, tiling_method.manual_axes, *args) # type: ignore + tiling_transform = lambda f, *args: vtile_manual(f, tiling_method.manual_axes, *args) manual_axes = tiling_method.manual_axes else: raise NotImplementedError(f"Unrecognized tiling method: {tiling_method}") @@ -2201,8 +2376,6 @@ def lower_mesh_computation( out_jaxpr_avals = fun_or_jaxpr.out_avals consts = fun_or_jaxpr.consts - all_args_info = AllArgsInfo(global_in_avals, in_shardings, jaxpr.debug_info) - assert len(out_shardings) == len(out_jaxpr_avals) if spmd_lowering: global_out_avals = out_jaxpr_avals @@ -2218,8 +2391,8 @@ def lower_mesh_computation( # 2. Build up the HLO tuple_args = dispatch.should_tuple_args(len(in_jaxpr_avals), backend.platform) - in_partitions: list[sharding_impls.XLACompatibleSharding | None] | None - out_partitions: list[sharding_impls.XLACompatibleSharding | None] | None + in_partitions: list[JSharding | None] | None + out_partitions: list[JSharding | None] | None axis_ctx: mlir.AxisContext if spmd_lowering: in_partitions = map(_to_logical_sharding, global_in_avals, in_shardings) @@ -2258,7 +2431,7 @@ def lower_mesh_computation( closed_jaxpr, ordered_effects=ordered_effects, backend_or_name=backend, - platforms=lowering_parameters.platforms or (backend.platform,), + platforms=platforms, axis_context=axis_ctx, name_stack=name_stack, donated_args=donated_invars, @@ -2275,6 +2448,7 @@ def lower_mesh_computation( str(name_stack), lowering_result.module, donated_invars, + platforms, global_in_avals=global_in_avals, global_out_avals=global_out_avals, in_shardings=in_shardings, @@ -2293,17 +2467,20 @@ def lower_mesh_computation( in_layouts=(None,) * len(global_in_avals), out_layouts=(None,) * len(global_out_avals), shape_poly_state=lowering_result.shape_poly_state, - all_args_info=all_args_info) + all_args_info=None, + pgle_profiler=None) class MeshComputation(stages.XlaLowering): - _hlo: ir.Module | None + _hlo: ir.Module _executable: MeshExecutable | None - def __init__(self, name: str, hlo: ir.Module | None, - donated_invars: Sequence[bool], **compile_args): + def __init__(self, name: str, hlo: ir.Module, + donated_invars: Sequence[bool], platforms: Sequence[str], + **compile_args): self._name = name self._hlo = hlo self._donated_invars = donated_invars + self._platforms = platforms self.compile_args = compile_args self._executable = None @@ -2327,43 +2504,11 @@ def cost_analysis(self) -> dict[str, float]: if xb.using_pjrt_c_api(backend): raise NotImplementedError( "Lowered.cost_analysis not implemented on platform " - f"'{backend.platform}'. Use compile().cost_analysis() for " # type: ignore + f"'{backend.platform}'. Use compile().cost_analysis() for " "post-compilation cost estimates.") return xe.hlo_module_cost_analysis(backend, self.hlo().as_hlo_module()) -if xla_extension_version < 229: - def _get_input_indices( - avals: Sequence[ShapedArray], - shardings: Sequence[sharding_impls.XLACompatibleSharding], - da_object: xc.DeviceList | Sequence[xc.Device], # type: ignore - ) -> Sequence[tuple[Index | None, ...]]: - - input_indices = [] - if not isinstance(da_object, xc.DeviceList): - da_object = _create_da_object(tuple(da_object)) - num_addressable_devices = len(da_object.addressable_device_list) - - def _get_replicated_slices(num_addressable_devices: int, ndim: int | None): - if ndim is None: - return ((slice(None),),) * num_addressable_devices - else: - return ((slice(None),) * ndim,) * num_addressable_devices - - for aval, sharding in zip(avals, shardings): - if aval is core.abstract_token: - index = _get_replicated_slices(num_addressable_devices, None) - else: - if sharding.is_fully_replicated: - index = _get_replicated_slices(num_addressable_devices, aval.ndim) - else: - index = tuple( - sharding.addressable_devices_indices_map(aval.shape).values()) # type: ignore - input_indices.append(index) - - return input_indices - - def get_out_shardings_from_executable( xla_executable, device_assignment: Sequence[xc.Device], @@ -2461,7 +2606,7 @@ def _get_mesh_pspec_shardings_from_executable( _orig_out_sharding_handlers = {} -_ShardingT = TypeVar("_ShardingT", bound=sharding_impls.XLACompatibleSharding) +_ShardingT = TypeVar("_ShardingT", bound=JSharding) def _register_out_sharding_handler( @@ -2470,24 +2615,10 @@ def _register_out_sharding_handler( ) -> None: _orig_out_sharding_handlers[sharding_cls] = handler - -def _gspmd_to_named_sharding_via_mesh( - out_s: sharding_impls.GSPMDSharding, - mesh: Mesh) -> sharding_impls.NamedSharding: - parsed_pspec = sharding_impls.parse_flatten_op_sharding( - out_s._hlo_sharding, mesh)[0] - return create_mesh_pspec_sharding( - mesh, parsed_pspec.get_partition_spec(), parsed_pspec, - out_s.memory_kind) - def _gspmd_to_named_sharding( out_s: sharding_impls.GSPMDSharding, orig_in_s: sharding_impls.NamedSharding) -> sharding_impls.NamedSharding: - parsed_pspec = sharding_impls.parse_flatten_op_sharding( - out_s._hlo_sharding, orig_in_s.mesh)[0] - return create_mesh_pspec_sharding( - orig_in_s.mesh, parsed_pspec.get_partition_spec(), parsed_pspec, - out_s.memory_kind) + return sharding_impls._gspmd_to_named_sharding_via_mesh(out_s, orig_in_s.mesh) _register_out_sharding_handler( sharding_impls.NamedSharding, _gspmd_to_named_sharding) @@ -2518,49 +2649,49 @@ def _get_out_sharding_from_orig_sharding( out = [] orig_handler = _orig_out_sharding_handlers[type(orig_in_s)] for o, out_aval in safe_zip(out_shardings, out_avals): - if isinstance(o, sharding_impls.GSPMDSharding): - try: - # Only return the same input sharding object if the OpShardings and - # in_aval.ndim and out_aval.ndim match. This is because if OpSharding is - # replicated then, it doesn't encode the ndim in it. The devices - # will be the same at this point because those checks happen before. - if (orig_aval is not None and out_aval is not None and - out_aval.ndim == orig_aval.ndim - and sharding_impls.are_op_shardings_equal( - o._hlo_sharding, orig_in_s._to_xla_hlo_sharding(orig_aval.ndim)) - and o.memory_kind == orig_in_s.memory_kind): - out.append(orig_in_s) - else: + if (isinstance(o, sharding_impls.GSPMDSharding) and + out_aval is not core.abstract_token): + # Only return the same input sharding object if the OpShardings and + # in_aval.ndim and out_aval.ndim match. This is because if OpSharding is + # replicated then, it doesn't encode the ndim in it. The devices + # will be the same at this point because those checks happen before. + if (orig_aval is not None and out_aval is not None and + out_aval.ndim == orig_aval.ndim + and sharding_impls.are_op_shardings_equal( + o._hlo_sharding, orig_in_s._to_xla_hlo_sharding(orig_aval.ndim)) + and o.memory_kind == orig_in_s.memory_kind): + out.append(orig_in_s) + else: + try: out.append(orig_handler(o, orig_in_s)) - except: - out.append(o) + except: + out.append(o) else: out.append(o) return out -def maybe_get_orig_out_sharding( - in_shardings, out_shardings, in_avals, out_avals): - if all(hasattr(o, '_original_sharding') for o in out_shardings): - return [o._original_sharding for o in out_shardings] +def maybe_recover_user_shardings( + old_shardings, new_shardings, old_avals, new_avals): + if all(not isinstance(o, sharding_impls.GSPMDSharding) for o in new_shardings): + return new_shardings orig_in_s = None orig_aval = None - for i, aval in safe_zip(in_shardings, in_avals): - oi = getattr(i, '_original_sharding', None) + for oi, aval in safe_zip(old_shardings, old_avals): if type(oi) in _orig_out_sharding_handlers: orig_in_s = oi orig_aval = aval break if orig_in_s is not None: return _get_out_sharding_from_orig_sharding( - out_shardings, out_avals, orig_in_s, orig_aval) + new_shardings, new_avals, orig_in_s, orig_aval) - return out_shardings + return new_shardings def _get_layouts_from_executable( xla_executable, in_layouts, out_layouts, num_ordered_effects -) -> tuple[Sequence[SpecifiedLayout | None], Sequence[SpecifiedLayout | None]]: +) -> tuple[Sequence[DeviceLocalLayout | None], Sequence[DeviceLocalLayout | None]]: try: in_layouts_xla = xla_executable.get_parameter_layouts() out_layouts_xla = xla_executable.get_output_layouts() @@ -2573,67 +2704,54 @@ def _get_layouts_from_executable( new_in_layouts = [] for x, i in safe_zip(in_layouts_xla, in_layouts): - x = SpecifiedLayout(x) - if isinstance(i, SpecifiedLayout): + x = DeviceLocalLayout.from_pjrt_layout(x) + if isinstance(i, DeviceLocalLayout): if i != x: raise AssertionError( - f"Unexpected XLA layout override: (XLA) {x} != {i} (User sharding)") + f"Unexpected XLA layout override: (XLA) {x} != {i} (User input" + " layout)") new_in_layouts.append(i) else: new_in_layouts.append(x) new_out_layouts = [] for x, o in safe_zip(out_layouts_xla, out_layouts): - x = SpecifiedLayout(x) - if isinstance(o, SpecifiedLayout): + x = DeviceLocalLayout.from_pjrt_layout(x) + if isinstance(o, DeviceLocalLayout): if o != x: raise AssertionError( - f"Unexpected XLA layout override: (XLA) {x} != {o} (User sharding)") + f"Unexpected XLA layout override: (XLA) {x} != {o} (User output" + " layout)") new_out_layouts.append(o) else: new_out_layouts.append(x) - assert all(isinstance(i, SpecifiedLayout) for i in new_in_layouts) - assert all(isinstance(o, SpecifiedLayout) for o in new_out_layouts) - return new_in_layouts, new_out_layouts # type: ignore + assert all(isinstance(i, DeviceLocalLayout) for i in new_in_layouts) + assert all(isinstance(o, DeviceLocalLayout) for o in new_out_layouts) + return new_in_layouts, new_out_layouts def get_logical_mesh_ids(mesh_shape): return np.arange(math.prod(mesh_shape)).reshape(mesh_shape) -@weakref_lru_cache -def _cached_compilation(computation, name, mesh, spmd_lowering, - tuple_args, auto_spmd_lowering, allow_prop_to_inputs, - allow_prop_to_outputs, host_callbacks, backend, - da, pmap_nreps, compiler_options_keys, - compiler_options_values): - # TODO(phawkins): One would normally just write: - # dev = np.array(device_assignment) - # The formulation below is substantially faster if there are many devices. - # If we were to optimize __getattr__ on xc.Device we might not need this - # workaround. - dev = np.vectorize(lambda i: da[i], otypes=[object])( - np.arange(len(da)) - ) +def create_compile_options( + computation, mesh, spmd_lowering, tuple_args, auto_spmd_lowering, + allow_prop_to_inputs, allow_prop_to_outputs, backend, + np_dev, pmap_nreps, compiler_options): if pmap_nreps > 1: num_replicas, num_partitions = pmap_nreps, 1 elif spmd_lowering: - num_replicas, num_partitions = 1, dev.size + num_replicas, num_partitions = 1, np_dev.size else: - num_replicas, num_partitions = dev.size, 1 + num_replicas, num_partitions = np_dev.size, 1 if pmap_nreps > 1: # In `jit` device_assignment is set to None when num_replicas > 1. Do # the same thing here too. xla_device_assignment = None else: - xla_device_assignment = dev.reshape((num_replicas, num_partitions)) - - if compiler_options_keys is None: - compiler_options = None - else: - compiler_options = dict(safe_zip(compiler_options_keys, compiler_options_values)) + xla_device_assignment = np_dev.reshape((num_replicas, num_partitions)) fdo_profile = (None if compiler_options is None else compiler_options.pop("fdo_profile", None)) @@ -2658,19 +2776,39 @@ def _cached_compilation(computation, name, mesh, spmd_lowering, get_logical_mesh_ids(list(mesh.shape.values())) .reshape(-1)) compile_options.parameter_is_tupled_arguments = tuple_args - if xla_extension_version >= 241: - opts.allow_spmd_sharding_propagation_to_parameters = list(allow_prop_to_inputs) + opts.allow_spmd_sharding_propagation_to_parameters = list(allow_prop_to_inputs) opts.allow_spmd_sharding_propagation_to_output = list(allow_prop_to_outputs) + return compile_options + + +@weakref_lru_cache +def _cached_compilation(computation, name, mesh, spmd_lowering, + tuple_args, auto_spmd_lowering, allow_prop_to_inputs, + allow_prop_to_outputs, host_callbacks, backend, + da, pmap_nreps, compiler_options_keys, + compiler_options_values, + pgle_profiler): + # One would normally just write: dev = np.array(device_assignment) + # The formulation below is substantially faster if there are many devices. + dev = np.vectorize(lambda i: da[i], otypes=[object])(np.arange(len(da))) + + if compiler_options_keys is None: + compiler_options = None + else: + compiler_options = dict(safe_zip(compiler_options_keys, compiler_options_values)) - if hasattr(backend, "compile_replicated"): - return None, compile_options + compile_options = create_compile_options( + computation, mesh, spmd_lowering, tuple_args, auto_spmd_lowering, + allow_prop_to_inputs, allow_prop_to_outputs, backend, + dev, pmap_nreps, compiler_options) with dispatch.log_elapsed_time( "Finished XLA compilation of {fun_name} in {elapsed_time} sec", fun_name=name, event=dispatch.BACKEND_COMPILE_EVENT): xla_executable = compiler.compile_or_get_cached( - backend, computation, dev, compile_options, host_callbacks) - return xla_executable, compile_options + backend, computation, dev, compile_options, host_callbacks, + pgle_profiler) + return xla_executable def _maybe_get_and_check_in_shardings( @@ -2685,9 +2823,9 @@ def _maybe_get_and_check_in_shardings( If in_sharding is unspecified, then the sharding returned by XLA is returned. """ - in_shardings_xla = _get_in_shardings_from_xla( # type: ignore + in_shardings_xla = _get_in_shardings_from_xla( xla_executable, device_assignment, len(global_in_avals), - num_ordered_effects) # type: ignore + num_ordered_effects) if in_shardings_xla is None: return in_shardings @@ -2697,25 +2835,23 @@ def _maybe_get_and_check_in_shardings( if is_unspecified(orig): if (aval is not core.abstract_token and dtypes.issubdtype(aval.dtype, dtypes.extended)): - xla_s = aval.dtype._rules.logical_sharding(aval, xla_s) + xla_s = sharding_impls.logical_sharding(aval, xla_s) new_in_shardings.append(xla_s) else: - # TODO(yashkatariya): Remove the if branch for abstract_token once - # choosing input shardings by XLA is enabled again. - if aval is core.abstract_token: - new_in_shardings.append(orig) - else: - xla_hlo_s = xla_s._to_xla_hlo_sharding(aval.ndim) # type: ignore - orig_hlo_s = orig._to_xla_hlo_sharding(aval.ndim) # type: ignore - # MANUAL HloSharding comes from other partitioning frameworks. - if (not dtypes.issubdtype(aval.dtype, dtypes.extended) and - not xla_hlo_s.is_manual() and - (not op_shardings.are_op_shardings_equal(xla_hlo_s, orig_hlo_s) or - xla_s.memory_kind != orig.memory_kind)): # type: ignore - raise AssertionError( - f"Unexpected XLA sharding override: (XLA) {xla_s} != {orig} " - "(User sharding)") - new_in_shardings.append(orig) + xla_hlo_s = xla_s._to_xla_hlo_sharding(aval.ndim) + orig_hlo_s = orig._to_xla_hlo_sharding(aval.ndim) # pytype: disable=attribute-error + # MANUAL HloSharding comes from other partitioning frameworks. + if (not dtypes.issubdtype(aval.dtype, dtypes.extended) and + not xla_hlo_s.is_manual() and + (not op_shardings.are_op_shardings_equal(xla_hlo_s, orig_hlo_s))): + raise AssertionError( + f"Unexpected XLA sharding override: (XLA) {xla_s} != {orig} " + "(User sharding)") + new_in_shardings.append(orig) + + new_in_shardings = maybe_recover_user_shardings( + in_shardings, new_in_shardings, global_in_avals, global_in_avals) + return new_in_shardings @@ -2723,9 +2859,9 @@ def _maybe_get_and_check_out_shardings( xla_executable, out_shardings, device_assignment, global_out_avals, num_ordered_effects, all_default_mem_kind ): - out_shardings_xla = get_out_shardings_from_executable( # type: ignore + out_shardings_xla = get_out_shardings_from_executable( xla_executable, device_assignment, len(global_out_avals), - num_ordered_effects, all_default_mem_kind) # type: ignore + num_ordered_effects, all_default_mem_kind) if out_shardings_xla is None: return out_shardings @@ -2735,16 +2871,16 @@ def _maybe_get_and_check_out_shardings( if is_unspecified(orig): if (aval is not core.abstract_token and dtypes.issubdtype(aval.dtype, dtypes.extended)): - xla_s = aval.dtype._rules.logical_sharding(aval, xla_s) + xla_s = sharding_impls.logical_sharding(aval, xla_s) new_out_shardings.append(xla_s) else: - xla_hlo_s = xla_s._to_xla_hlo_sharding(aval.ndim) # type: ignore - orig_hlo_s = orig._to_xla_hlo_sharding(aval.ndim) # type: ignore + xla_hlo_s = xla_s._to_xla_hlo_sharding(aval.ndim) + orig_hlo_s = orig._to_xla_hlo_sharding(aval.ndim) # pytype: disable=attribute-error # MANUAL HloSharding comes from other partitioning frameworks. if (not dtypes.issubdtype(aval.dtype, dtypes.extended) and not xla_hlo_s.is_manual() and (not op_shardings.are_op_shardings_equal(xla_hlo_s, orig_hlo_s) or - xla_s.memory_kind != orig.memory_kind)): # type: ignore + xla_s.memory_kind != orig.memory_kind)): # pytype: disable=attribute-error raise AssertionError( f"Unexpected XLA sharding override: (XLA) {xla_s} != {orig} " "(User sharding)") @@ -2762,12 +2898,12 @@ def finalize_out_shardings(out_shardings, device_assignment): @dataclasses.dataclass class UnloadedMeshExecutable: xla_executable: Any - device_assignment: xc.DeviceList | Sequence[xc.Device] # type: ignore + device_assignment: xc.DeviceList | Sequence[xc.Device] backend: xb.XlaBackend input_avals: Sequence[ShapedArray] - input_shardings: Sequence[sharding_impls.XLACompatibleSharding] + input_shardings: Sequence[JSharding] output_avals: Sequence[ShapedArray] - output_shardings: Sequence[sharding_impls.XLACompatibleSharding] + output_shardings: Sequence[JSharding] committed: bool name: str unordered_effects: list[core.Effect] @@ -2775,27 +2911,23 @@ class UnloadedMeshExecutable: keepalive: Sequence[Any] host_callbacks: Sequence[Any] kept_var_idx: set[int] - out_mut: Sequence[None | int] | None + mut: MutationData | None auto_spmd_lowering: bool - in_layouts: Sequence[SpecifiedLayout | None] - out_layouts: Sequence[SpecifiedLayout | None] + in_layouts: Sequence[DeviceLocalLayout | None] + out_layouts: Sequence[DeviceLocalLayout | None] all_args_info: AllArgsInfo | None + pgle_profiler: profiler.PGLEProfiler | None def build_unsafe_call(self): - if xla_extension_version >= 229: - handle_args = InputsHandler(self.input_shardings) - else: - input_indices = _get_input_indices(self.input_avals, self.input_shardings, - self.device_assignment) - handle_args = InputsHandler( - self.input_shardings, self.xla_executable.local_devices(), input_indices) + handle_args = InputsHandler(self.input_shardings) handle_outs = global_avals_to_results_handler( - self.output_avals, self.output_shardings, self.committed) # type: ignore # arg-type + self.output_avals, self.output_shardings, self.committed) - unsafe_call = ExecuteReplicated( # type: ignore # assignment + unsafe_call = ExecuteReplicated( self.xla_executable, self.name, self.backend, handle_args, handle_outs, self.unordered_effects, self.ordered_effects, self.keepalive, - bool(self.host_callbacks), self.kept_var_idx, self.out_mut) + bool(self.host_callbacks), self.kept_var_idx, self.mut, + self.pgle_profiler) return unsafe_call def load(self) -> MeshExecutable: @@ -2806,15 +2938,13 @@ def load(self) -> MeshExecutable: self.in_layouts, self.out_layouts, self.all_args_info, self) - # May return a MeshExecutable in the compile_replicated case. @staticmethod def from_hlo(name: str, hlo: ir.Module, global_in_avals: Sequence[ShapedArray], global_out_avals: Sequence[ShapedArray], - in_shardings: Sequence[sharding_impls.XLACompatibleSharding | AUTO], - out_shardings: Sequence[(sharding_impls.XLACompatibleSharding | AUTO | - UnspecifiedValue)], + in_shardings: Sequence[JSharding | AUTO], + out_shardings: Sequence[(JSharding | AUTO | UnspecifiedValue)], spmd_lowering: bool, tuple_args: bool, auto_spmd_lowering: bool, @@ -2824,16 +2954,17 @@ def from_hlo(name: str, keepalive: Any, kept_var_idx: set[int], backend: xb.XlaBackend, - device_assignment: xc.DeviceList | Sequence[xc.Device], # type: ignore + device_assignment: xc.DeviceList | Sequence[xc.Device], committed: bool, in_layouts: MaybeLayout, out_layouts: MaybeLayout, pmap_nreps: int = 1, - out_mut: Sequence[None | int] | None = None, + mut: MutationData | None = None, shape_poly_state: mlir.ShapePolyLoweringState | None = None, all_default_mem_kind: bool = True, all_args_info: AllArgsInfo | None = None, compiler_options=None, + pgle_profiler: profiler.PGLEProfiler | None = None ) -> MeshExecutable: if shape_poly_state is not None and shape_poly_state.uses_dim_vars: hlo = mlir.refine_polymorphic_shapes(hlo) @@ -2847,8 +2978,10 @@ def from_hlo(name: str, da = _create_da_object(tuple(device_assignment)) del device_assignment - allow_prop_to_inputs = tuple(is_unspecified(i) for i in in_shardings) - allow_prop_to_outputs = tuple(is_unspecified(o) for o in out_shardings) + allow_prop_to_inputs = tuple(is_unspecified(i) or is_auto(i) + for i in in_shardings) + allow_prop_to_outputs = tuple(is_unspecified(o) or is_auto(o) + for o in out_shardings) mesh = None if auto_spmd_lowering: @@ -2857,37 +2990,26 @@ def from_hlo(name: str, mesh = i.mesh # type: ignore break - xla_executable, compile_options = _cached_compilation( + xla_executable = _cached_compilation( hlo, name, mesh, spmd_lowering, tuple_args, auto_spmd_lowering, allow_prop_to_inputs, allow_prop_to_outputs, tuple(host_callbacks), backend, da, pmap_nreps, - compiler_options_keys, compiler_options_values) - - if hasattr(backend, "compile_replicated"): - semantics_in_shardings = SemanticallyEqualShardings(in_shardings) # type: ignore - semantics_out_shardings = SemanticallyEqualShardings(out_shardings) # type: ignore - return _compile_replicated_mesh_executable_from_hlo( - hlo, name, tuple(global_in_avals), tuple(global_out_avals), - semantics_in_shardings, semantics_out_shardings, auto_spmd_lowering, - compile_options, tuple(host_callbacks), bool(unordered_effects), - tuple(ordered_effects), tuple(kept_var_idx), backend, da, committed, - pmap_nreps) + compiler_options_keys, compiler_options_values, pgle_profiler) if auto_spmd_lowering: assert mesh is not None in_shardings_xla, out_shardings_xla = _get_mesh_pspec_shardings_from_executable( xla_executable, mesh) - in_shardings = [x if is_auto(i) else getattr(i, '_original_sharding', i) # type: ignore + in_shardings = [x if is_auto(i) else i for x, i in safe_zip(in_shardings_xla, in_shardings)] out_shardings = [x if is_auto(o) else o for x, o in safe_zip(out_shardings_xla, out_shardings)] else: if pmap_nreps == 1: assert mesh is None - if xla_extension_version >= 241: - in_shardings = _maybe_get_and_check_in_shardings( - xla_executable, in_shardings, tuple(da), global_in_avals, - len(ordered_effects)) + in_shardings = _maybe_get_and_check_in_shardings( + xla_executable, in_shardings, tuple(da), global_in_avals, + len(ordered_effects)) out_shardings = _maybe_get_and_check_out_shardings( xla_executable, out_shardings, tuple(da), global_out_avals, len(ordered_effects), all_default_mem_kind) @@ -2895,21 +3017,17 @@ def from_hlo(name: str, in_shardings, out_shardings, committed, da = _get_metadata_jit_pmap( xla_executable.local_devices(), len(in_shardings), len(out_shardings)) - if xla_extension_version >= 217: - in_layouts, out_layouts = _get_layouts_from_executable( - xla_executable, in_layouts, out_layouts, len(ordered_effects)) - else: - assert all(i is None for i in in_layouts) - assert all(o is None for o in out_layouts) + in_layouts, out_layouts = _get_layouts_from_executable( + xla_executable, in_layouts, out_layouts, len(ordered_effects)) - out_shardings = maybe_get_orig_out_sharding( + out_shardings = maybe_recover_user_shardings( in_shardings, out_shardings, global_in_avals, global_out_avals) out_shardings = finalize_out_shardings(out_shardings, da) return UnloadedMeshExecutable( xla_executable=xla_executable, - device_assignment=da, # type: ignore + device_assignment=da, backend=backend, input_avals=global_in_avals, input_shardings=in_shardings, # type: ignore @@ -2922,18 +3040,19 @@ def from_hlo(name: str, keepalive=keepalive, host_callbacks=host_callbacks, kept_var_idx=kept_var_idx, - out_mut=out_mut, + mut=mut, auto_spmd_lowering=auto_spmd_lowering, - in_layouts=in_layouts, # type: ignore - out_layouts=out_layouts, # type: ignore - all_args_info=all_args_info).load() # type: ignore + in_layouts=in_layouts, + out_layouts=out_layouts, + all_args_info=all_args_info, + pgle_profiler=pgle_profiler).load() class MeshExecutableFastpathData(NamedTuple): xla_executable: xc.LoadedExecutable out_pytree_def: Any - in_shardings: Sequence[sharding_impls.XLACompatibleSharding] - out_shardings: Sequence[sharding_impls.XLACompatibleSharding] + in_shardings: Sequence[JSharding] + out_shardings: Sequence[JSharding] out_avals: Sequence[ShapedArray] out_committed: Sequence[bool] kept_var_bitvec: Iterable[bool] @@ -2991,36 +3110,37 @@ def xla_extension_executable(self): return self.xla_executable def call(self, *args): + args_after_dce = [a for i, a in enumerate(args) if i in self._kept_var_idx] if self._all_args_info is None: - kept_args = [a for i, a in enumerate(args) if i in self._kept_var_idx] + kept_args = args_after_dce ref_avals = self.in_avals - in_shardings = self._in_shardings debug_info = None else: kept_args = args ref_avals = self._all_args_info.in_avals - iter_in_shardings = iter(self._in_shardings) - in_shardings = [next(iter_in_shardings) if i in self._kept_var_idx else s - for i, s in enumerate(self._all_args_info.in_shardings)] debug_info = self._all_args_info.debug_info - arg_avals = map(xla.abstractify, kept_args) - check_arg_avals_for_call(ref_avals, arg_avals, debug_info) + all_arg_avals = map(xla.abstractify, kept_args) + check_arg_avals_for_call(ref_avals, all_arg_avals, debug_info) # Check the GDA sharding and the input sharding. - check_gda_or_array_xla_sharding_match(kept_args, in_shardings, debug_info) + check_array_xla_sharding_layout_match( + args_after_dce, self._in_shardings, self._in_layouts, debug_info, + self._kept_var_idx) return self.unsafe_call(*args) # pylint: disable=not-callable - def input_shardings(self) -> Sequence[sharding_impls.XLACompatibleSharding]: + def input_shardings(self) -> Sequence[JSharding]: return self._in_shardings - def output_shardings(self) -> Sequence[sharding_impls.XLACompatibleSharding]: + def output_shardings(self) -> Sequence[JSharding]: return self._out_shardings def input_layouts(self): - return self._in_layouts + return [Layout(l, s) + for l, s in safe_zip(self._in_layouts, self._in_shardings)] def output_layouts(self): - return self._out_layouts + return [Layout(l, s) + for l, s in safe_zip(self._out_layouts, self._out_shardings)] def create_cpp_call(self, no_kwargs, in_tree, out_tree): if not (isinstance(self.unsafe_call, ExecuteReplicated) and @@ -3041,7 +3161,7 @@ def aot_cache_miss(*args, **kwargs): kept_var_bitvec = [i in self._kept_var_idx for i in range(len(args_flat))] in_shardings = [ - a.dtype._rules.physical_sharding(a, s) + sharding_impls.physical_sharding(a, s) if a is not core.abstract_token and dtypes.issubdtype(a.dtype, dtypes.extended) else s for s, a in zip(self._in_shardings, self.in_avals) @@ -3053,21 +3173,11 @@ def aot_cache_miss(*args, **kwargs): self.unsafe_call.in_handler.input_indices) else: fastpath_data = None - return outs, fastpath_data - - if xla_extension_version >= 226: - return xc._xla.pjit( - self.unsafe_call.name, None, aot_cache_miss, [], [], [], - tree_util.dispatch_registry, - shard_arg if xla_extension_version >= 229 else temp_shard_arg) # type: ignore - else: - return xc._xla.pjit(self.unsafe_call.name, None, aot_cache_miss, [], [], [], # type: ignore - tree_util.dispatch_registry) - + return outs, fastpath_data, False # Do not remove cache entry -# TODO(yashkatariya): Remove once minimum jaxlib version is 0.4.24 -def temp_shard_arg(arg, devices, arg_indices, sharding, canonicalize=True): - return shard_arg(arg, sharding) + return xc._xla.pjit( + self.unsafe_call.name, None, aot_cache_miss, [], [], [], + tree_util.dispatch_registry, lambda x, s: shard_args([s], [x])[0]) def check_arg_avals_for_call(ref_avals, arg_avals, @@ -3114,102 +3224,91 @@ def _get_metadata_jit_pmap(local_devices, num_in_shardings, num_out_shardings): return in_shardings, out_shardings, committed, tuple(local_devices) -@weakref_lru_cache -def _compile_replicated_mesh_executable_from_hlo( - computation, name, global_in_avals, global_out_avals, semantics_in_shardings, - semantics_out_shardings, auto_spmd_lowering, compile_options, - host_callbacks, has_unordered_effects, ordered_effects, kept_var_idx, - backend, da, committed, pmap_nreps): - assert not auto_spmd_lowering - in_shardings = semantics_in_shardings.shardings - out_shardings = semantics_out_shardings.shardings - - kept_var_idx = set(kept_var_idx) - # Will compute out_handler with executable information. - unsafe_call = backend.compile_replicated( - is_trivial=False, name=name, computation=computation, - compile_options=compile_options, host_callbacks=host_callbacks, - has_unordered_effects=has_unordered_effects, - device_assignment=da, ordered_effects=ordered_effects, - in_avals=global_in_avals, - in_shardings=in_shardings, kept_var_idx=kept_var_idx, - out_avals=global_out_avals, out_shardings=out_shardings, - committed=committed, pmap_nreps=pmap_nreps) - xla_executable = None - return MeshExecutable(xla_executable, lambda: unsafe_call, global_in_avals, - global_out_avals, in_shardings, out_shardings, - auto_spmd_lowering, kept_var_idx, - (None,) * len(global_in_avals), - (None,) * len(global_out_avals)) - - -@lru_cache -def create_mesh_pspec_sharding( - mesh: Mesh, pspec: PartitionSpec | None, parsed_pspec=None, - memory_kind: str | None = None) -> sharding_impls.NamedSharding: - if pspec is None: - pspec, parsed_pspec = PartitionSpec(), None - return sharding_impls.NamedSharding(mesh, pspec, _parsed_pspec=parsed_pspec, - memory_kind=memory_kind) +create_mesh_pspec_sharding = sharding_impls.create_mesh_pspec_sharding def check_device_backend_on_shardings(shardings) -> bool: for i in shardings: if is_unspecified(i) or is_auto(i): continue - if hasattr(i, '_original_sharding') and getattr( - i._original_sharding, '_device_backend', False): + if getattr(i, '_device_backend', False): return True return False -def check_gda_or_array_xla_sharding_match( - args, in_xla_shardings: Sequence[sharding_impls.XLACompatibleSharding], - jaxpr_debug_info: core.JaxprDebugInfo | None) -> None: +def check_array_xla_sharding_layout_match( + args_after_dce, + in_xla_shardings: Sequence[JSharding], + in_xla_layouts: Sequence[DeviceLocalLayout], + jaxpr_debug_info: core.JaxprDebugInfo | None, + kept_var_idx: set[int]) -> None: from jax._src.array import ArrayImpl - arg_names = ([''] * len(args) if jaxpr_debug_info is None else - jaxpr_debug_info.arg_names) + # jaxpr_debug_info.arg_names are before DCE, so need to DCE them. + arg_names = ( + [""] * len(args_after_dce) if jaxpr_debug_info is None + else [a for i, a in enumerate(jaxpr_debug_info.arg_names) # type: ignore + if i in kept_var_idx] + ) errors = [] num_errors = 5 - for arg, xs, name in safe_zip(args, in_xla_shardings, arg_names): + for arg, xs, xl, name in safe_zip( + args_after_dce, in_xla_shardings, in_xla_layouts, arg_names): if not isinstance(arg, ArrayImpl): continue if is_unspecified_or_auto(xs): continue db_xs = check_device_backend_on_shardings([xs]) - if not db_xs: - xs = getattr(xs, '_original_sharding', xs) # Raise memory kind mismatch error even if the arg is uncommitted. if arg.sharding.memory_kind != xs.memory_kind: errors.append( - "Got input sharding(s) that compiled object was called with: " + ("Got input sharding(s) that compiled object was called with: " f"{arg.sharding} and sharding(s) the computation was compiled " - f"with: {xs} for arg {name} with shape: {arg.aval.str_short()}") + f"with: {xs} for arg {name} with shape: {arg.aval.str_short()}", + 'sharding')) if (not db_xs and arg._committed and not op_shardings.are_op_shardings_equal( arg.sharding._to_xla_hlo_sharding(arg.ndim), xs._to_xla_hlo_sharding(arg.ndim))): errors.append( - "Got input sharding(s) that compiled object was called with: " + ("Got input sharding(s) that compiled object was called with: " f"{arg.sharding} and sharding(s) the computation was compiled " - f"with: {xs} for arg {name} with shape: {arg.aval.str_short()}") + f"with: {xs} for arg {name} with shape: {arg.aval.str_short()}", + 'sharding')) + + if (not db_xs and arg._committed and + arg.layout.device_local_layout is not None and xl is not None and + arg.layout.device_local_layout != xl): + errors.append( + ("Got input layout(s) that compiled object was called with: " + f"{arg.layout.device_local_layout} and layout(s) the computation was " + f"compiled with: {xl} for arg {name} with " + f"shape: {arg.aval.str_short()}", + 'layout')) if errors: - str_errors = '\n'.join(errors[:num_errors]) + first_errors, error_kinds = unzip2(errors[:num_errors]) + str_errors = '\n'.join(first_errors) + if all(k == 'sharding' for k in error_kinds): + kind_str = r'sharding(s)' + elif all(k == 'layout' for k in error_kinds): + kind_str = 'layout(s)' + else: + kind_str = 'sharding(s) and layout(s)' num_mismatch_str = ( f'the {len(errors)} mismatches' if len(errors) < num_errors else f"{num_errors} mismatches out of {len(errors)}") raise ValueError( - "Compiled object called with input sharding(s) does not match the " - "sharding(s) the computation was compiled with. " + f"Compiled object called with input {kind_str} does " + f"not match the {kind_str} the computation was " + "compiled with. " f"Here are {num_mismatch_str}:\n{str_errors}") def get_array_mapping(pspec: PartitionSpec) -> ArrayMappingOrAutoOrUnspecified: - parsed_pspec, _, _ = sharding_impls.prepare_axis_resources( + parsed_pspec = sharding_impls.prepare_axis_resources( pspec, "pspec to array_mapping") return _get_array_mapping(parsed_pspec) @@ -3276,12 +3375,3 @@ def _check_aval(aval, what_thunk): def maybe_extend_axis_env(*args, **kwargs): with core.extend_axis_env(*args, **kwargs): yield - - -def device_put(x, devices: Sequence[xc.ArrayImpl], - replicate: bool=False) -> list[xc.ArrayImpl]: - """Call device_put on a sequence of devices and return a flat sequence of buffers.""" - if replicate: - return [jax.device_put(x, device) for device in devices] - else: - return [jax.device_put(val, device) for val, device in safe_zip(x, devices)] diff --git a/jax/_src/interpreters/xla.py b/jax/_src/interpreters/xla.py index 3b3c7f74515a..2db877d3f970 100644 --- a/jax/_src/interpreters/xla.py +++ b/jax/_src/interpreters/xla.py @@ -17,12 +17,12 @@ from __future__ import annotations from collections import defaultdict -from collections.abc import Sequence +from collections.abc import Callable, Sequence import dataclasses import functools from functools import partial import itertools as it -from typing import Any, Callable, Protocol, Union +from typing import Any, Protocol, Union import numpy as np @@ -92,7 +92,7 @@ def sharding_to_proto(sharding: SpatialSharding): proto = xc.OpSharding() if isinstance(sharding, tuple) and not isinstance(sharding[0], int): assert all(s is None or isinstance(s, tuple) for s in sharding) - return tuple_sharding_proto(list(map(sharding_to_proto, sharding))) # type: ignore + return tuple_sharding_proto(list(map(sharding_to_proto, sharding))) if sharding is None: proto.type = xc.OpSharding.Type.REPLICATED @@ -110,8 +110,6 @@ def tuple_sharding_proto(elems): return proto - - ### handlers # JAX abstract values -> XLA shapes @@ -132,6 +130,10 @@ def aval_to_xla_shapes(aval: core.AbstractValue) -> Sequence[xc.Shape]: # IR constants +class InvalidInputException(Exception): + pass + + # TODO(mattjj): try to remove this canonicalize_dtype stuff def canonicalize_dtype(x): typ = type(x) @@ -142,8 +144,8 @@ def canonicalize_dtype(x): if handler: return handler(x) if hasattr(x, '__jax_array__'): return canonicalize_dtype(x.__jax_array__()) - raise TypeError(f"Argument '{x}' of type {type(x)} is not a valid " - "JAX type.") + raise InvalidInputException( + f"Argument '{x}' of type {type(x)} is not a valid JAX type.") def _canonicalize_masked_array_dtype(x): raise ValueError("numpy masked arrays are not supported as direct inputs to JAX functions. " diff --git a/jax/_src/jaxpr_util.py b/jax/_src/jaxpr_util.py index b9a7763f596d..3f3f677b069d 100644 --- a/jax/_src/jaxpr_util.py +++ b/jax/_src/jaxpr_util.py @@ -17,11 +17,12 @@ from __future__ import annotations from collections import Counter, defaultdict +from collections.abc import Callable import gzip import itertools import json import types -from typing import Any, Callable, Union +from typing import Any, Union from jax._src import core from jax._src import util @@ -152,7 +153,7 @@ def _pprof_profile( else: raw_frames = zip(*tb.raw_frames()) frames = [loc[(code, lasti)] for code, lasti in raw_frames - if source_info_util.is_user_filename(code.co_filename)] # type: ignore + if source_info_util.is_user_filename(code.co_filename)] samples.append({ "location_id": frames, "value": [count], diff --git a/jax/_src/lax/ann.py b/jax/_src/lax/ann.py index 683c4754fcbf..f950cfeada92 100644 --- a/jax/_src/lax/ann.py +++ b/jax/_src/lax/ann.py @@ -229,11 +229,19 @@ def _approx_top_k_abstract_eval(operand, *, k, reduction_dimension, if not dtypes.issubdtype(operand.dtype, np.floating): raise ValueError('operand must be a floating type') reduction_input_size = dims[reduction_dimension] - dims[reduction_dimension] = xc.ops.ApproxTopKReductionOutputSize( - reduction_input_size, len(dims), k, recall_target, aggregate_to_topk, - reduction_input_size_override)[0] - return (operand.update( - shape=dims, dtype=operand.dtype, weak_type=operand.weak_type), + if aggregate_to_topk: + dims[reduction_dimension] = k + elif core.is_constant_shape((reduction_input_size, k)): + dims[reduction_dimension] = xc.ops.ApproxTopKReductionOutputSize( + reduction_input_size, len(dims), k, recall_target, aggregate_to_topk, + reduction_input_size_override)[0] + else: + raise NotImplementedError( + "approx_top_k with aggregate_to_topk=False not yet implemented when " + f"either the `k` ({k}) or the " + f" reduction dimension size ({reduction_input_size}) are symbolic") + return (operand.update(shape=dims, dtype=operand.dtype, + weak_type=operand.weak_type), operand.update(shape=dims, dtype=np.dtype(np.int32))) @@ -254,30 +262,6 @@ def _comparator_builder(op_type, is_max_k): def _get_init_val_literal(op_type, is_max_k): return np.array(-np.inf if is_max_k else np.inf, dtype=op_type) -def _approx_top_k_tpu_translation(ctx, avals_in, avals_out, operand, *, k, - reduction_dimension, recall_target, is_max_k, - reduction_input_size_override, - aggregate_to_topk): - c = ctx.builder - op_shape = c.get_shape(operand) - if not op_shape.is_array(): - raise ValueError(f'operand must be an array, but was {op_shape}') - op_dims = op_shape.dimensions() - op_type = op_shape.element_type() - if reduction_dimension < 0: - reduction_dimension = len(op_dims) + reduction_dimension - comparator = _comparator_builder(op_type, is_max_k) - init_val_literal = _get_init_val_literal(op_type, is_max_k) - iota = xc.ops.Iota(c, xc.Shape.array_shape(np.dtype(np.int32), op_dims), - reduction_dimension) - init_val = xc.ops.Constant(c, init_val_literal) - init_arg = xc.ops.Constant(c, np.int32(-1)) - out = xc.ops.ApproxTopK(c, [operand, iota], [init_val, init_arg], k, - reduction_dimension, comparator, recall_target, - aggregate_to_topk, reduction_input_size_override) - return xla.xla_destructure(c, out) - - def _comparator_builder_mlir(ctx, op_type, is_max_k): scalar = ir.RankedTensorType.get([], op_type) index = ir.RankedTensorType.get([], ir.IntegerType.get_signless(32)) @@ -326,7 +310,6 @@ def _approx_top_k_lowering(ctx, operand, *, k, init_val = mlir.ir_constant(init_val_array.reshape(())) backend_config = { - "top_k" : mlir.i64_attr(k), "reduction_dim" : mlir.i64_attr(reduction_dimension), "recall_target" : mlir.ir.FloatAttr.get(recall_type, recall_target), "aggregate_to_topk" : mlir.ir.BoolAttr.get(aggregate_to_topk), @@ -342,13 +325,24 @@ def _approx_top_k_lowering(ctx, operand, *, k, mlir.shape_tensor(mlir.eval_dynamic_shape(ctx, aval_out.shape)) for aval_out in ctx.avals_out] - out = mlir.custom_call( - "ApproxTopK", - result_types=[mlir.aval_to_ir_type(aval) for aval in ctx.avals_out], - operands=[operand, iota, init_val, init_arg], - called_computations=[comparator.name.value], - backend_config=backend_config, - result_shapes=result_shapes) + if core.is_constant_dim(k): + backend_config["top_k"] = mlir.i64_attr(k) + out = mlir.custom_call( + "ApproxTopK", + result_types=[mlir.aval_to_ir_type(aval) for aval in ctx.avals_out], + operands=[operand, iota, init_val, init_arg], + called_computations=[comparator.name.value], + backend_config=backend_config, + result_shapes=result_shapes) + else: + k_value, = mlir.eval_dynamic_shape_as_vals(ctx, (k,)) + out = mlir.custom_call( + "stablehlo.dynamic_approx_top_k", + result_types=[mlir.aval_to_ir_type(aval) for aval in ctx.avals_out], + operands=[operand, iota, init_val, init_arg, k_value], + called_computations=[comparator.name.value], + backend_config=backend_config, + result_shapes=result_shapes) return out.results diff --git a/jax/_src/lax/control_flow/common.py b/jax/_src/lax/control_flow/common.py index f75175b79bb6..b613193876b6 100644 --- a/jax/_src/lax/control_flow/common.py +++ b/jax/_src/lax/control_flow/common.py @@ -15,10 +15,10 @@ from __future__ import annotations -from collections.abc import Sequence +from collections.abc import Callable, Sequence import os from functools import partial -from typing import Any, Callable +from typing import Any from jax._src import core from jax._src import linear_util as lu diff --git a/jax/_src/lax/control_flow/conditionals.py b/jax/_src/lax/control_flow/conditionals.py index 0193c7203e1a..f87c6cdc4e28 100644 --- a/jax/_src/lax/control_flow/conditionals.py +++ b/jax/_src/lax/control_flow/conditionals.py @@ -15,14 +15,15 @@ from __future__ import annotations import collections -from collections.abc import Sequence +from collections.abc import Callable, Sequence import functools from functools import partial import inspect import itertools import operator -from typing import Any, Callable, TypeVar +from typing import Any, TypeVar +import jax from jax.tree_util import tree_flatten, tree_unflatten from jax._src import ad_util from jax._src import config @@ -69,7 +70,7 @@ @api_boundary def switch(index, branches: Sequence[Callable], *operands, operand=_no_operand_sentinel): - """Apply exactly one of ``branches`` given by ``index``. + """Apply exactly one of the ``branches`` given by ``index``. If ``index`` is out of bounds, it is clamped to within bounds. @@ -153,14 +154,12 @@ def switch(index, branches, *operands): # of those effects (even if they don't have an explicit data dependence). index = core.raise_as_much_as_possible(index) - linear = (False,) * (len(consts) + len(ops)) - out = cond_p.bind( - index, *consts, *ops, branches=tuple(jaxprs), linear=linear) + out = cond_p.bind(index, *consts, *ops, branches=tuple(jaxprs)) return tree_unflatten(out_trees[0], out) def _cond(pred, true_fun: Callable, false_fun: Callable, *operands, - operand=_no_operand_sentinel, linear=None): + operand=_no_operand_sentinel): """Conditionally apply ``true_fun`` or ``false_fun``. Wraps XLA's `Conditional @@ -234,12 +233,6 @@ def cond(pred, true_fun, false_fun, *operands): return false_fun(*operands) ops, ops_tree = tree_flatten(operands) - if linear is None: - linear_ops = [False] * len(ops) - else: - linear_ops, ops_tree2 = tree_flatten(linear) - if ops_tree != ops_tree2: - raise TypeError('linear tree and operand tree mismatch') ops_avals = tuple(map(_abstractify, ops)) jaxprs, consts, out_trees = _initial_style_jaxprs_with_common_consts( @@ -247,6 +240,7 @@ def cond(pred, true_fun, false_fun, *operands): if any(isinstance(op_aval, AbstractRef) for op_aval in ops_avals): raise ValueError("Cannot pass `Ref`s into `cond`.") true_jaxpr, false_jaxpr = jaxprs + out_tree, false_out_tree = out_trees if any(isinstance(out_aval, AbstractRef) for out_aval in true_jaxpr.out_avals + false_jaxpr.out_avals): @@ -255,6 +249,14 @@ def cond(pred, true_fun, false_fun, *operands): _check_tree_and_avals("true_fun and false_fun output", out_tree, true_jaxpr.out_avals, false_out_tree, false_jaxpr.out_avals) + # prune passhtrough outputs + true_fwds = pe._jaxpr_forwarding(true_jaxpr.jaxpr) + false_fwds = pe._jaxpr_forwarding(false_jaxpr.jaxpr) + in_fwd = [i if i == j else None for i, j in zip(true_fwds, false_fwds)] + keep = [f is None for f in in_fwd] + true_jaxpr = pe.prune_closed_jaxpr_outputs(true_jaxpr, keep) + false_jaxpr = pe.prune_closed_jaxpr_outputs(false_jaxpr, keep) + joined_effects = core.join_effects(true_jaxpr.effects, false_jaxpr.effects) disallowed_effects = effects.control_flow_allowed_effects.filter_not_in(joined_effects) if disallowed_effects: @@ -269,10 +271,19 @@ def cond(pred, true_fun, false_fun, *operands): false_jaxpr = replace_jaxpr_effects(false_jaxpr, joined_effects) true_jaxpr = replace_jaxpr_effects(true_jaxpr, joined_effects) - linear = [False] * len(consts) + linear_ops - out = cond_p.bind( - index, *consts, *ops, - branches=(false_jaxpr, true_jaxpr), linear=tuple(linear)) + out = cond_p.bind(index, *consts, *ops, branches=(false_jaxpr, true_jaxpr)) + num_consts = len(consts) + out_ = iter(out) + + def _cast_to_array(x): + _copy = isinstance(x, np.bool_) + return jax.numpy.asarray(x, copy=_copy) + + out = [ + next(out_) if fwd is None else _cast_to_array(ops[fwd - num_consts]) + for fwd in in_fwd + ] + assert next(out_, None) is None return tree_unflatten(out_tree, out) @api_boundary @@ -347,7 +358,7 @@ def _bcast_select_n(pred, *cases): return lax.select_n(pred, *cases) def _cond_batching_rule(spmd_axis_name, axis_size, axis_name, main_type, args, - dims, branches, linear): + dims, branches): index, *ops = args index_dim, *op_dims = dims # TODO(sharadmv): clean this up by adding a specific blocklist @@ -405,11 +416,10 @@ def _cond_batching_rule(spmd_axis_name, axis_size, axis_name, main_type, args, for jaxpr in branches) out_dims = [0 if b else batching.not_mapped for b in out_bat] - out = cond_p.bind( - index, *ops, branches=branches_batched, linear=linear) + out = cond_p.bind(index, *ops, branches=branches_batched) return out, out_dims -def _cond_jvp(primals, tangents, branches, linear): +def _cond_jvp(primals, tangents, branches): nonzeros = [type(t) is not ad_util.Zero for t in tangents] index_nz, *ops_nz = nonzeros @@ -426,17 +436,14 @@ def _cond_jvp(primals, tangents, branches, linear): _, *ops_dot = tangents ops_dot = _prune_zeros(ops_dot) - ops_lin = tuple(linear) - linear_jvp = ops_lin + (True,) * len(ops_dot) - out = cond_p.bind( - index, *ops, *ops_dot, branches=branches_jvp, linear=linear_jvp) + out = cond_p.bind(index, *ops, *ops_dot, branches=branches_jvp) out_primals, out_tangents = split_list(out, [len(out_nz)]) out_tangents_iter = iter(out_tangents) out_tangents = [next(out_tangents_iter) if nz else ad_util.Zero.from_value(p) for p, nz in zip(out_primals, out_nz)] return out_primals, out_tangents -def _cond_partial_eval(trace, *tracers, branches, linear): +def _cond_partial_eval(trace, *tracers, branches): in_unknowns = [t.pval[0] is not None for t in tracers] index_uk, *ops_uk = in_unknowns if any(isinstance(eff, RefEffect) for branch in branches for eff in @@ -447,7 +454,7 @@ def _cond_partial_eval(trace, *tracers, branches, linear): if index_uk: # When the branch index is unknown, we stage out the whole cond. # TODO(mattjj): remove this path when old remat is removed - params = dict(branches=branches, linear=linear) + params = dict(branches=branches) return trace.default_process_primitive(cond_p, tracers, params) branches_out_uks = [] @@ -477,9 +484,7 @@ def _cond_partial_eval(trace, *tracers, branches, linear): for j in branches_known[1:]) in_consts = [t.pval.get_known() for t in tracers if t.pval.is_known()] - linear_known = [l for l, uk in zip(linear, ops_uk) if not uk] - out_consts_res = cond_p.bind(*in_consts, branches=branches_known, - linear=tuple(linear_known)) + out_consts_res = cond_p.bind(*in_consts, branches=branches_known) out_consts, res = split_list(out_consts_res, [len(out_consts_res) - num_res]) index_tracer = trace.instantiate_const(tracers[0]) @@ -488,9 +493,7 @@ def _cond_partial_eval(trace, *tracers, branches, linear): res_tracers = map(trace.new_instantiated_const, res) out_tracers = [pe.JaxprTracer(trace, pe.PartialVal.unknown(aval), None) for aval in branches_unknown[0].out_avals] - linear_unknown = ([False] * num_res + - [l for l, uk in zip(linear, in_unknowns[1:]) if uk]) - params = dict(branches=branches_unknown, linear=tuple(linear_unknown)) + params = dict(branches=branches_unknown) name_stack = source_info_util.current_name_stack()[len(trace.name_stack):] source = source_info_util.current().replace(name_stack=name_stack) eqn = pe.new_eqn_recipe( @@ -559,8 +562,7 @@ def _cond_partial_eval_custom(saveable, unks_in, inst_in, eqn): # Build the known eqn. ins_known, _ = partition_list(unks_in, eqn.invars) # includes index invar out_binders_known, _ = partition_list(unks_out, eqn.outvars) - linear_known = [l for l, uk in zip(eqn.params['linear'], ops_uk) if not uk] - params_known = dict(branches=branches_known, linear=tuple(linear_known)) + params_known = dict(branches=branches_known) effects_known = _join_cond_effects(branches_known) eqn_known = pe.new_jaxpr_eqn( ins_known, [*out_binders_known, *res_binders], cond_p, params_known, @@ -568,8 +570,7 @@ def _cond_partial_eval_custom(saveable, unks_in, inst_in, eqn): # Build the staged eqn. _, out_binders_staged = partition_list(inst_out, eqn.outvars) - linear_staged = [False] * len(res_binders) + list(eqn.params['linear']) - params_staged = dict(branches=branches_staged, linear=tuple(linear_staged)) + params_staged = dict(branches=branches_staged) effects_staged = _join_cond_effects(branches_staged) eqn_staged = pe.new_jaxpr_eqn( [eqn.invars[0], *res_binders, *eqn.invars[1:]], out_binders_staged, @@ -678,9 +679,7 @@ def _cond_dce_rule(used_outputs: list[bool], eqn: core.JaxprEqn, for closed_jaxpr, jaxpr in zip(closed_branches, dce_branches_)] # Finally, update parameters and form the new eqn. - dce_linear = [l for l, used in zip(eqn.params['linear'], used_inputs) if used] - new_params = dict(eqn.params, branches=tuple(dce_branches), - linear=tuple(dce_linear)) + new_params = dict(eqn.params, branches=tuple(dce_branches)) new_effects = core.join_effects(*(b.effects for b in dce_branches)) new_effects = _join_cond_effects(dce_branches_) new_eqn = pe.new_jaxpr_eqn( @@ -710,9 +709,9 @@ def transposed(*args): return _make_closed_jaxpr(transposed, res_avals + jaxpr.out_avals) -def _cond_transpose(cts, *args, branches, linear): - del linear # could use for error checking, but see #14026 +def _cond_transpose(cts, *args, branches): index, *ops = args + assert type(index) is not ad.UndefinedPrimal linear = [type(x) is ad.UndefinedPrimal for x in ops] in_avals = map(raise_to_shaped, branches[0].in_avals) num_res = len(ops) - sum(linear) @@ -730,10 +729,8 @@ def _cond_transpose(cts, *args, branches, linear): res = ops[:num_res] cts = map(ad.instantiate_zeros, cts) - linear_trans = (False,) * num_res + (True,) * len(cts) - out = cond_p.bind( - index, *res, *cts, branches=branches_trans, linear=linear_trans) + out = cond_p.bind(index, *res, *cts, branches=branches_trans) assert all(map(core.typecheck, lin_in_avals, out)) out_iter = iter(out) @@ -747,7 +744,7 @@ def _cond_axis_substitution(params, subst, traverse): branches = tuple(core.subst_axis_names_jaxpr(jaxpr, subst) for jaxpr in params['branches']) return dict(params, branches=branches) -def _cond_typecheck(bind_time, *in_atoms, branches, linear): +def _cond_typecheck(bind_time, *in_atoms, branches): if not bind_time: _, *in_atoms = in_atoms avals = [x.aval for x in in_atoms] @@ -755,14 +752,9 @@ def _cond_typecheck(bind_time, *in_atoms, branches, linear): tc(branches, 'branches', 'tuple of ClosedJaxpr', type(branches) is tuple and all(type(x) is core.ClosedJaxpr for x in branches)) - tc(linear, 'linear', 'tuple of bool', - type(linear) is tuple and all(type(x) is bool for x in linear)) if len(branches) == 0: raise core.JaxprTypeError('cond requires at least one branch function') - if len(linear) + 1 != len(avals): - raise core.JaxprTypeError(f'cond given {len(linear)} linear flags for ' - f'{len(avals) - 1} non-predicate operands') jaxpr0 = branches[0] jaxpr0_in_avals_str = _avals_short(jaxpr0.in_avals) @@ -806,14 +798,14 @@ def _cond_typecheck(bind_time, *in_atoms, branches, linear): f'called with operands of type {_avals_short(op_avals)}') return jaxpr0.out_avals, joined_effects -def cond_bind(*args, branches, linear): +def cond_bind(*args, branches): if config.enable_checks.value: avals = map(core.get_aval, args) in_atoms = [core.Var('', a) for a in avals] # dummies - _cond_typecheck(True, *in_atoms, branches=branches, linear=linear) + _cond_typecheck(True, *in_atoms, branches=branches) for jaxpr in branches: core.check_jaxpr(jaxpr.jaxpr) - return core.AxisPrimitive.bind(cond_p, *args, branches=branches, linear=linear) + return core.AxisPrimitive.bind(cond_p, *args, branches=branches) cond_p = core.AxisPrimitive('cond') cond_p.multiple_results = True @@ -831,8 +823,7 @@ def cond_bind(*args, branches, linear): pe.partial_eval_jaxpr_custom_rules[cond_p] = _cond_partial_eval_custom pe.dce_rules[cond_p] = _cond_dce_rule -def _cond_lowering(ctx, index, *args, branches, linear): - del linear # Unused. +def _cond_lowering(ctx, index, *args, branches): joined_effects = core.join_effects(*(branch.effects for branch in branches)) ordered_effects = list(effects.ordered_effects.filter_in(joined_effects)) num_tokens = len(ordered_effects) @@ -868,11 +859,11 @@ def _cond_lowering(ctx, index, *args, branches, linear): mlir.register_lowering(cond_p, _cond_lowering) @register_discharge_rule(cond_p) -def _cond_state_discharge_rule(in_avals, out_avals, *args, branches, linear): +def _cond_state_discharge_rule(in_avals, out_avals, *args, branches): discharged_branches = tuple( core.ClosedJaxpr(discharge_state(branch.jaxpr, ())[0], ()) for branch in branches) - out_vals = cond_p.bind(*args, branches=discharged_branches, linear=linear) + out_vals = cond_p.bind(*args, branches=discharged_branches) out_vals, out_ref_vals = util.split_list( out_vals, [len(out_avals)]) ref_val_iter = iter(out_ref_vals) @@ -953,7 +944,19 @@ def other_platforms_code(*args): ... # Use a switch, to get the proper transformation rules for free. Since # platform index has no dependence on the input data, it won't be vectorized # under vmap. - return switch(platform_index, branches, *args) + # If the switch and the platform_index_p above are in the same compilation + # unit then constant-folding will remove the unnecessary branches. However, + # if we run in eager mode the switch below cannot be constant-folded and + # the compilation may fail if some of the branches contain custom calls not + # recognized on the compilation platform. Detect eager mode and keep only the + # needed branch. + try: + platform_index_concrete = core.concrete_or_error(operator.index, platform_index) + except core.ConcretizationTypeError: + return switch(platform_index, branches, *args) + else: + assert 0 <= platform_index_concrete < len(branches) + return branches[platform_index_concrete](*args) # A primitive to compute the index of a platform into a list of platforms. # Args: @@ -975,7 +978,9 @@ def _platform_index_lowering(ctx: mlir.LoweringRuleContext, *, platforms: Sequence[Sequence[str]], has_default: bool): - def lower_constant(ctx: mlir.LoweringRuleContext, *, i: int) -> mlir.ir.Value: + def lower_constant( + ctx: mlir.LoweringRuleContext, *, i: int + ) -> Sequence[ir.Value]: return mlir.ir_constants(np.int32(i)) platform_rules: dict[str, mlir.LoweringRule] = {} for i, ps in enumerate(platforms): diff --git a/jax/_src/lax/control_flow/for_loop.py b/jax/_src/lax/control_flow/for_loop.py index 76cdcf0b4855..936656b0e7df 100644 --- a/jax/_src/lax/control_flow/for_loop.py +++ b/jax/_src/lax/control_flow/for_loop.py @@ -15,10 +15,10 @@ from __future__ import annotations -from collections.abc import Sequence +from collections.abc import Callable, Sequence import functools import operator -from typing import Any, Callable, Generic, TypeVar +from typing import Any, Generic, TypeVar import jax.numpy as jnp from jax import lax @@ -177,7 +177,7 @@ def wrapped_body(i, refs): def scan(f: Callable[[Carry, X], tuple[Carry, Y]], init: Carry, - xs: X, + xs: X | None = None, length: int | None = None, reverse: bool = False, unroll: int = 1) -> tuple[Carry, Y]: @@ -441,7 +441,7 @@ def _for_partial_eval(trace: pe.JaxprTrace, *tracers: pe.JaxprTracer, else: raise Exception("Invalid fixpoint") del out_unknowns # redundant since it's the same as `in_unknowns` - tracers = tuple(trace.instantiate_const(t) if uk else t # type: ignore + tracers = tuple(trace.instantiate_const(t) if uk else t for t, uk in zip(tracers, in_unknowns)) # We use `partial_eval_jaxpr_custom` here because it won't remove effectful diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py index e939f2170f8a..0c704b84475b 100644 --- a/jax/_src/lax/control_flow/loops.py +++ b/jax/_src/lax/control_flow/loops.py @@ -14,55 +14,69 @@ """Module for the loop primitives.""" from __future__ import annotations -from collections.abc import Sequence +from collections.abc import Callable, Sequence from functools import partial import inspect import itertools import operator -from typing import Any, Callable, TypeVar +from typing import Any, TypeVar +import weakref import jax -import weakref -from jax._src import config -from jax._src import core -from jax._src import linear_util as lu -from jax._src.core import ConcreteArray, ShapedArray, raise_to_shaped -from jax.tree_util import (tree_flatten, tree_unflatten, treedef_is_leaf, - tree_map, tree_flatten_with_path, keystr) -from jax._src.api_util import shaped_abstractify -from jax._src.tree_util import equality_errors from jax._src import ad_checkpoint from jax._src import ad_util from jax._src import api +from jax._src import config +from jax._src import core from jax._src import dispatch from jax._src import dtypes from jax._src import effects +from jax._src import linear_util as lu from jax._src import source_info_util +from jax._src import state from jax._src import util +from jax._src.api_util import shaped_abstractify +from jax._src.core import ConcreteArray, ShapedArray, raise_to_shaped from jax._src.interpreters import ad from jax._src.interpreters import batching from jax._src.interpreters import mlir from jax._src.interpreters import partial_eval as pe +from jax._src.interpreters import pxla from jax._src.interpreters import xla from jax._src.lax import lax from jax._src.lax import slicing from jax._src.lax import windowed_reductions +from jax._src.lax.control_flow.common import ( + _abstractify, _avals_short, _check_tree_and_avals, _initial_style_jaxpr, + _initial_style_jaxpr_attrs, _make_closed_jaxpr_attrs, _prune_zeros, + _typecheck_param) from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import hlo -from jax._src import state -from jax._src.state import discharge as state_discharge from jax._src.numpy.ufuncs import logaddexp +from jax._src.state import discharge as state_discharge from jax._src.traceback_util import api_boundary +from jax._src.tree_util import equality_errors from jax._src.typing import Array -from jax._src.util import (partition_list, safe_map, safe_zip, split_list, - unzip2, weakref_lru_cache, merge_lists) +from jax._src.util import ( + merge_lists, + partition_list, + safe_map, + safe_zip, + split_list, + split_list_checked, + unzip2, + weakref_lru_cache, +) +from jax.tree_util import ( + keystr, + tree_flatten, + tree_flatten_with_path, + tree_map, + tree_unflatten, + treedef_is_leaf, +) import numpy as np -from jax._src.lax.control_flow.common import ( - _abstractify, _avals_short, _check_tree_and_avals, _initial_style_jaxpr, - _initial_style_jaxpr_attrs, _make_closed_jaxpr_attrs, _prune_zeros, - _typecheck_param) - _map = safe_map zip = safe_zip @@ -104,10 +118,11 @@ def _promote_weak_typed_inputs(in_vals, in_avals, out_avals): @api_boundary def scan(f: Callable[[Carry, X], tuple[Carry, Y]], init: Carry, - xs: X, + xs: X | None = None, length: int | None = None, reverse: bool = False, - unroll: int | bool = 1) -> tuple[Carry, Y]: + unroll: int | bool = 1, + _split_transpose: bool = False) -> tuple[Carry, Y]: """Scan a function over leading array axes while carrying along state. The `Haskell-like type signature`_ in brief is @@ -183,6 +198,11 @@ def scan(f, init, xs, length=None): the loop. If a boolean is provided, it will determine if the loop is competely unrolled (i.e. `unroll=True`) or left completely unrolled (i.e. `unroll=False`). + _split_transpose: experimental optional bool specifying whether to further + split the transpose into a scan (computing activation gradients), and a + map (computing gradients corresponding to the array arguments). Enabling + this may increase memory requirements, and so is an experimental feature + that may evolve or even be rolled back. Returns: A pair of type ``(c, [b])`` where the first element represents the final @@ -227,7 +247,7 @@ def scan(f, init, xs, length=None): ys = [] maybe_reversed = reversed if reverse else lambda x: x for i in maybe_reversed(range(length)): - xs_slice = [_index_array(i, core.get_aval(x), x) for x in xs_flat] + xs_slice = [slicing.index_in_dim(x, i, keepdims=False) for x in xs_flat] carry, y = f(carry, tree_unflatten(xs_tree, xs_slice)) ys.append(y) stack = lambda *ys: jax.numpy.stack(ys) @@ -273,6 +293,8 @@ def _create_jaxpr(init): if isinstance(unroll, bool): unroll = max(length, 1) if unroll else 1 + if unroll < 1: + raise ValueError("`unroll` must be a `bool` or a positive `int`.") if attrs_tracked: in_state = _get_states(attrs_tracked) in_carry, in_ext = split_list(in_flat, [num_carry]) @@ -282,20 +304,29 @@ def _create_jaxpr(init): reverse=reverse, length=length, jaxpr=jaxpr, num_consts=len(consts), num_carry=num_carry, linear=(False,) * (len(consts) + len(in_flat)), - unroll=unroll) + unroll=unroll, + _split_transpose=_split_transpose) if attrs_tracked: out_state, out = split_list(out, [len(attrs_tracked)]) _set_states(attrs_tracked, out_state) return tree_unflatten(out_tree, out) def _set_states(attrs_tracked, vals): - from jax.experimental.attrs import jax_setattr # type: ignore - for ((obj, attr), val) in zip(attrs_tracked, vals): + from jax.experimental.attrs import jax_setattr + valss = split_list_checked(vals, [td.num_leaves for _, td, _ in attrs_tracked]) + for ((_, treedef, (obj, attr)), leaves) in zip(attrs_tracked, valss): + val = tree_unflatten(treedef, leaves) jax_setattr(obj, attr, val) def _get_states(attrs_tracked): - from jax.experimental.attrs import jax_getattr # type: ignore - return [jax_getattr(obj, attr) for (obj, attr) in attrs_tracked] + from jax.experimental.attrs import jax_getattr + vals = [] + for treedef, _, (obj, attr) in attrs_tracked: + tree = jax_getattr(obj, attr) + leaves, treedef_ = tree_flatten(tree) + assert treedef == treedef_ + vals.extend(leaves) + return vals def _check_scan_carry_type(body_fun, in_carry, out_carry_tree, out_avals): try: @@ -361,177 +392,91 @@ def _aval_mismatch_extra(a1: core.AbstractValue, a2: core.AbstractValue) -> str: 'the shapes do not match' * shape_mismatch) return '' - -def _scan_impl_unrolled(*args, reverse, length, num_consts, num_carry, linear, - f_impl, x_avals, y_avals): - consts, init, xs = split_list(args, [num_consts, num_carry]) - - carry = init - ys = [] - - for i in range(length): - i_ = length - i - 1 if reverse else i - x = _map(partial(_index_array, i_), x_avals, xs) - out = f_impl(*consts, *carry, *x) - carry, y = split_list(out, [num_carry]) - ys.append(y) - - ys = list(reversed(ys)) if reverse else ys - ys = list(zip(*ys)) - ys = _map(_stack, y_avals, ys) - return (*carry, *ys) - -def _scan_impl_loop(*args, reverse, length, num_consts, num_carry, linear, - f_impl, x_avals, y_avals): - consts, init, xs = split_list(args, [num_consts, num_carry]) - - def cond_fun(vals): - i, *_ = vals - return i < length - - def body_fun(vals): - [i], carry, ys = split_list(vals, [1, num_carry]) - i_ = length - i - 1 if reverse else i - # TODO(jakevdp)[key-reuse]: this key reuse logic is not quite right, - # because the scan body may consume any keys within it. - # Import here to avoid circular imports - from jax.experimental import key_reuse - xs_unconsumed = _map(key_reuse.reuse_key, xs) - x = _map(partial(_dynamic_index_array, i_), x_avals, xs_unconsumed) - out_flat = f_impl(*consts, *carry, *x) - carry_out, y_updates = split_list(out_flat, [num_carry]) - ys_out = _map(partial(_update_array, i_), y_avals, ys, y_updates) - return [i + 1] + carry_out + ys_out - - # TODO(jakevdp)[key-reuse]: mark xs consumed here if f_impl consumes them. - - ys_init = _map(partial(_empty_array, length), y_avals) - if length == 0: - return init + ys_init - else: - init_val = [lax._const(length, 0)] + init + ys_init - _, *outs = while_loop(cond_fun, body_fun, init_val) - return outs - -def _scan_impl_block_unrolled(*args, reverse, length, num_consts, num_carry, - linear, block_length, f_impl, x_avals, y_avals): - consts, init, xs = split_list(args, [num_consts, num_carry]) - - num_blocks, rem = divmod(length, block_length) - assert rem == 0 - - partition = partial(_partition_leading, num_blocks, block_length) - xs_block = _map(partition, x_avals, xs) - - prepend_aval = partial(_prepend_dim_to_aval, block_length) - x_block_avals = _map(prepend_aval, x_avals) - y_block_avals = _map(prepend_aval, y_avals) - - f_impl_block = partial( - _scan_impl_unrolled, reverse=reverse, length=block_length, - num_consts=num_consts, num_carry=num_carry, linear=linear, - f_impl=f_impl, x_avals=x_avals, y_avals=y_avals) - - outs = _scan_impl_loop( - *consts, *init, *xs_block, reverse=reverse, length=num_blocks, - num_consts=num_consts, num_carry=num_carry, linear=linear, - f_impl=f_impl_block, x_avals=x_block_avals, y_avals=y_block_avals) - - carry, ys_blocks = split_list(outs, [num_carry]) - combine = partial(_combine_leading, num_blocks, block_length) - ys = _map(combine, y_avals, ys_blocks) - return (*carry, *ys) - +# TODO(mattjj): re-land #19819 version? simpler, but caused ~1 perf regression. def _scan_impl(*args, reverse, length, num_consts, num_carry, jaxpr, linear, - unroll): - _, _, x_avals = split_list(jaxpr.in_avals, [num_consts, num_carry]) + unroll, _split_transpose): + del _split_transpose + consts, carry, xs_ = split_list(args, [num_consts, num_carry]) _, y_avals = split_list(jaxpr.out_avals, [num_carry]) - f_impl = core.jaxpr_as_fun(jaxpr) - + num_trips, remainder = divmod(length, unroll) if unroll == 1: - return _scan_impl_loop( - *args, reverse=reverse, length=length, num_consts=num_consts, - num_carry=num_carry, linear=linear, f_impl=f_impl, x_avals=x_avals, - y_avals=y_avals) - - consts, init, xs = split_list(args, [num_consts, num_carry]) - num_blocks, rem = divmod(length, unroll) - length_div = num_blocks * unroll - - if rem > 0: - if reverse: - split = partial(_split_leading_dim, rem) - xs_rem, xs = unzip2(_map(split, x_avals, xs)) - else: - split = partial(_split_leading_dim, length_div) - xs, xs_rem = unzip2(_map(split, x_avals, xs)) - - outs = _scan_impl_block_unrolled( - *consts, *init, *xs, reverse=reverse, length=length_div, - num_consts=num_consts, num_carry=num_carry, linear=linear, - block_length=unroll, f_impl=f_impl, x_avals=x_avals, y_avals=y_avals) - - carry, ys = split_list(outs, [num_carry]) - - if rem > 0: - outs = _scan_impl_unrolled( - *consts, *carry, *xs_rem, reverse=reverse, length=rem, - num_consts=num_consts, num_carry=num_carry, linear=linear, - f_impl=f_impl, x_avals=x_avals, y_avals=y_avals) - carry, ys_rem = split_list(outs, [num_carry]) - if reverse: - ys = _map(_concatenate, y_avals, ys_rem, ys) - else: - ys = _map(_concatenate, y_avals, ys, ys_rem) - - return (*carry, *ys) - -def _stack(aval, vals): - vals = [lax.expand_dims(x, (0,)) for x in vals] - return lax.concatenate(vals, 0) - -def _concatenate(aval, x1, x2): - return lax.concatenate([x1, x2], 0) - -def _split_leading_dim(i, aval, x): - assert x.ndim >= 1 - return (slicing.slice_in_dim(x, 0, i), - slicing.slice_in_dim(x, i, x.shape[0])) - -def _dynamic_index_array(i, aval, x): - return slicing.dynamic_index_in_dim(x, i, keepdims=False) - -def _index_array(i, aval, x): - return slicing.index_in_dim(x, i, keepdims=False) - -def _empty_array(sz, aval): - return lax.broadcast(lax.empty(aval.dtype), (sz, *aval.shape)) - -def _update_array(i, aval, xs, x): - return slicing.dynamic_update_index_in_dim(xs, x, i, 0) - -def _partition_leading(sz0, sz1, aval, x): - assert x.ndim >= 1 - assert x.shape[0] == sz0 * sz1 - return lax.reshape(x, (sz0, sz1, *x.shape[1:])) - -def _combine_leading(sz0, sz1, aval, x): - assert x.ndim >= 2 - assert x.shape[0] == sz0 - assert x.shape[1] == sz1 - return lax.collapse(x, 0, 2) + xss = xs_ + yss = _map(partial(_empty_array, (length,)), y_avals) + else: + if remainder: + if not reverse: + xs_, xs_rem = unzip2(_map(partial(_split_leading, num_trips*unroll), xs_)) + else: + xs_rem, xs_ = unzip2(_map(partial(_split_leading, remainder), xs_)) + xss = [lax.reshape(x, (num_trips, unroll, *x.shape[1:])) for x in xs_] + yss = _map(partial(_empty_array, (num_trips, unroll)), y_avals) + + def cond_fun(while_carry): + i, _, _ = while_carry + return i < num_trips + def body_fun(while_carry): + i_, carry, yss = while_carry + i = num_trips - i_ - 1 if reverse else i_ + xs = [slicing.dynamic_index_in_dim(xs, i, keepdims=False) for xs in xss] + carry, ys = inner(unroll, carry, xs) + yss = [slicing.dynamic_update_index_in_dim(ys, upd, i, 0) + for ys, upd in zip(yss, ys)] + return i_ + 1, carry, yss + def inner(n, carry, xs): + ys = [] + if unroll == 1: + carry_y = eval_jaxpr_p.bind(*consts, *carry, *xs, jaxpr=jaxpr) + return split_list(carry_y, [num_carry]) + for i_ in range(n): + i = n - i_ - 1 if reverse else i_ + x = [slicing.index_in_dim(x, i, keepdims=False) for x in xs] + carry_y = eval_jaxpr_p.bind(*consts, *carry, *x, jaxpr=jaxpr) + carry, y = split_list(carry_y, [num_carry]) + ys.append(y) + ys = list(reversed(ys)) if reverse else ys + return carry, _map(jax.numpy.stack, zip(*ys)) + + if num_trips: + i = lax._const(num_trips, 0) + _, carry, yss = jax.lax.while_loop(cond_fun, body_fun, (i, carry, yss)) + if unroll != 1: + ys = [lax.reshape(ys, (num_trips * unroll, *ys.shape[2:])) for ys in yss] + else: + ys = yss + if remainder: + carry, ys_rem = inner(remainder, carry, xs_rem) + ys = _map(_concat, ys, ys_rem) if not reverse else _map(_concat, ys_rem, ys) + return [*carry, *ys] + +def _split_leading(sz, x): + return (slicing.slice_in_dim(x, 0, sz), + slicing.slice_in_dim(x, sz, x.shape[0])) + +def _concat(a, b): return lax.concatenate([a, b], 0) + +def _empty_array(prefix, aval): + return lax.broadcast(lax.empty(aval.dtype), (*prefix, *aval.shape)) + +eval_jaxpr_p = core.Primitive('eval_jaxpr') +eval_jaxpr_p.multiple_results = True +def _stage_jaxpr(trace, *tracers, jaxpr): + params = dict(call_jaxpr=jaxpr) + return trace.default_process_primitive(core.closed_call_p, tracers, params) +pe.custom_staging_rules[eval_jaxpr_p] = _stage_jaxpr +@eval_jaxpr_p.def_effectful_abstract_eval # abstract eval only used for jax2tf +def _stage_jaxpr_abstract_eval(*_, jaxpr): return jaxpr.out_avals, jaxpr.effects def _prepend_dim_to_aval(sz, aval): return core.unmapped_aval(sz, core.no_axis_name, 0, aval) def _scan_abstract_eval(*args, reverse, length, num_consts, num_carry, jaxpr, - linear, unroll): + linear, unroll, _split_transpose): carry_avals, y_avals = split_list(jaxpr.out_avals, [num_carry]) ys_avals = _map(partial(_prepend_dim_to_aval, length), y_avals) return carry_avals + ys_avals, jaxpr.effects def _scan_jvp(primals, tangents, reverse, length, jaxpr, num_consts, num_carry, - linear, unroll): + linear, unroll, _split_transpose): num_xs = len(jaxpr.in_avals) - num_carry - num_consts num_ys = len(jaxpr.out_avals) - num_carry nonzeros = [type(t) is not ad_util.Zero for t in tangents] @@ -577,7 +522,8 @@ def _scan_jvp(primals, tangents, reverse, length, jaxpr, num_consts, num_carry, reverse=reverse, length=length, jaxpr=jaxpr_jvp_rearranged, num_consts=num_consts + len(consts_dot), num_carry=num_carry + len(init_dot), - linear=jaxpr_jvp_linear, unroll=unroll) + linear=jaxpr_jvp_linear, unroll=unroll, + _split_transpose=_split_transpose) carry, carry_dot, ys, ys_dot = split_list(out_flat, [num_carry, len(init_dot), num_ys]) primals_out = carry + ys @@ -587,7 +533,7 @@ def _scan_jvp(primals, tangents, reverse, length, jaxpr, num_consts, num_carry, return primals_out, tangents_out def _scan_partial_eval(trace, *tracers, reverse, length, num_consts, num_carry, - jaxpr, linear, unroll): + jaxpr, linear, unroll, _split_transpose): num_ys = len(jaxpr.out_avals) - num_carry unknowns = [not t.pval.is_known() for t in tracers] const_uk, init_uk, xs_uk = split_list(unknowns, [num_consts, num_carry]) @@ -685,7 +631,8 @@ def _scan_partial_eval(trace, *tracers, reverse, length, num_consts, num_carry, out_known = scan_p.bind( *known_inputs, reverse=reverse, length=length, jaxpr=jaxpr_known, num_consts=len(jaxpr_known_consts), num_carry=num_carry - sum(carry_uk), - linear=tuple(linear_known), unroll=unroll) + linear=tuple(linear_known), unroll=unroll, + _split_transpose=_split_transpose) del linear_known # Complete the known output by filling in forwarded values using fwds_known. out_known_iter = iter(out_known) @@ -721,7 +668,8 @@ def _scan_partial_eval(trace, *tracers, reverse, length, num_consts, num_carry, dict(reverse=reverse, length=length, unroll=unroll, jaxpr=jaxpr_unknown, linear=linear_unknown, num_consts=len(intensive_res) + sum(const_uk), - num_carry=sum(carry_uk)), + num_carry=sum(carry_uk), + _split_transpose=_split_transpose), jaxpr_unknown.effects, source) for t in out_tracers: t.recipe = eqn @@ -730,17 +678,15 @@ def _scan_partial_eval(trace, *tracers, reverse, length, num_consts, num_carry, def _maybe_put(x): if isinstance(x, np.ndarray): - return dispatch._put_x( - x, - jax.sharding.SingleDeviceSharding(jax.local_devices(backend='cpu')[0]), - shaped_abstractify(x), - False, - ) + aval = shaped_abstractify(x) + s = jax.sharding.SingleDeviceSharding(jax.local_devices(backend='cpu')[0]) + result_handler = pxla.global_aval_to_result_handler(aval, s, False) + return result_handler(pxla.shard_args([s], [x])) else: return x def _scan_transpose(cts, *args, reverse, length, num_consts, - num_carry, jaxpr, linear, unroll): + num_carry, jaxpr, linear, unroll, _split_transpose): # we've only implemented transposing scans with specific lin/nonlin patterns consts_lin, init_lin, xs_lin = split_list(linear, [num_consts, num_carry]) num_ires = len(consts_lin) - sum(consts_lin) @@ -774,12 +720,103 @@ def _scan_transpose(cts, *args, reverse, length, num_consts, [True] * (len(ct_consts) + len(ct_carry) + len(ct_ys)) + [False] * num_eres) in_state = _get_states(attrs_tracked) - outs = scan_p.bind( - *ires, *in_state, *ct_consts, *ct_carry, *ct_ys, *eres, - reverse=not reverse, length=length, jaxpr=jaxpr_trans, - num_consts=num_ires, - num_carry=num_consts-num_ires+num_carry+len(attrs_tracked), - linear=tuple(linear_trans), unroll=unroll) + + transpose_inputs = *ires, *in_state, *ct_consts, *ct_carry, *ct_ys, *eres + transpose_num_out_carry = num_consts-num_ires+num_carry+len(attrs_tracked) + + if not _split_transpose: + outs = scan_p.bind( + *transpose_inputs, + reverse=not reverse, length=length, jaxpr=jaxpr_trans, + num_consts=num_ires, + num_carry=transpose_num_out_carry, + linear=tuple(linear_trans), unroll=unroll, + _split_transpose=False) + else: + inst_mask = [False] * transpose_num_out_carry + [True] * ( + len(jaxpr_trans.out_avals) - transpose_num_out_carry) + + unknowns_mask = [False] * (len(transpose_inputs) - len(eres)) + [ + True + ] * len(eres) + + # The residuals may contain original parameters (e.g. forwarded extensive + # array arguments) and residuals from the primal. Hence we iterate and + # update all values of the mask that we've set to True (i.e. 'unknown') to + # see if we should actually push them to the known computation in order to + # perform the scan (known) - map (unknown) split. The test effectively is + # done by comparing the output masks. + # + # TODO(dvytin): improve performance by doing backwards abstract eval. + # + # For example, a mask arising from a relu() is an extensive residual, yet + # only really used in the backpropagation scan, not in the unknown map. But + # an intermediate activation of a matmul will be used only in the map part. + # If we were to erroneously push the relu mask to the unknown part, then, + # in the output, the partial evaluator will also pull the loop-carried state + # to the unknown, and that is something we can test by comparing the output + # mask of pe against our intended inst mask. + for index in range(len(jaxpr_trans.in_avals)): + if unknowns_mask[index]: + mask_for_dependence = [False]*len(jaxpr_trans.in_avals) + mask_for_dependence[index] = True # try moving this to unknown + _, _, outs_for_dependence, _ = pe.partial_eval_jaxpr_nounits( + jaxpr_trans, mask_for_dependence, inst_mask) + if inst_mask != outs_for_dependence: + unknowns_mask[index] = False + + jaxpr_known_body, jaxpr_unknown_body, outs_mask, res_avals = ( + pe.partial_eval_jaxpr_nounits(jaxpr_trans, unknowns_mask, inst_mask) + ) + + num_knowns = len(outs_mask) - sum(outs_mask) + + linear_list = list(linear_trans) + known_linear = [ + l for mask, l in zip(unknowns_mask, linear_list) if not mask + ] + unknown_linear = [l for mask, l in zip(unknowns_mask, linear_list) if mask] + unknown_linear = [False] * len(res_avals) + unknown_linear + + known_args = [ + arg for mask, arg in zip(unknowns_mask, transpose_inputs) if not mask + ] + unknown_args = [ + arg for mask, arg in zip(unknowns_mask, transpose_inputs) if mask + ] + # 1. Apply the known scan. + knowns_and_residual = scan_p.bind( + *known_args, + reverse=not reverse, + length=length, + num_consts=num_ires, + num_carry=transpose_num_out_carry, + jaxpr=jaxpr_known_body, + linear=tuple(known_linear), + unroll=unroll, + _split_transpose=False, # Just generate the loop now. + ) + known_results, residuals = split_list(knowns_and_residual, [num_knowns]) + + # 2. Apply the unknown map to residuals and unknown arguments. + unknown_results = scan_p.bind( + *residuals, *unknown_args, + reverse=reverse, # Keep reverse as is for better scheduling. + length=length, + num_consts=0, + num_carry=0, + jaxpr=jaxpr_unknown_body, + linear=tuple(unknown_linear), + unroll=unroll, + _split_transpose=False, # Just generate the loop now. + ) + known_results_iter = iter(known_results) + unknown_results_iter = iter(unknown_results) + outs = [ + next(known_results_iter) if not mask else next(unknown_results_iter) + for mask in outs_mask + ] + out_state, outs = split_list(outs, [len(attrs_tracked)]) _set_states(attrs_tracked, out_state) ct_consts, ct_init, ct_xs = split_list(outs, [num_consts - num_ires, num_carry]) @@ -831,7 +868,8 @@ def transposed(*res1_cbar_bbar_res2): def _scan_batching_rule(spmd_axis_name, axis_size, axis_name, main_type, args, dims, reverse, length, - jaxpr, num_consts, num_carry, linear, unroll): + jaxpr, num_consts, num_carry, linear, unroll, + _split_transpose): num_ys = len(jaxpr.out_avals) - num_carry orig_batched = [d is not batching.not_mapped for d in dims] const_batched, init_batched, xs_batched = split_list(orig_batched, [num_consts, num_carry]) @@ -872,7 +910,8 @@ def _scan_batching_rule(spmd_axis_name, axis_size, axis_name, main_type, args, outs = scan_p.bind( *new_args, reverse=reverse, length=length, jaxpr=jaxpr_batched, - num_consts=num_consts, num_carry=num_carry, linear=linear, unroll=unroll) + num_consts=num_consts, num_carry=num_carry, linear=linear, unroll=unroll, + _split_transpose=_split_transpose) carry_bdims = [0 if b else batching.not_mapped for b in carry_batched] ys_bdims = [1 if b else batching.not_mapped for b in ys_batched] return outs, carry_bdims + ys_bdims @@ -1031,7 +1070,8 @@ def known(*ins_known): return eqn_known, eqn_staged, unks_out, inst_out, new_vars def _scan_typecheck(bind_time, *in_atoms, reverse, length, num_consts, - num_carry, jaxpr, linear, unroll): + num_carry, jaxpr, linear, unroll, _split_transpose): + del _split_transpose if not bind_time: _, *in_atoms = in_atoms avals = [x.aval for x in in_atoms] @@ -1091,61 +1131,81 @@ def _scan_pp_rule(eqn, context, settings): del printed_params['num_consts'] if not printed_params['reverse']: del printed_params['reverse'] + if not printed_params['_split_transpose']: + del printed_params['_split_transpose'] return core._pp_eqn(eqn.replace(params=printed_params), context, settings) def _scan_state_discharge_rule(in_avals, out_avals, *args, jaxpr, num_consts, - num_carry, linear, unroll, reverse, length): - jaxpr, consts = jaxpr.jaxpr, jaxpr.consts + num_carry, linear, unroll, reverse, length, + _split_transpose): + # We're shuffling parameters between three signatures for the scan body: + # jaxpr : (n_consts, n_carry, n_xs) -> (n_carry, n_ys) + # discharged : (n_consts, n_carry, n_xs) -> (n_carry, n_ys, n_ref_consts, n_ref_xs) + # wrapped : (n_val_consts, (n_ref_consts, n_carry), (n_val_xs, n_ref_xs)) + # -> ((n_ref_consts, n_carry), (n_ys, n_ref_xs)) + # where we partition consts and xs between ref and non-ref versions: + # n_carry = (n_val_consts, n_ref_consts) + # n_xs = (n_val_xs, n_ref_xs) + + # avals from jaxpr (i.e. rank-reduced) rather than from caller + jaxpr, in_avals, out_avals, consts = jaxpr.jaxpr, jaxpr.in_avals, jaxpr.out_avals, jaxpr.consts if consts: raise NotImplementedError - consts, carry, xs = split_list(args, [num_consts, num_carry]) - consts_linear, carry_linear, xs_linear = split_list( - linear, [num_consts, num_carry]) - consts_avals, carry_avals, xs_avals = split_list(in_avals, - [num_consts, num_carry]) - is_ref = [isinstance(a, state.AbstractRef) for a in consts_avals] - remaining_const_avals, in_ref_avals = partition_list(is_ref, consts_avals) - remaining_consts, in_refs = partition_list(is_ref, consts) - remaining_consts_linear, in_refs_linear = partition_list(is_ref, consts_linear) - num_refs = sum(is_ref) - num_extensive_in = len(in_avals) - num_carry - num_consts - num_extensive_out = len(out_avals) - num_carry - num_remaining_consts = num_consts - num_refs + n_consts = num_consts + n_carry = num_carry + n_xs = len(in_avals) - n_consts - n_carry + n_ys = len(out_avals) - n_carry + consts_avals, carry_avals, xs_avals = split_list_checked(in_avals, + [n_consts, n_carry, n_xs]) + is_ref_const = [isinstance(a, state.AbstractRef) for a in consts_avals] + assert not any(isinstance(a, state.AbstractRef) for a in carry_avals) + is_ref_xs = [isinstance(a, state.AbstractRef) for a in xs_avals] + n_ref_consts = sum(is_ref_const) + n_val_consts = n_consts - n_ref_consts + n_ref_xs = sum(is_ref_xs) + n_val_xs = n_xs - n_ref_xs discharged_jaxpr, discharged_consts = state_discharge.discharge_state(jaxpr, ()) if discharged_consts: raise NotImplementedError("Discharged jaxpr has consts. If you see this, " "please open an issue at " "https://github.com/google/jax/issues") - # The discharged jaxpr will have output refs stashed at the end - def wrapped(*refs_and_args): - consts, refs, carry, xs = split_list(refs_and_args, [num_remaining_consts, - num_refs, - num_carry]) - consts_with_refs = merge_lists(is_ref, consts, refs) - outs_and_refs = core.eval_jaxpr(discharged_jaxpr, (), *consts_with_refs, - *carry, *xs) - carry, ys, out_refs = split_list(outs_and_refs, [num_carry, - num_extensive_out]) - assert len(out_refs) == num_refs - return [*out_refs, *carry, *ys] - new_in_avals = [*remaining_const_avals, *[a.inner_aval for a in in_ref_avals], - *carry_avals, - *[core.mapped_aval(length, 0, a) for a in xs_avals]] - new_jaxpr, _, (), () = pe.trace_to_jaxpr_dynamic(lu.wrap_init(wrapped), new_in_avals) - new_linear = (*remaining_consts_linear, *in_refs_linear, - *carry_linear, *xs_linear) - all_out = scan_p.bind(*remaining_consts, *in_refs, *carry, *xs, + def wrapped(*wrapped_args): + val_consts, ref_consts_in, carry_in, val_xs, ref_xs_in = split_list_checked(wrapped_args, + [n_val_consts, n_ref_consts, n_carry, n_val_xs, n_ref_xs]) + consts = merge_lists(is_ref_const, val_consts, ref_consts_in) + xs = merge_lists(is_ref_xs, val_xs, ref_xs_in) + outs = core.eval_jaxpr(discharged_jaxpr, (), *consts, *carry_in, *xs) + carry_out, ys, ref_consts_out, ref_xs_out = split_list_checked(outs, + [n_carry, n_ys, n_ref_consts, n_ref_xs]) + return [*ref_consts_out, *carry_out, *ys, *ref_xs_out] + + def arrange_jaxpr_args_for_wrapped(args): + consts, carry_in, xs = split_list_checked(args, [n_consts, n_carry, n_xs]) + val_consts, ref_consts_in = partition_list(is_ref_const, consts) + val_xs, ref_xs_in = partition_list(is_ref_xs, xs) + return *val_consts, *ref_consts_in, *carry_in, *val_xs, *ref_xs_in + + args_for_wrapped = arrange_jaxpr_args_for_wrapped(args) + linear_for_wrapped = arrange_jaxpr_args_for_wrapped(linear) + avals_for_wrapped = arrange_jaxpr_args_for_wrapped(in_avals) + avals_for_wrapped_no_refs = [aval.inner_aval if isinstance(aval, state.AbstractRef) else aval + for aval in avals_for_wrapped] + new_jaxpr, _, (), () = pe.trace_to_jaxpr_dynamic(lu.wrap_init(wrapped), avals_for_wrapped_no_refs) + all_out = scan_p.bind(*args_for_wrapped, jaxpr=core.ClosedJaxpr(new_jaxpr, ()), length=length, - num_consts=num_remaining_consts, - num_carry=num_refs + num_carry, + num_consts=n_val_consts, + num_carry=n_ref_consts + n_carry, unroll=unroll, reverse=reverse, - linear=new_linear) - refs_out, carry_out, ys_out = split_list(all_out, [num_refs, num_carry]) - new_invals = [*merge_lists(is_ref, [None] * num_remaining_consts, refs_out), - *[None] * num_carry, *[None] * num_extensive_in] - assert len(new_invals) == len(in_avals) - return new_invals, [*carry_out, *ys_out] + linear=linear_for_wrapped, _split_transpose=_split_transpose) + ref_consts_out, carry_out, ys, ref_xs_out = split_list_checked(all_out, + [n_ref_consts, n_carry, n_ys, n_ref_xs]) + refs_out_matching_in_avals = [ + *merge_lists(is_ref_const, [None] * n_val_consts, ref_consts_out), + *[None] * n_carry, + *merge_lists(is_ref_xs, [None] * n_val_xs, ref_xs_out)] + assert len(refs_out_matching_in_avals) == len(in_avals) + return refs_out_matching_in_avals, [*carry_out, *ys] def scan_bind(*args, **params): if config.enable_checks.value: @@ -1176,6 +1236,11 @@ def scan_bind(*args, **params): # TODO(mattjj,frostig): un-comment this pp rule # core.pp_eqn_rules[scan_p] = _scan_pp_rule +def _propagate_mem_kind_scan(*xm, reverse, length, num_consts, num_carry, jaxpr, + linear, unroll, _split_transpose): + return pxla.get_out_memory_kinds_via_propagation(jaxpr) +pxla.memory_kind_propagate_rule[scan_p] = _propagate_mem_kind_scan + ### while_loop @api_boundary @@ -1322,10 +1387,8 @@ def _while_loop_batching_rule(spmd_axis_name, axis_size, axis_name, main_type, args, dims, cond_nconsts, cond_jaxpr, body_nconsts, body_jaxpr): from jax._src.callback import _IOEffect, _OrderedIOEffect - if any(eff in branch.effects for eff in [_IOEffect, _OrderedIOEffect] - for branch in [body_jaxpr, cond_jaxpr]): - raise NotImplementedError( - "IO effect not supported in vmap-of-while.") + if any(_OrderedIOEffect in fn.effects for fn in [body_jaxpr, cond_jaxpr]): + raise Exception("Ordered IO effects not supported in vmap.") orig_batched = [d is not batching.not_mapped for d in dims] cconst_bat, bconst_bat, init_bat = split_list(orig_batched, [cond_nconsts, body_nconsts]) @@ -1356,6 +1419,9 @@ def _while_loop_batching_rule(spmd_axis_name, axis_size, axis_name, main_type, if pred_bat: # If the predicate is batched, we have to batch *all* of the carry # regardless of if the body needs it. + if any(_IOEffect in fn.effects for fn in [body_jaxpr, cond_jaxpr]): + raise Exception("Unordered IO effects not supported in while_loop " + "with batched predicate") carry_bat = [True] * len(carry_bat) carry_dims = [0] * len(carry_bat) body_jaxpr_batched, _ = batching.batch_jaxpr_axes( @@ -1482,7 +1548,7 @@ def _while_partial_eval(trace: pe.JaxprTrace, *tracers: pe.Tracer, cond_nconsts: # carry to unknown. We need one last iteration to prepare the jaxpr. carry_uk = carry_init_uk for _ in range(1 + len(carry_uk)): - body_jaxpr_known, _, carry_out_uk, body_res_avals = pe.partial_eval_jaxpr_nounits( # type: ignore + body_jaxpr_known, _, carry_out_uk, body_res_avals = pe.partial_eval_jaxpr_nounits( body_jaxpr, body_consts_uk + carry_uk, instantiate=carry_uk) if carry_out_uk == carry_uk: break @@ -1491,7 +1557,7 @@ def _while_partial_eval(trace: pe.JaxprTrace, *tracers: pe.Tracer, cond_nconsts: else: assert False, "Fixpoint not reached" - cond_jaxpr_known, _, cond_uk, _ = pe.partial_eval_jaxpr_nounits( # type: ignore + cond_jaxpr_known, _, cond_uk, _ = pe.partial_eval_jaxpr_nounits( cond_jaxpr, cond_consts_uk + carry_uk, instantiate=False) if cond_uk[0] or all(not uk for uk in unknowns) or all(unknowns): @@ -2046,18 +2112,20 @@ def map(f, xs): return ys def _rng_bit_generator_batching_rule(batched_args, batch_dims, *, shape, dtype, algorithm): - """Calls RBG in a loop and stacks the results.""" - key, = batched_args + keys, = batched_args bd, = batch_dims if bd is batching.not_mapped: - return lax.rng_bit_generator_p.bind(key, shape=shape, dtype=dtype, + return lax.rng_bit_generator_p.bind(keys, shape=shape, dtype=dtype, algorithm=algorithm), (None, None) - key = batching.moveaxis(key, bd, 0) - map_body = lambda k: lax.rng_bit_generator_p.bind(k, shape=shape, dtype=dtype, algorithm=algorithm) - stacked_keys, stacked_bits = map(map_body, key) - return (stacked_keys, stacked_bits), (0, 0) + keys = batching.moveaxis(keys, bd, 0) + batch_size = keys.shape[0] + key = keys[0] + new_key, bits = lax.rng_bit_generator_p.bind(key, shape=(batch_size, *shape), + dtype=dtype, algorithm=algorithm) + new_keys = jax.lax.dynamic_update_index_in_dim(keys, new_key, 0, axis=0) + return (new_keys, bits), (0, 0) -batching.primitive_batchers[lax.rng_bit_generator_p] = _rng_bit_generator_batching_rule # type: ignore +batching.primitive_batchers[lax.rng_bit_generator_p] = _rng_bit_generator_batching_rule ### associative_scan @@ -2101,7 +2169,7 @@ def associative_scan(fn: Callable, elems, reverse: bool = False, axis: int = 0): Example 2: partial products of an array of matrices - >>> mats = jax.random.uniform(jax.random.PRNGKey(0), (4, 2, 2)) + >>> mats = jax.random.uniform(jax.random.key(0), (4, 2, 2)) >>> partial_prods = lax.associative_scan(jnp.matmul, mats) >>> partial_prods.shape (4, 2, 2) @@ -2301,16 +2369,13 @@ def register_lowering(fn, platform=None): mlir.cache_lowering(mlir.lower_fun(fn, multiple_results=False)), platform=platform) - # Default for platforms not treated specially below. - register_lowering(partial(associative_scan, reduce_fn)) - # On GPU, we choose between window reduction and associative scan - # based on the input size. - for platform in ['cuda', 'rocm']: - register_lowering( - partial(cumred_gpu_impl, reduce_window_fn, reduce_fn), platform) - # On TPU, an implementation using reduce_window is handled specially by the - # compiler and is efficient. On other backends, it is O(n^2). - register_lowering(partial(cumred_reduce_window_impl, reduce_window_fn), 'tpu') + # For jax-metal, until reduce_window legalization is better supported. + register_lowering(partial(associative_scan, reduce_fn), 'METAL') + # In XLA, there's a rewriter for an O(N^2) reduce-window implementation. + register_lowering( + partial(cumred_reduce_window_impl, reduce_window_fn) + ) + return reducer_p cumsum_p = _cumulative_reduction_primitive("cumsum", lax.add, windowed_reductions._reduce_window_sum) diff --git a/jax/_src/lax/control_flow/solves.py b/jax/_src/lax/control_flow/solves.py index 734037415dea..4d55907f6b37 100644 --- a/jax/_src/lax/control_flow/solves.py +++ b/jax/_src/lax/control_flow/solves.py @@ -49,7 +49,7 @@ def _split_root_args(args, const_lengths): @api_boundary def custom_root(f, initial_guess, solve, tangent_solve, has_aux=False): - """Differentiably solve for a roots of a function. + """Differentiably solve for the roots of a function. This is a low-level routine, mostly intended for internal use in JAX. Gradients of custom_root() are defined with respect to closed-over variables diff --git a/jax/_src/lax/convolution.py b/jax/_src/lax/convolution.py index 423c44773dad..2b2ad5bbb515 100644 --- a/jax/_src/lax/convolution.py +++ b/jax/_src/lax/convolution.py @@ -141,10 +141,10 @@ def conv_general_dilated( rhs_dilation = (1,) * (rhs.ndim - 2) if isinstance(padding, str): lhs_perm, rhs_perm, _ = dnums - rhs_shape = np.take(rhs.shape, rhs_perm)[2:] # type: ignore[index] + rhs_shape = np.take(rhs.shape, rhs_perm)[2:] effective_rhs_shape = [core.dilate_dim(k, r) for k, r in zip(rhs_shape, rhs_dilation)] padding = lax.padtype_to_pads( - np.take(lhs.shape, lhs_perm)[2:], effective_rhs_shape, # type: ignore[index] + np.take(lhs.shape, lhs_perm)[2:], effective_rhs_shape, window_strides, padding) else: try: @@ -328,7 +328,7 @@ def conv_transpose(lhs: Array, rhs: Array, strides: Sequence[int], raise ValueError('No 4+ dimensional dimension_number defaults.') dn = conv_dimension_numbers(lhs.shape, rhs.shape, dimension_numbers) k_shape = np.take(rhs.shape, dn.rhs_spec) - k_sdims = k_shape[2:] # type: ignore[index] + k_sdims = k_shape[2:] # Calculate correct output shape given padding and strides. pads: str | Sequence[tuple[int, int]] if isinstance(padding, str) and padding in {'SAME', 'VALID'}: @@ -411,7 +411,7 @@ def _conv_general_dilated_shape_rule( rhs_trans = lax._dilate_shape(np.take(rhs.shape, rhs_perm), rhs_dilation) out_trans = conv_shape_tuple(lhs_trans, rhs_trans, window_strides, padding, batch_group_count) - return tuple(np.take(out_trans, np.argsort(out_perm))) # type: ignore[arg-type] + return tuple(np.take(out_trans, np.argsort(out_perm))) def _conv_general_dilated_dtype_rule( @@ -719,10 +719,10 @@ def _conv_general_dilated_lower( dimension_numbers=dnums, feature_group_count=mlir.i64_attr(feature_group_count), batch_group_count=mlir.i64_attr(batch_group_count), - window_strides=mlir.dense_int_array_v6(window_strides), + window_strides=mlir.dense_int_array(window_strides), padding=mlir.dense_int_elements(padding), - lhs_dilation=mlir.dense_int_array_v6(lhs_dilation), - rhs_dilation=mlir.dense_int_array_v6(rhs_dilation), + lhs_dilation=mlir.dense_int_array(lhs_dilation), + rhs_dilation=mlir.dense_int_array(rhs_dilation), window_reversal=window_reversal, precision_config=lax.precision_attr(precision)) ] @@ -744,9 +744,9 @@ def prep_one_pad(pad_lo_hi: tuple[core.DimSize, core.DimSize]): dimension_numbers=dnums, feature_group_count=mlir.i64_attr(feature_group_count), batch_group_count=mlir.i64_attr(batch_group_count), - window_strides=mlir.dense_int_array_v6(window_strides), - lhs_dilation=mlir.dense_int_array_v6(lhs_dilation), - rhs_dilation=mlir.dense_int_array_v6(rhs_dilation), + window_strides=mlir.dense_int_array(window_strides), + lhs_dilation=mlir.dense_int_array(lhs_dilation), + rhs_dilation=mlir.dense_int_array(rhs_dilation), window_reversal=window_reversal, precision_config=lax.precision_attr(precision)) ] diff --git a/jax/_src/lax/eigh.py b/jax/_src/lax/eigh.py index 58adb90fa943..fc66b0f2e7ee 100644 --- a/jax/_src/lax/eigh.py +++ b/jax/_src/lax/eigh.py @@ -554,6 +554,7 @@ def eigh( if compute_slice: eig_vals = eig_vals[subset_by_index[0] : subset_by_index[1]] eig_vecs = eig_vecs[:, subset_by_index[0] : subset_by_index[1]] + return eig_vals, eig_vecs n = N if n is None else n with jax.default_matmul_precision(precision): diff --git a/jax/_src/lax/fft.py b/jax/_src/lax/fft.py index 3cefd45b7e85..a1cce3500df1 100644 --- a/jax/_src/lax/fft.py +++ b/jax/_src/lax/fft.py @@ -30,8 +30,6 @@ from jax._src.interpreters import mlir from jax._src.lib.mlir.dialects import hlo from jax._src.lib import xla_client -from jax._src.lib import xla_extension_version -from jax._src.lib import ducc_fft from jax._src.numpy.util import promote_dtypes_complex, promote_dtypes_inexact __all__ = [ @@ -122,76 +120,6 @@ def _fft_lowering(ctx, x, *, fft_type, fft_lengths): ] -def _fft_lowering_cpu(ctx, x, *, fft_type, fft_lengths): - x_aval, = ctx.avals_in - - in_shape = x_aval.shape - dtype = x_aval.dtype - out_aval, = ctx.avals_out - out_shape = out_aval.shape - - forward = fft_type in (xla_client.FftType.FFT, xla_client.FftType.RFFT) - ndims = len(in_shape) - assert len(fft_lengths) >= 1 - assert len(fft_lengths) <= ndims, (fft_lengths, ndims) - assert len(in_shape) == len(out_shape) == ndims - - # PocketFft does not allow size 0 dimensions. - if 0 in in_shape or 0 in out_shape: - if fft_type == xla_client.FftType.RFFT: - assert dtype in (np.float32, np.float64), dtype - out_dtype = np.dtype(np.complex64 if dtype == np.float32 else np.complex128) - - elif fft_type == xla_client.FftType.IRFFT: - assert np.issubdtype(dtype, np.complexfloating), dtype - out_dtype = np.dtype(np.float32 if dtype == np.complex64 else np.float64) - - else: - assert np.issubdtype(dtype, np.complexfloating), dtype - out_dtype = dtype - - zero = mlir.ir_constant(np.array(0, dtype=out_dtype)) - return [ - mlir.broadcast_in_dim(ctx, zero, out_aval, broadcast_dimensions=[])] - - strides_in = [] - stride = 1 - for d in reversed(in_shape): - strides_in.append(stride) - stride *= d - strides_in = mlir.shape_tensor( - mlir.eval_dynamic_shape(ctx, tuple(reversed(strides_in)))) - - strides_out = [] - stride = 1 - for d in reversed(out_shape): - strides_out.append(stride) - stride *= d - strides_out = mlir.shape_tensor( - mlir.eval_dynamic_shape(ctx, tuple(reversed(strides_out)))) - - # scale = 1. if forward else (1. / np.prod(fft_lengths)) as a f64[1] tensor - double_type = mlir.ir.RankedTensorType.get((), mlir.ir.F64Type.get()) - size_fft_length_prod = np.prod(fft_lengths) if fft_lengths else 1 - size_fft_lengths, = mlir.eval_dynamic_shape_as_vals(ctx, (size_fft_length_prod,)) - size_fft_lengths = hlo.ConvertOp(double_type, size_fft_lengths) - one = mlir.ir_constant(np.float64(1.)) - scale = one if forward else hlo.DivOp(one, size_fft_lengths) - scale = hlo.ReshapeOp( - mlir.ir.RankedTensorType.get((1,), mlir.ir.F64Type.get()), - scale).result - - in_shape = mlir.shape_tensor(mlir.eval_dynamic_shape(ctx, in_shape)) - out_shape = mlir.shape_tensor(mlir.eval_dynamic_shape(ctx, out_shape)) - in_shape = in_shape if fft_type != xla_client.FftType.IRFFT else out_shape - - result_type = mlir.aval_to_ir_type(out_aval) - return [ducc_fft.dynamic_ducc_fft_hlo( - result_type, x, - input_dtype=x_aval.dtype, ndims=ndims, input_shape=in_shape, - strides_in=strides_in, strides_out=strides_out, scale=scale, - fft_type=fft_type, fft_lengths=fft_lengths, result_shape=out_shape)] - def _naive_rfft(x, fft_lengths): y = fft(x, xla_client.FftType.FFT, fft_lengths) n = fft_lengths[-1] @@ -253,8 +181,3 @@ def _fft_batching_rule(batched_args, batch_dims, fft_type, fft_lengths): mlir.register_lowering(fft_p, _fft_lowering) ad.deflinear2(fft_p, _fft_transpose_rule) batching.primitive_batchers[fft_p] = _fft_batching_rule - -# TODO(phawkins): when jaxlib 0.4.21 is the minimum, use XLA's FFT lowering -# always on CPU. At that point, we can also delete the DUCC FFT kernel from JAX. -if xla_extension_version < 211: - mlir.register_lowering(fft_p, _fft_lowering_cpu, platform='cpu') diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index bc321e3306cf..ed6a677db64f 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -15,14 +15,14 @@ from __future__ import annotations import builtins -from collections.abc import Sequence +from collections.abc import Callable, Sequence import enum import functools from functools import partial import itertools import math import operator -from typing import Any, Callable, TypeVar, Union, cast as type_cast, overload, TYPE_CHECKING +from typing import Any, ClassVar, TypeVar, Union, cast as type_cast, overload, TYPE_CHECKING import warnings import numpy as np @@ -44,6 +44,7 @@ from jax._src import linear_util as lu from jax._src import pretty_printer as pp from jax._src import source_info_util +from jax._src import state from jax._src import util from jax._src.abstract_arrays import array_types from jax._src.core import (Primitive, UnshapedArray, ShapedArray, ConcreteArray, @@ -66,7 +67,7 @@ from jax._src.lib.mlir.dialects import chlo from jax._src.lib.mlir.dialects import hlo from jax._src.sharding_impls import PmapSharding -from jax._src.typing import Array, ArrayLike, DuckTypedArray, DTypeLike, Shape +from jax._src.typing import Array, ArrayLike, DimSize, DuckTypedArray, DTypeLike, Shape from jax._src.util import (cache, safe_zip, safe_map, canonicalize_axis, split_list, NumpyComplexWarning) @@ -84,9 +85,9 @@ map, unsafe_map = safe_map, map zip, unsafe_zip = safe_zip, zip -def _clip_int_to_valid_range(val: int, dtype) -> int: +def _clip_int_to_valid_range(val: DimSize, dtype) -> int: info = np.iinfo(dtype) - return builtins.max(info.min, builtins.min(int(val), info.max)) + return core.max_dim(info.min, core.min_dim(val, info.max)) def _validate_shapes(shapes: Sequence[Shape]): def _check_static_shape(shape: Shape): @@ -518,7 +519,7 @@ def convert_element_type(operand: ArrayLike, new_dtype: DTypeLike) -> Array: def _convert_element_type(operand: ArrayLike, new_dtype: DTypeLike | None = None, weak_type: bool = False): if hasattr(operand, '__jax_array__'): - operand = operand.__jax_array__() # type: ignore + operand = operand.__jax_array__() if (dtypes.issubdtype(new_dtype, dtypes.extended) or dtypes.issubdtype(getattr(operand, 'dtype', None), dtypes.extended)): @@ -624,14 +625,14 @@ def concatenate(operands: Array | Sequence[ArrayLike], dimension: int) -> Array: _precision_strings: dict[Any, Precision] = {} -# TODO(b/328046715): pytype appears unable to handle overriding __new__ in an -# enum class. Doing this crashes Pytype. For now, just write an explicit type -# for type checkers. +# TODO(b/333851820): pytype does not properly handle _missing_ in enums. +# We work around that by defining `Precision` as a normal class. if TYPE_CHECKING: + class Precision: - DEFAULT: Precision - HIGH: Precision - HIGHEST: Precision + DEFAULT: ClassVar[Precision] + HIGH: ClassVar[Precision] + HIGHEST: ClassVar[Precision] def __new__(cls, value: Precision | int | str | None) -> Precision: raise NotImplementedError @@ -645,42 +646,44 @@ def value(self) -> int: raise NotImplementedError else: + class Precision(enum.Enum): - """Precision enum for lax functions + """Precision enum for lax matrix multiply related functions. - The `precision` argument to JAX functions generally controls the tradeoff - between speed and accuracy for array computations on accelerator backends, - (i.e. TPU and GPU). Members are: + The device-dependent `precision` argument to JAX functions generally + controls the tradeoff between speed and accuracy for array computations on + accelerator backends, (i.e. TPU and GPU). Has no impact on CPU backends. + This only has an effect on float32 computations, and does not affect the + input/output datatypes. Members are: DEFAULT: - Fastest mode, but least accurate. Performs computations in bfloat16. - Aliases: ``'default'``, ``'fastest'``, ``'bfloat16'``. + Fastest mode, but least accurate. On TPU: performs float32 computations in + bfloat16. On GPU: uses tensorfloat32 if available (e.g. on A100 and H100 + GPUs), otherwise standard float32 (e.g. on V100 GPUs). Aliases: + ``'default'``, ``'fastest'``. HIGH: - Slower but more accurate. Performs float32 computations in 3 bfloat16 - passes, or using tensorfloat32 where available. Aliases: ``'high'``, - ``'bfloat16_3x'``, ``'tensorfloat32'``. + Slower but more accurate. On TPU: performs float32 computations in 3 + bfloat16 passes. On GPU: uses tensorfloat32 where available, otherwise + float32. Aliases: ``'high'``.. HIGHEST: - Slowest but most accurate. Performs computations in float32 or float64 - as applicable. Aliases: ``'highest'``, ``'float32'``. + Slowest but most accurate. On TPU: performs float32 computations in 6 + bfloat16. Aliases: ``'highest'``. On GPU: uses float32. """ + DEFAULT = 0 HIGH = 1 HIGHEST = 2 + @classmethod + def _missing_(cls, value: object) -> Precision | None: + return _precision_strings.get(value) + def __repr__(self) -> str: - return f"{self.__class__.__name__}.{self.name}" + return f'{self.__class__.__name__}.{self.name}' def __str__(self) -> str: return self.name - # You can't define __new__ on an enum class directly, but you can monkey-patch - # it after the fact. Another way to do this might be using a metaclass. - def _precision_new(cls, value: Precision | int | str | None) -> Precision: - return super(Precision, cls).__new__(cls, _precision_strings.get(value, value)) - - Precision.__new__ = _precision_new - - _precision_strings['highest'] = Precision.HIGHEST _precision_strings['float32'] = Precision.HIGHEST @@ -785,6 +788,33 @@ def dot_general(lhs: ArrayLike, rhs: ArrayLike, dimension_numbers: DotDimensionN precision=canonicalize_precision(precision), preferred_element_type=preferred_element_type) + +def ragged_dot( + lhs: Array, + rhs: Array, + group_sizes: Array, + precision: PrecisionLike = None, + preferred_element_type: DTypeLike | None = None, + group_offset: Array | None = None, + ) -> Array: + """Ragged matrix multiplication. + + Args: + lhs: (m, k) shaped array. + rhs: (g, k, n) shaped array. + group_sizes: (g,) shaped array with integer element type, where g denotes number of groups. The ith element indicates the size of ith group. + precision: Optional. Consistent with precision argument for :func:`jax.lax.dot`. + preferred_element_type: Optional. Consistent with precision argument for :func:`jax.lax.dot`. + group_offset: Optional. (1,) shaped array that indicates the group in group_sizes to start computing from. If not specified, defaults to [0]. + + Results: + (m, n) shaped array with preferred_element_type element type. + """ + return ragged_dot_p.bind(lhs, rhs, group_sizes, + precision=canonicalize_precision(precision), + preferred_element_type=preferred_element_type, group_offset=group_offset) + + def broadcast(operand: ArrayLike, sizes: Sequence[int]) -> Array: """Broadcasts an array, adding new leading dimensions @@ -1206,12 +1236,24 @@ def top_k(operand: ArrayLike, k: int) -> tuple[Array, Array]: k: integer specifying the number of top entries. Returns: - values: array containing the top k values along the last axis. - indices: array containing the indices corresponding to values. + A tuple ``(values, indices)`` where + + - ``values`` is an array containing the top k values along the last axis. + - ``indices`` is an array containing the indices corresponding to values. See also: - - :func:`jax.lax.approx_max_k` - - :func:`jax.lax.approx_min_k` + - :func:`jax.lax.approx_max_k` + - :func:`jax.lax.approx_min_k` + + Examples: + Find the largest three values, and their indices, within an array: + + >>> x = jnp.array([9., 3., 6., 4., 10.]) + >>> values, indices = jax.lax.top_k(x, 3) + >>> values + Array([10., 9., 6.], dtype=float32) + >>> indices + Array([4, 0, 2], dtype=int32) """ if core.is_constant_dim(k): k = int(k) @@ -1233,8 +1275,8 @@ def full(shape: Shape, fill_value: ArrayLike, dtype: DTypeLike | None = None, *, dtype: the type of the output array, or `None`. If not `None`, `fill_value` will be cast to `dtype`. sharding: an optional sharding specification for the resulting array, - note, sharding will currently be ignored in jitted mode, this might change - in the future. + note, sharding will currently be ignored in jitted mode, this might change + in the future. """ shape = canonicalize_shape(shape) if np.shape(fill_value): @@ -1251,15 +1293,12 @@ def full(shape: Shape, fill_value: ArrayLike, dtype: DTypeLike | None = None, *, # if needed? if (sharding is not None and not isinstance(sharding, PmapSharding) and isinstance(fill_value, array.ArrayImpl)): - broadcast_shape = sharding.shard_shape(shape) shard = broadcast(fill_value, broadcast_shape) return array.make_array_from_callback(shape, sharding, lambda _: shard) return broadcast(fill_value, shape) - - def zeros_like_shaped_array(aval: ShapedArray) -> Array: assert isinstance(aval, ShapedArray) if dtypes.issubdtype(aval.dtype, dtypes.extended): @@ -1272,6 +1311,12 @@ def zeros_like_shaped_array(aval: ShapedArray) -> Array: ad_util.aval_zeros_likers[ShapedArray] = zeros_like_shaped_array +def zeros_like_abstract_ref(aval: state.AbstractRef) -> core.MutableArray: + val = ad_util.zeros_like_aval(aval.inner_aval) + return core.mutable_array(val) + +ad_util.aval_zeros_likers[state.AbstractRef] = zeros_like_abstract_ref # type: ignore + def iota(dtype: DTypeLike, size: int) -> Array: """Wraps XLA's `Iota `_ @@ -1290,7 +1335,7 @@ def broadcasted_iota(dtype: DTypeLike, shape: Shape, dimension: int) -> Array: return iota_p.bind(*dynamic_shape, dtype=dtype, shape=tuple(static_shape), dimension=dimension) -def _eye(dtype: DTypeLike, shape: Shape, offset: int) -> Array: +def _eye(dtype: DTypeLike, shape: Shape, offset: DimSize) -> Array: """Like numpy.eye, create a 2D array with ones on a diagonal.""" offset = _clip_int_to_valid_range(offset, np.int32) dtype = dtypes.canonicalize_dtype(dtype) @@ -1302,7 +1347,7 @@ def _delta(dtype: DTypeLike, shape: Shape, axes: Sequence[int]) -> Array: """This utility function exists for creating Kronecker delta arrays.""" axes = map(int, axes) dtype = dtypes.canonicalize_dtype(dtype) - base_shape = tuple(np.take(shape, axes)) # type: ignore[arg-type] + base_shape = tuple(np.take(shape, axes)) iotas = [broadcasted_iota(np.uint32, base_shape, i) for i in range(len(base_shape))] eyes = [eq(i1, i2) for i1, i2 in zip(iotas[:-1], iotas[1:])] @@ -1310,11 +1355,12 @@ def _delta(dtype: DTypeLike, shape: Shape, axes: Sequence[int]) -> Array: new_dtype=dtype, weak_type=False) return broadcast_in_dim(result, shape, axes) -def _tri(dtype: DTypeLike, shape: Shape, offset: int) -> Array: +def _tri(dtype: DTypeLike, shape: Shape, offset: DimSize) -> Array: """Like numpy.tri, create a 2D array with ones below a diagonal.""" offset = _clip_int_to_valid_range(offset, np.int32) dtype = dtypes.canonicalize_dtype(dtype) - bool_tri = ge(add(broadcasted_iota(np.int32, shape, 0), np.int32(offset)), + bool_tri = ge(add(broadcasted_iota(np.int32, shape, 0), + asarray(core.dimension_as_value(offset)).astype(np.int32)), broadcasted_iota(np.int32, shape, 1)) return convert_element_type_p.bind(bool_tri, new_dtype=dtype, weak_type=False) @@ -1402,7 +1448,7 @@ def full_like(x: ArrayLike | DuckTypedArray, If not specified, the output will have the same sharding as the input, with a few exceptions/limitations in particular: 1. Sharding is not available during tracing, thus this will rely on jit. - 2. If x is weakly typed or uncomitted, will use default sharding. + 2. If x is weakly typed or uncommitted, will use default sharding. 3. Shape is not None and is different from x.shape, default will be used. Returns: @@ -1415,17 +1461,22 @@ def full_like(x: ArrayLike | DuckTypedArray, if dtypes.issubdtype(dtype, dtypes.extended): return dtype._rules.full(fill_shape, fill_value, dtype) # type: ignore[union-attr] + # If `x` has a sharding but no `_committed` attribute + # (in case of ShapeDtypeStruct), default it to True. use_x_sharding = ( - sharding is None and - isinstance(x, array.ArrayImpl) and - not weak_type and x._committed and - # NB: consider reusng x.sharding for mismatched shapes - # if x is replicated or single device. - fill_shape == x.shape) + sharding is None + # Tracer have special logic in handling sharding and even + # though hasattr(x, 'sharding') returns False, it is very slow. + # This bypasses the check. + and not isinstance(x, core.Tracer) + and hasattr(x, 'sharding') + and getattr(x, '_committed', True) + and not weak_type + and fill_shape == np.shape(x) # type: ignore[arg-type] + ) if use_x_sharding: - assert isinstance(x, array.ArrayImpl) # makes pytype happy. # TODO(yashkatariya): Use shard_alike in tracing_mode once it is supported. - sharding = x.sharding + sharding = x.sharding # type: ignore val = full(fill_shape, _convert_element_type(fill_value, dtype, weak_type), sharding=sharding) return val @@ -1560,7 +1611,7 @@ def zeros_like_array(x: ArrayLike) -> Array: def _add_arrays(x, y): if (isinstance(a := core.get_aval(x), ShapedArray) and dtypes.issubdtype(a.dtype, dtypes.extended)): - return dtype._rules.add(dtype, x, y) # type: ignore + return dtype._rules.add(dtype, x, y) # pytype: disable=attribute-error return add(x, y) for t in itertools.chain( @@ -1726,8 +1777,8 @@ def broadcast_hlo( for aval, arg in zip(avals, args): if aval.shape != aval_out.shape: assert len(aval.shape) <= len(aval_out.shape), (aval, aval_out) - dims = mlir.dense_int_array_v6( - range(len(aval_out.shape) - len(aval.shape), len(aval_out.shape))) + dims = mlir.dense_int_array( + list(range(len(aval_out.shape) - len(aval.shape), len(aval_out.shape)))) if any(isinstance(d, ir.Value) for d in aval_out.shape): arg = hlo.dynamic_broadcast_in_dim( mlir.aval_to_ir_type(aval_out), arg, @@ -1740,7 +1791,7 @@ def broadcast_hlo( return out def _nary_lower_hlo(op: Callable, ctx, - *args: ir.Value | Sequence[ir.Value], + *args: ir.Value, explicit_type=False, **params) -> Sequence[ir.Value]: """Lowers an elementwise operator to its MLIR equivalent. @@ -2424,9 +2475,9 @@ def _convert_element_type_shape_rule(operand, *, new_dtype, weak_type): def _convert_element_type_dtype_rule(operand, *, new_dtype, weak_type): if (operand.dtype != new_dtype and ((dtypes.issubdtype(operand.dtype, dtypes.extended) and - not operand.dtype._rules.convert_from(operand.dtype, new_dtype)) or # type: ignore + not operand.dtype._rules.convert_from(operand.dtype, new_dtype)) or (dtypes.issubdtype(new_dtype, dtypes.extended) and - not new_dtype._rules.convert_to(operand.dtype, new_dtype)))): # type: ignore + not new_dtype._rules.convert_to(operand.dtype, new_dtype)))): raise ValueError( f"Cannot convert_element_type from {dtype_to_string(operand.dtype)} " f"to {dtype_to_string(new_dtype)}") @@ -2729,7 +2780,7 @@ def _dot_general_transpose_lhs(g, x, y, *, dimension_numbers, precision, else: ans_batch, _, ans_y = ranges_like(x_batch, x_kept, y_kept) dims = ((ans_y, y_kept), (ans_batch, y_batch)) - x_contract_sorted_by_y = list(np.take(x_contract, np.argsort(y_contract))) # type: ignore[arg-type] + x_contract_sorted_by_y = list(np.take(x_contract, np.argsort(y_contract))) out_axes = np.argsort(list(x_batch) + x_kept + x_contract_sorted_by_y) x_bar = transpose(dot_general(g, y, dims, precision=precision, preferred_element_type=preferred_element_type), @@ -2878,6 +2929,10 @@ def precision_attr(precision: Precision) -> ir.ArrayAttr: def _dot_general_lower(ctx, lhs, rhs, *, dimension_numbers, precision, preferred_element_type: np.dtype | None, platform: str = "default"): + def _is_fp8_mixed_precision_matmul(_lhs_dtypes, _rhs_dtypes): + fp8_dtypes = (dtypes.float8_e4m3fn, dtypes.float8_e5m2, + dtypes.float8_e5m2fnuz, dtypes.float8_e4m3fnuz) + return _lhs_dtypes in fp8_dtypes and _rhs_dtypes in fp8_dtypes del preferred_element_type # Implied by the output aval lhs_aval, rhs_aval = ctx.avals_in lhs_dtype, rhs_dtype = lhs_aval.dtype, rhs_aval.dtype @@ -2900,21 +2955,13 @@ def _dot_general_lower(ctx, lhs, rhs, *, dimension_numbers, core.ShapedArray(rhs_aval.shape, aval_out.dtype)) lhs_dtype = rhs_dtype = aval_out.dtype else: # cpu and gpu - lhs = mlir.convert_hlo(ctx, lhs, lhs_aval, - core.ShapedArray(lhs_aval.shape, aval_out.dtype)) - rhs = mlir.convert_hlo(ctx, rhs, rhs_aval, - core.ShapedArray(rhs_aval.shape, aval_out.dtype)) - lhs_dtype = rhs_dtype = aval_out.dtype - - # TODO(b/195364460): Work around slow XLA/CPU implementation of float16 matmul - if platform == "cpu": - if lhs_dtype == np.float16: - lhs = mlir.convert_hlo(ctx, lhs, lhs_aval, - core.ShapedArray(lhs_aval.shape, np.float32)) - - if rhs_dtype == np.float16: - rhs = mlir.convert_hlo(ctx, rhs, rhs_aval, - core.ShapedArray(rhs_aval.shape, np.float32)) + # Do not convert mixed fp8 types to output type. + if not _is_fp8_mixed_precision_matmul(lhs_dtype, rhs_dtype): + lhs = mlir.convert_hlo(ctx, lhs, lhs_aval, + core.ShapedArray(lhs_aval.shape, aval_out.dtype)) + rhs = mlir.convert_hlo(ctx, rhs, rhs_aval, + core.ShapedArray(rhs_aval.shape, aval_out.dtype)) + lhs_dtype = rhs_dtype = aval_out.dtype dot_dnums = hlo.DotDimensionNumbers.get( @@ -2939,6 +2986,62 @@ def _dot_general_lower(ctx, lhs, rhs, *, dimension_numbers, platform=platform) +def _ragged_dot_shape_rule(lhs: Array, rhs: Array, group_sizes: Array, **_) -> Shape: + m, k = lhs.shape + group_count, rk, n = rhs.shape + if k != rk: + raise TypeError(f"ragged_dot requires that lhs.shape[1] == rhs.shape[1]: got {k} and {rk}.") + num_groups = group_sizes.shape[0] + if group_count != num_groups: + raise TypeError(f"ragged_dot requires that rhs.shape[0] == group_sizes.shape[0]: got {group_count} and {num_groups}.") + return (m, n) + +# DotDimensionNumbers used in the dot_general call for ragged_dot(). +_RAGGED_DOT_DOT_DIMENSION_NUMBERS: DotDimensionNumbers = (([2, 0], [1, 0]), ([], [])) + +def _ragged_dot_dtype_rule(lhs: Array, rhs: Array, group_sizes: Array, + precision, preferred_element_type: DTypeLike | None, **_) -> np.dtype: + if not dtypes.issubdtype(group_sizes.dtype, np.integer): + raise TypeError("ragged_dot requires that group_sizes.dtype is subtype of np.integer.") + # defer the output dtype to dot_general, which is part of the _ragged_dot_impl. + return _dot_general_dtype_rule(lhs, rhs, dimension_numbers=_RAGGED_DOT_DOT_DIMENSION_NUMBERS, precision=precision, preferred_element_type=preferred_element_type) + +ragged_dot_p = standard_primitive(_ragged_dot_shape_rule, + _ragged_dot_dtype_rule, 'ragged_dot') +ragged_dot_p.def_impl(partial(dispatch.apply_primitive, ragged_dot_p)) + +def _ragged_dot_impl( + lhs: Array, + rhs: Array, + group_sizes: Array, + precision: PrecisionLike = None, + preferred_element_type: DTypeLike | None = None, + group_offset: Array | None = None, + ) -> Array: + if group_offset is not None: + raise NotImplementedError("Unimplemented group_offset support.") + shape = (rhs.shape[0], lhs.shape[0], lhs.shape[1]) + lhs = broadcast_in_dim(lhs, shape, [1, 2]) + iota = broadcasted_iota(group_sizes.dtype, shape, 1) + group_ends = jax.lax.cumsum(group_sizes) + group_starts = concatenate( + [_zeros(group_sizes)[:1], group_ends[:-1]], dimension=0, + ) + group_ends = broadcast_in_dim(group_ends, shape, (0,)) + group_starts = broadcast_in_dim(group_starts, shape, (0,)) + mask = bitwise_and(group_starts <= iota, iota < group_ends) + lhs = select(mask, lhs, _zeros(lhs)) + return dot_general( + lhs, + rhs, + dimension_numbers=_RAGGED_DOT_DOT_DIMENSION_NUMBERS, + precision=precision, + preferred_element_type=preferred_element_type, + ) + +mlir.register_lowering(ragged_dot_p, mlir.lower_fun(_ragged_dot_impl, multiple_results=False)) + + def _broadcast_in_dim_shape_rule(operand, *, shape, broadcast_dimensions): _check_shapelike('broadcast_in_dim', 'shape', shape) _check_shapelike('broadcast_in_dim', 'broadcast_dimensions', @@ -3590,7 +3693,7 @@ def _transpose_lower(ctx, x, *, permutation): transpose_p = standard_primitive(_transpose_shape_rule, _input_dtype, 'transpose') ad.deflinear2(transpose_p, - lambda t, _, permutation: [transpose(t, np.argsort(permutation))]) # type: ignore[arg-type] + lambda t, _, permutation: [transpose(t, np.argsort(permutation))]) batching.primitive_batchers[transpose_p] = _transpose_batch_rule mlir.register_lowering(transpose_p, _transpose_lower) pe.def_trivial_padding(transpose_p) @@ -3628,8 +3731,7 @@ def _select_transpose_rule(t, which, *cases): for c in cases] else: zeros = full_like(t, 0) - if (dtypes.dtype(which) == np.dtype(np.bool_) and - config.new_select_transpose.value): + if dtypes.dtype(which) == np.dtype(np.bool_): ct0 = select(which, zeros, t) if ad.is_undefined_primal(cases[0]) else None ct1 = select(which, t, zeros) if ad.is_undefined_primal(cases[1]) else None return (None, ct0, ct1) @@ -3868,7 +3970,7 @@ def _reduce_lower(ctx, *values, computation, jaxpr, dimensions): operands, init_values = util.split_list(values, [len(values) // 2]) init_value_avals = ctx.avals_in[len(values) // 2:] op = hlo.ReduceOp([mlir.aval_to_ir_type(aval) for aval in ctx.avals_out], - operands, init_values, mlir.dense_int_array_v6(dimensions)) + operands, init_values, mlir.dense_int_array(dimensions)) ir_types = [mlir.aval_to_ir_type(aval) for aval in init_value_avals] reducer = op.regions[0].blocks.append(*(ir_types + ir_types)) with ir.InsertionPoint(reducer): @@ -4079,7 +4181,7 @@ def _unary_reduce_lower(reducer, unit_factory, ctx, x, *, axes): dtype = aval_out.dtype op = hlo.ReduceOp([mlir.aval_to_ir_type(aval_out)], [x], mlir.ir_constants(unit_factory(aval_out.dtype)), - mlir.dense_int_array_v6(axes)) + mlir.dense_int_array(axes)) scalar_type = mlir.aval_to_ir_type(core.ShapedArray((), dtype)) reducer_region = op.regions[0].blocks.append(scalar_type, scalar_type) with ir.InsertionPoint(reducer_region): @@ -4555,7 +4657,7 @@ def _rng_bit_generator_weak_type_rule(key, *, shape, dtype, algorithm): return (key.weak_type, False) RandomAlgorithm = xops.RandomAlgorithm -RandomAlgorithm.__str__ = lambda algorithm: algorithm.name # type: ignore[assignment] +RandomAlgorithm.__str__ = lambda algorithm: algorithm.name # type: ignore[method-assign] def _rng_algorithm(algorithm: RandomAlgorithm): if algorithm == RandomAlgorithm.RNG_THREE_FRY: @@ -4694,7 +4796,9 @@ def _copy_impl(prim, *args, **kwargs): ad.deflinear(copy_p, lambda t: [copy_p.bind(t)]) pe.def_trivial_padding(copy_p) batching.defvectorized(copy_p) - +def _propagate_mem_kind_copy(in_mem_kind): + return in_mem_kind +pxla.memory_kind_propagate_rule[copy_p] = _propagate_mem_kind_copy def rng_bit_generator(key, shape, dtype=np.uint32, algorithm=RandomAlgorithm.RNG_DEFAULT): @@ -5004,7 +5108,7 @@ def remaining(original, *removed_lists): def canonicalize_precision(precision: PrecisionLike) -> tuple[Precision, Precision] | None: - """Turns an API precision specification, into a pair of enumeration values. + """Turns an API precision specification into a pair of enumeration values. The API can take the precision as a string, or int, and either as a single value to apply to both operands, or as a sequence of two values. @@ -5111,18 +5215,6 @@ def handler(bufs): return core.DArray(aval, phys_handler(bufs)) return handler - @staticmethod - def physical_hlo_sharding(aval, hlo_sharding: xc.HloSharding) -> xc.HloSharding: - return hlo_sharding - - @staticmethod - def logical_sharding(aval, phys_sharding): - return phys_sharding - - @staticmethod - def physical_sharding(aval, sharding): - return sharding - @staticmethod def convert_from(bint_dtype, other_dtype) -> bool: return other_dtype in (np.dtype('int32'), np.dtype('int64')) @@ -5131,12 +5223,5 @@ def convert_from(bint_dtype, other_dtype) -> bool: def convert_to(other_dtype, bint_dtype) -> bool: return other_dtype in (np.dtype('int32'), np.dtype('int64')) - @staticmethod - def replicate_trailing_dims(ctx, val: ir.Value, aval) -> ir.Value: - return val - - @staticmethod - def check_replicated_trailing_dims(sharding: jax.sharding.GSPMDSharding, aval): - pass core.bint._rules = BIntRules diff --git a/jax/_src/lax/linalg.py b/jax/_src/lax/linalg.py index 80162a204980..d31bba99171c 100644 --- a/jax/_src/lax/linalg.py +++ b/jax/_src/lax/linalg.py @@ -14,12 +14,11 @@ from __future__ import annotations -import inspect +from collections.abc import Callable import functools from functools import partial import math -from typing import cast, Any, Callable, Literal, TypeVar, overload -import warnings +from typing import Any, Literal, TypeVar, overload import numpy as np @@ -62,51 +61,6 @@ # traceables -# TODO(phawkins): remove backward compatibility shim after 2022/08/11. -def _warn_on_positional_kwargs(f: TFun) -> TFun: - """Decorator used for backward compatibility of keyword-only arguments. - - Some functions were changed to mark their keyword arguments as keyword-only. - This decorator allows existing code to keep working temporarily, while issuing - a warning if a now keyword-only parameter is passed positionally.""" - sig = inspect.signature(f) - pos_names = [name for name, p in sig.parameters.items() - if p.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD] - kwarg_names = [name for name, p in sig.parameters.items() - if p.kind == inspect.Parameter.KEYWORD_ONLY] - - # This decorator assumes that all arguments to `f` are either - # positional-or-keyword or keyword-only. - assert len(pos_names) + len(kwarg_names) == len(sig.parameters) - - @functools.wraps(f) - def wrapped(*args, **kwargs): - if len(args) < len(pos_names): - a = pos_names[len(args)] - raise TypeError(f"{f.__name__} missing required positional argument: {a}") - - pos_args = args[:len(pos_names)] - extra_kwargs = args[len(pos_names):] - - if len(extra_kwargs) > len(kwarg_names): - raise TypeError(f"{f.__name__} takes at most {len(sig.parameters)} " - f" arguments but {len(args)} were given.") - - for name, value in zip(kwarg_names, extra_kwargs): - if name in kwargs: - raise TypeError(f"{f.__name__} got multiple values for argument: " - f"{name}") - - warnings.warn(f"Argument {name} to {f.__name__} is now a keyword-only " - "argument. Support for passing it positionally will be " - "removed in an upcoming JAX release.", - DeprecationWarning) - kwargs[name] = value - return f(*pos_args, **kwargs) - - return cast(TFun, wrapped) - -@_warn_on_positional_kwargs def cholesky(x: Array, *, symmetrize_input: bool = True) -> Array: """Cholesky decomposition. @@ -136,7 +90,7 @@ def cholesky(x: Array, *, symmetrize_input: bool = True) -> Array: x = symmetrize(x) return jnp.tril(cholesky_p.bind(x)) -@_warn_on_positional_kwargs + def eig(x: ArrayLike, *, compute_left_eigenvectors: bool = True, compute_right_eigenvectors: bool = True) -> list[Array]: """Eigendecomposition of a general matrix. @@ -162,7 +116,6 @@ def eig(x: ArrayLike, *, compute_left_eigenvectors: bool = True, compute_right_eigenvectors=compute_right_eigenvectors) -@_warn_on_positional_kwargs def eigh( x: Array, *, @@ -216,6 +169,20 @@ def eigh( return v, w +def cholesky_update(r_matrix: ArrayLike, w_vector: ArrayLike) -> Array: + """Given a Cholesky decomposition A = R.T @ R and a vector w, + computes the Cholesky decomposition of A + w @ w.T in O(N^2) time. + + Args: + r_matrix: An upper-triangular matrix (R) such that A = R.T @ R. + w_vector: A vector (w) for rank-1 update. + + Returns: + A new R' matrix being the Cholesky decomposition of A + w @ w.T. + """ + return cholesky_update_p.bind(r_matrix, w_vector) + + def lu_pivots_to_permutation(pivots: ArrayLike, permutation_size: int) -> Array: """Converts the pivots (row swaps) returned by LU to a permutation. @@ -267,7 +234,7 @@ def lu(x: ArrayLike) -> tuple[Array, Array, Array]: lu, pivots, permutation = lu_p.bind(x) return lu, pivots, permutation -@_warn_on_positional_kwargs + def qr(x: ArrayLike, *, full_matrices: bool = True) -> tuple[Array, Array]: """QR decomposition. @@ -333,7 +300,6 @@ def svd( # TODO: Add `max_qdwh_iterations` to the function signature for TPU SVD. -@_warn_on_positional_kwargs def svd( x: ArrayLike, *, @@ -361,7 +327,6 @@ def svd( return s -@_warn_on_positional_kwargs def triangular_solve(a: ArrayLike, b: ArrayLike, *, left_side: bool = False, lower: bool = False, transpose_a: bool = False, conjugate_a: bool = False, @@ -516,6 +481,81 @@ def _cholesky_cpu_lowering(ctx, operand): mlir.register_lowering( cholesky_p, _cholesky_cpu_lowering, platform='cpu') +# Cholesky update + +def _cholesky_update_abstract_eval(r_matrix, w_vector): + r_dtype = dtypes.canonicalize_dtype(r_matrix.dtype) + w_dtype = dtypes.canonicalize_dtype(w_vector.dtype) + if not (r_dtype == w_dtype and r_dtype in (np.float32, np.float64)): + raise NotImplementedError( + "Rank-1 Cholesky update is only implemented for float32 and float64.") + if not (r_matrix.ndim == 2 and w_vector.ndim == 1 + and r_matrix.shape[-2] == r_matrix.shape[-1] + and r_matrix.shape[-2] == w_vector.shape[-1]): + raise ValueError( + "Rank-1 update to Cholesky decomposition takes a square matrix " + "and a vector as inputs. Got shapes {}, {} instead".format( + r_matrix.shape, w_vector.shape)) + return ShapedArray(r_matrix.shape, r_matrix.dtype) + +def _cholesky_update_cuda_lowering_rule(ctx, r_matrix, w_vector): + r_matrix_aval, _ = ctx.avals_in + try: + [platform] = ctx.module_context.platforms + except ValueError: + raise ValueError( + "Can only lower cholesky_update on a single platform." + ) from None + if platform != "cuda": + raise NotImplementedError( + "Can only lower fast cholesky_update on CUDA." + ) + return gpu_linalg.cuda_cholesky_update( + r_matrix, w_vector, r_matrix_aval.dtype) + + +def _cholesky_update_jax_fn(R, z): + def _drotg(x, y): + """Get coefs for Givens rotation in a numerically stable way.""" + def _drotg_nonzero(x, y): + abs_x = jax.numpy.abs(x) + abs_y = jax.numpy.abs(y) + denominator = jnp.where(abs_x > abs_y, abs_x, abs_y) + x /= denominator + y /= denominator + rh = 1 / jax.numpy.sqrt(x ** 2 + y ** 2) + return x * rh, -y * rh + one_and_zero = ( + jnp.array(1., dtype=x.dtype), + jnp.array(0., dtype=x.dtype), + ) + return jax.lax.cond(y == 0, lambda x, y: one_and_zero, _drotg_nonzero, x, y) + + def _drot( + first_vector: jax.Array, second_vector: jax.Array, + c_coef: float, s_coef: float) -> tuple[jax.Array, jax.Array]: + return ( + c_coef * first_vector - s_coef * second_vector, + c_coef * second_vector + s_coef * first_vector) + n = z.shape[0] + for k in range(n): + c, s = _drotg(R[k, k], z[k]) + row_k, z = _drot(R[k, :], z, c, s) + R = R.at[k, :].set(row_k) + return R + +cholesky_update_p = Primitive('cholesky_update') +cholesky_update_p.multiple_results = False +cholesky_update_p.def_abstract_eval(_cholesky_update_abstract_eval) +cholesky_update_p.def_impl(partial(dispatch.apply_primitive, cholesky_update_p)) + +mlir.register_lowering( + cholesky_update_p, _cholesky_update_cuda_lowering_rule, platform='cuda') + +mlir.register_lowering( + cholesky_update_p, + mlir.lower_fun(_cholesky_update_jax_fn, multiple_results=False)) + # Asymmetric eigendecomposition def eig_impl(operand, *, compute_left_eigenvectors, compute_right_eigenvectors): @@ -1550,7 +1590,7 @@ def _geqrf_cpu_gpu_lowering(geqrf_impl, batched_geqrf_impl, ctx, a, *, a_out, taus = batched_geqrf_impl(a_aval.dtype, a) else: if platform in ["cuda", "rocm"]: - a_out, taus, info_geqrf = geqrf_impl(a_aval.dtype, a) # type: ignore + a_out, taus, info_geqrf = geqrf_impl(a_aval.dtype, a) else: a_shape_vals = mlir.eval_dynamic_shape_as_ivals(ctx, a_aval.shape) a_out, taus, info_geqrf = geqrf_impl(a_aval.dtype, a, @@ -1664,7 +1704,7 @@ def _householder_product_cpu_gpu_lowering(orgqr_impl, ctx, a, taus, *, raise NotImplementedError( "Shape polymorphism for native serialization for householder_product " f"on GPU is not implemented; b/261671778; {a_aval.shape}") - a, info_orgqr = orgqr_impl(a_aval.dtype, a, taus) # type: ignore + a, info_orgqr = orgqr_impl(a_aval.dtype, a, taus) else: a_shape_vals = mlir.eval_dynamic_shape_as_ivals(ctx, a_aval.shape) tau_shape_vals = mlir.eval_dynamic_shape_as_ivals(ctx, taus_aval.shape) @@ -2208,7 +2248,6 @@ def tridiagonal_solve(dl: Array, d: Array, du: Array, b: Array) -> Array: # Schur Decomposition -@_warn_on_positional_kwargs def schur(x: ArrayLike, *, compute_schur_vectors: bool = True, sort_eig_vals: bool = False, diff --git a/jax/_src/lax/parallel.py b/jax/_src/lax/parallel.py index e76641890598..47386cb4a5f0 100644 --- a/jax/_src/lax/parallel.py +++ b/jax/_src/lax/parallel.py @@ -79,7 +79,7 @@ def psum(x, axis_name, *, axis_index_groups=None): >>> print(y) [0. 0.16666667 0.33333334 0.5 ] - Suppose we want to perform ``psum`` among two groups, one with ``device0`` and ``device1``, the other with `device2` and `device3`, + Suppose we want to perform ``psum`` among two groups, one with ``device0`` and ``device1``, the other with ``device2`` and ``device3``, >>> y = jax.pmap(lambda x: jax.lax.psum(x, 'i', axis_index_groups=[[0, 1], [2, 3]]), axis_name='i')(x) >>> print(y) @@ -247,6 +247,36 @@ def _canonicalize_axis_index_groups(axis_index_groups): return return tuple(map(tuple, axis_index_groups)) + +def pbroadcast(x, axis_name, source): + """Perform a collective broadcast and replicate from ``source``. + + This is equivalent to + ``` + def pbroadcast(x, axis_name, source): + masked = jnp.where(axis_index(axis_name) == source, x, zeros_like(x)) + return psum(masked, axis_name) + ``` + but implemented in a hardware optimized way. + + If ``x`` is a pytree then the result is equivalent to mapping this function to + each leaf in the tree. + + This function is an analog of the CollectiveBroadcast HLO. + + Args: + x: array(s) with a mapped axis named ``axis_name``. + axis_name: hashable Python object used to name a pmapped axis (see the + :func:`jax.pmap` documentation for more details). + source: int, representing which index into ``axis_name`` that should be copied. + + Returns: + Array(s) with ``x`` being copied from the ``source`` index slice of ``axis_name``. + """ + return tree_util.tree_map( + partial(pbroadcast_p.bind, axis_name=axis_name, source=source), x) + + def ppermute(x, axis_name, perm): """Perform a collective permutation according to the permutation ``perm``. @@ -483,9 +513,9 @@ def xeinsum(spec: str, *operands): xs = list(operands) for idx, (in_subs, in_named) in enumerate(safe_zip(all_in_subs, all_in_named)): # if a subscript axis appears only in one input and not the output, reduce! - other_named = set().union( # type: ignore + other_named = set().union( *[named for i, named in enumerate(all_in_named) if i != idx]) - other_subs = set().union( # type: ignore + other_subs = set().union( *[subs for i, subs in enumerate(all_in_subs) if i != idx]) subs_reduce = list(set(in_subs) - {*out_subs, *other_subs}) @@ -697,7 +727,7 @@ def _replica_groups(axis_env, axis_name, axis_index_groups): return replica_groups def _replica_groups_hlo(replica_groups: Sequence[Sequence[int]] - ) -> ir.DenseIntElementsAttr: + ) -> ir.DenseElementsAttr: # Uneven replica groups are padded with -1. groups = np.array(list(itertools.zip_longest(*replica_groups, fillvalue=-1)), dtype=np.int64).T @@ -832,7 +862,7 @@ def pos_reduce(x): assert not pos_axes size = len(axis_index_groups[0]) else: - size = math.prod([core.axis_frame(name).size for name in named_axes]) # type: ignore + size = math.prod([core.axis_frame(name).size for name in named_axes]) return tuple(lax._const(x, size) * pos_reduce(x) for x in args) return core.AxisPrimitive.bind( psum_p, *args, axes=axes, axis_index_groups=axis_index_groups) @@ -927,6 +957,43 @@ def _collective_batcher(prim, args, dims, **params): batching.axis_primitive_batchers[ppermute_p] = _ppermute_batcher core.axis_substitution_rules[ppermute_p] = partial(_subst_all_names_in_param, 'axis_name') +def _pbroadcast_transpose_rule(t, x, source, axis_name): + is_source = axis_index(axis_name) == source + tsum = psum(t, axis_name) + return [lax_numpy.where(is_source, tsum, lax_numpy.zeros_like(t))] + +def _pbroadcast_batcher(axis_size, frame_name, _, vals_in, dims_in, axis_name, source): + (v,), (d,) = vals_in, dims_in + if not isinstance(axis_name, (tuple, list)): + axis_name = (axis_name,) + remaining_axes = tuple(axis for axis in axis_name if axis != frame_name) + if remaining_axes: + raise NotImplementedError("pbroadcast batcher only supports a single axis") + assert axis_name[0] == frame_name, "pbroadcast batcher called with a wrong axis!" + assert source >= 0 and source < axis_size, "collective broadcast doesn't fit in the axis size!" + if axis_size == 1 and remaining_axes: + return pbroadcast_p.bind(v, source=source, axis_name=remaining_axes), d + if d is batching.not_mapped: + return v, d + return lax_numpy.take(v, [source] * axis_size, d), d + +def _pbroadcast_lowering(ctx, x, *, axis_name, source): + replica_groups = _replica_groups(ctx.module_context.axis_env, axis_name, None) + def source_to_front(group): + return [group[source]] + list(group[:source]) + list(group[source + 1:]) + replica_groups = [source_to_front(group) for group in replica_groups] + channel = ctx.module_context.new_channel() + return hlo.CollectiveBroadcastOp( + x, replica_groups=_replica_groups_hlo(replica_groups)).results + +pbroadcast_p = core.AxisPrimitive('pbroadcast') +pbroadcast_p.def_abstract_eval(lambda x, **params: raise_to_shaped(x)) +ad.deflinear2(pbroadcast_p, _pbroadcast_transpose_rule) +mlir.register_lowering(pbroadcast_p, _pbroadcast_lowering) +batching.primitive_batchers[pbroadcast_p] = partial(_collective_batcher, pbroadcast_p) +batching.axis_primitive_batchers[pbroadcast_p] = _pbroadcast_batcher +core.axis_substitution_rules[pbroadcast_p] = partial(_subst_all_names_in_param, 'axis_name') + def _moveaxis(src, dst, x): perm = [i for i in range(x.ndim) if i != src] @@ -1204,7 +1271,7 @@ def _all_gather_lowering(ctx, x, *, all_gather_dimension, axis_name, broadcast_dimensions = [i for i in range(len(new_shape)) if i != all_gather_dimension] x = hlo.broadcast_in_dim( mlir.aval_to_ir_type(x_aval.update(shape=new_shape)), x, - mlir.dense_int_array_v6(broadcast_dimensions)) + mlir.dense_int_array(broadcast_dimensions)) replica_groups = _replica_groups(ctx.module_context.axis_env, axis_name, axis_index_groups) if is_spmd: @@ -1473,10 +1540,10 @@ def psum_scatter(x, axis_name, *, scatter_dimension=0, axis_index_groups=None, When ``False`` (the default value), the size of dimension in ``scatter_dimension`` must match the size of axis ``axis_name`` (or the group size if ``axis_index_groups`` is given). After scattering the - all-reduce result along ``scatter_dimension``, the output is sequeezed by + all-reduce result along ``scatter_dimension``, the output is squeezed by removing ``scatter_dimension``, so the result has lower rank than the input. When ``True``, the size of dimension in ``scatter_dimension`` must - be dividible by the size of axis ``axis_name`` (or the group size if + be divisible by the size of axis ``axis_name`` (or the group size if ``axis_index_groups`` is given), and the ``scatter_dimension`` axis is preserved (so the result has the same rank as the input). diff --git a/jax/_src/lax/qdwh.py b/jax/_src/lax/qdwh.py index 92c41db22224..bac3ea957955 100644 --- a/jax/_src/lax/qdwh.py +++ b/jax/_src/lax/qdwh.py @@ -71,19 +71,17 @@ def _use_qr(u, m, n, params): m, n: the dynamic shape of the matrix, where m <= M and n <= N. params: the QDWH parameters. """ - a, b, c = params + a_minus_e_by_sqrt_c, sqrt_c, e = params M, N = u.shape - y = _dynamic_concat(jnp.sqrt(c) * u, jnp.eye(N, dtype=jnp.dtype(u)), m) + y = _dynamic_concat(sqrt_c * u, jnp.eye(N, dtype=jnp.dtype(u)), m) q, _ = lax_linalg.qr(y, full_matrices=False) # q1 = q[:m, :] q1 = _mask(lax.slice(q, (0, 0), (M, N)), (m, n)) # q2 = (q[m:, :]).T.conj() q2 = lax.dynamic_slice_in_dim(q, m, N, axis=0) q2 = _mask(q2, (n, n)).T.conj() - e = b / c - u = (e * u + (a - e) / jnp.sqrt(c) * jnp.einsum('ij,jk->ik', q1, q2)) - return u + return e * u + a_minus_e_by_sqrt_c * (q1 @ q2) def _use_cholesky(u, m, n, params): @@ -94,7 +92,7 @@ def _use_cholesky(u, m, n, params): m, n: the dynamic shape of the matrix, where m <= M and n <= N. params: the QDWH parameters. """ - a, b, c = params + a_minus_e, c, e = params _, N = u.shape x = c * (u.T.conj() @ u) + jnp.eye(N, dtype=jnp.dtype(u)) # Pads the lower-right corner with the identity matrix to prevent the Cholesky @@ -111,11 +109,10 @@ def _use_cholesky(u, m, n, params): z = lax_linalg.triangular_solve(y, z, left_side=True, lower=True, transpose_a=True, conjugate_a=True).T.conj() - e = b / c - u = e * u + (a - e) * z - return u + return e * u + a_minus_e * z -def _qdwh(x, m, n, is_hermitian, max_iterations, eps): + +def _qdwh(x, m, n, max_iterations, eps): """QR-based dynamically weighted Halley iteration for polar decomposition.""" # Estimates `alpha` and `beta = alpha * l`, where `alpha` is an estimate of @@ -123,89 +120,134 @@ def _qdwh(x, m, n, is_hermitian, max_iterations, eps): # the smallest singular value of x. if eps is None: eps = float(jnp.finfo(x.dtype).eps) - alpha = (jnp.sqrt(jnp.linalg.norm(x, ord=1)) * - jnp.sqrt(jnp.linalg.norm(x, ord=jnp.inf))).astype(x.dtype) - l = eps + one_norm = jnp.linalg.norm(x, ord=1) + inf_norm = jnp.linalg.norm(x, ord=jnp.inf) + alpha_inverse = lax.rsqrt(one_norm) * lax.rsqrt(inf_norm) + alpha_inverse = jnp.where(one_norm == 0, 1, alpha_inverse) + u = x * alpha_inverse.astype(x.dtype) - u = x / alpha + l = eps # Iteration tolerances. tol_l = 10.0 * eps / 2.0 tol_norm = jnp.cbrt(tol_l) - def cond_fun(state): - _, _, _, is_unconverged, is_not_max_iteration = state - return jnp.logical_and(is_unconverged, is_not_max_iteration) - - def body_fun(state): - u, l, iter_idx, _, _ = state + def get_qr_params(a, b, c): + e = b / c + a_minus_e = a - e + sqrt_c = c ** (1 / 2) + return (a_minus_e / sqrt_c, sqrt_c, e) + + def get_chol_params(a, b, c): + e = b / c + a_minus_e = a - e + return (a_minus_e, c, e) + + CHOLESKY_CUTOFF = 100 + + qr_coefs = [] + chol_coefs = [] + k = 0 + while l + tol_l < 1 and k < max_iterations: + k += 1 + l2 = l * l + dd = (4 * (1 / l2 - 1) / l2) ** (1 / 3) + sqd = (1.0 + dd) ** (1 / 2) + a = sqd + (2 - dd + 2 * (2 - l2) / (l2 * sqd)) ** (1 / 2) + b = (a - 1) ** 2 / 4 + c = a + b - 1 + l = l * (a + b * l2) / (1 + c * l2) + if c > CHOLESKY_CUTOFF: + qr_coefs.append(get_qr_params(a, b, c)) + else: + chol_coefs.append(get_chol_params(a, b, c)) + + def iteration(k, state, update_fn, coefs, test_convergence): + u, _ = state + + if coefs is None: + # As l → 1, the coefficients a, b, c → 3, 1, 3, which is Halley's method. + params = get_chol_params(3, 1, 3) + else: + params = lax.dynamic_index_in_dim(coefs, k, keepdims=False) u_prev = u + u = update_fn(u, m, n, params) + + is_not_converged = True + if test_convergence: + is_not_converged = jnp.linalg.norm(u - u_prev) > tol_norm + return u, is_not_converged + + def iterate(u, coefs, **kwargs): + if not coefs: + return u, True + coefs = jnp.array(coefs).astype(x.dtype) + body = functools.partial(iteration, coefs=coefs, **kwargs) + return lax.fori_loop(0, len(coefs), body, (u, True)) + + u, _ = iterate( + u, coefs=qr_coefs, update_fn=_use_qr, test_convergence=False + ) + u, is_not_converged = iterate( + u, coefs=chol_coefs, update_fn=_use_cholesky, test_convergence=True + ) + + # If l has converged but u still has not, continue with Halley's method + # (coef = None) until convergence. + def cond_fun(state): + k, _, is_not_converged = state + return jnp.logical_and(is_not_converged, k < max_iterations) - # Computes parameters. - l2 = l**2 - dd = jnp.cbrt(4.0 * (1.0 / l2 - 1.0) / l2) - sqd = jnp.sqrt(1.0 + dd) - a = (sqd + jnp.sqrt(8.0 - 4.0 * dd + 8.0 * (2.0 - l2) / (l2 * sqd)) / 2) - a = jnp.real(a) - b = (a - 1.0)**2 / 4.0 - c = a + b - 1.0 - - # Updates l. - l = l * (a + b * l2) / (1.0 + c * l2) - - # Uses QR or Cholesky decomposition. - def true_fn(u): - return _use_qr(u, m, n, params=(a, b, c)) - - def false_fn(u): - return _use_cholesky(u, m, n, params=(a, b, c)) - - u = jax.lax.cond(c > 100, true_fn, false_fn, operand=(u)) - - if is_hermitian: - u = (u + u.T.conj()) / 2.0 - - # Checks convergence. - iterating_l = jnp.abs(1.0 - l) > tol_l - iterating_u = jnp.linalg.norm(u-u_prev) > tol_norm - is_unconverged = jnp.logical_or(iterating_l, iterating_u) - - is_not_max_iteration = iter_idx < max_iterations - - return u, l, iter_idx + 1, is_unconverged, is_not_max_iteration - - iter_idx = 1 - is_unconverged = True - is_not_max_iteration = True - u, _, num_iters, is_unconverged, _ = jax.lax.while_loop( - cond_fun=cond_fun, body_fun=body_fun, - init_val=(u, l, iter_idx, is_unconverged, is_not_max_iteration)) + def body_fun(state): + k, u, is_not_converged = state + u, is_not_converged = iteration( + k, + (u, is_not_converged), + coefs=None, + update_fn=_use_cholesky, + test_convergence=True, + ) + return k + 1, u, is_not_converged + + k = len(qr_coefs) + len(chol_coefs) + num_iters, u, is_not_converged = lax.while_loop( + cond_fun, body_fun, (k, u, is_not_converged) + ) # Applies Newton-Schulz refinement for better accuracy. u = 1.5 * u - 0.5 * u @ (u.T.conj() @ u) h = u.T.conj() @ x - h = (h + h.T.conj()) / 2.0 + h = (h + h.T.conj()) / 2 # Converged within the maximum number of iterations. - is_converged = jnp.logical_not(is_unconverged) + is_converged = jnp.logical_not(is_not_converged) - return u, h, num_iters - 1, is_converged + return u, h, num_iters, is_converged # TODO: Add pivoting. -@functools.partial(jax.jit, static_argnames=('is_hermitian',)) -def qdwh(x, *, is_hermitian=False, max_iterations=None, eps=None, - dynamic_shape: tuple[int, int] | None = None): +@functools.partial( + jax.jit, static_argnames=('is_hermitian', 'max_iterations', 'eps') +) +def qdwh( + x, + *, + is_hermitian: bool = False, + max_iterations: int | None = None, + eps: float | None = None, + dynamic_shape: tuple[int, int] | None = None, +): """QR-based dynamically weighted Halley iteration for polar decomposition. Args: - x: A full-rank matrix, with shape `M x N`. The matrix may be - padded up to that size from a smaller true shape (``dynamic_shape``). - is_hermitian: True if `x` is Hermitian. Default to `False`. - eps: The final result will satisfy - ``|x_k - x_k-1| < |x_k| * (4*eps)**(1/3)`` where `x_k` is the iterate. + x: A full-rank matrix, with shape `M x N`. The matrix may be padded up to + that size from a smaller true shape (``dynamic_shape``). + is_hermitian: True if `x` is Hermitian. Default to `False`. This parameter + is currently unused, but exists for backward compatibility. + eps: The final result will satisfy ``|x_k - x_k-1| < |x_k| * + (4*eps)**(1/3)`` where `x_k` is the iterate. max_iterations: Iterations will terminate after this many steps even if the above is unsatisfied. dynamic_shape: the unpadded shape as an ``(m, n)`` tuple; optional. @@ -216,12 +258,17 @@ def qdwh(x, *, is_hermitian=False, max_iterations=None, eps=None, and `is_converged`, whose value is `True` when the convergence is achieved within the maximum number of iterations. """ + # TODO: Possibly take advantage of Hermitian inputs to speed up the QDWH step. is_hermitian = core.concrete_or_error( bool, is_hermitian, 'The `is_hermitian` argument must be statically ' 'specified to use `qdwh` within JAX transformations.') if max_iterations is None: max_iterations = 10 + else: + max_iterations = core.concrete_or_error( + int, max_iterations, 'The `max_iterations` argument must be statically ' + 'specified to use `qdwh` within JAX transformations.') M, N = x.shape if M < N: @@ -233,8 +280,6 @@ def qdwh(x, *, is_hermitian=False, max_iterations=None, eps=None, m, n = M, N with jax.default_matmul_precision('float32'): - u, h, num_iters, is_converged = _qdwh(x, m, n, is_hermitian, max_iterations, - eps) - + u, h, num_iters, is_converged = _qdwh(x, m, n, max_iterations, eps) return u, h, num_iters, is_converged diff --git a/jax/_src/lax/slicing.py b/jax/_src/lax/slicing.py index 48689bbb0795..b2bd30b3d364 100644 --- a/jax/_src/lax/slicing.py +++ b/jax/_src/lax/slicing.py @@ -14,12 +14,12 @@ from __future__ import annotations -from collections.abc import Sequence +from collections.abc import Callable, Sequence import enum import operator from functools import partial import math -from typing import Callable, NamedTuple +from typing import NamedTuple import weakref import numpy as np @@ -232,7 +232,7 @@ class GatherDimensionNumbers(NamedTuple): in the output of the gather. Must be a tuple of integers in ascending order. start_index_map: for each dimension in `start_indices`, gives the - corresponding dimension in `operand` that is to be sliced. Must be a + corresponding dimension in the `operand` that is to be sliced. Must be a tuple of integers with size equal to `start_indices.shape[-1]`. Unlike XLA's `GatherDimensionNumbers` structure, `index_vector_dim` is @@ -261,8 +261,8 @@ class GatherScatterMode(enum.Enum): will be discarded. PROMISE_IN_BOUNDS: The user promises that indices are in bounds. No additional checking will be - performed. In practice, with the current XLA implementation this means - that, out-of-bounds gathers will be clamped but out-of-bounds scatters will + performed. In practice, with the current XLA implementation this means + that out-of-bounds gathers will be clamped but out-of-bounds scatters will be discarded. Gradients will not be correct if indices are out-of-bounds. """ CLIP = enum.auto() @@ -370,7 +370,7 @@ class ScatterDimensionNumbers(NamedTuple): are the mirror image of `collapsed_slice_dims` in the case of `gather`. scatter_dims_to_operand_dims: for each dimension in `scatter_indices`, gives the corresponding dimension in `operand`. Must be a sequence of integers - with size equal to indices.shape[-1]. + with size equal to `scatter_indices.shape[-1]`. Unlike XLA's `ScatterDimensionNumbers` structure, `index_vector_dim` is implicit; there is always an index vector dimension and it must always be the @@ -1821,10 +1821,13 @@ def _gather_lower(ctx, operand, indices, *, assert mode in (GatherScatterMode.PROMISE_IN_BOUNDS, GatherScatterMode.CLIP), mode dnums = hlo.GatherDimensionNumbers.get( - collapsed_slice_dims=list(dimension_numbers.collapsed_slice_dims), - index_vector_dim=len(ctx.avals_in[1].shape) - 1, - offset_dims=list(dimension_numbers.offset_dims), - start_index_map=list(dimension_numbers.start_index_map)) + collapsed_slice_dims=list(dimension_numbers.collapsed_slice_dims), + operand_batching_dims=[], + start_indices_batching_dims=[], + index_vector_dim=len(ctx.avals_in[1].shape) - 1, + offset_dims=list(dimension_numbers.offset_dims), + start_index_map=list(dimension_numbers.start_index_map), + ) if not core.is_constant_shape(slice_sizes): slice_sizes = mlir.eval_dynamic_shape_as_tensor(ctx, slice_sizes) # TODO(burmako): Fix overly conservative type inference of DynamicGatherOp. @@ -1845,7 +1848,7 @@ def _gather_lower(ctx, operand, indices, *, operand, indices, dnums, - mlir.dense_int_array_v6(slice_sizes), + mlir.dense_int_array(slice_sizes), indices_are_sorted=ir.BoolAttr.get(indices_are_sorted))] mlir.register_lowering(gather_p, _gather_lower) @@ -2475,10 +2478,13 @@ def _scatter_lower(ctx, operand, indices, updates, *, dnums = dimension_numbers scatter_dnums = hlo.ScatterDimensionNumbers.get( - update_window_dims=list(dnums.update_window_dims), - inserted_window_dims=list(dnums.inserted_window_dims), - scattered_dims_to_operand_dims=list(dnums.scatter_dims_to_operand_dims), - index_vector_dim=len(ctx.avals_in[1].shape) - 1) + update_window_dims=list(dnums.update_window_dims), + inserted_window_dims=list(dnums.inserted_window_dims), + input_batching_dims=[], + scatter_indices_batching_dims=[], + scattered_dims_to_operand_dims=list(dnums.scatter_dims_to_operand_dims), + index_vector_dim=len(ctx.avals_in[1].shape) - 1, + ) result = mlir.aval_to_ir_types(aval_out) operand = [operand] updates = [updates] @@ -2532,10 +2538,13 @@ def _scatter_add_lower_gpu(ctx, operand, indices, updates, aval_out, = ctx.avals_out dnums = dimension_numbers scatter_dnums = hlo.ScatterDimensionNumbers.get( - update_window_dims=list(dnums.update_window_dims), - inserted_window_dims=list(dnums.inserted_window_dims), - scattered_dims_to_operand_dims=list(dnums.scatter_dims_to_operand_dims), - index_vector_dim=len(ctx.avals_in[1].shape) - 1) + update_window_dims=list(dnums.update_window_dims), + inserted_window_dims=list(dnums.inserted_window_dims), + input_batching_dims=[], + scatter_indices_batching_dims=[], + scattered_dims_to_operand_dims=list(dnums.scatter_dims_to_operand_dims), + index_vector_dim=len(ctx.avals_in[1].shape) - 1, + ) real_dtype = _real_dtype(aval_out.dtype) operand_type_part = mlir.aval_to_ir_types( core.ShapedArray(aval_out.shape, real_dtype)) diff --git a/jax/_src/lax/svd.py b/jax/_src/lax/svd.py index cbe229bc62bf..77ff4297e137 100644 --- a/jax/_src/lax/svd.py +++ b/jax/_src/lax/svd.py @@ -15,7 +15,7 @@ """A JIT-compatible library for QDWH-based singular value decomposition. QDWH is short for QR-based dynamically weighted Halley iteration. The Halley -iteration implemented through QR decmopositions is numerically stable and does +iteration implemented through QR decompositions is numerically stable and does not require solving a linear system involving the iteration matrix or computing its inversion. This is desirable for multicore and heterogeneous computing systems. @@ -46,58 +46,6 @@ import jax.numpy as jnp -@functools.partial(jax.jit, static_argnums=(2, 3, 4)) -def _constant_svd( - a: Any, - return_nan: bool, - full_matrices: bool, - compute_uv: bool = True, - subset_by_index: tuple[int, int] | None = None, -) -> Any | Sequence[Any]: - """SVD on matrix of all zeros.""" - m, n = a.shape - k = min(m, n) - if subset_by_index is not None: - k = min(k, subset_by_index[1] - subset_by_index[0]) - - s = jnp.where( - return_nan, - jnp.full(shape=(k,), fill_value=jnp.nan, dtype=a.real.dtype), - jnp.zeros(shape=(k,), dtype=a.real.dtype), - ) - if compute_uv: - fill_value = ( - jnp.nan + 1j * jnp.nan - if jnp.issubdtype(a.dtype, jnp.complexfloating) - else jnp.nan - ) - if full_matrices: - u = jnp.where( - return_nan, - jnp.full((m, m), fill_value, dtype=a.dtype), - jnp.eye(m, m, dtype=a.dtype), - ) - vh = jnp.where( - return_nan, - jnp.full((n, n), fill_value, dtype=a.dtype), - jnp.eye(n, n, dtype=a.dtype), - ) - else: - u = jnp.where( - return_nan, - jnp.full((m, k), fill_value, dtype=a.dtype), - jnp.eye(m, k, dtype=a.dtype), - ) - vh = jnp.where( - return_nan, - jnp.full((k, n), fill_value, dtype=a.dtype), - jnp.eye(k, n, dtype=a.dtype), - ) - return (u, s, vh) - else: - return s - - @functools.partial(jax.jit, static_argnums=(1, 2, 3, 4)) def _svd_tall_and_square_input( a: Any, @@ -111,7 +59,7 @@ def _svd_tall_and_square_input( Args: a: A matrix of shape `m x n` with `m >= n`. hermitian: True if `a` is Hermitian. - compute_uv: Whether to compute also `u` and `v` in addition to `s`. + compute_uv: Whether to also compute `u` and `v` in addition to `s`. max_iterations: The predefined maximum number of iterations of QDWH. Returns: @@ -121,25 +69,29 @@ def _svd_tall_and_square_input( `a = (u * s) @ v.T.conj()`. For `compute_uv=False`, only `s` is returned. """ - u, h, _, _ = lax.linalg.qdwh(a, is_hermitian=hermitian, - max_iterations=max_iterations) + u_p, h, _, _ = lax.linalg.qdwh( + a, is_hermitian=hermitian, max_iterations=max_iterations + ) # TODO: Uses `eigvals_only=True` if `compute_uv=False`. - v, s = lax.linalg.eigh(h, subset_by_index=subset_by_index) + v, s = lax.linalg.eigh( + h, subset_by_index=subset_by_index, sort_eigenvalues=False + ) + # Singular values are non-negative by definition. But eigh could return small # negative values, so we clamp them to zero. s = jnp.maximum(s, 0.0) - # Flips the singular values in descending order. - s_out = jnp.flip(s) + # Sort or reorder singular values to be in descending order. + sort_idx = jnp.argsort(s, descending=True) + s_out = s[sort_idx] if not compute_uv: return s_out # Reorders eigenvectors. - v_out = jnp.fliplr(v) - - u_out = u @ v_out + v_out = v[:, sort_idx] + u_out = u_p @ v_out # Makes correction if computed `u` from qdwh is not unitary. # Section 5.5 of Nakatsukasa, Yuji, and Nicholas J. Higham. "Stable and @@ -148,89 +100,16 @@ def _svd_tall_and_square_input( # 35, no. 3 (2013): A1325-A1349. def correct_rank_deficiency(u_out): u_out, r = lax.linalg.qr(u_out, full_matrices=False) - u_out = u_out @ jnp.diag(lax.sign(jnp.diag(r))) + u_out = u_out @ jnp.diag(jnp.where(jnp.diag(r) >= 0, 1, -1)) return u_out eps = float(jnp.finfo(a.dtype).eps) - u_out = lax.cond(s[0] < a.shape[1] * eps * s_out[0], - correct_rank_deficiency, - lambda u_out: u_out, - operand=(u_out)) - + do_correction = s_out[-1] <= a.shape[1] * eps * s_out[0] + cond_f = lambda args: args[1] + body_f = lambda args: (correct_rank_deficiency(args[0]), False) + u_out, _ = lax.while_loop(cond_f, body_f, (u_out, do_correction)) return (u_out, s_out, v_out) - -@functools.partial(jax.jit, static_argnums=(1, 2, 3, 4, 5)) -def _qdwh_svd( - a: Any, - full_matrices: bool, - compute_uv: bool = True, - hermitian: bool = False, - max_iterations: int = 10, - subset_by_index: tuple[int, int] | None = None, -) -> Any | Sequence[Any]: - """Singular value decomposition. - - Args: - a: A matrix of shape `m x n`. - full_matrices: If True, `u` and `vh` have the shapes `m x m` and `n x n`, - respectively. If False, the shapes are `m x k` and `k x n`, respectively, - where `k = min(m, n)`. - compute_uv: Whether to compute also `u` and `v` in addition to `s`. - hermitian: True if `a` is Hermitian. - max_iterations: The predefined maximum number of iterations of QDWH. - - Returns: - A 3-tuple (`u`, `s`, `vh`), where `u` and `vh` are unitary matrices, - `s` is vector of length `k` containing the singular values in the - non-increasing order, and `k = min(m, n)`. The shapes of `u` and `vh` - depend on the value of `full_matrices`. For `compute_uv=False`, - only `s` is returned. - """ - m, n = a.shape - - is_flip = False - if m < n: - a = a.T.conj() - m, n = a.shape - is_flip = True - - reduce_to_square = False - if full_matrices: - q_full, a_full = lax.linalg.qr(a, full_matrices=True) - q = q_full[:, :n] - u_out_null = q_full[:, n:] - a = a_full[:n, :] - reduce_to_square = True - else: - # The constant `1.15` comes from Yuji Nakatsukasa's implementation - # https://www.mathworks.com/matlabcentral/fileexchange/36830-symmetric-eigenvalue-decomposition-and-the-svd?s_tid=FX_rc3_behav - if m > 1.15 * n: - q, a = lax.linalg.qr(a, full_matrices=False) - reduce_to_square = True - - if not compute_uv: - with jax.default_matmul_precision('float32'): - return _svd_tall_and_square_input( - a, hermitian, compute_uv, max_iterations, subset_by_index - ) - - with jax.default_matmul_precision('float32'): - u_out, s_out, v_out = _svd_tall_and_square_input( - a, hermitian, compute_uv, max_iterations, subset_by_index - ) - if reduce_to_square: - u_out = q @ u_out - - if full_matrices: - u_out = jnp.hstack((u_out, u_out_null)) - - if is_flip: - return(v_out, s_out, u_out.T.conj()) - - return (u_out, s_out, v_out.T.conj()) - - @functools.partial(jax.jit, static_argnums=(1, 2, 3, 4, 5)) def svd( a: Any, @@ -247,11 +126,11 @@ def svd( full_matrices: If True, `u` and `vh` have the shapes `m x m` and `n x n`, respectively. If False, the shapes are `m x k` and `k x n`, respectively, where `k = min(m, n)`. - compute_uv: Whether to compute also `u` and `v` in addition to `s`. + compute_uv: Whether to also compute `u` and `v` in addition to `s`. hermitian: True if `a` is Hermitian. max_iterations: The predefined maximum number of iterations of QDWH. subset_by_index: Optional 2-tuple [start, end] indicating the range of - indices of singular componenets to compute. For example, if + indices of singular components to compute. For example, if ``subset_by_index`` = [0,2], then ``svd`` computes the two largest singular values (and their singular vectors if `compute_uv` is true. @@ -309,29 +188,56 @@ def svd( # subset_by_index accordingly. subset_by_index = (rank - subset_by_index[1], rank - subset_by_index[0]) - # QDWH algorithm fails at zero-matrix `A` and produces all NaNs, which can - # be seen from a dynamically weighted Halley (DWH) iteration: - # X_{k+1} = X_k(a_k I + b_k {X_k}^H X_k)(I + c_k {X_k}^H X_k)^{−1} and - # X_0 = A/alpha, where alpha = ||A||_2, the triplet (a_k, b_k, c_k) are - # weighting parameters, and X_k denotes the k^{th} iterate. - all_zero = jnp.all(a == 0) - non_finite = jnp.logical_not(jnp.all(jnp.isfinite(a))) - return lax.cond( - jnp.logical_or(all_zero, non_finite), - functools.partial( - _constant_svd, - return_nan=non_finite, - full_matrices=full_matrices, - compute_uv=compute_uv, - subset_by_index=subset_by_index, - ), - functools.partial( - _qdwh_svd, - full_matrices=full_matrices, - compute_uv=compute_uv, - hermitian=hermitian, - max_iterations=max_iterations, - subset_by_index=subset_by_index, - ), - operand=(a), + m, n = a.shape + is_flip = False + if m < n: + a = a.T.conj() + m, n = a.shape + is_flip = True + + reduce_to_square = False + if full_matrices: + q_full, a_full = lax.linalg.qr(a, full_matrices=True) + q = q_full[:, :n] + u_out_null = q_full[:, n:] + a = a_full[:n, :] + reduce_to_square = True + else: + # The constant `1.15` comes from Yuji Nakatsukasa's implementation + # https://www.mathworks.com/matlabcentral/fileexchange/36830-symmetric-eigenvalue-decomposition-and-the-svd?s_tid=FX_rc3_behav + if m > 1.15 * n: + q, a = lax.linalg.qr(a, full_matrices=False) + reduce_to_square = True + + if not compute_uv: + with jax.default_matmul_precision('float32'): + return _svd_tall_and_square_input( + a, hermitian, compute_uv, max_iterations, subset_by_index + ) + + with jax.default_matmul_precision('float32'): + u_out, s_out, v_out = _svd_tall_and_square_input( + a, hermitian, compute_uv, max_iterations, subset_by_index + ) + if reduce_to_square: + u_out = q @ u_out + + if full_matrices: + u_out = jnp.hstack((u_out, u_out_null)) + + is_finite = jnp.all(jnp.isfinite(a)) + cond_f = lambda args: jnp.logical_not(args[0]) + body_f = lambda args: ( + jnp.array(True), + jnp.full_like(u_out, jnp.nan), + jnp.full_like(s_out, jnp.nan), + jnp.full_like(v_out, jnp.nan), + ) + _, u_out, s_out, v_out = lax.while_loop( + cond_f, body_f, (is_finite, u_out, s_out, v_out) ) + + if is_flip: + return (v_out, s_out, u_out.T.conj()) + + return (u_out, s_out, v_out.T.conj()) diff --git a/jax/_src/lax/windowed_reductions.py b/jax/_src/lax/windowed_reductions.py index af85a04d7bb2..096fce7deb3a 100644 --- a/jax/_src/lax/windowed_reductions.py +++ b/jax/_src/lax/windowed_reductions.py @@ -14,41 +14,46 @@ from __future__ import annotations -from collections.abc import Sequence +from collections.abc import Callable, Sequence from functools import partial -from typing import Callable import warnings -import numpy as np - from jax import tree_util - -from jax._src import ad_util from jax._src import core from jax._src import dispatch from jax._src import dtypes from jax._src import util -from jax._src.core import ShapedArray, ConcreteArray +from jax._src.core import ConcreteArray, ShapedArray from jax._src.interpreters import ad from jax._src.interpreters import batching from jax._src.interpreters import mlir -from jax._src.lax import lax from jax._src.lax import convolution +from jax._src.lax import lax from jax._src.lax import slicing from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import hlo from jax._src.numpy.ufuncs import logaddexp from jax._src.typing import Array +import numpy as np +from jax._src.core import ClosedJaxpr +from jax._src.core import jaxpr_as_fun +from jax._src.interpreters.ad import jvp_jaxpr +from jax._src import ad_util map = util.safe_map zip = util.safe_zip -def reduce_window(operand, init_value, computation: Callable, - window_dimensions: core.Shape, window_strides: Sequence[int], - padding: str | Sequence[tuple[int, int]], - base_dilation: Sequence[int] | None = None, - window_dilation: Sequence[int] | None = None) -> Array: +def _reduce_window( + operand, + init_value, + computation, + window_dimensions: core.Shape, + window_strides: Sequence[int], + padding: str | Sequence[tuple[int, int]], + base_dilation: Sequence[int] | None = None, + window_dilation: Sequence[int] | None = None, +): """Wraps XLA's `ReduceWindowWithGeneralPadding `_ operator. @@ -56,13 +61,18 @@ def reduce_window(operand, init_value, computation: Callable, flat_operands, operand_tree = tree_util.tree_flatten(operand) flat_init_values, init_value_tree = tree_util.tree_flatten(init_value) if operand_tree != init_value_tree: - raise ValueError('Operands must have the same tree structure as ' - f'init_values: {operand_tree} vs. {init_value_tree}') - if len(flat_operands) == 0: - raise ValueError('reduce_window must have at least one operand.') + raise ValueError( + "Operands must have the same tree structure as " + f"init_values: {operand_tree} vs. {init_value_tree}" + ) if len(flat_operands) != len(flat_init_values): - raise ValueError('Must have same total number of operands as init_values: ' - f' {len(flat_operands)} vs. {len(flat_init_values)}') + raise ValueError( + "Must have same total number of operands as init_values: " + f" {len(flat_operands)} vs. {len(flat_init_values)}" + ) + + if len(flat_operands) == 0: + raise ValueError("reduce_window must have at least one operand.") if isinstance(padding, str): dilated_window_dims = ( window_dimensions if window_dilation is None else @@ -82,21 +92,52 @@ def reduce_window(operand, init_value, computation: Callable, else: flat_init_avals = map(lax._abstractify, flat_init_values) jaxpr, out_tree = lax._variadic_reduction_jaxpr( - computation, tuple(flat_init_avals), init_value_tree) + computation, tuple(flat_init_avals), init_value_tree + ) if operand_tree != out_tree: raise ValueError( 'reduce_window output must have the same tree structure as the operands' f' {operand_tree} vs. {out_tree}') out_flat = reduce_window_p.bind( - *flat_operands, *flat_init_values, jaxpr=jaxpr.jaxpr, - consts=tuple(jaxpr.consts), window_dimensions=tuple(window_dimensions), - window_strides=tuple(window_strides), padding=padding, + *flat_operands, + *flat_init_values, + jaxpr=jaxpr.jaxpr, + consts=tuple(jaxpr.consts), + window_dimensions=tuple(window_dimensions), + window_strides=tuple(window_strides), + padding=padding, base_dilation=tuple(base_dilation), - window_dilation=tuple(window_dilation)) + window_dilation=tuple(window_dilation), + ) return tree_util.tree_unflatten(out_tree, out_flat) -def _get_monoid_window_reducer(monoid_op: Callable, - xs: Sequence[Array]) -> Callable | None: + + +def reduce_window( + operand, + init_value, + computation: Callable, + window_dimensions: core.Shape, + window_strides: Sequence[int], + padding: str | Sequence[tuple[int, int]], + base_dilation: Sequence[int] | None = None, + window_dilation: Sequence[int] | None = None, +) -> Array: + return _reduce_window( + operand, + init_value, + computation, + window_dimensions, + window_strides, + padding, + base_dilation, + window_dilation, + ) + + +def _get_monoid_window_reducer( + monoid_op, xs: Sequence[Array] +) -> Callable | None: if len(xs) != 1: return None x, = xs @@ -112,6 +153,7 @@ def _get_monoid_window_reducer(monoid_op: Callable, and _reduce_window_min) return None + def _reduce_window_sum(operand: Array, window_dimensions: core.Shape, window_strides: Sequence[int], padding: Sequence[tuple[int, int]], @@ -260,10 +302,19 @@ def _select_and_gather_add(tangents: Array, operand: Array, def _reduce_window_abstract_eval_rule( - *avals, jaxpr, consts, window_dimensions, window_strides, padding, - base_dilation, window_dilation): + *avals, + jaxpr, + consts, + window_dimensions, + window_strides, + padding, + base_dilation, + window_dilation, +): operand_avals, init_val_avals = util.split_list(avals, [len(avals) // 2]) - if any(o.dtype != iv.dtype for o, iv in zip(operand_avals, init_val_avals)): + if any( + o.dtype != iv.dtype for o, iv in zip(operand_avals, init_val_avals) + ): msg = ("reduce_window got inconsistent dtypes for operands and init_values:" " got operand dtypes {} and init_value dtypes {}.") raise TypeError(msg.format([o.dtype for o in operand_avals], @@ -273,13 +324,28 @@ def _reduce_window_abstract_eval_rule( "have shapes {}.") raise TypeError(msg.format([v.shape for v in init_val_avals])) out_shape = _common_reduce_window_shape_rule( - operand_avals[0], window_dimensions, window_strides, padding, - base_dilation, window_dilation) + operand_avals[0], + window_dimensions, + window_strides, + padding, + base_dilation, + window_dilation, + ) return tuple(ShapedArray(out_shape, op.dtype) for op in operand_avals) + def _generic_reduce_window_batch_rule( - batched_args, batch_dims, *, jaxpr, consts, window_dimensions, - window_strides, padding, base_dilation, window_dilation): + batched_args, + batch_dims, + *, + jaxpr, + consts, + window_dimensions, + window_strides, + padding, + base_dilation, + window_dilation, +): num_operands = len(batched_args) // 2 operands, init_values = util.split_list(batched_args, [num_operands]) operand_bdims, init_value_bdims = util.split_list(batch_dims, [num_operands]) @@ -306,14 +372,68 @@ def _generic_reduce_window_batch_rule( reduce_window_p = core.Primitive('reduce_window') + + +def reduce_window_jvp( + primals, + tangents, + window_dimensions, + window_strides, + padding, + base_dilation, + window_dilation, + jaxpr, + consts, +): + + reduction_jaxpr = jaxpr + + n = len(primals) // 2 # number of primal operands + operand, init_value = util.split_list(primals, [n]) + operand_tangent, init_value_tangent = util.split_list(tangents, [n]) + if not all(isinstance(t, ad.Zero) for t in init_value_tangent): + raise TypeError("reduce_window jvp does not support non-zero init_value_tangent.") + + init_value_tangent = map(ad_util.instantiate, init_value_tangent) + c_reduction_jaxpr = ClosedJaxpr(reduction_jaxpr, consts) + jvp_reduction = jvp_jaxpr(c_reduction_jaxpr, (True,) * len(tangents), [False] * len(init_value_tangent))[0] + + def wrapper(left, right): + pl, tl = util.split_list(left, [n]) + pr, tr = util.split_list(right, [n]) + return jaxpr_as_fun(jvp_reduction)(*pl, *pr, *tl, *tr) + + jvp_primals_tangents = _reduce_window( + operand=[*operand, *operand_tangent], + init_value=[*init_value, *init_value_tangent], + computation=wrapper, + window_dimensions=window_dimensions, + window_strides=window_strides, + padding=padding, + base_dilation=base_dilation, + window_dilation=window_dilation, + ) + primals, tangents = util.split_list(jvp_primals_tangents, [len(jvp_primals_tangents) // 2]) + return [*primals], [*tangents] + +ad.primitive_jvps[reduce_window_p] = reduce_window_jvp reduce_window_p.multiple_results = True reduce_window_p.def_impl(partial(dispatch.apply_primitive, reduce_window_p)) reduce_window_p.def_abstract_eval(_reduce_window_abstract_eval_rule) batching.primitive_batchers[reduce_window_p] = _generic_reduce_window_batch_rule -def _generic_reduce_window_lower(ctx, *args, jaxpr, consts, - window_dimensions, window_strides, padding, - base_dilation, window_dilation): + +def _generic_reduce_window_lower( + ctx, + *args, + jaxpr, + consts, + window_dimensions, + window_strides, + padding, + base_dilation, + window_dilation, +): operands, init_values = util.split_list(args, [len(args) // 2]) _, init_value_avals = util.split_list(ctx.avals_in, [len(operands)]) @@ -330,11 +450,15 @@ def reducer_body(reducer: ir.Block) -> Sequence[ir.Value]: reducer_name="generic_reduce_window_reducer", reducer_body=reducer_body, operands=operands, - init_values=init_values, init_values_avals=init_value_avals, + init_values=init_values, + init_values_avals=init_value_avals, out_avals=ctx.avals_out, - window_dimensions=window_dimensions, window_strides=window_strides, - base_dilation=base_dilation, window_dilation=window_dilation, - padding=padding) + window_dimensions=window_dimensions, + window_strides=window_strides, + base_dilation=base_dilation, + window_dilation=window_dilation, + padding=padding, + ) mlir.register_lowering(reduce_window_p, _generic_reduce_window_lower) @@ -402,9 +526,14 @@ def _reduce_window_chooser_jvp_rule(prim, g, operand, *, window_dimensions, window_dilation) -def _common_reduce_window_shape_rule(operand, window_dimensions, - window_strides, padding, base_dilation, - window_dilation): +def _common_reduce_window_shape_rule( + operand, + window_dimensions, + window_strides, + padding, + base_dilation, + window_dilation, +): lax._check_shapelike("reduce_window", "window_dimensions", window_dimensions, non_zero_shape=True) lax._check_shapelike("reduce_window", "window_strides", window_strides, @@ -412,8 +541,10 @@ def _common_reduce_window_shape_rule(operand, window_dimensions, lax._check_shapelike("reduce_window", "base_dilation", base_dilation) lax._check_shapelike("reduce_window", "window_dilation", window_dilation) if operand.ndim != len(window_dimensions): - msg = ("reduce_window got the wrong number of window_dimensions for " - "operand: got operand shape {} with window_dimensions {}.") + msg = ( + "reduce_window got the wrong number of window_dimensions for " + "operand: got operand shape {} with window_dimensions {}." + ) raise TypeError(msg.format(operand.shape, window_dimensions)) if len(window_strides) != len(window_dimensions): msg = ("reduce_window got inconsistent window_strides and " @@ -443,6 +574,7 @@ def reduce_window_shape_tuple(operand_shape, window_dimensions, window_strides, operand_padded = tuple(d + pl + ph for d, (pl, ph) in zip(operand_shape, padding)) return tuple(map(core.stride_dim, operand_padded, window_dimensions, window_strides)) + reduce_window_max_p = lax.standard_primitive( _common_reduce_window_shape_rule, lax._input_dtype, 'reduce_window_max') ad.defjvp(reduce_window_max_p, partial(_reduce_window_chooser_jvp_rule, @@ -463,24 +595,36 @@ def reduce_window_shape_tuple(operand_shape, window_dimensions, window_strides, def _reduce_window_lower( reduce_op, - init_value, ctx, operand, *, - window_dimensions, window_strides, padding, base_dilation, - window_dilation): + init_value, + ctx, + operand, + *, + window_dimensions, + window_strides, + padding, + base_dilation, + window_dilation, +): operand_aval, = ctx.avals_in scalar_aval = operand_aval.update(shape=()) - return mlir.reduce_window(ctx, + return mlir.reduce_window( + ctx, reducer_name=f"reduce_window_{scalar_aval.dtype}_reducer", reducer_body=lambda reducer: [reduce_op(*reducer.arguments)], operands=[operand], - init_values=[mlir.full_like_aval(ctx, init_value(scalar_aval.dtype), - scalar_aval)], + init_values=[ + mlir.full_like_aval(ctx, init_value(scalar_aval.dtype), scalar_aval) + ], init_values_avals=[scalar_aval], out_avals=ctx.avals_out, window_dimensions=window_dimensions, - window_strides=window_strides, base_dilation=base_dilation, - window_dilation=window_dilation, padding=padding) + window_strides=window_strides, + base_dilation=base_dilation, + window_dilation=window_dilation, + padding=padding, + ) mlir.register_lowering(reduce_window_sum_p, partial( @@ -520,8 +664,8 @@ def _select_and_scatter_lower( operand, source, init_value, - window_dimensions=mlir.dense_int_array_v6(window_dimensions), - window_strides=mlir.dense_int_array_v6(window_strides), + window_dimensions=mlir.dense_int_array(window_dimensions), + window_strides=mlir.dense_int_array(window_strides), padding=ir.DenseIntElementsAttr.get(np.asarray(padding, np.int64), shape=(len(padding), 2))) select = op.select.blocks.append(scalar_type, scalar_type) @@ -757,7 +901,9 @@ def snd(t, t_aval): double_word_out_aval = out_aval.update(dtype=double_word_dtype) def reducer_body(reducer: ir.Block) -> Sequence[ir.Value]: - x, y = reducer.arguments + x: ir.Value + y: ir.Value + x, y = reducer.arguments # type: ignore assert select_prim is lax.ge_p or select_prim is lax.le_p cmp_op = "GE" if select_prim is lax.ge_p else "LE" out = hlo.SelectOp(mlir.compare_hlo(fst(x), fst(y), cmp_op), x, y) @@ -871,6 +1017,7 @@ def _select_and_gather_add_batching_rule( _select_and_gather_add_using_variadic_reducewindow, multiple_results=False)) + # TODO(b/183233858): use variadic reducewindow on GPU, when implemented. mlir.register_lowering( select_and_gather_add_p, diff --git a/jax/_src/lax_reference.py b/jax/_src/lax_reference.py index 8fc1756422d7..81209f6ed34a 100644 --- a/jax/_src/lax_reference.py +++ b/jax/_src/lax_reference.py @@ -239,6 +239,35 @@ def dot_general(lhs, rhs, dimension_numbers): dtype=dtype) return out.astype(dtypes.bfloat16) if lhs.dtype == dtypes.bfloat16 else out +def ragged_dot( + lhs, + rhs, + group_sizes, +): + """Reference ragged dot implementation.""" + m, lk = lhs.shape + group_count, rk, n = rhs.shape + assert lk == rk + assert group_count == group_sizes.shape[0] + assert lhs.dtype == rhs.dtype + + out = np.zeros((m, n), dtype=lhs.dtype) + result_iota = np.expand_dims(np.arange(out.shape[0]), list(range(1, out.ndim))) + start = 0 + for i, size in enumerate(group_sizes): + out += np.where( + np.logical_and(start <= result_iota, result_iota < (start + size)), + np.einsum( + "nk,km->nm", + lhs, + rhs[i, :, :], + dtype=np.float32 if lhs.dtype == dtypes.bfloat16 else None, + ), + np.zeros(out.shape, dtype=out.dtype), + ) + start += size + return out.astype(dtypes.bfloat16) if lhs.dtype == dtypes.bfloat16 else out + def broadcast(operand, sizes): return np.broadcast_to(operand, sizes + np.shape(operand)) diff --git a/jax/_src/layout.py b/jax/_src/layout.py index 7b558794d5a2..3b4424345d00 100644 --- a/jax/_src/layout.py +++ b/jax/_src/layout.py @@ -14,46 +14,139 @@ from __future__ import annotations +from typing import Union + +import numpy as np +from jax._src.dtypes import iinfo, issubdtype +from jax._src.sharding import Sharding +from jax._src.sharding_impls import AUTO as AutoSharding, is_auto from jax._src.lib import xla_client as xc +from jax._src.lib import xla_extension_version -# TODO(yashkatariya): Revist the 3 class hierarchy after ifrt::Layout lands. -class Layout: - pass +class AutoLayout: + def __repr__(self): + return "AUTO" -class XLACompatibleLayout(Layout): - def _to_xla_layout(self) -> str: - raise NotImplementedError("Subclasses should implement this method.") +if xla_extension_version >= 274: + class DeviceLocalLayout: + major_to_minor: tuple[int, ...] + tiling: tuple[tuple[int, ...], ...] | None + AUTO = AutoLayout() -class SpecifiedLayout(XLACompatibleLayout): - layout: xc.Layout + def __init__(self, major_to_minor: tuple[int, ...], + tiling: tuple[tuple[int, ...], ...] | None = None): + self.major_to_minor = tuple(major_to_minor) + self.tiling = None if tiling is None else tuple(map(tuple, tiling)) - def __init__(self, layout: xc.Layout): - self._layout = layout - self._layout_str = self._layout.to_string() - self._minor_to_major = self._layout.minor_to_major() + @staticmethod + def from_pjrt_layout(pjrt_layout: xc.PjRtLayout): + xla_layout = pjrt_layout._xla_layout() + return DeviceLocalLayout(xla_layout.minor_to_major()[::-1], # pytype: disable=wrong-arg-types + xla_layout.tiling()) - def __repr__(self): - return f'SpecifiedLayout({self._layout_str})' + def __repr__(self): + return (f'DeviceLocalLayout(major_to_minor={self.major_to_minor},' + f' tiling={self.tiling})') - def __hash__(self): - return hash(self._layout) + def __hash__(self): + return hash((self.major_to_minor, self.tiling)) - def __eq__(self, other): - if not isinstance(other, SpecifiedLayout): - return False - return self._layout == other._layout + def __eq__(self, other): + if not isinstance(other, DeviceLocalLayout): + return False + return (self.major_to_minor == other.major_to_minor and + self.tiling == other.tiling) + + def _to_xla_layout(self, dtype) -> str: + if self.tiling is None: + xla_layout = xc.Layout(self.major_to_minor[::-1]) + else: + if issubdtype(dtype, np.integer): + sub_byte_size = iinfo(dtype).bits if iinfo(dtype).bits < 8 else 0 + else: + sub_byte_size = 0 + xla_layout = xc.Layout(self.major_to_minor[::-1], self.tiling, # type: ignore + sub_byte_size) + return str(xla_layout) +else: + class DeviceLocalLayout: # type: ignore + layout: xc.PjRtLayout + + AUTO = AutoLayout() + + def __init__(self, layout: xc.PjRtLayout): + self._layout = layout + self._layout_str = str(self._layout) + + @staticmethod + def from_pjrt_layout(pjrt_layout: xc.PjRtLayout): + return DeviceLocalLayout(pjrt_layout) # type: ignore - def _to_xla_layout(self) -> str: - return self._layout_str + def __repr__(self): + return f'DeviceLocalLayout({self._layout_str})' + def __hash__(self): + return hash(self._layout) -class LayoutRequest: + def __eq__(self, other): + if not isinstance(other, DeviceLocalLayout): + return False + return self._layout == other._layout + + def _to_xla_layout(self, dtype) -> str: + return self._layout_str + + +LayoutOptions = Union[DeviceLocalLayout, None, AutoLayout] # pytype: disable=invalid-annotation +ShardingOptions = Union[Sharding, None, AutoSharding] + + +class Layout: + __slots__ = ['device_local_layout', 'sharding'] + + def __init__(self, device_local_layout: LayoutOptions = None, + sharding: ShardingOptions = None): + # If layout is concrete and sharding is not, error. + if (isinstance(device_local_layout, DeviceLocalLayout) and + (sharding is None or is_auto(sharding))): + raise ValueError( + 'Sharding has to be concrete when layout is of type' + f' {type(device_local_layout)}. Please pass a' + ' `jax.sharding.NamedSharding`, `jax.sharding.PositionalSharding` or' + ' `jax.sharding.SingleDeviceSharding` to the sharding argument. Got' + f' sharding {sharding}' + ) + if not isinstance( + device_local_layout, (DeviceLocalLayout, type(None), AutoLayout)): + raise TypeError( + 'Invalid value received for the device_local_layout argument.' + ' Expected values are `None`, `DeviceLocalLayout.AUTO` or an' + f' instance of `DeviceLocalLayout`. Got {device_local_layout} of' + f' type {type(device_local_layout)}' + ) + if not isinstance( + sharding, (Sharding, type(None), AutoSharding)): + raise TypeError( + 'Invalid value received for the sharding argument. Expected values' + ' are `None`, `pjit.AUTO` or an instance of `jax.Sharding`. Got' + f' {sharding} of type {type(sharding)}') + + self.device_local_layout = device_local_layout + self.sharding = sharding def __repr__(self): - return "Request a layout from the compiler" + return (f'Layout(device_local_layout={self.device_local_layout},' + f' sharding={self.sharding})') + + def __hash__(self): + return hash((self.device_local_layout, self.sharding)) -AUTO = LayoutRequest() + def __eq__(self, other): + if not isinstance(other, Layout): + return False + return (self.device_local_layout == other.device_local_layout and + self.sharding == other.sharding) diff --git a/jax/_src/lazy_loader.py b/jax/_src/lazy_loader.py index 6041c77c65c0..cf6e68e49c81 100644 --- a/jax/_src/lazy_loader.py +++ b/jax/_src/lazy_loader.py @@ -14,9 +14,9 @@ """A LazyLoader class.""" -from collections.abc import Sequence +from collections.abc import Callable, Sequence import importlib -from typing import Any, Callable +from typing import Any def attach(package_name: str, submodules: Sequence[str]) -> tuple[ diff --git a/jax/_src/lib/BUILD b/jax/_src/lib/BUILD index e6c6f7e3f23d..c0d88759dcc0 100644 --- a/jax/_src/lib/BUILD +++ b/jax/_src/lib/BUILD @@ -31,6 +31,7 @@ py_library_providing_imports_info( "__init__.py", "mlir/__init__.py", "mlir/dialects/__init__.py", + "mosaic_gpu.py", "triton.py", ], lib_rule = pytype_strict_library, diff --git a/jax/_src/lib/__init__.py b/jax/_src/lib/__init__.py index 2347cc63c81c..b2bcc53a53f8 100644 --- a/jax/_src/lib/__init__.py +++ b/jax/_src/lib/__init__.py @@ -87,8 +87,6 @@ def _parse_version(v: str) -> tuple[int, ...]: import jaxlib.xla_client as xla_client import jaxlib.lapack as lapack -import jaxlib.ducc_fft as ducc_fft - xla_extension = xla_client._xla pytree = xla_client._xla.pytree jax_jit = xla_client._xla.jax_jit @@ -102,7 +100,10 @@ def _xla_gc_callback(*args): try: import jaxlib.cuda._versions as cuda_versions # pytype: disable=import-error except ImportError: - cuda_versions = None + try: + import jax_cuda12_plugin._versions as cuda_versions # pytype: disable=import-error + except ImportError: + cuda_versions = None import jaxlib.gpu_solver as gpu_solver # pytype: disable=import-error import jaxlib.gpu_sparse as gpu_sparse # pytype: disable=import-error @@ -132,11 +133,6 @@ def _cuda_path() -> str | None: # both of the things XLA looks for in the cuda path, namely bin/ptxas and # nvvm/libdevice/libdevice.10.bc path = _jaxlib_path.parent / "nvidia" / "cuda_nvcc" - if path.is_dir(): - return str(path) - # Failing that, we use the copy of libdevice.10.bc we include with jaxlib and - # hope that the user has ptxas in their PATH. - path = _jaxlib_path / "cuda" if path.is_dir(): return str(path) return None @@ -144,3 +140,5 @@ def _cuda_path() -> str | None: cuda_path = _cuda_path() transfer_guard_lib = xla_client._xla.transfer_guard_lib + +Device = xla_client._xla.Device diff --git a/jax/_src/lib/mlir/__init__.py b/jax/_src/lib/mlir/__init__.py index 215e3bb697e1..5fc9dff3ac49 100644 --- a/jax/_src/lib/mlir/__init__.py +++ b/jax/_src/lib/mlir/__init__.py @@ -16,9 +16,4 @@ import jaxlib.mlir.ir as ir import jaxlib.mlir.passmanager as passmanager - -# TODO(phawkins): make this unconditional after jaxlib 0.4.22 is the minimum -try: - from jaxlib.mlir._mlir_libs import register_jax_dialects # type: ignore -except ImportError: - register_jax_dialects = None +from jaxlib.mlir._mlir_libs import register_jax_dialects diff --git a/jax/_src/lib/mlir/dialects/__init__.py b/jax/_src/lib/mlir/dialects/__init__.py index ae47aacc9cb2..d46d65a6c97a 100644 --- a/jax/_src/lib/mlir/dialects/__init__.py +++ b/jax/_src/lib/mlir/dialects/__init__.py @@ -23,6 +23,15 @@ import jaxlib.mlir.dialects.scf as scf import jaxlib.mlir.dialects.sparse_tensor as sparse_tensor import jaxlib.mlir.dialects.vector as vector +try: + # pytype: disable=import-error + import jaxlib.mlir.dialects.gpu as gpu + import jaxlib.mlir.dialects.nvgpu as nvgpu + import jaxlib.mlir.dialects.nvvm as nvvm + import jaxlib.mlir.dialects.llvm as llvm + # pytype: enable=import-error +except ImportError: + pass from jax._src import lib diff --git a/jax/_src/lib/mosaic_gpu.py b/jax/_src/lib/mosaic_gpu.py new file mode 100644 index 000000000000..494112093029 --- /dev/null +++ b/jax/_src/lib/mosaic_gpu.py @@ -0,0 +1,23 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ruff: noqa + +try: + try: + from jaxlib.mosaic.gpu import _mosaic_gpu_ext # pytype: disable=import-error + except ImportError: + from jax_cuda12_plugin import _mosaic_gpu_ext # pytype: disable=import-error +except ImportError as e: + raise ModuleNotFoundError("Failed to import the Mosaic GPU bindings") from e diff --git a/jax/_src/lib/triton.py b/jax/_src/lib/triton.py index dee3eb3febf3..c0a5202e9dbc 100644 --- a/jax/_src/lib/triton.py +++ b/jax/_src/lib/triton.py @@ -12,26 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -# ruff: noqa - -import sys - -# TODO(slebedev): Update the message to recommend jaxlib 0.4.25. -_ERROR = ( - "Cannot import the Triton bindings. You may need a newer version of" - " jaxlib. Try installing a nightly wheel following instructions in" - " https://jax.readthedocs.io/en/latest/installation.html#nightly-installation" -) - -try: - from jaxlib.triton import dialect # pytype: disable=import-error -except TypeError: - from jaxlib import version - - if sys.version_info[:2] == (3, 9) and version.__version_info__ < (0, 4, 25): - # Triton MLIR bindings are known to be broken on Python 3.9 in jaxlib - # prior to 0.4.25. - raise ModuleNotFoundError(_ERROR) from None - raise -except ImportError as e: - raise ModuleNotFoundError(_ERROR) from e +from jaxlib.triton import dialect # noqa: F401 # pytype: disable=import-error diff --git a/jax/_src/linear_util.py b/jax/_src/linear_util.py index 8f431d160153..bc4cc242f055 100644 --- a/jax/_src/linear_util.py +++ b/jax/_src/linear_util.py @@ -63,15 +63,16 @@ def trans1(static_arg, *dynamic_args, **kwargs): """ from __future__ import annotations +from collections.abc import Callable from functools import partial -from typing import Any, Callable, NamedTuple +from typing import Any, NamedTuple import weakref from jax._src import config from jax._src import core from jax._src import traceback_util from jax._src.tree_util import tree_map -from jax._src.util import curry +from jax._src.util import curry, cache_clearing_funs traceback_util.register_exclusion(__file__) @@ -246,8 +247,9 @@ def transformation(gen, fun: WrappedFun, *gen_static_args) -> WrappedFun: return fun.wrap(gen, gen_static_args, None) @curry -def transformation_with_aux(gen, fun: WrappedFun, *gen_static_args, - use_eq_store=False) -> tuple[WrappedFun, Any]: +def transformation_with_aux( + gen, fun: WrappedFun, *gen_static_args, use_eq_store: bool = False +) -> tuple[WrappedFun, Callable[[], Any]]: """Adds one more transformation with auxiliary output to a WrappedFun.""" out_store = Store() if not use_eq_store else EqualStore() out_thunk = lambda: out_store.val @@ -359,17 +361,9 @@ def _evict_function(f): memoized_fun.cache_clear = fun_caches.clear # type: ignore memoized_fun.evict_function = _evict_function # type: ignore - cache_clearing_funs.add(memoized_fun.cache_clear) - return memoized_fun -cache_clearing_funs = weakref.WeakSet() # type: ignore - -def clear_all_caches(): - global cache_clearing_funs - for clear in cache_clearing_funs: - clear() @partial(partial, tree_map) def _copy_main_traces(x): diff --git a/jax/_src/lru_cache.py b/jax/_src/lru_cache.py new file mode 100644 index 000000000000..3b1f9df07210 --- /dev/null +++ b/jax/_src/lru_cache.py @@ -0,0 +1,184 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import heapq +import logging +import pathlib +import warnings + +from jax._src.compilation_cache_interface import CacheInterface + + +try: + import filelock +except ImportError: + filelock = None + + +logger = logging.getLogger(__name__) + + +class LRUCache(CacheInterface): + """Bounded cache with least-recently-used (LRU) eviction policy. + + This implementation includes cache reading, writing and eviction + based on the LRU policy. + + Notably, when ``max_size`` is set to -1, the cache eviction + is disabled, and the LRU cache functions as a normal cache + without any size limitations. + """ + + def __init__(self, path: str, *, max_size: int, lock_timeout_secs: float | None = 10): + """Args: + + path: The path to the cache directory. + max_size: The maximum size of the cache in bytes. Caching will be + disabled if this value is set to ``0``. A special value of ``-1`` + indicates no limit, allowing the cache size to grow indefinitely. + lock_timeout_secs: (optional) The timeout for acquiring a file lock. + """ + # TODO(ayx): add support for cloud other filesystems such as GCS + if not self._is_local_filesystem(path): + raise NotImplementedError("LRUCache only supports local filesystem at this time.") + + self.path = pathlib.Path(path) + self.path.mkdir(parents=True, exist_ok=True) + + # TODO(ayx): having a `self._path` is required by the base class + # `CacheInterface`, but the base class can be removed after `LRUCache` + # and the original `GFileCache` are unified + self._path = self.path + + self.eviction_enabled = max_size != -1 # no eviction if `max_size` is set to -1 + + if self.eviction_enabled: + if filelock is None: + raise RuntimeError("Please install filelock package to set `jax_compilation_cache_max_size`") + + self.max_size = max_size + self.lock_timeout_secs = lock_timeout_secs + + self.lock_path = self.path / ".lockfile" + self.lock = filelock.FileLock(self.lock_path) + + def get(self, key: str) -> bytes | None: + """Retrieves the cached value for the given key. + + Args: + key: The key for which the cache value is retrieved. + + Returns: + The cached data as bytes if available; ``None`` otherwise. + """ + if not key: + raise ValueError("key cannot be empty") + + file = self.path / key + + if self.eviction_enabled: + self.lock.acquire(timeout=self.lock_timeout_secs) + + try: + if not file.exists(): + logger.debug(f"Cache miss for key: {key!r}") + return None + + logger.debug(f"Cache hit for key: {key!r}") + file.touch() # update mtime + return file.read_bytes() + + finally: + if self.eviction_enabled: + self.lock.release() + + def put(self, key: str, val: bytes) -> None: + """Adds a new entry to the cache. + + If a cache item with the same key already exists, no action + will be taken, even if the value is different. + + Args: + key: The key under which the data will be stored. + val: The data to be stored. + """ + if not key: + raise ValueError("key cannot be empty") + + # prevent adding entries that exceed the maximum size limit of the cache + if self.eviction_enabled and len(val) > self.max_size: + msg = (f"Cache value for key {key!r} of size {len(val)} bytes exceeds " + f"the maximum cache size of {self.max_size} bytes") + warnings.warn(msg) + return + + file = self.path / key + + if self.eviction_enabled: + self.lock.acquire(timeout=self.lock_timeout_secs) + + try: + if file.exists(): + return + + self._evict_if_needed(additional_size=len(val)) + file.write_bytes(val) + + finally: + if self.eviction_enabled: + self.lock.release() + + def _evict_if_needed(self, *, additional_size: int = 0) -> None: + """Evicts the least recently used items from the cache if necessary + to ensure the cache does not exceed its maximum size. + + Args: + additional_size: The size of the new entry being added to the cache. + This is included to account for the new entry when checking if + eviction is needed. + """ + if not self.eviction_enabled: + return + + # a priority queue, each element is a tuple `(file_mtime, file, file_size)` + h: list[tuple[int, pathlib.Path, int]] = [] + dir_size = 0 + for file in self.path.iterdir(): + if file.is_file() and file != self.lock_path: + file_size = file.stat().st_size + file_mtime = file.stat().st_mtime_ns + + dir_size += file_size + heapq.heappush(h, (file_mtime, file, file_size)) + + target_size = self.max_size - additional_size + # evict files until the directory size is less than or equal + # to `target_size` + while dir_size > target_size: + file_mtime, file, file_size = heapq.heappop(h) + msg = (f"Evicting cache file {file.name}: file size {file_size} bytes, " + f"target cache size {target_size} bytes") + logger.debug(msg) + file.unlink() + dir_size -= file_size + + # See comments in `jax.src.compilation_cache.get_file_cache()` for details. + # TODO(ayx): This function has a duplicate in that place, and there is + # redundancy here. However, this code is temporary, and once the issue + # is fixed, this code can be removed. + @staticmethod + def _is_local_filesystem(path: str) -> bool: + return path.startswith("file://") or "://" not in path diff --git a/jax/_src/maps.py b/jax/_src/maps.py index bde12f3fcdaf..4b574775dc70 100644 --- a/jax/_src/maps.py +++ b/jax/_src/maps.py @@ -15,12 +15,12 @@ from __future__ import annotations from collections import OrderedDict, abc -from collections.abc import Iterable, Sequence, Mapping +from collections.abc import Callable, Iterable, Sequence, Mapping import contextlib from functools import wraps, partial, partialmethod, lru_cache import itertools as it import math -from typing import Callable, Any, NamedTuple, Union +from typing import Any, NamedTuple, Union, cast as type_cast import numpy as np @@ -62,7 +62,7 @@ from jax._src.util import (safe_map, safe_zip, HashableFunction, unzip2, unzip3, as_hashable_function, distributed_debug_log, tuple_insert, moveaxis, split_list, wrap_name, - merge_lists, partition_list) + merge_lists, partition_list, fun_name) source_info_util.register_exclusion(__file__) traceback_util.register_exclusion(__file__) @@ -116,7 +116,7 @@ class SerialLoop: jointly over chunks of multiple axes (with the usual requirement that they do not coincide in a named shape of any value in the program). - Example:: + Examples: # Processes `x` in a vectorized way, but in 20 micro-batches. xmap(f, in_axes=['i'], out_axes=[i], axis_resources={'i': SerialLoop(20)})(x) @@ -161,7 +161,7 @@ def serial_loop(name: ResourceAxisName, length: int): name: Name of the loop in the resource environment. length: Number of iterations. - Example:: + Examples: >>> x = jnp.linspace(0, jnp.pi, 4) ... @@ -285,6 +285,13 @@ def xmap(fun: Callable, backend: str | None = None) -> stages.Wrapped: """Assign a positional signature to a program that uses named array axes. + .. warning:: + xmap is deprecated and will be removed in a future release. Use + :py:func:`~jax.shard_map` or :py:func:`~jax.vmap` with the + ``spmd_axis_name`` argument for expressing SPMD device-parallel + computations. Please file an issue on https://github.com/google/jax/issues + if neither are suitable for your use case. + .. warning:: This is an experimental feature and the details can change at any time. Use at your own risk! @@ -527,7 +534,7 @@ def infer_params(*args): args_flat, in_tree = tree_flatten(args) fun_flat, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree) if donate_argnums: - donated_invars = donation_vector(donate_argnums, (), args, {}) + donated_invars = donation_vector(donate_argnums, (), in_tree, kws=False) else: donated_invars = (False,) * len(args_flat) in_axes_flat = _flatten_axes("xmap in_axes", in_tree, in_axes, tupled_args=True) @@ -570,7 +577,7 @@ def infer_params(*args): in_axes_flat, args_flat) params = dict( - name=getattr(fun, '__name__', ''), + name=fun_name(fun), in_axes=tuple(in_axes_flat), out_axes_thunk=out_axes_thunk, donated_invars=donated_invars, @@ -620,21 +627,21 @@ def lower(*args, **kwargs): in_tree = treedef_tuple([in_tree, tree_flatten({})[1]]) in_avals = in_tree.unflatten(avals_flat) return stages.Lowered.from_flat_info( - computation, in_tree, in_avals, donate_argnums, out_tree(), # type: ignore - no_kwargs=True) + computation, in_tree, in_avals, donate_argnums, out_tree()) fun_mapped.lower = lower - return fun_mapped + return type_cast(stages.Wrapped, fun_mapped) def xmap_impl(fun: lu.WrappedFun, *args, name, in_axes, out_axes_thunk, donated_invars, global_axis_sizes, axis_resources, resource_env, backend, spmd_in_axes, spmd_out_axes_thunk): in_avals = [core.raise_to_shaped(core.get_aval(arg)) for arg in args] - xmap_callable = make_xmap_callable( + computation = make_xmap_callable( fun, name, in_axes, out_axes_thunk, donated_invars, global_axis_sizes, axis_resources, resource_env, backend, spmd_in_axes, spmd_out_axes_thunk, - mlir.LoweringParameters(), *in_avals).compile().unsafe_call + mlir.LoweringParameters(), *in_avals) + xmap_callable = computation.compile().unsafe_call distributed_debug_log(("Running xmapped function", name), ("python function", fun.f), ("mesh", resource_env.physical_mesh), @@ -700,16 +707,17 @@ def make_xmap_callable(fun: lu.WrappedFun, f, 'xmap', name, mesh, in_shardings, out_shardings, donated_invars, use_spmd_lowering, in_avals, - tiling_method=tiling_method, + tiling_method=tiling_method, lowering_platforms=None, lowering_parameters=lowering_parameters) else: jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(f, in_avals) return pxla.lower_sharding_computation( core.ClosedJaxpr(jaxpr, consts), 'jit', name, (UNSPECIFIED,) * len(in_avals), (UNSPECIFIED,) * len(out_avals), - donated_invars, in_avals, keep_unused=True, inline=False, - devices_from_context=None, lowering_parameters=lowering_parameters, - in_layouts=(None,) * len(in_avals), out_layouts=(None,) * len(out_avals)) + (None,) * len(in_avals), (None,) * len(out_avals), + donated_invars, keep_unused=True, inline=False, + devices_from_context=None, lowering_platforms=None, + lowering_parameters=lowering_parameters, pgle_profiler=None) class EvaluationPlan(NamedTuple): @@ -1808,9 +1816,10 @@ def _fix_inferred_spmd_sharding(jaxpr, resource_env, gen_fresh_name = None): [tmpvar], [outvar], sharding_constraint_p, dict(resource_env=resource_env, sharding=gspmd_sharding, + layout=None, unconstrained_dims=unconstrained_dims), set(), - eqn.source_info)) + eqn.source_info, eqn.ctx)) return jaxpr.replace(eqns=new_eqns) def _flatten_axes(what, tree, axes, tupled_args): @@ -1832,7 +1841,7 @@ class NoQuotesStr(str): def _thread_local_flag_unsupported(_): raise RuntimeError("thread-local xmap flags not supported!") def _clear_compilation_cache(_): - make_xmap_callable.cache_clear() # type: ignore + make_xmap_callable.cache_clear() # pytype: disable=attribute-error def _ensure_spmd_and(f): def update(v): @@ -1842,7 +1851,7 @@ def update(v): return update -SPMD_LOWERING = config.define_bool_state( +SPMD_LOWERING = config.bool_state( name="experimental_xmap_spmd_lowering", default=False, help=("When set, multi-device xmap computations will be compiled through " @@ -1850,7 +1859,7 @@ def update(v): "Not supported on CPU!"), update_global_hook=_clear_compilation_cache, update_thread_local_hook=_thread_local_flag_unsupported) -SPMD_LOWERING_MANUAL = config.define_bool_state( +SPMD_LOWERING_MANUAL = config.bool_state( name="experimental_xmap_spmd_lowering_manual", default=False, help=("When set, multi-device xmap computations will be compiled using " @@ -1859,7 +1868,7 @@ def update(v): "Requires experimental_xmap_spmd_lowering!"), update_global_hook=_ensure_spmd_and(_clear_compilation_cache), update_thread_local_hook=_thread_local_flag_unsupported) -_ENSURE_FIXED_SHARDING = config.define_bool_state( +_ENSURE_FIXED_SHARDING = config.bool_state( name="experimental_xmap_ensure_fixed_sharding", default=False, help=("When set and `experimental_xmap_spmd_lowering` is enabled, the lowering will " diff --git a/jax/_src/mesh.py b/jax/_src/mesh.py index 7339bd6cf7c1..32138678561f 100644 --- a/jax/_src/mesh.py +++ b/jax/_src/mesh.py @@ -91,7 +91,7 @@ def __repr__(self): return f"ResourceEnv(mesh=Mesh({mesh_repr}), {self.loops!r})" -@functools.lru_cache(maxsize=128) +@util.cache(max_size=128, trace_context_in_key=False) def _get_local_mesh(global_mesh: Mesh, process_index: int) -> Mesh: if global_mesh.empty: return global_mesh @@ -144,7 +144,7 @@ class Mesh(contextlib.ContextDecorator): dimensions of the ``devices`` argument. Its length should match the rank of ``devices``. - Example: + Examples: >>> from jax.experimental.pjit import pjit >>> from jax.sharding import Mesh diff --git a/jax/_src/nn/functions.py b/jax/_src/nn/functions.py index 8e232fcc6d8e..822fb548ed90 100644 --- a/jax/_src/nn/functions.py +++ b/jax/_src/nn/functions.py @@ -20,6 +20,7 @@ import operator import numpy as np from typing import Any +import warnings import jax import jax.numpy as jnp @@ -35,6 +36,12 @@ from jax._src.ops.special import logsumexp as _logsumexp +class Unspecified: + def __repr__(self): + return "_UNSPECIFIED" +_UNSPECIFIED = Unspecified() + + # activations @custom_jvp @@ -62,7 +69,7 @@ def relu(x: ArrayLike) -> Array: Returns: An array. - Example: + Examples: >>> jax.nn.relu(jax.numpy.array([-2., -1., -0.5, 0, 0.5, 1., 2.])) Array([0. , 0. , 0. , 0. , 0.5, 1. , 2. ], dtype=float32) @@ -110,6 +117,32 @@ def softplus(x: ArrayLike) -> Array: """ return jnp.logaddexp(x, 0) +@jax.jit +def sparse_plus(x: ArrayLike) -> Array: + r"""Sparse plus function. + + Computes the function: + + .. math:: + + \mathrm{sparse\_plus}(x) = \begin{cases} + 0, & x \leq -1\\ + \frac{1}{4}(x+1)^2, & -1 < x < 1 \\ + x, & 1 \leq x + \end{cases} + + This is the twin function of the softplus activation ensuring a zero output + for inputs less than -1 and a linear output for inputs greater than 1, + while remaining smooth, convex, monotonic by an adequate definition between + -1 and 1. + + Args: + x: input (float) + """ + numpy_util.check_arraylike("sparse_plus", x) + x = jnp.asarray(x) + return jnp.where(x <= -1.0, 0.0, jnp.where(x >= 1.0, x, (x + 1.0)**2/4)) + @jax.jit def soft_sign(x: ArrayLike) -> Array: r"""Soft-sign activation function. @@ -147,9 +180,41 @@ def sigmoid(x: ArrayLike) -> Array: """ return lax.logistic(x) +@jax.jit +def sparse_sigmoid(x: ArrayLike) -> Array: + r"""Sparse sigmoid activation function. + + Computes the function: + + .. math:: + + \mathrm{sparse\_sigmoid}(x) = \begin{cases} + 0, & x \leq -1\\ + \frac{1}{2}(x+1), & -1 < x < 1 \\ + 1, & 1 \leq x + \end{cases} + + This is the twin function of the ``sigmoid`` activation ensuring a zero output + for inputs less than -1, a 1 output for inputs greater than 1, and a linear + output for inputs between -1 and 1. It is the derivative of ``sparse_plus``. + + For more information, see `Learning with Fenchel-Young Losses (section 6.2) + `_. + + Args: + x : input array + + Returns: + An array. + + See also: + :func:`sigmoid` + """ + return 0.5 * jnp.clip(x + 1.0, 0.0, 2.0) + @jax.jit def silu(x: ArrayLike) -> Array: - r"""SiLU (a.k.a. swish) activation function. + r"""SiLU (aka swish) activation function. Computes the element-wise function: @@ -173,6 +238,29 @@ def silu(x: ArrayLike) -> Array: swish = silu +@jax.jit +def mish(x: ArrayLike) -> Array: + r"""Mish activation function. + + Computes the element-wise function: + + .. math:: + \mathrm{mish}(x) = x \cdot \mathrm{tanh}(\mathrm{softplus}(x)) + + For more information, see + `Mish: A Self Regularized Non-Monotonic Activation Function + `_. + + Args: + x : input array + + Returns: + An array. + """ + numpy_util.check_arraylike("mish", x) + x_arr = jnp.asarray(x) + return x_arr * jnp.tanh(softplus(x_arr)) + @jax.jit def log_sigmoid(x: ArrayLike) -> Array: r"""Log-sigmoid activation function. @@ -288,7 +376,7 @@ def celu(x: ArrayLike, alpha: ArrayLike = 1.0) -> Array: For more information, see `Continuously Differentiable Exponential Linear Units - `_. + `_. Args: x : input array @@ -316,7 +404,7 @@ def selu(x: ArrayLike) -> Array: For more information, see `Self-Normalizing Neural Networks - `_. + `_. Args: x : input array @@ -405,7 +493,7 @@ def glu(x: ArrayLike, axis: int = -1) -> Array: def log_softmax(x: ArrayLike, axis: int | tuple[int, ...] | None = -1, where: ArrayLike | None = None, - initial: ArrayLike | None = None) -> Array: + initial: ArrayLike | None | Unspecified = _UNSPECIFIED) -> Array: r"""Log-Softmax function. Computes the logarithm of the :code:`softmax` function, which rescales @@ -420,19 +508,26 @@ def log_softmax(x: ArrayLike, axis: the axis or axes along which the :code:`log_softmax` should be computed. Either an integer or a tuple of integers. where: Elements to include in the :code:`log_softmax`. - initial: The minimum value used to shift the input array. Must be present - when :code:`where` is not None. Returns: An array. + Note: + If any input values are ``+inf``, the result will be all ``NaN``: this reflects the + fact that ``inf / inf`` is not well-defined in the context of floating-point math. + See also: :func:`softmax` """ + if initial is not _UNSPECIFIED: + # Added 2024-4-10 + warnings.warn("The initial argument to log_softmax is deprecated, and no longer has any effect.", + DeprecationWarning, stacklevel=2) + del initial numpy_util.check_arraylike("log_softmax", x) x_arr = jnp.asarray(x) - x_max = jnp.max(x_arr, axis, where=where, initial=initial, keepdims=True) - x_safe = x_arr if where is None else jnp.where(where, x_arr, initial) + x_max = jnp.max(x_arr, axis, where=where, initial=-jnp.inf, keepdims=True) + x_safe = x_arr if where is None else jnp.where(where, x_arr, -jnp.inf) shifted = x_safe - lax.stop_gradient(x_max) shifted_logsumexp = jnp.log( jnp.sum(jnp.exp(shifted), axis, where=where, keepdims=True)) @@ -447,7 +542,7 @@ def log_softmax(x: ArrayLike, def softmax(x: ArrayLike, axis: int | tuple[int, ...] | None = -1, where: ArrayLike | None = None, - initial: ArrayLike | None = None) -> Array: + initial: ArrayLike | None | Unspecified = _UNSPECIFIED) -> Array: r"""Softmax function. Computes the function which rescales elements to the range :math:`[0, 1]` @@ -462,22 +557,29 @@ def softmax(x: ArrayLike, softmax output summed across these dimensions should sum to :math:`1`. Either an integer or a tuple of integers. where: Elements to include in the :code:`softmax`. - initial: The minimum value used to shift the input array. Must be present - when :code:`where` is not None. Returns: An array. + Note: + If any input values are ``+inf``, the result will be all ``NaN``: this reflects the + fact that ``inf / inf`` is not well-defined in the context of floating-point math. + See also: :func:`log_softmax` """ + if initial is not _UNSPECIFIED: + # Added 2024-4-10 + warnings.warn("The initial argument to softmax is deprecated, and no longer has any effect.", + DeprecationWarning, stacklevel=2) + del initial if config.softmax_custom_jvp.value: # mypy is confused by the `functools.partial` application in the definition # of `_softmax` and incorrectly concludes that `_softmax` returns # `ReturnValue` -- the unsubstituted type parameter of `custom_jvp`. - return _softmax(x, axis, where, initial) # type: ignore[return-value] + return _softmax(x, axis, where) else: - return _softmax_deprecated(x, axis, where, initial) + return _softmax_deprecated(x, axis, where) # TODO(mattjj): replace softmax with _softmax when deprecation flag is removed @partial(jax.custom_jvp, nondiff_argnums=(1,)) @@ -485,7 +587,7 @@ def _softmax( x: ArrayLike, axis: int | tuple[int, ...] | None = -1, where: ArrayLike | None = None, - initial: ArrayLike | None = None) -> Array: + initial: ArrayLike | None = -jnp.inf) -> Array: x_max = jnp.max(x, axis, where=where, initial=initial, keepdims=True) x_safe = x if where is None else jnp.where(where, x, initial) unnormalized = jnp.exp(x_safe - x_max) @@ -504,7 +606,7 @@ def _softmax_deprecated( x: ArrayLike, axis: int | tuple[int, ...] | None = -1, where: ArrayLike | None = None, - initial: ArrayLike | None = None) -> Array: + initial: ArrayLike | None = -jnp.inf) -> Array: x_max = jnp.max(x, axis, where=where, initial=initial, keepdims=True) x_safe = x if where is None else jnp.where(where, x, initial) unnormalized = jnp.exp(x_safe - lax.stop_gradient(x_max)) diff --git a/jax/_src/nn/initializers.py b/jax/_src/nn/initializers.py index 5468fd663181..cf245f7927be 100644 --- a/jax/_src/nn/initializers.py +++ b/jax/_src/nn/initializers.py @@ -62,7 +62,7 @@ def zeros(key: KeyArray, The ``key`` argument is ignored. >>> import jax, jax.numpy as jnp - >>> jax.nn.initializers.zeros(jax.random.PRNGKey(42), (2, 3), jnp.float32) + >>> jax.nn.initializers.zeros(jax.random.key(42), (2, 3), jnp.float32) Array([[0., 0., 0.], [0., 0., 0.]], dtype=float32) """ @@ -77,7 +77,7 @@ def ones(key: KeyArray, The ``key`` argument is ignored. >>> import jax, jax.numpy as jnp - >>> jax.nn.initializers.ones(jax.random.PRNGKey(42), (3, 2), jnp.float32) + >>> jax.nn.initializers.ones(jax.random.key(42), (3, 2), jnp.float32) Array([[1., 1.], [1., 1.], [1., 1.]], dtype=float32) @@ -96,7 +96,7 @@ def constant(value: ArrayLike, >>> import jax, jax.numpy as jnp >>> initializer = jax.nn.initializers.constant(-7) - >>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) + >>> initializer(jax.random.key(42), (2, 3), jnp.float32) Array([[-7., -7., -7.], [-7., -7., -7.]], dtype=float32) """ @@ -122,7 +122,7 @@ def uniform(scale: RealNumeric = 1e-2, >>> import jax, jax.numpy as jnp >>> initializer = jax.nn.initializers.uniform(10.0) - >>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) # doctest: +SKIP + >>> initializer(jax.random.key(42), (2, 3), jnp.float32) # doctest: +SKIP Array([[7.298188 , 8.691938 , 8.7230015], [2.0818567, 1.8662417, 5.5022564]], dtype=float32) """ @@ -148,7 +148,7 @@ def normal(stddev: RealNumeric = 1e-2, >>> import jax, jax.numpy as jnp >>> initializer = jax.nn.initializers.normal(5.0) - >>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) # doctest: +SKIP + >>> initializer(jax.random.key(42), (2, 3), jnp.float32) # doctest: +SKIP Array([[ 3.0613258 , 5.6129413 , 5.6866574 ], [-4.063663 , -4.4520254 , 0.63115686]], dtype=float32) """ @@ -372,11 +372,11 @@ def glorot_uniform(in_axis: int | Sequence[int] = -2, Returns: An initializer. - Example: + Examples: >>> import jax, jax.numpy as jnp >>> initializer = jax.nn.initializers.glorot_uniform() - >>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) # doctest: +SKIP + >>> initializer(jax.random.key(42), (2, 3), jnp.float32) # doctest: +SKIP Array([[ 0.50350785, 0.8088631 , 0.81566876], [-0.6393332 , -0.6865721 , 0.11003882]], dtype=float32) @@ -410,11 +410,11 @@ def glorot_normal(in_axis: int | Sequence[int] = -2, Returns: An initializer. - Example: + Examples: >>> import jax, jax.numpy as jnp >>> initializer = jax.nn.initializers.glorot_normal() - >>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) # doctest: +SKIP + >>> initializer(jax.random.key(42), (2, 3), jnp.float32) # doctest: +SKIP Array([[ 0.41770416, 0.75262755, 0.7619329 ], [-0.5516644 , -0.6028657 , 0.08661086]], dtype=float32) @@ -448,11 +448,11 @@ def lecun_uniform(in_axis: int | Sequence[int] = -2, Returns: An initializer. - Example: + Examples: >>> import jax, jax.numpy as jnp >>> initializer = jax.nn.initializers.lecun_uniform() - >>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) # doctest: +SKIP + >>> initializer(jax.random.key(42), (2, 3), jnp.float32) # doctest: +SKIP Array([[ 0.56293887, 0.90433645, 0.9119454 ], [-0.71479625, -0.7676109 , 0.12302713]], dtype=float32) @@ -484,11 +484,11 @@ def lecun_normal(in_axis: int | Sequence[int] = -2, Returns: An initializer. - Example: + Examples: >>> import jax, jax.numpy as jnp >>> initializer = jax.nn.initializers.lecun_normal() - >>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) # doctest: +SKIP + >>> initializer(jax.random.key(42), (2, 3), jnp.float32) # doctest: +SKIP Array([[ 0.46700746, 0.8414632 , 0.8518669 ], [-0.61677957, -0.67402434, 0.09683388]], dtype=float32) @@ -520,11 +520,11 @@ def he_uniform(in_axis: int | Sequence[int] = -2, Returns: An initializer. - Example: + Examples: >>> import jax, jax.numpy as jnp >>> initializer = jax.nn.initializers.he_uniform() - >>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) # doctest: +SKIP + >>> initializer(jax.random.key(42), (2, 3), jnp.float32) # doctest: +SKIP Array([[ 0.79611576, 1.2789248 , 1.2896855 ], [-1.0108745 , -1.0855657 , 0.17398663]], dtype=float32) @@ -558,11 +558,11 @@ def he_normal(in_axis: int | Sequence[int] = -2, Returns: An initializer. - Example: + Examples: >>> import jax, jax.numpy as jnp >>> initializer = jax.nn.initializers.he_normal() - >>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) # doctest: +SKIP + >>> initializer(jax.random.key(42), (2, 3), jnp.float32) # doctest: +SKIP Array([[ 0.6604483 , 1.1900088 , 1.2047218 ], [-0.87225807, -0.95321447, 0.1369438 ]], dtype=float32) @@ -591,11 +591,11 @@ def orthogonal(scale: RealNumeric = 1.0, Returns: An orthogonal initializer. - Example: + Examples: >>> import jax, jax.numpy as jnp >>> initializer = jax.nn.initializers.orthogonal() - >>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) # doctest: +SKIP + >>> initializer(jax.random.key(42), (2, 3), jnp.float32) # doctest: +SKIP Array([[ 3.9026976e-01, 7.2495741e-01, -5.6756169e-01], [ 8.8047469e-01, -4.7409311e-01, -1.3157725e-04]], dtype=float32) """ @@ -634,11 +634,11 @@ def delta_orthogonal( A `delta orthogonal initializer`_. The shape passed to the initializer must be 3D, 4D, or 5D. - Example: + Examples: >>> import jax, jax.numpy as jnp >>> initializer = jax.nn.initializers.delta_orthogonal() - >>> initializer(jax.random.PRNGKey(42), (3, 3, 3), jnp.float32) # doctest: +SKIP + >>> initializer(jax.random.key(42), (3, 3, 3), jnp.float32) # doctest: +SKIP Array([[[ 0. , 0. , 0. ], [ 0. , 0. , 0. ], [ 0. , 0. , 0. ]], diff --git a/jax/_src/numpy/array_methods.py b/jax/_src/numpy/array_methods.py index fb1e52bd1be9..1d27c4b3aa28 100644 --- a/jax/_src/numpy/array_methods.py +++ b/jax/_src/numpy/array_methods.py @@ -31,11 +31,13 @@ import numpy as np import jax from jax import lax +from jax.sharding import Sharding from jax._src import core from jax._src import dtypes from jax._src.api_util import _ensure_index_tuple from jax._src.array import ArrayImpl from jax._src.lax import lax as lax_internal +from jax._src.lib import xla_client as xc from jax._src.numpy import lax_numpy from jax._src.numpy import reductions from jax._src.numpy import ufuncs @@ -55,7 +57,7 @@ # functions, which can themselves handle instances from any of these classes. -def _astype(arr: ArrayLike, dtype: DTypeLike) -> Array: +def _astype(arr: ArrayLike, dtype: DTypeLike, copy: bool = False, device: xc.Device | Sharding | None = None) -> Array: """Copy the array and cast to a specified dtype. This is implemented via :func:`jax.lax.convert_element_type`, which may @@ -63,7 +65,7 @@ def _astype(arr: ArrayLike, dtype: DTypeLike) -> Array: some cases. In particular, the details of float-to-int and int-to-float casts are implementation dependent. """ - return lax_numpy.astype(arr, dtype) + return lax_numpy.astype(arr, dtype, copy=copy, device=device) def _nbytes(arr: ArrayLike) -> int: @@ -84,12 +86,11 @@ def _itemsize(arr: ArrayLike) -> int: def _clip(number: ArrayLike, - min: ArrayLike | None = None, max: ArrayLike | None = None, - out: None = None) -> Array: + min: ArrayLike | None = None, max: ArrayLike | None = None) -> Array: """Return an array whose values are limited to a specified range. Refer to :func:`jax.numpy.clip` for full documentation.""" - return lax_numpy.clip(number, a_min=min, a_max=max, out=out) + return lax_numpy.clip(number, min=min, max=max) def _transpose(a: Array, *args: Any) -> Array: @@ -303,11 +304,13 @@ def __array_module__(self, types): def _compress_method(a: ArrayLike, condition: ArrayLike, - axis: int | None = None, out: None = None) -> Array: + axis: int | None = None, *, out: None = None, + size: int | None = None, fill_value: ArrayLike = 0) -> Array: """Return selected slices of this array along given axis. Refer to :func:`jax.numpy.compress` for full documentation.""" - return lax_numpy.jaxcompress(condition, a, axis, out) + return lax_numpy.compress(condition, a, axis=axis, out=out, + size=size, fill_value=fill_value) @core.stash_axis_env() diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 3d621b6e3cb0..d8723a01d42e 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -27,13 +27,14 @@ import builtins import collections -from collections.abc import Sequence +from collections.abc import Callable, Sequence from functools import partial +import importlib import math import operator import types -from typing import (overload, Any, Callable, Literal, NamedTuple, Protocol, - TypeVar, Union) +from typing import (cast, overload, Any, Literal, NamedTuple, + Protocol, TypeVar, Union) from textwrap import dedent as _dedent import warnings @@ -61,17 +62,30 @@ from jax._src.lax.lax import (_array_copy, _sort_lt_comparator, _sort_le_comparator, PrecisionLike) from jax._src.lax import lax as lax_internal -from jax._src.lib import xla_client as xc, xla_extension_version +from jax._src.lib import xla_client as xc from jax._src.numpy import reductions from jax._src.numpy import ufuncs from jax._src.numpy import util from jax._src.numpy.vectorize import vectorize -from jax._src.typing import Array, ArrayLike, DimSize, DuckTypedArray, DType, DTypeLike, Shape +from jax._src.typing import ( + Array, ArrayLike, DeprecatedArg, DimSize, DuckTypedArray, + DType, DTypeLike, Shape, StaticScalar, +) from jax._src.util import (unzip2, subvals, safe_zip, ceil_of_ratio, partition_list, canonicalize_axis as _canonicalize_axis, NumpyComplexWarning) +for pkg_name in ['jax_cuda12_plugin', 'jax.jaxlib']: + try: + cuda_plugin_extension = importlib.import_module( + f'{pkg_name}.cuda_plugin_extension' + ) + except ImportError: + cuda_plugin_extension = None # type: ignore + else: + break + newaxis = None T = TypeVar('T') @@ -81,9 +95,9 @@ def canonicalize_shape(shape: Any, context: str="") -> core.Shape: if (not isinstance(shape, (tuple, list)) and (getattr(shape, 'ndim', None) == 0 or ndim(shape) == 0)): - return core.canonicalize_shape((shape,), context) # type: ignore + return core.canonicalize_shape((shape,), context) else: - return core.canonicalize_shape(shape, context) # type: ignore + return core.canonicalize_shape(shape, context) # Common docstring additions: @@ -113,8 +127,37 @@ def canonicalize_shape(shape: Any, context: str="") -> core.Shape: printoptions = np.printoptions set_printoptions = np.set_printoptions -@util.implements(np.iscomplexobj) def iscomplexobj(x: Any) -> bool: + """Check if the input is a complex number or an array containing complex elements. + + JAX implementation of :func:`numpy.iscomplexobj`. + + The function evaluates based on input type rather than value. + Inputs with zero imaginary parts are still considered complex. + + Args: + x: input object to check. + + Returns: + True if ``x`` is a complex number or an array containing at least one complex element, + False otherwise. + + See Also: + - :func:`jax.numpy.isrealobj` + - :func:`jax.numpy.iscomplex` + + Examples: + >>> jnp.iscomplexobj(True) + False + >>> jnp.iscomplexobj(0) + False + >>> jnp.iscomplexobj(jnp.array([1, 2])) + False + >>> jnp.iscomplexobj(1+2j) + True + >>> jnp.iscomplexobj(jnp.array([0, 1+2j])) + True + """ if x is None: return False try: @@ -333,6 +376,8 @@ def result_type(*args: Any) -> DType: @jit def trunc(x: ArrayLike) -> Array: util.check_arraylike('trunc', x) + if dtypes.isdtype(dtypes.dtype(x), ('integral', 'bool')): + return lax_internal.asarray(x) return where(lax.lt(x, _lax_const(x, 0)), ufuncs.ceil(x), ufuncs.floor(x)) @@ -383,23 +428,161 @@ def _conv(x: Array, y: Array, mode: str, op: str, precision: PrecisionLike, return result[0, 0, out_order] -@util.implements(np.convolve, lax_description=_PRECISION_DOC, - extra_params=_CONV_PREFERRED_ELEMENT_TYPE_DESCRIPTION) @partial(jit, static_argnames=('mode', 'precision', 'preferred_element_type')) def convolve(a: ArrayLike, v: ArrayLike, mode: str = 'full', *, precision: PrecisionLike = None, preferred_element_type: DTypeLike | None = None) -> Array: + r"""Convolution of two one dimensional arrays. + + JAX implementation of :func:`numpy.convolve`. + + Convolution of one dimensional arrays is defined as: + + .. math:: + + c_k = \sum_j a_{k - j} v_j + + Args: + a: left-hand input to the convolution. Must have ``a.ndim == 1``. + v: right-hand input to the convolution. Must have ``v.ndim == 1``. + mode: controls the size of the output. Available operations are: + + * ``"full"``: (default) output the full convolution of the inputs. + * ``"same"``: return a centered portion of the ``"full"`` output which + is the same size as ``a``. + * ``"valid"``: return the portion of the ``"full"`` output which do not + depend on padding at the array edges. + + precision: Specify the precision of the computation. Refer to + :class:`jax.lax.Precision` for a description of available values. + + preferred_element_type: A datatype, indicating to accumulate results to and + return a result with that datatype. Default is ``None``, which means the + default accumulation type for the input types. + + Returns: + Array containing the convolved result. + + See Also: + - :func:`jax.scipy.signal.convolve`: ND convolution + - :func:`jax.numpy.correlate`: 1D correlation + + Examples: + A few 1D convolution examples: + + >>> x = jnp.array([1, 2, 3, 2, 1]) + >>> y = jnp.array([4, 1, 2]) + + ``jax.numpy.convolve``, by default, returns full convolution using implicit + zero-padding at the edges: + + >>> jnp.convolve(x, y) + Array([ 4., 9., 16., 15., 12., 5., 2.], dtype=float32) + + Specifying ``mode = 'same'`` returns a centered convolution the same size + as the first input: + + >>> jnp.convolve(x, y, mode='same') + Array([ 9., 16., 15., 12., 5.], dtype=float32) + + Specifying ``mode = 'valid'`` returns only the portion where the two arrays + fully overlap: + + >>> jnp.convolve(x, y, mode='valid') + Array([16., 15., 12.], dtype=float32) + + For complex-valued inputs: + + >>> x1 = jnp.array([3+1j, 2, 4-3j]) + >>> y1 = jnp.array([1, 2-3j, 4+5j]) + >>> jnp.convolve(x1, y1) + Array([ 3. +1.j, 11. -7.j, 15.+10.j, 7. -8.j, 31. +8.j], dtype=complex64) + """ util.check_arraylike("convolve", a, v) return _conv(asarray(a), asarray(v), mode=mode, op='convolve', precision=precision, preferred_element_type=preferred_element_type) -@util.implements(np.correlate, lax_description=_PRECISION_DOC, - extra_params=_CONV_PREFERRED_ELEMENT_TYPE_DESCRIPTION) @partial(jit, static_argnames=('mode', 'precision', 'preferred_element_type')) def correlate(a: ArrayLike, v: ArrayLike, mode: str = 'valid', *, precision: PrecisionLike = None, preferred_element_type: DTypeLike | None = None) -> Array: + r"""Correlation of two one dimensional arrays. + + JAX implementation of :func:`numpy.correlate`. + + Correlation of one dimensional arrays is defined as: + + .. math:: + + c_k = \sum_j a_{k + j} \overline{v_j} + + where :math:`\overline{v_j}` is the complex conjugate of :math:`v_j`. + + Args: + a: left-hand input to the correlation. Must have ``a.ndim == 1``. + v: right-hand input to the correlation. Must have ``v.ndim == 1``. + mode: controls the size of the output. Available operations are: + + * ``"full"``: output the full correlation of the inputs. + * ``"same"``: return a centered portion of the ``"full"`` output which + is the same size as ``a``. + * ``"valid"``: (default) return the portion of the ``"full"`` output which do not + depend on padding at the array edges. + + precision: Specify the precision of the computation. Refer to + :class:`jax.lax.Precision` for a description of available values. + + preferred_element_type: A datatype, indicating to accumulate results to and + return a result with that datatype. Default is ``None``, which means the + default accumulation type for the input types. + + Returns: + Array containing the cross-correlation result. + + See Also: + - :func:`jax.scipy.signal.correlate`: ND correlation + - :func:`jax.numpy.convolve`: 1D convolution + + Examples: + >>> x = jnp.array([1, 2, 3, 2, 1]) + >>> y = jnp.array([4, 5, 6]) + + Since default ``mode = 'valid'``, ``jax.numpy.correlate`` returns only the + portion of correlation where the two arrays fully overlap: + + >>> jnp.correlate(x, y) + Array([32., 35., 28.], dtype=float32) + + Specifying ``mode = 'full'`` returns full correlation using implicit + zero-padding at the edges. + + >>> jnp.correlate(x, y, mode='full') + Array([ 6., 17., 32., 35., 28., 13., 4.], dtype=float32) + + Specifying ``mode = 'same'`` returns a centered correlation the same size + as the first input: + + >>> jnp.correlate(x, y, mode='same') + Array([17., 32., 35., 28., 13.], dtype=float32) + + If both the inputs arrays are real-valued and symmetric then the result will + also be symmetric and will be equal to the result of ``jax.numpy.convolve``. + + >>> x1 = jnp.array([1, 2, 3, 2, 1]) + >>> y1 = jnp.array([4, 5, 4]) + >>> jnp.correlate(x1, y1, mode='full') + Array([ 4., 13., 26., 31., 26., 13., 4.], dtype=float32) + >>> jnp.convolve(x1, y1, mode='full') + Array([ 4., 13., 26., 31., 26., 13., 4.], dtype=float32) + + For complex-valued inputs: + + >>> x2 = jnp.array([3+1j, 2, 2-3j]) + >>> y2 = jnp.array([4, 2-5j, 1]) + >>> jnp.correlate(x2, y2, mode='full') + Array([ 3. +1.j, 3.+17.j, 18.+11.j, 27. +4.j, 8.-12.j], dtype=complex64) + """ util.check_arraylike("correlate", a, v) return _conv(asarray(a), asarray(v), mode=mode, op='correlate', precision=precision, preferred_element_type=preferred_element_type) @@ -538,8 +721,78 @@ def histogramdd(sample: ArrayLike, bins: ArrayLike | list[ArrayLike] = 10, view of the input. """ -@util.implements(np.transpose, lax_description=_ARRAY_VIEW_DOC) def transpose(a: ArrayLike, axes: Sequence[int] | None = None) -> Array: + """Return a transposed version of an N-dimensional array. + + JAX implementation of :func:`numpy.transpose`, implemented in terms of + :func:`jax.lax.transpose`. + + Args: + a: input array + axes: optionally specify the permutation using a length-`a.ndim` sequence of integers + ``i`` satisfying ``0 <= i < a.ndim``. Defaults to ``range(a.ndim)[::-1]``, i.e + reverses the order of all axes. + + Returns: + transposed copy of the array. + + See Also: + - :func:`jax.Array.transpose`: equivalent function via an :class:`~jax.Array` method. + - :attr:`jax.Array.T`: equivalent function via an :class:`~jax.Array` property. + - :func:`jax.numpy.matrix_transpose`: transpose the last two axes of an array. This is + suitable for working with batched 2D matrices. + - :func:`jax.numpy.swapaxes`: swap any two axes in an array. + - :func:`jax.numpy.moveaxis`: move an axis to another position in the array. + + Note: + Unlike :func:`numpy.transpose`, :func:`jax.numpy.transpose` will return a copy rather + than a view of the input array. However, under JIT, the compiler will optimize-away + such copies when possible, so this doesn't have performance impacts in practice. + + Examples: + For a 1D array, the transpose is the identity: + + >>> x = jnp.array([1, 2, 3, 4]) + >>> jnp.transpose(x) + Array([1, 2, 3, 4], dtype=int32) + + For a 2D array, the transpose is a matrix transpose: + + >>> x = jnp.array([[1, 2], + ... [3, 4]]) + >>> jnp.transpose(x) + Array([[1, 3], + [2, 4]], dtype=int32) + + For an N-dimensional array, the transpose reverses the order of the axes: + + >>> x = jnp.zeros(shape=(3, 4, 5)) + >>> jnp.transpose(x).shape + (5, 4, 3) + + The ``axes`` argument can be specified to change this default behavior: + + >>> jnp.transpose(x, (0, 2, 1)).shape + (3, 5, 4) + + Since swapping the last two axes is a common operation, it can be done + via its own API, :func:`jax.numpy.matrix_transpose`: + + >>> jnp.matrix_transpose(x).shape + (3, 5, 4) + + For convenience, transposes may also be performed using the :meth:`jax.Array.transpose` + method or the :attr:`jax.Array.T` property: + + >>> x = jnp.array([[1, 2], + ... [3, 4]]) + >>> x.transpose() + Array([[1, 3], + [2, 4]], dtype=int32) + >>> x.T + Array([[1, 3], + [2, 4]], dtype=int32) + """ util.check_arraylike("transpose", a) axes_ = list(range(ndim(a))[::-1]) if axes is None else axes axes_ = [_canonicalize_axis(i, ndim(a)) for i in axes_] @@ -552,19 +805,50 @@ def permute_dims(a: ArrayLike, /, axes: tuple[int, ...]) -> Array: return lax.transpose(a, axes) -@util.implements(getattr(np, 'matrix_transpose', None)) def matrix_transpose(x: ArrayLike, /) -> Array: - """Transposes the last two dimensions of x. + """Transpose the last two dimensions of an array. + + JAX implementation of :func:`numpy.matrix_transpose`, implemented in terms of + :func:`jax.lax.transpose`. - Parameters - ---------- - x : array_like - Input array. Must have ``x.ndim >= 2``. + Args: + x: input array, Must have ``x.ndim >= 2`` - Returns - ------- - xT : Array - Transposed array. + Returns: + matrix-transposed copy of the array. + + See Also: + - :attr:`jax.Array.mT`: same operation accessed via an :func:`~jax.Array` property. + - :func:`jax.numpy.transpose`: general multi-axis transpose + + Note: + Unlike :func:`numpy.matrix_transpose`, :func:`jax.numpy.matrix_transpose` will return a + copy rather than a view of the input array. However, under JIT, the compiler will + optimize-away such copies when possible, so this doesn't have performance impacts in practice. + + Examples: + Here is a 2x2x2 matrix representing a batched 2x2 matrix: + + >>> x = jnp.array([[[1, 2], + ... [3, 4]], + ... [[5, 6], + ... [7, 8]]]) + >>> jnp.matrix_transpose(x) + Array([[[1, 3], + [2, 4]], + + [[5, 7], + [6, 8]]], dtype=int32) + + For convenience, you can perform the same transpose via the :attr:`~jax.Array.mT` + property of :class:`jax.Array`: + + >>> x.mT + Array([[[1, 3], + [2, 4]], + + [[5, 7], + [6, 8]]], dtype=int32) """ util.check_arraylike("matrix_transpose", x) ndim = np.ndim(x) @@ -601,8 +885,62 @@ def rot90(m: ArrayLike, k: int = 1, axes: tuple[int, int] = (0, 1)) -> Array: return flip(transpose(m, perm), ax2) -@util.implements(np.flip, lax_description=_ARRAY_VIEW_DOC) def flip(m: ArrayLike, axis: int | Sequence[int] | None = None) -> Array: + """Reverse the order of elements of an array along the given axis. + + JAX implementation of :func:`numpy.flip`. + + Args: + m: Array. + axis: integer or sequence of integers. Specifies along which axis or axes + should the array elements be reversed. Default is ``None``, which flips + along all axes. + + Returns: + An array with the elements in reverse order along ``axis``. + + See Also: + - :func:`jax.numpy.fliplr`: reverse the order along axis 1 (left/right) + - :func:`jax.numpy.flipud`: reverse the order along axis 0 (up/down) + + Examples: + >>> x1 = jnp.array([[1, 2], + ... [3, 4]]) + >>> jnp.flip(x1) + Array([[4, 3], + [2, 1]], dtype=int32) + + If ``axis`` is specified with an integer, then ``jax.numpy.flip`` reverses + the array along that particular axis only. + + >>> jnp.flip(x1, axis=1) + Array([[2, 1], + [4, 3]], dtype=int32) + + >>> x2 = jnp.arange(1, 9).reshape(2, 2, 2) + >>> x2 + Array([[[1, 2], + [3, 4]], + + [[5, 6], + [7, 8]]], dtype=int32) + >>> jnp.flip(x2) + Array([[[8, 7], + [6, 5]], + + [[4, 3], + [2, 1]]], dtype=int32) + + When ``axis`` is specified with a sequence of integers, then + ``jax.numpy.flip`` reverses the array along the specified axes. + + >>> jnp.flip(x2, axis=[1, 2]) + Array([[[4, 3], + [2, 1]], + + [[8, 7], + [6, 5]]], dtype=int32) + """ util.check_arraylike("flip", m) return _flip(asarray(m), reductions._ensure_optional_axes(axis)) @@ -614,32 +952,143 @@ def _flip(m: Array, axis: int | tuple[int, ...] | None = None) -> Array: return lax.rev(m, [_canonicalize_axis(ax, ndim(m)) for ax in axis]) -@util.implements(np.fliplr, lax_description=_ARRAY_VIEW_DOC) def fliplr(m: ArrayLike) -> Array: + """Reverse the order of elements of an array along axis 1. + + JAX implementation of :func:`numpy.fliplr`. + + Args: + m: Array with at least two dimensions. + + Returns: + An array with the elements in reverse order along axis 1. + + See Also: + - :func:`jax.numpy.flip`: reverse the order along the given axis + - :func:`jax.numpy.flipud`: reverse the order along axis 0 + + Examples: + >>> x = jnp.array([[1, 2], + ... [3, 4]]) + >>> jnp.fliplr(x) + Array([[2, 1], + [4, 3]], dtype=int32) + """ util.check_arraylike("fliplr", m) return _flip(asarray(m), 1) -@util.implements(np.flipud, lax_description=_ARRAY_VIEW_DOC) def flipud(m: ArrayLike) -> Array: + """Reverse the order of elements of an array along axis 0. + + JAX implementation of :func:`numpy.flipud`. + + Args: + m: Array with at least one dimension. + + Returns: + An array with the elements in reverse order along axis 0. + + See Also: + - :func:`jax.numpy.flip`: reverse the order along the given axis + - :func:`jax.numpy.fliplr`: reverse the order along axis 1 + + Examples: + >>> x = jnp.array([[1, 2], + ... [3, 4]]) + >>> jnp.flipud(x) + Array([[3, 4], + [1, 2]], dtype=int32) + """ util.check_arraylike("flipud", m) return _flip(asarray(m), 0) -@util.implements(np.iscomplex) @jit def iscomplex(x: ArrayLike) -> Array: + """Return boolean array showing where the input is complex. + + JAX implementation of :func:`numpy.iscomplex`. + + Args: + x: Input array to check. + + Returns: + A new array containing boolean values indicating complex elements. + + See Also: + - :func:`jax.numpy.iscomplexobj` + - :func:`jax.numpy.isrealobj` + + Examples: + >>> jnp.iscomplex(jnp.array([True, 0, 1, 2j, 1+2j])) + Array([False, False, False, True, True], dtype=bool) + """ i = ufuncs.imag(x) return lax.ne(i, _lax_const(i, 0)) -@util.implements(np.isreal) @jit def isreal(x: ArrayLike) -> Array: + """Return boolean array showing where the input is real. + + JAX implementation of :func:`numpy.isreal`. + + Args: + x: input array to check. + + Returns: + A new array containing boolean values indicating real elements. + + See Also: + - :func:`jax.numpy.iscomplex` + - :func:`jax.numpy.isrealobj` + + Examples: + >>> jnp.isreal(jnp.array([False, 0j, 1, 2.1, 1+2j])) + Array([ True, True, True, True, False], dtype=bool) + """ i = ufuncs.imag(x) return lax.eq(i, _lax_const(i, 0)) -@util.implements(np.angle) + @partial(jit, static_argnames=['deg']) def angle(z: ArrayLike, deg: bool = False) -> Array: + """Return the angle of a complex valued number or array. + + JAX implementation of :func:`numpy.angle`. + + Args: + z: A complex number or an array of complex numbers. + deg: Boolean. If ``True``, returns the result in degrees else returns + in radians. Default is ``False``. + + Returns: + An array of counterclockwise angle of each element of ``z``, with the same + shape as ``z`` of dtype float. + + Examples: + + If ``z`` is a number + + >>> z1 = 2+3j + >>> jnp.angle(z1) + Array(0.98279375, dtype=float32, weak_type=True) + + If ``z`` is an array + + >>> z2 = jnp.array([[1+3j, 2-5j], + ... [4-3j, 3+2j]]) + >>> with jnp.printoptions(precision=2, suppress=True): + ... print(jnp.angle(z2)) + [[ 1.25 -1.19] + [-0.64 0.59]] + + If ``deg=True``. + + >>> with jnp.printoptions(precision=2, suppress=True): + ... print(jnp.angle(z2, deg=True)) + [[ 71.57 -68.2 ] + [-36.87 33.69]] + """ re = ufuncs.real(z) im = ufuncs.imag(z) dtype = _dtype(re) @@ -753,7 +1202,7 @@ def gradient_along_axis(a, h, axis): if len(axis_tuple) == 0: return [] - if min([s for i, s in enumerate(a.shape) if i in axis_tuple]) < 2: + if min(s for i, s in enumerate(a.shape) if i in axis_tuple) < 2: raise ValueError("Shape of array too small to calculate " "a numerical gradient, " "at least 2 elements are required.") @@ -773,35 +1222,246 @@ def gradient_along_axis(a, h, axis): return a_grad[0] if len(axis_tuple) == 1 else a_grad -@util.implements(np.isrealobj) def isrealobj(x: Any) -> bool: + """Check if the input is not a complex number or an array containing complex elements. + + JAX implementation of :func:`numpy.isrealobj`. + + The function evaluates based on input type rather than value. + Inputs with zero imaginary parts are still considered complex. + + Args: + x: input object to check. + + Returns: + False if ``x`` is a complex number or an array containing at least one complex element, + True otherwise. + + See Also: + - :func:`jax.numpy.iscomplexobj` + - :func:`jax.numpy.isreal` + + Examples: + >>> jnp.isrealobj(0) + True + >>> jnp.isrealobj(1.2) + True + >>> jnp.isrealobj(jnp.array([1, 2])) + True + >>> jnp.isrealobj(1+2j) + False + >>> jnp.isrealobj(jnp.array([0, 1+2j])) + False + """ return not iscomplexobj(x) -@util.implements(np.reshape, lax_description=_ARRAY_VIEW_DOC) -def reshape(a: ArrayLike, newshape: DimSize | Shape, order: str = "C") -> Array: +def reshape( + a: ArrayLike, shape: DimSize | Shape | None = None, order: str = "C", *, + newshape: DimSize | Shape | DeprecatedArg = DeprecatedArg()) -> Array: + """Return a reshaped copy of an array. + + JAX implementation of :func:`numpy.reshape`, implemented in terms of + :func:`jax.lax.reshape`. + + Args: + a: input array to reshape + shape: integer or sequence of integers giving the new shape, which must match the + size of the input array. If any single dimension is given size ``-1``, it will be + replaced with a value such that the output has the correct size. + order: ``'F'`` or ``'C'``, specifies whether the reshape should apply column-major + (fortran-style, ``"F"``) or row-major (C-style, ``"C"``) order; default is ``"C"``. + JAX does not support ``order="A"``. + + Returns: + reshaped copy of input array with the specified shape. + + Notes: + Unlike :func:`numpy.reshape`, :func:`jax.numpy.reshape` will return a copy rather + than a view of the input array. However, under JIT, the compiler will optimize-away + such copies when possible, so this doesn't have performance impacts in practice. + + See Also: + - :meth:`jax.Array.reshape`: equivalent functionality via an array method. + - :func:`jax.numpy.ravel`: flatten an array into a 1D shape. + - :func:`jax.numpy.squeeze`: remove one or more length-1 axes from an array's shape. + + Examples: + >>> x = jnp.array([[1, 2, 3], + ... [4, 5, 6]]) + >>> jnp.reshape(x, 6) + Array([1, 2, 3, 4, 5, 6], dtype=int32) + >>> jnp.reshape(x, (3, 2)) + Array([[1, 2], + [3, 4], + [5, 6]], dtype=int32) + + You can use ``-1`` to automatically compute a shape that is consistent with + the input size: + + >>> jnp.reshape(x, -1) # -1 is inferred to be 6 + Array([1, 2, 3, 4, 5, 6], dtype=int32) + >>> jnp.reshape(x, (-1, 2)) # -1 is inferred to be 3 + Array([[1, 2], + [3, 4], + [5, 6]], dtype=int32) + + The default ordering of axes in the reshape is C-style row-major ordering. + To use Fortran-style column-major ordering, specify ``order='F'``: + + >>> jnp.reshape(x, 6, order='F') + Array([1, 4, 2, 5, 3, 6], dtype=int32) + >>> jnp.reshape(x, (3, 2), order='F') + Array([[1, 5], + [4, 3], + [2, 6]], dtype=int32) + + For convenience, this functionality is also available via the + :meth:`jax.Array.reshape` method: + + >>> x.reshape(3, 2) + Array([[1, 2], + [3, 4], + [5, 6]], dtype=int32) + """ __tracebackhide__ = True util.check_arraylike("reshape", a) + + # TODO(micky774): deprecated 2024-5-9, remove after deprecation expires. + if not isinstance(newshape, DeprecatedArg): + if shape is not None: + raise ValueError( + "jnp.reshape received both `shape` and `newshape` arguments. Note that " + "using `newshape` is deprecated, please only use `shape` instead." + ) + warnings.warn( + "The newshape argument of jax.numpy.reshape is deprecated and setting it " + "will soon raise an error. To avoid an error in the future, and to " + "suppress this warning, please use the shape argument instead.", + DeprecationWarning, stacklevel=2) + shape = newshape + del newshape + elif shape is None: + raise TypeError( + "jnp.shape requires passing a `shape` argument, but none was given." + ) try: # forward to method for ndarrays - return a.reshape(newshape, order=order) # type: ignore[call-overload,union-attr] + return a.reshape(shape, order=order) # type: ignore[call-overload,union-attr] except AttributeError: pass - return asarray(a).reshape(newshape, order=order) + return asarray(a).reshape(shape, order=order) -@util.implements(np.ravel, lax_description=_ARRAY_VIEW_DOC) @partial(jit, static_argnames=('order',), inline=True) def ravel(a: ArrayLike, order: str = "C") -> Array: + """Flatten array into a 1-dimensional shape. + + JAX implementation of :func:`numpy.ravel`, implemented in terms of + :func:`jax.lax.reshape`. + + ``ravel(arr, order=order)`` is equivalent to ``reshape(arr, -1, order=order)``. + + Args: + a: array to be flattened. + order: ``'F'`` or ``'C'``, specifies whether the reshape should apply column-major + (fortran-style, ``"F"``) or row-major (C-style, ``"C"``) order; default is ``"C"``. + JAX does not support `order="A"` or `order="K"`. + + Returns: + flattened copy of input array. + + Notes: + Unlike :func:`numpy.ravel`, :func:`jax.numpy.ravel` will return a copy rather + than a view of the input array. However, under JIT, the compiler will optimize-away + such copies when possible, so this doesn't have performance impacts in practice. + + See Also: + - :meth:`jax.Array.ravel`: equivalent functionality via an array method. + - :func:`jax.numpy.reshape`: general array reshape. + + Examples: + >>> x = jnp.array([[1, 2, 3], + ... [4, 5, 6]]) + + By default, ravel in C-style, row-major order + + >>> jnp.ravel(x) + Array([1, 2, 3, 4, 5, 6], dtype=int32) + + Optionally ravel in Fortran-style, column-major: + + >>> jnp.ravel(x, order='F') + Array([1, 4, 2, 5, 3, 6], dtype=int32) + + For convenience, the same functionality is available via the :meth:`jax.Array.ravel` + method: + + >>> x.ravel() + Array([1, 2, 3, 4, 5, 6], dtype=int32) + """ util.check_arraylike("ravel", a) if order == "K": raise NotImplementedError("Ravel not implemented for order='K'.") return reshape(a, (size(a),), order) -@util.implements(np.ravel_multi_index) def ravel_multi_index(multi_index: Sequence[ArrayLike], dims: Sequence[int], mode: str = 'raise', order: str = 'C') -> Array: + """Convert multi-dimensional indices into flat indices. + + JAX implementation of :func:`numpy.ravel_multi_index` + + Args: + multi_index: sequence of integer arrays containing indices in each dimension. + dims: sequence of integer sizes; must have ``len(dims) == len(multi_index)`` + mode: how to handle out-of bound indices. Options are + + - ``"raise"`` (default): raise a ValueError. This mode is incompatible + with :func:`~jax.jit` or other JAX transformations. + - ``"clip"``: clip out-of-bound indices to valid range. + - ``"wrap"``: wrap out-of-bound indices to valid range. + + order: ``"C"`` (default) or ``"F"``, specify whether to assume C-style + row-major order or Fortran-style column-major order. + + Returns: + array of flattened indices + + See also: + :func:`jax.numpy.unravel_index`: inverse of this function. + + Examples: + Define a 2-dimensional array and a sequence of indices of even values: + + >>> x = jnp.array([[2., 3., 4.], + ... [5., 6., 7.]]) + >>> indices = jnp.where(x % 2 == 0) + >>> indices + (Array([0, 0, 1], dtype=int32), Array([0, 2, 1], dtype=int32)) + >>> x[indices] + Array([2., 4., 6.], dtype=float32) + + Compute the flattened indices: + + >>> indices_flat = jnp.ravel_multi_index(indices, x.shape) + >>> indices_flat + Array([0, 2, 4], dtype=int32) + + These flattened indices can be used to extract the same values from the + flattened ``x`` array: + + >>> x_flat = x.ravel() + >>> x_flat + Array([2., 3., 4., 5., 6., 7.], dtype=float32) + >>> x_flat[indices_flat] + Array([2., 4., 6.], dtype=float32) + + The original indices can be recovered with :func:`~jax.numpy.unravel_index`: + + >>> jnp.unravel_index(indices_flat, x.shape) + (Array([0, 0, 1], dtype=int32), Array([0, 2, 1], dtype=int32)) + """ assert len(multi_index) == len(dims), f"len(multi_index)={len(multi_index)} != len(dims)={len(dims)}" dims = tuple(core.concrete_or_error(operator.index, d, "in `dims` argument of ravel_multi_index().") for d in dims) util.check_arraylike("ravel_multi_index", *multi_index) @@ -837,13 +1497,48 @@ def ravel_multi_index(multi_index: Sequence[ArrayLike], dims: Sequence[int], return result -_UNRAVEL_INDEX_DOC = """\ -Unlike numpy's implementation of unravel_index, negative indices are accepted -and out-of-bounds indices are clipped into the valid range. -""" - -@util.implements(np.unravel_index, lax_description=_UNRAVEL_INDEX_DOC) def unravel_index(indices: ArrayLike, shape: Shape) -> tuple[Array, ...]: + """Convert flat indices into multi-dimensional indices. + + JAX implementation of :func:`numpy.unravel_index`. The JAX version differs in + its treatment of out-of-bound indices: unlike NumPy, negative indices are + supported, and out-of-bound indices are clipped to the nearest valid value. + + Args: + indices: integer array of flat indices + shape: shape of multidimensional array to index into + + Returns: + Tuple of unraveled indices + + See also: + :func:`jax.numpy.ravel_multi_index`: Inverse of this function. + + Examples: + Start with a 1D array values and indices: + + >>> x = jnp.array([2., 3., 4., 5., 6., 7.]) + >>> indices = jnp.array([1, 3, 5]) + >>> print(x[indices]) + [3. 5. 7.] + + Now if ``x`` is reshaped, ``unravel_indices`` can be used to convert + the flat indices into a tuple of indices that access the same entries: + + >>> shape = (2, 3) + >>> x_2D = x.reshape(shape) + >>> indices_2D = jnp.unravel_index(indices, shape) + >>> indices_2D + (Array([0, 1, 1], dtype=int32), Array([1, 0, 2], dtype=int32)) + >>> print(x_2D[indices_2D]) + [3. 5. 7.] + + The inverse function, ``ravel_multi_index``, can be used to obtain the + original indices: + + >>> jnp.ravel_multi_index(indices_2D, shape) + Array([1, 3, 5], dtype=int32) + """ util.check_arraylike("unravel_index", indices) indices_arr = asarray(indices) # Note: we do not convert shape to an array, because it may be passed as a @@ -851,10 +1546,12 @@ def unravel_index(indices: ArrayLike, shape: Shape) -> tuple[Array, ...]: try: shape = list(shape) except TypeError: - shape = [shape] + # TODO: Consider warning here since shape is supposed to be a sequence, so + # this should not happen. + shape = cast(list[Any], [shape]) if any(ndim(s) != 0 for s in shape): raise ValueError("unravel_index: shape should be a scalar or 1D sequence.") - out_indices = [0] * len(shape) + out_indices: list[ArrayLike] = [0] * len(shape) for i, s in reversed(list(enumerate(shape))): indices_arr, out_indices[i] = ufuncs.divmod(indices_arr, s) oob_pos = indices_arr > 0 @@ -882,8 +1579,63 @@ def resize(a: ArrayLike, new_shape: Shape) -> Array: return reshape(arr, new_shape) -@util.implements(np.squeeze, lax_description=_ARRAY_VIEW_DOC) + def squeeze(a: ArrayLike, axis: int | Sequence[int] | None = None) -> Array: + """Remove one or more length-1 axes from array + + JAX implementation of :func:`numpy.sqeeze`, implemented via :func:`jax.lax.squeeze`. + + Args: + a: input array + axis: integer or sequence of integers specifying axes to remove. If any specified + axis does not have a length of 1, an error is raised. If not specified, squeeze + all length-1 axes in ``a``. + + Returns: + copy of ``a`` with length-1 axes removed. + + Notes: + Unlike :func:`numpy.squeeze`, :func:`jax.numpy.squeeze` will return a copy rather + than a view of the input array. However, under JIT, the compiler will optimize-away + such copies when possible, so this doesn't have performance impacts in practice. + + See Also: + - :func:`jax.numpy.expand_dims`: the inverse of ``squeeze``: add dimensions of length 1. + - :meth:`jax.Array.squeeze`: equivalent functionality via an array method. + - :func:`jax.lax.squeeze`: equivalent XLA API. + - :func:`jax.numpy.ravel`: flatten an array into a 1D shape. + - :func:`jax.numpy.reshape`: general array reshape. + + Examples: + >>> x = jnp.array([[[0]], [[1]], [[2]]]) + >>> x.shape + (3, 1, 1) + + Squeeze all length-1 dimensions: + + >>> jnp.squeeze(x) + Array([0, 1, 2], dtype=int32) + >>> _.shape + (3,) + + Equivalent while specifying the axes explicitly: + + >>> jnp.squeeze(x, axis=(1, 2)) + Array([0, 1, 2], dtype=int32) + + Attempting to squeeze a non-unit axis results in an error: + + >>> jnp.squeeze(x, axis=0) # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + ValueError: cannot select an axis to squeeze out which has size not equal to one, got shape=(3, 1, 1) and dimensions=(0,) + + For convenience, this functionality is also available via the + :meth:`jax.Array.squeeze` method: + + >>> x.squeeze() + Array([0, 1, 2], dtype=int32) + """ util.check_arraylike("squeeze", a) return _squeeze(asarray(a), _ensure_index_tuple(axis) if axis is not None else None) @@ -898,25 +1650,171 @@ def _squeeze(a: Array, axis: tuple[int, ...]) -> Array: return lax.squeeze(a, axis) -@util.implements(np.expand_dims) def expand_dims(a: ArrayLike, axis: int | Sequence[int]) -> Array: + """Insert dimensions of length 1 into array + + JAX implementation of :func:`numpy.expand_dims`, implemented via + :func:`jax.lax.expand_dims`. + + Args: + a: input array + axis: integer or sequence of integers specifying positions of axes to add. + + Returns: + Copy of ``a`` with added dimensions. + + Notes: + Unlike :func:`numpy.expand_dims`, :func:`jax.numpy.expand_dims` will return a copy + rather than a view of the input array. However, under JIT, the compiler will optimize + away such copies when possible, so this doesn't have performance impacts in practice. + + See Also: + - :func:`jax.numpy.squeeze`: inverse of this operation, i.e. remove length-1 dimensions. + - :func:`jax.lax.expand_dims`: XLA version of this functionality. + + Examples: + >>> x = jnp.array([1, 2, 3]) + >>> x.shape + (3,) + + Expand the leading dimension: + + >>> jnp.expand_dims(x, 0) + Array([[1, 2, 3]], dtype=int32) + >>> _.shape + (1, 3) + + Expand the trailing dimension: + + >>> jnp.expand_dims(x, 1) + Array([[1], + [2], + [3]], dtype=int32) + >>> _.shape + (3, 1) + + Expand multiple dimensions: + + >>> jnp.expand_dims(x, (0, 1, 3)) + Array([[[[1], + [2], + [3]]]], dtype=int32) + >>> _.shape + (1, 1, 3, 1) + + Dimensions can also be expanded more succinctly by indexing with ``None``: + + >>> x[None] # equivalent to jnp.expand_dims(x, 0) + Array([[1, 2, 3]], dtype=int32) + >>> x[:, None] # equivalent to jnp.expand_dims(x, 1) + Array([[1], + [2], + [3]], dtype=int32) + >>> x[None, None, :, None] # equivalent to jnp.expand_dims(x, (0, 1, 3)) + Array([[[[1], + [2], + [3]]]], dtype=int32) + """ util.check_arraylike("expand_dims", a) axis = _ensure_index_tuple(axis) return lax.expand_dims(a, axis) -@util.implements(np.swapaxes, lax_description=_ARRAY_VIEW_DOC) @partial(jit, static_argnames=('axis1', 'axis2'), inline=True) def swapaxes(a: ArrayLike, axis1: int, axis2: int) -> Array: + """Swap two axes of an array. + + JAX implementation of :func:`numpy.swapaxes`, implemented in terms of + :func:`jax.lax.transpose`. + + Args: + a: input array + axis1: index of first axis + axis2: index of second axis + + Returns: + Copy of ``a`` with specified axes swapped. + + Notes: + Unlike :func:`numpy.swapaxes`, :func:`jax.numpy.swapaxes` will return a copy rather + than a view of the input array. However, under JIT, the compiler will optimize away + such copies when possible, so this doesn't have performance impacts in practice. + + See Also: + - :func:`jax.numpy.moveaxis`: move a single axis of an array. + - :func:`jax.numpy.rollaxis`: older API for ``moveaxis``. + - :func:`jax.lax.transpose`: more general axes permutations. + - :meth:`jax.Array.swapaxes`: same functionality via an array method. + + Examples: + >>> a = jnp.ones((2, 3, 4, 5)) + >>> jnp.swapaxes(a, 1, 3).shape + (2, 5, 4, 3) + + Equivalent output via the ``swapaxes`` array method: + + >>> a.swapaxes(1, 3).shape + (2, 5, 4, 3) + + Equivalent output via :func:`~jax.numpy.transpose`: + + >>> a.transpose(0, 3, 2, 1).shape + (2, 5, 4, 3) + """ util.check_arraylike("swapaxes", a) perm = np.arange(ndim(a)) perm[axis1], perm[axis2] = perm[axis2], perm[axis1] return lax.transpose(a, list(perm)) -@util.implements(np.moveaxis, lax_description=_ARRAY_VIEW_DOC) def moveaxis(a: ArrayLike, source: int | Sequence[int], destination: int | Sequence[int]) -> Array: + """Move an array axis to a new position + + JAX implementation of :func:`numpy.moveaxis`, implemented in terms of + :func:`jax.lax.transpose`. + + Args: + a: input array + source: index or indices of the axes to move. + destination: index or indices of the axes destinations + + Returns: + Copy of ``a`` with axes moved from ``source`` to ``destination``. + + Notes: + Unlike :func:`numpy.moveaxis`, :func:`jax.numpy.moveaxis` will return a copy rather + than a view of the input array. However, under JIT, the compiler will optimize away + such copies when possible, so this doesn't have performance impacts in practice. + + See also: + - :func:`jax.numpy.swapaxes`: swap two axes. + - :func:`jax.numpy.rollaxis`: older API for moving an axis. + - :func:`jax.numpy.transpose`: general axes permutation. + + Examples: + >>> a = jnp.ones((2, 3, 4, 5)) + + Move axis ``1`` to the end of the array: + + >>> jnp.moveaxis(a, 1, -1).shape + (2, 4, 5, 3) + + Move the last axis to position 1: + + >>> jnp.moveaxis(a, -1, 1).shape + (2, 5, 3, 4) + + Move multiple axes: + + >>> jnp.moveaxis(a, (0, 1), (-1, -2)).shape + (4, 5, 3, 2) + + This can also be accomplished via :func:`~jax.numpy.transpose`: + + >>> a.transpose(2, 3, 1, 0).shape + (4, 5, 3, 2) + """ util.check_arraylike("moveaxis", a) return _moveaxis(asarray(a), _ensure_index_tuple(source), _ensure_index_tuple(destination)) @@ -1058,7 +1956,7 @@ def interp(x: ArrayLike, xp: ArrayLike, fp: ArrayLike, return jitted_interp(x, xp, fp, left, right, period) -@overload # type: ignore[no-overload-impl] +@overload def where(condition: ArrayLike, x: Literal[None] = None, y: Literal[None] = None, /, *, size: int | None = None, fill_value: None | ArrayLike | tuple[ArrayLike, ...] = None @@ -1076,69 +1974,80 @@ def where(condition: ArrayLike, x: ArrayLike | None = None, fill_value: None | ArrayLike | tuple[ArrayLike, ...] = None ) -> Array | tuple[Array, ...]: ... -_DEPRECATED_WHERE_ARG = object() -@util.implements(np.where, # type: ignore[no-redef] - lax_description=_dedent(""" - At present, JAX does not support JIT-compilation of the single-argument form - of :py:func:`jax.numpy.where` because its output shape is data-dependent. The - three-argument form does not have a data-dependent shape and can be JIT-compiled - successfully. Alternatively, you can use the optional ``size`` keyword to - statically specify the expected size of the output.\n\n - - Special care is needed when the ``x`` or ``y`` input to - :py:func:`jax.numpy.where` could have a value of NaN. - Specifically, when a gradient is taken - with :py:func:`jax.grad` (reverse-mode differentiation), a NaN in either - ``x`` or ``y`` will propagate into the gradient, regardless of the value - of ``condition``. More information on this behavior and workarounds - is available in the JAX FAQ: - https://jax.readthedocs.io/en/latest/faq.html#gradients-contain-nan-where-using-where"""), - extra_params=_dedent(""" - size : int, optional - Only referenced when ``x`` and ``y`` are ``None``. If specified, the indices of the first - ``size`` elements of the result will be returned. If there are fewer elements than ``size`` - indicates, the return value will be padded with ``fill_value``. - fill_value : array_like, optional - When ``size`` is specified and there are fewer than the indicated number of elements, the - remaining elements will be filled with ``fill_value``, which defaults to zero.""")) -def where( - acondition = None, if_true = None, if_false = None, /, *, - size=None, fill_value=None, - # Deprecated keyword-only names. - condition = _DEPRECATED_WHERE_ARG, x = _DEPRECATED_WHERE_ARG, - y = _DEPRECATED_WHERE_ARG -) -> Array | tuple[Array, ...]: - if (condition is not _DEPRECATED_WHERE_ARG or x is not _DEPRECATED_WHERE_ARG - or y is not _DEPRECATED_WHERE_ARG): - # TODO(phawkins): deprecated Nov 17 2023, remove after deprecation expires. - warnings.warn( - "Passing condition, x, or y to jax.numpy.where via keyword arguments " - "is deprecated.", - DeprecationWarning, - stacklevel=2, - ) - if condition is not _DEPRECATED_WHERE_ARG: - if acondition is not None: - raise ValueError("condition should be a positional-only argument") - acondition = condition - if x is not _DEPRECATED_WHERE_ARG: - if if_true is not None: - raise ValueError("x should be a positional-only argument") - if_true = x - if y is not _DEPRECATED_WHERE_ARG: - if if_false is not None: - raise ValueError("y should be a positional-only argument") - if_false = y - - if if_true is None and if_false is None: - util.check_arraylike("where", acondition) - return nonzero(acondition, size=size, fill_value=fill_value) +def where(condition, x=None, y=None, /, *, size=None, fill_value=None): + """Select elements from two arrays based on a condition. + + JAX implementation of :func:`numpy.where`. + + .. note:: + when only ``condition`` is provided, ``jnp.where(condition)`` is equivalent + to ``jnp.nonzero(condition)``. For that case, refer to the documentation of + :func:`jax.numpy.nonzero`. The docstring below focuses on the case where + ``x`` and ``y`` are specified. + + The three-term version of ``jnp.where`` lowers to :func:`jax.lax.select`. + + Args: + condition: boolean array. Must be broadcast-compatible with ``x`` and ``y`` when + they are specified. + x: arraylike. Should be broadcast-compatible with ``condition`` and ``y``, and + typecast-compatible with ``y``. + y: arraylike. Should be broadcast-compatible with ``condition`` and ``x``, and + typecast-compatible with ``x``. + size: integer, only referenced when ``x`` and ``y`` are ``None``. For details, + see :func:`jax.numpy.nonzero`. + fill_value: only referenced when ``x`` and ``y`` are ``None``. For details, + see :func:`jax.numpy.nonzero`. + + Returns: + An array of dtype ``jnp.result_type(x, y)`` with values drawn from ``x`` where ``condition`` + is True, and from ``y`` where condition is ``False. If ``x`` and ``y`` are ``None``, the + function behaves differently; see `:func:`jax.numpy.nonzero` for a description of the return + type. + + See Also: + - :func:`jax.numpy.nonzero` + - :func:`jax.numpy.argwhere` + - :func:`jax.lax.select` + + Notes: + Special care is needed when the ``x`` or ``y`` input to :func:`jax.numpy.where` could + have a value of NaN. Specifically, when a gradient is taken with :func:`jax.grad` + (reverse-mode differentiation), a NaN in either ``x`` or ``y`` will propagate into the + gradient, regardless of the value of ``condition``. More information on this behavior + and workarounds is available in the `JAX FAQ + `_. + + Examples: + When ``x`` and ``y`` are not provided, ``where`` behaves equivalently to + :func:`jax.numpy.nonzero`: + + >>> x = jnp.arange(10) + >>> jnp.where(x > 4) + (Array([5, 6, 7, 8, 9], dtype=int32),) + >>> jnp.nonzero(x > 4) + (Array([5, 6, 7, 8, 9], dtype=int32),) + + When ``x`` and ``y`` are provided, ``where`` selects between them based on + the specified condition: + + >>> jnp.where(x > 4, x, 0) + Array([0, 0, 0, 0, 0, 5, 6, 7, 8, 9], dtype=int32) + """ + if x is None and y is None: + util.check_arraylike("where", condition) + return nonzero(condition, size=size, fill_value=fill_value) else: - util.check_arraylike("where", acondition, if_true, if_false) + util.check_arraylike("where", condition, x, y) if size is not None or fill_value is not None: - raise ValueError("size and fill_value arguments cannot be used in three-term where function.") - return util._where(acondition, if_true, if_false) + raise ValueError("size and fill_value arguments cannot be used in " + "three-term where function.") + if x is None or y is None: + raise ValueError("Either both or neither of the x and y arguments " + "should be provided to jax.numpy.where, got " + f"{x} and {y}.") + return util._where(condition, x, y) @util.implements(np.select) @@ -1152,26 +2061,76 @@ def select( raise ValueError(msg.format(len(condlist), len(choicelist))) if len(condlist) == 0: raise ValueError("condlist must be non-empty") - choices = util.promote_dtypes(default, *choicelist) - choicelist = choices[1:] - output = choices[0] - for cond, choice in zip(condlist[::-1], choicelist[::-1]): - output = where(cond, choice, output) - return output - + # Put the default at front with condition False because + # argmax returns zero for an array of False values. + choicelist = util.promote_dtypes(default, *choicelist) + conditions = stack(broadcast_arrays(False, *condlist)) + idx = argmax(conditions.astype(bool), axis=0) + return lax.select_n(*broadcast_arrays(idx, *choicelist)) -@util.implements(np.bincount, lax_description="""\ -Jax adds the optional `length` parameter which specifies the output length, and -defaults to ``x.max() + 1``. It must be specified for bincount to be compiled -with non-static operands. Values larger than the specified length will be discarded. -If `length` is specified, `minlength` will be ignored. -Additionally, while ``np.bincount`` raises an error if the input array contains -negative values, ``jax.numpy.bincount`` clips negative values to zero. -""") def bincount(x: ArrayLike, weights: ArrayLike | None = None, minlength: int = 0, *, length: int | None = None ) -> Array: + """Count the number of occurrences of each value in an integer array. + + JAX implementation of :func:`numpy.bincount`. + + For an array of positive integers ``x``, this function returns an array ``counts`` + of size ``x.max() + 1``, such that ``counts[i]`` contains the number of occurrences + of the value ``i`` in ``x``. + + The JAX version has a few differences from the NumPy version: + + - In NumPy, passing an array ``x`` with negative entries will result in an error. + In JAX, negative values are clipped to zero. + - JAX adds an optional ``length`` parameter which can be used to statically specify + the length of the output array so that this function can be used with transformations + like :func:`jax.jit`. In this case, items larger than `length + 1` will be dropped. + + Args: + x : N-dimensional array of positive integers + weights: optional array of weights associated with ``x``. If not specified, the + weight for each entry will be ``1``. + minlength: the minimum length of the output counts array. + length: the length of the output counts array. Must be specified statically for + ``bincount`` to be used with :func:`jax.jit` and other JAX transformations. + + Returns: + An array of counts or summed weights reflecting the number of occurrences of values + in ``x``. + + See Also: + - :func:`jax.numpy.histogram` + - :func:`jax.numpy.digitize` + - :func:`jax.numpy.unique_counts` + + Examples: + Basic bincount: + + >>> x = jnp.array([1, 1, 2, 3, 3, 3]) + >>> jnp.bincount(x) + Array([0, 2, 1, 3], dtype=int32) + + Weighted bincount: + + >>> weights = jnp.array([1, 2, 3, 4, 5, 6]) + >>> jnp.bincount(x, weights) + Array([ 0, 3, 3, 15], dtype=int32) + + Specifying a static ``length`` makes this jit-compatible: + + >>> jit_bincount = jax.jit(jnp.bincount, static_argnames=['length']) + >>> jit_bincount(x, length=5) + Array([0, 2, 1, 3, 0], dtype=int32) + + Any negative numbers are clipped to the first bin, and numbers beyond the + specified ``length`` are dropped: + + >>> x = jnp.array([-1, -1, 1, 3, 10]) + >>> jnp.bincount(x, length=5) + Array([2, 1, 0, 1, 0], dtype=int32) + """ util.check_arraylike("bincount", x) if not issubdtype(_dtype(x), integer): raise TypeError(f"x argument to bincount must have an integer type; got {_dtype(x)}") @@ -1286,20 +2245,63 @@ def array_split(ary: ArrayLike, indices_or_sections: int | Sequence[int] | Array axis: int = 0) -> list[Array]: return _split("array_split", ary, indices_or_sections, axis=axis) -@util.implements(np.clip, skip_params=['out']) + +_DEPRECATED_CLIP_ARG = DeprecatedArg() +@util.implements( + np.clip, + skip_params=['a', 'a_min'], + extra_params=_dedent(""" + x : array_like + Array containing elements to clip. + min : array_like, optional + Minimum value. If ``None``, clipping is not performed on the + corresponding edge. The value of ``min`` is broadcast against x. + max : array_like, optional + Maximum value. If ``None``, clipping is not performed on the + corresponding edge. The value of ``max`` is broadcast against x. +""") +) @jit -def clip(a: ArrayLike, a_min: ArrayLike | None = None, - a_max: ArrayLike | None = None, out: None = None) -> Array: - util.check_arraylike("clip", a) - if out is not None: - raise NotImplementedError("The 'out' argument to jnp.clip is not supported.") - if a_min is None and a_max is None: - raise ValueError("At most one of a_min and a_max may be None") - if a_min is not None: - a = ufuncs.maximum(a_min, a) - if a_max is not None: - a = ufuncs.minimum(a_max, a) - return asarray(a) +def clip( + x: ArrayLike | None = None, # Default to preserve backwards compatability + /, + min: ArrayLike | None = None, + max: ArrayLike | None = None, + *, + a: ArrayLike | DeprecatedArg = _DEPRECATED_CLIP_ARG, + a_min: ArrayLike | None | DeprecatedArg = _DEPRECATED_CLIP_ARG, + a_max: ArrayLike | None | DeprecatedArg = _DEPRECATED_CLIP_ARG +) -> Array: + # TODO(micky774): deprecated 2024-4-2, remove after deprecation expires. + x = a if not isinstance(a, DeprecatedArg) else x + if x is None: + raise ValueError("No input was provided to the clip function.") + min = a_min if not isinstance(a_min, DeprecatedArg) else min + max = a_max if not isinstance(a_max, DeprecatedArg) else max + if any(not isinstance(t, DeprecatedArg) for t in (a, a_min, a_max)): + warnings.warn( + "Passing arguments 'a', 'a_min', or 'a_max' to jax.numpy.clip is " + "deprecated. Please use 'x', 'min', and 'max' respectively instead.", + DeprecationWarning, + stacklevel=2, + ) + + util.check_arraylike("clip", x) + if any(jax.numpy.iscomplexobj(t) for t in (x, min, max)): + # TODO(micky774): Deprecated 2024-4-2, remove after deprecation expires. + warnings.warn( + "Clip received a complex value either through the input or the min/max " + "keywords. Complex values have no ordering and cannot be clipped. " + "Attempting to clip using complex numbers is deprecated and will soon " + "raise a ValueError. Please convert to a real value or array by taking " + "the real or imaginary components via jax.numpy.real/imag respectively.", + DeprecationWarning, stacklevel=2, + ) + if min is not None: + x = ufuncs.maximum(min, x) + if max is not None: + x = ufuncs.minimum(max, x) + return asarray(x) @util.implements(np.around, skip_params=['out']) @partial(jit, static_argnames=('decimals',)) @@ -1378,33 +2380,93 @@ def allclose(a: ArrayLike, b: ArrayLike, rtol: ArrayLike = 1e-05, return reductions.all(isclose(a, b, rtol, atol, equal_nan)) -_NONZERO_DOC = """\ -Because the size of the output of ``nonzero`` is data-dependent, the function is not -typically compatible with JIT. The JAX version adds the optional ``size`` argument which -must be specified statically for ``jnp.nonzero`` to be used within some of JAX's -transformations. -""" -_NONZERO_EXTRA_PARAMS = """ -size : int, optional - If specified, the indices of the first ``size`` True elements will be returned. If there are - fewer unique elements than ``size`` indicates, the return value will be padded with ``fill_value``. -fill_value : array_like, optional - When ``size`` is specified and there are fewer than the indicated number of elements, the - remaining elements will be filled with ``fill_value``, which defaults to zero. -""" - -@util.implements(np.nonzero, lax_description=_NONZERO_DOC, extra_params=_NONZERO_EXTRA_PARAMS) def nonzero(a: ArrayLike, *, size: int | None = None, fill_value: None | ArrayLike | tuple[ArrayLike, ...] = None ) -> tuple[Array, ...]: + """Return indices of nonzero elements of an array. + + JAX implementation of :func:`numpy.nonzero`. + + Because the size of the output of ``nonzero`` is data-dependent, the function + is not compatible with JIT and other transformations. The JAX version adds + the optional ``size`` argument which must be specified statically for + ``jnp.nonzero`` to be used within JAX's transformations. + + Args: + a: N-dimensional array. + size: optional static integer specifying the number of nonzero entries to + return. If there are more nonzero elements than the specified ``size``, + then indices will be truncated at the end. If there are fewer nonzero + elements than the specified size, then indices will be padded with + ``fill_value``, which defaults to zero. + fill_value: optional padding value when ``size`` is specified. Defaults to 0. + + Returns: + Tuple of JAX Arrays of length ``a.ndim``, containing the indices of each + nonzero value. + + See also: + - :func:`jax.numpy flatnonzero` + - :func:`jax.numpy.where` + + Examples: + + One-dimensional array returns a length-1 tuple of indices: + + >>> x = jnp.array([0, 5, 0, 6, 0, 7]) + >>> jnp.nonzero(x) + (Array([1, 3, 5], dtype=int32),) + + Two-dimensional array returns a length-2 tuple of indices: + + >>> x = jnp.array([[0, 5, 0], + ... [6, 0, 7]]) + >>> jnp.nonzero(x) + (Array([0, 1, 1], dtype=int32), Array([1, 0, 2], dtype=int32)) + + In either case, the resulting tuple of indices can be used directly to extract + the nonzero values: + + >>> indices = jnp.nonzero(x) + >>> x[indices] + Array([5, 6, 7], dtype=int32) + + The output of ``nonzero`` has a dynamic shape, because the number of returned + indices depends on the contents of the input array. As such, it is incompatible + with JIT and other JAX transformations: + + >>> x = jnp.array([0, 5, 0, 6, 0, 7]) + >>> jax.jit(jnp.nonzero)(x) # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape int32[]. + The size argument of jnp.nonzero must be statically specified to use jnp.nonzero within JAX transformations. + + This can be addressed by passing a static ``size`` parameter to specify the + desired output shape: + + >>> nonzero_jit = jax.jit(jnp.nonzero, static_argnames='size') + >>> nonzero_jit(x, size=3) + (Array([1, 3, 5], dtype=int32),) + + If ``size`` does not match the true size, the result will be either truncated or padded: + + >>> nonzero_jit(x, size=2) # size < 3: indices are truncated + (Array([1, 3], dtype=int32),) + >>> nonzero_jit(x, size=5) # size > 3: indices are padded with zeros. + (Array([1, 3, 5, 0, 0], dtype=int32),) + + You can specify a custom fill value for the padding using the ``fill_value`` argument: + + >>> nonzero_jit(x, size=5, fill_value=len(x)) + (Array([1, 3, 5, 6, 6], dtype=int32),) + """ util.check_arraylike("nonzero", a) arr = asarray(a) del a if ndim(arr) == 0: - # Added 2023 Dec 6 - warnings.warn("Calling nonzero on 0d arrays is deprecated. Use `atleast_1d(arr).nonzero()", - DeprecationWarning, stacklevel=2) - arr = atleast_1d(arr) + raise ValueError("Calling nonzero on 0d arrays is not allowed. " + "Use jnp.atleast_1d(scalar).nonzero() instead.") mask = arr if arr.dtype == bool else (arr != 0) calculated_size = mask.sum() if size is None else size calculated_size = core.concrete_dim_or_error(calculated_size, @@ -1425,9 +2487,49 @@ def nonzero(a: ArrayLike, *, size: int | None = None, out = tuple(where(fill_mask, fval, entry) for fval, entry in safe_zip(fill_value_tup, out)) return out -@util.implements(np.flatnonzero, lax_description=_NONZERO_DOC, extra_params=_NONZERO_EXTRA_PARAMS) + def flatnonzero(a: ArrayLike, *, size: int | None = None, fill_value: None | ArrayLike | tuple[ArrayLike, ...] = None) -> Array: + """Return indices of nonzero elements in a flattened array + + JAX implementation of :func:`numpy.flatnonzero`. + + ``jnp.flatnonzero(x)`` is equivalent to ``nonzero(ravel(a))[0]``. For a full + discussion of the parameters to this function, refer to :func:`jax.numpy.nonzero`. + + Args: + a: N-dimensional array. + size: optional static integer specifying the number of nonzero entries to + return. See :func:`jax.numpy.nonzero` for more discussion of this parameter. + fill_value: optional padding value when ``size`` is specified. Defaults to 0. + See :func:`jax.numpy.nonzero` for more discussion of this parameter. + + Returns: + Array containing the indices of each nonzero value in the flattened array. + + See Also: + - :func:`jax.numpy.nonzero` + - :func:`jax.numpy.where` + + Examples: + >>> x = jnp.array([[0, 5, 0], + ... [6, 0, 8]]) + >>> jnp.flatnonzero(x) + Array([1, 3, 5], dtype=int32) + + This is equivalent to calling :func:`~jax.numpy.nonzero` on the flattened + array, and extracting the first entry in the resulting tuple: + + >>> jnp.nonzero(x.ravel())[0] + Array([1, 3, 5], dtype=int32) + + The returned indices can be used to extract nonzero entries from the + flattened array: + + >>> indices = jnp.flatnonzero(x) + >>> x.ravel()[indices] + Array([5, 6, 8], dtype=int32) + """ return nonzero(ravel(a), size=size, fill_value=fill_value)[0] @@ -1795,7 +2897,7 @@ def pad(array: ArrayLike, pad_width: PadValueLike[int | Array | np.ndarray], 'symmetric': ['reflect_type'], } try: - unsupported_kwargs = set(kwargs) - set(allowed_kwargs[mode]) # type: ignore[call-overload] + unsupported_kwargs = set(kwargs) - set(allowed_kwargs[mode]) except KeyError: msg = "Unimplemented padding mode '{}' for np.pad." raise NotImplementedError(msg.format(mode)) @@ -1834,6 +2936,18 @@ def stack(arrays: np.ndarray | Array | Sequence[ArrayLike], new_arrays.append(expand_dims(a, axis)) return concatenate(new_arrays, axis=axis, dtype=dtype) +@util.implements(getattr(np, 'unstack', None)) +@partial(jit, static_argnames="axis") +def unstack(x: ArrayLike, /, *, axis: int = 0) -> tuple[Array, ...]: + util.check_arraylike("unstack", x) + x = asarray(x) + if x.ndim == 0: + raise ValueError( + "Unstack requires arrays with rank > 0, however a scalar array was " + "passed." + ) + return tuple(moveaxis(x, axis, 0)) + @util.implements(np.tile) def tile(A: ArrayLike, reps: DimSize | Sequence[DimSize]) -> Array: util.check_arraylike("tile", A) @@ -1842,7 +2956,7 @@ def tile(A: ArrayLike, reps: DimSize | Sequence[DimSize]) -> Array: except TypeError: reps_tup: tuple[DimSize, ...] = (reps,) else: - reps_tup = tuple(reps) # type: ignore[assignment,arg-type] + reps_tup = tuple(reps) # type: ignore[arg-type] reps_tup = tuple(operator.index(rep) if core.is_constant_dim(rep) else rep for rep in reps_tup) A_shape = (1,) * (len(reps_tup) - ndim(A)) + shape(A) @@ -2089,7 +3203,7 @@ def _supports_buffer_protocol(obj): https://jax.readthedocs.io/en/latest/faq.html). """ -deprecations.register(__name__, "array-none") +deprecations.register("jax-numpy-array-none") @util.implements(np.array, lax_description=_ARRAY_DOC) def array(object: Any, dtype: DTypeLike | None = None, copy: bool = True, @@ -2133,17 +3247,21 @@ def array(object: Any, dtype: DTypeLike | None = None, copy: bool = True, if hasattr(object, '__jax_array__'): object = object.__jax_array__() elif hasattr(object, '__cuda_array_interface__'): - if xla_extension_version >= 237: - cai = object.__cuda_array_interface__ - backend = xla_bridge.get_backend("cuda") - object = xc._xla.cuda_array_interface_to_buffer(cai, backend) + cai = object.__cuda_array_interface__ + backend = xla_bridge.get_backend("cuda") + if cuda_plugin_extension is None: + device_id = None + else: + device_id = cuda_plugin_extension.get_device_ordinal(cai["data"][0]) + object = xc._xla.cuda_array_interface_to_buffer( + cai=cai, gpu_backend=backend, device_id=device_id) object = tree_map(lambda leaf: leaf.__jax_array__() if hasattr(leaf, "__jax_array__") else leaf, object) leaves = tree_leaves(object, is_leaf=lambda x: x is None) if any(leaf is None for leaf in leaves): # Added Nov 16 2023 - if deprecations.is_accelerated(__name__, "array-none"): + if deprecations.is_accelerated("jax-numpy-array-none"): raise TypeError("None is not a valid value for jnp.array") warnings.warn( "None encountered in jnp.array(); this is currently treated as NaN. " @@ -2179,7 +3297,7 @@ def array(object: Any, dtype: DTypeLike | None = None, copy: bool = True, if object: out = stack([asarray(elt, dtype=dtype) for elt in object]) else: - out = np.array([], dtype=dtype) # type: ignore[arg-type] + out = np.array([], dtype=dtype) elif _supports_buffer_protocol(object): object = memoryview(object) # TODO(jakevdp): update this once we support NumPy 2.0 semantics for the copy arg. @@ -2209,12 +3327,42 @@ def _convert_to_array_if_dtype_fails(x: ArrayLike) -> ArrayLike: In particular, the details of float-to-int and int-to-float casts are implementation dependent. """) -def astype(x: ArrayLike, dtype: DTypeLike | None, /, *, copy: bool = True) -> Array: - del copy # unused in JAX +def astype(x: ArrayLike, dtype: DTypeLike | None, + /, *, copy: bool = False, + device: xc.Device | Sharding | None = None) -> Array: + util.check_arraylike("astype", x) + x_arr = asarray(x) + if dtype is None: dtype = dtypes.canonicalize_dtype(float_) dtypes.check_user_dtype_supported(dtype, "astype") - return lax.convert_element_type(x, dtype) + if issubdtype(x_arr.dtype, complexfloating): + if dtypes.isdtype(dtype, ("integral", "real floating")): + warnings.warn( + "Casting from complex to real dtypes will soon raise a ValueError. " + "Please first use jnp.real or jnp.imag to take the real/imaginary " + "component of your input.", + DeprecationWarning, stacklevel=2 + ) + elif np.dtype(dtype) == bool: + # convert_element_type(complex, bool) has the wrong semantics. + x_arr = (x_arr != _lax_const(x_arr, 0)) + + # We offer a more specific warning than the usual ComplexWarning so we prefer + # to issue our warning. + with warnings.catch_warnings(): + warnings.simplefilter("ignore", ComplexWarning) + return _place_array( + lax.convert_element_type(x_arr, dtype), + device=device, copy=copy, + ) + +def _place_array(x, device=None, copy=None): + # TODO(micky774): Implement in future PRs as we formalize device placement + # semantics + if copy: + return _array_copy(x) + return x @util.implements(np.asarray, lax_description=_ARRAY_DOC) @@ -2232,7 +3380,7 @@ def asarray(a: Any, dtype: DTypeLike | None = None, order: str | None = None, dtypes.check_user_dtype_supported(dtype, "asarray") if dtype is not None: dtype = dtypes.canonicalize_dtype(dtype, allow_extended_dtype=True) # type: ignore[assignment] - return array(a, dtype=dtype, copy=bool(copy), order=order) # type: ignore + return array(a, dtype=dtype, copy=bool(copy), order=order) @util.implements(np.copy, lax_description=_ARRAY_DOC) @@ -2241,11 +3389,40 @@ def copy(a: ArrayLike, order: str | None = None) -> Array: return array(a, copy=True, order=order) -@util.implements(np.zeros_like) def zeros_like(a: ArrayLike | DuckTypedArray, dtype: DTypeLike | None = None, shape: Any = None, *, device: xc.Device | Sharding | None = None) -> Array: + """Create an array full of zeros with the same shape and dtype as an array. + + JAX implementation of :func:`numpy.zeros_like`. + + Args: + a: Array-like object with ``shape`` and ``dtype`` attributes. + shape: optionally override the shape of the created array. + dtype: optionally override the dtype of the created array. + device: (optional) :class:`~jax.Device` or :class:`~jax.sharding.Sharding` + to which the created array will be committed. + + Returns: + Array of the specified shape and dtype, on the specified device if specified. + + See also: + - :func:`jax.numpy.zeros` + - :func:`jax.numpy.empty_like` + - :func:`jax.numpy.ones_like` + - :func:`jax.numpy.full_like` + + Examples: + >>> x = jnp.arange(4) + >>> jnp.zeros_like(x) + Array([0, 0, 0, 0], dtype=int32) + >>> jnp.zeros_like(x, dtype=bool) + Array([False, False, False, False], dtype=bool) + >>> jnp.zeros_like(x, shape=(2, 3)) + Array([[0, 0, 0], + [0, 0, 0]], dtype=int32) + """ if not (hasattr(a, 'dtype') and hasattr(a, 'shape')): # support duck typing util.check_arraylike("zeros_like", a) dtypes.check_user_dtype_supported(dtype, "zeros_like") @@ -2254,11 +3431,40 @@ def zeros_like(a: ArrayLike | DuckTypedArray, return lax.full_like(a, 0, dtype, shape, sharding=_normalize_to_sharding(device)) -@util.implements(np.ones_like) def ones_like(a: ArrayLike | DuckTypedArray, dtype: DTypeLike | None = None, shape: Any = None, *, device: xc.Device | Sharding | None = None) -> Array: + """Create an array of ones with the same shape and dtype as an array. + + JAX implementation of :func:`numpy.ones_like`. + + Args: + a: Array-like object with ``shape`` and ``dtype`` attributes. + shape: optionally override the shape of the created array. + dtype: optionally override the dtype of the created array. + device: (optional) :class:`~jax.Device` or :class:`~jax.sharding.Sharding` + to which the created array will be committed. + + Returns: + Array of the specified shape and dtype, on the specified device if specified. + + See also: + - :func:`jax.numpy.empty` + - :func:`jax.numpy.zeros_like` + - :func:`jax.numpy.ones_like` + - :func:`jax.numpy.full_like` + + Examples: + >>> x = jnp.arange(4) + >>> jnp.ones_like(x) + Array([1, 1, 1, 1], dtype=int32) + >>> jnp.ones_like(x, dtype=bool) + Array([ True, True, True, True], dtype=bool) + >>> jnp.ones_like(x, shape=(2, 3)) + Array([[1, 1, 1], + [1, 1, 1]], dtype=int32) + """ if not (hasattr(a, 'dtype') and hasattr(a, 'shape')): # support duck typing util.check_arraylike("ones_like", a) dtypes.check_user_dtype_supported(dtype, "ones_like") @@ -2267,22 +3473,48 @@ def ones_like(a: ArrayLike | DuckTypedArray, return lax.full_like(a, 1, dtype, shape, sharding=_normalize_to_sharding(device)) -@util.implements(np.empty_like, lax_description="""\ -Because XLA cannot create uninitialized arrays, the JAX version will -return an array initialized with zeros.""") def empty_like(prototype: ArrayLike | DuckTypedArray, dtype: DTypeLike | None = None, shape: Any = None, *, device: xc.Device | Sharding | None = None) -> Array: + """Create an empty array with the same shape and dtype as an array. + + JAX implementation of :func:`numpy.empty_like`. Because XLA cannot create + an un-initialized array, :func:`jax.numpy.empty` will always return an + array full of zeros. + + Args: + a: Array-like object with ``shape`` and ``dtype`` attributes. + shape: optionally override the shape of the created array. + dtype: optionally override the dtype of the created array. + device: (optional) :class:`~jax.Device` or :class:`~jax.sharding.Sharding` + to which the created array will be committed. + + Returns: + Array of the specified shape and dtype, on the specified device if specified. + + See also: + - :func:`jax.numpy.empty` + - :func:`jax.numpy.zeros_like` + - :func:`jax.numpy.ones_like` + - :func:`jax.numpy.full_like` + + Examples: + >>> x = jnp.arange(4) + >>> jnp.empty_like(x) + Array([0, 0, 0, 0], dtype=int32) + >>> jnp.empty_like(x, dtype=bool) + Array([False, False, False, False], dtype=bool) + >>> jnp.empty_like(x, shape=(2, 3)) + Array([[0, 0, 0], + [0, 0, 0]], dtype=int32) + """ if not (hasattr(prototype, 'dtype') and hasattr(prototype, 'shape')): # support duck typing util.check_arraylike("empty_like", prototype) dtypes.check_user_dtype_supported(dtype, "empty_like") return zeros_like(prototype, dtype=dtype, shape=shape, device=device) -def _maybe_device_put(arr: Array, device: xc.Device | Sharding | None) -> Array: - return arr if device is None else jax.device_put(arr, device) - def _normalize_to_sharding(device: xc.Device | Sharding | None) -> Sharding | None: if isinstance(device, xc.Device): return SingleDeviceSharding(device) @@ -2290,10 +3522,43 @@ def _normalize_to_sharding(device: xc.Device | Sharding | None) -> Sharding | No return device -@util.implements(np.full) def full(shape: Any, fill_value: ArrayLike, dtype: DTypeLike | None = None, *, device: xc.Device | Sharding | None = None) -> Array: + """Create an array full of a specified value. + + JAX implementation of :func:`numpy.full`. + + Args: + shape: int or sequence of ints specifying the shape of the created array. + fill_value: scalar or array with which to fill the created array. + dtype: optional dtype for the created array; defaults to the dtype of the + fill value. + device: (optional) :class:`~jax.Device` or :class:`~jax.sharding.Sharding` + to which the created array will be committed. + + Returns: + Array of the specified shape and dtype, on the specified device if specified. + + See also: + - :func:`jax.numpy.full_like` + - :func:`jax.numpy.empty` + - :func:`jax.numpy.zeros` + - :func:`jax.numpy.ones` + + Examples: + >>> jnp.full(4, 2, dtype=float) + Array([2., 2., 2., 2.], dtype=float32) + >>> jnp.full((2, 3), 0, dtype=bool) + Array([[False, False, False], + [False, False, False]], dtype=bool) + + `fill_value` may also be an array that is broadcast to the specified shape: + + >>> jnp.full((2, 3), fill_value=jnp.arange(3)) + Array([[0, 1, 2], + [0, 1, 2]], dtype=int32) + """ dtypes.check_user_dtype_supported(dtype, "full") util.check_arraylike("full", fill_value) @@ -2301,14 +3566,50 @@ def full(shape: Any, fill_value: ArrayLike, shape = canonicalize_shape(shape) return lax.full(shape, fill_value, dtype, sharding=_normalize_to_sharding(device)) else: - return _maybe_device_put(broadcast_to(asarray(fill_value, dtype=dtype), shape), device) + return jax.device_put( + broadcast_to(asarray(fill_value, dtype=dtype), shape), device) -@util.implements(np.full_like) def full_like(a: ArrayLike | DuckTypedArray, fill_value: ArrayLike, dtype: DTypeLike | None = None, shape: Any = None, *, device: xc.Device | Sharding | None = None) -> Array: + """Create an array full of a specified value with the same shape and dtype as an array. + + JAX implementation of :func:`numpy.full_like`. + + Args: + a: Array-like object with ``shape`` and ``dtype`` attributes. + fill_value: scalar or array with which to fill the created array. + shape: optionally override the shape of the created array. + dtype: optionally override the dtype of the created array. + device: (optional) :class:`~jax.Device` or :class:`~jax.sharding.Sharding` + to which the created array will be committed. + + Returns: + Array of the specified shape and dtype, on the specified device if specified. + + See also: + - :func:`jax.numpy.full` + - :func:`jax.numpy.empty_like` + - :func:`jax.numpy.zeros_like` + - :func:`jax.numpy.ones_like` + + Examples: + >>> x = jnp.arange(4.0) + >>> jnp.full_like(x, 2) + Array([2., 2., 2., 2.], dtype=float32) + >>> jnp.full_like(x, 0, shape=(2, 3)) + Array([[0., 0., 0.], + [0., 0., 0.]], dtype=float32) + + `fill_value` may also be an array that is broadcast to the specified shape: + + >>> x = jnp.arange(6).reshape(2, 3) + >>> jnp.full_like(x, fill_value=jnp.array([[1], [2]])) + Array([[1, 1, 1], + [2, 2, 2]], dtype=int32) + """ if hasattr(a, 'dtype') and hasattr(a, 'shape'): # support duck typing util.check_arraylike("full_like", 0, fill_value) else: @@ -2320,13 +3621,39 @@ def full_like(a: ArrayLike | DuckTypedArray, return lax.full_like(a, fill_value, dtype, shape, sharding=_normalize_to_sharding(device)) else: shape = np.shape(a) if shape is None else shape # type: ignore[arg-type] - dtype = result_type(a) if dtype is None else dtype # type: ignore[arg-type] - return _maybe_device_put(broadcast_to(asarray(fill_value, dtype=dtype), shape), device) + dtype = result_type(a) if dtype is None else dtype + return jax.device_put( + broadcast_to(asarray(fill_value, dtype=dtype), shape), device) -@util.implements(np.zeros) def zeros(shape: Any, dtype: DTypeLike | None = None, *, device: xc.Device | Sharding | None = None) -> Array: + """Create an array full of zeros. + + JAX implementation of :func:`numpy.zeros`. + + Args: + shape: int or sequence of ints specifying the shape of the created array. + dtype: optional dtype for the created array; defaults to floating point. + device: (optional) :class:`~jax.Device` or :class:`~jax.sharding.Sharding` + to which the created array will be committed. + + Returns: + Array of the specified shape and dtype, on the specified device if specified. + + See also: + - :func:`jax.numpy.zeros_like` + - :func:`jax.numpy.empty` + - :func:`jax.numpy.ones` + - :func:`jax.numpy.full` + + Examples: + >>> jnp.zeros(4) + Array([0., 0., 0., 0.], dtype=float32) + >>> jnp.zeros((2, 3), dtype=bool) + Array([[False, False, False], + [False, False, False]], dtype=bool) + """ if isinstance(shape, types.GeneratorType): raise TypeError("expected sequence object with len >= 0 or a single integer") if (m := _check_forgot_shape_tuple("zeros", shape, dtype)): raise TypeError(m) @@ -2334,9 +3661,35 @@ def zeros(shape: Any, dtype: DTypeLike | None = None, *, shape = canonicalize_shape(shape) return lax.full(shape, 0, _jnp_dtype(dtype), sharding=_normalize_to_sharding(device)) -@util.implements(np.ones) + def ones(shape: Any, dtype: DTypeLike | None = None, *, device: xc.Device | Sharding | None = None) -> Array: + """Create an array full of ones. + + JAX implementation of :func:`numpy.ones`. + + Args: + shape: int or sequence of ints specifying the shape of the created array. + dtype: optional dtype for the created array; defaults to floating point. + device: (optional) :class:`~jax.Device` or :class:`~jax.sharding.Sharding` + to which the created array will be committed. + + Returns: + Array of the specified shape and dtype, on the specified device if specified. + + See also: + - :func:`jax.numpy.ones_like` + - :func:`jax.numpy.empty` + - :func:`jax.numpy.zeros` + - :func:`jax.numpy.full` + + Examples: + >>> jnp.ones(4) + Array([1., 1., 1., 1.], dtype=float32) + >>> jnp.ones((2, 3), dtype=bool) + Array([[ True, True, True], + [ True, True, True]], dtype=bool) + """ if isinstance(shape, types.GeneratorType): raise TypeError("expected sequence object with len >= 0 or a single integer") if (m := _check_forgot_shape_tuple("ones", shape, dtype)): raise TypeError(m) @@ -2344,11 +3697,37 @@ def ones(shape: Any, dtype: DTypeLike | None = None, *, dtypes.check_user_dtype_supported(dtype, "ones") return lax.full(shape, 1, _jnp_dtype(dtype), sharding=_normalize_to_sharding(device)) -@util.implements(np.empty, lax_description="""\ -Because XLA cannot create uninitialized arrays, the JAX version will -return an array initialized with zeros.""") + def empty(shape: Any, dtype: DTypeLike | None = None, *, device: xc.Device | Sharding | None = None) -> Array: + """Create an empty array. + + JAX implementation of :func:`numpy.empty`. Because XLA cannot create an + un-initialized array, :func:`jax.numpy.empty` will always return an array + full of zeros. + + Args: + shape: int or sequence of ints specifying the shape of the created array. + dtype: optional dtype for the created array; defaults to floating point. + device: (optional) :class:`~jax.Device` or :class:`~jax.sharding.Sharding` + to which the created array will be committed. + + Returns: + Array of the specified shape and dtype, on the specified device if specified. + + See also: + - :func:`jax.numpy.empty_like` + - :func:`jax.numpy.zeros` + - :func:`jax.numpy.ones` + - :func:`jax.numpy.full` + + Examples: + >>> jnp.empty(4) + Array([0., 0., 0., 0.], dtype=float32) + >>> jnp.empty((2, 3), dtype=bool) + Array([[False, False, False], + [False, False, False]], dtype=bool) + """ if (m := _check_forgot_shape_tuple("empty", shape, dtype)): raise TypeError(m) dtypes.check_user_dtype_supported(dtype, "empty") return zeros(shape, dtype, device=device) @@ -2362,8 +3741,38 @@ def _check_forgot_shape_tuple(name, shape, dtype) -> str | None: # type: ignore "with a single tuple argument for the shape?") -@util.implements(np.array_equal) def array_equal(a1: ArrayLike, a2: ArrayLike, equal_nan: bool = False) -> Array: + """Check if two arrays are element-wise equal. + + JAX implementation of :func:`numpy.array_equal`. + + Args: + a1: first input array to compare. + a2: second input array to compare. + equal_nan: Boolean. If ``True``, NaNs in ``a1`` will be considered + equal to NaNs in ``a2``. Default is ``False``. + + Returns: + Boolean scalar array indicating whether the input arrays are element-wise equal. + + See Also: + - :func:`jax.numpy.allclose` + - :func:`jax.numpy.array_equiv` + + Examples: + >>> jnp.array_equal(jnp.array([1, 2, 3]), jnp.array([1, 2, 3])) + Array(True, dtype=bool) + >>> jnp.array_equal(jnp.array([1, 2, 3]), jnp.array([1, 2])) + Array(False, dtype=bool) + >>> jnp.array_equal(jnp.array([1, 2, 3]), jnp.array([1, 2, 4])) + Array(False, dtype=bool) + >>> jnp.array_equal(jnp.array([1, 2, float('nan')]), + ... jnp.array([1, 2, float('nan')])) + Array(False, dtype=bool) + >>> jnp.array_equal(jnp.array([1, 2, float('nan')]), + ... jnp.array([1, 2, float('nan')]), equal_nan=True) + Array(True, dtype=bool) + """ a1, a2 = asarray(a1), asarray(a2) if shape(a1) != shape(a2): return bool_(False) @@ -2436,9 +3845,10 @@ def fromiter(*args, **kwargs): is later modified in-place, it may lead to undefined behavior when using the associated JAX array. """) -def from_dlpack(x: Any) -> Array: +def from_dlpack(x: Any, /, *, device: xc.Device | Sharding | None = None, + copy: bool | None = None) -> Array: from jax.dlpack import from_dlpack # pylint: disable=g-import-not-at-top - return from_dlpack(x.__dlpack__()) + return from_dlpack(x, device=device, copy=copy) @util.implements(np.fromfunction) def fromfunction(function: Callable[..., Array], shape: Any, @@ -2455,35 +3865,197 @@ def fromstring(string: str, dtype: DTypeLike = float, count: int = -1, *, sep: s return asarray(np.fromstring(string=string, dtype=dtype, count=count, sep=sep)) -@util.implements(np.eye) -def eye(N: DimSize, M: DimSize | None = None, k: int = 0, +def eye(N: DimSize, M: DimSize | None = None, + k: int | ArrayLike = 0, + dtype: DTypeLike | None = None, + *, device: xc.Device | Sharding | None = None) -> Array: + """Create a square or rectangular identity matrix + + JAX implementation of :func:`numpy.eye`. + + Args: + N: integer specifying the first dimension of the array. + M: optional integer specifying the second dimension of the array; + defaults to the same value as ``N``. + k: optional integer specifying the offset of the diagonal. Use positive + values for upper diagonals, and negative values for lower diagonals. + Default is zero. + dtype: optional dtype; defaults to floating point. + device: optional :class:`~jax.Device` or :class:`~jax.sharding.Sharding` + to which the created array will be committed. + + Returns: + Identity array of shape ``(N, M)``, or ``(N, N)`` if ``M`` is not specified. + + See also: + :func:`jax.numpy.identity`: Simpler API for generating square identity matrices. + + Examples: + A simple 3x3 identity matrix: + + >>> jnp.eye(3) + Array([[1., 0., 0.], + [0., 1., 0.], + [0., 0., 1.]], dtype=float32) + + Integer identity matrices with offset diagonals: + + >>> jnp.eye(3, k=1, dtype=int) + Array([[0, 1, 0], + [0, 0, 1], + [0, 0, 0]], dtype=int32) + >>> jnp.eye(3, k=-1, dtype=int) + Array([[0, 0, 0], + [1, 0, 0], + [0, 1, 0]], dtype=int32) + + Non-square identity matrix: + + >>> jnp.eye(3, 5, k=1) + Array([[0., 1., 0., 0., 0.], + [0., 0., 1., 0., 0.], + [0., 0., 0., 1., 0.]], dtype=float32) + """ + # TODO(vfdev-5): optimize putting the array directly on the device specified + # instead of putting it on default device and then on the specific device + output = _eye(N, M=M, k=k, dtype=dtype) + if device is not None: + return jax.device_put(output, device=device) + return output + + +def _eye(N: DimSize, M: DimSize | None = None, + k: int | ArrayLike = 0, dtype: DTypeLike | None = None) -> Array: dtypes.check_user_dtype_supported(dtype, "eye") + if isinstance(k, int): + k = lax_internal._clip_int_to_valid_range(k, np.int32) + util.check_arraylike("eye", k) + offset = asarray(k) + if not (offset.shape == () and dtypes.issubdtype(offset.dtype, np.integer)): + raise ValueError(f"k must be a scalar integer; got {k}") N_int = core.canonicalize_dim(N, "'N' argument of jnp.eye()") M_int = N_int if M is None else core.canonicalize_dim(M, "'M' argument of jnp.eye()") if N_int < 0 or M_int < 0: raise ValueError(f"negative dimensions are not allowed, got {N} and {M}") - k = operator.index(k) - return lax_internal._eye(_jnp_dtype(dtype), (N_int, M_int), k) + i = lax.broadcasted_iota(offset.dtype, (N_int, M_int), 0) + j = lax.broadcasted_iota(offset.dtype, (N_int, M_int), 1) + return (i + offset == j).astype(dtype) -@util.implements(np.identity) def identity(n: DimSize, dtype: DTypeLike | None = None) -> Array: + """Create a square identity matrix + + JAX implementation of :func:`numpy.identity`. + + Args: + n: integer specifying the size of each array dimension. + dtype: optional dtype; defaults to floating point. + + Returns: + Identity array of shape ``(n, n)``. + + See also: + :func:`jax.numpy.eye`: non-square and/or offset identity matrices. + + Examples: + A simple 3x3 identity matrix: + + >>> jnp.identity(3) + Array([[1., 0., 0.], + [0., 1., 0.], + [0., 0., 1.]], dtype=float32) + + A 2x2 integer identity matrix: + + >>> jnp.identity(2, dtype=int) + Array([[1, 0], + [0, 1]], dtype=int32) + """ dtypes.check_user_dtype_supported(dtype, "identity") return eye(n, dtype=dtype) -@util.implements(np.arange,lax_description= """ -.. note:: - - Using ``arange`` with the ``step`` argument can lead to precision errors, - especially with lower-precision data types like ``fp8`` and ``bf16``. - For more details, see the docstring of :func:`numpy.arange`. - To avoid precision errors, consider using an expression like - ``(jnp.arange(-600, 600) * .01).astype(jnp.bfloat16)`` to generate a sequence in a higher precision - and then convert it to the desired lower precision. -""") def arange(start: DimSize, stop: DimSize | None = None, + step: DimSize | None = None, dtype: DTypeLike | None = None, + *, device: xc.Device | Sharding | None = None) -> Array: + """Create an array of evenly-spaced values. + + JAX implementation of :func:`numpy.arange`, implemented in terms of + :func:`jax.lax.iota`. + + Similar to Python's :func:`range` function, this can be called with a few + different positional signatures: + + - ``jnp.arange(stop)``: generate values from 0 to ``stop``, stepping by 1. + - ``jnp.arange(start, stop)``: generate values from ``start`` to ``stop``, + stepping by 1. + - ``jnp.arange(start, stop, step)``: generate values from ``start`` to ``stop``, + stepping by ``step``. + + Like with Python's :func:`range` function, the starting value is inclusive, + and the stop value is exclusive. + + Args: + start: start of the interval, inclusive. + stop: optional end of the interval, exclusive. If not specified, then + ``(start, stop) = (0, start)`` + step: optional step size for the interval. Default = 1. + dtype: optional dtype for the returned array; if not specified it will + be determined via type promotion of `start`, `stop`, and `step`. + device: (optional) :class:`~jax.Device` or :class:`~jax.sharding.Sharding` + to which the created array will be committed. + + Returns: + Array of evenly-spaced values from ``start`` to ``stop``, separated by ``step``. + + Note: + Using ``arange`` with a floating-point ``step`` argument can lead to unexpected + results due to accumulation of floating-point errors, especially with + lower-precision data types like ``float8_*`` and ``bfloat16``. + To avoid precision errors, consider generating a range of integers, and scaling + it to the desired range. For example, instead of this:: + + jnp.arange(-1, 1, 0.01, dtype='bfloat16') + + it can be more accurate to generate a sequence of integers, and scale them:: + + (jnp.arange(-100, 100) * 0.01).astype('bfloat16') + + Examples: + Single-argument version specifies only the ``stop`` value: + + >>> jnp.arange(4) + Array([0, 1, 2, 3], dtype=int32) + + Passing a floating-point ``stop`` value leads to a floating-point result: + + >>> jnp.arange(4.0) + Array([0., 1., 2., 3.], dtype=float32) + + Two-argument version specifies ``start`` and ``stop``, with ``step=1``: + + >>> jnp.arange(1, 6) + Array([1, 2, 3, 4, 5], dtype=int32) + + Three-argument version specifies ``start``, ``stop``, and ``step``: + + >>> jnp.arange(0, 2, 0.5) + Array([0. , 0.5, 1. , 1.5], dtype=float32) + + See Also: + - :func:`jax.numpy.linspace`: generate a fixed number of evenly-spaced values. + - :func:`jax.lax.iota`: directly generate integer sequences in XLA. + """ + # TODO(vfdev-5): optimize putting the array directly on the device specified + # instead of putting it on default device and then on the specific device + output = _arange(start, stop=stop, step=step, dtype=dtype) + if device is not None: + return jax.device_put(output, device=device) + return output + + +def _arange(start: DimSize, stop: DimSize | None = None, step: DimSize | None = None, dtype: DTypeLike | None = None) -> Array: dtypes.check_user_dtype_supported(dtype, "arange") if not config.dynamic_shapes.value: @@ -2679,13 +4251,11 @@ def _geomspace(start: ArrayLike, stop: ArrayLike, num: int = 50, endpoint: bool util.check_arraylike("geomspace", start, stop) start = asarray(start, dtype=computation_dtype) stop = asarray(stop, dtype=computation_dtype) - # follow the numpy geomspace convention for negative and complex endpoints - signflip = 1 - (1 - ufuncs.sign(ufuncs.real(start))) * (1 - ufuncs.sign(ufuncs.real(stop))) // 2 - signflip = signflip.astype(computation_dtype) - res = signflip * logspace(ufuncs.log10(signflip * start), - ufuncs.log10(signflip * stop), num, - endpoint=endpoint, base=10.0, - dtype=computation_dtype, axis=0) + + sign = ufuncs.sign(start) + res = sign * logspace(ufuncs.log10(start / sign), ufuncs.log10(stop / sign), + num, endpoint=endpoint, base=10.0, + dtype=computation_dtype, axis=0) if axis != 0: res = moveaxis(res, 0, axis) return lax.convert_element_type(res, dtype) @@ -2728,9 +4298,39 @@ def _i0_jvp(primals, tangents): primal_out, tangent_out = jax.jvp(i0.fun, primals, tangents) return primal_out, where(primals[0] == 0, 0.0, tangent_out) - -@util.implements(np.ix_) def ix_(*args: ArrayLike) -> tuple[Array, ...]: + """Return a multi-dimensional grid (open mesh) from N one-dimensional sequences. + + JAX implementation of :func:`numpy.ix_`. + + Args: + *args: N one-dimensional arrays + + Returns: + Tuple of Jax arrays forming an open mesh, each with N dimensions. + + See Also: + - :obj:`jax.numpy.ogrid` + - :obj:`jax.numpy.mgrid` + - :func:`jax.numpy.meshgrid` + + Examples: + >>> rows = jnp.array([0, 2]) + >>> cols = jnp.array([1, 3]) + >>> open_mesh = jnp.ix_(rows, cols) + >>> open_mesh + (Array([[0], + [2]], dtype=int32), Array([[1, 3]], dtype=int32)) + >>> [grid.shape for grid in open_mesh] + [(2, 1), (1, 2)] + >>> x = jnp.array([[10, 20, 30, 40], + ... [50, 60, 70, 80], + ... [90, 100, 110, 120], + ... [130, 140, 150, 160]]) + >>> x[open_mesh] + Array([[ 20, 40], + [100, 120]], dtype=int32) + """ util.check_arraylike("ix", *args) n = len(args) output = [] @@ -2871,6 +4471,27 @@ def repeat(a: ArrayLike, repeats: ArrayLike, axis: int | None = None, *, return take(a, gather_indices, axis=axis) +@util.implements(getattr(np, "trapezoid", getattr(np, "trapz", None))) +@partial(jit, static_argnames=('axis',)) +def trapezoid(y: ArrayLike, x: ArrayLike | None = None, dx: ArrayLike = 1.0, + axis: int = -1) -> Array: + # TODO(phawkins): remove this annotation after fixing jnp types. + dx_array: Array + if x is None: + util.check_arraylike('trapezoid', y) + y_arr, = util.promote_dtypes_inexact(y) + dx_array = asarray(dx) + else: + util.check_arraylike('trapezoid', y, x) + y_arr, x_arr = util.promote_dtypes_inexact(y, x) + if x_arr.ndim == 1: + dx_array = diff(x_arr) + else: + dx_array = moveaxis(diff(x_arr, axis=axis), axis, -1) + y_arr = moveaxis(y_arr, axis, -1) + return 0.5 * (dx_array * (y_arr[..., 1:] + y_arr[..., :-1])).sum(-1) + + @util.implements(np.tri) def tri(N: int, M: int | None = None, k: int = 0, dtype: DTypeLike | None = None) -> Array: dtypes.check_user_dtype_supported(dtype, "tri") @@ -2904,8 +4525,8 @@ def triu(m: ArrayLike, k: int = 0) -> Array: @util.implements(np.trace, skip_params=['out']) -@partial(jit, static_argnames=('offset', 'axis1', 'axis2', 'dtype')) -def trace(a: ArrayLike, offset: int = 0, axis1: int = 0, axis2: int = 1, +@partial(jit, static_argnames=('axis1', 'axis2', 'dtype')) +def trace(a: ArrayLike, offset: int | ArrayLike = 0, axis1: int = 0, axis2: int = 1, dtype: DTypeLike | None = None, out: None = None) -> Array: util.check_arraylike("trace", a) if out is not None: @@ -2913,13 +4534,6 @@ def trace(a: ArrayLike, offset: int = 0, axis1: int = 0, axis2: int = 1, dtypes.check_user_dtype_supported(dtype, "trace") a_shape = shape(a) - if dtype is None: - dtype = _dtype(a) - if issubdtype(dtype, integer): - default_int = dtypes.canonicalize_dtype(int) - if iinfo(dtype).bits < iinfo(default_int).bits: - dtype = default_int - a = moveaxis(a, (axis1, axis2), (-2, -1)) # Mask out the diagonal and reduce. @@ -3118,29 +4732,60 @@ def trim_zeros_tol(filt, tol, trim='fb'): end = argmin(nz[::-1]) if 'b' in trim.lower() else 0 return filt[start:len(filt) - end] - -@util.implements(np.append) @partial(jit, static_argnames=('axis',)) def append( arr: ArrayLike, values: ArrayLike, axis: int | None = None ) -> Array: + """Return a new array with values appended to the end of the original array. + + JAX implementation of :func:`numpy.append`. + + Args: + arr: original array. + values: values to be appended to the array. The ``values`` must have + the same number of dimensions as ``arr``, and all dimensions must + match except in the specified axis. + axis: axis along which to append values. If None (default), both ``arr`` + and ``values`` will be flattened before appending. + + Returns: + A new array with values appended to ``arr``. + + See also: + - :func:`jax.numpy.insert` + - :func:`jax.numpy.delete` + + Examples: + >>> a = jnp.array([1, 2, 3]) + >>> b = jnp.array([4, 5, 6]) + >>> jnp.append(a, b) + Array([1, 2, 3, 4, 5, 6], dtype=int32) + + Appending along a specific axis: + + >>> a = jnp.array([[1, 2], + ... [3, 4]]) + >>> b = jnp.array([[5, 6]]) + >>> jnp.append(a, b, axis=0) + Array([[1, 2], + [3, 4], + [5, 6]], dtype=int32) + + Appending along a trailing axis: + + >>> a = jnp.array([[1, 2, 3], + ... [4, 5, 6]]) + >>> b = jnp.array([[7], [8]]) + >>> jnp.append(a, b, axis=1) + Array([[1, 2, 3, 7], + [4, 5, 6, 8]], dtype=int32) + """ if axis is None: return concatenate([ravel(arr), ravel(values)], 0) else: return concatenate([arr, values], axis=axis) -@util.implements(np.delete, - lax_description=_dedent(""" - delete() usually requires the index specification to be static. If the index - is an integer array that is guaranteed to contain unique entries, you may - specify ``assume_unique_indices=True`` to perform the operation in a - manner that does not require static indices."""), - extra_params=_dedent(""" - assume_unique_indices : int, optional (default=False) - In case of array-like integer (not boolean) indices, assume the indices are unique, - and perform the deletion in a way that is compatible with JIT and other JAX - transformations.""")) def delete( arr: ArrayLike, obj: ArrayLike | slice, @@ -3148,6 +4793,67 @@ def delete( *, assume_unique_indices: bool = False, ) -> Array: + """Delete entry or entries from an array. + + JAX implementation of :func:`numpy.delete`. + + Args: + arr: array from which entries will be deleted. + obj: index, indices, or slice to be deleted. + axis: axis along which entries will be deleted. + assume_unique_indices: In case of array-like integer (not boolean) indices, + assume the indices are unique, and perform the deletion in a way that is + compatible with JIT and other JAX transformations. + + Returns: + Copy of ``arr`` with specified indices deleted. + + Note: + ``delete()`` usually requires the index specification to be static. If the + index is an integer array that is guaranteed to contain unique entries, you + may specify ``assume_unique_indices=True`` to perform the operation in a + manner that does not require static indices. + + Examples: + Delete entries from a 1D array: + + >>> a = jnp.array([4, 5, 6, 7, 8, 9]) + >>> jnp.delete(a, 2) + Array([4, 5, 7, 8, 9], dtype=int32) + >>> jnp.delete(a, slice(1, 4)) # delete a[1:4] + Array([4, 8, 9], dtype=int32) + >>> jnp.delete(a, slice(None, None, 2)) # delete a[::2] + Array([5, 7, 9], dtype=int32) + + Delete entries from a 2D array along a specified axis: + + >>> a2 = jnp.array([[4, 5, 6], + ... [7, 8, 9]]) + >>> jnp.delete(a2, 1, axis=1) + Array([[4, 6], + [7, 9]], dtype=int32) + + Delete multiple entries via a sequence of indices: + + >>> indices = jnp.array([0, 1, 3]) + >>> jnp.delete(a, indices) + Array([6, 8, 9], dtype=int32) + + This will fail under :func:`~jax.jit` and other transformations, because + the output shape cannot be known with the possibility of duplicate indices: + + >>> jax.jit(jnp.delete)(a, indices) # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape int32[3]. + + If you can ensure that the indices are unique, pass ``assume_unique_indices`` + to allow this to be executed under JIT: + + >>> jit_delete = jax.jit(jnp.delete, static_argnames=['assume_unique_indices']) + >>> jit_delete(a, indices, assume_unique_indices=True) + Array([6, 8, 9], dtype=int32) + """ util.check_arraylike("delete", arr) if axis is None: arr = ravel(arr) @@ -3258,8 +4964,7 @@ def insert(arr: ArrayLike, obj: ArrayLike | slice, values: ArrayLike, def apply_along_axis( func1d: Callable, axis: int, arr: ArrayLike, *args, **kwargs ) -> Array: - # TODO(jakevdp): Non-array input deprecated 2023-09-22; change to error. - util.check_arraylike("apply_along_axis", arr, emit_warning=True) + util.check_arraylike("apply_along_axis", arr) num_dims = ndim(arr) axis = _canonicalize_axis(axis, num_dims) func = lambda arr: func1d(arr, *args, **kwargs) @@ -3273,8 +4978,7 @@ def apply_along_axis( @util.implements(np.apply_over_axes) def apply_over_axes(func: Callable[[ArrayLike, int], Array], a: ArrayLike, axes: Sequence[int]) -> Array: - # TODO(jakevdp): Non-array input deprecated 2023-09-22; change to error. - util.check_arraylike("apply_over_axes", a, emit_warning=True) + util.check_arraylike("apply_over_axes", a) a_arr = asarray(a) for axis in axes: b = func(a_arr, axis) @@ -3289,20 +4993,71 @@ def apply_over_axes(func: Callable[[ArrayLike, int], Array], a: ArrayLike, ### Tensor contraction operations - -_DOT_PREFERRED_ELEMENT_TYPE_DESCRIPTION = """ -preferred_element_type : dtype, optional - If specified, accumulate results and return a result of the given data type. - If not specified, the accumulation dtype is determined from the type promotion - rules of the input array dtypes. -""" - -@util.implements(np.dot, lax_description=_PRECISION_DOC, - extra_params=_DOT_PREFERRED_ELEMENT_TYPE_DESCRIPTION) @partial(jit, static_argnames=('precision', 'preferred_element_type'), inline=True) def dot(a: ArrayLike, b: ArrayLike, *, precision: PrecisionLike = None, preferred_element_type: DTypeLike | None = None) -> Array: + """Compute the dot product of two arrays. + + JAX implementation of :func:`numpy.dot`. + + This differs from :func:`jax.numpy.matmul` in two respects: + + - if either ``a`` or ``b`` is a scalar, the result of ``dot`` is equivalent to + :func:`jax.numpy.multiply`, while the result of ``matmul`` is an error. + - if ``a`` and ``b`` have more than 2 dimensions, the batch indices are + stacked rather than broadcast. + + Args: + a: first input array, of shape ``(..., N)``. + b: second input array. Must have shape ``(N,)`` or ``(..., N, M)``. + In the multi-dimensional case, leading dimensions must be broadcast-compatible + with the leading dimensions of ``a``. + precision: either ``None`` (default), which means the default precision for + the backend, a :class:`~jax.lax.Precision` enum value (``Precision.DEFAULT``, + ``Precision.HIGH`` or ``Precision.HIGHEST``) or a tuple of two + such values indicating precision of ``a`` and ``b``. + preferred_element_type: either ``None`` (default), which means the default + accumulation type for the input types, or a datatype, indicating to + accumulate results to and return a result with that datatype. + + Returns: + array containing the dot product of the inputs, with batch dimensions of + ``a`` and ``b`` stacked rather than broadcast. + + See also: + - :func:`jax.numpy.matmul`: broadcasted batched matmul. + - :func:`jax.lax.dot_general`: general batched matrix multiplication. + + Examples: + For scalar inputs, ``dot`` computes the element-wise product: + + >>> x = jnp.array([1, 2, 3]) + >>> jnp.dot(x, 2) + Array([2, 4, 6], dtype=int32) + + For vector or matrix inputs, ``dot`` computes the vector or matrix product: + + >>> M = jnp.array([[2, 3, 4], + ... [5, 6, 7], + ... [8, 9, 0]]) + >>> jnp.dot(M, x) + Array([20, 38, 26], dtype=int32) + >>> jnp.dot(M, M) + Array([[ 51, 60, 29], + [ 96, 114, 62], + [ 61, 78, 95]], dtype=int32) + + For higher-dimensional matrix products, batch dimensions are stacked, whereas + in :func:`~jax.numpy.matmul` they are broadcast. For example: + + >>> a = jnp.zeros((3, 2, 4)) + >>> b = jnp.zeros((3, 4, 1)) + >>> jnp.dot(a, b).shape + (3, 2, 3, 1) + >>> jnp.matmul(a, b).shape + (3, 2, 1) + """ util.check_arraylike("dot", a, b) dtypes.check_user_dtype_supported(preferred_element_type, "dot") a, b = asarray(a), asarray(b) @@ -3330,14 +5085,64 @@ def dot(a: ArrayLike, b: ArrayLike, *, return lax_internal._convert_element_type(result, preferred_element_type, output_weak_type) -@util.implements(np.matmul, module='numpy', lax_description=_PRECISION_DOC, - extra_params=_DOT_PREFERRED_ELEMENT_TYPE_DESCRIPTION) @partial(jit, static_argnames=('precision', 'preferred_element_type'), inline=True) def matmul(a: ArrayLike, b: ArrayLike, *, precision: PrecisionLike = None, preferred_element_type: DTypeLike | None = None, ) -> Array: - """Matrix Multiply.""" + """Perform a matrix multiplication. + + JAX implementation of :func:`numpy.matmul`. + + Args: + a: first input array, of shape ``(..., N)``. + b: second input array. Must have shape ``(N,)`` or ``(..., N, M)``. + In the multi-dimensional case, leading dimensions must be broadcast-compatible + with the leading dimensions of ``a``. + precision: either ``None`` (default), which means the default precision for + the backend, a :class:`~jax.lax.Precision` enum value (``Precision.DEFAULT``, + ``Precision.HIGH`` or ``Precision.HIGHEST``) or a tuple of two + such values indicating precision of ``a`` and ``b``. + preferred_element_type: either ``None`` (default), which means the default + accumulation type for the input types, or a datatype, indicating to + accumulate results to and return a result with that datatype. + + Returns: + array containing the matrix product of the inputs. Shape is ``a.shape[:-1]`` + if ``b.ndim == 1``, otherwise the shape is ``(..., M)``, where leading + dimensions of ``a`` and ``b`` are broadcast together. + + See Also: + - :func:`jax.numpy.linalg.vecdot`: batched vector product. + - :func:`jax.numpy.linalg.tensordot`: batched tensor product. + - :func:`jax.lax.dot_general`: general N-dimensional batched dot product. + + Examples: + Vector dot products: + + >>> a = jnp.array([1, 2, 3]) + >>> b = jnp.array([4, 5, 6]) + >>> jnp.matmul(a, b) + Array(32, dtype=int32) + + Matrix dot product: + + >>> a = jnp.array([[1, 2, 3], + ... [4, 5, 6]]) + >>> b = jnp.array([[1, 2], + ... [3, 4], + ... [5, 6]]) + >>> jnp.matmul(a, b) + Array([[22, 28], + [49, 64]], dtype=int32) + + For convenience, in all cases you can do the same computation using + the ``@`` operator: + + >>> a @ b + Array([[22, 28], + [49, 64]], dtype=int32) + """ util.check_arraylike("matmul", a, b) dtypes.check_user_dtype_supported(preferred_element_type, "matmul") a, b = asarray(a), asarray(b) @@ -3403,14 +5208,47 @@ def matmul(a: ArrayLike, b: ArrayLike, *, return lax_internal._convert_element_type(result, preferred_element_type, output_weak_type) -@util.implements(np.vdot, lax_description=_PRECISION_DOC, - extra_params=_DOT_PREFERRED_ELEMENT_TYPE_DESCRIPTION) @partial(jit, static_argnames=('precision', 'preferred_element_type'), inline=True) def vdot( a: ArrayLike, b: ArrayLike, *, precision: PrecisionLike = None, preferred_element_type: DTypeLike | None = None, ) -> Array: + """Perform a conjugate multiplication of two 1D vectors. + + JAX implementation of :func:`numpy.vdot`. + + Args: + a: first input array, if not 1D it will be flattened. + b: second input array, if not 1D it will be flattened. Must have ``a.size == b.size``. + precision: either ``None`` (default), which means the default precision for + the backend, a :class:`~jax.lax.Precision` enum value (``Precision.DEFAULT``, + ``Precision.HIGH`` or ``Precision.HIGHEST``) or a tuple of two + such values indicating precision of ``a`` and ``b``. + preferred_element_type: either ``None`` (default), which means the default + accumulation type for the input types, or a datatype, indicating to + accumulate results to and return a result with that datatype. + + Returns: + Scalar array (shape ``()``) containing the conjugate vector product of the inputs. + + See Also: + - :func:`jax.numpy.vecdot`: batched vector product. + - :func:`jax.numpy.matmul`: general matrix multiplication. + - :func:`jax.lax.dot_general`: general N-dimensional batched dot product. + + Examples: + >>> x = jnp.array([1j, 2j, 3j]) + >>> y = jnp.array([1., 2., 3.]) + >>> jnp.vdot(x, y) + Array(0.-14.j, dtype=complex64) + + Note the difference between this and :func:`~jax.numpy.dot`, which does not + conjugate the first input when complex: + + >>> jnp.dot(x, y) + Array(0.+14.j, dtype=complex64) + """ util.check_arraylike("vdot", a, b) if issubdtype(_dtype(a), complexfloating): a = ufuncs.conj(a) @@ -3418,11 +5256,51 @@ def vdot( preferred_element_type=preferred_element_type) -@util.implements(getattr(np, "vecdot", None), lax_description=_PRECISION_DOC, - extra_params=_DOT_PREFERRED_ELEMENT_TYPE_DESCRIPTION) def vecdot(x1: ArrayLike, x2: ArrayLike, /, *, axis: int = -1, precision: PrecisionLike = None, preferred_element_type: DTypeLike | None = None) -> Array: + """Perform a conjugate multiplication of two batched vectors. + + JAX implementation of :func:`numpy.vecdot`. + + Args: + a: left-hand side array. + b: right-hand side array. Size of ``b[axis]`` must match size of ``a[axis]``, + and remaining dimensions must be broadcast-compatible. + axis: axis along which to compute the dot product (default: -1) + precision: either ``None`` (default), which means the default precision for + the backend, a :class:`~jax.lax.Precision` enum value (``Precision.DEFAULT``, + ``Precision.HIGH`` or ``Precision.HIGHEST``) or a tuple of two + such values indicating precision of ``a`` and ``b``. + preferred_element_type: either ``None`` (default), which means the default + accumulation type for the input types, or a datatype, indicating to + accumulate results to and return a result with that datatype. + + Returns: + array containing the conjugate dot product of ``a`` and ``b`` along ``axis``. + The non-contracted dimensions are broadcast together. + + See Also: + - :func:`jax.numpy.vdot`: flattened vector product. + - :func:`jax.numpy.matmul`: general matrix multiplication. + - :func:`jax.lax.dot_general`: general N-dimensional batched dot product. + + Examples: + Vector conjugate-dot product of two 1D arrays: + + >>> a = jnp.array([1j, 2j, 3j]) + >>> b = jnp.array([4., 5., 6.]) + >>> jnp.linalg.vecdot(a, b) + Array(0.-32.j, dtype=complex64) + + Batched vector dot product of two 2D arrays: + + >>> a = jnp.array([[1, 2, 3], + ... [4, 5, 6]]) + >>> b = jnp.array([[2, 3, 4]]) + >>> jnp.linalg.vecdot(a, b, axis=-1) + Array([20, 47], dtype=int32) + """ util.check_arraylike("jnp.vecdot", x1, x2) x1_arr, x2_arr = asarray(x1), asarray(x2) if x1_arr.shape[axis] != x2_arr.shape[axis]: @@ -3433,12 +5311,81 @@ def vecdot(x1: ArrayLike, x2: ArrayLike, /, *, axis: int = -1, signature="(n),(n)->()")(x1_arr, x2_arr) -@util.implements(np.tensordot, lax_description=_PRECISION_DOC, - extra_params=_DOT_PREFERRED_ELEMENT_TYPE_DESCRIPTION) def tensordot(a: ArrayLike, b: ArrayLike, axes: int | Sequence[int] | Sequence[Sequence[int]] = 2, *, precision: PrecisionLike = None, preferred_element_type: DTypeLike | None = None) -> Array: + """Compute the tensor dot product of two N-dimensional arrays. + + JAX implementation of :func:`numpy.linalg.tensordot`. + + Args: + a: N-dimensional array + b: M-dimensional array + axes: integer or tuple of sequences of integers. If an integer `k`, then + sum over the last `k` axes of ``a`` and the first `k` axes of ``b``, + in order. If a tuple, then ``axes[0]`` specifies the axes of ``a`` and + ``axes[1]`` specifies the axes of ``b``. + precision: either ``None`` (default), which means the default precision for + the backend, a :class:`~jax.lax.Precision` enum value (``Precision.DEFAULT``, + ``Precision.HIGH`` or ``Precision.HIGHEST``) or a tuple of two + such values indicating precision of ``a`` and ``b``. + preferred_element_type: either ``None`` (default), which means the default + accumulation type for the input types, or a datatype, indicating to + accumulate results to and return a result with that datatype. + + Returns: + array containing the tensor dot product of the inputs + + See also: + - :func:`jax.numpy.einsum`: NumPy API for more general tensor contractions. + - :func:`jax.lax.dot_general`: XLA API for more general tensor contractions. + + Examples: + >>> x1 = jnp.arange(24.).reshape(2, 3, 4) + >>> x2 = jnp.ones((3, 4, 5)) + >>> jnp.tensordot(x1, x2) + Array([[ 66., 66., 66., 66., 66.], + [210., 210., 210., 210., 210.]], dtype=float32) + + Equivalent result when specifying the axes as explicit sequences: + + >>> jnp.tensordot(x1, x2, axes=([1, 2], [0, 1])) + Array([[ 66., 66., 66., 66., 66.], + [210., 210., 210., 210., 210.]], dtype=float32) + + Equivalent result via :func:`~jax.numpy.einsum`: + + >>> jnp.einsum('ijk,jkm->im', x1, x2) + Array([[ 66., 66., 66., 66., 66.], + [210., 210., 210., 210., 210.]], dtype=float32) + + Setting ``axes=1`` for two-dimensional inputs is equivalent to a matrix + multiplication: + + >>> x1 = jnp.array([[1, 2], + ... [3, 4]]) + >>> x2 = jnp.array([[1, 2, 3], + ... [4, 5, 6]]) + >>> jnp.linalg.tensordot(x1, x2, axes=1) + Array([[ 9, 12, 15], + [19, 26, 33]], dtype=int32) + >>> x1 @ x2 + Array([[ 9, 12, 15], + [19, 26, 33]], dtype=int32) + + Setting ``axes=0`` for one-dimensional inputs is equivalent to + :func:`~jax.numpy.outer`: + + >>> x1 = jnp.array([1, 2]) + >>> x2 = jnp.array([1, 2, 3]) + >>> jnp.linalg.tensordot(x1, x2, axes=0) + Array([[1, 2, 3], + [2, 4, 6]], dtype=int32) + >>> jnp.outer(x1, x2) + Array([[1, 2, 3], + [2, 4, 6]], dtype=int32) + """ util.check_arraylike("tensordot", a, b) dtypes.check_user_dtype_supported(preferred_element_type, "tensordot") a, b = asarray(a), asarray(b) @@ -3479,21 +5426,17 @@ def tensordot(a: ArrayLike, b: ArrayLike, return lax_internal._convert_element_type(result, preferred_element_type, output_weak_type) -_EINSUM_DOC = _PRECISION_DOC + """\ -A tuple ``precision`` does not necessarily map to multiple arguments of ``einsum()``; -rather, the specified ``precision`` is forwarded to each ``dot_general`` call used in -the implementation. - -:func:`jax.numpy.einsum` also differs from :func:`numpy.einsum` in that the ``optimize`` -keyword defaults to ``"optimal"`` rather than ``False``. -""" +class Unoptimized(opt_einsum.paths.PathOptimizer): + """Unoptimized path for einsum.""" + def __call__(self, inputs, *args, **kwargs): + return [(0, 1)] * (len(inputs) - 1) @overload def einsum( subscript: str, /, *operands: ArrayLike, out: None = None, - optimize: str = "optimal", + optimize: str | bool | list[tuple[int, ...]] = "optimal", precision: PrecisionLike = None, preferred_element_type: DTypeLike | None = None, _dot_general: Callable[..., Array] = lax.dot_general, @@ -3505,27 +5448,234 @@ def einsum( axes: Sequence[Any], /, *operands: ArrayLike | Sequence[Any], out: None = None, - optimize: str = "optimal", + optimize: str | bool | list[tuple[int, ...]] = "optimal", precision: PrecisionLike = None, preferred_element_type: DTypeLike | None = None, _dot_general: Callable[..., Array] = lax.dot_general, ) -> Array: ... -@util.implements(np.einsum, lax_description=_EINSUM_DOC, skip_params=['out']) def einsum( subscripts, /, *operands, out: None = None, - optimize: str = "optimal", + optimize: str | bool | list[tuple[int, ...]] = "optimal", precision: PrecisionLike = None, preferred_element_type: DTypeLike | None = None, _dot_general: Callable[..., Array] = lax.dot_general, ) -> Array: + """Einstein summation + + JAX implementation of :func:`numpy.einsum`. + + ``einsum`` is a powerful and generic API for computing various reductions, + inner products, outer products, axis reorderings, and combinations thereof + across one or more input arrays. It has a somewhat complicated overloaded API; + the arguments below reflect the most common calling convention. The Examples + section below demonstrates some of the alternative calling conventions. + + Args: + subscripts: string containing axes names separated by commas. + *operands: sequence of one or more arrays corresponding to the subscripts. + optimize: specify how to optimize the order of computation. In JAX this defaults + to ``"optimal"`` which produces optimized expressions via the opt_einsum_ + package. Other options are ``True`` (same as ``"optimal"``), ``False`` + (unoptimized), or any string supported by ``opt_einsum``, which + includes ``"auto"``, ``"greedy"``, ``"eager"``, and others. It may also + be a pre-computed path (see :func:`~jax.numpy.einsum_path`). + precision: either ``None`` (default), which means the default precision for + the backend, a :class:`~jax.lax.Precision` enum value (``Precision.DEFAULT``, + ``Precision.HIGH`` or ``Precision.HIGHEST``). + preferred_element_type: either ``None`` (default), which means the default + accumulation type for the input types, or a datatype, indicating to + accumulate results to and return a result with that datatype. + out: unsupported by JAX + _dot_general: optionally override the ``dot_general`` callable used by ``einsum``. + This parameter is experimental, and may be removed without warning at any time. + + Returns: + array containing the result of the einstein summation. + + See also: + :func:`jax.numpy.einsum_path` + + Examples: + The mechanics of ``einsum`` are perhaps best demonstrated by example. Here we + show how to use ``einsum`` to compute a number of quantities from one or more + arrays. For more discussion and examples of ``einsum``, see the documentation + of :func:`numpy.einsum`. + + >>> M = jnp.arange(16).reshape(4, 4) + >>> x = jnp.arange(4) + >>> y = jnp.array([5, 4, 3, 2]) + + **Vector product** + + >>> jnp.einsum('i,i', x, y) + Array(16, dtype=int32) + >>> jnp.vecdot(x, y) + Array(16, dtype=int32) + + Here are some alternative ``einsum`` calling conventions to compute the same + result: + + >>> jnp.einsum('i,i->', x, y) # explicit form + Array(16, dtype=int32) + >>> jnp.einsum(x, (0,), y, (0,)) # implicit form via indices + Array(16, dtype=int32) + >>> jnp.einsum(x, (0,), y, (0,), ()) # explicit form via indices + Array(16, dtype=int32) + + **Matrix product** + + >>> jnp.einsum('ij,j->i', M, x) # explicit form + Array([14, 38, 62, 86], dtype=int32) + >>> jnp.matmul(M, x) + Array([14, 38, 62, 86], dtype=int32) + + Here are some alternative ``einsum`` calling conventions to compute the same + result: + + >>> jnp.einsum('ij,j', M, x) # implicit form + Array([14, 38, 62, 86], dtype=int32) + >>> jnp.einsum(M, (0, 1), x, (1,), (0,)) # explicit form via indices + Array([14, 38, 62, 86], dtype=int32) + >>> jnp.einsum(M, (0, 1), x, (1,)) # implicit form via indices + Array([14, 38, 62, 86], dtype=int32) + + **Outer product** + + >>> jnp.einsum("i,j->ij", x, y) + Array([[ 0, 0, 0, 0], + [ 5, 4, 3, 2], + [10, 8, 6, 4], + [15, 12, 9, 6]], dtype=int32) + >>> jnp.outer(x, y) + Array([[ 0, 0, 0, 0], + [ 5, 4, 3, 2], + [10, 8, 6, 4], + [15, 12, 9, 6]], dtype=int32) + + Some other ways of computing outer products: + + >>> jnp.einsum("i,j", x, y) # implicit form + Array([[ 0, 0, 0, 0], + [ 5, 4, 3, 2], + [10, 8, 6, 4], + [15, 12, 9, 6]], dtype=int32) + >>> jnp.einsum(x, (0,), y, (1,), (0, 1)) # explicit form via indices + Array([[ 0, 0, 0, 0], + [ 5, 4, 3, 2], + [10, 8, 6, 4], + [15, 12, 9, 6]], dtype=int32) + >>> jnp.einsum(x, (0,), y, (1,)) # implicit form via indices + Array([[ 0, 0, 0, 0], + [ 5, 4, 3, 2], + [10, 8, 6, 4], + [15, 12, 9, 6]], dtype=int32) + + **1D array sum** + + >>> jnp.einsum("i->", x) # requires explicit form + Array(6, dtype=int32) + >>> jnp.einsum(x, (0,), ()) # explicit form via indices + Array(6, dtype=int32) + >>> jnp.sum(x) + Array(6, dtype=int32) + + **Sum along an axis** + + >>> jnp.einsum("...j->...", M) # requires explicit form + Array([ 6, 22, 38, 54], dtype=int32) + >>> jnp.einsum(M, (..., 0), (...,)) # explicit form via indices + Array([ 6, 22, 38, 54], dtype=int32) + >>> M.sum(-1) + Array([ 6, 22, 38, 54], dtype=int32) + + **Matrix transpose** + + >>> y = jnp.array([[1, 2, 3], + ... [4, 5, 6]]) + >>> jnp.einsum("ij->ji", y) # explicit form + Array([[1, 4], + [2, 5], + [3, 6]], dtype=int32) + >>> jnp.einsum("ji", y) # implicit form + Array([[1, 4], + [2, 5], + [3, 6]], dtype=int32) + >>> jnp.einsum(y, (1, 0)) # implicit form via indices + Array([[1, 4], + [2, 5], + [3, 6]], dtype=int32) + >>> jnp.einsum(y, (0, 1), (1, 0)) # explicit form via indices + Array([[1, 4], + [2, 5], + [3, 6]], dtype=int32) + >>> jnp.transpose(y) + Array([[1, 4], + [2, 5], + [3, 6]], dtype=int32) + + **Matrix diagonal** + + >>> jnp.einsum("ii->i", M) + Array([ 0, 5, 10, 15], dtype=int32) + >>> jnp.diagonal(M) + Array([ 0, 5, 10, 15], dtype=int32) + + **Matrix trace** + + >>> jnp.einsum("ii", M) + Array(30, dtype=int32) + >>> jnp.trace(M) + Array(30, dtype=int32) + + **Tensor products** + + >>> x = jnp.arange(30).reshape(2, 3, 5) + >>> y = jnp.arange(60).reshape(3, 4, 5) + >>> jnp.einsum('ijk,jlk->il', x, y) # explicit form + Array([[ 3340, 3865, 4390, 4915], + [ 8290, 9940, 11590, 13240]], dtype=int32) + >>> jnp.tensordot(x, y, axes=[(1, 2), (0, 2)]) + Array([[ 3340, 3865, 4390, 4915], + [ 8290, 9940, 11590, 13240]], dtype=int32) + >>> jnp.einsum('ijk,jlk', x, y) # implicit form + Array([[ 3340, 3865, 4390, 4915], + [ 8290, 9940, 11590, 13240]], dtype=int32) + >>> jnp.einsum(x, (0, 1, 2), y, (1, 3, 2), (0, 3)) # explicit form via indices + Array([[ 3340, 3865, 4390, 4915], + [ 8290, 9940, 11590, 13240]], dtype=int32) + >>> jnp.einsum(x, (0, 1, 2), y, (1, 3, 2)) # implicit form via indices + Array([[ 3340, 3865, 4390, 4915], + [ 8290, 9940, 11590, 13240]], dtype=int32) + + **Chained dot products** + + >>> w = jnp.arange(5, 9).reshape(2, 2) + >>> x = jnp.arange(6).reshape(2, 3) + >>> y = jnp.arange(-2, 4).reshape(3, 2) + >>> z = jnp.array([[2, 4, 6], [3, 5, 7]]) + >>> jnp.einsum('ij,jk,kl,lm->im', w, x, y, z) + Array([[ 481, 831, 1181], + [ 651, 1125, 1599]], dtype=int32) + >>> jnp.einsum(w, (0, 1), x, (1, 2), y, (2, 3), z, (3, 4)) # implicit, via indices + Array([[ 481, 831, 1181], + [ 651, 1125, 1599]], dtype=int32) + >>> w @ x @ y @ z # direct chain of matmuls + Array([[ 481, 831, 1181], + [ 651, 1125, 1599]], dtype=int32) + >>> jnp.linalg.multi_dot([w, x, y, z]) + Array([[ 481, 831, 1181], + [ 651, 1125, 1599]], dtype=int32) + + .. _opt_einsum: https://github.com/dgasmith/opt_einsum + """ operands = (subscripts, *operands) if out is not None: raise NotImplementedError("The 'out' argument to jnp.einsum is not supported.") spec = operands[0] if isinstance(operands[0], str) else None - optimize = 'optimal' if optimize is True else optimize + path_type = 'optimal' if optimize is True else Unoptimized() if optimize is False else optimize # Allow handling of shape polymorphism non_constant_dim_types = { @@ -3539,14 +5689,14 @@ def einsum( contract_path = _poly_einsum_handlers.get(ty, _default_poly_einsum_handler) # using einsum_call=True here is an internal api for opt_einsum... sorry operands, contractions = contract_path( - *operands, einsum_call=True, use_blas=True, optimize=optimize) + *operands, einsum_call=True, use_blas=True, optimize=path_type) contractions = tuple((a, frozenset(b), c) for a, b, c, *_ in contractions) einsum = jit(_einsum, static_argnums=(1, 2, 3, 4), inline=True) if spec is not None: einsum = jax.named_call(einsum, name=spec) - return einsum(operands, contractions, precision, # type: ignore[operator] + return einsum(operands, contractions, precision, preferred_element_type, _dot_general) @@ -3563,9 +5713,77 @@ def _default_poly_einsum_handler(*operands, **kwargs): contract_operands = [operands[mapping[id(d)]] for d in out_dummies] return contract_operands, contractions -@util.implements(np.einsum_path) -def einsum_path(subscripts, *operands, optimize='greedy'): - # using einsum_call=True here is an internal api for opt_einsum +@overload +def einsum_path( + subscripts: str, /, + *operands: ArrayLike, + optimize: bool | str | list[tuple[int, ...]] = ..., +) -> tuple[list[tuple[int, ...]], Any]: ... + +@overload +def einsum_path( + arr: ArrayLike, + axes: Sequence[Any], /, + *operands: ArrayLike | Sequence[Any], + optimize: bool | str | list[tuple[int, ...]] = ..., +) -> tuple[list[tuple[int, ...]], Any]: ... + +def einsum_path( + subscripts, /, + *operands, + optimize: bool | str | list[tuple[int, ...]] = 'auto' + ) -> tuple[list[tuple[int, ...]], Any]: + """Evaluates the optimal contraction path without evaluating the einsum. + + JAX implementation of :func:`numpy.einsum_path`. This function calls into + the opt_einsum_ package, and makes use of its optimization routines. + + Args: + subscripts: string containing axes names separated by commas. + *operands: sequence of one or more arrays corresponding to the subscripts. + optimize: specify how to optimize the order of computation. In JAX this defaults + to ``"auto"``. Other options are ``True`` (same as ``"optimize"``), ``False`` + (unoptimized), or any string supported by ``opt_einsum``, which + includes ``"optimize"``,, ``"greedy"``, ``"eager"``, and others. + + Returns: + A tuple containing the path that may be passed to :func:`~jax.numpy.einsum`, and a + printable object representing this optimal path. + + Examples: + >>> key1, key2, key3 = jax.random.split(jax.random.key(0), 3) + >>> x = jax.random.randint(key1, minval=-5, maxval=5, shape=(2, 3)) + >>> y = jax.random.randint(key2, minval=-5, maxval=5, shape=(3, 100)) + >>> z = jax.random.randint(key3, minval=-5, maxval=5, shape=(100, 5)) + >>> path, path_info = jnp.einsum_path("ij,jk,kl", x, y, z, optimize="optimal") + >>> print(path) + [(1, 2), (0, 1)] + >>> print(path_info) + Complete contraction: ij,jk,kl->il + Naive scaling: 4 + Optimized scaling: 3 + Naive FLOP count: 9.000e+3 + Optimized FLOP count: 3.060e+3 + Theoretical speedup: 2.941e+0 + Largest intermediate: 1.500e+1 elements + -------------------------------------------------------------------------------- + scaling BLAS current remaining + -------------------------------------------------------------------------------- + 3 GEMM kl,jk->lj ij,lj->il + 3 GEMM lj,ij->il il->il + + Use the computed path in :func:`~jax.numpy.einsum`: + + >>> jnp.einsum("ij,jk,kl", x, y, z, optimize=path) + Array([[-539, 216, 95, 592, 209], + [ 527, 76, 285, -436, -529]], dtype=int32) + + .. _opt_einsum: https://github.com/dgasmith/opt_einsum + """ + if optimize is True: + optimize = 'optimal' + elif optimize is False: + optimize = Unoptimized() return opt_einsum.contract_path(subscripts, *operands, optimize=optimize) def _removechars(s, chars): @@ -3722,15 +5940,54 @@ def filter_singleton_dims(operand, names, other_shape, other_names): return lax_internal._convert_element_type(operands[0], preferred_element_type, output_weak_type) -@util.implements(np.inner, lax_description=_PRECISION_DOC, - extra_params=_DOT_PREFERRED_ELEMENT_TYPE_DESCRIPTION) @partial(jit, static_argnames=('precision', 'preferred_element_type'), inline=True) def inner( a: ArrayLike, b: ArrayLike, *, precision: PrecisionLike = None, preferred_element_type: DType | None = None, ) -> Array: - # TODO(jakevdp): Non-array input deprecated 2023-09-22; change to error. - util.check_arraylike("inner", a, b, emit_warning=True) + """Compute the inner product of two arrays. + + JAX implementation of :func:`numpy.inner`. + + Unlike :func:`jax.numpy.matmul` or :func:`jax.numpy.dot`, this always performs + a contraction along the last dimension of each input. + + Args: + a: array of shape ``(..., N)`` + b: array of shape ``(..., N)`` + precision: either ``None`` (default), which means the default precision for + the backend, a :class:`~jax.lax.Precision` enum value (``Precision.DEFAULT``, + ``Precision.HIGH`` or ``Precision.HIGHEST``) or a tuple of two + such values indicating precision of ``a`` and ``b``. + preferred_element_type: either ``None`` (default), which means the default + accumulation type for the input types, or a datatype, indicating to + accumulate results to and return a result with that datatype. + + Returns: + array of shape ``(*a.shape[:-1], *b.shape[:-1])`` containing the batched vector + product of the inputs. + + See also: + - :func:`jax.numpy.vecdot`: conjugate multiplication along a specified axis. + - :func:`jax.numpy.tensordot`: general tensor multiplication. + - :func:`jax.numpy.matmul`: general batched matrix & vector multiplication. + + Examples: + For 1D inputs, this implements standard (non-conjugate) vector multiplication: + + >>> a = jnp.array([1j, 3j, 4j]) + >>> b = jnp.array([4., 2., 5.]) + >>> jnp.inner(a, b) + Array(0.+30.j, dtype=complex64) + + For multi-dimensional inputs, batch dimensions are stacked rather than broadcast: + + >>> a = jnp.ones((2, 3)) + >>> b = jnp.ones((5, 3)) + >>> jnp.inner(a, b).shape + (2, 5) + """ + util.check_arraylike("inner", a, b) if ndim(a) == 0 or ndim(b) == 0: a = asarray(a, dtype=preferred_element_type) b = asarray(b, dtype=preferred_element_type) @@ -3744,8 +6001,7 @@ def inner( def outer(a: ArrayLike, b: ArrayLike, out: None = None) -> Array: if out is not None: raise NotImplementedError("The 'out' argument to jnp.outer is not supported.") - # TODO(jakevdp): Non-array input deprecated 2023-09-22; change to error. - util.check_arraylike("outer", a, b, emit_warning=True) + util.check_arraylike("outer", a, b) a, b = util.promote_dtypes(a, b) return ravel(a)[:, None] * ravel(b)[None, :] @@ -3754,8 +6010,7 @@ def outer(a: ArrayLike, b: ArrayLike, out: None = None) -> Array: def cross(a, b, axisa: int = -1, axisb: int = -1, axisc: int = -1, axis: int | None = None): # TODO(jakevdp): NumPy 2.0 deprecates 2D inputs. Follow suit here. - # TODO(jakevdp): Non-array input deprecated 2023-09-22; change to error. - util.check_arraylike("cross", a, b, emit_warning=True) + util.check_arraylike("cross", a, b) if axis is not None: axisa = axis axisb = axis @@ -3782,8 +6037,7 @@ def cross(a, b, axisa: int = -1, axisb: int = -1, axisc: int = -1, @util.implements(np.kron) @jit def kron(a: ArrayLike, b: ArrayLike) -> Array: - # TODO(jakevdp): Non-array input deprecated 2023-09-22; change to error. - util.check_arraylike("kron", a, b, emit_warning=True) + util.check_arraylike("kron", a, b) a, b = util.promote_dtypes(a, b) if ndim(a) < ndim(b): a = expand_dims(a, range(ndim(b) - ndim(a))) @@ -3818,36 +6072,66 @@ def vander( ### Misc -_ARGWHERE_DOC = """\ -Because the size of the output of ``argwhere`` is data-dependent, the function is not -typically compatible with JIT. The JAX version adds the optional ``size`` argument, which -specifies the size of the leading dimension of the output - it must be specified statically -for ``jnp.argwhere`` to be compiled with non-static operands. If ``size`` is specified, -the indices of the first ``size`` True elements will be returned; if there are fewer -nonzero elements than `size` indicates, the index arrays will be zero-padded. -""" - - -@util.implements(np.argwhere, - lax_description=_dedent(""" - Because the size of the output of ``argwhere`` is data-dependent, the function is not - typically compatible with JIT. The JAX version adds the optional ``size`` argument which - must be specified statically for ``jnp.argwhere`` to be used within some of JAX's - transformations."""), - extra_params=_dedent(""" - size : int, optional - If specified, the indices of the first ``size`` True elements will be returned. If there - are fewer results than ``size`` indicates, the return value will be padded with ``fill_value``. - fill_value : array_like, optional - When ``size`` is specified and there are fewer than the indicated number of elements, the - remaining elements will be filled with ``fill_value``, which defaults to zero.""")) def argwhere( a: ArrayLike, *, size: int | None = None, fill_value: ArrayLike | None = None, ) -> Array: - result = transpose(vstack(nonzero(a, size=size, fill_value=fill_value))) + """Find the indices of nonzero array elements + + JAX implementation of :func:`numpy.argwhere`. + + ``jnp.argwhere(x)`` is essentially equivalent to ``jnp.column_stack(jnp.nonzero(x))`` + with special handling for zero-dimensional (i.e. scalar) inputs. + + Because the size of the output of ``argwhere`` is data-dependent, the function is not + typically compatible with JIT. The JAX version adds the optional ``size`` argument, which + specifies the size of the leading dimension of the output - it must be specified statically + for ``jnp.argwhere`` to be compiled with non-static operands. See :func:`jax.numpy.nonzero` + for a full discussion of ``size`` and its semantics. + + Args: + a: array for which to find nonzero elements + size: optional integer specifying statically the number of expected nonzero elements. + This must be specified in order to use ``argwhere`` within JAX transformations like + :func:`jax.jit`. See :func:`jax.numpy.nonzero` for more information. + fill_value: optional array specifying the fill value when ``size`` is specified. + See :func:`jax.numpy.nonzero` for more information. + + Returns: + a two-dimensional array of shape ``[size, x.ndim]``. If ``size`` is not specified as + an argument, it is equal to the number of nonzero elements in ``x``. + + See Also: + - :func:`jax.numpy.where` + - :func:`jax.numpy.nonzero` + + Examples: + Two-dimensional array: + + >>> x = jnp.array([[1, 0, 2], + ... [0, 3, 0]]) + >>> jnp.argwhere(x) + Array([[0, 0], + [0, 2], + [1, 1]], dtype=int32) + + Equivalent computation using :func:`jax.numpy.column_stack` and :func:`jax.numpy.nonzero`: + + >>> jnp.column_stack(jnp.nonzero(x)) + Array([[0, 0], + [0, 2], + [1, 1]], dtype=int32) + + Special case for zero-dimensional (i.e. scalar) inputs: + + >>> jnp.argwhere(1) + Array([], shape=(1, 0), dtype=int32) + >>> jnp.argwhere(0) + Array([], shape=(0, 0), dtype=int32) + """ + result = transpose(vstack(nonzero(atleast_1d(a), size=size, fill_value=fill_value))) if ndim(a) == 0: return result[:0].reshape(result.shape[0], 0) return result.reshape(result.shape[0], ndim(a)) @@ -3951,30 +6235,29 @@ def _nanargmin(a, axis: int | None = None, keepdims : bool = False): @util.implements(np.sort, extra_params=""" -kind : deprecated; specify sort algorithm using stable=True or stable=False -order : not supported stable : bool, default=True Specify whether to use a stable sort. descending : bool, default=False Specify whether to do a descending sort. - """) +kind : deprecated; specify sort algorithm using stable=True or stable=False +order : not supported +""") @partial(jit, static_argnames=('axis', 'kind', 'order', 'stable', 'descending')) def sort( a: ArrayLike, axis: int | None = -1, - kind: str | None = None, - order: None = None, *, + *, + kind: None = None, + order: None = None, stable: bool = True, descending: bool = False, ) -> Array: util.check_arraylike("sort", a) if kind is not None: - # Deprecated 2024-01-05 - warnings.warn("The 'kind' argument to sort has no effect, and is deprecated. " - "Use stable=True or stable=False to specify sort stability.", - category=DeprecationWarning, stacklevel=2) + raise TypeError("'kind' argument to sort is not supported. Use" + " stable=True or stable=False to specify sort stability.") if order is not None: - raise ValueError("'order' argument to sort is not supported.") + raise TypeError("'order' argument to sort is not supported.") if axis is None: arr = ravel(a) axis = 0 @@ -3996,8 +6279,7 @@ def sort_complex(a: ArrayLike) -> Array: @partial(jit, static_argnames=('axis',)) def lexsort(keys: Array | np.ndarray | Sequence[ArrayLike], axis: int = -1) -> Array: key_tuple = tuple(keys) - # TODO(jakevdp): Non-array input deprecated 2023-09-22; change to error. - util.check_arraylike("lexsort", *key_tuple, emit_warning=True) + util.check_arraylike("lexsort", *key_tuple) key_arrays = tuple(asarray(k) for k in key_tuple) if len(key_arrays) == 0: raise TypeError("need sequence of keys with len > 0 in lexsort") @@ -4012,31 +6294,30 @@ def lexsort(keys: Array | np.ndarray | Sequence[ArrayLike], axis: int = -1) -> A @util.implements(np.argsort, extra_params=""" -kind : deprecated; specify sort algorithm using stable=True or stable=False -order : not supported stable : bool, default=True Specify whether to use a stable sort. descending : bool, default=False Specify whether to do a descending sort. +kind : deprecated; specify sort algorithm using stable=True or stable=False +order : not supported """) @partial(jit, static_argnames=('axis', 'kind', 'order', 'stable', 'descending')) def argsort( a: ArrayLike, axis: int | None = -1, - kind: str | None = None, + *, + kind: None = None, order: None = None, - *, stable: bool = True, + stable: bool = True, descending: bool = False, ) -> Array: util.check_arraylike("argsort", a) arr = asarray(a) if kind is not None: - # Deprecated 2024-01-05 - warnings.warn("The 'kind' argument to argsort has no effect, and is deprecated. " - "Use stable=True or stable=False to specify sort stability.", - category=DeprecationWarning, stacklevel=2) + raise TypeError("'kind' argument to argsort is not supported. Use" + " stable=True or stable=False to specify sort stability.") if order is not None: - raise ValueError("'order' argument to argsort is not supported.") + raise TypeError("'order' argument to argsort is not supported.") if axis is None: arr = ravel(arr) axis = 0 @@ -4055,17 +6336,57 @@ def argsort( return lax.rev(indices, dimensions=[dimension]) if descending else indices -@util.implements(np.partition, lax_description=""" -The JAX version requires the ``kth`` argument to be a static integer rather than -a general array. This is implemented via two calls to :func:`jax.lax.top_k`. If -you're only accessing the top or bottom k values of the output, it may be more -efficient to call :func:`jax.lax.top_k` directly. - -The JAX version differs from the NumPy version in the treatment of NaN entries; -NaNs which have the negative bit set are sorted to the beginning of the array. -""") @partial(jit, static_argnames=['kth', 'axis']) def partition(a: ArrayLike, kth: int, axis: int = -1) -> Array: + """Returns a partially-sorted copy of an array. + + JAX implementation of :func:`numpy.partition`. The JAX version differs from + NumPy in the treatment of NaN entries: NaNs which have the negative bit set + are sorted to the beginning of the array. + + Args: + a: array to be partitioned. + kth: static integer index about which to partition the array. + axis: static integer axis along which to partition the array; default is -1. + + Returns: + A copy of ``a`` partitioned at the ``kth`` value along ``axis``. The entries + before ``kth`` are values smaller than ``take(a, kth, axis)``, and entries + after ``kth`` are indices of values larger than ``take(a, kth, axis)`` + + Note: + The JAX version requires the ``kth`` argument to be a static integer rather than + a general array. This is implemented via two calls to :func:`jax.lax.top_k`. If + you're only accessing the top or bottom k values of the output, it may be more + efficient to call :func:`jax.lax.top_k` directly. + + See Also: + - :func:`jax.numpy.sort`: full sort + - :func:`jax.numpy.argpartition`: indirect partial sort + - :func:`jax.lax.top_k`: directly find the top k entries + - :func:`jax.lax.approx_max_k`: compute the approximate top k entries + - :func:`jax.lax.approx_min_k`: compute the approximate bottom k entries + + Examples: + >>> x = jnp.array([6, 8, 4, 3, 1, 9, 7, 5, 2, 3]) + >>> kth = 4 + >>> x_partitioned = jnp.partition(x, kth) + >>> x_partitioned + Array([1, 2, 3, 3, 4, 9, 8, 7, 6, 5], dtype=int32) + + The result is a partially-sorted copy of the input. All values before ``kth`` + are of smaller than the pivot value, and all values after ``kth`` are larger + than the pivot value: + + >>> smallest_values = x_partitioned[:kth] + >>> pivot_value = x_partitioned[kth] + >>> largest_values = x_partitioned[kth + 1:] + >>> print(smallest_values, pivot_value, largest_values) + [1 2 3 3] 4 [9 8 7 6 5] + + Notice that among ``smallest_values`` and ``largest_values``, the returned + order is arbitrary and implementation-dependent. + """ # TODO(jakevdp): handle NaN values like numpy. util.check_arraylike("partition", a) arr = asarray(a) @@ -4081,17 +6402,58 @@ def partition(a: ArrayLike, kth: int, axis: int = -1) -> Array: return swapaxes(out, -1, axis) -@util.implements(np.argpartition, lax_description=""" -The JAX version requires the ``kth`` argument to be a static integer rather than -a general array. This is implemented via two calls to :func:`jax.lax.top_k`. If -you're only accessing the top or bottom k values of the output, it may be more -efficient to call :func:`jax.lax.top_k` directly. - -The JAX version differs from the NumPy version in the treatment of NaN entries; -NaNs which have the negative bit set are sorted to the beginning of the array. -""") @partial(jit, static_argnames=['kth', 'axis']) def argpartition(a: ArrayLike, kth: int, axis: int = -1) -> Array: + """Returns indices that partially sort an array. + + JAX implementation of :func:`numpy.argpartition`. The JAX version differs from + NumPy in the treatment of NaN entries: NaNs which have the negative bit set are + sorted to the beginning of the array. + + Args: + a: array to be partitioned. + kth: static integer index about which to partition the array. + axis: static integer axis along which to partition the array; default is -1. + + Returns: + Indices which partition ``a`` at the ``kth`` value along ``axis``. The entries + before ``kth`` are indices of values smaller than ``take(a, kth, axis)``, and + entries after ``kth`` are indices of values larger than ``take(a, kth, axis)`` + + Note: + The JAX version requires the ``kth`` argument to be a static integer rather than + a general array. This is implemented via two calls to :func:`jax.lax.top_k`. If + you're only accessing the top or bottom k values of the output, it may be more + efficient to call :func:`jax.lax.top_k` directly. + + See Also: + - :func:`jax.numpy.partition`: direct partial sort + - :func:`jax.numpy.argsort`: full indirect sort + - :func:`jax.lax.top_k`: directly find the top k entries + - :func:`jax.lax.approx_max_k`: compute the approximate top k entries + - :func:`jax.lax.approx_min_k`: compute the approximate bottom k entries + + Examples: + >>> x = jnp.array([6, 8, 4, 3, 1, 9, 7, 5, 2, 3]) + >>> kth = 4 + >>> idx = jnp.argpartition(x, kth) + >>> idx + Array([4, 8, 3, 9, 2, 0, 1, 5, 6, 7], dtype=int32) + + The result is a sequence of indices that partially sort the input. All indices + before ``kth`` are of values smaller than the pivot value, and all indices + after ``kth`` are of values larger than the pivot value: + + >>> x_partitioned = x[idx] + >>> smallest_values = x_partitioned[:kth] + >>> pivot_value = x_partitioned[kth] + >>> largest_values = x_partitioned[kth + 1:] + >>> print(smallest_values, pivot_value, largest_values) + [1 2 3 3] 4 [6 8 9 7 5] + + Notice that among ``smallest_values`` and ``largest_values``, the returned + order is arbitrary and implementation-dependent. + """ # TODO(jakevdp): handle NaN values like numpy. util.check_arraylike("partition", a) arr = asarray(a) @@ -4157,9 +6519,58 @@ def roll(a: ArrayLike, shift: ArrayLike | Sequence[int], return _roll_static(arr, shift, axis) -@util.implements(np.rollaxis, lax_description=_ARRAY_VIEW_DOC) @partial(jit, static_argnames=('axis', 'start')) def rollaxis(a: ArrayLike, axis: int, start: int = 0) -> Array: + """Roll the specified axis to a given position. + + JAX implementation of :func:`numpy.rollaxis`. + + This function exists for compatibility with NumPy, but in most cases the newer + :func:`jax.numpy.moveaxis` instead, because the meaning of its arguments is + more intuitive. + + Args: + a: input array. + axis: index of the axis to roll forward. + start: index toward which the axis will be rolled (default = 0). After + normalizing negative axes, if ``start <= axis``, the axis is rolled to + the ``start`` index; if ``start > axis``, the axis is rolled until the + position before ``start``. + + Returns: + Copy of ``a`` with rolled axis. + + Notes: + Unlike :func:`numpy.rollaxis`, :func:`jax.numpy.rollaxis` will return a copy rather + than a view of the input array. However, under JIT, the compiler will optimize away + such copies when possible, so this doesn't have performance impacts in practice. + + See also: + - :func:`jax.numpy.moveaxis`: newer API with clearer semantics than ``rollaxis``; + this should be preferred to ``rollaxis`` in most cases. + - :func:`jax.numpy.swapaxes`: swap two axes. + - :func:`jax.numpy.transpose`: general permutation of axes. + + Examples: + >>> a = jnp.ones((2, 3, 4, 5)) + + Roll axis 2 to the start of the array: + + >>> jnp.rollaxis(a, 2).shape + (4, 2, 3, 5) + + Roll axis 1 to the end of the array: + + >>> jnp.rollaxis(a, 1, a.ndim).shape + (2, 4, 5, 3) + + Equivalent of these two with :func:`~jax.numpy.moveaxis` + + >>> jnp.moveaxis(a, 2, 0).shape + (4, 2, 3, 5) + >>> jnp.moveaxis(a, 1, -1).shape + (2, 4, 5, 3) + """ util.check_arraylike("rollaxis", a) start = core.concrete_or_error(operator.index, start, "'start' argument of jnp.rollaxis()") a_ndim = ndim(a) @@ -4235,27 +6646,6 @@ def unpackbits( return swapaxes(unpacked, axis, -1) -@util.implements(np.take, skip_params=['out'], - lax_description=""" -By default, JAX assumes that all indices are in-bounds. Alternative out-of-bound -index semantics can be specified via the ``mode`` parameter (see below). -""", - extra_params=""" -mode : string, default="fill" - Out-of-bounds indexing mode. The default mode="fill" returns invalid values - (e.g. NaN) for out-of bounds indices (see also ``fill_value`` below). - For more discussion of mode options, see :attr:`jax.numpy.ndarray.at`. -fill_value : optional - The fill value to return for out-of-bounds slices when mode is 'fill'. Ignored - otherwise. Defaults to NaN for inexact types, the largest negative value for - signed types, the largest positive value for unsigned types, and True for booleans. -unique_indices : bool, default=False - If True, the implementation will assume that the indices are unique, - which can result in more efficient execution on some backends. -indices_are_sorted : bool, default=False - If True, the implementation will assume that the indices are sorted in - ascending order, which can lead to more efficient execution on some backends. -""") def take( a: ArrayLike, indices: ArrayLike, @@ -4264,8 +6654,80 @@ def take( mode: str | None = None, unique_indices: bool = False, indices_are_sorted: bool = False, - fill_value: ArrayLike | None = None, + fill_value: StaticScalar | None = None, ) -> Array: + """Take elements from an array. + + JAX implementation of :func:`numpy.take`, implemented in terms of + :func:`jax.lax.gather`. JAX's behavior differs from NumPy in the case + of out-of-bound indices; see the ``mode`` parameter below. + + Args: + a: array from which to take values. + indices: N-dimensional array of integer indices of values to take from the array. + axis: the axis along which to take values. If not specified, the array will + be flattened before indexing is applied. + mode: Out-of-bounds indexing mode, either ``"fill"`` or ``"clip"``. The default + ``mode="fill"`` returns invalid values (e.g. NaN) for out-of bounds indices; + the ``fill_value`` argument gives control over this value. For more discussion + of ``mode`` options, see :attr:`jax.numpy.ndarray.at`. + fill_value: The fill value to return for out-of-bounds slices when mode is 'fill'. + Ignored otherwise. Defaults to NaN for inexact types, the largest negative value for + signed types, the largest positive value for unsigned types, and True for booleans. + unique_indices: If True, the implementation will assume that the indices are unique, + which can result in more efficient execution on some backends. If set to True and + indices are not unique, the output is undefined. + indices_are_sorted : If True, the implementation will assume that the indices are + sorted in ascending order, which can lead to more efficient execution on some + backends. If set to True and indices are not sorted, the output is undefined. + + Returns: + Array of values extracted from ``a``. + + See also: + - :attr:`jax.numpy.ndarray.at`: take values via indexing syntax. + - :func:`jax.numpy.take_along_axis`: take values along an axis + + Examples: + >>> x = jnp.array([[1., 2., 3.], + ... [4., 5., 6.]]) + >>> indices = jnp.array([2, 0]) + + Passing no axis results in indexing into the flattened array: + + >>> jnp.take(x, indices) + Array([3., 1.], dtype=float32) + >>> x.ravel()[indices] # equivalent indexing syntax + Array([3., 1.], dtype=float32) + + Passing an axis results ind applying the index to every subarray along the axis: + + >>> jnp.take(x, indices, axis=1) + Array([[3., 1.], + [6., 4.]], dtype=float32) + >>> x[:, indices] # equivalent indexing syntax + Array([[3., 1.], + [6., 4.]], dtype=float32) + + Out-of-bound indices fill with invalid values. For float inputs, this is `NaN`: + + >>> jnp.take(x, indices, axis=0) + Array([[nan, nan, nan], + [ 1., 2., 3.]], dtype=float32) + >>> x.at[indices].get(mode='fill', fill_value=jnp.nan) # equivalent indexing syntax + Array([[nan, nan, nan], + [ 1., 2., 3.]], dtype=float32) + + This default out-of-bound behavior can be adjusted using the ``mode`` parameter, for + example, we can instead clip to the last valid value: + + >>> jnp.take(x, indices, axis=0, mode='clip') + Array([[4., 5., 6.], + [1., 2., 3.]], dtype=float32) + >>> x.at[indices].get(mode='clip') # equivalent indexing syntax + Array([[4., 5., 6.], + [1., 2., 3.]], dtype=float32) + """ return _take(a, indices, None if axis is None else operator.index(axis), out, mode, unique_indices=unique_indices, indices_are_sorted=indices_are_sorted, fill_value=fill_value) @@ -4341,24 +6803,85 @@ def _normalize_index(index, axis_size): return lax.select(index < 0, lax.add(index, axis_size_val), index) -TAKE_ALONG_AXIS_DOC = """ -Unlike :func:`numpy.take_along_axis`, :func:`jax.numpy.take_along_axis` takes -an optional ``mode`` parameter controlling how out-of-bounds indices should be -handled. By default, out-of-bounds indices yield invalid values (e.g., ``NaN``). -See :attr:`jax.numpy.ndarray.at` for further discussion of out-of-bounds -indexing in JAX. -""" - - -@util.implements(np.take_along_axis, update_doc=False, - lax_description=TAKE_ALONG_AXIS_DOC) -@partial(jit, static_argnames=('axis', 'mode')) +@partial(jit, static_argnames=('axis', 'mode', 'fill_value')) def take_along_axis( arr: ArrayLike, indices: ArrayLike, axis: int | None, mode: str | lax.GatherScatterMode | None = None, + fill_value: StaticScalar | None = None, ) -> Array: + """Take elements from an array. + + JAX implementation of :func:`numpy.take_along_axis`, implemented in + terms of :func:`jax.lax.gather`. JAX's behavior differs from NumPy + in the case of out-of-bound indices; see the ``mode`` parameter below. + + Args: + a: array from which to take values. + indices: array of integer indices. If ``axis`` is ``None``, must be one-dimensional. + If ``axis`` is not None, must have ``a.ndim == indices.ndim``, and ``a`` must be + broadcast-compatible with ``indices`` along dimensions other than ``axis``. + axis: the axis along which to take values. If not specified, the array will + be flattened before indexing is applied. + mode: Out-of-bounds indexing mode, either ``"fill"`` or ``"clip"``. The default + ``mode="fill"`` returns invalid values (e.g. NaN) for out-of bounds indices. + For more discussion of ``mode`` options, see :attr:`jax.numpy.ndarray.at`. + + Returns: + Array of values extracted from ``a``. + + See also: + - :attr:`jax.numpy.ndarray.at`: take values via indexing syntax. + - :func:`jax.numpy.take`: take the same indices along every axis slice. + + Examples: + >>> x = jnp.array([[1., 2., 3.], + ... [4., 5., 6.]]) + >>> indices = jnp.array([[0, 2], + ... [1, 0]]) + >>> jnp.take_along_axis(x, indices, axis=1) + Array([[1., 3.], + [5., 4.]], dtype=float32) + >>> x[jnp.arange(2)[:, None], indices] # equivalent via indexing syntax + Array([[1., 3.], + [5., 4.]], dtype=float32) + + Out-of-bound indices fill with invalid values. For float inputs, this is `NaN`: + + >>> indices = jnp.array([[1, 0, 2]]) + >>> jnp.take_along_axis(x, indices, axis=0) + Array([[ 4., 2., nan]], dtype=float32) + >>> x.at[indices, jnp.arange(3)].get( + ... mode='fill', fill_value=jnp.nan) # equivalent via indexing syntax + Array([[ 4., 2., nan]], dtype=float32) + + ``take_along_axis`` is helpful for extracting values from multi-dimensional + argsorts and arg reductions. For, here we compute :func:`~jax.numpy.argsort` + indices along an axis, and use ``take_along_axis`` to construct the sorted + array: + + >>> x = jnp.array([[5, 3, 4], + ... [2, 7, 6]]) + >>> indices = jnp.argsort(x, axis=1) + >>> indices + Array([[1, 2, 0], + [0, 2, 1]], dtype=int32) + >>> jnp.take_along_axis(x, indices, axis=1) + Array([[3, 4, 5], + [2, 6, 7]], dtype=int32) + + Similarly, we can use :func:`~jax.numpy.argmin` with ``keepdims=True`` and + use ``take_along_axis`` to extract the minimum value: + + >>> idx = jnp.argmin(x, axis=1, keepdims=True) + >>> idx + Array([[1], + [0]], dtype=int32) + >>> jnp.take_along_axis(x, idx, axis=1) + Array([[3], + [2]], dtype=int32) + """ util.check_arraylike("take_along_axis", arr, indices) a = asarray(arr) index_dtype = dtypes.dtype(indices) @@ -4370,8 +6893,9 @@ def take_along_axis( if ndim(indices) != 1: msg = "take_along_axis indices must be 1D if axis=None, got shape {}" raise ValueError(msg.format(idx_shape)) - return take_along_axis(a.ravel(), indices, 0) - rank = ndim(arr) + a = a.ravel() + axis = 0 + rank = a.ndim if rank != ndim(indices): msg = "indices and arr must have the same number of dimensions; {} vs. {}" raise ValueError(msg.format(ndim(indices), a.ndim)) @@ -4439,7 +6963,7 @@ def replace(tup, val): collapsed_slice_dims=tuple(collapsed_slice_dims), start_index_map=tuple(start_index_map)) return lax.gather(a, gather_indices_arr, dnums, tuple(slice_sizes), - mode="fill" if mode is None else mode) + mode="fill" if mode is None else mode, fill_value=fill_value) ### Indexing @@ -4488,6 +7012,10 @@ def _attempt_rewriting_take_via_slice(arr: Array, idx: Any, mode: str | None) -> for elt in (i.start, i.stop, i.step)): return None + if any(i is Ellipsis for i in idx): + # Remove ellipses and add trailing `slice(None)`. + idx = _canonicalize_tuple_index(arr.ndim, idx=idx) + simple_revs = {i for i, ind in enumerate(idx) if _is_simple_reverse_slice(ind)} int_indices = {i for i, (ind, size) in enumerate(zip(idx, arr.shape)) if _is_valid_integer_index_for_slice(ind, size, mode)} @@ -4861,8 +7389,8 @@ def _index_to_gather(x_shape: Sequence[int], idx: Sequence[Any], "with type {} at position {}, indexer value {}") raise TypeError(msg.format(abstract_i.dtype.name, idx_pos, i)) - msg = "Indexing mode not yet supported. Open a feature request!\n{}" - raise IndexError(msg.format(idx)) + raise IndexError("Indexing mode not yet supported. Got unsupported indexer " + f"at position {idx_pos}: {i!r}") if len(gather_indices) == 0: gather_indices_array: ArrayLike = np.zeros((0,), dtype=index_dtype) @@ -5172,15 +7700,134 @@ def lcm(x1: ArrayLike, x2: ArrayLike) -> Array: ufuncs.multiply(x1, ufuncs.floor_divide(x2, d))) -@util.implements(np.extract) -def extract(condition: ArrayLike, arr: ArrayLike) -> Array: - return compress(ravel(condition), ravel(arr)) +def extract(condition: ArrayLike, arr: ArrayLike, + *, size: int | None = None, fill_value: ArrayLike = 0) -> Array: + """Return the elements of an array that satisfy a condition. + + JAX implementation of :func:`numpy.extract`. + + Args: + condition: array of conditions. Will be converted to boolean and flattened to 1D. + arr: array of values to extract. Will be flattened to 1D. + size: optional static size for output. Must be specified in order for ``extract`` + to be compatible with JAX transformations like :func:`~jax.jit` or :func:`~jax.vmap`. + fill_value: if ``size`` is specified, fill padded entries with this value (default: 0). + + Returns: + 1D array of extracted entries . If ``size`` is specified, the result will have shape + ``(size,)`` and be right-padded with ``fill_value``. If ``size`` is not specified, + the output shape will depend on the number of True entries in ``condition``. + + Notes: + This function does not require strict shape agreement between ``condition`` and ``arr``. + If ``condition.size > arr.size``, then ``condition`` will be truncated, and if + ``arr.size > condition.size``, then ``arr`` will be truncated. + + See also: + :func:`jax.numpy.compress`: multi-dimensional version of ``extract``. + + Examples: + Extract values from a 1D array: + + >>> x = jnp.array([1, 2, 3, 4, 5, 6]) + >>> mask = (x % 2 == 0) + >>> jnp.extract(mask, x) + Array([2, 4, 6], dtype=int32) + + In the simplest case, this is equivalent to boolean indexing: + + >>> x[mask] + Array([2, 4, 6], dtype=int32) + + For use with JAX transformations, you can pass the ``size`` argument to + specify a static shape for the output, along with an optional ``fill_value`` + that defaults to zero: + + >>> jnp.extract(mask, x, size=len(x), fill_value=0) + Array([2, 4, 6, 0, 0, 0], dtype=int32) + + Notice that unlike with boolean indexing, ``extract`` does not require strict + agreement between the sizes of the array and condition, and will effectively + truncate both to the minimum size: + + >>> short_mask = jnp.array([False, True]) + >>> jnp.extract(short_mask, x) + Array([2], dtype=int32) + >>> long_mask = jnp.array([True, False, True, False, False, False, False, False]) + >>> jnp.extract(long_mask, x) + Array([1, 3], dtype=int32) + """ + util.check_arraylike("extreact", condition, arr, fill_value) + return compress(ravel(condition), ravel(arr), size=size, fill_value=fill_value) -@util.implements(np.compress, skip_params=['out']) def compress(condition: ArrayLike, a: ArrayLike, axis: int | None = None, - out: None = None) -> Array: - util.check_arraylike("compress", condition, a) + *, size: int | None = None, fill_value: ArrayLike = 0, out: None = None) -> Array: + """Compress an array along a given axis using a boolean condition. + + JAX implementation of :func:`numpy.compress`. + + Args: + condition: 1-dimensional array of conditions. Will be converted to boolean. + a: N-dimensional array of values. + axis: axis along which to compress. If None (default) then ``a`` will be + flattened, and axis will be set to 0. + size: optional static size for output. Must be specified in order for ``compress`` + to be compatible with JAX transformations like :func:`~jax.jit` or :func:`~jax.vmap`. + fill_value: if ``size`` is specified, fill padded entries with this value (default: 0). + out: not implemented by JAX. + + Returns: + An array of dimension ``a.ndim``, compressed along the specified axis. + + See also: + - :func:`jax.numpy.extract`: 1D version of ``compress``. + - :meth:`jax.Array.compress`: equivalent functionality as an array method. + + Notes: + This function does not require strict shape agreement between ``condition`` and ``a``. + If ``condition.size > a.shape[axis]``, then ``condition`` will be truncated, and if + ``a.shape[axis] > condition.size``, then ``a`` will be truncated. + + Examples: + Compressing along the rows of a 2D array: + + >>> a = jnp.array([[1, 2, 3, 4], + ... [5, 6, 7, 8], + ... [9, 10, 11, 12]]) + >>> condition = jnp.array([True, False, True]) + >>> jnp.compress(condition, a, axis=0) + Array([[ 1, 2, 3, 4], + [ 9, 10, 11, 12]], dtype=int32) + + For convenience, you can equivalently use the :meth:`~jax.Array.compress` + method of JAX arrays: + + >>> a.compress(condition, axis=0) + Array([[ 1, 2, 3, 4], + [ 9, 10, 11, 12]], dtype=int32) + + Note that the condition need not match the shape of the specified axis; + here we compress the columns with the length-3 condition. Values beyond + the size of the condition are ignored: + + >>> jnp.compress(condition, a, axis=1) + Array([[ 1, 3], + [ 5, 7], + [ 9, 11]], dtype=int32) + + The optional ``size`` argument lets you specify a static output size so + that the output is statically-shaped, and so this function can be used + with transformations like :func:`~jax.jit` and :func:`~jax.vmap`: + + >>> f = lambda c, a: jnp.extract(c, a, size=len(a), fill_value=0) + >>> mask = (a % 3 == 0) + >>> jax.vmap(f)(mask, a) + Array([[ 3, 0, 0, 0], + [ 6, 0, 0, 0], + [ 9, 12, 0, 0]], dtype=int32) + """ + util.check_arraylike("compress", condition, a, fill_value) condition_arr = asarray(condition).astype(bool) if out is not None: raise NotImplementedError("The 'out' argument to jnp.compress is not supported.") @@ -5192,10 +7839,20 @@ def compress(condition: ArrayLike, a: ArrayLike, axis: int | None = None, else: arr = moveaxis(a, axis, 0) condition_arr, extra = condition_arr[:arr.shape[0]], condition_arr[arr.shape[0]:] - if reductions.any(extra): - raise ValueError("condition contains entries that are out of bounds") arr = arr[:condition_arr.shape[0]] - return moveaxis(arr[condition_arr], 0, axis) + + if size is None: + if reductions.any(extra): + raise ValueError("condition contains entries that are out of bounds") + result = arr[condition_arr] + elif not 0 <= size <= arr.shape[0]: + raise ValueError("size must be positive and not greater than the size of the array axis;" + f" got {size=} for a.shape[axis]={arr.shape[0]}") + else: + mask = expand_dims(condition_arr, range(1, arr.ndim)) + arr = where(mask, arr, array(fill_value, dtype=arr.dtype)) + result = arr[argsort(condition_arr, stable=True, descending=True)][:size] + return moveaxis(result, 0, axis) @util.implements(np.cov) @@ -5321,19 +7978,71 @@ def _searchsorted_via_compare_all(sorted_arr: Array, query: Array, side: str, dt return comparisons.sum(dtype=dtype, axis=0) -@util.implements(np.searchsorted, skip_params=['sorter'], - extra_params=_dedent(""" - method : str - One of 'scan' (default), 'scan_unrolled', 'sort' or 'compare_all'. Controls the method used by the - implementation: 'scan' tends to be more performant on CPU (particularly when ``a`` is - very large), 'scan_unrolled' is more performant on GPU at the expense of additional compile time, - 'sort' is often more performant on accelerator backends like GPU and TPU - (particularly when ``v`` is very large), and 'compare_all' can be most performant - when ``a`` is very small.""")) -@partial(jit, static_argnames=('side', 'sorter', 'method')) +@partial(jit, static_argnames=('side', 'method')) def searchsorted(a: ArrayLike, v: ArrayLike, side: str = 'left', - sorter: None = None, *, method: str = 'scan') -> Array: - util.check_arraylike("searchsorted", a, v) + sorter: ArrayLike | None = None, *, method: str = 'scan') -> Array: + """Perform a binary search within a sorted array. + + JAX implementation of :func:`numpy.searchsorted`. + + This will return the indices within a sorted array ``a`` where values in ``v`` + can be inserted to maintain its sort order. + + Args: + a: one-dimensional array, assumed to be in sorted order unless ``sorter`` is specified. + v: N-dimensional array of query values + side: ``'left'`` (default) or ``'right'``; specifies whether insertion indices will be + to the left or the right in case of ties. + sorter: optional array of indices specifying the sort order of ``a``. If specified, + then the algorithm assumes that ``a[sorter]`` is in sorted order. + method: one of ``'scan'`` (default), ``'scan_unrolled'``, ``'sort'`` or ``'compare_all'``. + See *Note* below. + + Returns: + Array of insertion indices of shape ``v.shape``. + + Note: + The ``method`` argument controls the algorithm used to compute the insertion indices. + + - ``'scan'`` (the default) tends to be more performant on CPU, particularly when ``a`` is + very large. + - ``'scan_unrolled'`` is more performant on GPU at the expense of additional compile time. + - ``'sort'`` is often more performant on accelerator backends like GPU and TPU, particularly + when ``v`` is very large. + - ``'compare_all'`` tends to be the most performant when ``a`` is very small. + + Examples: + Searching for a single value: + + >>> a = jnp.array([1, 2, 2, 3, 4, 5, 5]) + >>> jnp.searchsorted(a, 2) + Array(1, dtype=int32) + >>> jnp.searchsorted(a, 2, side='right') + Array(3, dtype=int32) + + Searching for a batch of values: + + >>> vals = jnp.array([0, 3, 8, 1.5, 2]) + >>> jnp.searchsorted(a, vals) + Array([0, 3, 7, 1, 1], dtype=int32) + + Optionally, the ``sorter`` argument can be used to find insertion indices into + an array sorted via :func:`jax.numpy.argsort`: + + >>> a = jnp.array([4, 3, 5, 1, 2]) + >>> sorter = jnp.argsort(a) + >>> jnp.searchsorted(a, vals, sorter=sorter) + Array([0, 2, 5, 1, 1], dtype=int32) + + The result is equivalent to passing the sorted array: + + >>> jnp.searchsorted(jnp.sort(a), vals) + Array([0, 2, 5, 1, 1], dtype=int32) + """ + if sorter is None: + util.check_arraylike("searchsorted", a, v) + else: + util.check_arraylike("searchsorted", a, v, sorter) if side not in ['left', 'right']: raise ValueError(f"{side!r} is an invalid value for keyword 'side'. " "Expected one of ['left', 'right'].") @@ -5341,11 +8050,11 @@ def searchsorted(a: ArrayLike, v: ArrayLike, side: str = 'left', raise ValueError( f"{method!r} is an invalid value for keyword 'method'. " "Expected one of ['sort', 'scan', 'scan_unrolled', 'compare_all'].") - if sorter is not None: - raise NotImplementedError("sorter is not implemented") if ndim(a) != 1: raise ValueError("a should be 1-dimensional") a, v = util.promote_dtypes(a, v) + if sorter is not None: + a = a[sorter] dtype = int32 if len(a) <= np.iinfo(np.int32).max else int64 if len(a) == 0: return zeros_like(v, dtype=dtype) @@ -5366,7 +8075,7 @@ def digitize(x: ArrayLike, bins: ArrayLike, right: bool = False) -> Array: if bins_arr.ndim != 1: raise ValueError(f"digitize: bins must be a 1-dimensional array; got {bins=}") if bins_arr.shape[0] == 0: - return zeros(x, dtype=dtypes.canonicalize_dtype(int_)) + return zeros_like(x, dtype=int32) side = 'right' if not right else 'left' return where( bins_arr[-1] >= bins_arr[0], diff --git a/jax/_src/numpy/linalg.py b/jax/_src/numpy/linalg.py index 2ce86faa1d7a..7d411208525e 100644 --- a/jax/_src/numpy/linalg.py +++ b/jax/_src/numpy/linalg.py @@ -16,10 +16,11 @@ from collections.abc import Sequence from functools import partial +import itertools +import math import warnings import numpy as np -import textwrap import operator from typing import Literal, NamedTuple, cast, overload @@ -28,12 +29,13 @@ from jax import lax from jax._src.lax import lax as lax_internal +from jax._src.lax.lax import PrecisionLike from jax._src.lax import linalg as lax_linalg from jax._src.numpy import lax_numpy as jnp from jax._src.numpy import reductions, ufuncs -from jax._src.numpy.util import implements, promote_dtypes_inexact, check_arraylike +from jax._src.numpy.util import promote_dtypes_inexact, check_arraylike from jax._src.util import canonicalize_axis -from jax._src.typing import ArrayLike, Array +from jax._src.typing import ArrayLike, Array, DTypeLike, DeprecatedArg class EighResult(NamedTuple): @@ -64,9 +66,67 @@ def _H(x: ArrayLike) -> Array: def _symmetrize(x: Array) -> Array: return (x + _H(x)) / 2 -@implements(np.linalg.cholesky) @partial(jit, static_argnames=['upper']) def cholesky(a: ArrayLike, *, upper: bool = False) -> Array: + """Compute the Cholesky decomposition of a matrix. + + JAX implementation of :func:`numpy.linalg.cholesky`. + + The Cholesky decomposition of a matrix `A` is: + + .. math:: + + A = U^HU + + or + + .. math:: + + A = LL^H + + where `U` is an upper-triangular matrix and `L` is a lower-triangular matrix, and + :math:`X^H` is the Hermitian transpose of `X`. + + Args: + a: input array, representing a (batched) positive-definite hermitian matrix. + Must have shape ``(..., N, N)``. + upper: if True, compute the upper Cholesky decomposition `L`. if False + (default), compute the lower Cholesky decomposition `U`. + + Returns: + array of shape ``(..., N, N)`` representing the Cholesky decomposition + of the input. If the input is not Hermitian positive-definite, The result + will contain NaN entries. + + + See also: + - :func:`jax.scipy.linalg.cholesky`: SciPy-style Cholesky API + - :func:`jax.lax.linalg.cholesky`: XLA-style Cholesky API + + Examples: + A small real Hermitian positive-definite matrix: + + >>> x = jnp.array([[2., 1.], + ... [1., 2.]]) + + Lower Cholesky factorization: + + >>> jnp.linalg.cholesky(x) + Array([[1.4142135 , 0. ], + [0.70710677, 1.2247449 ]], dtype=float32) + + Upper Cholesky factorization: + + >>> jnp.linalg.cholesky(x, upper=True) + Array([[1.4142135 , 0.70710677], + [0. , 1.2247449 ]], dtype=float32) + + Reconstructing ``x`` from its factorization: + + >>> L = jnp.linalg.cholesky(x) + >>> jnp.allclose(x, L @ L.T) + Array(True, dtype=bool) + """ check_arraylike("jnp.linalg.cholesky", a) a, = promote_dtypes_inexact(jnp.asarray(a)) L = lax_linalg.cholesky(a) @@ -130,7 +190,6 @@ def svd( ... -@implements(np.linalg.svd) @partial( jit, static_argnames=( @@ -147,6 +206,75 @@ def svd( hermitian: bool = False, subset_by_index: tuple[int, int] | None = None, ) -> Array | SVDResult: + r"""Compute the singular value decomposition. + + JAX implementation of :func:`numpy.linalg.svd`, implemented in terms of + :func:`jax.lax.linalg.svd`. + + The SVD of a matrix `A` is given by + + .. math:: + + A = U\Sigma V^H + + - :math:`U` contains the left singular vectors and satisfies :math:`U^HU=I` + - :math:`V` contains the right singular vectors and satisfies :math:`V^HV=I` + - :math:`\Sigma` is a diagonal matrix of singular values. + + Args: + a: input array, of shape ``(..., N, M)`` + full_matrices: if True (default) compute the full matrices; i.e. ``u`` and ``vh`` have + shape ``(..., N, N)`` and ``(..., M, M)``. If False, then the shapes are + ``(..., N, K)`` and ``(..., K, M)`` with ``K = min(N, M)``. + compute_uv: if True (default), return the full SVD ``(u, s, vh)``. If False then return + only the singular values ``s``. + hermitian: if True, assume the matrix is hermitian, which allows for a more efficient + implementation (default=False) + subset_by_index: (TPU-only) Optional 2-tuple [start, end] indicating the range of + indices of singular values to compute. For example, if ``[n-2, n]`` then + ``svd`` computes the two largest singular values and their singular vectors. + Only compatible with ``full_matrices=False``. + + Returns: + A tuple of arrays ``(u, s, vh)`` if ``compute_uv`` is True, otherwise the array ``s``. + + - ``u``: left singular vectors of shape ``(..., N, N)`` if ``full_matrices`` is True + or ``(..., N, K)`` otherwise. + - ``s``: singular values of shape ``(..., K)`` + - ``vh``: conjugate-transposed right singular vectors of shape ``(..., M, M)`` + if ``full_matrices`` is True or ``(..., K, M)`` otherwise. + + where ``K = min(N, M)``. + + See also: + - :func:`jax.scipy.linalg.svd`: SciPy-style SVD API + - :func:`jax.lax.linalg.svd`: XLA-style SVD API + + Examples: + Consider the SVD of a small real-valued array: + + >>> x = jnp.array([[1., 2., 3.], + ... [6., 5., 4.]]) + >>> u, s, vt = jnp.linalg.svd(x, full_matrices=False) + >>> s # doctest: +SKIP + Array([9.361919 , 1.8315067], dtype=float32) + + The singular vectors are in the columns of ``u`` and ``v = vt.T``. These vectors are + orthonormal, which can be demonstrated by comparing the matrix product with the + identity matrix: + + >>> jnp.allclose(u.T @ u, jnp.eye(2), atol=1E-5) + Array(True, dtype=bool) + >>> v = vt.T + >>> jnp.allclose(v.T @ v, jnp.eye(2), atol=1E-5) + Array(True, dtype=bool) + + Given the SVD, ``x`` can be reconstructed via matrix multiplication: + + >>> x_reconstructed = u @ jnp.diag(s) @ vt + >>> jnp.allclose(x_reconstructed, x) + Array(True, dtype=bool) + """ check_arraylike("jnp.linalg.svd", a) a, = promote_dtypes_inexact(jnp.asarray(a)) if hermitian: @@ -182,9 +310,51 @@ def svd( ) -@implements(np.linalg.matrix_power) @partial(jit, static_argnames=('n',)) def matrix_power(a: ArrayLike, n: int) -> Array: + """Raise a square matrix to an integer power. + + JAX implementation of :func:`numpy.linalg.matrix_power`, implemented via + repeated squarings. + + Args: + a: array of shape ``(..., M, M)`` to be raised to the power `n`. + n: the integer exponent to which the matrix should be raised. + + Returns: + Array of shape ``(..., M, M)`` containing the matrix power of a to the n. + + Examples: + >>> a = jnp.array([[1., 2.], + ... [3., 4.]]) + >>> jnp.linalg.matrix_power(a, 3) + Array([[ 37., 54.], + [ 81., 118.]], dtype=float32) + >>> a @ a @ a # equivalent evaluated directly + Array([[ 37., 54.], + [ 81., 118.]], dtype=float32) + + This also supports zero powers: + + >>> jnp.linalg.matrix_power(a, 0) + Array([[1., 0.], + [0., 1.]], dtype=float32) + + and also supports negative powers: + + >>> with jnp.printoptions(precision=3): + ... jnp.linalg.matrix_power(a, -2) + Array([[ 5.5 , -2.5 ], + [-3.75, 1.75]], dtype=float32) + + Negative powers are equivalent to matmul of the inverse: + + >>> inv_a = jnp.linalg.inv(a) + >>> with jnp.printoptions(precision=3): + ... inv_a @ inv_a + Array([[ 5.5 , -2.5 ], + [-3.75, 1.75]], dtype=float32) + """ check_arraylike("jnp.linalg.matrix_power", a) arr, = promote_dtypes_inexact(jnp.asarray(a)) @@ -221,18 +391,62 @@ def matrix_power(a: ArrayLike, n: int) -> Array: return result -@implements(np.linalg.matrix_rank) @jit -def matrix_rank(M: ArrayLike, tol: ArrayLike | None = None) -> Array: +def matrix_rank( + M: ArrayLike, rtol: ArrayLike | None = None, *, + tol: ArrayLike | DeprecatedArg | None = DeprecatedArg()) -> Array: + """Compute the rank of a matrix. + + JAX implementation of :func:`numpy.linalg.matrix_rank`. + + The rank is calculated via the Singular Value Decomposition (SVD), and determined + by the number of singular values greater than the specified tolerance. + + Args: + M: array of shape ``(..., N, K)`` whose rank is to be computed. + rtol: optional array of shape ``(...)`` specifying the tolerance. Singular values + smaller than `rtol * largest_singular_value` are considered to be zero. If + ``rtol`` is None (the default), a reasonable default is chosen based the + floating point precision of the input. + + Returns: + array of shape ``a.shape[-2]`` giving the matrix rank. + + Notes: + The rank calculation may be inaccurate for matrices with very small singular + values or those that are numerically ill-conditioned. Consider adjusting the + ``rtol`` parameter or using a more specialized rank computation method in such cases. + + Examples: + >>> a = jnp.array([[1, 2], + ... [3, 4]]) + >>> jnp.linalg.matrix_rank(a) + Array(2, dtype=int32) + + >>> b = jnp.array([[1, 0], # Rank-deficient matrix + ... [0, 0]]) + >>> jnp.linalg.matrix_rank(b) + Array(1, dtype=int32) + """ check_arraylike("jnp.linalg.matrix_rank", M) + # TODO(micky774): deprecated 2024-5-14, remove after deprecation expires. + if not isinstance(tol, DeprecatedArg): + rtol = tol + del tol + warnings.warn( + "The tol argument for linalg.matrix_rank is deprecated using it will soon raise " + "an error. To prepare for future releases, and suppress this warning, " + "please use rtol instead.", + DeprecationWarning, stacklevel=2 + ) M, = promote_dtypes_inexact(jnp.asarray(M)) if M.ndim < 2: return (M != 0).any().astype(jnp.int32) S = svd(M, full_matrices=False, compute_uv=False) - if tol is None: - tol = S.max(-1) * np.max(M.shape[-2:]).astype(S.dtype) * jnp.finfo(S.dtype).eps - tol = jnp.expand_dims(tol, np.ndim(tol)) - return reductions.sum(S > tol, axis=-1) + if rtol is None: + rtol = S.max(-1) * np.max(M.shape[-2:]).astype(S.dtype) * jnp.finfo(S.dtype).eps + rtol = jnp.expand_dims(rtol, np.ndim(rtol)) + return reductions.sum(S > rtol, axis=-1) @custom_jvp @@ -278,23 +492,44 @@ def _slogdet_qr(a: Array) -> tuple[Array, Array]: sign_taus = reductions.prod(jnp.where(taus[..., :(n-1)] != 0, -1, 1), axis=-1).astype(sign_diag.dtype) return sign_diag * sign_taus, log_abs_det -@implements( - np.linalg.slogdet, - extra_params=textwrap.dedent(""" - method: string, optional - One of ``lu`` or ``qr``, specifying whether the determinant should be - computed using an LU decomposition or a QR decomposition. Defaults to - LU decomposition if ``None``. - """)) + @partial(jit, static_argnames=('method',)) def slogdet(a: ArrayLike, *, method: str | None = None) -> SlogdetResult: + """ + Compute the sign and (natural) logarithm of the determinant of an array. + + JAX implementation of :func:`numpy.linalg.slotdet`. + + Args: + a: array of shape ``(..., M, M)`` for which to compute the sign and log determinant. + method: the method to use for determinant computation. Options are + + - ``'lu'`` (default): use the LU decomposition. + - ``'qr'``: use the QR decomposition. + + Returns: + A tuple of arrays ``(sign, logabsdet)``, each of shape ``a.shape[:-2]`` + + - ``sign`` is the sign of the determinant. + - ``logabsdet`` is the natural log of the determinant's absolute value. + + See also: + :func:`jax.numpy.linalg.det`: direct computation of determinant + + Examples: + >>> a = jnp.array([[1, 2], + ... [3, 4]]) + >>> sign, logabsdet = jnp.linalg.slogdet(a) + >>> sign # -1 indicates negative determinant + Array(-1., dtype=float32) + >>> jnp.exp(logabsdet) # Absolute value of determinant + Array(2., dtype=float32) + """ check_arraylike("jnp.linalg.slogdet", a) a, = promote_dtypes_inexact(jnp.asarray(a)) a_shape = jnp.shape(a) if len(a_shape) < 2 or a_shape[-1] != a_shape[-2]: - msg = "Argument to slogdet() must have shape [..., n, n], got {}" - raise ValueError(msg.format(a_shape)) - + raise ValueError("Argument to slogdet() must have shape [..., n, n], got {a_shape}") if method is None or method == "lu": return SlogdetResult(*_slogdet_lu(a)) elif method == "qr": @@ -424,9 +659,28 @@ def _det_3x3(a: Array) -> Array: @custom_jvp -@implements(np.linalg.det) @jit def det(a: ArrayLike) -> Array: + """ + Compute the determinant of an array. + + JAX implementation of :func:`numpy.linalg.det`. + + Args: + a: array of shape ``(..., M, M)`` for which to compute the determinant. + + Returns: + An array of determinants of shape ``a.shape[:-2]``. + + See also: + :func:`jax.scipy.linalg.det`: Scipy-style API for determinant. + + Examples: + >>> a = jnp.array([[1, 2], + ... [3, 4]]) + >>> jnp.linalg.det(a) + Array(-2., dtype=float32) + """ check_arraylike("jnp.linalg.det", a) a, = promote_dtypes_inexact(jnp.asarray(a)) a_shape = jnp.shape(a) @@ -450,34 +704,127 @@ def _det_jvp(primals, tangents): return y, jnp.trace(z, axis1=-1, axis2=-2) -@implements(np.linalg.eig, lax_description=""" -This differs from :func:`numpy.linalg.eig` in that the return type of -:func:`jax.numpy.linalg.eig` is always ``complex64`` for 32-bit input, -and ``complex128`` for 64-bit input. - -At present, non-symmetric eigendecomposition is only implemented on the CPU -backend. However eigendecomposition for symmetric/Hermitian matrices is -implemented more widely (see :func:`jax.numpy.linalg.eigh`). -""") def eig(a: ArrayLike) -> tuple[Array, Array]: + """ + Compute the eigenvalues and eigenvectors of a square array. + + JAX implementation of :func:`numpy.linalg.eig`. + + Args: + a: array of shape ``(..., M, M)`` for which to compute the eigenvalues and vectors. + + Returns: + A tuple ``(eigenvalues, eigenvectors)`` with + + - ``eigenvalues``: an array of shape ``(..., M)`` containing the eigenvalues. + - ``eigenvectors``: an array of shape ``(..., M, M)``, where column ``v[:, i]`` is the + eigenvector corresponding to the eigenvalue ``w[i]``. + + Notes: + - This differs from :func:`numpy.linalg.eig` in that the return type of + :func:`jax.numpy.linalg.eig` is always complex64 for 32-bit input, and complex128 + for 64-bit input. + - At present, non-symmetric eigendecomposition is only implemented on the CPU backend. + + See also: + - :func:`jax.numpy.linalg.eigh`: eigenvectors and eigenvalues of a Hermitian matrix. + - :func:`jax.numpy.linalg.eigvals`: compute eigenvalues only. + + Examples: + >>> a = jnp.array([[1., 2.], + ... [2., 1.]]) + >>> w, v = jnp.linalg.eig(a) + >>> with jax.numpy.printoptions(precision=4): + ... w + Array([ 3.+0.j, -1.+0.j], dtype=complex64) + >>> v + Array([[ 0.70710677+0.j, -0.70710677+0.j], + [ 0.70710677+0.j, 0.70710677+0.j]], dtype=complex64) + """ check_arraylike("jnp.linalg.eig", a) a, = promote_dtypes_inexact(jnp.asarray(a)) w, v = lax_linalg.eig(a, compute_left_eigenvectors=False) return w, v -@implements(np.linalg.eigvals) @jit def eigvals(a: ArrayLike) -> Array: + """ + Compute the eigenvalues of a general matrix. + + JAX implementation of :func:`numpy.linalg.eigvals`. + + Args: + a: array of shape ``(..., M, M)`` for which to compute the eigenvalues. + + Returns: + An array of shape ``(..., M)`` containing the eigenvalues. + + See also: + - :func:`jax.numpy.linalg.eig`: computes eigenvalues eigenvectors of a general matrix. + - :func:`jax.numpy.linalg.eigh`: computes eigenvalues eigenvectors of a Hermitian matrix. + + Notes: + - This differs from :func:`numpy.linalg.eigvals` in that the return type of + :func:`jax.numpy.linalg.eigvals` is always complex64 for 32-bit input, and + complex128 for 64-bit input. + - At present, non-symmetric eigendecomposition is only implemented on the CPU backend. + + Examples: + >>> a = jnp.array([[1., 2.], + ... [2., 1.]]) + >>> w = jnp.linalg.eigvals(a) + >>> with jnp.printoptions(precision=2): + ... w + Array([ 3.+0.j, -1.+0.j], dtype=complex64) + """ check_arraylike("jnp.linalg.eigvals", a) + a, = promote_dtypes_inexact(jnp.asarray(a)) return lax_linalg.eig(a, compute_left_eigenvectors=False, compute_right_eigenvectors=False)[0] -@implements(np.linalg.eigh) @partial(jit, static_argnames=('UPLO', 'symmetrize_input')) def eigh(a: ArrayLike, UPLO: str | None = None, symmetrize_input: bool = True) -> EighResult: + """ + Compute the eigenvalues and eigenvectors of a Hermitian matrix. + + JAX implementation of :func:`numpy.linalg.eigh`. + + Args: + a: array of shape ``(..., M, M)``, containing the Hermitian (if complex) + or symmetric (if real) matrix. + UPLO: specifies whether the calculation is done with the lower triangular + part of ``a`` (``'L'``, default) or the upper triangular part (``'U'``). + symmetrize_input: if True (default) then input is symmetrized, which leads + to better behavior under automatic differentiation. + + Returns: + A namedtuple ``(eigenvalues, eigenvectors)`` where + + - ``eigenvalues``: an array of shape ``(..., M)`` containing the eigenvalues, + sorted in ascending order. + - ``eigenvectors``: an array of shape ``(..., M, M)``, where column ``v[:, i]`` is the + normalized eigenvector corresponding to the eigenvalue ``w[i]``. + + See also: + - :func:`jax.numpy.linalg.eig`: general eigenvalue decomposition. + - :func:`jax.numpy.linalg.eigvalsh`: compute eigenvalues only. + - :func:`jax.scipy.linalg.eigh`: SciPy API for Hermitian eigendecomposition. + - :func:`jax.lax.linalg.eigh`: XLA API for Hermitian eigendecomposition. + + Examples: + >>> a = jnp.array([[1, -2j], + ... [2j, 1]]) + >>> w, v = jnp.linalg.eigh(a) + >>> w + Array([-1., 3.], dtype=float32) + >>> with jnp.printoptions(precision=3): + ... v + Array([[-0.707+0.j , -0.707+0.j ], + [ 0. +0.707j, 0. -0.707j]], dtype=complex64) + """ check_arraylike("jnp.linalg.eigh", a) if UPLO is None or UPLO == "L": lower = True @@ -492,56 +839,134 @@ def eigh(a: ArrayLike, UPLO: str | None = None, return EighResult(w, v) -@implements(np.linalg.eigvalsh) @partial(jit, static_argnames=('UPLO',)) def eigvalsh(a: ArrayLike, UPLO: str | None = 'L') -> Array: + """ + Compute the eigenvalues of a Hermitian matrix. + + JAX implementation of :func:`numpy.linalg.eigvalsh`. + + Args: + a: array of shape ``(..., M, M)``, containing the Hermitian (if complex) + or symmetric (if real) matrix. + UPLO: specifies whether the calculation is done with the lower triangular + part of ``a`` (``'L'``, default) or the upper triangular part (``'U'``). + + Returns: + An array of shape ``(..., M)`` containing the eigenvalues, sorted in + ascending order. + + See also: + - :func:`jax.numpy.linalg.eig`: general eigenvalue decomposition. + - :func:`jax.numpy.linalg.eigh`: computes eigenvalues and eigenvectors of a + Hermitian matrix. + + Examples: + >>> a = jnp.array([[1, -2j], + ... [2j, 1]]) + >>> w = jnp.linalg.eigvalsh(a) + >>> w + Array([-1., 3.], dtype=float32) + """ check_arraylike("jnp.linalg.eigvalsh", a) + a, = promote_dtypes_inexact(jnp.asarray(a)) w, _ = eigh(a, UPLO) return w +# TODO(micky774): deprecated 2024-5-14, remove wrapper after deprecation expires. +def pinv(a: ArrayLike, rtol: ArrayLike | None = None, + hermitian: bool = False, *, + rcond: ArrayLike | DeprecatedArg | None = DeprecatedArg()) -> Array: + """Compute the (Moore-Penrose) pseudo-inverse of a matrix. + + JAX implementation of :func:`numpy.linalg.pinv`. + + Args: + a: array of shape ``(..., M, N)`` containing matrices to pseudo-invert. + rtol: float or array_like of shape ``a.shape[:-2]``. Specifies the cutoff + for small singular values.of shape ``(...,)``. + Cutoff for small singular values; singular values smaller + ``rtol * largest_singular_value`` are treated as zero. The default is + determined based on the floating point precision of the dtype. + hermitian: if True, then the input is assumed to be Hermitian, and a more + efficient algorithm is used (default: False) + + Returns: + An array of shape ``(..., N, M)`` containing the pseudo-inverse of ``a``. + + See also: + - :func:`jax.numpy.linalg.inv`: multiplicative inverse of a square matrix. + + Notes: + :func:`jax.numpy.linalg.prng` differs from :func:`numpy.linalg.prng` in the + default value of `rcond``: in NumPy, the default is `1e-15`. In JAX, the + default is ``10. * max(num_rows, num_cols) * jnp.finfo(dtype).eps``. + + Examples: + >>> a = jnp.array([[1, 2], + ... [3, 4], + ... [5, 6]]) + >>> a_pinv = jnp.linalg.pinv(a) + >>> a_pinv # doctest: +SKIP + Array([[-1.333332 , -0.33333257, 0.6666657 ], + [ 1.0833322 , 0.33333272, -0.41666582]], dtype=float32) + + The pseudo-inverse operates as a multiplicative inverse so long as the + output is not rank-deficient: + + >>> jnp.allclose(a_pinv @ a, jnp.eye(2), atol=1E-4) + Array(True, dtype=bool) + """ + if not isinstance(rcond, DeprecatedArg): + rtol = rcond + del rcond + warnings.warn( + "The rcond argument for linalg.pinv is deprecated using it will soon " + "raise an error. To prepare for future releases, and suppress this " + "warning, please use rtol instead.", + DeprecationWarning, stacklevel=2 + ) + + return _pinv(a, rtol, hermitian) + + @partial(custom_jvp, nondiff_argnums=(1, 2)) -@implements(np.linalg.pinv, lax_description=textwrap.dedent("""\ - It differs only in default value of `rcond`. In `numpy.linalg.pinv`, the - default `rcond` is `1e-15`. Here the default is - `10. * max(num_rows, num_cols) * jnp.finfo(dtype).eps`. - """)) -@partial(jit, static_argnames=('hermitian',)) -def pinv(a: ArrayLike, rcond: ArrayLike | None = None, - hermitian: bool = False) -> Array: +@partial(jit, static_argnames=('hermitian')) +def _pinv(a: ArrayLike, rtol: ArrayLike | None = None, hermitian: bool = False) -> Array: # Uses same algorithm as # https://github.com/numpy/numpy/blob/v1.17.0/numpy/linalg/linalg.py#L1890-L1979 check_arraylike("jnp.linalg.pinv", a) - arr = jnp.asarray(a) + arr, = promote_dtypes_inexact(jnp.asarray(a)) m, n = arr.shape[-2:] if m == 0 or n == 0: return jnp.empty(arr.shape[:-2] + (n, m), arr.dtype) arr = ufuncs.conj(arr) - if rcond is None: + if rtol is None: max_rows_cols = max(arr.shape[-2:]) - rcond = 10. * max_rows_cols * jnp.array(jnp.finfo(arr.dtype).eps) - rcond = jnp.asarray(rcond) + rtol = 10. * max_rows_cols * jnp.array(jnp.finfo(arr.dtype).eps) + rtol = jnp.asarray(rtol) u, s, vh = svd(arr, full_matrices=False, hermitian=hermitian) - # Singular values less than or equal to ``rcond * largest_singular_value`` + # Singular values less than or equal to ``rtol * largest_singular_value`` # are set to zero. - rcond = lax.expand_dims(rcond[..., jnp.newaxis], range(s.ndim - rcond.ndim - 1)) - cutoff = rcond * s[..., 0:1] + rtol = lax.expand_dims(rtol[..., jnp.newaxis], range(s.ndim - rtol.ndim - 1)) + cutoff = rtol * s[..., 0:1] s = jnp.where(s > cutoff, s, jnp.inf).astype(u.dtype) res = jnp.matmul(vh.mT, ufuncs.divide(u.mT, s[..., jnp.newaxis]), precision=lax.Precision.HIGHEST) return lax.convert_element_type(res, arr.dtype) -@pinv.defjvp +@_pinv.defjvp @jax.default_matmul_precision("float32") -def _pinv_jvp(rcond, hermitian, primals, tangents): +def _pinv_jvp(rtol, hermitian, primals, tangents): # The Differentiation of Pseudo-Inverses and Nonlinear Least Squares Problems # Whose Variables Separate. Author(s): G. H. Golub and V. Pereyra. SIAM # Journal on Numerical Analysis, Vol. 10, No. 2 (Apr., 1973), pp. 413-432. # (via https://en.wikipedia.org/wiki/Moore%E2%80%93Penrose_inverse#Derivative) a, = primals # m x n a_dot, = tangents - p = pinv(a, rcond=rcond, hermitian=hermitian) # n x m + p = pinv(a, rtol=rtol, hermitian=hermitian) # n x m if hermitian: # svd(..., hermitian=True) symmetrizes its input, and the JVP must match. a = _symmetrize(a) @@ -561,9 +986,57 @@ def _pinv_jvp(rcond, hermitian, primals, tangents): return p, p_dot -@implements(np.linalg.inv) @jit def inv(a: ArrayLike) -> Array: + """Return the inverse of a square matrix + + JAX implementation of :func:`numpy.linalg.inv`. + + Args: + a: array of shape ``(..., N, N)`` specifying square array(s) to be inverted. + + Returns: + Array of shape ``(..., N, N)`` containing the inverse of the input. + + Notes: + In most cases, explicitly computing the inverse of a matrix is ill-advised. For + example, to compute ``x = inv(A) @ b``, it is more performant and numerically + precise to use a direct solve, such as :func:`jax.scipy.linalg.solve`. + + See Also: + - :func:`jax.scipy.linalg.inv`: SciPy-style API for matrix inverse + - :func:`jax.numpy.linalg.solve`: direct linear solver + + Examples: + Compute the inverse of a 3x3 matrix + + >>> a = jnp.array([[1., 2., 3.], + ... [2., 4., 2.], + ... [3., 2., 1.]]) + >>> a_inv = jnp.linalg.inv(a) + >>> a_inv # doctest: +SKIP + Array([[ 0. , -0.25 , 0.5 ], + [-0.25 , 0.5 , -0.25000003], + [ 0.5 , -0.25 , 0. ]], dtype=float32) + + Check that multiplying with the inverse gives the identity: + + >>> jnp.allclose(a @ a_inv, jnp.eye(3), atol=1E-5) + Array(True, dtype=bool) + + Multiply the inverse by a vector ``b``, to find a solution to ``a @ x = b`` + + >>> b = jnp.array([1., 4., 2.]) + >>> a_inv @ b + Array([ 0. , 1.25, -0.5 ], dtype=float32) + + Note, however, that explicitly computing the inverse in such a case can lead + to poor performance and loss of precision as the size of the problem grows. + Instead, you should use a direct solver like :func:`jax.numpy.linalg.solve`: + + >>> jnp.linalg.solve(a, b) + Array([ 0. , 1.25, -0.5 ], dtype=float32) + """ check_arraylike("jnp.linalg.inv", a) arr = jnp.asarray(a) if arr.ndim < 2 or arr.shape[-1] != arr.shape[-2]: @@ -573,11 +1046,74 @@ def inv(a: ArrayLike) -> Array: arr, lax.broadcast(jnp.eye(arr.shape[-1], dtype=arr.dtype), arr.shape[:-2])) -@implements(np.linalg.norm) @partial(jit, static_argnames=('ord', 'axis', 'keepdims')) def norm(x: ArrayLike, ord: int | str | None = None, axis: None | tuple[int, ...] | int = None, keepdims: bool = False) -> Array: + """Compute the norm of a matrix or vector. + + JAX implementation of :func:`numpy.linalg.norm`. + + Args: + x: N-dimensional array for which the norm will be computed. + ord: specify the kind of norm to take. Default is Frobenius norm for matrices, + and the 2-norm for vectors. For other options, see Notes below. + axis: integer or sequence of integers specifying the axes over which the norm + will be computed. Defaults to all axes of ``x``. + keepdims: if True, the output array will have the same number of dimensions as + the input, with the size of reduced axes replaced by ``1`` (default: False). + + Returns: + array containing the specified norm of x. + + Notes: + The flavor of norm computed depends on the value of ``ord`` and the number of + axes being reduced. + + For **vector norms** (i.e. a single axis reduction): + + - ``ord=None`` (default) computes the 2-norm + - ``ord=inf`` computes ``max(abs(x))`` + - ``ord=-inf`` computes min(abs(x))`` + - ``ord=0`` computes ``sum(x!=0)`` + - for other numerical values, computes ``sum(abs(x) ** ord)**(1/ord)`` + + For **matrix norms** (i.e. two axes reductions): + + - ``ord='fro'`` or ``ord=None`` (default) computes the Frobenius norm + - ``ord='nuc'`` computes the nuclear norm, or the sum of the singular values + - ``ord=1`` computes ``max(abs(x).sum(0))`` + - ``ord=-1`` computes ``min(abs(x).sum(0))`` + - ``ord=2`` computes the 2-norm, i.e. the largest singular value + - ``ord=-2`` computes the smallest singular value + + Examples: + Vector norms: + + >>> x = jnp.array([3., 4., 12.]) + >>> jnp.linalg.norm(x) + Array(13., dtype=float32) + >>> jnp.linalg.norm(x, ord=1) + Array(19., dtype=float32) + >>> jnp.linalg.norm(x, ord=0) + Array(3., dtype=float32) + + Matrix norms: + + >>> x = jnp.array([[1., 2., 3.], + ... [4., 5., 7.]]) + >>> jnp.linalg.norm(x) # Frobenius norm + Array(10.198039, dtype=float32) + >>> jnp.linalg.norm(x, ord='nuc') # nuclear norm + Array(10.762535, dtype=float32) + >>> jnp.linalg.norm(x, ord=1) # 1-norm + Array(10., dtype=float32) + + Batched vector norm: + + >>> jnp.linalg.norm(x, axis=1) + Array([3.7416575, 9.486833 ], dtype=float32) + """ check_arraylike("jnp.linalg.norm", x) x, = promote_dtypes_inexact(jnp.asarray(x)) x_shape = jnp.shape(x) @@ -675,9 +1211,72 @@ def qr(a: ArrayLike, mode: Literal["r"]) -> Array: ... @overload def qr(a: ArrayLike, mode: str = "reduced") -> Array | QRResult: ... -@implements(np.linalg.qr) @partial(jit, static_argnames=('mode',)) def qr(a: ArrayLike, mode: str = "reduced") -> Array | QRResult: + """Compute the QR decomposition of an array + + JAX implementation of :func:`numpy.linalg.qr`. + + The QR decomposition of a matrix `A` is given by + + .. math:: + + A = QR + + Where `Q` is a unitary matrix (i.e. :math:`Q^HQ=I`) and `R` is an upper-triangular + matrix. + + Args: + a: array of shape (..., M, N) + mode: Computational mode. Supported values are: + + - ``"reduced"`` (default): return `Q` of shape ``(..., M, K)`` and `R` of shape + ``(..., K, N)``, where ``K = min(M, N)``. + - ``"complete"``: return `Q` of shape ``(..., M, M)`` and `R` of shape ``(..., M, N)``. + - ``"raw"``: return lapack-internal representations of shape ``(..., M, N)`` and ``(..., K)``. + - ``"r"``: return `R` only. + + Returns: + A tuple ``(Q, R)`` (if ``mode`` is not ``"r"``) otherwise an array ``R``, + where: + + - ``Q`` is an orthogonal matrix of shape ``(..., M, K)`` (if ``mode`` is ``"reduced"``) + or ``(..., M, M)`` (if ``mode`` is ``"complete"``). + - ``R`` is an upper-triangular matrix of shape ``(..., M, N)`` (if ``mode`` is + ``"r"`` or ``"complete"``) or ``(..., K, N)`` (if ``mode`` is ``"reduced"``) + + with ``K = min(M, N)``. + + See also: + - :func:`jax.scipy.linalg.qr`: SciPy-style QR decomposition API + - :func:`jax.lax.linalg.qr`: XLA-style QR decomposition API + + Examples: + Compute the QR decomposition of a matrix: + + >>> a = jnp.array([[1., 2., 3., 4.], + ... [5., 4., 2., 1.], + ... [6., 3., 1., 5.]]) + >>> Q, R = jnp.linalg.qr(a) + >>> Q # doctest: +SKIP + Array([[-0.12700021, -0.7581426 , -0.6396022 ], + [-0.63500065, -0.43322435, 0.63960224], + [-0.7620008 , 0.48737738, -0.42640156]], dtype=float32) + >>> R # doctest: +SKIP + Array([[-7.8740077, -5.080005 , -2.4130025, -4.953006 ], + [ 0. , -1.7870499, -2.6534991, -1.028908 ], + [ 0. , 0. , -1.0660033, -4.050814 ]], dtype=float32) + + Check that ``Q`` is orthonormal: + + >>> jnp.allclose(Q.T @ Q, jnp.eye(3), atol=1E-5) + Array(True, dtype=bool) + + Reconstruct the input: + + >>> jnp.allclose(Q @ R, a) + Array(True, dtype=bool) + """ check_arraylike("jnp.linalg.qr", a) a, = promote_dtypes_inexact(jnp.asarray(a)) if mode == "raw": @@ -695,9 +1294,44 @@ def qr(a: ArrayLike, mode: str = "reduced") -> Array | QRResult: return QRResult(q, r) -@implements(np.linalg.solve) @jit def solve(a: ArrayLike, b: ArrayLike) -> Array: + """Solve a linear system of equations + + JAX implementation of :func:`numpy.linalg.solve`. + + This solves a (batched) linear system of equations ``a @ x = b`` + for ``x`` given ``a`` and ``b``. + + Args: + a: array of shape ``(..., N, N)``. + b: array of shape ``(N,)`` (for 1-dimensional right-hand-side) or + ``(..., N, M)`` (for batched 2-dimensional right-hand-side). + + Returns: + An array containing the result of the linear solve. The result has shape ``(..., N)`` + if ``b`` is of shape ``(N,)``, and has shape ``(..., N, M)`` otherwise. + + See also: + - :func:`jax.scipy.linalg.solve`: SciPy-style API for solving linear systems. + - :func:`jax.lax.custom_linear_solve`: matrix-free linear solver. + + Examples: + A simple 3x3 linear system: + + >>> A = jnp.array([[1., 2., 3.], + ... [2., 4., 2.], + ... [3., 2., 1.]]) + >>> b = jnp.array([14., 16., 10.]) + >>> x = jnp.linalg.solve(A, b) + >>> x + Array([1., 2., 3.], dtype=float32) + + Confirming that the result solves the system: + + >>> jnp.allclose(A @ x, b) + Array(True, dtype=bool) + """ check_arraylike("jnp.linalg.solve", a, b) a, b = promote_dtypes_inexact(jnp.asarray(a), jnp.asarray(b)) @@ -762,29 +1396,82 @@ def _lstsq(a: ArrayLike, b: ArrayLike, rcond: float | None, *, _jit_lstsq = jit(partial(_lstsq, numpy_resid=False)) -@implements(np.linalg.lstsq, lax_description=textwrap.dedent("""\ - It has two important differences: - - 1. In `numpy.linalg.lstsq`, the default `rcond` is `-1`, and warns that in the future - the default will be `None`. Here, the default rcond is `None`. - 2. In `np.linalg.lstsq` the returned residuals are empty for low-rank or over-determined - solutions. Here, the residuals are returned in all cases, to make the function - compatible with jit. The non-jit compatible numpy behavior can be recovered by - passing numpy_resid=True. - The lstsq function does not currently have a custom JVP rule, so the gradient is - poorly behaved for some inputs, particularly for low-rank `a`. - """)) def lstsq(a: ArrayLike, b: ArrayLike, rcond: float | None = None, *, numpy_resid: bool = False) -> tuple[Array, Array, Array, Array]: + """ + Return the least-squares solution to a linear equation. + + JAX implementation of :func:`numpy.linalg.lstsq`. + + Args: + a: array of shape ``(M, N)`` representing the coefficient matrix. + b: array of shape ``(M,)`` or ``(M, K)`` representing the right-hand side. + rcond: Cut-off ratio for small singular values. Singular values smaller than + ``rcond * largest_singular_value`` are treated as zero. If None (default), + the optimal value will be used to reduce floating point errors. + numpy_resid: If True, compute and return residuals in the same way as NumPy's + `linalg.lstsq`. This is necessary if you want to precisely replicate NumPy's + behavior. If False (default), a more efficient method is used to compute residuals. + + Returns: + Tuple of arrays ``(x, resid, rank, s)`` where + + - ``x`` is a shape ``(N,)`` or ``(N, K)`` array containing the least-squares solution. + - ``resid`` is the sum of squared residual of shape ``()`` or ``(K,)``. + - ``rank`` is the rank of the matrix ``a``. + - ``s`` is the singular values of the matrix ``a``. + + Examples: + >>> a = jnp.array([[1, 2], + ... [3, 4]]) + >>> b = jnp.array([5, 6]) + >>> x, _, _, _ = jnp.linalg.lstsq(a, b) + >>> with jnp.printoptions(precision=3): + ... print(x) + [-4. 4.5] + """ check_arraylike("jnp.linalg.lstsq", a, b) if numpy_resid: return _lstsq(a, b, rcond, numpy_resid=True) return _jit_lstsq(a, b, rcond) -@implements(getattr(np.linalg, "cross", None)) def cross(x1: ArrayLike, x2: ArrayLike, /, *, axis=-1): + r"""Compute the cross-product of two 3D vectors + + JAX implementation of :func:`numpy.linalg.cross` + + Args: + x1: N-dimensional array, with ``x1.shape[axis] == 3`` + x2: N-dimensional array, with ``x2.shape[axis] == 3``, and other axes + broadcast-compatible with ``x1``. + axis: axis along which to take the cross product (default: -1). + + Returns: + array containing the result of the cross-product + + See Also: + :func:`jax.numpy.cross`: more flexible cross-product API. + + Examples: + + Showing that :math:`\hat{x} \times \hat{y} = \hat{z}`: + + >>> x = jnp.array([1., 0., 0.]) + >>> y = jnp.array([0., 1., 0.]) + >>> jnp.linalg.cross(x, y) + Array([0., 0., 1.], dtype=float32) + + Cross product of :math:`\hat{x}` with all three standard unit vectors, + via broadcasting: + + >>> xyz = jnp.eye(3) + >>> jnp.linalg.cross(x, xyz, axis=-1) + Array([[ 0., 0., 0.], + [ 0., 0., 1.], + [ 0., -1., 0.]], dtype=float32) + """ check_arraylike("jnp.linalg.outer", x1, x2) x1, x2 = jnp.asarray(x1), jnp.asarray(x2) if x1.shape[axis] != 3 or x2.shape[axis] != 3: @@ -795,8 +1482,29 @@ def cross(x1: ArrayLike, x2: ArrayLike, /, *, axis=-1): return jnp.cross(x1, x2, axis=axis) -@implements(getattr(np.linalg, "outer", None)) def outer(x1: ArrayLike, x2: ArrayLike, /) -> Array: + """Compute the outer product of two 1-dimensional arrays. + + JAX implementation of :func:`numpy.linalg.outer`. + + Args: + x1: array + x2: array + + Returns: + array containing the outer product of ``x1`` and ``x2`` + + See also: + :func:`jax.numpy.outer`: similar function in the main :mod:`jax.numpy` module. + + Examples: + >>> x1 = jnp.array([1, 2, 3]) + >>> x2 = jnp.array([4, 5, 6]) + >>> jnp.linalg.outer(x1, x2) + Array([[ 4, 5, 6], + [ 8, 10, 12], + [12, 15, 18]], dtype=int32) + """ check_arraylike("jnp.linalg.outer", x1, x2) x1, x2 = jnp.asarray(x1), jnp.asarray(x2) if x1.ndim != 1 or x2.ndim != 1: @@ -804,18 +1512,83 @@ def outer(x1: ArrayLike, x2: ArrayLike, /) -> Array: return x1[:, None] * x2[None, :] -@implements(getattr(np.linalg, "matrix_norm", None)) def matrix_norm(x: ArrayLike, /, *, keepdims: bool = False, ord: str = 'fro') -> Array: - """ - Computes the matrix norm of a matrix (or a stack of matrices) x. + """Compute the norm of a matrix or stack of matrices. + + JAX implementation of :func:`numpy.linalg.matrix_norm` + + Args: + x: array of shape ``(..., M, N)`` for which to take the norm. + keepdims: if True, keep the reduced dimensions in the output. + ord: A string or int specifying the type of norm; default is the Frobenius norm. + See :func:`numpy.linalg.norm` for details on available options. + + Returns: + array containing the norm of ``x``. Has shape ``x.shape[:-2]`` if ``keepdims`` is + False, or shape ``(..., 1, 1)`` if ``keepdims`` is True. + + See also: + - :func:`jax.numpy.linalg.vector_norm`: Norm of a vector or stack of vectors. + - :func:`jax.numpy.linalg.norm`: More general matrix or vector norm. + + Examples: + >>> x = jnp.array([[1, 2, 3], + ... [4, 5, 6], + ... [7, 8, 9]]) + >>> jnp.linalg.matrix_norm(x) + Array(16.881943, dtype=float32) """ check_arraylike('jnp.linalg.matrix_norm', x) return norm(x, ord=ord, keepdims=keepdims, axis=(-2, -1)) -@implements(getattr(np.linalg, "matrix_transpose", None)) def matrix_transpose(x: ArrayLike, /) -> Array: - """Transposes a matrix (or a stack of matrices) x.""" + """Transpose a matrix or stack of matrices. + + JAX implementation of :func:`numpy.linalg.matrix_transpose`. + + Args: + x: array of shape ``(..., M, N)`` + + Returns: + array of shape ``(..., N, M)`` containing the matrix transpose of ``x``. + + See also: + :func:`jax.numpy.transpose`: more general transpose operation. + + Examples: + Transpose of a single matrix: + + >>> x = jnp.array([[1, 2, 3], + ... [4, 5, 6]]) + >>> jnp.linalg.matrix_transpose(x) + Array([[1, 4], + [2, 5], + [3, 6]], dtype=int32) + + Transpose of a stack of matrices: + + >>> x = jnp.array([[[1, 2], + ... [3, 4]], + ... [[5, 6], + ... [7, 8]]]) + >>> jnp.linalg.matrix_transpose(x) + Array([[[1, 3], + [2, 4]], + + [[5, 7], + [6, 8]]], dtype=int32) + + For convenience, the same computation can be done via the + :attr:`~jax.Array.mT` property of JAX array objects: + + >>> x.mT + Array([[[1, 3], + [2, 4]], + + [[5, 7], + [6, 8]]], dtype=int32) + """ check_arraylike('jnp.linalg.matrix_transpose', x) x_arr = jnp.asarray(x) ndim = x_arr.ndim @@ -824,10 +1597,41 @@ def matrix_transpose(x: ArrayLike, /) -> Array: return jax.lax.transpose(x_arr, (*range(ndim - 2), ndim - 1, ndim - 2)) -@implements(getattr(np.linalg, "vector_norm", None)) def vector_norm(x: ArrayLike, /, *, axis: int | None = None, keepdims: bool = False, ord: int | str = 2) -> Array: - """Computes the vector norm of a vector (or batch of vectors) x.""" + """Compute the vector norm of a vector or batch of vectors. + + JAX implementation of :func:`numpy.linalg.vector_norm`. + + Args: + x: N-dimensional array for which to take the norm. + axis: optional axis along which to compute the vector norm. If None (default) + then ``x`` is flattened and the norm is taken over all values. + keepdims: if True, keep the reduced dimensions in the output. + ord: A string or int specifying the type of norm; default is the 2-norm. + See :func:`numpy.linalg.norm` for details on available options. + + Returns: + array containing the norm of ``x``. + + See also: + - :func:`jax.numpy.linalg.matrix_norm`: Norm of a matrix or stack of matrices. + - :func:`jax.numpy.linalg.norm`: More general matrix or vector norm. + + Examples: + Norm of a single vector: + + >>> x = jnp.array([1., 2., 3.]) + >>> jnp.linalg.vector_norm(x) + Array(3.7416575, dtype=float32) + + Norm of a batch of vectors: + + >>> x = jnp.array([[1., 2., 3.], + ... [4., 5., 7.]]) + >>> jnp.linalg.vector_norm(x, axis=1) + Array([3.7416575, 9.486833 ], dtype=float32) + """ check_arraylike('jnp.linalg.vector_norm', x) if axis is None: result = norm(jnp.ravel(x), ord=ord) @@ -837,31 +1641,542 @@ def vector_norm(x: ArrayLike, /, *, axis: int | None = None, keepdims: bool = Fa return norm(x, axis=axis, keepdims=keepdims, ord=ord) -@implements(getattr(np.linalg, "vecdot", None)) -def vecdot(x1: ArrayLike, x2: ArrayLike, /, *, axis: int = -1) -> Array: - return jnp.vecdot(x1, x2, axis=axis) +def vecdot(x1: ArrayLike, x2: ArrayLike, /, *, axis: int = -1, + precision: PrecisionLike = None, + preferred_element_type: DTypeLike | None = None) -> Array: + """Compute the (batched) vector conjugate dot product of two arrays. + + JAX implementation of :func:`numpy.linalg.vecdot`. + + Args: + x1: left-hand side array. + x2: right-hand side array. Size of ``x2[axis]`` must match size of ``x1[axis]``, + and remaining dimensions must be broadcast-compatible. + axis: axis along which to compute the dot product (default: -1) + precision: either ``None`` (default), which means the default precision for + the backend, a :class:`~jax.lax.Precision` enum value (``Precision.DEFAULT``, + ``Precision.HIGH`` or ``Precision.HIGHEST``) or a tuple of two + such values indicating precision of ``x1`` and ``x2``. + preferred_element_type: either ``None`` (default), which means the default + accumulation type for the input types, or a datatype, indicating to + accumulate results to and return a result with that datatype. + + Returns: + array containing the conjugate dot product of ``x1`` and ``x2`` along ``axis``. + The non-contracted dimensions are broadcast together. + + See also: + - :func:`jax.numpy.vecdot`: similar API in the ``jax.numpy`` namespace. + - :func:`jax.numpy.linalg.matmul`: matrix multiplication. + - :func:`jax.numpy.linalg.tensordot`: general tensor dot product. + + Examples: + Vector dot product of two 1D arrays: + + >>> x1 = jnp.array([1, 2, 3]) + >>> x2 = jnp.array([4, 5, 6]) + >>> jnp.linalg.vecdot(x1, x2) + Array(32, dtype=int32) + + Batched vector dot product of two 2D arrays: + + >>> x1 = jnp.array([[1, 2, 3], + ... [4, 5, 6]]) + >>> x2 = jnp.array([[2, 3, 4]]) + >>> jnp.linalg.vecdot(x1, x2, axis=-1) + Array([20, 47], dtype=int32) + """ + check_arraylike('jnp.linalg.vecdot', x1, x2) + return jnp.vecdot(x1, x2, axis=axis, precision=precision, + preferred_element_type=preferred_element_type) + +def matmul(x1: ArrayLike, x2: ArrayLike, /, *, + precision: PrecisionLike = None, + preferred_element_type: DTypeLike | None = None) -> Array: + """Perform a matrix multiplication. -@implements(getattr(np.linalg, "matmul", None)) -def matmul(x1: ArrayLike, x2: ArrayLike, /) -> Array: + JAX implementation of :func:`numpy.linalg.matmul`. + + Args: + x1: first input array, of shape ``(..., N)``. + x2: second input array. Must have shape ``(N,)`` or ``(..., N, M)``. + In the multi-dimensional case, leading dimensions must be broadcast-compatible + with the leading dimensions of ``x1``. + precision: either ``None`` (default), which means the default precision for + the backend, a :class:`~jax.lax.Precision` enum value (``Precision.DEFAULT``, + ``Precision.HIGH`` or ``Precision.HIGHEST``) or a tuple of two + such values indicating precision of ``x1`` and ``x2``. + preferred_element_type: either ``None`` (default), which means the default + accumulation type for the input types, or a datatype, indicating to + accumulate results to and return a result with that datatype. + + Returns: + array containing the matrix product of the inputs. Shape is ``x1.shape[:-1]`` + if ``x2.ndim == 1``, otherwise the shape is ``(..., M)``. + + See Also: + :func:`jax.numpy.matmul`: NumPy API for this function. + :func:`jax.numpy.linalg.vecdot`: batched vector product. + :func:`jax.numpy.linalg.tensordot`: batched tensor product. + + Examples: + Vector dot products: + + >>> x1 = jnp.array([1, 2, 3]) + >>> x2 = jnp.array([4, 5, 6]) + >>> jnp.linalg.matmul(x1, x2) + Array(32, dtype=int32) + + Matrix dot product: + + >>> x1 = jnp.array([[1, 2, 3], + ... [4, 5, 6]]) + >>> x2 = jnp.array([[1, 2], + ... [3, 4], + ... [5, 6]]) + >>> jnp.linalg.matmul(x1, x2) + Array([[22, 28], + [49, 64]], dtype=int32) + + For convenience, in all cases you can do the same computation using + the ``@`` operator: + + >>> x1 @ x2 + Array([[22, 28], + [49, 64]], dtype=int32) + """ check_arraylike('jnp.linalg.matmul', x1, x2) - return jnp.matmul(x1, x2) + return jnp.matmul(x1, x2, precision=precision, + preferred_element_type=preferred_element_type) -@implements(getattr(np.linalg, "tensordot", None)) def tensordot(x1: ArrayLike, x2: ArrayLike, /, *, - axes: int | tuple[Sequence[int], Sequence[int]] = 2) -> Array: + axes: int | tuple[Sequence[int], Sequence[int]] = 2, + precision: PrecisionLike = None, + preferred_element_type: DTypeLike | None = None) -> Array: + """Compute the tensor dot product of two N-dimensional arrays. + + JAX implementation of :func:`numpy.linalg.tensordot`. + + Args: + x1: N-dimensional array + x2: M-dimensional array + axes: integer or tuple of sequences of integers. If an integer `k`, then + sum over the last `k` axes of ``x1`` and the first `k` axes of ``x2``, + in order. If a tuple, then ``axes[0]`` specifies the axes of ``x1`` and + ``axes[1]`` specifies the axes of ``x2``. + precision: either ``None`` (default), which means the default precision for + the backend, a :class:`~jax.lax.Precision` enum value (``Precision.DEFAULT``, + ``Precision.HIGH`` or ``Precision.HIGHEST``) or a tuple of two + such values indicating precision of ``x1`` and ``x2``. + preferred_element_type: either ``None`` (default), which means the default + accumulation type for the input types, or a datatype, indicating to + accumulate results to and return a result with that datatype. + + Returns: + array containing the tensor dot product of the inputs + + See also: + - :func:`jax.numpy.tensordot`: equivalent API in the :mod:`jax.numpy` namespace. + - :func:`jax.numpy.einsum`: NumPy API for more general tensor contractions. + - :func:`jax.lax.dot_general`: XLA API for more general tensor contractions. + + Examples: + >>> x1 = jnp.arange(24.).reshape(2, 3, 4) + >>> x2 = jnp.ones((3, 4, 5)) + >>> jnp.linalg.tensordot(x1, x2) + Array([[ 66., 66., 66., 66., 66.], + [210., 210., 210., 210., 210.]], dtype=float32) + + Equivalent result when specifying the axes as explicit sequences: + + >>> jnp.linalg.tensordot(x1, x2, axes=([1, 2], [0, 1])) + Array([[ 66., 66., 66., 66., 66.], + [210., 210., 210., 210., 210.]], dtype=float32) + + Equivalent result via :func:`~jax.numpy.einsum`: + + >>> jnp.einsum('ijk,jkm->im', x1, x2) + Array([[ 66., 66., 66., 66., 66.], + [210., 210., 210., 210., 210.]], dtype=float32) + + Setting ``axes=1`` for two-dimensional inputs is equivalent to a matrix + multiplication: + + >>> x1 = jnp.array([[1, 2], + ... [3, 4]]) + >>> x2 = jnp.array([[1, 2, 3], + ... [4, 5, 6]]) + >>> jnp.linalg.tensordot(x1, x2, axes=1) + Array([[ 9, 12, 15], + [19, 26, 33]], dtype=int32) + >>> x1 @ x2 + Array([[ 9, 12, 15], + [19, 26, 33]], dtype=int32) + + Setting ``axes=0`` for one-dimensional inputs is equivalent to + :func:`jax.numpy.linalg.outer`: + + >>> x1 = jnp.array([1, 2]) + >>> x2 = jnp.array([1, 2, 3]) + >>> jnp.linalg.tensordot(x1, x2, axes=0) + Array([[1, 2, 3], + [2, 4, 6]], dtype=int32) + >>> jnp.linalg.outer(x1, x2) + Array([[1, 2, 3], + [2, 4, 6]], dtype=int32) + """ check_arraylike('jnp.linalg.tensordot', x1, x2) - return jnp.tensordot(x1, x2, axes=axes) + return jnp.tensordot(x1, x2, axes=axes, precision=precision, + preferred_element_type=preferred_element_type) -@implements(getattr(np.linalg, "svdvals", None)) def svdvals(x: ArrayLike, /) -> Array: + """Compute the singular values of a matrix. + + JAX implementation of :func:`numpy.linalg.svdvals`. + + Args: + x: array of shape ``(..., M, N)`` for which singular values will be computed. + + Returns: + array of singular values of shape ``(..., K)`` with ``K = min(M, N)``. + + See also: + :func:`jax.numpy.linalg.svd`: compute singular values and singular vectors + + Examples: + >>> x = jnp.array([[1, 2, 3], + ... [4, 5, 6]]) + >>> jnp.linalg.svdvals(x) + Array([9.508031 , 0.7728694], dtype=float32) + """ check_arraylike('jnp.linalg.svdvals', x) return svd(x, compute_uv=False, hermitian=False) -@implements(getattr(np.linalg, "diagonal", None)) def diagonal(x: ArrayLike, /, *, offset: int = 0) -> Array: + """Extract the diagonal of an matrix or stack of matrices. + + JAX implementation of :func:`numpy.linalg.diagonal`. + + Args: + x: array of shape ``(..., M, N)`` from which the diagonal will be extracted. + offset: positive or negative offset from the main diagonal. + + Returns: + Array of shape ``(..., K)`` where ``K`` is the length of the specified diagonal. + + See Also: + - :func:`jax.numpy.diagonal`: more general functionality for extracting diagonals. + - :func:`jax.numpy.diag`: create a diagonal matrix from values. + + Examples: + Diagonals of a single matrix: + + >>> x = jnp.array([[1, 2, 3, 4], + ... [5, 6, 7, 8], + ... [9, 10, 11, 12]]) + >>> jnp.linalg.diagonal(x) + Array([ 1, 6, 11], dtype=int32) + >>> jnp.linalg.diagonal(x, offset=1) + Array([ 2, 7, 12], dtype=int32) + >>> jnp.linalg.diagonal(x, offset=-1) + Array([ 5, 10], dtype=int32) + + Batched diagonals: + + >>> x = jnp.arange(24).reshape(2, 3, 4) + >>> jnp.linalg.diagonal(x) + Array([[ 0, 5, 10], + [12, 17, 22]], dtype=int32) + """ check_arraylike('jnp.linalg.diagonal', x) return jnp.diagonal(x, offset=offset, axis1=-2, axis2=-1) + + +def tensorinv(a: ArrayLike, ind: int = 2) -> Array: + """Compute the tensor inverse of an array. + + JAX implementation of :func:`numpy.linalg.tensorinv`. + + This computes the inverse of the :func:`~jax.numpy.linalg.tensordot` + operation with the same ``ind`` value. + + Args: + a: array to be inverted. Must have ``prod(a.shape[:ind]) == prod(a.shape[ind:])`` + ind: positive integer specifying the number of indices in the tensor product. + + Returns: + array of shape ``(*a.shape[ind:], *a.shape[:ind])`` containing the + tensor inverse of ``a``. + + See also: + - :func:`jax.numpy.linalg.tensordot` + - :func:`jax.numpy.linalg.tensorsolve` + + Examples: + >>> key = jax.random.key(1337) + >>> x = jax.random.normal(key, shape=(2, 2, 4)) + >>> xinv = jnp.linalg.tensorinv(x, 2) + >>> xinv_x = jnp.linalg.tensordot(xinv, x, axes=2) + >>> jnp.allclose(xinv_x, jnp.eye(4), atol=1E-4) + Array(True, dtype=bool) + """ + check_arraylike("tensorinv", a) + arr = jnp.asarray(a) + ind = operator.index(ind) + if ind <= 0: + raise ValueError(f"ind must be a positive integer; got {ind=}") + contracting_shape, batch_shape = arr.shape[:ind], arr.shape[ind:] + flatshape = (math.prod(contracting_shape), math.prod(batch_shape)) + if flatshape[0] != flatshape[1]: + raise ValueError("tensorinv is only possible when the product of the first" + " `ind` dimensions equals that of the remaining dimensions." + f" got {arr.shape=} with {ind=}.") + return inv(arr.reshape(flatshape)).reshape(*batch_shape, *contracting_shape) + + +def tensorsolve(a: ArrayLike, b: ArrayLike, axes: tuple[int, ...] | None = None) -> Array: + """Solve the tensor equation a x = b for x. + + JAX implementation of :func:`numpy.linalg.tensorsolve`. + + Args: + a: input array. After reordering via ``axes`` (see below), shape must be + ``(*b.shape, *x.shape)``. + b: right-hand-side array. + axes: optional tuple specifying axes of ``a`` that should be moved to the end + + Returns: + array x such that after reordering of axes of ``a``, ``tensordot(a, x, x.ndim)`` + is equivalent to ``b``. + + See also: + - :func:`jax.numpy.linalg.tensordot` + - :func:`jax.numpy.linalg.tensorinv` + + Examples: + >>> key1, key2 = jax.random.split(jax.random.key(8675309)) + >>> a = jax.random.normal(key1, shape=(2, 2, 4)) + >>> b = jax.random.normal(key2, shape=(2, 2)) + >>> x = jnp.linalg.tensorsolve(a, b) + >>> x.shape + (4,) + + Now show that ``x`` can be used to reconstruct ``b`` using + :func:`~jax.numpy.linalg.tensordot`: + + >>> b_reconstructed = jnp.linalg.tensordot(a, x, axes=x.ndim) + >>> jnp.allclose(b, b_reconstructed) + Array(True, dtype=bool) + """ + check_arraylike("tensorsolve", a, b) + a_arr, b_arr = jnp.asarray(a), jnp.asarray(b) + if axes is not None: + a_arr = jnp.moveaxis(a_arr, axes, len(axes) * (a_arr.ndim - 1,)) + out_shape = a_arr.shape[b_arr.ndim:] + if a_arr.shape[:b_arr.ndim] != b_arr.shape: + raise ValueError("After moving axes to end, leading shape of a must match shape of b." + f" got a.shape={a_arr.shape}, b.shape={b_arr.shape}") + if b_arr.size != math.prod(out_shape): + raise ValueError("Input arrays must have prod(a.shape[:b.ndim]) == prod(a.shape[b.ndim:]);" + f" got a.shape={a_arr.shape}, b.ndim={b_arr.ndim}.") + a_arr = a_arr.reshape(b_arr.size, math.prod(out_shape)) + return solve(a_arr, b_arr.ravel()).reshape(out_shape) + + +def multi_dot(arrays: Sequence[ArrayLike], *, precision: PrecisionLike = None) -> Array: + """Efficiently compute matrix products between a sequence of arrays. + + JAX implementation of :func:`numpy.linalg.multi_dot`. + + JAX internally uses the opt_einsum library to compute the most efficient + operation order. + + Args: + arrays: sequence of arrays. All must be two-dimensional, except the first + and last which may be one-dimensional. + precision: either ``None`` (default), which means the default precision for + the backend, a :class:`~jax.lax.Precision` enum value (``Precision.DEFAULT``, + ``Precision.HIGH`` or ``Precision.HIGHEST``). + + Returns: + an array representing the equivalent of ``reduce(jnp.matmul, arrays)``, but + evaluated in the optimal order. + + This function exists because the cost of computing sequences of matmul operations + can differ vastly depending on the order in which the operations are evaluated. + For a single matmul, the number of floating point operations (flops) required to + compute a matrix product can be approximated this way: + + >>> def approx_flops(x, y): + ... # for 2D x and y, with x.shape[1] == y.shape[0] + ... return 2 * x.shape[0] * x.shape[1] * y.shape[1] + + Suppose we have three matrices that we'd like to multiply in sequence: + + >>> key1, key2, key3 = jax.random.split(jax.random.key(0), 3) + >>> x = jax.random.normal(key1, shape=(200, 5)) + >>> y = jax.random.normal(key2, shape=(5, 100)) + >>> z = jax.random.normal(key3, shape=(100, 10)) + + Because of associativity of matrix products, there are two orders in which we might + evaluate the product ``x @ y @ z``, and both produce equivalent outputs up to floating + point precision: + + >>> result1 = (x @ y) @ z + >>> result2 = x @ (y @ z) + >>> jnp.allclose(result1, result2, atol=1E-4) + Array(True, dtype=bool) + + But the computational cost of these differ greatly: + + >>> print("(x @ y) @ z flops:", approx_flops(x, y) + approx_flops(x @ y, z)) + (x @ y) @ z flops: 600000 + >>> print("x @ (y @ z) flops:", approx_flops(y, z) + approx_flops(x, y @ z)) + x @ (y @ z) flops: 30000 + + The second approach is about 20x more efficient in terms of estimated flops! + + ``multi_dot`` is a function that will automatically choose the fastest + computational path for such problems: + + >>> result3 = jnp.linalg.multi_dot([x, y, z]) + >>> jnp.allclose(result1, result3, atol=1E-4) + Array(True, dtype=bool) + + We can use JAX's :ref:`ahead-of-time-lowering` tools to estimate the total flops + of each approach, and confirm that ``multi_dot`` is choosing the more efficient + option: + + >>> jax.jit(lambda x, y, z: (x @ y) @ z).lower(x, y, z).cost_analysis()['flops'] + 600000.0 + >>> jax.jit(lambda x, y, z: x @ (y @ z)).lower(x, y, z).cost_analysis()['flops'] + 30000.0 + >>> jax.jit(jnp.linalg.multi_dot).lower([x, y, z]).cost_analysis()['flops'] + 30000.0 + """ + check_arraylike('jnp.linalg.multi_dot', *arrays) + arrs: list[Array] = list(map(jnp.asarray, arrays)) + if len(arrs) < 2: + raise ValueError(f"multi_dot requires at least two arrays; got len(arrays)={len(arrs)}") + if not (arrs[0].ndim in (1, 2) and arrs[-1].ndim in (1, 2) and + all(a.ndim == 2 for a in arrs[1:-1])): + raise ValueError("multi_dot: input arrays must all be two-dimensional, except for" + " the first and last array which may be 1 or 2 dimensional." + f" Got array shapes {[a.shape for a in arrs]}") + if any(a.shape[-1] != b.shape[0] for a, b in zip(arrs[:-1], arrs[1:])): + raise ValueError("multi_dot: last dimension of each array must match first dimension" + f" of following array. Got array shapes {[a.shape for a in arrs]}") + einsum_axes: list[tuple[int, ...]] = [(i, i+1) for i in range(len(arrs))] + if arrs[0].ndim == 1: + einsum_axes[0] = einsum_axes[0][1:] + if arrs[-1].ndim == 1: + einsum_axes[-1] = einsum_axes[-1][:1] + return jnp.einsum(*itertools.chain(*zip(arrs, einsum_axes)), # type: ignore[call-overload] + optimize='optimal', precision=precision) + + +@partial(jit, static_argnames=['p']) +def cond(x: ArrayLike, p=None): + """Compute the condition number of a matrix. + + JAX implementation of :func:`numpy.linalg.cond`. + + The condition number is defined as ``norm(x, p) * norm(inv(x), p)``. For ``p = 2`` + (the default), the condition number is the ratio of the largest to the smallest + singular value. + + Args: + x: array of shape ``(..., M, N)`` for which to compute the condition number. + p: the order of the norm to use. One of ``{None, 1, -1, 2, -2, inf, -inf, 'fro'}``; + see :func:`jax.numpy.linalg.norm` for the meaning of these. The default is ``p = None``, + which is equivalent to ``p = 2``. If not in ``{None, 2, -2}`` then ``x`` must be square, + i.e. ``M = N``. + + Returns: + array of shape ``x.shape[:-2]`` containing the condition number. + + See also: + :func:`jax.numpy.linalg.norm` + + Examples: + + Well-conditioned matrix: + + >>> x = jnp.array([[1, 2], + ... [2, 1]]) + >>> jnp.linalg.cond(x) + Array(3., dtype=float32) + + Ill-conditioned matrix: + + >>> x = jnp.array([[1, 2], + ... [0, 0]]) + >>> jnp.linalg.cond(x) + Array(inf, dtype=float32) + """ + check_arraylike("cond", x) + arr = jnp.asarray(x) + if arr.ndim < 2: + raise ValueError(f"jnp.linalg.cond: input array must be at least 2D; got {arr.shape=}") + if arr.shape[-1] == 0 or arr.shape[-2] == 0: + raise ValueError(f"jnp.linalg.cond: input array must not be empty; got {arr.shape=}") + if p is None or p == 2: + s = svdvals(x) + return s[..., 0] / s[..., -1] + elif p == -2: + s = svdvals(x) + r = s[..., -1] / s[..., 0] + else: + if arr.shape[-2] != arr.shape[-1]: + raise ValueError(f"jnp.linalg.cond: for {p=}, array must be square; got {arr.shape=}") + r = norm(x, ord=p, axis=(-2, -1)) * norm(inv(x), ord=p, axis=(-2, -1)) + # Convert NaNs to infs where original array has no NaNs. + return jnp.where(ufuncs.isnan(r) & ~ufuncs.isnan(x).any(axis=(-2, -1)), jnp.inf, r) + + +def trace(x: ArrayLike, /, *, + offset: int = 0, dtype: DTypeLike | None = None) -> Array: + """Compute the trace of a matrix. + + JAX implementation of :func:`numpy.linalg.trace`. + + Args: + x: array of shape ``(..., M, N)`` and whose innermost two + dimensions form MxN matrices for which to take the trace. + offset: positive or negative offset from the main diagonal + (default: 0). + dtype: data type of the returned array (default: ``None``). If ``None``, + then output dtype will match the dtype of ``x``, promoted to default + precision in the case of integer types. + + Returns: + array of batched traces with shape ``x.shape[:-2]`` + + See also: + - :func:`jax.numpy.trace`: similar API in the ``jax.numpy`` namespace. + + Examples: + Trace of a single matrix: + + >>> x = jnp.array([[1, 2, 3, 4], + ... [5, 6, 7, 8], + ... [9, 10, 11, 12]]) + >>> jnp.linalg.trace(x) + Array(18, dtype=int32) + >>> jnp.linalg.trace(x, offset=1) + Array(21, dtype=int32) + >>> jnp.linalg.trace(x, offset=-1, dtype="float32") + Array(15., dtype=float32) + + Batched traces: + + >>> x = jnp.arange(24).reshape(2, 3, 4) + >>> jnp.linalg.trace(x) + Array([15, 51], dtype=int32) + """ + check_arraylike('jnp.linalg.trace', x) + return jnp.trace(x, offset=offset, axis1=-2, axis2=-1, dtype=dtype) diff --git a/jax/_src/numpy/polynomial.py b/jax/_src/numpy/polynomial.py index dc2ff57c0d28..45595c4387a2 100644 --- a/jax/_src/numpy/polynomial.py +++ b/jax/_src/numpy/polynomial.py @@ -46,7 +46,7 @@ def _roots_no_zeros(p: Array) -> Array: @jit -def _roots_with_zeros(p: Array, num_leading_zeros: int) -> Array: +def _roots_with_zeros(p: Array, num_leading_zeros: Array | int) -> Array: # Avoid lapack errors when p is all zero p = _where(len(p) == num_leading_zeros, 1.0, p) # Roll any leading zeros to the end & compute the roots @@ -57,35 +57,49 @@ def _roots_with_zeros(p: Array, num_leading_zeros: int) -> Array: return _where(arange(roots.size) < roots.size - num_leading_zeros, roots, complex(np.nan, np.nan)) -@implements(np.roots, lax_description="""\ -Unlike the numpy version of this function, the JAX version returns the roots in -a complex array regardless of the values of the roots. Additionally, the jax -version of this function adds the ``strip_zeros`` function which must be set to -False for the function to be compatible with JIT and other JAX transformations. -With ``strip_zeros=False``, if your coefficients have leading zeros, the -roots will be padded with NaN values: - ->>> coeffs = jnp.array([0, 1, 2]) - -# The default behavior matches numpy and strips leading zeros: ->>> jnp.roots(coeffs) -Array([-2.+0.j], dtype=complex64) - -# With strip_zeros=False, extra roots are set to NaN: ->>> jnp.roots(coeffs, strip_zeros=False) -Array([-2. +0.j, nan+nanj], dtype=complex64) -""", -extra_params=""" -strip_zeros : bool, default=True - If set to True, then leading zeros in the coefficients will be stripped, similar - to :func:`numpy.roots`. If set to False, leading zeros will not be stripped, and - undefined roots will be represented by NaN values in the function output. - ``strip_zeros`` must be set to ``False`` for the function to be compatible with - :func:`jax.jit` and other JAX transformations. -""") def roots(p: ArrayLike, *, strip_zeros: bool = True) -> Array: + r"""Returns the roots of a polynomial given the coefficients ``p``. + + JAX implementations of :func:`numpy.roots`. + + Args: + p: Array of polynomial coefficients having rank-1. + strip_zeros : bool, default=True. If True, then leading zeros in the + coefficients will be stripped, similar to :func:`numpy.roots`. If set to + False, leading zeros will not be stripped, and undefined roots will be + represented by NaN values in the function output. ``strip_zeros`` must be + set to ``False`` for the function to be compatible with :func:`jax.jit` and + other JAX transformations. + + Returns: + An array containing the roots of the polynomial. + + Note: + Unlike ``np.roots`` of this function, the ``jnp.roots`` returns the roots + in a complex array regardless of the values of the roots. + + See Also: + - :func:`jax.numpy.poly`: Finds the polynomial coefficients of the given + sequence of roots. + - :func:`jax.numpy.polyfit`: Least squares polynomial fit to data. + - :func:`jax.numpy.polyval`: Evaluate a polynomial at specific values. + + Examples: + >>> coeffs = jnp.array([0, 1, 2]) + + The default behavior matches numpy and strips leading zeros: + + >>> jnp.roots(coeffs) + Array([-2.+0.j], dtype=complex64) + + With ``strip_zeros=False``, extra roots are set to NaN: + + >>> jnp.roots(coeffs, strip_zeros=False) + Array([-2. +0.j, nan+nanj], dtype=complex64) + """ check_arraylike("roots", p) - p_arr = atleast_1d(*promote_dtypes_inexact(p)) + p_arr = atleast_1d(promote_dtypes_inexact(p)[0]) + del p if p_arr.ndim != 1: raise ValueError("Input must be a rank-1 array.") if p_arr.size < 2: @@ -96,57 +110,155 @@ def roots(p: ArrayLike, *, strip_zeros: bool = True) -> Array: num_leading_zeros = core.concrete_or_error(int, num_leading_zeros, "The error occurred in the jnp.roots() function. To use this within a " "JIT-compiled context, pass strip_zeros=False, but be aware that leading zeros " - "will be result in some returned roots being set to NaN.") + "will result in some returned roots being set to NaN.") return _roots_no_zeros(p_arr[num_leading_zeros:]) else: return _roots_with_zeros(p_arr, num_leading_zeros) -_POLYFIT_DOC = """\ -Unlike NumPy's implementation of polyfit, :py:func:`jax.numpy.polyfit` will not warn on rank reduction, which indicates an ill conditioned matrix -Also, it works best on rcond <= 10e-3 values. -""" -@implements(np.polyfit, lax_description=_POLYFIT_DOC) @partial(jit, static_argnames=('deg', 'rcond', 'full', 'cov')) -def polyfit(x: Array, y: Array, deg: int, rcond: float | None = None, - full: bool = False, w: Array | None = None, cov: bool = False +def polyfit(x: ArrayLike, y: ArrayLike, deg: int, rcond: float | None = None, + full: bool = False, w: ArrayLike | None = None, cov: bool = False ) -> Array | tuple[Array, ...]: - check_arraylike("polyfit", x, y) + r"""Least squares polynomial fit to data. + + Jax implementation of :func:`numpy.polyfit`. + + Given a set of data points ``(x, y)`` and degree of polynomial ``deg``, the + function finds a polynomial equation of the form: + + .. math:: + + y = p(x) = p[0] x^{deg} + p[1] x^{deg - 1} + ... + p[deg] + + Args: + x: Array of data points of shape ``(M,)``. + y: Array of data points of shape ``(M,)`` or ``(M, K)``. + deg: Degree of the polynomials. It must be specified statically. + rcond: Relative condition number of the fit. Default value is ``len(x) * eps``. + It must be specified statically. + full: Switch that controls the return value. Default is ``False`` which + restricts the return value to the array of polynomail coefficients ``p``. + If ``True``, the function returns a tuple ``(p, resids, rank, s, rcond)``. + It must be specified statically. + w: Array of weights of shape ``(M,)``. If None, all data points are considered + to have equal weight. If not None, the weight :math:`w_i` is applied to the + unsquared residual of :math:`y_i - \widehat{y}_i` at :math:`x_i`, where + :math:`\widehat{y}_i` is the fitted value of :math:`y_i`. Default is None. + cov: Boolean or string. If ``True``, returns the covariance matrix scaled + by ``resids/(M-deg-1)`` along with ploynomial coefficients. If + ``cov='unscaled'``, returns the unscaaled version of covariance matrix. + Default is ``False``. ``cov`` is ignored if ``full=True``. It must be + specified statically. + + Returns: + - An array polynomial coefficients ``p`` if ``full=False`` and ``cov=False``. + + - A tuple of arrays ``(p, resids, rank, s, rcond)`` if ``full=True``. Where + + - ``p`` is an array of shape ``(M,)`` or ``(M, K)`` containing the polynomial + coefficients. + - ``resids`` is the sum of squared residual of shape () or (K,). + - ``rank`` is the rank of the matrix ``x``. + - ``s`` is the singular values of the matrix ``x``. + - ``rcond`` as the array. + - A tuple of arrays ``(p, C)`` if ``full=False`` and ``cov=True``. Where + + - ``p`` is an array of shape ``(M,)`` or ``(M, K)`` containing the polynomial + coefficients. + - ``C`` is the covariance matrix of polynomial coefficients of shape + ``(deg + 1, deg + 1)`` or ``(deg + 1, deg + 1, 1)``. + + Note: + Unlike :func:`numpy.polyfit` implementation of polyfit, :func:`jax.numpy.polyfit` + will not warn on rank reduction, which indicates an ill conditioned matrix. + + See Also: + - :func:`jax.numpy.poly`: Finds the polynomial coefficients of the given + sequence of roots. + - :func:`jax.numpy.polyval`: Evaluate a polynomial at specific values. + - :func:`jax.numpy.roots`: Computes the roots of a polynomial for given + coefficients. + + Examples: + >>> x = jnp.array([3., 6., 9., 4.]) + >>> y = jnp.array([[0, 1, 2], + ... [2, 5, 7], + ... [8, 4, 9], + ... [1, 6, 3]]) + >>> p = jnp.polyfit(x, y, 2) + >>> with jnp.printoptions(precision=2, suppress=True): + ... print(p) + [[ 0.2 -0.35 -0.14] + [-1.17 4.47 2.96] + [ 1.95 -8.21 -5.93]] + + If ``full=True``, returns a tuple of arrays as follows: + + >>> p, resids, rank, s, rcond = jnp.polyfit(x, y, 2, full=True) + >>> with jnp.printoptions(precision=2, suppress=True): + ... print("Polynomial Coefficients:", "\n", p, "\n", + ... "Residuals:", resids, "\n", + ... "Rank:", rank, "\n", + ... "s:", s, "\n", + ... "rcond:", rcond) + Polynomial Coefficients: + [[ 0.2 -0.35 -0.14] + [-1.17 4.47 2.96] + [ 1.95 -8.21 -5.93]] + Residuals: [0.37 5.94 0.61] + Rank: 3 + s: [1.67 0.47 0.04] + rcond: 4.7683716e-07 + + If ``cov=True`` and ``full=False``, returns a tuple of arrays having + polynomial coefficients and covariance matrix. + + >>> p, C = jnp.polyfit(x, y, 2, cov=True) + >>> p.shape, C.shape + ((3, 3), (3, 3, 1)) + """ + if w is None: + check_arraylike("polyfit", x, y) + else: + check_arraylike("polyfit", x, y, w) deg = core.concrete_or_error(int, deg, "deg must be int") order = deg + 1 # check arguments + x_arr, y_arr = asarray(x), asarray(y) + del x, y if deg < 0: raise ValueError("expected deg >= 0") - if x.ndim != 1: + if x_arr.ndim != 1: raise TypeError("expected 1D vector for x") - if x.size == 0: + if x_arr.size == 0: raise TypeError("expected non-empty vector for x") - if y.ndim < 1 or y.ndim > 2: + if y_arr.ndim < 1 or y_arr.ndim > 2: raise TypeError("expected 1D or 2D array for y") - if x.shape[0] != y.shape[0]: + if x_arr.shape[0] != y_arr.shape[0]: raise TypeError("expected x and y to have same length") # set rcond if rcond is None: - rcond = len(x) * finfo(x.dtype).eps + rcond = len(x_arr) * finfo(x_arr.dtype).eps rcond = core.concrete_or_error(float, rcond, "rcond must be float") # set up least squares equation for powers of x - lhs = vander(x, order) - rhs = y + lhs = vander(x_arr, order) + rhs = y_arr # apply weighting if w is not None: - check_arraylike("polyfit", w) w, = promote_dtypes_inexact(w) - if w.ndim != 1: + w_arr = asarray(w) + if w_arr.ndim != 1: raise TypeError("expected a 1-d array for weights") - if w.shape[0] != y.shape[0]: + if w_arr.shape[0] != y_arr.shape[0]: raise TypeError("expected w and y to have the same length") - lhs *= w[:, np.newaxis] + lhs *= w_arr[:, np.newaxis] if rhs.ndim == 2: - rhs *= w[:, np.newaxis] + rhs *= w_arr[:, np.newaxis] else: - rhs *= w + rhs *= w_arr # scale lhs to improve condition number and solve scale = sqrt((lhs*lhs).sum(axis=0)) @@ -162,12 +274,12 @@ def polyfit(x: Array, y: Array, deg: int, rcond: float | None = None, if cov == "unscaled": fac = 1 else: - if len(x) <= order: + if len(x_arr) <= order: raise ValueError("the number of data points must exceed order " "to scale the covariance matrix") - fac = resids / (len(x) - order) + fac = resids / (len(x_arr) - order) fac = fac[0] #making np.array() of shape (1,) to int - if y.ndim == 1: + if y_arr.ndim == 1: return c, Vbase * fac else: return c, Vbase[:, :, np.newaxis] * fac @@ -175,80 +287,214 @@ def polyfit(x: Array, y: Array, deg: int, rcond: float | None = None, return c -_POLY_DOC = """\ -This differs from np.poly when an integer array is given. -np.poly returns a result with dtype float64 in this case. -jax returns a result with an inexact type, but not necessarily -float64. +@jit +def poly(seq_of_zeros: ArrayLike) -> Array: + r"""Returns the coefficients of a polynomial for the given sequence of roots. -This also differs from np.poly when the input array strictly -contains pairs of complex conjugates, e.g. [1j, -1j, 1-1j, 1+1j]. -np.poly returns an array with a real dtype in such cases. -jax returns an array with a complex dtype in such cases. -""" + JAX implementation of :func:`numpy.poly`. -@implements(np.poly, lax_description=_POLY_DOC) -@jit -def poly(seq_of_zeros: Array) -> Array: + Args: + seq_of_zeros: A scalar or an array of roots of the polynomial of shape ``(M,)`` + or ``(M, M)``. + + Returns: + An array containing the coefficients of the polynomial. The dtype of the + output is always promoted to inexact. + + Note: + + :func:`jax.numpy.poly` differs from :func:`numpy.poly`: + + - When the input is a scalar, ``np.poly`` raises a ``TypeError``, whereas + ``jnp.poly`` treats scalars the same as length-1 arrays. + - For complex-valued or square-shaped inputs, ``jnp.poly`` always returns + complex coefficients, whereas ``np.poly`` may return real or complex + depending on their values. + + See also: + - :func:`jax.numpy.polyfit`: Least squares polynomial fit. + - :func:`jax.numpy.polyval`: Evaluate a polynomial at specific values. + - :func:`jax.numpy.roots`: Computes the roots of a polynomial for given + coefficients. + + Example: + + Scalar inputs: + + >>> jnp.poly(1) + Array([ 1., -1.], dtype=float32) + + Input array with integer values: + + >>> x = jnp.array([1, 2, 3]) + >>> jnp.poly(x) + Array([ 1., -6., 11., -6.], dtype=float32) + + Input array with complex conjugates: + + >>> x = jnp.array([2, 1+2j, 1-2j]) + >>> jnp.poly(x) + Array([ 1.+0.j, -4.+0.j, 9.+0.j, -10.+0.j], dtype=complex64) + + Input array as square matrix with real valued inputs: + + >>> x = jnp.array([[2, 1, 5], + ... [3, 4, 7], + ... [1, 3, 5]]) + >>> jnp.round(jnp.poly(x)) + Array([ 1.+0.j, -11.-0.j, 9.+0.j, -15.+0.j], dtype=complex64) + """ check_arraylike('poly', seq_of_zeros) seq_of_zeros, = promote_dtypes_inexact(seq_of_zeros) - seq_of_zeros = atleast_1d(seq_of_zeros) + seq_of_zeros_arr = atleast_1d(seq_of_zeros) + del seq_of_zeros - sh = seq_of_zeros.shape + sh = seq_of_zeros_arr.shape if len(sh) == 2 and sh[0] == sh[1] and sh[0] != 0: # import at runtime to avoid circular import from jax._src.numpy import linalg - seq_of_zeros = linalg.eigvals(seq_of_zeros) + seq_of_zeros_arr = linalg.eigvals(seq_of_zeros_arr) - if seq_of_zeros.ndim != 1: + if seq_of_zeros_arr.ndim != 1: raise ValueError("input must be 1d or non-empty square 2d array.") - dt = seq_of_zeros.dtype - if len(seq_of_zeros) == 0: + dt = seq_of_zeros_arr.dtype + if len(seq_of_zeros_arr) == 0: return ones((), dtype=dt) a = ones((1,), dtype=dt) - for k in range(len(seq_of_zeros)): - a = convolve(a, array([1, -seq_of_zeros[k]], dtype=dt), mode='full') + for k in range(len(seq_of_zeros_arr)): + a = convolve(a, array([1, -seq_of_zeros_arr[k]], dtype=dt), mode='full') return a -@implements(np.polyval, lax_description="""\ -The ``unroll`` parameter is JAX specific. It does not effect correctness but can -have a major impact on performance for evaluating high-order polynomials. The -parameter controls the number of unrolled steps with ``lax.scan`` inside the -``polyval`` implementation. Consider setting ``unroll=128`` (or even higher) to -improve runtime performance on accelerators, at the cost of increased -compilation time. -""") @partial(jit, static_argnames=['unroll']) -def polyval(p: Array, x: Array, *, unroll: int = 16) -> Array: +def polyval(p: ArrayLike, x: ArrayLike, *, unroll: int = 16) -> Array: + r"""Evaluates the polynomial at specific values. + + JAX implementations of :func:`numpy.polyval`. + + For the 1D-polynomial coefficients ``p`` of length ``M``, the function returns + the value: + + .. math:: + + p_0 x^{M - 1} + p_1 x^{M - 2} + ... + p_{M - 1} + + Args: + p: An array of polynomial coefficients of shape ``(M,)``. + x: A number or an array of numbers. + unroll: A number used to control the number of unrolled steps with + ``lax.scan``. It must be specified statically. + + Returns: + An array of same shape as ``x``. + + Note: + + The ``unroll`` parameter is JAX specific. It does not affect correctness but + can have a major impact on performance for evaluating high-order polynomials. + The parameter controls the number of unrolled steps with ``lax.scan`` inside + the ``jnp.polyval`` implementation. Consider setting ``unroll=128`` (or even + higher) to improve runtime performance on accelerators, at the cost of + increased compilation time. + + See also: + - :func:`jax.numpy.polyfit`: Least squares polynomial fit. + - :func:`jax.numpy.poly`: Finds the coefficients of a polynomial with given + roots. + - :func:`jax.numpy.roots`: Computes the roots of a polynomial for given + coefficients. + + Example: + >>> p = jnp.array([2, 5, 1]) + >>> jnp.polyval(p, 3) + Array(34., dtype=float32) + + If ``x`` is a 2D array, ``polyval`` returns 2D-array with same shape as + that of ``x``: + + >>> x = jnp.array([[2, 1, 5], + ... [3, 4, 7], + ... [1, 3, 5]]) + >>> jnp.polyval(p, x) + Array([[ 19., 8., 76.], + [ 34., 53., 134.], + [ 8., 34., 76.]], dtype=float32) + """ check_arraylike("polyval", p, x) - p, x = promote_dtypes_inexact(p, x) - shape = lax.broadcast_shapes(p.shape[1:], x.shape) - y = lax.full_like(x, 0, shape=shape, dtype=x.dtype) - y, _ = lax.scan(lambda y, p: (y * x + p, None), y, p, unroll=unroll) + p_arr, x_arr = promote_dtypes_inexact(p, x) + del p, x + shape = lax.broadcast_shapes(p_arr.shape[1:], x_arr.shape) + y = lax.full_like(x_arr, 0, shape=shape, dtype=x_arr.dtype) + y, _ = lax.scan(lambda y, p: (y * x_arr + p, None), y, p_arr, unroll=unroll) return y @implements(np.polyadd) @jit -def polyadd(a1: Array, a2: Array) -> Array: +def polyadd(a1: ArrayLike, a2: ArrayLike) -> Array: check_arraylike("polyadd", a1, a2) - a1, a2 = promote_dtypes(a1, a2) - if a2.shape[0] <= a1.shape[0]: - return a1.at[-a2.shape[0]:].add(a2) + a1_arr, a2_arr = promote_dtypes(a1, a2) + del a1, a2 + if a2_arr.shape[0] <= a1_arr.shape[0]: + return a1_arr.at[-a2_arr.shape[0]:].add(a2_arr) else: - return a2.at[-a1.shape[0]:].add(a1) + return a2_arr.at[-a1_arr.shape[0]:].add(a1_arr) -@implements(np.polyint) @partial(jit, static_argnames=('m',)) -def polyint(p: Array, m: int = 1, k: int | None = None) -> Array: +def polyint(p: ArrayLike, m: int = 1, k: int | ArrayLike | None = None) -> Array: + r"""Returns the coefficients of the integration of specified order of a polynomial. + + JAX implementation of :func:`numpy.polyint`. + + Args: + p: An array of polynomial coefficients. + m: Order of integration. Default is 1. It must be specified statically. + k: Scalar or array of ``m`` integration constant (s). + + Returns: + An array of coefficients of integrated polynomial. + + See also: + - :func:`jax.numpy.polyder`: Computes the coefficients of the derivative of + a polynomial. + - :func:`jax.numpy.polyval`: Evaluates a polynomial at specific values. + + Examples: + + The first order integration of the polynomial :math:`12 x^2 + 12 x + 6` is + :math:`4 x^3 + 6 x^2 + 6 x`. + + >>> p = jnp.array([12, 12, 6]) + >>> jnp.polyint(p) + Array([4., 6., 6., 0.], dtype=float32) + + Since the constant ``k`` is not provided, the result included ``0`` at the end. + If the constant ``k`` is provided: + + >>> jnp.polyint(p, k=4) + Array([4., 6., 6., 4.], dtype=float32) + + and the second order integration is :math:`x^4 + 2 x^3 + 3 x`: + + >>> jnp.polyint(p, m=2) + Array([1., 2., 3., 0., 0.], dtype=float32) + + When ``m>=2``, the constants ``k`` should be provided as an array having + ``m`` elements. The second order integration of the polynomial + :math:`12 x^2 + 12 x + 6` with the constants ``k=[4, 5]`` is + :math:`x^4 + 2 x^3 + 3 x^2 + 4 x + 5`: + + >>> jnp.polyint(p, m=2, k=jnp.array([4, 5])) + Array([1., 2., 3., 4., 5.], dtype=float32) + """ m = core.concrete_or_error(operator.index, m, "'m' argument of jnp.polyint") k = 0 if k is None else k check_arraylike("polyint", p, k) - p, k_arr = promote_dtypes_inexact(p, k) + p_arr, k_arr = promote_dtypes_inexact(p, k) + del p, k if m < 0: raise ValueError("Order of integral must be positive (see polyder)") k_arr = atleast_1d(k_arr) @@ -257,27 +503,62 @@ def polyint(p: Array, m: int = 1, k: int | None = None) -> Array: if k_arr.shape != (m,): raise ValueError("k must be a scalar or a rank-1 array of length 1 or m.") if m == 0: - return p + return p_arr else: - grid = (arange(len(p) + m, dtype=p.dtype)[np.newaxis] - - arange(m, dtype=p.dtype)[:, np.newaxis]) + grid = (arange(len(p_arr) + m, dtype=p_arr.dtype)[np.newaxis] + - arange(m, dtype=p_arr.dtype)[:, np.newaxis]) coeff = maximum(1, grid).prod(0)[::-1] - return true_divide(concatenate((p, k_arr)), coeff) + return true_divide(concatenate((p_arr, k_arr)), coeff) -@implements(np.polyder) @partial(jit, static_argnames=('m',)) -def polyder(p: Array, m: int = 1) -> Array: +def polyder(p: ArrayLike, m: int = 1) -> Array: + r"""Returns the coefficients of the derivative of specified order of a polynomial. + + JAX implementation of :func:`numpy.polyder`. + + Args: + p: Array of polynomials coefficients. + m: Order of differentiation (positive integer). Default is 1. It must be + specified statically. + + Returns: + An array of polynomial coefficients representing the derivative. + + Note: + :func:`jax.numpy.polyder` differs from :func:`numpy.polyder` when an integer + array is given. NumPy returns the result with dtype ``int`` whereas JAX + returns the result with dtype ``float``. + + See also: + - :func:`jax.numpy.polyint`: Computes the integral of polynomial. + - :func:`jax.numpy.polyval`: Evaluates a polynomial at specific values. + + Examples: + + The first order derivative of the polynomial :math:`2 x^3 - 5 x^2 + 3 x - 1` + is :math:`6 x^2 - 10 x +3`: + + >>> p = jnp.array([2, -5, 3, -1]) + >>> jnp.polyder(p) + Array([ 6., -10., 3.], dtype=float32) + + and its second order derivative is :math:`12 x - 10`: + + >>> jnp.polyder(p, m=2) + Array([ 12., -10.], dtype=float32) + """ check_arraylike("polyder", p) m = core.concrete_or_error(operator.index, m, "'m' argument of jnp.polyder") - p, = promote_dtypes_inexact(p) + p_arr, = promote_dtypes_inexact(p) + del p if m < 0: raise ValueError("Order of derivative must be positive") if m == 0: - return p - coeff = (arange(m, len(p), dtype=p.dtype)[np.newaxis] - - arange(m, dtype=p.dtype)[:, np.newaxis]).prod(0) - return p[:-m] * coeff[::-1] + return p_arr + coeff = (arange(m, len(p_arr), dtype=p_arr.dtype)[np.newaxis] + - arange(m, dtype=p_arr.dtype)[:, np.newaxis]).prod(0) + return p_arr[:-m] * coeff[::-1] _LEADING_ZEROS_DOC = """\ @@ -292,6 +573,7 @@ def polyder(p: Array, m: int = 1) -> Array: def polymul(a1: ArrayLike, a2: ArrayLike, *, trim_leading_zeros: bool = False) -> Array: check_arraylike("polymul", a1, a2) a1_arr, a2_arr = promote_dtypes_inexact(a1, a2) + del a1, a2 if trim_leading_zeros and (len(a1_arr) > 1 or len(a2_arr) > 1): a1_arr, a2_arr = trim_zeros(a1_arr, trim='f'), trim_zeros(a2_arr, trim='f') if len(a1_arr) == 0: @@ -304,6 +586,7 @@ def polymul(a1: ArrayLike, a2: ArrayLike, *, trim_leading_zeros: bool = False) - def polydiv(u: ArrayLike, v: ArrayLike, *, trim_leading_zeros: bool = False) -> tuple[Array, Array]: check_arraylike("polydiv", u, v) u_arr, v_arr = promote_dtypes_inexact(u, v) + del u, v m = len(u_arr) - 1 n = len(v_arr) - 1 scale = 1. / v_arr[0] @@ -319,7 +602,7 @@ def polydiv(u: ArrayLike, v: ArrayLike, *, trim_leading_zeros: bool = False) -> @implements(np.polysub) @jit -def polysub(a1: Array, a2: Array) -> Array: +def polysub(a1: ArrayLike, a2: ArrayLike) -> Array: check_arraylike("polysub", a1, a2) a1, a2 = promote_dtypes(a1, a2) return polyadd(a1, -a2) diff --git a/jax/_src/numpy/reductions.py b/jax/_src/numpy/reductions.py index b1e80b952a96..2e27ec229474 100644 --- a/jax/_src/numpy/reductions.py +++ b/jax/_src/numpy/reductions.py @@ -15,11 +15,11 @@ from __future__ import annotations import builtins -from collections.abc import Sequence +from collections.abc import Callable, Sequence from functools import partial import math import operator -from typing import overload, Any, Callable, Literal, Protocol, Union +from typing import overload, Any, Literal, Protocol, Union import warnings import numpy as np @@ -33,7 +33,7 @@ _broadcast_to, check_arraylike, _complex_elem_type, promote_dtypes_inexact, promote_dtypes_numeric, _where, implements) from jax._src.lax import lax as lax_internal -from jax._src.typing import Array, ArrayLike, DType, DTypeLike +from jax._src.typing import Array, ArrayLike, DType, DTypeLike, DeprecatedArg from jax._src.util import ( canonicalize_axis as _canonicalize_axis, maybe_named_axis, NumpyComplexWarning) @@ -65,6 +65,20 @@ def _upcast_f16(dtype: DTypeLike) -> DType: return np.dtype('float32') return np.dtype(dtype) +def _promote_integer_dtype(dtype: DTypeLike) -> DTypeLike: + # Note: NumPy always promotes to 64-bit; jax instead promotes to the + # default dtype as defined by dtypes.int_ or dtypes.uint. + if dtypes.issubdtype(dtype, np.bool_): + return dtypes.int_ + elif dtypes.issubdtype(dtype, np.unsignedinteger): + if np.iinfo(dtype).bits < np.iinfo(dtypes.uint).bits: + return dtypes.uint + elif dtypes.issubdtype(dtype, np.integer): + if np.iinfo(dtype).bits < np.iinfo(dtypes.int_).bits: + return dtypes.int_ + return dtype + + ReductionOp = Callable[[Any, Any], Any] def _reduction(a: ArrayLike, name: str, np_fun: Any, op: ReductionOp, init_val: ArrayLike, @@ -103,16 +117,7 @@ def _reduction(a: ArrayLike, name: str, np_fun: Any, op: ReductionOp, init_val: result_dtype = dtype or dtypes.dtype(a) if dtype is None and promote_integers: - # Note: NumPy always promotes to 64-bit; jax instead promotes to the - # default dtype as defined by dtypes.int_ or dtypes.uint. - if dtypes.issubdtype(result_dtype, np.bool_): - result_dtype = dtypes.int_ - elif dtypes.issubdtype(result_dtype, np.unsignedinteger): - if np.iinfo(result_dtype).bits < np.iinfo(dtypes.uint).bits: - result_dtype = dtypes.uint - elif dtypes.issubdtype(result_dtype, np.integer): - if np.iinfo(result_dtype).bits < np.iinfo(dtypes.int_).bits: - result_dtype = dtypes.int_ + result_dtype = _promote_integer_dtype(result_dtype) result_dtype = dtypes.canonicalize_dtype(result_dtype) @@ -199,15 +204,6 @@ def force(x): force, x, "The axis argument must be known statically.") -# TODO(jakevdp) change promote_integers default to False -_PROMOTE_INTEGERS_DOC = """ -promote_integers : bool, default=True - If True, then integer inputs will be promoted to the widest available integer - dtype, following numpy's behavior. If False, the result will have the same dtype - as the input. ``promote_integers`` is ignored if ``dtype`` is specified. -""" - - @partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims', 'promote_integers'), inline=True) def _reduce_sum(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, keepdims: bool = False, @@ -219,10 +215,76 @@ def _reduce_sum(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, initial=initial, where_=where, parallel_reduce=lax.psum, promote_integers=promote_integers) -@implements(np.sum, skip_params=['out'], extra_params=_PROMOTE_INTEGERS_DOC) + def sum(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, where: ArrayLike | None = None, promote_integers: bool = True) -> Array: + r"""Sum of the elements of the array over a given axis. + + JAX implementation of :func:`numpy.sum`. + + Args: + a: Input array. + axis: int or array, default=None. Axis along which the sum to be computed. + If None, the sum is computed along all the axes. + dtype: The type of the output array. Default=None. + out: Unused by JAX + keepdims: bool, default=False. If true, reduced axes are left in the result + with size 1. + initial: int or array, Default=None. Initial value for the sum. + where: int or array, default=None. The elements to be used in the sum. Array + should be broadcast compatible to the input. + promote_integers : bool, default=True. If True, then integer inputs will be + promoted to the widest available integer dtype, following numpy's behavior. + If False, the result will have the same dtype as the input. + ``promote_integers`` is ignored if ``dtype`` is specified. + + Returns: + An array of the sum along the given axis. + + See also: + - :func:`jax.numpy.prod`: Compute the product of array elements over a given + axis. + - :func:`jax.numpy.max`: Compute the maximum of array elements over given axis. + - :func:`jax.numpy.min`: Compute the minimum of array elements over given axis. + + Examples: + + By default, the sum is computed along all the axes. + + >>> x = jnp.array([[1, 3, 4, 2], + ... [5, 2, 6, 3], + ... [8, 1, 3, 9]]) + >>> jnp.sum(x) + Array(47, dtype=int32) + + If ``axis=1``, the sum is computed along axis 1. + + >>> jnp.sum(x, axis=1) + Array([10, 16, 21], dtype=int32) + + If ``keepdims=True``, ``ndim`` of the output is equal to that of the input. + + >>> jnp.sum(x, axis=1, keepdims=True) + Array([[10], + [16], + [21]], dtype=int32) + + To include only specific elements in the sum, you can use ``where``. + + >>> where=jnp.array([[0, 0, 1, 0], + ... [0, 0, 1, 1], + ... [1, 1, 1, 0]], dtype=bool) + >>> jnp.sum(x, axis=1, keepdims=True, where=where) + Array([[ 4], + [ 9], + [12]], dtype=int32) + >>> where=jnp.array([[False], + ... [False], + ... [False]]) + >>> jnp.sum(x, axis=0, keepdims=True, where=where) + Array([[0, 0, 0, 0]], dtype=int32) + """ return _reduce_sum(a, axis=_ensure_optional_axes(axis), dtype=dtype, out=out, keepdims=keepdims, initial=initial, where=where, promote_integers=promote_integers) @@ -238,11 +300,77 @@ def _reduce_prod(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None axis=axis, dtype=dtype, out=out, keepdims=keepdims, initial=initial, where_=where, promote_integers=promote_integers) -@implements(np.prod, skip_params=['out'], extra_params=_PROMOTE_INTEGERS_DOC) + def prod(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, where: ArrayLike | None = None, promote_integers: bool = True) -> Array: + r"""Return product of the array elements over a given axis. + + JAX implementation of :func:`numpy.prod`. + + Args: + a: Input array. + axis: int or array, default=None. Axis along which the product to be computed. + If None, the product is computed along all the axes. + dtype: The type of the output array. Default=None. + keepdims: bool, default=False. If true, reduced axes are left in the result + with size 1. + initial: int or array, Default=None. Initial value for the product. + where: int or array, default=None. The elements to be used in the product. + Array should be broadcast compatible to the input. + promote_integers : bool, default=True. If True, then integer inputs will be + promoted to the widest available integer dtype, following numpy's behavior. + If False, the result will have the same dtype as the input. + ``promote_integers`` is ignored if ``dtype`` is specified. + out: Unused by JAX. + + Returns: + An array of the product along the given axis. + + See also: + - :func:`jax.numpy.sum`: Compute the sum of array elements over a given axis. + - :func:`jax.numpy.max`: Compute the maximum of array elements over given axis. + - :func:`jax.numpy.min`: Compute the minimum of array elements over given axis. + + Examples: + By default, ``jnp.prod`` computes along all the axes. + + >>> x = jnp.array([[1, 3, 4, 2], + ... [5, 2, 1, 3], + ... [2, 1, 3, 1]]) + >>> jnp.prod(x) + Array(4320, dtype=int32) + + If ``axis=1``, product is computed along axis 1. + + >>> jnp.prod(x, axis=1) + Array([24, 30, 6], dtype=int32) + + If ``keepdims=True``, ``ndim`` of the output is equal to that of the input. + + >>> jnp.prod(x, axis=1, keepdims=True) + Array([[24], + [30], + [ 6]], dtype=int32) + + To include only specific elements in the sum, you can use a``where``. + + >>> where=jnp.array([[1, 0, 1, 0], + ... [0, 0, 1, 1], + ... [1, 1, 1, 0]], dtype=bool) + >>> jnp.prod(x, axis=1, keepdims=True, where=where) + Array([[4], + [3], + [6]], dtype=int32) + >>> where = jnp.array([[False], + ... [False], + ... [False]]) + >>> jnp.prod(x, axis=1, keepdims=True, where=where) + Array([[1], + [1], + [1]], dtype=int32) + """ return _reduce_prod(a, axis=_ensure_optional_axes(axis), dtype=dtype, out=out, keepdims=keepdims, initial=initial, where=where, promote_integers=promote_integers) @@ -256,10 +384,77 @@ def _reduce_max(a: ArrayLike, axis: Axis = None, out: None = None, axis=axis, out=out, keepdims=keepdims, initial=initial, where_=where, parallel_reduce=lax.pmax) -@implements(np.max, skip_params=['out']) + def max(a: ArrayLike, axis: Axis = None, out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, where: ArrayLike | None = None) -> Array: + r"""Return the maximum of the array elements along a given axis. + + JAX implementation of :func:`numpy.max`. + + Args: + a: Input array. + axis: int or array, default=None. Axis along which the maximum to be computed. + If None, the maximum is computed along all the axes. + keepdims: bool, default=False. If true, reduced axes are left in the result + with size 1. + initial: int or array, default=None. Initial value for the maximum. + where: int or array of boolean dtype, default=None. The elements to be used + in the maximum. Array should be broadcast compatible to the input. + ``initial`` must be specified when ``where`` is used. + out: Unused by JAX. + + Returns: + An array of maximum values along the given axis. + + See also: + - :func:`jax.numpy.min`: Compute the minimum of array elements along a given + axis. + - :func:`jax.numpy.sum`: Compute the sum of array elements along a given axis. + - :func:`jax.numpy.prod`: Compute the product of array elements along a given + axis. + + Examples: + + By default, ``jnp.max`` computes the maximum of elements along all the axes. + + >>> x = jnp.array([[9, 3, 4, 5], + ... [5, 2, 7, 4], + ... [8, 1, 3, 6]]) + >>> jnp.max(x) + Array(9, dtype=int32) + + If ``axis=1``, the maximum will be computed along axis 1. + + >>> jnp.max(x, axis=1) + Array([9, 7, 8], dtype=int32) + + If ``keepdims=True``, ``ndim`` of the output will be same of that of the input. + + >>> jnp.max(x, axis=1, keepdims=True) + Array([[9], + [7], + [8]], dtype=int32) + + To include only specific elements in computing the maximum, you can use + ``where``. It can either have same dimension as input + + >>> where=jnp.array([[0, 0, 1, 0], + ... [0, 0, 1, 1], + ... [1, 1, 1, 0]], dtype=bool) + >>> jnp.max(x, axis=1, keepdims=True, initial=0, where=where) + Array([[4], + [7], + [8]], dtype=int32) + + or must be broadcast compatible with input. + + >>> where = jnp.array([[False], + ... [False], + ... [False]]) + >>> jnp.max(x, axis=0, keepdims=True, initial=0, where=where) + Array([[0, 0, 0, 0]], dtype=int32) + """ return _reduce_max(a, axis=_ensure_optional_axes(axis), out=out, keepdims=keepdims, initial=initial, where=where) @@ -271,10 +466,76 @@ def _reduce_min(a: ArrayLike, axis: Axis = None, out: None = None, axis=axis, out=out, keepdims=keepdims, initial=initial, where_=where, parallel_reduce=lax.pmin) -@implements(np.min, skip_params=['out']) + def min(a: ArrayLike, axis: Axis = None, out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, where: ArrayLike | None = None) -> Array: + r"""Return the minimum of array elements along a given axis. + + JAX implementation of :func:`numpy.min`. + + Args: + a: Input array. + axis: int or array, default=None. Axis along which the minimum to be computed. + If None, the minimum is computed along all the axes. + keepdims: bool, default=False. If true, reduced axes are left in the result + with size 1. + initial: int or array, Default=None. Initial value for the minimum. + where: int or array, default=None. The elements to be used in the minimum. + Array should be broadcast compatible to the input. ``initial`` must be + specified when ``where`` is used. + out: Unused by JAX. + + Returns: + An array of minimum values along the given axis. + + See also: + - :func:`jax.numpy.max`: Compute the maximum of array elements along a given + axis. + - :func:`jax.numpy.sum`: Compute the sum of array elements along a given axis. + - :func:`jax.numpy.prod`: Compute the product of array elements along a given + axis. + + Examples: + By default, the minimum is computed along all the axes. + + >>> x = jnp.array([[2, 5, 1, 6], + ... [3, -7, -2, 4], + ... [8, -4, 1, -3]]) + >>> jnp.min(x) + Array(-7, dtype=int32) + + If ``axis=1``, the minimum is computed along axis 1. + + >>> jnp.min(x, axis=1) + Array([ 1, -7, -4], dtype=int32) + + If ``keepdims=True``, ``ndim`` of the output will be same of that of the input. + + >>> jnp.min(x, axis=1, keepdims=True) + Array([[ 1], + [-7], + [-4]], dtype=int32) + + To include only specific elements in computing the minimum, you can use + ``where``. ``where`` can either have same dimension as input. + + >>> where=jnp.array([[1, 0, 1, 0], + ... [0, 0, 1, 1], + ... [1, 1, 1, 0]], dtype=bool) + >>> jnp.min(x, axis=1, keepdims=True, initial=0, where=where) + Array([[ 0], + [-2], + [-4]], dtype=int32) + + or must be broadcast compatible with input. + + >>> where = jnp.array([[False], + ... [False], + ... [False]]) + >>> jnp.min(x, axis=0, keepdims=True, initial=0, where=where) + Array([[0, 0, 0, 0]], dtype=int32) + """ return _reduce_min(a, axis=_ensure_optional_axes(axis), out=out, keepdims=keepdims, initial=initial, where=where) @@ -284,9 +545,53 @@ def _reduce_all(a: ArrayLike, axis: Axis = None, out: None = None, return _reduction(a, "all", np.all, lax.bitwise_and, True, preproc=_cast_to_bool, axis=axis, out=out, keepdims=keepdims, where_=where) -@implements(np.all, skip_params=['out']) + def all(a: ArrayLike, axis: Axis = None, out: None = None, keepdims: bool = False, *, where: ArrayLike | None = None) -> Array: + r"""Test whether all array elements along a given axis evaluate to True. + + JAX implementation of :func:`numpy.all`. + + Args: + a: Input array. + axis: int or array, default=None. Axis along which to be tested. If None, + tests along all the axes. + keepdims: bool, default=False. If true, reduced axes are left in the result + with size 1. + where: int or array of boolean dtype, default=None. The elements to be used + in the test. Array should be broadcast compatible to the input. + out: Unused by JAX. + + Returns: + An array of boolean values. + + Examples: + By default, ``jnp.all`` tests for True values along all the axes. + + >>> x = jnp.array([[True, True, True, False], + ... [True, False, True, False], + ... [True, True, False, False]]) + >>> jnp.all(x) + Array(False, dtype=bool) + + If ``axis=0``, tests for True values along axis 0. + + >>> jnp.all(x, axis=0) + Array([ True, False, False, False], dtype=bool) + + If ``keepdims=True``, ``ndim`` of the output will be same of that of the input. + + >>> jnp.all(x, axis=0, keepdims=True) + Array([[ True, False, False, False]], dtype=bool) + + To include specific elements in testing for True values, you can use a``where``. + + >>> where=jnp.array([[1, 0, 1, 0], + ... [0, 0, 1, 1], + ... [1, 1, 1, 0]], dtype=bool) + >>> jnp.all(x, axis=0, keepdims=True, where=where) + Array([[ True, True, False, False]], dtype=bool) + """ return _reduce_all(a, axis=_ensure_optional_axes(axis), out=out, keepdims=keepdims, where=where) @@ -296,14 +601,69 @@ def _reduce_any(a: ArrayLike, axis: Axis = None, out: None = None, return _reduction(a, "any", np.any, lax.bitwise_or, False, preproc=_cast_to_bool, axis=axis, out=out, keepdims=keepdims, where_=where) -@implements(np.any, skip_params=['out']) + def any(a: ArrayLike, axis: Axis = None, out: None = None, keepdims: bool = False, *, where: ArrayLike | None = None) -> Array: + r"""Test whether any of the array elements along a given axis evaluate to True. + + JAX implementation of :func:`numpy.any`. + + Args: + a: Input array. + axis: int or array, default=None. Axis along which to be tested. If None, + tests along all the axes. + keepdims: bool, default=False. If true, reduced axes are left in the result + with size 1. + where: int or array of boolean dtype, default=None. The elements to be used + in the test. Array should be broadcast compatible to the input. + out: Unused by JAX. + + Returns: + An array of boolean values. + + Examples: + By default, ``jnp.any`` tests along all the axes. + + >>> x = jnp.array([[True, True, True, False], + ... [True, False, True, False], + ... [True, True, False, False]]) + >>> jnp.any(x) + Array(True, dtype=bool) + + If ``axis=0``, tests along axis 0. + + >>> jnp.any(x, axis=0) + Array([ True, True, True, False], dtype=bool) + + If ``keepdims=True``, ``ndim`` of the output will be same of that of the input. + + >>> jnp.any(x, axis=0, keepdims=True) + Array([[ True, True, True, False]], dtype=bool) + + To include specific elements in testing for True values, you can use a``where``. + + >>> where=jnp.array([[1, 0, 1, 0], + ... [0, 1, 0, 1], + ... [1, 0, 1, 0]], dtype=bool) + >>> jnp.any(x, axis=0, keepdims=True, where=where) + Array([[ True, False, True, False]], dtype=bool) + """ return _reduce_any(a, axis=_ensure_optional_axes(axis), out=out, keepdims=keepdims, where=where) -amin = min -amax = max +def amin(a: ArrayLike, axis: Axis = None, out: None = None, + keepdims: bool = False, initial: ArrayLike | None = None, + where: ArrayLike | None = None) -> Array: + """Alias of :func:`jax.numpy.min`.""" + return min(a, axis=axis, out=out, keepdims=keepdims, + initial=initial, where=where) + +def amax(a: ArrayLike, axis: Axis = None, out: None = None, + keepdims: bool = False, initial: ArrayLike | None = None, + where: ArrayLike | None = None) -> Array: + """Alias of :func:`jax.numpy.max`.""" + return max(a, axis=axis, out=out, keepdims=keepdims, + initial=initial, where=where) def _axis_size(a: ArrayLike, axis: int | Sequence[int]): if not isinstance(axis, (tuple, list)): @@ -316,10 +676,65 @@ def _axis_size(a: ArrayLike, axis: int | Sequence[int]): size *= maybe_named_axis(a, lambda i: a_shape[i], lambda name: lax.psum(1, name)) return size -@implements(np.mean, skip_params=['out']) + def mean(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, keepdims: bool = False, *, where: ArrayLike | None = None) -> Array: + r"""Return the mean of array elements along a given axis. + + JAX implementation of :func:`numpy.mean`. + + Args: + a: input array. + axis: optional, int or sequence of ints, default=None. Axis along which the + mean to be computed. If None, mean is computed along all the axes. + dtype: The type of the output array. Default=None. + keepdims: bool, default=False. If true, reduced axes are left in the result + with size 1. + where: optional, boolean array, default=None. The elements to be used in the + mean. Array should be broadcast compatible to the input. + out: Unused by JAX. + + Returns: + An array of the mean along the given axis. + + See also: + - :func:`jax.numpy.sum`: Compute the sum of array elements over a given axis. + - :func:`jax.numpy.max`: Compute the maximum of array elements over given axis. + - :func:`jax.numpy.min`: Compute the minimum of array elements over given axis. + + Examples: + By default, the mean is computed along all the axes. + + >>> x = jnp.array([[1, 3, 4, 2], + ... [5, 2, 6, 3], + ... [8, 1, 2, 9]]) + >>> jnp.mean(x) + Array(3.8333335, dtype=float32) + + If ``axis=1``, the mean is computed along axis 1. + + >>> jnp.mean(x, axis=1) + Array([2.5, 4. , 5. ], dtype=float32) + + If ``keepdims=True``, ``ndim`` of the output is equal to that of the input. + + >>> jnp.mean(x, axis=1, keepdims=True) + Array([[2.5], + [4. ], + [5. ]], dtype=float32) + + To use only specific elements of ``x`` to compute the mean, you can use + ``where``. + + >>> where = jnp.array([[1, 0, 1, 0], + ... [0, 1, 0, 1], + ... [1, 1, 0, 1]], dtype=bool) + >>> jnp.mean(x, axis=1, keepdims=True, where=where) + Array([[2.5], + [2.5], + [6. ]], dtype=float32) + """ return _mean(a, _ensure_optional_axes(axis), dtype, out, keepdims, where=where) @@ -425,16 +840,92 @@ def _average(a: ArrayLike, axis: Axis = None, weights: ArrayLike | None = None, return avg -@implements(np.var, skip_params=['out']) def var(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, ddof: int = 0, keepdims: bool = False, *, - where: ArrayLike | None = None) -> Array: - return _var(a, _ensure_optional_axes(axis), dtype, out, ddof, keepdims, + where: ArrayLike | None = None, correction: int | float | None = None) -> Array: + r"""Compute the variance along a given axis. + + JAX implementation of :func:`numpy.var`. + + Args: + a: input array. + axis: optional, int or sequence of ints, default=None. Axis along which the + variance is computed. If None, variance is computed along all the axes. + dtype: The type of the output array. Default=None. + ddof: int, default=0. Degrees of freedom. The divisor in the variance computation + is ``N-ddof``, ``N`` is number of elements along given axis. + keepdims: bool, default=False. If true, reduced axes are left in the result + with size 1. + where: optional, boolean array, default=None. The elements to be used in the + variance. Array should be broadcast compatible to the input. + correction: int or float, default=None. Alternative name for ``ddof``. + Both ddof and correction can't be provided simultaneously. + out: Unused by JAX. + + Returns: + An array of the variance along the given axis. + + See also: + - :func:`jax.numpy.mean`: Compute the mean of array elements over a given axis. + - :func:`jax.numpy.std`: Compute the standard deviation of array elements over + given axis. + - :func:`jax.numpy.nanvar`: Compute the variance along a given axis, ignoring + NaNs values. + - :func:`jax.numpy.nanstd`: Computed the standard deviation of a given axis, + ignoring NaN values. + + Examples: + By default, ``jnp.var`` computes the variance along all axes. + + >>> x = jnp.array([[1, 3, 4, 2], + ... [5, 2, 6, 3], + ... [8, 4, 2, 9]]) + >>> with jnp.printoptions(precision=2, suppress=True): + ... jnp.var(x) + Array(5.74, dtype=float32) + + If ``axis=1``, variance is computed along axis 1. + + >>> jnp.var(x, axis=1) + Array([1.25 , 2.5 , 8.1875], dtype=float32) + + To preserve the dimensions of input, you can set ``keepdims=True``. + + >>> jnp.var(x, axis=1, keepdims=True) + Array([[1.25 ], + [2.5 ], + [8.1875]], dtype=float32) + + If ``ddof=1``: + + >>> with jnp.printoptions(precision=2, suppress=True): + ... print(jnp.var(x, axis=1, keepdims=True, ddof=1)) + [[ 1.67] + [ 3.33] + [10.92]] + + To include specific elements of the array to compute variance, you can use + ``where``. + + >>> where = jnp.array([[1, 0, 1, 0], + ... [0, 1, 1, 0], + ... [1, 1, 1, 0]], dtype=bool) + >>> with jnp.printoptions(precision=2, suppress=True): + ... print(jnp.var(x, axis=1, keepdims=True, where=where)) + [[2.25] + [4. ] + [6.22]] + """ + if correction is None: + correction = ddof + elif not isinstance(ddof, int) or ddof != 0: + raise ValueError("ddof and correction can't be provided simultaneously.") + return _var(a, _ensure_optional_axes(axis), dtype, out, correction, keepdims, where=where) @partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims')) def _var(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, - out: None = None, ddof: int = 0, keepdims: bool = False, *, + out: None = None, correction: int | float = 0, keepdims: bool = False, *, where: ArrayLike | None = None) -> Array: check_arraylike("var", a) dtypes.check_user_dtype_supported(dtype, "var") @@ -460,9 +951,9 @@ def _var(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, else: normalizer = sum(_broadcast_to(where, np.shape(a)), axis, dtype=computation_dtype, keepdims=keepdims) - normalizer = lax.sub(normalizer, lax.convert_element_type(ddof, computation_dtype)) + normalizer = lax.sub(normalizer, lax.convert_element_type(correction, computation_dtype)) result = sum(centered, axis, dtype=computation_dtype, keepdims=keepdims, where=where) - return lax.div(result, normalizer).astype(dtype) + return _where(normalizer > 0, lax.div(result, normalizer).astype(dtype), np.nan) def _var_promote_types(a_dtype: DTypeLike, dtype: DTypeLike | None) -> tuple[DType, DType]: @@ -486,16 +977,88 @@ def _var_promote_types(a_dtype: DTypeLike, dtype: DTypeLike | None) -> tuple[DTy return _upcast_f16(computation_dtype), np.dtype(dtype) -@implements(np.std, skip_params=['out']) def std(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, ddof: int = 0, keepdims: bool = False, *, - where: ArrayLike | None = None) -> Array: - return _std(a, _ensure_optional_axes(axis), dtype, out, ddof, keepdims, + where: ArrayLike | None = None, correction: int | float | None = None) -> Array: + r"""Compute the standard deviation along a given axis. + + JAX implementation of :func:`numpy.std`. + + Args: + a: input array. + axis: optional, int or sequence of ints, default=None. Axis along which the + standard deviation is computed. If None, standard deviaiton is computed + along all the axes. + dtype: The type of the output array. Default=None. + ddof: int, default=0. Degrees of freedom. The divisor in the standard deviation + computation is ``N-ddof``, ``N`` is number of elements along given axis. + keepdims: bool, default=False. If true, reduced axes are left in the result + with size 1. + where: optional, boolean array, default=None. The elements to be used in the + standard deviation. Array should be broadcast compatible to the input. + correction: int or float, default=None. Alternative name for ``ddof``. + Both ddof and correction can't be provided simultaneously. + out: Unused by JAX. + + Returns: + An array of the standard deviation along the given axis. + + See also: + - :func:`jax.numpy.var`: Compute the variance of array elements over given + axis. + - :func:`jax.numpy.mean`: Compute the mean of array elements over a given axis. + - :func:`jax.numpy.nanvar`: Compute the variance along a given axis, ignoring + NaNs values. + - :func:`jax.numpy.nanstd`: Computed the standard deviation of a given axis, + ignoring NaN values. + + Examples: + By default, ``jnp.std`` computes the standard deviation along all axes. + + >>> x = jnp.array([[1, 3, 4, 2], + ... [4, 2, 5, 3], + ... [5, 4, 2, 3]]) + >>> with jnp.printoptions(precision=2, suppress=True): + ... jnp.std(x) + Array(1.21, dtype=float32) + + If ``axis=0``, computes along axis 0. + + >>> with jnp.printoptions(precision=2, suppress=True): + ... print(jnp.std(x, axis=0)) + [1.7 0.82 1.25 0.47] + + To preserve the dimensions of input, you can set ``keepdims=True``. + + >>> with jnp.printoptions(precision=2, suppress=True): + ... print(jnp.std(x, axis=0, keepdims=True)) + [[1.7 0.82 1.25 0.47]] + + If ``ddof=1``: + + >>> with jnp.printoptions(precision=2, suppress=True): + ... print(jnp.std(x, axis=0, keepdims=True, ddof=1)) + [[2.08 1. 1.53 0.58]] + + To include specific elements of the array to compute standard deviation, you + can use ``where``. + + >>> where = jnp.array([[1, 0, 1, 0], + ... [0, 1, 0, 1], + ... [1, 1, 1, 0]], dtype=bool) + >>> jnp.std(x, axis=0, keepdims=True, where=where) + Array([[2., 1., 1., 0.]], dtype=float32) + """ + if correction is None: + correction = ddof + elif not isinstance(ddof, int) or ddof != 0: + raise ValueError("ddof and correction can't be provided simultaneously.") + return _std(a, _ensure_optional_axes(axis), dtype, out, correction, keepdims, where=where) @partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims')) def _std(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, - out: None = None, ddof: int = 0, keepdims: bool = False, *, + out: None = None, correction: int | float = 0, keepdims: bool = False, *, where: ArrayLike | None = None) -> Array: check_arraylike("std", a) dtypes.check_user_dtype_supported(dtype, "std") @@ -503,7 +1066,7 @@ def _std(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, raise ValueError(f"dtype argument to jnp.std must be inexact; got {dtype}") if out is not None: raise NotImplementedError("The 'out' argument to jnp.std is not supported.") - return lax.sqrt(var(a, axis=axis, dtype=dtype, ddof=ddof, keepdims=keepdims, where=where)) + return lax.sqrt(var(a, axis=axis, dtype=dtype, correction=correction, keepdims=keepdims, where=where)) @implements(np.ptp, skip_params=['out']) @@ -663,7 +1226,8 @@ def __call__(self, a: ArrayLike, axis: Axis = None, """ def _make_cumulative_reduction(np_reduction: Any, reduction: Callable[..., Array], - fill_nan: bool = False, fill_value: ArrayLike = 0) -> CumulativeReduction: + fill_nan: bool = False, fill_value: ArrayLike = 0, + promote_integers: bool = False) -> CumulativeReduction: @implements(np_reduction, skip_params=['out'], lax_description=CUML_REDUCTION_LAX_DESCRIPTION) def cumulative_reduction(a: ArrayLike, axis: Axis = None, @@ -691,12 +1255,18 @@ def _cumulative_reduction(a: ArrayLike, axis: Axis = None, if fill_nan: a = _where(lax_internal._isnan(a), _lax_const(a, fill_value), a) - if not dtype and dtypes.dtype(a) == np.bool_: - dtype = dtypes.canonicalize_dtype(dtypes.int_) - if dtype: - a = lax.convert_element_type(a, dtype) + result_type: DTypeLike = dtypes.dtype(dtype or a) + if dtype is None and promote_integers or dtypes.issubdtype(result_type, np.bool_): + result_type = _promote_integer_dtype(result_type) + result_type = dtypes.canonicalize_dtype(result_type) - return reduction(a, axis) + a = lax.convert_element_type(a, result_type) + result = reduction(a, axis) + + # We downcast to boolean because we accumulate in integer types + if dtypes.issubdtype(dtype, np.bool_): + result = lax.convert_element_type(result, np.bool_) + return result return cumulative_reduction @@ -707,50 +1277,88 @@ def _cumulative_reduction(a: ArrayLike, axis: Axis = None, fill_nan=True, fill_value=0) nancumprod = _make_cumulative_reduction(np.nancumprod, lax.cumprod, fill_nan=True, fill_value=1) +_cumsum_with_promotion = _make_cumulative_reduction( + np.cumsum, lax.cumsum, fill_nan=False, promote_integers=True +) + +@implements(getattr(np, 'cumulative_sum', None)) +def cumulative_sum( + x: ArrayLike, /, *, axis: int | None = None, + dtype: DTypeLike | None = None, + include_initial: bool = False) -> Array: + check_arraylike("cumulative_sum", x) + x = lax_internal.asarray(x) + if x.ndim == 0: + raise ValueError( + "The input must be non-scalar to take a cumulative sum, however a " + "scalar value or scalar array was given." + ) + if axis is None: + axis = 0 + if x.ndim > 1: + raise ValueError( + f"The input array has rank {x.ndim}, however axis was not set to an " + "explicit value. The axis argument is only optional for one-dimensional " + "arrays.") + + axis = _canonicalize_axis(axis, x.ndim) + dtypes.check_user_dtype_supported(dtype) + out = _cumsum_with_promotion(x, axis=axis, dtype=dtype) + if include_initial: + zeros_shape = list(x.shape) + zeros_shape[axis] = 1 + out = lax_internal.concatenate( + [lax_internal.full(zeros_shape, 0, dtype=out.dtype), out], + dimension=axis) + return out # Quantiles + +# TODO(jakevdp): interpolation argument deprecated 2024-05-16 @implements(np.quantile, skip_params=['out', 'overwrite_input']) -@partial(api.jit, static_argnames=('axis', 'overwrite_input', 'interpolation', - 'keepdims', 'method')) +@partial(api.jit, static_argnames=('axis', 'overwrite_input', 'interpolation', 'keepdims', 'method')) def quantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = None, out: None = None, overwrite_input: bool = False, method: str = "linear", - keepdims: bool = False, interpolation: None = None) -> Array: + keepdims: bool = False, *, interpolation: DeprecatedArg | str = DeprecatedArg()) -> Array: check_arraylike("quantile", a, q) if overwrite_input or out is not None: msg = ("jax.numpy.quantile does not support overwrite_input=True or " "out != None") raise ValueError(msg) - if interpolation is not None: + if not isinstance(interpolation, DeprecatedArg): warnings.warn("The interpolation= argument to 'quantile' is deprecated. " - "Use 'method=' instead.", DeprecationWarning) - return _quantile(lax_internal.asarray(a), lax_internal.asarray(q), axis, interpolation or method, keepdims, False) + "Use 'method=' instead.", DeprecationWarning, stacklevel=2) + method = interpolation + return _quantile(lax_internal.asarray(a), lax_internal.asarray(q), axis, method, keepdims, False) +# TODO(jakevdp): interpolation argument deprecated 2024-05-16 @implements(np.nanquantile, skip_params=['out', 'overwrite_input']) -@partial(api.jit, static_argnames=('axis', 'overwrite_input', 'interpolation', - 'keepdims', 'method')) +@partial(api.jit, static_argnames=('axis', 'overwrite_input', 'interpolation', 'keepdims', 'method')) def nanquantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = None, out: None = None, overwrite_input: bool = False, method: str = "linear", - keepdims: bool = False, interpolation: None = None) -> Array: + keepdims: bool = False, *, interpolation: DeprecatedArg | str = DeprecatedArg()) -> Array: check_arraylike("nanquantile", a, q) if overwrite_input or out is not None: msg = ("jax.numpy.nanquantile does not support overwrite_input=True or " "out != None") raise ValueError(msg) - if interpolation is not None: + if not isinstance(interpolation, DeprecatedArg): warnings.warn("The interpolation= argument to 'nanquantile' is deprecated. " - "Use 'method=' instead.", DeprecationWarning) - return _quantile(lax_internal.asarray(a), lax_internal.asarray(q), axis, interpolation or method, keepdims, True) + "Use 'method=' instead.", DeprecationWarning, stacklevel=2) + method = interpolation + return _quantile(lax_internal.asarray(a), lax_internal.asarray(q), axis, method, keepdims, True) def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None, - interpolation: str, keepdims: bool, squash_nans: bool) -> Array: - if interpolation not in ["linear", "lower", "higher", "midpoint", "nearest"]: - raise ValueError("interpolation can only be 'linear', 'lower', 'higher', " - "'midpoint', or 'nearest'") + method: str, keepdims: bool, squash_nans: bool) -> Array: + if method not in ["linear", "lower", "higher", "midpoint", "nearest"]: + raise ValueError("method can only be 'linear', 'lower', 'higher', 'midpoint', or 'nearest'") a, = promote_dtypes_inexact(a) keepdim = [] if dtypes.issubdtype(a.dtype, np.complexfloating): raise ValueError("quantile does not support complex input, as the operation is poorly defined.") if axis is None: + if keepdims: + keepdim = [1] * a.ndim a = a.ravel() axis = 0 elif isinstance(axis, tuple): @@ -842,50 +1450,57 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None, high_weight = lax.broadcast_in_dim(high_weight, high_value.shape, broadcast_dimensions=(0,)) - if interpolation == "linear": + if method == "linear": result = lax.add(lax.mul(low_value.astype(q.dtype), low_weight), lax.mul(high_value.astype(q.dtype), high_weight)) - elif interpolation == "lower": + elif method == "lower": result = low_value - elif interpolation == "higher": + elif method == "higher": result = high_value - elif interpolation == "nearest": + elif method == "nearest": pred = lax.le(high_weight, _lax_const(high_weight, 0.5)) result = lax.select(pred, low_value, high_value) - elif interpolation == "midpoint": + elif method == "midpoint": result = lax.mul(lax.add(low_value, high_value), _lax_const(low_value, 0.5)) else: - raise ValueError(f"interpolation={interpolation!r} not recognized") + raise ValueError(f"{method=!r} not recognized") if keepdims and keepdim: if q_ndim > 0: keepdim = [np.shape(q)[0], *keepdim] result = result.reshape(keepdim) return lax.convert_element_type(result, a.dtype) +# TODO(jakevdp): interpolation argument deprecated 2024-05-16 @implements(np.percentile, skip_params=['out', 'overwrite_input']) -@partial(api.jit, static_argnames=('axis', 'overwrite_input', 'interpolation', - 'keepdims', 'method')) +@partial(api.jit, static_argnames=('axis', 'overwrite_input', 'interpolation', 'keepdims', 'method')) def percentile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = None, out: None = None, overwrite_input: bool = False, method: str = "linear", - keepdims: bool = False, interpolation: None = None) -> Array: + keepdims: bool = False, *, interpolation: str | DeprecatedArg = DeprecatedArg()) -> Array: check_arraylike("percentile", a, q) q, = promote_dtypes_inexact(q) + if not isinstance(interpolation, DeprecatedArg): + warnings.warn("The interpolation= argument to 'percentile' is deprecated. " + "Use 'method=' instead.", DeprecationWarning, stacklevel=2) + method = interpolation return quantile(a, q / 100, axis=axis, out=out, overwrite_input=overwrite_input, - interpolation=interpolation, method=method, keepdims=keepdims) + method=method, keepdims=keepdims) +# TODO(jakevdp): interpolation argument deprecated 2024-05-16 @implements(np.nanpercentile, skip_params=['out', 'overwrite_input']) -@partial(api.jit, static_argnames=('axis', 'overwrite_input', 'interpolation', - 'keepdims', 'method')) +@partial(api.jit, static_argnames=('axis', 'overwrite_input', 'interpolation', 'keepdims', 'method')) def nanpercentile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = None, out: None = None, overwrite_input: bool = False, method: str = "linear", - keepdims: bool = False, interpolation: None = None) -> Array: + keepdims: bool = False, *, interpolation: str | DeprecatedArg = DeprecatedArg()) -> Array: check_arraylike("nanpercentile", a, q) q = ufuncs.true_divide(q, 100.0) + if not isinstance(interpolation, DeprecatedArg): + warnings.warn("The interpolation= argument to 'nanpercentile' is deprecated. " + "Use 'method=' instead.", DeprecationWarning, stacklevel=2) + method = interpolation return nanquantile(a, q, axis=axis, out=out, overwrite_input=overwrite_input, - interpolation=interpolation, method=method, - keepdims=keepdims) + method=method, keepdims=keepdims) @implements(np.median, skip_params=['out', 'overwrite_input']) @partial(api.jit, static_argnames=('axis', 'overwrite_input', 'keepdims')) diff --git a/jax/_src/numpy/setops.py b/jax/_src/numpy/setops.py index 15dc52cda55c..6ac7ce804d8f 100644 --- a/jax/_src/numpy/setops.py +++ b/jax/_src/numpy/setops.py @@ -17,7 +17,6 @@ from functools import partial import math import operator -from textwrap import dedent as _dedent from typing import cast, NamedTuple import numpy as np @@ -34,7 +33,7 @@ sort, where, zeros) from jax._src.numpy.reductions import any, cumsum from jax._src.numpy.ufuncs import isnan -from jax._src.numpy.util import check_arraylike, implements +from jax._src.numpy.util import check_arraylike from jax._src.util import canonicalize_axis from jax._src.typing import Array, ArrayLike @@ -61,21 +60,73 @@ def _in1d(ar1: ArrayLike, ar2: ArrayLike, invert: bool) -> Array: else: return (ar1_flat[:, None] == ar2_flat[None, :]).any(-1) -@implements(np.setdiff1d, - lax_description=_dedent(""" - Because the size of the output of ``setdiff1d`` is data-dependent, the function is not - typically compatible with JIT. The JAX version adds the optional ``size`` argument which - must be specified statically for ``jnp.setdiff1d`` to be used within some of JAX's - transformations."""), - extra_params=_dedent(""" - size : int, optional - If specified, the first ``size`` elements of the result will be returned. If there are - fewer elements than ``size`` indicates, the return value will be padded with ``fill_value``. - fill_value : array_like, optional - When ``size`` is specified and there are fewer than the indicated number of elements, the - remaining elements will be filled with ``fill_value``, which defaults to zero.""")) + def setdiff1d(ar1: ArrayLike, ar2: ArrayLike, assume_unique: bool = False, *, size: int | None = None, fill_value: ArrayLike | None = None) -> Array: + """Compute the set difference of two 1D arrays. + + JAX implementation of :func:`numpy.setdiff1d`. + + Because the size of the output of ``setdiff1d`` is data-dependent, the function + semantics are not typically compatible with :func:`~jax.jit` and other JAX + transformations. The JAX version adds the optional ``size`` argument which + must be specified statically for ``jnp.setdiff1d`` to be used in such contexts. + transformations. + + Args: + ar1: first array of elements to be differenced. + ar2: second array of elements to be differenced. + assume_unique: if True, assume the input arrays contain unique values. This allows + a more efficient implementation, but if ``assume_unique`` is True and the input + arrays contain duplicates, the behavior is undefined. default: False. + size: if specified, return only the first ``size`` sorted elements. If there are fewer + elements than ``size`` indicates, the return value will be padded with ``fill_value``. + fill_value: when ``size`` is specified and there are fewer than the indicated number of + elements, fill the remaining entries ``fill_value``. Defaults to the minimum value. + + Returns: + an array containing the set difference of elements in the input array: i.e. the elements + in ``ar1`` that are not contained in ``ar2``. + + See also: + - :func:`jax.numpy.intersect1d`: the set intersection of two 1D arrays. + - :func:`jax.numpy.setxor1d`: the set XOR of two 1D arrays. + - :func:`jax.numpy.union1d`: the set union of two 1D arrays. + + Examples: + Computing the set difference of two arrays: + + >>> ar1 = jnp.array([1, 2, 3, 4]) + >>> ar2 = jnp.array([3, 4, 5, 6]) + >>> jnp.setdiff1d(ar1, ar2) + Array([1, 2], dtype=int32) + + Because the output shape is dynamic, this will fail under :func:`~jax.jit` and other + transformations: + + >>> jax.jit(jnp.setdiff1d)(ar1, ar2) # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape int32[4]. + The error occurred while tracing the function setdiff1d at /Users/vanderplas/github/google/jax/jax/_src/numpy/setops.py:64 for jit. This concrete value was not available in Python because it depends on the value of the argument ar1. + + In order to ensure statically-known output shapes, you can pass a static ``size`` + argument: + + >>> jit_setdiff1d = jax.jit(jnp.setdiff1d, static_argnames=['size']) + >>> jit_setdiff1d(ar1, ar2, size=2) + Array([1, 2], dtype=int32) + + If ``size`` is too small, the difference is truncated: + + >>> jit_setdiff1d(ar1, ar2, size=1) + Array([1], dtype=int32) + + If ``size`` is too large, then the output is padded with ``fill_value``: + + >>> jit_setdiff1d(ar1, ar2, size=4, fill_value=0) + Array([1, 2, 0, 0], dtype=int32) + """ check_arraylike("setdiff1d", ar1, ar2) if size is None: ar1 = core.concrete_or_error(None, ar1, "The error arose in setdiff1d()") @@ -98,22 +149,68 @@ def setdiff1d(ar1: ArrayLike, ar2: ArrayLike, assume_unique: bool = False, return where(arange(size) < mask.sum(), arr1[where(mask, size=size)], fill_value) -@implements(np.union1d, - lax_description=_dedent(""" - Because the size of the output of ``union1d`` is data-dependent, the function is not - typically compatible with JIT. The JAX version adds the optional ``size`` argument which - must be specified statically for ``jnp.union1d`` to be used within some of JAX's - transformations."""), - extra_params=_dedent(""" - size : int, optional - If specified, the first ``size`` elements of the result will be returned. If there are - fewer elements than ``size`` indicates, the return value will be padded with ``fill_value``. - fill_value : array_like, optional - When ``size`` is specified and there are fewer than the indicated number of elements, the - remaining elements will be filled with ``fill_value``, which defaults to the minimum - value of the union.""")) def union1d(ar1: ArrayLike, ar2: ArrayLike, *, size: int | None = None, fill_value: ArrayLike | None = None) -> Array: + """Compute the set union of two 1D arrays. + + JAX implementation of :func:`numpy.union1d`. + + Because the size of the output of ``union1d`` is data-dependent, the function + semantics are not typically compatible with :func:`~jax.jit` and other JAX + transformations. The JAX version adds the optional ``size`` argument which + must be specified statically for ``jnp.union1d`` to be used in such contexts. + transformations. + + Args: + ar1: first array of elements to be unioned. + ar2: second array of elements to be unioned + size: if specified, return only the first ``size`` sorted elements. If there are fewer + elements than ``size`` indicates, the return value will be padded with ``fill_value``. + fill_value: when ``size`` is specified and there are fewer than the indicated number of + elements, fill the remaining entries ``fill_value``. Defaults to the minimum value. + + Returns: + an array containing the union of elements in the input array. + + See also: + - :func:`jax.numpy.intersect1d`: the set intersection of two 1D arrays. + - :func:`jax.numpy.setxor1d`: the set XOR of two 1D arrays. + - :func:`jax.numpy.setdiff1d`: the set difference of two 1D arrays. + + Examples: + Computing the union of two arrays: + + >>> ar1 = jnp.array([1, 2, 3, 4]) + >>> ar2 = jnp.array([3, 4, 5, 6]) + >>> jnp.union1d(ar1, ar2) + Array([1, 2, 3, 4, 5, 6], dtype=int32) + + Because the output shape is dynamic, this will fail under :func:`~jax.jit` and other + transformations: + + >>> jax.jit(jnp.union1d)(ar1, ar2) # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape int32[4]. + The error occurred while tracing the function union1d at /Users/vanderplas/github/google/jax/jax/_src/numpy/setops.py:101 for jit. This concrete value was not available in Python because it depends on the value of the argument ar1. + + In order to ensure statically-known output shapes, you can pass a static ``size`` + argument: + + >>> jit_union1d = jax.jit(jnp.union1d, static_argnames=['size']) + >>> jit_union1d(ar1, ar2, size=6) + Array([1, 2, 3, 4, 5, 6], dtype=int32) + + If ``size`` is too small, the union is truncated: + + >>> jit_union1d(ar1, ar2, size=4) + Array([1, 2, 3, 4], dtype=int32) + + If ``size`` is too large, then the output is padded with ``fill_value``: + + >>> jit_union1d(ar1, ar2, size=8, fill_value=0) + Array([1, 2, 3, 4, 5, 6, 0, 0], dtype=int32) + """ check_arraylike("union1d", ar1, ar2) if size is None: ar1 = core.concrete_or_error(None, ar1, "The error arose in union1d()") @@ -125,11 +222,35 @@ def union1d(ar1: ArrayLike, ar2: ArrayLike, return cast(Array, out) -@implements(np.setxor1d, lax_description=""" -In the JAX version, the input arrays are explicitly flattened regardless -of assume_unique value. -""") def setxor1d(ar1: ArrayLike, ar2: ArrayLike, assume_unique: bool = False) -> Array: + """Compute the set-wise xor of elements in two arrays. + + JAX implementation of :func:`numpy.setxor1d`. + + Because the size of the output of ``setxor1d`` is data-dependent, the function is not + compatible with JIT or other JAX transformations. + + Args: + ar1: first array of values to intersect. + ar2: second array of values to intersect. + assume_unique: if True, assume the input arrays contain unique values. This allows + a more efficient implementation, but if ``assume_unique`` is True and the input + arrays contain duplicates, the behavior is undefined. default: False. + + Returns: + An array of values that are found in exactly one of the input arrays. + + See also: + - :func:`jax.numpy.intersect1d`: the set intersection of two 1D arrays. + - :func:`jax.numpy.union1d`: the set union of two 1D arrays. + - :func:`jax.numpy.setdiff1d`: the set difference of two 1D arrays. + + Examples: + >>> ar1 = jnp.array([1, 2, 3, 4]) + >>> ar2 = jnp.array([3, 4, 5, 6]) + >>> jnp.setxor1d(ar1, ar2) + Array([1, 2, 5, 6], dtype=int32) + """ check_arraylike("setxor1d", ar1, ar2) ar1 = core.concrete_or_error(None, ar1, "The error arose in setxor1d()") ar2 = core.concrete_or_error(None, ar2, "The error arose in setxor1d()") @@ -152,9 +273,7 @@ def setxor1d(ar1: ArrayLike, ar2: ArrayLike, assume_unique: bool = False) -> Arr @partial(jit, static_argnames=['return_indices']) def _intersect1d_sorted_mask(ar1: ArrayLike, ar2: ArrayLike, return_indices: bool = False) -> tuple[Array, ...]: - """ - Helper function for intersect1d which is jit-able - """ + # JIT-compatible helper function for intersect1d ar = concatenate((ar1, ar2)) if return_indices: iota = lax.broadcasted_iota(np.int64, np.shape(ar), dimension=0) @@ -169,9 +288,70 @@ def _intersect1d_sorted_mask(ar1: ArrayLike, ar2: ArrayLike, return_indices: boo return aux, mask -@implements(np.intersect1d) def intersect1d(ar1: ArrayLike, ar2: ArrayLike, assume_unique: bool = False, return_indices: bool = False) -> Array | tuple[Array, Array, Array]: + """Compute the set intersection of two 1D arrays. + + JAX implementation of :func:`numpy.intersect1d`. + + Because the size of the output of ``intersect1d`` is data-dependent, the function is not + compatible with JIT or other JAX transformations. + + Args: + ar1: first array of values to intersect. + ar2: second array of values to intersect. + assume_unique: if True, assume the input arrays contain unique values. This allows + a more efficient implementation, but if ``assume_unique`` is True and the input + arrays contain duplicates, the behavior is undefined. default: False. + return_indices: If True, return arrays of indices specifying where the intersected + values first appear in the input arrays. + + Returns: + An array ``intersection``, or if ``return_indices=True``, a tuple of arrays + ``(intersection, ar1_indices, ar2_indices)``. Returned values are + + - ``intersection``: + A 1D array containing each value that appears in both ``ar1`` and ``ar2``. + - ``ar1_indices``: + *(returned if return_indices=True)* an array of shape ``intersection.shape`` containing + the indices in flattened ``ar1`` of values in ``intersection``. For 1D inputs, + ``intersection`` is equivalent to ``ar1[ar1_indices]``. + - ``ar2_indices``: + *(returned if return_indices=True)* an array of shape ``intersection.shape`` containing + the indices in flattened ``ar2`` of values in ``intersection``. For 1D inputs, + ``intersection`` is equivalent to ``ar2[ar2_indices]``. + + See also: + - :func:`jax.numpy.union1d`: the set union of two 1D arrays. + - :func:`jax.numpy.setxor1d`: the set XOR of two 1D arrays. + - :func:`jax.numpy.setdiff1d`: the set difference of two 1D arrays. + + Examples: + >>> ar1 = jnp.array([1, 2, 3, 4]) + >>> ar2 = jnp.array([3, 4, 5, 6]) + >>> jnp.intersect1d(ar1, ar2) + Array([3, 4], dtype=int32) + + Computing intersection with indices: + + >>> intersection, ar1_indices, ar2_indices = jnp.intersect1d(ar1, ar2, return_indices=True) + >>> intersection + Array([3, 4], dtype=int32) + + ``ar1_indices`` gives the indices of the intersected values within ``ar1``: + + >>> ar1_indices + Array([2, 3], dtype=int32) + >>> jnp.all(intersection == ar1[ar1_indices]) + Array(True, dtype=bool) + + ``ar2_indices`` gives the indices of the intersected values within ``ar2``: + + >>> ar2_indices + Array([0, 1], dtype=int32) + >>> jnp.all(intersection == ar2[ar2_indices]) + Array(True, dtype=bool) + """ check_arraylike("intersect1d", ar1, ar2) ar1 = core.concrete_or_error(None, ar1, "The error arose in intersect1d()") ar2 = core.concrete_or_error(None, ar2, "The error arose in intersect1d()") @@ -206,11 +386,29 @@ def intersect1d(ar1: ArrayLike, ar2: ArrayLike, assume_unique: bool = False, return int1d -@implements(np.isin, lax_description=""" -In the JAX version, the `assume_unique` argument is not referenced. -""") def isin(element: ArrayLike, test_elements: ArrayLike, assume_unique: bool = False, invert: bool = False) -> Array: + """Determine whether elements in ``element`` appear in ``test_elements``. + + JAX implementation of :func:`numpy.isin`. + + Args: + element: input array of elements for which membership will be checked. + test_elements: N-dimensional array of test values to check for the presence of + each element. + invert: If True, return ``~isin(element, test_elements)``. Default is False. + assume_unique: unused by JAX + + Returns: + A boolean array of shape ``element.shape`` that specifies whether each element + appears in ``test_elements``. + + Examples: + >>> elements = jnp.array([1, 2, 3, 4]) + >>> test_elements = jnp.array([[1, 5, 6, 3, 7, 1]]) + >>> jnp.isin(elements, test_elements) + Array([ True, False, True, False], dtype=bool) + """ del assume_unique # unused check_arraylike("isin", element, test_elements) result = _in1d(element, test_elements, invert=invert) @@ -312,23 +510,176 @@ def _unique(ar: Array, axis: int, return_index: bool = False, return_inverse: bo ret += (mask.sum(),) return ret[0] if len(ret) == 1 else ret -@implements(np.unique, skip_params=['axis'], - lax_description=_dedent(""" - Because the size of the output of ``unique`` is data-dependent, the function is not - typically compatible with JIT. The JAX version adds the optional ``size`` argument which - must be specified statically for ``jnp.unique`` to be used within some of JAX's - transformations."""), - extra_params=_dedent(""" - size : int, optional - If specified, the first ``size`` unique elements will be returned. If there are fewer unique - elements than ``size`` indicates, the return value will be padded with ``fill_value``. - fill_value : array_like, optional - When ``size`` is specified and there are fewer than the indicated number of elements, the - remaining elements will be filled with ``fill_value``. The default is the minimum value - along the specified axis of the input.""")) + def unique(ar: ArrayLike, return_index: bool = False, return_inverse: bool = False, return_counts: bool = False, axis: int | None = None, *, equal_nan: bool = True, size: int | None = None, fill_value: ArrayLike | None = None): + """Return the unique values from an array. + + JAX implementation of :func:`numpy.unique`. + + Because the size of the output of ``unique`` is data-dependent, the function + semantics are not typically compatible with :func:`~jax.jit` and other JAX + transformations. The JAX version adds the optional ``size`` argument which + must be specified statically for ``jnp.unique`` to be used in such contexts. + + Args: + ar: N-dimensional array from which unique values will be extracted. + return_index: if True, also return the indices in ``ar`` where each value occurs + return_inverse: if True, also return the indices that can be used to reconstruct + ``ar`` from the unique values. + return_counts: if True, also return the number of occurrences of each unique value. + axis: if specified, compute unique values along the specified axis. If None (default), + then flatten ``ar`` before computing the unique values. + equal_nan: if True, consider NaN values equivalent when determining uniqueness. + size: if specified, return only the first ``size`` sorted unique elements. If there are fewer + unique elements than ``size`` indicates, the return value will be padded with ``fill_value``. + fill_value: when ``size`` is specified and there are fewer than the indicated number of + elements, fill the remaining entries ``fill_value``. Defaults to the minimum unique value. + + Returns: + An array or tuple of arrays, depending on the values of ``return_index``, ``return_inverse``, + and ``return_counts``. Returned values are + + - ``unique_values``: + if ``axis`` is None, a 1D array of length ``n_unique``, If ``axis`` is + specified, shape is ``(*ar.shape[:axis], n_unique, *ar.shape[axis + 1:])``. + - ``unique_index``: + *(returned only if return_index is True)* An array of shape ``(n_unique,)``. Contains + the indices of the first occurrence of each unique value in ``ar``. For 1D inputs, + ``ar[unique_index]`` is equivalent to ``unique_values``. + - ``unique_inverse``: + *(returned only if return_inverse is True)* An array of shape ``(ar.size,)`` if ``axis`` + is None, or of shape ``(1, 1, ..., ar.shape[axis], 1, ... 1)`` if ``axis`` is specified. + Contains the indices within ``unique_values`` of each value in ``ar``. For 1D inputs, + ``unique_values[unique_inverse]`` is equivalent to ``ar``. + - ``unique_counts``: + *(returned only if return_counts is True)* An array of shape ``(n_unique,)``. + Contains the number of occurrences of each unique value in ``ar``. + + See also: + - :func:`jax.numpy.unique_counts`: shortcut to ``unique(arr, return_counts=True)``. + - :func:`jax.numpy.unique_inverse`: shortcut to ``unique(arr, return_inverse=True)``. + - :func:`jax.numpy.unique_all`: shortcut to ``unique`` with all return values. + - :func:`jax.numpy.unique_values`: like ``unique``, but no optional return values. + + Examples: + >>> x = jnp.array([3, 4, 1, 3, 1]) + >>> jnp.unique(x) + Array([1, 3, 4], dtype=int32) + + **JIT compilation & the size argument** + + If you try this under :func:`~jax.jit` or another transformation, you will get an + error because the output shape is dynamic: + + >>> jax.jit(jnp.unique)(x) # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + jax.errors.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape int32[5]. + The error arose for the first argument of jnp.unique(). To make jnp.unique() compatible with JIT and other transforms, you can specify a concrete value for the size argument, which will determine the output size. + + The issue is that the output of transformed functions must have static shapes. + In order to make this work, you can pass a static ``size`` parameter: + + >>> jit_unique = jax.jit(jnp.unique, static_argnames=['size']) + >>> jit_unique(x, size=3) + Array([1, 3, 4], dtype=int32) + + If your static size is smaller than the true number of unique values, they will be truncated. + + >>> jit_unique(x, size=2) + Array([1, 3], dtype=int32) + + If the static size is larger than the true number of unique values, they will be padded with + ``fill_value``, which defaults to the minimum unique value: + + >>> jit_unique(x, size=5) + Array([1, 3, 4, 1, 1], dtype=int32) + >>> jit_unique(x, size=5, fill_value=0) + Array([1, 3, 4, 0, 0], dtype=int32) + + **Multi-dimensional unique values** + + If you pass a multi-dimensional array to ``unique``, it will be flattened by default: + + >>> M = jnp.array([[1, 2], + ... [2, 3], + ... [1, 2]]) + >>> jnp.unique(M) + Array([1, 2, 3], dtype=int32) + + If you pass an ``axis`` keyword, you can find unique *slices* of the array along + that axis: + + >>> jnp.unique(M, axis=0) + Array([[1, 2], + [2, 3]], dtype=int32) + + **Returning indices** + + If you set ``return_index=True``, then ``unique`` returns the indices of the + first occurrence of each unique value: + + >>> x = jnp.array([3, 4, 1, 3, 1]) + >>> values, indices = jnp.unique(x, return_index=True) + >>> print(values) + [1 3 4] + >>> print(indices) + [2 0 1] + >>> jnp.all(values == x[indices]) + Array(True, dtype=bool) + + In multiple dimensions, the unique values can be extracted with :func:`jax.numpy.take` + evaluated along the specified axis: + + >>> values, indices = jnp.unique(M, axis=0, return_index=True) + >>> jnp.all(values == jnp.take(M, indices, axis=0)) + Array(True, dtype=bool) + + **Returning inverse** + + If you set ``return_inverse=True``, then ``unique`` returns the indices within the + unique values for every entry in the input array: + + >>> x = jnp.array([3, 4, 1, 3, 1]) + >>> values, inverse = jnp.unique(x, return_inverse=True) + >>> print(values) + [1 3 4] + >>> print(inverse) + [1 2 0 1 0] + >>> jnp.all(values[inverse] == x) + Array(True, dtype=bool) + + In multiple dimensions, the input can be reconstructed using + :func:`jax.numpy.take_along_axis`: + + >>> values, inverse = jnp.unique(M, axis=0, return_inverse=True) + >>> jnp.all(jnp.take_along_axis(values, inverse, axis=0) == M) + Array(True, dtype=bool) + + **Returning counts** + + If you set ``return_counts=True``, then ``unique`` returns the number of occurrences + within the input for every unique value: + + >>> x = jnp.array([3, 4, 1, 3, 1]) + >>> values, counts = jnp.unique(x, return_counts=True) + >>> print(values) + [1 3 4] + >>> print(counts) + [2 2 1] + + For multi-dimensional arrays, this also returns a 1D array of counts + indicating number of occurrences along the specified axis: + + >>> values, counts = jnp.unique(M, axis=0, return_counts=True) + >>> print(values) + [[1 2] + [2 3]] + >>> print(counts) + [2 1] + """ check_arraylike("unique", ar) if size is None: ar = core.concrete_or_error(None, ar, @@ -352,6 +703,7 @@ def unique(ar: ArrayLike, return_index: bool = False, return_inverse: bool = Fal class _UniqueAllResult(NamedTuple): + """Struct returned by :func:`jax.numpy.unique_all`.""" values: Array indices: Array inverse_indices: Array @@ -359,38 +711,260 @@ class _UniqueAllResult(NamedTuple): class _UniqueCountsResult(NamedTuple): + """Struct returned by :func:`jax.numpy.unique_counts`.""" values: Array counts: Array class _UniqueInverseResult(NamedTuple): + """Struct returned by :func:`jax.numpy.unique_inverse`.""" values: Array inverse_indices: Array -@implements(getattr(np, "unique_all", None)) -def unique_all(x: ArrayLike, /) -> _UniqueAllResult: +def unique_all(x: ArrayLike, /, *, size: int | None = None, + fill_value: ArrayLike | None = None) -> _UniqueAllResult: + """Return unique values from x, along with indices, inverse indices, and counts. + + JAX implementation of :func:`numpy.unique_all`; this is equivalent to calling + :func:`jax.numpy.unique` with `return_index`, `return_inverse`, `return_counts`, + and `equal_nan` set to True. + + Because the size of the output of ``unique_all`` is data-dependent, the function + semantics are not typically compatible with :func:`~jax.jit` and other JAX + transformations. The JAX version adds the optional ``size`` argument which + must be specified statically for ``jnp.unique`` to be used in such contexts. + + Args: + x: N-dimensional array from which unique values will be extracted. + size: if specified, return only the first ``size`` sorted unique elements. If there are fewer + unique elements than ``size`` indicates, the return value will be padded with ``fill_value``. + fill_value: when ``size`` is specified and there are fewer than the indicated number of + elements, fill the remaining entries ``fill_value``. Defaults to the minimum unique value. + + Returns: + A tuple ``(values, indices, inverse_indices, counts)``, with the following properties: + + - ``values``: + an array of shape ``(n_unique,)`` containing the unique values from ``x``. + - ``indices``: + An array of shape ``(n_unique,)``. Contains the indices of the first occurrence of + each unique value in ``x``. For 1D inputs, ``x[indices]`` is equivalent to ``values``. + - ``inverse_indices``: + An array of shape ``x.shape``. Contains the indices within ``values`` of each value + in ``x``. For 1D inputs, ``values[inverse_indices]`` is equivalent to ``x``. + - ``counts``: + An array of shape ``(n_unique,)``. Contains the number of occurrences of each unique + value in ``x``. + + See also: + - :func:`jax.numpy.unique`: general function for computing unique values. + - :func:`jax.numpy.unique_values`: compute only ``values``. + - :func:`jax.numpy.unique_counts`: compute only ``values`` and ``counts``. + - :func:`jax.numpy.unique_inverse`: compute only ``values`` and ``inverse``. + + Examples: + Here we compute the unique values in a 1D array: + + >>> x = jnp.array([3, 4, 1, 3, 1]) + >>> result = jnp.unique_all(x) + + The result is a :class:`~typing.NamedTuple` with four named attributes. + The ``values`` attribute contains the unique values from the array: + + >>> result.values + Array([1, 3, 4], dtype=int32) + + The ``indices`` attribute contains the indices of the unique ``values`` within + the input array: + + >>> result.indices + Array([2, 0, 1], dtype=int32) + >>> jnp.all(result.values == x[result.indices]) + Array(True, dtype=bool) + + The ``inverse_indices`` attribute contains the indices of the input within ``values``: + + >>> result.inverse_indices + Array([1, 2, 0, 1, 0], dtype=int32) + >>> jnp.all(x == result.values[result.inverse_indices]) + Array(True, dtype=bool) + + The ``counts`` attribute contains the counts of each unique value in the input: + + >>> result.counts + Array([2, 2, 1], dtype=int32) + + For examples of the ``size`` and ``fill_value`` arguments, see :func:`jax.numpy.unique`. + """ check_arraylike("unique_all", x) values, indices, inverse_indices, counts = unique( - x, return_index=True, return_inverse=True, return_counts=True, equal_nan=False) + x, return_index=True, return_inverse=True, return_counts=True, equal_nan=False, + size=size, fill_value=fill_value) return _UniqueAllResult(values=values, indices=indices, inverse_indices=inverse_indices, counts=counts) -@implements(getattr(np, "unique_counts", None)) -def unique_counts(x: ArrayLike, /) -> _UniqueCountsResult: +def unique_counts(x: ArrayLike, /, *, size: int | None = None, + fill_value: ArrayLike | None = None) -> _UniqueCountsResult: + """Return unique values from x, along with counts. + + JAX implementation of :func:`numpy.unique_counts`; this is equivalent to calling + :func:`jax.numpy.unique` with `return_counts` and `equal_nan` set to True. + + Because the size of the output of ``unique_counts`` is data-dependent, the function + semantics are not typically compatible with :func:`~jax.jit` and other JAX + transformations. The JAX version adds the optional ``size`` argument which + must be specified statically for ``jnp.unique`` to be used in such contexts. + + Args: + x: N-dimensional array from which unique values will be extracted. + size: if specified, return only the first ``size`` sorted unique elements. If there are fewer + unique elements than ``size`` indicates, the return value will be padded with ``fill_value``. + fill_value: when ``size`` is specified and there are fewer than the indicated number of + elements, fill the remaining entries ``fill_value``. Defaults to the minimum unique value. + + Returns: + A tuple ``(values, counts)``, with the following properties: + + - ``values``: + an array of shape ``(n_unique,)`` containing the unique values from ``x``. + - ``counts``: + An array of shape ``(n_unique,)``. Contains the number of occurrences of each unique + value in ``x``. + + See also: + - :func:`jax.numpy.unique`: general function for computing unique values. + - :func:`jax.numpy.unique_values`: compute only ``values``. + - :func:`jax.numpy.unique_inverse`: compute only ``values`` and ``inverse``. + - :func:`jax.numpy.unique_all`: compute ``values``, ``indices``, ``inverse_indices``, + and ``counts``. + + Examples: + Here we compute the unique values in a 1D array: + + >>> x = jnp.array([3, 4, 1, 3, 1]) + >>> result = jnp.unique_counts(x) + + The result is a :class:`~typing.NamedTuple` with two named attributes. + The ``values`` attribute contains the unique values from the array: + + >>> result.values + Array([1, 3, 4], dtype=int32) + + The ``counts`` attribute contains the counts of each unique value in the input: + + >>> result.counts + Array([2, 2, 1], dtype=int32) + + For examples of the ``size`` and ``fill_value`` arguments, see :func:`jax.numpy.unique`. + """ check_arraylike("unique_counts", x) - values, counts = unique(x, return_counts=True, equal_nan=False) + values, counts = unique(x, return_counts=True, equal_nan=False, + size=size, fill_value=fill_value) return _UniqueCountsResult(values=values, counts=counts) -@implements(getattr(np, "unique_inverse", None)) -def unique_inverse(x: ArrayLike, /) -> _UniqueInverseResult: +def unique_inverse(x: ArrayLike, /, *, size: int | None = None, + fill_value: ArrayLike | None = None) -> _UniqueInverseResult: + """Return unique values from x, along with indices, inverse indices, and counts. + + JAX implementation of :func:`numpy.unique_inverse`; this is equivalent to calling + :func:`jax.numpy.unique` with `return_inverse` and `equal_nan` set to True. + + Because the size of the output of ``unique_inverse`` is data-dependent, the function + semantics are not typically compatible with :func:`~jax.jit` and other JAX + transformations. The JAX version adds the optional ``size`` argument which + must be specified statically for ``jnp.unique`` to be used in such contexts. + + Args: + x: N-dimensional array from which unique values will be extracted. + size: if specified, return only the first ``size`` sorted unique elements. If there are fewer + unique elements than ``size`` indicates, the return value will be padded with ``fill_value``. + fill_value: when ``size`` is specified and there are fewer than the indicated number of + elements, fill the remaining entries ``fill_value``. Defaults to the minimum unique value. + + Returns: + A tuple ``(values, indices, inverse_indices, counts)``, with the following properties: + + - ``values``: + an array of shape ``(n_unique,)`` containing the unique values from ``x``. + - ``inverse_indices``: + An array of shape ``x.shape``. Contains the indices within ``values`` of each value + in ``x``. For 1D inputs, ``values[inverse_indices]`` is equivalent to ``x``. + + See also: + - :func:`jax.numpy.unique`: general function for computing unique values. + - :func:`jax.numpy.unique_values`: compute only ``values``. + - :func:`jax.numpy.unique_counts`: compute only ``values`` and ``counts``. + - :func:`jax.numpy.unique_all`: compute ``values``, ``indices``, ``inverse_indices``, + and ``counts``. + + Examples: + Here we compute the unique values in a 1D array: + + >>> x = jnp.array([3, 4, 1, 3, 1]) + >>> result = jnp.unique_inverse(x) + + The result is a :class:`~typing.NamedTuple` with two named attributes. + The ``values`` attribute contains the unique values from the array: + + >>> result.values + Array([1, 3, 4], dtype=int32) + + The ``indices`` attribute contains the indices of the unique ``values`` within + the input array: + + The ``inverse_indices`` attribute contains the indices of the input within ``values``: + + >>> result.inverse_indices + Array([1, 2, 0, 1, 0], dtype=int32) + >>> jnp.all(x == result.values[result.inverse_indices]) + Array(True, dtype=bool) + + For examples of the ``size`` and ``fill_value`` arguments, see :func:`jax.numpy.unique`. + """ check_arraylike("unique_inverse", x) - values, inverse_indices = unique(x, return_inverse=True, equal_nan=False) + values, inverse_indices = unique(x, return_inverse=True, equal_nan=False, + size=size, fill_value=fill_value) return _UniqueInverseResult(values=values, inverse_indices=inverse_indices) -@implements(getattr(np, "unique_values", None)) -def unique_values(x: ArrayLike, /) -> Array: +def unique_values(x: ArrayLike, /, *, size: int | None = None, + fill_value: ArrayLike | None = None) -> Array: + """Return unique values from x, along with indices, inverse indices, and counts. + + JAX implementation of :func:`numpy.unique_values`; this is equivalent to calling + :func:`jax.numpy.unique` with `equal_nan` set to True. + + Because the size of the output of ``unique_values`` is data-dependent, the function + semantics are not typically compatible with :func:`~jax.jit` and other JAX + transformations. The JAX version adds the optional ``size`` argument which + must be specified statically for ``jnp.unique`` to be used in such contexts. + + Args: + x: N-dimensional array from which unique values will be extracted. + size: if specified, return only the first ``size`` sorted unique elements. If there are fewer + unique elements than ``size`` indicates, the return value will be padded with ``fill_value``. + fill_value: when ``size`` is specified and there are fewer than the indicated number of + elements, fill the remaining entries ``fill_value``. Defaults to the minimum unique value. + + Returns: + An array ``values`` of shape ``(n_unique,)`` containing the unique values from ``x``. + + See also: + - :func:`jax.numpy.unique`: general function for computing unique values. + - :func:`jax.numpy.unique_values`: compute only ``values``. + - :func:`jax.numpy.unique_counts`: compute only ``values`` and ``counts``. + - :func:`jax.numpy.unique_inverse`: compute only ``values`` and ``inverse``. + + Examples: + Here we compute the unique values in a 1D array: + + >>> x = jnp.array([3, 4, 1, 3, 1]) + >>> jnp.unique_values(x) + Array([1, 3, 4], dtype=int32) + + For examples of the ``size`` and ``fill_value`` arguments, see :func:`jax.numpy.unique`. + """ check_arraylike("unique_values", x) - return cast(Array, unique(x, equal_nan=False)) + return cast(Array, unique(x, equal_nan=False, size=size, fill_value=fill_value)) diff --git a/jax/_src/numpy/ufunc_api.py b/jax/_src/numpy/ufunc_api.py index 2d3eb1edf5c2..2e114193af13 100644 --- a/jax/_src/numpy/ufunc_api.py +++ b/jax/_src/numpy/ufunc_api.py @@ -16,10 +16,11 @@ from __future__ import annotations +from collections.abc import Callable from functools import partial import math import operator -from typing import Any, Callable +from typing import Any import jax from jax._src.typing import Array, ArrayLike, DTypeLike @@ -258,7 +259,7 @@ def scan_fun(carry, _): _, result = jax.lax.scan(scan_fun, (0, arr[0].astype(dtype)), None, length=arr.shape[0]) return _moveaxis(result, 0, axis) - @implements(np.ufunc.accumulate, module="numpy.ufunc") + @implements(np.ufunc.at, module="numpy.ufunc") @partial(jax.jit, static_argnums=[0], static_argnames=['inplace']) def at(self, a: ArrayLike, indices: Any, b: ArrayLike | None = None, /, *, inplace: bool = True) -> Array: @@ -354,7 +355,7 @@ def frompyfunc(func: Callable[..., Any], /, nin: int, nout: int, """Create a JAX ufunc from an arbitrary JAX-compatible scalar function. Args: - func : a callable that takes `nin` scalar arguments and return `nout` outputs. + func : a callable that takes `nin` scalar arguments and returns `nout` outputs. nin: integer specifying the number of scalar inputs nout: integer specifying the number of scalar outputs identity: (optional) a scalar specifying the identity of the operation, if any. diff --git a/jax/_src/numpy/ufuncs.py b/jax/_src/numpy/ufuncs.py index dee3d9a12d89..707a09fc6930 100644 --- a/jax/_src/numpy/ufuncs.py +++ b/jax/_src/numpy/ufuncs.py @@ -18,10 +18,11 @@ from __future__ import annotations +from collections.abc import Callable from functools import partial import operator -from textwrap import dedent -from typing import Any, Callable, overload + +import warnings import numpy as np @@ -44,186 +45,404 @@ 64: np.int64, } -UnOp = Callable[[ArrayLike], Array] -BinOp = Callable[[ArrayLike, ArrayLike], Array] - - def _constant_like(x, const): return np.array(const, dtype=dtypes.dtype(x)) - def _replace_inf(x: ArrayLike) -> Array: return lax.select(isposinf(real(x)), lax._zeros(x), x) +def _to_bool(x: Array) -> Array: + return x if x.dtype == bool else lax.ne(x, _lax_const(x, 0)) -def _one_to_one_unop( - numpy_fn: Callable[..., Any], lax_fn: UnOp, - promote_to_inexact: bool = False, lax_doc: bool = False) -> UnOp: - if promote_to_inexact: - fn = lambda x, /: lax_fn(*promote_args_inexact(numpy_fn.__name__, x)) - else: - fn = lambda x, /: lax_fn(*promote_args(numpy_fn.__name__, x)) - fn.__name__ = numpy_fn.__name__ - fn.__qualname__ = f"jax.numpy.{numpy_fn.__name__}" - fn = jit(fn, inline=True) - if lax_doc: - doc = dedent('\n\n'.join(lax_fn.__doc__.split('\n\n')[1:])).strip() # type: ignore[union-attr] - return implements(numpy_fn, lax_description=doc, module='numpy')(fn) - else: - return implements(numpy_fn, module='numpy')(fn) +@implements(np.fabs, module='numpy') +@partial(jit, inline=True) +def fabs(x: ArrayLike, /) -> Array: + return lax.abs(*promote_args_inexact('fabs', x)) +@implements(getattr(np, 'bitwise_invert', np.invert), module='numpy') +@partial(jit, inline=True) +def bitwise_invert(x: ArrayLike, /) -> Array: + return lax.bitwise_not(*promote_args('bitwise_invert', x)) -def _one_to_one_binop( - numpy_fn: Callable[..., Any], lax_fn: BinOp, - promote_to_inexact: bool = False, lax_doc: bool = False, - promote_to_numeric: bool = False) -> BinOp: - if promote_to_inexact: - fn = lambda x1, x2, /: lax_fn(*promote_args_inexact(numpy_fn.__name__, x1, x2)) - elif promote_to_numeric: - fn = lambda x1, x2, /: lax_fn(*promote_args_numeric(numpy_fn.__name__, x1, x2)) - else: - fn = lambda x1, x2, /: lax_fn(*promote_args(numpy_fn.__name__, x1, x2)) - fn.__qualname__ = f"jax.numpy.{numpy_fn.__name__}" - fn = jit(fn, inline=True) - if lax_doc: - doc = dedent('\n\n'.join(lax_fn.__doc__.split('\n\n')[1:])).strip() # type: ignore[union-attr] - return implements(numpy_fn, lax_description=doc, module='numpy')(fn) - else: - return implements(numpy_fn, module='numpy')(fn) - - -def _maybe_bool_binop( - numpy_fn: Callable[..., Any], lax_fn: BinOp, bool_lax_fn: BinOp, - lax_doc: bool = False) -> BinOp: - def fn(x1, x2, /): - x1, x2 = promote_args(numpy_fn.__name__, x1, x2) - return lax_fn(x1, x2) if x1.dtype != np.bool_ else bool_lax_fn(x1, x2) - fn.__qualname__ = f"jax.numpy.{numpy_fn.__name__}" - fn = jit(fn, inline=True) - if lax_doc: - doc = dedent('\n\n'.join(lax_fn.__doc__.split('\n\n')[1:])).strip() # type: ignore[union-attr] - return implements(numpy_fn, lax_description=doc, module='numpy')(fn) - else: - return implements(numpy_fn, module='numpy')(fn) - - -def _comparison_op(numpy_fn: Callable[..., Any], lax_fn: BinOp) -> BinOp: - def fn(x1, x2, /): - x1, x2 = promote_args(numpy_fn.__name__, x1, x2) - # Comparison on complex types are defined as a lexicographic ordering on - # the (real, imag) pair. - if dtypes.issubdtype(dtypes.dtype(x1), np.complexfloating): - rx = lax.real(x1) - ry = lax.real(x2) - return lax.select(lax.eq(rx, ry), lax_fn(lax.imag(x1), lax.imag(x2)), - lax_fn(rx, ry)) - return lax_fn(x1, x2) - fn.__qualname__ = f"jax.numpy.{numpy_fn.__name__}" - fn = jit(fn, inline=True) - return implements(numpy_fn, module='numpy')(fn) - -@overload -def _logical_op(np_op: Callable[..., Any], bitwise_op: UnOp) -> UnOp: ... -@overload -def _logical_op(np_op: Callable[..., Any], bitwise_op: BinOp) -> BinOp: ... -@overload -def _logical_op(np_op: Callable[..., Any], bitwise_op: UnOp | BinOp) -> UnOp | BinOp: ... - -def _logical_op(np_op: Callable[..., Any], bitwise_op: UnOp | BinOp) -> UnOp | BinOp: - @implements(np_op, update_doc=False, module='numpy') - @partial(jit, inline=True) - def op(*args): - zero = lambda x: lax.full_like(x, shape=(), fill_value=0) - args = (x if dtypes.issubdtype(dtypes.dtype(x), np.bool_) else lax.ne(x, zero(x)) - for x in args) - return bitwise_op(*promote_args(np_op.__name__, *args)) - return op +@implements(np.bitwise_not, module='numpy') +@partial(jit, inline=True) +def bitwise_not(x: ArrayLike, /) -> Array: + return lax.bitwise_not(*promote_args('bitwise_not', x)) +@implements(np.invert, module='numpy') +@partial(jit, inline=True) +def invert(x: ArrayLike, /) -> Array: + return lax.bitwise_not(*promote_args('invert', x)) + +@implements(np.negative, module='numpy') +@partial(jit, inline=True) +def negative(x: ArrayLike, /) -> Array: + return lax.neg(*promote_args('negative', x)) + +@implements(np.positive, module='numpy') +@partial(jit, inline=True) +def positive(x: ArrayLike, /) -> Array: + return lax.asarray(*promote_args('positive', x)) + +@implements(np.sign, module='numpy') +@partial(jit, inline=True) +def sign(x: ArrayLike, /) -> Array: + return lax.sign(*promote_args('sign', x)) + +@implements(np.floor, module='numpy') +@partial(jit, inline=True) +def floor(x: ArrayLike, /) -> Array: + check_arraylike('floor', x) + if dtypes.isdtype(dtypes.dtype(x), ('integral', 'bool')): + return lax.asarray(x) + return lax.floor(*promote_args_inexact('floor', x)) + +@implements(np.ceil, module='numpy') +@partial(jit, inline=True) +def ceil(x: ArrayLike, /) -> Array: + check_arraylike('ceil', x) + if dtypes.isdtype(dtypes.dtype(x), ('integral', 'bool')): + return lax.asarray(x) + return lax.ceil(*promote_args_inexact('ceil', x)) + +@implements(np.exp, module='numpy') +@partial(jit, inline=True) +def exp(x: ArrayLike, /) -> Array: + return lax.exp(*promote_args_inexact('exp', x)) + +@implements(np.log, module='numpy') +@partial(jit, inline=True) +def log(x: ArrayLike, /) -> Array: + return lax.log(*promote_args_inexact('log', x)) + +@implements(np.expm1, module='numpy') +@partial(jit, inline=True) +def expm1(x: ArrayLike, /) -> Array: + return lax.expm1(*promote_args_inexact('expm1', x)) + +@implements(np.log1p, module='numpy') +@partial(jit, inline=True) +def log1p(x: ArrayLike, /) -> Array: + return lax.log1p(*promote_args_inexact('log1p', x)) + +@implements(np.sin, module='numpy') +@partial(jit, inline=True) +def sin(x: ArrayLike, /) -> Array: + return lax.sin(*promote_args_inexact('sin', x)) + +@implements(np.cos, module='numpy') +@partial(jit, inline=True) +def cos(x: ArrayLike, /) -> Array: + return lax.cos(*promote_args_inexact('cos', x)) + +@implements(np.tan, module='numpy') +@partial(jit, inline=True) +def tan(x: ArrayLike, /) -> Array: + return lax.tan(*promote_args_inexact('tan', x)) + +@implements(np.arcsin, module='numpy') +@partial(jit, inline=True) +def arcsin(x: ArrayLike, /) -> Array: + return lax.asin(*promote_args_inexact('arcsin', x)) + +@implements(np.arccos, module='numpy') +@partial(jit, inline=True) +def arccos(x: ArrayLike, /) -> Array: + return lax.acos(*promote_args_inexact('arccos', x)) + +@implements(np.arctan, module='numpy') +@partial(jit, inline=True) +def arctan(x: ArrayLike, /) -> Array: + return lax.atan(*promote_args_inexact('arctan', x)) + +@implements(np.sinh, module='numpy') +@partial(jit, inline=True) +def sinh(x: ArrayLike, /) -> Array: + return lax.sinh(*promote_args_inexact('sinh', x)) + +@implements(np.cosh, module='numpy') +@partial(jit, inline=True) +def cosh(x: ArrayLike, /) -> Array: + return lax.cosh(*promote_args_inexact('cosh', x)) + +@implements(np.arcsinh, module='numpy') +@partial(jit, inline=True) +def arcsinh(x: ArrayLike, /) -> Array: + return lax.asinh(*promote_args_inexact('arcsinh', x)) + +@implements(np.arccosh, module='numpy') @jit -def _arccosh(x: ArrayLike, /) -> Array: - # Note: arccosh is multi-valued for complex input, and lax.acosh uses a different - # convention than np.arccosh. - out = lax.acosh(*promote_args_inexact("arccosh", x)) - if dtypes.issubdtype(out.dtype, np.complexfloating): - out = _where(real(out) < 0, lax.neg(out), out) - return out - -fabs = _one_to_one_unop(np.fabs, lax.abs, True) -bitwise_invert = _one_to_one_unop(getattr(np, 'bitwise_invert', np.invert), lax.bitwise_not) -bitwise_invert = _one_to_one_unop(getattr(np, 'bitwise_invert', np.invert), lax.bitwise_not) -bitwise_not = _one_to_one_unop(np.bitwise_not, lax.bitwise_not) -invert = _one_to_one_unop(np.invert, lax.bitwise_not) -negative = _one_to_one_unop(np.negative, lax.neg) -positive = _one_to_one_unop(np.positive, lambda x: lax.asarray(x)) -floor = _one_to_one_unop(np.floor, lax.floor, True) -ceil = _one_to_one_unop(np.ceil, lax.ceil, True) -exp = _one_to_one_unop(np.exp, lax.exp, True) -log = _one_to_one_unop(np.log, lax.log, True) -expm1 = _one_to_one_unop(np.expm1, lax.expm1, True) -log1p = _one_to_one_unop(np.log1p, lax.log1p, True) -sin = _one_to_one_unop(np.sin, lax.sin, True) -cos = _one_to_one_unop(np.cos, lax.cos, True) -tan = _one_to_one_unop(np.tan, lax.tan, True) -arcsin = _one_to_one_unop(np.arcsin, lax.asin, True) -arccos = _one_to_one_unop(np.arccos, lax.acos, True) -arctan = _one_to_one_unop(np.arctan, lax.atan, True) -sinh = _one_to_one_unop(np.sinh, lax.sinh, True) -cosh = _one_to_one_unop(np.cosh, lax.cosh, True) -arcsinh = _one_to_one_unop(np.arcsinh, lax.asinh, True) -arccosh = _one_to_one_unop(np.arccosh, _arccosh, True) -tanh = _one_to_one_unop(np.tanh, lax.tanh, True) -arctanh = _one_to_one_unop(np.arctanh, lax.atanh, True) -sign = _one_to_one_unop(np.sign, lax.sign) -sqrt = _one_to_one_unop(np.sqrt, lax.sqrt, True) -cbrt = _one_to_one_unop(np.cbrt, lax.cbrt, True) - -add = _maybe_bool_binop(np.add, lax.add, lax.bitwise_or) -bitwise_and = _one_to_one_binop(np.bitwise_and, lax.bitwise_and) -bitwise_left_shift = _one_to_one_binop(getattr(np, "bitwise_left_shift", np.left_shift), lax.shift_left, promote_to_numeric=True) -bitwise_or = _one_to_one_binop(np.bitwise_or, lax.bitwise_or) -bitwise_xor = _one_to_one_binop(np.bitwise_xor, lax.bitwise_xor) -left_shift = _one_to_one_binop(np.left_shift, lax.shift_left, promote_to_numeric=True) -equal = _one_to_one_binop(np.equal, lax.eq) -multiply = _maybe_bool_binop(np.multiply, lax.mul, lax.bitwise_and) -not_equal = _one_to_one_binop(np.not_equal, lax.ne) -subtract = _one_to_one_binop(np.subtract, lax.sub) -arctan2 = _one_to_one_binop(np.arctan2, lax.atan2, True) -minimum = _one_to_one_binop(np.minimum, lax.min) -maximum = _one_to_one_binop(np.maximum, lax.max) -float_power = _one_to_one_binop(np.float_power, lax.pow, True) -nextafter = _one_to_one_binop(np.nextafter, lax.nextafter, True, True) - -greater_equal = _comparison_op(np.greater_equal, lax.ge) -greater = _comparison_op(np.greater, lax.gt) -less_equal = _comparison_op(np.less_equal, lax.le) -less = _comparison_op(np.less, lax.lt) - -logical_and: BinOp = _logical_op(np.logical_and, lax.bitwise_and) -logical_not: UnOp = _logical_op(np.logical_not, lax.bitwise_not) -logical_or: BinOp = _logical_op(np.logical_or, lax.bitwise_or) -logical_xor: BinOp = _logical_op(np.logical_xor, lax.bitwise_xor) +def arccosh(x: ArrayLike, /) -> Array: + # Note: arccosh is multi-valued for complex input, and lax.acosh + # uses a different convention than np.arccosh. + result = lax.acosh(*promote_args_inexact("arccosh", x)) + if dtypes.issubdtype(result.dtype, np.complexfloating): + result = _where(real(result) < 0, lax.neg(result), result) + return result + +@implements(np.tanh, module='numpy') +@partial(jit, inline=True) +def tanh(x: ArrayLike, /) -> Array: + return lax.tanh(*promote_args_inexact('tanh', x)) + +@implements(np.arctanh, module='numpy') +@partial(jit, inline=True) +def arctanh(x: ArrayLike, /) -> Array: + return lax.atanh(*promote_args_inexact('arctanh', x)) + +@implements(np.sqrt, module='numpy') +@partial(jit, inline=True) +def sqrt(x: ArrayLike, /) -> Array: + return lax.sqrt(*promote_args_inexact('sqrt', x)) + +@implements(np.cbrt, module='numpy') +@partial(jit, inline=True) +def cbrt(x: ArrayLike, /) -> Array: + return lax.cbrt(*promote_args_inexact('cbrt', x)) + +@implements(np.add, module='numpy') +@partial(jit, inline=True) +def add(x: ArrayLike, y: ArrayLike, /) -> Array: + x, y = promote_args("add", x, y) + return lax.add(x, y) if x.dtype != bool else lax.bitwise_or(x, y) + +@implements(np.multiply, module='numpy') +@partial(jit, inline=True) +def multiply(x: ArrayLike, y: ArrayLike, /) -> Array: + x, y = promote_args("multiply", x, y) + return lax.mul(x, y) if x.dtype != bool else lax.bitwise_and(x, y) + +@implements(np.bitwise_and, module='numpy') +@partial(jit, inline=True) +def bitwise_and(x: ArrayLike, y: ArrayLike, /) -> Array: + return lax.bitwise_and(*promote_args("bitwise_and", x, y)) + +@implements(np.bitwise_or, module='numpy') +@partial(jit, inline=True) +def bitwise_or(x: ArrayLike, y: ArrayLike, /) -> Array: + return lax.bitwise_or(*promote_args("bitwise_or", x, y)) + +@implements(np.bitwise_xor, module='numpy') +@partial(jit, inline=True) +def bitwise_xor(x: ArrayLike, y: ArrayLike, /) -> Array: + return lax.bitwise_xor(*promote_args("bitwise_xor", x, y)) + +@implements(np.left_shift, module='numpy') +@partial(jit, inline=True) +def left_shift(x: ArrayLike, y: ArrayLike, /) -> Array: + return lax.shift_left(*promote_args_numeric("left_shift", x, y)) + +@implements(getattr(np, "bitwise_left_shift", np.left_shift), module='numpy') +@partial(jit, inline=True) +def bitwise_left_shift(x: ArrayLike, y: ArrayLike, /) -> Array: + return lax.shift_left(*promote_args_numeric("bitwise_left_shift", x, y)) + +@implements(np.equal, module='numpy') +@partial(jit, inline=True) +def equal(x: ArrayLike, y: ArrayLike, /) -> Array: + return lax.eq(*promote_args("equal", x, y)) + +@implements(np.not_equal, module='numpy') +@partial(jit, inline=True) +def not_equal(x: ArrayLike, y: ArrayLike, /) -> Array: + return lax.ne(*promote_args("not_equal", x, y)) + +@implements(np.subtract, module='numpy') +@partial(jit, inline=True) +def subtract(x: ArrayLike, y: ArrayLike, /) -> Array: + return lax.sub(*promote_args("subtract", x, y)) + +@implements(np.arctan2, module='numpy') +@partial(jit, inline=True) +def arctan2(x: ArrayLike, y: ArrayLike, /) -> Array: + return lax.atan2(*promote_args_inexact("arctan2", x, y)) + +@implements(np.minimum, module='numpy') +@partial(jit, inline=True) +def minimum(x: ArrayLike, y: ArrayLike, /) -> Array: + return lax.min(*promote_args("minimum", x, y)) + +@implements(np.maximum, module='numpy') +@partial(jit, inline=True) +def maximum(x: ArrayLike, y: ArrayLike, /) -> Array: + return lax.max(*promote_args("maximum", x, y)) + +@implements(np.float_power, module='numpy') +@partial(jit, inline=True) +def float_power(x: ArrayLike, y: ArrayLike, /) -> Array: + return lax.pow(*promote_args_inexact("float_power", x, y)) + +@implements(np.nextafter, module='numpy') +@partial(jit, inline=True) +def nextafter(x: ArrayLike, y: ArrayLike, /) -> Array: + return lax.nextafter(*promote_args_inexact("nextafter", x, y)) + +# Logical ops +@implements(np.logical_and, module='numpy') +@partial(jit, inline=True) +def logical_and(x: ArrayLike, y: ArrayLike, /) -> Array: + return lax.bitwise_and(*map(_to_bool, promote_args("logical_and", x, y))) + +@implements(np.logical_or, module='numpy') +@partial(jit, inline=True) +def logical_or(x: ArrayLike, y: ArrayLike, /) -> Array: + return lax.bitwise_or(*map(_to_bool, promote_args("logical_or", x, y))) + +@implements(np.logical_xor, module='numpy') +@partial(jit, inline=True) +def logical_xor(x: ArrayLike, y: ArrayLike, /) -> Array: + return lax.bitwise_xor(*map(_to_bool, promote_args("logical_xor", x, y))) + +@implements(np.logical_not, module='numpy') +@partial(jit, inline=True) +def logical_not(x: ArrayLike, /) -> Array: + return lax.bitwise_not(*map(_to_bool, promote_args("logical_not", x))) + +# Comparison ops +def _complex_comparison(lax_op: Callable[[ArrayLike, ArrayLike], Array], + x: Array, y: Array): + if dtypes.issubdtype(x.dtype, np.complexfloating): + return lax.select(lax.eq(x.real, y.real), + lax_op(x.imag, y.imag), + lax_op(x.real, y.real)) + return lax_op(x, y) + +@implements(np.greater_equal, module='numpy') +@partial(jit, inline=True) +def greater_equal(x: ArrayLike, y: ArrayLike, /) -> Array: + return _complex_comparison(lax.ge, *promote_args("greater_equal", x, y)) + +@implements(np.greater, module='numpy') +@partial(jit, inline=True) +def greater(x: ArrayLike, y: ArrayLike, /) -> Array: + return _complex_comparison(lax.gt, *promote_args("greater", x, y)) + +@implements(np.less_equal, module='numpy') +@partial(jit, inline=True) +def less_equal(x: ArrayLike, y: ArrayLike, /) -> Array: + return _complex_comparison(lax.le, *promote_args("less_equal", x, y)) + +@implements(np.less, module='numpy') +@partial(jit, inline=True) +def less(x: ArrayLike, y: ArrayLike, /) -> Array: + return _complex_comparison(lax.lt, *promote_args("less", x, y)) # Array API aliases -# TODO(jakevdp): directly reference np_fun when minimum numpy version is 2.0 -acos = _one_to_one_unop(getattr(np, "acos", np.arccos), lax.acos, True) -acosh = _one_to_one_unop(getattr(np, "acosh", np.arccosh), _arccosh, True) -asin = _one_to_one_unop(getattr(np, "asin", np.arcsin), lax.asin, True) -asinh = _one_to_one_unop(getattr(np, "asinh", np.arcsinh), lax.asinh, True) -atan = _one_to_one_unop(getattr(np, "atan", np.arctan), lax.atan, True) -atanh = _one_to_one_unop(getattr(np, "atanh", np.arctanh), lax.atanh, True) -atan2 = _one_to_one_binop(getattr(np, "atan2", np.arctan2), lax.atan2, True) +@partial(jit, inline=True) +def acos(x: ArrayLike, /) -> Array: + """Alias of :func:`jax.numpy.arccos`""" + return arccos(*promote_args('acos', x)) + +@partial(jit, inline=True) +def acosh(x: ArrayLike, /) -> Array: + """Alias of :func:`jax.numpy.arccosh`""" + return arccosh(*promote_args('acosh', x)) + +@partial(jit, inline=True) +def asin(x: ArrayLike, /) -> Array: + """Alias of :func:`jax.numpy.arcsin`""" + return arcsin(*promote_args('asin', x)) +@partial(jit, inline=True) +def asinh(x: ArrayLike, /) -> Array: + """Alias of :func:`jax.numpy.arcsinh`""" + return arcsinh(*promote_args('asinh', x)) + +@partial(jit, inline=True) +def atan(x: ArrayLike, /) -> Array: + """Alias of :func:`jax.numpy.arctan`""" + return arctan(*promote_args('atan', x)) + +@partial(jit, inline=True) +def atanh(x: ArrayLike, /) -> Array: + """Alias of :func:`jax.numpy.arctanh`""" + return arctanh(*promote_args('atanh', x)) + +@partial(jit, inline=True) +def atan2(x: ArrayLike, y: ArrayLike, /) -> Array: + """Alias of :func:`jax.numpy.arctan2`""" + return arctan2(*promote_args('atan2', x, y)) -@implements(getattr(np, 'bitwise_count', None), module='numpy') @jit def bitwise_count(x: ArrayLike, /) -> Array: + r"""Counts the number of 1 bits in the binary representation of the absolute value + of each element of ``x``. + + LAX-backend implementation of :func:`numpy.bitwise_count`. + + Args: + x: Input array, only accepts integer subtypes + + Returns: + An array-like object containing the binary 1 bit counts of the absolute value of + each element in ``x``, with the same shape as ``x`` of dtype uint8. + + Examples: + >>> x1 = jnp.array([64, 32, 31, 20]) + >>> # 64 = 0b1000000, 32 = 0b100000, 31 = 0b11111, 20 = 0b10100 + >>> jnp.bitwise_count(x1) + Array([1, 1, 5, 2], dtype=uint8) + + >>> x2 = jnp.array([-16, -7, 7]) + >>> # |-16| = 0b10000, |-7| = 0b111, 7 = 0b111 + >>> jnp.bitwise_count(x2) + Array([1, 3, 3], dtype=uint8) + + >>> x3 = jnp.array([[2, -7],[-9, 7]]) + >>> # 2 = 0b10, |-7| = 0b111, |-9| = 0b1001, 7 = 0b111 + >>> jnp.bitwise_count(x3) + Array([[1, 3], + [2, 3]], dtype=uint8) + """ x, = promote_args_numeric("bitwise_count", x) # Following numpy we take the absolute value and return uint8. return lax.population_count(abs(x)).astype('uint8') -@implements(np.right_shift, module='numpy') @partial(jit, inline=True) def right_shift(x1: ArrayLike, x2: ArrayLike, /) -> Array: + r"""Right shift the bits of ``x1`` to the amount specified in ``x2``. + + LAX-backend implementation of :func:`numpy.right_shift`. + + Args: + x1: Input array, only accepts unsigned integer subtypes + x2: The amount of bits to shift each element in ``x1`` to the right, only accepts + integer subtypes + + Returns: + An array-like object containing the right shifted elements of ``x1`` by the + amount specified in ``x2``, with the same shape as the broadcasted shape of + ``x1`` and ``x2``. + + Note: + If ``x1.shape != x2.shape``, they must be compatible for broadcasting to a + shared shape, this shared shape will also be the shape of the output. Right shifting + a scalar x1 by scalar x2 is equivalent to ``x1 // 2**x2``. + + Examples: + >>> def print_binary(x): + ... return [bin(int(val)) for val in x] + + >>> x1 = jnp.array([1, 2, 4, 8]) + >>> print_binary(x1) + ['0b1', '0b10', '0b100', '0b1000'] + >>> x2 = 1 + >>> result = jnp.right_shift(x1, x2) + >>> result + Array([0, 1, 2, 4], dtype=int32) + >>> print_binary(result) + ['0b0', '0b1', '0b10', '0b100'] + + >>> x1 = 16 + >>> print_binary([x1]) + ['0b10000'] + >>> x2 = jnp.array([1, 2, 3, 4]) + >>> result = jnp.right_shift(x1, x2) + >>> result + Array([8, 4, 2, 1], dtype=int32) + >>> print_binary(result) + ['0b1000', '0b100', '0b10', '0b1'] + """ x1, x2 = promote_args_numeric(np.right_shift.__name__, x1, x2) lax_fn = lax.shift_right_logical if \ np.issubdtype(x1.dtype, np.unsignedinteger) else lax.shift_right_arithmetic @@ -237,18 +456,78 @@ def bitwise_right_shift(x1: ArrayLike, x2: ArrayLike, /) -> Array: np.issubdtype(x1.dtype, np.unsignedinteger) else lax.shift_right_arithmetic return lax_fn(x1, x2) -@implements(np.absolute, module='numpy') + @partial(jit, inline=True) def absolute(x: ArrayLike, /) -> Array: + r"""Calculate the absolute value element-wise. + + LAX-backend implementation of :func:`numpy.absolute`. + + This is the same function as :func:`jax.numpy.abs`. + + Args: + x: Input array + + Returns: + An array-like object containing the absolute value of each element in ``x``, + with the same shape as ``x``. For complex valued input, :math:`a + ib`, + the absolute value is :math:`\sqrt{a^2+b^2}`. + + Examples: + >>> x1 = jnp.array([5, -2, 0, 12]) + >>> jnp.absolute(x1) + Array([ 5, 2, 0, 12], dtype=int32) + + >>> x2 = jnp.array([[ 8, -3, 1],[ 0, 9, -6]]) + >>> jnp.absolute(x2) + Array([[8, 3, 1], + [0, 9, 6]], dtype=int32) + + >>> x3 = jnp.array([8 + 15j, 3 - 4j, -5 + 0j]) + >>> jnp.absolute(x3) + Array([17., 5., 5.], dtype=float32) + """ check_arraylike('absolute', x) dt = dtypes.dtype(x) return lax.asarray(x) if dt == np.bool_ or dtypes.issubdtype(dt, np.unsignedinteger) else lax.abs(x) -abs = implements(np.abs, module='numpy')(absolute) -@implements(np.rint, module='numpy') +@partial(jit, inline=True) +def abs(x: ArrayLike, /) -> Array: + """Alias of :func:`jax.numpy.absolute`.""" + return absolute(x) + + @jit def rint(x: ArrayLike, /) -> Array: + """Rounds the elements of x to the nearest integer + + LAX-backend implementation of :func:`numpy.rint`. + + Args: + x: Input array + + Returns: + An array-like object containing the rounded elements of ``x``. Always promotes + to inexact. + + Note: + If an element of x is exactly half way, e.g. ``0.5`` or ``1.5``, rint will round + to the nearest even integer. + + Example: + >>> x1 = jnp.array([5, 4, 7]) + >>> jnp.rint(x1) + Array([5., 4., 7.], dtype=float32) + + >>> x2 = jnp.array([-2.5, -1.5, -0.5, 0.5, 1.5, 2.5, 3.5, 4.5]) + >>> jnp.rint(x2) + Array([-2., -2., -0., 0., 2., 2., 4., 4.], dtype=float32) + + >>> x3 = jnp.array([-2.5+3.5j, 4.5-0.5j]) + >>> jnp.rint(x3) + Array([-2.+4.j, 4.-0.j], dtype=complex64) + """ check_arraylike('rint', x) dtype = dtypes.dtype(x) if dtype == bool or dtypes.issubdtype(dtype, np.integer): @@ -258,9 +537,39 @@ def rint(x: ArrayLike, /) -> Array: return lax.round(x, lax.RoundingMethod.TO_NEAREST_EVEN) -@implements(np.copysign, module='numpy') @jit def copysign(x1: ArrayLike, x2: ArrayLike, /) -> Array: + """Copies the sign of each element in ``x2`` to the corresponding element in ``x1``. + + LAX-backend implementation of :func:`numpy.copysign`. + + Args: + x1: Input array + x2: The array whose elements will be used to determine the sign, must be + broadcast-compatible with ``x1`` + + Returns: + An array object containing the potentially changed elements of ``x1``, always promotes + to inexact dtype, and has a shape of ``jnp.broadcast_shapes(x1.shape, x2.shape)`` + + Examples: + >>> x1 = jnp.array([5, 2, 0]) + >>> x2 = -1 + >>> jnp.copysign(x1, x2) + Array([-5., -2., -0.], dtype=float32) + + >>> x1 = jnp.array([6, 8, 0]) + >>> x2 = 2 + >>> jnp.copysign(x1, x2) + Array([6., 8., 0.], dtype=float32) + + >>> x1 = jnp.array([2, -3]) + >>> x2 = jnp.array([[1],[-4], [5]]) + >>> jnp.copysign(x1, x2) + Array([[ 2., 3.], + [-2., -3.], + [ 2., 3.]], dtype=float32) + """ x1, x2 = promote_args_inexact("copysign", x1, x2) if dtypes.issubdtype(dtypes.dtype(x1), np.complexfloating): raise TypeError("copysign does not support complex-valued inputs") @@ -276,34 +585,93 @@ def true_divide(x1: ArrayLike, x2: ArrayLike, /) -> Array: divide = true_divide -@implements(np.floor_divide, module='numpy') @jit def floor_divide(x1: ArrayLike, x2: ArrayLike, /) -> Array: + """Calculates the floor division of x1 by x2 element-wise + + LAX-backend implementation of :func:`numpy.floor_divide`. + + Args: + x1: Input array, the dividend + x2: Input array, the divisor + + Returns: + An array-like object containing each of the quotients rounded down + to the nearest integer towards negative infinity. This is equivalent + to ``x1 // x2`` in Python. + + Examples: + >>> x1 = jnp.array([10, 20, 30]) + >>> x2 = jnp.array([3, 4, 7]) + >>> jnp.floor_divide(x1, x2) + Array([3, 5, 4], dtype=int32) + + >>> x1 = jnp.array([-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5]) + >>> x2 = 3 + >>> jnp.floor_divide(x1, x2) + Array([-2, -2, -1, -1, -1, 0, 0, 0, 1, 1, 1], dtype=int32) + + >>> x1 = jnp.array([6, 6, 6], dtype=jnp.int32) + >>> x2 = jnp.array([2.0, 2.5, 3.0], dtype=jnp.float32) + >>> jnp.floor_divide(x1, x2) + Array([3., 2., 2.], dtype=float32) + + Note: + ``x1 // x2`` is equivalent to ``jnp.floor_divide(x1, x2)`` for arrays ``x1`` and ``x2`` + + See Also: + :func:`jnp.divide` and :func:`jnp.true_divide` for floating point division + """ x1, x2 = promote_args_numeric("floor_divide", x1, x2) dtype = dtypes.dtype(x1) - if dtypes.issubdtype(dtype, np.integer): + if dtypes.issubdtype(dtype, np.unsignedinteger): + return lax.div(x1, x2) + elif dtypes.issubdtype(dtype, np.integer): quotient = lax.div(x1, x2) select = logical_and(lax.sign(x1) != lax.sign(x2), lax.rem(x1, x2) != 0) # TODO(mattjj): investigate why subtracting a scalar was causing promotion return _where(select, quotient - 1, quotient) elif dtypes.issubdtype(dtype, np.complexfloating): - x1r = lax.real(x1) - x1i = lax.imag(x1) - x2r = lax.real(x2) - x2i = lax.imag(x2) - which = lax.ge(lax.abs(x2r), lax.abs(x2i)) - rat1 = _where(which, lax.full_like(x2i, 1), lax.div(x2r, x2i)) - rat2 = _where(which, lax.div(x2i, x2r), _lax_const(x2i, 1)) - out = lax.floor(lax.div(lax.add(lax.mul(x1r, rat1), lax.mul(x1i, rat2)), - lax.add(lax.mul(x2r, rat1), lax.mul(x2i, rat2)))) - return lax.convert_element_type(out, dtype) + raise TypeError("floor_divide does not support complex-valued inputs") else: return _float_divmod(x1, x2)[0] -@implements(np.divmod, module='numpy') @jit def divmod(x1: ArrayLike, x2: ArrayLike, /) -> tuple[Array, Array]: + """Calculates the integer quotient and remainder of x1 by x2 element-wise + + LAX-backend implementation of :func:`numpy.divmod`. + + Args: + x1: Input array, the dividend + x2: Input array, the divisor + + Returns: + A tuple of arrays ``(x1 // x2, x1 % x2)``. + + Examples: + >>> x1 = jnp.array([10, 20, 30]) + >>> x2 = jnp.array([3, 4, 7]) + >>> jnp.divmod(x1, x2) + (Array([3, 5, 4], dtype=int32), Array([1, 0, 2], dtype=int32)) + + >>> x1 = jnp.array([-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5]) + >>> x2 = 3 + >>> jnp.divmod(x1, x2) + (Array([-2, -2, -1, -1, -1, 0, 0, 0, 1, 1, 1], dtype=int32), + Array([1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2], dtype=int32)) + + >>> x1 = jnp.array([6, 6, 6], dtype=jnp.int32) + >>> x2 = jnp.array([1.9, 2.5, 3.1], dtype=jnp.float32) + >>> jnp.divmod(x1, x2) + (Array([3., 2., 1.], dtype=float32), + Array([0.30000007, 1. , 2.9 ], dtype=float32)) + + See Also: + - :func:`jax.numpy.floor_divide`: floor division function + - :func:`jax.numpy.remainder`: remainder function + """ x1, x2 = promote_args_numeric("divmod", x1, x2) if dtypes.issubdtype(dtypes.dtype(x1), np.integer): return floor_divide(x1, x2), remainder(x1, x2) @@ -707,14 +1075,14 @@ def _isposneginf(infinity: float, x: ArrayLike, out) -> Array: return lax.full_like(x, False, dtype=np.bool_) -isposinf: UnOp = implements(np.isposinf, skip_params=['out'])( - lambda x, /, out=None: _isposneginf(np.inf, x, out) -) +@implements(np.isposinf, module='numpy') +def isposinf(x, /, out=None): + return _isposneginf(np.inf, x, out) -isneginf: UnOp = implements(np.isneginf, skip_params=['out'])( - lambda x, /, out=None: _isposneginf(-np.inf, x, out) -) +@implements(np.isposinf, module='numpy') +def isneginf(x, /, out=None): + return _isposneginf(-np.inf, x, out) @implements(np.isnan, module='numpy') @@ -737,12 +1105,22 @@ def heaviside(x1: ArrayLike, x2: ArrayLike, /) -> Array: @implements(np.hypot, module='numpy') @jit def hypot(x1: ArrayLike, x2: ArrayLike, /) -> Array: - check_arraylike("hypot", x1, x2) - x1, x2 = promote_dtypes_inexact(x1, x2) - x1 = lax.abs(x1) - x2 = lax.abs(x2) + x1, x2 = promote_args_inexact("hypot", x1, x2) + + # TODO(micky774): Promote to ValueError when deprecation is complete + # (began 2024-4-14). + if dtypes.issubdtype(x1.dtype, np.complexfloating): + warnings.warn( + "Passing complex-valued inputs to hypot is deprecated and will raise a " + "ValueError in the future. Please convert to real values first, such as " + "by using jnp.real or jnp.imag to take the real or imaginary components " + "respectively.", + DeprecationWarning, stacklevel=2) + x1, x2 = lax.abs(x1), lax.abs(x2) + idx_inf = lax.bitwise_or(isposinf(x1), isposinf(x2)) x1, x2 = maximum(x1, x2), minimum(x1, x2) - return lax.select(x1 == 0, x1, x1 * lax.sqrt(1 + lax.square(lax.div(x2, lax.select(x1 == 0, lax._ones(x1), x1))))) + x = _where(x1 == 0, x1, x1 * lax.sqrt(1 + lax.square(lax.div(x2, _where(x1 == 0, lax._ones(x1), x1))))) + return _where(idx_inf, _lax_const(x, np.inf), x) @implements(np.reciprocal, module='numpy') diff --git a/jax/_src/numpy/util.py b/jax/_src/numpy/util.py index 1dfd94d8a6df..21b96deea3c6 100644 --- a/jax/_src/numpy/util.py +++ b/jax/_src/numpy/util.py @@ -13,11 +13,11 @@ # limitations under the License. from __future__ import annotations -from collections.abc import Sequence +from collections.abc import Callable, Sequence from functools import partial import re import textwrap -from typing import Any, Callable, NamedTuple, TypeVar +from typing import Any, NamedTuple, TypeVar import warnings @@ -418,6 +418,8 @@ def _broadcast_to(arr: ArrayLike, shape: DimSize | Shape) -> Array: arr_shape = np.shape(arr) if core.definitely_equal_shape(arr_shape, shape): return arr + elif len(shape) < len(arr_shape): + raise ValueError(f"Cannot broadcast to shape with fewer dimensions: {arr_shape=} {shape=}") else: nlead = len(shape) - len(arr_shape) shape_tail = shape[nlead:] diff --git a/jax/_src/numpy/vectorize.py b/jax/_src/numpy/vectorize.py index 7749f1b42cce..2c517467e287 100644 --- a/jax/_src/numpy/vectorize.py +++ b/jax/_src/numpy/vectorize.py @@ -13,10 +13,10 @@ # limitations under the License. from __future__ import annotations -from collections.abc import Collection, Sequence +from collections.abc import Callable, Collection, Sequence import functools import re -from typing import Any, Callable +from typing import Any from jax._src import api from jax import lax @@ -124,7 +124,7 @@ def _parse_input_dimensions( shapes.append(arg.shape[:ndim]) broadcast_shape = lax.broadcast_shapes(*shapes) # TODO(mattjj): this code needs updating for dynamic shapes (hence ignore) - return broadcast_shape, dim_sizes # type: ignore + return broadcast_shape, dim_sizes def _check_output_dims( diff --git a/jax/_src/op_shardings.py b/jax/_src/op_shardings.py index ee3ea6b4494f..6b248736ce45 100644 --- a/jax/_src/op_shardings.py +++ b/jax/_src/op_shardings.py @@ -26,7 +26,7 @@ def get_num_ways_dim_sharded( hlo_sharding: xc.HloSharding) -> tuple[list[int], int]: - if hlo_sharding.is_replicated(): # type: ignore + if hlo_sharding.is_replicated(): return [], 1 partitions = hlo_sharding.tile_assignment_dimensions() subgroup_types = hlo_sharding.subgroup_types() @@ -50,7 +50,7 @@ def is_op_sharding_replicated(op: xc.OpSharding | xc.HloSharding) -> bool: op = xc.HloSharding.from_proto(op) if op.num_devices() == 1: return True - return op.is_replicated() # type: ignore + return op.is_replicated() def are_op_shardings_equal(op1: xc.OpSharding | xc.HloSharding, diff --git a/jax/_src/ops/scatter.py b/jax/_src/ops/scatter.py index 5934ae415610..2bcfe96ad2f0 100644 --- a/jax/_src/ops/scatter.py +++ b/jax/_src/ops/scatter.py @@ -16,9 +16,8 @@ from __future__ import annotations -from collections.abc import Sequence -import sys -from typing import Callable, Union +from collections.abc import Callable, Sequence +from typing import Union import warnings import numpy as np @@ -36,11 +35,8 @@ from jax._src.typing import Array, ArrayLike -if sys.version_info >= (3, 10): - from types import EllipsisType - SingleIndex = int | slice | Sequence[int] | Array | EllipsisType | None -else: - SingleIndex = Union[int, slice, Sequence[int], Array, None] +from types import EllipsisType +SingleIndex = int | slice | Sequence[int] | Array | EllipsisType | None Index = Union[SingleIndex, tuple[SingleIndex, ...]] Scalar = Union[complex, float, int, np.number] @@ -103,9 +99,6 @@ def _scatter_impl(x, y, scatter_op, treedef, static_idx, dynamic_idx, idx = jnp._merge_static_and_dynamic_indices(treedef, static_idx, dynamic_idx) indexer = jnp._index_to_gather(jnp.shape(x), idx, normalize_indices=normalize_indices) - # TODO(jakevdp): implement scalar boolean logic. - if indexer.scalar_bool_dims: - raise TypeError("Scalar boolean indices are not allowed in scatter.") # Avoid calling scatter if the slice shape is empty, both as a fast path and # to handle cases like zeros(0)[array([], int32)]. @@ -121,6 +114,9 @@ def _scatter_impl(x, y, scatter_op, treedef, static_idx, dynamic_idx, if indexer.reversed_y_dims: y = lax.rev(y, indexer.reversed_y_dims) + if indexer.scalar_bool_dims: + x = lax.expand_dims(x, indexer.scalar_bool_dims) + # Transpose the gather dimensions into scatter dimensions (cf. # lax._gather_transpose_rule) dnums = lax.ScatterDimensionNumbers( @@ -133,10 +129,11 @@ def _scatter_impl(x, y, scatter_op, treedef, static_idx, dynamic_idx, indices_are_sorted=indexer.indices_are_sorted or indices_are_sorted, unique_indices=indexer.unique_indices or unique_indices, mode=mode) + if indexer.scalar_bool_dims: + out = lax.squeeze(out, indexer.scalar_bool_dims) return lax_internal._convert_element_type(out, dtype, weak_type) - def _get_identity(op, dtype): """Get an appropriate identity for a given operation in a given dtype.""" if op is lax.scatter_add: diff --git a/jax/_src/ops/special.py b/jax/_src/ops/special.py index 72826621aa0f..59ad594ef2bc 100644 --- a/jax/_src/ops/special.py +++ b/jax/_src/ops/special.py @@ -30,21 +30,21 @@ @overload def logsumexp(a: ArrayLike, axis: Axis = None, b: ArrayLike | None = None, - keepdims: bool = False, return_sign: Literal[False] = False) -> Array: ... + keepdims: bool = False, return_sign: Literal[False] = False, where: ArrayLike | None = None) -> Array: ... @overload def logsumexp(a: ArrayLike, axis: Axis = None, b: ArrayLike | None = None, - keepdims: bool = False, *, return_sign: Literal[True]) -> tuple[Array, Array]: ... + keepdims: bool = False, *, return_sign: Literal[True], where: ArrayLike | None = None) -> tuple[Array, Array]: ... @overload def logsumexp(a: ArrayLike, axis: Axis = None, b: ArrayLike | None = None, - keepdims: bool = False, return_sign: bool = False) -> Array | tuple[Array, Array]: ... + keepdims: bool = False, return_sign: bool = False, where: ArrayLike | None = None) -> Array | tuple[Array, Array]: ... def logsumexp(a: ArrayLike, axis: Axis = None, b: ArrayLike | None = None, - keepdims: bool = False, return_sign: bool = False) -> Array | tuple[Array, Array]: + keepdims: bool = False, return_sign: bool = False, where: ArrayLike | None = None) -> Array | tuple[Array, Array]: r"""Log-sum-exp reduction. - Computes + JAX implementation of :func:`scipy.special.logsumexp`. .. math:: \mathrm{logsumexp}(a) = \mathrm{log} \sum_j b \cdot \mathrm{exp}(a_{ij}) @@ -63,6 +63,7 @@ def logsumexp(a: ArrayLike, axis: Axis = None, b: ArrayLike | None = None, where ``sign`` is the sign of the sums and ``result`` contains the logarithms of their absolute values. If ``False`` only ``result`` is returned and it will contain NaN values if the sums are negative. + where: Elements to include in the reduction. Returns: Either an array ``result`` or a pair of arrays ``(result, sign)``, depending @@ -75,14 +76,14 @@ def logsumexp(a: ArrayLike, axis: Axis = None, b: ArrayLike | None = None, a_arr, = promote_args_inexact("logsumexp", a) b_arr = a_arr # for type checking pos_dims, dims = _reduction_dims(a_arr, axis) - amax = jnp.max(a_arr.real, axis=dims, keepdims=keepdims) + amax = jnp.max(a_arr.real, axis=dims, keepdims=keepdims, where=where, initial=-jnp.inf) amax = lax.stop_gradient(lax.select(jnp.isfinite(amax), amax, lax.full_like(amax, 0))) amax_with_dims = amax if keepdims else lax.expand_dims(amax, pos_dims) exp_a = lax.exp(lax.sub(a_arr, amax_with_dims.astype(a_arr.dtype))) if b is not None: exp_a = lax.mul(exp_a, b_arr) - sumexp = exp_a.sum(axis=dims, keepdims=keepdims) + sumexp = exp_a.sum(axis=dims, keepdims=keepdims, where=where) sign = lax.sign(sumexp) if return_sign or not np.issubdtype(a_arr.dtype, np.complexfloating): sumexp = abs(sumexp) diff --git a/jax/_src/pallas/BUILD b/jax/_src/pallas/BUILD index 9b11331618aa..c0fa02131bc8 100644 --- a/jax/_src/pallas/BUILD +++ b/jax/_src/pallas/BUILD @@ -12,11 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +load("@rules_python//python:defs.bzl", "py_library") load( "//jaxlib:jax.bzl", "py_deps", ) -load("@rules_python//python:defs.bzl", "py_library") package( default_applicable_licenses = [], @@ -27,13 +27,13 @@ package( py_library( name = "pallas", - srcs = glob( - include = ["**/*.py"], - exclude = [ - "triton/*.py", - "mosaic/*.py", - ], - ), + srcs = [ + "__init__.py", + "core.py", + "pallas_call.py", + "primitives.py", + "utils.py", + ], deps = [ "//jax", "//jax:ad_util", @@ -46,21 +46,3 @@ py_library( "//jax/_src/lib", ] + py_deps("numpy"), ) - -py_library( - name = "gpu", - visibility = [], - deps = [ - ":pallas", - "//jax/_src/pallas/triton", - ], -) - -py_library( - name = "tpu", - visibility = [], - deps = [ - ":pallas", - "//jax/_src/pallas/mosaic", - ], -) diff --git a/jax/_src/pallas/core.py b/jax/_src/pallas/core.py index 456c5c48a9f6..9143acbd0aa2 100644 --- a/jax/_src/pallas/core.py +++ b/jax/_src/pallas/core.py @@ -15,14 +15,16 @@ """Module for pallas-core functionality.""" from __future__ import annotations -import copy -from collections.abc import Sequence +from collections.abc import Callable, Iterator, Sequence import contextlib +import copy import dataclasses import functools -from typing import Any, Callable, Union -from collections.abc import Iterator +import threading +from typing import Any, Union +import warnings +import jax from jax._src import api_util from jax._src import core as jax_core from jax._src import linear_util as lu @@ -33,12 +35,16 @@ from jax._src.state import discharge as state_discharge import jax.numpy as jnp -# TODO(sharadmv): enable type checking -# mypy: ignore-errors + +class DynamicGridDim: + pass +dynamic_grid_dim = DynamicGridDim() + partial = functools.partial -Grid = tuple[Union[int, None], ...] # None indicates that the bound is dynamic. -StaticGrid = tuple[int, ...] # None indicates that the bound is dynamic. +Grid = tuple[Union[int, jax_core.Array], ...] +StaticGrid = tuple[int, ...] +GridMappingGrid = tuple[Union[int, DynamicGridDim], ...] split_list = util.split_list map, unsafe_map = util.safe_map, map @@ -86,25 +92,60 @@ def _ref_raise_to_shaped(ref_aval: AbstractMemoryRef, weak_type): jax_core.raise_to_shaped_mappings[AbstractMemoryRef] = _ref_raise_to_shaped +@dataclasses.dataclass(frozen=True) +class PallasGridContext: + grid: GridMappingGrid + mapped_dims: tuple[int, ...] + + def size(self, axis: int) -> int | DynamicGridDim: + valid_grid = tuple( + s for i, s in enumerate(self.grid) if i not in self.mapped_dims + ) + try: + size = valid_grid[axis] + except IndexError as e: + raise ValueError( + f"Axis {axis} is out of bounds for grid {self.grid}" + ) from e + return size + + @dataclasses.dataclass -class GridEnv: - axis_index: Any - axis_size: int +class PallasTracingEnv(threading.local): + grid_context: PallasGridContext | None = None +_pallas_tracing_env = PallasTracingEnv() + -_grid_env_stack: list[tuple[GridEnv, ...]] = [] +def axis_frame() -> PallasGridContext: + # This is like jax_core.axis_frame, except there should only ever be one + # active PallasGridAxisName for a particular main_trace because we cannot + # nest pallas_calls. + env = _pallas_tracing_env + assert env.grid_context is not None + return env.grid_context + + +@dataclasses.dataclass(frozen=True) +class GridAxis: + index: jax.Array + size: int + +# Stores the kernel execution position and the size along grid axes. +GridEnv = Sequence[GridAxis] + +_grid_env_stack: list[GridEnv] = [] @contextlib.contextmanager -def grid_env(env: tuple[tuple[Any, int], ...]) -> Iterator[None]: - _grid_env_stack.append(tuple(GridEnv(axis_index, axis_size) - for axis_index, axis_size in env)) +def grid_env(env: GridEnv) -> Iterator[None]: + _grid_env_stack.append(env) try: yield finally: _grid_env_stack.pop() -def current_grid_env() -> tuple[GridEnv, ...] | None: +def current_grid_env() -> GridEnv | None: if not _grid_env_stack: return None return _grid_env_stack[-1] @@ -129,20 +170,39 @@ class Blocked: IndexingMode = Union[Blocked, Unblocked] -@dataclasses.dataclass(init=False, unsafe_hash=True) +@dataclasses.dataclass(unsafe_hash=True) class BlockSpec: - index_map: Callable[..., Any] | None - block_shape: tuple[int | None, ...] | None - memory_space: Any - indexing_mode: IndexingMode + """Specifies how an array should be sliced for each iteration of a kernel. + + See :ref:`pallas_blockspec` for more details. + """ + block_shape: tuple[int | None, ...] | None = None + index_map: Callable[..., Any] | None = None + memory_space: Any | None = dataclasses.field(kw_only=True, default=None) + indexing_mode: IndexingMode = dataclasses.field(kw_only=True, default=blocked) + + def __init__( + self, + block_shape: Any | None = None, + index_map: Any | None = None, + *, + memory_space: Any | None = None, + indexing_mode: IndexingMode = blocked, + ) -> None: + if callable(block_shape): + # TODO(slebedev): Remove this code path and update the signature of + # __init__ after October 1, 2024. + warnings.warn( + "BlockSpec now expects ``block_shape`` to be passed before" + " ``index_map``. Update your code by swapping the order of these" + " arguments. For example, ``pl.BlockSpace(lambda i: i, (42,))``" + " should be written as ``pl.BlockSpec((42,), lambda i: i)``.", + DeprecationWarning, + ) + index_map, block_shape = block_shape, index_map - def __init__(self, index_map: Callable[..., Any] | None = None, - block_shape: tuple[int | None, ...] | None = None, - memory_space: Any = None, indexing_mode: IndexingMode = blocked): - self.index_map = index_map - if block_shape is not None and not isinstance(block_shape, tuple): - block_shape = tuple(block_shape) self.block_shape = block_shape + self.index_map = index_map self.memory_space = memory_space self.indexing_mode = indexing_mode @@ -155,6 +215,10 @@ def compute_index(self, *args): return out +# A PyTree of BlockSpec | NoBlockSpec. +BlockSpecTree = Any + + @dataclasses.dataclass(frozen=True) class BlockMapping: block_shape: tuple[Mapped | int, ...] @@ -182,25 +246,43 @@ def compute_start_indices(self, loop_idx, *args): replace = dataclasses.replace +@contextlib.contextmanager +def tracing_grid_env(grid: GridMappingGrid, mapped_dims: tuple[int, ...]): + assert all(i is dynamic_grid_dim or isinstance(i, int) for i in grid) + old_grid_context = _pallas_tracing_env.grid_context + try: + _pallas_tracing_env.grid_context = PallasGridContext(grid, mapped_dims) + yield + finally: + _pallas_tracing_env.grid_context = old_grid_context + + @dataclasses.dataclass(frozen=True) class GridMapping: - grid: Grid + grid: GridMappingGrid block_mappings: tuple[BlockMapping | None, ...] - mapped_dims: tuple[int, ...] - num_index_operands: int - num_scratch_operands: int + mapped_dims: tuple[int, ...] = () + num_index_operands: int = 0 + num_scratch_operands: int = 0 + # Number of constants hoisted to operands by ``_hoist_consts_to_refs``. + num_constant_operands: int = 0 replace = dataclasses.replace @property def num_dynamic_grid_bounds(self): - return sum(b is None for b in self.grid) + return sum(b is dynamic_grid_dim for b in self.grid) @property def static_grid(self) -> StaticGrid: if self.num_dynamic_grid_bounds: raise ValueError("Expected a grid with fully static bounds") - return self.grid # typing: ignore + return self.grid # type: ignore + + @contextlib.contextmanager + def trace_env(self): + with tracing_grid_env(self.grid, self.mapped_dims): + yield def _preprocess_grid(grid: Grid | int | None) -> Grid: @@ -212,9 +294,13 @@ def _preprocess_grid(grid: Grid | int | None) -> Grid: def _convert_block_spec_to_block_mapping( - in_avals: list[jax_core.ShapedArray], block_spec: BlockSpec | None, - aval: jax_core.ShapedArray, in_tree: Any, - ) -> BlockSpec | None: + in_avals: Sequence[jax_core.ShapedArray], + block_spec: BlockSpec, + aval: jax_core.ShapedArray, + in_tree: Any, + grid: GridMappingGrid, + mapped_dims: tuple[int, ...], +) -> BlockMapping | None: if block_spec is no_block_spec: return None if block_spec.index_map is None: @@ -226,11 +312,13 @@ def _convert_block_spec_to_block_mapping( block_shape = tuple( mapped if s is None else s for s in block_shape) flat_fun, _ = api_util.flatten_fun(lu.wrap_init(compute_index), in_tree) - jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals) + with tracing_grid_env(grid, mapped_dims): + jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals) return BlockMapping( block_shape, jax_core.ClosedJaxpr(jaxpr, consts), block_spec.indexing_mode ) + def _tile_ref(ref: state.AbstractRef, block_shape: tuple[int, ...] | None ) -> state.AbstractRef: if block_shape is None: @@ -239,6 +327,15 @@ def _tile_ref(ref: state.AbstractRef, block_shape: tuple[int, ...] | None return ref.update(inner_aval=ref.inner_aval.update(shape=shape)) +def _check_static_ref_shape(ref: state.AbstractRef) -> state.AbstractRef: + shape = ref.shape + if not jax_core.is_constant_shape(shape): + # TODO(necula): thread the tree labels so that we can localize the error + raise ValueError("shape polymorphism for Pallas does not support " + f"dynamically-shaped blocks. Found block_shape: {shape}") + return ref + + def _get_ref_avals(grid, in_avals, in_specs, out_avals, out_specs): def _get_memory_space(spec): if spec is no_block_spec: @@ -256,13 +353,13 @@ def _get_memory_space(spec): in_specs = [None] * len(in_avals) out_specs = [None] * len(out_avals) tiled_in_ref_avals = [ - aval if in_spec is no_block_spec - else _tile_ref(aval, in_spec.block_shape) + _check_static_ref_shape(aval if in_spec is no_block_spec + else _tile_ref(aval, in_spec.block_shape)) for aval, in_spec in zip(in_ref_avals, in_specs) ] tiled_out_ref_avals = [ - aval if out_spec is no_block_spec - else _tile_ref(aval, out_spec.block_shape) + _check_static_ref_shape(aval if out_spec is no_block_spec + else _tile_ref(aval, out_spec.block_shape)) for aval, out_spec in zip(out_ref_avals, out_specs) ] return in_specs, tiled_in_ref_avals, out_specs, tiled_out_ref_avals @@ -271,6 +368,7 @@ class NoBlockSpec: pass no_block_spec = NoBlockSpec() + @dataclasses.dataclass(init=False, unsafe_hash=True) class GridSpec: grid: Grid @@ -282,12 +380,8 @@ class GridSpec: def __init__( self, grid: Grid | None = None, - in_specs: BlockSpec - | Sequence[BlockSpec | NoBlockSpec] - | NoBlockSpec = no_block_spec, - out_specs: BlockSpec - | Sequence[BlockSpec | NoBlockSpec] - | NoBlockSpec = no_block_spec, + in_specs: BlockSpecTree = no_block_spec, + out_specs: BlockSpecTree = no_block_spec, ): # Be more lenient for in/out_specs if isinstance(in_specs, list): @@ -331,6 +425,10 @@ def _get_in_out_specs(self, in_avals, in_tree, out_avals, out_tree): def get_grid_mapping( self, in_avals, in_tree, out_avals, out_tree ) -> tuple[tuple[jax_core.AbstractValue, ...], GridMapping]: + assert all(i is None or isinstance(i, int) for i in self.grid) + grid_mapping_grid = tuple( + dynamic_grid_dim if d is None else d for d in self.grid + ) flat_in_specs, flat_out_specs = self._get_in_out_specs( in_avals, in_tree, out_avals, out_tree) in_specs, in_ref_avals, out_specs, out_ref_avals = _get_ref_avals( @@ -340,25 +438,45 @@ def get_grid_mapping( # Create args, kwargs pytree def grid_tree = tree_util.tree_structure((tuple(grid_avals), {})) in_block_mappings = map( - partial(_convert_block_spec_to_block_mapping, grid_avals, - in_tree=grid_tree), in_specs, in_ref_avals) + partial( + _convert_block_spec_to_block_mapping, + grid_avals, + in_tree=grid_tree, + grid=grid_mapping_grid, + mapped_dims=(), + ), + in_specs, + in_ref_avals, + ) out_block_mappings = map( - partial(_convert_block_spec_to_block_mapping, grid_avals, - in_tree=grid_tree), out_specs, out_ref_avals) + partial( + _convert_block_spec_to_block_mapping, + grid_avals, + in_tree=grid_tree, + grid=grid_mapping_grid, + mapped_dims=(), + ), + out_specs, + out_ref_avals, + ) grid_mapping = GridMapping( - self.grid, (*in_block_mappings, *out_block_mappings), (), - num_index_operands=0, num_scratch_operands=0) + grid_mapping_grid, (*in_block_mappings, *out_block_mappings) # type: ignore + ) jaxpr_in_avals = tree_util.tree_unflatten(in_tree, in_ref_avals) jaxpr_out_avals = tree_util.tree_unflatten(out_tree, out_ref_avals) if not isinstance(jaxpr_out_avals, (tuple, list)): jaxpr_out_avals = (jaxpr_out_avals,) return (*jaxpr_in_avals, *jaxpr_out_avals), grid_mapping - def unzip_dynamic_grid_bounds(self) -> tuple[GridSpec, tuple[Any, ...]]: - static_grid = tuple(d if isinstance(d, int) else None for d in self.grid) + def unzip_dynamic_grid_bounds( + self, + ) -> tuple[GridSpec, tuple[Any, ...]]: + static_grid = tuple( + d if isinstance(d, int) else None for d in self.grid + ) dynamic_bounds = tuple(d for d in self.grid if not isinstance(d, int)) # We can't use dataclasses.replace, because our fields are incompatible # with __init__'s signature. static_self = copy.copy(self) - static_self.grid = static_grid + static_self.grid = static_grid # type: ignore return static_self, dynamic_bounds diff --git a/jax/_src/pallas/mosaic/BUILD b/jax/_src/pallas/mosaic/BUILD index deecfdcdc09f..4c849dfba267 100644 --- a/jax/_src/pallas/mosaic/BUILD +++ b/jax/_src/pallas/mosaic/BUILD @@ -14,12 +14,8 @@ # Package for Mosaic-specific Pallas extensions -load( - "//jaxlib:jax.bzl", - "py_deps", - "py_library_providing_imports_info", -) load("@rules_python//python:defs.bzl", "py_library") +load("//jaxlib:jax.bzl", "py_deps") package( default_applicable_licenses = [], @@ -28,20 +24,6 @@ package( ], ) -py_library_providing_imports_info( - name = "mosaic", - srcs = ["__init__.py"], - lib_rule = py_library, - deps = [ - ":core", - ":kernel_regeneration_util", - ":lowering", - ":pallas_call_registration", - ":pipeline", - ":primitives", - ], -) - py_library( name = "core", srcs = ["core.py"], @@ -95,14 +77,6 @@ py_library( ] + py_deps("numpy"), ) -py_library( - name = "kernel_regeneration_util", - srcs = ["kernel_regeneration_util.py"], - deps = [ - "//third_party/py/mlir:ir", - ], -) - py_library( name = "pipeline", srcs = ["pipeline.py"], @@ -113,5 +87,15 @@ py_library( "//jax:api_util", "//jax:util", "//jax/_src/pallas", - ], + ] + py_deps("numpy"), +) + +py_library( + name = "random", + srcs = ["random.py"], + deps = [ + ":primitives", + "//jax", + "//jax:typing", + ] + py_deps("numpy"), ) diff --git a/jax/_src/pallas/mosaic/__init__.py b/jax/_src/pallas/mosaic/__init__.py index 7e7a06b2801a..38d13f42da99 100644 --- a/jax/_src/pallas/mosaic/__init__.py +++ b/jax/_src/pallas/mosaic/__init__.py @@ -11,42 +11,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -"""Module for Mosaic lowering of Pallas call.""" - -from jax._src.pallas.mosaic import core -from jax._src.pallas.mosaic import pallas_call_registration -from jax._src.pallas.mosaic.core import PrefetchScalarGridSpec -from jax._src.pallas.mosaic.core import SemaphoreType -from jax._src.pallas.mosaic.core import TPUMemorySpace -from jax._src.pallas.mosaic.core import semaphore -from jax._src.pallas.mosaic.core import dma_semaphore -from jax._src.pallas.mosaic.kernel_regeneration_util import encode_kernel_regeneration_metadata -from jax._src.pallas.mosaic.kernel_regeneration_util import extract_kernel_regeneration_metadata -from jax._src.pallas.mosaic.lowering import LoweringException -from jax._src.pallas.mosaic.pipeline import emit_pipeline_with_allocations -from jax._src.pallas.mosaic.pipeline import emit_pipeline -from jax._src.pallas.mosaic.pipeline import PipelineCallbackArgs -from jax._src.pallas.mosaic.pipeline import PipelinePrefetchArgs -from jax._src.pallas.mosaic.pipeline import ManualPrefetchArgs -from jax._src.pallas.mosaic.primitives import DeviceIdType -from jax._src.pallas.mosaic.primitives import async_copy -from jax._src.pallas.mosaic.primitives import async_remote_copy -from jax._src.pallas.mosaic.primitives import bitcast -from jax._src.pallas.mosaic.primitives import device_id -from jax._src.pallas.mosaic.primitives import get_barrier_semaphore -from jax._src.pallas.mosaic.primitives import make_async_copy -from jax._src.pallas.mosaic.primitives import make_async_remote_copy -from jax._src.pallas.mosaic.primitives import repeat -from jax._src.pallas.mosaic.primitives import roll -from jax._src.pallas.mosaic.primitives import run_scoped -from jax._src.pallas.mosaic.primitives import semaphore_signal -from jax._src.pallas.mosaic.primitives import semaphore_wait -from jax._src.pallas.mosaic.primitives import trace - -ANY = TPUMemorySpace.ANY -CMEM = TPUMemorySpace.CMEM -SMEM = TPUMemorySpace.SMEM -VMEM = TPUMemorySpace.VMEM - -del pallas_call_registration diff --git a/jax/_src/pallas/mosaic/core.py b/jax/_src/pallas/mosaic/core.py index 200dfbc5296d..f4a794792253 100644 --- a/jax/_src/pallas/mosaic/core.py +++ b/jax/_src/pallas/mosaic/core.py @@ -28,20 +28,17 @@ import jax.numpy as jnp from jax._src.pallas import core as pallas_core -# TODO(sharadmv): enable type checking -# mypy: ignore-errors - map, unsafe_map = util.safe_map, map zip, unsafe_zip = util.safe_zip, zip partial = functools.partial Grid = pallas_core.Grid BlockSpec = pallas_core.BlockSpec +BlockSpecTree = pallas_core.BlockSpecTree GridMapping = pallas_core.GridMapping NoBlockSpec = pallas_core.NoBlockSpec AbstractMemoryRef = pallas_core.AbstractMemoryRef no_block_spec = pallas_core.no_block_spec -_preprocess_grid = pallas_core._preprocess_grid _convert_block_spec_to_block_mapping = pallas_core._convert_block_spec_to_block_mapping split_list = util.split_list @@ -65,8 +62,14 @@ class semaphore(semaphore_dtype): pass class dma_semaphore(semaphore_dtype): pass class barrier_semaphore(semaphore_dtype): pass +class AbstractSemaphoreTyRules: + @staticmethod + def pallas_interpret_element_aval(_) -> jax_core.ShapedArray: + return jax_core.ShapedArray((), jnp.dtype('int32')) + class AbstractSemaphoreTy(dtypes.ExtendedDType): name: str + _rules = AbstractSemaphoreTyRules def __repr__(self) -> str: return self.name @@ -75,7 +78,7 @@ def __eq__(self, other): return self.__class__ == other.__class__ def __hash__(self) -> int: - return hash((self.__class__)) + return hash(self.__class__) # TODO(sharadmv): implement dtype rules for AbstractSemaphoreTy @@ -97,6 +100,7 @@ class SemaphoreType(enum.Enum): BARRIER = "barrier" def __call__(self, shape: tuple[int, ...]): + dtype: Any if self == SemaphoreType.DMA: dtype = DmaSemaphoreTy() elif self == SemaphoreType.BARRIER: @@ -105,7 +109,7 @@ def __call__(self, shape: tuple[int, ...]): dtype = SemaphoreTy() return MemoryRef(shape, dtype, TPUMemorySpace.SEMAPHORE) - def get_aval(self) -> "AbstractMemoryRef": + def get_aval(self) -> AbstractMemoryRef: return self(()).get_aval() @dataclasses.dataclass(frozen=True) @@ -157,12 +161,8 @@ def __init__( self, num_scalar_prefetch: int, grid: Grid | None = None, - in_specs: BlockSpec - | Sequence[BlockSpec | NoBlockSpec] - | NoBlockSpec = no_block_spec, - out_specs: BlockSpec - | Sequence[BlockSpec | NoBlockSpec] - | NoBlockSpec = no_block_spec, + in_specs: BlockSpecTree = no_block_spec, + out_specs: BlockSpecTree = no_block_spec, scratch_shapes: Any | Sequence[Any] = () ): super().__init__(grid, in_specs, out_specs) @@ -172,6 +172,10 @@ def __init__( def get_grid_mapping( self, in_avals, in_tree, out_avals, out_tree ) -> tuple[tuple[jax_core.AbstractValue, ...], GridMapping]: + assert all(i is None or isinstance(i, int) for i in self.grid) + grid_mapping_grid = tuple( + pallas_core.dynamic_grid_dim if d is None else d for d in self.grid + ) all_avals = tree_util.tree_unflatten(in_tree, in_avals) flat_scratch_shapes, scratch_tree = tree_util.tree_flatten( self.scratch_shapes) @@ -197,15 +201,29 @@ def get_grid_mapping( ((*grid_avals, *scalar_avals), {}) ) in_block_mappings = map( - partial(_convert_block_spec_to_block_mapping, - (*grid_avals, *scalar_ref_avals), - in_tree=index_map_in_tree), in_specs, in_ref_avals) + partial( + _convert_block_spec_to_block_mapping, + (*grid_avals, *scalar_ref_avals), + in_tree=index_map_in_tree, + grid=grid_mapping_grid, + mapped_dims=(), + ), + in_specs, + in_ref_avals, + ) out_block_mappings = map( - partial(_convert_block_spec_to_block_mapping, - (*grid_avals, *scalar_ref_avals), - in_tree=index_map_in_tree), out_specs, out_ref_avals) + partial( + _convert_block_spec_to_block_mapping, + (*grid_avals, *scalar_ref_avals), + in_tree=index_map_in_tree, + grid=grid_mapping_grid, + mapped_dims=(), + ), + out_specs, + out_ref_avals, + ) grid_mapping = GridMapping( - grid=self.grid, + grid=grid_mapping_grid, # type: ignore block_mappings=(*in_block_mappings, *out_block_mappings), mapped_dims=(), num_index_operands=num_flat_scalar_prefetch, diff --git a/jax/_src/pallas/mosaic/kernel_regeneration_util.py b/jax/_src/pallas/mosaic/kernel_regeneration_util.py deleted file mode 100644 index 236d398c6ccf..000000000000 --- a/jax/_src/pallas/mosaic/kernel_regeneration_util.py +++ /dev/null @@ -1,60 +0,0 @@ -# Copyright 2023 The JAX Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Helpers to encode and decode Mosaic kernel regeneration metadata.""" - -import base64 -import json -from typing import Any -from jaxlib.mlir import ir - - -def encode_kernel_regeneration_metadata( - metadata: dict[str, Any] -) -> dict[str, bytes]: - """Serializes the given kernel regeneration metadata. - - This function hides the serialization details from the end user. - - Args: - metadata: dictionary with user-defined data to be serialized in the backend - config. - - Returns: - A dict that can be directly passed to pallas_call as a 'mosaic_params' - argument. - - Raises: - TypeError: when the input metadata is not serializable in json format. - """ - serialized_metadata = bytes(json.dumps(metadata), encoding="utf-8") - return dict(kernel_regeneration_metadata=serialized_metadata) - - -def extract_kernel_regeneration_metadata(op: ir.Operation) -> dict[str, Any]: - """Extract kernel regeneration metadata from the given Operation. - - This function hides the serialization details from the end user. - - Args: - op: the tpu custom_call mlir Operation that contains the kernel metadata. - - Returns: - The decoded metadata in the form of a dict. This corresponds to the dict - in input to the 'encode' function. - """ - kernel_regeneration_metadata = ir.StringAttr( - op.attributes["kernel_regeneration_metadata"] - ).value - return json.loads(base64.b64decode(kernel_regeneration_metadata)) diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index 27ab8082a8fb..80411f12a838 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -15,21 +15,24 @@ """Module for lowering JAX to Mosaic-compatible MLIR dialects.""" from __future__ import annotations +from collections.abc import Callable, Sequence import dataclasses import functools -from typing import Any, Callable -from collections.abc import Sequence +import string +from typing import Any import jax from jax import core as jax_core from jax import lax from jax import tree_util +from jax._src import ad_util from jax._src import custom_derivatives from jax._src import debugging +from jax._src import dtypes from jax._src import linear_util as lu -from jax._src import ad_util from jax._src import mesh as mesh_lib from jax._src import pjit +from jax._src import prng from jax._src import source_info_util from jax._src import state from jax._src.interpreters import mlir @@ -55,9 +58,9 @@ from jax._src.util import safe_zip from jax._src.util import split_list from jax._src.util import unzip2 -from jax._src.util import unzip3 from jax.experimental.mosaic.dialects import tpu import jax.numpy as jnp +from jaxlib.mlir.ir import Module import numpy as np # TODO(sharadmv): enable type checking @@ -68,17 +71,23 @@ VMEM = tpu_core.TPUMemorySpace.VMEM SMEM = tpu_core.TPUMemorySpace.SMEM -# The value interpreter as a dynamic dimension by MLIR. +# The value interpreted as a dynamic dimension by MLIR. MLIR_DYNAMIC = -9223372036854775808 partial = functools.partial map, unsafe_map = safe_map, map # pylint: disable=redefined-builtin zip, unsafe_zip = safe_zip, zip # pylint: disable=redefined-builtin +UNSIGNED_TO_SIGNED = { + np.dtype('uint8'): np.dtype('int8'), + np.dtype('uint16'): np.dtype('int16'), + np.dtype('uint32'): np.dtype('int32'), + np.dtype('uint64'): np.dtype('int64'), +} @dataclasses.dataclass class MeshContext: - logical_to_mesh: ir.Value + mesh_shape: tuple[int, ...] axis_names: tuple[str, ...] mesh_strides: tuple[int, ...] @@ -86,7 +95,9 @@ class MeshContext: @dataclasses.dataclass class LoweringContext: ir_context: ir.Context - grid_indices: Sequence[ir.Value] | None + grid_rank: int # Includes both user and vmap axes. + mapped_dims: tuple[int, ...] # Indices of vmapped grid dimensions. + user_grid_indices: Sequence[ir.Value] | None block_shapes: list[tuple[int | pl_core.Mapped, ...]] name_stack: source_info_util.NameStack mesh_context: MeshContext | None @@ -120,7 +131,13 @@ def _dtype_to_ir_type(dtype: jnp.dtype) -> ir.Type: return ir.Type.parse("!tpu.semaphore") else: raise NotImplementedError - return mlir.dtype_to_ir_type(dtype) + # TODO(justinfu): Remove after mosaic supports unsigned types. + # This conversion makes mosaic interpret all unsigned types as signed types. + type = mlir.dtype_to_ir_type(dtype) + if isinstance(type, ir.IntegerType): + return ir.IntegerType.get_signless(type.width) + else: + return type def aval_to_ir_type(aval, shape=None, memory_space: TPUMemorySpace | None = None): if isinstance(aval, tpu_core.AbstractSemaphore): @@ -134,6 +151,15 @@ def aval_to_ir_type(aval, shape=None, memory_space: TPUMemorySpace | None = None raise ValueError(f"Cannot allocate {aval.sem_type}.") memspace = _memory_space_to_tpu_memspace(TPUMemorySpace.SEMAPHORE) return ir.MemRefType.get((), sem_type, memory_space=memspace) + if dtypes.issubdtype(aval.dtype, dtypes.prng_key): + shape = aval.dtype._impl.key_shape + if memory_space is None: + memory_space = TPUMemorySpace.SMEM + if memory_space != TPUMemorySpace.SMEM: + raise ValueError(f"PRNG keys must be stored in SMEM. Got {memory_space}") + memspace = _memory_space_to_tpu_memspace(memory_space) + return ir.MemRefType.get(shape, _dtype_to_ir_type(np.dtype(np.uint32)), + memory_space=memspace) if isinstance(aval, state.AbstractRef): if shape is None: shape = aval.shape @@ -297,20 +323,7 @@ def _prepare_mesh_info(self, mesh: mesh_lib.Mesh | None): mesh_strides = pallas_utils.strides_from_shape(tuple( mesh.shape[a] for a in axis_names )) - logical_to_mesh = np.empty((mesh.size, len(axis_names)), dtype=np.int32) - for i, idx in enumerate(np.ndindex(*mesh.device_ids.shape)): - logical_to_mesh[i] = np.array(idx) - self.mesh_info = MeshInfo(logical_to_mesh, axis_names, mesh_strides) - l_to_m_aval = pl_core.AbstractMemoryRef( - jax_core.raise_to_shaped(jax_core.get_aval(logical_to_mesh)), - TPUMemorySpace.SMEM, - ) - # We are now passing in the logical -> mesh index mapping - # TODO(sharadmv,apaszke): avoid stalling pipeline by marking the index - # mapping as scalar prefetch and instead just mark it as an SMEM operand. - self.scalar_prefetch_types = ( - _get_arg_type(l_to_m_aval, None)[0], - *self.scalar_prefetch_types) + self.mesh_info = MeshInfo(mesh.device_ids.shape, axis_names, mesh_strides) def maybe_compress_grid(self): # If we have many leading parallel dimensions, we should "compress" them @@ -323,9 +336,7 @@ def has_communication(self) -> bool: return bool(jax_core.used_axis_names_jaxpr(self.jaxpr)) def get_extra_args(self) -> tuple[Any, ...]: - if self.mesh_info is None: - return () - return (self.mesh_info.logical_to_mesh,) + return () def get_dimension_semantics(self) -> ir.ArrayAttr: @@ -343,7 +354,7 @@ def _get_semantics(s: str | None) -> str: @dataclasses.dataclass class MeshInfo: - logical_to_mesh: np.ndarray + mesh_shape: tuple[int, ...] axis_names: list[str] mesh_strides: tuple[int, ...] @@ -355,7 +366,7 @@ def lower_jaxpr_to_module( jaxpr: jax_core.Jaxpr, dimension_semantics: tuple[str | None, ...] | None, mesh: mesh_lib.Mesh | None = None -) -> ir.Module: +) -> tuple[Module, tuple[Any, ...]]: mosaic_grid_mapping = MosaicGridMapping( jaxpr, grid_mapping, dimension_semantics, mesh) mosaic_grid_mapping.maybe_compress_grid() @@ -392,13 +403,20 @@ def lower_jaxpr_to_module( raise NotImplementedError("Index map jaxpr with consts not supported.") # ANY operands don't support windowing and require empty window_params. if aval.memory_space == tpu_core.TPUMemorySpace.ANY: - requires_windowing = bm.block_shape != full_ty.shape - for atom in bm.index_map_jaxpr.jaxpr.outvars: - if requires_windowing: - break - requires_windowing = not ( - isinstance(atom, jax_core.Literal) and atom.val == 0 - ) + # We may not require windowing if our block_shape matches the original + # shape or the dimensions are mapped. + requires_windowing = any( + b != s + for b, s in zip(bm.block_shape, full_ty.shape) + if not (b is pl_core.mapped and s == 1) + ) + if np.prod(grid) != 1: + for atom in bm.index_map_jaxpr.jaxpr.outvars: + if requires_windowing: + break + requires_windowing = not ( + isinstance(atom, jax_core.Literal) and atom.val == 0 + ) if requires_windowing: raise NotImplementedError( "Operands in placed in the TPUMemorySpace.ANY memory space don't" @@ -433,7 +451,9 @@ def lower_jaxpr_to_module( m.body.append(mlir_func) sym_tab.insert(mlir_func) func_op.attributes["window_params"] = ir.ArrayAttr.get(window_params) - static_grid = [MLIR_DYNAMIC if b is None else b for b in grid] + static_grid = [ + MLIR_DYNAMIC if b is pl_core.dynamic_grid_dim else b for b in grid + ] func_op.attributes["iteration_bounds"] = ir.DenseI64ArrayAttr.get(static_grid) func_op.attributes["scalar_prefetch"] = ir.IntegerAttr.get( @@ -468,13 +488,15 @@ def body_func(*args): mesh_info = mosaic_grid_mapping.mesh_info if mesh_info is not None: - (l_to_m,), scalar_prefetch = split_list(scalar_prefetch, [1]) - mesh_context = MeshContext(l_to_m, mesh_info.axis_names, - mesh_info.mesh_strides) + mesh_context = MeshContext( + mesh_info.mesh_shape, mesh_info.axis_names, mesh_info.mesh_strides + ) else: mesh_context = None lowering_context = LoweringContext( ctx, + len(mosaic_grid_mapping.grid), + mosaic_grid_mapping.mapped_dims, None, arg_block_shapes, source_info_util.NameStack(), @@ -524,13 +546,15 @@ def body_func(*args): if i not in mosaic_grid_mapping.mapped_dims) mesh_info = mosaic_grid_mapping.mesh_info if mesh_info is not None: - (l_to_m,), scalar_prefetch = split_list(scalar_prefetch, [1]) - mesh_context = MeshContext(l_to_m, mesh_info.axis_names, - mesh_info.mesh_strides) + mesh_context = MeshContext( + mesh_info.mesh_shape, mesh_info.axis_names, mesh_info.mesh_strides + ) else: mesh_context = None lowering_context = LoweringContext( ctx, + len(mosaic_grid_mapping.grid), + mosaic_grid_mapping.mapped_dims, jaxpr_indices, arg_block_shapes, source_info_util.NameStack(), @@ -575,6 +599,29 @@ class LoweringException(Exception): pass +def _compute_name_stack_updates( + old_name_stack: list[str], + new_name_stack: list[str] +) -> tuple[list[str], list[str]]: + """Computes the popped/pushed items to the name stack after an update. + + Args: + old_name_stack: The name stack prior to the update. + new_name_stack: The name stack after the update. + + Returns: + popped: A list of names popped from the name stack as part of the update. + pushed: A list of names pushed to the name stack as part of the update. + """ + common_prefix_idx = 0 + for i, (old, new) in enumerate(unsafe_zip(old_name_stack, new_name_stack)): + if old == new: + common_prefix_idx = i+1 + else: + break + return old_name_stack[common_prefix_idx:], new_name_stack[common_prefix_idx:] + + def jaxpr_subcomp( ctx: LoweringContext, jaxpr: jax_core.Jaxpr, *args: ir.Value ) -> Sequence[ir.Value]: @@ -591,13 +638,18 @@ def read_env(atom: jax_core.Atom): return atom.val if isinstance(atom, jax_core.Literal) else env[atom] def write_env(var: jax_core.Var, val): - assert isinstance(val, ir.Value), type(val) + is_valid_type = isinstance(val, (ir.Value, KeyScalarBundle)) + assert is_valid_type, type(val) env[var] = val for invar, bs in zip(jaxpr.invars, ctx.block_shapes): block_shape_env[invar] = bs map(write_env, jaxpr.invars, args) + initial_name_stack = [scope.name for scope in ctx.name_stack.stack] + current_name_stack: list[str] = [] + # TODO(justinfu): Handle transform scopes. + current_name_stack.extend(initial_name_stack) for eqn in jaxpr.eqns: invals = map(read_env, eqn.invars) source_info = eqn.source_info.replace( @@ -618,6 +670,17 @@ def write_env(var: jax_core.Var, val): [v.aval for v in eqn.outvars], block_shapes, ) + + # Insert trace_start and trace_stop ops on named_scope boundaries. + name_stack = [scope.name for scope in source_info.name_stack.stack] + popped, pushed = _compute_name_stack_updates( + current_name_stack, name_stack) + current_name_stack = name_stack + for _ in popped: + tpu.TraceStopOp() + for name in pushed: + tpu.TraceStartOp(message=name, level=10) + try: ans = lowering_rules[eqn.primitive]( rule_context, *invals, **eqn.params @@ -632,6 +695,7 @@ def write_env(var: jax_core.Var, val): " inval" f" types={map(lambda t: getattr(t, 'type', None), invals)}\nIn" f" jaxpr:\n{jaxpr}" + f"\nException: {e}" ) from e else: raise NotImplementedError( @@ -642,6 +706,14 @@ def write_env(var: jax_core.Var, val): map(write_env, eqn.outvars, ans) else: write_env(eqn.outvars[0], ans) + + # Drain the name stack at the end of a jaxpr and insert trace_stop ops. + popped, pushed = _compute_name_stack_updates( + current_name_stack, initial_name_stack) + for _ in popped: + tpu.TraceStopOp() + assert len(pushed) == 0 + outvals = map(read_env, jaxpr.outvars) outvals = [ ir_constant(x) if isinstance(var, jax_core.Literal) else x @@ -653,6 +725,8 @@ def write_env(var: jax_core.Var, val): def _ensure_mlir_value(val, aval): if isinstance(val, ir.Value): return val + if isinstance(val, KeyScalarBundle): + return val elif isinstance(val, (np.generic, np.ndarray, int, float)): return ir_constant(val, _dtype_to_ir_type(aval.dtype)) else: @@ -740,80 +814,129 @@ def _maybe_cast_to_index(cast_to_index, x): return _make_index(x) return _ensure_mlir_value(x, aval=jax_core.ShapedArray((), jnp.int32)) -def _index_to_start_size(idx: tuple[indexing.Slice | int | ir.Value, ...], - cast_to_index: bool) -> tuple[ir.Value, int, bool]: + +def _index_to_start_size_stride( + idx: tuple[indexing.Slice | int | ir.Value, ...], cast_to_index: bool +) -> tuple[ir.Value, int | ir.Value, int, bool]: assert not isinstance(idx, slice) if isinstance(idx, indexing.Slice): start = _maybe_cast_to_index(cast_to_index, idx.start) size = idx.size + stride = idx.stride squeeze = False elif isinstance(idx, int): start = _maybe_cast_to_index(cast_to_index, idx) size = 1 + stride = 1 squeeze = True else: if np.shape(idx): raise ValueError(f"Can only use ()-shaped and slice indexing: {idx}") start = _maybe_cast_to_index(cast_to_index, idx) size = 1 + stride = 1 squeeze = True - return start, size, squeeze + return start, size, stride, squeeze -def _indexer_to_start_size( - indexer: NDIndexer, ref_block_shape: tuple[int | pl_core.Mapped, ...], *, +def _indexer_to_start_size_stride( + indexer: NDIndexer, + ref_block_shape: tuple[int | pl_core.Mapped, ...], + *, cast_to_index: bool, -) -> tuple[tuple[ir.Value, ...], tuple[int, ...], tuple[bool, ...], - tuple[int | pl_core.Mapped, ...]]: +) -> tuple[ + tuple[ir.Value, ...], + tuple[int | ir.Value, ...], + tuple[int, ...], + tuple[bool, ...], + tuple[int | pl_core.Mapped, ...], +]: indices_iter = iter(indexer.indices) - starts, sizes, squeeze_dims = unzip3( - ( - _maybe_cast_to_index(cast_to_index, 0), - 1, - True, - ) - if s is pl_core.mapped - else _index_to_start_size(next(indices_iter), cast_to_index) - for s in ref_block_shape - ) + starts, sizes, strides, squeeze_dims = [], [], [], [] + for s in ref_block_shape: + start, size, stride, squeeze_dim = ( + ( + _maybe_cast_to_index(cast_to_index, 0), + 1, + 1, + True, + ) + if s is pl_core.mapped + else _index_to_start_size_stride(next(indices_iter), cast_to_index) + ) + starts.append(start) + sizes.append(size) + strides.append(stride) + squeeze_dims.append(squeeze_dim) next_index = next(indices_iter, None) assert next_index is None, (indexer.indices, ref_block_shape) new_ref_block_shape = tuple(s for s, squeeze in zip(sizes, squeeze_dims) if not squeeze) - return tuple(starts), tuple(sizes), tuple(squeeze_dims), new_ref_block_shape + return ( + tuple(starts), + tuple(sizes), + tuple(strides), + tuple(squeeze_dims), + new_ref_block_shape, + ) def _slice_memref(ref: ir.Value, ref_aval: state.AbstractRef, indexer: NDIndexer, ref_block_shape: tuple[int | pl_core.Mapped, ...] - ) -> tuple[ir.Value, state.AbstractRef, tuple[int | pl_core.Mapped, ...], + ) -> tuple[ir.Value, tuple[int | pl_core.Mapped, ...], tuple[int | pl_core.Mapped, ...]]: assert ref_block_shape is not None target_shape = indexer.get_indexer_shape() - starts, sizes, squeeze_dims, ref_block_shape = _indexer_to_start_size( - indexer, ref_block_shape, cast_to_index=False, + starts, sizes, strides, squeeze_dims, ref_block_shape = ( + _indexer_to_start_size_stride( + indexer, + ref_block_shape, + cast_to_index=False, + ) ) + if not all((s is None or s == 1) for s in strides): + raise NotImplementedError("Strided slices of references are unsupported.") + dynamic_sizes = tuple(s for s in sizes if isinstance(s, ir.Value)) + ir_dynamic_size = ir.ShapedType.get_dynamic_size() + static_sizes = tuple(s if not isinstance(s, ir.Value) + else ir_dynamic_size for s in sizes) target_ref_ty = ir.MemRefType.get( - tuple(sizes), _dtype_to_ir_type(ref_aval.dtype), + static_sizes, _dtype_to_ir_type(ref_aval.dtype), memory_space=ref.type.memory_space) - inner_aval = ref_aval.inner_aval - out_aval = ref_aval.update(inner_aval=inner_aval.update(shape=target_shape)) - out = tpu.MemRefSliceOp(target_ref_ty, ref, starts).result + out = tpu.MemRefSliceOp(target_ref_ty, ref, starts, dynamic_sizes).result if any(squeeze_dims): # We need to squeeze out some dimensions + static_sizes = tuple(s if not isinstance(s, ir.Value) + else ir_dynamic_size for s in target_shape) squeezed_ref_ty = ir.MemRefType.get( - tuple(target_shape), _dtype_to_ir_type(ref_aval.dtype), + static_sizes, _dtype_to_ir_type(ref_aval.dtype), memory_space=ref.type.memory_space) out = tpu.MemRefSqueezeOp(squeezed_ref_ty, out).result - return out, out_aval, ref_block_shape + return out, ref_block_shape def _index_ref(ref, ref_aval, ref_block_shape, indexers): for indexer in indexers: - ref, ref_aval, ref_block_shape = _slice_memref(ref, ref_aval, indexer, - ref_block_shape) - return ref, ref_aval, ref_block_shape - + ref, ref_block_shape = _slice_memref(ref, ref_aval, indexer, + ref_block_shape) + return ref, ref_block_shape + +@dataclasses.dataclass(frozen=True) +class KeyScalarBundle: + """A container class for PRNG key data. + + We pass around keys as a KeyScalarBundle in the lowering pass rather than + as a vector, since we want the key data to live in scalar registers rather + than vector registers. This special dataclass exists so we can return + multiple scalar values from load_op, because the load_op primitive does + not allow multiple results. + + Attributes: + scalars: A list of OpResults representing scalar key data during the + lowering pass. + """ + scalars: list[ir.OpResult] def _load_lowering_rule(ctx: LoweringRuleContext, *args_flat, args_tree, **_): ref, indexers, mask, _ = args_tree.unflatten(args_flat) @@ -826,12 +949,18 @@ def _load_lowering_rule(ctx: LoweringRuleContext, *args_flat, args_tree, **_): raise NotImplementedError ref_block_shape, *_ = ctx.block_shapes - ref, _, ref_block_shape = _index_ref( + ref, ref_block_shape = _index_ref( ref, ref_aval, ref_block_shape, slice_indexers) ref_type = ir.MemRefType(ref.type) is_smem_load = str(ref_type.memory_space) == "#tpu.memory_space" ref_aval, *_ = ctx.avals_in (aval_out,) = ctx.avals_out + if isinstance(aval_out.dtype, prng.KeyTy): + if not is_smem_load: + raise ValueError("PRNG keys must be loaded from SMEM. Did you set " + "the memory space to TPUMemorySpace.SMEM in the " + "BlockSpec for the PRNG key input?") + return _prng_key_load_lowering_rule(ctx, *args_flat, args_tree=args_tree) if not is_smem_load and not ref_block_shape: raise NotImplementedError( "Indexing into a ()-shaped Ref not yet supported on TPU.") @@ -840,14 +969,21 @@ def _load_lowering_rule(ctx: LoweringRuleContext, *args_flat, args_tree, **_): for a in idx_aval.indices ): raise ValueError("Cannot do int indexing on TPU") - starts, sizes, _, _ = _indexer_to_start_size( - idx, ref_block_shape, cast_to_index=True, + starts, sizes, strides, _, _ = _indexer_to_start_size_stride( + idx, + ref_block_shape, + cast_to_index=True, ) + need_stride = not all((s is None or s == 1) for s in strides) load_aval = jax_core.ShapedArray(sizes, dtype=ref_aval.dtype) if is_smem_load: if ctx.avals_out[0].shape: raise ValueError("Can only load scalars from SMEM") return memref.LoadOp(ref, starts).result + if need_stride: + load_val = tpu.StridedLoadOp( + aval_to_ir_type(load_aval), ref, starts, strides + ).result else: load_val = vector.LoadOp(aval_to_ir_type(load_aval), ref, starts).result if load_aval == aval_out: @@ -856,6 +992,37 @@ def _load_lowering_rule(ctx: LoweringRuleContext, *args_flat, args_tree, **_): _dtype_to_ir_type(aval_out.dtype)) return vector.ShapeCastOp(vec_type, load_val).result +def _prng_key_load_lowering_rule(ctx: LoweringRuleContext, *args_flat, args_tree) -> KeyScalarBundle: + """Lowering rule for loading PRNG keys from SMEM. + + PRNG key loads are currently lowered as a list of scalar loads from SMEM, + rather than a single vector load. + We store these scalars in a bundle type called KeyScalarBundle, which has + special case handling for functions that consume the key such as set_seed. + """ + ref, _, _, _ = args_tree.unflatten(args_flat) + (aval_out,) = ctx.avals_out + assert isinstance(aval_out.dtype, prng.KeyTy) + ref_block_shape = aval_out.dtype._impl.key_shape + + if len(ref_block_shape) != 2: + raise NotImplementedError("Seed key_data must be 2D.") + if tuple(ref_block_shape) != (1, 1): + raise NotImplementedError( + f"Seed key_data of shape != (1, 1) not supported. Got: {ref_block_shape}") + + load_ops = [] + for i in range(ref_block_shape[0]): + idx = NDIndexer(indices=(0, i), shape=ref_block_shape, + int_indexer_shape=tuple()) + starts, _, _, _, _ = _indexer_to_start_size_stride( + idx, + ref_block_shape, + cast_to_index=True, + ) + load_ops.append(memref.LoadOp(ref, starts).result) + return KeyScalarBundle(scalars=load_ops) + lowering_rules[primitives.load_p] = _load_lowering_rule skip_mlir_conversions.add(primitives.load_p) @@ -873,7 +1040,7 @@ def _masked_swap_lowering_rule( raise NotImplementedError ref_block_shape, *_ = ctx.block_shapes - ref, _, ref_block_shape = _index_ref( + ref, ref_block_shape = _index_ref( ref, ref_aval, ref_block_shape, slice_indexers) ref_type = ir.MemRefType(ref.type) @@ -890,10 +1057,12 @@ def _masked_swap_lowering_rule( raise NotImplementedError( "Indexing into a ()-shaped Ref not yet supported on TPU.") - starts, _, _, _ = _indexer_to_start_size( - idx, ref_block_shape, cast_to_index=True, + starts, _, strides, _, _ = _indexer_to_start_size_stride( + idx, + ref_block_shape, + cast_to_index=True, ) - + need_stride = not all((s is None or s == 1) for s in strides) if is_smem_store: if val_aval.shape: raise ValueError("Can only store scalars to SMEM") @@ -912,7 +1081,10 @@ def _masked_swap_lowering_rule( mem_aval = aval_out.update(shape=tuple(mem_slice_shape)) mem_aval_vec_type = ir.VectorType.get(mem_aval.shape, _dtype_to_ir_type(mem_aval.dtype)) - result = vector.LoadOp(mem_aval_vec_type, ref, starts).result + if need_stride: + result = tpu.StridedLoadOp(mem_aval_vec_type, ref, starts, strides).result + else: + result = vector.LoadOp(mem_aval_vec_type, ref, starts).result if mem_aval != aval_out: # We are slicing a scalar so provided dummy 1 indices result_vec_type = ir.VectorType.get(aval_out.shape, @@ -921,7 +1093,10 @@ def _masked_swap_lowering_rule( val_vec_type = ir.VectorType.get(mem_aval.shape, _dtype_to_ir_type(mem_aval.dtype)) val = vector.ShapeCastOp(val_vec_type, val).result - vector.StoreOp(val, ref, starts) + if need_stride: + tpu.StridedStoreOp(val, ref, starts, strides) + else: + vector.StoreOp(val, ref, starts) return result @@ -941,13 +1116,15 @@ def _multiple_of_lowering_rule(ctx: LoweringRuleContext, val, *, values): def _reduce_max_lowering_rule(ctx: LoweringRuleContext, x, *, axes): (x_aval,) = ctx.avals_in + if not ctx.avals_out[0].shape: + raise NotImplementedError( + "Cannot lower reductions to scalar. Reduce to one element vector" + " instead, using keepdims=True." + ) + out_type = aval_to_ir_type(ctx.avals_out[0]) if jnp.issubdtype(x_aval.dtype, jnp.floating): - # TODO(apaszke): Remove in 03/2024. - if hasattr(vector.CombiningKind, "MAXIMUMF"): - kind = vector.CombiningKind.MAXIMUMF - else: - kind = vector.CombiningKind.MAXF + kind = vector.CombiningKind.MAXIMUMF val = ir.FloatAttr.get(ir.F32Type.get(), float("-inf")) identity = ir.DenseElementsAttr.get_splat(out_type, val) elif jnp.issubdtype(x_aval.dtype, jnp.signedinteger): @@ -973,6 +1150,12 @@ def _reduce_max_lowering_rule(ctx: LoweringRuleContext, x, *, axes): def _reduce_sum_lowering_rule(ctx: LoweringRuleContext, x, *, axes): (x_aval,) = ctx.avals_in + if not ctx.avals_out[0].shape: + raise NotImplementedError( + "Cannot lower reductions to scalar. Reduce to one element vector" + " instead, using keepdims=True." + ) + out_type = aval_to_ir_type(ctx.avals_out[0]) if jnp.issubdtype(x_aval.dtype, jnp.floating): kind = ir.Attribute.parse("#vector.kind") @@ -1031,14 +1214,24 @@ def _dot_general_lowering_rule( (aval_out,) = ctx.avals_out out_type = aval_to_ir_type(aval_out) val_type = out_type.element_type - if any(cls.isinstance(val_type) for cls in [ir.BF16Type, ir.F32Type]): + if any( + cls.isinstance(val_type) + for cls in [ + ir.BF16Type, + ir.F32Type, + ir.Float8E5M2Type, + ir.Float8E4M3FNType, + ] + ): val = ir.FloatAttr.get(val_type, 0.0) elif ir.IntegerType.isinstance(val_type): val = ir.IntegerAttr.get(val_type, 0) else: raise NotImplementedError(ctx.avals_out[0].dtype) if any(len(a.shape) != 2 for a in ctx.avals_in): - raise NotImplementedError(ctx.avals_in) + raise NotImplementedError( + f"Only 2D tensors supported in dot; received: {ctx.avals_in}" + ) lhs_aval, _ = ctx.avals_in # This is really a matrix-vector product. It only looks like matrix-matrix. if lhs_dims == (1,) and rhs_dims == (1,) and ctx.avals_in[1].shape[0] == 1: @@ -1110,14 +1303,14 @@ def _convert_helper(x, *, to_dtype): if jnp.issubdtype(from_dtype, jnp.dtype("bool")): x = x.astype(jnp.int32) return _convert_helper(x, to_dtype=to_dtype) - if jnp.issubdtype(from_dtype, jnp.integer): + if jnp.issubdtype(from_dtype, jnp.signedinteger): if from_dtype.itemsize < 4: x = x.astype(jnp.int32) if jnp.issubdtype(to_dtype, jnp.floating) and to_dtype.itemsize < 4: x = x.astype(jnp.float32) return x.astype(to_dtype) if jnp.issubdtype(from_dtype, jnp.floating): - if jnp.issubdtype(to_dtype, jnp.integer): + if jnp.issubdtype(to_dtype, jnp.signedinteger): if from_dtype.itemsize < 4: x = x.astype(jnp.float32) if to_dtype.itemsize < 4: @@ -1138,6 +1331,11 @@ def _convert_element_type_lowering_rule( out_aval = ctx.avals_out[0] old_dtype = ctx.avals_in[0].dtype out_type = aval_to_ir_type(out_aval) + + # TODO(justinfu): Remove after mosaic supports unsigned types. + # This conversion makes mosaic interpret all unsigned types as signed types. + if np.issubdtype(new_dtype, jnp.unsignedinteger): + new_dtype = UNSIGNED_TO_SIGNED[new_dtype] if old_dtype == new_dtype: return x if jnp.issubdtype(old_dtype, jnp.floating) and jnp.issubdtype( @@ -1234,24 +1432,30 @@ def _transpose_lowering_rule(ctx: LoweringRuleContext, x, *, permutation): def _bcast(x, y, x_aval, y_aval, out_aval): + x_dtype = x_aval.dtype + y_dtype = y_aval.dtype + if y_aval.weak_type: + y_dtype = x_aval.dtype + elif x_aval.weak_type: + x_dtype = y_aval.dtype if isinstance(x, (np.ndarray, np.number, int, float)): - if hasattr(y, "type") and y.type == ir.IndexType.get(): + if getattr(y, "type", None) == ir.IndexType.get(): mlir_type = y.type else: - mlir_type = _dtype_to_ir_type(x_aval.dtype) + mlir_type = _dtype_to_ir_type(x_dtype) x = ir_constant(x, mlir_type) if isinstance(y, (np.ndarray, np.number, int, float)): - if hasattr(x, "type") and x.type == ir.IndexType.get(): + if getattr(x, "type", None) == ir.IndexType.get(): mlir_type = x.type else: - mlir_type = _dtype_to_ir_type(y_aval.dtype) + mlir_type = _dtype_to_ir_type(y_dtype) y = ir_constant(y, mlir_type) out_shape = list(out_aval.shape) if x_aval.shape != out_aval.shape: - x_ty = ir.VectorType.get(out_shape, _dtype_to_ir_type(x_aval.dtype)) + x_ty = ir.VectorType.get(out_shape, _dtype_to_ir_type(x_dtype)) x = vector.BroadcastOp(x_ty, x) if y_aval.shape != out_aval.shape: - y_ty = ir.VectorType.get(out_shape, _dtype_to_ir_type(y_aval.dtype)) + y_ty = ir.VectorType.get(out_shape, _dtype_to_ir_type(y_dtype)) y = vector.BroadcastOp(y_ty, y) return x, y @@ -1484,6 +1688,20 @@ def _log1p_lowering_rule(ctx: LoweringRuleContext, x): lowering_rules[lax.log1p_p] = _log1p_lowering_rule +def _round_lowering_rule(ctx: LoweringRuleContext, x, *, rounding_method): + if rounding_method == 0: + return math.RoundOp(x).result + elif rounding_method == 1: + return math.RoundEvenOp(x).result + else: + raise NotImplementedError(f"Unsupported rounding method: {rounding_method}") + + +lowering_rules[lax.round_p] = _round_lowering_rule + + +# See https://mlir.llvm.org/docs/Dialects/ArithOps/#arithcmpi-arithcmpiop for +# the mapping from comparison type to integer predicates for int comparisons. _cmpi_lowering_types = { lax.eq_p: 0, lax.ne_p: 1, @@ -1493,10 +1711,15 @@ def _log1p_lowering_rule(ctx: LoweringRuleContext, x): lax.ge_p: 5, } +# See https://mlir.llvm.org/docs/Dialects/ArithOps/#arithcmpf-arithcmpfop for +# the mapping from comparison type to integer predicate for float comparisons. _cmpf_lowering_types = { lax.eq_p: 1, - lax.gt_p: 2, lax.ne_p: 6, + lax.lt_p: 4, + lax.le_p: 5, + lax.gt_p: 2, + lax.ge_p: 3, } @@ -1523,18 +1746,22 @@ def _cmp_lowering_rule(prim, ctx: LoweringRuleContext, x, y): lowering_rules[lax.ge_p] = functools.partial(_cmp_lowering_rule, lax.ge_p) -def _and_lowering_rule(ctx: LoweringRuleContext, lhs, rhs): - return arith.AndIOp(lhs, rhs).result +def _and_lowering_rule(ctx: LoweringRuleContext, x, y): + x, y = _bcast(x, y, *ctx.avals_in, *ctx.avals_out) + return arith.AndIOp(x, y).result lowering_rules[lax.and_p] = _and_lowering_rule +skip_mlir_conversions.add(lax.and_p) -def _or_lowering_rule(ctx: LoweringRuleContext, lhs, rhs): - return arith.OrIOp(lhs, rhs).result +def _or_lowering_rule(ctx: LoweringRuleContext, x, y): + x, y = _bcast(x, y, *ctx.avals_in, *ctx.avals_out) + return arith.OrIOp(x, y).result lowering_rules[lax.or_p] = _or_lowering_rule +skip_mlir_conversions.add(lax.or_p) def _not_lowering_rule(ctx: LoweringRuleContext, x): @@ -1631,8 +1858,8 @@ def _for_lowering_rule( def _lower_jaxpr_to_for_loop(ctx: LoweringRuleContext, - jaxpr: jax_core.Jaxpr, start: int, - num_steps: int, consts, *args, + jaxpr: jax_core.Jaxpr, start: int | ir.Value, + num_steps: int | ir.Value, consts, *args, has_loop_index: bool, unroll: int): def _run_body(i, args): @@ -1649,7 +1876,11 @@ def _run_body(i, args): args = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args) return args - if num_steps == unroll: + if ( + not isinstance(start, ir.Value) + and not isinstance(num_steps, ir.Value) + and num_steps == unroll + ): # No need for an scf.For. We can just unroll completely for i in range(start, start + num_steps): args = _run_body( @@ -1660,13 +1891,9 @@ def _run_body(i, args): if unroll != 1: raise NotImplementedError( f"Only unroll={num_steps=} and unroll=1 supported. Got {unroll=}.") - lbd = ir_constant(0, mlir_type=mlir.dtype_to_ir_type(jnp.dtype("int32"))) - if isinstance(num_steps, int): - ubd = ir_constant( - num_steps, mlir_type=_dtype_to_ir_type(jnp.dtype("int32")) - ) - else: - ubd = num_steps + i32 = jax_core.ShapedArray((), jnp.int32) + lbd = _ensure_mlir_value(start, i32) + ubd = arith.addi(lbd, _ensure_mlir_value(num_steps, i32)) step = ir_constant(1, mlir_type=_dtype_to_ir_type(jnp.dtype("int32"))) for_op = scf.ForOp(lbd, ubd, step, args) with ir.InsertionPoint(for_op.body): @@ -1705,10 +1932,12 @@ def _scan_lowering_rule( linear: tuple[bool, ...], length: int, reverse: bool, - unroll: bool, + unroll: bool | int, num_consts: int, num_carry: int, + _split_transpose: bool, ): + del _split_transpose # Can only handle fori_loop-like scans num_extensive = len(args) - num_consts - num_carry if num_extensive: raise NotImplementedError @@ -1719,9 +1948,9 @@ def _scan_lowering_rule( if jaxpr_consts: raise NotImplementedError del jaxpr_consts - jaxpr, has_loop_index = ( - pallas_utils.pattern_match_scan_to_fori_loop(jaxpr, num_consts, num_carry) - ) + jaxpr, has_loop_index = pallas_utils.pattern_match_scan_to_fori_loop( + jaxpr, num_consts, num_carry + ) consts, args = split_list(args, [num_consts]) consts_avals, args_avals = split_list(ctx.avals_in, [num_consts]) if has_loop_index: @@ -1744,17 +1973,15 @@ def _scan_lowering_rule( skip_mlir_conversions.add(lax.scan_p) -def _while_lowering_rule( +def _lower_while_via_fori( ctx: LoweringRuleContext, *args, + fori_jaxpr, cond_nconsts, cond_jaxpr, body_nconsts, body_jaxpr, ): - jaxpr = pallas_utils.pattern_match_while_to_fori_loop( - cond_jaxpr, cond_nconsts, body_jaxpr, body_nconsts - ) _, body_consts, carry = split_list(args, [cond_nconsts, body_nconsts]) (lb, ub), args = carry[:2], carry[2:] for_out = _lower_jaxpr_to_for_loop( @@ -1762,9 +1989,9 @@ def _while_lowering_rule( block_shapes=ctx.block_shapes[: body_nconsts + 1] + ctx.block_shapes[body_nconsts + 2 :], ), - jaxpr, + fori_jaxpr, lb, - ub, + arith.subi(ub, lb), body_consts, *args, has_loop_index=True, @@ -1773,9 +2000,87 @@ def _while_lowering_rule( return [ub, ub, *for_out] +def _while_lowering_rule( + ctx: LoweringRuleContext, + *args, + cond_nconsts, + cond_jaxpr, + body_nconsts, + body_jaxpr, +): + # First try to lower via a simpler fori loop, which may optimize better. + fori_jaxpr, err = pallas_utils.pattern_match_while_to_fori_loop( + cond_jaxpr, cond_nconsts, body_jaxpr, body_nconsts + ) + if fori_jaxpr is not None: + return _lower_while_via_fori( + ctx, + *args, + fori_jaxpr=fori_jaxpr, + cond_nconsts=cond_nconsts, + cond_jaxpr=cond_jaxpr, + body_nconsts=body_nconsts, + body_jaxpr=body_jaxpr, + ) + + # If we fail conversion to fori, fallback to an ordinary while loop. + cond_consts, body_consts, carry = split_list( + args, [cond_nconsts, body_nconsts] + ) + cond_const_block_shapes, body_const_block_shapes, carry_block_shapes = ( + split_list(ctx.block_shapes, [cond_nconsts, body_nconsts]) + ) + cond_const_types = [a.type for a in cond_consts] + body_const_types = [a.type for a in body_consts] + carry_types = [a.type for a in carry] + all_types = [*cond_const_types, *body_const_types, *carry_types] + while_op = scf.WhileOp(all_types, args) + + before_block = while_op.before.blocks.append(*all_types) + cond_consts_, _, carry_ = split_list( + before_block.arguments, + [cond_nconsts, body_nconsts], + ) + cond_args = [*cond_consts_, *carry_] + with ir.InsertionPoint.at_block_begin(before_block): + [cond] = jaxpr_subcomp( + ctx.lowering_context.replace( + block_shapes=[*cond_const_block_shapes, *carry_block_shapes] + ), + cond_jaxpr.jaxpr, + *cond_args, + ) + scf.condition(cond, before_block.arguments) + + after_block = while_op.after.blocks.append(*all_types) + cond_consts_, body_consts_, carry_ = split_list( + after_block.arguments, + [cond_nconsts, body_nconsts], + ) + all_args = [*cond_consts_, *body_consts_, *carry_] + cond_const_args, body_const_args, carry_args = split_list( + all_args, [cond_nconsts, body_nconsts] + ) + with ir.InsertionPoint.at_block_begin(after_block): + loop_out = jaxpr_subcomp( + ctx.lowering_context.replace( + block_shapes=[*body_const_block_shapes, *carry_block_shapes], + ), + body_jaxpr.jaxpr, + *body_const_args, + *carry_args, + ) + all_handles = [*cond_const_args, *body_const_args, *loop_out] + if all_handles: + scf.yield_(all_handles) + + all_out = list(while_op.results_) + return all_out[cond_nconsts + body_nconsts :] + + lowering_rules[lax.while_p] = _while_lowering_rule -def _cond_lowering_rule(ctx: LoweringRuleContext, *args, branches, linear): +def _cond_lowering_rule(ctx: LoweringRuleContext, *args, branches): index, *args = args out_types = map(aval_to_ir_type, ctx.avals_out) pred = arith.CmpIOp( @@ -1794,7 +2099,6 @@ def _cond_lowering_rule(ctx: LoweringRuleContext, *args, branches, linear): arith.SubIOp(index, ir_constant(1, index.type)).result, *args, branches=branches[1:], - linear=linear, ) else: out = jaxpr_subcomp(lowering_context, branches[1].jaxpr, *args) @@ -1846,22 +2150,32 @@ def _debug_callback_lowering_rule(ctx: LoweringRuleContext, *args, **kwargs): def _program_id_lowering_rule(ctx: LoweringRuleContext, *, axis: int): - if ctx.lowering_context.grid_indices is None: + if ctx.lowering_context.user_grid_indices is None: raise ValueError( f"program id: {axis} was passed, but user did not provide a grid." ) - length = len(ctx.lowering_context.grid_indices) + length = len(ctx.lowering_context.user_grid_indices) if not (0 <= axis < length): raise ValueError( f"user passed in program id with axis: {axis}, but grid only has" f" length: {length}" ) - return ctx.lowering_context.grid_indices[axis] + return ctx.lowering_context.user_grid_indices[axis] lowering_rules[primitives.program_id_p] = _program_id_lowering_rule def _num_programs_lowering_rule(ctx: LoweringRuleContext, *, axis: int): - del ctx - return tpu.iteration_bound(axis) + mapped_axes = set(ctx.lowering_context.mapped_dims) + seen_user_axes = 0 + for i in range(ctx.lowering_context.grid_rank): + seen_user_axes += int(i not in mapped_axes) + if seen_user_axes == axis + 1: + break + else: + raise ValueError( + f"user passed in program id with axis: {axis}, but grid only has" + f" length: {len(ctx.lowering_context.grid_rank)}" + ) + return tpu.iteration_bound(i) lowering_rules[primitives.num_programs_p] = _num_programs_lowering_rule @@ -1874,9 +2188,11 @@ def _repeat_lowering_rule(ctx: LoweringRuleContext, x, *, repeats, axis): def _roll_lowering_rule( - ctx: LoweringRuleContext, x, *, shift, axis, stride, stride_axis + ctx: LoweringRuleContext, x, shift, *, axis, stride, stride_axis ): - return tpu.RotateOp( + (out_aval,) = ctx.avals_out + return tpu.DynamicRotateOp( + aval_to_ir_type(out_aval), x, shift, axis, @@ -1906,24 +2222,30 @@ def _slice_lowering_rule( def _xor_lowering_rule(ctx: LoweringRuleContext, x, y): + x, y = _bcast(x, y, *ctx.avals_in, *ctx.avals_out) return arith.XOrIOp(x, y).result lowering_rules[lax.xor_p] = _xor_lowering_rule +skip_mlir_conversions.add(lax.xor_p) def _shift_left_lowering_rule(ctx: LoweringRuleContext, x, d): + x, d = _bcast(x, d, *ctx.avals_in, *ctx.avals_out) return arith.ShLIOp(x, d).result lowering_rules[lax.shift_left_p] = _shift_left_lowering_rule +skip_mlir_conversions.add(lax.shift_left_p) def _shift_right_logical_lowering_rules(ctx: LoweringRuleContext, x, d): + x, d = _bcast(x, d, *ctx.avals_in, *ctx.avals_out) return arith.ShRUIOp(x, d).result lowering_rules[lax.shift_right_logical_p] = _shift_right_logical_lowering_rules +skip_mlir_conversions.add(lax.shift_right_logical_p) def _bitcast_lowering_rule(ctx: LoweringRuleContext, x, *, ty): @@ -1931,24 +2253,16 @@ def _bitcast_lowering_rule(ctx: LoweringRuleContext, x, *, ty): (out_aval,) = ctx.avals_out return tpu.BitcastOp(aval_to_ir_type(out_aval), x).result - lowering_rules[tpu_primitives.bitcast_p] = _bitcast_lowering_rule -def _trace_start_lowering_rule( - ctx: LoweringRuleContext, *, message: str, level: int -): - return tpu.TraceStartOp(message=message, level=level).results - - -lowering_rules[tpu_primitives.trace_start_p] = _trace_start_lowering_rule - - -def _trace_stop_lowering_rule(ctx: LoweringRuleContext): - return tpu.TraceStopOp().results - - -lowering_rules[tpu_primitives.trace_stop_p] = _trace_stop_lowering_rule - +def _bitcast_convert_type_lowering_rule( + ctx: LoweringRuleContext, x, *, new_dtype): + (in_aval, ) = ctx.avals_in + (out_aval,) = ctx.avals_out + if in_aval.dtype.itemsize != new_dtype.itemsize: + raise NotImplementedError("Changing bitwidths not supported.") + return tpu.BitcastOp(aval_to_ir_type(out_aval), x).result +lowering_rules[lax.bitcast_convert_type_p] = _bitcast_convert_type_lowering_rule def _alloc_value(aval: jax_core.AbstractValue) -> ir.Value: if isinstance(aval, pl_core.AbstractMemoryRef): @@ -1993,7 +2307,7 @@ def _device_id_to_logical( device_ids = tree_util.tree_leaves(device_id) mesh_strides = ctx.lowering_context.mesh_context.mesh_strides def _linearize_mesh_indices(*indices): - return sum([a * b for a, b in zip(indices, mesh_strides)]) + return sum(a * b for a, b in zip(indices, mesh_strides)) lower_ctx = LoweringRuleContext( lowering_context=ctx.lowering_context, avals_in=[jax_core.ShapedArray((), jnp.int32)] * len(device_ids), @@ -2006,18 +2320,34 @@ def _linearize_mesh_indices(*indices): return device_id raise NotImplementedError(f"Unsupported device id type: {device_id_type}") + +def _semaphore_read_lowering_rule( + ctx: LoweringRuleContext, + *args, + args_tree, +): + sem_aval, _ = tree_util.tree_unflatten(args_tree, ctx.avals_in) + sem, indexers = tree_util.tree_unflatten(args_tree, args) + sem, _ = _index_ref(sem, sem_aval, sem_aval.shape, indexers) + return tpu.SemaphoreReadOp(sem).result + + +lowering_rules[tpu_primitives.semaphore_read_p] = _semaphore_read_lowering_rule + def _semaphore_signal_lowering_rule( ctx: LoweringRuleContext, *args, args_tree, device_id_type: tpu_primitives.DeviceIdType, ): - sem_aval, _, _, _ = tree_util.tree_unflatten(args_tree, ctx.avals_in) - sem, indexers, value, device_id = tree_util.tree_unflatten(args_tree, args) - sem, _, _ = _index_ref(sem, sem_aval, sem_aval.shape, indexers) + sem_aval, _, _, _, _ = tree_util.tree_unflatten(args_tree, ctx.avals_in) + sem, indexers, value, device_id, core_index = tree_util.tree_unflatten(args_tree, args) + sem, _ = _index_ref(sem, sem_aval, sem_aval.shape, indexers) if device_id is not None: device_id = _device_id_to_logical(ctx, device_id, device_id_type) - return tpu.SemaphoreSignalOp(sem, value, device_id=device_id).results + return tpu.SemaphoreSignalOp( + sem, value, device_id=device_id, core_id=core_index + ).results lowering_rules[tpu_primitives.semaphore_signal_p] = ( @@ -2027,7 +2357,7 @@ def _semaphore_signal_lowering_rule( def _semaphore_wait_lowering_rule(ctx: LoweringRuleContext, *args, args_tree): sem_aval, _, _ = tree_util.tree_unflatten(args_tree, ctx.avals_in) sem, indexers, value = tree_util.tree_unflatten(args_tree, args) - sem, _, _ = _index_ref(sem, sem_aval, sem_aval.shape, indexers) + sem, _ = _index_ref(sem, sem_aval, sem_aval.shape, indexers) return tpu.SemaphoreWaitOp(sem, value).results lowering_rules[tpu_primitives.semaphore_wait_p] = _semaphore_wait_lowering_rule @@ -2049,16 +2379,16 @@ def _dma_start_lowering_rule(ctx: LoweringRuleContext, *args, tree, ) block_shapes = tree_util.tree_unflatten(tree, ctx.block_shapes) src_ref_block_shape, dst_ref_block_shape = block_shapes[0], block_shapes[2] - src_ref, _, _ = _index_ref( + src_ref, _ = _index_ref( src_ref, src_ref_aval, src_ref_block_shape, src_indexers ) if src_sem is not None: - src_sem, _, _ = _index_ref( + src_sem, _ = _index_ref( src_sem, src_sem_aval, src_sem_aval.shape, src_sem_indexers) - dst_ref, _, _ = _index_ref( + dst_ref, _ = _index_ref( dst_ref, dst_ref_aval, dst_ref_block_shape, dst_indexers ) - sem, _, _ = _index_ref(sem, sem_aval, sem_aval.shape, sem_indexers) + sem, _ = _index_ref(sem, sem_aval, sem_aval.shape, sem_indexers) if device_id is not None: device_id = _device_id_to_logical(ctx, device_id, device_id_type) return tpu.EnqueueDMAOp(src_ref, dst_ref, sem, source_semaphore=src_sem, @@ -2073,10 +2403,10 @@ def _dma_wait_lowering_rule(ctx: LoweringRuleContext, *args, tree, sem_aval, _, ref_aval, _ = tree_util.tree_unflatten(tree, ctx.avals_in) block_shapes = tree_util.tree_unflatten(tree, ctx.block_shapes) ref_block_shape = block_shapes[2] - ref, _, _ = _index_ref( + ref, _ = _index_ref( ref, ref_aval, ref_block_shape, indexers ) - sem, _, _ = _index_ref(sem, sem_aval, sem_aval.shape, sem_indexers) + sem, _ = _index_ref(sem, sem_aval, sem_aval.shape, sem_indexers) return tpu.WaitDMAOp(sem, ref).results lowering_rules[tpu_primitives.dma_wait_p] = _dma_wait_lowering_rule @@ -2085,14 +2415,135 @@ def _device_id_lowering_rule(ctx: LoweringRuleContext): lowering_rules[tpu_primitives.device_id_p] = _device_id_lowering_rule def _axis_index_rule(ctx: LoweringRuleContext, *, axis_name: str): - device_id = _make_index(tpu.DeviceIdOp().result) - l_to_m = ctx.lowering_context.mesh_context.logical_to_mesh + device_id = tpu.DeviceIdOp().result + mesh_shape = ctx.lowering_context.mesh_context.mesh_shape axis_names = ctx.lowering_context.mesh_context.axis_names - col = _make_index(axis_names.index(axis_name)) - return memref.LoadOp(l_to_m, [device_id, col]).result + axis_index = axis_names.index(axis_name) + axis_size = ir_constant(mesh_shape[axis_index]) + minor_divisor = ir_constant( + np.prod(mesh_shape[axis_index + 1 :], dtype=np.int32) + ) + return arith.remsi(arith.divsi(device_id, minor_divisor), axis_size) lowering_rules[lax.axis_index_p] = _axis_index_rule def _get_barrier_semaphore_rule(ctx: LoweringRuleContext): memref_type = aval_to_ir_type(ctx.avals_out[0]) return tpu.GetBarrierSemaphoreOp(memref_type).result lowering_rules[tpu_primitives.get_barrier_semaphore_p] = _get_barrier_semaphore_rule + + +def _delay_rule(ctx: LoweringRuleContext, nanos: int): + return tpu.DelayOp(nanos).results + + +lowering_rules[tpu_primitives.delay_p] = _delay_rule + + +def _debug_print_rule( + ctx: LoweringRuleContext, *args, fmt: str, has_placeholders: bool +): + primitives.check_debug_print_format(fmt, *args) + if has_placeholders: + if not all( + isinstance(arg.type, ir.IntegerType) and arg.type.width == 32 + for arg in args + ): + raise TypeError( + "All arguments must be 32-bit integers when using" + " placeholders (`{...}`). If you need to print values of other types," + " remove placeholders from the format string." + ) + + # TPU expects $0, $1 etc as placeholders. + tpu_fmt = "".join( + f"{text}${idx}" + for idx, (text, _, _, _) in enumerate(string.Formatter().parse(fmt)) + ) + else: + tpu_fmt = fmt + tpu.log(args, tpu_fmt, formatted=has_placeholders) + return () + + +lowering_rules[primitives.debug_print_p] = _debug_print_rule + + +def _prng_seed_lowering_rule(ctx: LoweringRuleContext, *seeds): + del ctx + # In the KeyScalarBundle case we unpack the bundle and set the seed with + # the list of scalars. + if len(seeds) == 1 and isinstance(seeds[0], KeyScalarBundle): + return tpu.PRNGSeed32Op(seeds[0].scalars).results + # For integer seeds, we can set the seed directly as PRNGSeed32Op natively + # takes in a list of integers as input. + all_integers = all(isinstance(seed.type, ir.IntegerType) for seed in seeds) + if not all_integers: + seed_types = [seed.type for seed in seeds] + raise ValueError(f"All seed data must be scalar integers. Got {seed_types}") + return tpu.PRNGSeed32Op(seeds).results +lowering_rules[tpu_primitives.prng_seed_p] = _prng_seed_lowering_rule + + +def _prng_random_bits_lowering_rule(ctx: LoweringRuleContext, *, shape): + if len(shape) <= 1: + # TODO(b/342054464): Support implicit dims for PRNGRandomBitsOp. + raise NotImplementedError("random_bits only supports rank>=2 outputs.") + out_aval = ctx.avals_out[0] + out_type = aval_to_ir_type(out_aval) + return tpu.PRNGRandomBitsOp(out_type).result +lowering_rules[tpu_primitives.prng_random_bits_p] = _prng_random_bits_lowering_rule + + +def random_seed_lowering(ctx, seeds, *, impl): + seed_lowering = lower_fun( + impl.seed, multiple_results=False) + return seed_lowering(ctx, seeds) +lowering_rules[prng.random_seed_p] = random_seed_lowering + + +def random_bits_lowering(ctx, keys, *, bit_width, shape): + assert bit_width == 32, "Only 32-bit PRNG supported." + aval, = ctx.avals_in + impl = aval.dtype._impl + bits_lowering = lower_fun( + impl.random_bits, multiple_results=False) + return bits_lowering(ctx, keys, bit_width=bit_width, shape=shape) +lowering_rules[prng.random_bits_p] = random_bits_lowering + + +def random_fold_in_lowering(ctx, keys, msgs): + keys_aval, _ = ctx.avals_in + impl = keys_aval.dtype._impl + fold_in_lowering = lower_fun( + impl.fold_in, multiple_results=False) + return fold_in_lowering(ctx, keys, msgs) +lowering_rules[prng.random_fold_in_p] = random_fold_in_lowering + + +def random_unwrap_lowering(ctx, key): + del ctx, key + raise NotImplementedError("key_data not implemented.") +lowering_rules[prng.random_unwrap_p] = random_unwrap_lowering + + +def random_wrap_lowering(ctx, key_data, *, impl): + del ctx, impl + if isinstance(key_data.type, ir.VectorType): + # If the key data lives in vregs, need to unpack it to sregs. + key_data_list = [] + key_data_shape = key_data.type.shape + if len(key_data_shape) != 2: + raise NotImplementedError("Seed key_data must be 2D.") + if tuple(key_data_shape) != (1, 1): + raise NotImplementedError( + "Seed key_data of shape != (1, 1) not supported. " + f"Got: {key_data_shape}") + for i in range(key_data_shape[1]): + key_data_list.append(vector.ExtractOp(key_data, [], [0, i])) + return KeyScalarBundle(scalars=key_data_list) + if isinstance(key_data, KeyScalarBundle): + return key_data + else: + raise NotImplementedError(f"key_data wrap {type(key_data)}") + +lowering_rules[prng.random_wrap_p] = random_wrap_lowering diff --git a/jax/_src/pallas/mosaic/pallas_call_registration.py b/jax/_src/pallas/mosaic/pallas_call_registration.py index aeb2a733e817..8f307e560bf0 100644 --- a/jax/_src/pallas/mosaic/pallas_call_registration.py +++ b/jax/_src/pallas/mosaic/pallas_call_registration.py @@ -17,24 +17,25 @@ from __future__ import annotations from typing import Any +import warnings import jax from jax import core as jax_core -from jax.experimental import mosaic -from jax.experimental.mosaic.dialects import tpu +from jax._src import core as jax_src_core from jax._src import sharding_impls from jax._src.interpreters import mlir from jax._src.lib.mlir import ir from jax._src.pallas import core from jax._src.pallas.mosaic import lowering from jax._src.pallas.pallas_call import pallas_call_p +from jax.experimental import mosaic +from jax.experimental.mosaic.dialects import tpu def pallas_call_tpu_lowering_rule( ctx: mlir.LoweringRuleContext, *in_nodes, jaxpr: jax_core.Jaxpr, name: str, - which_linear: tuple[bool, ...], grid_mapping: core.GridMapping, input_output_aliases: tuple[tuple[int, int], ...], in_shapes: tuple[jax.ShapeDtypeStruct, ...], @@ -47,18 +48,23 @@ def pallas_call_tpu_lowering_rule( return mlir.lower_fun(pallas_call_p.impl, multiple_results=True)( ctx, *in_nodes, jaxpr=jaxpr, name=name, out_shapes=out_shapes, in_shapes=in_shapes, - which_linear=which_linear, interpret=interpret, debug=debug, input_output_aliases=input_output_aliases, grid_mapping=grid_mapping, compiler_params=compiler_params) if debug: print(jaxpr) - if 'mosaic_params' in compiler_params: - assert 'mosaic' not in compiler_params - mosaic_params = compiler_params['mosaic_params'] - elif 'mosaic' in compiler_params: - mosaic_params = compiler_params['mosaic'] + if "mosaic_params" in compiler_params: + # TODO(slebedev): Remove this branch after July 12th 2024. + warnings.warn( + "Passing Mosaic parameters via compiler_params=dict(mosaic_params=...)" + " is deprecated. Use compiler_params=dict(mosaic=...) instead.", + DeprecationWarning, + ) + assert "mosaic" not in compiler_params + mosaic_params = compiler_params["mosaic_params"] + elif "mosaic" in compiler_params: + mosaic_params = compiler_params["mosaic"] else: mosaic_params = {} mesh = None @@ -70,39 +76,42 @@ def pallas_call_tpu_lowering_rule( mlir_ctx.append_dialect_registry(mlir.upstream_dialects) mlir_ctx.load_all_available_dialects() tpu.register_dialect(mlir_ctx) - if mosaic_params is None: - mosaic_params = {} dimension_semantics = mosaic_params.get("dimension_semantics", None) - kernel_regeneration_metadata = mosaic_params.get( - "kernel_regeneration_metadata" - ) mosaic_module, extra_args = lowering.lower_jaxpr_to_module( mlir_ctx, grid_mapping, in_shapes, out_shapes, jaxpr, dimension_semantics=dimension_semantics, mesh=mesh) if debug: print(mosaic_module) - if extra_args and input_output_aliases: - raise NotImplementedError( - "Cannot use both input_output_aliases and extra_args." - ) + num_extra_args = len(extra_args) num_dyn_bounds = grid_mapping.num_dynamic_grid_bounds input_output_aliases = tuple( - (a[0] + num_dyn_bounds, a[1]) for a in input_output_aliases + (a[0] + num_dyn_bounds + num_extra_args, a[1]) + for a in input_output_aliases ) out_avals = [jax_core.ShapedArray(s.shape, s.dtype) for s in out_shapes] + + # Replace in_avals to physical avals. + # This step is required for mapping logical types to physical types. + # (e.g. PRNG key -> uint32[2]) + physical_avals = [jax_src_core.physical_aval(aval) for aval in ctx.avals_in] + ctx = ctx.replace(avals_in=physical_avals) + def _lower_fun(*args): # Dynamic grid bounds have to go at the front. dynamic_grid_args, args = args[:num_dyn_bounds], args[num_dyn_bounds:], return mosaic.as_tpu_kernel( mosaic_module, out_avals, - backend=ctx.module_context.backend, + backend="tpu", kernel_name=name, - kernel_regeneration_metadata=kernel_regeneration_metadata, - cost_estimate=mosaic_params.get("cost_estimate", None), - vmem_limit_bytes=mosaic_params.get("vmem_limit_bytes", None), - flags=mosaic_params.get("flags", None), + cost_estimate=mosaic_params.get("cost_estimate"), + vmem_limit_bytes=mosaic_params.get("vmem_limit_bytes"), + flags=mosaic_params.get("flags"), + allow_input_fusion=mosaic_params.get("allow_input_fusion"), input_output_aliases=input_output_aliases, + internal_scratch_in_bytes=mosaic_params.get( + "internal_scratch_in_bytes" + ), )( *dynamic_grid_args, *extra_args, @@ -111,5 +120,3 @@ def _lower_fun(*args): ) return mlir.lower_fun(_lower_fun, multiple_results=True)( ctx, *in_nodes) -mlir.register_lowering(pallas_call_p, pallas_call_tpu_lowering_rule, - platform="tpu") diff --git a/jax/_src/pallas/mosaic/pipeline.py b/jax/_src/pallas/mosaic/pipeline.py index 8c3d0de88739..0d778a60c711 100644 --- a/jax/_src/pallas/mosaic/pipeline.py +++ b/jax/_src/pallas/mosaic/pipeline.py @@ -11,1261 +11,1055 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + """Module for emitting custom TPU pipelines within a Pallas call.""" +from __future__ import annotations +from collections.abc import Sequence import dataclasses +import enum import functools -import math -from typing import Any, Callable, Generic, NamedTuple, Optional, Protocol, Sequence, TypeVar, Union, cast +import itertools +import operator +from typing import Union, Any import jax from jax import lax from jax import tree_util -from jax._src.api_util import flatten_axes -from jax._src.pallas import core -from jax._src.pallas import primitives +from jax._src import util as jax_util +from jax._src.pallas import core as pallas_core from jax._src.pallas.mosaic import core as tpu_core from jax._src.pallas.mosaic import primitives as tpu_primitives -from jax._src.state import indexing -from jax._src.util import split_list +from jax.experimental import pallas as pl import jax.numpy as jnp +import numpy as np + SMEM = tpu_core.TPUMemorySpace.SMEM VMEM = tpu_core.TPUMemorySpace.VMEM DMA = tpu_core.SemaphoreType.DMA REF = tpu_core.MemoryRef - -partial = functools.partial - -T = TypeVar("T") - - -@tree_util.register_pytree_node_class -@dataclasses.dataclass(frozen=True) -class PipelineArg(Generic[T]): - """Wrapper for pipeline arguments that exist for inputs, outputs, and accums.""" - input: T - out: T - in_out: T - - @property - def input_and_in_out(self) -> T: - return cast(Any, self.input) + cast(Any, self.in_out) - - def tree_flatten(self): - return ((self.input, self.out, self.in_out), None) - - @classmethod - def tree_unflatten(cls, _, children): - return cls(*children) - - -class PipelineBuffer(NamedTuple): - """Current and next buffer indices for an input/output/accum ref.""" - current: Union[REF, jax.Array] - next: Union[REF, jax.Array] - -# TODO(enriqueps): Add SMEM support. -class PipelineAllocation(NamedTuple): - """Allocated VMEM ref and semaphore for an input/output/accum ref.""" - vmem_ref: REF - semaphore: tpu_core.SemaphoreType - -# PyTree versions of the various arguments. -PipelineBlockSpecs = Union[Sequence[core.BlockSpec], Any] -PipelineRefs = Union[Sequence[REF], Any] -PipelineBuffers = Union[Sequence[PipelineBuffer], Any] -PipelineAllocations = Union[Sequence[PipelineAllocation], Any] +SemaphoreType = tpu_core.SemaphoreType +ArrayRef = Union[REF, jax.Array] GridIndices = tuple[jax.Array, ...] CondVal = Union[jax.Array, bool] +PipelineBlockSpecs = Union[Sequence[pallas_core.BlockSpec], Any] +PipelineRefs = Union[Sequence[REF], Any] -def _broadcast_pytree_to(name: str, from_pytree: Any, to_pytree: Any) -> Any: - """Broadcasts a prefix-pytree of to_pytree, to the shape of to_pytree. - - Useful for supporting passing in prefixes of things as arguments, like in - jax.vmap. - - Args: - name: Name for error messages. - from_pytree: Prefix tree. - to_pytree: Target pytree. - - Returns: - Broadcasted pytree. - """ - to_treedef = tree_util.tree_structure(to_pytree) - return tree_util.tree_unflatten( - to_treedef, flatten_axes(name, to_treedef, from_pytree) - ) - - -def _tree_map_with_kwargs(f, *args, **kwargs): - """jax.tree_util.tree_map that supports kwargs.""" - kwargs_keys = kwargs.keys() - kwargs_values = kwargs.values() - return tree_util.tree_map( - lambda arg0, partial_f, *args: partial_f(arg0, *args), - args[0], - tree_util.tree_map( - lambda _, *tree_mapped_kwargs_values: partial( - f, **dict(zip(kwargs_keys, tree_mapped_kwargs_values)) - ), - args[0], - *kwargs_values, - is_leaf=lambda x: x is None, - ), - *args[1:], - ) - - -def _get_next_indices(grid: core.StaticGrid, indices: GridIndices) -> GridIndices: - """Takes a grid and current indices and returns the next indices. - - grid: (3, 4, 5) - indices: [1, 0, 4] - returns: [1, 1, 0] - - Args: - grid: Grid spec. - indices: Current indices. - - Returns: - Incremented indices. - """ - next_indices = [] - carry = True - for dim_size, index in reversed(list(zip(grid, indices))): - i = jnp.where(carry, index + 1, index) - carry = dim_size == i - next_indices.append(jnp.where(carry, 0, i)) - return tuple(reversed(next_indices)) - - -def _replace_nones_in_block_spec(block_spec: core.BlockSpec) -> core.BlockSpec: - """Replaces Nones in a block spec shape with 1s.""" - block_shape = cast(tuple[int, ...], block_spec.block_shape) - block_shape = tuple([1 if dim is None else dim for dim in block_shape]) - return dataclasses.replace(block_spec, block_shape=block_shape) - - -def _run_block_spec( - block_spec: core.BlockSpec, indices: GridIndices -) -> tuple[Union[slice, indexing.Slice], ...]: - """Runs a block spec for the given indices and returns the slices. - - Args: - block_spec: Block spec to run. - indices: Grid indices to run on. - - Returns: - Array slices for the block spec. - """ - index_map = block_spec.index_map - if index_map is None: - raise ValueError("Block spec index_map is None.") - block_indices = index_map(*indices) - return tuple( - indexing.ds( - primitives.multiple_of(index * block_size, block_size), block_size - ) - for index, block_size in zip( - block_indices, cast(Any, block_spec.block_shape) - ) +# TODO(sharadmv): make this a parameter and make it queryable from the Device. +_TILING = (8, 128) + +def _broadcast_pytree_to(from_pytree, to_pytree): + """Broadcast a prefix pytree to a given full tree.""" + proxy = object() + treedef = tree_util.tree_structure(to_pytree) + broadcast_leaves = [] + def add_leaves(i, x): + broadcast_leaves.extend( + [i] * tree_util.tree_structure(x).num_leaves) + try: + tree_util.tree_map(add_leaves, from_pytree, to_pytree, + is_leaf=lambda x: x is None) + except ValueError: + raise ValueError(f"Cannot broadcast tree {from_pytree} " + f"to full tree structure {treedef}.") from None + broadcast_leaves = [None if a is proxy else a for a in broadcast_leaves] + assert len(broadcast_leaves) == treedef.num_leaves + return tree_util.tree_unflatten(treedef, broadcast_leaves) + + +def _get_tpu_generation() -> int: + kind = jax.devices()[0].device_kind + if kind.endswith(' lite'): + kind = kind[:-len(' lite')] + assert kind[:5] == "TPU v", kind + return int(kind[5]) + +def _make_tiling(shape: tuple[int, ...], dtype: np.dtype) -> tuple[int, ...]: + # For a n-dimensional shape, returns (8, 128) for the last 2 dimensions + # and 1 for the leading n - 2. For example, (256, 256) -> (8, 128) and + # (2, 3, 128, 128) -> (1, 1, 8, 128). + if len(shape) < 2: + raise ValueError(f"Shape must have at least 2 dimensions: {shape=}") + leading_dims, final_dims = shape[:-2], shape[-2:] + # We want to find the minimum power of 2 that fits the second-minor dimension + # of shape, with maximum value 8. + second_minor, _ = final_dims + packing = 4 // dtype.itemsize + max_tiling = _TILING[0] + second_minor_tiling = (1 + int(_get_tpu_generation() < 4)) * packing + while second_minor_tiling < min(second_minor, max_tiling): + second_minor_tiling *= 2 + return (*(1,) * len(leading_dims), second_minor_tiling, _TILING[1]) + + +def _mod(a, n): + """"Calculates a mod n for positive and negative a with |a| <= n.""" + return lax.rem(a + n, n) + + +def _round_up_to_nearest_multiple(s: int, multiple: int) -> int: + if s % multiple == 0: + return s + # Subtract off the remainder, then add multiple + return s - s % multiple + multiple + + +def _make_ds( + idx: jax.Array | int, size: jax.Array | int +) -> pl.Slice: + """Make a DMA slice with mosaic size hints.""" + out = pl.ds(idx * size, size) + assert isinstance(out, pl.Slice) + return out + + +def _make_block_slice( + block_index: jax.Array, block_size: int, size: int, tiling: int +) -> pl.Slice | slice: + # Computes a slice given a block index and block size. In the default case, + # we return slice(block_index * block_size, (block_index + 1) * block_size). + # However, if the total size of the ref does not divide block size and we are + # selecting the last block, we need to pick the lowest tiling size multiple + # that contains the block. + if size % block_size == 0: + return _make_ds(block_index, block_size) + if block_size % tiling != 0: + raise ValueError(f"Block size must divide tiling: {block_size=}, {tiling=}") + num_blocks = pl.cdiv(size, block_size) + is_last = block_index == num_blocks - 1 + rounded_size = jnp.where( + is_last, + _round_up_to_nearest_multiple(size % block_size, tiling), + block_size, ) + rounded_size = pl.multiple_of(rounded_size, tiling) + return pl.ds(block_index * block_size, rounded_size) -def _dma_slice_not_equal( - dma_slice_a: tuple[Union[slice, indexing.Slice], ...], - dma_slice_b: tuple[Union[slice, indexing.Slice], ...], -) -> jax.Array: - """Returns True if the two slices are not equal.""" - dma_slice_not_equal = cast(jax.Array, False) - for a, b in zip(dma_slice_a, dma_slice_b): - dma_slice_not_equal = jnp.logical_or( - dma_slice_not_equal, a.start != b.start - ) - return dma_slice_not_equal - - -def _block_copy( - block_spec: core.BlockSpec, - ref: REF, - allocation: Optional[PipelineAllocation], - buffers: PipelineBuffer, - accum_allocation: Optional[PipelineAllocation] = None, - accum_buffers: Optional[PipelineBuffer] = None, - *, - indices: tuple[ - GridIndices, - GridIndices, - GridIndices, - ], - is_input: bool, - is_wait: bool, - force_copy: Optional[Union[jax.Array, bool]] = None, - force_skip: Optional[Union[jax.Array, bool]] = None, - use_accum: Optional[Union[jax.Array, bool]] = None, - accum_if_skipping: Optional[Union[jax.Array, bool]] = None, - zero_accum_if_skipping: Optional[Union[jax.Array, bool]] = None, -): - """General purpose input/output block copys. - - Basic flow: - - - Wait on input copy if previous block spec was different. - - Start input copy if block spec is changing and it's not the last step. - - Wait on output copy if previous block spec was different. - - Start output copy if block spec is changing or is last step. +def _tuples_differ(xs, ys): + """Dynamic index-tuple comparison calculation.""" + differences = jax.tree.map(lambda x, y: x != y, xs, ys) + return functools.reduce(lambda x, y: x | y, differences, False) - The step constraints are enforced with force_copy and caller conds. - Args: - block_spec: Block spec. - ref: HBM ref. - allocation: VMEM ref and semaphore. If this is None it means the source refs - are already in VMEM and we can avoid copy operations. - buffers: Current and next buffer indices. - accum_allocation: Accumulator VMEM ref and semaphore. - accum_buffers: Accumulator current and next buffer indices. - indices: Current grid indices. - is_input: True if is input copy. - is_wait: True if we want to wait on instead of start a copy. - force_copy: Force copy if this condition is True. force_skip overrides this. - force_skip: Force skipping the operation if this condition is True. - use_accum: Whether to add the accum to the VMEM ref before copying. - accum_if_skipping: Whether to accumulate into the current accum buffer if - skipping copy. - zero_accum_if_skipping: Whether to zero out the existing accum before - accumulating into it if skipping copy. +def _grid_size(grid): + """Dynamic grid size calculation.""" + size = jnp.array(1, jnp.int32) + for dim in grid: + size *= dim + return size - Returns: - Current and next buffer indices, swapped if a copy was started. - """ - if allocation is None: - # Has existing allocation. - return buffers - (vmem_ref, sem) = allocation.vmem_ref, allocation.semaphore - (prev_indices, curr_indices, next_indices) = indices - - prev_dma_slice = _run_block_spec(block_spec, prev_indices) - dma_slice = _run_block_spec(block_spec, curr_indices) - next_dma_slice = _run_block_spec(block_spec, next_indices) - - prev_dma_slice_changed = _dma_slice_not_equal(prev_dma_slice, dma_slice) - dma_slice_is_changing = _dma_slice_not_equal(dma_slice, next_dma_slice) - - buffer, next_buffer = buffers.current, buffers.next - if is_input: - if is_wait: - # We wait for inputs of the current body iteration. - used_dma_slice = dma_slice - used_buffer = buffer - else: - # We send to the next ones. - used_dma_slice = next_dma_slice - used_buffer = next_buffer - else: - if is_wait: - # We wait for the outputs of the previous body iteration. - used_dma_slice = prev_dma_slice - used_buffer = next_buffer - else: - # We send the current ones. - used_dma_slice = dma_slice - used_buffer = buffer - if is_input: - from_ref = ref.at[used_dma_slice] - to_ref = vmem_ref.at[used_buffer] - else: - from_ref = vmem_ref.at[used_buffer] - to_ref = ref.at[used_dma_slice] - - async_copy = tpu_primitives.make_async_copy( - from_ref, - to_ref, - sem, +def _get_indices(step, grid, offsets): + """Get indices for a given step and grid.""" + extended_grid = grid + (1,) + strides = tuple( + itertools.accumulate(extended_grid[::-1], func=operator.mul))[::-1] + indices = tuple( + lax.div(lax.rem(step, a), b) + for a, b in zip(strides[:-1], strides[1:]) ) - - if is_wait: - cond = prev_dma_slice_changed - do_fn = async_copy.wait - advance_buffers = False - else: - cond = dma_slice_is_changing - do_fn = async_copy.start - advance_buffers = True - - if force_copy is not None: - cond = jnp.logical_or(cond, force_copy) - if force_skip is not None: - cond = jnp.logical_and(cond, jnp.logical_not(force_skip)) - - def do_and_advance_buffers(): - if accum_allocation is not None: - - def accum(): - with tpu_primitives.trace("ep_accum_copy"): - accum_dtype = jnp.float32 - if vmem_ref.dtype == jnp.int32: - accum_dtype = jnp.int32 - accum_vmem_ref = accum_allocation.vmem_ref - vmem_ref[used_buffer] = ( - vmem_ref[used_buffer].astype(accum_dtype) - + accum_vmem_ref[accum_buffers.current].astype(accum_dtype) - ).astype(vmem_ref.dtype) - - lax.cond(use_accum, accum, lambda: None) - - do_fn() - if advance_buffers: - return PipelineBuffer(next_buffer, buffer) - return buffers - - def dont_advance_buffers(): - if accum_allocation is not None: - - def accum(): - with tpu_primitives.trace("ep_accum_store"): - - def zero_accum(): - accum_vmem_ref = accum_allocation.vmem_ref - accum_vmem_ref[...] = jnp.zeros_like(accum_vmem_ref[...]) - - lax.cond(zero_accum_if_skipping, zero_accum, lambda: None) - - accum_dtype = jnp.float32 - if vmem_ref.dtype == jnp.int32: - accum_dtype = jnp.int32 - accum_vmem_ref = accum_allocation.vmem_ref - accum_vmem_ref[accum_buffers.current] = ( - accum_vmem_ref[accum_buffers.current].astype(accum_dtype) - + vmem_ref[used_buffer].astype(accum_dtype) - ).astype(accum_vmem_ref.dtype) - - lax.cond(accum_if_skipping, accum, lambda: None) - - return buffers - - return lax.cond(cond, do_and_advance_buffers, dont_advance_buffers) - - -# Start copying an input's next block to its next buffer. -_start_block_copy_in = partial(_block_copy, is_input=True, is_wait=False) -# Wait for the copy of an input's current block to its current buffer. -_wait_block_copy_in = partial(_block_copy, is_input=True, is_wait=True) -# Start copying an output's current block from its current buffer. -_start_block_copy_out = partial(_block_copy, is_input=False, is_wait=False) -# Wait for the copy of an output's previous block from its previous buffer. -_wait_block_copy_out = partial(_block_copy, is_input=False, is_wait=True) - - -class PipelineBody(Protocol): - """Body of a pipeline.""" - - def __call__(self, *ref_args: PipelineRefs) -> None: - ... - - -class MakePipelineRefs(Protocol): - """Makes pipeline refs from flat user friendly function args.""" - - def __call__(self, *ref_args: PipelineRefs) -> PipelineArg[PipelineRefs]: - ... + return tuple(a + b for a, b in zip(indices, offsets, strict=True)) -class MakePipelineAllocations(Protocol): - """Makes pipeline allocations from flat user friendly function args.""" - - def __call__( - self, *ref_args: PipelineRefs, return_treedef: bool = False - ) -> Any: - ... +class BufferType(enum.Enum): + """Buffer type for the arguments to an emitted pipeline.""" + INPUT = 1 + OUTPUT = 2 + ACCUMULATOR = 3 +@tree_util.register_pytree_node_class @dataclasses.dataclass(frozen=True) -class PipelinePrefetchArgs: - """Args for pipeline prefetch.""" - pipeline_refs: PipelineArg[PipelineRefs] - pipeline_allocations: PipelineArg[PipelineAllocations] - pipeline_buffers: PipelineArg[PipelineBuffers] - - -class StartPipelinePrefetch(Protocol): - """Starts pipeline prefetch. - - Use force_copy if a spec's indices don't change from last to first grid - indices and you still want to force a copy. This must be used in conjunction - with the prologue's return value to force a wait. +class BufferedRef: + """A helper class to automate VMEM double buffering in pallas pipelines. + + Attributes: + spec: pallas blockspec. + dtype: dtype for buffers. + buffer_type: enum indicating whether this is an input, output, or in/out + accumulator buffered reference. + vmem_ref: a double-buffer to hold a working buffer and a dirty buffer used + to copy into and out of. In the case of a BufferedRef targeting a VMEM + reference, this simply points to the existing ref. + accum_ref: accumulating buffer used by accumulator BufferedRefs. + current_slot: current slot index to the working buffer. + next_slot: slot that will point to the working buffer in the next iteration. + sem_recv: semaphore for input DMAs. + sem_send: semaphore for output DMAs. + + block_shape: passthrough property for the BlockSpec's block_shape. + compute_index: passthrough property for the BlockSpec's compute_index. + memory_space: passthrough property for the BlockSpec's memory_space. + current_ref: points to the current working slice of the double-buffer. + is_input: whether this BufferedRef acts as a pipeline input. + is_output: whether this BufferedRef acts as a pipeline output. + is_accumulator: whether this BufferedRef is an accumulator. """ + spec: pl.BlockSpec # static metadata + dtype: Any # static metadata + buffer_type: BufferType # static metadata + vmem_ref: REF | None + accum_ref: REF | None + current_slot: ArrayRef | None + next_slot: ArrayRef | None + sem_recv: SemaphoreType | None + sem_send: SemaphoreType | None - def __call__( - self, - prefetch_args: PipelinePrefetchArgs, - *, - force_copy: Union[ - bool, tuple[Union[CondVal, Any], Union[CondVal, Any]] - ] = False, - force_skip: Union[ - bool, tuple[Union[CondVal, Any], Union[CondVal, Any]] - ] = False, - ) -> tuple[PipelineBuffers, PipelineBuffers]: - ... + def tree_flatten(self): + return ((self.vmem_ref, self.accum_ref, self.current_slot, + self.next_slot, self.sem_recv, self.sem_send), + (self.spec, self.dtype, self.buffer_type)) + @classmethod + def tree_unflatten(cls, meta, data): + return cls(*meta, *data) -@dataclasses.dataclass(frozen=True) -class ManualPrefetchArgs: - """Args for pipeline prefetch.""" + @classmethod + def create(cls, spec, dtype, buffer_type) -> BufferedRef: + """Create a BufferedRef. + + Args: + spec: pallas blockspec. + dtype: dtype for buffers. + buffer_type: enum indicating whether this is an input, output, or in/out + accumulator buffered reference. + + Returns: + Initialized BufferedRef + """ + block_shape = tuple([1 if x is None else x for x in spec.block_shape]) + if spec.memory_space == VMEM: + # We don't need to do any double-buffering in the case that our pipeline + # reference is already in VMEM, we just need allocate the accumulation + # buffer and we will refer to the original reference slices directly. + return cls( + spec=spec, dtype=dtype, + buffer_type=buffer_type, + vmem_ref=None, # to be bound to existing ref by the pipeline routine + accum_ref=(VMEM(block_shape, dtype) + if buffer_type is BufferType.ACCUMULATOR else None), + current_slot=None, next_slot=None, sem_recv=None, sem_send=None) + else: + return cls( + spec=spec, dtype=dtype, + buffer_type=buffer_type, + vmem_ref=VMEM((2,) + block_shape, dtype), + accum_ref=(VMEM(block_shape, dtype) + if buffer_type is BufferType.ACCUMULATOR else None), + current_slot=SMEM((1,), jnp.int32), + next_slot=SMEM((1,), jnp.int32), + sem_recv=(None if buffer_type is BufferType.OUTPUT + else SemaphoreType.DMA), + sem_send=(None if buffer_type is BufferType.INPUT + else SemaphoreType.DMA),) - pipeline_specs: PipelineBlockSpecs - pipeline_refs: PipelineRefs - pipeline_allocations: PipelineAllocations - pipeline_buffers: PipelineBuffers + @classmethod + def input(cls, spec, dtype): + return cls.create(spec, dtype, BufferType.INPUT) + @classmethod + def output(cls, spec, dtype): + return cls.create(spec, dtype, BufferType.OUTPUT) -class StartManualPrefetch(Protocol): - """Starts manual prefetch. + @classmethod + def accumulator(cls, spec, dtype): + return cls.create(spec, dtype, BufferType.ACCUMULATOR) - Use force_copy if a spec's indices don't change from last to first grid - indices and you still want to force a copy. This must be used in conjunction - with the prologue's return value to force a wait. - """ + @property + def block_shape(self): + return self.spec.block_shape - def __call__( - self, - prefetch_args: ManualPrefetchArgs, - *, - indices: GridIndices, - force_copy: Union[bool, Union[CondVal, Any]] = False, - force_skip: Union[bool, Union[CondVal, Any]] = False, - ) -> PipelineBuffers: - ... + @property + def compute_index(self): + return self.spec.compute_index + @property + def memory_space(self): + return self.spec.memory_space -@dataclasses.dataclass(frozen=True) -class PipelineCallbackArgs: - """Args for pipeline prologue and epilogue.""" - pipeline_specs: PipelineArg[PipelineBlockSpecs] - pipeline_refs: PipelineArg[PipelineRefs] - pipeline_buffer_refs: PipelineArg[PipelineBuffers] - pipeline_allocations: PipelineArg[PipelineAllocations] - pipeline_buffers: PipelineArg[PipelineBuffers] - make_pipeline_refs: MakePipelineRefs - start_pipeline_prefetch: StartPipelinePrefetch - start_manual_prefetch: StartManualPrefetch - run_manual_compute: Callable[[Callable[[], None]], None] - - -PipelinePrologue = Callable[ - [PipelineCallbackArgs], - # Returns a tuple of tuples of prefix-pytrees for inputs and accums. The - # first specifies which ones to skip the prologue input copy for and the - # second specifies which ones to force the prologue input wait on. - tuple[ - tuple[Union[CondVal, Any], Union[CondVal, Any]], - tuple[Union[CondVal, Any], Union[CondVal, Any]], - ], -] -PipelineEpilogue = Callable[ - [PipelineCallbackArgs], tuple[PipelineBuffers, PipelineBuffers] -] -PipelineOutPrologue = Callable[[PipelineCallbackArgs], Union[CondVal, Any]] -PipelineOutEpilogue = Callable[[PipelineCallbackArgs], Union[CondVal, Any]] - - -class Pipeline(Protocol): - - def __call__( - self, - *ref_args: PipelineRefs, - scratchs: PipelineRefs = None, - allocations: Union[None, Any] = None, - init_allocations: CondVal = False, - prologue: Union[PipelinePrologue, None] = None, - epilogue: Union[PipelineEpilogue, None] = None, - out_prologue: Union[PipelineOutPrologue, None] = None, - out_epilogue: Union[PipelineOutEpilogue, None] = None, - ) -> None: - ... + @property + def current_ref(self): + buffer_slice = tuple( + [0 if x is None else slice(None) for x in self.block_shape]) + if self.memory_space == VMEM: + return self.vmem_ref.at[buffer_slice] + else: + return self.vmem_ref.at[(self.current_slot[0], *buffer_slice)] + @property + def is_input(self): + return self.buffer_type in [BufferType.INPUT, BufferType.ACCUMULATOR] -def emit_pipeline_with_allocations( - body: PipelineBody, - *, - grid: core.StaticGrid, - in_specs: PipelineBlockSpecs, - out_specs: PipelineBlockSpecs, - should_accumulate_out: Union[Sequence[bool], Any] = False, -) -> tuple[Pipeline, MakePipelineAllocations]: - """Wraps body function in a custom pipeline defined by grid and specs. + @property + def is_output(self): + return self.buffer_type in [BufferType.OUTPUT, BufferType.ACCUMULATOR] - This has the same semantics as pallas_call but is meant to be called inside - pallas_call for nesting grids. This is useful when you need to have separate - windowing strategies for example for communication vs. computation. + @property + def is_accumulator(self): + return self.buffer_type == BufferType.ACCUMULATOR + + def bind_existing_ref(self, vmem_ref, indices): + """For handling VMEM references, the pipeline aliases the existing ref.""" + if self.memory_space == VMEM: + return dataclasses.replace( + self, vmem_ref=vmem_ref.at[self.compute_slice(indices)]) + return self + + def compute_slice(self, grid_indices): + """Compute DMA slice from grid indices.""" + block_shape = tuple([1 if x is None else x for x in self.block_shape]) + indices = self.compute_index(*grid_indices) + return jax.tree.map(_make_ds, indices, block_shape) + + def init_slots(self): + """Initialize slot indices.""" + if self.memory_space == VMEM: return + self.current_slot[0] = 0 + self.next_slot[0] = 0 + + def swap_slots(self): + """Switch to the next slot.""" + if self.memory_space == VMEM: return + self.current_slot[0] = self.next_slot[0] + + def get_dma_slice(self, src_shape, src_dtype, grid_indices): + # We need to handle blocks that might go OOB in the src array. An in bounds + # block looks like this (for array shape (600, 600) and block shape + # (256, 256)): + # + # +--------------+------------------| + # | Block (0,0) | | + # | (256, 256) | | + # +--------------+ | + # | A (600, 600) | + # | | + # +---------------------------------+ + # + # For in-bounds blocks, we don't need to do anything special. + # An out-of-bounds block looks like this: + # + # +--------------+------------------| + # | | + # | | + # + | + # | A (600, 600) | + # +--------------+ | + # | Block (2,0) | | + # + --------------------------------| + # | XXXXXXXXXX | + # +--------------+ + # where the X's indicate where the block is out of bounds. + # + # When we have an out of bounds block like this, we need to truncate it to + # a tile boundary (tiles are (8, 128) along the two minormost dimensions). + # In this case, we'll have a block that is indexing the + # 512:768 elements of A along the first dimension. We need to convert 768 + # into 600 (600 % 8 == 0), so our indexing will look like this: + + # +--------------+------------------| + # | | + # | | + # + | + # | A (600, 600) | + # +--------------+ | + # | Block (2,0) | | + # + --------------------------------| + # where it is now a (88, 256) sized block. + # + # Suppose A is now (601, 600), instead of picking a (88, 256)-sized block + # for the last iteration on that dimension, we will pick the next highest + # tile multiple, i.e. (96, 256). + if len(src_shape) < 2: + raise NotImplementedError("Must use >1D values.") + + tiling = _make_tiling(src_shape, src_dtype) + block_shape = tuple(1 if b is None else b for b in self.block_shape) + block_indices = self.compute_index(*grid_indices) + return jax.tree.map( + _make_block_slice, block_indices, block_shape, src_shape, tiling + ) - By default outputs are written to but `should_accumulate_out` can be used to - specify which outputs we should add to instead. This is so we can reduce - across pipeline calls within and across parent grid iterations. + def copy_in(self, src_ref, grid_indices): + """Starts copy of HBM dma slice into the current slot.""" + assert self.is_input + if self.memory_space == VMEM: return + next_slot = lax.rem(self.current_slot[0] + 1, 2) + self.next_slot[0] = next_slot + src_slice = self.get_dma_slice(src_ref.shape, src_ref.dtype, grid_indices) + dst_slice = tuple(pl.ds(0, s.size) for s in src_slice) + tpu_primitives.make_async_copy( + src_ref.at[src_slice], + self.vmem_ref.at[next_slot].at[dst_slice], + self.sem_recv).start() + + def copy_out(self, dst_ref, grid_indices): + """Starts copy of HBM dma slice from the current slot.""" + assert self.is_output + if self.memory_space == VMEM: return + slot = self.current_slot[0] + self.next_slot[0] = lax.rem(slot + 1, 2) + dst_slice = self.get_dma_slice(dst_ref.shape, dst_ref.dtype, grid_indices) + src_slice = tuple(pl.ds(0, s.size) for s in dst_slice) + tpu_primitives.make_async_copy( + self.vmem_ref.at[slot].at[src_slice], + dst_ref.at[dst_slice], + self.sem_send).start() + + def wait_in(self, src_ref, grid_indices): + """Waits for input copy to finish.""" + assert self.is_input + if self.memory_space == VMEM: return + src_slice = self.get_dma_slice(src_ref.shape, src_ref.dtype, grid_indices) + dst_slice = tuple(pl.ds(0, s.size) for s in src_slice) + tpu_primitives.make_async_copy( + src_ref.at[src_slice], # nb: doesn't matter + self.vmem_ref.at[self.current_slot[0]].at[dst_slice], # only dst shape is important + self.sem_recv).wait() + + def wait_out(self, dst_ref, grid_indices): + """Waits for output copy to finish.""" + assert self.is_output + if self.memory_space == VMEM: return + prev_slot = lax.rem(self.current_slot[0] + 1, 2) + dst_slice = self.get_dma_slice(dst_ref.shape, dst_ref.dtype, grid_indices) + src_slice = tuple(pl.ds(0, s.size) for s in dst_slice) + tpu_primitives.make_async_copy( + self.vmem_ref.at[prev_slot].at[src_slice], # nb: doesn't matter + dst_ref.at[dst_slice], # only dst shape is important + self.sem_send).wait() + + # Accumulator methods + # + # Accumulating inline in VMEM saves half the HBM<->VMEM bandwidth cost of + # doing another full loop around HBM to do a reduction, at the current cost + # of allocating another VMEM buffer. + # + # NB: there's no actual need to have an additional accumulation buffer, if + # we just rewrote inner kernels to handle the initial-zero-init and output + # reduction, we don't need to waste VMEM. Consider removing this magic + # init and reduce support. + + def set_accumulator(self, init=False): + """Set accumulator or zero it out to initialize.""" + assert self.is_accumulator + if self.accum_ref is not None: + def _init(): + self.accum_ref[...] = jnp.zeros_like(self.accum_ref[...]) + def _set(): + self.accum_ref[...] = self.current_ref[...].astype(self.accum_ref) + lax.cond(init, _init, _set) + + def accumulate(self): + """Add into the current slot.""" + assert self.is_accumulator + if self.accum_ref is not None: + accum_dtype = jnp.float32 + if self.vmem_ref.dtype == jnp.int32: + accum_dtype = jnp.int32 + # TODO(levskaya): we could generalize init and reduction functions, + # could it ever be useful to support more generic monoids? + self.current_ref[...] = ( + self.current_ref[...].astype(accum_dtype) + + self.accum_ref[...].astype(accum_dtype) + ).astype(self.vmem_ref.dtype) + + +# Helper to tree map over BufferedRefs as leaves. +map_brefs = functools.partial( + jax.tree.map, + is_leaf=lambda x: isinstance(x, BufferedRef)) + + +class Scheduler: + """Sequences input and output copies and waits for a pipeline.""" + + def __init__(self, + step: jax.Array, + grid: tuple[int | jax.Array, ...], + grid_offsets: tuple[int | jax.Array, ...], + first_cycle=None, + last_cycle=None, + init_accumulators=None, + ): + """Initializes scheduler. + + Args: + step: inner step number. + grid: pallas grid for BufferedRefs. + grid_offsets: offsets for grid indices (used for megacore). + first_cycle: whether this is the first invocation of the pipeline. + last_cycle: whether this is the last invocation of the pipeline. + init_accumulators: do we zero-initialize accumulator state for this + invocation of the pipeline. + """ + self.step = step + self.grid = grid + self.first_cycle = first_cycle + self.last_cycle = last_cycle + self.init_accumulators = init_accumulators + + # Total number of linear steps. + self.num_steps = _grid_size(grid) + + # First and last inner step conditionals. + self.first_step = step == 0 + self.last_step = step == self.num_steps - 1 + + # First and last total step conditionals. + self.first_step_ever = first_cycle & self.first_step + self.last_step_ever = last_cycle & self.last_step + + # Cyclic steps + self.prev_step = _mod(step - 1, self.num_steps) + self.next_step = _mod(step + 1, self.num_steps) + + # Derived grid indices for present, previous, and next steps. + self.indices = _get_indices(step, grid, grid_offsets) + self.prev_indices = _get_indices( + self.prev_step, grid, grid_offsets + ) + self.next_indices = _get_indices( + self.next_step, grid, grid_offsets + ) - This is like `pltpu.emit_pipeline` but also returns a function for creating - the allocation descriptors for the pipeline so they can be allocated at a - parent grid and passed in so they live across the parent grid's iterations. + def grid_env(self): + return pallas_core.grid_env( + list(map(pallas_core.GridAxis, self.indices, self.grid))) + + def has_changed(self, buffered_ref): + indices = buffered_ref.compute_index(*self.indices) + prev_indices = buffered_ref.compute_index(*self.prev_indices) + return _tuples_differ(indices, prev_indices) + + def will_change(self, buffered_ref): + indices = buffered_ref.compute_index(*self.indices) + next_indices = buffered_ref.compute_index(*self.next_indices) + return _tuples_differ(indices, next_indices) + + def alias_local_refs(self, buffered_ref, ref): + return buffered_ref.bind_existing_ref(ref, self.indices) + + # SCHEDULE ---------------------------------------------------------------- + + # Below is the sequence of conditional waits and copies used for inputs, + # outputs, and in-out accumulators. + + def initialize(self, buffered_ref, src_ref, schedule=None): + pred = self.first_step_ever + if schedule is not None: + pred = schedule['prologue_copy_in'](self, buffered_ref, src_ref) + + with jax.named_scope("ep_initialize"): + @pl.when(self.first_step_ever) + def _init_slots(): + buffered_ref.init_slots() + + @pl.when(pred) + def _start(): + if buffered_ref.is_input: + buffered_ref.copy_in(src_ref, self.indices) + + buffered_ref.swap_slots() + + def wait_in(self, buffered_ref, src_ref, schedule=None): + pred = self.has_changed(buffered_ref) | self.first_step + if schedule is not None: + pred = schedule['wait_in'](self, buffered_ref, src_ref) + + @jax.named_scope("ep_wait_in") + def _wait(): + if buffered_ref.is_input: + buffered_ref.wait_in(src_ref, self.indices) + if buffered_ref.is_accumulator: + buffered_ref.set_accumulator(self.init_accumulators) + @jax.named_scope("ep_set_accum") + def _no_wait(): + if buffered_ref.is_accumulator: + @pl.when(self.first_step) + def _set_accumulator(): + buffered_ref.set_accumulator(self.init_accumulators) + lax.cond(pred, _wait, _no_wait) + + def copy_in(self, buffered_ref, src_ref, schedule=None): + pred = self.will_change(buffered_ref) & ~self.last_step_ever + if schedule is not None: + pred = schedule['copy_in'](self, buffered_ref, src_ref) + + @pl.when(pred) + @jax.named_scope("ep_copy_in") + def _send(): + if buffered_ref.is_input: + @pl.when(~self.last_step) + def _copy_in(): + buffered_ref.copy_in(src_ref, self.next_indices) + + # --> Call prefetch here to grab the first inputs of next cycle. + + # convenience method for prefetch callbacks. + def prefetch(self, buffered_ref, src_ref, schedule=None): + pred = ((self.will_change(buffered_ref) | self.last_step) & + ~self.last_step_ever) + if schedule is not None: + pred = schedule['prefetch'](self, buffered_ref, src_ref) + + @pl.when(pred) + @jax.named_scope("ep_prefetch") + def _send(): + if buffered_ref.is_input: + @pl.when(self.last_step) + def _prefetch_in(): + buffered_ref.copy_in(src_ref, self.next_indices) + + def wait_out(self, buffered_ref, dst_ref, schedule=None): + pred = ((self.has_changed(buffered_ref) | self.first_step) & + ~self.first_step_ever) + if schedule is not None: + pred = schedule['wait_out'](self, buffered_ref, dst_ref) + + @pl.when(pred) + @jax.named_scope("ep_wait_out") + def _wait(): + if buffered_ref.is_output: + buffered_ref.wait_out(dst_ref, self.prev_indices) + + # --> Call "postyeet" here, after last output copy is finished from previous + # cycle + + def copy_out(self, buffered_ref, dst_ref, schedule=None): + pred = self.will_change(buffered_ref) | self.last_step + if schedule is not None: + pred = schedule['copy_out'](self, buffered_ref, dst_ref) + + @jax.named_scope("ep_copy_out") + def _copy_out_and_accumulate(): + if buffered_ref.is_accumulator: + buffered_ref.accumulate() + if buffered_ref.is_output: + buffered_ref.copy_out(dst_ref, self.indices) + @jax.named_scope("ep_accum") + def _just_accumulate(): + if buffered_ref.is_accumulator: + @pl.when(self.last_step) + def _accumulate(): + buffered_ref.accumulate() + lax.cond(pred, _copy_out_and_accumulate, _just_accumulate) + + def finalize(self, buffered_ref, dst_ref, schedule=None): + pred = self.last_step_ever + if schedule is not None: + pred = schedule['epilogue_wait_out'](self, buffered_ref, dst_ref) + + @pl.when(pred) + @jax.named_scope("ep_finalize") + def _end(): + if buffered_ref.is_output: + buffered_ref.swap_slots() # formally correct, not actually necessary. + buffered_ref.wait_out(dst_ref, self.indices) + + # END SCHEDULE -------------------------------------------------------------- + + +# Scheduling overrides. + +# When trying to fuse across pipelines that use accumulator arguments, we +# sometimes need to mess with the default scheduling above to avoid data-races +# or to maximize performance. A schedule is simply a set of functions that +# calculate predicates for whether or not the pipeline input and output +# BufferedRefs should do copies and waits. + + +# Copy of the default pipeline schedule. The default schedule tacitly assumes +# that the source and target HBM Refs change with each cycle. +_default_schedule = dict( + prologue_copy_in=lambda s, bref, _: s.first_step_ever, + wait_in=lambda s, bref, _: s.has_changed(bref) | s.first_step, + copy_in=lambda s, bref, _: s.will_change(bref) & ~s.last_step_ever, + prefetch=lambda s, bref, _: ( + (s.will_change(bref) | s.last_step) & ~s.last_step_ever), + wait_out=lambda s, bref, _: ( + (s.has_changed(bref) | s.first_step) & ~s.first_step_ever), + copy_out=lambda s, bref, _: s.will_change(bref) | s.last_step, + epilogue_wait_out=lambda s, bref, _: s.last_step_ever, +) + + +# Alternative schedule needed for accumulators reading and writing to a fixed +# HBM reference to avoid HBM data races for trivially small grids: only +# read/write when tiles change or at the very beginning or end of a fused +# pipeline schedule. +_fixed_schedule = dict( + prologue_copy_in=lambda s, bref, _: s.first_step_ever, + wait_in=lambda s, bref, _: s.has_changed(bref) | s.first_step_ever, + copy_in=lambda s, bref, _: s.will_change(bref) & ~s.last_step_ever, + prefetch=lambda s, bref, _: s.will_change(bref) & ~s.last_step_ever, + wait_out=lambda s, bref, _: s.has_changed(bref) & ~s.first_step_ever, + copy_out=lambda s, bref, _: s.will_change(bref) | s.last_step_ever, + epilogue_wait_out=lambda s, bref, _: s.last_step_ever, +) + + +def get_pipeline_schedule(schedule) -> Any: + """Retrieve a named pipeline schedule or pass through fully specified one.""" + predefined_schedules = { + 'default': _default_schedule, + 'fixed': _fixed_schedule + } + if isinstance(schedule, str): + return predefined_schedules[schedule].copy() + return schedule + + +# Main pipeline methods + + +def make_pipeline_allocations( + *refs, + in_specs=None, + out_specs=None, + should_accumulate_out=False, +): + """Create BufferedRefs for the pipeline. + + This function creates buffered refs for an inner pipeline that can be + created at the top-level of a pallas call such that they may be reused across + multiple invocations of the inner pipeline. Args: - body: Pipeline body. - grid: Pallas grid. - in_specs: Input block specs. - out_specs: Output block specs. - should_accumulate_out: Prefix-pytree of out_specs specifying which should be - accumulated into with True. + in_specs: input pallas block specs + out_specs: output pallas block specs + should_accumulate_out: booleans to indicate which outputs should be treated + as accumulators. Returns: - Tuple of wrapped pipelined body and a function to create the allocation - descriptors. + A list of BufferedRefs, one corresponding to each ref specified in the + in_specs and out_specs. """ + # TODO(levskaya): generalize argument tree handling here and in emit_pipeline. + num_in_specs = len(in_specs) if not isinstance(in_specs, (list, tuple)): - in_specs = [in_specs] + in_specs = (in_specs,) if not isinstance(out_specs, (list, tuple)): - out_specs = [out_specs] - should_accumulate_out = _broadcast_pytree_to( - "should_accumulate_out", should_accumulate_out, out_specs - ) - in_out_specs = tree_util.tree_map( - lambda spec, accum: spec if accum else None, - out_specs, - should_accumulate_out, - ) - pipeline_specs: PipelineArg[PipelineBlockSpecs] = PipelineArg( - in_specs, out_specs, in_out_specs - ) - del in_specs, out_specs, should_accumulate_out, in_out_specs - pipeline_specs_with_nones = pipeline_specs - pipeline_specs = jax.tree_util.tree_map( - _replace_nones_in_block_spec, pipeline_specs_with_nones - ) - - def make_pipeline_refs( - *ref_args: PipelineRefs, - ) -> PipelineArg[PipelineRefs]: - in_refs, out_refs = split_list(ref_args, [len(pipeline_specs.input)]) - return PipelineArg(in_refs, out_refs, out_refs) - - def start_pipeline_prefetch( - prefetch_args: PipelinePrefetchArgs, - *, - indices: GridIndices, - force_copy: Union[ - bool, tuple[Union[CondVal, Any], Union[CondVal, Any]] - ] = False, - force_skip: Union[ - bool, tuple[Union[CondVal, Any], Union[CondVal, Any]] - ] = False, - ) -> tuple[PipelineBuffers, PipelineBuffers]: - if isinstance(force_copy, bool): - force_copy = (force_copy, force_copy) - if isinstance(force_skip, bool): - force_skip = (force_skip, force_skip) - force_input_copy, force_in_out_copy = force_copy - force_input_copy = _broadcast_pytree_to( - "force_input_copy", - force_input_copy, - pipeline_specs.input, - ) - force_in_out_copy = _broadcast_pytree_to( - "force_in_out_copy", - force_in_out_copy, - pipeline_specs.in_out, + out_specs = (out_specs,) + if isinstance(in_specs, list): + in_specs = tuple(in_specs) + if isinstance(out_specs, list): + out_specs = tuple(out_specs) + in_refs = refs[:num_in_specs] + out_refs = refs[num_in_specs:] + def make_input_bref(in_spec, in_ref): + return BufferedRef.input(in_spec, in_ref.dtype) + in_brefs = jax.tree.map(make_input_bref, in_specs, in_refs) + def make_output_bref(out_spec, out_ref, accumulate): + if accumulate: + return BufferedRef.accumulator(out_spec, out_ref.dtype) + return BufferedRef.output(out_spec, out_ref.dtype) + out_brefs = jax.tree.map( + make_output_bref, out_specs, out_refs, should_accumulate_out) + return (*in_brefs, *out_brefs) + + +class GridDimensionSemantics: + pass +PARALLEL = GridDimensionSemantics() +ARBITRARY = GridDimensionSemantics() + + +def _partition_grid( + grid: tuple[int | jax.Array, ...], + core_axis: int | None, + dimension_semantics: tuple[GridDimensionSemantics, ...] | None, +) -> tuple[tuple[int | jax.Array, ...], tuple[int | jax.Array, ...]]: + if core_axis is None: + # We aren't partitioning the grid + return grid, (0,) * len(grid) + num_cores = pl.num_programs(core_axis) + # Check that num_cores is statically known + if not isinstance(num_cores, int): + raise NotImplementedError( + f"Cannot partition grid over dynamic number of cores: {core_axis=}" ) - force_input_skip, force_in_out_skip = force_skip - force_input_skip = _broadcast_pytree_to( - "force_input_skip", - force_input_skip, - pipeline_specs.input, + if num_cores == 1: + # We aren't partitioning the grid + return grid, (0,) * len(grid) + + # If dimension_semantics aren't provided, we assume it is all arbitrary. + if dimension_semantics is None: + dimension_semantics = (ARBITRARY,) * len(grid) + if len(dimension_semantics) != len(grid): + raise ValueError("dimension_semantics must be the same length as grid.") + + parallel_dimensions = {i for i, d in enumerate(dimension_semantics) + if d == PARALLEL} + # If there are no parallel dimensions, we can't partition the grid + if not parallel_dimensions: + # TODO(sharadmv): enable running kernel on just one core + raise NotImplementedError( + "Cannot partition over cores without parallel grid dimensions:" + f" {dimension_semantics=}" ) - force_in_out_skip = _broadcast_pytree_to( - "force_in_out_skip", - force_in_out_skip, - pipeline_specs.in_out, + if all(not isinstance(grid[i], int) for i in parallel_dimensions): + raise NotImplementedError( + f"Cannot partition cores over only dynamic grid dimensions: {grid=}" ) - next_in_and_in_out_buffers = _tree_map_with_kwargs( - partial(_start_block_copy_in, indices=indices), - pipeline_specs.input_and_in_out, - prefetch_args.pipeline_refs.input_and_in_out, - prefetch_args.pipeline_allocations.input_and_in_out, - prefetch_args.pipeline_buffers.input_and_in_out, - force_copy=force_input_copy + force_in_out_copy, - force_skip=force_input_skip + force_in_out_skip, + # Try to find a divisible dimension to partition the grid on + divisible_dimensions = { + i for i in parallel_dimensions + if isinstance(grid[i], int) and grid[i] % num_cores == 0 + } + if divisible_dimensions: + first_divisible_dimension, *_ = ( + i for i in range(len(dimension_semantics)) if i in divisible_dimensions ) - next_in_buffers, next_in_out_buffers = split_list( - next_in_and_in_out_buffers, [len(pipeline_specs.input)] + partitioned_dim_size = grid[first_divisible_dimension] // num_cores + partitioned_dim_offset = pl.program_id(core_axis) * partitioned_dim_size + new_grid = jax_util.tuple_update( + grid, first_divisible_dimension, partitioned_dim_size ) - return next_in_buffers, next_in_out_buffers - - def start_manual_prefetch( - prefetch_args: ManualPrefetchArgs, - *, - indices: GridIndices, - force_copy: Union[bool, Union[CondVal, Any]] = False, - force_skip: Union[bool, Union[CondVal, Any]] = False, - ) -> PipelineBuffers: - force_copy = _broadcast_pytree_to( - "force_input_copy", - force_copy, - prefetch_args.pipeline_specs, + offsets = jax_util.tuple_update( + (0,) * len(grid), first_divisible_dimension, partitioned_dim_offset ) - force_skip = _broadcast_pytree_to( - "force_skip", - force_skip, - prefetch_args.pipeline_specs, + else: + # No divisible dimensions, so we can't evenly partition the grid. Let's pick + # the largest dimension and try to divide it as evenly as possible. + # TODO(sharadmv): take the product of many nondivisible dimensions to + # potentially divide it more evenly + largest_parallel_dimension = max(grid[i] for i in parallel_dimensions + if isinstance(grid[i], int)) # type: ignore + partition_dimension, *_ = ( + i + for i, d in enumerate(grid) + if isinstance(d, int) and d == largest_parallel_dimension ) - next_buffers = _tree_map_with_kwargs( - partial(_start_block_copy_in, indices=indices), - prefetch_args.pipeline_specs, - prefetch_args.pipeline_refs, - prefetch_args.pipeline_allocations, - prefetch_args.pipeline_buffers, - force_copy=force_copy, - force_skip=force_skip, + base_num_iters, rem = divmod(grid[partition_dimension], num_cores) + assert rem > 0, rem + # We have some remainder iterations that we need to assign somewhere. We + # know that rem < num_cores, so we can assign one extra iteration to each + # core except for the last (num_cores - rem). + core_index = pl.program_id(core_axis) + num_iters = jnp.where(core_index < rem, base_num_iters + 1, + base_num_iters) + new_grid = jax_util.tuple_update(grid, partition_dimension, num_iters) + # Ordinarily, we would compute the offset as: + # grid_offset = pl.program_id(core_axis) * num_iters + # However, since we have some cores that don't have an extra iteration, we + # need to adjust the offset by `rem`. + grid_offset = jnp.where( + core_index < rem, + core_index * num_iters, + core_index * base_num_iters + rem, ) - return next_buffers - - def run_manual_compute(fn: Callable[[], None]) -> None: - fn() - - def make_pipeline_allocations( - *ref_args: PipelineRefs, - return_treedef: bool = False, - ) -> tuple[tuple[Any, tree_util.PyTreeDef], Any]: - pipeline_buffers = tree_util.tree_map( - lambda _: PipelineBuffer(*((SMEM((1,), jnp.int32),) * 2)), - pipeline_specs, + offsets = jax_util.tuple_update( + (0,) * len(grid), partition_dimension, grid_offset ) - pipeline_refs = make_pipeline_refs(*ref_args) - - def make_allocation(spec, ref): - if ref.memory_space == VMEM: - # Don't make an allocation the ref is already in VMEM, we can use it - # directly for free. - return None - return PipelineAllocation( - VMEM((2, *spec.block_shape), getattr(ref, "dtype", ref)), - DMA, - ) + return new_grid, offsets - pipeline_allocations = tree_util.tree_map( - make_allocation, pipeline_specs, pipeline_refs - ) - def grab_allocation(_, ref): - if ref.memory_space == VMEM: - return ref - return None +def emit_pipeline( + body, + *, + grid: tuple[int | jax.Array, ...], + in_specs=None, + out_specs=None, + should_accumulate_out=False, + core_axis: int | None = None, + dimension_semantics: tuple[GridDimensionSemantics, ...] | None = None +): + """Creates a function to emit a manual pallas pipeline. - pipeline_existing_allocations = tree_util.tree_map( - grab_allocation, pipeline_specs, pipeline_refs - ) + This has the same semantics as pallas_call but is meant to be called inside + pallas_call for nesting grids. This is useful when you need to have separate + windowing strategies for communication and computation. - def make_in_out_existing_allocations(spec, ref): - if ref.memory_space == VMEM: - return VMEM(spec.block_shape, getattr(ref, "dtype", ref)) - return None + The new argument `should_accumulate_out` can be used to specify which outputs + we should accumulate into automatically within and across pipeline + invocations. - in_out_existing_allocations = tree_util.tree_map( - make_in_out_existing_allocations, - pipeline_specs.in_out, - pipeline_refs.in_out, + Args: + body: pallas kernel to set up pipeline for. + grid: a pallas grid definition. + in_specs: input pallas block specs + out_specs: output pallas block specs + should_accumulate_out: booleans to indicate which outputs should be treated + as accumulators. + core_axis: optional int, indicates whether or not to partition the grid + along the core axis. + dimension_semantics: optional tuple of GridDimensionSemantics (e.g. PARALLEL + or ARBITRARY). + """ + if any(not isinstance(d, (int, jax.Array)) for d in grid): + grid_types = tuple(type(d) for d in grid) + raise ValueError( + f"Grid must consist of Python integers and JAX Arrays: {grid_types}" ) + grid, grid_offsets = _partition_grid(grid, core_axis, dimension_semantics) - flat_allocations, allocations_treedef = tree_util.tree_flatten(( - pipeline_buffers, - pipeline_allocations, - in_out_existing_allocations, - )) - if return_treedef: - flat_allocations = cast(Any, (tuple(flat_allocations), allocations_treedef)) - return ( - (flat_allocations, allocations_treedef), - pipeline_existing_allocations, - ) + num_steps = _grid_size(grid) + if not isinstance(in_specs, (list, tuple)): + in_specs = (in_specs,) + if not isinstance(out_specs, (list, tuple)): + out_specs = (out_specs,) + if isinstance(in_specs, list): + in_specs = tuple(in_specs) + if isinstance(out_specs, list): + out_specs = tuple(out_specs) + should_accumulate_out = _broadcast_pytree_to(should_accumulate_out, out_specs) def pipeline( - *ref_args: PipelineRefs, - scratchs: Union[PipelineRefs, None] = None, - allocations: Union[ - None, - tuple[PipelineArg[PipelineBuffers], PipelineArg[PipelineAllocations]], - ] = None, - init_allocations: CondVal = False, - prologue: Union[PipelinePrologue, None] = None, - epilogue: Union[PipelineEpilogue, None] = None, - out_prologue: Union[PipelineOutPrologue, None] = None, - out_epilogue: Union[PipelineOutEpilogue, None] = None, - ) -> None: - use_in_out = jnp.logical_not(init_allocations) - if scratchs is None: - scratchs = [] - if not isinstance(scratchs, (list, tuple)): - scratchs = [scratchs] - - def pipeline_body( - pipeline_refs: PipelineArg[PipelineRefs], - pipeline_existing_allocations: PipelineArg[PipelineRefs], - pipeline_buffer_refs: PipelineArg[PipelineBuffers], - pipeline_allocations: PipelineArg[PipelineAllocations], - in_out_existing_allocations: PipelineRefs, - ): - - def init_pipeline_allocations(): - def init_buffer_ref(_, buffer_ref): - buffer_ref.current[0] = 0 - buffer_ref.next[0] = 1 - - tree_util.tree_map( - init_buffer_ref, - pipeline_specs, - pipeline_buffer_refs, - ) - - do_init_pipeline_allocations = jnp.logical_or( - allocations is None, init_allocations - ) - lax.cond( - do_init_pipeline_allocations, - init_pipeline_allocations, - lambda: None, - ) - - zero_indices = (jnp.array(0, dtype=jnp.int32),) * len(grid) - last_indices = tuple( - [jnp.asarray(dim_size - 1, dtype=jnp.int32) for dim_size in grid] - ) - indices = zero_indices - pipeline_buffers: PipelineArg[PipelineBuffers] = tree_util.tree_map( - lambda buffer_ref: buffer_ref[0], - pipeline_buffer_refs, - ) - if prologue is not None: - (skip_input_prologue, skip_in_out_prologue), ( - force_input_prologue_wait, - force_in_out_prologue_wait, - ) = prologue( - PipelineCallbackArgs( - pipeline_specs=pipeline_specs, - pipeline_refs=pipeline_refs, - pipeline_buffer_refs=pipeline_buffer_refs, - pipeline_allocations=pipeline_allocations, - pipeline_buffers=pipeline_buffers, - make_pipeline_refs=make_pipeline_refs, - start_pipeline_prefetch=partial( - cast(Any, start_pipeline_prefetch), - indices=(last_indices, zero_indices, indices), - ), - start_manual_prefetch=partial( - cast(Any, start_manual_prefetch), - indices=(last_indices, zero_indices, indices), - ), - run_manual_compute=run_manual_compute, - ) - ) - else: - skip_input_prologue = False - skip_in_out_prologue = False - force_input_prologue_wait = False - force_in_out_prologue_wait = False - skip_input_prologue = _broadcast_pytree_to( - "skip_input_prologue", - skip_input_prologue, - pipeline_specs.input, - ) - skip_in_out_prologue = _broadcast_pytree_to( - "skip_in_out_prologue", - skip_in_out_prologue, - pipeline_specs.out, - ) - force_input_prologue_wait = _broadcast_pytree_to( - "force_input_prologue_wait", - force_input_prologue_wait, - pipeline_specs.input, - ) - force_in_out_prologue_wait = _broadcast_pytree_to( - "force_in_out_prologue_wait", - force_in_out_prologue_wait, - pipeline_specs.out, - ) - _tree_map_with_kwargs( - partial( - _start_block_copy_in, - indices=(indices, indices, indices), - force_copy=True, - ), - pipeline_specs.input, - pipeline_refs.input, - pipeline_allocations.input, - tree_util.tree_map( - lambda _, buffers: PipelineBuffer(buffers.next, buffers.current), - pipeline_specs.input, - pipeline_buffers.input, - ), - force_skip=skip_input_prologue, - ) - lax.cond( - use_in_out, - lambda: _tree_map_with_kwargs( - partial( - _start_block_copy_in, - indices=(indices, indices, indices), - force_copy=True, - ), - pipeline_specs.in_out, - pipeline_refs.in_out, - pipeline_allocations.in_out, - tree_util.tree_map( - lambda _, buffers: PipelineBuffer( - buffers.next, buffers.current - ), - pipeline_specs.in_out, - pipeline_buffers.in_out, - ), - force_skip=skip_in_out_prologue, - ), - lambda: pipeline_buffers.in_out, - ) - total_iterations = math.prod(grid) - - def fori_loop_body( - i: jax.Array, - carry: tuple[ - GridIndices, - GridIndices, - PipelineArg[PipelineBuffers], - ], - ) -> tuple[ - GridIndices, - GridIndices, - PipelineArg[PipelineBuffers], - ]: - (prev_indices, indices, pipeline_buffers) = carry - next_indices = _get_next_indices(grid, indices) - copy_indices = (prev_indices, indices, next_indices) - - with tpu_primitives.trace("ep_wait_input"): - input_copy_args = [ - pipeline_specs.input, - pipeline_refs.input, - pipeline_allocations.input, - pipeline_buffers.input, - ] - in_out_copy_args = [ - pipeline_specs.in_out, - pipeline_refs.in_out, - pipeline_allocations.in_out, - pipeline_buffers.in_out, - ] - input_force_copy = lambda skip, wait: jnp.logical_and( - i == 0, jnp.logical_or(jnp.logical_not(skip), wait) - ) - _tree_map_with_kwargs( - partial( - _wait_block_copy_in, - indices=copy_indices, - ), - *input_copy_args, - force_copy=tree_util.tree_map( - input_force_copy, - skip_input_prologue, - force_input_prologue_wait, - ), - ) - lax.cond( - use_in_out, - lambda: _tree_map_with_kwargs( - partial( - _wait_block_copy_in, - indices=copy_indices, - ), - *in_out_copy_args, - force_copy=tree_util.tree_map( - input_force_copy, - skip_in_out_prologue, - force_in_out_prologue_wait, - ), - ), - lambda: pipeline_buffers.in_out, - ) - - def start_next_iteration_in_block_copies(): - next_in_buffers = tree_util.tree_map( - partial( - _start_block_copy_in, - indices=copy_indices, - ), - *input_copy_args, - ) - next_in_out_buffers = lax.cond( - use_in_out, - lambda: tree_util.tree_map( - partial( - _start_block_copy_in, - indices=copy_indices, - ), - *in_out_copy_args, - ), - lambda: pipeline_buffers.in_out, - ) - return next_in_buffers, next_in_out_buffers - - @tpu_primitives.trace("ep_run_epilogue") - def run_epilogue(): - if epilogue is None: - return pipeline_buffers.input, pipeline_buffers.in_out - return epilogue( - PipelineCallbackArgs( - pipeline_specs=pipeline_specs, - pipeline_refs=pipeline_refs, - pipeline_buffer_refs=pipeline_buffer_refs, - pipeline_allocations=pipeline_allocations, - pipeline_buffers=pipeline_buffers, - make_pipeline_refs=make_pipeline_refs, - start_pipeline_prefetch=partial( - cast(Any, start_pipeline_prefetch), - indices=(prev_indices, indices, zero_indices), - ), - start_manual_prefetch=partial( - cast(Any, start_manual_prefetch), - indices=(prev_indices, indices, zero_indices), - ), - run_manual_compute=run_manual_compute, - ) - ) - - next_in_buffers, next_in_out_buffers = lax.cond( - i == total_iterations - 1, - run_epilogue, - start_next_iteration_in_block_copies, - ) - - with tpu_primitives.trace("ep_kernel"): - - def grab_body_ref( - spec_with_nones, - spec, - allocation, - buffers, - existing_allocation, - in_out_existing_allocation=None, - ): - if existing_allocation is None: - buffer_slice = tuple([ - 0 if dim is None else slice(None) - for dim in spec_with_nones.block_shape - ]) - return allocation.vmem_ref.at[(buffers.current, *buffer_slice)] - dma_slice = _run_block_spec(spec, indices) - dma_slice = tuple([ - 0 if dim is None else _slice - for dim, _slice in zip(spec_with_nones.block_shape, dma_slice) - ]) - if in_out_existing_allocation is None: - return existing_allocation.at[dma_slice] - return in_out_existing_allocation.at[dma_slice] - - in_args = tree_util.tree_map( - grab_body_ref, - pipeline_specs_with_nones.input, - pipeline_specs.input, - pipeline_allocations.input, - pipeline_buffers.input, - pipeline_existing_allocations.input, - ) - out_args = tree_util.tree_map( - grab_body_ref, - pipeline_specs_with_nones.out, - pipeline_specs.out, - pipeline_allocations.out, - pipeline_buffers.out, - pipeline_existing_allocations.out, - in_out_existing_allocations, - ) - with core.grid_env(cast(Any, zip(indices, grid))): - body(*in_args, *out_args, *scratchs) - - def accum_existing_in_out_existing_allocation( - spec, - existing_allocation, - in_out_existing_allocation, - ): - if ( - existing_allocation is not None - and in_out_existing_allocation is not None - ): - dma_slice = _run_block_spec(spec, indices) - next_dma_slice = _run_block_spec(spec, next_indices) - dma_slice_is_changing = _dma_slice_not_equal( - dma_slice, next_dma_slice - ) - - def init(): - existing_allocation[dma_slice] = in_out_existing_allocation[ - dma_slice - ] - - def accum(): - existing_allocation[dma_slice] = ( - existing_allocation[dma_slice].astype(jnp.float32) - + in_out_existing_allocation[dma_slice].astype(jnp.float32) - ).astype(existing_allocation.dtype) - - lax.cond( - jnp.logical_or( - dma_slice_is_changing, i == total_iterations - 1 - ), - lambda: lax.cond(use_in_out, accum, init), - lambda: None, - ) - - tree_util.tree_map( - accum_existing_in_out_existing_allocation, - pipeline_specs.out, - pipeline_existing_allocations.out, - in_out_existing_allocations, - ) - - with tpu_primitives.trace("ep_wait_output"): - - def run_out_prologue(): - if out_prologue is not None: - skip_out_prologue_wait = out_prologue( - PipelineCallbackArgs( - pipeline_specs=pipeline_specs, - pipeline_refs=pipeline_refs, - pipeline_buffer_refs=pipeline_buffer_refs, - pipeline_allocations=pipeline_allocations, - pipeline_buffers=pipeline_buffers, - make_pipeline_refs=make_pipeline_refs, - start_pipeline_prefetch=partial( - cast(Any, start_pipeline_prefetch), - indices=copy_indices, - ), - start_manual_prefetch=partial( - cast(Any, start_manual_prefetch), - indices=copy_indices, - ), - run_manual_compute=run_manual_compute, - ) - ) - skip_out_prologue_wait = _broadcast_pytree_to( - "skip_out_prologue_wait", - skip_out_prologue_wait, - pipeline_specs.out, - ) - _tree_map_with_kwargs( - partial( - _wait_block_copy_out, - indices=copy_indices, - ), - pipeline_specs.out, - pipeline_refs.out, - pipeline_allocations.out, - pipeline_buffers.out, - force_skip=skip_out_prologue_wait, - ) - - @tpu_primitives.trace("ep_wait_prev_iteration_out_block_copies") - def wait_prev_iteration_out_block_copies(): - tree_util.tree_map( - partial( - _wait_block_copy_out, - indices=copy_indices, - ), - pipeline_specs.out, - pipeline_refs.out, - pipeline_allocations.out, - pipeline_buffers.out, - ) - - lax.cond( - i == 0, - run_out_prologue, - wait_prev_iteration_out_block_copies, - ) - if out_epilogue is not None: - skip_out_epilogue_wait = out_epilogue( - PipelineCallbackArgs( - pipeline_specs=pipeline_specs, - pipeline_refs=pipeline_refs, - pipeline_buffer_refs=pipeline_buffer_refs, - pipeline_allocations=pipeline_allocations, - pipeline_buffers=pipeline_buffers, - make_pipeline_refs=make_pipeline_refs, - start_pipeline_prefetch=cast( - Any, lambda *args, **kwargs: None - ), - start_manual_prefetch=cast( - Any, lambda *args, **kwargs: None - ), - run_manual_compute=cast(Any, lambda *args, **kwargs: None), - ) - ) - else: - skip_out_epilogue_wait = cast(Any, False) - skip_out_epilogue_wait = _broadcast_pytree_to( - "skip_out_epilogue_wait", - skip_out_epilogue_wait, - pipeline_specs.out, - ) - force_start_block_copy_out = jax.tree_util.tree_map( - lambda skip_out_wait: jnp.logical_and( - jnp.logical_not(skip_out_wait), i == total_iterations - 1 - ), - skip_out_epilogue_wait, - ) - next_out_buffers = _tree_map_with_kwargs( - partial( - _start_block_copy_out, - indices=copy_indices, - use_accum=use_in_out, - # If an output tile doesn't change from last to first, we need - # to add its accum since the body overwrites outputs each - # pipeline. - accum_if_skipping=i == total_iterations - 1, - # Initialize the accum if this is the first time this is - # happening. - zero_accum_if_skipping=do_init_pipeline_allocations, - ), - pipeline_specs.out, - pipeline_refs.out, - pipeline_allocations.out, - pipeline_buffers.out, - pipeline_allocations.in_out, - pipeline_buffers.in_out, - force_copy=force_start_block_copy_out, - ) - - prev_indices = indices - indices = next_indices - return ( - prev_indices, - indices, - PipelineArg(next_in_buffers, next_out_buffers, next_in_out_buffers), - ) - - (prev_indices, indices, pipeline_buffers) = lax.fori_loop( - 0, - total_iterations, - fori_loop_body, - (last_indices, indices, pipeline_buffers), - ) - - def set_buffer_ref(buffer_ref, buffer): - buffer_ref[0] = buffer - - tree_util.tree_map( - set_buffer_ref, - pipeline_buffer_refs, - pipeline_buffers, - ) - - with tpu_primitives.trace("ep_end_pipeline"): - with tpu_primitives.trace("ep_wait_output"): - if out_epilogue is not None: - skip_out_epilogue_wait = out_epilogue( - PipelineCallbackArgs( - pipeline_specs=pipeline_specs, - pipeline_refs=pipeline_refs, - pipeline_buffer_refs=pipeline_buffer_refs, - pipeline_allocations=pipeline_allocations, - pipeline_buffers=pipeline_buffers, - make_pipeline_refs=make_pipeline_refs, - start_pipeline_prefetch=partial( - cast(Any, start_pipeline_prefetch), - indices=(prev_indices, indices, zero_indices), - ), - start_manual_prefetch=partial( - cast(Any, start_manual_prefetch), - indices=(prev_indices, indices, zero_indices), - ), - run_manual_compute=run_manual_compute, - ) - ) - else: - skip_out_epilogue_wait = None - skip_out_epilogue_wait = _broadcast_pytree_to( - "skip_out_epilogue_wait", - skip_out_epilogue_wait, - pipeline_specs.out, - ) - _tree_map_with_kwargs( - partial( - _wait_block_copy_out, - indices=(prev_indices, indices, zero_indices), - force_copy=True, - ), - pipeline_specs.out, - pipeline_refs.out, - pipeline_allocations.out, - pipeline_buffers.out, - force_skip=skip_out_epilogue_wait, - ) - - pipeline_refs = make_pipeline_refs(*ref_args) + *refs: Any, + scratches=None, + allocations=None, + first_cycle: CondVal = True, + last_cycle: CondVal = True, + init_accumulators: CondVal = False, + prefetch=None, + postyeet=None, + schedule=None, + ): + """ + Run the pipeline. + + Args: + *ref_args: a list of pallas refs (or more generally a list of pytrees of + pallas refs) + scratches: scratch buffers for the inner kernel + allocations: a list of BufferedRefs, one corresponding to each ref + first_cycle: boolean indicating if this is the first invocation of the + inner pipeline cycle. + last_cycle: boolean indicating if this is the last invocation of the + inner pipeline cycle. + init_accumulators: whether to zero-init accumulators during this cycle. + prefetch: callback called as fn(*brefs, scheduler) that is used to fetch + the next cycle invocations first inputs. Called during the inputs phase + in the final inner step. + postyeet: callback called as fn(*brefs, scheduler) that is used to finish + any writes or transfers from the last output of the previous cycle. + Called during the outputs phase in the first inner step. + schedule: manually specified pipeline schedules for brefs, None indicates + default schedule. + """ + if scratches is None: + scratches = () if allocations is None: - (flat_allocations, allocations_treedef), existing_allocations = ( - make_pipeline_allocations(*ref_args) - ) - tpu_primitives.run_scoped( - partial(pipeline_body, pipeline_refs, existing_allocations), - *tree_util.tree_unflatten(allocations_treedef, flat_allocations), - ) - else: - (_, allocations_treedef), existing_allocations = ( - make_pipeline_allocations(*ref_args) - ) - pipeline_body( - pipeline_refs, - existing_allocations, - *tree_util.tree_unflatten(allocations_treedef, list(allocations)), + # run with inline scoped allocations + return tpu_primitives.run_scoped( + lambda allocations: pipeline( + *refs, + scratches=scratches, + allocations=allocations, + first_cycle=first_cycle, + last_cycle=last_cycle, + init_accumulators=init_accumulators, + prefetch=prefetch, + postyeet=postyeet, + schedule=schedule, + ), + make_pipeline_allocations( + *refs, + in_specs=in_specs, + out_specs=out_specs, + should_accumulate_out=should_accumulate_out), ) + if isinstance(allocations, list): + allocations = tuple(allocations) + # Normalize custom schedule arguments. + if schedule is None: + schedule = map_brefs(lambda x: None, allocations) + if not isinstance(schedule, (list, tuple)): + schedule = map_brefs(lambda x: schedule, allocations) + if isinstance(schedule, list): + schedule = tuple(schedule) + schedule = map_brefs( + lambda _, x: get_pipeline_schedule(x), allocations, schedule) + + def loop_body(step, _): + nonlocal allocations + scheduler = Scheduler( + step, + grid, + grid_offsets=grid_offsets, + first_cycle=first_cycle, + last_cycle=last_cycle, + init_accumulators=init_accumulators) + + # prepare any local VMEM aliases + brefs = map_brefs(scheduler.alias_local_refs, allocations, refs) + + # loop input handling phase + map_brefs(scheduler.initialize, brefs, refs, schedule) + map_brefs(scheduler.wait_in, brefs, refs, schedule) + map_brefs(scheduler.copy_in, brefs, refs, schedule) + + # prefetch inputs for the *next* invocation of this pipeline + with jax.named_scope("ep_prefetch"): + if prefetch is not None: + lax.cond(step == num_steps - 1, + lambda: prefetch(*brefs, scheduler), + lambda: None) + + # run the kernel! + current_refs = map_brefs(lambda x: x.current_ref, brefs) + with jax.named_scope("ep_run_kernel"): + with scheduler.grid_env(): + body(*current_refs, *scratches) + + # loop output handling phase + map_brefs(scheduler.wait_out, brefs, refs, schedule) + # handle writes for the *last* invocation of this pipeline's outputs + with jax.named_scope("ep_postyeet"): + if postyeet is not None: + lax.cond(step == 0, + lambda: postyeet(*brefs, scheduler), + lambda: None) + map_brefs(scheduler.copy_out, brefs, refs, schedule) + map_brefs(scheduler.finalize, brefs, refs, schedule) + + return () + + # run pipeline + lax.fori_loop(0, num_steps, loop_body, ()) + + return pipeline - return pipeline, lambda *args, **kwargs: tuple( - make_pipeline_allocations(*args, **kwargs)[0][0] - ) - -def emit_pipeline( - body: PipelineBody, +def emit_pipeline_with_allocations( + body, *, - grid: core.StaticGrid, - in_specs: PipelineBlockSpecs, - out_specs: PipelineBlockSpecs, - should_accumulate_out: Union[Sequence[bool], Any] = False, -) -> Pipeline: - """Wraps body function in a custom pipeline defined by grid and specs. - - This has the same semantics as pallas_call but is meant to be called inside - pallas_call for nesting grids. This is useful when you need to have separate - windowing strategies for example for communication vs. computation. - - By default outputs are written to but `should_accumulate_out` can be used to - specify which outputs we should add to instead. This is so we can reduce - across pipeline calls within and across parent grid iterations. + grid, + in_specs=None, + out_specs=None, + should_accumulate_out=False, +): + """Creates pallas pipeline and top-level allocation preparation functions. Args: - body: Pipeline body. - grid: Pallas grid. - in_specs: Input block specs. - out_specs: Output block specs. - should_accumulate_out: Prefix-pytree of out_specs specifying which should be - accumulated into with True. + body: pallas kernel to set up pipeline for. + grid: a pallas grid definition. + in_specs: input pallas block specs + out_specs: output pallas block specs + should_accumulate_out: booleans to indicate which outputs should be treated + as accumulators. Returns: - Wrapped pipelined body. + (emit_pipeline, make_allocations) function pair, where: + emit_pipeline is the pallas pipeline function. + make_allocations is a function to create buffered refs for the inner + pipeline that can be created at the top-level of a pallas call to be + reused across multiple invocations of the inner pipeline. + """ - return emit_pipeline_with_allocations( + make_allocations = functools.partial(make_pipeline_allocations, + in_specs=in_specs, + out_specs=out_specs, + should_accumulate_out=should_accumulate_out) + pipeline = emit_pipeline( body, grid=grid, in_specs=in_specs, out_specs=out_specs, - should_accumulate_out=should_accumulate_out, - )[0] + should_accumulate_out=should_accumulate_out) + + return pipeline, make_allocations diff --git a/jax/_src/pallas/mosaic/primitives.py b/jax/_src/pallas/mosaic/primitives.py index e7092684c544..f4c24e4e5e16 100644 --- a/jax/_src/pallas/mosaic/primitives.py +++ b/jax/_src/pallas/mosaic/primitives.py @@ -15,10 +15,10 @@ """Module for Pallas:TPU-specific JAX primitives and functions.""" from __future__ import annotations -import contextlib +from collections.abc import Callable import dataclasses import enum -from typing import Any, Callable +from typing import Any import jax from jax._src import api_util @@ -36,9 +36,12 @@ from jax._src.interpreters import partial_eval as pe from jax._src.pallas import core as pl_core from jax._src.pallas.mosaic import core as tpu_core +from jax._src.state import discharge as state_discharge from jax._src.typing import DTypeLike import jax.numpy as jnp +Slice = indexing.Slice + map, unsafe_map = util.safe_map, map zip, unsafe_zip = util.safe_zip, zip @@ -101,22 +104,18 @@ def _bitcast(x): mlir.register_lowering(bitcast_p, _bitcast_lowering_rule) -trace_start_p = jax_core.Primitive('trace_start') -trace_start_p.multiple_results = True - - roll_p = jax_core.Primitive("roll") def roll( x, - shift: int, + shift, axis: int, *, stride: int | None = None, stride_axis: int | None = None, ): - if shift < 0: + if isinstance(shift, int) and shift < 0: raise ValueError("shift must be non-negative.") if axis < 0 or axis >= len(x.shape): raise ValueError("axis is out of range.") @@ -130,19 +129,20 @@ def roll( if axis == stride_axis: raise ValueError("expected axis and stride_axis are different.") return roll_p.bind( - x, shift=shift, axis=axis, stride=stride, stride_axis=stride_axis + x, shift, axis=axis, stride=stride, stride_axis=stride_axis ) @roll_p.def_abstract_eval -def _roll_abstract_eval(x, **_): +def _roll_abstract_eval(x, shift, **_): + del shift return jax_core.raise_to_shaped(x) def _roll_lowering_rule( - ctx: mlir.LoweringRuleContext, x, *, shift, axis, stride, stride_axis + ctx: mlir.LoweringRuleContext, x, shift, *, axis, stride, stride_axis ): - def _roll(x): + def _roll(x, shift): if stride is None: return jnp.roll(x, shift, axis) outputs = [ @@ -151,45 +151,12 @@ def _roll(x): ] return jnp.concatenate(outputs, stride_axis) - return mlir.lower_fun(_roll, multiple_results=False)(ctx, x) + return mlir.lower_fun(_roll, multiple_results=False)(ctx, x, shift) mlir.register_lowering(roll_p, _roll_lowering_rule) -@trace_start_p.def_impl -def _trace_start_impl(*, message: str, level: int): - del message, level - return [] - -@trace_start_p.def_abstract_eval -def _trace_start_abstract_eval(*, message: str, level: int): - del message, level - return [] - -mlir.register_lowering(trace_start_p, lambda ctx, **_: []) - - -trace_stop_p = jax_core.Primitive('trace_stop') -trace_stop_p.multiple_results = True - -@trace_stop_p.def_impl -def _trace_stop_impl(): - return [] - -@trace_stop_p.def_abstract_eval -def _trace_stop_abstract_eval(): - return [] - -mlir.register_lowering(trace_stop_p, lambda ctx: []) - -@contextlib.contextmanager -def trace(message: str, level: int = 10): - trace_start_p.bind(message=message, level=level) - yield - trace_stop_p.bind() - - run_scoped_p = jax_core.Primitive('run_scoped') run_scoped_p.multiple_results = True @@ -222,6 +189,42 @@ class DeviceIdType(enum.Enum): LOGICAL = "logical" +def check_sem_avals(sem_aval, sem_indexers_avals, name): + if not isinstance(sem_aval, state.AbstractRef): + raise ValueError(f"Cannot {name} on a non-semaphore Ref: {sem_aval}") + sem_shape = sem_aval.shape + if sem_indexers_avals: + sem_shape = sem_indexers_avals[-1].get_indexer_shape() + if sem_shape: + raise ValueError(f"Cannot {name} on a non-()-shaped semaphore: {sem_shape}") + sem_dtype = sem_aval.dtype + if not ( + jnp.issubdtype(sem_dtype, tpu_core.semaphore) + or jnp.issubdtype(sem_dtype, tpu_core.barrier_semaphore) + ): + raise ValueError(f"Must {name} a REGULAR or BARRIER semaphore: {sem_dtype}") + + +semaphore_read_p = jax_core.Primitive("semaphore_read") +semaphore_read_p.multiple_results = False + + +def semaphore_read(sem_or_view): + ref, indexers = _get_ref_and_indexers(sem_or_view) + args = [ref, indexers] + flat_args, args_tree = tree_util.tree_flatten(args) + return semaphore_read_p.bind(*flat_args, args_tree=args_tree) + +@semaphore_read_p.def_abstract_eval +def _semaphore_read_abstract_eval( + *avals, + args_tree, +): + sem_aval, sem_indexers_avals = tree_util.tree_unflatten(args_tree, avals) + check_sem_avals(sem_aval, sem_indexers_avals, "read") + return jax_core.ShapedArray((), jnp.dtype("int32")) + + semaphore_signal_p = jax_core.Primitive('semaphore_signal') semaphore_signal_p.multiple_results = True @@ -232,10 +235,11 @@ def semaphore_signal( *, device_id: int | jax.Array | None | tuple[int | jax.Array, ...] = None, device_id_type: DeviceIdType = DeviceIdType.MESH, + core_index: int | jax.Array | None = None, ): ref, indexers = _get_ref_and_indexers(sem_or_view) inc = jnp.asarray(inc, dtype=jnp.int32) - args = [ref, indexers, inc, device_id] + args = [ref, indexers, inc, device_id, core_index] flat_args, args_tree = tree_util.tree_flatten(args) semaphore_signal_p.bind( *flat_args, @@ -251,20 +255,10 @@ def _semaphore_signal_abstract_eval( device_id_type: DeviceIdType, ): del device_id_type - sem_aval, sem_indexers_avals, value_aval, device_id_avals = ( + sem_aval, sem_indexers_avals, value_aval, device_id_avals, core_index_aval = ( tree_util.tree_unflatten(args_tree, avals) ) - if not isinstance(sem_aval, state.AbstractRef): - raise ValueError(f"Cannot signal on a non-Ref: {sem_aval}") - sem_shape = sem_aval.shape - if sem_indexers_avals: - sem_shape = sem_indexers_avals[-1].get_indexer_shape() - if sem_shape: - raise ValueError(f"Cannot signal on a non-()-shaped semaphore: {sem_shape}") - sem_dtype = sem_aval.dtype - if not (jnp.issubdtype(sem_dtype, tpu_core.semaphore) or jnp.issubdtype( - sem_dtype, tpu_core.barrier_semaphore)): - raise ValueError(f"Must signal a REGULAR or BARRIER semaphore: {sem_dtype}") + check_sem_avals(sem_aval, sem_indexers_avals, "signal") if value_aval.dtype != jnp.dtype("int32"): raise ValueError("Must signal an int32 value.") if device_id_avals is not None: @@ -286,6 +280,7 @@ def _semaphore_signal_pp_eqn(eqn: jax_core.JaxprEqn, sem_indexers, value, device_ids, + _, ) = tree_util.tree_unflatten(tree, invars) out = pp.concat([ pp.text('semaphore_signal'), @@ -319,19 +314,9 @@ def semaphore_wait(sem_or_view, dec: int | jax.Array = 1): @semaphore_wait_p.def_abstract_eval def _semaphore_wait_abstract_eval(*avals, args_tree): sem_aval, sem_indexers_avals, value_aval = tree_util.tree_unflatten(args_tree, avals) - if not isinstance(sem_aval, state.AbstractRef): - raise ValueError(f"Cannot signal on a non-semaphore Ref: {sem_aval}") - sem_shape = sem_aval.shape - if sem_indexers_avals: - sem_shape = sem_indexers_avals[-1].get_indexer_shape() - if sem_shape: - raise ValueError(f"Cannot signal on a non-()-shaped semaphore: {sem_shape}") - sem_dtype = sem_aval.dtype - if not (jnp.issubdtype(sem_dtype, tpu_core.semaphore) or jnp.issubdtype( - sem_dtype, tpu_core.barrier_semaphore)): - raise ValueError(f"Must signal a REGULAR or BARRIER semaphore: {sem_dtype}") + check_sem_avals(sem_aval, sem_indexers_avals, "wait") if value_aval.dtype != jnp.dtype("int32"): - raise ValueError("Must signal an int32 value.") + raise ValueError("Must wait an int32 value.") return [] def _semaphore_wait_pp_eqn(eqn: jax_core.JaxprEqn, @@ -480,6 +465,117 @@ def _dma_start_pp_eqn(eqn: jax_core.JaxprEqn, jax_core.pp_eqn_rules[dma_start_p] = _dma_start_pp_eqn +def dma_start_discharge_rule(in_avals, out_avals, + *args, tree, device_id_type): + ( + src_ref, + src_indexers, + dst_ref, + dst_indexers, + dst_sem, + dst_sem_indexers, + src_sem, + src_sem_indexers, + device_id, + ) = tree_util.tree_unflatten(tree, args) + del out_avals, dst_sem, dst_sem_indexers + is_remote = src_sem is not None and device_id is not None + if is_remote: + if device_id_type == DeviceIdType.MESH: + raise NotImplementedError("Mesh device_id_type not supported.") + else: + assert src_sem is None + assert src_sem_indexers is None + assert device_id is None + + def _find_slice_start_size(indexer): + num_scalar_idxs = 0 + # TODO(b/329733289): support strided load/store in interpret mode. + for s in indexer.indices: + if isinstance(s, Slice) and s.stride > 1: + raise NotImplementedError("Strides not supported in discharge" + " rule of dma_start.") + if not isinstance(s, Slice): + num_scalar_idxs += 1 + indices = indexer.indices + scalar_dims = [not isinstance(s, Slice) and not s.shape for s in indices] + slice_starts = [s.start if isinstance(s, Slice) else s for s in indices] + slice_sizes = tuple(s.size if isinstance(s, Slice) else 1 for s in indices) + return scalar_dims, slice_starts, slice_sizes, num_scalar_idxs + + num_src_index_vals = 0 + if src_indexers: + if len(src_indexers) != 1: + raise NotImplementedError("Multiple indexers not supported.") + idx = src_indexers[0] + if all((isinstance(s, Slice) or not s.shape) for s in idx.indices): + (_, slice_starts, + slice_sizes, num_scalar_idxs) = _find_slice_start_size(idx) + num_src_index_vals += num_scalar_idxs + updates = jax.lax.dynamic_slice( + src_ref, slice_starts, slice_sizes=slice_sizes) + else: + updates = src_ref[idx.indices] + else: + updates = src_ref + + if is_remote: + # Note that this code only works in SPMD mode. If not all devices execute + # the DMA then the devices that do will hang. + # TODO(justinfu): Verify that code only works in SPMD mode. + axis_env = jax_core.thread_local_state.trace_state.axis_env + axis_names = tuple(frame.name for frame in axis_env) + nonempty_axis_names = tuple(name for name in axis_names if name is not None) + if len(nonempty_axis_names) > 1: + raise NotImplementedError("Sharding with more than one named axis not " + "implemented in dma_start_p.") + shard_axis = nonempty_axis_names[0] + my_axis = jax.lax.axis_index(shard_axis) + # Update dst_ref from the perspective of the current device as the + # receiver. + who_copy_to_me = jax.lax.all_gather(device_id, shard_axis) == my_axis + # TODO(justinfu): Add a checkify for verifying there is at most one source. + # TODO(justinfu): Handle the case where no other device is copying to + # this device. + index = jnp.argmax(who_copy_to_me, axis=0) + global_updates = jax.lax.all_gather(updates, shard_axis) + updates = jax.lax.dynamic_index_in_dim( + global_updates, index, axis=0, keepdims=False) + + num_dst_index_vals = 0 + if dst_indexers: + if len(dst_indexers) != 1: + raise NotImplementedError("Multiple indexers not supported.") + idx = dst_indexers[0] + if all((isinstance(s, Slice) or not s.shape) for s in idx.indices): + (_, slice_starts, slice_sizes, + num_scalar_idxs) = _find_slice_start_size(idx) + num_dst_index_vals += num_scalar_idxs + if updates.shape != slice_sizes: + raise ValueError("DMA src and dst slices must have same shape. " + f"Got src={updates.shape}, dst={slice_sizes}") + new_dst = jax.lax.dynamic_update_slice( + dst_ref, updates, slice_starts) + else: + new_dst = dst_ref.at[idx.indices].set(updates) + else: + new_dst = updates + + # TODO(b/345505876): Implement semaphore counting. + new_avals = (None,) # src_aval + new_avals += (None,) * num_src_index_vals + new_avals += (new_dst,) # dst_aval + new_avals += (None,) * num_dst_index_vals + new_avals += (None,) # dst_sem_aval + if is_remote: + new_avals += (None, None) # src_sem_aval, device_id + assert (len(new_avals) == + len(in_avals)), f"{len(new_avals), new_avals} != {len(in_avals)}" + return new_avals, [] + +state_discharge.register_discharge_rule(dma_start_p)(dma_start_discharge_rule) + + dma_wait_p = jax_core.Primitive('dma_wait') dma_wait_p.multiple_results = True @@ -505,6 +601,13 @@ def _dma_wait_pp_eqn(eqn: jax_core.JaxprEqn, jax_core.pp_eqn_rules[dma_wait_p] = _dma_wait_pp_eqn +def dma_wait_discharge_rule(in_avals, out_avals, + *args, tree, device_id_type): + del out_avals, args, tree, device_id_type + # TODO(justinfu): Implement semaphore counting. + return (None,) * len(in_avals), [] +state_discharge.register_discharge_rule(dma_wait_p)(dma_wait_discharge_rule) + def _get_ref_and_indexers(ref): if isinstance(ref, state.RefView): return ref.ref, ref.indexers @@ -527,6 +630,24 @@ def async_copy(src_ref, dst_ref, sem): def make_async_remote_copy(src_ref, dst_ref, send_sem, recv_sem, device_id, device_id_type: DeviceIdType = DeviceIdType.MESH): + """Creates a description of a remote copy operation. + + Copies data from src_ref on the current device to dst_ref on the device + specified by device_id. Both semaphores should be waited on using the + descriptor on both source and target devices. + + Note that device_id can also refer to the current device. + + Args: + src_ref: The source Reference. + dst_ref: The destination Reference. + send_sem: The semaphore on the source device. + recv_sem: The semaphore on the destination device. + device_id: The device id of the destination device. + device_id_type: The type of the device id. + Returns: + An AsyncCopyDescriptor. + """ src_ref, src_indexers = _get_ref_and_indexers(src_ref) send_sem, send_sem_indexers = _get_ref_and_indexers(send_sem) dst_ref, dst_indexers = _get_ref_and_indexers(dst_ref) @@ -560,4 +681,68 @@ def _get_barrier_semaphore_abstract_eval(): ) def get_barrier_semaphore(): + """Returns a barrier semaphore. + + This function returns a barrier semaphore based on the collective_id of the + current pallas kernel. + + It's very important that the semaphore is wait-ed back down to 0, or else the + semaphores will become corrupted. + + It's also very important that the collective_id is different for each pallas + kernel with communication. E.g. if you have two pallas kernels, one that syncs + across the X axis of the device mesh and the second that syncs across the Y + axis, they must have different collective_ids. + However it is legal for two kernels that perform the same synchronization + pattern (e.g. only communicating with neighbours on the same mesh axis) + to share a collective_id. However, if in doubt, prefer not sharing + collective_ids, as doing so incorrectly can lead to silent data corruption or + crashes. + Note that re-using the same collective_id doesn't guarantee that the same + semaphore is provided by XLA. + """ return get_barrier_semaphore_p.bind() + +delay_p = jax_core.Primitive("delay") +delay_p.multiple_results = True + + +@delay_p.def_abstract_eval +def _delay_abstract_eval(nanos): + del nanos + return [] + + +def delay(nanos): + """Delays vector execution for the given number of nanosconds.""" + delay_p.bind(nanos) + + +# RNG Ops +prng_seed_p = jax_core.Primitive("prng_seed") +prng_seed_p.multiple_results = True + +@prng_seed_p.def_abstract_eval +def _(*_): + return [] + + +def prng_seed(*seeds: int | jax.Array) -> None: + """Sets the seed for PRNG. + + Args: + seeds: One or more integer seeds for setting the PRNG seed. If + more than one seed is passed in, the seed material will be + mixed before setting the internal PRNG state. + """ + prng_seed_p.bind(*seeds) + +prng_random_bits_p = jax_core.Primitive( + 'prng_random_bits') + +@prng_random_bits_p.def_abstract_eval +def _(*, shape): + return jax_core.ShapedArray(shape, jnp.dtype("int32")) + +def prng_random_bits(shape): + return prng_random_bits_p.bind(shape=shape) diff --git a/jax/_src/pallas/mosaic/random.py b/jax/_src/pallas/mosaic/random.py new file mode 100644 index 000000000000..c642d99578cd --- /dev/null +++ b/jax/_src/pallas/mosaic/random.py @@ -0,0 +1,211 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections.abc import Callable + +import jax +import numpy as np +from jax import numpy as jnp +from jax import random as jax_api_random +from jax._src import blocked_sampler +from jax._src import typing +from jax._src.pallas.mosaic.primitives import prng_seed +from jax._src.pallas.mosaic.primitives import prng_random_bits +from jax._src.pallas import primitives +from jax._src import prng as jax_prng + + +Shape = jax_prng.Shape +SampleFnType = blocked_sampler.SampleFn +KeylessSampleFnType = Callable[..., jax.Array] + +set_seed = prng_seed + +FOLD_IN_ROUNDS = 128 + + +def to_pallas_key(key: jax_prng.PRNGKeyArray) -> jax_prng.PRNGKeyArray: + """Helper function for converting non-Pallas PRNG keys into Pallas keys.""" + batch_dims = key.shape + key_data = jax_api_random.key_data(key) + pallas_key_size = np.prod(tpu_key_impl.key_shape) + if np.prod(key_data.shape[len(batch_dims):]) < pallas_key_size: + raise ValueError(f"Key data must be at least {pallas_key_size} bytes.") + pallas_key_data = jnp.reshape( + key_data, batch_dims + (-1,))[..., :pallas_key_size] + pallas_key_data = jnp.reshape(pallas_key_data, + batch_dims + tpu_key_impl.key_shape) + return jax_api_random.wrap_key_data(pallas_key_data, impl="pallas_tpu") + +def _seed_func(seed: jnp.int32): + seed_data = jnp.zeros(tpu_key_impl.key_shape, dtype=jnp.int32) + return (seed_data + seed).astype(jnp.uint32) + +def _random_bits(key: typing.Array, bit_width: int, shape: Shape): + if bit_width != 32: + raise ValueError("Bit width must be 32") + prng_seed(key) + return prng_random_bits(shape) + +def _fold_in(key: jax_prng.PRNGKeyArray, data: typing.Array): + # Roughly, we compute the new key as follows: + # new_key = random_bits(data)[..., 127] ^ random_bits(old_key)[..., 127] + # Because the TPU generates random numbers in (8, 128) blocks at once, we + # can generate that many values without additional cost which will reduce + # correlation between the old and new keys. + key_shape = tpu_key_impl.key_shape + + prng_seed(data) + data_bits = prng_random_bits( + key_shape + (FOLD_IN_ROUNDS,)).astype(jnp.uint32) + prng_seed(key) + key_bits = prng_random_bits( + key_shape + (FOLD_IN_ROUNDS,)).astype(jnp.uint32) + + mixed = key_bits[..., FOLD_IN_ROUNDS-1] ^ data_bits[..., FOLD_IN_ROUNDS-1] + assert mixed.shape == key_shape + return jax.random.wrap_key_data(mixed, impl="pallas_tpu") + +def _split(key: typing.Array, shape: Shape): + del key, shape + raise NotImplementedError() + +tpu_key_impl = jax_prng.PRNGImpl( + # Pallas currently only supports 2D+ windows, so set the key_shape + # to be 2D to have better compatibility with setting BlockSpecs. + key_shape=(1, 1), + seed=_seed_func, + split=_split, + random_bits=_random_bits, + fold_in=_fold_in, + name="pallas_tpu", + tag="pl" +) +jax_prng.register_prng(tpu_key_impl) + +# Implementation of the stateful Pallas PRNG API. +# Users should set the seed using the `set_seed` function, +# and call the appropriate stateful sampling functions. +# The actual key impl should never be used. The impl +# serves as internal boilerplate code because JAX's existing +# random functions expect a key as an argument, and +# the keys are only generated as part of unused arguments. + +def _pl_stateful_seed_func(seed: jnp.int32): + del seed + # Unused. Return the correct shape and dtype. + return jnp.empty((), dtype=jnp.int32) + +def _pl_stateful_random_bits(key: typing.Array, bit_width: int, shape: Shape): + del key + assert bit_width == 32, "Bit width must be 32" + return prng_random_bits(shape) + +def _pl_stateful_fold_in(key: typing.Array, data: typing.Array): + del key, data + raise NotImplementedError() + +def _pl_stateful_split(key: typing.Array, shape: Shape): + del key, shape + raise NotImplementedError() + + +tpu_internal_stateful_impl = jax_prng.PRNGImpl( + key_shape=(), + seed=_pl_stateful_seed_func, + split=_pl_stateful_split, + random_bits=_pl_stateful_random_bits, + fold_in=_pl_stateful_fold_in, + name="_pallas_internal_stateful", + tag="_pl_stateful" +) +jax_prng.register_prng(tpu_internal_stateful_impl) + +def _make_stateful_sampler(sampler: SampleFnType) -> KeylessSampleFnType: + """Converts a jax.random sampling function to a stateful version. + + Args: + sampler: A sampling function that consumes a key and returns + random samples. + + Returns: + A stateful sampling function with the key argument removed. + """ + def new_sampler(*args, **kwargs): + # Pass in a placeholder key into the sampling function. + # The key is ignored by the stateful random_bits function, but all jax + # sampling functions expect a key as input so we must pass one in here. + placeholder_key = jax_api_random.key(0, impl=tpu_internal_stateful_impl) + return sampler(placeholder_key, *args, **kwargs) + # Remove key argument from docstring. + if sampler.__doc__: + doc_lines = filter( + lambda line: "key:" not in line, sampler.__doc__.split("\n")) + new_sampler.__doc__ = "\n".join(doc_lines) + return new_sampler + +bits = _make_stateful_sampler(jax_api_random.bits) # type: ignore +uniform = _make_stateful_sampler(jax_api_random.uniform) # type: ignore +bernoulli = _make_stateful_sampler(jax_api_random.bernoulli) # type: ignore + + +def sample_block(sampler_fn: SampleFnType, + global_key: jax_prng.PRNGKeyArray, + block_size: Shape, + tile_size: Shape, + total_size: Shape, + block_index: tuple[typing.ArrayLike, ...] | None = None, + **kwargs) -> jax.Array: + """Samples a block of random values with invariance guarantees. + + `sample_block` allows the sampling of identical blocks of random values + across kernels with different block shapes and iteration orders. Each call + to `sample_block` returns a `block_size`-shaped array of random samples + corresponding to the `block_index`. + + `tile_size` should be chosen such that it is a divisor to all block sizes + one needs to be invariant to. The larger the `tile_size`, the more + efficient the sampling process wil be and therefore the best choice is + typically the greatest common divisor between all possible block sizes. + + Args: + sampler_fn: A sampling function that consumes a key and returns + random samples. + global_key: The global key to use for sampling. + block_size: The shape of an individual block. + tile_size: The shape of a `tile`, which is the smallest unit at + which samples are generated. This should be selected to be a divisor + of all block sizes one needs to be invariant to. + total_size: The total size of the array to sample. + block_index: The index denoting which block to generate keys for. Defaults + to the program_id for each block axis. + **kwargs: Additional arguments to pass to the sampler_fn. + + Returns: + A `block_size` shaped array of samples for the current block corresponding + to `block_index`. + """ + if len(block_size) != len(tile_size): + raise ValueError(f"block_size ({len(block_size)}) and tile_size " + f"({len(tile_size)}) must have the same length.") + + if block_index is None: + num_axes = len(block_size) + block_index = tuple( + primitives.program_id(axis) for axis in range(num_axes)) + + keys = blocked_sampler.blocked_fold_in( + global_key, total_size, block_size, tile_size, block_index) + return blocked_sampler.sample_block( + sampler_fn, keys, block_size, tile_size, **kwargs) diff --git a/jax/_src/pallas/mosaic_gpu/BUILD b/jax/_src/pallas/mosaic_gpu/BUILD new file mode 100644 index 000000000000..6f39f2686d4f --- /dev/null +++ b/jax/_src/pallas/mosaic_gpu/BUILD @@ -0,0 +1,63 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Package for Mosaic-specific Pallas extensions + +load("@rules_python//python:defs.bzl", "py_library") +load( + "//jaxlib:jax.bzl", + "py_deps", + "pytype_strict_library", +) + +package( + default_applicable_licenses = [], + default_visibility = [ + "//:__subpackages__", + ], +) + +py_library( + name = "mosaic_gpu", + srcs = ["__init__.py"], + deps = [ + ":pallas_call_registration", + ], +) + +pytype_strict_library( + name = "pallas_call_registration", + srcs = ["pallas_call_registration.py"], + deps = [ + ":lowering", + "//jax", + "//jax:mlir", + "//jax:mosaic_gpu", + "//jax/_src/pallas", + ], +) + +pytype_strict_library( + name = "lowering", + srcs = ["lowering.py"], + deps = [ + "//jax", + "//jax:core", + "//jax:mlir", + "//jax:mosaic_gpu", + "//jax:util", + "//jax/_src/lib", + "//jax/_src/pallas", + ] + py_deps("numpy"), +) diff --git a/jax/_src/pallas/mosaic_gpu/__init__.py b/jax/_src/pallas/mosaic_gpu/__init__.py new file mode 100644 index 000000000000..862a661e24b9 --- /dev/null +++ b/jax/_src/pallas/mosaic_gpu/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py new file mode 100644 index 000000000000..5b4db68f2552 --- /dev/null +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -0,0 +1,419 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Module for lowering JAX primitives to Mosaic GPU.""" + +from __future__ import annotations + +from collections.abc import Sequence +import dataclasses +import functools +import math +from typing import Any, cast + +import jax +from jax._src import core as jax_core +from jax._src import pjit +from jax._src import util +from jax._src.interpreters import mlir +from jax._src.lax import lax +from jax._src.lib.mlir import ir +from jax._src.lib.mlir.dialects import arith as arith_dialect +from jax._src.lib.mlir.dialects import memref as memref_dialect +from jax._src.lib.mlir.dialects import nvgpu as nvgpu_dialect +from jax._src.pallas import core as pl_core +from jax._src.pallas import primitives +from jax._src.state import primitives as sp +from jax.experimental.mosaic import gpu as mosaic_gpu +from jax.experimental.mosaic.gpu import dsl as mgpu +import jax.numpy as jnp +import numpy as np + + +# TODO(slebedev): Enable type checking. +# mypy: ignore-errors +# pytype: skip-file + +map, unsafe_map = util.safe_map, map +zip, unsafe_zip = util.safe_zip, zip + +partial = functools.partial + + +@dataclasses.dataclass +class ModuleContext: + name: str + grid_mapping: pl_core.GridMapping + runtime_smem: ir.Value # ir.MemRefType + + def scratch_view(self, shapes: list[jax.ShapeDtypeStruct]) -> list[ir.Value]: + """Return memref views into the runtime scrath based on the shapes.""" + + smem_scratch_bytes = math.prod(ir.MemRefType(self.runtime_smem.type).shape) + required_scratch_bytes = sum( + math.prod(sh.shape) * jnp.dtype(sh.dtype).itemsize for sh in shapes + ) + if smem_scratch_bytes < required_scratch_bytes: + raise ValueError( + f"Too few {smem_scratch_bytes=} provided (pass via compiler_params), we" + f" need {required_scratch_bytes} ({shapes=})" + ) + + views = [] + off = 0 + smem = ir.Attribute.parse("#gpu.address_space") + for sh in shapes: + sh_bytes = math.prod(sh.shape) * jnp.dtype(sh.dtype).itemsize + strides = (*np.cumprod(sh.shape)[:-1:-1], 1) + + # We need scratch to be able to store 128 items of x. + scratch = memref_dialect.subview( + self.runtime_smem, + offsets=[_index(off)], + sizes=[_index(sh_bytes)], + strides=[_index(i) for i in strides], + ) + scratch_ty = ir.MemRefType.get( + [np.prod(sh.shape)], mlir.dtype_to_ir_type(sh.dtype), memory_space=smem + ) + off += sh_bytes + views.append(memref_dialect.view(scratch_ty, scratch, _index(off), [])) + + return views + + +@dataclasses.dataclass +class LoweringRuleContext: + module_context: ModuleContext + avals_in: Sequence[jax_core.ShapedArray] + avals_out: Sequence[jax_core.ShapedArray] + block_shapes: list[tuple[int | pl_core.Mapped, ...]] | None + + replace = dataclasses.replace + + +@dataclasses.dataclass +class LoweringResult: + module: ir.Module + grid: tuple[int, ...] + gmem_scratch_bytes: int + out_structs: tuple[jax.ShapeDtypeStruct, ...] + + +@dataclasses.dataclass +class BlockInfo: + full_shape_dtype: jax.ShapeDtypeStruct + start_indices: Sequence[Any] + block_shape: tuple[int, ...] + + +class LoweringError(Exception): + pass + + +def lower_jaxpr_to_module( + grid_mapping: pl_core.GridMapping, + in_structs: tuple[jax.ShapeDtypeStruct, ...], + out_structs: tuple[jax.ShapeDtypeStruct, ...], + jaxpr: jax_core.Jaxpr, + name: str, + compiler_params: dict[str, Any], +) -> LoweringResult: + assert len(jaxpr.outvars) == 0 + assert not grid_mapping.mapped_dims + grid = grid_mapping.grid + if len(grid) < 3: + grid += (1,) * (3 - len(grid)) + block = (128,) + (1,) * (len(grid) - 1) + + def body(launch_ctx: mosaic_gpu.LaunchContext, *buffers): + *buffers_gmem, (*buffers_smem, runtime_smem) = buffers + assert len(buffers_gmem) == len(buffers_smem) + in_buffers_gmem = buffers_gmem[: len(in_structs)] + in_buffers_smem = buffers_smem[: len(in_structs)] + out_buffers_gmem = buffers_gmem[len(in_structs) :] + out_buffers_smem = buffers_smem[len(in_structs) :] + + # arrival_count= determines the expected number of arrivals for each + # barrier in the array. It is not accidental that we do just a single + # mbarrier_arrive_expect_tx below. + # TODO(slebedev): Consider enforcing this in the mgpu.BarrierArray. + [barrier] = mgpu.BarrierArray(1, arrival_count=1) + + with mgpu.single_thread(): + nvgpu_dialect.mbarrier_arrive_expect_tx( + barrier.barrier_array.value, + _index( + sum(math.prod(s.shape) * s.dtype.itemsize for s in in_structs) + ), + barrier.offset, + ) + + for b_gmem, b_smem in zip(in_buffers_gmem, in_buffers_smem): + # TODO(slebedev): Support 128-byte swizzling, once we can lower matmuls. + launch_ctx.async_copy( + src_ref=b_gmem, + dst_ref=b_smem, + barrier=barrier, + swizzle=None, + arrive=False, + uniform=False, + ) + + barrier.wait() + + module_ctx = ModuleContext(name, grid_mapping, runtime_smem) + _ = lower_jaxpr_to_mosaic_gpu(module_ctx, jaxpr, None, *buffers_smem) + + for b_gmem, b_smem in zip(out_buffers_gmem, out_buffers_smem): + # TODO(slebedev): Support 128-byte swizzling, once we can lower matmuls. + launch_ctx.async_copy(src_ref=b_smem, dst_ref=b_gmem, swizzle=None) + + launch_ctx.await_async_copy(0) + + extra_smem_scratch = [ + jax.ShapeDtypeStruct( + shape=[compiler_params.get("smem_scratch_bytes", 0)], dtype=np.int8 + ) + ] + module, out_structs, gmem_scratch_bytes, _ = mosaic_gpu._lower_as_gpu_kernel( + body, + grid, + block, + in_shapes=in_structs, + out_shape=out_structs, + smem_scratch_shape=(*in_structs, *out_structs, *extra_smem_scratch), + ) + + return LoweringResult(module, grid, gmem_scratch_bytes, out_structs) + + +mosaic_lowering_rules = {} + + +def register_lowering_rule(primitive: jax_core.Primitive): + def deco(fn): + mosaic_lowering_rules[primitive] = fn + return fn + + return deco + + +def lower_jaxpr_to_mosaic_gpu( + ctx: ModuleContext, + jaxpr: jax_core.Jaxpr, + block_infos: Sequence[BlockInfo | None] | None, + *args, +) -> Sequence[ir.Value]: + env = {} + block_info_env = {} + + def read_env(atom: jax_core.Atom): + return atom.val if isinstance(atom, jax_core.Literal) else env[atom] + + def read_block_info_env(atom: jax_core.Atom): + if isinstance(atom, jax_core.Literal): + return None + return block_info_env.get(atom, None) + + def write_env(var: jax_core.Var, val): + env[var] = val + + if block_infos is None: + block_infos = [None] * len(jaxpr.invars) + for invar, block_info in zip(jaxpr.invars, block_infos): + block_info_env[invar] = block_info + map(write_env, jaxpr.invars, args) + for eqn in jaxpr.eqns: + invals = map(read_env, eqn.invars) + if eqn.primitive not in mosaic_lowering_rules: + raise NotImplementedError( + "Unimplemented primitive in Pallas Mosaic GPU lowering: " + f"{eqn.primitive.name}. " + "Please file an issue on https://github.com/google/jax/issues." + ) + rule = mosaic_lowering_rules[eqn.primitive] + rule_ctx = LoweringRuleContext( + ctx, + avals_in=[cast(jax_core.ShapedArray, v.aval) for v in eqn.invars], + avals_out=[cast(jax_core.ShapedArray, v.aval) for v in eqn.outvars], + block_shapes=map(read_block_info_env, eqn.invars), + ) + try: + outvals = rule(rule_ctx, *invals, **eqn.params) + except LoweringError: + raise # We only add the extra info to the innermost exception. + except Exception as e: + inval_types = map(lambda t: getattr(t, "type", None), invals) + raise LoweringError( + f"Exception while lowering eqn:\n {eqn}\nWith context:\n " + f" {rule_ctx}\nWith inval types={inval_types}\nIn jaxpr:\n{jaxpr}" + ) from e + if eqn.primitive.multiple_results: + map(write_env, eqn.outvars, outvals) + else: + write_env(eqn.outvars[0], outvals) + return map(read_env, jaxpr.outvars) + + +@register_lowering_rule(sp.get_p) +def _get_lowering_rule(ctx: LoweringRuleContext, x_smem, *indexers, tree): + del tree # Unused. + if indexers: + raise NotImplementedError("No support for indexers yet") + return mgpu.FragmentedArray.load_strided(x_smem) + + +@register_lowering_rule(sp.swap_p) +def _swap_lowering_rule( + ctx: LoweringRuleContext, x_smem, value, *indexers, tree +): + del tree # Unused. + if indexers: + raise NotImplementedError("No support for indexers yet") + old_value = mgpu.FragmentedArray.load_strided(x_smem) + value.store_untiled(x_smem) + return old_value + + +@register_lowering_rule(pjit.pjit_p) +def _pjit_lowering_rule(ctx: LoweringRuleContext, *args, jaxpr, **_): + if jaxpr.consts: + raise NotImplementedError + return lower_jaxpr_to_mosaic_gpu(ctx.module_context, jaxpr.jaxpr, None, *args) + + +@register_lowering_rule(lax.broadcast_in_dim_p) +def _broadcast_in_dim_lowering_rule( + ctx: LoweringRuleContext, + x: mgpu.FragmentedArray, + *, + broadcast_dimensions, + shape, +): + if broadcast_dimensions: + raise NotImplementedError + return x.broadcast(shape) + + +@register_lowering_rule(lax.convert_element_type_p) +def _convert_element_type_lowering_rule( + ctx: LoweringRuleContext, x, *, new_dtype, weak_type +): + return _ensure_fa(x, *ctx.avals_in).astype(mlir.dtype_to_ir_type(new_dtype)) + + +def _binary_op_lowering_rule(ctx: LoweringRuleContext, x, y, *, impl): + x, y = _bcast(x, y, *ctx.avals_in, *ctx.avals_out) + return impl(x, y) + + +mosaic_lowering_rules.update({ + lax.add_p: partial(_binary_op_lowering_rule, impl=lambda x, y: x + y), + lax.sub_p: partial(_binary_op_lowering_rule, impl=lambda x, y: x - y), + lax.mul_p: partial(_binary_op_lowering_rule, impl=lambda x, y: x * y), + lax.div_p: partial(_binary_op_lowering_rule, impl=lambda x, y: x / y), +}) + + +@register_lowering_rule(lax.integer_pow_p) +def _integer_pow_lowering_rule(ctx: LoweringRuleContext, x, y): + x = _ensure_fa(x, *ctx.avals_in) + if y == 2: + return x * x + return NotImplementedError + + +@register_lowering_rule(lax.rsqrt_p) +def _rsqrt_lowering_rule(ctx: LoweringRuleContext, x): + return _ensure_fa(x, *ctx.avals_in).rsqrt() + + +@register_lowering_rule(lax.reduce_sum_p) +def _reduce_sum_lowering_rule(ctx: LoweringRuleContext, x, *, axes): + if axes != (0,): + raise NotImplementedError("No support for axes other than 0 yet") + [x_aval] = ctx.avals_in + [scratch] = ctx.module_context.scratch_view( + [jax.ShapeDtypeStruct(shape=(4,), dtype=x_aval.dtype)] + ) + return mgpu.FragmentedArray.splat(x.reduce_sum(scratch), ()) + + +@register_lowering_rule(primitives.debug_print_p) +def _debug_print_lowering_rule( + ctx: LoweringRuleContext, + *args, + fmt, + has_placeholders: bool, +): + del has_placeholders + primitives.check_debug_print_format(fmt, *args) + mgpu.debug_print(fmt, *args) + return () + + +def _bcast( + x: ir.Value, + y: ir.Value, + x_aval: jax_core.ShapedArray, + y_aval: jax_core.ShapedArray, + out_aval: jax_core.ShapedArray, +) -> ir.Value: + if isinstance(x, (np.ndarray, np.number, int, float)): + x_dtype = x_aval.dtype + if x_aval.weak_type: + x_dtype = y_aval.dtype + x = mgpu.FragmentedArray.splat( + _ir_constant(x, mlir.dtype_to_ir_type(x_dtype)), () + ) + if isinstance(y, (np.ndarray, np.number, int, float)): + y_dtype = y_aval.dtype + if y_aval.weak_type: + y_dtype = x_aval.dtype + y = mgpu.FragmentedArray.splat( + _ir_constant(y, mlir.dtype_to_ir_type(y_dtype)), () + ) + assert isinstance(x, mgpu.FragmentedArray) + assert isinstance(y, mgpu.FragmentedArray) + if x_aval.shape != out_aval.shape: + x = x.broadcast(out_aval.shape) + if y_aval.shape != out_aval.shape: + y = y.broadcast(out_aval.shape) + return x, y + + +def _ensure_fa(x: object, aval: jax_core.ShapedArray) -> mgpu.FragmentedArray: + if isinstance(x, mgpu.FragmentedArray): + return x + elif isinstance(x, (np.number, np.ndarray,int, float)): + return mgpu.FragmentedArray.splat( + _ir_constant(x, mlir.dtype_to_ir_type(aval.dtype)), () + ) + raise NotImplementedError + + +def _ir_constant(v: object, t: ir.Type) -> ir.Value: + if isinstance(v, (np.number, np.ndarray, int, float)): + if isinstance(t, ir.IntegerType): + v = int(v) + else: + assert isinstance(t, ir.FloatType) + v = float(v) + return arith_dialect.constant(t, v) + raise NotImplementedError(f"Unsupported constant: {v!r}") + + +def _index(i: int) -> ir.Value: + return arith_dialect.constant(ir.IndexType.get(), int(i)) diff --git a/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py b/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py new file mode 100644 index 000000000000..740f0c31ebb7 --- /dev/null +++ b/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py @@ -0,0 +1,89 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Module registering a lowering rule for pallas_call on GPU.""" + + +from __future__ import annotations + +from typing import Any + +import jax +from jax import core as jax_core +from jax._src.interpreters import mlir +from jax._src.pallas import core as pallas_core +from jax._src.pallas.mosaic_gpu import lowering +from jax._src.pallas.pallas_call import pallas_call_p +from jax.experimental.mosaic import gpu as mosaic_gpu + + +def pallas_call_lowering( + ctx: mlir.LoweringRuleContext, + *args, + jaxpr: jax_core.Jaxpr, + name: str, + in_shapes: tuple[jax.ShapeDtypeStruct, ...], + out_shapes: tuple[jax.ShapeDtypeStruct, ...], + interpret: bool, + debug: bool, + input_output_aliases: tuple[tuple[int, int], ...], + grid_mapping: pallas_core.GridMapping, + compiler_params: dict[str, Any], +): + if interpret: + return mlir.lower_fun(pallas_call_p.impl, multiple_results=True)( + ctx, + *args, + jaxpr=jaxpr, + name=name, + out_shapes=out_shapes, + in_shapes=in_shapes, + interpret=interpret, + debug=debug, + input_output_aliases=input_output_aliases, + grid_mapping=grid_mapping, + compiler_params=compiler_params, + ) + + if grid_mapping.num_dynamic_grid_bounds: + raise NotImplementedError( + "dynamic grid bounds not supported in the Mosaic GPU backend" + ) + if input_output_aliases: + raise NotImplementedError( + "input_output_aliases not supported in the Mosaic GPU backend" + ) + + if debug: + print(jaxpr) + print(grid_mapping) + + lowering_result = lowering.lower_jaxpr_to_module( + grid_mapping, + in_shapes, + out_shapes, + jaxpr, + name, + compiler_params, + ) + if debug: + print(lowering_result.module.operation) + + return mosaic_gpu._mosaic_gpu_lowering_rule( + ctx, + *args, + module=lowering_result.module, + gmem_scratch_bytes=lowering_result.gmem_scratch_bytes, + out_types=lowering_result.out_structs, + ) diff --git a/jax/_src/pallas/pallas_call.py b/jax/_src/pallas/pallas_call.py index ef6067841022..41b2d2e61c10 100644 --- a/jax/_src/pallas/pallas_call.py +++ b/jax/_src/pallas/pallas_call.py @@ -15,44 +15,50 @@ """Module for calling pallas functions from JAX.""" from __future__ import annotations +from collections.abc import Callable, Sequence +from functools import partial, reduce import itertools -from functools import partial -from functools import reduce - -from typing import Any, Callable -from collections.abc import Sequence +from typing import Any import jax from jax import api_util -from jax import tree_util from jax import lax +from jax import tree_util +from jax._src import ad_util +from jax._src import checkify +from jax._src import config +from jax._src import core as jax_core +from jax._src import effects +from jax._src import linear_util as lu from jax._src import state from jax._src.interpreters import ad from jax._src.interpreters import batching -from jax._src.interpreters import partial_eval as pe from jax._src.interpreters import mlir +from jax._src.interpreters import partial_eval as pe from jax._src.interpreters import xla -from jax._src import ad_util -from jax._src import core as jax_core -from jax._src.state import primitives as sp -from jax._src import linear_util as lu +from jax._src.pallas import core as pallas_core +from jax._src.pallas.primitives import uninitialized_value from jax._src.state import discharge as state_discharge +from jax._src.state import primitives as sp from jax._src.util import ( - split_list, safe_map, safe_zip, weakref_lru_cache, - tuple_insert, partition_list, merge_lists) + safe_map, + safe_zip, + split_list, + tuple_insert, + weakref_lru_cache, +) import jax.numpy as jnp import numpy as np -from jax._src.pallas import core as pallas_core - map, unsafe_map = safe_map, map zip, unsafe_zip = safe_zip, zip Grid = pallas_core.Grid -BlockSpec = pallas_core.BlockSpec GridSpec = pallas_core.GridSpec BlockMapping = pallas_core.BlockMapping GridMapping = pallas_core.GridMapping +BlockSpec = pallas_core.BlockSpec +BlockSpecTree = pallas_core.BlockSpecTree NoBlockSpec = pallas_core.NoBlockSpec no_block_spec = pallas_core.no_block_spec @@ -83,12 +89,70 @@ def _maybe_dynamic_update_slice(start_idx, block_shape, value, update, assert update.shape == block_shape return lax.dynamic_update_slice(value, update, start_idx) -def _uninitialized_value(shape, dtype): - if jnp.issubdtype(dtype, jnp.floating): - return jnp.full(shape, jnp.nan, dtype) - elif jnp.issubdtype(dtype, jnp.integer): - return jnp.full(shape, jnp.iinfo(dtype).min, dtype) - raise NotImplementedError(dtype) +def _pad_values_to_block_dimension(value, + block_shape): + """Pads values so the shape evenly divides into block dimensions. + + For example, if values has a shape of (33, 2, 5) with a block_shape of + (32, 2, 4), this function will pad the value of shape to (64, 2, 8). + + Args: + value: Array to be padded. + block_shape: Block shapes to use for padding. If None, no padding will + be performed. + + Returns: + A padded array. + """ + if block_shape is None: + return value + padded_shape = tuple( + ((v - 1) // b + 1) * b for v, b in zip(value.shape, block_shape) + ) + if padded_shape != value.shape: + pad_width = tuple((0, a-b) for a, b in zip(padded_shape, value.shape)) + pad_value = uninitialized_value(shape=(), dtype=value.dtype) + value = jnp.pad(value, pad_width, constant_values=pad_value) + return value + +def _initialize_scratch_vals(scratch_avals) -> tuple[jax.Array, ...]: + scratch_avals = (jax_core.raise_to_shaped(x) for x in scratch_avals) + return tuple(uninitialized_value(a.shape, a.dtype) for a in scratch_avals) + +def _initialize_output_vals( + out_shapes, input_args, input_output_aliases) -> Sequence[jax.Array]: + oi_map = {v: k for k, v in input_output_aliases} + output_vals = [] + for i, out_shape in enumerate(out_shapes): + if i in oi_map: + output_vals.append(input_args[oi_map[i]]) + else: + output_vals.append(uninitialized_value(out_shape.shape, out_shape.dtype)) + return output_vals + +def _logical_to_interpret_mode_dtype(dtype): + """Converts logical dtypes into JAX dtypes for interpret mode. + + This function is used to convert device-specific dtypes that have no + corresponding equivalent in JAX/XLA into a type that can be executed + by the XLA interpreter (e.g. TPU semaphores -> int32). + """ + if (hasattr(dtype, "_rules") and + hasattr(dtype._rules, "pallas_interpret_element_aval")): + return dtype._rules.pallas_interpret_element_aval(dtype).dtype + return dtype + +def _logical_aval_to_interpret_mode_aval(aval): + """Logical to interpret mode aval conversion.""" + if isinstance(aval, pallas_core.AbstractMemoryRef): + inner_aval = _logical_aval_to_interpret_mode_aval(aval.inner_aval) + return aval.update(inner_aval=inner_aval) + if isinstance(aval, jax_core.ShapedArray): + inner_dtype = _logical_to_interpret_mode_dtype(aval.dtype) + return jax_core.ShapedArray(aval.shape, + inner_dtype, + weak_type=aval.weak_type, named_shape=aval.named_shape) + return aval def _get_next_indices(grid, indices): next_indices = [] @@ -99,7 +163,7 @@ def _get_next_indices(grid, indices): next_indices.append(jnp.where(carry, 0, i)) return tuple(reversed(next_indices)) -def _pallas_call_impl(*args, jaxpr, name, out_shapes, which_linear, +def _pallas_call_impl(*args, jaxpr, name, out_shapes, interpret, debug: bool, in_shapes, input_output_aliases: tuple[tuple[int, int], ...], @@ -114,21 +178,16 @@ def _pallas_call_impl(*args, jaxpr, name, out_shapes, which_linear, # will do. dynamic_grid_args_iter = iter(dynamic_grid_args) grid = tuple( - a if a is not None else next(dynamic_grid_args_iter) + a if a is not pallas_core.dynamic_grid_dim + else next(dynamic_grid_args_iter) for a in grid_mapping.grid ) assert next(dynamic_grid_args_iter, None) is None - discharged_jaxpr, consts = state_discharge.discharge_state(jaxpr, ()) + with grid_mapping.trace_env(): + discharged_jaxpr, consts = state_discharge.discharge_state(jaxpr, ()) if debug: print(discharged_jaxpr) - oi_map = {v: k for k, v in input_output_aliases} - out = [] - for i, out_shape in enumerate(out_shapes): - if i in oi_map: - out.append(args[oi_map[i]]) - else: - # TODO(sharadmv): use unitialized values for outputs - out.append(jnp.zeros(out_shape.shape, out_shape.dtype)) + out = _initialize_output_vals(out_shapes, args, input_output_aliases) scalars, args = split_list(args, [grid_mapping.num_index_operands]) # type: ignore # invars: [*scalar_prefetch, *inputs, *outputs, *scratch] num_invars = len(jaxpr.invars) @@ -141,12 +200,7 @@ def _pallas_call_impl(*args, jaxpr, name, out_shapes, which_linear, jaxpr.invars, [grid_mapping.num_index_operands, num_inputs_outputs] ) scratch_avals = [v.aval for v in scratch_invars] - if not all( - hasattr(a, "shape") and hasattr(a, "dtype") for a in scratch_avals - ): - raise NotImplementedError(f"Cannot initialize scratch: {scratch_avals}") - scratch_values = [_uninitialized_value(a.shape, a.dtype) - for a in scratch_avals] + scratch_values = _initialize_scratch_vals(scratch_avals) carry = [] for x, bm in zip(itertools.chain(args, out), grid_mapping.block_mappings): @@ -157,7 +211,29 @@ def _pallas_call_impl(*args, jaxpr, name, out_shapes, which_linear, raise NotImplementedError("Padding with aliasing not supported.") x = lax.pad(x, jnp.zeros((), x.dtype), [(*p, 0) for p in padding]) carry.append(x) + + block_shapes_without_mapped_dims = [ + None if block_mapping is None else block_mapping.block_shape + for block_mapping in grid_mapping.block_mappings + ] + is_indexing_dim = [ + None if bm is None else tuple(b is pallas_core.mapped for b in bm) + for bm in block_shapes_without_mapped_dims + ] + block_shapes = [ + None if (bm is None or iid is None) + else tuple(1 if i else b for i, b in zip(iid, bm)) + for iid, bm in zip(is_indexing_dim, block_shapes_without_mapped_dims) + ] + + # Pad values to evenly divide into block dimensions. + # This allows interpret mode to catch errors on OOB memory accesses + # by poisoning values with NaN. It also fixes an inconsistency with + # lax.dynamic_slice where if the slice goes out of bounds, it will instead + # move the start_index backwards so the slice will fit in memory. + carry = map(_pad_values_to_block_dimension, carry, block_shapes) carry.extend(scratch_values) + num_inout = len(args) + len(out) grid_start_indices = (jnp.int32(0),) * len(grid) if grid: @@ -171,7 +247,7 @@ def cond(carry): def body(carry): i, loop_idx, *carry = carry local_grid_env = tuple( - (idx, b) + pallas_core.GridAxis(idx, b) for dim, (idx, b) in enumerate(zip(loop_idx, grid)) if dim not in grid_mapping.mapped_dims ) @@ -180,18 +256,6 @@ def body(carry): start_indices = [ None if bm is None else bm.compute_start_indices(loop_idx, *scalars) for bm in grid_mapping.block_mappings] - block_shapes_without_mapped_dims = [ - None if block_mapping is None else block_mapping.block_shape - for block_mapping in grid_mapping.block_mappings - ] - is_indexing_dim = [ - None if bm is None else tuple(b is pallas_core.mapped for b in bm) - for bm in block_shapes_without_mapped_dims - ] - block_shapes = [ - None if bm is None else tuple(1 if i else b for i, b in zip(iid, bm)) - for iid, bm in zip(is_indexing_dim, block_shapes_without_mapped_dims) - ] blocks = map(_maybe_dynamic_slice, start_indices, block_shapes, carry, is_indexing_dim) with pallas_core.grid_env(local_grid_env): @@ -224,13 +288,13 @@ def body(carry): if input_output_aliases: raise NotImplementedError("Padding with aliasing not supported.") pad_low, pad_high = zip(*padding) - limit_indices = [s - p for s, p in zip(out.shape, pad_high)] + limit_indices = [s - p for s, p in zip(o.shape, pad_high)] o = lax.slice(o, pad_low, limit_indices) out_nopad.append(o) return out_nopad return xla.apply_primitive(pallas_call_p, *args, jaxpr=jaxpr, name=name, in_shapes=in_shapes, - out_shapes=out_shapes, which_linear=which_linear, + out_shapes=out_shapes, grid_mapping=grid_mapping, interpret=interpret, debug=debug, input_output_aliases=input_output_aliases, @@ -241,7 +305,7 @@ def _pallas_call_abstract_eval(*avals, out_shapes, **_): return map(lambda x: jax_core.ShapedArray(x.shape, x.dtype), out_shapes) pallas_call_p.def_abstract_eval(_pallas_call_abstract_eval) -def _pallas_call_jvp_rule(primals, tangents, *, jaxpr, name, which_linear, +def _pallas_call_jvp_rule(primals, tangents, *, jaxpr, name, input_output_aliases: tuple[tuple[int, int], ...], in_shapes, out_shapes, grid_mapping, debug, interpret, compiler_params: Any): if grid_mapping.num_dynamic_grid_bounds: @@ -267,8 +331,14 @@ def _pallas_call_jvp_rule(primals, tangents, *, jaxpr, name, which_linear, jvp_jaxpr.invars, [len(primals), len(out_shapes), len(tangents)] ) invars = (*primal_refs, *tangent_refs, *primal_out_refs, *tangent_out_refs) - # TODO(sharadmv): Fix state effect tracking after invar switch. - jvp_jaxpr = jvp_jaxpr.replace(invars=invars) + effs = [] + for eff in jvp_jaxpr.effects: + if isinstance(eff, effects.JaxprInputEffect): + eff = eff.replace( + input_index=invars.index(jvp_jaxpr.invars[eff.input_index]) + ) + effs.append(eff) + jvp_jaxpr = jvp_jaxpr.replace(invars=invars, effects=effs) if debug: print(jvp_jaxpr) in_bms, out_bms = split_list(grid_mapping.block_mappings, [len(primals)]) @@ -281,7 +351,6 @@ def _pallas_call_jvp_rule(primals, tangents, *, jaxpr, name, which_linear, in_shapes=(*in_shapes, *in_shapes), out_shapes=(*out_shapes, *out_shapes), grid_mapping=grid_mapping.replace(block_mappings=jvp_bms), - which_linear=which_linear + (True,) * len(tangents), interpret=interpret, debug=debug, input_output_aliases=(), @@ -291,7 +360,7 @@ def _pallas_call_jvp_rule(primals, tangents, *, jaxpr, name, which_linear, return out_primals, out_tangents ad.primitive_jvps[pallas_call_p] = _pallas_call_jvp_rule -def _batch_block_mapping(grid: tuple[int, ...], aval: jax_core.ShapedArray, +def _batch_block_mapping(grid_mapping: GridMapping, aval: jax_core.ShapedArray, dim: int | batching.NotMapped, block_mapping: BlockMapping | None) -> BlockMapping: def _block_map_function(new_idx, *args): @@ -306,11 +375,12 @@ def _block_map_function(new_idx, *args): return tuple(indices) i32_aval = jax_core.ShapedArray((), jnp.int32) if block_mapping is None: - idx_avals = [i32_aval] * (len(grid) + 1) + idx_avals = [i32_aval] * (len(grid_mapping.grid) + 1) else: idx_avals = [i32_aval, *block_mapping.index_map_jaxpr.in_avals] - block_mapping_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic( - lu.wrap_init(_block_map_function), idx_avals) + with grid_mapping.trace_env(): + block_mapping_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic( + lu.wrap_init(_block_map_function), idx_avals) shape = aval.shape if block_mapping is None else block_mapping.block_shape if dim is batching.not_mapped: new_block_shape = shape @@ -326,17 +396,149 @@ def _block_map_function(new_idx, *args): return block_mapping.replace(block_shape=new_block_shape, index_map_jaxpr=jaxpr) -def _pallas_call_batching_rule(args, dims, *, - jaxpr: jax_core.Jaxpr, - name: str, - in_shapes: tuple[jax.ShapeDtypeStruct, ...], - out_shapes: tuple[jax.ShapeDtypeStruct, ...], - grid_mapping: GridMapping, - input_output_aliases: tuple[tuple[int, int], ...], - debug: bool, - interpret: bool, - which_linear: tuple[bool, ...], - compiler_params: Any): + +def _broadcast_input_output_aliases( + args: Sequence[jax.Array], + dims: Sequence[int | batching.NotMapped], + *, + input_output_aliases: tuple[tuple[int, int], ...], + axis_size: int, +) -> tuple[tuple[jax.Array, ...], tuple[int | batching.NotMapped, ...]]: + """Broadcast input/output operands. + + When we have input/output aliasing, since the output will be mapped, we need + to make sure to broadcast the input across that dimension if it is not + mapped. If the input is mapped, but on a different axis, we tranpose the input + to match the output. + """ + + args_ = list(args) + dims_ = list(dims) + for input_index, _ in input_output_aliases: + dim = dims_[input_index] + dims_[input_index] = 0 + if dim is batching.not_mapped: + args_[input_index] = batching.broadcast(args_[input_index], axis_size, 0) + elif dim != 0: + # TODO(cjfj): Change output batching axis instead? + args_[input_index] = jnp.moveaxis(args[input_index], dim, 0) + + return tuple(args_), tuple(dims_) + + +def _batch_with_explicit_loop( + args: Sequence[jax.Array], + dims: Sequence[int | batching.NotMapped], + *, + jaxpr: jax_core.Jaxpr, + name: str, + in_shapes: tuple[jax.ShapeDtypeStruct, ...], + out_shapes: tuple[jax.ShapeDtypeStruct, ...], + grid_mapping: GridMapping, + input_output_aliases: tuple[tuple[int, int], ...], + debug: bool, + interpret: bool, + compiler_params: Any, +): + """Batch the pallas_call by calling it in loop over the batch size. + + This function provides a fallback implementation of batching a pallas_call + for the cases in which adding a batch dimension to the pallas grid is not + supported. This is currently the case when the batched dimension corresponds + to a dynamic axis or a scalar prefetch argument. + + This implementation builds a HLO loop that dynamic_slices the inputs according + to the current iteration index and dynamic_updates an (initially empty) output + allocation. + """ + + if not dims: + raise NotImplementedError("vmapping pallas_call with no arguments.") + + (axis_size,) = { + arg.shape[dim] + for arg, dim in zip(args, dims) + if dim is not batching.not_mapped + } + + args, dims = _broadcast_input_output_aliases( + args, + dims, + input_output_aliases=input_output_aliases, + axis_size=axis_size, + ) + + # The output arrays are completelly overwritten, so we can just initialize + # empty arrays. + initial_state = [ + jnp.empty( + tuple_insert(out_shape.shape, 0, axis_size), dtype=out_shape.dtype + ) + for out_shape in out_shapes + ] + + def body(batch_index: jax.Array, state: list[jax.Array]) -> list[jax.Array]: + batch_args = [] + + for arg, dim in zip(args, dims): + # If the argument is mapped, extract a slice of size 1 in the mapped + # dimension at the current index. + if dim is batching.not_mapped: + batch_args.append(arg) + else: + batch_args.append( + jnp.squeeze( + jax.lax.dynamic_slice_in_dim( + operand=arg, + start_index=batch_index, + slice_size=1, + axis=dim, + ), + axis=dim, + ) + ) + + batch_out = pallas_call_p.bind( + *batch_args, + jaxpr=jaxpr, + name=name, + in_shapes=in_shapes, + out_shapes=out_shapes, + grid_mapping=grid_mapping, + input_output_aliases=input_output_aliases, + debug=debug, + interpret=interpret, + compiler_params=compiler_params, + ) + for i, batch_out_array in enumerate(batch_out): + state[i] = jax.lax.dynamic_update_index_in_dim( + state[i], + batch_out_array, + batch_index, + axis=0, + ) + + return state + + result = jax.lax.fori_loop(0, axis_size, body, initial_state, unroll=False) + + return result, (0,) * len(result) + + +def _pallas_call_batching_rule( + args, + dims, + *, + jaxpr: jax_core.Jaxpr, + name: str, + in_shapes: tuple[jax.ShapeDtypeStruct, ...], + out_shapes: tuple[jax.ShapeDtypeStruct, ...], + grid_mapping: GridMapping, + input_output_aliases: tuple[tuple[int, int], ...], + debug: bool, + interpret: bool, + compiler_params: Any, +): def _maybe_squeeze_out_bdim( x: jax.Array, bdim: int | batching.NotMapped @@ -345,6 +547,27 @@ def _maybe_squeeze_out_bdim( return x return jnp.squeeze(x, axis=bdim) + axis_size, = {x.shape[d] for x, d in zip(args, dims) + if d is not batching.not_mapped} + if axis_size == 1: + # Why are we even vmapping? + args = map(_maybe_squeeze_out_bdim, args, dims) + out = pallas_call_p.bind( + *args, + jaxpr=jaxpr, + name=name, + in_shapes=in_shapes, + out_shapes=out_shapes, + grid_mapping=grid_mapping, + input_output_aliases=input_output_aliases, + debug=debug, + interpret=interpret, + compiler_params=compiler_params, + ) + return [jnp.expand_dims(x, 0) for x in out], (0,) * len(out) + + # The first num_dynamic_grid_bounds arguments are size-1 arrays that store + # the size of the dynamic bounds. dynamic_grid_args, args = split_list( args, [grid_mapping.num_dynamic_grid_bounds] ) @@ -356,10 +579,23 @@ def _maybe_squeeze_out_bdim( for arg, bdim in zip(dynamic_grid_args, dynamic_grid_dims) ): dynamic_grid_args = safe_map( - _maybe_squeeze_out_bdim, dynamic_grid_args, dynamic_grid_dims) + _maybe_squeeze_out_bdim, dynamic_grid_args, dynamic_grid_dims + ) elif any(bdim is not batching.not_mapped for bdim in dynamic_grid_dims): - raise NotImplementedError( - f"Batched dynamic grid bounds unsupported: {dynamic_grid_dims}" + # TODO(amagni, sharadmv): Explore possibility of batching dynamic grid + # bounds. + return _batch_with_explicit_loop( + args=dynamic_grid_args + args, + dims=dynamic_grid_dims + dims, + jaxpr=jaxpr, + name=name, + in_shapes=in_shapes, + out_shapes=out_shapes, + grid_mapping=grid_mapping, + input_output_aliases=input_output_aliases, + debug=debug, + interpret=interpret, + compiler_params=compiler_params, ) else: pass # No dynamic grid dimensions @@ -380,12 +616,24 @@ def _maybe_squeeze_out_bdim( args = (*scalar_args, *args) dims = (*scalar_bdims, *bdims) else: - # TODO(sharadmv,apaszke): enable batching over prefetched scalar args - raise NotImplementedError + # TODO(amagni,sharadmv,apaszke): enable efficient batching over + # prefetched scalar args. + return _batch_with_explicit_loop( + args=scalar_args + args, + dims=scalar_bdims + bdims, + jaxpr=jaxpr, + name=name, + in_shapes=in_shapes, + out_shapes=out_shapes, + grid_mapping=grid_mapping, + input_output_aliases=input_output_aliases, + debug=debug, + interpret=interpret, + compiler_params=compiler_params, + ) + if not dims: raise NotImplementedError("vmapping pallas_call with no arguments.") - axis_size, = {x.shape[d] for x, d in zip(args, dims) - if d is not batching.not_mapped} block_mappings = grid_mapping.block_mappings avals = [v.aval for v in jaxpr.invars] # How should we pick output dimensions? This actually matters because XLA @@ -395,18 +643,9 @@ def _maybe_squeeze_out_bdim( # TODO(sharadmv): explore inferring better output dimensions via a heuristic # TODO(sharadmv): explore a long term solution to output dim inference - # When we have input/output aliasing, since the output will be mapped, we need - # to make sure to broadcast the input across that dimension if it is not - # mapped. - dims_ = list(dims) - args_ = list(args) - for input_index, _ in input_output_aliases: - dim = dims_[input_index] - if dim is batching.not_mapped: - dims_[input_index] = 0 - args_[input_index] = batching.broadcast(args_[input_index], axis_size, 0) - args = tuple(args_) - dims = tuple(dims_) + args, dims = _broadcast_input_output_aliases( + args, dims, input_output_aliases=input_output_aliases, axis_size=axis_size + ) all_dims = list(dims) + [0] * len(out_shapes) @@ -418,7 +657,7 @@ def _maybe_squeeze_out_bdim( # operands (the last in the list). avals_to_batch = avals[num_index_operands:(len(avals) - num_scratch_operands)] batched_block_mappings = map( - partial(_batch_block_mapping, grid_mapping.grid), + partial(_batch_block_mapping, grid_mapping), avals_to_batch, all_dims[num_index_operands:], block_mappings, @@ -444,7 +683,6 @@ def _maybe_squeeze_out_bdim( name=f"batched_{name}", in_shapes=batched_in_shapes, out_shapes=batched_out_shapes, - which_linear=which_linear, grid_mapping=batched_grid_mapping, input_output_aliases=input_output_aliases, debug=debug, @@ -452,42 +690,229 @@ def _maybe_squeeze_out_bdim( compiler_params=compiler_params, ) return out, (0,) * len(out) + + batching.primitive_batchers[pallas_call_p] = _pallas_call_batching_rule def _hoist_consts_to_refs(jaxpr: jax_core.Jaxpr) -> jax_core.Jaxpr: - all_const_avals = [var.aval for var in jaxpr.constvars] - is_const_ref = [isinstance(var.aval, state.AbstractRef) for var in - jaxpr.constvars] - const_avals, const_ref_avals = partition_list(is_const_ref, all_const_avals) - const_avals = map(state.AbstractRef, const_avals) - merged_const_avals = merge_lists(is_const_ref, const_avals, const_ref_avals) - arg_avals = [var.aval for var in jaxpr.invars] - in_avals = [*merged_const_avals, *arg_avals] - num_consts = len(merged_const_avals) + """Hoists the constants in the given jaxpr into invars. + + Args: + jaxpr: The jaxpr. + + Returns: + A new jaxpr where the constants were hoisted into invars as ``Ref``s. + The invars for the constants are added *before* any existing invars. + """ + if not jaxpr.constvars: + return jaxpr # Nothing to hoist. + + is_const_ref = [ + isinstance(var.aval, state.AbstractRef) for var in jaxpr.constvars + ] + const_avals = [ + var.aval if is_ref else state.AbstractRef(var.aval) + for is_ref, var in zip(is_const_ref, jaxpr.constvars) + ] + in_avals = const_avals + [var.aval for var in jaxpr.invars] def _hoist(*consts_args): - all_consts, args = split_list(consts_args, [num_consts]) - consts, const_refs = partition_list(is_const_ref, all_consts) + all_consts, args = split_list(consts_args, [len(const_avals)]) # We immediately read the const values out of the `Ref`s. - consts = map(lambda x: sp.ref_get(x, ()), consts) - all_consts = merge_lists(is_const_ref, consts, const_refs) + all_consts = [ + c if is_ref else sp.ref_get(c, ()) + for is_ref, c in zip(is_const_ref, all_consts) + ] return jax_core.eval_jaxpr(jaxpr, all_consts, *args) + hoisted_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic( lu.wrap_init(_hoist), in_avals) assert not consts, "All consts should have been converted to refs" return hoisted_jaxpr + +def checkify_pallas_kernel_body_jaxpr( + body_jaxpr: jax_core.ClosedJaxpr, + enabled_errors, + error: checkify.Error, + grid_mapping: GridMapping) -> tuple[ + jax_core.ClosedJaxpr, tree_util.PyTreeDef, set[checkify.ErrorEffect]]: + err_vals, err_tree = tree_util.tree_flatten(error) + err_vals = map(checkify.get_shaped_aval, err_vals) + flat_err_and_in_vals = [*err_vals, *body_jaxpr.in_avals] + + with pallas_core.tracing_grid_env(grid_mapping.grid, ()): + checked_jaxpr, out_tree, error_effects = checkify.jaxpr_to_checkify_jaxpr( + body_jaxpr, enabled_errors, err_tree, *flat_err_and_in_vals) + return checked_jaxpr, out_tree, error_effects + +def pallas_call_checkify_rule(error: checkify.Error, + enabled_errors, + *args: jax_core.Value, + jaxpr: jax_core.Jaxpr, + interpret: bool, + input_output_aliases: tuple[tuple[int, int], ...], + grid_mapping: GridMapping, + out_shapes, + **kwargs): + # TODO(b/346651778): Support TPU/GPU checkify. + if not interpret: + raise NotImplementedError( + "Checkify for pallas_call only supports interpret mode.") + # We implement the checkify rule in 4 steps: + # 1) First, trace the kernel body to get the expected error shapes. + # 2) Checkify the kernel body to obtain a jaxpr with errors as inputs + # and outputs. + # 3) Create a new kernel which stores the errors in output memrefs instead of + # returning them, since pallas kernels do not return outputs. + # 4) Create block specs for the error state and call pallas_call with + # the new kernel. + dynamic_grid_bounds, scalars, args = split_list( # type: ignore + args, [grid_mapping.num_dynamic_grid_bounds, grid_mapping.num_index_operands] + ) + num_scalars = len(scalars) + num_invars = len(jaxpr.invars) + num_inputs_outputs = ( + num_invars + - grid_mapping.num_index_operands + - grid_mapping.num_scratch_operands + ) + num_kernel_inputs = len(args) + num_scratch = num_invars - num_inputs_outputs + num_kernel_outputs = num_invars - num_scratch - num_kernel_inputs + + # Trace the jaxpr to get an initial error value so the kernel jaxpr has all of + # the required inputs. + closed_jaxpr = pe.close_jaxpr(jaxpr) + _jaxpr, _, error_effects = checkify_pallas_kernel_body_jaxpr( + closed_jaxpr, enabled_errors, error, grid_mapping) + error = error._add_placeholder_effects(error_effects) + err_vals, err_tree = jax.tree.flatten(error) + shaped_err_avals = map(checkify.get_shaped_aval, err_vals) + + # Trace the kernel jaxpr to get a checkified jaxpr. This jaxpr will have + # all enabled errors removed, but have the error as inputs and return values. + input_avals = [v.aval for v in jaxpr.invars] + num_err_vals = len(err_vals) + shaped_input_avals = tuple(jax_core.raise_to_shaped(x) for x in input_avals) + checkify_in_avals = [*shaped_err_avals, + *shaped_input_avals] + closed_kernel_jaxpr = pe.close_jaxpr(jaxpr) + with pallas_core.tracing_grid_env(grid_mapping.grid, ()): + checked_jaxpr, out_tree, _ = checkify.jaxpr_to_checkify_jaxpr( + closed_kernel_jaxpr, enabled_errors, err_tree, *checkify_in_avals) + + # Create a new kernel to remove the error as an return value and instead + # write them to a memref. This is because pallas kernels are expected + # to have no return values but instead write their outputs to a ref. + def checked_kernel_fn(*args): + (scalars, _, inputs, out_error_refs, outputs, scratch + ) = split_list( + args, + [num_scalars, num_err_vals, + num_kernel_inputs, num_err_vals, num_kernel_outputs]) + input_error_vals = [err_ref[...] for err_ref in out_error_refs] + # We need to re-order the inputs here. A checkified jaxpr always expects + # errors before other arguments. + jaxpr_args = [*input_error_vals, *scalars, *inputs, *outputs, *scratch] + assert len(checked_jaxpr.jaxpr.invars) == len(jaxpr_args) + result_flat = jax.core.eval_jaxpr( + checked_jaxpr.jaxpr, checked_jaxpr.consts, *jaxpr_args) + output_errors, _ = split_list(result_flat, [num_err_vals]) + # Store new errors back in the error refs. + for out_ref, error in zip(out_error_refs, output_errors): + out_ref[...] = error + return [] + + # Trace the new checked_kernel_fn with Memref inputs so that + # we can replace the old kernel jaxpr with the new checked jaxpr in + # pallas_call. + # TODO(justinfu): Place errors in scalar memory for non-interpret mode. + error_mem_space = None + error_memref_aval = [pallas_core.AbstractMemoryRef( + err_val, error_mem_space) for err_val in shaped_err_avals] + shaped_scalar_avals, input_aval, output_aval, scratch_aval = split_list( + shaped_input_avals, [num_scalars, num_kernel_inputs, num_kernel_outputs]) + retrace_in_avals = [*shaped_scalar_avals, *error_memref_aval, *input_aval, + *error_memref_aval, *output_aval, *scratch_aval] + jaxpr_flat_avals, jaxpr_in_tree = tree_util.tree_flatten(retrace_in_avals) + wrapped_kernel_with_err, out_tree_thunk = api_util.flatten_fun_nokwargs( + lu.wrap_init(checked_kernel_fn), jaxpr_in_tree) + debug = pe.debug_info( + checked_kernel_fn, jaxpr_in_tree, out_tree_thunk, False, "checkify_pallas") + with pallas_core.tracing_grid_env(grid_mapping.grid, ()): + final_jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic( + wrapped_kernel_with_err, jaxpr_flat_avals, debug) + + # Prepare pallas_call inputs. We need to create new block specs + # for the new error inputs and outputs. + scalar_avals = map(checkify.get_shaped_aval, scalars) + error_block_specs = [no_block_spec] * num_err_vals + grid_avals = [ + jax_core.ShapedArray((), jnp.dtype("int32"))] * len(grid_mapping.grid) + # TODO(justinfu): Place these in device-specific scalar memory. + scalar_ref_avals = [ + pallas_core.AbstractMemoryRef( + jax_core.ShapedArray(aval.shape, aval.dtype), None) + for aval in scalar_avals] + grid_tree = tree_util.tree_structure(((*grid_avals, *scalar_avals), {})) + error_block_mappings = map( + partial( + pallas_core._convert_block_spec_to_block_mapping, + (*grid_avals, *scalar_ref_avals), + in_tree=grid_tree, + grid=grid_mapping.grid, + mapped_dims=grid_mapping.mapped_dims), + error_block_specs, error_memref_aval) + input_block_mappings, output_block_mappings = split_list( + grid_mapping.block_mappings, [num_kernel_inputs,]) + grid_mapping_with_error = grid_mapping.replace( + block_mappings=(*error_block_mappings, *input_block_mappings, + *error_block_mappings, *output_block_mappings) + ) + error_out_shapes = tuple( + jax.ShapeDtypeStruct(e.shape, e.dtype) for e in shaped_err_avals) + # Bump all input_output_aliases by num_err_vals to make room for error + # TODO(justinfu): Don't bump scalars here. + input_output_aliases = tuple( + (i+num_err_vals, o+num_err_vals) for (i, o) in input_output_aliases) + input_output_aliases_with_error = tuple( + (i+num_scalars, i) for i in range(num_err_vals)) + input_output_aliases + + new_vals_in = [*scalars, *err_vals, *args] + result = pallas_call_p.bind(*dynamic_grid_bounds, *new_vals_in, + jaxpr=final_jaxpr, + interpret=interpret, + grid_mapping=grid_mapping_with_error, + input_output_aliases=input_output_aliases_with_error, + out_shapes=error_out_shapes + out_shapes, + **kwargs) + errors, results = split_list(result, [num_err_vals]) + new_error, _ = jax.tree.unflatten(out_tree, errors) + return new_error, results +checkify.error_checks[pallas_call_p] = pallas_call_checkify_rule + @weakref_lru_cache def _trace_to_jaxpr(fun: Callable, grid_spec: GridSpec, flat_in_avals, - flat_out_avals, in_tree, out_tree): + flat_out_avals, in_tree, out_tree, interpret: bool): avals, grid_mapping = grid_spec.get_grid_mapping(flat_in_avals, in_tree, flat_out_avals, out_tree) + if interpret: + avals = jax.tree_util.tree_map(_logical_aval_to_interpret_mode_aval, avals) jaxpr_flat_avals, jaxpr_in_tree = tree_util.tree_flatten(avals) wrapped_fun, out_tree_thunk = api_util.flatten_fun_nokwargs( lu.wrap_init(fun), jaxpr_in_tree) debug = pe.debug_info(fun, jaxpr_in_tree, out_tree_thunk, False, "pallas_call") - jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_fun, jaxpr_flat_avals, debug) - jaxpr = _hoist_consts_to_refs(jaxpr) + with pallas_core.tracing_grid_env(grid_mapping.grid, ()): + jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_fun, + jaxpr_flat_avals, debug) + if consts: + jaxpr = _hoist_consts_to_refs(jaxpr) + # Pad ``block_mappings`` to account for the hoisted constants. + grid_mapping = grid_mapping.replace( + block_mappings=(*grid_mapping.block_mappings, *[None] * len(consts)), + num_constant_operands=len(consts), + ) return grid_mapping, jaxpr, consts, out_tree_thunk() def _extract_function_name(f: Callable, name: str | None) -> str: @@ -496,30 +921,78 @@ def _extract_function_name(f: Callable, name: str | None) -> str: return name -def _pallas_call_default_lowering( - ctx: mlir.LoweringRuleContext, - *in_nodes, - interpret: bool, - **params): - platforms = ctx.module_context.platforms - if len(platforms) > 1: - raise ValueError("Can only lower pallas_call on a single platform.") - platform = platforms[0] +_PALLAS_USE_MOSAIC_GPU = config.bool_flag( + "jax_pallas_use_mosaic_gpu", + default=config.bool_env("JAX_PALLAS_USE_MOSAIC_GPU", False), + help=( + "If True, lower Pallas kernels to the experimental Mosaic GPU" + " dialect, instead of Trition IR." + ), +) + + +def _unsupported_lowering_error(platform: str) -> Exception: + return ValueError( + f"Cannot lower pallas_call on platform: {platform}. To use Pallas on GPU," + " install jaxlib GPU 0.4.24 or newer. To use Pallas on TPU, install" + " jaxlib TPU and libtpu. See" + " https://jax.readthedocs.io/en/latest/installation.html." + ) + + +def _pallas_call_lowering( + ctx: mlir.LoweringRuleContext, *in_nodes, interpret: bool, **params +): if interpret: # If we are in interpret mode, we don't care what platform we are on. impl = partial(_pallas_call_impl, **params, interpret=True) return mlir.lower_fun(impl, multiple_results=True)(ctx, *in_nodes) - if platform == "cpu": - # We only support interpret mode on the CPU backend. + + def cpu_lowering(ctx: mlir.LoweringRuleContext, + *in_nodes: mlir.ir.Value | Sequence[mlir.ir.Value], + **params): raise ValueError("Only interpret mode is supported on CPU backend.") - # If we are actually using a specific backend (GPU or TPU), we should have - # already registered backend-specific lowerings. If we get this far, it means - # those backends aren't present. - raise ValueError( - f"Cannot lower pallas_call on platform: {platform}. " - "To use Pallas on GPU, please install Triton and JAX-Triton. " - "To use Pallas on TPU, please install Jaxlib TPU and libtpu.") -mlir.register_lowering(pallas_call_p, _pallas_call_default_lowering) + + def tpu_lowering(ctx: mlir.LoweringRuleContext, + *in_nodes: mlir.ir.Value | Sequence[mlir.ir.Value], + **params): + try: + from jax._src.pallas.mosaic import pallas_call_registration + except ImportError: + raise _unsupported_lowering_error("tpu") + else: + return pallas_call_registration.pallas_call_tpu_lowering_rule( + ctx, *in_nodes, **params + ) + + def gpu_lowering(ctx: mlir.LoweringRuleContext, + *in_nodes: mlir.ir.Value | Sequence[mlir.ir.Value], + **params): + try: + if _PALLAS_USE_MOSAIC_GPU.value: + from jax._src.pallas.mosaic_gpu import pallas_call_registration + else: + from jax._src.pallas.triton import pallas_call_registration # type: ignore + except ImportError: + raise _unsupported_lowering_error("gpu") + else: + return pallas_call_registration.pallas_call_lowering( + ctx, *in_nodes, **params + ) + + return mlir.lower_per_platform(ctx, "pallas_call", + dict(cpu=cpu_lowering, + tpu=tpu_lowering, + cuda=gpu_lowering, + rocm=gpu_lowering), + None, # default_rule + effects.no_effects, + *in_nodes, + interpret=interpret, + **params) + + +mlir.register_lowering(pallas_call_p, _pallas_call_lowering) def pallas_call( @@ -529,21 +1002,55 @@ def pallas_call( grid_spec: GridSpec | None = None, debug: bool = False, grid: Grid | None = None, - in_specs: Sequence[BlockSpec | NoBlockSpec] | NoBlockSpec = no_block_spec, - out_specs: BlockSpec | NoBlockSpec - | Sequence[BlockSpec | NoBlockSpec] = no_block_spec, + in_specs: BlockSpecTree = no_block_spec, + out_specs: BlockSpecTree = no_block_spec, input_output_aliases: dict[int, int] = {}, interpret: bool = False, name: str | None = None, compiler_params: dict[str, Any] | None = None, - **compiler_params_: Any, -): +) -> Callable[..., Any]: + """Invokes a Pallas kernel on some inputs. + + See `Pallas Quickstart `_. + + Args: + f: the kernel function, that receives a Ref for each input and output. + The shape of the Refs are given by the ``block_shape`` in the + corresponding ``in_specs`` and ``out_specs``. + out_shape: a PyTree of :class:`jax.ShapeDtypeStruct` describing the shape + and dtypes of the outputs. + grid_spec: TO BE DOCUMENTED. + debug: if True, Pallas prints various intermediate forms of the kernel + as it is being processed. + grid: the iteration space, as a tuple of integers. The kernel is executed + as many times as ``prod(grid)``. The default value ``None`` is equivalent + to ``()``. + See details at :ref:`pallas_grid`. + in_specs: a PyTree of :class:`jax.experimental.pallas.BlockSpec` with + a structure matching that of the positional arguments. + See details at :ref:`pallas_blockspec`. + out_specs: a PyTree of :class:`jax.experimental.pallas.BlockSpec` with + a structure matching that of the outputs. + See details at :ref:`pallas_blockspec`. + The default value for `out_specs` specifies the whole array, + e.g., as ``pl.BlockSpec(x.shape, lambda *indices: indices)``. + input_output_aliases: a dictionary mapping the index of some inputs to + the index of the output that aliases them. + interpret: runs the ``pallas_call`` as a ``jax.jit`` of a scan over the + grid whose body is the kernel lowered as a JAX function. This does not + require a TPU or a GPU, and is the only way to run Pallas kernels on CPU. + This is useful for debugging. + name: TO BE DOCUMENTED. + compiler_params: TO BE DOCUMENTED. + + Returns: + A function that can be called on a number of positional array arguments to + invoke the Pallas kernel. + + """ name = _extract_function_name(f, name) if compiler_params is None: compiler_params = {} - assert not (compiler_params and compiler_params_) - if compiler_params_: - compiler_params = compiler_params_ if grid is not None and grid_spec is not None: raise ValueError("Cannot specify both grid and grid_spec at the same time.") if grid_spec is None: @@ -563,11 +1070,10 @@ def wrapped(*args): for v in flat_out_shapes) grid_mapping, jaxpr, consts, _ = _trace_to_jaxpr( f, grid_spec, flat_in_avals, flat_out_avals, in_tree, - out_tree) - which_linear = (False,) * len(flat_args) + out_tree, interpret=interpret) out_flat = pallas_call_p.bind( *dynamic_grid_bounds, *consts, *flat_args, - jaxpr=jaxpr, name=name, which_linear=which_linear, + jaxpr=jaxpr, name=name, in_shapes=tuple(jax.ShapeDtypeStruct(a.shape, a.dtype) for a in flat_args), out_shapes=tuple(flat_out_shapes), debug=debug, diff --git a/jax/_src/pallas/primitives.py b/jax/_src/pallas/primitives.py index a38c3c4c3df9..ce87f2bc026c 100644 --- a/jax/_src/pallas/primitives.py +++ b/jax/_src/pallas/primitives.py @@ -12,11 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Module for pallas-specific JAX primitives and functions.""" +"""Pallas-specific JAX primitives.""" + from __future__ import annotations + import enum import functools - +import string from typing import Any import jax @@ -24,66 +26,73 @@ from jax import tree_util from jax._src import ad_util from jax._src import core as jax_core +from jax._src import effects from jax._src import pretty_printer as pp from jax._src import state -from jax._src.util import (safe_map, safe_zip) +from jax._src import util +from jax._src.interpreters import ad +from jax._src.interpreters import batching +from jax._src.pallas import core as pallas_core from jax._src.state import discharge as state_discharge from jax._src.state import indexing from jax._src.state import primitives as sp -from jax._src.interpreters import ad from jax.interpreters import mlir import jax.numpy as jnp -from jax._src.pallas import core as pallas_core - -# TODO(sharadmv): enable type checking -# mypy: ignore-errors - partial = functools.partial Slice = indexing.Slice NDIndexer = indexing.NDIndexer -map, unsafe_map = safe_map, map -zip, unsafe_zip = safe_zip, zip +map, unsafe_map = util.safe_map, map +zip, unsafe_zip = util.safe_zip, zip program_id_p = jax_core.Primitive("program_id") -def program_id(axis): +def program_id(axis: int) -> jax.Array: + """Returns the kernel execution position along the given axis of the grid. + + For example, with a 2D `grid` in the kernel execution corresponding to the + grid coordinates `(1, 2)`, + `program_id(axis=0)` returns `1` and `program_id(axis=1)` returns `2`. + + Args: + axis: the axis of the grid along which to count the program. + """ return program_id_p.bind(axis=axis) def program_id_bind(*, axis: int): grid_env = pallas_core.current_grid_env() if grid_env: - return grid_env[axis].axis_index + return grid_env[axis].index + frame = pallas_core.axis_frame() + # Query the size of the axis to make sure its a valid axis (and error + # otherwise). + _ = frame.size(axis) return jax_core.Primitive.bind(program_id_p, axis=axis) program_id_p.def_custom_bind(program_id_bind) -def _program_id_impl(*, axis: int): - grid_env = pallas_core.current_grid_env() - return grid_env[axis].axis_index -program_id_p.def_impl(_program_id_impl) - def _program_id_abstract_eval(**_): return jax_core.ShapedArray((), jnp.int32) program_id_p.def_abstract_eval(_program_id_abstract_eval) - num_programs_p = jax_core.Primitive("num_programs") -def num_programs(axis): +def num_programs(axis: int) -> int | jax.Array: + """Returns the size of the grid along the given axis.""" return num_programs_p.bind(axis=axis) @num_programs_p.def_custom_bind def _num_programs_bind(*, axis: int): + # We might be using a local grid env grid_env = pallas_core.current_grid_env() if grid_env: - return jnp.asarray(grid_env[axis].axis_size, dtype=jnp.int32) - return jax_core.Primitive.bind(num_programs_p, axis=axis) - -@num_programs_p.def_impl -def _num_programs_impl(*, axis: int): - grid_env = pallas_core.current_grid_env() - return jnp.asarray(grid_env[axis].axis_size, dtype=jnp.int32) + return grid_env[axis].size + # Otherwise, we look up the size of the grid in the axis env + frame = pallas_core.axis_frame() + size = frame.size(axis) + if size is pallas_core.dynamic_grid_dim: + return jax_core.Primitive.bind(num_programs_p, axis=axis) + return size @num_programs_p.def_abstract_eval def _num_programs_abstract_eval(**_): @@ -223,7 +232,7 @@ def _max_contiguous_abstract_eval(aval, **_): multiple_of_p.def_impl(lambda x, **_: x) mlir.register_lowering(multiple_of_p, lambda _, x, **__: [x]) -def multiple_of(x, values): +def multiple_of(x: jax.Array, values: list[int] | int) -> jax.Array: if not isinstance(values, list): values = [values] return multiple_of_p.bind(x, values=values) @@ -280,12 +289,12 @@ def _load_jvp(primals, tangents, args_tree, **params): other_tangent = ad_util.instantiate(other_tangent) return ( load_p.bind( - *tree_util.flatten(ref_primal, indexers, mask, other_primal), + *tree_util.tree_leaves((ref_primal, indexers, mask, other_primal)), args_tree=args_tree, **params, ), load_p.bind( - *tree_util.flatten(ref_tangent, indexers, mask, other_tangent), + *tree_util.tree_leaves((ref_tangent, indexers, mask, other_tangent)), args_tree=args_tree, **params, ), @@ -294,6 +303,41 @@ def _load_jvp(primals, tangents, args_tree, **params): ad.primitive_jvps[load_p] = _load_jvp +def uninitialized_value(shape, dtype): + if jnp.issubdtype(dtype, jnp.floating): + return jnp.full(shape, jnp.nan, dtype) + elif jnp.issubdtype(dtype, jnp.integer): + return jnp.full(shape, jnp.iinfo(dtype).min, dtype) + elif jnp.issubdtype(dtype, jnp.bool): + return jnp.full(shape, False, dtype) + raise NotImplementedError(dtype) + +def _pad_values_to_avoid_dynamic_slice_oob_shift(value, + slice_sizes, unpad=False): + """ + DynamicSlice and DynamicUpdateSlice adjust the start index in cases where the + requested slice overruns the bounds of the array. This pads the array with + uninitialised values such that the requested slice will never overrun. + + For example, if arr is [1.,2.,3.,4.] and a slice of size 4, start index 2 is + requested then the result will be [3.,4.,NaN,NaN] after padding, rather than + [1.,2.,3.,4.] from the unpadded array + + unpad=True performs the inverse operation + """ + + padding_config = tuple((0, slice_size, 0) for slice_size in slice_sizes) + if unpad: + padding_config = tuple((-low, -high, -interior) + for (low, high, interior) in padding_config) + padding_value = uninitialized_value(shape=(), dtype=value.dtype) + value = lax.pad(value, + padding_config=padding_config, + padding_value=padding_value) + return value + +_unpad_values_to_avoid_dynamic_slice_oob_shift = partial( + _pad_values_to_avoid_dynamic_slice_oob_shift, unpad=True) def _load_discharge_rule(in_avals, out_avals, *args_flat, args_tree, **_): del out_avals # Unused. @@ -303,10 +347,18 @@ def _load_discharge_rule(in_avals, out_avals, *args_flat, args_tree, **_): raise NotImplementedError("Only one indexer supported in discharge rule.") idx = indexers[0] if all((isinstance(s, Slice) or not s.shape) for s in idx.indices): + # TODO(b/329733289): support strided load/store in interpret mode. + for s in idx.indices: + if isinstance(s, Slice) and s.stride > 1: + raise NotImplementedError("Unimplemented stride support.") indices = idx.indices scalar_dims = [not isinstance(s, Slice) and not s.shape for s in indices] slice_starts = [s.start if isinstance(s, Slice) else s for s in indices] slice_sizes = tuple(s.size if isinstance(s, Slice) else 1 for s in indices) + # fixes an inconstency with lax.dynamic_slice where if the slice goes out + # of bounds, it will instead move the start_index backwards so the slice + # will fit in memory. + ref = _pad_values_to_avoid_dynamic_slice_oob_shift(ref, slice_sizes) out_ones = lax.dynamic_slice(ref, slice_starts, slice_sizes=slice_sizes) out_indexer = tuple(0 if scalar else slice(None) for scalar in scalar_dims) out = out_ones[out_indexer] @@ -382,12 +434,12 @@ def _swap_jvp(primals, tangents, *, args_tree, **params): val_tangent = ad_util.instantiate(val_tangent) return ( swap_p.bind( - *tree_util.flatten(ref_primal, indexers, val_primal, mask), + *tree_util.tree_leaves((ref_primal, indexers, val_primal, mask)), args_tree=args_tree, **params, ), swap_p.bind( - *tree_util.flatten(ref_tangent, indexers, val_tangent, mask), + *tree_util.tree_leaves((ref_tangent, indexers, val_tangent, mask)), args_tree=args_tree, **params, ), @@ -404,6 +456,10 @@ def _swap_discharge_rule(in_avals, out_avals, *args_flat, args_tree, **_): raise NotImplementedError("Only one indexer supported in discharge rule.") idx = indexers[0] if all((isinstance(s, Slice) or not s.shape) for s in idx.indices): + # TODO(b/329733289): support strided load/store in interpret mode. + for s in idx.indices: + if isinstance(s, Slice) and s.stride > 1: + raise NotImplementedError("Unimplemented stride support.") indices = idx.indices scalar_dims = [ i @@ -412,6 +468,10 @@ def _swap_discharge_rule(in_avals, out_avals, *args_flat, args_tree, **_): ] slice_starts = [s.start if isinstance(s, Slice) else s for s in indices] slice_sizes = tuple(s.size if isinstance(s, Slice) else 1 for s in indices) + # Fixes an inconsistency with lax.dynamic_update_slice where if the slice + # goes out of bounds, it will instead move the start_index backwards so the + # slice will fit in memory. + ref = _pad_values_to_avoid_dynamic_slice_oob_shift(ref, slice_sizes) out = lax.dynamic_slice(ref, slice_starts, slice_sizes=slice_sizes) out = jnp.squeeze(out, scalar_dims) if mask is not None: @@ -420,6 +480,7 @@ def _swap_discharge_rule(in_avals, out_avals, *args_flat, args_tree, **_): val = jnp.where(mask, val, out_) val = jnp.expand_dims(val, scalar_dims) x_new = lax.dynamic_update_slice(ref, val, start_indices=slice_starts) + x_new = _unpad_values_to_avoid_dynamic_slice_oob_shift(x_new, slice_sizes) elif all(not isinstance(s, Slice) for s in idx.indices): out = ref[idx.indices] if mask is not None: @@ -476,3 +537,121 @@ def dot(a, b, trans_a: bool = False, trans_b: bool = False, precision=precision, preferred_element_type=jnp.float32, ) + + +class PrintEffect(effects.Effect): + __str__ = lambda self: "Print" + + +debug_print_effect = PrintEffect() + +# TODO(slebedev): Consider making the effect ordered. +effects.lowerable_effects.add_type(PrintEffect) +effects.control_flow_allowed_effects.add_type(PrintEffect) +effects.remat_allowed_effects.add_type(PrintEffect) +effects.custom_derivatives_allowed_effects.add_type(PrintEffect) + + +debug_print_p = jax_core.Primitive("debug_print") +debug_print_p.multiple_results = True + + +def debug_print(fmt: str, *args: jax.ArrayLike): + """Prints scalar values from inside a Pallas kernel. + + Args: + fmt: A format string to be included in the output. The restrictions on the + format string depend on the backend: + * On GPU, when using Triton, ``fmt`` must not contain any placeholders + (``{...}``), since it is always printed before any of the values. + * On GPU, when using the experimental Mosaic GPU backend, ``fmt`` must + contain a placeholder for each value to be printed. Format specs and + conversions are not supported. + * In TPU, if ``fmt`` contains placeholders, all values must be 32-bit + integers. If there are no placeholders, the values are printed after + the format string. + *args: The scalar values to print. + """ # fmt: skip + has_placeholders = False + if fmt: + _, field_name, *_ = next(iter(string.Formatter().parse(fmt))) + has_placeholders = field_name is not None + return debug_print_p.bind(*args, fmt=fmt, has_placeholders=has_placeholders) + + +def check_debug_print_format( + fmt: str, *args: jax.ArrayLike +): + n_placeholders = 0 + for _, field, spec, conversion in string.Formatter().parse(fmt): + if field is not None: + n_placeholders += 1 + if spec or conversion: + raise ValueError( + "The format string should not contain any format specs or conversions" + ) + if field: + raise ValueError( + "The format string should not reference arguments by position or name" + ) + + if len(args) != n_placeholders: + raise TypeError( + f"The format string expects {n_placeholders} " + f"argument{'' if n_placeholders == 1 else 's'}, but got {len(args)}" + ) + + +@debug_print_p.def_impl +def debug_print_impl(*args: Any, fmt: str, has_placeholders: bool): + if has_placeholders: + print(fmt.format(*args)) + else: + print(fmt, *args) + return () + + +@debug_print_p.def_effectful_abstract_eval +def debug_print_abstract_eval(*avals: Any, fmt: str, has_placeholders: bool): + del fmt, has_placeholders + if any(aval.shape for aval in avals): + raise ValueError("Only scalar values are supported") + return [], {debug_print_effect} + + +def debug_print_batching_rule(args, dims, **params): + """Unrolls the print primitive across the mapped axis.""" + axis_size = next(x.shape[i] for x, i in zip(args, dims) if i is not None) + + # TODO(sharadmv): implement in terms of rolled loop unstead of unrolled. + def get_arg_at_dim(i, dim, arg): + if dim is batching.not_mapped: + # Broadcast unmapped argument + return arg + return lax.index_in_dim(arg, i, axis=dim, keepdims=False) + + outs = [] + for i in range(axis_size): + args_idx = map(functools.partial(get_arg_at_dim, i), dims, args) + outs.append(debug_print_p.bind(*args_idx, **params)) + outs = [jnp.stack(xs) for xs in zip(*outs)] + return outs, (0,) * len(outs) + + +batching.primitive_batchers[debug_print_p] = functools.partial( + debug_print_batching_rule, debug_print_p +) + + +@functools.partial(mlir.register_lowering, debug_print_p) +def debug_print_lowering_rule(ctx, *args, **params): + result, _, _ = mlir.emit_python_callback( + ctx, + functools.partial(debug_print_p.impl, **params), + None, + list(args), + ctx.avals_in, + ctx.avals_out, + has_side_effect=True, + ) + return result diff --git a/jax/_src/pallas/triton/BUILD b/jax/_src/pallas/triton/BUILD index 042cf5fc212e..370cbb713ac5 100644 --- a/jax/_src/pallas/triton/BUILD +++ b/jax/_src/pallas/triton/BUILD @@ -17,10 +17,8 @@ load( "//jaxlib:jax.bzl", "py_deps", - "py_library_providing_imports_info", "pytype_strict_library", ) -load("@rules_python//python:defs.bzl", "py_library") package( default_applicable_licenses = [], @@ -29,24 +27,39 @@ package( ], ) -py_library_providing_imports_info( - name = "triton", - srcs = ["__init__.py"], - lib_rule = pytype_strict_library, +pytype_strict_library( + name = "primitives", + srcs = ["primitives.py"], deps = [ ":lowering", - ":pallas_call_registration", + "//jax", + "//jax:ad_util", + "//jax:api_util", + "//jax:core", + "//jax:mlir", + "//jax:partial_eval", + "//jax:source_info_util", + "//jax:util", "//jax/_src/lib", - ], + "//jax/_src/pallas", + ] + py_deps("numpy"), ) -# TODO(slebedev): Enable pytype for this target. -py_library( +pytype_strict_library( name = "lowering", srcs = ["lowering.py"], deps = [ "//jax", - ], + "//jax:ad_util", + "//jax:api_util", + "//jax:core", + "//jax:mlir", + "//jax:partial_eval", + "//jax:source_info_util", + "//jax:util", + "//jax/_src/lib", + "//jax/_src/pallas", + ] + py_deps("numpy"), ) pytype_strict_library( @@ -60,5 +73,5 @@ pytype_strict_library( "//jax:util", "//jax/_src/lib", "//jax/_src/pallas", - ] + py_deps("jax_triton"), + ], ) diff --git a/jax/_src/pallas/triton/__init__.py b/jax/_src/pallas/triton/__init__.py index 4769311ec0fb..38d13f42da99 100644 --- a/jax/_src/pallas/triton/__init__.py +++ b/jax/_src/pallas/triton/__init__.py @@ -11,23 +11,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -"""Contains Triton-specific pallas modules.""" - -from jax._src.pallas.triton import pallas_call_registration -from jax._src.lib import gpu_triton as triton_kernel_call_lib - - -try: - get_compute_capability = triton_kernel_call_lib.get_compute_capability -except AttributeError: - - def get_compute_capability() -> int: - raise RuntimeError( - "get_compute_capability is not available. Try installing jaxlib with" - " GPU support following instructions in" - " https://jax.readthedocs.io/en/latest/installation.html." - ) - - -del pallas_call_registration, triton_kernel_call_lib diff --git a/jax/_src/pallas/triton/lowering.py b/jax/_src/pallas/triton/lowering.py index 976e51a71e82..b11aaa0266d5 100644 --- a/jax/_src/pallas/triton/lowering.py +++ b/jax/_src/pallas/triton/lowering.py @@ -16,11 +16,12 @@ from __future__ import annotations -from collections.abc import Sequence +from collections.abc import Callable, Sequence import dataclasses import functools +import math import operator -from typing import Any, Callable +from typing import Any, TypeVar import jax from jax import lax @@ -32,6 +33,7 @@ from jax._src import custom_derivatives from jax._src import linear_util as lu from jax._src import pjit +from jax._src import source_info_util from jax._src import state from jax._src import util from jax._src.interpreters import mlir @@ -39,12 +41,12 @@ from jax._src.lax.control_flow import for_loop from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import arith as arith_dialect +from jax._src.lib.mlir.dialects import math as math_dialect from jax._src.lib.mlir.dialects import scf as scf_dialect from jax._src.lib.triton import dialect as tt_dialect from jax._src.pallas import core as pallas_core from jax._src.pallas import primitives from jax._src.pallas import utils as pallas_utils -from jax._src.state import AbstractRef from jax._src.state import discharge from jax._src.state import indexing from jax._src.state import primitives as sp @@ -54,9 +56,11 @@ import jax.numpy as jnp import numpy as np - # TODO(sharadmv): Enable type checking. # mypy: ignore-errors +# pytype: skip-file + +_T = TypeVar("_T") map, unsafe_map = util.safe_map, map zip, unsafe_zip = util.safe_zip, zip @@ -73,6 +77,8 @@ class ModuleContext: name: str grid_mapping: GridMapping program_ids: Sequence[ir.Value] + traceback_caches: mlir.TracebackCaches = dataclasses.field(repr=False) + platform: str @dataclasses.dataclass @@ -131,6 +137,10 @@ def _bcast_to(a: ir.Value, shape: tuple[int, ...]) -> ir.Value: a_type = ir.RankedTensorType(a.type) if a_type.shape == [*shape]: return a + if a_type.rank != len(shape) or not all( + a_type.shape[i] in (dim, 1) for i, dim in enumerate(shape) + ): + raise ValueError(f"Cannot broadcast from {a_type.shape} to {[*shape]}") return tt_dialect.broadcast( ir.RankedTensorType.get(shape, a_type.element_type, a_type.encoding), a ) @@ -163,6 +173,13 @@ def _bcast( triton_lowering_rules = {} +def register_lowering(primitive: jax_core.Primitive) -> Callable[[_T], _T]: + def wrapper(fn): + triton_lowering_rules[primitive] = fn + return fn + return wrapper + + def _process_grid_to_3d_grid(grid_mapping: GridMapping): launch_grid = [] launch_grid_to_pallas_grid = [] @@ -203,7 +220,7 @@ def _process_grid_to_3d_grid(grid_mapping: GridMapping): return prog_id_dims, prog_ids else: - new_grid = [np.prod(collapse_dims), *prog_id_dims] + new_grid = [math.prod(collapse_dims), *prog_id_dims] assert new_grid[0] < 2**31 - 1, \ "Cannot fix pallas kernel launch grid within CUDA limits" @@ -214,8 +231,8 @@ def _process_grid_to_3d_grid(grid_mapping: GridMapping): for i, s in enumerate(collapse_dims): out_idx = launch_grid_to_pallas_grid[i] s = _i32_constant(s) - out_indices[out_idx] = _mod(grid0, s) - grid0 = _floordiv(grid0, s) + out_indices[out_idx] = _mod(grid0, s, signed=False) + grid0 = _floordiv(grid0, s, signed=False) for i in range(len(prog_id_dims)): out_idx = launch_grid_to_pallas_grid[num_collapse + i] @@ -237,9 +254,8 @@ def lower_jaxpr_to_triton_module( in_shapes, grid_mapping: GridMapping, name: str, - cuda_options: Any, + platform: str ) -> LoweringResult: - # TODO(slebedev): Use cuda_options= during lowering. jaxpr, _ = pe.dce_jaxpr(jaxpr, [True] * len(jaxpr.outvars), instantiate=True) with _new_ir_context(), ir.Location.unknown(): module = ir.Module.create() @@ -269,7 +285,9 @@ def lower_jaxpr_to_triton_module( for i, pid in enumerate(program_ids) if i not in grid_mapping.mapped_dims ] - ctx = ModuleContext(name, grid_mapping, local_program_ids) + ctx = ModuleContext( + name, grid_mapping, local_program_ids, mlir.TracebackCaches(), platform + ) if grid_mapping.num_index_operands: raise NotImplementedError( "Scalar prefetch not supported in Triton lowering." @@ -292,7 +310,9 @@ def lower_jaxpr_to_triton_module( if block_mapping is not None else None for shape_dtype, block_mapping, start_idx in zip( - in_shapes, grid_mapping.block_mappings, start_indices + (*in_shapes, *[()] * grid_mapping.num_constant_operands), + grid_mapping.block_mappings, + start_indices, ) ] () = lower_jaxpr_to_triton_ir(ctx, jaxpr, block_infos, *entry.arguments) @@ -320,11 +340,12 @@ def read_block_info_env(atom: jax_core.Atom): def write_env(var: jax_core.Var, val): env[var] = val - if block_infos is None: - block_infos = [None] * len(jaxpr.invars) - for invar, block_info in zip(jaxpr.invars, block_infos): - block_info_env[invar] = block_info + if block_infos is not None: + for invar, block_info in zip(jaxpr.invars, block_infos): + block_info_env[invar] = block_info + map(write_env, jaxpr.invars, args) + for eqn in jaxpr.eqns: invals = map(read_env, eqn.invars) if eqn.primitive not in triton_lowering_rules: @@ -336,9 +357,13 @@ def write_env(var: jax_core.Var, val): avals_in = [v.aval for v in eqn.invars] avals_out = [v.aval for v in eqn.outvars] eqn_block_infos = map(read_block_info_env, eqn.invars) + loc = mlir._source_info_to_location( + ctx, eqn.primitive, eqn.params, eqn.source_info + ) rule_ctx = LoweringRuleContext(ctx, avals_in, avals_out, eqn_block_infos) try: - outvals = rule(rule_ctx, *invals, **eqn.params) + with source_info_util.user_context(eqn.source_info.traceback), loc: + outvals = rule(rule_ctx, *invals, **eqn.params) except LoweringError: raise # We only add the extra info to the innermost exception. except Exception as e: @@ -351,6 +376,7 @@ def write_env(var: jax_core.Var, val): map(write_env, eqn.outvars, outvals) else: write_env(eqn.outvars[0], outvals) + return map(read_env, jaxpr.outvars) @@ -364,21 +390,17 @@ def _program_id(axis: int) -> ir.Value: return tt_dialect.get_program_id(axis) +@register_lowering(primitives.program_id_p) def _program_id_lowering_rule(ctx: LoweringRuleContext, *, axis): return ctx.context.program_ids[axis] -triton_lowering_rules[primitives.program_id_p] = _program_id_lowering_rule - - +@register_lowering(primitives.num_programs_p) def _num_programs_lowering_rule(ctx: LoweringRuleContext, *, axis): if axis not in range(3): raise ValueError(f"axis must be in [0, 3), but got: {axis}") return tt_dialect.get_num_programs(axis) -triton_lowering_rules[primitives.num_programs_p] = _num_programs_lowering_rule - - def _atomic_rmw( op: tt_dialect.RMWOp, ptr: ir.Value, @@ -400,6 +422,7 @@ def _atomic_rmw( ) +@register_lowering(primitives.atomic_rmw_p) def _atomic_lowering_rule( ctx: LoweringRuleContext, *args_flat, @@ -439,9 +462,7 @@ def _atomic_lowering_rule( return _atomic_rmw(op, ptr, val, mask=mask) -triton_lowering_rules[primitives.atomic_rmw_p] = _atomic_lowering_rule - - +@register_lowering(primitives.atomic_cas_p) def _atomic_cas_lowering_rule(ctx: LoweringRuleContext, ptr, cmp, val): _, cmp_aval, val_aval = ctx.avals_in if ir.RankedTensorType.isinstance(ptr.type): @@ -462,9 +483,6 @@ def _atomic_cas_lowering_rule(ctx: LoweringRuleContext, ptr, cmp, val): ) -triton_lowering_rules[primitives.atomic_cas_p] = _atomic_cas_lowering_rule - - def _associative_scan_lowering(body, ctx: LoweringRuleContext, args, axes): flat_args = tree_util.tree_leaves(args) (axis,) = axes @@ -497,6 +515,7 @@ def _associative_scan_lowering(body, ctx: LoweringRuleContext, args, axes): return list(scan_op.result) +@register_lowering(lax.cumsum_p) def _cumsum_lowering_rule( ctx: LoweringRuleContext, x, *, axis: int, reverse: bool ): @@ -505,18 +524,10 @@ def _cumsum_lowering_rule( return _associative_scan_lowering(jnp.add, ctx, x, (axis,))[0] -triton_lowering_rules[lax.cumsum_p] = _cumsum_lowering_rule - - +@register_lowering(lax.not_p) def _not_lowering_rule(ctx: LoweringRuleContext, x): [x_aval] = ctx.avals_in - if not np.issubdtype(x_aval.dtype, jnp.integer): - raise NotImplementedError(f"unsupported type: {x_aval.dtype}") - one = _full(x.type, 0xFFFFFFFFFFFFFFFF) - return arith_dialect.xori(x, one) - - -triton_lowering_rules[lax.not_p] = _not_lowering_rule + return arith_dialect.xori(x, _full(x.type, ~x_aval.dtype.type(0))) @dataclasses.dataclass(frozen=True) @@ -525,22 +536,45 @@ class _Extern: symbol: str result_type: str - def matches(self, args: Sequence[jax_core.ShapedArray]) -> bool: - if len(args) != len(self.arg_types): + def matches(self, avals: Sequence[jax_core.ShapedArray]) -> bool: + if len(avals) != len(self.arg_types): return False return all( aval.weak_type or aval.dtype.name == arg_type - for aval, arg_type in zip(args, self.arg_types) + for aval, arg_type in zip(avals, self.arg_types) + ) + + def lower(self, ctx: LoweringRuleContext, *args: Sequence[ir.Value]): + [out_aval] = ctx.avals_out + result_type = _dtype_to_ir_type(jnp.dtype(self.result_type)) + if out_aval.shape: + result_type = ir.RankedTensorType.get(out_aval.shape, result_type) + return tt_dialect.extern_elementwise( + result_type, + args, + libname="", + libpath="", + symbol=self.symbol, + pure=True, ) -def _extern_elementwise( - name: str, table: Sequence[_Extern] +@dataclasses.dataclass(frozen=True) +class _Fallback: + arg_types: Sequence[str] + lower: Callable[..., ir.Value] + + matches = _Extern.matches + + +def _make_dispatch_table( + name: str, **tables: Sequence[_Extern | _Fallback] ) -> Callable[..., ir.Value]: def inner(ctx: LoweringRuleContext, *args: ir.Value) -> ir.Value: - extern = next((e for e in table if e.matches(ctx.avals_in)), None) - if extern is None: + table = tables[ctx.context.platform] + h = next((e for e in table if e.matches(ctx.avals_in)), None) + if h is None: arg_aval_dtypes = tuple(aval.dtype.name for aval in ctx.avals_in) raise NotImplementedError( f"unsupported types for {name}: {arg_aval_dtypes}" @@ -548,227 +582,376 @@ def inner(ctx: LoweringRuleContext, *args: ir.Value) -> ir.Value: [out_aval] = ctx.avals_out bcast_args = [] - for aval, arg, arg_type in zip(ctx.avals_in, args, extern.arg_types): + for aval, arg, arg_type in zip(ctx.avals_in, args, h.arg_types): bcast_arg = _bcast_to(_ensure_ir_value(arg, aval), out_aval.shape) if aval.weak_type and aval.dtype.name != arg_type: - bcast_arg = _cast(bcast_arg, _dtype_to_ir_type(jnp.dtype(arg_type))) + bcast_arg = _cast(bcast_arg, aval.dtype, jnp.dtype(arg_type)) bcast_args.append(bcast_arg) - - result_type = _dtype_to_ir_type(jnp.dtype(extern.result_type)) - if out_aval.shape: - result_type = ir.RankedTensorType.get(out_aval.shape, result_type) - return tt_dialect.extern_elementwise( - result_type, - bcast_args, - libname="", - libpath="", - symbol=extern.symbol, - pure=True, - ) + return h.lower(ctx, *bcast_args) return inner +_abs_dispatch_table = _make_dispatch_table( + "abs", + cuda=[ + _Extern(["int32"], "__nv_abs", "int32"), + _Extern(["int64"], "__nv_llabs", "int64"), + _Extern(["float32"], "__nv_fabsf", "float32"), + _Extern(["float64"], "__nv_fabs", "float64"), + ], + rocm=[ + _Fallback(["int32"], lambda ctx, x: math_dialect.absi(x)), + _Fallback(["int64"], lambda ctx, x: math_dialect.absi(x)), + _Fallback(["float32"], lambda ctx, x: math_dialect.absf(x)), + _Fallback(["float64"], lambda ctx, x: math_dialect.absf(x)), + ], +) + + +@register_lowering(lax.abs_p) +def _abs_lowering_rule(ctx: LoweringRuleContext, x): + try: + return _abs_dispatch_table(ctx, x) + except NotImplementedError as e: + [x_aval] = ctx.avals_in + if jnp.issubdtype(x_aval, jnp.integer): + return math_dialect.absi(x) + elif jnp.issubdtype(x_aval, jnp.floating): + return math_dialect.absf(x) + else: + raise e from None + + triton_lowering_rules.update({ lax.neg_p: lambda ctx, x: _minus(x), - lax.abs_p: _extern_elementwise( - "abs", - [ - _Extern(["int32"], "__nv_abs", "int32"), - _Extern(["int64"], "__nv_llabs", "int64"), - _Extern(["float32"], "__nv_fabsf", "float32"), - _Extern(["float64"], "__nv_fabs", "float64"), - ], - ), - lax.ceil_p: _extern_elementwise( + lax.ceil_p: _make_dispatch_table( "ceil", - [ + cuda=[ _Extern(["float32"], "__nv_ceilf", "float32"), _Extern(["float64"], "__nv_ceil", "float64"), ], + rocm=[ + _Extern(["float32"], "__ocml_ceil_f32", "float32"), + _Extern(["float64"], "__ocml_ceil_f64", "float64"), + ], ), - lax.floor_p: _extern_elementwise( + lax.floor_p: _make_dispatch_table( "floor", - [ + cuda=[ _Extern(["float32"], "__nv_floorf", "float32"), _Extern(["float64"], "__nv_floor", "float64"), + _Fallback(["float16"], lambda ctx, x: math_dialect.floor(x)), + _Fallback(["bfloat16"], lambda ctx, x: math_dialect.floor(x)), + ], + rocm=[ + _Extern(["float32"], "__ocml_floor_f32", "float32"), + _Extern(["float64"], "__ocml_floor_f64", "float64"), + _Fallback(["float16"], lambda ctx, x: math_dialect.floor(x)), + _Fallback(["bfloat16"], lambda ctx, x: math_dialect.floor(x)), ], ), - lax.exp_p: _extern_elementwise( + lax.exp_p: _make_dispatch_table( "exp", - [ + cuda=[ _Extern(["float32"], "__nv_expf", "float32"), _Extern(["float64"], "__nv_exp", "float64"), + _Fallback(["float16"], lambda ctx, x: math_dialect.exp(x)), + _Fallback(["bfloat16"], lambda ctx, x: math_dialect.exp(x)), + ], + rocm=[ + _Fallback(["float32"], lambda ctx, x: math_dialect.exp(x)), + _Fallback(["float64"], lambda ctx, x: math_dialect.exp(x)), + _Fallback(["float16"], lambda ctx, x: math_dialect.exp(x)), + _Fallback(["bfloat16"], lambda ctx, x: math_dialect.exp(x)), ], ), - lax.exp2_p: _extern_elementwise( + lax.exp2_p: _make_dispatch_table( "exp2", - [ + cuda=[ _Extern(["float32"], "__nv_exp2f", "float32"), _Extern(["float64"], "__nv_exp2", "float64"), + _Fallback(["float16"], lambda ctx, x: math_dialect.exp2(x)), + _Fallback(["bfloat16"], lambda ctx, x: math_dialect.exp2(x)), + ], + rocm=[ + _Extern(["float32"], "__ocml_exp2_f32", "float32"), + _Extern(["float64"], "__ocml_exp2_f64", "float64"), + _Fallback(["float16"], lambda ctx, x: math_dialect.exp2(x)), + _Fallback(["bfloat16"], lambda ctx, x: math_dialect.exp2(x)), ], ), - lax.expm1_p: _extern_elementwise( + lax.expm1_p: _make_dispatch_table( "expm1", - [ + cuda=[ _Extern(["float32"], "__nv_expm1f", "float32"), _Extern(["float64"], "__nv_expm1", "float64"), ], + rocm=[ + _Extern(["float32"], "__ocml_expm1_f32", "float32"), + _Extern(["float64"], "__ocml_expm1_f64", "float64"), + ], ), - lax.log_p: _extern_elementwise( + lax.log_p: _make_dispatch_table( "log", - [ + cuda=[ _Extern(["float32"], "__nv_logf", "float32"), _Extern(["float64"], "__nv_log", "float64"), + _Fallback(["float16"], lambda ctx, x: math_dialect.log(x)), + _Fallback(["bfloat16"], lambda ctx, x: math_dialect.log(x)), + ], + rocm=[ + _Extern(["float32"], "__ocml_log_f32", "float32"), + _Extern(["float64"], "__ocml_log_f64", "float64"), + _Fallback(["float16"], lambda ctx, x: math_dialect.log(x)), + _Fallback(["bfloat16"], lambda ctx, x: math_dialect.log(x)), ], ), - lax.log1p_p: _extern_elementwise( + lax.log1p_p: _make_dispatch_table( "log1p", - [ + cuda=[ _Extern(["float32"], "__nv_log1pf", "float32"), _Extern(["float64"], "__nv_log1p", "float64"), ], + rocm=[ + _Extern(["float32"], "__ocml_log1p_f32", "float32"), + _Extern(["float64"], "__ocml_log1p_f64", "float64"), + ], ), - lax.sqrt_p: _extern_elementwise( + lax.sqrt_p: _make_dispatch_table( "sqrt", - [ + cuda=[ _Extern(["float32"], "__nv_sqrtf", "float32"), _Extern(["float64"], "__nv_sqrt", "float64"), + _Fallback(["float16"], lambda ctx, x: math_dialect.sqrt(x)), + _Fallback(["bfloat16"], lambda ctx, x: math_dialect.sqrt(x)), + ], + rocm=[ + _Extern(["float32"], "__ocml_sqrt_f32", "float32"), + _Extern(["float64"], "__ocml_sqrt_f64", "float64"), + _Fallback(["float16"], lambda ctx, x: math_dialect.sqrt(x)), + _Fallback(["bfloat16"], lambda ctx, x: math_dialect.sqrt(x)), ], ), - lax.pow_p: _extern_elementwise( + lax.pow_p: _make_dispatch_table( "pow", - [ + cuda=[ _Extern(["float32", "int32"], "__nv_powif", "float32"), _Extern(["float64", "int32"], "__nv_powi", "float64"), _Extern(["float32", "float32"], "__nv_powf", "float32"), _Extern(["float64", "float64"], "__nv_pow", "float64"), ], + rocm=[ + _Extern(["float32", "int32"], "__ocml_pown_f32", "float32"), + _Extern(["float64", "int32"], "__ocml_pown_f64", "float64"), + _Extern(["float32", "float32"], "__ocml_pow_f32", "float32"), + _Extern(["float64", "float64"], "__ocml_pow_f64", "float64"), + ], ), - lax.cbrt_p: _extern_elementwise( + lax.cbrt_p: _make_dispatch_table( "cbrt", - [ + cuda=[ _Extern(["float32"], "__nv_cbrtf", "float32"), _Extern(["float64"], "__nv_cbrt", "float64"), ], + rocm=[ + _Extern(["float32"], "__ocml_cbrt_f32", "float32"), + _Extern(["float64"], "__ocml_cbrt_f64", "float64"), + ], ), - lax.rsqrt_p: _extern_elementwise( + lax.rsqrt_p: _make_dispatch_table( "rsqrt", - [ + cuda=[ _Extern(["float32"], "__nv_rsqrtf", "float32"), _Extern(["float64"], "__nv_rsqrt", "float64"), ], + rocm=[ + _Extern(["float32"], "__ocml_rsqrt_f32", "float32"), + _Extern(["float64"], "__ocml_rsqrt_f64", "float64"), + ], ), - lax.sin_p: _extern_elementwise( + lax.sin_p: _make_dispatch_table( "sin", - [ + cuda=[ _Extern(["float32"], "__nv_sinf", "float32"), _Extern(["float64"], "__nv_sin", "float64"), + _Fallback(["float16"], lambda ctx, x: math_dialect.sin(x)), + _Fallback(["bfloat16"], lambda ctx, x: math_dialect.sin(x)), + ], + rocm=[ + _Extern(["float32"], "__ocml_sin_f32", "float32"), + _Extern(["float64"], "__ocml_sin_f64", "float64"), + _Fallback(["float16"], lambda ctx, x: math_dialect.sin(x)), + _Fallback(["bfloat16"], lambda ctx, x: math_dialect.sin(x)), ], ), - lax.cos_p: _extern_elementwise( + lax.cos_p: _make_dispatch_table( "cos", - [ + cuda=[ _Extern(["float32"], "__nv_cosf", "float32"), _Extern(["float64"], "__nv_cos", "float64"), + _Fallback(["float16"], lambda ctx, x: math_dialect.cos(x)), + _Fallback(["bfloat16"], lambda ctx, x: math_dialect.cos(x)), + ], + rocm=[ + _Extern(["float32"], "__ocml_cos_f32", "float32"), + _Extern(["float64"], "__ocml_cos_f64", "float64"), + _Fallback(["float16"], lambda ctx, x: math_dialect.cos(x)), + _Fallback(["bfloat16"], lambda ctx, x: math_dialect.cos(x)), ], ), - lax.tan_p: _extern_elementwise( + lax.tan_p: _make_dispatch_table( "tan", - [ + cuda=[ _Extern(["float32"], "__nv_tanf", "float32"), _Extern(["float64"], "__nv_tan", "float64"), ], + rocm=[ + _Extern(["float32"], "__ocml_tan_f32", "float32"), + _Extern(["float64"], "__ocml_tan_f64", "float64"), + ], ), - lax.asin_p: _extern_elementwise( + lax.asin_p: _make_dispatch_table( "asin", - [ + cuda=[ _Extern(["float32"], "__nv_asinf", "float32"), _Extern(["float64"], "__nv_asin", "float64"), ], + rocm=[ + _Extern(["float32"], "__ocml_asin_f32", "float32"), + _Extern(["float64"], "__ocml_asin_f64", "float64"), + ], ), - lax.acos_p: _extern_elementwise( + lax.acos_p: _make_dispatch_table( "acos", - [ + cuda=[ _Extern(["float32"], "__nv_acosf", "float32"), _Extern(["float64"], "__nv_acos", "float64"), ], + rocm=[ + _Extern(["float32"], "__ocml_acos_f32", "float32"), + _Extern(["float64"], "__ocml_acos_f64", "float64"), + ], ), - lax.atan_p: _extern_elementwise( + lax.atan_p: _make_dispatch_table( "atan", - [ + cuda=[ _Extern(["float32"], "__nv_atanf", "float32"), _Extern(["float64"], "__nv_atan", "float64"), ], + rocm=[ + _Extern(["float32"], "__ocml_atan_f32", "float32"), + _Extern(["float64"], "__ocml_atan_f64", "float64"), + ], ), - lax.atan2_p: _extern_elementwise( + lax.atan2_p: _make_dispatch_table( "atan2", - [ - _Extern(["float32"], "__nv_atan2f", "float32"), - _Extern(["float64"], "__nv_atan2", "float64"), + cuda=[ + _Extern(["float32", "float32"], "__nv_atan2f", "float32"), + _Extern(["float64", "float64"], "__nv_atan2", "float64"), + ], + rocm=[ + _Extern(["float32", "float32"], "__ocml_atan2_f32", "float32"), + _Extern(["float64", "float64"], "__ocml_atan2_f64", "float64"), ], ), - lax.sinh_p: _extern_elementwise( + lax.sinh_p: _make_dispatch_table( "sinh", - [ + cuda=[ _Extern(["float32"], "__nv_sinhf", "float32"), _Extern(["float64"], "__nv_sinh", "float64"), ], + rocm=[ + _Extern(["float32"], "__ocml_sinh_f32", "float32"), + _Extern(["float64"], "__ocml_sinh_f64", "float64"), + ], ), - lax.cosh_p: _extern_elementwise( + lax.cosh_p: _make_dispatch_table( "cosh", - [ + cuda=[ _Extern(["float32"], "__nv_coshf", "float32"), _Extern(["float64"], "__nv_cosh", "float64"), ], + rocm=[ + _Extern(["float32"], "__ocml_cosh_f32", "float32"), + _Extern(["float64"], "__ocml_cosh_f64", "float64"), + ], ), - lax.tanh_p: _extern_elementwise( + lax.tanh_p: _make_dispatch_table( "tanh", - [ + cuda=[ _Extern(["float32"], "__nv_tanhf", "float32"), _Extern(["float64"], "__nv_tanh", "float64"), ], + rocm=[ + _Extern(["float32"], "__ocml_tanh_f32", "float32"), + _Extern(["float64"], "__ocml_tanh_f64", "float64"), + ], ), - lax.asinh_p: _extern_elementwise( + lax.asinh_p: _make_dispatch_table( "asinh", - [ + cuda=[ _Extern(["float32"], "__nv_asinhf", "float32"), _Extern(["float64"], "__nv_asinh", "float64"), ], + rocm=[ + _Extern(["float32"], "__ocml_asinh_f32", "float32"), + _Extern(["float64"], "__ocml_asinh_f64", "float64"), + ], ), - lax.acosh_p: _extern_elementwise( + lax.acosh_p: _make_dispatch_table( "acosh", - [ - _Extern(["float32"], "__nv_acosf", "float32"), + cuda=[ + _Extern(["float32"], "__nv_acoshf", "float32"), _Extern(["float64"], "__nv_acosh", "float64"), ], + rocm=[ + _Extern(["float32"], "__ocml_acosh_f32", "float32"), + _Extern(["float64"], "__ocml_acosh_f64", "float64"), + ], ), - lax.atanh_p: _extern_elementwise( + lax.atanh_p: _make_dispatch_table( "atanh", - [ + cuda=[ _Extern(["float32"], "__nv_atanhf", "float32"), _Extern(["float64"], "__nv_atanh", "float64"), ], + rocm=[ + _Extern(["float32"], "__ocml_atanh_f32", "float32"), + _Extern(["float64"], "__ocml_atanh_f64", "float64"), + ], ), - lax.population_count_p: _extern_elementwise( + lax.population_count_p: _make_dispatch_table( "population_count", - [ + cuda=[ _Extern(["int32"], "__nv_popc", "int32"), - _Extern(["int64"], "__nv_popcll", "int64"), + _Extern(["int64"], "__nv_popcll", "int32"), + ], + rocm=[ + _Fallback(["int32"], lambda ctx, x: math_dialect.ctpop(x)), + _Fallback(["int64"], lambda ctx, x: math_dialect.ctpop(x)), ], ), - lax.clz_p: _extern_elementwise( + lax.clz_p: _make_dispatch_table( "clz", - [ + cuda=[ _Extern(["int32"], "__nv_clz", "int32"), - _Extern(["int64"], "__nv_clzll", "int64"), + _Extern(["int64"], "__nv_clzll", "int32"), + ], + rocm=[ + _Fallback(["int32"], lambda ctx, x: math_dialect.ctlz(x)), + _Fallback(["int64"], lambda ctx, x: math_dialect.ctlz(x)), ], ), - lax.nextafter_p: _extern_elementwise( + lax.nextafter_p: _make_dispatch_table( "nextafter", - [ - _Extern(["float32", "float32"], "_nv_nextafterf", "float32"), - _Extern(["float64", "float64"], "_nv_nextafter", "float64"), + cuda=[ + _Extern(["float32", "float32"], "__nv_nextafterf", "float32"), + _Extern(["float64", "float64"], "__nv_nextafter", "float64"), + ], + rocm=[ + _Extern(["float32", "float32"], "__ocml_nextafter_f32", "float32"), + _Extern(["float64", "float64"], "__ocml_nextafter_f64", "float64"), ], ), }) @@ -824,35 +1007,39 @@ def _mul(x: ir.Value, y: ir.Value) -> ir.Value: raise NotImplementedError(f"unsupported types: {x.type} and {y.type}") -def _floordiv(x: ir.Value, y: ir.Value) -> ir.Value: +def _floordiv(x: ir.Value, y: ir.Value, *, signed: bool) -> ir.Value: assert x.type == y.type, (str(x.type), str(y.type)) x_element_type = _element_type(x.type) + if isinstance(x_element_type, (ir.F32Type, ir.F64Type)): + return arith_dialect.divf(x, y) if not isinstance(x_element_type, ir.IntegerType): raise NotImplementedError(f"unsupported types: {x.type} and {y.type}") - if x_element_type.is_signed: + if signed: return arith_dialect.divsi(x, y) else: return arith_dialect.divui(x, y) -def _truediv(x: ir.Value, y: ir.Value) -> ir.Value: +def _truediv(x: ir.Value, y: ir.Value, *, signed: bool) -> ir.Value: assert x.type == y.type, (str(x.type), str(y.type)) x_element_type = _element_type(x.type) if isinstance(x_element_type, ir.IntegerType): x_element_type = ir.F32Type.get() - x = _int_float_cast(x, x_element_type) - y = _int_float_cast(y, x_element_type) - if isinstance(x_element_type, ir.FloatType): + x = _int_float_cast(x, x_element_type, signed=signed) + y = _int_float_cast(y, x_element_type, signed=signed) + if isinstance(x_element_type, (ir.F32Type, ir.F64Type)): return arith_dialect.divf(x, y) raise NotImplementedError(f"unsupported types: {x.type} and {y.type}") -def _mod(x: ir.Value, y: ir.Value) -> ir.Value: +def _mod(x: ir.Value, y: ir.Value, *, signed: bool) -> ir.Value: assert x.type == y.type, (str(x.type), str(y.type)) x_element_type = _element_type(x.type) + if isinstance(x_element_type, ir.FloatType): + return arith_dialect.remf(x, y) if not isinstance(x_element_type, ir.IntegerType): raise NotImplementedError(f"unsupported types: {x.type} and {y.type}") - if x_element_type.is_signed: + if signed: return arith_dialect.remsi(x, y) else: return arith_dialect.remui(x, y) @@ -864,13 +1051,13 @@ def _cmp( si_pred: arith_dialect.CmpIPredicate, ui_pred: arith_dialect.CmpIPredicate, f_pred: arith_dialect.CmpFPredicate, + *, + signed: bool, ) -> ir.Value: assert x.type == y.type, (str(x.type), str(y.type)) x_element_type = _element_type(x.type) if isinstance(x_element_type, ir.IntegerType): - return arith_dialect.cmpi( - si_pred if x_element_type.is_signed else ui_pred, x, y - ) + return arith_dialect.cmpi(si_pred if signed else ui_pred, x, y) elif isinstance(x_element_type, ir.FloatType): return arith_dialect.cmpf(f_pred, x, y) else: @@ -919,29 +1106,58 @@ def _cmp( lax.add_p: _add, lax.sub_p: _sub, lax.mul_p: _mul, - lax.rem_p: _mod, lax.and_p: arith_dialect.andi, lax.or_p: arith_dialect.ori, lax.xor_p: arith_dialect.xori, lax.shift_left_p: arith_dialect.shli, lax.shift_right_arithmetic_p: arith_dialect.shrsi, lax.shift_right_logical_p: arith_dialect.shrui, + ad_util.add_any_p: _add, +} + +for prim, fn in _JAX_TO_TRITON_BINARY.items(): + + def signless_rule(ctx: LoweringRuleContext, x, y, fn=fn): + x, y = _bcast(x, y, *ctx.avals_in, *ctx.avals_out) + return fn(x, y) + + triton_lowering_rules[prim] = signless_rule + + +_JAX_TO_TRITON_SIGNED_BINARY = { + lax.rem_p: _mod, lax.eq_p: _equal, lax.ne_p: _not_equal, lax.gt_p: _greater_than, lax.ge_p: _greater_equal, lax.lt_p: _less_than, lax.le_p: _less_equal, - ad_util.add_any_p: _add, } -for prim, fn in _JAX_TO_TRITON_BINARY.items(): +for prim, fn in _JAX_TO_TRITON_SIGNED_BINARY.items(): - def rule(ctx: LoweringRuleContext, x, y, fn=fn): + def signed_rule(ctx: LoweringRuleContext, x, y, fn=fn): + x_aval, _ = ctx.avals_in x, y = _bcast(x, y, *ctx.avals_in, *ctx.avals_out) - return fn(x, y) + return fn(x, y, signed=jnp.issubdtype(x_aval.dtype, jnp.signedinteger)) - triton_lowering_rules[prim] = rule + triton_lowering_rules[prim] = signed_rule + + +@register_lowering(primitives.debug_print_p) +def debug_print_lowering_rule( + ctx: LoweringRuleContext, + *args: ir.Value, + fmt: str, + has_placeholders: bool, +): + if has_placeholders: + raise ValueError( + "pl.debug_print() does not support placeholders when lowering to Triton" + ) + + tt_dialect.print_(f" {fmt} ", hex=False, args=args) + return () def _set_attr(v: ir.Value, name: str, attr: ir.Attribute) -> None: @@ -959,6 +1175,7 @@ def _set_attr(v: ir.Value, name: str, attr: ir.Attribute) -> None: op.attributes[name] = attr +@register_lowering(primitives.multiple_of_p) def _multiple_of_rule(ctx: LoweringRuleContext, x, values: Sequence[int]): [x_aval] = ctx.avals_in assert max(1, len(x_aval.shape)) == len(values) @@ -970,9 +1187,7 @@ def _multiple_of_rule(ctx: LoweringRuleContext, x, values: Sequence[int]): return x -triton_lowering_rules[primitives.multiple_of_p] = _multiple_of_rule - - +@register_lowering(primitives.max_contiguous_p) def _max_contiguous_rule(ctx: LoweringRuleContext, x, values: Sequence[int]): [x_aval] = ctx.avals_in assert len(x_aval.shape) == len(values) @@ -984,25 +1199,38 @@ def _max_contiguous_rule(ctx: LoweringRuleContext, x, values: Sequence[int]): return x -triton_lowering_rules[primitives.max_contiguous_p] = _max_contiguous_rule - - +@register_lowering(sp.broadcast_to_p) def _broadcast_to_rule(ctx: LoweringRuleContext, x, shape: Sequence[int]): (x_aval,) = ctx.avals_in return _bcast_to(_ensure_ir_value(x, x_aval), shape) -triton_lowering_rules[sp.broadcast_to_p] = _broadcast_to_rule +@register_lowering(lax.integer_pow_p) +def _integer_pow_rule(ctx: LoweringRuleContext, x, *, y: int): + if y == 0: + return _full(x.type, 1) + is_reciprocal = y < 0 + if is_reciprocal: + y = -y -def _integer_pow(a, *, y): - if y == 2: - return a * a - if y == 3: - return a * a * a - if y == -2: - return 1.0 / (a * a) - return jax.lax.pow(a, y) + acc = None + while y > 0: + y, mod = divmod(y, 2) + if mod: + acc = x if acc is None else _mul(acc, x) + if y > 0: + x = _mul(x, x) + assert acc is not None + + [x_aval] = ctx.avals_in + [out_aval] = ctx.avals_out + acc = _cast(acc, x_aval.dtype, out_aval.dtype) + if is_reciprocal: + signed = jnp.issubdtype(out_aval.dtype, jnp.signedinteger) + return _truediv(_full(acc.type, 1), acc, signed=signed) + else: + return acc def lower_fun( @@ -1022,7 +1250,6 @@ def f_lowered(ctx: LoweringRuleContext, *args, **params): _JAX_FN_MAPPING = { lax.clamp_p: lambda min, a, max: jnp.minimum(jnp.maximum(min, a), max), - lax.integer_pow_p: _integer_pow, lax.logistic_p: lambda a: 1 / (1 + jnp.exp(-a)), } @@ -1030,6 +1257,7 @@ def f_lowered(ctx: LoweringRuleContext, *args, **params): triton_lowering_rules[prim] = lower_fun(fn, multiple_results=False) +@register_lowering(lax.min_p) def _min_lowering_rule(ctx: LoweringRuleContext, x, y): # TODO(slebedev): Consider allowing customizing nan behavior. x_aval, y_aval = ctx.avals_in @@ -1047,9 +1275,7 @@ def _min_lowering_rule(ctx: LoweringRuleContext, x, y): return arith_dialect.minui(x, y) -triton_lowering_rules[lax.min_p] = _min_lowering_rule - - +@register_lowering(lax.max_p) def _max_lowering_rule(ctx: LoweringRuleContext, x, y): # TODO(slebedev): Consider allowing customizing nan behavior. x_aval, y_aval = ctx.avals_in @@ -1067,29 +1293,39 @@ def _max_lowering_rule(ctx: LoweringRuleContext, x, y): return arith_dialect.maxui(x, y) -triton_lowering_rules[lax.max_p] = _max_lowering_rule - - +@register_lowering(lax.div_p) def _div_lowering_rule(ctx: LoweringRuleContext, x, y): x_aval, y_aval = ctx.avals_in x, y = _bcast(x, y, *ctx.avals_in, *ctx.avals_out) - if np.issubdtype(x_aval.dtype, np.floating) or np.issubdtype( + signed = jnp.issubdtype(x_aval.dtype, jnp.signedinteger) or jnp.issubdtype( + y_aval.dtype, jnp.signedinteger + ) + if jnp.issubdtype(x_aval.dtype, np.floating) or jnp.issubdtype( y_aval.dtype, np.floating ): - return _truediv(x, y) - return _floordiv(x, y) + return _truediv(x, y, signed=signed) + return _floordiv(x, y, signed=signed) -triton_lowering_rules[lax.div_p] = _div_lowering_rule +@register_lowering(lax.sign_p) +def _sign_lowering_rule(ctx: LoweringRuleContext, x): + [x_aval] = ctx.avals_in + signed = jnp.issubdtype(x_aval.dtype, jnp.signedinteger) + zero = _full(x.type, 0) + return _sub( + _cast(_greater_than(x, zero, signed=signed), jnp.bool_, x_aval.dtype), + _cast(_less_than(x, zero, signed=signed), jnp.bool_, x_aval.dtype), + ) +@register_lowering(lax.iota_p) def _iota_lowering_rule(ctx: LoweringRuleContext, *, dtype, shape, dimension): - if dimension != 0: - raise NotImplementedError - return _cast(_make_range(0, *shape), _dtype_to_ir_type(dtype)) - - -triton_lowering_rules[lax.iota_p] = _iota_lowering_rule + iota = _make_range(0, shape[dimension]) + iota = _cast(iota, jnp.int32, dtype) + for i in range(len(shape)): + if i != dimension: + iota = _expand_dims(iota, i) + return _bcast_to(iota, shape) def _element_type(t: ir.Type) -> ir.Type: @@ -1161,49 +1397,66 @@ def _float_float_cast(src: ir.Value, dst_type: ir.Type) -> ir.Value: raise NotImplementedError -def _int_int_cast(src: ir.Value, dst_type: ir.Type) -> ir.Value: +def _int_int_cast(src: ir.Value, dst_type: ir.Type, signed: bool) -> ir.Value: src_element_type = ir.IntegerType(_element_type(src.type)) dst_element_type = ir.IntegerType(_element_type(dst_type)) assert src_element_type != dst_element_type if dst_element_type.width == 1: - return _not_equal(src, _full(src.type, 0)) + return _not_equal(src, _full(src.type, 0), signed=signed) - is_signed = src_element_type.is_signed and src_element_type.width != 1 if src_element_type.width == dst_element_type.width: return arith_dialect.bitcast(dst_type, src) elif src_element_type.width > dst_element_type.width: return arith_dialect.trunci(dst_type, src) - elif is_signed: + elif signed and src_element_type.width != 1: return arith_dialect.extsi(dst_type, src) else: return arith_dialect.extui(dst_type, src) -def _float_int_cast(src: ir.Value, dst_type: ir.Type) -> ir.Value: +def _float_int_cast( + src: ir.Value, dst_type: ir.Type, *, signed: bool +) -> ir.Value: src_element_type = _element_type(src.type) if not isinstance(src_element_type, (ir.BF16Type, ir.F16Type, ir.F32Type, ir.F64Type)): raise NotImplementedError(f"cannot cast {src} tp {dst_type}") dst_element_type = ir.IntegerType(_element_type(dst_type)) if dst_element_type.width == 1: - return _not_equal(src, _full(src.type, 0)) - elif dst_element_type.is_signed: + return _not_equal(src, _full(src.type, 0), signed=signed) + elif signed: return arith_dialect.fptosi(dst_type, src) else: return arith_dialect.fptoui(dst_type, src) -def _int_float_cast(src: ir.Value, dst_type: ir.Type) -> ir.Value: +def _int_float_cast( + src: ir.Value, dst_type: ir.Type, *, signed: bool +) -> ir.Value: src_element_type = ir.IntegerType(_element_type(src.type)) dst_element_type = _element_type(dst_type) - if not isinstance(dst_element_type, (ir.BF16Type, ir.F16Type, ir.F32Type, ir.F64Type)): + if not isinstance( + dst_element_type, (ir.BF16Type, ir.F16Type, ir.F32Type, ir.F64Type) + ): raise NotImplementedError(f"cannot cast {src} tp {dst_type}") - if src_element_type.width == 1 or not src_element_type.is_signed: + if src_element_type.width == 1 or not signed: return arith_dialect.uitofp(dst_type, src) else: return arith_dialect.sitofp(dst_type, src) -def _cast(src: ir.Value, dst_type: ir.Type) -> ir.Value: +def _cast( + src: ir.Value, + src_type: jax.typing.DTypeLike, + dst_type: jax.typing.DTypeLike, +) -> ir.Value: + return _ir_cast( + src, + _dtype_to_ir_type(dst_type), + signed=jnp.issubdtype(src_type, jnp.signedinteger), + ) + + +def _ir_cast(src: ir.Value, dst_type: ir.Type, *, signed: bool) -> ir.Value: if ir.RankedTensorType.isinstance( src.type ) and not ir.RankedTensorType.isinstance(dst_type): @@ -1227,7 +1480,9 @@ def _cast(src: ir.Value, dst_type: ir.Type) -> ir.Value: if isinstance(src_element_type, (ir.F16Type, ir.BF16Type)) and not isinstance( dst_element_type, ir.F32Type ): - return _cast(_cast(src, ir.F32Type.get()), dst_type) + return _ir_cast( + _ir_cast(src, ir.F32Type.get(), signed=False), dst_type, signed=False + ) if isinstance(src_element_type, ir.FloatType) and isinstance( dst_element_type, ir.FloatType @@ -1237,26 +1492,26 @@ def _cast(src: ir.Value, dst_type: ir.Type) -> ir.Value: if isinstance(src_element_type, ir.IntegerType) and isinstance( dst_element_type, ir.IntegerType ): - return _int_int_cast(src, dst_type) + return _int_int_cast(src, dst_type, signed=signed) if isinstance(src_element_type, ir.FloatType) and isinstance( dst_element_type, ir.IntegerType ): - return _float_int_cast(src, dst_type) + return _float_int_cast(src, dst_type, signed=signed) if isinstance(src_element_type, ir.IntegerType) and isinstance( dst_element_type, ir.FloatType ): - return _int_float_cast(src, dst_type) + return _int_float_cast(src, dst_type, signed=signed) if tt_dialect.PointerType.isinstance(src_element_type) and isinstance( dst_element_type, ir.IntegerType ): if dst_element_type.width == 64: return tt_dialect.ptr_to_int(dst_type, src) - else: - x = _cast(src, ir.IntegerType.get_signless(64)) + elif dst_element_type.width == 1: + x = _ir_cast(src, ir.IntegerType.get_signless(64), signed=signed) zero = _full(x.type, 0) - return _cast(_not_equal(x, zero), dst_type) + return _ir_cast(_not_equal(x, zero, signed=signed), dst_type, signed=signed) if isinstance( src_element_type, ir.IntegerType ) and tt_dialect.PointerType.isinstance(dst_element_type): @@ -1269,6 +1524,7 @@ def _cast(src: ir.Value, dst_type: ir.Type) -> ir.Value: raise NotImplementedError(f"cannot cast {src} to {dst_type}") +@register_lowering(lax.convert_element_type_p) def _convert_element_type_lowering_rule( ctx: LoweringRuleContext, x, *, new_dtype, weak_type ): @@ -1276,14 +1532,10 @@ def _convert_element_type_lowering_rule( x = _ensure_ir_value(x, x_aval) if new_dtype == x_aval.dtype: return x - return _cast(x, _dtype_to_ir_type(new_dtype)) - - -triton_lowering_rules[lax.convert_element_type_p] = ( - _convert_element_type_lowering_rule -) + return _cast(x, x_aval.dtype, new_dtype) +@register_lowering(lax.select_n_p) def select_n_lowering_rule(ctx: LoweringRuleContext, pred, x, y): pred_aval, a_aval, b_aval = ctx.avals_in [out_aval] = ctx.avals_out @@ -1292,9 +1544,7 @@ def select_n_lowering_rule(ctx: LoweringRuleContext, pred, x, y): return arith_dialect.select(pred, y, x) -triton_lowering_rules[lax.select_n_p] = select_n_lowering_rule - - +@register_lowering(lax.broadcast_in_dim_p) def _broadcast_in_dim_lowering_rule( ctx: LoweringRuleContext, x, *, broadcast_dimensions, shape ): @@ -1307,19 +1557,12 @@ def _broadcast_in_dim_lowering_rule( return _bcast_to(x, shape) -triton_lowering_rules[jax.lax.broadcast_in_dim_p] = ( - _broadcast_in_dim_lowering_rule -) - - +@register_lowering(lax.squeeze_p) def _squeeze_lowering_rule(ctx: LoweringRuleContext, a, *, dimensions): del dimensions return _reshape_lowering_rule(ctx, a, new_sizes=None, dimensions=None) -triton_lowering_rules[lax.squeeze_p] = _squeeze_lowering_rule - - def _reshape(x: ir.Value, shape: Sequence[int]) -> ir.Value: if not shape: raise ValueError("cannot reshape to an empty shape") @@ -1331,6 +1574,7 @@ def _reshape(x: ir.Value, shape: Sequence[int]) -> ir.Value: ) +@register_lowering(lax.reshape_p) def _reshape_lowering_rule( ctx: LoweringRuleContext, a, *, new_sizes, dimensions ): @@ -1374,9 +1618,6 @@ def _reshape_lowering_rule( return a -triton_lowering_rules[jax.lax.reshape_p] = _reshape_lowering_rule - - def _compute_pointers_from_indices( root_ptr: ir.Value, block_info: BlockInfo | None, @@ -1415,14 +1656,22 @@ def _compute_pointers_from_indices( else: index = next(indexer_iter) if isinstance(index, primitives.Slice): - # Handle slices with static and dynamic indices and static sizes - if isinstance(index.start, int): - ptr_dim_offset = _make_range(index.start, index.start + index.size) - else: + if index.is_dynamic_start: + # Compute the offset as start + range(0, size). ptr_dim_offset = _add( _bcast_to(index.start, [index.size]), - _cast(_make_range(0, index.size), index.start.type), + _ir_cast(_make_range(0, index.size), index.start.type, signed=False), ) + elif index.stride > 1: + # Compute the offset as start + range(0, size) * stride. + iota = _make_range(0, index.size) + ptr_dim_offset = _add( + _bcast_to(_i32_constant(index.start), [index.size]), + _mul(iota, _full(iota.type, index.stride)), + ) + else: + ptr_dim_offset = _make_range(index.start, index.start + index.size) + # We need to add broadcastable dimensions for the advanced int indexing # and for previous slices num_left_expand_dims = len(int_indexer_shape) + other_shape_idx @@ -1459,7 +1708,7 @@ def _compute_pointers_from_indices( ptr_dim_offset = _bcast_to(ptr_dim_offset, indexer_shape) index_type = ir.IntegerType(_element_type(ptr_dim_offset.type)) if start_offset is not None: - start_offset = _cast(start_offset, index_type) + start_offset = _ir_cast(start_offset, index_type, signed=False) ptr_dim_offset = _add( ptr_dim_offset, _bcast_to(start_offset, indexer_shape) ) @@ -1476,14 +1725,7 @@ def _compute_pointers_from_indices( ) -def _pack_indices(non_slice_idx, indexed_dims): - non_slice_idx_iter = iter(non_slice_idx) - return tuple( - next(non_slice_idx_iter) if indexed else slice(None) - for indexed in indexed_dims - ) - - +@register_lowering(sp.get_p) def _get_lowering_rule(ctx: LoweringRuleContext, ptr, *idx, tree): indexers = tree_util.tree_unflatten(tree, idx) if not tt_dialect.PointerType.isinstance(ptr.type): @@ -1503,27 +1745,10 @@ def _get_lowering_rule(ctx: LoweringRuleContext, ptr, *idx, tree): ) -triton_lowering_rules[sp.get_p] = _get_lowering_rule - - _STR_TO_EVICTION_POLICY = {str(e): e for e in tt_dialect.EvictionPolicy} _STR_TO_CACHE_MODIFIER = {str(c): c for c in tt_dialect.CacheModifier} -def _infer_load_return_type(ptr: ir.Value) -> ir.Type: - if ir.RankedTensorType.isinstance(ptr.type): - ptr_type = ir.RankedTensorType(ptr.type) - element_type = tt_dialect.PointerType(ptr_type.element_type) - return ir.RankedTensorType.get( - ptr_type.shape, - element_type.pointee_type, - ptr_type.encoding, - ) - else: - ptr_type = tt_dialect.PointerType(ptr.type) - return ptr_type.pointee_type - - def _load( ptr: ir.Value, mask: ir.Value | None = None, @@ -1570,15 +1795,16 @@ def _load( is_int1 = isinstance(pointee_type, ir.IntegerType) and pointee_type.width == 1 if is_int1: pointee_type = ir.IntegerType.get_signless(8) - ptr = _cast( - ptr, tt_dialect.PointerType.get(pointee_type, ptr_type.address_space) + ptr = _ir_cast( + ptr, + tt_dialect.PointerType.get(pointee_type, ptr_type.address_space), + signed=False, ) if other is not None: - other = _cast(other, pointee_type) + other = _ir_cast(other, pointee_type, signed=False) result = tt_dialect.load( - _infer_load_return_type(ptr), ptr, mask=mask, other=other, @@ -1587,10 +1813,13 @@ def _load( is_volatile=is_volatile, ) return ( - result if not is_int1 else _cast(result, ir.IntegerType.get_signless(1)) + result + if not is_int1 + else _ir_cast(result, ir.IntegerType.get_signless(1), signed=False) ) +@register_lowering(primitives.load_p) def _masked_load_lowering_rule( ctx: LoweringRuleContext, *args_flat, @@ -1626,9 +1855,7 @@ def _masked_load_lowering_rule( ) -triton_lowering_rules[primitives.load_p] = _masked_load_lowering_rule - - +@register_lowering(sp.swap_p) def _swap_lowering_rule(ctx: LoweringRuleContext, ptr, value, *idx, tree): indexers = tree_util.tree_unflatten(tree, idx) if not tt_dialect.PointerType.isinstance(ptr.type): @@ -1643,9 +1870,6 @@ def _swap_lowering_rule(ctx: LoweringRuleContext, ptr, value, *idx, tree): ) -triton_lowering_rules[sp.swap_p] = _swap_lowering_rule - - def _store( ptr: ir.Value, value: ir.Value, @@ -1688,16 +1912,19 @@ def _store( pointee_type = ptr_type.pointee_type if isinstance(pointee_type, ir.IntegerType) and pointee_type.width == 1: pointee_type = ir.IntegerType.get_signless(8) - ptr = _cast( - ptr, tt_dialect.PointerType.get(pointee_type, ptr_type.address_space) + ptr = _ir_cast( + ptr, + tt_dialect.PointerType.get(pointee_type, ptr_type.address_space), + signed=False, ) - value = _cast(value, pointee_type) + value = _ir_cast(value, pointee_type, signed=False) return tt_dialect.store( ptr, value, mask=mask, cache=cache_modifier, evict=eviction_policy ) +@register_lowering(primitives.swap_p) def _masked_swap_lowering_rule( ctx: LoweringRuleContext, *args_flat, args_tree, eviction_policy ): @@ -1722,9 +1949,7 @@ def _masked_swap_lowering_rule( return old_value -triton_lowering_rules[primitives.swap_p] = _masked_swap_lowering_rule - - +@register_lowering(sp.addupdate_p) def _addupdate_lowering_rule(ctx: LoweringRuleContext, ptr, value, *idx, tree): indexers = tree_util.tree_unflatten(tree, idx) if not tt_dialect.PointerType.isinstance(ptr.type): @@ -1746,16 +1971,11 @@ def _addupdate_lowering_rule(ctx: LoweringRuleContext, ptr, value, *idx, tree): return [] -triton_lowering_rules[sp.addupdate_p] = _addupdate_lowering_rule - - +@register_lowering(lax.transpose_p) def _transpose_lowering(ctx: LoweringRuleContext, x, *, permutation): return tt_dialect.trans(x, permutation) -triton_lowering_rules[lax.transpose_p] = _transpose_lowering - - def _check_dot_operands( x_type: ir.RankedTensorType, y_type: ir.RankedTensorType, options: Any ): @@ -1817,12 +2037,24 @@ def _dot( else: max_num_imprecise_acc = 0 - return tt_dialect.dot(x, y, acc, allow_tf32, max_num_imprecise_acc) + # Ideally, replace all allow_tf32 usages with InputPrecision directly. + input_precision = tt_dialect.InputPrecision.IEEE + if allow_tf32: + input_precision = tt_dialect.InputPrecision.TF32 + + return tt_dialect.dot( + x, + y, + acc, + max_num_imprecise_acc=max_num_imprecise_acc, + input_precision=input_precision + ) _TF32_PRECISIONS = (lax.Precision.HIGH, lax.Precision.DEFAULT) +@register_lowering(lax.dot_general_p) def _dot_general_lowering( ctx: LoweringRuleContext, a, @@ -1859,13 +2091,11 @@ def _dot_general_lowering( allow_tf32=allow_tf32, out_type=_dtype_to_ir_type(acc_dtype), ), - _dtype_to_ir_type(out_dtype), + acc_dtype, + out_dtype, ) -triton_lowering_rules[lax.dot_general_p] = _dot_general_lowering - - def _reduction_lowering(body, ctx: LoweringRuleContext, a, axes): flat_args = tree_util.tree_leaves(a) (axis,) = axes @@ -1978,6 +2208,7 @@ def _reduce_argmin_combine(left, right): ) +@register_lowering(pjit.pjit_p) def _pjit_lowering_rule(ctx: LoweringRuleContext, *args, jaxpr, **_): if jaxpr.consts: raise NotImplementedError @@ -1986,9 +2217,8 @@ def _pjit_lowering_rule(ctx: LoweringRuleContext, *args, jaxpr, **_): ) -triton_lowering_rules[pjit.pjit_p] = _pjit_lowering_rule - - +@register_lowering(jax_core.closed_call_p) +@register_lowering(custom_derivatives.custom_jvp_call_p) def _closed_call_lowering_rule( ctx: LoweringRuleContext, *args, call_jaxpr, **_ ): @@ -1998,17 +2228,11 @@ def _closed_call_lowering_rule( return lower_jaxpr_to_triton_ir(ctx.context, jaxpr, ctx.block_infos, *args) -triton_lowering_rules[jax_core.closed_call_p] = _closed_call_lowering_rule -triton_lowering_rules[custom_derivatives.custom_jvp_call_p] = ( - _closed_call_lowering_rule -) - - +@register_lowering(ad_checkpoint.remat_p) def _remat_lowering_rule(ctx: LoweringRuleContext, *args, jaxpr, **_): return lower_jaxpr_to_triton_ir(ctx.context, jaxpr, ctx.block_infos, *args) -triton_lowering_rules[ad_checkpoint.remat_p] = _remat_lowering_rule triton_lowering_rules[ad_util.stop_gradient_p] = lambda _, x: x @@ -2022,6 +2246,7 @@ def _is_read_only(ref_effects) -> bool: return isinstance(eff, state.ReadEffect) +@register_lowering(for_loop.for_p) def _for_lowering_rule( ctx: LoweringRuleContext, *args, @@ -2039,7 +2264,9 @@ def _for_lowering_rule( step = _i32_constant(1) init_args = map(_ensure_ir_value, args, ctx.avals_in) # Partially discharge state from jaxpr for non-pointers - should_discharge = [not isinstance(a, AbstractRef) for a in ctx.avals_in] + should_discharge = [ + not isinstance(a, state.AbstractRef) for a in ctx.avals_in + ] discharged_jaxpr, () = discharge.discharge_state( jaxpr, (), should_discharge=[True, *should_discharge] ) @@ -2073,9 +2300,6 @@ def _for_lowering_rule( return merge_lists(is_loop_arg, non_loop_args, list(for_op.results_)) -triton_lowering_rules[for_loop.for_p] = _for_lowering_rule - - def _lower_jaxpr_to_for_loop( ctx: LoweringRuleContext, jaxpr: jax_core.Jaxpr, @@ -2112,6 +2336,7 @@ def _lower_jaxpr_to_for_loop( return list(for_op.results_) +@register_lowering(lax.scan_p) def _scan_lowering_rule( ctx: LoweringRuleContext, *args, @@ -2122,7 +2347,9 @@ def _scan_lowering_rule( unroll, num_consts, num_carry, + _split_transpose, ): + del _split_transpose # Only implements fori_loop-like scans num_extensive = len(args) - num_consts - num_carry if num_extensive: raise NotImplementedError @@ -2159,9 +2386,6 @@ def _scan_lowering_rule( return for_out -triton_lowering_rules[lax.scan_p] = _scan_lowering_rule - - def _maybe_pattern_match_fori_loop( ctx: LoweringRuleContext, *args, @@ -2242,6 +2466,7 @@ def _maybe_pattern_match_fori_loop( return [ub, ub, *for_out] +@register_lowering(lax.while_p) def _while_lowering_rule( ctx: LoweringRuleContext, *args, @@ -2311,15 +2536,12 @@ def _while_lowering_rule( return all_out[cond_nconsts + body_nconsts :] -triton_lowering_rules[lax.while_p] = _while_lowering_rule - - +@register_lowering(lax.cond_p) def _cond_lowering_rule( ctx: LoweringRuleContext, index, *args, # *consts, *ops branches, # tuple(jaxprs) - linear, ): block_infos = ctx.block_infos @@ -2331,7 +2553,7 @@ def to_type(out_aval): out_types = [to_type(out) for out in ctx.avals_out] - use_branch0 = _equal(index, _ir_constant(0, index.type)) + use_branch0 = _equal(index, _ir_constant(0, index.type), signed=False) # TODO(bjp): Switch to scf.index_switch once exposed in triton.cc if_op = scf_dialect.IfOp(use_branch0, out_types, hasElse=True) with ir.InsertionPoint.at_block_begin(if_op.then_block): @@ -2349,7 +2571,6 @@ def to_type(out_aval): _sub(index, _ir_constant(1, index.type)), *args, branches=branches[1:], - linear=linear, ) else: outs1 = lower_jaxpr_to_triton_ir( @@ -2362,9 +2583,6 @@ def to_type(out_aval): return list(if_op.results_) -triton_lowering_rules[lax.cond_p] = _cond_lowering_rule - - def _ensure_ir_value(x: object, aval: jax_core.ShapedArray) -> ir.Value: if isinstance(x, ir.Value): return x @@ -2393,7 +2611,7 @@ def _i64_constant(v: int) -> ir.Value: def _dtype_to_ir_type(dtype: jnp.dtype) -> ir.Type: - if np.issubdtype(dtype, np.integer): + if jnp.issubdtype(dtype, np.integer): # All integer types in Triton are signless. return ir.IntegerType.get_signless(dtype.itemsize * 8) return mlir.dtype_to_ir_type(dtype) diff --git a/jax/_src/pallas/triton/pallas_call_registration.py b/jax/_src/pallas/triton/pallas_call_registration.py index aa7b90b1bc15..e6d521692ec2 100644 --- a/jax/_src/pallas/triton/pallas_call_registration.py +++ b/jax/_src/pallas/triton/pallas_call_registration.py @@ -14,79 +14,18 @@ """Module registering a lowering rule for pallas_call on GPU.""" -# TODO(sharadmv): Enable type checking. -# mypy: ignore-errors - from __future__ import annotations -import dataclasses import io from typing import Any -import zlib import jax from jax import core as jax_core -from jax._src import config from jax._src.interpreters import mlir -from jax._src.lib import gpu_triton as triton_kernel_call_lib from jax._src.lib.mlir import ir from jax._src.pallas import core as pallas_core from jax._src.pallas.pallas_call import pallas_call_p from jax._src.pallas.triton import lowering -from jax._src import util - - -@dataclasses.dataclass -class CompilationResult: - kernel_name: str - ttir: str - ptx: str - shared_mem_bytes: int - compute_capability: int - lowering_result: lowering.LoweringResult - - -@util.weakref_lru_cache -def compile_jaxpr( - jaxpr: jax_core.Jaxpr, - in_shapes, - grid_mapping: pallas_core.GridMapping, - name: str, - num_warps: int, - num_stages: int, - debug: bool, -) -> CompilationResult: - from jax_triton.triton_lib import compile_ttir_to_ptx_inplace - import triton.backends.nvidia.compiler as cb - - # TODO(sharadmv): handle multiple devices, right now we assume device 0 - # which is fine when we have multiple of the same GPU but this won't work in - # general. - device = 0 - compute_capability = triton_kernel_call_lib.get_compute_capability(device) - target = ("cuda", compute_capability) - cuda_backend = cb.CUDABackend(target) - cuda_options = cuda_backend.parse_options( - dict( - num_warps=num_warps, - num_stages=num_stages, - debug=debug, - ) - ) - lowering_result = lowering.lower_jaxpr_to_triton_module( - jaxpr, in_shapes, grid_mapping, name, cuda_options - ) - - ttir = str(lowering_result.module) - ptx, name, shared_mem_bytes, _ = compile_ttir_to_ptx_inplace( - lowering_result.module, - cuda_backend, - cuda_options, - compute_capability, - ) - return CompilationResult( - name, ttir, ptx, shared_mem_bytes, compute_capability, lowering_result - ) def normalize_grid(grid: pallas_core.StaticGrid) -> tuple[int, int, int]: @@ -94,119 +33,63 @@ def normalize_grid(grid: pallas_core.StaticGrid) -> tuple[int, int, int]: grid = (grid,) elif len(grid) > 3: raise ValueError("`grid` should have three or fewer dimensions.") - return tuple(grid) + (1,) * (3 - len(grid)) + return tuple(grid) + (1,) * (3 - len(grid)) # type: ignore def avals_to_layouts(avals): return [list(reversed(range(aval.ndim))) for aval in avals] -def _pallas_call_ptx_lowering( +def pallas_call_lowering( ctx: mlir.LoweringRuleContext, *in_nodes, jaxpr: jax_core.Jaxpr, name: str, in_shapes: tuple[jax.ShapeDtypeStruct, ...], out_shapes: tuple[jax.ShapeDtypeStruct, ...], + interpret: bool, debug: bool, input_output_aliases: tuple[tuple[int, int], ...], grid_mapping: pallas_core.GridMapping, - triton_params: dict[str, Any], - num_warps: int, - num_stages: int, + compiler_params: dict[str, Any], ): - compilation_result = compile_jaxpr( - jaxpr, - (*in_shapes, *out_shapes), - grid_mapping, - name, - num_warps, - num_stages, - debug=debug, - ) - # Triton returns a tuple for ROCm. We just want file path to be passed - if ctx.module_context.platforms[0] == 'rocm': - compilation_result.ptx = compilation_result.ptx[1] - - if debug: - compilation_result.lowering_result.module.dump() - - kernel = triton_kernel_call_lib.TritonKernel( - compilation_result.kernel_name, - num_warps, - compilation_result.shared_mem_bytes, - compilation_result.ptx, - compilation_result.ttir, - compilation_result.compute_capability, - 1, - 1, - 1, # TODO(giorgioa): Add support for clustering on H100s on Pallas. - ) - - grid = normalize_grid(compilation_result.lowering_result.grid) - - kernel_params = [] - for _ in range(len(in_shapes) + len(out_shapes)): - kernel_params.append( - triton_kernel_call_lib.create_array_parameter( - 0, # bytes to zero # TODO(cjfj): Expose through user API. - 16, # divisible by 16 - ) + if interpret: + return mlir.lower_fun(pallas_call_p.impl, multiple_results=True)( + ctx, + *in_nodes, + jaxpr=jaxpr, + name=name, + out_shapes=out_shapes, + in_shapes=in_shapes, + interpret=interpret, + debug=debug, + input_output_aliases=input_output_aliases, + grid_mapping=grid_mapping, + compiler_params=compiler_params, ) - kernel_call = triton_kernel_call_lib.TritonKernelCall( - kernel, grid[0], grid[1], grid[2], kernel_params - ) - - out_types = [ - ir.RankedTensorType.get(shape.shape, mlir.dtype_to_ir_type(shape.dtype)) - for shape in out_shapes - ] - - serialized_metadata = triton_params.get("serialized_metadata", b"") - kernel_call_proto = kernel_call.to_proto(name, serialized_metadata) - return mlir.custom_call( - call_target_name="triton_kernel_call", - result_types=out_types, - operands=in_nodes, - backend_config=zlib.compress(kernel_call_proto), - operand_layouts=avals_to_layouts(ctx.avals_in), - result_layouts=avals_to_layouts(ctx.avals_out), - operand_output_aliases=dict(input_output_aliases), - ).results - + if grid_mapping.num_dynamic_grid_bounds: + raise NotImplementedError( + "dynamic grid bounds not supported in the Triton backend" + ) + triton_params = compiler_params.get("triton", compiler_params) + num_warps = triton_params.pop("num_warps", 4) + [lowering_platform] = ctx.platforms or ctx.module_context.platforms + if lowering_platform == "rocm": + num_stages = triton_params.pop("num_stages", 1) + else: + num_stages = triton_params.pop("num_stages", 3) -def _pallas_call_ttir_lowering( - ctx: mlir.LoweringRuleContext, - *in_nodes, - jaxpr: jax_core.Jaxpr, - name: str, - in_shapes: tuple[jax.ShapeDtypeStruct, ...], - out_shapes: tuple[jax.ShapeDtypeStruct, ...], - debug: bool, - input_output_aliases: tuple[tuple[int, int], ...], - grid_mapping: pallas_core.GridMapping, - triton_params: dict[str, Any] | None = None, - num_warps: int, - num_stages: int, -): - # TODO(sharadmv): handle multiple devices, right now we assume device 0 - # which is fine when we have multiple of the same GPU but this won't work in - # general. - device = 0 - compute_capability = triton_kernel_call_lib.get_compute_capability(device) - cuda_options = dict( - compute_capability=compute_capability, - num_warps=num_warps, - num_stages=num_stages, - debug=debug, - ) + if debug: + print(jaxpr) + print(grid_mapping) lowering_result = lowering.lower_jaxpr_to_triton_module( - jaxpr, (*in_shapes, *out_shapes), grid_mapping, name, cuda_options + jaxpr, (*in_shapes, *out_shapes), grid_mapping, name, lowering_platform ) + module_op = lowering_result.module.operation if debug: - lowering_result.module.dump() + print(module_op.get_asm(enable_debug_info=True, pretty_debug_info=True)) grid_x, grid_y, grid_z = normalize_grid(lowering_result.grid) out_types = [ @@ -214,10 +97,10 @@ def _pallas_call_ttir_lowering( for shape in out_shapes ] buf = io.BytesIO() - lowering_result.module.operation.write_bytecode(buf) + module_op.write_bytecode(buf) backend_config = dict( name=ir.StringAttr.get(name), - ir=ir.StringAttr.get(buf.getvalue()), + ir=ir.StringAttr.get(buf.getvalue()), # type: ignore num_stages=mlir.i32_attr(num_stages), num_warps=mlir.i32_attr(num_warps), grid_x=mlir.i32_attr(grid_x), @@ -240,82 +123,3 @@ def _pallas_call_ttir_lowering( result_layouts=avals_to_layouts(ctx.avals_out), operand_output_aliases=dict(input_output_aliases), ).results - - -_TRITON_COMPILE_VIA_XLA = config.DEFINE_bool( - "jax_triton_compile_via_xla", - default=config.bool_env("JAX_TRITON_COMPILE_VIA_XLA", True), - help="If True, Pallas delegates Triton kernel compilation to XLA.", -) - - -def pallas_call_lowering( - ctx: mlir.LoweringRuleContext, - *in_nodes, - jaxpr: jax_core.Jaxpr, - name: str, - in_shapes: tuple[jax.ShapeDtypeStruct, ...], - out_shapes: tuple[jax.ShapeDtypeStruct, ...], - which_linear: tuple[bool, ...], - interpret: bool, - debug: bool, - input_output_aliases: tuple[tuple[int, int], ...], - grid_mapping: pallas_core.GridMapping, - compiler_params: dict[str, Any], -): - if interpret: - return mlir.lower_fun(pallas_call_p.impl, multiple_results=True)( - ctx, - *in_nodes, - jaxpr=jaxpr, - name=name, - out_shapes=out_shapes, - in_shapes=in_shapes, - which_linear=which_linear, - interpret=interpret, - debug=debug, - input_output_aliases=input_output_aliases, - grid_mapping=grid_mapping, - compiler_params=compiler_params, - ) - - if grid_mapping.num_dynamic_grid_bounds: - raise NotImplementedError( - "dynamic grid bounds not supported in the Triton backend" - ) - triton_params = compiler_params.get("triton_params", {}) - triton_compiler_params = compiler_params.get("triton", {}) - num_warps = triton_compiler_params.pop("num_warps", 4) - if len(ctx.module_context.platforms) > 1: - raise NotImplementedError("multi-platform lowering for Pallas kernels") - if ctx.module_context.platforms[0] == "rocm": - num_stages = triton_compiler_params.pop("num_stages", 1) - else: - num_stages = triton_compiler_params.pop("num_stages", 3) - - if debug: - print(jaxpr) - print(grid_mapping) - - if _TRITON_COMPILE_VIA_XLA.value: - lowering_fn = _pallas_call_ttir_lowering - else: - lowering_fn = _pallas_call_ptx_lowering - - return lowering_fn( - ctx, - *in_nodes, - jaxpr=jaxpr, - name=name, - in_shapes=in_shapes, - out_shapes=out_shapes, - debug=debug, - input_output_aliases=input_output_aliases, - grid_mapping=grid_mapping, - triton_params=triton_params, - num_warps=num_warps, - num_stages=num_stages, - ) - - -mlir.register_lowering(pallas_call_p, pallas_call_lowering, platform="gpu") diff --git a/jax/_src/pallas/triton/primitives.py b/jax/_src/pallas/triton/primitives.py new file mode 100644 index 000000000000..8518a94ed9cf --- /dev/null +++ b/jax/_src/pallas/triton/primitives.py @@ -0,0 +1,122 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Module for GPU-specific JAX primitives.""" + +from __future__ import annotations + +from collections.abc import Sequence + +import jax +from jax import core as jax_core +from jax._src.lib.triton import dialect as tt_dialect +from jax._src.pallas.triton import lowering +from jax.interpreters import mlir +import jax.numpy as jnp + + +def approx_tanh(x: jax.Array) -> jax.Array: + r"""Elementwise approximate hyperbolic tangent: :math:`\mathrm{tanh}(x)`. + + See + https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-tanh. + """ + if x.dtype == jnp.float16: + asm = "tanh.approx.f16 $0, $1;" + constraint = "h" + elif x.dtype == jnp.bfloat16: + asm = "tanh.approx.bf16 $0, $1;" + constraint = "h" + elif x.dtype == jnp.float32: + asm = "tanh.approx.f32 $0, $1;" + constraint = "f" + else: + raise TypeError(f"approx_tanh does not accept {x.dtype} arrays") + + [result] = elementwise_inline_asm( + asm, + args=[x], + constraints=f"={constraint},{constraint}", + pack=1, + result_shape_dtypes=[jax.ShapeDtypeStruct(x.shape, x.dtype)], + ) + return result + + +def elementwise_inline_asm( + asm: str, + *, + args: Sequence[jax.Array], + constraints: str, + pack: int, + result_shape_dtypes: Sequence[jax.ShapeDtypeStruct], +) -> Sequence[jax.Array]: + """Inline assembly applying an elementwise operation. + + Args: + asm: The assembly code to run. + args: The arguments to pass to the assembly code. + constraints: LLVM inline assembly `constraints + `_. + pack: The number of elements from each argument expected by a single + instance of the assembly code. + result_shape_dtypes: The shapes and dtypes of the results produced by the + assembly code. + + Returns: + The results produced by the assembly code. + """ + return elementwise_inline_asm_p.bind( + *args, + asm=asm, + constraints=constraints, + pack=pack, + result_shape_dtypes=result_shape_dtypes, + ) + + +elementwise_inline_asm_p = jax_core.Primitive("elementwise_inline_asm_p") +elementwise_inline_asm_p.multiple_results = True + + +@elementwise_inline_asm_p.def_abstract_eval +def _elementwise_inline_asm_abstract_eval( + *avals: jax_core.ShapedArray, result_shape_dtypes, **kwargs +) -> Sequence[jax_core.ShapedArray]: + del kwargs # Unused. + if not all(x.shape == y.shape for x, y in zip(avals, avals[1:])): + raise ValueError( + "All arguments of elementwise_inline_asm must have the same shape" + ) + return [jax_core.ShapedArray(s.shape, s.dtype) for s in result_shape_dtypes] + + +@lowering.register_lowering(elementwise_inline_asm_p) +def _elementwise_inline_asm_lowering( + ctx: lowering.LoweringRuleContext, + *args, + asm, + constraints, + pack, + result_shape_dtypes, +): + del result_shape_dtypes # Unused. + return tt_dialect.ElementwiseInlineAsmOp( + [*map(mlir.aval_to_ir_type, ctx.avals_out)], + asm, + constraints=constraints, + pure=True, + packed_element=pack, + args=args, + ).result diff --git a/jax/_src/pallas/utils.py b/jax/_src/pallas/utils.py index 3229288d3ad1..41466be0822d 100644 --- a/jax/_src/pallas/utils.py +++ b/jax/_src/pallas/utils.py @@ -13,7 +13,11 @@ # limitations under the License. """Pallas utility functions.""" -import math + +from __future__ import annotations +from typing import overload + +import jax from jax import lax from jax._src import core as jax_core from jax._src.util import split_list @@ -30,9 +34,26 @@ def _wrapped(f): lax.cond(condition, f, lambda: None) return _wrapped - +@overload def cdiv(a: int, b: int) -> int: - return (a + b - 1) // b + ... + +@overload +def cdiv(a: int, b: jax.Array) -> jax.Array: + ... + +@overload +def cdiv(a: jax.Array, b: int) -> jax.Array: + ... + +@overload +def cdiv(a: jax.Array, b: jax.Array) -> jax.Array: + ... + +def cdiv(a: int | jax.Array, b: int | jax.Array) -> int | jax.Array: + if isinstance(a, int) and isinstance(b, int): + return (a + b - 1) // b + return lax.div(a + b - 1, b) def strides_from_shape(shape: tuple[int, ...]) -> tuple[int, ...]: @@ -45,9 +66,10 @@ def strides_from_shape(shape: tuple[int, ...]) -> tuple[int, ...]: def next_power_of_2(x: int) -> int: - if x == 0: - return 1 - return int(2 ** math.ceil(math.log2(x))) + """Returns the next power of two greater than or equal to `x`.""" + if x < 0: + raise ValueError("`next_power_of_2` requires a non-negative integer.") + return 1 if x == 0 else 2 ** (x - 1).bit_length() def pattern_match_scan_to_fori_loop( @@ -97,44 +119,42 @@ def pattern_match_while_to_fori_loop( cond_nconsts: int, body_jaxpr: jax_core.Jaxpr, body_nconsts: int, -) -> tuple[jax_core.Jaxpr, bool]: +) -> tuple[jax_core.Jaxpr | None, str | None]: # Try to pattern match to fori loop. + # Successful matches produce (jaxpr, None), while failures use the str + # component of the return tuple to capture information about the failure. if cond_nconsts: - raise NotImplementedError("Conditional jaxpr can't contain consts.") + return (None, "Conditional jaxpr can't contain consts.") _, cond_invars = split_list(cond_jaxpr.jaxpr.invars, [cond_nconsts]) cond_in_avals = [v.aval for v in cond_invars] if len(cond_in_avals) < 2: - raise NotImplementedError("Conditional jaxpr have only two carry args.") + return (None, "Conditional jaxpr have only two carry args.") # Check that the first two carry values are scalar ints a1, a2 = cond_in_avals[:2] if a1.shape or a1.dtype not in (jnp.int32, jnp.int64): - raise NotImplementedError( - "First conditional jaxpr carry arg is not a scalar int." - ) + return (None, "First conditional jaxpr carry arg is not a scalar int.") if a2.shape or a2.dtype not in (jnp.int32, jnp.int64): - raise NotImplementedError( - "Second conditional jaxpr carry arg is not a scalar int." - ) + return (None, "Second conditional jaxpr carry arg is not a scalar int.") # Check that the only eqn in the cond checks the loop index condition v1, v2 = cond_invars[:2] outvar = cond_jaxpr.jaxpr.outvars[0] assert outvar.aval.dtype == jnp.bool_ if len(cond_jaxpr.jaxpr.eqns) != 1: - raise NotImplementedError("Non-trivial conditional jaxprs not supported.") + return (None, "Non-trivial conditional jaxprs not supported.") eqn = cond_jaxpr.jaxpr.eqns[0] if eqn.primitive != lax.lt_p: - raise NotImplementedError("Non-trivial conditional jaxprs not supported.") + return (None, "Non-trivial conditional jaxprs not supported.") if eqn.outvars != [outvar]: - raise NotImplementedError("Non-trivial conditional jaxprs not supported.") + return (None, "Non-trivial conditional jaxprs not supported.") if eqn.invars != [v1, v2]: - raise NotImplementedError("Non-trivial conditional jaxprs not supported.") + return (None, "Non-trivial conditional jaxprs not supported.") # Check that the carry is updated in the body appropriately _, body_invars = split_list(body_jaxpr.jaxpr.invars, [body_nconsts]) v1, v2 = body_invars[:2] vo1, vo2 = body_jaxpr.jaxpr.outvars[:2] # Upper bound should be constant if v2 is not vo2: - raise NotImplementedError("Loop upper bound is not constant.") + return (None, "Loop upper bound is not constant.") # Check that we increment the loop index in the body for i, eqn in enumerate(body_jaxpr.jaxpr.eqns): if eqn.primitive is lax.add_p: @@ -145,7 +165,7 @@ def pattern_match_while_to_fori_loop( eqn_index = i break else: - raise NotImplementedError("Loop index not incremented in body.") + return (None, "Loop index not incremented in body.") jaxpr = body_jaxpr.jaxpr new_invars = ( *jaxpr.invars[:body_nconsts], @@ -158,4 +178,4 @@ def pattern_match_while_to_fori_loop( invars=new_invars, outvars=new_outvars, ) - return jaxpr + return jaxpr, None diff --git a/jax/_src/partition_spec.py b/jax/_src/partition_spec.py index 05200fd347a7..18e7d18d931d 100644 --- a/jax/_src/partition_spec.py +++ b/jax/_src/partition_spec.py @@ -14,7 +14,7 @@ class _UnconstrainedPartitionSingleton: - def __str__(self): + def __repr__(self): return "UNCONSTRAINED" diff --git a/jax/_src/pickle_util.py b/jax/_src/pickle_util.py index e5fc1f989c28..2b82bb66074b 100644 --- a/jax/_src/pickle_util.py +++ b/jax/_src/pickle_util.py @@ -18,7 +18,7 @@ from typing import Any try: - import cloudpickle # type: ignore[import] + import cloudpickle # type: ignore[import-not-found] except ImportError: cloudpickle = None diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index d166894e8613..454611424a02 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -14,21 +14,23 @@ from __future__ import annotations -from collections.abc import Sequence, Iterable +from collections import defaultdict +from collections.abc import Callable, Sequence, Iterable import dataclasses -from functools import partial, lru_cache +from functools import partial import inspect import itertools as it import logging import operator as op import weakref -from typing import Callable, cast, NamedTuple, Any, Union, Optional +from typing import NamedTuple, Any, Union, cast import threading import warnings import numpy as np from jax._src import api +from jax._src import ad_util from jax._src import api_util from jax._src import config from jax._src import core @@ -37,6 +39,7 @@ from jax._src import linear_util as lu from jax._src import mesh as mesh_lib from jax._src import op_shardings +from jax._src import profiler from jax._src import sharding_impls from jax._src import source_info_util from jax._src import stages @@ -53,30 +56,32 @@ from jax._src.interpreters import partial_eval as pe from jax._src.partition_spec import PartitionSpec from jax._src.interpreters import xla - from jax._src.interpreters import ad from jax._src.interpreters import batching from jax._src.interpreters import mlir from jax._src.interpreters import pxla from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import func as func_dialect +from jax._src.lib import jax_jit from jax._src.lib import xla_client as xc from jax._src.lib import xla_extension_version +from jax._src import sharding from jax._src.sharding_impls import ( - NamedSharding, XLACompatibleSharding, GSPMDSharding, - XLADeviceAssignment, SingleDeviceSharding, PmapSharding, - AUTO, UNSPECIFIED, UnspecifiedValue, + NamedSharding, GSPMDSharding, + SingleDeviceSharding, PmapSharding, AUTO, UNSPECIFIED, UnspecifiedValue, ParsedPartitionSpec, SpecSync, get_single_pspec, is_auto, is_unspecified, is_unspecified_or_auto, prepare_axis_resources, parse_flatten_op_sharding) -from jax._src.state import discharge as state_discharge, RefEffect +from jax._src.layout import Layout, DeviceLocalLayout, AutoLayout +from jax._src.state import discharge as state_discharge, RefEffect, AbstractRef from jax._src.traceback_util import api_boundary from jax._src.tree_util import ( - tree_map, tree_flatten, tree_unflatten, treedef_is_leaf, tree_structure, - treedef_children, broadcast_prefix, all_leaves, prefix_errors, keystr) + tree_flatten, tree_unflatten, treedef_is_leaf, tree_structure, tree_leaves, + treedef_children, broadcast_prefix, all_leaves, prefix_errors, keystr, + PyTreeDef, none_leaf_registry as none_lr) from jax._src.util import ( HashableFunction, safe_map, safe_zip, wraps, distributed_debug_log, split_list, weakref_lru_cache, - merge_lists, flatten, unflatten, subs_list2) + merge_lists, flatten, unflatten, subs_list, fun_name) map, unsafe_map = safe_map, map zip, unsafe_zip = safe_zip, zip @@ -112,13 +117,14 @@ def _device_assignment_mismatch_error(fun_name, fails, args_flat, api_name, if arg_names is None: arg_names = [''] * len(args_flat) for a, n in zip(args_flat, arg_names): - da = a.sharding._device_assignment if hasattr(a, 'sharding') else None + da = (a.sharding._device_assignment + if getattr(a, 'sharding', None) is not None else None) arg_list.append((n, da, shaped_abstractify(a))) mismatched_args_msg = _find_arg_mismatch(arg_list, fails, fun_name) if len(mismatched_args_msg) == 2: - first, second = mismatched_args_msg # type: ignore + first, second = mismatched_args_msg # pytype: disable=bad-unpacking extra_msg = f" Got {first} and {second}" elif len(mismatched_args_msg) == 1: first, second = fails @@ -132,57 +138,114 @@ def _device_assignment_mismatch_error(fun_name, fails, args_flat, api_name, return msg -def _python_pjit_helper(fun, infer_params_fn, *args, **kwargs): - args_flat, _, params, _, out_tree, _, _, _, arg_names, attrs_tracked = \ - infer_params_fn(*args, **kwargs) +class PjitInfo(NamedTuple): + """Things that we know about a jit instance before it is called. + + In other words, this structure contains arguments to jit()/pjit(), + preprocessed and validated. + """ + fun_sourceinfo: str | None + fun_signature: inspect.Signature | None + # Shardings, as specified by the user. These can either be UNSPECIFIED or they + # can be a tree (prefix) of shardings or None. + user_specified_in_shardings: bool + in_shardings_treedef: PyTreeDef + in_shardings_leaves: tuple[Any, ...] + out_shardings_treedef: PyTreeDef + out_shardings_leaves: tuple[Any, ...] + in_layouts_treedef: PyTreeDef + in_layouts_leaves: tuple[Any, ...] + out_layouts_treedef: PyTreeDef + out_layouts_leaves: tuple[Any, ...] + static_argnums: tuple[int, ...] + static_argnames: tuple[str, ...] + donate_argnums: tuple[int, ...] + donate_argnames: tuple[str, ...] + device: xc.Device | None + backend: str | None + keep_unused: bool + inline: bool + abstracted_axes: Any | None + has_explicit_sharding: bool + use_resource_env: bool # False for jit, True for pjit + + # Hash and compare PjitInfo by identity when used as a cache key. + def __hash__(self): + return id(self) + + def __eq__(self, other): + return self is other + + +def _python_pjit_helper(fun, jit_info, *args, **kwargs): + p, args_flat = _infer_params(fun, jit_info, args, kwargs) + for arg in args_flat: dispatch.check_arg(arg) - if attrs_tracked: - init_states = _get_states(attrs_tracked) + + if p.attrs_tracked: + init_states = _get_states(p.attrs_tracked) args_flat = [*init_states, *args_flat] + try: - out_flat = pjit_p.bind(*args_flat, **params) + out_flat = pjit_p.bind(*args_flat, **p.params) except pxla.DeviceAssignmentMismatchError as e: fails, = e.args - api_name = 'jit' if params['resource_env'] is None else 'pjit' + api_name = 'jit' if p.params['resource_env'] is None else 'pjit' fun_name = getattr(fun, '__qualname__', getattr(fun, '__name__', str(fun))) msg = _device_assignment_mismatch_error( - fun_name, fails, args_flat, api_name, arg_names) + fun_name, fails, args_flat, api_name, p.arg_names) raise ValueError(msg) from None - if attrs_tracked: - final_states, out_flat = split_list(out_flat, [len(attrs_tracked)]) - _set_states(attrs_tracked, final_states) - outs = tree_unflatten(out_tree, out_flat) - return outs, out_flat, out_tree, args_flat, params['jaxpr'], attrs_tracked - -def _set_states(attrs_tracked, vals): - from jax.experimental.attrs import jax_setattr # type: ignore - for ((obj, attr), val) in zip(attrs_tracked, vals): - jax_setattr(obj, attr, val) + except xla.InvalidInputException as e: + arg_names = [''] * len(args_flat) if p.arg_names is None else p.arg_names + # Run canonicalization again to figure out which arg failed. + if p.params['jaxpr'].consts: + raise TypeError(e.args[0]) from e + else: + for arg, name, aval in zip(args_flat, arg_names, p.in_avals): + try: + xla.canonicalize_dtype(arg) + except xla.InvalidInputException as _: + # Reraise as TypeError with the new message. + raise TypeError( + f"Argument '{name}' of shape {aval.str_short()} of type" + f' {type(arg)} is not a valid JAX type.') from e + raise AssertionError("Unreachable") from e -def _get_states(attrs_tracked): - from jax.experimental.attrs import jax_getattr # type: ignore - return [jax_getattr(obj, attr) for (obj, attr) in attrs_tracked] + if p.attrs_tracked: + num_states_out = sum(end_tree.num_leaves for _, end_tree, _ in p.attrs_tracked) + final_states, out_flat = split_list(out_flat, [num_states_out]) + _set_states(p.attrs_tracked, final_states) + outs = tree_unflatten(p.out_tree, out_flat) + return outs, out_flat, p.out_tree, args_flat, p.params['jaxpr'], p.attrs_tracked -def _python_pjit(fun: Callable, infer_params_fn): - @wraps(fun) - @api_boundary - def wrapped(*args, **kwargs): - if config.disable_jit.value: - return fun(*args, **kwargs) - return _python_pjit_helper(fun, infer_params_fn, *args, **kwargs)[0] - - def _python_pjit_evict_fn(): - _create_pjit_jaxpr.evict_function(fun) # type: ignore - wrapped.clear_cache = _python_pjit_evict_fn - return wrapped +def _set_states(attrs_tracked, vals): + from jax.experimental.attrs import jax_setattr + valss = split_list(vals, [td.num_leaves for _, td, _ in attrs_tracked[:-1]]) + for ((_, treedef, (obj, attr)), leaves) in zip(attrs_tracked, valss): + val = tree_unflatten(treedef, leaves) + jax_setattr(obj, attr, val) +def _get_states(attrs_tracked): + from jax.experimental.attrs import jax_getattr + vals = [] + for treedef, _, (obj, attr) in attrs_tracked: + tree = jax_getattr(obj, attr) + leaves, treedef_ = tree_flatten(tree) + assert treedef == treedef_ + vals.extend(leaves) + return vals + +def _need_to_rebuild_with_fdo(pgle_profiler): + return (pgle_profiler is not None and pgle_profiler.is_enabled() + and not pgle_profiler.is_fdo_consumed()) def _get_fastpath_data( - executable, out_tree, args_flat, out_flat, attrs_tracked, effects -) -> Optional[pxla.MeshExecutableFastpathData]: + executable, out_tree, args_flat, out_flat, attrs_tracked, effects, + consts, abstracted_axes, pgle_profiler +) -> pxla.MeshExecutableFastpathData | None: out_reflattened, out_tree = pxla.reflatten_outputs_for_dispatch(out_tree, out_flat) use_fastpath = ( @@ -194,14 +257,16 @@ def _get_fastpath_data( and not executable.unsafe_call.has_unordered_effects and not executable.unsafe_call.has_host_callbacks and all(isinstance(x, xc.ArrayImpl) for x in out_reflattened) + and abstracted_axes is None # no attr state effects and not attrs_tracked # no ref state effects and not any(isinstance(e, RefEffect) for e in effects) # no prng reuse checking - and not (config.enable_key_reuse_checks.value and any( + and not (config.debug_key_reuse.value and any( hasattr(arg, 'dtype') and dtypes.issubdtype(arg.dtype, dtypes.prng_key) - for arg in (*args_flat, *out_flat))) + for arg in (*args_flat, *out_flat, *consts))) + and not _need_to_rebuild_with_fdo(pgle_profiler) ) if use_fastpath: @@ -210,7 +275,7 @@ def _get_fastpath_data( kept_var_bitvec = [i in executable._kept_var_idx for i in range(len(args_flat))] in_shardings = [ - a.dtype._rules.physical_sharding(a, s) + sharding_impls.physical_sharding(a, s) if a is not core.abstract_token and dtypes.issubdtype(a.dtype, dtypes.extended) else s for s, a in zip(executable._in_shardings, executable.in_avals) @@ -228,6 +293,7 @@ def _get_fastpath_data( class _MostRecentPjitCallExecutable(threading.local): def __init__(self): self.weak_key_dict = weakref.WeakKeyDictionary() + self.weak_pgle_profiler_dict = weakref.WeakKeyDictionary() _most_recent_pjit_call_executable = _MostRecentPjitCallExecutable() @@ -236,9 +302,15 @@ def _read_most_recent_pjit_call_executable(jaxpr): return _most_recent_pjit_call_executable.weak_key_dict.get(jaxpr, None) +def _read_pgle_profiler(jaxpr): + return _most_recent_pjit_call_executable.weak_pgle_profiler_dict.get( + jaxpr, None + ) + def _cpp_pjit_evict_fn(self): self._clear_cache() - _create_pjit_jaxpr.evict_function(self._fun) # type: ignore + _create_pjit_jaxpr.evict_function(self._fun) # pytype: disable=attribute-error + _infer_params_cached.cache_clear() # The entries are doubled here from the default 4096 because _pjit_call_impl @@ -254,31 +326,27 @@ def _get_cpp_global_cache(pjit_has_explicit_sharding): return _cpp_pjit_cache -def _cpp_pjit(fun: Callable, infer_params_fn, static_argnums, static_argnames, - donate_argnums, pjit_has_explicit_sharding): +def _cpp_pjit(fun: Callable, jit_info: PjitInfo): @api_boundary def cache_miss(*args, **kwargs): outs, out_flat, out_tree, args_flat, jaxpr, attrs_tracked = _python_pjit_helper( - fun, infer_params_fn, *args, **kwargs) + fun, jit_info, *args, **kwargs) executable = _read_most_recent_pjit_call_executable(jaxpr) + pgle_profiler = _read_pgle_profiler(jaxpr) maybe_fastpath_data = _get_fastpath_data( - executable, out_tree, args_flat, out_flat, attrs_tracked, jaxpr.effects) - return outs, maybe_fastpath_data - - if xla_extension_version >= 226: - cpp_pjit_f = xc._xla.pjit( # type: ignore - getattr(fun, "__name__", ""), - fun, cache_miss, static_argnums, static_argnames, - donate_argnums, tree_util.dispatch_registry, - pxla.shard_arg if xla_extension_version >= 229 else pxla.temp_shard_arg, # type: ignore - _get_cpp_global_cache(pjit_has_explicit_sharding)) # type: ignore - else: - cpp_pjit_f = xc._xla.pjit( # type: ignore - getattr(fun, "__name__", ""), - fun, cache_miss, static_argnums, static_argnames, - donate_argnums, tree_util.dispatch_registry, - _get_cpp_global_cache(pjit_has_explicit_sharding)) + executable, out_tree, args_flat, out_flat, attrs_tracked, jaxpr.effects, + jaxpr.consts, jit_info.abstracted_axes, + pgle_profiler) + + return outs, maybe_fastpath_data, _need_to_rebuild_with_fdo(pgle_profiler) + + cpp_pjit_f = xc._xla.pjit( + fun_name(fun), + fun, cache_miss, jit_info.static_argnums, jit_info.static_argnames, + jit_info.donate_argnums, tree_util.dispatch_registry, + lambda x, sharding: pxla.shard_args([sharding], [x])[0], + _get_cpp_global_cache(jit_info.has_explicit_sharding)) cpp_pjitted_f = wraps(fun)(cpp_pjit_f) cpp_pjitted_f._fun = fun @@ -286,10 +354,53 @@ def cache_miss(*args, **kwargs): return cpp_pjitted_f -def pre_infer_params(fun, in_shardings, out_shardings, - donate_argnums, donate_argnames, - static_argnums, static_argnames, device, - backend, abstracted_axes): +def _pjit_explicit_sharding(in_shardings, out_shardings, device, + backend) -> bool: + in_shardings_flat, _ = tree_flatten(in_shardings) + out_shardings_flat, _ = tree_flatten(out_shardings) + return (device is not None or + backend is not None or + any(not is_unspecified(i) for i in in_shardings_flat) or + any(not is_unspecified(i) for i in out_shardings_flat)) + + +def _split_layout_and_sharding(entries): + entries_flat, treedef = tree_flatten(entries, is_leaf=lambda x: x is None) + layouts, shardings = [], [] + + for e in entries_flat: + if e is None or is_unspecified_or_auto(e): + layouts.append(None) + shardings.append(e) + elif isinstance(e, Layout): + layouts.append(e.device_local_layout) + shardings.append(e.sharding) + elif isinstance(e, (DeviceLocalLayout, AutoLayout)): + raise ValueError( + '`jax.jit` does not accept device-local layouts directly. Create ' + 'a `Layout` instance wrapping this device-local layout and pass ' + f'that to `jit` instead. Got {e}') + else: + layouts.append(None) + shardings.append(e) + + assert len(layouts) == len(shardings) + return tree_unflatten(treedef, layouts), tree_unflatten(treedef, shardings) + + +def _parse_jit_arguments(fun: Callable, in_shardings: Any, out_shardings: Any, + donate_argnums: int | Sequence[int] | None, + donate_argnames: str | Iterable[str] | None, + static_argnums: int | Sequence[int] | None, + static_argnames: str | Iterable[str] | None, + device: xc.Device | None, backend: str | None, + abstracted_axes: Any | None, keep_unused: bool, + inline: bool, use_resource_env: bool) -> PjitInfo: + """Parses the arguments to jit/pjit. + + Performs any preprocessing and validation of the arguments that we can do + ahead of time before the jit()-ed function is invoked. + """ if abstracted_axes and not config.dynamic_shapes.value: raise ValueError("abstracted_axes must be used with --jax_dynamic_shapes") @@ -297,11 +408,9 @@ def pre_infer_params(fun, in_shardings, out_shardings, if backend is not None or device is not None: warnings.warn( - 'backend and device argument on jit is deprecated. You can use a ' - '`jax.sharding.Mesh` context manager or device_put the arguments ' - 'before passing them to `jit`. Please see ' - 'https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html ' - 'for more information.', DeprecationWarning) + 'backend and device argument on jit is deprecated. You can use' + ' `jax.device_put(..., jax.local_devices("cpu")[0])` on the inputs to' + ' the jitted function to get the same behavior.', DeprecationWarning) if device is not None and backend is not None: raise ValueError("can't specify both a device and a backend for jit, " f"got {device=} and {backend=}") @@ -320,175 +429,190 @@ def pre_infer_params(fun, in_shardings, out_shardings, # rather than raising an error. https://github.com/google/jax/issues/2367 in_shardings = tuple(in_shardings) - in_shardings, _, _ = prepare_axis_resources(in_shardings, 'in_shardings') - out_shardings, _, _ = prepare_axis_resources(out_shardings, 'out_shardings') + in_layouts, in_shardings = _split_layout_and_sharding(in_shardings) + out_layouts, out_shardings = _split_layout_and_sharding(out_shardings) + + in_shardings = prepare_axis_resources(in_shardings, 'in_shardings') + out_shardings = prepare_axis_resources(out_shardings, 'out_shardings') + + user_specified_in_shardings = (in_shardings is not None and + not is_unspecified(in_shardings)) + + in_shardings_leaves, in_shardings_treedef = none_lr.flatten(in_shardings) + out_shardings_leaves, out_shardings_treedef = none_lr.flatten(out_shardings) + in_layouts_leaves, in_layouts_treedef = none_lr.flatten(in_layouts) + out_layouts_leaves, out_layouts_treedef = none_lr.flatten(out_layouts) + + fun_sourceinfo = api_util.fun_sourceinfo(fun) + fun_signature = api_util.fun_signature(fun) donate_argnums, donate_argnames, static_argnums, static_argnames = resolve_argnums( - fun, donate_argnums, donate_argnames, static_argnums, static_argnames) + fun, fun_signature, donate_argnums, donate_argnames, static_argnums, + static_argnames) - return (in_shardings, out_shardings, donate_argnums, donate_argnames, - static_argnums, static_argnames) + has_explicit_sharding = _pjit_explicit_sharding( + in_shardings, out_shardings, device, backend) + + return PjitInfo( + fun_sourceinfo=fun_sourceinfo, + fun_signature=fun_signature, + user_specified_in_shardings=user_specified_in_shardings, + in_shardings_treedef=in_shardings_treedef, + in_shardings_leaves=tuple(in_shardings_leaves), + out_shardings_treedef=out_shardings_treedef, + out_shardings_leaves=tuple(out_shardings_leaves), + in_layouts_treedef=in_layouts_treedef, + in_layouts_leaves=tuple(in_layouts_leaves), + out_layouts_treedef=out_layouts_treedef, + out_layouts_leaves=tuple(out_layouts_leaves), + static_argnums=static_argnums, + static_argnames=static_argnames, donate_argnums=donate_argnums, + donate_argnames=donate_argnames, device=device, backend=backend, + keep_unused=keep_unused, inline=inline, + abstracted_axes=abstracted_axes, + has_explicit_sharding=has_explicit_sharding, + use_resource_env=use_resource_env) -def post_infer_params(fun, infer_params_fn, static_argnums, static_argnames, - donate_argnums, abstracted_axes, - pjit_has_explicit_sharding): - if abstracted_axes is None: - wrapped = _cpp_pjit(fun, infer_params_fn, static_argnums, static_argnames, - donate_argnums, pjit_has_explicit_sharding) - else: - wrapped = _python_pjit(fun, infer_params_fn) +def _make_jit_wrapper(fun: Callable, jit_info: PjitInfo): @api_boundary def lower(*args, **kwargs): - lowering_parameters = kwargs.pop( - '_experimental_lowering_parameters', mlir.LoweringParameters()) - # TODO(yashkatariya): Remove this when it's added on jit. - in_layouts = kwargs.pop('_in_layouts', None) - out_layouts = kwargs.pop('_out_layouts', None) - (args_flat, flat_global_in_avals, params, in_tree, out_tree, - donated_invars, in_layouts_flat, out_layouts_flat, - arg_names, ()) = infer_params_fn( - *args, **kwargs, _in_layouts=in_layouts, _out_layouts=out_layouts) - resource_env = params['resource_env'] - mesh = None if resource_env is None else resource_env.physical_mesh + traced = trace(*args, **kwargs) try: - in_shardings = _resolve_in_shardings( - args_flat, params['in_shardings'], params['out_shardings'], mesh) - lowering = _pjit_lower( - params['jaxpr'], in_shardings, params['out_shardings'], - params['resource_env'], params['donated_invars'], params['name'], - params['keep_unused'], params['inline'], in_layouts=in_layouts_flat, - out_layouts=out_layouts_flat, lowering_parameters=lowering_parameters) + return traced.lower() except pxla.DeviceAssignmentMismatchError as e: fails, = e.args - api_name = 'jit' if params['resource_env'] is None else 'pjit' - fun_name = getattr(fun, '__qualname__', getattr(fun, '__name__', str(fun))) + fun_name = getattr(fun, '__qualname__', + getattr(fun, '__name__', str(fun))) msg = _device_assignment_mismatch_error( - fun_name, fails, args_flat, api_name, arg_names) + fun_name, fails, traced._args_flat, 'jit', traced._arg_names) raise ValueError(msg) from None - donate_argnums = tuple(i for i, d in enumerate(donated_invars) if d) - return stages.Lowered.from_flat_info( - lowering, in_tree, flat_global_in_avals, donate_argnums, - out_tree) - @api_boundary def eval_shape(*args, **kwargs): - _, _, params, _, out_tree, _, _, _, _, _ = infer_params_fn( - *args, **kwargs, _in_layouts=None, _out_layouts=None) - out_s = [None if is_unspecified(s) else getattr(s, '_original_sharding', s) - for s in params['out_shardings']] + p, _ = _infer_params(fun, jit_info, args, kwargs) + out_s = [None if is_unspecified(s) else s for s in p.params['out_shardings']] + # TODO(yashkatariya): Add `Layout` to SDS. out = [api.ShapeDtypeStruct(x.shape, x.dtype, x.named_shape, sharding=s) - for x, s in zip(params['jaxpr'].out_avals, out_s)] - return tree_unflatten(out_tree, out) + for x, s in zip(p.params['jaxpr'].out_avals, out_s)] + return tree_unflatten(p.out_tree, out) + @api_boundary + def trace(*args, **kwargs) -> stages.Traced: + p, args_flat = _infer_params(fun, jit_info, args, kwargs) + donate_argnums = tuple(i for i, d in enumerate(p.donated_invars) if d) + args_info = stages.make_args_info(p.in_tree, p.in_avals, donate_argnums) + lower_callable = partial(_resolve_and_lower, args_flat, **p.params, + pgle_profiler=None) + return stages.Traced( + p.params['jaxpr'], args_info, p.params["name"],p.out_tree, + lower_callable, args_flat, p.arg_names, p.num_consts) + + wrapped = _cpp_pjit(fun, jit_info) wrapped.lower = lower wrapped.eval_shape = eval_shape + wrapped.trace = trace return wrapped -def _pjit_explicit_sharding(in_shardings, out_shardings, device, - backend) -> bool: - in_shardings_flat, _ = tree_flatten(in_shardings) - out_shardings_flat, _ = tree_flatten(out_shardings) - return (device is not None or - backend is not None or - any(not is_unspecified(i) for i in in_shardings_flat) or - any(not is_unspecified(i) for i in out_shardings_flat)) - - -class PjitInfo(NamedTuple): - fun: Callable - fun_sourceinfo: str | None - fun_signature: inspect.Signature - in_shardings: Any - out_shardings: Any - static_argnums: tuple[int, ...] - static_argnames: tuple[str, ...] - donate_argnums: tuple[int, ...] - donate_argnames: tuple[str, ...] - device: xc.Device | None - backend: str | None - keep_unused: bool - inline: bool - resource_env: Any - abstracted_axes: Any | None - in_layouts: Any # pytree[XlaCompatibleLayout] | None - out_layouts: Any # pytree[XlaCompatibleLayout] | None - - -def common_infer_params(pjit_info_args, *args, **kwargs): - (fun, fun_sourceinfo, fun_signature, user_in_shardings, user_out_shardings, - static_argnums, static_argnames, - donate_argnums, donate_argnames, device, backend, keep_unused, inline, - resource_env, abstracted_axes, in_layouts, out_layouts) = pjit_info_args - - if (kwargs and user_in_shardings is not None and - not is_unspecified(user_in_shardings)): +def make_jit(fun: Callable, in_shardings: Any, out_shardings: Any, + donate_argnums: int | Sequence[int] | None, + donate_argnames: str | Iterable[str] | None, + static_argnums: int | Sequence[int] | None, + static_argnames: str | Iterable[str] | None, + device: xc.Device | None, backend: str | None, + abstracted_axes: Any | None, keep_unused: bool, + inline: bool, use_resource_env: bool) -> Any: + """jit() and pjit() are thin wrappers around this function.""" + jit_info = _parse_jit_arguments( + fun, in_shardings, out_shardings, donate_argnums, donate_argnames, + static_argnums, static_argnames, device, backend, abstracted_axes, + keep_unused, inline, use_resource_env) + return _make_jit_wrapper(fun, jit_info) + + +class PjitParams(NamedTuple): + consts: list[Any] # Only jaxpr constants, we can't keep other arguments alive + params: dict[str, Any] + in_avals: tuple[core.AbstractValue, ...] + in_tree: PyTreeDef + out_tree: PyTreeDef + donated_invars: tuple[bool, ...] + arg_names: tuple[str, ...] | None + num_consts: int + attrs_tracked: list[tuple[PyTreeDef, PyTreeDef, tuple[Any, str]]] + + +def _infer_params_impl( + fun: Callable, + ji: PjitInfo, + pjit_mesh: mesh_lib.Mesh | None, + resource_env: mesh_lib.ResourceEnv | None, + args: tuple[Any, ...], + kwargs: dict[str, Any], + in_avals: tuple[core.AbstractValue, ...] | None, +) -> tuple[PjitParams, list[Any]]: + have_kwargs = bool(kwargs) + if have_kwargs and ji.user_specified_in_shardings: raise ValueError( "pjit does not support kwargs when in_shardings is specified.") - if resource_env is not None: - pjit_mesh = resource_env.physical_mesh + if pjit_mesh is not None: + jit_name = 'pjit' + if (ji.backend or ji.device) and not pjit_mesh.empty: + raise ValueError( + "Mesh context manager should not be used with jit when backend or " + "device is also specified as an argument to jit.") else: - pjit_mesh = None + jit_name = 'jit' - if (backend or device) and pjit_mesh is not None and not pjit_mesh.empty: - raise ValueError( - "Mesh context manager should not be used with jit when backend or " - "device is also specified as an argument to jit.") + axes_specs = _flat_axes_specs(ji.abstracted_axes, *args, **kwargs) - axes_specs = _flat_axes_specs(abstracted_axes, *args, **kwargs) - - jit_name = 'jit' if resource_env is None else 'pjit' - - dbg = debug_info(jit_name, fun_sourceinfo, fun_signature, args, kwargs, - static_argnums, static_argnames) + dbg = debug_info(jit_name, ji.fun_sourceinfo, ji.fun_signature, args, kwargs, + ji.static_argnums, ji.static_argnames) f = lu.wrap_init(fun) f, res_paths = result_paths(f) - f, dyn_args = argnums_partial_except(f, static_argnums, args, allow_invalid=True) + f, dyn_args = argnums_partial_except(f, ji.static_argnums, args, allow_invalid=True) del args - f, dyn_kwargs = argnames_partial_except(f, static_argnames, kwargs) + f, dyn_kwargs = argnames_partial_except(f, ji.static_argnames, kwargs) explicit_args, in_tree = tree_flatten((dyn_args, dyn_kwargs)) flat_fun, out_tree = flatten_fun(f, in_tree) flat_fun, explicit_args = hoist_obj_attrs(flat_fun, explicit_args) - if (donate_argnums or donate_argnames) and not config.debug_nans.value: - donated_invars = donation_vector( - donate_argnums, donate_argnames, dyn_args, dyn_kwargs) + if (ji.donate_argnums or ji.donate_argnames) and not config.debug_nans.value: + donated_invars = donation_vector(ji.donate_argnums, ji.donate_argnames, in_tree) else: donated_invars = (False,) * len(explicit_args) - del donate_argnums, donate_argnames # If backend or device is set as an arg on jit, then resolve them to # in_shardings and out_shardings as if user passed in in_shardings # and out_shardings. - device_or_backend_set = False - if backend or device: - in_shardings = out_shardings = _create_sharding_with_device_backend( - device, backend) - device_or_backend_set = True + device_or_backend_set = bool(ji.backend or ji.device) + if device_or_backend_set: + sharding = _create_sharding_with_device_backend(ji.device, ji.backend) + leaves, treedef = tree_flatten(sharding) + in_shardings_leaves = out_shardings_leaves = tuple(leaves) + in_shardings_treedef = out_shardings_treedef = treedef else: - in_shardings = tree_map( - lambda x: _create_sharding_for_array(pjit_mesh, x, 'in_shardings', - jit_name), - user_in_shardings, is_leaf=lambda x: x is None) - out_shardings = tree_map( - lambda x: _create_sharding_for_array(pjit_mesh, x, 'out_shardings', - jit_name), - user_out_shardings, is_leaf=lambda x: x is None) - - del user_in_shardings, user_out_shardings - - assert in_shardings is not None or all(i is not None for i in in_shardings) - assert out_shardings is not None or all(o is not None for o in out_shardings) - + in_shardings_leaves = tuple( + _create_sharding_for_array(pjit_mesh, x, 'in_shardings', jit_name) + for x in ji.in_shardings_leaves) + in_shardings_treedef = ji.in_shardings_treedef + out_shardings_leaves = tuple( + _create_sharding_for_array(pjit_mesh, x, 'out_shardings', jit_name) + for x in ji.out_shardings_leaves) + out_shardings_treedef = ji.out_shardings_treedef + + assert None not in in_shardings_leaves + assert None not in out_shardings_leaves + + in_type: core.InputType | tuple[core.AbstractValue, ...] if config.dynamic_shapes.value: in_type = pe.infer_lambda_input_type(axes_specs, explicit_args) in_avals = tuple(a for a, e in in_type if e) - else: + elif in_avals is None: avals = [] for i, a in enumerate(explicit_args): try: @@ -501,46 +625,127 @@ def common_infer_params(pjit_info_args, *args, **kwargs): f"computation, whose {arg_path}." ) from e in_type = in_avals = tuple(avals) - - canonicalized_in_shardings_flat, in_layouts_flat = _process_in_axis_resources( - hashable_pytree(in_shardings), hashable_pytree(in_layouts), in_avals, - in_tree, resource_env, dbg, device_or_backend_set, True if kwargs else False) - - jaxpr, consts, out_shardings, out_layouts_flat, attrs_tracked = _pjit_jaxpr( - flat_fun, hashable_pytree(out_shardings), hashable_pytree(out_layouts), - in_type, dbg, device_or_backend_set, HashableFunction(out_tree, closure=()), - HashableFunction(res_paths, closure=()), inline) - - assert len(explicit_args) == len(canonicalized_in_shardings_flat) == len(in_layouts_flat) + else: + in_type = in_avals + + in_shardings_flat, in_layouts_flat = _process_in_axis_resources( + in_shardings_treedef, in_shardings_leaves, + ji.in_layouts_treedef, ji.in_layouts_leaves, + in_avals, in_tree, dbg, device_or_backend_set, have_kwargs) + + attr_token = _attr_token(flat_fun, in_type) + jaxpr, consts, out_type, attrs_tracked = _create_pjit_jaxpr( + flat_fun, in_type, attr_token, dbg, + HashableFunction(res_paths, closure=()), + IgnoreKey(ji.inline)) + _attr_update(flat_fun, in_type, attr_token, attrs_tracked) + out_shardings_flat, out_layouts_flat = _check_and_canonicalize_out_shardings( + out_shardings_treedef, out_shardings_leaves, ji.out_layouts_treedef, + ji.out_layouts_leaves, HashableFunction(out_tree, closure=()), + tuple(out_type), jaxpr.jaxpr.debug_info, device_or_backend_set) + + assert len(explicit_args) == len(in_shardings_flat) == len(in_layouts_flat) if config.dynamic_shapes.value: - implicit_args = _extract_implicit_args(in_type, explicit_args) + implicit_args = _extract_implicit_args( + cast(core.InputType, in_type), explicit_args) else: implicit_args = [] args_flat = [*implicit_args, *explicit_args] - num_extra_args = len(implicit_args) + len(attrs_tracked) + len(consts) - canonicalized_in_shardings_flat = \ - (UNSPECIFIED,) * num_extra_args + canonicalized_in_shardings_flat + num_states_in = sum(init_tree.num_leaves for init_tree, _, _ in attrs_tracked) + num_extra_args = len(implicit_args) + num_states_in + len(consts) + in_shardings_flat = (UNSPECIFIED,) * num_extra_args + in_shardings_flat in_layouts_flat = (None,) * num_extra_args + in_layouts_flat donated_invars = (False,) * num_extra_args + donated_invars - assert (len(canonicalized_in_shardings_flat) == len(in_layouts_flat) == - len(donated_invars) == len(attrs_tracked) + len(consts) + len(args_flat)) + assert (len(in_shardings_flat) == len(in_layouts_flat) == + len(donated_invars) == num_states_in + len(consts) + len(args_flat)) - # in_shardings and out_shardings here are all GSPMDSharding. params = dict( jaxpr=jaxpr, - in_shardings=canonicalized_in_shardings_flat, - out_shardings=out_shardings, + in_shardings=in_shardings_flat, + out_shardings=out_shardings_flat, + in_layouts=in_layouts_flat, + out_layouts=out_layouts_flat, resource_env=resource_env, donated_invars=donated_invars, - name=getattr(flat_fun, '__name__', ''), - keep_unused=keep_unused, - inline=inline, + name=fun_name(flat_fun), + keep_unused=ji.keep_unused, + inline=ji.inline, ) - return (consts + args_flat, in_type, params, in_tree, out_tree(), - donated_invars, in_layouts_flat, out_layouts_flat, - dbg.arg_names if dbg else None, attrs_tracked) + return PjitParams(consts, params, in_avals, in_tree, out_tree(), + donated_invars, dbg.arg_names if dbg else None, len(consts), + attrs_tracked), args_flat + + + +class InferParamsCacheEntry: + """Mutable value object for _infer_params_cached.""" + __slots__ = ['pjit_params'] + + pjit_params: PjitParams | None + + def __init__(self): + self.pjit_params = None + + +# We use an outer cache that is keyed on the signature of the arguments, but +# when populating a cache entry using _infer_params_impl, we need to provide +# actual arguments. In principle we could refactor _infer_params_impl to look +# only at an argument signature instead of args/kwargs in those cases that we +# cache, but this was a more minimal change. +@util.weakref_lru_cache +def _infer_params_cached( + fun: Callable, + jit_info: PjitInfo, + signature: jax_jit.ArgumentSignature, + in_avals: tuple[core.AbstractValue, ...], + pjit_mesh: mesh_lib.Mesh | None, + resource_env: mesh_lib.ResourceEnv | None, +) -> InferParamsCacheEntry: + return InferParamsCacheEntry() + + +def _infer_params( + fun: Callable, ji: PjitInfo, args: tuple[Any, ...], kwargs: dict[str, Any] +) -> tuple[PjitParams, list[Any]]: + if ji.use_resource_env: + # We need to fetch the mesh from inside the wrapped function, because + # meshes are dynamically scoped (i.e., with a context manager). + resource_env = mesh_lib.thread_resources.env + pjit_mesh = resource_env.physical_mesh + else: + resource_env = None + pjit_mesh = None + + skip_cache = xla_extension_version < 273 or config.dynamic_shapes.value + if not skip_cache: + signature, dynargs = jax_jit.parse_arguments( + args, tuple(kwargs.values()), tuple(kwargs.keys()), ji.static_argnums, + ji.static_argnames, tree_util.default_registry) + try: + avals = tuple(shaped_abstractify(a) for a in dynargs) + except (OverflowError, TypeError): + # If we see something we don't understand, use the slow path. + skip_cache = True + + if skip_cache: + p, args_flat = _infer_params_impl(fun, ji, pjit_mesh, resource_env, args, + kwargs, in_avals=None) + return p, p.consts + args_flat + + entry = _infer_params_cached( + fun, ji, signature, avals, pjit_mesh, resource_env) + if entry.pjit_params is None: + p, args_flat = _infer_params_impl( + fun, ji, pjit_mesh, resource_env, args, kwargs, in_avals=avals) + if p.attrs_tracked: + # If there are attrs_tracked, don't use the cache. + return p, p.consts + args_flat + else: + entry.pjit_params = p + return entry.pjit_params, entry.pjit_params.consts + dynargs + def _extract_implicit_args( in_type: Sequence[tuple[core.AbstractValue, bool]], @@ -570,7 +775,7 @@ def _extract_implicit_args( args[d1.val] = d2 assert core.same_referent(args[d1.val], d2) assert all(x is not None for x in args) - return [x for x, (_, e) in zip(args, in_type) if not e] # type: ignore + return [x for x, (_, e) in zip(args, in_type) if not e] # pytype: disable=bad-return-type def _flat_axes_specs(abstracted_axes, *args, **kwargs ) -> list[pe.AbstractedAxesSpec] | None: @@ -588,6 +793,9 @@ def eval_shape(self, *args, **kwargs): """See ``jax.eval_shape``.""" raise NotImplementedError + def trace(self, *args, **kwargs) -> stages.Traced: + raise NotImplementedError + # in_shardings and out_shardings can't be None as the default value # because `None` means that the input is fully replicated. @@ -638,19 +846,10 @@ def pjit( processes run the same :func:`~pjit`'d function in the same order. When running in this configuration, the mesh should contain devices across - all processes. However, any input argument dimensions partitioned over - multi-process mesh axes should be of size equal to the corresponding *local* - mesh axis size, and outputs will be similarly sized according to the local - mesh. ``fun`` will still be executed across *all* devices in the mesh, + all processes. All inputs arguments must be globally shaped. + ``fun`` will still be executed across *all* devices in the mesh, including those from other processes, and will be given a global view of the - data spread across multiple processes as a single array. However, outside - of :func:`~pjit` every process only "sees" its local piece of the input and output, - corresponding to its local sub-mesh. - - This means that each process's participating local devices must form a - _contiguous_ local sub-mesh within the full global mesh. A contiguous - sub-mesh is one where all of its devices are adjacent within the global - mesh, and form a rectangular prism. + data spread across multiple processes as a single array. The SPMD model also requires that the same multi-process :func:`~pjit`'d functions must be run in the same order on all processes, but they can be @@ -676,7 +875,7 @@ def pjit( The valid resource assignment specifications are: - - :py:class:`XLACompatibleSharding`, which will decide how the value + - :py:class:`Sharding`, which will decide how the value will be partitioned. With this, using a mesh context manager is not required. - :py:obj:`None` is a special case whose semantics are: @@ -782,39 +981,10 @@ def pjit( ... print(f(x)) # doctest: +SKIP [ 0.5 2. 4. 6. 8. 10. 12. 10. ] """ - (in_shardings, out_shardings, donate_argnums, donate_argnames, static_argnums, - static_argnames) = pre_infer_params( + return make_jit( fun, in_shardings, out_shardings, donate_argnums, donate_argnames, - static_argnums, static_argnames, device, backend, abstracted_axes) - - fun_sourceinfo = api_util.fun_sourceinfo(fun) - fun_signature = api_util.fun_signature(fun) - - def infer_params(*args, **kwargs): - # Putting this outside of wrapped would make resources lexically scoped - resource_env = mesh_lib.thread_resources.env - # TODO(yashkatariya): Remove this when it's added on jit. Also default to - # layout.DefaultLayout() when out of experimental. - in_layouts = kwargs.pop('_in_layouts', None) - out_layouts = kwargs.pop('_out_layouts', None) - pjit_info_args = PjitInfo( - fun=fun, - fun_sourceinfo=fun_sourceinfo, - fun_signature=fun_signature, - in_shardings=in_shardings, - out_shardings=out_shardings, static_argnums=static_argnums, - static_argnames=static_argnames, donate_argnums=donate_argnums, - donate_argnames=donate_argnames, device=device, backend=backend, - keep_unused=keep_unused, inline=inline, resource_env=resource_env, - abstracted_axes=abstracted_axes, in_layouts=in_layouts, - out_layouts=out_layouts) - return common_infer_params(pjit_info_args, *args, **kwargs) - - has_explicit_sharding = _pjit_explicit_sharding( - in_shardings, out_shardings, device, backend) - return post_infer_params(fun, infer_params, static_argnums, static_argnames, - donate_argnums, abstracted_axes, - has_explicit_sharding) + static_argnums, static_argnames, device, backend, abstracted_axes, + keep_unused, inline, use_resource_env=True) def hashable_pytree(pytree): @@ -827,10 +997,10 @@ def hashable_pytree(pytree): def _create_sharding_for_array(mesh, x, name, api_name): if x is None and (mesh is None or mesh.empty): return UNSPECIFIED - if isinstance(x, XLACompatibleSharding) or is_unspecified_or_auto(x): + if isinstance(x, sharding.Sharding) or is_unspecified_or_auto(x): return x if mesh is None: - msg = ('jax.jit only supports `XLACompatibleSharding`s being passed to' + msg = ('jax.jit only supports `Sharding`s being passed to' f' {name}. Looks like you are passing either `PartitionSpec` or `None`' f' which is not allowed in jax.jit.\n') if name == 'in_shardings': @@ -846,7 +1016,7 @@ def _create_sharding_for_array(mesh, x, name, api_name): raise RuntimeError( f'{api_name} requires a non-empty mesh if you are passing' f' `PartitionSpec`s or `None` to {name}! Is a mesh defined at the call' - f' site? Alternatively, provide `XLACompatibleSharding`s to {name} and' + f' site? Alternatively, provide `Sharding`s to {name} and' ' then the mesh context manager is not required.') # A nice user error is raised in prepare_axis_resources. assert x is None or isinstance(x, ParsedPartitionSpec), x @@ -861,6 +1031,9 @@ def _create_sharding_with_device_backend(device, backend): elif backend is not None: assert device is None out = SingleDeviceSharding(xb.get_backend(backend).local_devices()[0]) + else: + raise AssertionError('Unreachable!') + out._device_backend = True return out @@ -922,14 +1095,15 @@ class PytreeLeaf: def __repr__(self): return "pytree leaf" -@lru_cache(maxsize=4096) -def _process_in_axis_resources(in_shardings_thunk, in_layouts_thunk, in_avals, - in_tree, resource_env, debug_info, +@util.cache(max_size=4096, trace_context_in_key=False) +def _process_in_axis_resources(in_shardings_treedef, in_shardings_leaves, + in_layouts_treedef, in_layouts_leaves, + in_avals, in_tree, debug_info, device_or_backend_set, kws): if not kws: in_tree, _ = treedef_children(in_tree) - orig_in_shardings = in_shardings_thunk() + orig_in_shardings = tree_unflatten(in_shardings_treedef, in_shardings_leaves) # Only do this if original in_shardings are unspecified. If it is AUTO, go # via flatten_axis_resources. if is_unspecified(orig_in_shardings): @@ -938,7 +1112,7 @@ def _process_in_axis_resources(in_shardings_thunk, in_layouts_thunk, in_avals, in_shardings_flat = flatten_axis_resources( "pjit in_shardings", in_tree, orig_in_shardings, tupled_args=True) - in_layouts = in_layouts_thunk() + in_layouts = tree_unflatten(in_layouts_treedef, in_layouts_leaves) if in_layouts is None: in_layouts_flat = (in_layouts,) * len(in_avals) else: @@ -951,11 +1125,7 @@ def _process_in_axis_resources(in_shardings_thunk, in_layouts_thunk, in_avals, pjit_check_aval_sharding(in_shardings_flat, in_avals, None if debug_info is None else debug_info.arg_names, "pjit arguments", allow_uneven_sharding=False) - canonicalized_shardings = tuple( - i if is_unspecified_or_auto(i) else - to_gspmd_sharding(i, aval.ndim, device_or_backend_set) - for i, aval in zip(in_shardings_flat, in_avals)) - return canonicalized_shardings, tuple(in_layouts_flat) + return in_shardings_flat, in_layouts_flat callsites: set[str] = set() @@ -964,9 +1134,9 @@ def explain_tracing_cache_miss( if config.check_tracer_leaks.value: return def unpack(key): - transforms, (), _, (in_type, debug_info, _, inline), *_, ctx = key + transforms, (), _, (in_type, _, debug_info, _, inline), *_, ctx = key # TODO(dougalm,mattjj): enable cache miss explanation with attrs - _, (_, (in_tree,)), (_, ()) = transforms + _, (_, (in_tree,)), *_ = transforms return in_tree, in_type, debug_info, inline.val, ctx in_tree, in_type, debug_info, inline, ctx = unpack(key) if inline: return @@ -1010,7 +1180,9 @@ def unpack(key): f" {', '.join(map(repr, kwarg_keys))}") dont_match = [set(t[1].node_data()[1]) for t in args_kwargs_trees # type: ignore if t != [args_tree, kwargs_tree]] - close_kwargs = min(dont_match, key=set(kwarg_keys).symmetric_difference) + close_kwargs = min( + dont_match, key=set(kwarg_keys).symmetric_difference, default=None + ) if not close_kwargs: p(" closest seen is passing no keyword args") else: @@ -1046,7 +1218,7 @@ def unpack(key): for path, thing1, thing2, explanation in errs: fst, *path = path # type: ignore base = ['args', 'kwargs'][fst.idx] - p(f" * at {base}{keystr(path)}, seen {thing2} but now given {thing1}," # type: ignore + p(f" * at {base}{keystr(tuple(path))}, seen {thing2} but now given {thing1}," f" so {explanation}") return done() @@ -1087,9 +1259,16 @@ def unpack(key): p("explanation unavailable! please open an issue at https://github.com/google/jax") return done() - @partial(lu.cache, explain=explain_tracing_cache_miss) -def _create_pjit_jaxpr(fun, in_type, debug_info, out_paths, ignored_inline): +def _create_pjit_jaxpr( + fun: lu.WrappedFun, + in_type: core.InputType | Sequence[core.AbstractValue], + attr_data: int, + debug_info: lu.TracingDebugInfo, + out_paths: Callable, + ignored_inline: IgnoreKey +) -> tuple[core.ClosedJaxpr, list[Any], list[core.AbstractValue], + list[tuple[PyTreeDef, PyTreeDef, tuple[Any, str]]]]: del ignored_inline # just for explain_cache_miss with dispatch.log_elapsed_time( "Finished tracing + transforming {fun_name} for pjit in {elapsed_time} sec", @@ -1097,17 +1276,18 @@ def _create_pjit_jaxpr(fun, in_type, debug_info, out_paths, ignored_inline): pe_debug = debug_info and pe.debug_info_final(fun, debug_info.traced_for) if config.dynamic_shapes.value: jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic2( - lu.annotate(fun, in_type), debug_info=pe_debug) + lu.annotate(fun, cast(core.InputType, in_type)), debug_info=pe_debug) attrs_tracked = [] else: jaxpr, global_out_avals, consts, attrs_tracked = pe.trace_to_jaxpr_dynamic( fun, in_type, debug_info=pe_debug) + # assert attr_data is sentinel or attr_data matches attrs_tracked # TODO(dougalm,mattjj): enable debug info with attrs_tracked if not config.dynamic_shapes.value and not attrs_tracked: jaxpr = jaxpr_debug_info(jaxpr, debug_info, out_paths()) - if config.enable_key_reuse_checks.value: + if config.debug_key_reuse.value: # Import here to avoid circular imports from jax.experimental.key_reuse._core import check_key_reuse_jaxpr check_key_reuse_jaxpr(jaxpr) @@ -1121,24 +1301,20 @@ def _create_pjit_jaxpr(fun, in_type, debug_info, out_paths, ignored_inline): return closed_jaxpr, final_consts, global_out_avals, attrs_tracked -@lru_cache(maxsize=4096) +@util.cache(max_size=4096, trace_context_in_key=False) def _check_and_canonicalize_out_shardings( - out_shardings_thunk, out_layouts_thunk, out_tree, out_type, debug_info, - device_or_backend_set): - orig_out_shardings = out_shardings_thunk() - # TODO(yashkatariya): Remove the if branch and fix flatten_axis_resources - # instead. This condition exists because flatten_axis_resources passes in an - # `object()` while unflattening which breaks assertion is user defined - # pytrees (which shouldn't exist but they do). + out_shardings_treedef, out_shardings_leaves, out_layouts_treedef, + out_layouts_leaves, out_tree, out_type, debug_info, device_or_backend_set): + orig_out_shardings = tree_unflatten(out_shardings_treedef, out_shardings_leaves) if (is_unspecified(orig_out_shardings) or - isinstance(orig_out_shardings, XLACompatibleSharding)): + isinstance(orig_out_shardings, sharding.Sharding)): out_shardings_flat = (orig_out_shardings,) * len(out_type) else: out_shardings_flat = flatten_axis_resources( "pjit out_shardings", out_tree(), orig_out_shardings, tupled_args=False) - out_layouts = out_layouts_thunk() + out_layouts = tree_unflatten(out_layouts_treedef, out_layouts_leaves) if out_layouts is None: out_layouts_flat = (out_layouts,) * len(out_type) else: @@ -1150,23 +1326,46 @@ def _check_and_canonicalize_out_shardings( out_shardings_flat, out_type, None if debug_info is None else debug_info.result_paths, "pjit outputs", allow_uneven_sharding=False) - - canonicalized_out_shardings_flat = tuple( - o if is_unspecified(o) or is_auto(o) else - to_gspmd_sharding(o, aval.ndim, device_or_backend_set) - for o, aval in zip(out_shardings_flat, out_type) - ) - return canonicalized_out_shardings_flat, tuple(out_layouts_flat) - - -def _pjit_jaxpr(fun, out_shardings_thunk, out_layouts_thunk, in_type, debug_info, - device_or_backend_set, out_tree, result_paths, inline): - jaxpr, final_consts, out_type, attrs_tracked = _create_pjit_jaxpr( - fun, in_type, debug_info, result_paths, IgnoreKey(inline)) - canonicalized_out_shardings_flat, out_layouts_flat = _check_and_canonicalize_out_shardings( - out_shardings_thunk, out_layouts_thunk, out_tree, tuple(out_type), - jaxpr.jaxpr.debug_info, device_or_backend_set) - return jaxpr, final_consts, canonicalized_out_shardings_flat, out_layouts_flat, attrs_tracked + return out_shardings_flat, out_layouts_flat + + +AttrRecord = tuple[object, str, PyTreeDef, list[core.AbstractValue]] +_seen_attrs = weakref.WeakKeyDictionary() # type: ignore + +def seen_attrs_get( + fun: lu.WrappedFun, + in_type: core.InputType | tuple[core.AbstractValue, ...] +) -> list: + cache = _seen_attrs.setdefault(fun.f, defaultdict(list)) + assert fun.in_type is None or fun.in_type == in_type + return cache[(fun.transforms, fun.params, in_type)] + +def _attr_token( + fun: lu.WrappedFun, + in_type: core.InputType | tuple[core.AbstractValue, ...] +) -> int: + from jax.experimental.attrs import jax_getattr + cases = seen_attrs_get(fun, in_type) + for i, records in enumerate(cases): + for obj, attr, treedef, avals in records: + val = jax_getattr(obj, attr) + vals, treedef_ = tree_flatten(val) + avals_ = map(shaped_abstractify, vals) + if treedef != treedef_ or avals != avals_: break + else: + return i + return len(cases) + +def _attr_update(fun, in_type, i, attrs_tracked): + from jax.experimental.attrs import jax_getattr + leaves = lambda obj, attr: tree_leaves(jax_getattr(obj, attr)) + records = [(obj, attr, init_tree, map(shaped_abstractify, leaves(obj, attr))) + for init_tree, _, (obj, attr) in attrs_tracked] + cases = seen_attrs_get(fun, in_type) + if i == len(cases): + cases.append(records) + else: + assert i < len(cases) and cases[i] == records @dataclasses.dataclass(frozen=True) @@ -1185,14 +1384,13 @@ def pjit_check_aval_sharding( for aval, s, name in zip(flat_avals, shardings, new_names): if is_unspecified_or_auto(s): continue - s = getattr(s, '_original_sharding', s) name_str = f' with pytree key path {name}' if name else '' shape = aval.shape try: - # Sharding interfaces can implement `is_compatible_aval` as an optional + # Sharding interfaces can implement `check_compatible_aval` as an optional # method to raise a more meaningful error. - if hasattr(s, 'is_compatible_aval'): - s.is_compatible_aval(shape) + if hasattr(s, 'check_compatible_aval'): + s.check_compatible_aval(shape) else: s._to_xla_hlo_sharding(len(shape)) except ValueError as e: @@ -1201,7 +1399,7 @@ def pjit_check_aval_sharding( f'annotation {s}: {e}') # Use the `OpSharding` proto to find out how many ways each dimension of # the aval is sharded. This approach will work across all - # XLACompatibleSharding. + # Sharding. hlo_sharding = s._to_xla_hlo_sharding(len(shape)) assert hlo_sharding is not None num_ways_dim_sharded, _ = op_shardings.get_num_ways_dim_sharded(hlo_sharding) @@ -1220,6 +1418,56 @@ def pjit_check_aval_sharding( pjit_p.multiple_results = True +def _resolve_in_layouts(args, jit_in_layouts, resolved_in_shardings, in_avals): + # If device or backend is set, return the default layout. This is because you + # can pass arrays on cpu (with untiled layouts) to jit with backend='tpu' + # which causes error checks to fail. Returning the default layout allows + # this to exist. It's the same for handling shardings. + if pxla.check_device_backend_on_shardings(resolved_in_shardings): + return (None,) * len(jit_in_layouts) + + resolved_in_layouts = [] + for arg, jit_in_l, rs, aval in safe_zip( + args, jit_in_layouts, resolved_in_shardings, in_avals): + arg_layout, committed = ( + pxla._maybe_get_default_layout(getattr(arg, 'layout', None), jit_in_l, rs, aval), + getattr(arg, '_committed', True)) + # Sharding can be unspecified when array is committed if it's a PmapSharding. + is_pmap_sharding = (is_unspecified(rs) or + isinstance(getattr(arg, 'sharding', None), PmapSharding)) + if jit_in_l is None: + if committed: + if is_pmap_sharding: + resolved_in_layouts.append(None) + else: + resolved_in_layouts.append(arg_layout) + else: + resolved_in_layouts.append(None) + else: + # arg_layout can be None because some backends don't implement the + # required layout methods. Hence `arr.layout` can return + # `Layout(None, sharding)` + if (committed and not is_pmap_sharding and + arg_layout is not None and arg_layout != jit_in_l): + extra_msg = '' + if isinstance(jit_in_l, AutoLayout): + extra_msg = ( + ' The layout given to `jax.jit` is `DeviceLocalLayout.AUTO` but' + ' the corresponding argument passed is a `jax.Array` with a' + ' concrete layout. Consider passing a `jax.ShapeDtypeStruct`' + ' instead of `jax.Array` as an argument to the jitted function ' + ' when using `DeviceLocalLayout.AUTO`.' + ) + raise ValueError('Layout passed to jit does not match the layout ' + 'on the respective arg. ' + f'Got pjit layout: {jit_in_l},\n' + f'arg layout: {arg_layout} for ' + f'arg shape: {shaped_abstractify(arg).str_short()}.' + f'{extra_msg}') + resolved_in_layouts.append(jit_in_l) + return tuple(resolved_in_layouts) + + def _resolve_in_shardings( args, pjit_in_shardings: Sequence[PjitSharding], out_shardings: Sequence[PjitSharding], @@ -1240,9 +1488,6 @@ def _resolve_in_shardings( # not allow None as the sharding. if arg_s is None: continue - if not isinstance(arg_s, XLACompatibleSharding): - raise ValueError(f'One of the argument to pjit got sharding {arg_s} ' - 'which is not a subclass of XLACompatibleSharding.') # Don't consider PmapSharding inputs as committed. They will get resharded # unconditionally. if isinstance(arg_s, PmapSharding): @@ -1256,8 +1501,10 @@ def _resolve_in_shardings( pxla._get_and_check_device_assignment( it.chain( util.stable_unique(committed_arg_shardings), - ((i, pxla.MismatchType.IN_SHARDING, None) for i in util.stable_unique(pjit_in_shardings)), - ((o, pxla.MismatchType.OUT_SHARDING, None) for o in util.stable_unique(out_shardings))), + ((i, pxla.MismatchType.IN_SHARDING, None) + for i in util.stable_unique(pjit_in_shardings)), + ((o, pxla.MismatchType.OUT_SHARDING, None) + for o in util.stable_unique(out_shardings))), (None if pjit_mesh is None or pjit_mesh.empty else list(pjit_mesh.devices.flat))) resolved_in_shardings = [] @@ -1276,8 +1523,7 @@ def _resolve_in_shardings( if isinstance(arg_s, PmapSharding): resolved_in_shardings.append(UNSPECIFIED) else: - resolved_in_shardings.append(to_gspmd_sharding( - cast(XLACompatibleSharding, arg_s), arg.ndim)) + resolved_in_shardings.append(arg_s) else: if dispatch.is_single_device_sharding(arg_s): resolved_in_shardings.append(UNSPECIFIED) @@ -1309,17 +1555,16 @@ def _resolve_in_shardings( raise ValueError( 'Memory kinds passed to jax.jit does not match memory kind on the' f' respective arg. Got pjit memory kind: {pjit_in_s.memory_kind}, ' # type: ignore - f'arg memory kind: {arg_s.memory_kind} for ' # type: ignore + f'arg memory kind: {arg_s.memory_kind} for ' # pytype: disable=attribute-error f'arg shape: {shaped_abstractify(arg).str_short()}') if (committed and not isinstance(arg_s, PmapSharding) and not op_shardings.are_op_shardings_equal( pjit_in_s._to_xla_hlo_sharding(arg.ndim), # type: ignore arg_s._to_xla_hlo_sharding(arg.ndim))): - op = getattr(pjit_in_s, '_original_sharding', pjit_in_s) raise ValueError('Sharding passed to pjit does not match the sharding ' 'on the respective arg. ' - f'Got pjit sharding: {op},\n' + f'Got pjit sharding: {pjit_in_s},\n' f'arg sharding: {arg_s} for ' f'arg shape: {shaped_abstractify(arg).str_short()}') resolved_in_shardings.append(pjit_in_s) @@ -1327,24 +1572,64 @@ def _resolve_in_shardings( return tuple(resolved_in_shardings) -def _pjit_call_impl_python( - *args, jaxpr, in_shardings, out_shardings, resource_env, donated_invars, - name, keep_unused, inline): - global _most_recent_pjit_call_executable - +def _resolve_and_lower( + args, jaxpr, in_shardings, out_shardings, in_layouts, + out_layouts, resource_env, donated_invars, name, keep_unused, inline, + lowering_platforms, lowering_parameters, pgle_profiler): in_shardings = _resolve_in_shardings( args, in_shardings, out_shardings, resource_env.physical_mesh if resource_env is not None else None) - - compiled = _pjit_lower( - jaxpr, in_shardings, out_shardings, resource_env, + in_layouts = _resolve_in_layouts(args, in_layouts, in_shardings, + jaxpr.in_avals) + lowered = _pjit_lower( + jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, resource_env, donated_invars, name, keep_unused, inline, - lowering_parameters=mlir.LoweringParameters()).compile() + lowering_platforms=lowering_platforms, + lowering_parameters=lowering_parameters, + pgle_profiler=pgle_profiler) + return lowered + +def _pjit_call_impl_python( + *args, jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, + resource_env, donated_invars, name, keep_unused, inline): + global _most_recent_pjit_call_executable + + compile_options = None + pgle_profiler = None + pgle_profiler_dict = _most_recent_pjit_call_executable.weak_pgle_profiler_dict + if config.enable_pgle.value and config.pgle_profiling_runs.value > 0: + if jaxpr not in pgle_profiler_dict: + pgle_profiler_dict[jaxpr] = profiler.PGLEProfiler( + config.pgle_profiling_runs.value, + config.pgle_aggregation_percentile.value) + + pgle_profiler = pgle_profiler_dict[jaxpr] + # The method below will return FDO profile when module was profiled + # config.jax_pgle_profiling_runs amount of times, otherwise the result will + # be None. + fdo_profile = pgle_profiler.consume_fdo_profile() + if fdo_profile is not None: + compile_options = {'fdo_profile': fdo_profile} + + # TODO(patrios): Do not pass mutable profile session through cached lowering + # chain. Instead we need to move profilers dictionary to pxla module and use + # module as key. Right now we can't do that since there is no way to evict _pjit_lower_cached cache for in PGLE mode. + compiled = _resolve_and_lower( + args, jaxpr=jaxpr, in_shardings=in_shardings, + out_shardings=out_shardings, in_layouts=in_layouts, + out_layouts=out_layouts, resource_env=resource_env, + donated_invars=donated_invars, name=name, keep_unused=keep_unused, + inline=inline, lowering_platforms=None, + lowering_parameters=mlir.LoweringParameters(), + pgle_profiler=pgle_profiler + ).compile(compile_options) + _most_recent_pjit_call_executable.weak_key_dict[jaxpr] = compiled # This check is expensive so only do it if enable_checks is on. if compiled._auto_spmd_lowering and config.enable_checks.value: - pxla.check_gda_or_array_xla_sharding_match(args, compiled._in_shardings, - jaxpr.jaxpr.debug_info) + pxla.check_array_xla_sharding_layout_match( + args, compiled._in_shardings, compiled._in_layouts, + jaxpr.jaxpr.debug_info, compiled._kept_var_idx) if config.distributed_debug.value: # Defensively only perform fingerprint logic if debug logging is enabled # NOTE(skyewm): I didn't benchmark this @@ -1356,6 +1641,8 @@ def _pjit_call_impl_python( distributed_debug_log(("Running pjit'd function", name), ("in_shardings", in_shardings), ("out_shardings", out_shardings), + ("in_layouts", in_layouts), + ("out_layouts", out_layouts), ("abstract args", map(xla.abstractify, args)), ("fingerprint", fingerprint)) try: @@ -1387,8 +1674,9 @@ def _pjit_call_impl_python( @weakref_lru_cache -def _get_jaxpr_as_fun(jaxpr, in_shardings, out_shardings, resource_env, - donated_invars, name, keep_unused, inline): +def _get_jaxpr_as_fun(jaxpr, in_shardings, out_shardings, in_layouts, + out_layouts, resource_env, donated_invars, name, + keep_unused, inline): # The input jaxpr to `_get_jaxpr_as_fun` is under a weakref_lru_cache so # returning `core.jaxpr_as_fun(jaxpr)` directly creates a strong reference to # the jaxpr defeating the purpose of weakref_lru_cache. So return a function @@ -1400,109 +1688,57 @@ def _get_jaxpr_as_fun(jaxpr, in_shardings, out_shardings, resource_env, def _pjit_call_impl(*args, jaxpr, - in_shardings, out_shardings, resource_env, + in_shardings, out_shardings, in_layouts, out_layouts, + resource_env, donated_invars, name, keep_unused, inline): def call_impl_cache_miss(*args_, **kwargs_): out_flat, compiled = _pjit_call_impl_python( *args, jaxpr=jaxpr, in_shardings=in_shardings, - out_shardings=out_shardings, resource_env=resource_env, + out_shardings=out_shardings, in_layouts=in_layouts, + out_layouts=out_layouts, resource_env=resource_env, donated_invars=donated_invars, name=name, keep_unused=keep_unused, inline=inline) + pgle_profiler = _read_pgle_profiler(jaxpr) fastpath_data = _get_fastpath_data( - compiled, tree_structure(out_flat), args, out_flat, [], set()) - return out_flat, fastpath_data + compiled, tree_structure(out_flat), args, out_flat, [], jaxpr.effects, + jaxpr.consts, None, pgle_profiler) + return out_flat, fastpath_data, _need_to_rebuild_with_fdo(pgle_profiler) f = _get_jaxpr_as_fun( - jaxpr, tuple(getattr(i, '_original_sharding', i) for i in in_shardings), - tuple(getattr(o, '_original_sharding', o) for o in out_shardings), + jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, resource_env, donated_invars, name, keep_unused, inline) donated_argnums = [i for i, d in enumerate(donated_invars) if d] has_explicit_sharding = _pjit_explicit_sharding( in_shardings, out_shardings, None, None) - if xla_extension_version >= 226: - return xc._xla.pjit( - name, f, call_impl_cache_miss, [], [], donated_argnums, - tree_util.dispatch_registry, - pxla.shard_arg if xla_extension_version >= 229 else pxla.temp_shard_arg, # type: ignore - _get_cpp_global_cache(has_explicit_sharding))(*args) - else: - return xc._xla.pjit(name, f, call_impl_cache_miss, [], [], donated_argnums, # type: ignore - tree_util.dispatch_registry, - _get_cpp_global_cache(has_explicit_sharding))(*args) + return xc._xla.pjit( + name, f, call_impl_cache_miss, [], [], donated_argnums, + tree_util.dispatch_registry, + lambda x, sharding: pxla.shard_args([sharding], [x])[0], + _get_cpp_global_cache(has_explicit_sharding))(*args) pjit_p.def_impl(_pjit_call_impl) -@dataclasses.dataclass(frozen=True) -class SameDeviceAssignmentTuple: - shardings: tuple[PjitSharding, ...] - # device_assignment is Optional because shardings can contain `AUTO` and in - # that case `mesh` is compulsory to be used. So in that case - # `_pjit_lower_cached` cache, resource_env will check against the devices. - device_assignment: XLADeviceAssignment | None - - def __hash__(self): - shardings_hash = tuple( - (s._hlo_sharding_hash, s.memory_kind) # type: ignore - if isinstance(s, GSPMDSharding) else s - for s in self.shardings) - if self.device_assignment is None: - return hash(shardings_hash) - else: - return hash((shardings_hash, *self.device_assignment)) - - def __eq__(self, other): - if not isinstance(other, SameDeviceAssignmentTuple): - return False - eq = [] - for s, o in zip(self.shardings, other.shardings): - s = getattr(s, "_original_sharding", s) - o = getattr(o, "_original_sharding", o) - if isinstance(s, GSPMDSharding) and isinstance(o, GSPMDSharding): - eq.append( - op_shardings.are_op_shardings_equal(s._hlo_sharding, o._hlo_sharding) - and s.memory_kind == o.memory_kind) - else: - eq.append(s == o) - return all(eq) and self.device_assignment == other.device_assignment - - -def _pjit_lower( - jaxpr: core.ClosedJaxpr, - in_shardings, - out_shardings, - *args, **kwargs): - da = _fast_path_get_device_assignment(it.chain(in_shardings, out_shardings)) - in_shardings = SameDeviceAssignmentTuple(tuple(in_shardings), da) - out_shardings = SameDeviceAssignmentTuple(tuple(out_shardings), da) - return _pjit_lower_cached(jaxpr, in_shardings, out_shardings, *args, **kwargs) +def _pjit_lower(*args, **kwargs): + return _pjit_lower_cached(*args, **kwargs) @weakref_lru_cache def _pjit_lower_cached( jaxpr: core.ClosedJaxpr, - sdat_in_shardings: SameDeviceAssignmentTuple, - sdat_out_shardings: SameDeviceAssignmentTuple, + in_shardings, + out_shardings, + in_layouts: pxla.MaybeLayout, + out_layouts: pxla.MaybeLayout, resource_env, donated_invars, name: str, keep_unused: bool, inline: bool, *, + lowering_platforms: tuple[str, ...] | None, lowering_parameters: mlir.LoweringParameters, - in_layouts: pxla.MaybeLayout | None = None, - out_layouts: pxla.MaybeLayout | None = None): - in_shardings: tuple[PjitShardingMinusUnspecified, ...] = cast( - tuple[PjitShardingMinusUnspecified, ...], sdat_in_shardings.shardings) - out_shardings: tuple[PjitSharding, ...] = sdat_out_shardings.shardings - - # TODO(yashkatariya): Remove this when layouts are supported on jit and - # passed to params. - if in_layouts is None: - in_layouts = (None,) * len(in_shardings) - if out_layouts is None: - out_layouts = (None,) * len(out_shardings) - + pgle_profiler: profiler.PGLEProfiler | None): if resource_env is not None: pxla.resource_typecheck(jaxpr, resource_env, {}, lambda: "pjit") @@ -1521,35 +1757,45 @@ def _pjit_lower_cached( jaxpr, api_name, name, mesh, in_shardings, out_shardings, donated_invars, True, jaxpr.in_avals, tiling_method=None, + lowering_platforms=lowering_platforms, lowering_parameters=lowering_parameters) else: return pxla.lower_sharding_computation( jaxpr, api_name, name, in_shardings, out_shardings, - tuple(donated_invars), tuple(jaxpr.in_avals), + in_layouts, out_layouts, tuple(donated_invars), keep_unused=keep_unused, inline=inline, devices_from_context=( None if mesh is None or mesh.empty else list(mesh.devices.flat)), - lowering_parameters=lowering_parameters, in_layouts=in_layouts, - out_layouts=out_layouts) + lowering_platforms=lowering_platforms, + lowering_parameters=lowering_parameters, + pgle_profiler=pgle_profiler) def pjit_staging_rule(trace, *args, **params): + jaxpr, in_fwd, out_shardings, out_layouts = _pjit_forwarding( + params['jaxpr'], params['out_shardings'], params['out_layouts']) + params = dict(params, jaxpr=jaxpr, out_shardings=out_shardings, + out_layouts=out_layouts) + if (params["inline"] and all(is_unspecified(i) for i in params["in_shardings"]) and - all(is_unspecified(o) for o in params["out_shardings"])): - jaxpr = params['jaxpr'] + all(is_unspecified(o) for o in params["out_shardings"]) and + all(i is None for i in params["in_layouts"]) and + all(o is None for o in params["out_layouts"])): + if config.dynamic_shapes.value: # Inline jaxpr doesn't handle dynamic shapes when inlining. If dynamic # shapes are enabled, use eval_jaxpr, which uses the tracing machinery, # but redundantly performs abstract evaluation again. - return core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, *args, - propagate_source_info=False) + out_tracers = core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, *args, + propagate_source_info=False) else: - return pe.inline_jaxpr_into_trace(trace, jaxpr.jaxpr, jaxpr.consts, *args) + out_tracers = pe.inline_jaxpr_into_trace( + trace, jaxpr.jaxpr, jaxpr.consts, *args) elif config.dynamic_shapes.value: source_info = source_info_util.current() out_tracers = [] - for aval in _out_type(params['jaxpr']): + for aval in _out_type(jaxpr): if type(aval) is core.DShapedArray: shape = [args[d.val] if type(d) is core.InDBIdx else out_tracers[d.val] if type(d) is core.OutDBIdx else @@ -1558,13 +1804,51 @@ def pjit_staging_rule(trace, *args, **params): out_tracers.append(pe.DynamicJaxprTracer(trace, aval, source_info)) eqn = core.new_jaxpr_eqn( map(trace.getvar, args), map(trace.makevar, out_tracers), pjit_p, params, - params['jaxpr'].effects, source_info) + jaxpr.effects, source_info) trace.frame.add_eqn(eqn) - return out_tracers + elif any(isinstance(c, core.MutableArray) for c in jaxpr.consts): + jaxpr, consts = pxla._move_mutable_consts(jaxpr) + consts = map(trace.instantiate_const, consts) + in_shardings = (*params['in_shardings'],) + (UNSPECIFIED,) * len(consts) + in_layouts = (*params['in_layouts'],) + (None,) * len(consts) + donated_invars = (*params['donated_invars'],) + (False,) * len(consts) + new_params = dict(params, jaxpr=jaxpr, in_shardings=in_shardings, + in_layouts=in_layouts, donated_invars=donated_invars) + out_tracers = trace.default_process_primitive( + pjit_p, (*args, *consts), new_params) else: - return trace.default_process_primitive(pjit_p, args, params) + out_tracers = trace.default_process_primitive(pjit_p, args, params) + + out_tracers_ = iter(out_tracers) + out_tracers = [args[f] if type(f) is int else next(out_tracers_) + for f in in_fwd] + assert next(out_tracers_, None) is None + return out_tracers pe.custom_staging_rules[pjit_p] = pjit_staging_rule + +def _pjit_forwarding(jaxpr, out_shardings, out_layouts): + in_fwd: list[int | None] = pe._jaxpr_forwarding(jaxpr.jaxpr) + in_fwd = [fwd if is_unspecified(os) and ol is None else None for fwd, os, ol + in zip(in_fwd, out_shardings, out_layouts)] + keep = [f is None for f in in_fwd] + jaxpr = pe.prune_closed_jaxpr_outputs(jaxpr, keep) + out_shardings = [o for o, k in zip(out_shardings, keep) if k] + out_layouts = [o for o, k in zip(out_layouts , keep) if k] + return jaxpr, in_fwd, out_shardings, out_layouts + +def pjit_forwarding_rule(eqn): + jaxpr, in_fwd, out_shardings, out_layouts = _pjit_forwarding( + eqn.params['jaxpr'], eqn.params['out_shardings'], eqn.params['out_layouts']) + new_outvars = [v for v, f in zip(eqn.outvars, in_fwd) if f is None] + new_params = dict(eqn.params, jaxpr=jaxpr, out_shardings=(*out_shardings,), + out_layouts=(*out_layouts,)) + new_eqn = eqn.replace(params=new_params, outvars=new_outvars) + fwd_vars = [eqn.invars[f] if f is not None else None for f in in_fwd] + return fwd_vars, new_eqn +pe.forwarding_rules[pjit_p] = pjit_forwarding_rule + + # TODO(mattjj): remove/trivialize this when jaxprs have type annotation on them, # since it's actually not possible in general to infer the type from the term def _out_type(jaxpr: core.ClosedJaxpr) -> list[core.AbstractValue]: @@ -1589,13 +1873,14 @@ def _pjit_typecheck(ctx_factory, *in_atoms, jaxpr, **params): core.custom_typechecks[pjit_p] = _pjit_typecheck -def _pjit_abstract_eval(*args, jaxpr, out_shardings, resource_env, **_): +def _pjit_abstract_eval(*args, jaxpr, **_): return jaxpr.out_avals, jaxpr.effects pjit_p.def_effectful_abstract_eval(_pjit_abstract_eval) def _pjit_cached_lower_jaxpr_to_fun(ctx, name, jaxpr, effects, in_shardings, - out_shardings, api_name): + out_shardings, in_layouts, out_layouts, + api_name): mod_ctx = ctx.module_context axis_ctx = ctx.module_context.axis_context num_devices = None @@ -1604,29 +1889,29 @@ def _pjit_cached_lower_jaxpr_to_fun(ctx, name, jaxpr, effects, in_shardings, elif isinstance(axis_ctx, sharding_impls.SPMDAxisContext): num_devices = axis_ctx.mesh.size key = (pjit_p, name, jaxpr, effects, num_devices, - pxla.SemanticallyEqualShardings(in_shardings), - pxla.SemanticallyEqualShardings(out_shardings), api_name) + pxla.SemanticallyEqualShardings(in_shardings, jaxpr.in_avals), + pxla.SemanticallyEqualShardings(out_shardings, jaxpr.out_avals), + in_layouts, out_layouts, api_name) func = mod_ctx.cached_primitive_lowerings.get(key, None) if func is None: - arg_shardings = [None if is_unspecified(i) else i._to_xla_hlo_sharding(aval.ndim) - for aval, i in zip(ctx.avals_in, in_shardings)] - result_shardings = [None if is_unspecified(o) else o._to_xla_hlo_sharding(aval.ndim) - for aval, o in zip(ctx.avals_out, out_shardings)] + arg_shardings = [None if is_unspecified(i) else i for i in in_shardings] + result_shardings = [None if is_unspecified(o) else o for o in out_shardings] # TODO(b/228598865): inlined calls cannot have shardings set directly on the # inputs or outputs because they are lost during MLIR->HLO conversion. # using_sharding_annotation=False means we add an identity operation instead. func = mlir.lower_jaxpr_to_fun( mod_ctx, name, jaxpr, effects, ctx.name_stack, arg_shardings=arg_shardings, result_shardings=result_shardings, - use_sharding_annotations=False, api_name=api_name) + use_sharding_annotations=False, api_name=api_name, + arg_layouts=in_layouts, result_layouts=out_layouts) mod_ctx.cached_primitive_lowerings[key] = func return func def _pjit_lowering(ctx, *args, name, jaxpr, in_shardings, - out_shardings, resource_env, donated_invars, - keep_unused, inline): + out_shardings, in_layouts, out_layouts, resource_env, + donated_invars, keep_unused, inline): effects = list(ctx.tokens_in.effects()) output_types = map(mlir.aval_to_ir_types, ctx.avals_out) output_types = [mlir.token_type()] * len(effects) + output_types @@ -1634,13 +1919,15 @@ def _pjit_lowering(ctx, *args, name, jaxpr, in_shardings, func = _pjit_cached_lower_jaxpr_to_fun( ctx, name, jaxpr, tuple(effects), in_shardings, - out_shardings, api_name=('jit' if resource_env is None else 'pjit')) + out_shardings, in_layouts, out_layouts, + api_name=('jit' if resource_env is None else 'pjit')) tokens_in = [ctx.tokens_in.get(eff) for eff in effects] args = (*ctx.dim_var_values, *tokens_in, *args) call = func_dialect.CallOp(flat_output_types, ir.FlatSymbolRefAttr.get(func.name.value), mlir.flatten_lowering_ir_args(args)) + mlir.wrap_compute_type_in_place(ctx, call) out_nodes = unflatten(call.results, map(len, output_types)) tokens, out_nodes = split_list(out_nodes, [len(effects)]) tokens_out = ctx.tokens_in.update_tokens(mlir.TokenSet(zip(effects, tokens))) @@ -1653,7 +1940,7 @@ def _pjit_lowering(ctx, *args, name, jaxpr, in_shardings, def _pjit_batcher(insert_axis, spmd_axis_name, axis_size, axis_name, main_type, vals_in, dims_in, - jaxpr, in_shardings, out_shardings, + jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, resource_env, donated_invars, name, keep_unused, inline): segment_lens, dims_in = batching.indirectify_ragged_axes(dims_in) new_jaxpr, axes_out = batching.batch_jaxpr2( @@ -1678,16 +1965,24 @@ def _pjit_batcher(insert_axis, spmd_axis_name, _pjit_batcher_for_sharding(o, axis_out, new_parts, mesh, aval.ndim) if axis_out is not None else o for axis_out, o, aval in zip(axes_out, out_shardings, new_jaxpr.out_avals)) + # TODO(yashkatariya): Figure out layouts should change under vmap. + if not (all(l is None for l in in_layouts) and + all(l is None for l in out_layouts)): + raise NotImplementedError + vals_out = pjit_p.bind( *vals_in, jaxpr=new_jaxpr, in_shardings=in_shardings, out_shardings=out_shardings, + in_layouts=in_layouts, + out_layouts=out_layouts, resource_env=resource_env, donated_invars=donated_invars, name=name, keep_unused=keep_unused, inline=inline) + resolved_axes_out = batching.resolve_ragged_axes_against_inputs_outputs( vals_in, vals_out, axes_out) return vals_out, resolved_axes_out @@ -1697,56 +1992,53 @@ def _pjit_batcher(insert_axis, spmd_axis_name, pxla.spmd_primitive_batchers[pjit_p] = partial(_pjit_batcher, True, None) def _pjit_batcher_for_sharding( - s: GSPMDSharding | UnspecifiedValue, + s: sharding.Sharding | UnspecifiedValue, dim: int, val: tuple[str, ...], mesh, ndim: int): if is_unspecified(s): return s + hlo_s = s._to_xla_hlo_sharding(ndim) # type: ignore if not val: - if sharding_impls.is_op_sharding_replicated(s._hlo_sharding): # type: ignore + if sharding_impls.is_op_sharding_replicated(hlo_s): return s - old_op = s._hlo_sharding.to_proto() # type: ignore - new_op = old_op.clone() # type: ignore + new_op = hlo_s.to_proto().clone() tad = list(new_op.tile_assignment_dimensions) tad.insert(dim, 1) new_op.tile_assignment_dimensions = tad - if xla_extension_version >= 234: - new_gs = GSPMDSharding( - s._device_assignment, new_op, # type: ignore - _device_list=getattr(s, '_internal_device_list', None)) - else: - new_gs = GSPMDSharding(s._device_assignment, new_op) # type: ignore - if hasattr(s, '_original_sharding'): - vmapped_s = pxla._get_out_sharding_from_orig_sharding( - [new_gs], [None], s._original_sharding, None)[0] # type: ignore - new_gs = to_gspmd_sharding(vmapped_s, ndim) - return new_gs + new_gs = GSPMDSharding( + s._device_assignment, new_op, # type: ignore + _device_list=getattr(s, '_internal_device_list', None)) + return pxla._get_out_sharding_from_orig_sharding([new_gs], [None], s, None)[0] else: - assert isinstance(s, GSPMDSharding) - if isinstance(getattr(s, '_original_sharding', None), NamedSharding): - mesh = s._original_sharding.mesh # type: ignore + if isinstance(s, NamedSharding): + mesh = s.mesh if mesh is None or mesh.empty: - s_type = (f', got: {s._original_sharding!r}' - if hasattr(s, '_original_sharding') else '') raise ValueError( 'If you are using xmap or spmd_axis_name parameter of jax.vmap,' ' please make sure to run your jitted function inside the mesh' ' context manager. Only `jax.lax.with_sharding_constraint` with' ' `jax.sharding.NamedSharding` as an input can be transformed with' ' spmd_axis_name batching rules outside of an explicit mesh context' - f' manager scope{s_type}') - parsed_pspec = parse_flatten_op_sharding(s._hlo_sharding, mesh)[0] # type: ignore + f' manager scope{s!r}') + parsed_pspec = parse_flatten_op_sharding(hlo_s, mesh)[0] parsed_pspec = parsed_pspec.insert_axis_partitions(dim, val) - mps = NamedSharding._from_parsed_pspec(mesh, parsed_pspec) - if xla_extension_version >= 234: - return GSPMDSharding(mps._device_assignment, mps._to_xla_hlo_sharding(ndim), - _device_list=getattr(mps, '_internal_device_list', None)) - else: - return GSPMDSharding(mps._device_assignment, mps._to_xla_hlo_sharding(ndim)) + return NamedSharding._from_parsed_pspec(mesh, parsed_pspec) def _pjit_jvp(primals_in, tangents_in, - jaxpr, in_shardings, out_shardings, + jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, resource_env, donated_invars, name, keep_unused, inline): + if any(isinstance(c, core.MutableArray) for c in jaxpr.consts): + jaxpr, mut_primals = pxla._move_mutable_consts(jaxpr) + mut_tangents = map(ad_util.zeros_like_jaxval, mut_primals) + primals_in = [*primals_in, *mut_primals] + tangents_in = [*tangents_in, *mut_tangents] + in_shardings = (*in_shardings,) + (UNSPECIFIED,) * len(mut_primals) + in_layouts = (*in_layouts,) + (None,) * len(mut_primals) + donated_invars = (*donated_invars,) + (False,) * len(mut_primals) + + tangents_in = [ad_util.zeros_like_aval(a) if isinstance(a, AbstractRef) else x + for x, a in zip(tangents_in, jaxpr.in_avals)] + is_nz_tangents_in = [type(t) is not ad.Zero for t in tangents_in] jaxpr_jvp, is_nz_tangents_out = ad.jvp_jaxpr( jaxpr, is_nz_tangents_in, instantiate=False) @@ -1760,6 +2052,8 @@ def _filter_zeros(is_nz_l, l): jaxpr=jaxpr_jvp, in_shardings=(*in_shardings, *_filter_zeros_in(in_shardings)), out_shardings=(*out_shardings, *_filter_zeros_out(out_shardings)), + in_layouts=(*in_layouts, *_filter_zeros_in(in_layouts)), + out_layouts=(*out_layouts, *_filter_zeros_out(out_layouts)), resource_env=resource_env, donated_invars=(*donated_invars, *_filter_zeros_in(donated_invars)), name=name, @@ -1785,64 +2079,84 @@ def _known_jaxpr_fwd(known_jaxpr: core.ClosedJaxpr, def _pjit_partial_eval(trace, *in_tracers, jaxpr, in_shardings, out_shardings, - resource_env, donated_invars, name, keep_unused, inline): + in_layouts, out_layouts, resource_env, donated_invars, + name, keep_unused, inline): in_pvals = [t.pval for t in in_tracers] known_ins = tuple(pv.is_known() for pv in in_pvals) unknown_ins = tuple(not k for k in known_ins) - known_jaxpr, unknown_jaxpr, unknown_outs, res_avals = pe.partial_eval_jaxpr_nounits( - jaxpr, unknown_ins, instantiate=False) + known_jaxpr, unknown_jaxpr, unknown_outs, res_avals = \ + pe.partial_eval_jaxpr_nounits(jaxpr, unknown_ins, instantiate=False) unknown_outs = tuple(unknown_outs) known_outs = tuple(not uk for uk in unknown_outs) num_residuals = len(res_avals) res_shardings = (UNSPECIFIED,) * num_residuals + res_layouts = (None,) * num_residuals def keep_where(l, should_keep): return tuple(x for x, keep in zip(l, should_keep) if keep) - # Compute which outputs are just forwarded inputs. - num_out_primals = len(known_jaxpr.out_avals) - num_residuals - in_fwd = pe._jaxpr_forwarding(known_jaxpr.jaxpr) + known_out_shardings = keep_where(out_shardings, known_outs) + res_shardings + known_out_layouts = keep_where(out_layouts, known_outs) + res_layouts + # Input-to-output forwarding: compute which outputs are just forwarded inputs. + num_out_primals = len(known_jaxpr.out_avals) - num_residuals + in_fwd: list[int | None] = pe._jaxpr_forwarding(known_jaxpr.jaxpr) # Only forward primal outputs when corresponding out_sharding is UNSPECIFIED. in_fwd_primal, in_fwd_res = split_list(in_fwd, [num_out_primals]) - in_fwd = [fwd if is_unspecified(os) else None for os, fwd in - zip(keep_where(out_shardings, known_outs), in_fwd_primal) - ] + in_fwd_res + in_fwd = [ + fwd if is_unspecified(os) and ol is None else None + for os, ol, fwd in zip( + keep_where(out_shardings, known_outs), + keep_where(out_layouts, known_outs), in_fwd_primal) + ] + in_fwd_res del in_fwd_primal, in_fwd_res + # Prune jaxpr outputs and out_shardings by removing the input-forwards. + keep = [f is None for f in in_fwd] + known_jaxpr = pe.prune_closed_jaxpr_outputs(known_jaxpr, keep) + known_out_shardings = keep_where(known_out_shardings, keep) + known_out_layouts = keep_where(known_out_layouts, keep) + # Update num_out_primals to reflect pruning. + kept_primals, kept_res = split_list(keep, [num_out_primals]) + num_out_primals = sum(kept_primals) + del keep, kept_primals, kept_res - # Compute which residuals are just primal outputs. + # Output-to-output forwarding: compute which residuals are just primal outputs out_vars, res_vars = split_list(known_jaxpr.jaxpr.outvars, [num_out_primals]) idx_map = {id(v): i for i, v in enumerate(out_vars)} out_fwd = [None] * num_out_primals + [idx_map.get(id(v)) for v in res_vars] - - # Prune jaxpr outputs and out_shardings by removing forwards. - keep = [f1 is None and f2 is None for f1, f2 in zip(in_fwd, out_fwd)] + # Prune jaxpr outputs and out_shardings by removing forwarded residuals. + keep = [f is None for f in out_fwd] known_jaxpr = pe.prune_closed_jaxpr_outputs(known_jaxpr, keep) - known_out_shardings = keep_where(out_shardings, known_outs) + res_shardings known_out_shardings = keep_where(known_out_shardings, keep) - del keep, num_out_primals + known_out_layouts = keep_where(known_out_layouts, keep) + del keep known_params = dict( jaxpr=known_jaxpr, in_shardings=keep_where(in_shardings, known_ins), - out_shardings=known_out_shardings, resource_env=resource_env, + out_shardings=known_out_shardings, + in_layouts=keep_where(in_layouts, known_ins), + out_layouts=known_out_layouts, resource_env=resource_env, donated_invars=keep_where(donated_invars, known_ins), name=name, keep_unused=keep_unused, inline=inline) assert len(known_params['out_shardings']) == len(known_params['jaxpr'].out_avals) + assert len(known_params['out_layouts']) == len(known_params['jaxpr'].out_avals) # Bind known things to pjit_p. known_inputs = [pv.get_known() for pv in in_pvals if pv.is_known()] all_known_outs = pjit_p.bind(*known_inputs, **known_params) - all_known_outs = subs_list2(in_fwd, out_fwd, known_inputs, all_known_outs, - all_known_outs) + # Add back in the output fwds. + all_known_outs = subs_list(out_fwd, all_known_outs, all_known_outs) + # Add back in the input fwds. + all_known_outs = subs_list(in_fwd, known_inputs, all_known_outs) known_out_vals, residual_vals = \ split_list(all_known_outs, [len(all_known_outs) - num_residuals]) residual_tracers = map(trace.new_instantiated_const, residual_vals) - # The convention of partial_eval_jaxpr_nounits is to place residual binders - # at the front of the jaxpr produced, so we move them to the back since both - # the jaxpr equation built below and the pjit transpose rule assume a + # The convention of partial_eval_jaxpr_nounits is to place residual binders at + # the front of the jaxpr produced, so we move them to the back since both the + # jaxpr equation built below and the pjit transpose rule assume a # residual-inputs-last convention. unknown_jaxpr = pe.move_binders_to_back( unknown_jaxpr, [True] * num_residuals + [False] * sum(unknown_ins)) @@ -1851,6 +2165,8 @@ def keep_where(l, should_keep): jaxpr=unknown_jaxpr, in_shardings=(keep_where(in_shardings, unknown_ins) + res_shardings), out_shardings=keep_where(out_shardings, unknown_outs), + in_layouts=(keep_where(in_layouts, unknown_ins) + res_layouts), + out_layouts=keep_where(out_layouts, unknown_outs), resource_env=resource_env, donated_invars=(keep_where(donated_invars, unknown_ins) + (False,) * num_residuals), @@ -1884,28 +2200,41 @@ def _pjit_partial_eval_custom_params_updater( donated_invars_known, _ = pe.partition_list(unks_in, params_known['donated_invars']) in_shardings_known, _ = pe.partition_list(unks_in, params_known['in_shardings']) _, out_shardings_known = pe.partition_list(kept_outs_known, params_known['out_shardings']) + in_layouts_known, _ = pe.partition_list(unks_in, params_known['in_layouts']) + _, out_layouts_known = pe.partition_list(kept_outs_known, params_known['out_layouts']) + new_params_known = dict(params_known, in_shardings=tuple(in_shardings_known), out_shardings=(*out_shardings_known, *[UNSPECIFIED] * num_res_out), + in_layouts=tuple(in_layouts_known), + out_layouts=(*out_layouts_known, *[None] * num_res_out), donated_invars=tuple(donated_invars_known)) assert len(new_params_known['in_shardings']) == len(params_known['jaxpr'].in_avals) assert len(new_params_known['out_shardings']) == len(params_known['jaxpr'].out_avals) + assert len(new_params_known['in_layouts']) == len(params_known['jaxpr'].in_avals) + assert len(new_params_known['out_layouts']) == len(params_known['jaxpr'].out_avals) # added num_res new inputs to jaxpr_staged, and pruning according to inst_in _, donated_invars_staged = pe.partition_list(inst_in, params_staged['donated_invars']) donated_invars_staged = [False] * num_res_in + donated_invars_staged _, in_shardings_staged = pe.partition_list(inst_in, params_staged['in_shardings']) in_shardings_staged = [*[UNSPECIFIED] * num_res_in, *in_shardings_staged] - _, out_shardings_staged = pe.partition_list(kept_outs_staged, params_staged['out_shardings']) + _, in_layouts_staged = pe.partition_list(inst_in, params_staged['in_layouts']) + in_layouts_staged = [*[None] * num_res_in, *in_layouts_staged] + _, out_layouts_staged = pe.partition_list(kept_outs_staged, params_staged['out_layouts']) new_params_staged = dict(params_staged, in_shardings=tuple(in_shardings_staged), out_shardings=tuple(out_shardings_staged), + in_layouts=tuple(in_layouts_staged), + out_layouts=tuple(out_layouts_staged), donated_invars=tuple(donated_invars_staged)) assert len(new_params_staged['in_shardings']) == len(params_staged['jaxpr'].in_avals) assert len(new_params_staged['out_shardings']) == len(params_staged['jaxpr'].out_avals) + assert len(new_params_staged['in_layouts']) == len(params_staged['jaxpr'].in_avals) + assert len(new_params_staged['out_layouts']) == len(params_staged['jaxpr'].out_avals) return new_params_known, new_params_staged pe.partial_eval_jaxpr_custom_rules[pjit_p] = \ @@ -1922,7 +2251,7 @@ def _pjit_transpose_trace(fun, in_avals): def _pjit_transpose(cts_in, *primals_in, - jaxpr, in_shardings, out_shardings, + jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, resource_env, donated_invars, name, keep_unused, inline): def prune_type(ty, xs, maybe_zeros): return tuple(x for x, mz in zip(xs, maybe_zeros) if type(mz) is not ty) @@ -1936,6 +2265,10 @@ def prune_type(ty, xs, maybe_zeros): *prune_type(ad.UndefinedPrimal, in_shardings, primals_in), *prune_type(ad.Zero, out_shardings, cts_in) ) + transpose_in_layouts = ( + *prune_type(ad.UndefinedPrimal, in_layouts, primals_in), + *prune_type(ad.Zero, out_layouts, cts_in) + ) global_cts_in_avals = tuple(core.raise_to_shaped(core.get_aval(ct)) for ct in primals_and_nz_cts_in) @@ -1946,26 +2279,36 @@ def prune_type(ty, xs, maybe_zeros): ad.Zero, in_shardings, tree_unflatten(cts_out_treedef, [object()] * cts_out_treedef.num_leaves)) + transpose_out_layouts = prune_type( + ad.Zero, + in_layouts, + tree_unflatten(cts_out_treedef, [object()] * cts_out_treedef.num_leaves)) if attrs_tracked: init_states = _get_states(attrs_tracked) primals_and_nz_cts_in = [*init_states, *primals_and_nz_cts_in] transpose_in_shardings = (UNSPECIFIED,) * len(attrs_tracked) + transpose_in_shardings transpose_out_shardings = (UNSPECIFIED,) * len(attrs_tracked) + transpose_out_shardings + transpose_in_layouts = (None,) * len(attrs_tracked) + transpose_in_layouts + transpose_out_layouts = (None,) * len(attrs_tracked) + transpose_out_layouts nz_cts_out = pjit_p.bind( *primals_and_nz_cts_in, jaxpr=transpose_jaxpr, in_shardings=transpose_in_shardings, out_shardings=transpose_out_shardings, + in_layouts=transpose_in_layouts, + out_layouts=transpose_out_layouts, resource_env=resource_env, donated_invars=(False,) * len(primals_and_nz_cts_in), name=name, keep_unused=keep_unused, inline=inline) + if attrs_tracked: final_states, nz_cts_out = split_list(nz_cts_out, [len(init_states)]) _set_states(attrs_tracked, final_states) + return tree_unflatten(cts_out_treedef, nz_cts_out) ad.reducing_transposes[pjit_p] = _pjit_transpose @@ -1992,6 +2335,8 @@ def keep_where(xs, keeps): jaxpr=dced_jaxpr, in_shardings=keep_where(eqn_params["in_shardings"], used_inputs), out_shardings=keep_where(eqn_params["out_shardings"], used_outputs), + in_layouts=keep_where(eqn_params["in_layouts"], used_inputs), + out_layouts=keep_where(eqn_params["out_layouts"], used_outputs), donated_invars=keep_where(eqn_params["donated_invars"], used_inputs), ) if not any(used_inputs) and not any(used_outputs) and not dced_jaxpr.effects: @@ -2000,7 +2345,7 @@ def keep_where(xs, keeps): new_eqn = core.new_jaxpr_eqn( [v for v, used in zip(eqn.invars, used_inputs) if used], [v for v, used in zip(eqn.outvars, used_outputs) if used], - eqn.primitive, new_params, dced_jaxpr.effects, eqn.source_info) + eqn.primitive, new_params, dced_jaxpr.effects, eqn.source_info, eqn.ctx) return used_inputs, new_eqn pe.dce_rules[pjit_p] = dce_jaxpr_pjit_rule @@ -2030,13 +2375,12 @@ def _resource_typing_pjit(avals, params, source_info, resource_env, named_axis_r for aval, s in zip(jaxpr.in_avals, params['in_shardings']): if is_unspecified(s) or is_auto(s): continue - elif hasattr(s, '_original_sharding') and hasattr( - s._original_sharding, '_parsed_pspec'): - parsed_pspec = s._original_sharding._parsed_pspec + elif hasattr(s, '_parsed_pspec'): + parsed_pspec = s._parsed_pspec else: if resource_env is not None and not resource_env.physical_mesh.empty: parsed_pspec = parse_flatten_op_sharding( - s._hlo_sharding, resource_env.physical_mesh)[0] + s._to_xla_hlo_sharding(aval.ndim), resource_env.physical_mesh)[0] else: parsed_pspec = None if parsed_pspec is not None: @@ -2052,13 +2396,12 @@ def _resource_typing_pjit(avals, params, source_info, resource_env, named_axis_r for aval, s in zip(jaxpr.out_avals, params['out_shardings']): if is_unspecified(s) or is_auto(s): continue - elif hasattr(s, '_original_sharding') and hasattr( - s._original_sharding, '_parsed_pspec'): - parsed_pspec = s._original_sharding._parsed_pspec + elif hasattr(s, '_parsed_pspec'): + parsed_pspec = s._parsed_pspec else: if resource_env is not None and not resource_env.physical_mesh.empty: parsed_pspec = parse_flatten_op_sharding( - s._hlo_sharding, resource_env.physical_mesh)[0] + s._to_xla_hlo_sharding(aval.ndim), resource_env.physical_mesh)[0] else: parsed_pspec = None if parsed_pspec is not None: @@ -2077,6 +2420,10 @@ def _pjit_pp_rule(eqn, context, settings): del params['in_shardings'] if all(is_unspecified(s) for s in params['out_shardings']): del params['out_shardings'] + if all(l is None for l in params['in_layouts']): + del params['in_layouts'] + if all(l is None for l in params['out_layouts']): + del params['out_layouts'] if not params['keep_unused']: del params['keep_unused'] if (params['resource_env'] is None or @@ -2091,21 +2438,31 @@ def _pjit_pp_rule(eqn, context, settings): def _pjit_state_discharge_rule( - in_avals, out_avals, *args, jaxpr, in_shardings, out_shardings, **params): + in_avals, out_avals, *args, jaxpr, in_shardings, out_shardings, + in_layouts, out_layouts, **params): if not (all(map(is_unspecified, in_shardings)) and - all(map(is_unspecified, out_shardings))): raise NotImplementedError + all(map(is_unspecified, out_shardings))): + raise NotImplementedError + + if not (all(l is None for l in in_layouts) and + all(l is None for l in out_layouts)): + raise NotImplementedError + jaxpr, consts = jaxpr.jaxpr, jaxpr.consts num_outs = len(jaxpr.outvars) discharged_jaxpr, discharged_consts = state_discharge.discharge_state(jaxpr, consts) discharged_closed_jaxpr = core.ClosedJaxpr(discharged_jaxpr, discharged_consts) new_in_shardings = (UnspecifiedValue(),) * len(discharged_jaxpr.invars) new_out_shardings = (UnspecifiedValue(),) * len(discharged_jaxpr.outvars) + new_in_layouts = (None,) * len(discharged_jaxpr.invars) + new_out_layouts = (None,) * len(discharged_jaxpr.outvars) out_and_ref_vals = pjit_p.bind( *args, jaxpr=discharged_closed_jaxpr, in_shardings=new_in_shardings, - out_shardings=new_out_shardings, **params) + out_shardings=new_out_shardings, in_layouts=new_in_layouts, + out_layouts=new_out_layouts, **params) out_vals, ref_vals = split_list(out_and_ref_vals, [num_outs]) ref_vals_iter = iter(ref_vals) - new_invals = tuple(next(ref_vals_iter) if isinstance(aval, state_discharge.AbstractRef) + new_invals = tuple(next(ref_vals_iter) if isinstance(aval, AbstractRef) else None for aval in in_avals) sentinel = object() assert next(ref_vals_iter, sentinel) is sentinel @@ -2131,7 +2488,10 @@ def with_sharding_constraint(x, shardings): .. _Distributed arrays and automatic parallelization: https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html """ x_flat, tree = tree_flatten(x) - user_shardings, _, _ = prepare_axis_resources( + + layouts, shardings = _split_layout_and_sharding(shardings) + + user_shardings = prepare_axis_resources( shardings, "shardings", allow_unconstrained_dims=True) del shardings @@ -2139,6 +2499,10 @@ def with_sharding_constraint(x, shardings): flatten_axes("with_sharding_constraint shardings", tree, user_shardings)) del user_shardings + user_layouts_flat = tuple( + flatten_axes("with_sharding_constraint layouts", tree, layouts)) + del layouts + resource_env = mesh_lib.thread_resources.env mesh = resource_env.physical_mesh @@ -2154,19 +2518,27 @@ def with_sharding_constraint(x, shardings): shardings_flat, x_flat, None, "with_sharding_constraint arguments", allow_uneven_sharding=True) - outs = [sharding_constraint_p.bind(xf, sharding=to_gspmd_sharding(i, xf.ndim), + outs = [sharding_constraint_p.bind(xf, sharding=s, layout=l, resource_env=resource_env, unconstrained_dims=ud) - for xf, i, ud in zip(x_flat, shardings_flat, unconstrained_dims)] + for xf, s, l, ud in zip(x_flat, shardings_flat, user_layouts_flat, + unconstrained_dims)] return tree_unflatten(tree, outs) def _identity_fn(x): return x -def _sharding_constraint_impl(x, sharding, resource_env, unconstrained_dims): - if hasattr(x, 'sharding') and x.sharding.is_equivalent_to(sharding, x.ndim): - return x - # Run a jit here to raise good errors when device assignment don't match. - return api.jit(_identity_fn, out_shardings=sharding)(x) +def _sharding_constraint_impl(x, sharding, layout, resource_env, + unconstrained_dims): + if layout is None: + if hasattr(x, 'sharding') and x.sharding.is_equivalent_to(sharding, x.ndim): + return x + # Run a jit here to raise good errors when device assignment don't match. + return api.jit(_identity_fn, out_shardings=sharding)(x) + else: + if (hasattr(x, 'layout') and x.layout.device_local_layout == layout and + x.sharding.is_equivalent_to(sharding, x.ndim)): + return x + return api.jit(_identity_fn, out_shardings=Layout(layout, sharding))(x) sharding_constraint_p = core.Primitive("sharding_constraint") @@ -2175,7 +2547,7 @@ def _sharding_constraint_impl(x, sharding, resource_env, unconstrained_dims): ad.deflinear2(sharding_constraint_p, lambda ct, _, **params: (sharding_constraint_p.bind(ct, **params),)) -def _sharding_constraint_hlo_lowering(ctx, x_node, *, sharding, +def _sharding_constraint_hlo_lowering(ctx, x_node, *, sharding, layout, resource_env, unconstrained_dims): aval, = ctx.avals_in out_aval, = ctx.avals_out @@ -2184,34 +2556,59 @@ def _sharding_constraint_hlo_lowering(ctx, x_node, *, sharding, # NamedSharding. So update the NamedSharding to have the manual axes. if isinstance(axis_ctx, sharding_impls.SPMDAxisContext): mesh = resource_env.physical_mesh - parsed_pspec = parse_flatten_op_sharding(sharding._hlo_sharding, mesh)[0] + if mesh.empty and isinstance(sharding, NamedSharding): + mesh = sharding.mesh + parsed_pspec = parse_flatten_op_sharding( + sharding._to_xla_hlo_sharding(aval.ndim), mesh)[0] sharding = NamedSharding._from_parsed_pspec( mesh, parsed_pspec, _manual_axes=axis_ctx.manual_axes) - return [ - mlir.wrap_with_sharding_op(ctx, - x_node, out_aval, - sharding._to_xla_hlo_sharding(aval.ndim).to_proto(), - unspecified_dims=unconstrained_dims) - ] + out = mlir.wrap_with_sharding_op( + ctx, x_node, out_aval, sharding._to_xla_hlo_sharding(aval.ndim).to_proto(), + unspecified_dims=unconstrained_dims) + if layout is not None: + out = mlir.wrap_with_layout_op(ctx, out, out_aval, layout, aval) + return [out] mlir.register_lowering(sharding_constraint_p, _sharding_constraint_hlo_lowering) -def _sharding_constraint_batcher(insert_axis, spmd_axis_name, axis_size, - axis_name, main_type, vals_in, dims_in, - sharding, resource_env, unconstrained_dims): +def _sharding_constraint_batcher( + insert_axis, spmd_axis_name, axis_size, axis_name, main_type, vals_in, + dims_in, sharding, layout, resource_env, unconstrained_dims): + if spmd_axis_name is not None and isinstance(sharding, NamedSharding): + used = {n for ns in sharding.spec + for n in (ns if isinstance(ns, tuple) else (ns,))} + if set(spmd_axis_name) & used: + raise ValueError(f"vmap spmd_axis_name {spmd_axis_name} cannot appear in " + "with_sharding_constraint spec, but got spec " + f"{sharding.spec}") x, = vals_in d, = dims_in # None means unconstrained in ParsedPartitionSpec new_parts = (axis_name,) if insert_axis else ( None if spmd_axis_name is None else spmd_axis_name) unconstrained_dims = {ud + (d <= ud) for ud in unconstrained_dims} + if new_parts is None: unconstrained_dims.add(d) + + vmapped_sharding = _pjit_batcher_for_sharding( + sharding, d, new_parts, resource_env.physical_mesh, x.ndim) + if unconstrained_dims and isinstance(vmapped_sharding, NamedSharding): + new_spec = list(vmapped_sharding.spec) + [None] * (x.ndim - len(vmapped_sharding.spec)) + for u in unconstrained_dims: + new_spec[u] = PartitionSpec.UNCONSTRAINED + vmapped_sharding = NamedSharding( + vmapped_sharding.mesh, PartitionSpec(*new_spec)) + + # TODO(yashkatariya): Figure out layouts should change under vmap. + if layout is not None: + raise NotImplementedError + y = sharding_constraint_p.bind( x, - sharding=_pjit_batcher_for_sharding( - sharding, d, new_parts, resource_env.physical_mesh, x.ndim), + sharding=vmapped_sharding, + layout=layout, resource_env=resource_env, unconstrained_dims=unconstrained_dims) return y, d @@ -2226,55 +2623,29 @@ def _sharding_constraint_batcher(insert_axis, spmd_axis_name, axis_size, def _resource_typing_sharding_constraint(avals, params, source_info, resource_env, named_axis_resources): aval, = avals - if hasattr(params['sharding'], '_original_sharding'): - parsed_pspec = params['sharding']._original_sharding._parsed_pspec + parsed_pspec = None + if isinstance(params['sharding'], NamedSharding): + parsed_pspec = params['sharding']._parsed_pspec else: - parsed_pspec = parse_flatten_op_sharding( - params['sharding']._hlo_sharding, resource_env.physical_mesh)[0] - _check_resources_against_named_axes( - "with_sharding_constraint input", aval, parsed_pspec, named_axis_resources) + if not resource_env.physical_mesh.empty: + parsed_pspec = parse_flatten_op_sharding( + params['sharding']._to_xla_hlo_sharding(aval.ndim), + resource_env.physical_mesh)[0] + if parsed_pspec is not None: + _check_resources_against_named_axes( + "with_sharding_constraint input", aval, parsed_pspec, named_axis_resources) pxla.custom_resource_typing_rules[sharding_constraint_p] = \ _resource_typing_sharding_constraint # -------------------- helpers -------------------- -@lru_cache(maxsize=2048) -def to_gspmd_sharding(s: XLACompatibleSharding, ndim: int, - device_or_backend_set: bool = False) -> GSPMDSharding: - if isinstance(s, GSPMDSharding): - return s - if xla_extension_version >= 234: - gs = GSPMDSharding(s._device_assignment, s._to_xla_hlo_sharding(ndim), - memory_kind=s.memory_kind, - _device_list=getattr(s, '_internal_device_list', None)) - else: - gs = GSPMDSharding(s._device_assignment, s._to_xla_hlo_sharding(ndim), - memory_kind=s.memory_kind) - gs._original_sharding = s - if device_or_backend_set: - gs._original_sharding._device_backend = device_or_backend_set - return gs - - def get_unconstrained_dims(sharding: NamedSharding): assert sharding._parsed_pspec is not None return {i for i, axes in enumerate(sharding._parsed_pspec) if axes is None} -def _fast_path_get_device_assignment( - shardings: Iterable[PjitSharding]) -> XLADeviceAssignment | None: - da = None - for i in shardings: - if is_unspecified(i): - continue - if is_auto(i): - return i.mesh._flat_devices_tuple # type: ignore - return i._device_assignment # type: ignore - return da - - def _get_partition_spec( ppspec: Sequence[ParsedPartitionSpec]) -> Sequence[PartitionSpec]: return [get_single_pspec(p) for p in ppspec] diff --git a/jax/_src/pretty_printer.py b/jax/_src/pretty_printer.py index ec5c34ab846a..5c1e7e1198e8 100644 --- a/jax/_src/pretty_printer.py +++ b/jax/_src/pretty_printer.py @@ -31,7 +31,7 @@ import enum from functools import partial import sys -from typing import NamedTuple +from typing import Any, NamedTuple from jax._src import config from jax._src import util @@ -42,7 +42,7 @@ colorama = None -_PPRINT_USE_COLOR = config.DEFINE_bool( +_PPRINT_USE_COLOR = config.bool_flag( 'jax_pprint_use_color', config.bool_env('JAX_PPRINT_USE_COLOR', True), help='Enable jaxpr pretty-printing with colorful syntax highlighting.' @@ -69,12 +69,23 @@ def _can_use_color() -> bool: class Doc(util.StrictABC): __slots__ = () - def format(self, width: int = 80, use_color: bool | None = None, - annotation_prefix=" # ") -> str: + def format( + self, width: int = 80, *, use_color: bool | None = None, + annotation_prefix: str = " # ", + source_map: list[list[tuple[int, int, Any]]] | None = None + ) -> str: + """ + Formats a pretty-printer document as a string. + + Args: + source_map: for each line in the output, contains a list of + (start column, end column, source) tuples. Each tuple associates a + region of output text with a source. + """ if use_color is None: use_color = CAN_USE_COLOR and _PPRINT_USE_COLOR.value return _format(self, width, use_color=use_color, - annotation_prefix=annotation_prefix) + annotation_prefix=annotation_prefix, source_map=source_map) def __str__(self): return self.format() @@ -147,6 +158,21 @@ def __init__(self, n: int, child: Doc): def __repr__(self): return f"nest({self.n, self.child})" +_NO_SOURCE = object() + +class _SourceMapDoc(Doc): + __slots__ = ("child", "source") + child: Doc + source: Any + + def __init__(self, child: Doc, source: Any): + assert isinstance(child, Doc), child + self.child = child + self.source = source + + def __repr__(self): return f"source({self.child}, {self.source})" + + Color = enum.Enum("_Color", ["BLACK", "RED", "GREEN", "YELLOW", "BLUE", "MAGENTA", "CYAN", "WHITE", "RESET"]) Intensity = enum.Enum("_Intensity", ["DIM", "NORMAL", "BRIGHT"]) @@ -193,7 +219,7 @@ def _fits(doc: Doc, width: int, agenda: list[tuple[int, _BreakMode, Doc]] agenda.append((i + doc.n, m, doc.child)) elif isinstance(doc, _GroupDoc): agenda.append((i, _BreakMode.FLAT, doc.child)) - elif isinstance(doc, _ColorDoc): + elif isinstance(doc, _ColorDoc) or isinstance(doc, _SourceMapDoc): agenda.append((i, m, doc.child)) else: raise ValueError("Invalid document ", doc) @@ -224,7 +250,7 @@ def _sparse(doc: Doc) -> bool: agenda.append(doc.child) elif isinstance(doc, _GroupDoc): agenda.append(doc.child) - elif isinstance(doc, _ColorDoc): + elif isinstance(doc, _ColorDoc) or isinstance(doc, _SourceMapDoc): agenda.append(doc.child) else: raise ValueError("Invalid document ", doc) @@ -241,6 +267,7 @@ class _State(NamedTuple): mode: _BreakMode doc: Doc color: _ColorState + source_map: Any class _Line(NamedTuple): text: str @@ -283,17 +310,29 @@ def _align_annotations(lines): -def _format(doc: Doc, width: int, *, use_color, annotation_prefix) -> str: +def _format( + doc: Doc, width: int, *, use_color: bool, annotation_prefix: str, + source_map: list[list[tuple[int, int, Any]]] | None +) -> str: lines = [] default_colors = _ColorState(Color.RESET, Color.RESET, Intensity.NORMAL) annotation_colors = _ColorState(Color.RESET, Color.RESET, Intensity.DIM) color_state = default_colors - agenda = [_State(0, _BreakMode.BREAK, doc, default_colors)] + source_start = 0 # The column at which the current source region starts. + source = _NO_SOURCE # The currently active source region. + line_source_map = [] # Source maps for the current line of text. + agenda = [_State(0, _BreakMode.BREAK, doc, default_colors, source)] k = 0 line_text = "" line_annotations = [] while len(agenda) > 0: - i, m, doc, color = agenda.pop() + i, m, doc, color, agenda_source = agenda.pop() + if source_map is not None and agenda_source != source: + pos = len(line_text) + if source_start != pos and source is not _NO_SOURCE: + line_source_map.append((source_start, pos, source)) + source = agenda_source + source_start = pos if isinstance(doc, _NilDoc): pass elif isinstance(doc, _TextDoc): @@ -304,7 +343,7 @@ def _format(doc: Doc, width: int, *, use_color, annotation_prefix) -> str: line_annotations.append(doc.annotation) k += len(doc.text) elif isinstance(doc, _ConcatDoc): - agenda.extend(_State(i, m, d, color) + agenda.extend(_State(i, m, d, color, source) for d in reversed(doc.children)) elif isinstance(doc, _BreakDoc): if m == _BreakMode.BREAK: @@ -313,6 +352,13 @@ def _format(doc: Doc, width: int, *, use_color, annotation_prefix) -> str: annotation_colors) line_text += color_str lines.append(_Line(line_text, k, line_annotations)) + if source_map is not None: + pos = len(line_text) + if source_start != pos and source is not _NO_SOURCE: + line_source_map.append((source_start, pos, source)) + source_map.append(line_source_map) + line_source_map = [] + source_start = i line_text = " " * i line_annotations = [] k = i @@ -322,20 +368,22 @@ def _format(doc: Doc, width: int, *, use_color, annotation_prefix) -> str: line_text += doc.text k += len(doc.text) elif isinstance(doc, _NestDoc): - agenda.append(_State(i + doc.n, m, doc.child, color)) + agenda.append(_State(i + doc.n, m, doc.child, color, source)) elif isinstance(doc, _GroupDoc): # In Lindig's paper, _fits is passed the remainder of the document. # I'm pretty sure that's a bug and we care only if the current group fits! if (_sparse(doc) and _fits(doc, width - k, [(i, _BreakMode.FLAT, doc.child)])): - agenda.append(_State(i, _BreakMode.FLAT, doc.child, color)) + agenda.append(_State(i, _BreakMode.FLAT, doc.child, color, source)) else: - agenda.append(_State(i, _BreakMode.BREAK, doc.child, color)) + agenda.append(_State(i, _BreakMode.BREAK, doc.child, color, source)) elif isinstance(doc, _ColorDoc): color = _ColorState(doc.foreground or color.foreground, doc.background or color.background, doc.intensity or color.intensity) - agenda.append(_State(i, m, doc.child, color)) + agenda.append(_State(i, m, doc.child, color, source)) + elif isinstance(doc, _SourceMapDoc): + agenda.append(_State(i, m, doc.child, color, doc.source)) else: raise ValueError("Invalid document ", doc) @@ -343,6 +391,11 @@ def _format(doc: Doc, width: int, *, use_color, annotation_prefix) -> str: color_state, color_str = _update_color(use_color, color_state, annotation_colors) line_text += color_str + if source_map is not None: + pos = len(line_text) + if source_start != pos and source is not _NO_SOURCE: + line_source_map.append((source_start, pos, source)) + source_map.append(line_source_map) lines.append(_Line(line_text, k, line_annotations)) lines = _align_annotations(lines) out = "\n".join( @@ -406,6 +459,17 @@ def color(doc: Doc, *, foreground: Color | None = None, intensity=intensity) +def source_map(doc: Doc, source: Any): + """Source mapping. + + A source map associates a region of the pretty-printer's text output with a + source location that produced it. For the purposes of the pretty printer a + ``source`` may be any object: we require only that we can compare sources for + equality. A text region to source object mapping can be populated as a side + output of the ``format`` method. + """ + return _SourceMapDoc(doc, source) + type_annotation = partial(color, intensity=Intensity.NORMAL, foreground=Color.MAGENTA) keyword = partial(color, intensity=Intensity.BRIGHT, foreground=Color.BLUE) diff --git a/jax/_src/prng.py b/jax/_src/prng.py index 7f125bd44741..18ae281629b6 100644 --- a/jax/_src/prng.py +++ b/jax/_src/prng.py @@ -13,11 +13,11 @@ # limitations under the License. from __future__ import annotations -from collections.abc import Iterator, Sequence +from collections.abc import Callable, Iterator, Sequence from functools import partial, reduce import math import operator as op -from typing import Any, Callable, NamedTuple +from typing import Any, NamedTuple import numpy as np @@ -33,10 +33,9 @@ from jax._src import dispatch from jax._src import dtypes from jax._src import pretty_printer as pp -from jax._src import sharding_specs +from jax._src import source_info_util from jax._src import tree_util as tree_util_internal from jax._src import typing -from jax._src import op_shardings from jax._src.api import jit, vmap from jax._src.dtypes import float0 from jax._src.interpreters import ad @@ -46,15 +45,15 @@ from jax._src.interpreters import xla from jax._src.lax import lax as lax_internal from jax._src.lax import utils as lax_utils -from jax._src.lib.mlir import ir from jax._src.lib import gpu_prng from jax._src.lib import xla_client as xc +from jax._src.lib import version as jaxlib_version +from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import hlo from jax._src.numpy.array_methods import ( _array_operators, _set_array_base_attributes, _IndexUpdateHelper) -from jax._src.partition_spec import PartitionSpec from jax._src.sharding_impls import ( - NamedSharding, PmapSharding, GSPMDSharding, XLACompatibleSharding) + NamedSharding, PmapSharding, physical_sharding, logical_sharding) from jax._src.typing import Array from jax._src.util import safe_map, safe_zip @@ -66,7 +65,7 @@ Shape = tuple[int, ...] UINT_DTYPES = { - 8: jnp.uint8, 16: jnp.uint16, 32: jnp.uint32, 64: jnp.uint64} # type: ignore[has-type] + 8: jnp.uint8, 16: jnp.uint16, 32: jnp.uint32, 64: jnp.uint64} # -- PRNG implementation interface @@ -154,6 +153,7 @@ class behave like an array whose base elements are keys, hiding the _impl: PRNGImpl _base_array: typing.Array _consumed: bool | np.ndarray # Used in jax.experimental.key_reuse. + _source_info: None | source_info_util.SourceInfo = None def __init__(self, impl, key_data: Any): assert not isinstance(key_data, core.Tracer) @@ -195,7 +195,7 @@ def itemsize(self): _device = property(op.attrgetter('_base_array._device')) _committed = property(op.attrgetter('_base_array._committed')) - device = property(op.attrgetter('_base_array.device')) # type: ignore[assignment] + device = property(op.attrgetter('_base_array.device')) devices = property(op.attrgetter('_base_array.devices')) # type: ignore[assignment] is_fully_addressable = property(op.attrgetter('_base_array.is_fully_addressable')) # type: ignore[assignment] is_fully_replicated = property(op.attrgetter('_base_array.is_fully_replicated')) # type: ignore[assignment] @@ -233,8 +233,7 @@ def global_shards(self) -> list[Shard]: @property def sharding(self): - phys_sharding = self._base_array.sharding - return KeyTyRules.logical_sharding(self.aval, phys_sharding) + return logical_sharding(self.aval, self._base_array.sharding) def _is_scalar(self): base_ndim = len(self._impl.key_shape) @@ -322,44 +321,6 @@ def base_arr_shape_to_keys_shape(impl, base_arr_shape): base_ndim = len(impl.key_shape) return base_arr_shape[:-base_ndim] -def make_key_array_phys_sharding(aval, sharding): - if dispatch.is_single_device_sharding(sharding): - return sharding - elif isinstance(sharding, PmapSharding): - key_shape = aval.dtype._impl.key_shape - trailing_sharding = [sharding_specs.NoSharding()] * len(key_shape) - phys_sharding_spec = sharding_specs.ShardingSpec( - sharding=(*sharding.sharding_spec.sharding, *trailing_sharding), - mesh_mapping=sharding.sharding_spec.mesh_mapping) - return PmapSharding(devices=sharding.devices, - sharding_spec=phys_sharding_spec) - elif isinstance(sharding, NamedSharding): - key_shape = aval.dtype._impl.key_shape - trailing_spec = [None] * len(key_shape) - return NamedSharding( - sharding.mesh, - PartitionSpec(*sharding.spec, *trailing_spec)) - else: - hlos = sharding._to_xla_hlo_sharding(aval.ndim) - return GSPMDSharding( - sharding._device_assignment, - KeyTyRules.physical_hlo_sharding(aval, hlos)) - - -def get_logical_gspmd_sharding(aval, phys_sharding): - key_shape = aval.dtype._impl.key_shape - phys_hlo_sharding = phys_sharding._to_xla_hlo_sharding( - aval.ndim + len(key_shape)) - partitions, num_replicas = op_shardings.get_num_ways_dim_sharded( - phys_hlo_sharding) - suffix = [] if num_replicas == 1 else [num_replicas] - # Create logical sharding by cutting off the replicated trailing dims. - logical_op_sharding = phys_hlo_sharding.to_proto().clone() - tad = partitions[:-len(key_shape)] + suffix - logical_op_sharding.tile_assignment_dimensions = tad - return GSPMDSharding(phys_sharding._device_assignment, - xc.HloSharding.from_proto(logical_op_sharding)) - class KeyTyRules: @@ -382,43 +343,6 @@ def physical_element_aval(dtype) -> core.ShapedArray: def physical_const(val) -> Array: return val._base_array - @staticmethod - def physical_hlo_sharding(aval, hlo_sharding: xc.HloSharding) -> xc.HloSharding: - key_shape = aval.dtype._impl.key_shape - new_op_sharding = hlo_sharding.to_proto().clone() # type: ignore - partitions, num_replicas = op_shardings.get_num_ways_dim_sharded( - hlo_sharding) - suffix = [] if num_replicas == 1 else [num_replicas] - tad = partitions + [1] * len(key_shape) + suffix - new_op_sharding.tile_assignment_dimensions = tad - return xc.HloSharding.from_proto(new_op_sharding) - - @staticmethod - def physical_sharding( - aval, sharding: XLACompatibleSharding) -> XLACompatibleSharding: - return make_key_array_phys_sharding(aval, sharding) - - @staticmethod - def logical_sharding(aval, phys_sharding) -> XLACompatibleSharding: - # The trailing dims should always be replicated. - aval.dtype._rules.check_replicated_trailing_dims(phys_sharding, aval) - - if dispatch.is_single_device_sharding(phys_sharding): - return phys_sharding - elif isinstance(phys_sharding, PmapSharding): - key_shape = aval.dtype._impl.key_shape - logical_sharding_spec = sharding_specs.ShardingSpec( - sharding=phys_sharding.sharding_spec.sharding[:-len(key_shape)], - mesh_mapping=phys_sharding.sharding_spec.mesh_mapping) - return PmapSharding(devices=phys_sharding.devices, - sharding_spec=logical_sharding_spec) - elif isinstance(phys_sharding, NamedSharding): - logical_gs = get_logical_gspmd_sharding(aval, phys_sharding) - return pxla._gspmd_to_named_sharding_via_mesh( - logical_gs, phys_sharding.mesh) - else: - return get_logical_gspmd_sharding(aval, phys_sharding) - @staticmethod def result_handler(sticky_device, aval): def handler(_, buf): @@ -434,7 +358,7 @@ def local_sharded_result_handler(aval, sharding, indices): # set up a grounded sharding (with a grounded sharding spec) if isinstance(sharding, (PmapSharding, NamedSharding)): - phys_sharding = make_key_array_phys_sharding(aval, sharding) + phys_sharding = physical_sharding(aval, sharding) else: assert False, f'impossible sharding {sharding} in local sharded result handler' @@ -456,7 +380,7 @@ def global_sharded_result_handler(aval, out_sharding, committed): phys_aval = core.physical_aval(aval) phys_handler_maker = pxla.global_result_handlers[core.ShapedArray] - phys_sharding = make_key_array_phys_sharding(aval, out_sharding) + phys_sharding = physical_sharding(aval, out_sharding) phys_handler = phys_handler_maker(phys_aval, phys_sharding, committed) def handler(bufs): return PRNGKeyArray(aval.dtype._impl, phys_handler(bufs)) @@ -468,7 +392,7 @@ def make_sharded_array(aval, sharding, arrays, committed): phys_handler_maker = pxla.global_result_handlers[core.ShapedArray] phys_arrays = [random_unwrap(arr) for arr in arrays] - phys_sharding = make_key_array_phys_sharding(aval, sharding) + phys_sharding = physical_sharding(aval, sharding) phys_handler = phys_handler_maker(phys_aval, phys_sharding, committed) phys_result = phys_handler(phys_arrays) return PRNGKeyArray(aval.dtype._impl, phys_result) @@ -477,8 +401,9 @@ def make_sharded_array(aval, sharding, arrays, committed): def device_put_sharded(vals, aval, sharding, devices): physical_aval = core.physical_aval(aval) physical_buffers = tree_util.tree_map(random_unwrap, vals) - physical_sharding = make_key_array_phys_sharding(aval, sharding) - physical_result = pxla.batched_device_put(physical_aval, physical_sharding, physical_buffers, list(devices)) + phys_sharding = physical_sharding(aval, sharding) + physical_result = pxla.batched_device_put(physical_aval, phys_sharding, + physical_buffers, list(devices)) return random_wrap(physical_result, impl=aval.dtype._impl) @staticmethod @@ -486,37 +411,11 @@ def device_put_replicated(val, aval, sharding, devices): physical_aval = core.physical_aval(aval) assert len(xla.aval_to_xla_shapes(physical_aval)) == 1 physical_buf = random_unwrap(val) - physical_sharding = make_key_array_phys_sharding(aval, sharding) - physical_result = pxla.batched_device_put(physical_aval, physical_sharding, [physical_buf] * len(devices), devices) + phys_sharding = physical_sharding(aval, sharding) + physical_result = pxla.batched_device_put( + physical_aval, phys_sharding, [physical_buf] * len(devices), devices) return random_wrap(physical_result, impl=aval.dtype._impl) - @staticmethod - def check_replicated_trailing_dims(sharding: XLACompatibleSharding, aval): - if isinstance(sharding, PmapSharding): - return - phys_aval = core.physical_aval(aval) - hlo_s = sharding._to_xla_hlo_sharding(phys_aval.ndim) - partitions, _ = op_shardings.get_num_ways_dim_sharded(hlo_s) - num_trailing_dims = phys_aval.ndim - aval.ndim - if not all(i == 1 for i in partitions[-num_trailing_dims:]): - raise AssertionError( - "The trailing dims of extended dtypes should be replicated. Got" - f" sharding: {sharding}, partitions: {partitions}, " - f"num_trailing_dims: {num_trailing_dims}") - - @staticmethod - def replicate_trailing_dims(ctx, val: ir.Value, aval) -> ir.Value: - # Set the sharding of extended dtypes to be UNCONSTRAINED - # (i.e. XLA will choose) on aval.shape. - # For the trailing dims i.e. the dimension of key_shape on the base_array, - # the sharding is set to be REPLICATED always. - # For example: if the key.shape is (8, 2) and key_data(key).shape is (8, 2, 2), - # then the sharding will be P(P.UNCONSTRAINED, P.UNCONSTRAINED, None). - # The below custom call achieves the sharding like above example. - return mlir.wrap_with_sharding_op( - ctx, val, aval, xc.HloSharding.replicate().to_proto(), - unspecified_dims=set(range(aval.ndim))) - @staticmethod def tangent_dtype(_): return dtypes.float0 @@ -569,10 +468,11 @@ def __hash__(self) -> int: xla.canonicalize_dtype_handlers[PRNGKeyArray] = lambda x: x -def key_array_shard_arg_handler(x: PRNGKeyArray, sharding): - arr = x._base_array - phys_sharding = make_key_array_phys_sharding(x.aval, sharding) - return pxla.shard_arg_handlers[type(arr)](arr, phys_sharding) +def key_array_shard_arg_handler(xs: Sequence[PRNGKeyArray], shardings): + arrs = [x._base_array for x in xs] + phys_shardings = [physical_sharding(x.aval, sharding) + for x, sharding in zip(xs, shardings)] + return pxla.shard_args(phys_shardings, arrs) pxla.shard_arg_handlers[PRNGKeyArray] = key_array_shard_arg_handler @@ -1003,7 +903,29 @@ def _threefry2x32_lowering(key1, key2, x1, x2, use_rolled_loops=True): return tuple(x) -def _threefry2x32_gpu_lowering(lowering_func, ctx, k1, k2, x1, x2): +_threefry2x32_lowering_rule = mlir.lower_fun( + partial(_threefry2x32_lowering, use_rolled_loops=False), + multiple_results=True) + +_threefry2x32_cpu_lowering_rule = mlir.lower_fun( + partial(_threefry2x32_lowering, use_rolled_loops=True), + multiple_results=True) + + +def _threefry2x32_gpu_lowering_rule(lowering_func, ctx, k1, k2, x1, x2): + if not config.threefry_gpu_kernel_lowering.value: # back to default lowering + return _threefry2x32_lowering_rule(ctx, k1, k2, x1, x2) + + # TODO(b/338022728): when we export, use the old custom call target for now. + # Make forward_compatibility_mode False after 3 weeks. + # TODO(b/350111820): figure out why we cannot use the new cu_threefry2x32_ffi + # in Kokoro tests. For now, use the old cu_threefry2x32. + # lowering_parameters = ctx.module_context.lowering_parameters + # forward_compatibility_mode = ( + # lowering_parameters.for_export and + # not lowering_parameters.export_ignore_forward_compatibility) + forward_compatibility_mode = True + aval_out, aval_out_2 = ctx.avals_out assert aval_out == aval_out_2 k1_aval, k2_aval, x1_aval, x2_aval = ctx.avals_in @@ -1026,29 +948,34 @@ def _broadcast(x, aval): length = int(out_len) # will be passed statically output_shape = None - return lowering_func( - (_broadcast(k1, k1_aval), _broadcast(k2, k2_aval)), - (_broadcast(x1, x1_aval), _broadcast(x2, x2_aval)), length, - output_shape) + if jaxlib_version >= (0, 4, 31): + return lowering_func( + (_broadcast(k1, k1_aval), _broadcast(k2, k2_aval)), + (_broadcast(x1, x1_aval), _broadcast(x2, x2_aval)), length, + output_shape, + forward_compatibility_mode) + else: + return lowering_func( + (_broadcast(k1, k1_aval), _broadcast(k2, k2_aval)), + (_broadcast(x1, x1_aval), _broadcast(x2, x2_aval)), length, + output_shape) threefry2x32_p = core.Primitive("threefry2x32") threefry2x32_p.multiple_results = True threefry2x32_p.def_impl(partial(dispatch.apply_primitive, threefry2x32_p)) threefry2x32_p.def_abstract_eval(_threefry2x32_abstract_eval) batching.defbroadcasting(threefry2x32_p) -mlir.register_lowering(threefry2x32_p, mlir.lower_fun( - partial(_threefry2x32_lowering, use_rolled_loops=False), - multiple_results=True)) -mlir.register_lowering(threefry2x32_p, mlir.lower_fun( - partial(_threefry2x32_lowering, use_rolled_loops=True), - multiple_results=True), platform='cpu') +mlir.register_lowering( + threefry2x32_p, _threefry2x32_lowering_rule) +mlir.register_lowering( + threefry2x32_p, _threefry2x32_cpu_lowering_rule, platform='cpu') mlir.register_lowering( threefry2x32_p, - partial(_threefry2x32_gpu_lowering, gpu_prng.cuda_threefry2x32), + partial(_threefry2x32_gpu_lowering_rule, gpu_prng.cuda_threefry2x32), platform='cuda') mlir.register_lowering( threefry2x32_p, - partial(_threefry2x32_gpu_lowering, gpu_prng.rocm_threefry2x32), + partial(_threefry2x32_gpu_lowering_rule, gpu_prng.rocm_threefry2x32), platform='rocm') @@ -1116,8 +1043,8 @@ def bcast_iotas_to_reshaped_iota( mul: Callable[[core.DimSize, ir.Value], ir.Value], shape: core.Shape, iotas: Sequence[ir.Value]) -> ir.Value: - strides: core.Shape = (*(np.cumprod(shape[1:][::-1])[::-1]), 1) # type: ignore - return reduce(add, [mul(s, i) for i, s in zip(iotas, strides)]) # type: ignore + strides: core.Shape = (*(np.cumprod(shape[1:][::-1])[::-1]), 1) + return reduce(add, [mul(s, i) for i, s in zip(iotas, strides)]) def iota_2x32_shape_lowering(ctx, *, shape): aval_out, _ = ctx.avals_out @@ -1130,11 +1057,11 @@ def _mul(x: core.DimSize, y: ir.Value) -> ir.Value: if core.is_constant_dim(x): x_const = mlir.ir_constant(np.array(x, np.dtype('uint64'))) else: - x_const, = mlir.eval_dynamic_shape(ctx, (x,)) + x_shape, = mlir.eval_dynamic_shape(ctx, (x,)) x_const = hlo.convert( ir.RankedTensorType.get( - (), - mlir.dtype_to_ir_type(np.dtype('uint64'))), x_const) + [], + mlir.dtype_to_ir_type(np.dtype('uint64'))), x_shape) x_bcast = mlir.broadcast_in_dim(ctx, x_const, aval_u64, broadcast_dimensions=[]) return mlir.hlo.multiply(x_bcast, y) @@ -1194,9 +1121,9 @@ def threefry_split(key: typing.Array, shape: Shape) -> typing.Array: @partial(jit, static_argnums=(1,)) def _threefry_split(key, shape) -> typing.Array: if config.threefry_partitionable.value: - return _threefry_split_foldlike(key, shape) # type: ignore + return _threefry_split_foldlike(key, shape) else: - return _threefry_split_original(key, shape) # type: ignore + return _threefry_split_original(key, shape) @partial(jit, static_argnums=(1,), inline=True) def _threefry_split_original(key, shape) -> typing.Array: @@ -1377,26 +1304,3 @@ def _unsafe_rbg_fold_in(key: typing.Array, data: typing.Array) -> typing.Array: tag='urbg') register_prng(unsafe_rbg_prng_impl) - - -# Primitives related to key reuse -reuse_key_p = core.Primitive("reuse_key") -reuse_key_p.def_impl(lambda x: x) -reuse_key_p.def_abstract_eval(lambda x: x) -batching.defvectorized(reuse_key_p) -mlir.register_lowering(reuse_key_p, lambda _, k: [k]) - -def reuse_key(key): - """Explicitly mark a key as unconsumed. - - Outside the context of key reuse checking (see :mod:`jax.experimental.key_reuse`) - this function operates as an identity. - - Example: - - >>> import jax - >>> key = jax.random.key(0) - >>> data = jax.random.uniform(key) - >>> same_data = jax.random.uniform(reuse_key(key)) - """ - return reuse_key_p.bind(key) diff --git a/jax/_src/profiler.py b/jax/_src/profiler.py index 0b013b786dd5..cad4826ba801 100644 --- a/jax/_src/profiler.py +++ b/jax/_src/profiler.py @@ -14,6 +14,7 @@ from __future__ import annotations +from collections.abc import Callable from contextlib import contextmanager from functools import wraps import glob @@ -24,8 +25,7 @@ import os import socketserver import threading - -from typing import Callable +from typing import Any from jax._src import traceback_util traceback_util.register_exclusion(__file__) @@ -99,7 +99,7 @@ def start_trace(log_dir, create_perfetto_link: bool = False, The resulting trace can be viewed with TensorBoard. Note that TensorBoard doesn't need to be running when collecting the trace. - Only once trace may be collected a time. A RuntimeError will be raised if + Only one trace may be collected at a time. A RuntimeError will be raised if :func:`start_trace` is called while another trace is running. Args: @@ -211,7 +211,7 @@ def stop_trace(): _profile_state.reset() -def stop_and_get_fdo_profile() -> bytes: +def stop_and_get_fdo_profile() -> bytes | str: """Stops the currently-running profiler trace and export fdo_profile. Currently, this is only supported for GPU. @@ -236,7 +236,7 @@ def trace(log_dir, create_perfetto_link=False, create_perfetto_trace=False): The resulting trace can be viewed with TensorBoard. Note that TensorBoard doesn't need to be running when collecting the trace. - Only once trace may be collected a time. A RuntimeError will be raised if a + Only one trace may be collected at a time. A RuntimeError will be raised if a trace is started while another trace is running. Args: @@ -381,3 +381,62 @@ def save_device_memory_profile(filename, backend: str | None = None) -> None: profile = device_memory_profile(backend) with open(filename, "wb") as f: f.write(profile) + + +# Allows to run model with profiler given amount of times. After required amount +# of retries achived client can collect FDO data. +class PGLEProfiler: + + def __init__(self, retries: int, percentile: int): + self.retries: int = retries + self.percentile: int = percentile + self.collected_fdo: str | None = None + self.called_times: int = 0 + self.fdo_profiles: list[Any] = [] + self.current_session: xla_client.profiler.ProfilerSession | None = None + + def consume_fdo_profile(self) -> str | None: + if self.collected_fdo is not None: + return self.collected_fdo + + if not self.is_enabled() or self.called_times != self.retries: + return None + + self.collected_fdo = xla_client.profiler.aggregate_profiled_instructions( + self.fdo_profiles, self.percentile + ) + return self.collected_fdo + + def is_fdo_consumed(self): + return self.collected_fdo is not None + + def disable(self): + self.retries = 0 + + def is_enabled(self): + return self.retries > 0 + + def is_running(self): + return self.current_session is not None + + @classmethod + @contextmanager + def trace(cls, runner: PGLEProfiler | None): + if (runner is None or runner.is_running() + or not runner.is_enabled() or runner.is_fdo_consumed()): + yield + else: + options = xla_client.profiler.ProfileOptions() + options.enable_hlo_proto = True + runner.current_session = xla_client.profiler.ProfilerSession(options) + + try: + yield + finally: + xspace = runner.current_session.stop() + runner.fdo_profiles.append( + xla_client.profiler.get_fdo_profile(xspace) + ) + runner.current_session = None + + runner.called_times += 1 diff --git a/jax/_src/random.py b/jax/_src/random.py index f9045ebf3135..6a0a3c0f9932 100644 --- a/jax/_src/random.py +++ b/jax/_src/random.py @@ -30,6 +30,7 @@ from jax._src import config from jax._src import core +from jax._src import dispatch from jax._src import dtypes from jax._src import prng from jax._src import xla_bridge @@ -69,11 +70,8 @@ def _isnan(x: ArrayLike) -> Array: return lax.ne(x, x) -# TODO(jakevdp) Finalize batched input deprecation by setting error_on_batched=True. -# FutureWarning Added 2024-01-17 def _check_prng_key(name: str, key: KeyArrayLike, *, - allow_batched: bool = False, - error_on_batched: bool = False) -> tuple[KeyArray, bool]: + allow_batched: bool = False) -> tuple[KeyArray, bool]: if isinstance(key, Array) and dtypes.issubdtype(key.dtype, dtypes.prng_key): wrapped_key = key wrapped = False @@ -101,13 +99,8 @@ def _check_prng_key(name: str, key: KeyArrayLike, *, raise TypeError(f'unexpected PRNG key type {type(key)}') if (not allow_batched) and wrapped_key.ndim: - msg = (f"{name} accepts a single key, but was given a key array of " - f"shape {np.shape(key)} != (). Use jax.vmap for batching.") - if error_on_batched: - raise ValueError(msg) - else: - warnings.warn(msg + " In a future JAX version, this will be an error.", - FutureWarning, stacklevel=3) + raise ValueError(f"{name} accepts a single key, but was given a key array of" + f" shape {np.shape(key)} != (). Use jax.vmap for batching.") return wrapped_key, wrapped @@ -244,14 +237,14 @@ def fold_in(key: KeyArrayLike, data: IntegerArray) -> KeyArray: """Folds in data to a PRNG key to form a new PRNG key. Args: - key: a PRNG key (from ``PRNGKey``, ``split``, ``fold_in``). - data: a 32bit integer representing data to be folded in to the key. + key: a PRNG key (from ``key``, ``split``, ``fold_in``). + data: a 32-bit integer representing data to be folded into the key. Returns: A new PRNG key that is a deterministic function of the inputs and is statistically safe for producing a stream of new pseudo-random values. """ - key, wrapped = _check_prng_key("fold_in", key, error_on_batched=True) + key, wrapped = _check_prng_key("fold_in", key) if np.ndim(data): raise TypeError("fold_in accepts a scalar, but was given an array of" f"shape {np.shape(data)} != (). Use jax.vmap for batching.") @@ -274,14 +267,14 @@ def split(key: KeyArrayLike, num: int | tuple[int, ...] = 2) -> KeyArray: """Splits a PRNG key into `num` new keys by adding a leading axis. Args: - key: a PRNG key (from ``PRNGKey``, ``split``, ``fold_in``). + key: a PRNG key (from ``key``, ``split``, ``fold_in``). num: optional, a positive integer (or tuple of integers) indicating the number (or shape) of keys to produce. Defaults to 2. Returns: An array-like object of `num` new PRNG keys. """ - typed_key, wrapped = _check_prng_key("split", key, error_on_batched=True) + typed_key, wrapped = _check_prng_key("split", key) return _return_prng_keys(wrapped, _split(typed_key, num)) @@ -394,7 +387,7 @@ def uniform(key: KeyArrayLike, f"got {dtype}") dtype = dtypes.canonicalize_dtype(dtype) shape = core.as_named_shape(shape) - return _uniform(key, shape, dtype, minval, maxval) # type: ignore + return _uniform(key, shape, dtype, minval, maxval) @partial(jit, static_argnums=(1, 2)) def _uniform(key, shape, dtype, minval, maxval) -> Array: @@ -410,8 +403,10 @@ def _uniform(key, shape, dtype, minval, maxval) -> Array: finfo = jnp.finfo(dtype) nbits, nmant = finfo.bits, finfo.nmant - if nbits not in (16, 32, 64): - raise TypeError(f"uniform only accepts 16-, 32-, or 64-bit dtypes, got {dtype}.") + if nbits not in (8, 16, 32, 64): + raise TypeError( + f"uniform only accepts 8-, 16-, 32-, or 64-bit dtypesgot {dtype}." + ) rng_bits = nbits if nmant < 8: @@ -539,7 +534,7 @@ def shuffle(key: KeyArrayLike, x: ArrayLike, axis: int = 0) -> Array: "Use jax.random.permutation with independent=True.") warnings.warn(msg, FutureWarning) key, _ = _check_prng_key("shuffle", key) - return _shuffle(key, x, axis) # type: ignore + return _shuffle(key, x, axis) def permutation(key: KeyArrayLike, @@ -623,7 +618,7 @@ def choice(key: KeyArrayLike, e.g., ``(m, n)``, then ``m * n`` samples are drawn. Default is (), in which case a single value is returned. replace : boolean. Whether the sample is with or without replacement. - default is True. + Default is True. p : 1-D array-like, The probabilities associated with each entry in a. If not given the sample assumes a uniform distribution over all entries in a. @@ -707,7 +702,7 @@ def normal(key: KeyArrayLike, f"got {dtype}") dtype = dtypes.canonicalize_dtype(dtype) shape = core.as_named_shape(shape) - return _normal(key, shape, dtype) # type: ignore + return _normal(key, shape, dtype) @partial(jit, static_argnums=(1, 2)) def _normal(key, shape, dtype) -> Array: @@ -720,14 +715,14 @@ def _normal(key, shape, dtype) -> Array: _im = _normal_real(key_im, shape, real_dtype).astype(dtype) return (_re + 1j * _im) / sqrt2 else: - return _normal_real(key, shape, dtype) # type: ignore + return _normal_real(key, shape, dtype) @partial(jit, static_argnums=(1, 2)) def _normal_real(key, shape, dtype) -> Array: _check_shape("normal", shape) lo = np.nextafter(np.array(-1., dtype), np.array(0., dtype), dtype=dtype) hi = np.array(1., dtype) - u = uniform(key, shape, dtype, lo, hi) # type: ignore[arg-type] + u = uniform(key, shape, dtype, lo, hi) return lax.mul(np.array(np.sqrt(2), dtype), lax.erf_inv(u)) @@ -779,7 +774,7 @@ def multivariate_normal(key: KeyArrayLike, f"dtype, got {dtype}") if shape is not None: shape = core.canonicalize_shape(shape) - return _multivariate_normal(key, mean, cov, shape, dtype, method) # type: ignore + return _multivariate_normal(key, mean, cov, shape, dtype, method) @partial(jit, static_argnums=(3, 4, 5)) def _multivariate_normal(key, mean, cov, shape, dtype, method) -> Array: @@ -854,7 +849,7 @@ def truncated_normal(key: KeyArrayLike, dtype = dtypes.canonicalize_dtype(dtype) if shape is not None: shape = core.as_named_shape(shape) - return _truncated_normal(key, lower, upper, shape, dtype) # type: ignore + return _truncated_normal(key, lower, upper, shape, dtype) @partial(jit, static_argnums=(3, 4)) def _truncated_normal(key, lower, upper, shape, dtype) -> Array: @@ -912,7 +907,7 @@ def bernoulli(key: KeyArrayLike, msg = "bernoulli probability `p` must have a floating dtype, got {}." raise TypeError(msg.format(dtype)) p = lax.convert_element_type(p, dtype) - return _bernoulli(key, p, shape) # type: ignore + return _bernoulli(key, p, shape) @partial(jit, static_argnums=(2,)) def _bernoulli(key, p, shape) -> Array: @@ -1034,7 +1029,7 @@ def dirichlet(key: KeyArrayLike, The values are distributed according to the probability density function: .. math:: - f(\{x_i\}; \{\alpha_i\}) = \propto \prod_{i=1}^k x_i^{\alpha_i - 1} + f(\{x_i\}; \{\alpha_i\}) \propto \prod_{i=1}^k x_i^{\alpha_i - 1} Where :math:`k` is the dimension, and :math:`\{x_i\}` satisfies @@ -1233,7 +1228,7 @@ def _gamma_impl(key, a, *, log_space, use_vmap=False): keys = keys.flatten() alphas = a.flatten() - if use_vmap: + if use_vmap and _key_impl(key) is prng.threefry_prng_impl: samples = vmap(partial(_gamma_one, log_space=log_space))(keys, alphas) else: samples = lax.map( @@ -1759,7 +1754,7 @@ def chisquare(key: KeyArrayLike, The values are distributed according to the probability density function: .. math:: - f(x; \nu) \propto x^{k/2 - 1}e^{-x/2} + f(x; \nu) \propto x^{\nu/2 - 1}e^{-x/2} on the domain :math:`0 < x < \infty`, where :math:`\nu > 0` represents the degrees of freedom, given by the parameter ``df``. @@ -1812,7 +1807,7 @@ def f(key: KeyArrayLike, The values are distributed according to the probability density function: .. math:: - f(x; \nu) \propto x^{\nu_1/2 - 1}\left(1 + \frac{\nu_1}{\nu_2}x\right)^{ + f(x; \nu_1, \nu_2) \propto x^{\nu_1/2 - 1}\left(1 + \frac{\nu_1}{\nu_2}x\right)^{ -(\nu_1 + \nu_2) / 2} on the domain :math:`0 < x < \infty`. Here :math:`\nu_1` is the degrees of @@ -1867,7 +1862,7 @@ def _f(key, dfnum, dfden, shape, dtype) -> Array: def rademacher(key: KeyArrayLike, - shape: Shape, + shape: Shape = (), dtype: DTypeLikeInt = int) -> Array: r"""Sample from a Rademacher distribution. @@ -1876,11 +1871,11 @@ def rademacher(key: KeyArrayLike, .. math:: f(k) = \frac{1}{2}(\delta(k - 1) + \delta(k + 1)) - on the domain :math:`k \in \{-1, 1}`, where `\delta(x)` is the dirac delta function. + on the domain :math:`k \in \{-1, 1\}`, where :math:`\delta(x)` is the dirac delta function. Args: key: a PRNG key. - shape: The shape of the returned samples. + shape: The shape of the returned samples. Default (). dtype: The type used for samples. Returns: @@ -2144,7 +2139,7 @@ def rayleigh(key: KeyArrayLike, .. math:: f(x;\sigma) \propto xe^{-x^2/(2\sigma^2)} - on the domain :math:`-\infty < x < \infty`, and where `\sigma > 0` is the scale + on the domain :math:`-\infty < x < \infty`, and where :math:`\sigma > 0` is the scale parameter of the distribution. Args: @@ -2361,7 +2356,6 @@ def _triangular(key, left, mode, right, shape, dtype) -> Array: return tri - def lognormal(key: KeyArrayLike, sigma: RealArray = np.float32(1), shape: Shape | None = None, @@ -2395,7 +2389,7 @@ def lognormal(key: KeyArrayLike, dtype = dtypes.canonicalize_dtype(dtype) if shape is not None: shape = core.canonicalize_shape(shape) - return _lognormal(key, sigma, shape, dtype) # type: ignore + return _lognormal(key, sigma, shape, dtype) @partial(jit, static_argnums=(2, 3), inline=True) def _lognormal(key, sigma, shape, dtype) -> Array: @@ -2611,3 +2605,28 @@ def binomial( if shape is not None: shape = core.canonicalize_shape(shape) return _binomial(key, n, p, shape, dtype) + + +# Functions related to key reuse checking +random_clone_p = core.Primitive("random_clone") +dispatch.simple_impl(random_clone_p) +random_clone_p.def_abstract_eval(lambda x: x) +batching.defvectorized(random_clone_p) +mlir.register_lowering(random_clone_p, lambda _, k: [k]) + +def clone(key): + """Clone a key for reuse + + Outside the context of key reuse checking (see :mod:`jax.experimental.key_reuse`) + this function operates as an identity. + + Examples: + + >>> import jax + >>> key = jax.random.key(0) + >>> data = jax.random.uniform(key) + >>> cloned_key = jax.random.clone(key) + >>> same_data = jax.random.uniform(cloned_key) + >>> assert data == same_data + """ + return random_clone_p.bind(key) diff --git a/jax/_src/scipy/cluster/vq.py b/jax/_src/scipy/cluster/vq.py index 8a071ee89f57..a82c8928644d 100644 --- a/jax/_src/scipy/cluster/vq.py +++ b/jax/_src/scipy/cluster/vq.py @@ -11,36 +11,64 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations import operator -import scipy.cluster.vq -import textwrap - from jax import vmap import jax.numpy as jnp -from jax._src.numpy.util import implements, check_arraylike, promote_dtypes_inexact - - -_no_chkfinite_doc = textwrap.dedent(""" -Does not support the Scipy argument ``check_finite=True``, -because compiled JAX code cannot perform checks of array values at runtime -""") - - -@implements(scipy.cluster.vq.vq, lax_description=_no_chkfinite_doc, skip_params=('check_finite',)) -def vq(obs, code_book, check_finite=True): - check_arraylike("scipy.cluster.vq.vq", obs, code_book) - if obs.ndim != code_book.ndim: - raise ValueError("Observation and code_book should have the same rank") - obs, code_book = promote_dtypes_inexact(obs, code_book) - if obs.ndim == 1: - obs, code_book = obs[..., None], code_book[..., None] - if obs.ndim != 2: - raise ValueError("ndim different than 1 or 2 are not supported") - - # explicitly rank promotion - dist = vmap(lambda ob: jnp.linalg.norm(ob[None] - code_book, axis=-1))(obs) - code = jnp.argmin(dist, axis=-1) - dist_min = vmap(operator.getitem)(dist, code) - return code, dist_min +from jax._src.numpy.util import check_arraylike, promote_dtypes_inexact +from jax._src.typing import Array, ArrayLike + + +def vq(obs: ArrayLike, code_book: ArrayLike, check_finite: bool = True) -> tuple[Array, Array]: + """Assign codes from a code book to a set of observations. + + JAX implementation of :func:`scipy.cluster.vq.vq`. + + Assigns each observation vector in ``obs`` to a code from ``code_book`` + based on the nearest Euclidean distance. + + Args: + obs: array of observation vectors of shape ``(M, N)``. Each row represents + a single observation. If ``obs`` is one-dimensional, then each entry is + treated as a length-1 observation. + code_book: array of codes with shape ``(K, N)``. Each row represents a single + code vector. If ``code_book`` is one-dimensional, then each entry is treated + as a length-1 code. + check_finite: unused in JAX + + Returns: + A tuple of arrays ``(code, dist)`` + + - ``code`` is an integer array of shape ``(M,)`` containing indices ``0 <= i < K`` + of the closest entry in ``code_book`` for the given entry in ``obs``. + - ``dist`` is a float array of shape ``(M,)`` containing the euclidean + distance between each observation and the nearest code. + + Examples: + >>> obs = jnp.array([[1.1, 2.1, 3.1], + ... [5.9, 4.8, 6.2]]) + >>> code_book = jnp.array([[1., 2., 3.], + ... [2., 3., 4.], + ... [3., 4., 5.], + ... [4., 5., 6.]]) + >>> codes, distances = jax.scipy.cluster.vq.vq(obs, code_book) + >>> print(codes) + [0 3] + >>> print(distances) + [0.17320499 1.9209373 ] + """ + del check_finite # unused + check_arraylike("scipy.cluster.vq.vq", obs, code_book) + obs_arr, cb_arr = promote_dtypes_inexact(obs, code_book) + if obs_arr.ndim != cb_arr.ndim: + raise ValueError("Observation and code_book should have the same rank") + if obs_arr.ndim == 1: + obs_arr, cb_arr = obs_arr[..., None], cb_arr[..., None] + if obs_arr.ndim != 2: + raise ValueError("ndim different than 1 or 2 are not supported") + dist = vmap(lambda ob: jnp.linalg.norm(ob[None] - cb_arr, axis=-1))(obs_arr) + code = jnp.argmin(dist, axis=-1) + dist_min = vmap(operator.getitem)(dist, code) + return code, dist_min diff --git a/jax/_src/scipy/fft.py b/jax/_src/scipy/fft.py index 0db98ddc0f40..f1d907cf3f3b 100644 --- a/jax/_src/scipy/fft.py +++ b/jax/_src/scipy/fft.py @@ -18,11 +18,10 @@ from functools import partial import math -import scipy.fft as osp_fft from jax import lax import jax.numpy as jnp from jax._src.util import canonicalize_axis -from jax._src.numpy.util import implements, promote_dtypes_complex +from jax._src.numpy.util import promote_dtypes_complex from jax._src.typing import Array def _W4(N: int, k: Array) -> Array: @@ -42,9 +41,61 @@ def _dct_ortho_norm(out: Array, axis: int) -> Array: # Implementation based on # John Makhoul: A Fast Cosine Transform in One and Two Dimensions (1980) -@implements(osp_fft.dct) + def dct(x: Array, type: int = 2, n: int | None = None, axis: int = -1, norm: str | None = None) -> Array: + """Computes the discrete cosine transform of the input + + JAX implementation of :func:`scipy.fft.dct`. + + Args: + x: array + type: integer, default = 2. Currently only type 2 is supported. + n: integer, default = x.shape[axis]. The length of the transform. + If larger than ``x.shape[axis]``, the input will be zero-padded, if + smaller, the input will be truncated. + axis: integer, default=-1. The axis along which the dct will be performed. + norm: string. The normalization mode. Currently only ``"ortho"`` is supported. + + Returns: + array containing the discrete cosine transform of x + + See Also: + - :func:`jax.scipy.fft.dctn`: multidimensional DCT + - :func:`jax.scipy.fft.idct`: inverse DCT + - :func:`jax.scipy.fft.idctn`: multidimensional inverse DCT + + Examples: + >>> x = jax.random.normal(jax.random.key(0), (3, 3)) + >>> with jnp.printoptions(precision=2, suppress=True): + ... print(jax.scipy.fft.dct(x)) + [[-0.58 -0.33 -1.08] + [-0.88 -1.01 -1.79] + [-1.06 -2.43 1.24]] + + When ``n`` smaller than ``x.shape[axis]`` + + >>> with jnp.printoptions(precision=2, suppress=True): + ... print(jax.scipy.fft.dct(x, n=2)) + [[-0.22 -0.9 ] + [-0.57 -1.68] + [-2.52 -0.11]] + + When ``n`` smaller than ``x.shape[axis]`` and ``axis=0`` + + >>> with jnp.printoptions(precision=2, suppress=True): + ... print(jax.scipy.fft.dct(x, n=2, axis=0)) + [[-2.22 1.43 -0.67] + [ 0.52 -0.26 -0.04]] + + When ``n`` larger than ``x.shape[axis]`` and ``axis=1`` + + >>> with jnp.printoptions(precision=2, suppress=True): + ... print(jax.scipy.fft.dct(x, n=4, axis=1)) + [[-0.58 -0.35 -0.64 -1.11] + [-0.88 -0.9 -1.46 -1.68] + [-1.06 -2.25 -1.15 1.93]] + """ if type != 2: raise NotImplementedError('Only DCT type 2 is implemented.') @@ -81,11 +132,68 @@ def _dct2(x: Array, axes: Sequence[int], norm: str | None) -> Array: return out -@implements(osp_fft.dctn) def dctn(x: Array, type: int = 2, s: Sequence[int] | None=None, axes: Sequence[int] | None = None, norm: str | None = None) -> Array: + """Computes the multidimensional discrete cosine transform of the input + + JAX implementation of :func:`scipy.fft.dctn`. + + Args: + x: array + type: integer, default = 2. Currently only type 2 is supported. + s: integer or sequence of integers. Specifies the shape of the result. If not + specified, it will default to the shape of ``x`` along the specified ``axes``. + axes: integer or sequence of integers. Specifies the axes along which the + transform will be computed. + norm: string. The normalization mode. Currently only ``"ortho"`` is supported. + + Returns: + array containing the discrete cosine transform of x + + See Also: + - :func:`jax.scipy.fft.dct`: one-dimensional DCT + - :func:`jax.scipy.fft.idct`: one-dimensional inverse DCT + - :func:`jax.scipy.fft.idctn`: multidimensional inverse DCT + + Examples: + + ``jax.scipy.fft.dctn`` computes the transform along both the axes by default + when ``axes`` argument is ``None``. + + >>> x = jax.random.normal(jax.random.key(0), (3, 3)) + >>> with jnp.printoptions(precision=2, suppress=True): + ... print(jax.scipy.fft.dctn(x)) + [[-5.04 -7.54 -3.26] + [ 0.83 3.64 -4.03] + [ 0.12 -0.73 3.74]] + + When ``s=[2]``, dimension of the transform along ``axis 0`` will be ``2`` + and dimension along ``axis 1`` will be same as that of input. + + >>> with jnp.printoptions(precision=2, suppress=True): + ... print(jax.scipy.fft.dctn(x, s=[2])) + [[-2.92 -2.68 -5.74] + [ 0.42 0.97 1. ]] + + When ``s=[2]`` and ``axes=[1]``, dimension of the transform along ``axis 1`` will + be ``2`` and dimension along ``axis 0`` will be same as that of input. + Also when ``axes=[1]``, transform will be computed only along ``axis 1``. + + >>> with jnp.printoptions(precision=2, suppress=True): + ... print(jax.scipy.fft.dctn(x, s=[2], axes=[1])) + [[-0.22 -0.9 ] + [-0.57 -1.68] + [-2.52 -0.11]] + + When ``s=[2, 4]``, shape of the transform will be ``(2, 4)``. + + >>> with jnp.printoptions(precision=2, suppress=True): + ... print(jax.scipy.fft.dctn(x, s=[2, 4])) + [[-2.92 -2.49 -4.21 -5.57] + [ 0.42 0.79 1.16 0.8 ]] + """ if type != 2: raise NotImplementedError('Only DCT type 2 is implemented.') @@ -109,9 +217,69 @@ def dctn(x: Array, type: int = 2, return x -@implements(osp_fft.idct) def idct(x: Array, type: int = 2, n: int | None = None, axis: int = -1, norm: str | None = None) -> Array: + """Computes the inverse discrete cosine transform of the input + + JAX implementation of :func:`scipy.fft.idct`. + + Args: + x: array + type: integer, default = 2. Currently only type 2 is supported. + n: integer, default = x.shape[axis]. The length of the transform. + If larger than ``x.shape[axis]``, the input will be zero-padded, if + smaller, the input will be truncated. + axis: integer, default=-1. The axis along which the dct will be performed. + norm: string. The normalization mode. Currently only ``"ortho"`` is supported. + + Returns: + array containing the inverse discrete cosine transform of x + + See Also: + - :func:`jax.scipy.fft.dct`: DCT + - :func:`jax.scipy.fft.dctn`: multidimensional DCT + - :func:`jax.scipy.fft.idctn`: multidimensional inverse DCT + + Examples: + + >>> x = jax.random.normal(jax.random.key(0), (3, 3)) + >>> with jnp.printoptions(precision=2, suppress=True): + ... print(jax.scipy.fft.idct(x)) + [[-0.02 -0. -0.17] + [-0.02 -0.07 -0.28] + [-0.16 -0.36 0.18]] + + When ``n`` smaller than ``x.shape[axis]`` + + >>> with jnp.printoptions(precision=2, suppress=True): + ... print(jax.scipy.fft.idct(x, n=2)) + [[ 0. -0.19] + [-0.03 -0.34] + [-0.38 0.04]] + + When ``n`` smaller than ``x.shape[axis]`` and ``axis=0`` + + >>> with jnp.printoptions(precision=2, suppress=True): + ... print(jax.scipy.fft.idct(x, n=2, axis=0)) + [[-0.35 0.23 -0.1 ] + [ 0.17 -0.09 0.01]] + + When ``n`` larger than ``x.shape[axis]`` and ``axis=0`` + + >>> with jnp.printoptions(precision=2, suppress=True): + ... print(jax.scipy.fft.idct(x, n=4, axis=0)) + [[-0.34 0.03 0.07] + [ 0. 0.18 -0.17] + [ 0.14 0.09 -0.14] + [ 0. -0.18 0.14]] + + ``jax.scipy.fft.idct`` can be used to reconstruct ``x`` from the result + of ``jax.scipy.fft.dct`` + + >>> x_dct = jax.scipy.fft.dct(x) + >>> jnp.allclose(x, jax.scipy.fft.idct(x_dct)) + Array(True, dtype=bool) + """ if type != 2: raise NotImplementedError('Only DCT type 2 is implemented.') @@ -126,7 +294,6 @@ def idct(x: Array, type: int = 2, n: int | None = None, x = _dct_ortho_norm(x, axis) x = _dct_ortho_norm(x, axis) - k = lax.expand_dims(jnp.arange(N, dtype=jnp.float32), [a for a in range(x.ndim) if a != axis]) # everything is complex from here... w4 = _W4(N,k) @@ -139,11 +306,76 @@ def idct(x: Array, type: int = 2, n: int | None = None, out = _dct_deinterleave(x.real, axis) return out -@implements(osp_fft.idctn) + def idctn(x: Array, type: int = 2, - s: Sequence[int] | None=None, - axes: Sequence[int] | None = None, - norm: str | None = None) -> Array: + s: Sequence[int] | None=None, + axes: Sequence[int] | None = None, + norm: str | None = None) -> Array: + """Computes the multidimensional inverse discrete cosine transform of the input + + JAX implementation of :func:`scipy.fft.idctn`. + + Args: + x: array + type: integer, default = 2. Currently only type 2 is supported. + s: integer or sequence of integers. Specifies the shape of the result. If not + specified, it will default to the shape of ``x`` along the specified ``axes``. + axes: integer or sequence of integers. Specifies the axes along which the + transform will be computed. + norm: string. The normalization mode. Currently only ``"ortho"`` is supported. + + Returns: + array containing the inverse discrete cosine transform of x + + See Also: + - :func:`jax.scipy.fft.dct`: one-dimensional DCT + - :func:`jax.scipy.fft.dctn`: multidimensional DCT + - :func:`jax.scipy.fft.idct`: one-dimensional inverse DCT + + Examples: + + ``jax.scipy.fft.idctn`` computes the transform along both the axes by default + when ``axes`` argument is ``None``. + + >>> x = jax.random.normal(jax.random.key(0), (3, 3)) + >>> with jnp.printoptions(precision=2, suppress=True): + ... print(jax.scipy.fft.idctn(x)) + [[-0.03 -0.08 -0.08] + [ 0.05 0.12 -0.09] + [-0.02 -0.04 0.08]] + + When ``s=[2]``, dimension of the transform along ``axis 0`` will be ``2`` + and dimension along ``axis 1`` will be the same as that of input. + + >>> with jnp.printoptions(precision=2, suppress=True): + ... print(jax.scipy.fft.idctn(x, s=[2])) + [[-0.01 -0.03 -0.14] + [ 0. 0.03 0.06]] + + When ``s=[2]`` and ``axes=[1]``, dimension of the transform along ``axis 1`` will + be ``2`` and dimension along ``axis 0`` will be same as that of input. + Also when ``axes=[1]``, transform will be computed only along ``axis 1``. + + >>> with jnp.printoptions(precision=2, suppress=True): + ... print(jax.scipy.fft.idctn(x, s=[2], axes=[1])) + [[ 0. -0.19] + [-0.03 -0.34] + [-0.38 0.04]] + + When ``s=[2, 4]``, shape of the transform will be ``(2, 4)`` + + >>> with jnp.printoptions(precision=2, suppress=True): + ... print(jax.scipy.fft.idctn(x, s=[2, 4])) + [[-0.01 -0.01 -0.05 -0.11] + [ 0. 0.01 0.03 0.04]] + + ``jax.scipy.fft.idctn`` can be used to reconstruct ``x`` from the result + of ``jax.scipy.fft.dctn`` + + >>> x_dctn = jax.scipy.fft.dctn(x) + >>> jnp.allclose(x, jax.scipy.fft.idctn(x_dctn)) + Array(True, dtype=bool) + """ if type != 2: raise NotImplementedError('Only DCT type 2 is implemented.') diff --git a/jax/_src/scipy/integrate.py b/jax/_src/scipy/integrate.py index 97cfe0ff1d0e..b61cdb163b8d 100644 --- a/jax/_src/scipy/integrate.py +++ b/jax/_src/scipy/integrate.py @@ -16,29 +16,54 @@ from functools import partial -import scipy.integrate - from jax import jit -from jax._src.numpy import util from jax._src.typing import Array, ArrayLike import jax.numpy as jnp -@util.implements(scipy.integrate.trapezoid) + @partial(jit, static_argnames=('axis',)) def trapezoid(y: ArrayLike, x: ArrayLike | None = None, dx: ArrayLike = 1.0, axis: int = -1) -> Array: - # TODO(phawkins): remove this annotation after fixing jnp types. - dx_array: Array - if x is None: - util.check_arraylike('trapezoid', y) - y_arr, = util.promote_dtypes_inexact(y) - dx_array = jnp.asarray(dx) - else: - util.check_arraylike('trapezoid', y, x) - y_arr, x_arr = util.promote_dtypes_inexact(y, x) - if x_arr.ndim == 1: - dx_array = jnp.diff(x_arr) - else: - dx_array = jnp.moveaxis(jnp.diff(x_arr, axis=axis), axis, -1) - y_arr = jnp.moveaxis(y_arr, axis, -1) - return 0.5 * (dx_array * (y_arr[..., 1:] + y_arr[..., :-1])).sum(-1) + r""" + Integrate along the given axis using the composite trapezoidal rule. + + JAX implementation of :func:`scipy.integrate.trapezoid` + + The trapezoidal rule approximates the integral under a curve by summing the + areas of trapezoids formed between adjacent data points. + + Args: + y: array of data to integrate. + x: optional array of sample points corresponding to the ``y`` values. If not + provided, ``x`` defaults to equally spaced with spacing given by ``dx``. + dx: The spacing between sample points when `x` is None (default: 1.0). + axis: The axis along which to integrate (default: -1) + + Returns: + The definite integral approximated by the trapezoidal rule. + + See also: + :func:`jax.numpy.trapezoid`: NumPy-style API for trapezoidal integration + + Examples: + Integrate over a regular grid, with spacing 1.0: + + >>> y = jnp.array([1, 2, 3, 2, 3, 2, 1]) + >>> jax.scipy.integrate.trapezoid(y, dx=1.0) + Array(13., dtype=float32) + + Integrate over an irregular grid: + + >>> x = jnp.array([0, 2, 5, 7, 10, 15, 20]) + >>> jax.scipy.integrate.trapezoid(y, x) + Array(43., dtype=float32) + + Approximate :math:`\int_0^{2\pi} \sin^2(x)dx`, which equals :math:`\pi`: + + >>> x = jnp.linspace(0, 2 * jnp.pi, 1000) + >>> y = jnp.sin(x) ** 2 + >>> result = jax.scipy.integrate.trapezoid(y, x) + >>> jnp.allclose(result, jnp.pi) + Array(True, dtype=bool) + """ + return jnp.trapezoid(y, x, dx, axis) diff --git a/jax/_src/scipy/linalg.py b/jax/_src/scipy/linalg.py index 8f6855eb0627..5458d71dedf4 100644 --- a/jax/_src/scipy/linalg.py +++ b/jax/_src/scipy/linalg.py @@ -17,7 +17,6 @@ from functools import partial import numpy as np -import scipy.linalg import textwrap from typing import overload, Any, Literal @@ -29,7 +28,7 @@ from jax._src.lax import linalg as lax_linalg from jax._src.lax import qdwh from jax._src.numpy.util import ( - check_arraylike, implements, promote_dtypes, promote_dtypes_inexact, + check_arraylike, promote_dtypes, promote_dtypes_inexact, promote_dtypes_complex) from jax._src.typing import Array, ArrayLike @@ -46,17 +45,111 @@ def _cholesky(a: ArrayLike, lower: bool) -> Array: l = lax_linalg.cholesky(a if lower else jnp.conj(a.mT), symmetrize_input=False) return l if lower else jnp.conj(l.mT) -@implements(scipy.linalg.cholesky, - lax_description=_no_overwrite_and_chkfinite_doc, skip_params=('overwrite_a', 'check_finite')) + def cholesky(a: ArrayLike, lower: bool = False, overwrite_a: bool = False, check_finite: bool = True) -> Array: + """Compute the Cholesky decomposition of a matrix. + + JAX implementation of :func:`scipy.linalg.cholesky`. + + The Cholesky decomposition of a matrix `A` is: + + .. math:: + + A = U^HU = LL^H + + where `U` is an upper-triangular matrix and `L` is a lower-triangular matrix. + + Args: + a: input array, representing a (batched) positive-definite hermitian matrix. + Must have shape ``(..., N, N)``. + lower: if True, compute the lower Cholesky decomposition `L`. if False + (default), compute the upper Cholesky decomposition `U`. + overwrite_a: unused by JAX + check_finite: unused by JAX + + Returns: + array of shape ``(..., N, N)`` representing the cholesky decomposition + of the input. + + See Also: + - :func:`jax.numpy.linalg.cholesky`: NumPy-stype Cholesky API + - :func:`jax.lax.linalg.cholesky`: XLA-style Cholesky API + - :func:`jax.scipy.linalg.cho_factor` + - :func:`jax.scipy.linalg.cho_solve` + + Examples: + A small real Hermitian positive-definite matrix: + + >>> x = jnp.array([[2., 1.], + ... [1., 2.]]) + + Upper Cholesky factorization: + + >>> jax.scipy.linalg.cholesky(x) + Array([[1.4142135 , 0.70710677], + [0. , 1.2247449 ]], dtype=float32) + + Lower Cholesky factorization: + + >>> jax.scipy.linalg.cholesky(x, lower=True) + Array([[1.4142135 , 0. ], + [0.70710677, 1.2247449 ]], dtype=float32) + + Reconstructing ``x`` from its factorization: + + >>> L = jax.scipy.linalg.cholesky(x, lower=True) + >>> jnp.allclose(x, L @ L.T) + Array(True, dtype=bool) + """ del overwrite_a, check_finite # Unused return _cholesky(a, lower) -@implements(scipy.linalg.cho_factor, - lax_description=_no_overwrite_and_chkfinite_doc, skip_params=('overwrite_a', 'check_finite')) + def cho_factor(a: ArrayLike, lower: bool = False, overwrite_a: bool = False, check_finite: bool = True) -> tuple[Array, bool]: + """Factorization for Cholesky-based linear solves + + JAX implementation of :func:`scipy.linalg.cho_factor`. This function returns + a result suitable for use with :func:`jax.scipy.linalg.cho_solve`. For direct + Cholesky decompositions, prefer :func:`jax.scipy.linalg.cholesky`. + + Args: + a: input array, representing a (batched) positive-definite hermitian matrix. + Must have shape ``(..., N, N)``. + lower: if True, compute the lower triangular Cholesky decomposition (default: False). + overwrite_a: unused by JAX + check_finite: unused by JAX + + Returns: + ``(c, lower)``: ``c`` is an array of shape ``(..., N, N)`` representing the lower or + upper cholesky decomposition of the input; ``lower`` is a boolean specifying whether + this is the lower or upper decomposition. + + See Also: + - :func:`jax.scipy.linalg.cholesky` + - :func:`jax.scipy.linalg.cho_solve` + + Examples: + A small real Hermitian positive-definite matrix: + + >>> x = jnp.array([[2., 1.], + ... [1., 2.]]) + + Compute the cholesky factorization via :func:`~jax.scipy.linalg.cho_factor`, + and use it to solve a linear equation via :func:`~jax.scipy.linalg.cho_solve`. + + >>> b = jnp.array([3., 4.]) + >>> cfac = jax.scipy.linalg.cho_factor(x) + >>> y = jax.scipy.linalg.cho_solve(cfac, b) + >>> y + Array([0.6666666, 1.6666666], dtype=float32) + + Check that the result is consistent: + + >>> jnp.allclose(x @ y, b) + Array(True, dtype=bool) + """ del overwrite_a, check_finite # Unused return (cholesky(a, lower=lower), lower) @@ -70,10 +163,49 @@ def _cho_solve(c: ArrayLike, b: ArrayLike, lower: bool) -> Array: transpose_a=lower, conjugate_a=lower) return b -@implements(scipy.linalg.cho_solve, update_doc=False, - lax_description=_no_overwrite_and_chkfinite_doc, skip_params=('overwrite_b', 'check_finite')) + def cho_solve(c_and_lower: tuple[ArrayLike, bool], b: ArrayLike, overwrite_b: bool = False, check_finite: bool = True) -> Array: + """Solve a linear system using a Cholesky factorization + + JAX implementation of :func:`scipy.linalg.cho_solve`. Uses the output + of :func:`jax.scipy.linalg.cho_factor`. + + Args: + c_and_lower: ``(c, lower)``, where ``c`` is an array of shape ``(..., N, N)`` + representing the lower or upper cholesky decomposition of the matrix, and + ``lower`` is a boolean specifying whether this is the lower or upper decomposition. + b: right-hand-side of linear system. Must have shape ``(..., N)`` + overwrite_a: unused by JAX + check_finite: unused by JAX + + Returns: + Array of shape ``(..., N)`` representing the solution of the linear system. + + See Also: + - :func:`jax.scipy.linalg.cholesky` + - :func:`jax.scipy.linalg.cho_factor` + + Examples: + A small real Hermitian positive-definite matrix: + + >>> x = jnp.array([[2., 1.], + ... [1., 2.]]) + + Compute the cholesky factorization via :func:`~jax.scipy.linalg.cho_factor`, + and use it to solve a linear equation via :func:`~jax.scipy.linalg.cho_solve`. + + >>> b = jnp.array([3., 4.]) + >>> cfac = jax.scipy.linalg.cho_factor(x) + >>> y = jax.scipy.linalg.cho_solve(cfac, b) + >>> y + Array([0.6666666, 1.6666666], dtype=float32) + + Check that the result is consistent: + + >>> jnp.allclose(x @ y, b) + Array(True, dtype=bool) + """ del overwrite_b, check_finite # Unused c, lower = c_and_lower return _cho_solve(c, b, lower) @@ -112,17 +244,112 @@ def svd(a: ArrayLike, full_matrices: bool = True, compute_uv: bool = True, overwrite_a: bool = False, check_finite: bool = True, lapack_driver: str = 'gesdd') -> Array | tuple[Array, Array, Array]: ... -@implements(scipy.linalg.svd, - lax_description=_no_overwrite_and_chkfinite_doc, skip_params=('overwrite_a', 'check_finite', 'lapack_driver')) + def svd(a: ArrayLike, full_matrices: bool = True, compute_uv: bool = True, overwrite_a: bool = False, check_finite: bool = True, lapack_driver: str = 'gesdd') -> Array | tuple[Array, Array, Array]: + r"""Compute the singular value decomposition. + + JAX implementation of :func:`scipy.linalg.svd`. + + The SVD of a matrix `A` is given by + + .. math:: + + A = U\Sigma V^H + + - :math:`U` contains the left singular vectors and satisfies :math:`U^HU=I` + - :math:`V` contains the right singular vectors and satisfies :math:`V^HV=I` + - :math:`\Sigma` is a diagonal matrix of singular values. + + Args: + a: input array, of shape ``(..., N, M)`` + full_matrices: if True (default) compute the full matrices; i.e. ``u`` and ``vh`` have + shape ``(..., N, N)`` and ``(..., M, M)``. If False, then the shapes are + ``(..., N, K)`` and ``(..., K, M)`` with ``K = min(N, M)``. + compute_uv: if True (default), return the full SVD ``(u, s, vh)``. If False then return + only the singular values ``s``. + overwrite_a: unused by JAX + check_finite: unused by JAX + lapack_driver: unused by JAX + + Returns: + A tuple of arrays ``(u, s, vh)`` if ``compute_uv`` is True, otherwise the array ``s``. + + - ``u``: left singular vectors of shape ``(..., N, N)`` if ``full_matrices`` is True + or ``(..., N, K)`` otherwise. + - ``s``: singular values of shape ``(..., K)`` + - ``vh``: conjugate-transposed right singular vectors of shape ``(..., M, M)`` + if ``full_matrices`` is True or ``(..., K, M)`` otherwise. + + where ``K = min(N, M)``. + + See also: + - :func:`jax.numpy.linalg.svd`: NumPy-style SVD API + - :func:`jax.lax.linalg.svd`: XLA-style SVD API + + Examples: + Consider the SVD of a small real-valued array: + + >>> x = jnp.array([[1., 2., 3.], + ... [6., 5., 4.]]) + >>> u, s, vt = jax.scipy.linalg.svd(x, full_matrices=False) + >>> s # doctest: +SKIP + Array([9.361919 , 1.8315067], dtype=float32) + + The singular vectors are in the columns of ``u`` and ``v = vt.T``. These vectors are + orthonormal, which can be demonstrated by comparing the matrix product with the + identity matrix: + + >>> jnp.allclose(u.T @ u, jnp.eye(2), atol=1E-5) + Array(True, dtype=bool) + >>> v = vt.T + >>> jnp.allclose(v.T @ v, jnp.eye(2), atol=1E-5) + Array(True, dtype=bool) + + Given the SVD, ``x`` can be reconstructed via matrix multiplication: + + >>> x_reconstructed = u @ jnp.diag(s) @ vt + >>> jnp.allclose(x_reconstructed, x) + Array(True, dtype=bool) + """ del overwrite_a, check_finite, lapack_driver # unused return _svd(a, full_matrices=full_matrices, compute_uv=compute_uv) -@implements(scipy.linalg.det, - lax_description=_no_overwrite_and_chkfinite_doc, skip_params=('overwrite_a', 'check_finite')) + def det(a: ArrayLike, overwrite_a: bool = False, check_finite: bool = True) -> Array: + """Compute the determinant of a matrix + + JAX implementation of :func:`scipy.linalg.det`. + + Args: + a: input array, of shape ``(..., N, N)`` + overwrite_a: unused by JAX + check_finite: unused by JAX + + Returns + Determinant of shape ``a.shape[:-2]`` + + See Also: + :func:`jax.numpy.linalg.det`: NumPy-style determinant API + + Examples: + Determinant of a small 2D array: + + >>> x = jnp.array([[1., 2.], + ... [3., 4.]]) + >>> jax.scipy.linalg.det(x) + Array(-2., dtype=float32) + + Batch-wise determinant of multiple 2D arrays: + + >>> x = jnp.array([[[1., 2.], + ... [3., 4.]], + ... [[8., 5.], + ... [7., 9.]]]) + >>> jax.scipy.linalg.det(x) + Array([-2., 37.], dtype=float32) + """ del overwrite_a, check_finite # unused return jnp.linalg.det(a) @@ -182,13 +409,70 @@ def eigh(a: ArrayLike, b: ArrayLike | None = None, lower: bool = True, overwrite_b: bool = False, turbo: bool = True, eigvals: None = None, type: int = 1, check_finite: bool = True) -> Array | tuple[Array, Array]: ... -@implements(scipy.linalg.eigh, - lax_description=_no_overwrite_and_chkfinite_doc, - skip_params=('overwrite_a', 'overwrite_b', 'turbo', 'check_finite')) def eigh(a: ArrayLike, b: ArrayLike | None = None, lower: bool = True, eigvals_only: bool = False, overwrite_a: bool = False, overwrite_b: bool = False, turbo: bool = True, eigvals: None = None, type: int = 1, check_finite: bool = True) -> Array | tuple[Array, Array]: + """Compute eigenvalues and eigenvectors for a Hermitian matrix + + JAX implementation of :func:`scipy.linalg.eigh`. + + Args: + a: Hermitian input array of shape ``(..., N, N)`` + b: optional Hermitian input of shape ``(..., N, N)``. If specified, compute + the generalized eigenvalue problem. + lower: if True (default) access only the lower portion of the input matrix. + Otherwise access only the upper portion. + eigvals_only: If True, compute only the eigenvalues. If False (default) compute + both eigenvalues and eigenvectors. + type: if ``b`` is specified, ``type`` gives the type of generalized eigenvalue + problem to be computed. Denoting ``(λ, v)`` as an eigenvalue, eigenvector pair: + + - ``type = 1`` solves ``a @ v = λ * b @ v`` (default) + - ``type = 2`` solves ``a @ b @ v = λ * v`` + - ``type = 3`` solves ``b @ a @ v = λ * v`` + + eigvals: a ``(low, high)`` tuple specifying which eigenvalues to compute. + overwrite_a: unused by JAX. + overwrite_b: unused by JAX. + turbo: unused by JAX. + check_finite: unused by JAX. + + Returns: + A tuple of arrays ``(eigvals, eigvecs)`` if ``eigvals_only`` is False, otherwise + an array ``eigvals``. + + - ``eigvals``: array of shape ``(..., N)`` containing the eigenvalues. + - ``eigvecs``: array of shape ``(..., N, N)`` containing the eigenvectors. + + See also: + - :func:`jax.numpy.linalg.eigh`: NumPy-style eigh API. + - :func:`jax.lax.linalg.eigh`: XLA-style eigh API. + - :func:`jax.numpy.linalg.eig`: non-hermitian eigenvalue problem. + - :func:`jax.scipy.linalg.eigh_tridiagonal`: tri-diagonal eigenvalue problem. + + Examples: + Compute the standard eigenvalue decomposition of a simple 2x2 matrix: + + >>> a = jnp.array([[2., 1.], + ... [1., 2.]]) + >>> eigvals, eigvecs = jax.scipy.linalg.eigh(a) + >>> eigvals + Array([1., 3.], dtype=float32) + >>> eigvecs + Array([[-0.70710677, 0.70710677], + [ 0.70710677, 0.70710677]], dtype=float32) + + Eigenvectors are orthonormal: + + >>> jnp.allclose(eigvecs.T @ eigvecs, jnp.eye(2), atol=1E-5) + Array(True, dtype=bool) + + Solution satisfies the eigenvalue problem: + + >>> jnp.allclose(a @ eigvecs, eigvecs @ jnp.diag(eigvals)) + Array(True, dtype=bool) + """ del overwrite_a, overwrite_b, turbo, check_finite # unused return _eigh(a, b, lower, eigvals_only, eigvals, type) @@ -198,35 +482,230 @@ def _schur(a: Array, output: str) -> tuple[Array, Array]: a = a.astype(dtypes.to_complex_dtype(a.dtype)) return lax_linalg.schur(a) -@implements(scipy.linalg.schur) def schur(a: ArrayLike, output: str = 'real') -> tuple[Array, Array]: + """Compute the Schur decomposition + + JAX implementation of :func:`scipy.linalg.schur`. + + The Schur form `T` of a matrix `A` satisfies: + + .. math:: + + A = Z T Z^H + + where `Z` is unitary, and `T` is upper-triangular for the complex-valued Schur + decomposition (i.e. ``output="complex"``) and is quasi-upper-triangular for the + real-valued Schur decomposition (i.e. ``output="real"``). In the quasi-triangular + case, the diagonal may include 2x2 blocks associated with complex-valued + eigenvalue pairs of `A`. + + Args: + a: input array of shape ``(..., N, N)`` + output: Specify whether to compute the ``"real"`` (default) or ``"complex"`` + Schur decomposition. + + Returns: + A tuple of arrays ``(T, Z)`` + + - ``T`` is a shape ``(..., N, N)`` array containing the upper-triangular + Schur form of the input. + - ``Z`` is a shape ``(..., N, N)`` array containing the unitary Schur + transformation matrix. + + See also: + - :func:`jax.scipy.linalg.rsf2csf`: convert real Schur form to complex Schur form. + - :func:`jax.lax.linalg.schur`: XLA-style API for Schur decomposition. + + Examples: + A Schur decomposition of a 3x3 matrix: + + >>> a = jnp.array([[1., 2., 3.], + ... [1., 4., 2.], + ... [3., 2., 1.]]) + >>> T, Z = jax.scipy.linalg.schur(a) + + The Schur form ``T`` is quasi-upper-triangular in general, but is truly + upper-triangular in this case because the input matrix is symmetric: + + >>> T # doctest: +SKIP + Array([[-2.0000005 , 0.5066295 , -0.43360388], + [ 0. , 1.5505103 , 0.74519426], + [ 0. , 0. , 6.449491 ]], dtype=float32) + + The transformation matrix ``Z`` is unitary: + + >>> jnp.allclose(Z.T @ Z, jnp.eye(3), atol=1E-5) + Array(True, dtype=bool) + + The input can be reconstructed from the outputs: + + >>> jnp.allclose(Z @ T @ Z.T, a) + Array(True, dtype=bool) + """ if output not in ('real', 'complex'): raise ValueError( f"Expected 'output' to be either 'real' or 'complex', got {output=}.") return _schur(a, output) -@implements(scipy.linalg.inv, - lax_description=_no_overwrite_and_chkfinite_doc, skip_params=('overwrite_a', 'check_finite')) + def inv(a: ArrayLike, overwrite_a: bool = False, check_finite: bool = True) -> Array: + """Return the inverse of a square matrix + + JAX implementation of :func:`scipy.linalg.inv`. + + Args: + a: array of shape ``(..., N, N)`` specifying square array(s) to be inverted. + overwrite_a: unused in JAX + check_finite: unused in JAX + + Returns: + Array of shape ``(..., N, N)`` containing the inverse of the input. + + Notes: + In most cases, explicitly computing the inverse of a matrix is ill-advised. For + example, to compute ``x = inv(A) @ b``, it is more performant and numerically + precise to use a direct solve, such as :func:`jax.scipy.linalg.solve`. + + See Also: + - :func:`jax.numpy.linalg.inv`: NumPy-style API for matrix inverse + - :func:`jax.scipy.linalg.solve`: direct linear solver + + Examples: + Compute the inverse of a 3x3 matrix + + >>> a = jnp.array([[1., 2., 3.], + ... [2., 4., 2.], + ... [3., 2., 1.]]) + >>> a_inv = jax.scipy.linalg.inv(a) + >>> a_inv # doctest: +SKIP + Array([[ 0. , -0.25 , 0.5 ], + [-0.25 , 0.5 , -0.25000003], + [ 0.5 , -0.25 , 0. ]], dtype=float32) + + Check that multiplying with the inverse gives the identity: + + >>> jnp.allclose(a @ a_inv, jnp.eye(3), atol=1E-5) + Array(True, dtype=bool) + + Multiply the inverse by a vector ``b``, to find a solution to ``a @ x = b`` + + >>> b = jnp.array([1., 4., 2.]) + >>> a_inv @ b + Array([ 0. , 1.25, -0.5 ], dtype=float32) + + Note, however, that explicitly computing the inverse in such a case can lead + to poor performance and loss of precision as the size of the problem grows. + Instead, you should use a direct solver like :func:`jax.scipy.linalg.solve`: + + >>> jax.scipy.linalg.solve(a, b) + Array([ 0. , 1.25, -0.5 ], dtype=float32) + """ del overwrite_a, check_finite # unused return jnp.linalg.inv(a) -@implements(scipy.linalg.lu_factor, - lax_description=_no_overwrite_and_chkfinite_doc, skip_params=('overwrite_a', 'check_finite')) @partial(jit, static_argnames=('overwrite_a', 'check_finite')) def lu_factor(a: ArrayLike, overwrite_a: bool = False, check_finite: bool = True) -> tuple[Array, Array]: + """Factorization for LU-based linear solves + + JAX implementation of :func:`scipy.linalg.lu_factor`. + + This function returns a result suitable for use with :func:`jax.scipy.linalg.lu_solve`. + For direct LU decompositions, prefer :func:`jax.scipy.linalg.lu`. + + Args: + a: input array of shape ``(..., M, N)``. + overwrite_a: unused by JAX + check_finite: unused by JAX + + Returns: + A tuple ``(lu, piv)`` + + - ``lu`` is an array of shape ``(..., M, N)``, containing ``L`` in its + lower triangle and ``U`` in its upper. + - ``piv`` is an array of shape ``(..., K)`` with ``K = min(M, N)``, + which encodes the pivots. + + See Also: + - :func:`jax.scipy.linalg.lu` + - :func:`jax.scipy.linalg.lu_solve` + + Examples: + Solving a small linear system via LU factorization: + + >>> a = jnp.array([[2., 1.], + ... [1., 2.]]) + + Compute the lu factorization via :func:`~jax.scipy.linalg.lu_factor`, + and use it to solve a linear equation via :func:`~jax.scipy.linalg.lu_solve`. + + >>> b = jnp.array([3., 4.]) + >>> lufac = jax.scipy.linalg.lu_factor(a) + >>> y = jax.scipy.linalg.lu_solve(lufac, b) + >>> y + Array([0.6666666, 1.6666667], dtype=float32) + + Check that the result is consistent: + + >>> jnp.allclose(a @ y, b) + Array(True, dtype=bool) + """ del overwrite_a, check_finite # unused a, = promote_dtypes_inexact(jnp.asarray(a)) lu, pivots, _ = lax_linalg.lu(a) return lu, pivots -@implements(scipy.linalg.lu_solve, - lax_description=_no_overwrite_and_chkfinite_doc, skip_params=('overwrite_b', 'check_finite')) @partial(jit, static_argnames=('trans', 'overwrite_b', 'check_finite')) def lu_solve(lu_and_piv: tuple[Array, ArrayLike], b: ArrayLike, trans: int = 0, overwrite_b: bool = False, check_finite: bool = True) -> Array: + """Solve a linear system using an LU factorization + + JAX implementation of :func:`scipy.linalg.lu_solve`. Uses the output + of :func:`jax.scipy.linalg.lu_factor`. + + Args: + lu_and_piv: ``(lu, piv)``, output of :func:`~jax.scipy.linalg.lu_factor`. + ``lu`` is an array of shape ``(..., M, N)``, containing ``L`` in its lower + triangle and ``U`` in its upper. ``piv`` is an array of shape ``(..., K)``, + with ``K = min(M, N)``, which encodes the pivots. + b: right-hand-side of linear system. Must have shape ``(..., M)`` + trans: type of system to solve. Options are: + + - ``0``: :math:`A x = b` + - ``1``: :math:`A^Tx = b` + - ``2``: :math:`A^Hx = b` + + overwrite_b: unused by JAX + check_finite: unused by JAX + + Returns: + Array of shape ``(..., N)`` representing the solution of the linear system. + + See Also: + - :func:`jax.scipy.linalg.lu` + - :func:`jax.scipy.linalg.lu_factor` + + Examples: + Solving a small linear system via LU factorization: + + >>> a = jnp.array([[2., 1.], + ... [1., 2.]]) + + Compute the lu factorization via :func:`~jax.scipy.linalg.lu_factor`, + and use it to solve a linear equation via :func:`~jax.scipy.linalg.lu_solve`. + + >>> b = jnp.array([3., 4.]) + >>> lufac = jax.scipy.linalg.lu_factor(a) + >>> y = jax.scipy.linalg.lu_solve(lufac, b) + >>> y + Array([0.6666666, 1.6666667], dtype=float32) + + Check that the result is consistent: + + >>> jnp.allclose(a @ y, b) + Array(True, dtype=bool) + """ del overwrite_b, check_finite # unused lu, pivots = lu_and_piv m, _ = lu.shape[-2:] @@ -269,11 +748,75 @@ def lu(a: ArrayLike, permute_l: Literal[True], overwrite_a: bool = False, def lu(a: ArrayLike, permute_l: bool = False, overwrite_a: bool = False, check_finite: bool = True) -> tuple[Array, Array] | tuple[Array, Array, Array]: ... -@implements(scipy.linalg.lu, update_doc=False, - lax_description=_no_overwrite_and_chkfinite_doc, skip_params=('overwrite_a', 'check_finite')) + @partial(jit, static_argnames=('permute_l', 'overwrite_a', 'check_finite')) def lu(a: ArrayLike, permute_l: bool = False, overwrite_a: bool = False, check_finite: bool = True) -> tuple[Array, Array] | tuple[Array, Array, Array]: + """Compute the LU decomposition + + JAX implementation of :func:`scipy.linalg.lu`. + + The LU decomposition of a matrix `A` is: + + .. math:: + + A = P L U + + where `P` is a permutation matrix, `L` is lower-triangular and `U` is upper-triangular. + + Args: + a: array of shape ``(..., M, N)`` to decompose. + permute_l: if True, then permute ``L`` and return ``(P @ L, U)`` (default: False) + overwrite_a: not used by JAX + check_finite: not used by JAX + + Returns: + A tuple of arrays ``(P @ L, U)`` if ``permute_l`` is True, else ``(P, L, U)``: + + - ``P`` is a permutation matrix of shape ``(..., M, M)`` + - ``L`` is a lower-triangular matrix of shape ``(... M, K)`` + - ``U`` is an upper-triangular matrix of shape ``(..., K, N)`` + + with ``K = min(M, N)`` + + See also: + - :func:`jax.numpy.linalg.lu`: NumPy-style API for LU decomposition. + - :func:`jax.lax.linalg.lu`: XLA-style API for LU decomposition. + - :func:`jax.scipy.linalg.lu_solve`: LU-based linear solver. + + Examples: + An LU decomposition of a 3x3 matrix: + + >>> a = jnp.array([[1., 2., 3.], + ... [5., 4., 2.], + ... [3., 2., 1.]]) + >>> P, L, U = jax.scipy.linalg.lu(a) + + ``P`` is a permutation matrix: i.e. each row and column has a single ``1``: + + >>> P + Array([[0., 1., 0.], + [1., 0., 0.], + [0., 0., 1.]], dtype=float32) + + ``L`` and ``U`` are lower-triangular and upper-triangular matrices: + + >>> with jnp.printoptions(precision=3): + ... print(L) + ... print(U) + [[ 1. 0. 0. ] + [ 0.2 1. 0. ] + [ 0.6 -0.333 1. ]] + [[5. 4. 2. ] + [0. 1.2 2.6 ] + [0. 0. 0.667]] + + The original matrix can be reconstructed by multiplying the three together: + + >>> a_reconstructed = P @ L @ U + >>> jnp.allclose(a, a_reconstructed) + Array(True, dtype=bool) + """ del overwrite_a, check_finite # unused return _lu(a, permute_l) @@ -320,10 +863,77 @@ def qr(a: ArrayLike, overwrite_a: bool = False, lwork: Any = None, *, mode: Lit def qr(a: ArrayLike, overwrite_a: bool = False, lwork: Any = None, mode: str = "full", pivoting: bool = False, check_finite: bool = True) -> tuple[Array] | tuple[Array, Array]: ... -@implements(scipy.linalg.qr, - lax_description=_no_overwrite_and_chkfinite_doc, skip_params=('overwrite_a', 'check_finite', 'lwork')) + def qr(a: ArrayLike, overwrite_a: bool = False, lwork: Any = None, mode: str = "full", pivoting: bool = False, check_finite: bool = True) -> tuple[Array] | tuple[Array, Array]: + """Compute the QR decomposition of an array + + JAX implementation of :func:`scipy.linalg.qr`. + + The QR decomposition of a matrix `A` is given by + + .. math:: + + A = QR + + Where `Q` is a unitary matrix (i.e. :math:`Q^HQ=I`) and `R` is an upper-triangular + matrix. + + Args: + a: array of shape (..., M, N) + mode: Computational mode. Supported values are: + + - ``"full"`` (default): return `Q` of shape ``(M, M)`` and `R` of shape ``(M, N)``. + - ``"r"``: return only `R` + - ``"economic"``: return `Q` of shape ``(M, K)`` and `R` of shape ``(K, N)``, + where K = min(M, N). + + pivoting: Not implemented in JAX. + overwrite_a: unused in JAX + lwork: unused in JAX + check_finite: unused in JAX + + Returns: + A tuple ``(Q, R)`` (if ``mode`` is not ``"r"``) otherwise an array ``R``, + where: + + - ``Q`` is an orthogonal matrix of shape ``(..., M, M)`` (if ``mode`` is ``"full"``) + or ``(..., M, K)`` (if ``mode`` is ``"economic"``). + - ``R`` is an upper-triangular matrix of shape ``(..., M, N)`` (if ``mode`` is + ``"r"`` or ``"full"``) or ``(..., K, N)`` (if ``mode`` is ``"economic"``) + + with ``K = min(M, N)``. + + See also: + - :func:`jax.numpy.linalg.qr`: NumPy-style QR decomposition API + - :func:`jax.lax.linalg.qr`: XLA-style QR decomposition API + + Examples: + Compute the QR decomposition of a matrix: + + >>> a = jnp.array([[1., 2., 3., 4.], + ... [5., 4., 2., 1.], + ... [6., 3., 1., 5.]]) + >>> Q, R = jax.scipy.linalg.qr(a) + >>> Q # doctest: +SKIP + Array([[-0.12700021, -0.7581426 , -0.6396022 ], + [-0.63500065, -0.43322435, 0.63960224], + [-0.7620008 , 0.48737738, -0.42640156]], dtype=float32) + >>> R # doctest: +SKIP + Array([[-7.8740077, -5.080005 , -2.4130025, -4.953006 ], + [ 0. , -1.7870499, -2.6534991, -1.028908 ], + [ 0. , 0. , -1.0660033, -4.050814 ]], dtype=float32) + + Check that ``Q`` is orthonormal: + + >>> jnp.allclose(Q.T @ Q, jnp.eye(3), atol=1E-5) + Array(True, dtype=bool) + + Reconstruct the input: + + >>> jnp.allclose(Q @ R, a) + Array(True, dtype=bool) + """ del overwrite_a, lwork, check_finite # unused return _qr(a, mode, pivoting) @@ -352,12 +962,59 @@ def _solve(a: ArrayLike, b: ArrayLike, assume_a: str, lower: bool) -> Array: return vmap(custom_solve, b.ndim - 1, max(a.ndim, b.ndim) - 1)(b) -@implements(scipy.linalg.solve, - lax_description=_no_overwrite_and_chkfinite_doc, - skip_params=('overwrite_a', 'overwrite_b', 'debug', 'check_finite')) def solve(a: ArrayLike, b: ArrayLike, lower: bool = False, overwrite_a: bool = False, overwrite_b: bool = False, debug: bool = False, check_finite: bool = True, assume_a: str = 'gen') -> Array: + """Solve a linear system of equations + + JAX implementation of :func:`scipy.linalg.solve`. + + This solves a (batched) linear system of equations ``a @ x = b`` for ``x`` + given ``a`` and ``b``. + + Args: + a: array of shape ``(..., N, N)``. + b: array of shape ``(..., N)`` or ``(..., N, M)`` + lower: Referenced only if ``assume_a != 'gen'``. If True, only use the lower + triangle of the input, If False (default), only use the upper triangle. + assume_a: specify what properties of ``a`` can be assumed. Options are: + + - ``"gen"``: generic matrix (default) + - ``"sym"``: symmetric matrix + - ``"her"``: hermitian matrix + - ``"pos"``: positive-definite matrix + + overwrite_a: unused by JAX + overwrite_b: unused by JAX + debug: unused by JAX + check_finite: unused by JAX + + Returns: + An array of the same shape as ``b`` containing the solution to the linear system. + + See also: + - :func:`jax.scipy.linalg.lu_solve`: Solve via LU factorization. + - :func:`jax.scipy.linalg.cho_solve`: Solve via Cholesky factorization. + - :func:`jax.scipy.linalg.solve_triangular`: Solve a triangular system. + - :func:`jax.numpy.linalg.solve`: NumPy-style API for solving linear systems. + - :func:`jax.lax.custom_linear_solve`: matrix-free linear solver. + + Examples: + A simple 3x3 linear system: + + >>> A = jnp.array([[1., 2., 3.], + ... [2., 4., 2.], + ... [3., 2., 1.]]) + >>> b = jnp.array([14., 16., 10.]) + >>> x = jax.scipy.linalg.solve(A, b) + >>> x + Array([1., 2., 3.], dtype=float32) + + Confirming that the result solves the system: + + >>> jnp.allclose(A @ x, b) + Array(True, dtype=bool) + """ del overwrite_a, overwrite_b, debug, check_finite #unused valid_assume_a = ['gen', 'sym', 'her', 'pos'] if assume_a not in valid_assume_a: @@ -391,32 +1048,122 @@ def _solve_triangular(a: ArrayLike, b: ArrayLike, trans: int | str, else: return out -@implements(scipy.linalg.solve_triangular, - lax_description=_no_overwrite_and_chkfinite_doc, skip_params=('overwrite_b', 'debug', 'check_finite')) + def solve_triangular(a: ArrayLike, b: ArrayLike, trans: int | str = 0, lower: bool = False, unit_diagonal: bool = False, overwrite_b: bool = False, debug: Any = None, check_finite: bool = True) -> Array: - del overwrite_b, debug, check_finite # unused - return _solve_triangular(a, b, trans, lower, unit_diagonal) + """Solve a triangular linear system of equations -_expm_description = textwrap.dedent(""" -In addition to the original NumPy argument(s) listed below, -also supports the optional boolean argument ``upper_triangular`` -to specify whether the ``A`` matrix is upper triangular, and the optional -argument ``max_squarings`` to specify the max number of squarings allowed -in the scaling-and-squaring approximation method. Return nan if the actual -number of squarings required is more than ``max_squarings``. + JAX implementation of :func:`scipy.linalg.solve_triangular`. -The number of required squarings = max(0, ceil(log2(norm(A)) - c) -where norm() denotes the L1 norm, and + This solves a (batched) linear system of equations ``a @ x = b`` for ``x`` + given a triangular matrix ``a`` and a vector or matrix ``b``. + + Args: + a: array of shape ``(..., N, N)``. Only part of the array will be accessed, + depending on the ``lower`` and ``unit_diagonal`` arguments. + b: array of shape ``(..., N)`` or ``(..., N, M)`` + lower: If True, only use the lower triangle of the input, If False (default), + only use the upper triangle. + unit_diagonal: If True, ignore diagonal elements of ``a`` and assume they are + ``1`` (default: False). + trans: specify what properties of ``a`` can be assumed. Options are: + + - ``0`` or ``'N'``: solve :math:`Ax=b` + - ``1`` or ``'T'``: solve :math:`A^Tx=b` + - ``2`` or ``'C'``: solve :math:`A^Hx=b` + + overwrite_b: unused by JAX + debug: unused by JAX + check_finite: unused by JAX + + Returns: + An array of the same shape as ``b`` containing the solution to the linear system. + + See also: + :func:`jax.scipy.linalg.solve`: Solve a general linear system. + + Examples: + A simple 3x3 triangular linear system: + + >>> A = jnp.array([[1., 2., 3.], + ... [0., 3., 2.], + ... [0., 0., 5.]]) + >>> b = jnp.array([10., 8., 5.]) + >>> x = jax.scipy.linalg.solve_triangular(A, b) + >>> x + Array([3., 2., 1.], dtype=float32) + + Confirming that the result solves the system: + + >>> jnp.allclose(A @ x, b) + Array(True, dtype=bool) + + Computing the transposed problem: + + >>> x = jax.scipy.linalg.solve_triangular(A, b, trans='T') + >>> x + Array([10. , -4. , -3.4], dtype=float32) + + Confirming that the result solves the system: + + >>> jnp.allclose(A.T @ x, b) + Array(True, dtype=bool) + """ + del overwrite_b, debug, check_finite # unused + return _solve_triangular(a, b, trans, lower, unit_diagonal) -- c=2.42 for float64 or complex128, -- c=1.97 for float32 or complex64 -""") -@implements(scipy.linalg.expm, lax_description=_expm_description) @partial(jit, static_argnames=('upper_triangular', 'max_squarings')) def expm(A: ArrayLike, *, upper_triangular: bool = False, max_squarings: int = 16) -> Array: + """Compute the matrix exponential + + JAX implementation of :func:`scipy.linalg.expm`. + + Args: + A: array of shape ``(..., N, N)`` + upper_triangular: if True, then assume that ``A`` is upper-triangular. Default=False. + max_squarings: The number of squarings in the scaling-and-squaring approximation method + (default: 16). + + Returns: + An array of shape ``(..., N, N)`` containing the matrix exponent of ``A``. + + Notes: + This uses the scaling-and-squaring approximation method, with computational complexity + controlled by the optional ``max_squarings`` argument. Theoretically, the number of + required squarings is ``max(0, ceil(log2(norm(A))) - c)`` where ``norm(A)`` is the L1 + norm and ``c=2.42`` for float64/complex128, or ``c=1.97`` for float32/complex64. + + See Also: + :func:`jax.scipy.linalg.expm_frechet` + + Examples: + + ``expm`` is the matrix exponential, and has similar properties to the more + familiar scalar exponential. For scalars ``a`` and ``b``, :math:`e^{a + b} + = e^a e^b`. However, for matrices, this property only holds when ``A`` and + ``B`` commute (``AB = BA``). In this case, ``expm(A+B) = expm(A) @ expm(B)`` + + >>> A = jnp.array([[2, 0], + ... [0, 1]]) + >>> B = jnp.array([[3, 0], + ... [0, 4]]) + >>> jnp.allclose(jax.scipy.linalg.expm(A+B), + ... jax.scipy.linalg.expm(A) @ jax.scipy.linalg.expm(B), + ... rtol=0.0001) + Array(True, dtype=bool) + + If a matrix ``X`` is invertible, then + ``expm(X @ A @ inv(X)) = X @ expm(A) @ inv(X)`` + + >>> X = jnp.array([[3, 1], + ... [2, 5]]) + >>> X_inv = jax.scipy.linalg.inv(X) + >>> jnp.allclose(jax.scipy.linalg.expm(X @ A @ X_inv), + ... X @ jax.scipy.linalg.expm(A) @ X_inv) + Array(True, dtype=bool) + """ A, = promote_dtypes_inexact(A) if A.ndim < 2 or A.shape[-1] != A.shape[-2]: @@ -554,12 +1301,6 @@ def _pade13(A: Array) -> tuple[Array, Array]: return U,V -_expm_frechet_description = textwrap.dedent(""" -Does not currently support the Scipy argument ``jax.numpy.asarray_chkfinite``, -because `jax.numpy.asarray_chkfinite` does not exist at the moment. Does not -support the ``method='blockEnlarge'`` argument. -""") - @overload def expm_frechet(A: ArrayLike, E: ArrayLike, *, method: str | None = None, compute_expm: Literal[True] = True) -> tuple[Array, Array]: ... @@ -572,34 +1313,89 @@ def expm_frechet(A: ArrayLike, E: ArrayLike, *, method: str | None = None, def expm_frechet(A: ArrayLike, E: ArrayLike, *, method: str | None = None, compute_expm: bool = True) -> Array | tuple[Array, Array]: ... -@implements(scipy.linalg.expm_frechet, lax_description=_expm_frechet_description) + @partial(jit, static_argnames=('method', 'compute_expm')) def expm_frechet(A: ArrayLike, E: ArrayLike, *, method: str | None = None, compute_expm: bool = True) -> Array | tuple[Array, Array]: + """Compute the Frechet derivative of the matrix exponential. + + JAX implementation of :func:`scipy.linalg.expm_frechet` + + Args: + A: array of shape ``(..., N, N)`` + E: array of shape ``(..., N, N)``; specifies the direction of the derivative. + compute_expm: if True (default) then compute and return ``expm(A)``. + method: ignored by JAX + + Returns: + A tuple ``(expm_A, expm_frechet_AE)`` if ``compute_expm`` is True, else + the array ``expm_frechet_AE``. Both returned arrays have shape ``(..., N, N)``. + + See also: + :func:`jax.scipy.linalg.expm` + + Examples: + We can use this API to compute the matrix exponential of ``A``, as well as its + derivative in the direction ``E``: + + >>> key1, key2 = jax.random.split(jax.random.key(3372)) + >>> A = jax.random.normal(key1, (3, 3)) + >>> E = jax.random.normal(key2, (3, 3)) + >>> expmA, expm_frechet_AE = jax.scipy.linalg.expm_frechet(A, E) + + This can be equivalently computed using JAX's automatic differentiation methods; + here we'll compute the derivative of :func:`~jax.scipy.linalg.expm` in the + direction of ``E`` using :func:`jax.jvp`, and find the same results: + + >>> expmA2, expm_frechet_AE2 = jax.jvp(jax.scipy.linalg.expm, (A,), (E,)) + >>> jnp.allclose(expmA, expmA2) + Array(True, dtype=bool) + >>> jnp.allclose(expm_frechet_AE, expm_frechet_AE2) + Array(True, dtype=bool) + """ + del method # unused A_arr = jnp.asarray(A) E_arr = jnp.asarray(E) - if A_arr.ndim != 2 or A_arr.shape[0] != A_arr.shape[1]: - raise ValueError('expected A to be a square matrix') - if E_arr.ndim != 2 or E_arr.shape[0] != E_arr.shape[1]: - raise ValueError('expected E to be a square matrix') + if A_arr.ndim < 2 or A_arr.shape[-2] != A_arr.shape[1]: + raise ValueError(f'expected A to be a (batched) square matrix, got A.shape={A_arr.shape}') + if E_arr.ndim < 2 or E_arr.shape[-2] != E_arr.shape[-1]: + raise ValueError(f'expected E to be a (batched) square matrix, got E.shape={E_arr.shape}') if A_arr.shape != E_arr.shape: - raise ValueError('expected A and E to be the same shape') - if method is None: - method = 'SPS' - if method == 'SPS': - bound_fun = partial(expm, upper_triangular=False, max_squarings=16) - expm_A, expm_frechet_AE = jvp(bound_fun, (A_arr,), (E_arr,)) - else: - raise ValueError('only method=\'SPS\' is supported') + raise ValueError('expected A and E to be the same shape, got ' + f'A.shape={A_arr.shape} E.shape={E_arr.shape}') + bound_fun = partial(expm, upper_triangular=False, max_squarings=16) + expm_A, expm_frechet_AE = jvp(bound_fun, (A_arr,), (E_arr,)) if compute_expm: return expm_A, expm_frechet_AE else: return expm_frechet_AE -@implements(scipy.linalg.block_diag) @jit def block_diag(*arrs: ArrayLike) -> Array: + """Create a block diagonal matrix from input arrays. + + JAX implementation of :func:`scipy.linalg.block_diag`. + + Args: + *arrs: arrays of at most two dimensions + + Returns: + 2D block-diagonal array constructed by placing the input arrays + along the diagonal. + + Examples: + >>> A = jnp.ones((1, 1)) + >>> B = jnp.ones((2, 2)) + >>> C = jnp.ones((3, 3)) + >>> jax.scipy.linalg.block_diag(A, B, C) + Array([[1., 0., 0., 0., 0., 0.], + [0., 1., 1., 0., 0., 0.], + [0., 1., 1., 0., 0., 0.], + [0., 0., 0., 1., 1., 1.], + [0., 0., 0., 1., 1., 1.], + [0., 0., 0., 1., 1., 1.]], dtype=float32) + """ if len(arrs) == 0: arrs = (jnp.zeros((1, 0)),) arrs = tuple(promote_dtypes(*arrs)) @@ -619,11 +1415,54 @@ def block_diag(*arrs: ArrayLike) -> Array: return acc -@implements(scipy.linalg.eigh_tridiagonal) @partial(jit, static_argnames=("eigvals_only", "select", "select_range")) def eigh_tridiagonal(d: ArrayLike, e: ArrayLike, *, eigvals_only: bool = False, select: str = 'a', select_range: tuple[float, float] | None = None, tol: float | None = None) -> Array: + """Solve the eigenvalue problem for a symmetric real tridiagonal matrix + + JAX implementation of :func:`scipy.linalg.eigh_tridiagonal`. + + Args: + d: real-valued array of shape ``(N,)`` specifying the diagonal elements. + e: real-valued array of shape ``(N - 1,)`` specifying the off-diagonal elements. + eigvals_only: If True, return only the eigenvalues (default: False). Computation + of eigenvectors is not yet implemented, so ``eigvals_only`` must be set to True. + select: specify which eigenvalues to calculate. Supported values are: + + - ``'a'``: all eigenvalues + - ``'i'``: eigenvalues with indices ``select_range[0] <= i <= select_range[1]`` + + JAX does not currently implement ``select = 'v'``. + select_range: range of values used when ``select='i'``. + tol: absolute tolerance to use when solving for the eigenvalues. + + Returns: + An array of eigenvalues with shape ``(N,)``. + + See also: + :func:`jax.scipy.linalg.eigh`: general Hermitian eigenvalue solver + + Examples: + >>> d = jnp.array([1., 2., 3., 4.]) + >>> e = jnp.array([1., 1., 1.]) + >>> eigvals = jax.scipy.linalg.eigh_tridiagonal(d, e, eigvals_only=True) + >>> eigvals + Array([0.2547188, 1.8227171, 3.1772828, 4.745281 ], dtype=float32) + + For comparison, we can construct the full matrix and compute the same result + using :func:`~jax.scipy.linalg.eigh`: + + >>> A = jnp.diag(d) + jnp.diag(e, 1) + jnp.diag(e, -1) + >>> A + Array([[1., 1., 0., 0.], + [1., 2., 1., 0.], + [0., 1., 3., 1.], + [0., 0., 1., 4.]], dtype=float32) + >>> eigvals_full = jax.scipy.linalg.eigh(A, eigvals_only=True) + >>> jnp.allclose(eigvals, eigvals_full) + Array(True, dtype=bool) + """ if not eigvals_only: raise NotImplementedError("Calculation of eigenvectors is not implemented") @@ -829,6 +1668,36 @@ def polar(a: ArrayLike, side: str = 'right', *, method: str = 'qdwh', eps: float ``posdef`` is either :math:`n \times n` or :math:`m \times m` depending on whether ``side`` is ``"right"`` or ``"left"``, respectively. + Examples: + + Polar decomposition of a 3x3 matrix: + + >>> a = jnp.array([[1., 2., 3.], + ... [5., 4., 2.], + ... [3., 2., 1.]]) + >>> U, P = jax.scipy.linalg.polar(a) + + U is a Unitary Matrix: + + >>> jnp.round(U.T @ U) + Array([[ 1., -0., -0.], + [-0., 1., 0.], + [-0., 0., 1.]], dtype=float32) + + P is positive-semidefinite Matrix: + + >>> with jnp.printoptions(precision=2, suppress=True): + ... print(P) + [[4.79 3.25 1.23] + [3.25 3.06 2.01] + [1.23 2.01 2.91]] + + The original matrix can be reconstructed by multiplying the U and P: + + >>> a_reconstructed = U @ P + >>> jnp.allclose(a, a_reconstructed) + Array(True, dtype=bool) + .. _QDWH: https://epubs.siam.org/doi/abs/10.1137/090774999 """ arr = jnp.asarray(a) @@ -901,26 +1770,103 @@ def _sqrtm(A: ArrayLike) -> Array: return jnp.matmul(jnp.matmul(Z, sqrt_T, precision=lax.Precision.HIGHEST), jnp.conj(Z.T), precision=lax.Precision.HIGHEST) -@implements(scipy.linalg.sqrtm, - lax_description=""" -This differs from ``scipy.linalg.sqrtm`` in that the return type of -``jax.scipy.linalg.sqrtm`` is always ``complex64`` for 32-bit input, -and ``complex128`` for 64-bit input. -This function implements the complex Schur method described in [A]. It does not use recursive blocking -to speed up computations as a Sylvester Equation solver is not available yet in JAX. - -[A] Björck, Å., & Hammarling, S. (1983). - "A Schur method for the square root of a matrix". Linear algebra and its applications, 52, 127-140. -""") def sqrtm(A: ArrayLike, blocksize: int = 1) -> Array: + """Compute the matrix square root + + JAX implementation of :func:`scipy.linalg.sqrtm`. + + Args: + A: array of shape ``(N, N)`` + blocksize: Not supported in JAX; JAX always uses ``blocksize=1``. + + Returns: + An array of shape ``(N, N)`` containing the matrix square root of ``A`` + + See Also: + :func:`jax.scipy.linalg.expm` + + Examples: + >>> a = jnp.array([[1., 2., 3.], + ... [2., 4., 2.], + ... [3., 2., 1.]]) + >>> sqrt_a = jax.scipy.linalg.sqrtm(a) + >>> with jnp.printoptions(precision=2, suppress=True): + ... print(sqrt_a) + [[0.92+0.71j 0.54+0.j 0.92-0.71j] + [0.54+0.j 1.85+0.j 0.54-0.j ] + [0.92-0.71j 0.54-0.j 0.92+0.71j]] + + By definition, matrix multiplication of the matrix square root with itself should + equal the input: + + >>> jnp.allclose(a, sqrt_a @ sqrt_a) + Array(True, dtype=bool) + + Notes: + This function implements the complex Schur method described in [1]_. It does not use + recursive blocking to speed up computations as a Sylvester Equation solver is not + yet available in JAX. + + References: + .. [1] Björck, Å., & Hammarling, S. (1983). "A Schur method for the square root of a matrix". + Linear algebra and its applications, 52, 127-140. + """ if blocksize > 1: raise NotImplementedError("Blocked version is not implemented yet.") return _sqrtm(A) -@implements(scipy.linalg.rsf2csf, lax_description=_no_chkfinite_doc) + @partial(jit, static_argnames=('check_finite',)) def rsf2csf(T: ArrayLike, Z: ArrayLike, check_finite: bool = True) -> tuple[Array, Array]: + """Convert real Schur form to complex Schur form. + + JAX implementation of :func:`scipy.linalg.rsf2csf`. + + Args: + T: array of shape ``(..., N, N)`` containing the real Schur form of the input. + Z: array of shape ``(..., N, N)`` containing the corresponding Schur transformation + matrix. + check_finite: unused by JAX + + Returns: + A tuple of arrays ``(T, Z)`` of the same shape as the inputs, containing the + Complex Schur form and the associated Schur transformation matrix. + + See Also: + :func:`jax.scipy.linalg.schur`: Schur decomposition + + Examples: + >>> A = jnp.array([[0., 3., 3.], + ... [0., 1., 2.], + ... [2., 0., 1.]]) + >>> Tr, Zr = jax.scipy.linalg.schur(A) + >>> Tc, Zc = jax.scipy.linalg.rsf2csf(Tr, Zr) + + Both the real and complex form can be used to reconstruct the input matrix + to float32 precision: + + >>> jnp.allclose(Zr @ Tr @ Zr.T, A, atol=1E-5) + Array(True, dtype=bool) + >>> jnp.allclose(Zc @ Tc @ Zc.conj().T, A, atol=1E-5) + Array(True, dtype=bool) + + The real-valued Schur form is only quasi-upper-triangular, as we can see in this case: + + >>> with jax.numpy.printoptions(precision=2, suppress=True): + ... print(Tr) + [[ 3.76 -2.17 1.38] + [ 0. -0.88 -0.35] + [ 0. 2.37 -0.88]] + + By contrast, the complex form is truly upper-triangular: + + >>> with jnp.printoptions(precision=2, suppress=True): + ... print(Tc) + [[ 3.76+0.j 1.29-0.78j 2.02-0.5j ] + [ 0. +0.j -0.88+0.91j -2.02+0.j ] + [ 0. +0.j 0. +0.j -0.88-0.91j]] + """ del check_finite # unused T_arr = jnp.asarray(T) @@ -987,11 +1933,57 @@ def hessenberg(a: ArrayLike, *, calc_q: Literal[False], overwrite_a: bool = Fals def hessenberg(a: ArrayLike, *, calc_q: Literal[True], overwrite_a: bool = False, check_finite: bool = True) -> tuple[Array, Array]: ... -@implements(scipy.linalg.hessenberg, lax_description=_no_overwrite_and_chkfinite_doc) + @partial(jit, static_argnames=('calc_q', 'check_finite', 'overwrite_a')) def hessenberg(a: ArrayLike, *, calc_q: bool = False, overwrite_a: bool = False, check_finite: bool = True) -> Array | tuple[Array, Array]: - del overwrite_a, check_finite + """Compute the Hessenberg form of the matrix + + JAX implementation of :func:`scipy.linalg.hessenberg`. + + The Hessenberg form `H` of a matrix `A` satisfies: + + .. math:: + + A = Q H Q^H + + where `Q` is unitary and `H` is zero below the first subdiagonal. + + Args: + a : array of shape ``(..., N, N)`` + calc_q: if True, calculate the ``Q`` matrix (default: False) + overwrite_a: unused by JAX + check_finite: unused by JAX + + Returns: + A tuple of arrays ``(H, Q)`` if ``calc_q`` is True, else an array ``H`` + + - ``H`` has shape ``(..., N, N)`` and is the Hessenberg form of ``a`` + - ``Q`` has shape ``(..., N, N)`` and is the associated unitary matrix + + Examples: + Computing the Hessenberg form of a 4x4 matrix + + >>> a = jnp.array([[1., 2., 3., 4.], + ... [1., 4., 2., 3.], + ... [3., 2., 1., 4.], + ... [2., 3., 2., 2.]]) + >>> H, Q = jax.scipy.linalg.hessenberg(a, calc_q=True) + >>> with jnp.printoptions(suppress=True, precision=3): + ... print(H) + [[ 1. -5.078 1.167 1.361] + [-3.742 5.786 -3.613 -1.825] + [ 0. -2.992 2.493 -0.577] + [ 0. 0. -0.043 -1.279]] + + Notice the zeros in the subdiagonal positions. The original matrix + can be reconstructed using the ``Q`` vectors: + + >>> a_reconstructed = Q @ H @ Q.conj().T + >>> jnp.allclose(a_reconstructed, a) + Array(True, dtype=bool) + """ + del overwrite_a, check_finite # unused n = jnp.shape(a)[-1] if n == 0: if calc_q: @@ -1010,8 +2002,64 @@ def hessenberg(a: ArrayLike, *, calc_q: bool = False, overwrite_a: bool = False, else: return h -@implements(scipy.linalg.toeplitz) + def toeplitz(c: ArrayLike, r: ArrayLike | None = None) -> Array: + r"""Construct a Toeplitz matrix + + JAX implementation of :func:`scipy.linalg.toeplitz`. + + A Toeplitz matrix has equal diagonals: :math:`A_{ij} = k_{i - j}` + for :math:`0 \le i < n` and :math:`0 \le j < n`. This function + specifies the diagonals via the first column ``c`` and the first row + ``r``, such that for row `i` and column `j`: + + .. math:: + + A_{ij} = \begin{cases} + c_{i - j} & i \ge j \\ + r_{j - i} & i < j + \end{cases} + + Notice this implies that :math:`r_0` is ignored. + + Args: + c: array specifying the first column. Will be flattened + if not 1-dimensional. + r: (optional) array specifying the first row. If not specified, defaults + to ``conj(c)``. Will be flattened if not 1-dimensional. + + Returns: + toeplitz matrix of shape ``(c.size, r.size)``. + + Examples: + Specifying ``c`` only: + + >>> c = jnp.array([1, 2, 3]) + >>> jax.scipy.linalg.toeplitz(c) + Array([[1, 2, 3], + [2, 1, 2], + [3, 2, 1]], dtype=int32) + + Specifying ``c`` and ``r``: + + >>> r = jnp.array([-1, -2, -3]) + >>> jax.scipy.linalg.toeplitz(c, r) # Note r[0] is ignored + Array([[ 1, -2, -3], + [ 2, 1, -2], + [ 3, 2, 1]], dtype=int32) + + If specifying only complex-valued ``c``, ``r`` defaults to ``c.conj()``, + resulting in a Hermitian matrix if ``c[0].imag == 0``: + + >>> c = jnp.array([1, 2+1j, 1+2j]) + >>> M = jax.scipy.linalg.toeplitz(c) + >>> M + Array([[1.+0.j, 2.-1.j, 1.-2.j], + [2.+1.j, 1.+0.j, 2.-1.j], + [1.+2.j, 2.+1.j, 1.+0.j]], dtype=complex64) + >>> print("M is Hermitian:", jnp.all(M == M.conj().T)) + M is Hermitian: True + """ if r is None: check_arraylike("toeplitz", c) r = jnp.conjugate(jnp.asarray(c)) @@ -1035,3 +2083,36 @@ def toeplitz(c: ArrayLike, r: ArrayLike | None = None) -> Array: (nrows,), (1,), 'VALID', dimension_numbers=('NTC', 'IOT', 'NTC'), precision=lax.Precision.HIGHEST)[0] return jnp.flip(patches, axis=0) + + +@partial(jit, static_argnames=("n",)) +def hilbert(n: int) -> Array: + r"""Create a Hilbert matrix of order n. + + JAX implementation of :func:`scipy.linalg.hilbert`. + + The Hilbert matrix is defined by: + + .. math:: + + H_{ij} = \frac{1}{i + j + 1} + + for :math:`1 \le i \le n` and :math:`1 \le j \le n`. + + Args: + n: the size of the matrix to create. + + Returns: + A Hilbert matrix of shape ``(n, n)`` + + Examples: + >>> jax.scipy.linalg.hilbert(2) + Array([[1. , 0.5 ], + [0.5 , 0.33333334]], dtype=float32) + >>> jax.scipy.linalg.hilbert(3) + Array([[1. , 0.5 , 0.33333334], + [0.5 , 0.33333334, 0.25 ], + [0.33333334, 0.25 , 0.2 ]], dtype=float32) + """ + a = lax.broadcasted_iota(jnp.float64, (n, 1), 0) + return 1/(a + a.T + 1) diff --git a/jax/_src/scipy/ndimage.py b/jax/_src/scipy/ndimage.py index 1b01af5f4670..d81008308b94 100644 --- a/jax/_src/scipy/ndimage.py +++ b/jax/_src/scipy/ndimage.py @@ -12,20 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -from collections.abc import Sequence +from collections.abc import Callable, Sequence import functools import itertools import operator -import textwrap -from typing import Callable - -import scipy.ndimage from jax._src import api from jax._src import util from jax import lax import jax.numpy as jnp -from jax._src.numpy.util import implements from jax._src.typing import ArrayLike, Array from jax._src.util import safe_zip as zip @@ -120,22 +115,69 @@ def _map_coordinates(input: ArrayLike, coordinates: Sequence[ArrayLike], else: all_valid = functools.reduce(operator.and_, validities) contribution = jnp.where(all_valid, input_arr[indices], cval) - outputs.append(_nonempty_prod(weights) * contribution) + outputs.append(_nonempty_prod(weights) * contribution) # type: ignore result = _nonempty_sum(outputs) if jnp.issubdtype(input_arr.dtype, jnp.integer): result = _round_half_away_from_zero(result) return result.astype(input_arr.dtype) -@implements(scipy.ndimage.map_coordinates, lax_description=textwrap.dedent("""\ +""" Only nearest neighbor (``order=0``), linear interpolation (``order=1``) and modes ``'constant'``, ``'nearest'``, ``'wrap'`` ``'mirror'`` and ``'reflect'`` are currently supported. - Note that interpolation near boundaries differs from the scipy function, - because we fixed an outstanding bug (https://github.com/scipy/scipy/issues/2640); - this function interprets the ``mode`` argument as documented by SciPy, but - not as implemented by SciPy. - """)) + + """ + def map_coordinates( - input: ArrayLike, coordinates: Sequence[ArrayLike], order: int, mode: str = 'constant', cval: ArrayLike = 0.0, + input: ArrayLike, coordinates: Sequence[ArrayLike], order: int, + mode: str = 'constant', cval: ArrayLike = 0.0, ): + """ + Map the input array to new coordinates using interpolation. + + JAX implementation of :func:`scipy.ndimage.map_coordinates` + + Given an input array and a set of coordinates, this function returns the + interpolated values of the input array at those coordinates. + + Args: + input: N-dimensional input array from which values are interpolated. + coordinates: length-N sequence of arrays specifying the coordinates + at which to evaluate the interpolated values + order: The order of interpolation. JAX supports the following: + + * 0: Nearest-neighbor + * 1: Linear + + mode: Points outside the boundaries of the input are filled according to the given mode. + JAX supports one of ``('constant', 'nearest', 'mirror', 'wrap', 'reflect')``. Note the + ``'wrap'`` mode in JAX behaves as ``'grid-wrap'`` mode in SciPy, and ``'constant'`` + mode in JAX behaves as ``'grid-constant'`` mode in SciPy. This discrepancy was caused + by a former bug in those modes in SciPy (https://github.com/scipy/scipy/issues/2640), + which was first fixed in JAX by changing the behavior of the existing modes, and later + on fixed in SciPy, by adding modes with new names, rather than fixing the existing + ones, for backwards compatibility reasons. Default is 'constant'. + cval: Value used for points outside the boundaries of the input if ``mode='constant'`` + Default is 0.0. + + Returns: + The interpolated values at the specified coordinates. + + Examples: + >>> input = jnp.arange(12.0).reshape(3, 4) + >>> input + Array([[ 0., 1., 2., 3.], + [ 4., 5., 6., 7.], + [ 8., 9., 10., 11.]], dtype=float32) + >>> coordinates = [jnp.array([0.5, 1.5]), + ... jnp.array([1.5, 2.5])] + >>> jax.scipy.ndimage.map_coordinates(input, coordinates, order=1) + Array([3.5, 8.5], dtype=float32) + + Note: + Interpolation near boundaries differs from the scipy function, because JAX + fixed an outstanding bug; see https://github.com/google/jax/issues/11097. + This function interprets the ``mode`` argument as documented by SciPy, but + not as implemented by SciPy. + """ return _map_coordinates(input, coordinates, order, mode, cval) diff --git a/jax/_src/scipy/optimize/_lbfgs.py b/jax/_src/scipy/optimize/_lbfgs.py index a1719d5280ec..aa82ab4fd0c8 100644 --- a/jax/_src/scipy/optimize/_lbfgs.py +++ b/jax/_src/scipy/optimize/_lbfgs.py @@ -15,8 +15,9 @@ from __future__ import annotations -from typing import Callable, NamedTuple +from collections.abc import Callable from functools import partial +from typing import NamedTuple import jax import jax.numpy as jnp @@ -171,9 +172,9 @@ def body_fun(state: LBFGSResults): # replacements for next iteration status = jnp.array(0) status = jnp.where(state.f_k - f_kp1 < ftol, 4, status) - status = jnp.where(state.ngev >= maxgrad, 3, status) # type: ignore - status = jnp.where(state.nfev >= maxfun, 2, status) # type: ignore - status = jnp.where(state.k >= maxiter, 1, status) # type: ignore + status = jnp.where(state.ngev >= maxgrad, 3, status) + status = jnp.where(state.nfev >= maxfun, 2, status) + status = jnp.where(state.k >= maxiter, 1, status) status = jnp.where(ls_results.failed, 5, status) converged = jnp.linalg.norm(g_kp1, ord=norm) < gtol diff --git a/jax/_src/scipy/optimize/bfgs.py b/jax/_src/scipy/optimize/bfgs.py index b6fd9f9dda17..657b7610e6e1 100644 --- a/jax/_src/scipy/optimize/bfgs.py +++ b/jax/_src/scipy/optimize/bfgs.py @@ -15,8 +15,9 @@ from __future__ import annotations +from collections.abc import Callable from functools import partial -from typing import Callable, NamedTuple +from typing import NamedTuple import jax import jax.numpy as jnp diff --git a/jax/_src/scipy/optimize/line_search.py b/jax/_src/scipy/optimize/line_search.py index 078d23d97a96..189009693cdd 100644 --- a/jax/_src/scipy/optimize/line_search.py +++ b/jax/_src/scipy/optimize/line_search.py @@ -118,7 +118,7 @@ def body(state): # This will cause the line search to stop, and since the Wolfe conditions # are not satisfied the minimization should stop too. - threshold = jnp.where((jnp.finfo(dalpha).bits < 64), 1e-5, 1e-10) + threshold = jnp.where((jnp.finfo(dalpha.dtype).bits < 64), 1e-5, 1e-10) state = state._replace(failed=state.failed | (dalpha <= threshold)) # Cubmin is sometimes nan, though in this case the bounds check will fail. diff --git a/jax/_src/scipy/optimize/minimize.py b/jax/_src/scipy/optimize/minimize.py index 830f1228424a..4fc006be6df0 100644 --- a/jax/_src/scipy/optimize/minimize.py +++ b/jax/_src/scipy/optimize/minimize.py @@ -14,8 +14,8 @@ from __future__ import annotations -from collections.abc import Mapping -from typing import Any, Callable +from collections.abc import Callable, Mapping +from typing import Any import jax from jax._src.scipy.optimize.bfgs import minimize_bfgs diff --git a/jax/_src/scipy/signal.py b/jax/_src/scipy/signal.py index 284509be29eb..1282650ae1e5 100644 --- a/jax/_src/scipy/signal.py +++ b/jax/_src/scipy/signal.py @@ -14,16 +14,13 @@ from __future__ import annotations -from collections.abc import Sequence +from collections.abc import Callable, Sequence from functools import partial import math import operator -from typing import Callable import warnings import numpy as np -import scipy.signal as osp_signal -from scipy.fft import next_fast_len as osp_fft_next_fast_len import jax import jax.numpy.fft @@ -34,15 +31,66 @@ from jax._src.lax.lax import PrecisionLike from jax._src.numpy import linalg from jax._src.numpy.util import ( - check_arraylike, implements, promote_dtypes_inexact, promote_dtypes_complex) + check_arraylike, promote_dtypes_inexact, promote_dtypes_complex) from jax._src.third_party.scipy import signal_helper from jax._src.typing import Array, ArrayLike from jax._src.util import canonicalize_axis, tuple_delete, tuple_insert -@implements(osp_signal.fftconvolve) def fftconvolve(in1: ArrayLike, in2: ArrayLike, mode: str = "full", axes: Sequence[int] | None = None) -> Array: + """ + Convolve two N-dimensional arrays using Fast Fourier Transform (FFT). + + JAX implementation of :func:`scipy.signal.fftconvolve`. + + Args: + in1: left-hand input to the convolution. + in2: right-hand input to the convolution. Must have ``in1.ndim == in2.ndim``. + mode: controls the size of the output. Available operations are: + + * ``"full"``: (default) output the full convolution of the inputs. + * ``"same"``: return a centered portion of the ``"full"`` output which + is the same size as ``in1``. + * ``"valid"``: return the portion of the ``"full"`` output which do not + depend on padding at the array edges. + + axes: optional sequence of axes along which to apply the convolution. + + Returns: + Array containing the convolved result. + + See Also: + - :func:`jax.numpy.convolve`: 1D convolution + - :func:`jax.scipy.signal.convolve`: direct convolution + + Examples: + A few 1D convolution examples. Because FFT-based convolution is approximate, + We use :func:`jax.numpy.printoptions` below to adjust the printing precision: + + >>> x = jnp.array([1, 2, 3, 2, 1]) + >>> y = jnp.array([1, 1, 1]) + + Full convolution uses implicit zero-padding at the edges: + + >>> with jax.numpy.printoptions(precision=3): + ... print(jax.scipy.signal.fftconvolve(x, y, mode='full')) + [1. 3. 6. 7. 6. 3. 1.] + + Specifying ``mode = 'same'`` returns a centered convolution the same size + as the first input: + + >>> with jax.numpy.printoptions(precision=3): + ... print(jax.scipy.signal.fftconvolve(x, y, mode='same')) + [3. 6. 7. 6. 3.] + + Specifying ``mode = 'valid'`` returns only the portion where the two arrays + fully overlap: + + >>> with jax.numpy.printoptions(precision=3): + ... print(jax.scipy.signal.fftconvolve(x, y, mode='valid')) + [6. 7. 6.] + """ check_arraylike('fftconvolve', in1, in2) in1, in2 = promote_dtypes_inexact(in1, in2) if in1.ndim != in2.ndim: @@ -63,7 +111,9 @@ def fftconvolve(in1: ArrayLike, in2: ArrayLike, mode: str = "full", def _fftconvolve_unbatched(in1: Array, in2: Array, mode: str) -> Array: full_shape = tuple(s1 + s2 - 1 for s1, s2 in zip(in1.shape, in2.shape)) - fft_shape = tuple(osp_fft_next_fast_len(s) for s in full_shape) + + # TODO(jakevdp): potentially use next_fast_len to evaluate with a more efficient shape. + fft_shape = full_shape # tuple(next_fast_len(s) for s in full_shape) if mode == 'valid': no_swap = all(s1 >= s2 for s1, s2 in zip(in1.shape, in2.shape)) @@ -133,9 +183,63 @@ def _convolve_nd(in1: Array, in2: Array, mode: str, *, precision: PrecisionLike) return result[0, 0] -@implements(osp_signal.convolve) def convolve(in1: Array, in2: Array, mode: str = 'full', method: str = 'auto', precision: PrecisionLike = None) -> Array: + """Convolution of two N-dimensional arrays. + + JAX implementation of :func:`scipy.signal.convolve`. + + Args: + in1: left-hand input to the convolution. + in2: right-hand input to the convolution. Must have ``in1.ndim == in2.ndim``. + mode: controls the size of the output. Available operations are: + + * ``"full"``: (default) output the full convolution of the inputs. + * ``"same"``: return a centered portion of the ``"full"`` output which + is the same size as ``in1``. + * ``"valid"``: return the portion of the ``"full"`` output which do not + depend on padding at the array edges. + + method: controls the computation method. Options are + + * ``"auto"``: (default) always uses the ``"direct"`` method. + * ``"direct"``: lower to :func:`jax.lax.conv_general_dilated`. + * ``"fft"``: compute the result via a fast Fourier transform. + + precision: Specify the precision of the computation. Refer to + :class:`jax.lax.Precision` for a description of available values. + + Returns: + Array containing the convolved result. + + See Also: + - :func:`jax.numpy.convolve`: 1D convolution + - :func:`jax.scipy.signal.convolve2d`: 2D convolution + - :func:`jax.scipy.signal.correlate`: ND correlation + + Examples: + A few 1D convolution examples: + + >>> x = jnp.array([1, 2, 3, 2, 1]) + >>> y = jnp.array([1, 1, 1]) + + Full convolution uses implicit zero-padding at the edges: + + >>> jax.scipy.signal.convolve(x, y, mode='full') + Array([1., 3., 6., 7., 6., 3., 1.], dtype=float32) + + Specifying ``mode = 'same'`` returns a centered convolution the same size + as the first input: + + >>> jax.scipy.signal.convolve(x, y, mode='same') + Array([3., 6., 7., 6., 3.], dtype=float32) + + Specifying ``mode = 'valid'`` returns only the portion where the two arrays + fully overlap: + + >>> jax.scipy.signal.convolve(x, y, mode='valid') + Array([6., 7., 6.], dtype=float32) + """ if method == 'fft': return fftconvolve(in1, in2, mode=mode) elif method in ['direct', 'auto']: @@ -144,9 +248,73 @@ def convolve(in1: Array, in2: Array, mode: str = 'full', method: str = 'auto', raise ValueError(f"Got {method=}; expected 'auto', 'fft', or 'direct'.") -@implements(osp_signal.convolve2d) def convolve2d(in1: Array, in2: Array, mode: str = 'full', boundary: str = 'fill', fillvalue: float = 0, precision: PrecisionLike = None) -> Array: + """Convolution of two 2-dimensional arrays. + + JAX implementation of :func:`scipy.signal.convolve2d`. + + Args: + in1: left-hand input to the convolution. Must have ``in1.ndim == 2``. + in2: right-hand input to the convolution. Must have ``in2.ndim == 2``. + mode: controls the size of the output. Available operations are: + + * ``"full"``: (default) output the full convolution of the inputs. + * ``"same"``: return a centered portion of the ``"full"`` output which + is the same size as ``in1``. + * ``"valid"``: return the portion of the ``"full"`` output which do not + depend on padding at the array edges. + + boundary: only ``"fill"`` is supported. + fillvalue: only ``0`` is supported. + method: controls the computation method. Options are + + * ``"auto"``: (default) always uses the ``"direct"`` method. + * ``"direct"``: lower to :func:`jax.lax.conv_general_dilated`. + * ``"fft"``: compute the result via a fast Fourier transform. + + precision: Specify the precision of the computation. Refer to + :class:`jax.lax.Precision` for a description of available values. + + Returns: + Array containing the convolved result. + + See Also: + - :func:`jax.numpy.convolve`: 1D convolution + - :func:`jax.scipy.signal.convolve`: ND convolution + - :func:`jax.scipy.signal.correlate`: ND correlation + + Examples: + A few 2D convolution examples: + + >>> x = jnp.array([[1, 2], + ... [3, 4]]) + >>> y = jnp.array([[2, 1, 1], + ... [4, 3, 4], + ... [1, 3, 2]]) + + Full 2D convolution uses implicit zero-padding at the edges: + + >>> jax.scipy.signal.convolve2d(x, y, mode='full') + Array([[ 2., 5., 3., 2.], + [10., 22., 17., 12.], + [13., 30., 32., 20.], + [ 3., 13., 18., 8.]], dtype=float32) + + Specifying ``mode = 'same'`` returns a centered 2D convolution of the same size + as the first input: + + >>> jax.scipy.signal.convolve2d(x, y, mode='same') + Array([[22., 17.], + [30., 32.]], dtype=float32) + + Specifying ``mode = 'valid'`` returns only the portion of 2D convolution + where the two arrays fully overlap: + + >>> jax.scipy.signal.convolve2d(x, y, mode='valid') + Array([[22., 17.], + [30., 32.]], dtype=float32) + """ if boundary != 'fill' or fillvalue != 0: raise NotImplementedError("convolve2d() only supports boundary='fill', fillvalue=0") if jnp.ndim(in1) != 2 or jnp.ndim(in2) != 2: @@ -154,15 +322,134 @@ def convolve2d(in1: Array, in2: Array, mode: str = 'full', boundary: str = 'fill return _convolve_nd(in1, in2, mode, precision=precision) -@implements(osp_signal.correlate) def correlate(in1: Array, in2: Array, mode: str = 'full', method: str = 'auto', precision: PrecisionLike = None) -> Array: + """Cross-correlation of two N-dimensional arrays. + + JAX implementation of :func:`scipy.signal.correlate`. + + Args: + in1: left-hand input to the cross-correlation. + in2: right-hand input to the cross-correlation. Must have ``in1.ndim == in2.ndim``. + mode: controls the size of the output. Available operations are: + + * ``"full"``: (default) output the full cross-correlation of the inputs. + * ``"same"``: return a centered portion of the ``"full"`` output which + is the same size as ``in1``. + * ``"valid"``: return the portion of the ``"full"`` output which do not + depend on padding at the array edges. + + method: controls the computation method. Options are + + * ``"auto"``: (default) always uses the ``"direct"`` method. + * ``"direct"``: lower to :func:`jax.lax.conv_general_dilated`. + * ``"fft"``: compute the result via a fast Fourier transform. + + precision: Specify the precision of the computation. Refer to + :class:`jax.lax.Precision` for a description of available values. + + Returns: + Array containing the cross-correlation result. + + See Also: + - :func:`jax.numpy.correlate`: 1D cross-correlation + - :func:`jax.scipy.signal.correlate2d`: 2D cross-correlation + - :func:`jax.scipy.signal.convolve`: ND convolution + + Examples: + A few 1D correlation examples: + + >>> x = jnp.array([1, 2, 3, 2, 1]) + >>> y = jnp.array([1, 3, 2]) + + Full 1D correlation uses implicit zero-padding at the edges: + + >>> jax.scipy.signal.correlate(x, y, mode='full') + Array([ 2., 7., 13., 15., 11., 5., 1.], dtype=float32) + + Specifying ``mode = 'same'`` returns a centered 1D correlation of the same + size as the first input: + + >>> jax.scipy.signal.correlate(x, y, mode='same') + Array([ 7., 13., 15., 11., 5.], dtype=float32) + + Specifying ``mode = 'valid'`` returns only the portion of 1D correlation + where the two arrays fully overlap: + + >>> jax.scipy.signal.correlate(x, y, mode='valid') + Array([13., 15., 11.], dtype=float32) + """ return convolve(in1, jnp.flip(in2.conj()), mode, precision=precision, method=method) -@implements(osp_signal.correlate2d) def correlate2d(in1: Array, in2: Array, mode: str = 'full', boundary: str = 'fill', fillvalue: float = 0, precision: PrecisionLike = None) -> Array: + """Cross-correlation of two 2-dimensional arrays. + + JAX implementation of :func:`scipy.signal.correlate2d`. + + Args: + in1: left-hand input to the cross-correlation. Must have ``in1.ndim == 2``. + in2: right-hand input to the cross-correlation. Must have ``in2.ndim == 2``. + mode: controls the size of the output. Available operations are: + + * ``"full"``: (default) output the full cross-correlation of the inputs. + * ``"same"``: return a centered portion of the ``"full"`` output which + is the same size as ``in1``. + * ``"valid"``: return the portion of the ``"full"`` output which do not + depend on padding at the array edges. + + boundary: only ``"fill"`` is supported. + fillvalue: only ``0`` is supported. + method: controls the computation method. Options are + + * ``"auto"``: (default) always uses the ``"direct"`` method. + * ``"direct"``: lower to :func:`jax.lax.conv_general_dilated`. + * ``"fft"``: compute the result via a fast Fourier transform. + + precision: Specify the precision of the computation. Refer to + :class:`jax.lax.Precision` for a description of available values. + + Returns: + Array containing the cross-correlation result. + + See Also: + - :func:`jax.numpy.correlate`: 1D cross-correlation + - :func:`jax.scipy.signal.correlate`: ND cross-correlation + - :func:`jax.scipy.signal.convolve`: ND convolution + + Examples: + A few 2D correlation examples: + + >>> x = jnp.array([[2, 1, 3], + ... [1, 3, 1], + ... [4, 1, 2]]) + >>> y = jnp.array([[1, 3], + ... [4, 2]]) + + Full 2D correlation uses implicit zero-padding at the edges: + + >>> jax.scipy.signal.correlate2d(x, y, mode='full') + Array([[ 4., 10., 10., 12.], + [ 8., 15., 24., 7.], + [11., 28., 14., 9.], + [12., 7., 7., 2.]], dtype=float32) + + Specifying ``mode = 'same'`` returns a centered 2D correlation of the same + size as the first input: + + >>> jax.scipy.signal.correlate2d(x, y, mode='same') + Array([[15., 24., 7.], + [28., 14., 9.], + [ 7., 7., 2.]], dtype=float32) + + Specifying ``mode = 'valid'`` returns only the portion of 2D correlation + where the two arrays fully overlap: + + >>> jax.scipy.signal.correlate2d(x, y, mode='valid') + Array([[15., 24.], + [28., 14.]], dtype=float32) + """ if boundary != 'fill' or fillvalue != 0: raise NotImplementedError("correlate2d() only supports boundary='fill', fillvalue=0") if jnp.ndim(in1) != 2 or jnp.ndim(in2) != 2: @@ -191,9 +478,51 @@ def correlate2d(in1: Array, in2: Array, mode: str = 'full', boundary: str = 'fil return result -@implements(osp_signal.detrend) def detrend(data: ArrayLike, axis: int = -1, type: str = 'linear', bp: int = 0, overwrite_data: None = None) -> Array: + """ + Remove linear or piecewise linear trends from data. + + JAX implementation of :func:`scipy.signal.detrend`. + + Args: + data: The input array containing the data to detrend. + axis: The axis along which to detrend. Default is -1 (the last axis). + type: The type of detrending. Can be: + + * ``'linear'``: Fit a single linear trend for the entire data. + * ``'constant'``: Remove the mean value of the data. + + bp: A sequence of breakpoints. If given, piecewise linear trends + are fit between these breakpoints. + overwrite_data: This argument is not supported by JAX's implementation. + + Returns: + The detrended data array. + + Examples: + A simple detrend operation in one dimension: + + >>> data = jnp.array([1., 4., 8., 8., 9.]) + + Removing a linear trend from the data: + + >>> detrended = jax.scipy.signal.detrend(data) + >>> with jnp.printoptions(precision=3, suppress=True): # suppress float error + ... print("Detrended:", detrended) + ... print("Underlying trend:", data - detrended) + Detrended: [-1. -0. 2. -0. -1.] + Underlying trend: [ 2. 4. 6. 8. 10.] + + Removing a constant trend from the data: + + >>> detrended = jax.scipy.signal.detrend(data, type='constant') + >>> with jnp.printoptions(precision=3): # suppress float error + ... print("Detrended:", detrended) + ... print("Underlying trend:", data - detrended) + Detrended: [-5. -2. 2. 2. 3.] + Underlying trend: [6. 6. 6. 6. 6.] + """ if overwrite_data is not None: raise NotImplementedError("overwrite_data argument not implemented.") if type not in ['constant', 'linear']: @@ -358,7 +687,7 @@ def pad(x, n, axis=-1): if nperseg is not None: # if specified by user nperseg_int = jax.core.concrete_or_error(int, nperseg, "nperseg of windowed-FFT") - if nperseg_int < 1: # type: ignore[operator] + if nperseg_int < 1: raise ValueError('nperseg must be a positive integer') # parse window; if array like, then set nperseg = win.shape win, nperseg_int = signal_helper._triage_segments( @@ -366,7 +695,7 @@ def pad(x, n, axis=-1): input_length=x.shape[axis], dtype=x.dtype) if noverlap is None: - noverlap_int = nperseg_int // 2 # type: ignore[operator] + noverlap_int = nperseg_int // 2 else: noverlap_int = jax.core.concrete_or_error(int, noverlap, "noverlap of windowed-FFT") @@ -383,7 +712,7 @@ def pad(x, n, axis=-1): return jnp.zeros(x.shape, freq_dtype), jnp.zeros(x.shape, freq_dtype), jnp.zeros(x.shape, result_dtype) else: if x.size == 0 or y_arr.size == 0: - shape = tuple_insert(outershape, min([x.shape[axis], y_arr.shape[axis]]), axis) + shape = tuple_insert(outershape, min(x.shape[axis], y_arr.shape[axis]), axis) return jnp.zeros(shape, freq_dtype), jnp.zeros(shape, freq_dtype), jnp.zeros(shape, result_dtype) # Move time-axis to the end @@ -499,11 +828,44 @@ def detrend_func(d): return freqs, time, result -@implements(osp_signal.stft) def stft(x: Array, fs: ArrayLike = 1.0, window: str = 'hann', nperseg: int = 256, noverlap: int | None = None, nfft: int | None = None, detrend: bool = False, return_onesided: bool = True, boundary: str | None = 'zeros', padded: bool = True, axis: int = -1) -> tuple[Array, Array, Array]: + """ + Compute the short-time Fourier transform (STFT). + + JAX implementation of :func:`scipy.signal.stft`. + + Args: + x: Array representing a time series of input values. + fs: Sampling frequency of the time series (default: 1.0). + window: Data tapering window to apply to each segment. Can be a window function name, + a tuple specifying a window length and function, or an array (default: ``'hann'``). + nperseg: Length of each segment (default: 256). + noverlap: Number of points to overlap between segments (default: ``nperseg // 2``). + nfft: Length of the FFT used, if a zero-padded FFT is desired. If ``None`` (default), + the FFT length is ``nperseg``. + detrend: Specifies how to detrend each segment. Can be ``False`` (default: no detrending), + ``'constant'`` (remove mean), ``'linear'`` (remove linear trend), or a callable + accepting a segment and returning a detrended segment. + return_onesided: If True (default), return a one-sided spectrum for real inputs. + If False, return a two-sided spectrum. + boundary: Specifies whether the input signal is extended at both ends, and how. + Options are ``None`` (no extension), ``'zeros'`` (default), ``'even'``, ``'odd'``, + or ``'constant'``. + padded: Specifies whether the input signal is zero-padded at the end to make its + length a multiple of `nperseg`. If True (default), the padded signal length is + the next multiple of ``nperseg``. + axis: Axis along which the STFT is computed; the default is over the last axis (-1). + + Returns: + A length-3 tuple of arrays ``(f, t, Zxx)``. ``f`` is the Array of sample frequencies. + ``t`` is the Array of segment times, and ``Zxx`` is the STFT of ``x``. + + See Also: + :func:`jax.scipy.signal.istft`: inverse short-time Fourier transform. + """ return _spectral_helper(x, None, fs, window, nperseg, noverlap, nfft, detrend, return_onesided, scaling='spectrum', axis=axis, @@ -511,19 +873,56 @@ def stft(x: Array, fs: ArrayLike = 1.0, window: str = 'hann', nperseg: int = 256 padded=padded) -_csd_description = """ -The original SciPy function exhibits slightly different behavior between -``csd(x, x)``` and ```csd(x, x.copy())```. The LAX-backend version is designed -to follow the latter behavior. For using the former behavior, call this -function as `csd(x, None)`.""" - - -@implements(osp_signal.csd, lax_description=_csd_description) def csd(x: Array, y: ArrayLike | None, fs: ArrayLike = 1.0, window: str = 'hann', nperseg: int | None = None, noverlap: int | None = None, nfft: int | None = None, detrend: str = 'constant', return_onesided: bool = True, scaling: str = 'density', axis: int = -1, average: str = 'mean') -> tuple[Array, Array]: + """ + Estimate cross power spectral density (CSD) using Welch's method. + + This is a JAX implementation of :func:`scipy.signal.csd`. It is similar to + :func:`jax.scipy.signal.welch`, but it operates on two input signals and + estimates their cross-spectral density instead of the power spectral density + (PSD). + + Args: + x: Array representing a time series of input values. + y: Array representing the second time series of input values, the same length as ``x`` + along the specified ``axis``. If not specified, then assume ``y = x`` and compute + the PSD ``Pxx`` of ``x`` via Welch's method. + fs: Sampling frequency of the inputs (default: 1.0). + window: Data tapering window to apply to each segment. Can be a window function name, + a tuple specifying a window length and function, or an array (default: ``'hann'``). + nperseg: Length of each segment (default: 256). + noverlap: Number of points to overlap between segments (default: ``nperseg // 2``). + nfft: Length of the FFT used, if a zero-padded FFT is desired. If ``None`` (default), + the FFT length is ``nperseg``. + detrend: Specifies how to detrend each segment. Can be ``False`` (default: no detrending), + ``'constant'`` (remove mean), ``'linear'`` (remove linear trend), or a callable + accepting a segment and returning a detrended segment. + return_onesided: If True (default), return a one-sided spectrum for real inputs. + If False, return a two-sided spectrum. + scaling: Selects between computing the power spectral density (``'density'``, default) + or the power spectrum (``'spectrum'``) + axis: Axis along which the CSD is computed (default: -1). + average: The type of averaging to use on the periodograms; one of ``'mean'`` (default) + or ``'median'``. + + Returns: + A length-2 tuple of arrays ``(f, Pxy)``. ``f`` is the array of sample frequencies, + and ``Pxy`` is the cross spectral density of `x` and `y` + + Notes: + The original SciPy function exhibits slightly different behavior between + ``csd(x, x)`` and ``csd(x, x.copy())``. The LAX-backend version is designed + to follow the latter behavior. To replicate the former, call this function + function as ``csd(x, None)``. + + See Also: + - :func:`jax.scipy.signal.welch`: Power spectral density. + - :func:`jax.scipy.signal.stft`: Short-time Fourier transform. + """ freqs, _, Pxy = _spectral_helper(x, y, fs, window, nperseg, noverlap, nfft, detrend, return_onesided, scaling, axis, mode='psd') @@ -551,12 +950,46 @@ def csd(x: Array, y: ArrayLike | None, fs: ArrayLike = 1.0, window: str = 'hann' return freqs, Pxy -@implements(osp_signal.welch) def welch(x: Array, fs: ArrayLike = 1.0, window: str = 'hann', nperseg: int | None = None, noverlap: int | None = None, nfft: int | None = None, detrend: str = 'constant', return_onesided: bool = True, scaling: str = 'density', axis: int = -1, average: str = 'mean') -> tuple[Array, Array]: + """ + Estimate power spectral density (PSD) using Welch's method. + + This is a JAX implementation of :func:`scipy.signal.welch`. It divides the + input signal into overlapping segments, computes the modified periodogram for + each segment, and averages the results to obtain a smoother estimate of the PSD. + + Args: + x: Array representing a time series of input values. + fs: Sampling frequency of the inputs (default: 1.0). + window: Data tapering window to apply to each segment. Can be a window function name, + a tuple specifying a window length and function, or an array (default: ``'hann'``). + nperseg: Length of each segment (default: 256). + noverlap: Number of points to overlap between segments (default: ``nperseg // 2``). + nfft: Length of the FFT used, if a zero-padded FFT is desired. If ``None`` (default), + the FFT length is ``nperseg``. + detrend: Specifies how to detrend each segment. Can be ``False`` (default: no detrending), + ``'constant'`` (remove mean), ``'linear'`` (remove linear trend), or a callable + accepting a segment and returning a detrended segment. + return_onesided: If True (default), return a one-sided spectrum for real inputs. + If False, return a two-sided spectrum. + scaling: Selects between computing the power spectral density (``'density'``, default) + or the power spectrum (``'spectrum'``) + axis: Axis along which the PSD is computed (default: -1). + average: The type of averaging to use on the periodograms; one of ``'mean'`` (default) + or ``'median'``. + + Returns: + A length-2 tuple of arrays ``(f, Pxx)``. ``f`` is the array of sample frequencies, + and ``Pxx`` is the power spectral density of ``x``. + + See Also: + - :func:`jax.scipy.signal.csd`: Cross power spectral density. + - :func:`jax.scipy.signal.stft`: Short-time Fourier transform. + """ freqs, Pxx = csd(x, None, fs=fs, window=window, nperseg=nperseg, noverlap=noverlap, nfft=nfft, detrend=detrend, return_onesided=return_onesided, scaling=scaling, @@ -613,12 +1046,54 @@ def _overlap_and_add(x: Array, step_size: int) -> Array: return x.reshape(tuple(batch_shape) + (-1,)) -@implements(osp_signal.istft) def istft(Zxx: Array, fs: ArrayLike = 1.0, window: str = 'hann', nperseg: int | None = None, noverlap: int | None = None, nfft: int | None = None, input_onesided: bool = True, boundary: bool = True, time_axis: int = -1, freq_axis: int = -2) -> tuple[Array, Array]: + """ + Perform the inverse short-time Fourier transform (ISTFT). + + JAX implementation of :func:`scipy.signal.istft`; computes the inverse of + :func:`jax.scipy.signal.stft`. + + Args: + Zxx: STFT of the signal to be reconstructed. + fs: Sampling frequency of the time series (default: 1.0) + window: Data tapering window to apply to each segment. Can be a window function name, + a tuple specifying a window length and function, or an array (default: ``'hann'``). + nperseg: Number of data points per segment in the STFT. If ``None`` (default), the + value is determined from the size of ``Zxx``. + noverlap: Number of points to overlap between segments (default: ``nperseg // 2``). + nfft: Number of FFT points used in the STFT. If ``None`` (default), the + value is determined from the size of ``Zxx``. + input_onesided: If Tru` (default), interpret the input as a one-sided STFT + (positive frequencies only). If False, interpret the input as a two-sided STFT. + boundary: If True (default), it is assumed that the input signal was extended at + its boundaries by ``stft``. If `False`, the input signal is assumed to have been truncated at the boundaries by `stft`. + time_axis: Axis in `Zxx` corresponding to time segments (default: -1). + freq_axis: Axis in `Zxx` corresponding to frequency bins (default: -2). + + Returns: + A length-2 tuple of arrays ``(t, x)``. ``t`` is the Array of signal times, and ``x`` + is the reconstructed time series. + + See Also: + :func:`jax.scipy.signal.stft`: short-time Fourier transform. + + Examples: + Demonstrate that this gives the inverse of :func:`~jax.scipy.signal.stft`: + + >>> x = jnp.array([1., 2., 3., 2., 1., 0., 1., 2.]) + >>> f, t, Zxx = jax.scipy.signal.stft(x, nperseg=4) + >>> print(Zxx) # doctest: +SKIP + [[ 1. +0.j 2.5+0.j 1. +0.j 1. +0.j 0.5+0.j ] + [-0.5+0.5j -1.5+0.j -0.5-0.5j -0.5+0.5j 0. -0.5j] + [ 0. +0.j 0.5+0.j 0. +0.j 0. +0.j -0.5+0.j ]] + >>> t, x_reconstructed = jax.scipy.signal.istft(Zxx) + >>> print(x_reconstructed) + [1. 2. 3. 2. 1. 0. 1. 2.] + """ # Input validation check_arraylike("istft", Zxx) if Zxx.ndim < 2: @@ -668,9 +1143,18 @@ def istft(Zxx: Array, fs: ArrayLike = 1.0, window: str = 'hann', xsubs = ifunc(Zxx, axis=-2, n=nfft)[..., :nperseg_int, :] # Get window as array - if isinstance(window, (str, tuple)): - win = osp_signal.get_window(window, nperseg_int) - win = jnp.asarray(win, dtype=xsubs.dtype) + if window == 'hann': + # Implement the default case without scipy + win = jnp.array([1.0]) if nperseg_int == 1 else jnp.sin(jnp.linspace(0, jnp.pi, nperseg_int, endpoint=False)) ** 2 + win = win.astype(xsubs.dtype) + elif isinstance(window, (str, tuple)): + # TODO(jakevdp): implement get_window() in JAX to remove optional scipy dependency + try: + from scipy.signal import get_window + except ImportError as err: + raise ImportError(f"scipy must be available to use {window=}") from err + win = get_window(window, nperseg_int) + win = jnp.array(win, dtype=xsubs.dtype) else: win = jnp.asarray(window) if len(win.shape) != 1: diff --git a/jax/_src/scipy/spatial/transform.py b/jax/_src/scipy/spatial/transform.py index 3f96511a116b..ec7165e32ffd 100644 --- a/jax/_src/scipy/spatial/transform.py +++ b/jax/_src/scipy/spatial/transform.py @@ -18,17 +18,46 @@ import re import typing -import scipy.spatial.transform import jax import jax.numpy as jnp -from jax._src.numpy.util import implements -@implements(scipy.spatial.transform.Rotation) class Rotation(typing.NamedTuple): - """Rotation in 3 dimensions.""" + """Rotation in 3 dimensions. + JAX implementation of :class:`scipy.spatial.transform.Rotation`. + + Examples: + Construct an object describing a 90 degree rotation about the z-axis: + + >>> from jax.scipy.spatial.transform import Rotation + >>> r = Rotation.from_euler('z', 90, degrees=True) + + Convert to a rotation vector: + + >>> r.as_rotvec() + Array([0. , 0. , 1.5707964], dtype=float32) + + Convert to rotation matrix: + + >>> r.as_matrix() + Array([[ 0. , -0.99999994, 0. ], + [ 0.99999994, 0. , 0. ], + [ 0. , 0. , 0.99999994]], dtype=float32) + + Compose with another rotation: + + >>> r2 = Rotation.from_euler('x', 90, degrees=True) + >>> r3 = r * r2 + >>> r3.as_matrix() + Array([[0., 0., 1.], + [1., 0., 0.], + [0., 1., 0.]], dtype=float32) + + See the scipy :class:`~scipy.spatial.transform.Rotation` documentation for + further examples of manipulating Rotation objects. + """ quat: jax.Array @classmethod @@ -86,7 +115,7 @@ def identity(cls, num: int | None = None, dtype=float): def random(cls, random_key: jax.Array, num: int | None = None): """Generate uniformly distributed rotations.""" # Need to implement scipy.stats.special_ortho_group for this to work... - raise NotImplementedError + raise NotImplementedError() def __getitem__(self, indexer): """Extract rotation(s) at given index(es) from object.""" @@ -169,9 +198,31 @@ def single(self) -> bool: return self.quat.ndim == 1 -@implements(scipy.spatial.transform.Slerp) class Slerp(typing.NamedTuple): - """Spherical Linear Interpolation of Rotations.""" + """Spherical Linear Interpolation of Rotations. + + JAX implementation of :class:`scipy.spatial.transform.Slerp`. + + Examples: + Create a Slerp instance from a series of rotations: + + >>> import math + >>> from jax.scipy.spatial.transform import Rotation, Slerp + >>> rots = jnp.array([[90, 0, 0], + ... [0, 45, 0], + ... [0, 0, -30]]) + >>> key_rotations = Rotation.from_euler('zxy', rots, degrees=True) + >>> key_times = [0, 1, 2] + >>> slerp = Slerp.init(key_times, key_rotations) + >>> times = [0, 0.5, 1, 1.5, 2] + >>> interp_rots = slerp(times) + >>> interp_rots.as_euler('zxy') + Array([[ 1.5707963e+00, 0.0000000e+00, 0.0000000e+00], + [ 8.5309029e-01, 3.8711953e-01, 1.7768645e-01], + [-2.3841858e-07, 7.8539824e-01, 0.0000000e+00], + [-5.6668043e-02, 3.9213133e-01, -2.8347540e-01], + [ 0.0000000e+00, 0.0000000e+00, -5.2359891e-01]], dtype=float32) + """ times: jnp.ndarray timedelta: jnp.ndarray diff --git a/jax/_src/scipy/special.py b/jax/_src/scipy/special.py index d4aced143016..70f3ccd2ef80 100644 --- a/jax/_src/scipy/special.py +++ b/jax/_src/scipy/special.py @@ -16,10 +16,9 @@ from functools import partial import operator -from typing import cast, Any +from typing import cast, overload, Any import numpy as np -import scipy.special as osp_special import jax.numpy as jnp from jax import jit @@ -29,116 +28,470 @@ from jax._src import core from jax._src import custom_derivatives +from jax._src import deprecations from jax._src import dtypes from jax._src.lax.lax import _const as _lax_const from jax._src.numpy.util import promote_args_inexact, promote_dtypes_inexact -from jax._src.numpy.util import implements from jax._src.ops import special as ops_special from jax._src.third_party.scipy.betaln import betaln as _betaln_impl from jax._src.typing import Array, ArrayLike +from jax._src.nn.functions import softmax as nn_softmax +from jax._src.nn.functions import log_softmax as nn_log_softmax -@implements(osp_special.gammaln, module='scipy.special') def gammaln(x: ArrayLike) -> Array: + r"""Natural log of the absolute value of the gamma function. + + JAX implementation of :obj:`scipy.special.gammaln`. + + .. math:: + + \mathrm{gammaln}(x) = \log(|\Gamma(x)|) + + Where :math:`\Gamma` is the :func:`~jax.scipy.special.gamma` function. + + Args: + x: arraylike, real valued. + + Returns: + array containing the values of the log-gamma function + + See Also: + - :func:`jax.scipy.special.gammaln`: the natural log of the gamma function + - :func:`jax.scipy.special.gammasgn`: the sign of the gamma function + + Notes: + ``gammaln`` does not support complex-valued inputs. + """ x, = promote_args_inexact("gammaln", x) return lax.lgamma(x) -@implements(osp_special.gamma, module='scipy.special', lax_description="""\ -The JAX version only accepts real-valued inputs.""") +def gammasgn(x: ArrayLike) -> Array: + r"""Sign of the gamma function. + + JAX implementation of :obj:`scipy.special.gammasgn`. + + .. math:: + + \mathrm{gammasgn}(x) = \begin{cases} + +1 & \Gamma(x) > 0 \\ + -1 & \Gamma(x) < 0 + \end{cases} + + Where :math:`\Gamma` is the :func:`~jax.scipy.special.gamma` function. + Because :math:`\Gamma(x)` is never zero, no condition is required for this case. + + Args: + x: arraylike, real valued. + + Returns: + array containing the sign of the gamma function + + See Also: + - :func:`jax.scipy.special.gamma`: the gamma function + - :func:`jax.scipy.special.gammaln`: the natural log of the gamma function + """ + x, = promote_args_inexact("gammasgn", x) + floor_x = lax.floor(x) + return jnp.where((x > 0) | (x == floor_x) | (floor_x % 2 == 0), 1.0, -1.0) + + def gamma(x: ArrayLike) -> Array: + r"""The gamma function. + + JAX implementation of :obj:`scipy.special.gamma`. + + The gamma function is defined for :math:`\Re(z)>0` as + + .. math:: + + \mathrm{gamma}(z) = \Gamma(z) = \int_0^\infty t^{z-1}e^{-t}\mathrm{d}t + + and is extended by analytic continuation to arbitrary complex values `z`. + For positive integers `n`, the gamma function is related to the + :func:`~jax.scipy.special.factorial` function via the following identity: + + .. math:: + + \Gamma(n) = (n - 1)! + + Args: + x: arraylike, real valued. + + Returns: + array containing the values of the gamma function + + See Also: + - :func:`jax.scipy.special.factorial`: the factorial function. + - :func:`jax.scipy.special.gammaln`: the natural log of the gamma function + - :func:`jax.scipy.special.gammasgn`: the sign of the gamma function + + Notes: + Unlike the scipy version, JAX's ``gamma`` does not support complex-valued inputs. + """ x, = promote_args_inexact("gamma", x) - # Compute the sign for negative x, matching the semantics of scipy.special.gamma - floor_x = lax.floor(x) - sign = jnp.where((x > 0) | (x == floor_x), 1.0, (-1.0) ** floor_x) - return sign * lax.exp(lax.lgamma(x)) + return gammasgn(x) * lax.exp(lax.lgamma(x)) + -betaln = implements( - osp_special.betaln, - module='scipy.special', - update_doc=False -)(_betaln_impl) +def betaln(a: ArrayLike, b: ArrayLike) -> Array: + r"""Natural log of the absolute value of the beta function + + JAX implementation of :obj:`scipy.special.betaln`. + + .. math:: + + \mathrm{betaln}(a, b) = \log B(a, b) + + where :math:`B` is the :func:`~jax.scipy.special.beta` function. + + Args: + a: arraylike, real-valued. Parameter *a* of the beta distribution. + b: arraylike, real-valued. Parameter *b* of the beta distribution. + + Returns: + array containing the values of the log-beta function + + See Also: + :func:`jax.scipy.special.beta` + """ + a, b = promote_args_inexact("betaln", a, b) + return _betaln_impl(a, b) -@implements(osp_special.factorial, module='scipy.special') def factorial(n: ArrayLike, exact: bool = False) -> Array: + r"""Factorial function + + JAX implementation of :obj:`scipy.special.factorial` + + .. math:: + + \mathrm{factorial}(n) = n! = \prod_{k=1}^n k + + Args: + n: arraylike, values for which factorial will be computed elementwise + exact: bool, only ``exact=False`` is supported. + + Returns: + array containing values of the factorial. + + Notes: + This computes the float-valued factorial via the :func:`~jax.scipy.special.gamma` + function. JAX does not support exact factorials, because it is not particularly + useful: above ``n=20``, the exact result cannot be represented by 64-bit integers, + which are the largest integers available to JAX. + + See Also: + :func:`jax.scipy.special.gamma` + """ if exact: raise NotImplementedError("factorial with exact=True") n, = promote_args_inexact("factorial", n) return jnp.where(n < 0, 0, lax.exp(lax.lgamma(n + 1))) +@overload +def beta(a: ArrayLike, b: ArrayLike) -> Array: ... -@implements(osp_special.beta, module='scipy.special') -def beta(x: ArrayLike, y: ArrayLike) -> Array: - x, y = promote_args_inexact("beta", x, y) - return lax.exp(betaln(x, y)) +@overload +def beta(a: ArrayLike, *, y: ArrayLike) -> Array: ... + +@overload +def beta(*, x: ArrayLike, y: ArrayLike) -> Array: ... + +def beta(*args, **kwds): + r"""The beta function + + JAX implementation of :obj:`scipy.special.beta`. + + .. math:: + + \mathrm{beta}(a, b) = B(a, b) = \frac{\Gamma(a)\Gamma(b)}{\Gamma(a + b)} + + where :math:`\Gamma` is the :func:`~jax.scipy.special.gamma` function. + + Args: + a: arraylike, real-valued. Parameter *a* of the beta distribution. + b: arraylike, real-valued. Parameter *b* of the beta distribution. + + Returns: + array containing the values of the beta function. + + See Also: + - :func:`jax.scipy.special.gamma` + - :func:`jax.scipy.special.betaln` + """ + # TODO(jakevdp): deprecation warning added 2024-06-10; finalize after 2024-09-10 + if 'x' in kwds: + msg = "The `x` parameter of jax.scipy.special.beta is deprecated, use `a` instead." + deprecations.warn('jax-scipy-beta-args', msg, stacklevel=2) + if 'a' in kwds: + raise TypeError("beta() got both parameter 'a' and parameter 'x'.") + kwds['a'] = kwds.pop('x') + if 'y' in kwds: + msg = "The `y` parameter of jax.scipy.special.beta is deprecated, use `b` instead." + deprecations.warn('jax-scipy-beta-args', msg, stacklevel=2) + if 'b' in kwds: + raise TypeError("beta() got both parameter 'b' and parameter 'y'.") + kwds['b'] = kwds.pop('y') + if extra := kwds.keys() - {'a', 'b'}: + raise TypeError(f"beta() got unexpected keyword arguments {list(extra)}") + return _beta(*args, **kwds) + +def _beta(a, b): + a, b = promote_args_inexact("beta", a, b) + sign = gammasgn(a) * gammasgn(b) * gammasgn(a + b) + return sign * lax.exp(betaln(a, b)) -@implements(osp_special.betainc, module='scipy.special') def betainc(a: ArrayLike, b: ArrayLike, x: ArrayLike) -> Array: + r"""The regularized incomplete beta function. + + JAX implementation of :obj:`scipy.special.betainc`. + + .. math:: + + \mathrm{betainc}(a, b, x) = B(a, b)\int_0^x t^{a-1}(1-t^{b-1})\mathrm{d}t + + where :math:`B(a, b)` is the :func:`~jax.scipy.special.beta` function. + + Args: + a: arraylike, real-valued. Parameter *a* of the beta distribution. + b: arraylike, real-valued. Parameter *b* of the beta distribution. + x: arraylike, real-valued. Upper limit of the integration. + + Returns: + array containing values of the betainc function + + See Also: + - :func:`jax.scipy.special.beta` + - :func:`jax.scipy.special.betaln` + """ a, b, x = promote_args_inexact("betainc", a, b, x) return lax.betainc(a, b, x) -@implements(osp_special.digamma, module='scipy.special', lax_description="""\ -The JAX version only accepts real-valued inputs.""") def digamma(x: ArrayLike) -> Array: + r"""The digamma function + + JAX implementation of :obj:`scipy.special.digamma`. + + .. math:: + + \mathrm{digamma}(z) = \psi(z) = \frac{\mathrm{d}}{\mathrm{d}z}\log \Gamma(z) + + where :math:`\Gamma(z)` is the :func:`~jax.scipy.special.gamma` function. + + Args: + x: arraylike, real-valued. + + Returns: + array containing values of the digamma function. + + Notes: + The JAX version of `digamma` accepts real-valued inputs. + + See also: + - :func:`jax.scipy.special.gamma` + - :func:`jax.scipy.special.polygamma` + """ x, = promote_args_inexact("digamma", x) return lax.digamma(x) -@implements(osp_special.gammainc, module='scipy.special', update_doc=False) def gammainc(a: ArrayLike, x: ArrayLike) -> Array: + r"""The regularized lower incomplete gamma function. + + JAX implementation of :obj:`scipy.special.gammainc`. + + .. math:: + + \mathrm{gammainc}(x; a) = \frac{1}{\Gamma(a)}\int_0^x t^{a-1}e^{-t}\mathrm{d}t + + where :math:`\Gamma(a)` is the :func:`~jax.scipy.special.gamma` function. + + Args: + a: arraylike, real-valued. Positive shape parameter of the gamma distribution. + x: arraylike, real-valued. Non-negative upper limit of integration + + Returns: + array containing values of the gammainc function. + + See Also: + - :func:`jax.scipy.special.gamma` + - :func:`jax.scipy.special.gammaincc` + """ a, x = promote_args_inexact("gammainc", a, x) return lax.igamma(a, x) -@implements(osp_special.gammaincc, module='scipy.special', update_doc=False) def gammaincc(a: ArrayLike, x: ArrayLike) -> Array: + r"""The regularized upper incomplete gamma function. + + JAX implementation of :obj:`scipy.special.gammaincc`. + + .. math:: + + \mathrm{gammaincc}(x; a) = \frac{1}{\Gamma(a)}\int_x^\infty t^{a-1}e^{-t}\mathrm{d}t + + where :math:`\Gamma(a)` is the :func:`~jax.scipy.special.gamma` function. + + Args: + a: arraylike, real-valued. Positive shape parameter of the gamma distribution. + x: arraylike, real-valued. Non-negative lower limit of integration + + Returns: + array containing values of the gammaincc function. + + See Also: + - :func:`jax.scipy.special.gamma` + - :func:`jax.scipy.special.gammainc` + """ a, x = promote_args_inexact("gammaincc", a, x) return lax.igammac(a, x) -@implements(osp_special.erf, module='scipy.special', skip_params=["out"], - lax_description="Note that the JAX version does not support complex inputs.") def erf(x: ArrayLike) -> Array: + r"""The error function + + JAX implementation of :obj:`scipy.special.erf`. + + .. math:: + + \mathrm{erf}(x) = \frac{2}{\sqrt\pi} \int_{0}^x e^{-t^2} \mathrm{d}t + + Args: + x: arraylike, real-valued. + + Returns: + array containing values of the error function. + + Notes: + The JAX version only supports real-valued inputs. + + See also: + - :func:`jax.scipy.special.erfc` + - :func:`jax.scipy.special.erfinv` + """ x, = promote_args_inexact("erf", x) return lax.erf(x) -@implements(osp_special.erfc, module='scipy.special', update_doc=False) def erfc(x: ArrayLike) -> Array: + r"""The complement of the error function + + JAX implementation of :obj:`scipy.special.erfc`. + + .. math:: + + \mathrm{erfc}(x) = \frac{2}{\sqrt\pi} \int_{x}^\infty e^{-t^2} \mathrm{d}t + + This is the complement of the error function :func:`~jax.scipy.special.erf`, + ``erfc(x) = 1 - erf(x)``. + + Args: + x: arraylike, real-valued. + + Returns: + array containing values of the complement of the error function. + + Notes: + The JAX version only supports real-valued inputs. + + See also: + - :func:`jax.scipy.special.erf` + - :func:`jax.scipy.special.erfinv` + """ x, = promote_args_inexact("erfc", x) return lax.erfc(x) -@implements(osp_special.erfinv, module='scipy.special') def erfinv(x: ArrayLike) -> Array: + """The inverse of the error function + + JAX implementation of :obj:`scipy.special.erfinv`. + + Returns the inverse of :func:`~jax.scipy.special.erf`. + + Args: + x: arraylike, real-valued. + + Returns: + array containing values of the inverse error function. + + Notes: + The JAX version only supports real-valued inputs. + + See also: + - :func:`jax.scipy.special.erf` + - :func:`jax.scipy.special.erfc` + """ x, = promote_args_inexact("erfinv", x) return lax.erf_inv(x) @custom_derivatives.custom_jvp -@implements(osp_special.logit, module='scipy.special', update_doc=False) def logit(x: ArrayLike) -> Array: + r"""The logit function + + JAX implementation of :obj:`scipy.special.logit`. + + .. math:: + + \mathrm{logit}(p) = \log\frac{p}{1 - p} + + Args: + x: arraylike, real-valued. + + Returns: + array containing values of the logit function. + """ x, = promote_args_inexact("logit", x) return lax.log(lax.div(x, lax.sub(_lax_const(x, 1), x))) logit.defjvps( lambda g, ans, x: lax.div(g, lax.mul(x, lax.sub(_lax_const(x, 1), x)))) -@implements(osp_special.expit, module='scipy.special', update_doc=False) def expit(x: ArrayLike) -> Array: + r"""The logistic sigmoid (expit) function + + JAX implementation of :obj:`scipy.special.expit`. + + .. math:: + + \mathrm{expit}(x) = \frac{1}{1 + e^{-x}} + + Args: + x: arraylike, real-valued. + + Returns: + array containing values of the expit function. + """ x, = promote_args_inexact("expit", x) return lax.logistic(x) -logsumexp = implements(osp_special.logsumexp, module='scipy.special')(ops_special.logsumexp) +logsumexp = ops_special.logsumexp @custom_derivatives.custom_jvp -@implements(osp_special.xlogy, module='scipy.special') def xlogy(x: ArrayLike, y: ArrayLike) -> Array: + """Compute x*log(y), returning 0 for x=0. + + JAX implementation of :obj:`scipy.special.xlogy`. + + This is defined to return zero when :math:`(x, y) = (0, 0)`, with a custom + derivative rule so that automatic differentiation is well-defined at this point. + + Args: + x: arraylike, real-valued. + y: arraylike, real-valued. + + Returns: + array containing xlogy values. + + See also: + :func:`jax.scipy.special.xlog1py` + """ # Note: xlogy(0, 0) should return 0 according to the function documentation. x, y = promote_args_inexact("xlogy", x, y) x_ok = x != 0. @@ -153,8 +506,24 @@ def _xlogy_jvp(primals, tangents): @custom_derivatives.custom_jvp -@implements(osp_special.xlog1py, module='scipy.special', update_doc=False) def xlog1py(x: ArrayLike, y: ArrayLike) -> Array: + """Compute x*log(1 + y), returning 0 for x=0. + + JAX implementation of :obj:`scipy.special.xlog1py`. + + This is defined to return 0 when :math:`(x, y) = (0, -1)`, with a custom + derivative rule so that automatic differentiation is well-defined at this point. + + Args: + x: arraylike, real-valued. + y: arraylike, real-valued. + + Returns: + array containing xlog1py values. + + See also: + :func:`jax.scipy.special.xlogy` + """ # Note: xlog1py(0, -1) should return 0 according to the function documentation. x, y = promote_args_inexact("xlog1py", x, y) x_ok = x != 0. @@ -179,15 +548,62 @@ def _xlogx_jvp(primals, tangents): _xlogx.defjvp(_xlogx_jvp) -@implements(osp_special.entr, module='scipy.special') def entr(x: ArrayLike) -> Array: + r"""The entropy function + + JAX implementation of :obj:`scipy.special.entr`. + + .. math:: + + \mathrm{entr}(x) = \begin{cases} + -x\log(x) & x > 0 \\ + 0 & x = 0\\ + -\infty & x > 0 + \end{cases} + + Args: + x: arraylike, real-valued. + + Returns: + array containing entropy values. + + See also: + - :func:`jax.scipy.special.kl_div` + - :func:`jax.scipy.special.rel_entr` + """ x, = promote_args_inexact("entr", x) return lax.select(lax.lt(x, _lax_const(x, 0)), lax.full_like(x, -np.inf), lax.neg(_xlogx(x))) -@implements(osp_special.multigammaln, update_doc=False) + def multigammaln(a: ArrayLike, d: ArrayLike) -> Array: + r"""The natural log of the multivariate gamma function. + + JAX implementation of :func:`scipy.special.multigammaln`. + + .. math:: + + \mathrm{multigammaln}(a, d) = \log\Gamma_d(a) + + where + + .. math:: + + \Gamma_d(a) = \pi^{d(d-1)/4}\prod_{i=1}^d\Gamma(a-(i-1)/2) + + and :math:`\Gamma(x)` is the :func:`~jax.scipy.special.gamma` function. + + Args: + a: arraylike, real-valued. + d: int, the dimension of the integration space. + + Returns: + array containing values of the log-multigamma function. + + See also: + - :func:`jax.scipy.special.gamma` + """ d = core.concrete_or_error(int, d, "d argument of multigammaln") a, d_ = promote_args_inexact("multigammaln", a, d) @@ -201,49 +617,76 @@ def multigammaln(a: ArrayLike, d: ArrayLike) -> Array: return res + constant -@implements(osp_special.kl_div, module="scipy.special") def kl_div( p: ArrayLike, q: ArrayLike, ) -> Array: - p, q = promote_args_inexact("kl_div", p, q) - zero = _lax_const(p, 0.0) - both_gt_zero_mask = lax.bitwise_and(lax.gt(p, zero), lax.gt(q, zero)) - one_zero_mask = lax.bitwise_and(lax.eq(p, zero), lax.ge(q, zero)) - - safe_p = jnp.where(both_gt_zero_mask, p, 1) - safe_q = jnp.where(both_gt_zero_mask, q, 1) - - log_val = lax.sub( - lax.add( - lax.sub(_xlogx(safe_p), xlogy(safe_p, safe_q)), - safe_q, - ), - safe_p, - ) - result = jnp.where( - both_gt_zero_mask, log_val, jnp.where(one_zero_mask, q, np.inf) - ) - return result - - -@implements(osp_special.rel_entr, module="scipy.special") + r"""The Kullback-Leibler divergence. + + JAX implementation of :obj:`scipy.special.kl_div`. + + .. math:: + + \mathrm{kl\_div}(p, q) = \begin{cases} + p\log(p/q)-p+q & p>0,q>0\\ + q & p=0,q\ge 0\\ + \infty & \mathrm{otherwise} + \end{cases} + + Args: + p: arraylike, real-valued. + q: arraylike, real-valued. + + Returns: + array of KL-divergence values + + See also: + - :func:`jax.scipy.special.entr` + - :func:`jax.scipy.special.rel_entr` + """ + p, q = promote_args_inexact("kl_div", p, q) + return rel_entr(p, q) - p + q + + def rel_entr( p: ArrayLike, q: ArrayLike, ) -> Array: - p, q = promote_args_inexact("rel_entr", p, q) - zero = _lax_const(p, 0.0) - both_gt_zero_mask = lax.bitwise_and(lax.gt(p, zero), lax.gt(q, zero)) - one_zero_mask = lax.bitwise_and(lax.eq(p, zero), lax.ge(q, zero)) - - safe_p = jnp.where(both_gt_zero_mask, p, 1) - safe_q = jnp.where(both_gt_zero_mask, q, 1) - log_val = lax.sub(_xlogx(safe_p), xlogy(safe_p, safe_q)) - result = jnp.where( - both_gt_zero_mask, log_val, jnp.where(one_zero_mask, q, jnp.inf) - ) - return result + r"""The relative entropy function. + + JAX implementation of :obj:`scipy.special.rel_entr`. + + .. math:: + + \mathrm{rel\_entr}(p, q) = \begin{cases} + p\log(p/q) & p>0,q>0\\ + 0 & p=0,q\ge 0\\ + \infty & \mathrm{otherwise} + \end{cases} + + Args: + p: arraylike, real-valued. + q: arraylike, real-valued. + + Returns: + array of relative entropy values. + + See also: + - :func:`jax.scipy.special.entr` + - :func:`jax.scipy.special.kl_div` + """ + p, q = promote_args_inexact("rel_entr", p, q) + zero = _lax_const(p, 0.0) + both_gt_zero_mask = lax.bitwise_and(lax.gt(p, zero), lax.gt(q, zero)) + one_zero_mask = lax.bitwise_and(lax.eq(p, zero), lax.ge(q, zero)) + + safe_p = jnp.where(both_gt_zero_mask, p, 1) + safe_q = jnp.where(both_gt_zero_mask, q, 1) + log_val = lax.sub(_xlogx(safe_p), xlogy(safe_p, safe_q)) + result = jnp.where( + both_gt_zero_mask, log_val, jnp.where(one_zero_mask, zero, jnp.inf) + ) + return result # coefs of (2k)! / B_{2k} where B are bernoulli numbers # those numbers are obtained using https://www.wolframalpha.com @@ -268,8 +711,23 @@ def rel_entr( @custom_derivatives.custom_jvp -@implements(osp_special.zeta, module='scipy.special') def zeta(x: ArrayLike, q: ArrayLike | None = None) -> Array: + r"""The Hurwitz zeta function. + + JAX implementation of :func:`scipy.special.zeta`. JAX does not implement + the Riemann zeta function (i.e. ``q = None``). + + .. math:: + + \zeta(x, q) = \sum_{n=0}^\infty \frac{1}{(n + q)^x} + + Args: + x: arraylike, real-valued + q: arraylike, real-valued + + Returns: + array of zeta function values + """ if q is None: raise NotImplementedError( "Riemann zeta function not implemented; pass q != None to compute the Hurwitz Zeta function.") @@ -301,18 +759,38 @@ def _zeta_series_expansion(x: ArrayLike, q: ArrayLike | None = None) -> Array: m = jnp.expand_dims(np.arange(2 * M, dtype=M.dtype), tuple(range(s.ndim))) s_over_a = (s_ + m) / (a_ + N) T1 = jnp.cumprod(s_over_a, -1)[..., ::2] - T1 = jnp.clip(T1, a_max=jnp.finfo(dtype).max) + T1 = jnp.clip(T1, max=jnp.finfo(dtype).max) coefs = np.expand_dims(np.array(_BERNOULLI_COEFS[:T1.shape[-1]], dtype=dtype), tuple(range(a.ndim))) T1 = T1 / coefs T = T0 * (dtype(0.5) + T1.sum(-1)) return S + I + T -zeta.defjvp(partial(jvp, _zeta_series_expansion)) # type: ignore[arg-type] +zeta.defjvp(partial(jvp, _zeta_series_expansion)) -@implements(osp_special.polygamma, module='scipy.special', update_doc=False) def polygamma(n: ArrayLike, x: ArrayLike) -> Array: + r"""The polygamma function. + + JAX implementation of :func:`scipy.special.polygamma`. + + .. math:: + + \mathrm{polygamma}(n, x) = \psi^{(n)}(x) = \frac{\mathrm{d}^n}{\mathrm{d}x^n}\log \Gamma(x) + + where :math:`\Gamma` is the :func:`~jax.scipy.special.gamma` function. + + Args: + n: arraylike, integer-valued. The order of the derivative. + x: arraylike, real-valued. The value at which to evaluate the function. + + Returns: + array + + See also: + - :func:`jax.scipy.special.gamma` + - :func:`jax.scipy.special.digamma` + """ assert jnp.issubdtype(lax.dtype(n), jnp.integer) n_arr, x_arr = promote_args_inexact("polygamma", n, x) return lax.polygamma(n_arr, x_arr) @@ -398,6 +876,8 @@ def polygamma(n: ArrayLike, x: ArrayLike) -> Array: def ndtr(x: ArrayLike) -> Array: r"""Normal distribution function. + JAX implementation of :obj:`scipy.special.ndtr`. + Returns the area under the Gaussian probability density function, integrated from minus infinity to x: @@ -444,11 +924,13 @@ def _ndtr(x: ArrayLike) -> Array: def ndtri(p: ArrayLike) -> Array: r"""The inverse of the CDF of the Normal distribution function. + JAX implementation of :obj:`scipy.special.ndtri`. + Returns `x` such that the area under the PDF from :math:`-\infty` to `x` is equal to `p`. A piece-wise rational approximation is done for the function. - This is a based on the implementation in netlib. + This is based on the implementation in netlib. Args: p: an array of type `float32`, `float64`. @@ -540,7 +1022,7 @@ def _create_polynomial(var, coeffs): # later on. The result from the computation when p == 0 is not used so any # number that doesn't result in NaNs is fine. sanitized_mcp = jnp.where( - maybe_complement_p <= dtype(0.), + maybe_complement_p == dtype(0.), jnp.full(shape, dtype(0.5)), maybe_complement_p) @@ -571,15 +1053,17 @@ def _create_polynomial(var, coeffs): x = jnp.where(p > dtype(1. - np.exp(-2.)), x, -x) infinity = jnp.full(shape, dtype(np.inf)) - x_nan_replaced = jnp.where( - p <= dtype(0.0), -infinity, jnp.where(p >= dtype(1.0), infinity, x)) - return x_nan_replaced + x_fix_boundaries = jnp.where( + p == dtype(0.0), -infinity, jnp.where(p == dtype(1.0), infinity, x)) + return x_fix_boundaries @partial(custom_derivatives.custom_jvp, nondiff_argnums=(1,)) def log_ndtr(x: ArrayLike, series_order: int = 3) -> Array: r"""Log Normal distribution function. + JAX implementation of :obj:`scipy.special.log_ndtr`. + For details of the Normal distribution function see `ndtr`. This function calculates :math:`\log(\mathrm{ndtr}(x))` by either calling @@ -725,23 +1209,103 @@ def _norm_logpdf(x): log_normalizer = _lax_const(x, _norm_logpdf_constant) return lax.sub(lax.mul(neg_half, lax.square(x)), log_normalizer) -@implements(osp_special.i0e, module='scipy.special') + def i0e(x: ArrayLike) -> Array: + r"""Exponentially scaled modified bessel function of zeroth order. + + JAX implementation of :obj:`scipy.special.i0e`. + + .. math:: + + \mathrm{i0e}(x) = e^{-|x|} I_0(x) + + where :math:`I_0(x)` is the modified Bessel function :func:`~jax.scipy.special.i0`. + + Args: + x: array, real-valued + + Returns: + array of bessel function values. + + See also: + - :func:`jax.scipy.special.i0` + - :func:`jax.scipy.special.i1` + - :func:`jax.scipy.special.i1e` + """ x, = promote_args_inexact("i0e", x) return lax.bessel_i0e(x) -@implements(osp_special.i0, module='scipy.special') + def i0(x: ArrayLike) -> Array: + r"""Modified bessel function of zeroth order. + + JAX implementation of :obj:`scipy.special.i0`. + + .. math:: + + \mathrm{i0}(x) = I_0(x) = \sum_{k=0}^\infty \frac{(x^2/4)^k}{(k!)^2} + + Args: + x: array, real-valued + + Returns: + array of bessel function values. + + See also: + - :func:`jax.scipy.special.i0e` + - :func:`jax.scipy.special.i1` + - :func:`jax.scipy.special.i1e` + """ x, = promote_args_inexact("i0", x) return lax.mul(lax.exp(lax.abs(x)), lax.bessel_i0e(x)) -@implements(osp_special.i1e, module='scipy.special') + def i1e(x: ArrayLike) -> Array: + r"""Exponentially scaled modified bessel function of first order. + + JAX implementation of :obj:`scipy.special.i1e`. + + .. math:: + + \mathrm{i1e}(x) = e^{-|x|} I_1(x) + + where :math:`I_1(x)` is the modified Bessel function :func:`~jax.scipy.special.i1`. + + Args: + x: array, real-valued + + Returns: + array of bessel function values + + See also: + - :func:`jax.scipy.special.i0` + - :func:`jax.scipy.special.i0e` + - :func:`jax.scipy.special.i1` + """ x, = promote_args_inexact("i1e", x) return lax.bessel_i1e(x) -@implements(osp_special.i1, module='scipy.special') + def i1(x: ArrayLike) -> Array: + r"""Modified bessel function of first order. + + JAX implementation of :obj:`scipy.special.i1`. + + .. math:: + + \mathrm{i1}(x) = I_1(x) = \frac{1}{2}x\sum_{k=0}^\infty\frac{(x^2/4)^k}{k!(k+1)!} + + Args: + x: array, real-valued + + Returns: + array of bessel function values + + See also: + - :func:`jax.scipy.special.i0` + - :func:`jax.scipy.special.i0e` + - :func:`jax.scipy.special.i1e` + """ x, = promote_args_inexact("i1", x) return lax.mul(lax.exp(lax.abs(x)), lax.bessel_i1e(x)) @@ -826,7 +1390,7 @@ def bessel_jn(z: ArrayLike, *, v: int, n_iter: int=50) -> Array: def _gen_recurrence_mask( l_max: int, is_normalized: bool, dtype: Any ) -> tuple[Array, Array]: - """Generates mask for recurrence relation on the remaining entries. + """Generates a mask for recurrence relation on the remaining entries. The remaining entries are with respect to the diagonal and offdiagonal entries. @@ -981,7 +1545,7 @@ def _gen_associated_legendre(l_max: int, `Y_l^m(θ, φ) = N_l^m * P_l^m(cos(θ)) * exp(i m φ)`, where `N_l^m` is the normalization factor and θ and φ are the colatitude and longitude, respectively. `N_l^m` is chosen in the way that the spherical harmonics form - a set of orthonormal basis function of L^2(S^2). For the computational + a set of orthonormal basis functions of L^2(S^2). For the computational efficiency of spherical harmonics transform, the normalization factor is used in the computation of the ALFs. In addition, normalizing `P_l^m` avoids overflow/underflow and achieves better numerical stability. Three @@ -1005,7 +1569,7 @@ def _gen_associated_legendre(l_max: int, operation, `W` is a diagonal matrix containing the quadrature weights, and `I` is the identity matrix. The Gauss-Chebyshev points are equally spaced, which only provide approximate discrete orthogonality. The - Driscoll & Healy qudarture points are equally spaced and provide the + Driscoll & Healy quadrature points are equally spaced and provide the exact discrete orthogonality. The number of sampling points is required to be twice as the number of frequency points (modes) in the Driscoll & Healy approach, which enables FFT and achieves a fast spherical harmonics @@ -1216,7 +1780,7 @@ def sph_harm(m: Array, Args: m: The order of the harmonic; must have `|m| <= n`. Return values for - `|m| > n` ara undefined. + `|m| > n` are undefined. n: The degree of the harmonic; must have `n >= 0`. The standard notation for degree in descriptions of spherical harmonics is `l (lower case L)`. We use `n` here to be consistent with `scipy.special.sph_harm`. Return @@ -1226,7 +1790,7 @@ def sph_harm(m: Array, n_max: The maximum degree `max(n)`. If the supplied `n_max` is not the true maximum value of `n`, the results are clipped to `n_max`. For example, `sph_harm(m=jnp.array([2]), n=jnp.array([10]), theta, phi, n_max=6)` - acutually returns + actually returns `sph_harm(m=jnp.array([2]), n=jnp.array([6]), theta, phi, n_max=6)` Returns: A 1D array containing the spherical harmonics at (m, n, theta, phi). @@ -1459,8 +2023,25 @@ def _expi_neg(x: Array) -> Array: @custom_derivatives.custom_jvp @jit -@implements(osp_special.expi, module='scipy.special') def expi(x: ArrayLike) -> Array: + r"""Exponential integral function. + + JAX implementation of :obj:`scipy.special.expi` + + .. math:: + + \mathrm{expi}(x) = \int_{-\infty}^x \frac{e^t}{t} \mathrm{d}t + + Args: + x: arraylike, real-valued + + Returns: + array of expi values + + See also: + - :func:`jax.scipy.special.expn` + - :func:`jax.scipy.special.exp1` + """ x_arr, = promote_args_inexact("expi", x) return jnp.piecewise(x_arr, [x_arr < 0], [_expi_neg, _expi_pos]) @@ -1577,9 +2158,27 @@ def _expn3(n: int, x: Array) -> Array: @partial(custom_derivatives.custom_jvp, nondiff_argnums=(0,)) @jnp.vectorize -@implements(osp_special.expn, module='scipy.special') @jit def expn(n: ArrayLike, x: ArrayLike) -> Array: + r"""Generalized exponential integral function. + + JAX implementation of :obj:`scipy.special.expn`. + + .. math:: + + \mathrm{expn}(x) = E_n(x) = x^{n-1}\int_x^\infty\frac{e^{-t}}{t^n}\mathrm{d}t + + Args: + n: arraylike, real-valued + x: arraylike, real-valued + + Returns: + array of expn values + + See also: + - :func:`jax.scipy.special.expi` + - :func:`jax.scipy.special.exp1` + """ n, x = promote_args_inexact("expn", n, x) _c = _lax_const zero = _c(x, 0) @@ -1615,8 +2214,26 @@ def expn_jvp(n, primals, tangents): ) -@implements(osp_special.exp1, module="scipy.special") -def exp1(x: ArrayLike, module='scipy.special') -> Array: +def exp1(x: ArrayLike) -> Array: + r"""Exponential integral function. + + JAX implementation of :obj:`scipy.special.exp1` + + .. math:: + + \mathrm{exp1}(x) = E_1(x) = x^{n-1}\int_x^\infty\frac{e^{-t}}{t}\mathrm{d}t + + + Args: + x: arraylike, real-valued + + Returns: + array of exp1 values + + See also: + - :func:`jax.scipy.special.expi` + - :func:`jax.scipy.special.expn` + """ x, = promote_args_inexact("exp1", x) # Casting because custom_jvp generic does not work correctly with mypy. return cast(Array, expn(1, x)) @@ -1675,13 +2292,15 @@ def _spence(x: Array) -> Array: def spence(x: Array) -> Array: - r""" - Spence's function, also known as the dilogarithm for real values. + r"""Spence's function, also known as the dilogarithm for real values. + + JAX implementation of :obj:`scipy.special.spence`. + It is defined to be: .. math:: - \begin{equation} - \int_1^z \frac{\log(t)}{1 - t}dt + \mathrm{spence}(x) = \begin{equation} + \int_1^x \frac{\log(t)}{1 - t}dt \end{equation} Unlike the SciPy implementation, this is only defined for positive @@ -1706,7 +2325,7 @@ def spence(x: Array) -> Array: -\int_0^z \frac{\log(1 - t)}{t}dt \end{equation} - this is our spence(1 - z). + This is our spence(1 - z). """ x = jnp.asarray(x) dtype = lax.dtype(x) @@ -1716,8 +2335,21 @@ def spence(x: Array) -> Array: return _spence(x) -@implements(osp_special.bernoulli, module='scipy.special') def bernoulli(n: int) -> Array: + """Generate the first N Bernoulli numbers. + + JAX implementation of :func:`scipy.special.bernoulli`. + + Args: + n: integer, the number of Bernoulli terms to generate. + + Returns: + Array containing the first ``n`` Bernoulli numbers. + + Notes: + ``bernoulli`` generates numbers using the :math:`B_n^-` convention, + such that :math:`B_1=-1/2`. + """ # Generate Bernoulli numbers using the Chowla and Hartung algorithm. n = core.concrete_or_error(operator.index, n, "Argument n of bernoulli") if n < 0: @@ -1734,10 +2366,27 @@ def bernoulli(n: int) -> Array: @custom_derivatives.custom_jvp -@implements(osp_special.poch, module='scipy.special', lax_description="""\ -The JAX version only accepts positive and real inputs.""") def poch(z: ArrayLike, m: ArrayLike) -> Array: - # Factorial definition when m is close to an integer, otherwise gamma definition. + r"""The Pochammer symbol. + + JAX implementation of :obj:`scipy.special.poch`. + + .. math:: + + \mathrm{poch}(z, m) = (z)_m = \frac{\Gamma(z + m)}{\Gamma(z)} + + where :math:`\Gamma(z)` is the :func:`~jax.scipy.special.gamma` function. + + Args: + z: arraylike, real-valued + m: arraylike, real-valued + + Returns: + array of Pochammer values. + + Notes: + The JAX version supports only real-valued inputs. + """ z, m = promote_args_inexact("poch", z, m) return jnp.where(m == 0., jnp.array(1, dtype=z.dtype), gamma(z + m) / gamma(z)) @@ -1774,6 +2423,8 @@ def _hyp1f1_serie(a, b, x): https://doi.org/10.48550/arXiv.1407.7786 """ + precision = jnp.finfo(x.dtype).eps + def body(state): serie, k, term = state serie += term @@ -1785,7 +2436,7 @@ def body(state): def cond(state): serie, k, term = state - return (k < 250) & (lax.abs(term) / lax.abs(serie) > 1e-8) + return (k < 250) & (lax.abs(term) / lax.abs(serie) > precision) init = 1, 1, a / b * x @@ -1799,6 +2450,8 @@ def _hyp1f1_asymptotic(a, b, x): https://doi.org/10.48550/arXiv.1407.7786 """ + precision = jnp.finfo(x.dtype).eps + def body(state): serie, k, term = state serie += term @@ -1810,7 +2463,7 @@ def body(state): def cond(state): serie, k, term = state - return (k < 250) & (lax.abs(term) / lax.abs(serie) > 1e-8) + return (k < 250) & (lax.abs(term) / lax.abs(serie) > precision) init = 1, 1, (b - a) * (1 - a) / x serie = lax.while_loop(cond, body, init)[0] @@ -1826,6 +2479,8 @@ def _hyp1f1_a_derivative(a, b, x): https://functions.wolfram.com/HypergeometricFunctions/Hypergeometric1F1/20/01/01/ """ + precision = jnp.finfo(x.dtype).eps + def body(state): serie, k, term = state serie += term * (digamma(a + k) - digamma(a)) @@ -1837,7 +2492,7 @@ def body(state): def cond(state): serie, k, term = state - return (k < 250) & (lax.abs(term) / lax.abs(serie) > 1e-15) + return (k < 250) & (lax.abs(term) / lax.abs(serie) > precision) init = 0, 1, a / b * x @@ -1852,6 +2507,8 @@ def _hyp1f1_b_derivative(a, b, x): https://functions.wolfram.com/HypergeometricFunctions/Hypergeometric1F1/20/01/02/ """ + precision = jnp.finfo(x.dtype).eps + def body(state): serie, k, term = state serie += term * (digamma(b) - digamma(b + k)) @@ -1863,7 +2520,7 @@ def body(state): def cond(state): serie, k, term = state - return (k < 250) & (lax.abs(term) / lax.abs(serie) > 1e-15) + return (k < 250) & (lax.abs(term) / lax.abs(serie) > precision) init = 0, 1, a / b * x @@ -1883,17 +2540,33 @@ def _hyp1f1_x_derivative(a, b, x): @custom_derivatives.custom_jvp @jit @jnp.vectorize -@implements(osp_special.hyp1f1, module='scipy.special', lax_description="""\ -The JAX version only accepts positive and real inputs. Values of a, b and x -leading to high values of 1F1 might be erroneous, considering enabling double -precision. Convention for a = b = 0 is 1, unlike in scipy's implementation.""") -def hyp1f1(a, b, x): - """ - Implementation of the 1F1 hypergeometric function for real valued inputs - Backed by https://doi.org/10.48550/arXiv.1407.7786 - There is room for improvement in the implementation using recursion to - evaluate lower values of hyp1f1 when a or b or both are > 60-80 +def hyp1f1(a: ArrayLike, b: ArrayLike, x: ArrayLike) -> Array: + r"""The 1F1 hypergeometric function. + + JAX implementation of :obj:`scipy.special.hyp1f1`. + + .. math:: + + \mathrm{hyp1f1}(a, b, x) = {}_1F_1(x;a, b) = \sum_{k=0}^\infty \frac{(a)_k}{(b)_kk!}x^k + + where :math:`(\cdot)_k` is the Pochammer symbol (refer to :func:`~jax.scipy.special.poch`). + + The JAX version only accepts positive and real inputs. Values of ``a``, ``b``, + and ``x``, leading to high values of 1F1 may lead to erroneous results; + consider enabling double precision in this case. The convention for + ``a = b = 0`` is ``1``, unlike in scipy's implementation. + + Args: + a: arraylike, real-valued + b: arraylike, real-valued + x: arraylike, real-valued + + Returns: + array of 1F1 values. """ + # This is backed by https://doi.org/10.48550/arXiv.1407.7786 + # There is room for improvement in the implementation using recursion to + # evaluate lower values of hyp1f1 when a or b or both are > 60-80 a, b, x = promote_args_inexact('hyp1f1', a, b, x) result = lax.cond(lax.abs(x) < 100, _hyp1f1_serie, _hyp1f1_asymptotic, a, b, x) @@ -1911,3 +2584,72 @@ def hyp1f1(a, b, x): lambda b_dot, primal_out, a, b, x: _hyp1f1_b_derivative(a, b, x) * b_dot, lambda x_dot, primal_out, a, b, x: _hyp1f1_x_derivative(a, b, x) * x_dot ) + + +def softmax(x: ArrayLike, + /, + *, + axis: int | tuple[int, ...] | None = None, + ) -> Array: + r"""Softmax function. + + JAX implementation of :func:`scipy.special.softmax`. + + Computes the function which rescales elements to the range :math:`[0, 1]` + such that the elements along :code:`axis` sum to :math:`1`. + + .. math :: + \mathrm{softmax}(x) = \frac{\exp(x_i)}{\sum_j \exp(x_j)} + + Args: + x : input array + axis: the axis or axes along which the softmax should be computed. The + softmax output summed across these dimensions should sum to :math:`1`. + + Returns: + An array of the same shape as ``x``. + + Note: + If any input values are ``+inf``, the result will be all ``NaN``: this + reflects the fact that ``inf / inf`` is not well-defined in the context of + floating-point math. + + See also: + :func:`log_softmax` + """ + return nn_softmax(x, axis=axis) + + +def log_softmax(x: ArrayLike, + /, + *, + axis: int | tuple[int, ...] | None = None, + ) -> Array: + r"""Log-Softmax function. + + JAX implementation of :func:`scipy.special.log_softmax` + + Computes the logarithm of the :code:`softmax` function, which rescales + elements to the range :math:`[-\infty, 0)`. + + .. math :: + \mathrm{log\_softmax}(x)_i = \log \left( \frac{\exp(x_i)}{\sum_j \exp(x_j)} + \right) + + Args: + x : input array + axis: the axis or axes along which the :code:`log_softmax` should be + computed. + + Returns: + An array of the same shape as ``x`` + + Note: + If any input values are ``+inf``, the result will be all ``NaN``: this + reflects the fact that ``inf / inf`` is not well-defined in the context of + floating-point math. + + See also: + :func:`softmax` + """ + return nn_log_softmax(x, axis=axis) diff --git a/jax/_src/scipy/stats/_core.py b/jax/_src/scipy/stats/_core.py index 7325b8cfbe83..08d1c0b6b538 100644 --- a/jax/_src/scipy/stats/_core.py +++ b/jax/_src/scipy/stats/_core.py @@ -23,19 +23,64 @@ from jax import jit from jax._src import dtypes from jax._src.api import vmap -from jax._src.numpy.util import check_arraylike, implements, promote_args_inexact +from jax._src.numpy.util import check_arraylike, promote_args_inexact from jax._src.typing import ArrayLike, Array from jax._src.util import canonicalize_axis -import scipy ModeResult = namedtuple('ModeResult', ('mode', 'count')) -@implements(scipy.stats.mode, lax_description="""\ -Currently the only supported nan_policy is 'propagate' -""") @partial(jit, static_argnames=['axis', 'nan_policy', 'keepdims']) def mode(a: ArrayLike, axis: int | None = 0, nan_policy: str = "propagate", keepdims: bool = False) -> ModeResult: + """Compute the mode (most common value) along an axis of an array. + + JAX implementation of :func:`scipy.stats.mode`. + + Args: + a: arraylike + axis: int, default=0. Axis along which to compute the mode. + nan_policy: str. JAX only supports ``"propagate"``. + keepdims: bool, default=False. If true, reduced axes are left in the result + with size 1. + + Returns: + A tuple of arrays, ``(mode, count)``. ``mode`` is the array of modal values, + and ``count`` is the number of times each value appears in the input array. + + Examples: + >>> x = jnp.array([2, 4, 1, 1, 3, 4, 4, 2, 3]) + >>> mode, count = jax.scipy.stats.mode(x) + >>> mode, count + (Array(4, dtype=int32), Array(3, dtype=int32)) + + For multi dimensional arrays, ``jax.scipy.stats.mode`` computes the ``mode`` + and the corresponding ``count`` along ``axis=0``: + + >>> x1 = jnp.array([[1, 2, 1, 3, 2, 1], + ... [3, 1, 3, 2, 1, 3], + ... [1, 2, 2, 3, 1, 2]]) + >>> mode, count = jax.scipy.stats.mode(x1) + >>> mode, count + (Array([1, 2, 1, 3, 1, 1], dtype=int32), Array([2, 2, 1, 2, 2, 1], dtype=int32)) + + If ``axis=1``, ``mode`` and ``count`` will be computed along ``axis 1``. + + >>> mode, count = jax.scipy.stats.mode(x1, axis=1) + >>> mode, count + (Array([1, 3, 2], dtype=int32), Array([3, 3, 3], dtype=int32)) + + By default, ``jax.scipy.stats.mode`` reduces the dimension of the result. + To keep the dimensions same as that of the input array, the argument + ``keepdims`` must be set to ``True``. + + >>> mode, count = jax.scipy.stats.mode(x1, axis=1, keepdims=True) + >>> mode, count + (Array([[1], + [3], + [2]], dtype=int32), Array([[3], + [3], + [3]], dtype=int32)) + """ check_arraylike("mode", a) x = jnp.atleast_1d(a) @@ -90,9 +135,7 @@ def invert_permutation(i: Array) -> Array: """Helper function that inverts a permutation array.""" return jnp.empty_like(i).at[i].set(jnp.arange(i.size, dtype=i.dtype)) -@implements(scipy.stats.rankdata, lax_description="""\ -Currently the only supported nan_policy is 'propagate' -""") + @partial(jit, static_argnames=["method", "axis", "nan_policy"]) def rankdata( a: ArrayLike, @@ -101,13 +144,39 @@ def rankdata( axis: int | None = None, nan_policy: str = "propagate", ) -> Array: + """Compute the rank of data along an array axis. + + JAX implementation of :func:`scipy.stats.rankdata`. + + Ranks begin at 1, and the *method* argument controls how ties are handled. + + Args: + a: arraylike + method: str, default="average". Supported methods are + ``("average", "min", "max", "dense", "ordinal")`` + For details, see the :func:`scipy.stats.rankdata` documentation. + axis: optional integer. If not specified, the input array is flattened. + nan_policy: str, JAX's implementation only supports ``"propagate"``. + + Returns: + array of ranks along the specified axis. + Examples: + + >>> x = jnp.array([10, 30, 20]) + >>> rankdata(x) + Array([1., 3., 2.], dtype=float32) + + >>> x = jnp.array([1, 3, 2, 3]) + >>> rankdata(x) + Array([1. , 3.5, 2. , 3.5], dtype=float32) + """ check_arraylike("rankdata", a) if nan_policy not in ["propagate", "omit", "raise"]: raise ValueError( f"Illegal nan_policy value {nan_policy!r}; expected one of " - "{'propoagate', 'omit', 'raise'}" + "{'propagate', 'omit', 'raise'}" ) if nan_policy == "omit": raise NotImplementedError( @@ -148,19 +217,87 @@ def rankdata( return .5 * (count[dense] + count[dense - 1] + 1).astype(dtypes.canonicalize_dtype(jnp.float_)) raise ValueError(f"unknown method '{method}'") -@implements(scipy.stats.sem, lax_description="""\ -Currently the only supported nan_policies are 'propagate' and 'omit' -""") + @partial(jit, static_argnames=['axis', 'nan_policy', 'keepdims']) def sem(a: ArrayLike, axis: int | None = 0, ddof: int = 1, nan_policy: str = "propagate", *, keepdims: bool = False) -> Array: + """Compute the standard error of the mean. + + JAX implementation of :func:`scipy.stats.sem`. + + Args: + a: arraylike + axis: optional integer. If not specified, the input array is flattened. + ddof: integer, default=1. The degrees of freedom in the SEM computation. + nan_policy: str, default="propagate". JAX supports only "propagate" and + "omit". + keepdims: bool, default=False. If true, reduced axes are left in the result + with size 1. + + Returns: + array + + Examples: + >>> x = jnp.array([2, 4, 1, 1, 3, 4, 4, 2, 3]) + >>> with jnp.printoptions(precision=2, suppress=True): + ... jax.scipy.stats.sem(x) + Array(0.41, dtype=float32) + + For multi dimensional arrays, ``sem`` computes standard error of mean along + ``axis=0``: + + >>> x1 = jnp.array([[1, 2, 1, 3, 2, 1], + ... [3, 1, 3, 2, 1, 3], + ... [1, 2, 2, 3, 1, 2]]) + >>> with jnp.printoptions(precision=2, suppress=True): + ... jax.scipy.stats.sem(x1) + Array([0.67, 0.33, 0.58, 0.33, 0.33, 0.58], dtype=float32) + + If ``axis=1``, standard error of mean will be computed along ``axis 1``. + + >>> with jnp.printoptions(precision=2, suppress=True): + ... jax.scipy.stats.sem(x1, axis=1) + Array([0.33, 0.4 , 0.31], dtype=float32) + + If ``axis=None``, standard error of mean will be computed along all the axes. + + >>> with jnp.printoptions(precision=2, suppress=True): + ... jax.scipy.stats.sem(x1, axis=None) + Array(0.2, dtype=float32) + + By default, ``sem`` reduces the dimension of the result. To keep the + dimensions same as that of the input array, the argument ``keepdims`` must + be set to ``True``. + + >>> with jnp.printoptions(precision=2, suppress=True): + ... jax.scipy.stats.sem(x1, axis=1, keepdims=True) + Array([[0.33], + [0.4 ], + [0.31]], dtype=float32) + + Since, by default, ``nan_policy='propagate'``, ``sem`` propagates the ``nan`` + values in the result. + + >>> nan = jnp.nan + >>> x2 = jnp.array([[1, 2, 3, nan, 4, 2], + ... [4, 5, 4, 3, nan, 1], + ... [7, nan, 8, 7, 9, nan]]) + >>> with jnp.printoptions(precision=2, suppress=True): + ... jax.scipy.stats.sem(x2) + Array([1.73, nan, 1.53, nan, nan, nan], dtype=float32) + + If ``nan_policy='omit```, ``sem`` omits the ``nan`` values and computes the error + for the remainging values along the specified axis. + + >>> with jnp.printoptions(precision=2, suppress=True): + ... jax.scipy.stats.sem(x2, nan_policy='omit') + Array([1.73, 1.5 , 1.53, 2. , 2.5 , 0.5 ], dtype=float32) + """ b, = promote_args_inexact("sem", a) - if axis is None: - b = b.ravel() - axis = 0 if nan_policy == "propagate": - return b.std(axis, ddof=ddof) / jnp.sqrt(b.shape[axis]).astype(b.dtype) + size = b.size if axis is None else b.shape[axis] + return b.std(axis, ddof=ddof, keepdims=keepdims) / jnp.sqrt(size).astype(b.dtype) elif nan_policy == "omit": - count = (~jnp.isnan(b)).sum(axis) - return jnp.nanstd(b, axis, ddof=ddof) / jnp.sqrt(count).astype(b.dtype) + count = (~jnp.isnan(b)).sum(axis, keepdims=keepdims) + return jnp.nanstd(b, axis, ddof=ddof, keepdims=keepdims) / jnp.sqrt(count).astype(b.dtype) else: raise ValueError(f"{nan_policy} is not supported") diff --git a/jax/_src/scipy/stats/bernoulli.py b/jax/_src/scipy/stats/bernoulli.py index 94d0a6735210..96e4a68b7697 100644 --- a/jax/_src/scipy/stats/bernoulli.py +++ b/jax/_src/scipy/stats/bernoulli.py @@ -12,19 +12,42 @@ # See the License for the specific language governing permissions and # limitations under the License. - -import scipy.stats as osp_stats - from jax import lax import jax.numpy as jnp from jax._src.lax.lax import _const as _lax_const -from jax._src.numpy.util import implements, promote_args_inexact +from jax._src.numpy.util import promote_args_inexact from jax._src.typing import Array, ArrayLike from jax.scipy.special import xlogy, xlog1py -@implements(osp_stats.bernoulli.logpmf, update_doc=False) def logpmf(k: ArrayLike, p: ArrayLike, loc: ArrayLike = 0) -> Array: + r"""Bernoulli log probability mass function. + + JAX implementation of :obj:`scipy.stats.bernoulli` ``logpmf`` + + The Bernoulli probability mass function is defined as + + .. math:: + + f(k) = \begin{cases} + 1 - p, & k = 0 \\ + p, & k = 1 \\ + 0, & \mathrm{otherwise} + \end{cases} + + Args: + k: arraylike, value at which to evaluate the PMF + p: arraylike, distribution shape parameter + loc: arraylike, distribution offset + + Returns: + array of logpmf values + + See Also: + - :func:`jax.scipy.stats.bernoulli.cdf` + - :func:`jax.scipy.stats.bernoulli.pmf` + - :func:`jax.scipy.stats.bernoulli.ppf` + """ k, p, loc = promote_args_inexact("bernoulli.logpmf", k, p, loc) zero = _lax_const(k, 0) one = _lax_const(k, 1) @@ -33,12 +56,65 @@ def logpmf(k: ArrayLike, p: ArrayLike, loc: ArrayLike = 0) -> Array: return jnp.where(jnp.logical_or(lax.lt(x, zero), lax.gt(x, one)), -jnp.inf, log_probs) -@implements(osp_stats.bernoulli.pmf, update_doc=False) + def pmf(k: ArrayLike, p: ArrayLike, loc: ArrayLike = 0) -> Array: + r"""Bernoulli probability mass function. + + JAX implementation of :obj:`scipy.stats.bernoulli` ``pmf`` + + The Bernoulli probability mass function is defined as + + .. math:: + + f(k) = \begin{cases} + 1 - p, & k = 0 \\ + p, & k = 1 \\ + 0, & \mathrm{otherwise} + \end{cases} + + Args: + k: arraylike, value at which to evaluate the PMF + p: arraylike, distribution shape parameter + loc: arraylike, distribution offset + + Returns: + array of pmf values + + See Also: + - :func:`jax.scipy.stats.bernoulli.cdf` + - :func:`jax.scipy.stats.bernoulli.logpmf` + - :func:`jax.scipy.stats.bernoulli.ppf` + """ return jnp.exp(logpmf(k, p, loc)) -@implements(osp_stats.bernoulli.cdf, update_doc=False) + def cdf(k: ArrayLike, p: ArrayLike) -> Array: + r"""Bernoulli cumulative distribution function. + + JAX implementation of :obj:`scipy.stats.bernoulli` ``cdf`` + + The Bernoulli cumulative distribution function is defined as: + + .. math:: + + f_{cdf}(k, p) = \sum_{i=0}^k f_{pmf}(k, p) + + where :math:`f_{pmf}(k, p)` is the Bernoulli probability mass function + :func:`jax.scipy.stats.bernoulli.pmf`. + + Args: + k: arraylike, value at which to evaluate the CDF + p: arraylike, distribution shape parameter + loc: arraylike, distribution offset + + Returns: + array of cdf values + + See Also: + - :func:`jax.scipy.stats.bernoulli.logpmf` + - :func:`jax.scipy.stats.bernoulli.pmf` + - :func:`jax.scipy.stats.bernoulli.ppf` + """ k, p = promote_args_inexact('bernoulli.cdf', k, p) zero, one = _lax_const(k, 0), _lax_const(k, 1) conds = [ @@ -50,8 +126,28 @@ def cdf(k: ArrayLike, p: ArrayLike) -> Array: vals = [jnp.nan, zero, one - p, one] return jnp.select(conds, vals) -@implements(osp_stats.bernoulli.ppf, update_doc=False) + def ppf(q: ArrayLike, p: ArrayLike) -> Array: + """Bernoulli percent point function. + + JAX implementation of :obj:`scipy.stats.bernoulli` ``ppf`` + + The percent point function is the inverse of the cumulative + distribution function, :func:`jax.scipy.stats.bernoulli.cdf`. + + Args: + k: arraylike, value at which to evaluate the PPF + p: arraylike, distribution shape parameter + loc: arraylike, distribution offset + + Returns: + array of ppf values + + See Also: + - :func:`jax.scipy.stats.bernoulli.cdf` + - :func:`jax.scipy.stats.bernoulli.logpmf` + - :func:`jax.scipy.stats.bernoulli.pmf` + """ q, p = promote_args_inexact('bernoulli.ppf', q, p) zero, one = _lax_const(q, 0), _lax_const(q, 1) return jnp.where( diff --git a/jax/_src/scipy/stats/beta.py b/jax/_src/scipy/stats/beta.py index 2b30ed7b824a..19b8400ee29d 100644 --- a/jax/_src/scipy/stats/beta.py +++ b/jax/_src/scipy/stats/beta.py @@ -12,39 +12,128 @@ # See the License for the specific language governing permissions and # limitations under the License. -import scipy.stats as osp_stats - from jax import lax import jax.numpy as jnp from jax._src.lax.lax import _const as _lax_const -from jax._src.numpy.util import implements, promote_args_inexact +from jax._src.numpy.util import promote_args_inexact from jax._src.typing import Array, ArrayLike from jax.scipy.special import betaln, betainc, xlogy, xlog1py -@implements(osp_stats.beta.logpdf, update_doc=False) def logpdf(x: ArrayLike, a: ArrayLike, b: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: + r"""Beta log probability distribution function. + + JAX implementation of :obj:`scipy.stats.beta` ``logpdf``. + + The pdf of the beta function is: + + .. math:: + + f(x, a, b) = \frac{\Gamma(a + b)}{\Gamma(a)\Gamma(b)} x^{a-1}(1-x)^{b-1} + + where :math:`\Gamma` is the :func:`~jax.scipy.special.gamma` function, + It is defined for :math:`0\le x\le 1` and :math:`b>0`. + + Args: + x: arraylike, value at which to evaluate the PDF + a: arraylike, distribution shape parameter + b: arraylike, distribution shape parameter + loc: arraylike, distribution offset parameter + scale: arraylike, distribution scale parameter + + Returns: + array of logpdf values + + See Also: + - :func:`jax.scipy.stats.beta.cdf` + - :func:`jax.scipy.stats.beta.pdf` + - :func:`jax.scipy.stats.beta.sf` + - :func:`jax.scipy.stats.beta.logcdf` + - :func:`jax.scipy.stats.beta.logsf` + """ x, a, b, loc, scale = promote_args_inexact("beta.logpdf", x, a, b, loc, scale) one = _lax_const(x, 1) + zero = _lax_const(a, 0) shape_term = lax.neg(betaln(a, b)) y = lax.div(lax.sub(x, loc), scale) log_linear_term = lax.add(xlogy(lax.sub(a, one), y), xlog1py(lax.sub(b, one), lax.neg(y))) log_probs = lax.sub(lax.add(shape_term, log_linear_term), lax.log(scale)) - return jnp.where(jnp.logical_or(lax.gt(x, lax.add(loc, scale)), - lax.lt(x, loc)), -jnp.inf, log_probs) + result = jnp.where(jnp.logical_or(lax.gt(x, lax.add(loc, scale)), + lax.lt(x, loc)), -jnp.inf, log_probs) + result_positive_constants = jnp.where(jnp.logical_or(jnp.logical_or(lax.le(a, zero), lax.le(b, zero)), + lax.le(scale, zero)), jnp.nan, result) + return result_positive_constants -@implements(osp_stats.beta.pdf, update_doc=False) def pdf(x: ArrayLike, a: ArrayLike, b: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: + r"""Beta probability distribution function. + + JAX implementation of :obj:`scipy.stats.beta` ``pdf``. + + The pdf of the beta function is: + + .. math:: + + f(x, a, b) = \frac{\Gamma(a + b)}{\Gamma(a)\Gamma(b)} x^{a-1}(1-x)^{b-1} + + where :math:`\Gamma` is the :func:`~jax.scipy.special.gamma` function. + It is defined for :math:`0\le x\le 1` and :math:`b>0`. + + Args: + x: arraylike, value at which to evaluate the PDF + a: arraylike, distribution shape parameter + b: arraylike, distribution shape parameter + loc: arraylike, distribution offset parameter + scale: arraylike, distribution scale parameter + + Returns: + array of pdf values + + See Also: + - :func:`jax.scipy.stats.beta.cdf` + - :func:`jax.scipy.stats.beta.sf` + - :func:`jax.scipy.stats.beta.logcdf` + - :func:`jax.scipy.stats.beta.logpdf` + - :func:`jax.scipy.stats.beta.logsf` + """ return lax.exp(logpdf(x, a, b, loc, scale)) -@implements(osp_stats.beta.cdf, update_doc=False) def cdf(x: ArrayLike, a: ArrayLike, b: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: + r"""Beta cumulative distribution function + + JAX implementation of :obj:`scipy.stats.beta` ``cdf``. + + The cdf is defined as + + .. math:: + + f_{cdf}(x, a, b) = \int_{-\infty}^x f_{pdf}(y, a, b)\mathrm{d}y + + where :math:`f_{pdf}` is the beta distribution probability density function, + :func:`jax.scipy.stats.beta.pdf`. + + Args: + x: arraylike, value at which to evaluate the CDF + a: arraylike, distribution shape parameter + b: arraylike, distribution shape parameter + loc: arraylike, distribution offset parameter + scale: arraylike, distribution scale parameter + + Returns: + array of cdf values + + See Also: + - :func:`jax.scipy.stats.beta.pdf` + - :func:`jax.scipy.stats.beta.sf` + - :func:`jax.scipy.stats.beta.logcdf` + - :func:`jax.scipy.stats.beta.logpdf` + - :func:`jax.scipy.stats.beta.logsf` + """ x, a, b, loc, scale = promote_args_inexact("beta.cdf", x, a, b, loc, scale) return betainc( a, @@ -57,15 +146,73 @@ def cdf(x: ArrayLike, a: ArrayLike, b: ArrayLike, ) -@implements(osp_stats.beta.logcdf, update_doc=False) def logcdf(x: ArrayLike, a: ArrayLike, b: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: + r"""Beta log cumulative distribution function. + + JAX implementation of :obj:`scipy.stats.beta` ``logcdf``. + + The cdf is defined as + + .. math:: + + f_{cdf}(x, a, b) = \int_{-\infty}^x f_{pdf}(y, a, b)\mathrm{d}y + + where :math:`f_{pdf}` is the beta distribution probability density function, + :func:`jax.scipy.stats.beta.pdf`. + + Args: + x: arraylike, value at which to evaluate the CDF + a: arraylike, distribution shape parameter + b: arraylike, distribution shape parameter + loc: arraylike, distribution offset parameter + scale: arraylike, distribution scale parameter + + Returns: + array of logcdf values + + See Also: + - :func:`jax.scipy.stats.beta.cdf` + - :func:`jax.scipy.stats.beta.pdf` + - :func:`jax.scipy.stats.beta.sf` + - :func:`jax.scipy.stats.beta.logpdf` + - :func:`jax.scipy.stats.beta.logsf` + """ return lax.log(cdf(x, a, b, loc, scale)) -@implements(osp_stats.beta.sf, update_doc=False) def sf(x: ArrayLike, a: ArrayLike, b: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: + r"""Beta distribution survival function. + + JAX implementation of :obj:`scipy.stats.beta` ``sf``. + + The survival function is defined as + + .. math:: + + f_{sf}(x, a, b) = 1 - f_{cdf}(x, a, b) + + where :math:`f_{cdf}(x, a, b)` is the beta cumulative distribution function, + :func:`jax.scipy.stats.beta.cdf`. + + Args: + x: arraylike, value at which to evaluate the SF + a: arraylike, distribution shape parameter + b: arraylike, distribution shape parameter + loc: arraylike, distribution offset parameter + scale: arraylike, distribution scale parameter + + Returns: + array of sf values. + + See Also: + - :func:`jax.scipy.stats.beta.cdf` + - :func:`jax.scipy.stats.beta.pdf` + - :func:`jax.scipy.stats.beta.logcdf` + - :func:`jax.scipy.stats.beta.logpdf` + - :func:`jax.scipy.stats.beta.logsf` + """ x, a, b, loc, scale = promote_args_inexact("beta.sf", x, a, b, loc, scale) return betainc( b, @@ -78,7 +225,36 @@ def sf(x: ArrayLike, a: ArrayLike, b: ArrayLike, ) -@implements(osp_stats.beta.logsf, update_doc=False) def logsf(x: ArrayLike, a: ArrayLike, b: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: + r"""Beta distribution log survival function. + + JAX implementation of :obj:`scipy.stats.beta` ``logsf``. + + The survival function is defined as + + .. math:: + + f_{sf}(x, a, b) = 1 - f_{cdf}(x, a, b) + + where :math:`f_{cdf}(x, a, b)` is the beta cumulative distribution function, + :func:`jax.scipy.stats.beta.cdf`. + + Args: + x: arraylike, value at which to evaluate the SF + a: arraylike, distribution shape parameter + b: arraylike, distribution shape parameter + loc: arraylike, distribution offset parameter + scale: arraylike, distribution scale parameter + + Returns: + array of logsf values. + + See Also: + - :func:`jax.scipy.stats.beta.cdf` + - :func:`jax.scipy.stats.beta.pdf` + - :func:`jax.scipy.stats.beta.sf` + - :func:`jax.scipy.stats.beta.logcdf` + - :func:`jax.scipy.stats.beta.logpdf` + """ return lax.log(sf(x, a, b, loc, scale)) diff --git a/jax/_src/scipy/stats/betabinom.py b/jax/_src/scipy/stats/betabinom.py index 1c7b1f9bd71c..d526c373f23b 100644 --- a/jax/_src/scipy/stats/betabinom.py +++ b/jax/_src/scipy/stats/betabinom.py @@ -12,21 +12,42 @@ # See the License for the specific language governing permissions and # limitations under the License - -import scipy.stats as osp_stats - from jax import lax import jax.numpy as jnp from jax._src.lax.lax import _const as _lax_const -from jax._src.numpy.util import implements, promote_args_inexact +from jax._src.numpy.util import promote_args_inexact from jax._src.scipy.special import betaln from jax._src.typing import Array, ArrayLike -@implements(osp_stats.betabinom.logpmf, update_doc=False) def logpmf(k: ArrayLike, n: ArrayLike, a: ArrayLike, b: ArrayLike, loc: ArrayLike = 0) -> Array: - """JAX implementation of scipy.stats.betabinom.logpmf.""" + r"""Beta-binomial log probability mass function. + + JAX implementation of :obj:`scipy.stats.betabinom` ``logpmf`` + + The beta-binomial distribution's probability mass function is defined as + + .. math:: + + f(k, n, a, b) = {n \choose k}\frac{B(k+a,n-k-b)}{B(a,b)} + + where :math:`B(a, b)` is the :func:`~jax.scipy.special.beta` function. It is + defined for :math:`n\ge 0`, :math:`a>0`, :math:`b>0`, and non-negative integers `k`. + + Args: + k: arraylike, value at which to evaluate the PMF + n: arraylike, distribution shape parameter + a: arraylike, distribution shape parameter + b: arraylike, distribution shape parameter + loc: arraylike, distribution offset parameter + + Returns: + array of logpmf values + + See Also: + :func:`jax.scipy.stats.betabinom.pmf` + """ k, n, a, b, loc = promote_args_inexact("betabinom.logpmf", k, n, a, b, loc) y = lax.sub(lax.floor(k), loc) one = _lax_const(y, 1) @@ -40,8 +61,32 @@ def logpmf(k: ArrayLike, n: ArrayLike, a: ArrayLike, b: ArrayLike, return jnp.where(n_a_b_cond, jnp.nan, log_probs) -@implements(osp_stats.betabinom.pmf, update_doc=False) def pmf(k: ArrayLike, n: ArrayLike, a: ArrayLike, b: ArrayLike, loc: ArrayLike = 0) -> Array: - """JAX implementation of scipy.stats.betabinom.pmf.""" + r"""Beta-binomial probability mass function. + + JAX implementation of :obj:`scipy.stats.betabinom` ``pmf``. + + The beta-binomial distribution's probability mass function is defined as + + .. math:: + + f(k, n, a, b) = {n \choose k}\frac{B(k+a,n-k-b)}{B(a,b)} + + where :math:`B(a, b)` is the :func:`~jax.scipy.special.beta` function. It is + defined for :math:`n\ge 0`, :math:`a>0`, :math:`b>0`, and non-negative integers `k`. + + Args: + k: arraylike, value at which to evaluate the PMF + n: arraylike, distribution shape parameter + a: arraylike, distribution shape parameter + b: arraylike, distribution shape parameter + loc: arraylike, distribution offset parameter + + Returns: + array of pmf values + + See Also: + :func:`jax.scipy.stats.betabinom.logpmf` + """ return lax.exp(logpmf(k, n, a, b, loc)) diff --git a/jax/_src/scipy/stats/binom.py b/jax/_src/scipy/stats/binom.py index 878fdc744510..e97e84d1c84f 100644 --- a/jax/_src/scipy/stats/binom.py +++ b/jax/_src/scipy/stats/binom.py @@ -12,31 +12,72 @@ # See the License for the specific language governing permissions and # limitations under the License - -import scipy.stats as osp_stats - from jax import lax import jax.numpy as jnp -from jax._src.numpy.util import implements, promote_args_inexact +from jax._src.numpy.util import promote_args_inexact from jax._src.scipy.special import gammaln, xlogy, xlog1py from jax._src.typing import Array, ArrayLike -@implements(osp_stats.nbinom.logpmf, update_doc=False) def logpmf(k: ArrayLike, n: ArrayLike, p: ArrayLike, loc: ArrayLike = 0) -> Array: - """JAX implementation of scipy.stats.binom.logpmf.""" - k, n, p, loc = promote_args_inexact("binom.logpmf", k, n, p, loc) - y = lax.sub(k, loc) - comb_term = lax.sub( - gammaln(n + 1), - lax.add(gammaln(y + 1), gammaln(n - y + 1)) - ) - log_linear_term = lax.add(xlogy(y, p), xlog1py(lax.sub(n, y), lax.neg(p))) - log_probs = lax.add(comb_term, log_linear_term) - return jnp.where(lax.ge(k, loc) & lax.lt(k, loc + n + 1), log_probs, -jnp.inf) - - -@implements(osp_stats.nbinom.pmf, update_doc=False) + r"""Binomial log probability mass function. + + JAX implementation of :obj:`scipy.stats.binom` ``logpmf``. + + The binomial probability mass function is defined as + + .. math:: + + f(k, n, p) = {n \choose k}p^k(1-p)^{n-k} + + for :math:`0\le p\le 1` and non-negative integers :math:`k`. + + Args: + k: arraylike, value at which to evaluate the PMF + n: arraylike, distribution shape parameter + p: arraylike, distribution shape parameter + loc: arraylike, distribution offset parameter + + Returns: + array of logpmf values. + + See Also: + :func:`jax.scipy.stats.binom.pmf` + """ + k, n, p, loc = promote_args_inexact("binom.logpmf", k, n, p, loc) + y = lax.sub(k, loc) + comb_term = lax.sub( + gammaln(n + 1), + lax.add(gammaln(y + 1), gammaln(n - y + 1)) + ) + log_linear_term = lax.add(xlogy(y, p), xlog1py(lax.sub(n, y), lax.neg(p))) + log_probs = lax.add(comb_term, log_linear_term) + return jnp.where(lax.ge(k, loc) & lax.lt(k, loc + n + 1), log_probs, -jnp.inf) + + def pmf(k: ArrayLike, n: ArrayLike, p: ArrayLike, loc: ArrayLike = 0) -> Array: - """JAX implementation of scipy.stats.binom.pmf.""" - return lax.exp(logpmf(k, n, p, loc)) + r"""Binomial probability mass function. + + JAX implementation of :obj:`scipy.stats.binom` ``pmf``. + + The binomial probability mass function is defined as + + .. math:: + + f(k, n, p) = {n \choose k}p^k(1-p)^{n-k} + + for :math:`0\le p\le 1` and non-negative integers :math:`k`. + + Args: + k: arraylike, value at which to evaluate the PMF + n: arraylike, distribution shape parameter + p: arraylike, distribution shape parameter + loc: arraylike, distribution offset parameter + + Returns: + array of pmf values. + + See Also: + :func:`jax.scipy.stats.binom.logpmf` + """ + return lax.exp(logpmf(k, n, p, loc)) diff --git a/jax/_src/scipy/stats/cauchy.py b/jax/_src/scipy/stats/cauchy.py index 177cc0dcd197..922cdadb669a 100644 --- a/jax/_src/scipy/stats/cauchy.py +++ b/jax/_src/scipy/stats/cauchy.py @@ -14,17 +14,42 @@ import numpy as np -import scipy.stats as osp_stats from jax import lax from jax._src.lax.lax import _const as _lax_const -from jax._src.numpy.util import implements, promote_args_inexact +from jax._src.numpy.util import promote_args_inexact from jax.numpy import arctan from jax._src.typing import Array, ArrayLike -@implements(osp_stats.cauchy.logpdf, update_doc=False) def logpdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: + r"""Cauchy log probability distribution function. + + JAX implementation of :obj:`scipy.stats.cauchy` ``logpdf``. + + The Cauchy probability distribution function is + + .. math:: + + f(x) = \frac{1}{\pi(1 + x^2)} + + Args: + x: arraylike, value at which to evaluate the PDF + loc: arraylike, distribution offset parameter + scale: arraylike, distribution scale parameter + + Returns: + array of logpdf values + + See Also: + - :func:`jax.scipy.stats.cauchy.cdf` + - :func:`jax.scipy.stats.cauchy.pdf` + - :func:`jax.scipy.stats.cauchy.sf` + - :func:`jax.scipy.stats.cauchy.logcdf` + - :func:`jax.scipy.stats.cauchy.logsf` + - :func:`jax.scipy.stats.cauchy.isf` + - :func:`jax.scipy.stats.cauchy.ppf` + """ x, loc, scale = promote_args_inexact("cauchy.logpdf", x, loc, scale) pi = _lax_const(x, np.pi) scaled_x = lax.div(lax.sub(x, loc), scale) @@ -32,39 +57,203 @@ def logpdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: return lax.neg(lax.add(normalize_term, lax.log1p(lax.mul(scaled_x, scaled_x)))) -@implements(osp_stats.cauchy.pdf, update_doc=False) def pdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: - return lax.exp(logpdf(x, loc, scale)) + r"""Cauchy probability distribution function. + JAX implementation of :obj:`scipy.stats.cauchy` ``pdf``. + + The Cauchy probability distribution function is + + .. math:: + + f(x) = \frac{1}{\pi(1 + x^2)} + + Args: + x: arraylike, value at which to evaluate the PDF + loc: arraylike, distribution offset parameter + scale: arraylike, distribution scale parameter + + Returns: + array of pdf values + + See Also: + - :func:`jax.scipy.stats.cauchy.cdf` + - :func:`jax.scipy.stats.cauchy.sf` + - :func:`jax.scipy.stats.cauchy.logcdf` + - :func:`jax.scipy.stats.cauchy.logpdf` + - :func:`jax.scipy.stats.cauchy.logsf` + - :func:`jax.scipy.stats.cauchy.isf` + - :func:`jax.scipy.stats.cauchy.ppf` + """ + return lax.exp(logpdf(x, loc, scale)) -@implements(osp_stats.cauchy.cdf, update_doc=False) def cdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: + r"""Cauchy cumulative distribution function. + + JAX implementation of :obj:`scipy.stats.cauchy` ``cdf``. + + The cdf is defined as + + .. math:: + + f_{cdf} = \int_{-\infty}^x f_{pdf}(y) \mathrm{d}y + + where here :math:`f_{pdf}` is the Cauchy probability distribution function, + :func:`jax.scipy.stats.cauchy.pdf`. + + Args: + x: arraylike, value at which to evaluate the CDF + loc: arraylike, distribution offset parameter + scale: arraylike, distribution scale parameter + + Returns: + array of cdf values. + + See Also: + - :func:`jax.scipy.stats.cauchy.pdf` + - :func:`jax.scipy.stats.cauchy.sf` + - :func:`jax.scipy.stats.cauchy.logcdf` + - :func:`jax.scipy.stats.cauchy.logpdf` + - :func:`jax.scipy.stats.cauchy.logsf` + - :func:`jax.scipy.stats.cauchy.isf` + - :func:`jax.scipy.stats.cauchy.ppf` + """ x, loc, scale = promote_args_inexact("cauchy.cdf", x, loc, scale) pi = _lax_const(x, np.pi) scaled_x = lax.div(lax.sub(x, loc), scale) return lax.add(_lax_const(x, 0.5), lax.mul(lax.div(_lax_const(x, 1.), pi), arctan(scaled_x))) -@implements(osp_stats.cauchy.logcdf, update_doc=False) def logcdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: + r"""Cauchy log cumulative distribution function. + + JAX implementation of :obj:`scipy.stats.cauchy` ``logcdf`` + + The cdf is defined as + + .. math:: + + f_{cdf} = \int_{-\infty}^x f_{pdf}(y) \mathrm{d}y + + where here :math:`f_{pdf}` is the Cauchy probability distribution function, + :func:`jax.scipy.stats.cauchy.pdf`. + + Args: + x: arraylike, value at which to evaluate the CDF + loc: arraylike, distribution offset parameter + scale: arraylike, distribution scale parameter + + Returns: + array of logcdf values. + + See Also: + - :func:`jax.scipy.stats.cauchy.cdf` + - :func:`jax.scipy.stats.cauchy.pdf` + - :func:`jax.scipy.stats.cauchy.sf` + - :func:`jax.scipy.stats.cauchy.logpdf` + - :func:`jax.scipy.stats.cauchy.logsf` + - :func:`jax.scipy.stats.cauchy.isf` + - :func:`jax.scipy.stats.cauchy.ppf` + """ return lax.log(cdf(x, loc, scale)) -@implements(osp_stats.cauchy.sf, update_doc=False) def sf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: + r"""Cauchy distribution log survival function. + + JAX implementation of :obj:`scipy.stats.cauchy` ``sf``. + + The survival function is defined as + + .. math:: + + f_{sf}(x) = 1 - f_{cdf}(x) + + where :math:`f_{cdf}(x)` is the cumulative distribution function, + :func:`jax.scipy.stats.cauchy.cdf`. + + Args: + x: arraylike, value at which to evaluate the SF + loc: arraylike, distribution offset parameter + scale: arraylike, distribution scale parameter + + Returns: + array of sf values + + See Also: + - :func:`jax.scipy.stats.cauchy.cdf` + - :func:`jax.scipy.stats.cauchy.pdf` + - :func:`jax.scipy.stats.cauchy.logcdf` + - :func:`jax.scipy.stats.cauchy.logpdf` + - :func:`jax.scipy.stats.cauchy.logsf` + - :func:`jax.scipy.stats.cauchy.isf` + - :func:`jax.scipy.stats.cauchy.ppf` + """ x, loc, scale = promote_args_inexact("cauchy.sf", x, loc, scale) return cdf(-x, -loc, scale) -@implements(osp_stats.cauchy.logsf, update_doc=False) def logsf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: + r"""Cauchy distribution log survival function. + + JAX implementation of :obj:`scipy.stats.cauchy` ``logsf`` + + The survival function is defined as + + .. math:: + + f_{sf}(x) = 1 - f_{cdf}(x) + + where :math:`f_{cdf}(x)` is the cumulative distribution function, + :func:`jax.scipy.stats.cauchy.cdf`. + + Args: + x: arraylike, value at which to evaluate the SF + loc: arraylike, distribution offset parameter + scale: arraylike, distribution scale parameter + + Returns: + array of logsf values. + + See Also: + - :func:`jax.scipy.stats.cauchy.cdf` + - :func:`jax.scipy.stats.cauchy.pdf` + - :func:`jax.scipy.stats.cauchy.sf` + - :func:`jax.scipy.stats.cauchy.logcdf` + - :func:`jax.scipy.stats.cauchy.logpdf` + - :func:`jax.scipy.stats.cauchy.isf` + - :func:`jax.scipy.stats.cauchy.ppf` + """ x, loc, scale = promote_args_inexact("cauchy.logsf", x, loc, scale) return logcdf(-x, -loc, scale) -@implements(osp_stats.cauchy.isf, update_doc=False) def isf(q: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: + r"""Cauchy distribution inverse survival function. + + JAX implementation of :obj:`scipy.stats.cauchy` ``isf``. + + Returns the inverse of the survival function, + :func:`jax.scipy.stats.cauchy.sf`. + + Args: + q: arraylike, value at which to evaluate the ISF + loc: arraylike, distribution offset parameter + scale: arraylike, distribution scale parameter + + Returns: + array of isf values. + + See Also: + - :func:`jax.scipy.stats.cauchy.cdf` + - :func:`jax.scipy.stats.cauchy.pdf` + - :func:`jax.scipy.stats.cauchy.sf` + - :func:`jax.scipy.stats.cauchy.logcdf` + - :func:`jax.scipy.stats.cauchy.logpdf` + - :func:`jax.scipy.stats.cauchy.logsf` + - :func:`jax.scipy.stats.cauchy.ppf` + """ q, loc, scale = promote_args_inexact("cauchy.isf", q, loc, scale) pi = _lax_const(q, np.pi) half_pi = _lax_const(q, np.pi / 2) @@ -72,8 +261,31 @@ def isf(q: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: return lax.add(lax.mul(unscaled, scale), loc) -@implements(osp_stats.cauchy.ppf, update_doc=False) def ppf(q: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: + r"""Cauchy distribution percent point function. + + JAX implementation of :obj:`scipy.stats.cauchy` ``ppf``. + + The percent point function is defined as the inverse of the + cumulative distribution function, :func:`jax.scipy.stats.cauchy.cdf`. + + Args: + q: arraylike, value at which to evaluate the PPF + loc: arraylike, distribution offset parameter + scale: arraylike, distribution scale parameter + + Returns: + array of ppf values. + + See Also: + - :func:`jax.scipy.stats.cauchy.cdf` + - :func:`jax.scipy.stats.cauchy.pdf` + - :func:`jax.scipy.stats.cauchy.sf` + - :func:`jax.scipy.stats.cauchy.logcdf` + - :func:`jax.scipy.stats.cauchy.logpdf` + - :func:`jax.scipy.stats.cauchy.logsf` + - :func:`jax.scipy.stats.cauchy.isf` + """ q, loc, scale = promote_args_inexact("cauchy.ppf", q, loc, scale) pi = _lax_const(q, np.pi) half_pi = _lax_const(q, np.pi / 2) diff --git a/jax/_src/scipy/stats/chi2.py b/jax/_src/scipy/stats/chi2.py index 8058d49d7c9e..6637104e2123 100644 --- a/jax/_src/scipy/stats/chi2.py +++ b/jax/_src/scipy/stats/chi2.py @@ -12,19 +12,48 @@ # See the License for the specific language governing permissions and # limitations under the License - -import scipy.stats as osp_stats - from jax import lax import jax.numpy as jnp from jax._src.lax.lax import _const as _lax_const -from jax._src.numpy.util import implements, promote_args_inexact +from jax._src.numpy.util import promote_args_inexact from jax._src.typing import Array, ArrayLike from jax.scipy.special import gammainc, gammaincc -@implements(osp_stats.chi2.logpdf, update_doc=False) def logpdf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: + r"""Chi-square log probability distribution function. + + JAX implementation of :obj:`scipy.stats.chi2` ``logpdf``. + + The chi-square probability distribution function is given by: + + .. math:: + + f(x, k) = \begin{cases} + \frac{x^{k/2-1}e^{-x/2}}{2^{k/2}\Gamma(k/2)} & x \ge 0 \\ + 0 & \mathrm{otherwise} + \end{cases} + + for :math:`k` degrees of freedom, and where :math:`\Gamma` is the + :func:`~jax.scipy.special.gamma` function. JAX follows the scipy + convention of using ``df`` to denote degrees of freedom. + + Args: + x: arraylike, value at which to evaluate the PDF + df: arraylike, distribution shape parameter + loc: arraylike, distribution offset parameter + scale: arraylike, distribution scale parameter + + Returns: + array of logpdf values. + + See Also: + - :func:`jax.scipy.stats.chi2.cdf` + - :func:`jax.scipy.stats.chi2.pdf` + - :func:`jax.scipy.stats.chi2.sf` + - :func:`jax.scipy.stats.chi2.logcdf` + - :func:`jax.scipy.stats.chi2.logsf` + """ x, df, loc, scale = promote_args_inexact("chi2.logpdf", x, df, loc, scale) one = _lax_const(x, 1) two = _lax_const(x, 2) @@ -38,13 +67,75 @@ def logpdf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1 log_probs = lax.add(lax.sub(nrml_cnst, lax.log(scale)), kernel) return jnp.where(lax.lt(x, loc), -jnp.inf, log_probs) -@implements(osp_stats.chi2.pdf, update_doc=False) + def pdf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: + r"""Chi-square probability distribution function. + + JAX implementation of :obj:`scipy.stats.chi2` ``pdf``. + + The chi-square probability distribution function is given by: + + .. math:: + + f(x, k) = \begin{cases} + \frac{x^{k/2-1}e^{-x/2}}{2^{k/2}\Gamma(k/2)} & x \ge 0 \\ + 0 & \mathrm{otherwise} + \end{cases} + + for :math:`k` degrees of freedom, and where :math:`\Gamma` is the + :func:`~jax.scipy.special.gamma` function. JAX follows the scipy + convention of using ``df`` to denote degrees of freedom. + + Args: + x: arraylike, value at which to evaluate the PDF + df: arraylike, distribution shape parameter + loc: arraylike, distribution offset parameter + scale: arraylike, distribution scale parameter + + Returns: + array of pdf values. + + See Also: + - :func:`jax.scipy.stats.chi2.cdf` + - :func:`jax.scipy.stats.chi2.sf` + - :func:`jax.scipy.stats.chi2.logcdf` + - :func:`jax.scipy.stats.chi2.logpdf` + - :func:`jax.scipy.stats.chi2.logsf` + """ return lax.exp(logpdf(x, df, loc, scale)) -@implements(osp_stats.chi2.cdf, update_doc=False) def cdf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: + r"""Chi-square cumulative distribution function. + + JAX implementation of :obj:`scipy.stats.chi2` ``cdf``. + + The cdf is defined as + + .. math:: + + f_{cdf}(x, k) = \int_{-\infty}^x f_{pdf}(y, k)\mathrm{d}y + + where :math:`f_{pdf}` is the probability density function, + :func:`jax.scipy.stats.chi2.pdf`. JAX follows the scipy + convention of using ``df`` to denote degrees of freedom. + + Args: + x: arraylike, value at which to evaluate the CDF + df: arraylike, distribution shape parameter + loc: arraylike, distribution offset parameter + scale: arraylike, distribution scale parameter + + Returns: + array of cdf values. + + See Also: + - :func:`jax.scipy.stats.chi2.pdf` + - :func:`jax.scipy.stats.chi2.sf` + - :func:`jax.scipy.stats.chi2.logcdf` + - :func:`jax.scipy.stats.chi2.logpdf` + - :func:`jax.scipy.stats.chi2.logsf` + """ x, df, loc, scale = promote_args_inexact("chi2.cdf", x, df, loc, scale) two = _lax_const(scale, 2) return gammainc( @@ -60,13 +151,71 @@ def cdf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) - ) -@implements(osp_stats.chi2.logcdf, update_doc=False) def logcdf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: + r"""Chi-square log cumulative distribution function. + + JAX implementation of :obj:`scipy.stats.chi2` ``logcdf``. + + The cdf is defined as + + .. math:: + + f_{cdf}(x, k) = \int_{-\infty}^x f_{pdf}(y, k)\mathrm{d}y + + where :math:`f_{pdf}` is the probability density function, + :func:`jax.scipy.stats.chi2.pdf`. JAX follows the scipy + convention of using ``df`` to denote degrees of freedom. + + Args: + x: arraylike, value at which to evaluate the CDF + df: arraylike, distribution shape parameter + loc: arraylike, distribution offset parameter + scale: arraylike, distribution scale parameter + + Returns: + array of logcdf values + + See Also: + - :func:`jax.scipy.stats.chi2.cdf` + - :func:`jax.scipy.stats.chi2.pdf` + - :func:`jax.scipy.stats.chi2.sf` + - :func:`jax.scipy.stats.chi2.logpdf` + - :func:`jax.scipy.stats.chi2.logsf` + """ return lax.log(cdf(x, df, loc, scale)) -@implements(osp_stats.chi2.sf, update_doc=False) def sf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: + r"""Chi-square survival function. + + JAX implementation of :obj:`scipy.stats.chi2` ``sf``. + + The survival function is defined as + + .. math:: + + f_{sf}(x, k) = 1 - f_{cdf}(x, k) + + where :math:`f_{cdf}(x, k)` is the cumulative distribution function, + :func:`jax.scipy.stats.chi2.cdf`. JAX follows the scipy + convention of using ``df`` to denote degrees of freedom. + + Args: + x: arraylike, value at which to evaluate the SF + df: arraylike, distribution shape parameter + loc: arraylike, distribution offset parameter + scale: arraylike, distribution scale parameter + + Returns: + array of sf values. + + See Also: + - :func:`jax.scipy.stats.chi2.cdf` + - :func:`jax.scipy.stats.chi2.pdf` + - :func:`jax.scipy.stats.chi2.logcdf` + - :func:`jax.scipy.stats.chi2.logpdf` + - :func:`jax.scipy.stats.chi2.logsf` + """ x, df, loc, scale = promote_args_inexact("chi2.sf", x, df, loc, scale) two = _lax_const(scale, 2) return gammaincc( @@ -82,6 +231,35 @@ def sf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> ) -@implements(osp_stats.chi2.logsf, update_doc=False) def logsf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: + r"""Chi-square log survival function. + + JAX implementation of :obj:`scipy.stats.chi2` ``logsf``. + + The survival function is defined as + + .. math:: + + f_{sf}(x, k) = 1 - f_{cdf}(x, k) + + where :math:`f_{cdf}(x, k)` is the cumulative distribution function, + :func:`jax.scipy.stats.chi2.cdf`. JAX follows the scipy + convention of using ``df`` to denote degrees of freedom. + + Args: + x: arraylike, value at which to evaluate the SF + df: arraylike, distribution shape parameter + loc: arraylike, distribution offset parameter + scale: arraylike, distribution scale parameter + + Returns: + array of logsf values. + + See Also: + - :func:`jax.scipy.stats.chi2.cdf` + - :func:`jax.scipy.stats.chi2.pdf` + - :func:`jax.scipy.stats.chi2.sf` + - :func:`jax.scipy.stats.chi2.logcdf` + - :func:`jax.scipy.stats.chi2.logpdf` + """ return lax.log(sf(x, df, loc, scale)) diff --git a/jax/_src/scipy/stats/dirichlet.py b/jax/_src/scipy/stats/dirichlet.py index f8b5705f8118..ee28c7e3ea59 100644 --- a/jax/_src/scipy/stats/dirichlet.py +++ b/jax/_src/scipy/stats/dirichlet.py @@ -12,13 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. - -import scipy.stats as osp_stats - from jax import lax import jax.numpy as jnp from jax._src.lax.lax import _const as _lax_const -from jax._src.numpy.util import promote_dtypes_inexact, implements +from jax._src.numpy.util import promote_dtypes_inexact from jax.scipy.special import gammaln, xlogy from jax._src.typing import Array, ArrayLike @@ -28,8 +25,30 @@ def _is_simplex(x: Array) -> Array: return jnp.all(x > 0, axis=0) & (abs(x_sum - 1) < 1E-6) -@implements(osp_stats.dirichlet.logpdf, update_doc=False) def logpdf(x: ArrayLike, alpha: ArrayLike) -> Array: + r"""Dirichlet log probability distribution function. + + JAX implementation of :obj:`scipy.stats.dirichlet` ``logpdf``. + + The Dirichlet probability density function is + + .. math:: + + f(\mathbf{x}) = \frac{1}{B(\mathbf{\alpha})} \prod_{i=1}^K x_i^{\alpha_i - 1} + + where :math:`B(\mathbf{\alpha})` is the :func:`~jax.scipy.special.beta` function + in a :math:`K`-dimensional vector space. + + Args: + x: arraylike, value at which to evaluate the PDF + alpha: arraylike, distribution shape parameter + + Returns: + array of logpdf values. + + See Also: + :func:`jax.scipy.stats.dirichlet.pdf` + """ return _logpdf(*promote_dtypes_inexact(x, alpha)) def _logpdf(x: Array, alpha: Array) -> Array: @@ -52,6 +71,28 @@ def _logpdf(x: Array, alpha: Array) -> Array: return jnp.where(_is_simplex(x), log_probs, -jnp.inf) -@implements(osp_stats.dirichlet.pdf, update_doc=False) def pdf(x: ArrayLike, alpha: ArrayLike) -> Array: + r"""Dirichlet probability distribution function. + + JAX implementation of :obj:`scipy.stats.dirichlet` ``pdf``. + + The Dirichlet probability density function is + + .. math:: + + f(\mathbf{x}) = \frac{1}{B(\mathbf{\alpha})} \prod_{i=1}^K x_i^{\alpha_i - 1} + + where :math:`B(\mathbf{\alpha})` is the :func:`~jax.scipy.special.beta` function + in a :math:`K`-dimensional vector space. + + Args: + x: arraylike, value at which to evaluate the PDF + alpha: arraylike, distribution shape parameter + + Returns: + array of pdf values. + + See Also: + :func:`jax.scipy.stats.dirichlet.logpdf` + """ return lax.exp(logpdf(x, alpha)) diff --git a/jax/_src/scipy/stats/expon.py b/jax/_src/scipy/stats/expon.py index 0b2ff0ea4058..b09c52e97272 100644 --- a/jax/_src/scipy/stats/expon.py +++ b/jax/_src/scipy/stats/expon.py @@ -12,22 +12,67 @@ # See the License for the specific language governing permissions and # limitations under the License. -import scipy.stats as osp_stats - from jax import lax import jax.numpy as jnp -from jax._src.numpy.util import implements, promote_args_inexact +from jax._src.numpy.util import promote_args_inexact from jax._src.typing import Array, ArrayLike -@implements(osp_stats.expon.logpdf, update_doc=False) def logpdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: + r"""Exponential log probability distribution function. + + JAX implementation of :obj:`scipy.stats.expon` ``logpdf``. + + The Exponential probability distribution function is + + .. math:: + + f(x) = \begin{cases} + e^{-x} & x \ge 0 \\ + 0 & \mathrm{otherwise} + \end{cases} + + Args: + x: arraylike, value at which to evaluate the PDF + loc: arraylike, distribution offset parameter + scale: arraylike, distribution scale parameter + + Returns: + array of logpdf values. + + See Also: + :func:`jax.scipy.stats.expon.pdf` + """ x, loc, scale = promote_args_inexact("expon.logpdf", x, loc, scale) log_scale = lax.log(scale) linear_term = lax.div(lax.sub(x, loc), scale) log_probs = lax.neg(lax.add(linear_term, log_scale)) return jnp.where(lax.lt(x, loc), -jnp.inf, log_probs) -@implements(osp_stats.expon.pdf, update_doc=False) + def pdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: + r"""Exponential probability distribution function. + + JAX implementation of :obj:`scipy.stats.expon` ``pdf``. + + The Exponential probability distribution function is + + .. math:: + + f(x) = \begin{cases} + e^{-x} & x \ge 0 \\ + 0 & \mathrm{otherwise} + \end{cases} + + Args: + x: arraylike, value at which to evaluate the PDF + loc: arraylike, distribution offset parameter + scale: arraylike, distribution scale parameter + + Returns: + array of pdf values. + + See Also: + :func:`jax.scipy.stats.expon.logpdf` + """ return lax.exp(logpdf(x, loc, scale)) diff --git a/jax/_src/scipy/stats/gamma.py b/jax/_src/scipy/stats/gamma.py index d63429021566..f410d08e4f3d 100644 --- a/jax/_src/scipy/stats/gamma.py +++ b/jax/_src/scipy/stats/gamma.py @@ -12,18 +12,44 @@ # See the License for the specific language governing permissions and # limitations under the License. -import scipy.stats as osp_stats - from jax import lax import jax.numpy as jnp from jax._src.lax.lax import _const as _lax_const -from jax._src.numpy.util import implements, promote_args_inexact +from jax._src.numpy.util import promote_args_inexact from jax._src.typing import Array, ArrayLike from jax.scipy.special import gammaln, xlogy, gammainc, gammaincc -@implements(osp_stats.gamma.logpdf, update_doc=False) def logpdf(x: ArrayLike, a: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: + r"""Gamma log probability distribution function. + + JAX implementation of :obj:`scipy.stats.gamma` ``logpdf``. + + The Gamma probability distribution is given by + + .. math:: + + f(x, a) = \frac{1}{\Gamma(a)}x^{a-1}e^{-x} + + Where :math:`\Gamma(a)` is the :func:`~jax.scipy.special.gamma` function. + It is defined for :math:`x \ge 0` and :math:`a > 0`. + + Args: + x: arraylike, value at which to evaluate the PDF + a: arraylike, distribution shape parameter + loc: arraylike, distribution offset parameter + scale: arraylike, distribution scale parameter + + Returns: + array of logpdf values. + + See Also: + - :func:`jax.scipy.stats.gamma.cdf` + - :func:`jax.scipy.stats.gamma.pdf` + - :func:`jax.scipy.stats.gamma.sf` + - :func:`jax.scipy.stats.gamma.logcdf` + - :func:`jax.scipy.stats.gamma.logsf` + """ x, a, loc, scale = promote_args_inexact("gamma.logpdf", x, a, loc, scale) one = _lax_const(x, 1) y = lax.div(lax.sub(x, loc), scale) @@ -32,13 +58,70 @@ def logpdf(x: ArrayLike, a: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) log_probs = lax.sub(log_linear_term, shape_terms) return jnp.where(lax.lt(x, loc), -jnp.inf, log_probs) -@implements(osp_stats.gamma.pdf, update_doc=False) + def pdf(x: ArrayLike, a: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: + r"""Gamma probability distribution function. + + JAX implementation of :obj:`scipy.stats.gamma` ``pdf``. + + The Gamma probability distribution is given by + + .. math:: + + f(x, a) = \frac{1}{\Gamma(a)}x^{a-1}e^{-x} + + Where :math:`\Gamma(a)` is the :func:`~jax.scipy.special.gamma` function. + It is defined for :math:`x \ge 0` and :math:`a > 0`. + + Args: + x: arraylike, value at which to evaluate the PDF + a: arraylike, distribution shape parameter + loc: arraylike, distribution offset parameter + scale: arraylike, distribution scale parameter + + Returns: + array of pdf values. + + See Also: + - :func:`jax.scipy.stats.gamma.cdf` + - :func:`jax.scipy.stats.gamma.sf` + - :func:`jax.scipy.stats.gamma.logcdf` + - :func:`jax.scipy.stats.gamma.logpdf` + - :func:`jax.scipy.stats.gamma.logsf` + """ return lax.exp(logpdf(x, a, loc, scale)) -@implements(osp_stats.gamma.cdf, update_doc=False) def cdf(x: ArrayLike, a: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: + r"""Gamma cumulative distribution function. + + JAX implementation of :obj:`scipy.stats.gamma` ``cdf``. + + The cdf is defined as + + .. math:: + + f_{cdf}(x, a) = \int_{-\infty}^x f_{pdf}(y, a)\mathrm{d}y + + where :math:`f_{pdf}` is the probability density function, + :func:`jax.scipy.stats.gamma.pdf`. + + Args: + x: arraylike, value at which to evaluate the CDF + a: arraylike, distribution shape parameter + loc: arraylike, distribution offset parameter + scale: arraylike, distribution scale parameter + + Returns: + array of cdf values. + + See Also: + - :func:`jax.scipy.stats.gamma.pdf` + - :func:`jax.scipy.stats.gamma.sf` + - :func:`jax.scipy.stats.gamma.logcdf` + - :func:`jax.scipy.stats.gamma.logpdf` + - :func:`jax.scipy.stats.gamma.logsf` + """ x, a, loc, scale = promote_args_inexact("gamma.cdf", x, a, loc, scale) return gammainc( a, @@ -50,17 +133,101 @@ def cdf(x: ArrayLike, a: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> ) -@implements(osp_stats.gamma.logcdf, update_doc=False) def logcdf(x: ArrayLike, a: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: + r"""Gamma log cumulative distribution function. + + JAX implementation of :obj:`scipy.stats.gamma` ``logcdf``. + + The cdf is defined as + + .. math:: + + f_{cdf}(x, a) = \int_{-\infty}^x f_{pdf}(y, a)\mathrm{d}y + + where :math:`f_{pdf}` is the probability density function, + :func:`jax.scipy.stats.gamma.pdf`. + + Args: + x: arraylike, value at which to evaluate the CDF + a: arraylike, distribution shape parameter + loc: arraylike, distribution offset parameter + scale: arraylike, distribution scale parameter + + Returns: + array of logcdf values. + + See Also: + - :func:`jax.scipy.stats.gamma.cdf` + - :func:`jax.scipy.stats.gamma.pdf` + - :func:`jax.scipy.stats.gamma.sf` + - :func:`jax.scipy.stats.gamma.logpdf` + - :func:`jax.scipy.stats.gamma.logsf` + """ return lax.log(cdf(x, a, loc, scale)) -@implements(osp_stats.gamma.sf, update_doc=False) def sf(x: ArrayLike, a: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: + r"""Gamma survival function. + + JAX implementation of :obj:`scipy.stats.gamma` ``sf``. + + The survival function is defined as + + .. math:: + + f_{sf}(x, k) = 1 - f_{cdf}(x, k) + + where :math:`f_{cdf}(x, k)` is the cumulative distribution function, + :func:`jax.scipy.stats.gamma.cdf`. + + Args: + x: arraylike, value at which to evaluate the SF + a: arraylike, distribution shape parameter + loc: arraylike, distribution offset parameter + scale: arraylike, distribution scale parameter + + Returns: + array of sf values. + + See Also: + - :func:`jax.scipy.stats.gamma.cdf` + - :func:`jax.scipy.stats.gamma.pdf` + - :func:`jax.scipy.stats.gamma.logcdf` + - :func:`jax.scipy.stats.gamma.logpdf` + - :func:`jax.scipy.stats.gamma.logsf` + """ x, a, loc, scale = promote_args_inexact("gamma.sf", x, a, loc, scale) return gammaincc(a, lax.div(lax.sub(x, loc), scale)) -@implements(osp_stats.gamma.logsf, update_doc=False) def logsf(x: ArrayLike, a: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: + r"""Gamma log survival function. + + JAX implementation of :obj:`scipy.stats.gamma` ``logsf``. + + The survival function is defined as + + .. math:: + + f_{sf}(x, k) = 1 - f_{cdf}(x, k) + + where :math:`f_{cdf}(x, k)` is the cumulative distribution function, + :func:`jax.scipy.stats.gamma.cdf`. + + Args: + x: arraylike, value at which to evaluate the SF + a: arraylike, distribution shape parameter + loc: arraylike, distribution offset parameter + scale: arraylike, distribution scale parameter + + Returns: + array of logsf values. + + See Also: + - :func:`jax.scipy.stats.gamma.cdf` + - :func:`jax.scipy.stats.gamma.pdf` + - :func:`jax.scipy.stats.gamma.sf` + - :func:`jax.scipy.stats.gamma.logcdf` + - :func:`jax.scipy.stats.gamma.logpdf` + """ return lax.log(sf(x, a, loc, scale)) diff --git a/jax/_src/scipy/stats/gennorm.py b/jax/_src/scipy/stats/gennorm.py index 4b89a25289bf..9d24708c066a 100644 --- a/jax/_src/scipy/stats/gennorm.py +++ b/jax/_src/scipy/stats/gennorm.py @@ -12,22 +12,92 @@ # See the License for the specific language governing permissions and # limitations under the License. -import scipy.stats as osp_stats from jax import lax -from jax._src.numpy.util import implements, promote_args_inexact +from jax._src.numpy.util import promote_args_inexact from jax._src.typing import Array, ArrayLike -@implements(osp_stats.gennorm.logpdf, update_doc=False) -def logpdf(x: ArrayLike, p: ArrayLike) -> Array: - x, p = promote_args_inexact("gennorm.logpdf", x, p) - return lax.log(.5 * p) - lax.lgamma(1/p) - lax.abs(x)**p +def logpdf(x: ArrayLike, beta: ArrayLike) -> Array: + r"""Generalized normal log probability distribution function. -@implements(osp_stats.gennorm.cdf, update_doc=False) -def cdf(x: ArrayLike, p: ArrayLike) -> Array: - x, p = promote_args_inexact("gennorm.cdf", x, p) - return .5 * (1 + lax.sign(x) * lax.igamma(1/p, lax.abs(x)**p)) + JAX implementation of :obj:`scipy.stats.gennorm` ``logpdf``. -@implements(osp_stats.gennorm.pdf, update_doc=False) -def pdf(x: ArrayLike, p: ArrayLike) -> Array: - return lax.exp(logpdf(x, p)) + The generalized normal probability distribution function is defined as + + .. math:: + + f(x, \beta) = \frac{\beta}{2\Gamma(1/\beta)}\exp(-|x|^\beta) + + where :math:`\Gamma` is the :func:`~jax.scipy.special.gamma` function, and + :math:`\beta > 0`. + + Args: + x: arraylike, value at which to evaluate the PDF + beta: arraylike, distribution shape parameter + + Returns: + array of logpdf values. + + See Also: + - :func:`jax.scipy.stats.gennorm.cdf` + - :func:`jax.scipy.stats.gennorm.pdf` + """ + x, beta = promote_args_inexact("gennorm.logpdf", x, beta) + return lax.log(.5 * beta) - lax.lgamma(1/beta) - lax.abs(x)**beta + + +def cdf(x: ArrayLike, beta: ArrayLike) -> Array: + r"""Generalized normal cumulative distribution function. + + JAX implementation of :obj:`scipy.stats.gennorm` ``cdf``. + + The cdf is defined as + + .. math:: + + f_{cdf}(x, k) = \int_{-\infty}^x f_{pdf}(y, k)\mathrm{d}y + + where :math:`f_{pdf}` is the probability density function, + :func:`jax.scipy.stats.gennorm.pdf`. + + Args: + x: arraylike, value at which to evaluate the CDF + beta: arraylike, distribution shape parameter + + Returns: + array of cdf values. + + See Also: + - :func:`jax.scipy.stats.gennorm.pdf` + - :func:`jax.scipy.stats.gennorm.logpdf` + """ + x, beta = promote_args_inexact("gennorm.cdf", x, beta) + return .5 * (1 + lax.sign(x) * lax.igamma(1/beta, lax.abs(x)**beta)) + + +def pdf(x: ArrayLike, beta: ArrayLike) -> Array: + r"""Generalized normal probability distribution function. + + JAX implementation of :obj:`scipy.stats.gennorm` ``pdf``. + + The generalized normal probability distribution function is defined as + + .. math:: + + f(x, \beta) = \frac{\beta}{2\Gamma(1/\beta)}\exp(-|x|^\beta) + + where :math:`\Gamma` is the :func:`~jax.scipy.special.gamma` function, and + :math:`\beta > 0`. + + Args: + x: arraylike, value at which to evaluate the PDF + beta: arraylike, distribution shape parameter + + Returns: + array of pdf values. + + See Also: + - :func:`jax.scipy.stats.gennorm.cdf` + - :func:`jax.scipy.stats.gennorm.logpdf` + """ + return lax.exp(logpdf(x, beta)) diff --git a/jax/_src/scipy/stats/geom.py b/jax/_src/scipy/stats/geom.py index 6b59cb31db0a..1d5133f9c3ea 100644 --- a/jax/_src/scipy/stats/geom.py +++ b/jax/_src/scipy/stats/geom.py @@ -12,26 +12,68 @@ # See the License for the specific language governing permissions and # limitations under the License. -import scipy.stats as osp_stats - from jax import lax import jax.numpy as jnp from jax._src.lax.lax import _const as _lax_const -from jax._src.numpy.util import implements, promote_args_inexact +from jax._src.numpy.util import promote_args_inexact from jax.scipy.special import xlog1py from jax._src.typing import Array, ArrayLike -@implements(osp_stats.geom.logpmf, update_doc=False) def logpmf(k: ArrayLike, p: ArrayLike, loc: ArrayLike = 0) -> Array: - k, p, loc = promote_args_inexact("geom.logpmf", k, p, loc) - zero = _lax_const(k, 0) - one = _lax_const(k, 1) - x = lax.sub(k, loc) - log_probs = xlog1py(lax.sub(x, one), -p) + lax.log(p) - return jnp.where(lax.le(x, zero), -jnp.inf, log_probs) + r"""Geometric log probability mass function. + + JAX implementation of :obj:`scipy.stats.geom` ``logpmf``. + + The Geometric probability mass function is given by + + .. math:: + + f(k) = (1 - p)^{k-1}p + + for :math:`k\ge 1` and :math:`0 \le p \le 1`. + + Args: + k: arraylike, value at which to evaluate the PMF + p: arraylike, distribution shape parameter + loc: arraylike, distribution offset parameter + + Returns: + array of logpmf values. + + See Also: + :func:`jax.scipy.stats.geom.pmf` + """ + k, p, loc = promote_args_inexact("geom.logpmf", k, p, loc) + zero = _lax_const(k, 0) + one = _lax_const(k, 1) + x = lax.sub(k, loc) + log_probs = xlog1py(lax.sub(x, one), -p) + lax.log(p) + return jnp.where(lax.le(x, zero), -jnp.inf, log_probs) -@implements(osp_stats.geom.pmf, update_doc=False) def pmf(k: ArrayLike, p: ArrayLike, loc: ArrayLike = 0) -> Array: + r"""Geometric probability mass function. + + JAX implementation of :obj:`scipy.stats.geom` ``pmf``. + + The Geometric probability mass function is given by + + .. math:: + + f(k) = (1 - p)^{k-1}p + + for :math:`k\ge 1` and :math:`0 \le p \le 1`. + + Args: + k: arraylike, value at which to evaluate the PMF + p: arraylike, distribution shape parameter + loc: arraylike, distribution offset parameter + + Returns: + array of pmf values. + + See Also: + :func:`jax.scipy.stats.geom.logpmf` + """ return jnp.exp(logpmf(k, p, loc)) diff --git a/jax/_src/scipy/stats/kde.py b/jax/_src/scipy/stats/kde.py index 516935525ce1..c4c6fc10b132 100644 --- a/jax/_src/scipy/stats/kde.py +++ b/jax/_src/scipy/stats/kde.py @@ -17,19 +17,28 @@ from typing import Any import numpy as np -import scipy.stats as osp_stats import jax.numpy as jnp from jax import jit, lax, random, vmap -from jax._src.numpy.util import check_arraylike, promote_dtypes_inexact, implements +from jax._src.numpy.util import check_arraylike, promote_dtypes_inexact from jax._src.tree_util import register_pytree_node_class from jax.scipy import linalg, special -@implements(osp_stats.gaussian_kde, update_doc=False) @register_pytree_node_class @dataclass(frozen=True, init=False) class gaussian_kde: + """Gaussian Kernel Density Estimator + + JAX implementation of :class:`scipy.stats.gaussian_kde`. + + Parameters: + dataset: arraylike, real-valued. Data from which to estimate the distribution. + If 1D, shape is (n_data,). If 2D, shape is (n_dimensions, n_data). + bw_method: string, scalar, or callable. Either "scott", "silverman", a scalar + value, or a callable function which takes ``self`` as a parameter. + weights: arraylike, optional. Weights of the same shape as the dataset. + """ neff: Any dataset: Any weights: Any @@ -113,20 +122,19 @@ def d(self): def n(self): return self.dataset.shape[1] - @implements(osp_stats.gaussian_kde.evaluate, update_doc=False) def evaluate(self, points): + """Evaluate the Gaussian KDE on the given points.""" check_arraylike("evaluate", points) points = self._reshape_points(points) result = _gaussian_kernel_eval(False, self.dataset.T, self.weights[:, None], points.T, self.inv_cov) return result[:, 0] - @implements(osp_stats.gaussian_kde.__call__, update_doc=False) def __call__(self, points): return self.evaluate(points) - @implements(osp_stats.gaussian_kde.integrate_gaussian, update_doc=False) def integrate_gaussian(self, mean, cov): + """Integrate the distribution weighted by a Gaussian.""" mean = jnp.atleast_1d(jnp.squeeze(mean)) cov = jnp.atleast_2d(cov) @@ -141,8 +149,8 @@ def integrate_gaussian(self, mean, cov): return _gaussian_kernel_convolve(chol, norm, self.dataset, self.weights, mean) - @implements(osp_stats.gaussian_kde.integrate_box_1d, update_doc=False) def integrate_box_1d(self, low, high): + """Integrate the distribution over the given limits.""" if self.d != 1: raise ValueError("integrate_box_1d() only handles 1D pdfs") if jnp.ndim(low) != 0 or jnp.ndim(high) != 0: @@ -153,8 +161,8 @@ def integrate_box_1d(self, low, high): high = jnp.squeeze((high - self.dataset) / sigma) return jnp.sum(self.weights * (special.ndtr(high) - special.ndtr(low))) - @implements(osp_stats.gaussian_kde.integrate_kde, update_doc=False) def integrate_kde(self, other): + """Integrate the product of two Gaussian KDE distributions.""" if other.d != self.d: raise ValueError("KDEs are not the same dimensionality") @@ -189,12 +197,12 @@ def resample(self, key, shape=()): dtype=self.dataset.dtype).T return self.dataset[:, ind] + eps - @implements(osp_stats.gaussian_kde.pdf, update_doc=False) def pdf(self, x): + """Probability density function""" return self.evaluate(x) - @implements(osp_stats.gaussian_kde.logpdf, update_doc=False) def logpdf(self, x): + """Log probability density function""" check_arraylike("logpdf", x) x = self._reshape_points(x) result = _gaussian_kernel_eval(True, self.dataset.T, self.weights[:, None], diff --git a/jax/_src/scipy/stats/laplace.py b/jax/_src/scipy/stats/laplace.py index acd39046dcff..8761a2cb864f 100644 --- a/jax/_src/scipy/stats/laplace.py +++ b/jax/_src/scipy/stats/laplace.py @@ -12,29 +12,93 @@ # See the License for the specific language governing permissions and # limitations under the License. -import scipy.stats as osp_stats - from jax import lax from jax._src.lax.lax import _const as _lax_const -from jax._src.numpy.util import implements, promote_args_inexact +from jax._src.numpy.util import promote_args_inexact from jax._src.typing import Array, ArrayLike -@implements(osp_stats.laplace.logpdf, update_doc=False) def logpdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: + r"""Laplace log probability distribution function. + + JAX implementation of :obj:`scipy.stats.laplace` ``logpdf``. + + The Laplace probability distribution function is given by + + .. math:: + + f(x) = \frac{1}{2} e^{-|x|} + + Args: + x: arraylike, value at which to evaluate the PDF + loc: arraylike, distribution offset parameter + scale: arraylike, distribution scale parameter + + Returns: + array of logpdf values. + + See Also: + - :func:`jax.scipy.stats.laplace.cdf` + - :func:`jax.scipy.stats.laplace.pdf` + """ x, loc, scale = promote_args_inexact("laplace.logpdf", x, loc, scale) two = _lax_const(x, 2) linear_term = lax.div(lax.abs(lax.sub(x, loc)), scale) return lax.neg(lax.add(linear_term, lax.log(lax.mul(two, scale)))) -@implements(osp_stats.laplace.pdf, update_doc=False) def pdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: + r"""Laplace probability distribution function. + + JAX implementation of :obj:`scipy.stats.laplace` ``pdf``. + + The Laplace probability distribution function is given by + + .. math:: + + f(x) = \frac{1}{2} e^{-|x|} + + Args: + x: arraylike, value at which to evaluate the PDF + loc: arraylike, distribution offset parameter + scale: arraylike, distribution scale parameter + + Returns: + array of pdf values. + + See Also: + - :func:`jax.scipy.stats.laplace.cdf` + - :func:`jax.scipy.stats.laplace.logpdf` + """ return lax.exp(logpdf(x, loc, scale)) -@implements(osp_stats.laplace.cdf, update_doc=False) def cdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: + r"""Laplace cumulative distribution function. + + JAX implementation of :obj:`scipy.stats.laplace` ``cdf``. + + The cdf is defined as + + .. math:: + + f_{cdf}(x, k) = \int_{-\infty}^x f_{pdf}(y, k)\mathrm{d}y + + where :math:`f_{pdf}` is the probability density function, + :func:`jax.scipy.stats.laplace.pdf`. + + Args: + x: arraylike, value at which to evaluate the CDF + loc: arraylike, distribution offset parameter + scale: arraylike, distribution scale parameter + + Returns: + array of cdf values. + + See Also: + - :func:`jax.scipy.stats.laplace.pdf` + - :func:`jax.scipy.stats.laplace.logpdf` + """ x, loc, scale = promote_args_inexact("laplace.cdf", x, loc, scale) half = _lax_const(x, 0.5) one = _lax_const(x, 1) diff --git a/jax/_src/scipy/stats/logistic.py b/jax/_src/scipy/stats/logistic.py index b9f7b37b3a00..5e41e4f7c8e5 100644 --- a/jax/_src/scipy/stats/logistic.py +++ b/jax/_src/scipy/stats/logistic.py @@ -12,18 +12,42 @@ # See the License for the specific language governing permissions and # limitations under the License. -import scipy.stats as osp_stats from jax.scipy.special import expit, logit from jax import lax import jax.numpy as jnp from jax._src.lax.lax import _const as _lax_const -from jax._src.numpy.util import implements, promote_args_inexact +from jax._src.numpy.util import promote_args_inexact from jax._src.typing import Array, ArrayLike -@implements(osp_stats.logistic.logpdf, update_doc=False) def logpdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: + r"""Logistic log probability distribution function. + + JAX implementation of :obj:`scipy.stats.logistic` ``logpdf``. + + The logistic probability distribution function is given by + + .. math:: + + f(x) = \frac{e^{-x}}{(1 + e^{-x})^2} + + Args: + x: arraylike, value at which to evaluate the PDF + a: arraylike, distribution shape parameter + loc: arraylike, distribution offset parameter + scale: arraylike, distribution scale parameter + + Returns: + array of logpdf values. + + See Also: + - :func:`jax.scipy.stats.logistic.cdf` + - :func:`jax.scipy.stats.logistic.pdf` + - :func:`jax.scipy.stats.logistic.sf` + - :func:`jax.scipy.stats.logistic.isf` + - :func:`jax.scipy.stats.logistic.ppf` + """ x, loc, scale = promote_args_inexact("logistic.logpdf", x, loc, scale) x = lax.div(lax.sub(x, loc), scale) two = _lax_const(x, 2) @@ -31,30 +55,150 @@ def logpdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: return lax.sub(lax.mul(lax.neg(two), jnp.logaddexp(half_x, lax.neg(half_x))), lax.log(scale)) -@implements(osp_stats.logistic.pdf, update_doc=False) def pdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: + r"""Logistic probability distribution function. + + JAX implementation of :obj:`scipy.stats.logistic` ``pdf``. + + The logistic probability distribution function is given by + + .. math:: + + f(x) = \frac{e^{-x}}{(1 + e^{-x})^2} + + Args: + x: arraylike, value at which to evaluate the PDF + loc: arraylike, distribution offset parameter + scale: arraylike, distribution scale parameter + + Returns: + array of pdf values. + + See Also: + - :func:`jax.scipy.stats.logistic.cdf` + - :func:`jax.scipy.stats.logistic.sf` + - :func:`jax.scipy.stats.logistic.isf` + - :func:`jax.scipy.stats.logistic.logpdf` + - :func:`jax.scipy.stats.logistic.ppf` + """ return lax.exp(logpdf(x, loc, scale)) -@implements(osp_stats.logistic.ppf, update_doc=False) def ppf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: + """Logistic distribution percent point function. + + JAX implementation of :obj:`scipy.stats.logistic` ``ppf``. + + The percent point function is defined as the inverse of the + cumulative distribution function, :func:`jax.scipy.stats.logistic.cdf`. + + Args: + x: arraylike, value at which to evaluate the PPF + loc: arraylike, distribution offset parameter + scale: arraylike, distribution scale parameter + + Returns: + array of ppf values. + + See Also: + - :func:`jax.scipy.stats.logistic.cdf` + - :func:`jax.scipy.stats.logistic.pdf` + - :func:`jax.scipy.stats.logistic.sf` + - :func:`jax.scipy.stats.logistic.isf` + - :func:`jax.scipy.stats.logistic.logpdf` + """ x, loc, scale = promote_args_inexact("logistic.ppf", x, loc, scale) return lax.add(lax.mul(logit(x), scale), loc) -@implements(osp_stats.logistic.sf, update_doc=False) def sf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: + """Logistic distribution survival function. + + JAX implementation of :obj:`scipy.stats.logistic` ``sf`` + + The survival function is defined as + + .. math:: + + f_{sf}(x, k) = 1 - f_{cdf}(x, k) + + where :math:`f_{cdf}(x, k)` is the cumulative distribution function, + :func:`jax.scipy.stats.logistic.cdf`. + + Args: + x: arraylike, value at which to evaluate the SF + loc: arraylike, distribution offset parameter + scale: arraylike, distribution scale parameter + + Returns: + array of sf values. + + See Also: + - :func:`jax.scipy.stats.logistic.cdf` + - :func:`jax.scipy.stats.logistic.pdf` + - :func:`jax.scipy.stats.logistic.isf` + - :func:`jax.scipy.stats.logistic.logpdf` + - :func:`jax.scipy.stats.logistic.ppf` + """ x, loc, scale = promote_args_inexact("logistic.sf", x, loc, scale) return expit(lax.neg(lax.div(lax.sub(x, loc), scale))) -@implements(osp_stats.logistic.isf, update_doc=False) def isf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: + """Logistic distribution inverse survival function. + + JAX implementation of :obj:`scipy.stats.logistic` ``isf``. + + Returns the inverse of the survival function, + :func:`jax.scipy.stats.logistic.sf`. + + Args: + x: arraylike, value at which to evaluate the ISF + loc: arraylike, distribution offset parameter + scale: arraylike, distribution scale parameter + + Returns: + array of isf values. + + See Also: + - :func:`jax.scipy.stats.logistic.cdf` + - :func:`jax.scipy.stats.logistic.pdf` + - :func:`jax.scipy.stats.logistic.sf` + - :func:`jax.scipy.stats.logistic.logpdf` + - :func:`jax.scipy.stats.logistic.ppf` + """ x, loc, scale = promote_args_inexact("logistic.isf", x, loc, scale) return lax.add(lax.mul(lax.neg(logit(x)), scale), loc) -@implements(osp_stats.logistic.cdf, update_doc=False) def cdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: + r"""Logistic cumulative distribution function. + + JAX implementation of :obj:`scipy.stats.logistic` ``cdf``. + + The cdf is defined as + + .. math:: + + f_{cdf}(x, k) = \int_{-\infty}^x f_{pdf}(y, k)\mathrm{d}y + + where :math:`f_{pdf}` is the probability density function, + :func:`jax.scipy.stats.logistic.pdf`. + + Args: + x: arraylike, value at which to evaluate the CDF + loc: arraylike, distribution offset parameter + scale: arraylike, distribution scale parameter + + Returns: + array of cdf values. + + See Also: + - :func:`jax.scipy.stats.logistic.pdf` + - :func:`jax.scipy.stats.logistic.sf` + - :func:`jax.scipy.stats.logistic.isf` + - :func:`jax.scipy.stats.logistic.logpdf` + - :func:`jax.scipy.stats.logistic.ppf` + """ x, loc, scale = promote_args_inexact("logistic.cdf", x, loc, scale) return expit(lax.div(lax.sub(x, loc), scale)) diff --git a/jax/_src/scipy/stats/multinomial.py b/jax/_src/scipy/stats/multinomial.py index 150573ad7db0..fe9fd6781423 100644 --- a/jax/_src/scipy/stats/multinomial.py +++ b/jax/_src/scipy/stats/multinomial.py @@ -13,17 +13,37 @@ # limitations under the License. -import scipy.stats as osp_stats from jax import lax import jax.numpy as jnp -from jax._src.numpy.util import implements, promote_args_inexact, promote_args_numeric +from jax._src.numpy.util import promote_args_inexact, promote_args_numeric from jax._src.scipy.special import gammaln, xlogy from jax._src.typing import Array, ArrayLike -@implements(osp_stats.multinomial.logpmf, update_doc=False) def logpmf(x: ArrayLike, n: ArrayLike, p: ArrayLike) -> Array: - """JAX implementation of scipy.stats.multinomial.logpmf.""" + r"""Multinomial log probability mass function. + + JAX implementation of :obj:`scipy.stats.multinomial` ``logpdf``. + + The multinomial probability distribution is given by + + .. math:: + + f(x, n, p) = n! \prod_{i=1}^k \frac{p_i^{x_i}}{x_i!} + + with :math:`n = \sum_i x_i`. + + Args: + x: arraylike, value at which to evaluate the PMF + n: arraylike, distribution shape parameter + p: arraylike, distribution shape parameter + + Returns: + array of logpmf values. + + See Also: + :func:`jax.scipy.stats.multinomial.pmf` + """ p, = promote_args_inexact("multinomial.logpmf", p) x, n = promote_args_numeric("multinomial.logpmf", x, n) if not jnp.issubdtype(x.dtype, jnp.integer): @@ -34,7 +54,28 @@ def logpmf(x: ArrayLike, n: ArrayLike, p: ArrayLike) -> Array: return jnp.where(jnp.equal(jnp.sum(x), n), logprobs, -jnp.inf) -@implements(osp_stats.multinomial.pmf, update_doc=False) def pmf(x: ArrayLike, n: ArrayLike, p: ArrayLike) -> Array: - """JAX implementation of scipy.stats.multinomial.pmf.""" + r"""Multinomial probability mass function. + + JAX implementation of :obj:`scipy.stats.multinomial` ``pmf``. + + The multinomial probability distribution is given by + + .. math:: + + f(x, n, p) = n! \prod_{i=1}^k \frac{p_i^{x_i}}{x_i!} + + with :math:`n = \sum_i x_i`. + + Args: + x: arraylike, value at which to evaluate the PMF + n: arraylike, distribution shape parameter + p: arraylike, distribution shape parameter + + Returns: + array of pmf values + + See Also: + :func:`jax.scipy.stats.multinomial.logpmf` + """ return lax.exp(logpmf(x, n, p)) diff --git a/jax/_src/scipy/stats/multivariate_normal.py b/jax/_src/scipy/stats/multivariate_normal.py index e833da0a49c2..8ba34703aada 100644 --- a/jax/_src/scipy/stats/multivariate_normal.py +++ b/jax/_src/scipy/stats/multivariate_normal.py @@ -15,18 +15,39 @@ from functools import partial import numpy as np -import scipy.stats as osp_stats from jax import lax from jax import numpy as jnp -from jax._src.numpy.util import implements, promote_dtypes_inexact +from jax._src.numpy.util import promote_dtypes_inexact from jax._src.typing import Array, ArrayLike -@implements(osp_stats.multivariate_normal.logpdf, update_doc=False, lax_description=""" -In the JAX version, the `allow_singular` argument is not implemented. -""") def logpdf(x: ArrayLike, mean: ArrayLike, cov: ArrayLike, allow_singular: None = None) -> ArrayLike: + r"""Multivariate normal log probability distribution function. + + JAX implementation of :obj:`scipy.stats.multivariate_normal` ``logpdf``. + + The multivariate normal PDF is defined as + + .. math:: + + f(x) = \frac{1}{(2\pi)^k\det\Sigma}\exp\left(-\frac{(x-\mu)^T\Sigma^{-1}(x-\mu)}{2} \right) + + where :math:`\mu` is the ``mean``, :math:`\Sigma` is the covariance matrix (``cov``), and + :math:`k` is the rank of :math:`\Sigma`. + + Args: + x: arraylike, value at which to evaluate the PDF + mean: arraylike, centroid of distribution + cov: arraylike, covariance matrix of distribution + allow_singular: not supported + + Returns: + array of logpdf values. + + See Also: + :func:`jax.scipy.stats.multivariate_normal.pdf` + """ if allow_singular is not None: raise NotImplementedError("allow_singular argument of multivariate_normal.logpdf") x, mean, cov = promote_dtypes_inexact(x, mean, cov) @@ -50,6 +71,31 @@ def logpdf(x: ArrayLike, mean: ArrayLike, cov: ArrayLike, allow_singular: None = return (-1/2 * jnp.einsum('...i,...i->...', y, y) - n/2 * jnp.log(2*np.pi) - jnp.log(L.diagonal(axis1=-1, axis2=-2)).sum(-1)) -@implements(osp_stats.multivariate_normal.pdf, update_doc=False) + def pdf(x: ArrayLike, mean: ArrayLike, cov: ArrayLike) -> Array: + r"""Multivariate normal probability distribution function. + + JAX implementation of :obj:`scipy.stats.multivariate_normal` ``pdf``. + + The multivariate normal PDF is defined as + + .. math:: + + f(x) = \frac{1}{(2\pi)^k\det\Sigma}\exp\left(-\frac{(x-\mu)^T\Sigma^{-1}(x-\mu)}{2} \right) + + where :math:`\mu` is the ``mean``, :math:`\Sigma` is the covariance matrix (``cov``), and + :math:`k` is the rank of :math:`\Sigma`. + + Args: + x: arraylike, value at which to evaluate the PDF + mean: arraylike, centroid of distribution + cov: arraylike, covariance matrix of distribution + allow_singular: not supported + + Returns: + array of pdf values. + + See Also: + :func:`jax.scipy.stats.multivariate_normal.logpdf` + """ return lax.exp(logpdf(x, mean, cov)) diff --git a/jax/_src/scipy/stats/nbinom.py b/jax/_src/scipy/stats/nbinom.py index 6af74442da10..a8d968526e70 100644 --- a/jax/_src/scipy/stats/nbinom.py +++ b/jax/_src/scipy/stats/nbinom.py @@ -12,32 +12,73 @@ # See the License for the specific language governing permissions and # limitations under the License - -import scipy.stats as osp_stats - from jax import lax import jax.numpy as jnp from jax._src.lax.lax import _const as _lax_const -from jax._src.numpy.util import implements, promote_args_inexact +from jax._src.numpy.util import promote_args_inexact from jax._src.scipy.special import gammaln, xlogy from jax._src.typing import Array, ArrayLike -@implements(osp_stats.nbinom.logpmf, update_doc=False) def logpmf(k: ArrayLike, n: ArrayLike, p: ArrayLike, loc: ArrayLike = 0) -> Array: - """JAX implementation of scipy.stats.nbinom.logpmf.""" - k, n, p, loc = promote_args_inexact("nbinom.logpmf", k, n, p, loc) - one = _lax_const(k, 1) - y = lax.sub(k, loc) - comb_term = lax.sub( - lax.sub(gammaln(lax.add(y, n)), gammaln(n)), gammaln(lax.add(y, one)) - ) - log_linear_term = lax.add(xlogy(n, p), xlogy(y, lax.sub(one, p))) - log_probs = lax.add(comb_term, log_linear_term) - return jnp.where(lax.lt(k, loc), -jnp.inf, log_probs) - - -@implements(osp_stats.nbinom.pmf, update_doc=False) + r"""Negative-binomial log probability mass function. + + JAX implementation of :obj:`scipy.stats.nbinom` ``logpmf``. + + The negative-binomial probability mass function is given by + + .. math:: + + f(k) = {{k+n-1} \choose {n-1}}p^n(1-p)^k + + for :math:`k \ge 0` and :math:`0 \le p \le 1`. + + Args: + k: arraylike, value at which to evaluate the PMF + n: arraylike, distribution shape parameter + p: arraylike, distribution shape parameter + loc: arraylike, distribution offset parameter + + Returns: + array of logpdf values. + + See Also: + :func:`jax.scipy.stats.nbinom.pmf` + """ + k, n, p, loc = promote_args_inexact("nbinom.logpmf", k, n, p, loc) + one = _lax_const(k, 1) + y = lax.sub(k, loc) + comb_term = lax.sub( + lax.sub(gammaln(lax.add(y, n)), gammaln(n)), gammaln(lax.add(y, one)) + ) + log_linear_term = lax.add(xlogy(n, p), xlogy(y, lax.sub(one, p))) + log_probs = lax.add(comb_term, log_linear_term) + return jnp.where(lax.lt(k, loc), -jnp.inf, log_probs) + + def pmf(k: ArrayLike, n: ArrayLike, p: ArrayLike, loc: ArrayLike = 0) -> Array: - """JAX implementation of scipy.stats.nbinom.pmf.""" - return lax.exp(logpmf(k, n, p, loc)) + r"""Negative-binomial probability mass function. + + JAX implementation of :obj:`scipy.stats.nbinom` ``pmf``. + + The negative-binomial probability mass function is given by + + .. math:: + + f(k) = {{k+n-1} \choose {n-1}}p^n(1-p)^k + + for :math:`k \ge 0` and :math:`0 \le p \le 1`. + + Args: + k: arraylike, value at which to evaluate the PMF + n: arraylike, distribution shape parameter + p: arraylike, distribution shape parameter + loc: arraylike, distribution offset parameter + + Returns: + array of pmf values. + + See Also: + :func:`jax.scipy.stats.nbinom.logpmf` + """ + return lax.exp(logpmf(k, n, p, loc)) diff --git a/jax/_src/scipy/stats/norm.py b/jax/_src/scipy/stats/norm.py index 1258e8905d89..b222e187f255 100644 --- a/jax/_src/scipy/stats/norm.py +++ b/jax/_src/scipy/stats/norm.py @@ -15,18 +15,43 @@ from typing import cast import numpy as np -import scipy.stats as osp_stats from jax import lax import jax.numpy as jnp from jax._src.lax.lax import _const as _lax_const -from jax._src.numpy.util import implements, promote_args_inexact +from jax._src.numpy.util import promote_args_inexact from jax._src.typing import Array, ArrayLike from jax.scipy import special -@implements(osp_stats.norm.logpdf, update_doc=False) def logpdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: + r"""Normal log probability distribution function. + + JAX implementation of :obj:`scipy.stats.norm` ``logpdf``. + + The normal distribution pdf is given by + + .. math:: + + f(x) = \frac{1}{\sqrt{2\pi}}e^{-x^2/2} + + Args: + x: arraylike, value at which to evaluate the PDF + loc: arraylike, distribution offset parameter + scale: arraylike, distribution scale parameter + + Returns: + array of logpdf values. + + See Also: + - :func:`jax.scipy.stats.norm.cdf` + - :func:`jax.scipy.stats.norm.pdf` + - :func:`jax.scipy.stats.norm.sf` + - :func:`jax.scipy.stats.norm.logcdf` + - :func:`jax.scipy.stats.norm.logsf` + - :func:`jax.scipy.stats.norm.isf` + - :func:`jax.scipy.stats.norm.ppf` + """ x, loc, scale = promote_args_inexact("norm.logpdf", x, loc, scale) scale_sqrd = lax.square(scale) log_normalizer = lax.log(lax.mul(_lax_const(x, 2 * np.pi), scale_sqrd)) @@ -34,41 +59,229 @@ def logpdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: return lax.div(lax.add(log_normalizer, quadratic), _lax_const(x, -2)) -@implements(osp_stats.norm.pdf, update_doc=False) def pdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: + r"""Normal probability distribution function. + + JAX implementation of :obj:`scipy.stats.norm` ``pdf``. + + The normal distribution pdf is given by + + .. math:: + + f(x) = \frac{1}{\sqrt{2\pi}}e^{-x^2/2} + + Args: + x: arraylike, value at which to evaluate the PDF + loc: arraylike, distribution offset parameter + scale: arraylike, distribution scale parameter + + Returns: + array of pdf values. + + See Also: + - :func:`jax.scipy.stats.norm.cdf` + - :func:`jax.scipy.stats.norm.sf` + - :func:`jax.scipy.stats.norm.logcdf` + - :func:`jax.scipy.stats.norm.logpdf` + - :func:`jax.scipy.stats.norm.logsf` + - :func:`jax.scipy.stats.norm.isf` + - :func:`jax.scipy.stats.norm.ppf` + """ return lax.exp(logpdf(x, loc, scale)) -@implements(osp_stats.norm.cdf, update_doc=False) def cdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: + r"""Normal cumulative distribution function. + + JAX implementation of :obj:`scipy.stats.norm` ``cdf``. + + The cdf is defined as + + .. math:: + + f_{cdf}(x, k) = \int_{-\infty}^x f_{pdf}(y, k)\mathrm{d}y + + where :math:`f_{pdf}` is the probability density function, + :func:`jax.scipy.stats.norm.pdf`. + + Args: + x: arraylike, value at which to evaluate the CDF + loc: arraylike, distribution offset parameter + scale: arraylike, distribution scale parameter + + Returns: + array of cdf values. + + See Also: + - :func:`jax.scipy.stats.norm.pdf` + - :func:`jax.scipy.stats.norm.sf` + - :func:`jax.scipy.stats.norm.logcdf` + - :func:`jax.scipy.stats.norm.logpdf` + - :func:`jax.scipy.stats.norm.logsf` + - :func:`jax.scipy.stats.norm.isf` + - :func:`jax.scipy.stats.norm.ppf` + """ x, loc, scale = promote_args_inexact("norm.cdf", x, loc, scale) return special.ndtr(lax.div(lax.sub(x, loc), scale)) -@implements(osp_stats.norm.logcdf, update_doc=False) def logcdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: + r"""Normal log cumulative distribution function. + + JAX implementation of :obj:`scipy.stats.norm` ``logcdf``. + + The cdf is defined as + + .. math:: + + f_{cdf}(x, k) = \int_{-\infty}^x f_{pdf}(y, k)\mathrm{d}y + + where :math:`f_{pdf}` is the probability density function, + :func:`jax.scipy.stats.norm.pdf`. + + Args: + x: arraylike, value at which to evaluate the CDF + loc: arraylike, distribution offset parameter + scale: arraylike, distribution scale parameter + + Returns: + array of logcdf values. + + See Also: + - :func:`jax.scipy.stats.norm.cdf` + - :func:`jax.scipy.stats.norm.pdf` + - :func:`jax.scipy.stats.norm.sf` + - :func:`jax.scipy.stats.norm.logpdf` + - :func:`jax.scipy.stats.norm.logsf` + - :func:`jax.scipy.stats.norm.isf` + - :func:`jax.scipy.stats.norm.ppf` + """ x, loc, scale = promote_args_inexact("norm.logcdf", x, loc, scale) # Cast required because custom_jvp return type is broken. return cast(Array, special.log_ndtr(lax.div(lax.sub(x, loc), scale))) -@implements(osp_stats.norm.ppf, update_doc=False) def ppf(q: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: + """Normal distribution percent point function. + + JAX implementation of :obj:`scipy.stats.norm` ``ppf``. + + The percent point function is defined as the inverse of the + cumulative distribution function, :func:`jax.scipy.stats.norm.cdf`. + + Args: + q: arraylike, value at which to evaluate the PPF + loc: arraylike, distribution offset parameter + scale: arraylike, distribution scale parameter + + Returns: + array of ppf values. + + See Also: + - :func:`jax.scipy.stats.norm.cdf` + - :func:`jax.scipy.stats.norm.pdf` + - :func:`jax.scipy.stats.norm.sf` + - :func:`jax.scipy.stats.norm.logcdf` + - :func:`jax.scipy.stats.norm.logpdf` + - :func:`jax.scipy.stats.norm.logsf` + - :func:`jax.scipy.stats.norm.isf` + """ return jnp.asarray(special.ndtri(q) * scale + loc, float) -@implements(osp_stats.norm.logsf, update_doc=False) def logsf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: + """Normal distribution log survival function. + + JAX implementation of :obj:`scipy.stats.norm` ``logsf``. + + The survival function is defined as + + .. math:: + + f_{sf}(x) = 1 - f_{cdf}(x) + + where :math:`f_{cdf}(x)` is the cumulative distribution function, + :func:`jax.scipy.stats.norm.cdf`. + + Args: + x: arraylike, value at which to evaluate the SF + loc: arraylike, distribution offset parameter + scale: arraylike, distribution scale parameter + + Returns: + array of logsf values. + + See Also: + - :func:`jax.scipy.stats.norm.cdf` + - :func:`jax.scipy.stats.norm.pdf` + - :func:`jax.scipy.stats.norm.sf` + - :func:`jax.scipy.stats.norm.logcdf` + - :func:`jax.scipy.stats.norm.logpdf` + - :func:`jax.scipy.stats.norm.isf` + - :func:`jax.scipy.stats.norm.ppf` + """ x, loc, scale = promote_args_inexact("norm.logsf", x, loc, scale) return logcdf(-x, -loc, scale) -@implements(osp_stats.norm.sf, update_doc=False) def sf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: + """Normal distribution survival function. + + JAX implementation of :obj:`scipy.stats.norm` ``sf``. + + The survival function is defined as + + .. math:: + + f_{sf}(x) = 1 - f_{cdf}(x) + + where :math:`f_{cdf}(x)` is the cumulative distribution function, + :func:`jax.scipy.stats.norm.cdf`. + + Args: + x: arraylike, value at which to evaluate the SF + loc: arraylike, distribution offset parameter + scale: arraylike, distribution scale parameter + + Returns: + array of sf values. + + See Also: + - :func:`jax.scipy.stats.norm.cdf` + - :func:`jax.scipy.stats.norm.pdf` + - :func:`jax.scipy.stats.norm.logcdf` + - :func:`jax.scipy.stats.norm.logpdf` + - :func:`jax.scipy.stats.norm.logsf` + - :func:`jax.scipy.stats.norm.isf` + - :func:`jax.scipy.stats.norm.ppf` + """ x, loc, scale = promote_args_inexact("norm.sf", x, loc, scale) return cdf(-x, -loc, scale) -@implements(osp_stats.norm.isf, update_doc=False) def isf(q: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: + """Normal distribution inverse survival function. + + JAX implementation of :obj:`scipy.stats.norm` ``isf``. + + Returns the inverse of the survival function, + :func:`jax.scipy.stats.norm.sf`. + + Args: + x: arraylike, value at which to evaluate the ISF + loc: arraylike, distribution offset parameter + scale: arraylike, distribution scale parameter + + Returns: + array of isf values. + + See Also: + - :func:`jax.scipy.stats.norm.cdf` + - :func:`jax.scipy.stats.norm.pdf` + - :func:`jax.scipy.stats.norm.sf` + - :func:`jax.scipy.stats.norm.logcdf` + - :func:`jax.scipy.stats.norm.logpdf` + - :func:`jax.scipy.stats.norm.logsf` + - :func:`jax.scipy.stats.norm.ppf` + """ return ppf(lax.sub(_lax_const(q, 1), q), loc, scale) diff --git a/jax/_src/scipy/stats/pareto.py b/jax/_src/scipy/stats/pareto.py index 0600fba29857..0b0c9e1a4993 100644 --- a/jax/_src/scipy/stats/pareto.py +++ b/jax/_src/scipy/stats/pareto.py @@ -12,18 +12,41 @@ # See the License for the specific language governing permissions and # limitations under the License. - -import scipy.stats as osp_stats - from jax import lax import jax.numpy as jnp from jax._src.lax.lax import _const as _lax_const -from jax._src.numpy.util import implements, promote_args_inexact +from jax._src.numpy.util import promote_args_inexact from jax._src.typing import Array, ArrayLike -@implements(osp_stats.pareto.logpdf, update_doc=False) def logpdf(x: ArrayLike, b: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: + r"""Pareto log probability distribution function. + + JAX implementation of :obj:`scipy.stats.pareto` ``logpdf``. + + The Pareto probability density function is given by + + .. math:: + + f(x, b) = \begin{cases} + bx^{-(b+1)} & x \ge 1\\ + 0 & x < 1 + \end{cases} + + and is defined for :math:`b > 0`. + + Args: + x: arraylike, value at which to evaluate the PDF + b: arraylike, distribution shape parameter + loc: arraylike, distribution offset parameter + scale: arraylike, distribution scale parameter + + Returns: + array of logpdf values. + + See Also: + :func:`jax.scipy.stats.pareto.pdf` + """ x, b, loc, scale = promote_args_inexact("pareto.logpdf", x, b, loc, scale) one = _lax_const(x, 1) scaled_x = lax.div(lax.sub(x, loc), scale) @@ -31,6 +54,33 @@ def logpdf(x: ArrayLike, b: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) log_probs = lax.neg(lax.add(normalize_term, lax.mul(lax.add(b, one), lax.log(scaled_x)))) return jnp.where(lax.lt(x, lax.add(loc, scale)), -jnp.inf, log_probs) -@implements(osp_stats.pareto.pdf, update_doc=False) + def pdf(x: ArrayLike, b: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: + r"""Pareto probability distribution function. + + JAX implementation of :obj:`scipy.stats.pareto` ``pdf``. + + The Pareto probability density function is given by + + .. math:: + + f(x, b) = \begin{cases} + bx^{-(b+1)} & x \ge 1\\ + 0 & x < 1 + \end{cases} + + and is defined for :math:`b > 0`. + + Args: + x: arraylike, value at which to evaluate the PDF + b: arraylike, distribution shape parameter + loc: arraylike, distribution offset parameter + scale: arraylike, distribution scale parameter + + Returns: + array of pdf values. + + See Also: + :func:`jax.scipy.stats.pareto.logpdf` + """ return lax.exp(logpdf(x, b, loc, scale)) diff --git a/jax/_src/scipy/stats/poisson.py b/jax/_src/scipy/stats/poisson.py index 3d1862031be0..84f4cfe89208 100644 --- a/jax/_src/scipy/stats/poisson.py +++ b/jax/_src/scipy/stats/poisson.py @@ -12,31 +12,101 @@ # See the License for the specific language governing permissions and # limitations under the License. - -import scipy.stats as osp_stats - from jax import lax import jax.numpy as jnp from jax._src.lax.lax import _const as _lax_const -from jax._src.numpy.util import implements, promote_args_inexact +from jax._src.numpy.util import promote_args_inexact from jax._src.typing import Array, ArrayLike from jax.scipy.special import xlogy, gammaln, gammaincc -@implements(osp_stats.poisson.logpmf, update_doc=False) def logpmf(k: ArrayLike, mu: ArrayLike, loc: ArrayLike = 0) -> Array: + r"""Poisson log probability mass function. + + JAX implementation of :obj:`scipy.stats.poisson` ``logpmf``. + + The Poisson probability mass function is given by + + .. math:: + + f(k) = e^{-\mu}\frac{\mu^k}{k!} + + and is defined for :math:`k \ge 0` and :math:`\mu \ge 0`. + + Args: + k: arraylike, value at which to evaluate the PMF + mu: arraylike, distribution shape parameter + loc: arraylike, distribution offset parameter + + Returns: + array of logpmf values. + + See Also: + - :func:`jax.scipy.stats.poisson.cdf` + - :func:`jax.scipy.stats.poisson.pmf` + """ k, mu, loc = promote_args_inexact("poisson.logpmf", k, mu, loc) zero = _lax_const(k, 0) x = lax.sub(k, loc) log_probs = xlogy(x, mu) - gammaln(x + 1) - mu - return jnp.where(lax.lt(x, zero), -jnp.inf, log_probs) + return jnp.where(jnp.logical_or(lax.lt(x, zero), + lax.ne(jnp.round(k), k)), -jnp.inf, log_probs) + -@implements(osp_stats.poisson.pmf, update_doc=False) def pmf(k: ArrayLike, mu: ArrayLike, loc: ArrayLike = 0) -> Array: + r"""Poisson probability mass function. + + JAX implementation of :obj:`scipy.stats.poisson` ``pmf``. + + The Poisson probability mass function is given by + + .. math:: + + f(k) = e^{-\mu}\frac{\mu^k}{k!} + + and is defined for :math:`k \ge 0` and :math:`\mu \ge 0`. + + Args: + k: arraylike, value at which to evaluate the PMF + mu: arraylike, distribution shape parameter + loc: arraylike, distribution offset parameter + + Returns: + array of pmf values. + + See Also: + - :func:`jax.scipy.stats.poisson.cdf` + - :func:`jax.scipy.stats.poisson.logpmf` + """ return jnp.exp(logpmf(k, mu, loc)) -@implements(osp_stats.poisson.cdf, update_doc=False) + def cdf(k: ArrayLike, mu: ArrayLike, loc: ArrayLike = 0) -> Array: + r"""Poisson cumulative distribution function. + + JAX implementation of :obj:`scipy.stats.poisson` ``cdf``. + + The cumulative distribution function is defined as: + + .. math:: + + f_{cdf}(k, p) = \sum_{i=0}^k f_{pmf}(k, p) + + where :math:`f_{pmf}(k, p)` is the probability mass function + :func:`jax.scipy.stats.poisson.pmf`. + + Args: + k: arraylike, value at which to evaluate the CDF + mu: arraylike, distribution shape parameter + loc: arraylike, distribution offset parameter + + Returns: + array of cdf values. + + See Also: + - :func:`jax.scipy.stats.poisson.pmf` + - :func:`jax.scipy.stats.poisson.logpmf` + """ k, mu, loc = promote_args_inexact("poisson.logpmf", k, mu, loc) zero = _lax_const(k, 0) x = lax.sub(k, loc) diff --git a/jax/_src/scipy/stats/t.py b/jax/_src/scipy/stats/t.py index 742a2e16297c..2e276c831e28 100644 --- a/jax/_src/scipy/stats/t.py +++ b/jax/_src/scipy/stats/t.py @@ -14,16 +14,39 @@ import numpy as np -import scipy.stats as osp_stats from jax import lax from jax._src.lax.lax import _const as _lax_const -from jax._src.numpy.util import implements, promote_args_inexact +from jax._src.numpy.util import promote_args_inexact from jax._src.typing import Array, ArrayLike -@implements(osp_stats.t.logpdf, update_doc=False) def logpdf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: + r"""Student's T log probability distribution function. + + JAX implementation of :obj:`scipy.stats.t` ``logpdf``. + + The Student's T probability distribution function is given by + + .. math:: + + f(x, \nu) = \frac{\Gamma((\nu + 1)/2)}{\sqrt{\pi\nu}\Gamma(\nu/2)}(1 + x^2/\nu)^{(\nu+1)/2} + + Where :math:`\Gamma` is the :func:`~jax.scipy.special.gamma` function, and :math:`\nu > 0` + is the degrees of freedom (JAX follows the scipy convention of naming this ``df``). + + Args: + x: arraylike, value at which to evaluate the PDF + df: arraylike, distribution shape parameter + loc: arraylike, distribution offset parameter + scale: arraylike, distribution scale parameter + + Returns: + array of logpdf values. + + See Also: + :func:`jax.scipy.stats.t.pdf` + """ x, df, loc, scale = promote_args_inexact("t.logpdf", x, df, loc, scale) two = _lax_const(x, 2) scaled_x = lax.div(lax.sub(x, loc), scale) @@ -37,6 +60,30 @@ def logpdf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1 return lax.neg(lax.add(normalize_term, lax.mul(df_plus_one_over_two, lax.log1p(quadratic)))) -@implements(osp_stats.t.pdf, update_doc=False) def pdf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: + r"""Student's T probability distribution function. + + JAX implementation of :obj:`scipy.stats.t` ``pdf``. + + The Student's T probability distribution function is given by + + .. math:: + + f(x, \nu) = \frac{\Gamma((\nu + 1)/2)}{\sqrt{\pi\nu}\Gamma(\nu/2)}(1 + x^2/\nu)^{(\nu+1)/2} + + Where :math:`\Gamma` is the :func:`~jax.scipy.special.gamma` function, and :math:`\nu > 0` + is the degrees of freedom (JAX follows the scipy convention of naming this ``df``). + + Args: + x: arraylike, value at which to evaluate the PDF + df: arraylike, distribution shape parameter + loc: arraylike, distribution offset parameter + scale: arraylike, distribution scale parameter + + Returns: + array + + See Also: + :func:`jax.scipy.stats.t.logpdf` + """ return lax.exp(logpdf(x, df, loc, scale)) diff --git a/jax/_src/scipy/stats/truncnorm.py b/jax/_src/scipy/stats/truncnorm.py index beadd682da21..a02e07d10480 100644 --- a/jax/_src/scipy/stats/truncnorm.py +++ b/jax/_src/scipy/stats/truncnorm.py @@ -12,12 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. - -import scipy.stats as osp_stats - from jax import lax import jax.numpy as jnp -from jax._src.numpy.util import implements, promote_args_inexact +from jax._src.numpy.util import promote_args_inexact from jax._src.scipy.stats import norm from jax._src.scipy.special import logsumexp, log_ndtr, ndtr @@ -69,8 +66,41 @@ def mass_case_central(a, b): return out -@implements(osp_stats.truncnorm.logpdf, update_doc=False) def logpdf(x, a, b, loc=0, scale=1): + r"""Truncated normal log probability distribution function. + + JAX implementation of :obj:`scipy.stats.truncnorm` ``logpdf``. + + The truncated normal probability distribution is given by + + .. math:: + + f(x, a, b) = \begin{cases} + \frac{1}{\sqrt{2\pi}}e^{-x^2/2} & a \le x \le b \\ + 0 & \mathrm{otherwise} + \end{cases} + + where :math:`a` and :math:`b` are effectively specified in number of + standard deviations from zero. JAX uses the scipy nomenclature + of ``loc`` for the centroid and ``scale`` for the standard deviation. + + Args: + x: arraylike, value at which to evaluate the PDF + a: arraylike, distribution shape parameter + b: arraylike, distribution shape parameter + loc: arraylike, distribution offset parameter + scale: arraylike, distribution scale parameter + + Returns: + array of logpdf values. + + See Also: + - :func:`jax.scipy.stats.truncnorm.cdf` + - :func:`jax.scipy.stats.truncnorm.pdf` + - :func:`jax.scipy.stats.truncnorm.sf` + - :func:`jax.scipy.stats.truncnorm.logcdf` + - :func:`jax.scipy.stats.truncnorm.logsf` + """ x, a, b, loc, scale = promote_args_inexact("truncnorm.logpdf", x, a, b, loc, scale) val = lax.sub(norm.logpdf(x, loc, scale), _log_gauss_mass(a, b)) @@ -80,24 +110,144 @@ def logpdf(x, a, b, loc=0, scale=1): return val -@implements(osp_stats.truncnorm.pdf, update_doc=False) def pdf(x, a, b, loc=0, scale=1): + r"""Truncated normal probability distribution function. + + JAX implementation of :obj:`scipy.stats.truncnorm` ``pdf``. + + The truncated normal probability distribution is given by + + .. math:: + + f(x, a, b) = \begin{cases} + \frac{1}{\sqrt{2\pi}}e^{-x^2/2} & a \le x \le b \\ + 0 & \mathrm{otherwise} + \end{cases} + + where :math:`a` and :math:`b` are effectively specified in number of + standard deviations from the centroid. JAX uses the scipy nomenclature + of ``loc`` for the centroid and ``scale`` for the standard deviation. + + Args: + x: arraylike, value at which to evaluate the PDF + a: arraylike, distribution shape parameter + b: arraylike, distribution shape parameter + loc: arraylike, distribution offset parameter + scale: arraylike, distribution scale parameter + + Returns: + array of pdf values. + + See Also: + - :func:`jax.scipy.stats.truncnorm.cdf` + - :func:`jax.scipy.stats.truncnorm.sf` + - :func:`jax.scipy.stats.truncnorm.logcdf` + - :func:`jax.scipy.stats.truncnorm.logpdf` + - :func:`jax.scipy.stats.truncnorm.logsf` + """ return lax.exp(logpdf(x, a, b, loc, scale)) -@implements(osp_stats.truncnorm.logsf, update_doc=False) def logsf(x, a, b, loc=0, scale=1): + """Truncated normal distribution log survival function. + + JAX implementation of :obj:`scipy.stats.truncnorm` ``logsf`` + + The survival function is defined as + + .. math:: + + f_{sf}(x) = 1 - f_{cdf}(x) + + where :math:`f_{cdf}(x)` is the cumulative distribution function, + :func:`jax.scipy.stats.truncnorm.cdf`. + + Args: + x: arraylike, value at which to evaluate the SF + a: arraylike, distribution shape parameter + b: arraylike, distribution shape parameter + loc: arraylike, distribution offset parameter + scale: arraylike, distribution scale parameter + + Returns: + array of logsf values. + + See Also: + - :func:`jax.scipy.stats.truncnorm.cdf` + - :func:`jax.scipy.stats.truncnorm.pdf` + - :func:`jax.scipy.stats.truncnorm.sf` + - :func:`jax.scipy.stats.truncnorm.logcdf` + - :func:`jax.scipy.stats.truncnorm.logpdf` + """ x, a, b, loc, scale = promote_args_inexact("truncnorm.logsf", x, a, b, loc, scale) return logcdf(-x, -b, -a, -loc, scale) -@implements(osp_stats.truncnorm.sf, update_doc=False) def sf(x, a, b, loc=0, scale=1): + """Truncated normal distribution log survival function. + + JAX implementation of :obj:`scipy.stats.truncnorm` ``logsf`` + + The survival function is defined as + + .. math:: + + f_{sf}(x) = 1 - f_{cdf}(x) + + where :math:`f_{cdf}(x)` is the cumulative distribution function, + :func:`jax.scipy.stats.truncnorm.cdf`. + + Args: + x: arraylike, value at which to evaluate the SF + a: arraylike, distribution shape parameter + b: arraylike, distribution shape parameter + loc: arraylike, distribution offset parameter + scale: arraylike, distribution scale parameter + + Returns: + array of sf values. + + See Also: + - :func:`jax.scipy.stats.truncnorm.cdf` + - :func:`jax.scipy.stats.truncnorm.pdf` + - :func:`jax.scipy.stats.truncnorm.sf` + - :func:`jax.scipy.stats.truncnorm.logcdf` + - :func:`jax.scipy.stats.truncnorm.logpdf` + """ return lax.exp(logsf(x, a, b, loc, scale)) -@implements(osp_stats.truncnorm.logcdf, update_doc=False) def logcdf(x, a, b, loc=0, scale=1): + r"""Truncated normal log cumulative distribution function. + + JAX implementation of :obj:`scipy.stats.truncnorm` ``logcdf``. + + The cdf is defined as + + .. math:: + + f_{cdf} = \int_{-\infty}^x f_{pdf}(y) \mathrm{d}y + + where here :math:`f_{pdf}` is the probability distribution function, + :func:`jax.scipy.stats.truncnorm.pdf`. + + Args: + x: arraylike, value at which to evaluate the CDF + a: arraylike, distribution shape parameter + b: arraylike, distribution shape parameter + loc: arraylike, distribution offset parameter + scale: arraylike, distribution scale parameter + + Returns: + array of logcdf values. + + See Also: + - :func:`jax.scipy.stats.truncnorm.cdf` + - :func:`jax.scipy.stats.truncnorm.pdf` + - :func:`jax.scipy.stats.truncnorm.sf` + - :func:`jax.scipy.stats.truncnorm.logpdf` + - :func:`jax.scipy.stats.truncnorm.logsf` + """ x, a, b, loc, scale = promote_args_inexact("truncnorm.logcdf", x, a, b, loc, scale) x, a, b = jnp.broadcast_arrays(x, a, b) x = lax.div(lax.sub(x, loc), scale) @@ -113,6 +263,35 @@ def logcdf(x, a, b, loc=0, scale=1): return logcdf -@implements(osp_stats.truncnorm.cdf, update_doc=False) def cdf(x, a, b, loc=0, scale=1): + r"""Truncated normal cumulative distribution function. + + JAX implementation of :obj:`scipy.stats.truncnorm` ``cdf``. + + The cdf is defined as + + .. math:: + + f_{cdf} = \int_{-\infty}^x f_{pdf}(y) \mathrm{d}y + + where here :math:`f_{pdf}` is the probability distribution function, + :func:`jax.scipy.stats.truncnorm.pdf`. + + Args: + x: arraylike, value at which to evaluate the CDF + a: arraylike, distribution shape parameter + b: arraylike, distribution shape parameter + loc: arraylike, distribution offset parameter + scale: arraylike, distribution scale parameter + + Returns: + array of cdf values. + + See Also: + - :func:`jax.scipy.stats.truncnorm.pdf` + - :func:`jax.scipy.stats.truncnorm.sf` + - :func:`jax.scipy.stats.truncnorm.logcdf` + - :func:`jax.scipy.stats.truncnorm.logpdf` + - :func:`jax.scipy.stats.truncnorm.logsf` + """ return lax.exp(logcdf(x, a, b, loc, scale)) diff --git a/jax/_src/scipy/stats/uniform.py b/jax/_src/scipy/stats/uniform.py index ba186cc6ca78..8d36e23c1b70 100644 --- a/jax/_src/scipy/stats/uniform.py +++ b/jax/_src/scipy/stats/uniform.py @@ -12,30 +12,104 @@ # See the License for the specific language governing permissions and # limitations under the License. - -import scipy.stats as osp_stats - from jax import lax from jax import numpy as jnp from jax.numpy import where, inf, logical_or from jax._src.typing import Array, ArrayLike -from jax._src.numpy.util import implements, promote_args_inexact +from jax._src.numpy.util import promote_args_inexact -@implements(osp_stats.uniform.logpdf, update_doc=False) def logpdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: + r"""Uniform log probability distribution function. + + JAX implementation of :obj:`scipy.stats.uniform` ``logpdf``. + + The uniform distribution pdf is given by + + .. math:: + + f(x) = \begin{cases} + 1 & 0 \le x \le 1 \\ + 0 & \mathrm{otherwise} + \end{cases} + + Args: + x: arraylike, value at which to evaluate the PDF + loc: arraylike, distribution offset parameter + scale: arraylike, distribution scale parameter + + Returns: + array of logpdf values + + See Also: + - :func:`jax.scipy.stats.uniform.cdf` + - :func:`jax.scipy.stats.uniform.pdf` + - :func:`jax.scipy.stats.uniform.ppf` + """ x, loc, scale = promote_args_inexact("uniform.logpdf", x, loc, scale) log_probs = lax.neg(lax.log(scale)) return where(logical_or(lax.gt(x, lax.add(loc, scale)), lax.lt(x, loc)), -inf, log_probs) -@implements(osp_stats.uniform.pdf, update_doc=False) + def pdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: + r"""Uniform probability distribution function. + + JAX implementation of :obj:`scipy.stats.uniform` ``pdf``. + + The uniform distribution pdf is given by + + .. math:: + + f(x) = \begin{cases} + 1 & 0 \le x \le 1 \\ + 0 & \mathrm{otherwise} + \end{cases} + + Args: + x: arraylike, value at which to evaluate the PDF + loc: arraylike, distribution offset parameter + scale: arraylike, distribution scale parameter + + Returns: + array of pdf values. + + See Also: + - :func:`jax.scipy.stats.uniform.cdf` + - :func:`jax.scipy.stats.uniform.logpdf` + - :func:`jax.scipy.stats.uniform.ppf` + """ return lax.exp(logpdf(x, loc, scale)) -@implements(osp_stats.uniform.cdf, update_doc=False) + def cdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: + r"""Uniform cumulative distribution function. + + JAX implementation of :obj:`scipy.stats.uniform` ``cdf``. + + The cdf is defined as + + .. math:: + + f_{cdf} = \int_{-\infty}^x f_{pdf}(y) \mathrm{d}y + + where here :math:`f_{pdf}` is the probability distribution function, + :func:`jax.scipy.stats.uniform.pdf`. + + Args: + x: arraylike, value at which to evaluate the CDF + loc: arraylike, distribution offset parameter + scale: arraylike, distribution scale parameter + + Returns: + array of cdf values. + + See Also: + - :func:`jax.scipy.stats.uniform.pdf` + - :func:`jax.scipy.stats.uniform.logpdf` + - :func:`jax.scipy.stats.uniform.ppf` + """ x, loc, scale = promote_args_inexact("uniform.cdf", x, loc, scale) zero, one = jnp.array(0, x.dtype), jnp.array(1, x.dtype) conds = [lax.lt(x, loc), lax.gt(x, lax.add(loc, scale)), lax.ge(x, loc) & lax.le(x, lax.add(loc, scale))] @@ -43,8 +117,28 @@ def cdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: return jnp.select(conds, vals) -@implements(osp_stats.uniform.ppf, update_doc=False) + def ppf(q: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: + """Uniform distribution percent point function. + + JAX implementation of :obj:`scipy.stats.uniform` ``ppf``. + + The percent point function is defined as the inverse of the + cumulative distribution function, :func:`jax.scipy.stats.uniform.cdf`. + + Args: + q: arraylike, value at which to evaluate the PPF + loc: arraylike, distribution offset parameter + scale: arraylike, distribution scale parameter + + Returns: + array of ppf values. + + See Also: + - :func:`jax.scipy.stats.uniform.cdf` + - :func:`jax.scipy.stats.uniform.pdf` + - :func:`jax.scipy.stats.uniform.logpdf` + """ q, loc, scale = promote_args_inexact("uniform.ppf", q, loc, scale) return where( jnp.isnan(q) | (q < 0) | (q > 1), diff --git a/jax/_src/scipy/stats/vonmises.py b/jax/_src/scipy/stats/vonmises.py index b32799c37d3d..631cc8ee2145 100644 --- a/jax/_src/scipy/stats/vonmises.py +++ b/jax/_src/scipy/stats/vonmises.py @@ -12,20 +12,66 @@ # See the License for the specific language governing permissions and # limitations under the License. - -import scipy.stats as osp_stats from jax import lax import jax.numpy as jnp from jax._src.lax.lax import _const as _lax_const -from jax._src.numpy.util import implements, promote_args_inexact +from jax._src.numpy.util import promote_args_inexact from jax._src.typing import Array, ArrayLike -@implements(osp_stats.vonmises.logpdf, update_doc=False) + def logpdf(x: ArrayLike, kappa: ArrayLike) -> Array: + r"""von Mises log probability distribution function. + + JAX implementation of :obj:`scipy.stats.vonmises` ``logpdf``. + + The von Mises probability distribution function is given by + + .. math:: + + f(x, \kappa) = \frac{1}{2\pi I_0(\kappa)}e^{\kappa\cos x} + + Where :math:`I_0` is the modified Bessel function :func:`~jax.scipy.special.i0` + and :math:`\kappa\ge 0`, and the distribution is normalized in the interval + :math:`-\pi \le x \le \pi`. + + Args: + x: arraylike, value at which to evaluate the PDF + kappa: arraylike, distribution shape parameter + + Returns: + array of logpdf values. + + See Also: + :func:`jax.scipy.stats.vonmises.pdf` + """ x, kappa = promote_args_inexact('vonmises.logpdf', x, kappa) zero = _lax_const(kappa, 0) return jnp.where(lax.gt(kappa, zero), kappa * (jnp.cos(x) - 1) - jnp.log(2 * jnp.pi * lax.bessel_i0e(kappa)), jnp.nan) -@implements(osp_stats.vonmises.pdf, update_doc=False) + def pdf(x: ArrayLike, kappa: ArrayLike) -> Array: + r"""von Mises probability distribution function. + + JAX implementation of :obj:`scipy.stats.vonmises` ``pdf``. + + The von Mises probability distribution function is given by + + .. math:: + + f(x, \kappa) = \frac{1}{2\pi I_0(\kappa)}e^{\kappa\cos x} + + Where :math:`I_0` is the modified Bessel function :func:`~jax.scipy.special.i0` + and :math:`\kappa\ge 0`, and the distribution is normalized in the interval + :math:`-\pi \le x \le \pi`. + + Args: + x: arraylike, value at which to evaluate the PDF + kappa: arraylike, distribution shape parameter + + Returns: + array of pdf values. + + See Also: + :func:`jax.scipy.stats.vonmises.logpdf` + """ return lax.exp(logpdf(x, kappa)) diff --git a/jax/_src/scipy/stats/wrapcauchy.py b/jax/_src/scipy/stats/wrapcauchy.py index f05b4e8606ee..26b24d7da447 100644 --- a/jax/_src/scipy/stats/wrapcauchy.py +++ b/jax/_src/scipy/stats/wrapcauchy.py @@ -13,16 +13,36 @@ # limitations under the License. -import scipy.stats as osp_stats from jax import lax import jax.numpy as jnp from jax._src.lax.lax import _const as _lax_const -from jax._src.numpy.util import implements, promote_args_inexact +from jax._src.numpy.util import promote_args_inexact from jax._src.typing import Array, ArrayLike -@implements(osp_stats.wrapcauchy.logpdf, update_doc=False) def logpdf(x: ArrayLike, c: ArrayLike) -> Array: + r"""Wrapped Cauchy log probability distribution function. + + JAX implementation of :obj:`scipy.stats.wrapcauchy` ``logpdf``. + + The wrapped Cauchy probability distribution function is given by + + .. math:: + + f(x, c) = \frac{1-c^2}{2\pi(1+c^2-2c\cos x)} + + for :math:`0 Array: jnp.nan, ) -@implements(osp_stats.wrapcauchy.pdf, update_doc=False) + def pdf(x: ArrayLike, c: ArrayLike) -> Array: + r"""Wrapped Cauchy probability distribution function. + + JAX implementation of :obj:`scipy.stats.wrapcauchy` ``pdf``. + + The wrapped Cauchy probability distribution function is given by + + .. math:: + + f(x, c) = \frac{1-c^2}{2\pi(1+c^2-2c\cos x)} + + for :math:`0 Mapping[Device, Index | None]: global_map = sharding.devices_indices_map(global_shape) @@ -39,8 +42,41 @@ def _addressable_devices_indices_map( return {d: ind for d, ind in global_map.items() if d.process_index == d.client.process_index()} - -@util.use_cpp_class(xc.Sharding) +@cache(max_size=4096, trace_context_in_key=False) +def common_devices_indices_map(s, global_shape: Shape) -> Mapping[Device, Index]: + s.shard_shape(global_shape) # raises a good error message + hlo_sharding = s._to_xla_hlo_sharding(len(global_shape)) + indices = op_sharding_to_indices(hlo_sharding, global_shape, + len(s._device_assignment)) + return dict(safe_zip(s._device_assignment, indices)) + + +@cache(max_size=4096, trace_context_in_key=False) +def _common_shard_shape(self, global_shape: Shape) -> Shape: + hlo_sharding = self._to_xla_hlo_sharding(len(global_shape)) + if is_op_sharding_replicated(hlo_sharding): + return global_shape + partitions, _ = get_num_ways_dim_sharded(hlo_sharding) + assert len(partitions) == len(global_shape), (len(partitions), len(global_shape)) + out = [] + for dim, (s, p) in enumerate(safe_zip(global_shape, partitions)): + try: + quotient, remainder = divmod(s, p) + except TypeError: + # TODO Figure out how to partition dynamic shapes + raise NotImplementedError + if remainder != 0: + raise ValueError( + f"Sharding {self} implies that array axis {dim} is partitioned " + f"{p} times, but the dimension size is {s} " + f"(full shape: {global_shape}, " + f"per-dimension tiling factors: {partitions} should evenly divide " + "the shape)") + out.append(quotient) + return tuple(out) + + +@use_cpp_class(xc.Sharding) class Sharding: """Describes how a :class:`jax.Array` is laid out across devices. """ @@ -55,35 +91,6 @@ def device_set(self) -> set[Device]: """ raise NotImplementedError('Subclasses should implement this method.') - def devices_indices_map( - self, global_shape: Shape) -> Mapping[Device, Index | None]: - """Returns a mapping from devices to the array slices each contains. - - The mapping includes all global devices, i.e., including - non-addressable devices from other processes. - """ - raise NotImplementedError('Subclasses should implement this method.') - - def shard_shape(self, global_shape: Shape) -> Shape: - """Returns the shape of the data on each device. - - The shard shape returned by this function is calculated from - ``global_shape`` and the properties of the sharding. - """ - raise NotImplementedError('Subclasses should implement this method.') - - def is_equivalent_to(self, other: Sharding, ndim: int) -> bool: - """Returns ``True`` if two shardings are equivalent. - - Two shardings are equivalent if they place the same logical array shards on - the same devices. - - For example, a :class:`NamedSharding` may be equivalent - to a :class:`PositionalSharding` if both place the same shards of the array - on the same devices. - """ - raise NotImplementedError('Subclasses should implement this method.') - @property def is_fully_replicated(self) -> bool: """Is this sharding fully replicated? @@ -112,6 +119,14 @@ def with_memory_kind(self, kind: str) -> Sharding: """Returns a new Sharding instance with the specified memory kind.""" raise NotImplementedError('Subclasses should implement this method') + @property + def _device_assignment(self) -> XLADeviceAssignment: + raise NotImplementedError('Subclasses should implement this method.') + + def _to_xla_hlo_sharding(self, num_dimensions: int) -> xc.HloSharding: + raise NotImplementedError('Subclasses should implement this method.') + + ############################################################################# # Default implementations below that all subclasses will inherit. @@ -134,3 +149,49 @@ def addressable_devices_indices_map( ``device_indices_map`` that applies to the addressable devices. """ return _addressable_devices_indices_map(self, global_shape) + + def devices_indices_map(self, global_shape: Shape) -> Mapping[Device, Index]: + """Returns a mapping from devices to the array slices each contains. + + The mapping includes all global devices, i.e., including + non-addressable devices from other processes. + """ + return common_devices_indices_map(self, global_shape) + + @functools.cached_property + def _addressable_device_assignment(self) -> XLADeviceAssignment: + if self.is_fully_addressable: + return self._device_assignment + if hasattr(self, '_internal_device_list'): + return tuple(self._internal_device_list.addressable_device_list) + return tuple(d for d in self._device_assignment + if d.process_index == d.client.process_index()) + + def shard_shape(self, global_shape: Shape) -> Shape: + """Returns the shape of the data on each device. + + The shard shape returned by this function is calculated from + ``global_shape`` and the properties of the sharding. + """ + return _common_shard_shape(self, global_shape) + + def is_equivalent_to(self: Sharding, other: Sharding, ndim: int) -> bool: + """Returns ``True`` if two shardings are equivalent. + + Two shardings are equivalent if they place the same logical array shards on + the same devices. + + For example, a :class:`NamedSharding` may be equivalent + to a :class:`PositionalSharding` if both place the same shards of the array + on the same devices. + """ + try: + return (are_op_shardings_equal(self._to_xla_hlo_sharding(ndim), + other._to_xla_hlo_sharding(ndim)) + and self._internal_device_list == other._internal_device_list and # type: ignore + self.memory_kind == other.memory_kind) + # NotImplementedError is raised by PmapSharding because it can't lower + # to OpSharding. So if `other` is a PmapSharding, default to a strict + # equality check. + except NotImplementedError: + return self == other diff --git a/jax/_src/sharding_impls.py b/jax/_src/sharding_impls.py index a3ff34e8dd01..1ecacdac9b7c 100644 --- a/jax/_src/sharding_impls.py +++ b/jax/_src/sharding_impls.py @@ -25,19 +25,17 @@ from typing import Any, NamedTuple, Union, cast from jax._src import mesh as mesh_lib -from jax._src.op_shardings import ( - is_op_sharding_replicated, are_op_shardings_equal, get_num_ways_dim_sharded, - op_sharding_to_indices) from jax._src import sharding from jax._src import sharding_specs from jax._src import tree_util from jax._src import util from jax._src import xla_bridge -from jax._src.util import safe_map, safe_zip, use_cpp_class, use_cpp_method +from jax._src import core from jax._src.lib import xla_client as xc -from jax._src.lib import xla_extension_version +from jax._src.op_shardings import ( + are_op_shardings_equal, get_num_ways_dim_sharded, is_op_sharding_replicated) from jax._src.partition_spec import PartitionSpec - +from jax._src.util import safe_map, safe_zip, use_cpp_class, use_cpp_method import numpy as np @@ -45,101 +43,25 @@ Device = xc.Device Index = tuple[slice, ...] XLADeviceAssignment = tuple[Device, ...] - +# TODO(yashkatariya): Remove this after 3 months of deprecation. +XLACompatibleSharding = sharding.Sharding @dataclasses.dataclass(frozen=True) class TransferToMemoryKind: memory_kind: str -@functools.lru_cache(maxsize=4096) -def common_devices_indices_map(s, global_shape: Shape) -> Mapping[Device, Index]: - hlo_sharding = s._to_xla_hlo_sharding(len(global_shape)) - gspmd_sharding = GSPMDSharding(s._device_assignment, hlo_sharding) - return gspmd_sharding.devices_indices_map(global_shape) - - -@functools.lru_cache(maxsize=4096) -def _common_shard_shape(self, global_shape: Shape) -> Shape: - hlo_sharding = self._to_xla_hlo_sharding(len(global_shape)) - if is_op_sharding_replicated(hlo_sharding): - return global_shape - partitions, _ = get_num_ways_dim_sharded(hlo_sharding) - assert len(partitions) == len(global_shape), (len(partitions), len(global_shape)) - out = [] - for dim, (s, p) in enumerate(safe_zip(global_shape, partitions)): - try: - quotient, remainder = divmod(s, p) - except TypeError: - # TODO Figure out how to partition dynamic shapes - raise NotImplementedError - if remainder != 0: - raise ValueError( - f"Sharding {self} implies that array axis {dim} is partitioned " - f"{p} times, but the dimension size is {s} " - f"(full shape: {global_shape}, " - f"per-dimension tiling factors: {partitions} should evenly divide " - "the shape)") - out.append(quotient) - return tuple(out) - - -# Shardings that inherit from XLACompatibleSharding should implement the -# `_device_assignment` property and `_to_xla_hlo_sharding` method. -@use_cpp_class(xc.XLACompatibleSharding) -class XLACompatibleSharding(sharding.Sharding): - """A :class:`Sharding` that describes shardings expressible to XLA. - - Subclasses of :class:`XLACompatibleSharding` work with - all JAX APIs and transformations that use XLA. - """ - - # Abstract methods below that subclasses should implement. - - @property - def _device_assignment(self) -> XLADeviceAssignment: - raise NotImplementedError('Subclasses should implement this method.') - - def _to_xla_hlo_sharding(self, num_dimensions: int) -> xc.HloSharding: - raise NotImplementedError('Subclasses should implement this method.') - - ############################################################################# - # Default implementations below that all subclasses will inherit. - - def devices_indices_map(self, global_shape: Shape) -> Mapping[Device, Index]: - return common_devices_indices_map(self, global_shape) - - @functools.cached_property - def _addressable_device_assignment(self) -> XLADeviceAssignment: - if self.is_fully_addressable: - return self._device_assignment - if hasattr(self, '_internal_device_list'): - return tuple(self._internal_device_list.addressable_device_list) - return tuple(d for d in self._device_assignment - if d.process_index == d.client.process_index()) - - def shard_shape(self, global_shape: Shape) -> Shape: - return _common_shard_shape(self, global_shape) - - def is_equivalent_to(self: XLACompatibleSharding, # type: ignore - other: XLACompatibleSharding, ndim: int) -> bool: - try: - return (are_op_shardings_equal(self._to_xla_hlo_sharding(ndim), - other._to_xla_hlo_sharding(ndim)) - and self._internal_device_list == other._internal_device_list and # type: ignore - self.memory_kind == other.memory_kind) - # NotImplementedError is raised by PmapSharding because it can't lower - # to OpSharding. So if `other` is a PmapSharding, default to a strict - # equality check. - except NotImplementedError: - return self == other - - -@functools.lru_cache -def _check_mesh_resource_axis(mesh, parsed_pspec): +@util.cache(max_size=128, trace_context_in_key=False) +def _check_mesh_resource_axis(mesh, parsed_pspec, _manual_axes): try: - [mesh.shape[r] for p in parsed_pspec if p is not None - for r in p] + for p in parsed_pspec: + if p is not None: + for r in p: + mesh.shape[r] + if r in _manual_axes: + raise ValueError( + f"Axis: {r} of {parsed_pspec.get_partition_spec()} " + f"is also found in manual_axes: {_manual_axes}.") from None except KeyError as e: raise ValueError(f"Resource axis: {e.args[0]} of {parsed_pspec.user_spec} is " "undefined.") from None @@ -152,7 +74,7 @@ def hashed_index(x) -> int: return hash(tuple((v.start, v.stop) if isinstance(v, slice) else v for v in x)) -@functools.lru_cache(maxsize=4096) +@util.cache(max_size=4096, trace_context_in_key=False) def device_replica_id_map(sharding, global_shape: Shape) -> Mapping[Device, int]: try: device_indices_map_fn = sharding.devices_indices_map @@ -172,7 +94,7 @@ def device_replica_id_map(sharding, global_shape: Shape) -> Mapping[Device, int] return out -@functools.lru_cache(maxsize=4096) +@util.cache(max_size=4096, trace_context_in_key=False) def named_sharding_to_xla_hlo_sharding( self, num_dimensions: int) -> xc.HloSharding: mesh_shape = self.mesh.shape @@ -235,7 +157,7 @@ def named_sharding_to_xla_hlo_sharding( @use_cpp_class(xc.NamedSharding) -class NamedSharding(XLACompatibleSharding): +class NamedSharding(sharding.Sharding): r"""A :class:`NamedSharding` expresses sharding using named axes. A :class:`NamedSharding` is a pair of a :class:`Mesh` of devices and @@ -261,7 +183,7 @@ class NamedSharding(XLACompatibleSharding): mesh: A :class:`jax.sharding.Mesh` object. spec: A :class:`jax.sharding.PartitionSpec` object. - Example: + Examples: >>> from jax.sharding import Mesh >>> from jax.sharding import PartitionSpec as P @@ -284,24 +206,8 @@ def __init__( self.mesh = mesh self.spec = spec self._memory_kind = memory_kind - self._parsed_pspec = _parsed_pspec self._manual_axes = _manual_axes - self._preprocess() - - def _preprocess(self): - # This split exists because you can pass `_parsed_pspec` that has been - # modified from the original. For example: Adding extra dimension to - # axis_resources for vmap handlers. In such cases you need to preserve the - # `sync` attribute of parsed pspecs. - # PartitionSpec is inferred from the parsed pspec in this case. - # TODO(yaskatariya): Remove this and replace this with a normalized - # representation of Parsed Pspec - if self._parsed_pspec is None: - self._parsed_pspec, _, _ = prepare_axis_resources( - PartitionSpec() if self.spec is None else self.spec, - "NamedSharding spec", allow_unconstrained_dims=True) - - _check_mesh_resource_axis(self.mesh, self._parsed_pspec) + self._parsed_pspec = preprocess(self.mesh, self.spec, _parsed_pspec) def __repr__(self): mesh_repr = ", ".join(f"'{k}': {v}" for k, v in self.mesh.shape.items()) @@ -326,15 +232,15 @@ def __hash__(self): def __eq__(self, other): if not isinstance(other, NamedSharding): return False - if id(self) == id(other): + if self is other: return True if (self._parsed_pspec != other._parsed_pspec or self.memory_kind != other.memory_kind or self._manual_axes != other._manual_axes): return False - return id(self.mesh) == id(other.mesh) or self.mesh == other.mesh + return self.mesh is other.mesh or self.mesh == other.mesh - def is_compatible_aval(self, aval_shape: Shape): + def check_compatible_aval(self, aval_shape: Shape) -> None: assert self._parsed_pspec is not None if len(aval_shape) < len(self._parsed_pspec): extra_msg = (' For scalars the PartitionSpec should be P()' @@ -390,19 +296,19 @@ def _to_xla_hlo_sharding(self, num_dimensions: int) -> xc.HloSharding: return named_sharding_to_xla_hlo_sharding(self, num_dimensions) -@functools.lru_cache +@util.cache(max_size=128, trace_context_in_key=False) def get_replicated_hlo_sharding(): return xc.HloSharding.replicate() @use_cpp_class(xc.SingleDeviceSharding) -class SingleDeviceSharding(XLACompatibleSharding): +class SingleDeviceSharding(sharding.Sharding): """A :class:`Sharding` that places its data on a single device. Args: device: A single :py:class:`Device`. - Example: + Examples: >>> single_device_sharding = jax.sharding.SingleDeviceSharding( ... jax.devices()[0]) @@ -431,7 +337,7 @@ def __hash__(self): def __eq__(self, other): if not isinstance(other, SingleDeviceSharding): return False - if id(self) == id(other): + if self is other: return True return (self._device == other._device and self.memory_kind == other.memory_kind) @@ -466,7 +372,7 @@ def is_fully_addressable(self) -> bool: return True -@functools.lru_cache(maxsize=4096) +@util.cache(max_size=4096, trace_context_in_key=False) def pmap_sharding_devices_indices_map( self, global_shape: Shape) -> Mapping[Device, Index]: self.shard_shape(global_shape) # raises a good error message @@ -475,10 +381,11 @@ def pmap_sharding_devices_indices_map( @use_cpp_class(xc.PmapSharding) -class PmapSharding(XLACompatibleSharding): +class PmapSharding(sharding.Sharding): """Describes a sharding used by :func:`jax.pmap`.""" devices: np.ndarray sharding_spec: sharding_specs.ShardingSpec + _internal_device_list: xc.DeviceList @use_cpp_method() def __init__(self, devices: Sequence[Device] | np.ndarray, @@ -494,15 +401,15 @@ def __reduce__(self): def __eq__(self, other): if not isinstance(other, PmapSharding): return False - if id(self) == id(other): + if self is other: return True return (self.sharding_spec == other.sharding_spec and self.devices.shape == other.devices.shape and - self._internal_device_list == other._internal_device_list) # type: ignore + self._internal_device_list == other._internal_device_list) def __hash__(self): if not hasattr(self, '_hash'): - self._hash = hash((self._internal_device_list, self.sharding_spec)) # type: ignore + self._hash = hash((self._internal_device_list, self.sharding_spec)) return self._hash def __str__(self): @@ -578,7 +485,7 @@ def _device_assignment(self) -> XLADeviceAssignment: @property def memory_kind(self) -> str | None: try: - return self._internal_device_list.default_memory_kind # type: ignore + return self._internal_device_list.default_memory_kind except: return None @@ -597,7 +504,7 @@ def is_fully_replicated(self) -> bool: @functools.cached_property def is_fully_addressable(self) -> bool: - return self._internal_device_list.is_fully_addressable # type: ignore + return self._internal_device_list.is_fully_addressable def shard_shape(self, global_shape: Shape) -> Shape: sharded_dim = None @@ -630,7 +537,7 @@ def _op_sharding_to_pos_sharding( device_assignment: Sequence[xc.Device], memory_kind: str | None = None) -> PositionalSharding: if isinstance(op_sharding, xc.OpSharding): - op_sharding = xc.HloSharding.from_proto(op_sharding) # type: ignore + op_sharding = xc.HloSharding.from_proto(op_sharding) if op_sharding.is_replicated(): return PositionalSharding( @@ -653,7 +560,7 @@ def _op_sharding_to_pos_sharding( return p -@functools.lru_cache(maxsize=4096) +@util.cache(max_size=4096, trace_context_in_key=False) def _positional_sharding_to_xla_hlo_sharding( self, num_dimensions: int) -> xc.HloSharding: if self.shape == (1,) * self.ndim: @@ -675,15 +582,14 @@ def _positional_sharding_to_xla_hlo_sharding( return xc.HloSharding.from_proto(pbuf) -class PositionalSharding(XLACompatibleSharding): +class PositionalSharding(sharding.Sharding): _devices: tuple[xc.Device, ...] _memory_kind: str | None _ids: np.ndarray # dtype DeviceIdSet def __init__(self, devices: Sequence[xc.Device] | np.ndarray, *, memory_kind: str | None = None): - if xla_extension_version >= 235: - super().__init__() + super().__init__() if not isinstance(devices, np.ndarray): devices = np.array(devices, dtype='object') if not devices.size: @@ -728,6 +634,13 @@ def replicate(self, axis=None, keepdims=True) -> PositionalSharding: new_ids = self._ids.sum(axis=axis, keepdims=keepdims) # union return self._remake(self._devices, new_ids) + def check_compatible_aval(self, aval_shape: Shape) -> None: + if len(aval_shape) != len(self.shape) and not self.is_fully_replicated: + raise ValueError( + f"Sharding {self} is only valid for values of rank " + f"{len(self.shape)}, but was applied to a value of rank " + f"{len(aval_shape)}") + @classmethod def _remake( cls, devices: tuple[xc.Device, ...], ids: np.ndarray, @@ -750,12 +663,11 @@ def __hash__(self) -> int: def __eq__(self, other) -> bool: if not isinstance(other, PositionalSharding): return False - if id(self) == id(other): + if self is other: return True all_ids_equal = np.array_equal(self._ids,other._ids) mem_kind_equal = self.memory_kind == other.memory_kind - if (id(self._devices) == id(other._devices) and mem_kind_equal and - all_ids_equal): + if self._devices is other._devices and mem_kind_equal and all_ids_equal: return True return (mem_kind_equal and all_ids_equal and self._internal_device_list == other._internal_device_list) @@ -777,7 +689,7 @@ def with_memory_kind(self, kind: str) -> PositionalSharding: def is_fully_replicated(self) -> bool: return self.shape == (1,) * self.ndim - # XLACompatibleSharding interface + # sharding.Sharding interface @property def _device_assignment(self) -> XLADeviceAssignment: @@ -820,21 +732,13 @@ def __eq__(self, other) -> bool: self._ids == other._ids) -@functools.lru_cache(maxsize=4096) -def gspmd_sharding_devices_indices_map( - self, global_shape: Shape) -> Mapping[Device, Index]: - self.shard_shape(global_shape) # raises a good error message - indices = op_sharding_to_indices(self._hlo_sharding, global_shape, - len(self._devices)) - return dict(safe_zip(self._devices, indices)) - - @use_cpp_class(xc.GSPMDSharding) -class GSPMDSharding(XLACompatibleSharding): +class GSPMDSharding(sharding.Sharding): _devices: tuple[Device, ...] _hlo_sharding: xc.HloSharding _memory_kind: str | None _device_list: xc.DeviceList | None + _internal_device_list: xc.DeviceList @use_cpp_method() def __init__(self, devices: Sequence[Device], @@ -861,15 +765,15 @@ def _hlo_sharding_hash(self): def __eq__(self, other): if not isinstance(other, GSPMDSharding): return False - if id(self) == id(other): + if self is other: return True return (are_op_shardings_equal(self._hlo_sharding, other._hlo_sharding) and self.memory_kind == other.memory_kind - and self._internal_device_list == other._internal_device_list) # type: ignore + and self._internal_device_list == other._internal_device_list) def __hash__(self): if not hasattr(self, '_hash'): - self._hash = hash((self._internal_device_list, self._hlo_sharding_hash, # type: ignore + self._hash = hash((self._internal_device_list, self._hlo_sharding_hash, self.memory_kind)) return self._hash @@ -877,7 +781,7 @@ def __repr__(self): mem = '' if self._memory_kind is None else f', memory_kind={self._memory_kind}' return f'GSPMDSharding({self._hlo_sharding!r}{mem})' - def is_compatible_aval(self, aval_shape: Shape): + def check_compatible_aval(self, aval_shape: Shape) -> None: num_ways_dim_sharded, _ = get_num_ways_dim_sharded(self._hlo_sharding) if len(aval_shape) < len(num_ways_dim_sharded): raise ValueError( @@ -896,9 +800,6 @@ def memory_kind(self) -> str | None: def with_memory_kind(self, kind: str) -> GSPMDSharding: return GSPMDSharding(self._devices, self._hlo_sharding, memory_kind=kind) - def devices_indices_map(self, global_shape: Shape) -> Mapping[Device, Index]: - return gspmd_sharding_devices_indices_map(self, global_shape) - @property def _device_assignment(self) -> XLADeviceAssignment: return self._devices @@ -912,7 +813,7 @@ def is_fully_replicated(self) -> bool: @functools.cached_property def is_fully_addressable(self) -> bool: - return self._internal_device_list.is_fully_addressable # type: ignore + return self._internal_device_list.is_fully_addressable @classmethod def get_replicated(cls, device_assignment, *, memory_kind: str | None = None): @@ -1063,7 +964,9 @@ def from_user_input(cls, entry, arg_name, allow_unconstrained_dims=False): else: axis_spec = (axis_spec,) axis_specs.append(axis_spec) - return cls(entry, axis_specs) + new_entry = PartitionSpec( + *[tuple(e) if isinstance(e, (list, tuple)) else e for e in entry]) + return cls(new_entry, axis_specs) def __hash__(self): return hash((self.partitions, self.sync)) @@ -1115,6 +1018,25 @@ def __repr__(self): f"sync={self.sync})") +def preprocess(mesh, spec, parsed_pspec, _manual_axes=frozenset()): + # This split exists because you can pass `_parsed_pspec` that has been + # modified from the original. For example: Adding extra dimension to + # axis_resources for vmap handlers. In such cases you need to preserve the + # `sync` attribute of parsed pspecs. + # PartitionSpec is inferred from the parsed pspec in this case. + # TODO(yaskatariya): Remove this and replace this with a normalized + # representation of Parsed Pspec + if parsed_pspec is None: + parsed_pspec = prepare_axis_resources( + PartitionSpec() if spec is None else spec, + "NamedSharding spec", allow_unconstrained_dims=True) + + _check_mesh_resource_axis(mesh, parsed_pspec, _manual_axes) + return parsed_pspec + +# fallback for c++ . +preprocess_with_manual = preprocess + def prepare_axis_resources(axis_resources, arg_name, allow_unconstrained_dims=False): @@ -1131,23 +1053,20 @@ def prepare_axis_resources(axis_resources, if isinstance(entry, PmapSharding): raise ValueError(f'One of {what} got sharding {entry} which is not ' 'allowed.') - if not isinstance(entry, XLACompatibleSharding): - raise ValueError(f'One of {what} got sharding {entry} which is not a ' - 'subclass of XLACompatibleSharding.') new_entries.append(entry) else: new_entries.append(ParsedPartitionSpec.from_user_input( entry, what, allow_unconstrained_dims=allow_unconstrained_dims)) _check_unique_resources(new_entries, arg_name) - return tree_util.tree_unflatten(treedef, new_entries), new_entries, treedef + return tree_util.tree_unflatten(treedef, new_entries) def _check_unique_resources(axis_resources, arg_name): for arg_axis_resources in axis_resources: if not arg_axis_resources: continue if (is_unspecified_or_auto(arg_axis_resources) or - isinstance(arg_axis_resources, XLACompatibleSharding)): + isinstance(arg_axis_resources, sharding.Sharding)): continue constrained_dims = [d for d in arg_axis_resources if d is not None] resource_counts = collections.Counter( @@ -1184,11 +1103,6 @@ class SPMDAxisContext: def axis_env(self): # All collectives that touch axis_env should remember to set use_global_device_ids # when this context is enabled! - if self.manual_axes != frozenset(self.mesh.axis_names): - raise NotImplementedError( - "Collectives in manually partitioned computations are only supported " - "when all mesh axes are partitioned manually (no partial automatic sharding). " - "Make sure that you mention all mesh axes in axis_resources!") return self.unsafe_axis_env @property @@ -1334,7 +1248,7 @@ def explode_superdims(sizes, dims): def parse_flatten_op_sharding(hlo_sharding: xc.OpSharding | xc.HloSharding, mesh: mesh_lib.Mesh) -> Sequence[ParsedPartitionSpec]: if isinstance(hlo_sharding, xc.OpSharding): - hlo_sharding = xc.HloSharding.from_proto(hlo_sharding) # type: ignore + hlo_sharding = xc.HloSharding.from_proto(hlo_sharding) if hlo_sharding.tuple_elements(): out: list[ParsedPartitionSpec] = [] for s in hlo_sharding.tuple_elements(): @@ -1370,3 +1284,337 @@ def parse_flatten_op_sharding(hlo_sharding: xc.OpSharding | xc.HloSharding, ParsedPartitionSpec('', partitions))] else: raise AssertionError("Unhandled OpSharding type. Please open a bug report!") + + +def _slice_as_tuple(s: slice): + assert s.step is None + return (s.start, s.stop) + + +class NonUniformShardingError(ValueError): + """Raised when sharding is not uniform across processes.""" + + +def get_process_index_and_count( + tensor_sharding: sharding.Sharding, dim: int, ndims: int) -> tuple[int, int]: + """Get current process index and number of unique processes for given dimension. + + This function facilitates mapping of process-level data to individual + devices. Each process can use its index to obtain the data corresponding + to that index. If process level data is sharded on multiple dimensions + this function can be used to build the cross product of indices in + each sharded axis. Processes that need to load the same data will have + the same index. For shardings whose per-process data is not distributed + on a grid, the number of distinct shards will be such that it is possible to + build the target shape while maintaining a "cube" shape of local-process data. + + For example, in case of 4 hosts with sharding distributed like so: + + 1234 + 2143 + + For dim 0 (rows): all processes need to access all rows, so we return (0, 1) + For dim 1 (cols): + process 1 and 2 returns index 0 out of 2 (need cols 0 and 1), + process 3 and 4 returns index 1 out of 2 (need cols 2 and 3). + + On the other hand, for a sharding like: + + 1212 + 3434 + + Dim 0 (rows): process 1 and 2 returns (0, 2), process 3 and 4 returns (1, 2) + Dim 1 (cols): process 1 and 3 returns (0, 2), process 2 and 4 returns (1, 2) + + Note: This function requires sharding to be process uniform in dimension + `dim`: + each process has the same number of addressable indices in that + dimension and all index sets across processes are either disjoint or the same. + + For sharding to be process uniform the addressable shards doesn't need to + form contiguous subtensor, or even a sparse grid and in case of + interleaved high-dimensional tensor it is possible for sharding to be + process uniform only in some dimensions but not others. + + For example: + 1111 and 12 and 1212 and 1212 + 2222 21 2121 1212 + + are all sharding uniform, in both dimensions. However + + 1122 + 2121 + 1121 + 1222 + + is uniform in dimension 0 (both hosts access all rows), but + is not uniform in dimension 1 (host 1 accesses columns: 0, 1, and 3), + while host 2 accesses (0, 1, 2, 3). + + Returns: + A tuple of (index, num_distinct_shards) for the given dimension. + It is guaranteed that `index` will cover 0 to `num_distinct_shards - 1`, + across all processes. + + Raises: + NonUniformShardingError: if the sharding is not process uniform in dimension + `dim`. + """ + # TODO(sandler, yashkatariya): Consider making this function public. + + if (tensor_sharding.is_fully_addressable or + tensor_sharding.is_fully_replicated): + return (0, 1) + num_devices = len(tensor_sharding.device_set) + # Get device to indices map, we don't care about the concrete + # global shape here, only to get the distribution of shards across the tensor + # using (num_devices, num_devices, ...) This is a universal shape that is + # compatible with any mesh with num_devices. + device_map = tensor_sharding.devices_indices_map((num_devices,) * ndims) + + # Get the slices for 'dim' for all devices. + global_slice = {k: v[dim] for k, v in device_map.items()} + + # Contains mapping from process_index to a set of slices for that process. + process_to_slice = collections.defaultdict(set) + # Contains global set of slices across all processes. + all_slices = set() + + # Compute the set of slices for each process and the global set of slices. + for d, v in global_slice.items(): + key = (v.start, v.stop) + process_to_slice[d.process_index].add(key) + all_slices.add(key) + + # Get the set of slices for the current process which we will use to compute + # the index of the current process. + current_pid = next(iter(tensor_sharding.addressable_devices)).process_index + addressable_slices = frozenset(process_to_slice[current_pid]) + + # Verify that all processes have the same number of slices. + slices_per_process = len(addressable_slices) + if any(len(x) != slices_per_process for x in process_to_slice.values()): + raise NonUniformShardingError( + f'{tensor_sharding=} is non-uniform on {dim=} as some processes have ' + 'different number of slices.' + ) + unique_processes = list({frozenset(x) for x in process_to_slice.values()}) + + # After removing duplicate processes all unique slices should + # cover the dimension exactly once. If they don' it means that + # the sharding is not uniform. + if sum(len(h) for h in unique_processes) != len(all_slices): + raise NonUniformShardingError( + f'{tensor_sharding=} is non-uniform on {dim=}' + ) + return (unique_processes.index(addressable_slices), len(unique_processes)) + + +def local_to_global_shape( + sharding: sharding.Sharding, local_shape: Shape) -> tuple[int | None, ...]: + """Computes the global shape given the per process if possible. + + The returned shape will have the size of the global tensor in that dimension + or None, if it is not computable. The latter can happen when sharding + is not uniform along that dimension, e.g. different hosts require + different shapes, or if different processes have partial data overlap. + + If at most one dimension is sharded the shape is always computable. + Generally, global shape is computable for most practical meshes (including + topology aware such as meshes returned by mesh_utils.create_device_mesh) + + Some examples: Suppose mesh is {'a': 2, 'b': 2, 'c': 2} with 2 devices + per host, 4 hosts total. For different specs we get: + - P(): + global_shape = local_shape + + - P(('a', 'b', 'c'), None): + global_shape = (4 * local_shape[0], local_shape[1]) + Note: per device shape is (local_shape[0] / 2, local_shape[1]) + + - P(('a', 'b'), None) + global_shape = (4 * local_shape[0], local_shape[1]) + # NB: the same global shape as above, since sharding along 'c' dimension + # happens to be within process, and thus doesn't affect the global shape. + # The underlying difference will be in the per *device* shape, which + # would be (local_shape[0], local_shape[1]) in this case. + + - P(None, ('a', 'c')) + global_shape = (local_shape[0], 2 * local_shape[1]) + # Per device shape is (local_shape[0], local_shape[1] / 2) + - P(('a', 'c'), 'b'): + global_shape = (2 * local_shape[0], 2 * local_shape[1]) + # Per device shape is (local_shape[0] / 2, local_shape[1]) + - If devices in the Mesh are randomly permuted: For any partition spec + which shards more than 1 axis: e.g. P('a', ('b', 'c')): + global_shape = (None, None) + + Args: + local_shape: global shape of the tensor. + + Returns: + global_shape with Nones in non-uniform dimensions. + """ + + global_shape : list[int | None] = [None] * len(local_shape) + for i, local_dim in enumerate(local_shape): + try: + _, shard_count = get_process_index_and_count( + sharding, i, ndims=len(local_shape)) + global_shape[i] = local_dim * shard_count + except NonUniformShardingError: + global_shape[i] = None + continue + + return tuple(global_shape) + + +def num_addressable_indices( + tensor_sharding: sharding.Sharding, dim: int, global_shape: Shape) -> int: + """Returns the number of indices for given dimension this host has access to. + + Each host can have multiple number of devices that are spanning + possibly discontiguous slices of data. This function computes the + total number of unique indices for dimension `dim` that any of its + addressable devices hold. + + In most cases the addressable indices form a sparse grid (and in some + cases a subcube), and thus each host will hold the same of number of + indices for each dimension. However, it is possible to design a mesh that + addressable shards form a complicated pattern. In that case, the returned + value is the number of indices that are addressable by at least one device. + + For example, suppose the sharding looks like this: (number indicates + the host index) + + 1221 + 1221 + 0000 + + Then on host 1 and 2, both dim 0 (rows), and dim=1 (cols) will have size 2, + while on host 0, dim 0 will have size 1, and dim 1 will have size 4. + + Args: + tensor_sharding: Sharding of the tensor. + dim: dimension along which to compute the number of addressable indices. + global_shape: global shape of the tensor. + + Returns: + The number of indices for dimension `dim` that this host holds. + """ + # TODO(sandler, yashkatariya): Consider making this function public. + addressables = tensor_sharding.addressable_devices_indices_map(global_shape) + addressables = cast(Mapping[sharding.Device, Index], addressables) + num_unique_slices = len({ + _slice_as_tuple(addressable[dim]) for addressable in addressables.values() + }) + shard_size = tensor_sharding.shard_shape(global_shape)[dim] + return shard_size * num_unique_slices + + +def physical_hlo_sharding(aval, hlo_sharding: xc.HloSharding) -> xc.HloSharding: + elt_aval = aval.dtype._rules.physical_element_aval(aval.dtype) + new_op_sharding = hlo_sharding.to_proto().clone() + partitions, num_replicas = get_num_ways_dim_sharded(hlo_sharding) + suffix = [] if num_replicas == 1 else [num_replicas] + tad = partitions + [1] * elt_aval.ndim + suffix + new_op_sharding.tile_assignment_dimensions = tad + return xc.HloSharding.from_proto(new_op_sharding) + +def is_single_device_sharding(sharding: sharding.Sharding) -> bool: + # Special case PmapSharding here because PmapSharding maps away an axis + # and needs to be handled separately.test_pjit_single_device_sharding_add + return len(sharding.device_set) == 1 and not isinstance(sharding, PmapSharding) + +def make_key_array_phys_sharding(aval, sharding): + if is_single_device_sharding(sharding): + return sharding + elif isinstance(sharding, PmapSharding): + elt_aval = aval.dtype._rules.physical_element_aval(aval.dtype) + trailing_sharding = [sharding_specs.NoSharding()] * elt_aval.ndim + phys_sharding_spec = sharding_specs.ShardingSpec( + sharding=(*sharding.sharding_spec.sharding, *trailing_sharding), + mesh_mapping=sharding.sharding_spec.mesh_mapping) + return PmapSharding(devices=sharding.devices, + sharding_spec=phys_sharding_spec) + elif isinstance(sharding, NamedSharding): + elt_aval = aval.dtype._rules.physical_element_aval(aval.dtype) + trailing_spec = [None] * elt_aval.ndim + return NamedSharding( + sharding.mesh, + PartitionSpec(*sharding.spec, *trailing_spec)) + else: + hlos = sharding._to_xla_hlo_sharding(aval.ndim) + return GSPMDSharding( + sharding._device_assignment, physical_hlo_sharding(aval, hlos)) + + +def physical_sharding( + aval, sharding: sharding.Sharding) -> sharding.Sharding: + return make_key_array_phys_sharding(aval, sharding) + + +def get_logical_gspmd_sharding(aval, phys_sharding): + elt_aval = aval.dtype._rules.physical_element_aval(aval.dtype) + phys_hlo_sharding = phys_sharding._to_xla_hlo_sharding( + aval.ndim + elt_aval.ndim) + partitions, num_replicas = get_num_ways_dim_sharded(phys_hlo_sharding) + suffix = [] if num_replicas == 1 else [num_replicas] + # Create logical sharding by cutting off the replicated trailing dims. + logical_op_sharding = phys_hlo_sharding.to_proto().clone() + tad = partitions[:-elt_aval.ndim] + suffix + logical_op_sharding.tile_assignment_dimensions = tad + return GSPMDSharding(phys_sharding._device_assignment, + xc.HloSharding.from_proto(logical_op_sharding)) + +def check_replicated_trailing_dims(sharding: sharding.Sharding, aval): + if isinstance(sharding, PmapSharding): + return + phys_aval = core.physical_aval(aval) + hlo_s = sharding._to_xla_hlo_sharding(phys_aval.ndim) + partitions, _ = get_num_ways_dim_sharded(hlo_s) + num_trailing_dims = phys_aval.ndim - aval.ndim + if not all(i == 1 for i in partitions[-num_trailing_dims:]): + raise AssertionError( + "The trailing dims of extended dtypes should be replicated. Got" + f" sharding: {sharding}, partitions: {partitions}, " + f"num_trailing_dims: {num_trailing_dims}") + +def logical_sharding(aval, phys_sharding) -> sharding.Sharding: + # The trailing dims should always be replicated. + check_replicated_trailing_dims(phys_sharding, aval) + + if is_single_device_sharding(phys_sharding): + return phys_sharding + elif isinstance(phys_sharding, PmapSharding): + elt_aval = aval.dtype._rules.physical_element_aval(aval.dtype) + logical_sharding_spec = sharding_specs.ShardingSpec( + sharding=phys_sharding.sharding_spec.sharding[:-elt_aval.ndim], + mesh_mapping=phys_sharding.sharding_spec.mesh_mapping) + return PmapSharding(devices=phys_sharding.devices, + sharding_spec=logical_sharding_spec) + elif isinstance(phys_sharding, NamedSharding): + logical_gs = get_logical_gspmd_sharding(aval, phys_sharding) + return _gspmd_to_named_sharding_via_mesh( + logical_gs, phys_sharding.mesh) + else: + return get_logical_gspmd_sharding(aval, phys_sharding) + + +@util.cache() +def create_mesh_pspec_sharding( + mesh: mesh_lib.Mesh, pspec: PartitionSpec | None, parsed_pspec=None, + memory_kind: str | None = None) -> NamedSharding: + if pspec is None: + pspec, parsed_pspec = PartitionSpec(), None + return NamedSharding(mesh, pspec, _parsed_pspec=parsed_pspec, + memory_kind=memory_kind) + + +def _gspmd_to_named_sharding_via_mesh( + out_s: GSPMDSharding, mesh: mesh_lib.Mesh) -> NamedSharding: + parsed_pspec = parse_flatten_op_sharding( + out_s._hlo_sharding, mesh)[0] + return create_mesh_pspec_sharding( + mesh, parsed_pspec.get_partition_spec(), parsed_pspec, + out_s.memory_kind) diff --git a/jax/_src/sharding_specs.py b/jax/_src/sharding_specs.py index 02d6276a513a..7092b51ab894 100644 --- a/jax/_src/sharding_specs.py +++ b/jax/_src/sharding_specs.py @@ -55,22 +55,6 @@ ShardingSpec = pmap_lib.ShardingSpec -def _sharding_spec_mesh_shape(self): - sharded_axis_sizes = [] - for sharding in self.sharding: - if isinstance(sharding, NoSharding): - continue - elif isinstance(sharding, Unstacked): - sharded_axis_sizes.append(sharding.size) - elif isinstance(sharding, Chunked): - sharded_axis_sizes.extend(sharding.chunks) - else: - util.assert_unreachable(sharding) - return tuple(sharded_axis_sizes[a.axis] if isinstance(a, ShardedAxis) - else a.replicas - for a in self.mesh_mapping) - - def _sharding_spec_indices(self, shape: tuple[int, ...]) -> np.ndarray: """Returns NumPy-style indices corresponding to a sharding spec. @@ -134,7 +118,6 @@ def _sharding_spec_repr(self): return f'ShardingSpec({self.sharding}, {self.mesh_mapping})' -ShardingSpec.mesh_shape = property(_sharding_spec_mesh_shape) ShardingSpec.indices = _sharding_spec_indices # mypy raises: error: Cannot assign to a method [assignment] ShardingSpec.__repr__ = _sharding_spec_repr # type: ignore diff --git a/jax/_src/source_info_util.py b/jax/_src/source_info_util.py index 356e5c0d883f..718c86aa9ab7 100644 --- a/jax/_src/source_info_util.py +++ b/jax/_src/source_info_util.py @@ -142,10 +142,17 @@ def new_name_stack(name: str = '') -> NameStack: return name_stack -class SourceInfo(NamedTuple): +class SourceInfo: traceback: Traceback | None name_stack: NameStack + # It's slightly faster to use a class with __slots__ than a NamedTuple. + __slots__ = ['traceback', 'name_stack'] + + def __init__(self, traceback: Traceback | None, name_stack: NameStack): + self.traceback = traceback + self.name_stack = name_stack + def replace(self, *, traceback: Traceback | None = None, name_stack: NameStack | None = None) -> SourceInfo: return SourceInfo( @@ -188,7 +195,7 @@ def user_frames(source_info: SourceInfo) -> Iterator[Frame]: # frames, to allow testing this mechanism from tests. traceback = source_info.traceback code, lasti = traceback.raw_frames() if traceback else ([], []) - return (raw_frame_to_frame(code[i], lasti[i]) for i in range(len(code)) # type: ignore + return (raw_frame_to_frame(code[i], lasti[i]) for i in range(len(code)) if is_user_filename(code[i].co_filename)) @functools.lru_cache(maxsize=64) diff --git a/jax/_src/sourcemap.py b/jax/_src/sourcemap.py new file mode 100644 index 000000000000..b54f2193ff26 --- /dev/null +++ b/jax/_src/sourcemap.py @@ -0,0 +1,236 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +An implementation of sourcemaps following `TC39 `_. +""" + +from __future__ import annotations + +from collections.abc import Iterable, Sequence +from dataclasses import dataclass +import json +from typing import Union + +# A Segment encodes how parts in the generated source relate to the original source. +# Each segment is made up of 1, 4 or 5 variable-length fields. For their semantics see +# https://tc39.es/source-map/#mappings-structure +Segment = Union[ + tuple[int], tuple[int, int, int, int], tuple[int, int, int, int, int] +] + +# Mappings are sequences of segments for each line in the generated source. +Mappings = Sequence[Sequence[Segment]] + + +@dataclass(frozen=True) +class SourceMap: + version: int + # file: str + # source_root: str + sources: Sequence[str] + sources_content: Sequence[str] + names: Sequence[str] + mappings: Mappings + + @classmethod + def from_json(cls, json_data: str) -> SourceMap: + """Deserialize a source map from JSON.""" + data = json.loads(json_data) + return cls( + version=data["version"], + sources=data["sources"], + sources_content=data["sourcesContent"], + names=data["names"], + mappings=deserialize_mappings(data["mappings"]), + ) + + def to_json(self) -> str: + """Serialize a source map to JSON.""" + data = { + "version": self.version, + "sources": self.sources, + "sourcesContent": self.sources_content, + "names": self.names, + "mappings": serialize_mappings(self.mappings), + } + return json.dumps(data) + + +VLQ_SIGN_MASK = 0x01 +VLQ_MORE_MASK = 0x20 +VLQ_VALUE_MASK = 0x1F +VLQ_VALUE_BITWIDTH = 5 +VLQ_ALPHABET = ( + list(range(ord("A"), ord("Z") + 1)) + + list(range(ord("a"), ord("z") + 1)) + + list(range(ord("0"), ord("9") + 1)) + + [ord("+"), ord("/")] +) + + +def make_vlq_decode_table(): + lookup = {c: d for d, c in enumerate(VLQ_ALPHABET)} + return [lookup.get(i, None) for i in range(256)] + + +VLQ_DECODE_TABLE = make_vlq_decode_table() + + +def decode_vlq(enc: Iterable[int]) -> int: + """Decode a Base-64-VLQ into an integer.""" + enc_iter = iter(enc) + d = VLQ_DECODE_TABLE[next(enc_iter)] + sign = bool(d & VLQ_SIGN_MASK) + value = (d & VLQ_VALUE_MASK) >> 1 + # Compensate for first quantum containing sign as LSB: + shift = -1 + + while d & VLQ_MORE_MASK: + shift += VLQ_VALUE_BITWIDTH + d = VLQ_DECODE_TABLE[next(enc_iter)] + value |= (d & VLQ_VALUE_MASK) << shift + + return -value if sign else value + + +def encode_vlq(value: int) -> bytes: + """Encode an integer into a Base-64-VLQ.""" + # Move sign to LSB + value = ((-value) << 1 | 1) if value < 0 else value << 1 + buf = [] + + while True: + d = value & VLQ_VALUE_MASK + value >>= VLQ_VALUE_BITWIDTH + more = value > 0 + if more: + d |= VLQ_MORE_MASK + buf.append(VLQ_ALPHABET[d]) + if not more: + break + return bytes(buf) + + +def decode_segment(enc: Iterable[int]) -> Segment: + """Decode a sequence of VLQs into a segment.""" + enc_iter = iter(enc) + col = decode_vlq(enc_iter) + try: + source = decode_vlq(enc_iter) + except StopIteration: + # Stopping here is fine (1-segment). + return (col,) + source_line = decode_vlq(enc_iter) + source_col = decode_vlq(enc_iter) + try: + name = decode_vlq(enc_iter) + except StopIteration: + # Stopping here is fine too (4-segment). + return col, source, source_line, source_col + # (5-segment) + return col, source, source_line, source_col, name + + +def encode_segment(seg: Segment) -> bytes: + """Encode a segment into a sequence of VLQs.""" + return b"".join(encode_vlq(value) for value in seg) + + +def deserialize_mappings(mappings_str: str) -> Mappings: + """Decode a string of TC39 mapping data.""" + mappings_bytes = bytes(mappings_str, encoding="ascii") + return [ + list(map(decode_segment, mapping.split(b","))) if mapping else [] + for mapping in mappings_bytes.split(b";") + ] + + +def serialize_mappings(mappings: Mappings) -> str: + """Encode mappings into a string of TC39 mapping data.""" + enc = b";".join( + b",".join(encode_segment(seg) for seg in segs) for segs in mappings + ) + return enc.decode("ascii") + + +class MappingsGenerator: + """MappingsGenerator is a builder API for mappings. + + TC39 mapping data is inconvenient to emit directly: in an effort to compress + data + it encodes most indices using values _relative_ to the previous element. + MappingsGenerator simplifies things by taking absolute indices everywhere. + """ + + def __init__(self): + self._last_col = None + self._last_source = 0 + self._last_source_line = 0 + self._last_source_col = 0 + self._last_name = 0 + self._mappings = [] + self._cur_group = None + + def new_group(self): + """Start a new group (line).""" + self._last_col = 0 + self._cur_group = [] + self._mappings.append(self._cur_group) + + def new_segment(self, *seg): + """Start a new source mapping segment in the current group. + + Args: + *seg: A segment as in TC39, but all indices are absolute. See + https://tc39.es/source-map/#mappings-structure for details. + + Raises: + RuntimeError: If no current group exists. + """ + assert len(seg) >= 1 + group = self._cur_group + if group is None: + raise RuntimeError("No current group. Forgot to call new_group()?") + + col = seg[0] - self._last_col + self._last_col = seg[0] + + if len(seg) == 1: + group.append((col,)) + return + + source = seg[1] - self._last_source + self._last_source = seg[1] + source_line = seg[2] - self._last_source_line + self._last_source_line = seg[2] + source_col = seg[3] - self._last_source_col + self._last_source_col = seg[3] + + if len(seg) == 4: + group.append((col, source, source_line, source_col)) + return + + name = seg[4] - self._last_name + self._last_name = seg[4] + + if len(seg) == 5: + group.append((col, source, source_line, source_col, name)) + return + + assert False, "invalid segment" + + def mappings(self) -> Mappings: + """Return the mapping as a list of segments per line.""" + return self._mappings diff --git a/jax/_src/stages.py b/jax/_src/stages.py index c05007ccacea..874ef8834557 100644 --- a/jax/_src/stages.py +++ b/jax/_src/stages.py @@ -30,9 +30,10 @@ """ from __future__ import annotations +import functools from collections.abc import Sequence from dataclasses import dataclass -from typing import Any, NamedTuple, Protocol, Union +from typing import Any, NamedTuple, Protocol, Union, runtime_checkable import warnings import jax @@ -44,7 +45,8 @@ from jax._src import tree_util from jax._src.tree_util import tree_unflatten, keystr from jax._src import util -from jax._src.layout import SpecifiedLayout +from jax._src.sharding_impls import is_unspecified_or_auto +from jax._src.layout import Layout from jax._src.interpreters import mlir from jax._src.lib.mlir import ir from jax._src.lib import xla_client as xc @@ -71,7 +73,7 @@ def call(self, *args_flat) -> Sequence[Any]: # TODO(frostig): improve annotation (sequences of arrays/buffers) raise NotImplementedError - def input_shardings(self) -> Sequence[jax.sharding.XLACompatibleSharding]: + def input_shardings(self) -> Sequence[jax.sharding.Sharding]: """Flat sequence of input shardings. May raise ``NotImplementedError`` if unavailable, e.g. based on backend, @@ -79,7 +81,7 @@ def input_shardings(self) -> Sequence[jax.sharding.XLACompatibleSharding]: """ raise NotImplementedError - def output_shardings(self) -> Sequence[jax.sharding.XLACompatibleSharding]: + def output_shardings(self) -> Sequence[jax.sharding.Sharding]: """Flat sequence of output shardings. May raise ``NotImplementedError`` if unavailable, e.g. based on backend, @@ -87,12 +89,10 @@ def output_shardings(self) -> Sequence[jax.sharding.XLACompatibleSharding]: """ raise NotImplementedError - # Layouts are exposed via jax.experimental.layouts - # TODO(frostig,yashkatariya): expose here when no longer experimental. - def _input_layouts(self): + def input_layouts(self): raise NotImplementedError - def _output_layouts(self): + def output_layouts(self): raise NotImplementedError def as_text(self) -> str: @@ -219,19 +219,19 @@ def xla_extension_executable(self) -> xc.LoadedExecutable: def call(self, *args_flat) -> Sequence[Any]: raise NotImplementedError("must override") - def input_shardings(self) -> Sequence[jax.sharding.XLACompatibleSharding]: + def input_shardings(self) -> Sequence[jax.sharding.Sharding]: raise NotImplementedError( "compiled executable carries no input sharding information") - def output_shardings(self) -> Sequence[jax.sharding.XLACompatibleSharding]: + def output_shardings(self) -> Sequence[jax.sharding.Sharding]: raise NotImplementedError( "compiled executable carries no output sharding information") - def _input_layouts(self): + def input_layouts(self): raise NotImplementedError( "compiled executable carries no input layout information") - def _output_layouts(self): + def output_layouts(self): raise NotImplementedError( "compiled executable carries no input layout information") @@ -314,9 +314,11 @@ class XlaLowering(Lowering): def hlo(self) -> xc.XlaComputation: """Return an HLO representation of this computation.""" + hlo = self.stablehlo() + m: str | bytes + m = mlir.module_to_bytecode(hlo) return xla_extension.mlir.mlir_module_to_xla_computation( - mlir.module_to_string(self.stablehlo()), - use_tuple_args=self.compile_args["tuple_args"]) + m, use_tuple_args=self.compile_args["tuple_args"]) def mhlo(self) -> ir.Module: """Return an MHLO representation of this computation.""" @@ -368,11 +370,26 @@ def cost_analysis(self) -> dict[str, float]: # -- Public-facing API, plus helpers -@dataclass +@dataclass(frozen=True) class ArgInfo: - aval: core.AbstractValue + _aval: core.AbstractValue donated: bool + @property + def shape(self): + return self._aval.shape # pytype: disable=attribute-error + + @property + def dtype(self): + return self._aval.dtype # pytype: disable=attribute-error + + +@dataclass(frozen=True) +class OutInfo: + shape: tuple[int, ...] + dtype: jax.typing.DTypeLike + sharding: jax.sharding.Sharding | None = None + class Stage: args_info: Any # PyTree of ArgInfo @@ -385,7 +402,7 @@ def in_tree(self) -> tree_util.PyTreeDef: @property def in_avals(self): """Tree of input avals.""" - return tree_util.tree_map(lambda x: x.aval, self.args_info) + return tree_util.tree_map(lambda x: x._aval, self.args_info) @property def donate_argnums(self): @@ -409,6 +426,37 @@ class CompiledCallParams(NamedTuple): out_tree: tree_util.PyTreeDef +class Traced(Stage): + __slots__ = ["jaxpr", "args_info", "fun_name", "_out_tree", "_lower_callable", + "_args_flat", "_arg_names", "_num_consts"] + + def __init__(self, jaxpr: core.ClosedJaxpr, args_info, fun_name, out_tree, + lower_callable, args_flat=None, arg_names=None, + num_consts: int = 0): + self.jaxpr = jaxpr + self.args_info = args_info + self.fun_name = fun_name + self._out_tree = out_tree + self._lower_callable = lower_callable + self._args_flat = args_flat + self._arg_names = arg_names + self._num_consts = num_consts + + @property + def out_info(self): + return self._out_tree.unflatten( + [OutInfo(o.shape, o.dtype) for o in self.jaxpr.out_avals]) + + def lower(self, lowering_platforms: tuple[str, ...] | None = None, + _private_parameters: mlir.LoweringParameters | None = None): + if _private_parameters is None: + _private_parameters = mlir.LoweringParameters() + new_callable = functools.partial( + self._lower_callable, lowering_platforms=lowering_platforms, + lowering_parameters=_private_parameters) + return Lowered(new_callable(), self.args_info, self._out_tree) + + class Compiled(Stage): """Compiled representation of a function specialized to types/values. @@ -496,28 +544,33 @@ def runtime_executable(self) -> Any | None: return self._executable.runtime_executable() @property - def input_shardings(self): # PyTree[sharding.XLACompatibleSharding] + def input_shardings(self): # PyTree[sharding.Sharding] shardings_flat = self._executable.input_shardings() + # Some input shardings got DCE'd + if self.in_tree.num_leaves > len(shardings_flat): + iter_shardings_flat = iter(shardings_flat) + shardings_flat = [next(iter_shardings_flat) if i in self._executable._kept_var_idx + else None for i in range(self.in_tree.num_leaves)] return tree_util.tree_unflatten(self.in_tree, shardings_flat) # pytype: disable=attribute-error @property - def output_shardings(self): # PyTree[sharding.XLACompatibleSharding] + def output_shardings(self): # PyTree[sharding.Sharding] shardings_flat = self._executable.output_shardings() return tree_util.tree_unflatten(self.out_tree, shardings_flat) # pytype: disable=attribute-error - def _input_layouts(self): + def input_layouts(self): layouts_flat = self._executable.input_layouts() - assert all(isinstance(l, SpecifiedLayout) for l in layouts_flat) + assert all(isinstance(l, Layout) for l in layouts_flat) # Some input layouts got DCE'd if self.in_tree.num_leaves > len(layouts_flat): iter_layouts_flat = iter(layouts_flat) layouts_flat = [next(iter_layouts_flat) if i in self._executable._kept_var_idx - else None for i in range(self.in_tree.num_leaves)] + else Layout() for i in range(self.in_tree.num_leaves)] return tree_util.tree_unflatten(self.in_tree, layouts_flat) # pytype: disable=attribute-error - def _output_layouts(self): + def output_layouts(self): layouts_flat = self._executable.output_layouts() - assert all(isinstance(l, SpecifiedLayout) for l in layouts_flat) + assert all(isinstance(l, Layout) for l in layouts_flat) return tree_util.tree_unflatten(self.out_tree, layouts_flat) # pytype: disable=attribute-error @staticmethod @@ -540,18 +593,18 @@ def call(*args, **kwargs): leaf = PytreeLeaf() this_dummy = tree_unflatten(in_tree, [leaf] * in_tree.num_leaves) other_dummy = tree_unflatten( - params.in_tree, [leaf] * params.in_tree.num_leaves) # type: ignore + params.in_tree, [leaf] * params.in_tree.num_leaves) errs = list(tree_util.equality_errors(this_dummy, other_dummy)) msg = [] msg.append( "Function compiled with input pytree does not match the input pytree" f" it was called with. There are {len(errs)} mismatches, including:") for path, thing1, thing2, explanation in errs: - fst, *rest = path # type: ignore + fst, *rest = path base = ['args', 'kwargs'][fst.idx] msg.append( - f" * at {base}{keystr(rest)}, seen {thing2} but now given" # type: ignore - f" {thing1}, so {explanation}") + f" * at {base}{keystr(tuple(rest))}, seen {thing2} but now" + f" given {thing1}, so {explanation}") raise TypeError('\n'.join(msg)) try: out_flat = params.executable.call(*args_flat) @@ -601,11 +654,10 @@ class Lowered(Stage): querying properties of lowered computations across JAX's various lowering paths (:func:`~jax.jit`, :func:`~jax.pmap`, etc.). """ - __slots__ = ["args_info", "out_tree", "_lowering", "_no_kwargs"] - + __slots__ = ["_lowering", "args_info", "out_tree", "_no_kwargs"] + _lowering: XlaLowering args_info: Any # PyTree of ArgInfo out_tree: tree_util.PyTreeDef - _lowering: XlaLowering _no_kwargs: bool def __init__( @@ -614,10 +666,11 @@ def __init__( args_info, # PyTree of ArgInfo out_tree: tree_util.PyTreeDef, no_kwargs: bool = False): + self._lowering = lowering - self._no_kwargs = no_kwargs self.args_info = args_info self.out_tree = out_tree + self._no_kwargs = no_kwargs @classmethod def from_flat_info(cls, @@ -642,6 +695,14 @@ def from_flat_info(cls, out_tree, no_kwargs=no_kwargs) + @property + def out_info(self): # PyTree of OutInfo + out_avals = self._lowering.compile_args["global_out_avals"] + out_shardings = self._lowering.compile_args["out_shardings"] + return self.out_tree.unflatten( + [OutInfo(o.shape, o.dtype, None if is_unspecified_or_auto(s) else s) + for o, s in zip(out_avals, out_shardings)]) + def compile( self, compiler_options: CompilerOptions | None = None) -> Compiled: """Compile, returning a corresponding ``Compiled`` instance.""" @@ -701,8 +762,9 @@ def cost_analysis(self) -> Any | None: return None +@runtime_checkable class Wrapped(Protocol): - """A function ready to be specialized, lowered, and compiled. + """A function ready to be traced, lowered, and compiled. This protocol reflects the output of functions such as ``jax.jit``. Calling it results in JIT (just-in-time) lowering, @@ -714,6 +776,17 @@ def __call__(self, *args, **kwargs): """Executes the wrapped function, lowering and compiling as needed.""" raise NotImplementedError + def trace(self, *args, **kwargs) -> Traced: + """Trace this function explicitly for the given arguments. + + A traced function is staged out of Python and translated to a jaxpr. It is + ready for lowering but not yet lowered. + + Returns: + A ``Traced`` instance representing the tracing. + """ + raise NotImplementedError + def lower(self, *args, **kwargs) -> Lowered: """Lower this function explicitly for the given arguments. diff --git a/jax/_src/state/discharge.py b/jax/_src/state/discharge.py index c3cf21a873ba..f3a3e61a2ace 100644 --- a/jax/_src/state/discharge.py +++ b/jax/_src/state/discharge.py @@ -14,11 +14,11 @@ """Module for discharging state primitives.""" from __future__ import annotations -from collections.abc import Sequence +from collections.abc import Callable, Sequence import dataclasses from functools import partial import operator -from typing import Any, Callable, Protocol +from typing import Any, Protocol import numpy as np @@ -96,9 +96,6 @@ def register(f: DischargeRule): _discharge_rules[prim] = f return register -def _has_refs(eqn: core.JaxprEqn): - return any(isinstance(v.aval, AbstractRef) for v in eqn.invars) - def _eval_jaxpr_discharge_state( jaxpr: core.Jaxpr, should_discharge: Sequence[bool], consts: Sequence[Any], *args: Any): @@ -113,8 +110,12 @@ def _eval_jaxpr_discharge_state( if d and isinstance(v.aval, AbstractRef)} for eqn in jaxpr.eqns: - if _has_refs(eqn) and any(id(v.aval) in refs_to_discharge - for v in eqn.invars): + if eqn.primitive is core.mutable_array_p: + [invar], [outvar] = eqn.invars, eqn.outvars + ans = env.read(invar) + refs_to_discharge.add(id(outvar.aval)) + elif (any(id(v.aval) in refs_to_discharge for v in eqn.invars) + or core.internal_mutable_array_effect in eqn.effects ): if eqn.primitive not in _discharge_rules: raise NotImplementedError("No state discharge rule implemented for " f"primitive: {eqn.primitive}") @@ -188,12 +189,19 @@ def _convert_to_array_indexer(indexer: indexing.NDIndexer def _maybe_convert_to_dynamic_slice( indexer: indexing.NDIndexer, -) -> tuple[tuple[Array | int, ...], tuple[int, ...], tuple[int, ...]] | None: +) -> ( + tuple[tuple[Array | int, ...], tuple[Array | int, ...], tuple[int, ...]] + | None +): # An NDIndexer only corresponds to a `dynamic_slice` or `dynamic_update_slice` # if each of the indexers is a `Slice` or a ()-shaped value. if not all(isinstance(i, indexing.Slice) or not np.shape(i) for i in indexer.indices): return None + # TODO(b/329733289): support strided load/store in interpret mode. + for i in indexer.indices: + if isinstance(i, indexing.Slice) and i.stride > 1: + raise NotImplementedError("Unimplemented stride support.") _convert_i32 = lambda x: lax.convert_element_type(x, np.dtype("int32")) starts = tuple( _convert_i32(i.start) if isinstance(i, indexing.Slice) @@ -234,22 +242,23 @@ def _prepend_scatter(x, indexer, val, *, add=False): def _get_discharge(x, idx, tree): indexers = tree_util.tree_unflatten(tree, idx) - if len(indexers) > 1: - raise NotImplementedError("Only single indexer is supported.") - indexer = indexers[0] - if _is_trivial_indexer(indexer): - return x - # If everything in the indexer is a slice or ()-shaped, we can also - # use `lax.dynamic_slice` with 1-sized slices for ()-shaped indices. - # We need to squeeze out the the 1-sized slices at the end. - if maybe_slice := _maybe_convert_to_dynamic_slice(indexer): - starts, sizes, squeeze_dims = maybe_slice - y = lax_slicing.dynamic_slice(x, starts, sizes) - return lax.squeeze(y, squeeze_dims) - indexer = _convert_to_array_indexer(indexer) - if indexer is None: - return x - return x[None][(np.array(0, 'int32'), *indexer)] + result = x + for indexer in indexers: + if _is_trivial_indexer(indexer): + continue + if indexer is None: + continue + # If everything in the indexer is a slice or ()-shaped, we can also + # use `lax.dynamic_slice` with 1-sized slices for ()-shaped indices. + # We need to squeeze out the 1-sized slices at the end. + if maybe_slice := _maybe_convert_to_dynamic_slice(indexer): + starts, sizes, squeeze_dims = maybe_slice + y = lax_slicing.dynamic_slice(result, starts, sizes) + result = lax.squeeze(y, squeeze_dims) + else: + indexer = _convert_to_array_indexer(indexer) + result = result[None][(np.array(0, "int32"), *indexer)] + return result def _indexer(idx, indexed_dims): idx_ = iter(idx) @@ -268,23 +277,28 @@ def _swap_discharge_rule( def _swap_discharge(x, val, idx, tree): indexers = tree_util.tree_unflatten(tree, idx) - if len(indexers) > 1: - raise NotImplementedError("Only single indexer is supported.") - indexer = indexers[0] - if _is_trivial_indexer(indexer): - return x, val - # If everything in the indexer is a slice or ()-shaped, we can also - # use `lax.dynamic_slice` with 1-sized slices for ()-shaped indices. - # We need to squeeze out the the 1-sized slices at the end. - if maybe_slice := _maybe_convert_to_dynamic_slice(indexer): - starts, sizes, squeeze_dims = maybe_slice - x_old = lax_slicing.dynamic_slice(x, starts, sizes) - val = lax.expand_dims(val, squeeze_dims) - y = lax_slicing.dynamic_update_slice(x, val, starts) - return lax.squeeze(x_old, squeeze_dims), y - indexer = _convert_to_array_indexer(indexer) - x_old = _prepend_gather(x, indexer) - return x_old, _prepend_scatter(x, indexer, val) + + result = x + result_val = val + for indexer in indexers: + if _is_trivial_indexer(indexer): + continue + # If everything in the indexer is a slice or ()-shaped, we can also + # use `lax.dynamic_slice` with 1-sized slices for ()-shaped indices. + # We need to squeeze out the 1-sized slices at the end. + if maybe_slice := _maybe_convert_to_dynamic_slice(indexer): + starts, sizes, squeeze_dims = maybe_slice + result_old = lax_slicing.dynamic_slice(result, starts, sizes) + result_val = lax.expand_dims(result_val, squeeze_dims) + y = lax_slicing.dynamic_update_slice(result, result_val, starts) + result = lax.squeeze(result_old, squeeze_dims) + result_val = y + else: + indexer = _convert_to_array_indexer(indexer) + result_old = _prepend_gather(result, indexer) + result_val = _prepend_scatter(result, indexer, result_val) + result = result_old + return result, result_val @register_discharge_rule(addupdate_p) def _addupdate_discharge_rule( @@ -304,7 +318,7 @@ def _addupdate_discharge(x, val, idx, tree): return x + val # If everything in the indexer is a slice or ()-shaped, we can also # use `lax.dynamic_slice` with 1-sized slices for ()-shaped indices. - # We need to squeeze out the the 1-sized slices at the end. + # We need to squeeze out the 1-sized slices at the end. if maybe_slice := _maybe_convert_to_dynamic_slice(indexer): starts, sizes, squeeze_dims = maybe_slice x_old = lax_slicing.dynamic_slice(x, starts, sizes) @@ -476,7 +490,7 @@ def _run_state_partial_eval(trace: pe.JaxprTrace, *tracers: pe.JaxprTracer, else: raise Exception("Invalid fixpoint") del out_unknowns # redundant since it's the same as `in_unknowns` - tracers = tuple(trace.instantiate_const(t) if uk else t # type: ignore + tracers = tuple(trace.instantiate_const(t) if uk else t for t, uk in zip(tracers, in_unknowns)) # We use `partial_eval_jaxpr_stateful` here because it won't remove effectful diff --git a/jax/_src/state/indexing.py b/jax/_src/state/indexing.py index 175d883980c1..acf1c7216240 100644 --- a/jax/_src/state/indexing.py +++ b/jax/_src/state/indexing.py @@ -30,42 +30,81 @@ @tree_util.register_pytree_node_class @dataclasses.dataclass class Slice: - """Represents a slice with a dynamic start index and a fixed size.""" - start: Any - size: int + """A slice with a start index and a size. + + Both start index and size can either be static, i.e. known at tracing + and compilation time, or dynamic. + """ + + start: int | Array + size: int | Array + stride: int = 1 def __post_init__(self): - if self.size < 0: - raise ValueError("`size` must not be negative.") + if self.stride < 1: + raise ValueError("`stride` must be >= 1.") + + @property + def is_dynamic_start(self): + return not isinstance(self.start, int) + + @property + def is_dynamic_size(self): + return not isinstance(self.size, int) def tree_flatten(self): # If `start` is statically known, we treat it as static information - if isinstance(self.start, int): - return (), (self.start, self.size) - return (self.start,), (self.size,) + xs = () + data = () + xs += (self.start,) if self.is_dynamic_start else (None,) + data += (None,) if self.is_dynamic_start else (self.start,) + xs += (self.size,) if self.is_dynamic_size else (None,) + data += (None,) if self.is_dynamic_size else (self.size,) + data += (self.stride,) + return xs, data @classmethod def tree_unflatten(cls, aux_data, children) -> Slice: - return cls(*children, *aux_data) + start, size = ( + a if a is not None else b for a, b in zip(children, aux_data[:2]) + ) + return cls(start, size, aux_data[2]) @classmethod def from_slice(cls, slc: slice, size: int) -> Slice: start, stop, step = slc.indices(size) - if step != 1: - raise ValueError(f"slice must have a step of 1 (found: {step})") - return cls(start, max(stop - start, 0)) + if step < 1: + raise ValueError(f"slice must have a step >= 1 (found: {step})") + return cls(start, max((stop - start + step - 1) // step, 0), step) -def dslice(start: int | Array | None, size: int | None = None - ) -> slice | Slice: - """Constructs a `Slice` from a start and a size.""" +def dslice( + start: int | Array | None, + size: int | Array | None = None, + stride: int | None = None, +) -> slice | Slice: + """Constructs a ``Slice`` from a start index and a size. + + The semantics of ``dslice`` mirror those of the builtin ``slice`` type: + + * ``dslice(None)`` is ``:`` + * ``dslice(j)`` is ``:j`` + * ``dslice(i, j)`` is ``i:i+j`` + * ``dslice(i, j, stride)`` is ``i:i+j:stride`` + """ if start is None: return slice(None) + if stride is None: + stride = 1 + if not isinstance(stride, int): + raise ValueError("Non-static stride in `dslice`") if size is None: if not isinstance(start, int): raise ValueError("Non-static `dslice`") - return Slice(0, start) - return Slice(start, size) + return Slice(0, start, stride) + return Slice(start, size, stride) + + ds = dslice # Handy alias @@ -97,6 +136,7 @@ class NDIndexer: indices: tuple[DimIndexer, ...] shape: tuple[int, ...] int_indexer_shape: tuple[int, ...] + # Off by default to avoid doing validation during pytree operations. validate: bool = False def __post_init__(self): @@ -113,10 +153,12 @@ def __post_init__(self): if value := _maybe_concretize(start): if value >= s: raise ValueError(f"Out of bound slice: start={value}, dim={s}.") - if value + idx.size > s: - raise ValueError( - f"Out of bound slice: start={value}, size={idx.size}, dim={s}." - ) + if size := _maybe_concretize(idx.size): + if value + (size - 1) * idx.stride >= s: + raise ValueError( + f"Out of bound slice: start={value}, size={size}," + f" stride={idx.stride}, dim={s}." + ) continue # The shape of indexer integers should be broadcastable up to the # int_indexer_shape of the whole NDIndexer @@ -141,6 +183,10 @@ def __post_init__(self): f" {self.int_indexer_shape=}" ) from e + @property + def is_dynamic_size(self): + return any(isinstance(i, Slice) and i.is_dynamic_size for i in self.indices) + def tree_flatten(self): flat_idx, idx_tree = tree_util.tree_flatten(self.indices) return flat_idx, (idx_tree, self.shape, self.int_indexer_shape) @@ -149,47 +195,62 @@ def tree_flatten(self): def tree_unflatten(cls, data, flat_idx): idx_tree, shape, int_indexer_shape = data indices = tree_util.tree_unflatten(idx_tree, flat_idx) - return NDIndexer(tuple(indices), shape, int_indexer_shape) + return cls(tuple(indices), shape, int_indexer_shape) @classmethod def from_indices_shape(cls, indices, shape) -> NDIndexer: if not isinstance(indices, tuple): + # TODO(slebedev): Consider requiring `indices` to be a Sequence. indices = (indices,) - if len(indices) == 1 and indices[0] is ...: - indices = (slice(None),) * len(shape) - if any(idx is ... for idx in indices): - # TODO(sharadmv,mattjj): support patterns that include ellipsis in them - # e.g. x[0, ..., 1]. - raise NotImplementedError("Ellipsis in indexer not supported yet.") + + indices = list(indices) + if num_ellipsis := sum(idx is ... for idx in indices): + if num_ellipsis > 1: + raise ValueError("Only one ellipsis is supported.") + # Expand ... so that `indices` has the same length as `shape`. + ip = indices.index(...) + indices[ip:ip+1] = [slice(None)] * (len(shape) - len(indices) + 1) if len(indices) > len(shape): + indices = tuple(indices) raise ValueError("`indices` must not be longer than `shape`: " f"{indices=}, {shape=}") - # Pad out indices with slice(None) - indices = [*indices, *[slice(None)] * (len(shape) - len(indices))] - # Convert all `slice`s to `Slice`s - indices = tuple(Slice.from_slice(i, s) if isinstance(i, slice) - else i for i, s in zip(indices, shape)) + elif len(indices) < len(shape): + # Pad `indices` to have the same length as `shape`. + indices.extend([slice(None)] * (len(shape) - len(indices))) + + # Promote all builtin `slice`s to `Slice`. + indices = tuple( + Slice.from_slice(i, s) if isinstance(i, slice) else i + for i, s in zip(indices, shape)) + is_int_indexing = [not isinstance(i, Slice) for i in indices] - other_indexers, int_indexers = partition_list(is_int_indexing, indices) - indexer_shapes = [core.get_aval(i).shape for i in int_indexers] - if indexer_shapes: + if any(is_int_indexing): + other_indexers, int_indexers = partition_list(is_int_indexing, indices) + indexer_shapes = tuple(core.get_aval(i).shape for i in int_indexers) try: - bcast_shape = np.broadcast_shapes(*indexer_shapes) + int_indexer_shape = np.broadcast_shapes(*indexer_shapes) except ValueError as e: # Raise a nicer error than the NumPy one. - raise ValueError("Cannot broadcast shapes for indexing: " - f"{tuple(a for a in indexer_shapes)}") from e + raise ValueError( + f"Cannot broadcast shapes for indexing: {indexer_shapes}") from e + + # Here we use the `broadcast_to` primitive instead of composing lax + # primitives together because it is easier to lower in targets like + # Triton/Mosaic. + # + # The local import avoids a circular dependency between primitives + # and this module. + from jax._src.state import primitives as sp # pytype: disable=import-error + int_indexers = [ + sp.broadcast_to(i, int_indexer_shape) for i in int_indexers + ] + indices = tuple(merge_lists(is_int_indexing, other_indexers, int_indexers)) else: - bcast_shape = () - # Here we use the `broadcast_to` primitive instead of composing lax - # primitives together because it is easier to lower in targets like - # Triton/Mosaic. - from jax._src.state import primitives as sp # pytype: disable=import-error - int_indexers = [sp.broadcast_to(i, bcast_shape) for i in int_indexers] - indices = merge_lists(is_int_indexing, other_indexers, int_indexers) - return NDIndexer(tuple(indices), shape, bcast_shape, validate=True) - - def get_indexer_shape(self) -> tuple[int, ...]: + int_indexer_shape = () + + return cls(indices, shape, int_indexer_shape, validate=True) + + def get_indexer_shape(self) -> tuple[int | Array, ...]: _, slice_indexers, _ = unpack_ndindexer(self) slice_shape = [s.size for s in slice_indexers] # In NDIndexers, the int_indexer_shape is *always* at the front of the diff --git a/jax/_src/state/primitives.py b/jax/_src/state/primitives.py index 4b55792f7388..224c2f351ae3 100644 --- a/jax/_src/state/primitives.py +++ b/jax/_src/state/primitives.py @@ -148,11 +148,12 @@ def ref_addupdate(ref_or_view: AbstractRef, idx: Indexer | None, x: Array) -> No def _shape_after_indexing( - shape: tuple[int, ...], indexers: tuple[indexing.NDIndexer, ...] -) -> tuple[int, ...]: + shape: tuple[int | Array, ...], indexers: tuple[indexing.NDIndexer, ...] +) -> tuple[int | Array, ...]: for indexer in indexers: # Run some simple checks that all the indexers have consistent shapes - assert indexer.shape == shape, (indexer.shape, shape) + if not indexer.is_dynamic_size: + assert indexer.shape == shape, (indexer.shape, shape) shape = indexer.get_indexer_shape() return shape @@ -192,7 +193,7 @@ def _swap_abstract_eval(ref_aval: AbstractRef, if ref_aval.dtype != val_aval.dtype: raise ValueError("Invalid dtype for `swap`. " f"Ref dtype: {ref_aval.dtype}. " - f"Value shape: {val_aval.dtype}. ") + f"Value dtype: {val_aval.dtype}. ") out_aval = core.ShapedArray(expected_out_shape, ref_aval.dtype) else: if indexers: @@ -239,12 +240,26 @@ def _pp_slice(context: core.JaxprPpContext, dim, slc: indexing.Slice start, size = slc.start, slc.size if isinstance(start, core.Var): start_str = core.pp_var(start, context) - end_str = f'{start_str}+{size}' + size_str = ( + core.pp_var(size, context) + if isinstance(size, core.Var) + else str(size) + ) + return f'{start_str}:{start_str}+{size_str}' else: - start_str = '' if start == 0 else str(start) - end = start + size - end_str = '' if end == dim else str(end) - return f'{start_str}:{end_str}' + start_str = str(start) + if start == 0: + start_str = '' + if isinstance(size, core.Var): + size_str = core.pp_var(size, context) + if start_str: + return f'{start_str}:{start_str}+{size_str}' + else: + return f':{size_str}' + else: + end = start + size + end_str = '' if end == dim else str(end) + return f'{start_str}:{end_str}' def pp_indexer(context: core.JaxprPpContext,indexer: indexing.NDIndexer ) -> pp.Doc: @@ -328,7 +343,7 @@ def _get_jvp(primals: list[Any], tangents: list[Any], **params: Any): ref_tangent, *_ = tangents assert isinstance(ref_tangent.aval, AbstractRef) return (get_p.bind(ref_primal, *idx, **params), - get_p.bind(ref_tangent, *idx, **params)) # type: ignore[arg-type] + get_p.bind(ref_tangent, *idx, **params)) ad.primitive_jvps[get_p] = _get_jvp def _swap_jvp(primals: list[Any], tangents: list[Any], **params: Any): @@ -337,8 +352,8 @@ def _swap_jvp(primals: list[Any], tangents: list[Any], **params: Any): ref_tangent, x_tangent, *_ = tangents assert isinstance(ref_tangent.aval, AbstractRef) x_tangent = ad_util.instantiate(x_tangent) - return (swap_p.bind(ref_primal, x_primal, *idx, **params), # type: ignore[arg-type] - swap_p.bind(ref_tangent, x_tangent, *idx, **params)) # type: ignore[arg-type] + return (swap_p.bind(ref_primal, x_primal, *idx, **params), + swap_p.bind(ref_tangent, x_tangent, *idx, **params)) ad.primitive_jvps[swap_p] = _swap_jvp def addupdate_jvp_rule(primals: list[Any], tangents: list[Any], **params: Any): diff --git a/jax/_src/state/types.py b/jax/_src/state/types.py index dabc567ae6ea..303e4da0b5bf 100644 --- a/jax/_src/state/types.py +++ b/jax/_src/state/types.py @@ -18,7 +18,7 @@ from collections.abc import Sequence import dataclasses import math -from typing import Any, Generic, TypeVar, Union +from typing import Any, Union from jax._src import core from jax._src import effects @@ -74,8 +74,6 @@ class AccumEffect(RefEffect): # ## `Ref`s -Aval = TypeVar("Aval", bound=core.AbstractValue) - @dataclasses.dataclass class RefIndexer: ref_or_view: Any @@ -97,7 +95,11 @@ class RefView: indexers: tuple[indexing.NDIndexer, ...] @property - def shape(self) -> tuple[int, ...]: + def is_dynamic_size(self): + return self.indexers[-1].is_dynamic_size + + @property + def shape(self) -> tuple[int | Array, ...]: assert ( len(self.indexers) > 0 ), "Should not be able to create a trivial RefView" @@ -124,7 +126,7 @@ def __setitem__(self, slc, value): # We need an aval for `Ref`s so we can represent `get` and `swap` in Jaxprs. -class AbstractRef(core.AbstractValue, Generic[Aval]): +class AbstractRef(core.AbstractValue): __slots__ = ["inner_aval"] def __init__(self, inner_aval: core.AbstractValue): @@ -212,6 +214,6 @@ def get_ref_state_effects( def shaped_array_ref(shape: tuple[int, ...], dtype, weak_type: bool = False, - named_shape = None) -> AbstractRef[core.AbstractValue]: + named_shape = None) -> AbstractRef: return AbstractRef(core.ShapedArray(shape, dtype, weak_type=weak_type, named_shape=named_shape)) diff --git a/jax/_src/test_util.py b/jax/_src/test_util.py index ea5cc76fcf9f..edd769aff5c6 100644 --- a/jax/_src/test_util.py +++ b/jax/_src/test_util.py @@ -11,72 +11,73 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +# pyformat: disable from __future__ import annotations -from collections.abc import Generator, Iterable, Sequence -from contextlib import contextmanager, ExitStack +import collections +from collections.abc import Callable, Generator, Iterable, Sequence +from contextlib import ExitStack, contextmanager import datetime -import inspect -import io import functools from functools import partial +import inspect import math -import re import os +import re +import sys import tempfile import textwrap -from typing import Any, Callable +from typing import Any import unittest import warnings import zlib from absl.testing import absltest from absl.testing import parameterized - -import numpy as np -import numpy.random as npr - import jax from jax import lax -from jax.experimental.compilation_cache import compilation_cache -from jax._src.interpreters import mlir -from jax.tree_util import tree_map, tree_all, tree_flatten, tree_unflatten from jax._src import api -from jax._src import pjit as pjit_lib from jax._src import config from jax._src import core from jax._src import dispatch -from jax._src import linear_util as lu from jax._src import dtypes as _dtypes +from jax._src import linear_util as lu from jax._src import monitoring +from jax._src import pjit as pjit_lib from jax._src import stages -from jax._src.lib import xla_client as xc +from jax._src import xla_bridge from jax._src.cloud_tpu_init import running_in_cloud_tpu_vm +from jax._src.interpreters import mlir from jax._src.interpreters import pxla +from jax._src.lib import xla_client as xc from jax._src.numpy.util import promote_dtypes, promote_dtypes_inexact -from jax._src.util import unzip2 from jax._src.public_test_util import ( # noqa: F401 _assert_numpy_allclose, _check_dtypes_match, _default_tolerance, _dtype, check_close, check_grads, - check_jvp, check_vjp, default_gradient_tolerance, default_tolerance, tolerance) -from jax._src import xla_bridge + check_jvp, check_vjp, default_gradient_tolerance, default_tolerance, rand_like, tolerance) +from jax._src.util import unzip2 +from jax.experimental.compilation_cache import compilation_cache +from jax.tree_util import tree_all, tree_flatten, tree_map, tree_unflatten +import numpy as np +import numpy.random as npr # This submodule includes private test utilities that are not exported to # jax.test_util. Functionality appearing here is for internal use only, and # may be changed or removed at any time and without any deprecation cycle. -_TEST_DUT = config.DEFINE_string( +_TEST_DUT = config.string_flag( 'jax_test_dut', '', help= 'Describes the device under test in case special consideration is required.' ) -NUM_GENERATED_CASES = config.DEFINE_integer( +NUM_GENERATED_CASES = config.int_flag( 'jax_num_generated_cases', int(os.getenv('JAX_NUM_GENERATED_CASES', '10')), help='Number of generated cases to test') -_MAX_CASES_SAMPLING_RETRIES = config.DEFINE_integer( +_MAX_CASES_SAMPLING_RETRIES = config.int_flag( 'max_cases_sampling_retries', int(os.getenv('JAX_MAX_CASES_SAMPLING_RETRIES', '100')), 'Number of times a failed test sample should be retried. ' @@ -84,23 +85,23 @@ 'sampling process is terminated.' ) -_SKIP_SLOW_TESTS = config.DEFINE_bool( +_SKIP_SLOW_TESTS = config.bool_flag( 'jax_skip_slow_tests', config.bool_env('JAX_SKIP_SLOW_TESTS', False), help='Skip tests marked as slow (> 5 sec).' ) -_TEST_TARGETS = config.DEFINE_string( +_TEST_TARGETS = config.string_flag( 'test_targets', os.getenv('JAX_TEST_TARGETS', ''), 'Regular expression specifying which tests to run, called via re.search on ' 'the test name. If empty or unspecified, run all tests.' ) -_EXCLUDE_TEST_TARGETS = config.DEFINE_string( +_EXCLUDE_TEST_TARGETS = config.string_flag( 'exclude_test_targets', os.getenv('JAX_EXCLUDE_TEST_TARGETS', ''), 'Regular expression specifying which tests NOT to run, called via re.search ' 'on the test name. If empty or unspecified, run all tests.' ) -TEST_WITH_PERSISTENT_COMPILATION_CACHE = config.DEFINE_bool( +TEST_WITH_PERSISTENT_COMPILATION_CACHE = config.bool_flag( 'jax_test_with_persistent_compilation_cache', config.bool_env('JAX_TEST_WITH_PERSISTENT_COMPILATION_CACHE', False), help='If enabled, the persistent compilation cache will be enabled for all ' @@ -179,12 +180,39 @@ def check_eq(xs, ys, err_msg=''): tree_all(tree_map(assert_close, xs, ys)) +# TODO(yashkatariya): Make this context manager check for deprecation message +# in OSS. +@contextmanager +def unaccelerate_getattr_deprecation(module, name): + message, prev_attr = module._deprecations[name] + module._deprecations[name] = (message, getattr(module, f"_deprecated_{name}")) + try: + yield + finally: + module._deprecations[name] = (message, prev_attr) + @contextmanager -def capture_stdout() -> Generator[Callable[[], str], None, None]: - with unittest.mock.patch('sys.stdout', new_callable=io.StringIO) as fp: - def _read() -> str: - return fp.getvalue() - yield _read +def capture_stdout() -> Generator[Callable[[], str | None], None, None]: + """Context manager to capture all stdout output.""" + + # The encoding should also work on windows, the default doesn't necessarily. + with tempfile.NamedTemporaryFile(mode="w+", delete=True, encoding='utf-8') as f: + original_stdout = os.dup(sys.stdout.fileno()) + os.dup2(f.fileno(), sys.stdout.fileno()) + + # if get_stdout returns not it means we are not done capturing + # stdout. it should only be used after the context has exited. + captured = None + get_stdout: Callable[[], str | None] = lambda: captured + + try: + yield get_stdout + finally: + # Python also has its own buffers, make sure everything is flushed. + sys.stdout.flush() + f.seek(0) + captured = f.read() + os.dup2(original_stdout, sys.stdout.fileno()) @contextmanager @@ -228,18 +256,18 @@ def count_primitive_compiles(): @contextmanager def count_device_put_fast_path_hit(): - original_fn = xc.copy_array_to_devices_with_sharding + original_fn = xc.batched_copy_array_to_devices_with_sharding count = [0] - def copy_array_to_devices_with_sharding_and_count(*args, **kwargs): + def batched_copy_array_to_devices_with_sharding_and_count(*args, **kwargs): count[0] += 1 return original_fn(*args, **kwargs) - xc.copy_array_to_devices_with_sharding = copy_array_to_devices_with_sharding_and_count + xc.batched_copy_array_to_devices_with_sharding = batched_copy_array_to_devices_with_sharding_and_count try: yield count finally: - xc.copy_array_to_devices_with_sharding = original_fn + xc.batched_copy_array_to_devices_with_sharding = original_fn @contextmanager @@ -257,6 +285,20 @@ def pjit_lower_and_count(*args, **kwargs): finally: pjit_lib._pjit_lower = original_pjit_lower +@contextmanager +def count_cached_compilation_cache_miss(): + original_cached_compilation = pxla._cached_compilation + count = [0] + + def cached_compilation_and_count(*args, **kwargs): + count[0] += 1 + return original_cached_compilation(*args, **kwargs) + + pxla._cached_compilation = cached_compilation_and_count + try: + yield count + finally: + pxla._cached_compilation = original_cached_compilation @contextmanager def count_jit_tracing_cache_miss(): @@ -274,6 +316,21 @@ def create_pjit_jaxpr_and_count(*args): finally: pjit_lib._create_pjit_jaxpr = original_create_pjit_jaxpr +@contextmanager +def count_jit_infer_params_cache_miss(): + original_infer_params_impl = pjit_lib._infer_params_impl + count = collections.defaultdict(int) + + def infer_params_impl_and_count(fun, *args, **kw): + count[fun] += 1 + return original_infer_params_impl(fun, *args, **kw) + + pjit_lib._infer_params_impl = infer_params_impl_and_count + try: + yield count + finally: + pjit_lib._infer_params_impl = original_infer_params_impl + @contextmanager def count_aot_jit_cpp_cache_miss(): @@ -346,9 +403,8 @@ def supported_dtypes(): if device_under_test() == "tpu": types = {np.bool_, np.int8, np.int16, np.int32, np.uint8, np.uint16, np.uint32, _dtypes.bfloat16, np.float16, np.float32, np.complex64} - elif device_under_test() == "iree": - types = {np.bool_, np.int8, np.int16, np.int32, np.uint8, np.uint16, - np.uint32, np.float32} + elif device_under_test() == "METAL": + types = {np.int32, np.uint32, np.float32} else: types = {np.bool_, np.int8, np.int16, np.int32, np.int64, np.uint8, np.uint16, np.uint32, np.uint64, @@ -359,7 +415,7 @@ def supported_dtypes(): return types def is_device_rocm(): - return xla_bridge.get_backend().platform_version.startswith('rocm') + return 'rocm' in xla_bridge.get_backend().platform_version def is_device_cuda(): return 'cuda' in xla_bridge.get_backend().platform_version @@ -417,12 +473,20 @@ def is_device_tpu(version: int | None = None, variant: str = "") -> bool: return "v5 lite" in device_kind return expected_version in device_kind +def is_cuda_compute_capability_at_least(capability: str) -> bool: + if not is_device_cuda(): + return False + d, *_ = jax.local_devices(backend="gpu") + return d.compute_capability >= capability + def _get_device_tags(): """returns a set of tags defined for the device under test""" if is_device_rocm(): device_tags = {device_under_test(), "rocm"} elif is_device_cuda(): device_tags = {device_under_test(), "cuda"} + elif device_under_test() == "METAL": + device_tags = {device_under_test(), "gpu"} else: device_tags = {device_under_test()} return device_tags @@ -469,8 +533,13 @@ def device_supports_buffer_donation(): ) +@contextmanager def set_host_platform_device_count(nr_devices: int): - """Returns a closure that undoes the operation.""" + """Context manager to set host platform device count if not specified by user. + + This should only be used by tests at the top level in setUpModule(); it will + not work correctly if applied to individual test cases. + """ prev_xla_flags = os.getenv("XLA_FLAGS") flags_str = prev_xla_flags or "" # Don't override user-specified device count, or other XLA flags. @@ -479,13 +548,14 @@ def set_host_platform_device_count(nr_devices: int): f" --xla_force_host_platform_device_count={nr_devices}") # Clear any cached backends so new CPU backend will pick up the env var. xla_bridge.get_backend.cache_clear() - def undo(): + try: + yield + finally: if prev_xla_flags is None: del os.environ["XLA_FLAGS"] else: os.environ["XLA_FLAGS"] = prev_xla_flags xla_bridge.get_backend.cache_clear() - return undo def skip_on_flag(flag_name, skip_value): @@ -514,6 +584,18 @@ def wrap(func_or_class): return wrap +def is_running_under_pytest(): + return "pytest" in sys.modules + + +def skip_under_pytest(reason: str): + """A decorator for test methods to skip the test when run under pytest.""" + reason = "Running under pytest: " + reason + def skip(test_method): + return unittest.skipIf(is_running_under_pytest(), reason)(test_method) + return skip + + def format_test_name_suffix(opname, shapes, dtypes): arg_descriptions = (format_shape_dtype_string(shape, dtype) for shape, dtype in zip(shapes, dtypes)) @@ -976,6 +1058,38 @@ def wrapper(*args, **kw): return fun(*args, **kw) return wrapper +@contextmanager +def global_config_context(**kwds): + original_config = {} + try: + for key, value in kwds.items(): + original_config[key] = config._read(key) + config.update(key, value) + yield + finally: + for key, value in original_config.items(): + config.update(key, value) + + +class NotPresent: + def __repr__(self): + return "" + + +@contextmanager +def assert_global_configs_unchanged(): + starting_config = jax.config.values.copy() + yield + ending_config = jax.config.values + + if starting_config == ending_config: + return + differing = {k: (starting_config.get(k, NotPresent()), ending_config.get(k, NotPresent())) + for k in (starting_config.keys() | ending_config.keys()) + if (k not in starting_config or k not in ending_config + or starting_config[k] != ending_config[k])} + raise AssertionError(f"Test changed global config values. Differing values are: {differing}") + class JaxTestCase(parameterized.TestCase): """Base class for JAX tests including numerical checks and boilerplate.""" @@ -996,26 +1110,20 @@ class JaxTestCase(parameterized.TestCase): def setUp(self): super().setUp() - self._original_config = {} - for key, value in self._default_config.items(): - self._original_config[key] = config._read(key) - config.update(key, value) + self.enter_context(assert_global_configs_unchanged()) # We use the adler32 hash for two reasons. # a) it is deterministic run to run, unlike hash() which is randomized. # b) it returns values in int32 range, which RandomState requires. self._rng = npr.RandomState(zlib.adler32(self._testMethodName.encode())) - def tearDown(self): - for key, value in self._original_config.items(): - config.update(key, value) - super().tearDown() - @classmethod def setUpClass(cls): + cls._compilation_cache_exit_stack = ExitStack() + stack = cls._compilation_cache_exit_stack + stack.enter_context(global_config_context(**cls._default_config)) + if TEST_WITH_PERSISTENT_COMPILATION_CACHE.value: - cls._compilation_cache_exit_stack = ExitStack() - stack = cls._compilation_cache_exit_stack stack.enter_context(config.enable_compilation_cache(True)) stack.enter_context(config.raise_persistent_cache_errors(True)) stack.enter_context(config.persistent_cache_min_compile_time_secs(0)) @@ -1027,8 +1135,7 @@ def setUpClass(cls): @classmethod def tearDownClass(cls): - if TEST_WITH_PERSISTENT_COMPILATION_CACHE.value: - cls._compilation_cache_exit_stack.close() + cls._compilation_cache_exit_stack.close() def rng(self): return self._rng @@ -1285,7 +1392,8 @@ def integer(self): @_cached_property def all_integer(self): - return self.supported([np.int8, np.int16, np.int32, np.int64]) + return self.supported([ + _dtypes.int4, np.int8, np.int16, np.int32, np.int64]) @_cached_property def unsigned(self): @@ -1293,7 +1401,8 @@ def unsigned(self): @_cached_property def all_unsigned(self): - return self.supported([np.uint8, np.uint16, np.uint32, np.uint64]) + return self.supported([ + _dtypes.uint4, np.uint8, np.uint16, np.uint32, np.uint64]) @_cached_property def complex(self): @@ -1408,7 +1517,7 @@ def register_event_duration_listener(callback): def set_env(**kwargs): """Context manager to temporarily set/unset one or more environment variables. - Example: + Examples: >>> import os >>> os.environ['my_var'] = 'original' @@ -1463,11 +1572,11 @@ def complex_plane_sample(dtype, size_re=10, size_im=None): >>> print(complex_plane_sample(np.complex64, 0, 3)) [[-inf -infj 0. -infj inf -infj] [-inf-3.4028235e+38j 0.-3.4028235e+38j inf-3.4028235e+38j] - [-inf-2.0000052e+00j 0.-2.0000052e+00j inf-2.0000052e+00j] + [-inf-2.0000000e+00j 0.-2.0000000e+00j inf-2.0000000e+00j] [-inf-1.1754944e-38j 0.-1.1754944e-38j inf-1.1754944e-38j] [-inf+0.0000000e+00j 0.+0.0000000e+00j inf+0.0000000e+00j] [-inf+1.1754944e-38j 0.+1.1754944e-38j inf+1.1754944e-38j] - [-inf+2.0000052e+00j 0.+2.0000052e+00j inf+2.0000052e+00j] + [-inf+2.0000000e+00j 0.+2.0000000e+00j inf+2.0000000e+00j] [-inf+3.4028235e+38j 0.+3.4028235e+38j inf+3.4028235e+38j] [-inf +infj 0. +infj inf +infj]] @@ -1477,16 +1586,18 @@ def complex_plane_sample(dtype, size_re=10, size_im=None): finfo = np.finfo(dtype) def make_axis_points(size): - logmin = np.log10(abs(finfo.min)) - logtiny = np.log10(finfo.tiny) - logmax = np.log10(finfo.max) + prec_dps_ratio = 3.3219280948873626 + logmin = logmax = finfo.maxexp / prec_dps_ratio + logtiny = finfo.minexp / prec_dps_ratio axis_points = np.zeros(3 + 2 * size, dtype=finfo.dtype) with warnings.catch_warnings(): # Silence RuntimeWarning: overflow encountered in cast warnings.simplefilter("ignore") - axis_points[1:size + 1] = -np.logspace(logmin, logtiny, size, dtype=finfo.dtype) - axis_points[-size - 1:-1] = np.logspace(logtiny, logmax, size, dtype=finfo.dtype) + half_neg_line = -np.logspace(logmin, logtiny, size, dtype=finfo.dtype) + half_line = -half_neg_line[::-1] + axis_points[-size - 1:-1] = half_line + axis_points[1:size + 1] = half_neg_line if size > 1: axis_points[1] = finfo.min @@ -1508,3 +1619,392 @@ def make_axis_points(size): imag_part = imag_part.reshape((3 + 2 * size_im, -1)).repeat(3 + 2 * size_re, 1) return real_part + imag_part + + +class vectorize_with_mpmath(np.vectorize): + """Same as numpy.vectorize but using mpmath backend for function evaluation. + """ + + map_float_to_complex = dict(float16='complex32', float32='complex64', float64='complex128', float128='complex256', longdouble='clongdouble') + map_complex_to_float = {v: k for k, v in map_float_to_complex.items()} + + float_prec = dict( + # float16=11, + float32=24, + float64=53, + # float128=113, + # longdouble=113 + ) + + float_minexp = dict( + float16=-14, + float32=-126, + float64=-1022, + float128=-16382 + ) + + float_maxexp = dict( + float16=16, + float32=128, + float64=1024, + float128=16384, + ) + + def __init__(self, *args, **kwargs): + mpmath = kwargs.pop('mpmath', None) + if mpmath is None: + raise ValueError('vectorize_with_mpmath: no mpmath argument specified') + self.extra_prec_multiplier = kwargs.pop('extra_prec_multiplier', 0) + self.extra_prec = kwargs.pop('extra_prec', 0) + self.mpmath = mpmath + self.contexts = dict() + self.contexts_inv = dict() + for fp_format, prec in self.float_prec.items(): + ctx = self.mpmath.mp.clone() + ctx.prec = prec + self.contexts[fp_format] = ctx + self.contexts_inv[ctx] = fp_format + + super().__init__(*args, **kwargs) + + def get_context(self, x): + if isinstance(x, (np.ndarray, np.floating, np.complexfloating)): + fp_format = str(x.dtype) + fp_format = self.map_complex_to_float.get(fp_format, fp_format) + return self.contexts[fp_format] + raise NotImplementedError(f'get mpmath context from {type(x).__name__} instance') + + def nptomp(self, x): + """Convert numpy array/scalar to an array/instance of mpmath number type. + """ + if isinstance(x, np.ndarray): + return np.fromiter(map(self.nptomp, x.flatten()), dtype=object).reshape(x.shape) + elif isinstance(x, np.floating): + mpmath = self.mpmath + ctx = self.get_context(x) + prec, rounding = ctx._prec_rounding + if np.isposinf(x): + return ctx.make_mpf(mpmath.libmp.finf) + elif np.isneginf(x): + return ctx.make_mpf(mpmath.libmp.fninf) + elif np.isnan(x): + return ctx.make_mpf(mpmath.libmp.fnan) + elif np.isfinite(x): + mantissa, exponent = np.frexp(x) + man = int(np.ldexp(mantissa, prec)) + exp = int(exponent - prec) + r = ctx.make_mpf(mpmath.libmp.from_man_exp(man, exp, prec, rounding)) + assert ctx.isfinite(r), r._mpf_ + return r + elif isinstance(x, np.complexfloating): + re, im = self.nptomp(x.real), self.nptomp(x.imag) + return re.context.make_mpc((re._mpf_, im._mpf_)) + raise NotImplementedError(f'convert {type(x).__name__} instance to mpmath number type') + + def mptonp(self, x): + """Convert mpmath instance to numpy array/scalar type. + """ + if isinstance(x, np.ndarray) and x.dtype.kind == 'O': + x_flat = x.flatten() + item = x_flat[0] + ctx = item.context + fp_format = self.contexts_inv[ctx] + if isinstance(item, ctx.mpc): + dtype = getattr(np, self.map_float_to_complex[fp_format]) + elif isinstance(item, ctx.mpf): + dtype = getattr(np, fp_format) + else: + dtype = None + if dtype is not None: + return np.fromiter(map(self.mptonp, x_flat), dtype=dtype).reshape(x.shape) + elif isinstance(x, self.mpmath.ctx_mp.mpnumeric): + ctx = x.context + if isinstance(x, ctx.mpc): + fp_format = self.contexts_inv[ctx] + dtype = getattr(np, self.map_float_to_complex[fp_format]) + r = dtype().reshape(1).view(getattr(np, fp_format)) + r[0] = self.mptonp(x.real) + r[1] = self.mptonp(x.imag) + return r.view(dtype)[0] + elif isinstance(x, ctx.mpf): + fp_format = self.contexts_inv[ctx] + dtype = getattr(np, fp_format) + if ctx.isfinite(x): + sign, man, exp, bc = self.mpmath.libmp.normalize(*x._mpf_, *ctx._prec_rounding) + assert bc >= 0, (sign, man, exp, bc, x._mpf_) + if exp + bc < self.float_minexp[fp_format]: + return -ctx.zero if sign else ctx.zero + if exp + bc > self.float_maxexp[fp_format]: + return ctx.ninf if sign else ctx.inf + man = dtype(-man if sign else man) + r = np.ldexp(man, exp) + assert np.isfinite(r), (x, r, x._mpf_, man) + return r + elif ctx.isnan(x): + return dtype(np.nan) + elif ctx.isinf(x): + return dtype(-np.inf if x._mpf_[0] else np.inf) + raise NotImplementedError(f'convert {type(x)} instance to numpy floating point type') + + def __call__(self, *args, **kwargs): + mp_args = [] + context = None + for a in args: + if isinstance(a, (np.ndarray, np.floating, np.complexfloating)): + mp_args.append(self.nptomp(a)) + if context is None: + context = self.get_context(a) + else: + assert context is self.get_context(a) + else: + mp_args.append(a) + + extra_prec = int(context.prec * self.extra_prec_multiplier) + self.extra_prec + with context.extraprec(extra_prec): + result = super().__call__(*mp_args, **kwargs) + + if isinstance(result, tuple): + lst = [] + for r in result: + if ((isinstance(r, np.ndarray) and r.dtype.kind == 'O') + or isinstance(r, self.mpmath.ctx_mp.mpnumeric)): + r = self.mptonp(r) + lst.append(r) + return tuple(lst) + + if ((isinstance(result, np.ndarray) and result.dtype.kind == 'O') + or isinstance(result, self.mpmath.ctx_mp.mpnumeric)): + return self.mptonp(result) + + return result + + +class numpy_with_mpmath: + """Namespace of universal functions on numpy arrays that use mpmath + backend for evaluation and return numpy arrays as outputs. + """ + + _provides = [ + 'abs', 'absolute', 'sqrt', 'exp', 'expm1', 'exp2', + 'log', 'log1p', 'log10', 'log2', + 'sin', 'cos', 'tan', 'arcsin', 'arccos', 'arctan', + 'sinh', 'cosh', 'tanh', 'arcsinh', 'arccosh', 'arctanh', + 'square', 'positive', 'negative', 'conjugate', 'sign', 'sinc', + 'normalize', + ] + + _mp_names = dict( + abs='absmin', absolute='absmin', + log='ln', + arcsin='asin', arccos='acos', arctan='atan', + arcsinh='asinh', arccosh='acosh', arctanh='atanh', + ) + + def __init__(self, mpmath, extra_prec_multiplier=0, extra_prec=0): + self.mpmath = mpmath + + for name in self._provides: + mp_name = self._mp_names.get(name, name) + + if hasattr(self, name): + op = getattr(self, name) + else: + + def op(x, mp_name=mp_name): + return getattr(x.context, mp_name)(x) + + setattr(self, name, vectorize_with_mpmath(op, mpmath=mpmath, extra_prec_multiplier=extra_prec_multiplier, extra_prec=extra_prec)) + + # The following function methods operate on mpmath number instances. + # The corresponding function names must be listed in + # numpy_with_mpmath._provides list. + + def square(self, x): + return x * x + + def positive(self, x): + return x + + def negative(self, x): + return -x + + def sqrt(self, x): + ctx = x.context + if isinstance(x, ctx.mpc): + # Workaround mpmath 1.3 bug in sqrt(+-inf+-infj) evaluation (see mpmath/mpmath#776). + # TODO(pearu): remove this function when mpmath 1.4 or newer + # will be the required test dependency. + if ctx.isinf(x.imag): + return ctx.make_mpc((ctx.inf._mpf_, x.imag._mpf_)) + return ctx.sqrt(x) + + def expm1(self, x): + return x.context.expm1(x) + + def log1p(self, x): + ctx = x.context + if isinstance(x, ctx.mpc): + # Workaround mpmath 1.3 bug in log(+-inf+-infj) evaluation (see mpmath/mpmath#774). + # TODO(pearu): remove this function when mpmath 1.4 or newer + # will be the required test dependency. + if ctx.isinf(x.real) and ctx.isinf(x.imag): + pi = ctx.pi + if x.real > 0 and x.imag > 0: + return ctx.make_mpc((x.real._mpf_, (pi / 4)._mpf_)) + if x.real > 0 and x.imag < 0: + return ctx.make_mpc((x.real._mpf_, (-pi / 4)._mpf_)) + if x.real < 0 and x.imag < 0: + return ctx.make_mpc(((-x.real)._mpf_, (-3 * pi / 4)._mpf_)) + if x.real < 0 and x.imag > 0: + return ctx.make_mpc(((-x.real)._mpf_, (3 * pi / 4)._mpf_)) + return ctx.log1p(x) + + def tan(self, x): + ctx = x.context + if isinstance(x, ctx.mpc): + # Workaround mpmath 1.3 bug in tan(+-inf+-infj) evaluation (see mpmath/mpmath#781). + # TODO(pearu): remove this function when mpmath 1.4 or newer + # will be the required test dependency. + if ctx.isinf(x.imag) and (ctx.isinf(x.real) or ctx.isfinite(x.real)): + if x.imag > 0: + return ctx.make_mpc((ctx.zero._mpf_, ctx.one._mpf_)) + return ctx.make_mpc((ctx.zero._mpf_, (-ctx.one)._mpf_)) + if ctx.isinf(x.real) and ctx.isfinite(x.imag): + return ctx.make_mpc((ctx.nan._mpf_, ctx.nan._mpf_)) + return ctx.tan(x) + + def tanh(self, x): + ctx = x.context + if isinstance(x, ctx.mpc): + # Workaround mpmath 1.3 bug in tanh(+-inf+-infj) evaluation (see mpmath/mpmath#781). + # TODO(pearu): remove this function when mpmath 1.4 or newer + # will be the required test dependency. + if ctx.isinf(x.imag) and (ctx.isinf(x.real) or ctx.isfinite(x.real)): + if x.imag > 0: + return ctx.make_mpc((ctx.zero._mpf_, ctx.one._mpf_)) + return ctx.make_mpc((ctx.zero._mpf_, (-ctx.one)._mpf_)) + if ctx.isinf(x.real) and ctx.isfinite(x.imag): + return ctx.make_mpc((ctx.nan._mpf_, ctx.nan._mpf_)) + return ctx.tanh(x) + + def log2(self, x): + return x.context.ln(x) / x.context.ln2 + + def log10(self, x): + return x.context.ln(x) / x.context.ln10 + + def exp2(self, x): + return x.context.exp(x * x.context.ln2) + + def arcsin(self, x): + ctx = x.context + if isinstance(x, ctx.mpc): + # Workaround mpmath 1.3 bug in asin(+-inf+-infj) evaluation (see + # mpmath/mpmath#793). + # TODO(pearu): remove the if-block below when mpmath 1.4 or + # newer will be the required test dependency. + pi = ctx.pi + inf = ctx.inf + zero = ctx.zero + if ctx.isinf(x.real): + sign_real = -1 if x.real < 0 else 1 + real = sign_real * pi / (4 if ctx.isinf(x.imag) else 2) + imag = -inf if x.imag < 0 else inf + return ctx.make_mpc((real._mpf_, imag._mpf_)) + elif ctx.isinf(x.imag): + return ctx.make_mpc((zero._mpf_, x.imag._mpf_)) + + # On branch cut, mpmath.mp.asin returns different value compared + # to mpmath.fp.asin and numpy.arcsin (see + # mpmath/mpmath#786). The following if-block ensures + # compatibiliy with numpy.arcsin. + if x.real > 1 and x.imag == 0: + return ctx.asin(x).conjugate() + + return ctx.asin(x) + + def arcsinh(self, x): + ctx = x.context + + if isinstance(x, ctx.mpc): + # Workaround mpmath 1.3 bug in asinh(+-inf+-infj) evaluation + # (see mpmath/mpmath#749). + # TODO(pearu): remove the if-block below when mpmath 1.4 or + # newer will be the required test dependency. + pi = ctx.pi + inf = ctx.inf + zero = ctx.zero + if ctx.isinf(x.imag): + sign_imag = -1 if x.imag < 0 else 1 + real = -inf if x.real < 0 else inf + imag = sign_imag * pi / (4 if ctx.isinf(x.real) else 2) + return ctx.make_mpc((real._mpf_, imag._mpf_)) + elif ctx.isinf(x.real): + return ctx.make_mpc((x.real._mpf_, zero._mpf_)) + + # On branch cut, mpmath.mp.asinh returns different value + # compared to mpmath.fp.asinh and numpy.arcsinh (see + # mpmath/mpmath#786). The following if-block ensures + # compatibiliy with numpy.arcsinh. + if x.real == 0 and x.imag < -1: + return (-ctx.asinh(x)).conjugate() + return ctx.asinh(x) + + def normalize(self, exact, reference, value): + """Normalize reference and value using precision defined by the + difference of exact and reference. + """ + def worker(ctx, s, e, r, v): + ss, sm, se, sbc = s._mpf_ + es, em, ee, ebc = e._mpf_ + rs, rm, re, rbc = r._mpf_ + vs, vm, ve, vbc = v._mpf_ + + if not (ctx.isfinite(e) and ctx.isfinite(r) and ctx.isfinite(v)): + return r, v + + me = min(se, ee, re, ve) + + # transform mantissa parts to the same exponent base + sm_e = sm << (se - me) + em_e = em << (ee - me) + rm_e = rm << (re - me) + vm_e = vm << (ve - me) + + # find matching higher and non-matching lower bits of e and r + sm_b = bin(sm_e)[2:] if sm_e else '' + em_b = bin(em_e)[2:] if em_e else '' + rm_b = bin(rm_e)[2:] if rm_e else '' + vm_b = bin(vm_e)[2:] if vm_e else '' + + m = max(len(sm_b), len(em_b), len(rm_b), len(vm_b)) + em_b = '0' * (m - len(em_b)) + em_b + rm_b = '0' * (m - len(rm_b)) + rm_b + + c1 = 0 + for b0, b1 in zip(em_b, rm_b): + if b0 != b1: + break + c1 += 1 + c0 = m - c1 + + # truncate r and v mantissa + rm_m = rm_e >> c0 + vm_m = vm_e >> c0 + + # normalized r and v + nr = ctx.make_mpf((rs, rm_m, -c1, len(bin(rm_m)) - 2)) if rm_m else (-ctx.zero if rs else ctx.zero) + nv = ctx.make_mpf((vs, vm_m, -c1, len(bin(vm_m)) - 2)) if vm_m else (-ctx.zero if vs else ctx.zero) + + return nr, nv + + ctx = exact.context + scale = abs(exact) + if isinstance(exact, ctx.mpc): + rr, rv = worker(ctx, scale, exact.real, reference.real, value.real) + ir, iv = worker(ctx, scale, exact.imag, reference.imag, value.imag) + return ctx.make_mpc((rr._mpf_, ir._mpf_)), ctx.make_mpc((rv._mpf_, iv._mpf_)) + elif isinstance(exact, ctx.mpf): + return worker(ctx, scale, exact, reference, value) + else: + assert 0 # unreachable diff --git a/jax/_src/third_party/numpy/LICENSE b/jax/_src/third_party/numpy/LICENSE deleted file mode 100644 index f7a64e5174e4..000000000000 --- a/jax/_src/third_party/numpy/LICENSE +++ /dev/null @@ -1,30 +0,0 @@ -Copyright (c) 2005-2019, NumPy Developers. -All rights reserved. - -Redistribution and use in source and binary forms, with or without -modification, are permitted provided that the following conditions are -met: - - * Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - * Redistributions in binary form must reproduce the above - copyright notice, this list of conditions and the following - disclaimer in the documentation and/or other materials provided - with the distribution. - - * Neither the name of the NumPy Developers nor the names of any - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS -"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT -LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR -A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT -OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, -SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT -LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, -DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY -THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. \ No newline at end of file diff --git a/jax/_src/third_party/numpy/__init__.py b/jax/_src/third_party/numpy/__init__.py deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/jax/_src/third_party/numpy/linalg.py b/jax/_src/third_party/numpy/linalg.py deleted file mode 100644 index 7c8ffe9d5276..000000000000 --- a/jax/_src/third_party/numpy/linalg.py +++ /dev/null @@ -1,211 +0,0 @@ -import numpy as np - -import jax.numpy as jnp -import jax.numpy.linalg as la -from jax._src.numpy.util import check_arraylike, implements - - -def _isEmpty2d(arr): - # check size first for efficiency - return arr.size == 0 and np.prod(arr.shape[-2:]) == 0 - - -def _assertNoEmpty2d(*arrays): - for a in arrays: - if _isEmpty2d(a): - raise np.linalg.LinAlgError("Arrays cannot be empty") - - -def _assertRankAtLeast2(*arrays): - for a in arrays: - if a.ndim < 2: - raise np.linalg.LinAlgError( - '%d-dimensional array given. Array must be ' - 'at least two-dimensional' % a.ndim) - - -def _assertNdSquareness(*arrays): - for a in arrays: - m, n = a.shape[-2:] - if m != n: - raise np.linalg.LinAlgError( - 'Last 2 dimensions of the array must be square') - - -def _assert2d(*arrays): - for a in arrays: - if a.ndim != 2: - raise ValueError(f'{a.ndim}-dimensional array given. ' - 'Array must be two-dimensional') - - -@implements(np.linalg.cond) -def cond(x, p=None): - check_arraylike('jnp.linalg.cond', x) - _assertNoEmpty2d(x) - if p in (None, 2): - s = la.svd(x, compute_uv=False) - return s[..., 0] / s[..., -1] - elif p == -2: - s = la.svd(x, compute_uv=False) - r = s[..., -1] / s[..., 0] - else: - _assertRankAtLeast2(x) - _assertNdSquareness(x) - invx = la.inv(x) - r = la.norm(x, ord=p, axis=(-2, -1)) * la.norm(invx, ord=p, axis=(-2, -1)) - - # Convert nans to infs unless the original array had nan entries - orig_nan_check = jnp.full_like(r, ~jnp.isnan(r).any()) - nan_mask = jnp.logical_and(jnp.isnan(r), ~jnp.isnan(x).any(axis=(-2, -1))) - r = jnp.where(orig_nan_check, jnp.where(nan_mask, jnp.inf, r), r) - return r - - -@implements(np.linalg.tensorinv) -def tensorinv(a, ind=2): - check_arraylike('jnp.linalg.tensorinv', a) - a = jnp.asarray(a) - oldshape = a.shape - prod = 1 - if ind > 0: - invshape = oldshape[ind:] + oldshape[:ind] - for k in oldshape[ind:]: - prod *= k - else: - raise ValueError("Invalid ind argument.") - a = a.reshape(prod, -1) - ia = la.inv(a) - return ia.reshape(*invshape) - - -@implements(np.linalg.tensorsolve) -def tensorsolve(a, b, axes=None): - check_arraylike('jnp.linalg.tensorsolve', a, b) - a = jnp.asarray(a) - b = jnp.asarray(b) - an = a.ndim - if axes is not None: - allaxes = list(range(0, an)) - for k in axes: - allaxes.remove(k) - allaxes.insert(an, k) - - a = a.transpose(allaxes) - - Q = a.shape[-(an - b.ndim):] - - prod = 1 - for k in Q: - prod *= k - - a = a.reshape(-1, prod) - b = b.ravel() - - res = jnp.asarray(la.solve(a, b)) - res = res.reshape(Q) - - return res - - -@implements(np.linalg.multi_dot) -def multi_dot(arrays, *, precision=None): - check_arraylike('jnp.linalg.multi_dot', *arrays) - n = len(arrays) - # optimization only makes sense for len(arrays) > 2 - if n < 2: - raise ValueError("Expecting at least two arrays.") - elif n == 2: - return jnp.dot(arrays[0], arrays[1], precision=precision) - - arrays = [jnp.asarray(a) for a in arrays] - - # save original ndim to reshape the result array into the proper form later - ndim_first, ndim_last = arrays[0].ndim, arrays[-1].ndim - # Explicitly convert vectors to 2D arrays to keep the logic of the internal - # _multi_dot_* functions as simple as possible. - if arrays[0].ndim == 1: - arrays[0] = jnp.atleast_2d(arrays[0]) - if arrays[-1].ndim == 1: - arrays[-1] = jnp.atleast_2d(arrays[-1]).T - _assert2d(*arrays) - - # _multi_dot_three is much faster than _multi_dot_matrix_chain_order - if n == 3: - result = _multi_dot_three(*arrays, precision) - else: - order = _multi_dot_matrix_chain_order(arrays) - result = _multi_dot(arrays, order, 0, n - 1, precision) - - # return proper shape - if ndim_first == 1 and ndim_last == 1: - return result[0, 0] # scalar - elif ndim_first == 1 or ndim_last == 1: - return result.ravel() # 1-D - else: - return result - - -def _multi_dot_three(A, B, C, precision): - """ - Find the best order for three arrays and do the multiplication. - For three arguments `_multi_dot_three` is approximately 15 times faster - than `_multi_dot_matrix_chain_order` - """ - a0, a1b0 = A.shape - b1c0, c1 = C.shape - # cost1 = cost((AB)C) = a0*a1b0*b1c0 + a0*b1c0*c1 - cost1 = a0 * b1c0 * (a1b0 + c1) - # cost2 = cost(A(BC)) = a1b0*b1c0*c1 + a0*a1b0*c1 - cost2 = a1b0 * c1 * (a0 + b1c0) - - if cost1 < cost2: - return jnp.dot(jnp.dot(A, B, precision=precision), C, precision=precision) - else: - return jnp.dot(A, jnp.dot(B, C, precision=precision), precision=precision) - - -def _multi_dot_matrix_chain_order(arrays, return_costs=False): - """ - Return a jnp.array that encodes the optimal order of mutiplications. - The optimal order array is then used by `_multi_dot()` to do the - multiplication. - Also return the cost matrix if `return_costs` is `True` - The implementation CLOSELY follows Cormen, "Introduction to Algorithms", - Chapter 15.2, p. 370-378. Note that Cormen uses 1-based indices. - cost[i, j] = min([ - cost[prefix] + cost[suffix] + cost_mult(prefix, suffix) - for k in range(i, j)]) - """ - n = len(arrays) - # p stores the dimensions of the matrices - # Example for p: A_{10x100}, B_{100x5}, C_{5x50} --> p = [10, 100, 5, 50] - p = [a.shape[0] for a in arrays] + [arrays[-1].shape[1]] - # m is a matrix of costs of the subproblems - # m[i,j]: min number of scalar multiplications needed to compute A_{i..j} - m = np.zeros((n, n), dtype=np.double) - # s is the actual ordering - # s[i, j] is the value of k at which we split the product A_i..A_j - s = np.empty((n, n), dtype=np.intp) - - for l in range(1, n): - for i in range(n - l): - j = i + l - m[i, j] = jnp.inf - for k in range(i, j): - q = m[i, k] + m[k+1, j] + p[i]*p[k+1]*p[j+1] - if q < m[i, j]: - m[i, j] = q - s[i, j] = k # Note that Cormen uses 1-based index - - return (s, m) if return_costs else s - - -def _multi_dot(arrays, order, i, j, precision): - """Actually do the multiplication with the given order.""" - if i == j: - return arrays[i] - else: - return jnp.dot(_multi_dot(arrays, order, i, order[i, j], precision), - _multi_dot(arrays, order, order[i, j] + 1, j, precision), - precision=precision) diff --git a/jax/_src/third_party/scipy/interpolate.py b/jax/_src/third_party/scipy/interpolate.py index 0634eb4fd6a1..1eb726ea863c 100644 --- a/jax/_src/third_party/scipy/interpolate.py +++ b/jax/_src/third_party/scipy/interpolate.py @@ -1,10 +1,9 @@ from itertools import product -import scipy.interpolate as osp_interpolate from jax.numpy import (asarray, broadcast_arrays, can_cast, empty, nan, searchsorted, where, zeros) from jax._src.tree_util import register_pytree_node -from jax._src.numpy.util import check_arraylike, promote_dtypes_inexact, implements +from jax._src.numpy.util import check_arraylike, promote_dtypes_inexact def _ndim_coords_from_arrays(points, ndim=None): @@ -31,15 +30,30 @@ def _ndim_coords_from_arrays(points, ndim=None): return points -@implements( - osp_interpolate.RegularGridInterpolator, - lax_description=""" -In the JAX version, `bounds_error` defaults to and must always be `False` since no -bound error may be raised under JIT. - -Furthermore, in contrast to SciPy no input validation is performed. -""") class RegularGridInterpolator: + """Interpolate points on a regular rectangular grid. + + JAX implementation of :func:`scipy.interpolate.RegularGridInterpolator`. + + Args: + points: length-N sequence of arrays specifying the grid coordinates. + values: N-dimensional array specifying the grid values. + method: interpolation method, either ``"linear"`` or ``"nearest"``. + bounds_error: not implemented by JAX + fill_value: value returned for points outside the grid, defaults to NaN. + + Returns: + interpolator: callable interpolation object. + + Examples: + >>> points = (jnp.array([1, 2, 3]), jnp.array([4, 5, 6])) + >>> values = jnp.array([[10, 20, 30], [40, 50, 60], [70, 80, 90]]) + >>> interpolate = RegularGridInterpolator(points, values, method='linear') + + >>> query_points = jnp.array([[1.5, 4.5], [2.2, 5.8]]) + >>> interpolate(query_points) + Array([30., 64.], dtype=float32) + """ # Based on SciPy's implementation which in turn is originally based on an # implementation by Johannes Buchner @@ -76,7 +90,6 @@ def __init__(self, self.grid = tuple(asarray(p) for p in points) self.values = values - @implements(osp_interpolate.RegularGridInterpolator.__call__, update_doc=False) def __call__(self, xi, method=None): method = self.method if method is None else method if method not in ("linear", "nearest"): @@ -151,7 +164,7 @@ def _find_indices(self, xi): lambda obj: ((obj.grid, obj.values, obj.fill_value), (obj.method, obj.bounds_error)), lambda aux, children: RegularGridInterpolator( - *children[:2], # type: ignore[index] + *children[:2], *aux, - *children[2:]), # type: ignore[index] + *children[2:]), ) diff --git a/jax/_src/third_party/scipy/linalg.py b/jax/_src/third_party/scipy/linalg.py index 80f0656572c4..dce4df1fb817 100644 --- a/jax/_src/third_party/scipy/linalg.py +++ b/jax/_src/third_party/scipy/linalg.py @@ -1,13 +1,10 @@ from __future__ import annotations -from typing import Callable - -import scipy.linalg +from collections.abc import Callable from jax import jit, lax import jax.numpy as jnp from jax._src.numpy.linalg import norm -from jax._src.numpy.util import implements from jax._src.scipy.linalg import rsf2csf, schur from jax._src.typing import ArrayLike, Array @@ -40,20 +37,51 @@ def _inner_loop(i, p_F_minden): return lax.fori_loop(1, N, _outer_loop, (F, minden)) -_FUNM_LAX_DESCRIPTION = """\ -The array returned by :py:func:`jax.scipy.linalg.funm` may differ in dtype -from the array returned by py:func:`scipy.linalg.funm`. Specifically, in cases -where all imaginary parts of the array values are close to zero, the SciPy -function may return a real-valued array, whereas the JAX implementation will -return a complex-valued array. - -Additionally, unlike the SciPy implementation, when ``disp=True`` no warning -will be printed if the error in the array output is estimated to be large. -""" -@implements(scipy.linalg.funm, lax_description=_FUNM_LAX_DESCRIPTION) def funm(A: ArrayLike, func: Callable[[Array], Array], disp: bool = True) -> Array | tuple[Array, Array]: + """Evaluate a matrix-valued function + + JAX implementation of :func:`scipy.linalg.funm`. + + Args: + A: array of shape ``(N, N)`` for which the function is to be computed. + func: Callable object that takes a scalar argument and returns a scalar result. + Represents the function to be evaluated over the eigenvalues of A. + disp: If true (default), error information is not returned. Unlike scipy's version JAX + does not attempt to display information at runtime. + compute_expm: (N, N) array_like or None, optional. + If provided, the matrix exponential of A. This is used for improving efficiency when `func` + is the exponential function. If not provided, it is computed internally. + Defaults to None. + + Returns: + Array of same shape as ``A``, containing the result of ``func`` evaluated on the + eigenvalues of ``A``. + + Notes: + The returned dtype of JAX's implementation may differ from that of scipy; + specifically, in cases where all imaginary parts of the array values are + close to zero, the SciPy function may return a real-valued array, whereas + the JAX implementation will return a complex-valued array. + + Examples: + Applying an arbitrary matrix function: + + >>> A = jnp.array([[1., 2.], [3., 4.]]) + >>> def func(x): + ... return jnp.sin(x) + 2 * jnp.cos(x) + >>> jax.scipy.linalg.funm(A, func) # doctest: +SKIP + Array([[ 1.2452652 +0.j, -0.3701772 +0.j], + [-0.55526584+0.j, 0.6899995 +0.j]], dtype=complex64) + + Comparing two ways of computing the matrix exponent: + + >>> expA_1 = jax.scipy.linalg.funm(A, jnp.exp) + >>> expA_2 = jax.scipy.linalg.expm(A) + >>> jnp.allclose(expA_1, expA_2, rtol=1E-4) + Array(True, dtype=bool) + """ A_arr = jnp.asarray(A) if A_arr.ndim != 2 or A_arr.shape[0] != A_arr.shape[1]: raise ValueError('expected square array_like input') diff --git a/jax/_src/third_party/scipy/signal_helper.py b/jax/_src/third_party/scipy/signal_helper.py index 5205c9079b61..4a021675804d 100644 --- a/jax/_src/third_party/scipy/signal_helper.py +++ b/jax/_src/third_party/scipy/signal_helper.py @@ -2,7 +2,6 @@ from __future__ import annotations -import scipy.signal as osp_signal from typing import Any import warnings @@ -43,10 +42,20 @@ def _triage_segments(window: ArrayLike | str | tuple[Any, ...], nperseg: int | N if isinstance(window, (str, tuple)): nperseg_int = input_length if nperseg is None else int(nperseg) if nperseg_int > input_length: - warnings.warn(f'nperseg = {nperseg_int} is greater than input length ' - f' = {input_length}, using nperseg = {input_length}') + warnings.warn(f'nperseg={nperseg_int} is greater than {input_length=},' + f' using nperseg={input_length}') nperseg_int = input_length - win = jnp.array(osp_signal.get_window(window, nperseg_int), dtype=dtype) + if window == 'hann': + # Implement the default case without scipy + win = jnp.array([1.0]) if nperseg_int == 1 else jnp.sin(jnp.linspace(0, jnp.pi, nperseg_int, endpoint=False)) ** 2 + else: + # TODO(jakevdp): implement get_window() in JAX to remove optional scipy dependency + try: + from scipy.signal import get_window + except ImportError as err: + raise ImportError(f"scipy must be available to use {window=}") from err + win = get_window(window, nperseg_int) + win = jnp.array(win, dtype=dtype) else: win = jnp.asarray(window) nperseg_int = win.size if nperseg is None else int(nperseg) diff --git a/jax/_src/tpu_custom_call.py b/jax/_src/tpu_custom_call.py index bc5a8c0f2c10..d92c0686ad10 100644 --- a/jax/_src/tpu_custom_call.py +++ b/jax/_src/tpu_custom_call.py @@ -19,30 +19,43 @@ import base64 import collections.abc +from collections.abc import Callable, Sequence import dataclasses import functools import io import os import time -from typing import Any, Callable +from typing import Any -from absl import flags import jax from jax import core from jax._src import config +from jax._src import sharding_impls +from jax._src.interpreters import mlir from jax._src.lib import tpu from jax._src.lib import xla_client -from jax._src.lib.mlir.dialects import hlo -from jax._src.interpreters import mlir -from jax._src import sharding_impls from jax.interpreters import xla from jaxlib.mlir import ir -from jaxlib.mlir.dialects import stablehlo -import numpy as np +from jaxlib.mlir.dialects import mhlo +from jaxlib.mlir.passmanager import PassManager -FLAGS = flags.FLAGS +try: + from absl import flags + FLAGS = flags.FLAGS +except ImportError: + FLAGS = {} -_MOSAIC_ALLOW_HLO = config.define_bool_state( +_MOSAIC_USE_PYTHON_PIPELINE = config.bool_state( + name="mosaic_use_python_pipeline", + default=False, + help=( + "Run the initial Mosaic MLIR passes from Python, when as_tpu_kernel" + " is called (for Pallas, this happens at JAX lowering time), instead of" + " later within XLA." + ), +) + +_MOSAIC_ALLOW_HLO = config.bool_state( name="jax_mosaic_allow_hlo", default=False, help="Allow hlo dialects in Mosaic", @@ -79,6 +92,9 @@ class CustomCallBackendConfig: needs_layout_passes: bool vmem_limit_bytes: int | None flags: dict[str, bool | int | float] | None + allow_input_fusion: list[bool] | None + serialization_format: int | None + internal_scratch_in_bytes: int | None # We omit the body while printing, because primitive params get embedded # in HLO metadata, and the body blows up its size. @@ -104,10 +120,24 @@ def to_json(self) -> bytes: if self.needs_hlo_passes: config.write(b', "needs_hlo_passes": ') config.write(str(self.needs_hlo_passes).lower().encode("ascii")) + if self.serialization_format is not None: + config.write(b', "serialization_format": ') + config.write(str(self.serialization_format).lower().encode("ascii")) if self.needs_layout_passes: config.write(b', "needs_layout_passes": ') config.write(str(self.needs_layout_passes).lower().encode("ascii")) - config.write(b"}") + if self.allow_input_fusion is not None: + config.write(b', "allow_input_fusion": [') + for i, value in enumerate(self.allow_input_fusion): + config.write(b"true" if value else b"false") + # config.write(str(value).lower().encode("ascii")) + if i + 1 != len(self.allow_input_fusion): + config.write(b",") + config.write(b"]") + if self.internal_scratch_in_bytes is not None: + config.write(b', "internal_scratch_in_bytes": ') + config.write(str(self.internal_scratch_in_bytes).encode("ascii")) + config.write(b"}") # End of custom_call_config. if self.device_type is not None: config.write(b', "device_type": ') config.write( @@ -141,6 +171,9 @@ def to_json(self) -> bytes: if i + 1 != len(self.flags): config.write(b",") config.write(b"]") + # Prevent the compiler from sharding the custom call beyond what Mosaic does + # based on user annotations + config.write(b', "implicit_sharding": {"type": "MANUAL"}') config.write(b"}") return config.getvalue() @@ -150,13 +183,8 @@ def _tpu_custom_call_abstract_eval(*_, out_avals, **__): return out_avals -def _aval_to_layout(aval): - arange = np.arange(aval.ndim, dtype=np.dtype(np.int64))[::-1].copy() - return ir.DenseIntElementsAttr.get(arange, type=ir.IndexType.get()) - - -def _avals_to_layouts(avals): - return ir.ArrayAttr.get([_aval_to_layout(a) for a in avals]) +def _avals_to_layouts(avals) -> Sequence[Sequence[int]]: + return [tuple(range(a.ndim - 1, -1, -1)) for a in avals] def _tpu_custom_call_lowering( @@ -164,18 +192,11 @@ def _tpu_custom_call_lowering( *in_nodes, # pylint: disable=missing-function-docstring config: CustomCallBackendConfig, kernel_name: str | None, - kernel_regeneration_metadata: bytes | None, out_avals: Any, input_output_aliases: tuple[tuple[int, int], ...], ) -> ...: i32_type = ir.IntegerType.get_signless(32) - multiple_results = len(out_avals) > 1 - if multiple_results: - result_type = ir.TupleType.get_tuple( - [mlir.aval_to_ir_type(aval) for aval in out_avals] - ) - else: - result_type = mlir.aval_to_ir_type(out_avals[0]) + result_types = [mlir.aval_to_ir_type(aval) for aval in out_avals] axis_context = ctx.module_context.axis_context if isinstance(axis_context, sharding_impls.SPMDAxisContext): if axis_context.manual_axes != frozenset(axis_context.mesh.axis_names): @@ -193,51 +214,144 @@ def _tpu_custom_call_lowering( raise NotImplementedError( "Replica lowering for Mosaic kernels not implemented." ) - call = stablehlo.CustomCallOp( - [result_type], - in_nodes, - call_target_name=ir.StringAttr.get(b"tpu_custom_call"), - has_side_effect=ir.BoolAttr.get(False), - backend_config=ir.StringAttr.get(config.to_json()), - api_version=ir.IntegerAttr.get(i32_type, 1), - called_computations=ir.ArrayAttr.get([]), + if all(core.is_constant_shape(aval_out.shape) for aval_out in ctx.avals_out): + result_shapes = None + else: + result_shapes = [ + mlir.shape_tensor(mlir.eval_dynamic_shape(ctx, aval_out.shape)) + for aval_out in ctx.avals_out] + extra_attributes = None + # Add kernel_name and kernel_metadata as attributes to the custom call op. + # This is because we do not want to pollute the backend_config with this + # information. + if kernel_name is not None: + extra_attributes = dict(kernel_name=ir.StringAttr.get(kernel_name)) + call = mlir.custom_call( + "tpu_custom_call", + result_types=result_types, + operands=in_nodes, + backend_config=config.to_json(), + api_version=1, + operand_output_aliases=dict(input_output_aliases), operand_layouts=_avals_to_layouts(ctx.avals_in), result_layouts=_avals_to_layouts(ctx.avals_out), - output_operand_aliases=ir.ArrayAttr.get([ - hlo.OutputOperandAlias.get( - # if len(result_types) == 1 then the aliasing refers implicitly to - # the only output. - output_tuple_indices=[output_idx] - if len(out_avals) > 1 - else [], - operand_index=input_idx, - operand_tuple_indices=[], - ) - for input_idx, output_idx in input_output_aliases - ]), - ) + result_shapes=result_shapes, + extra_attributes=extra_attributes) - # Add kernel_name and kernel_regeneration_metadata as attributes to the - # custom call op. This is because we do not want to pollute the backend_config - # with this information. - if kernel_name is not None: - call.attributes["kernel_name"] = ir.StringAttr.get(kernel_name) - if kernel_regeneration_metadata is not None: - call.attributes["kernel_regeneration_metadata"] = ir.StringAttr.get( - base64.b64encode(kernel_regeneration_metadata) - ) - if multiple_results: - results = [stablehlo.get_tuple_element(call, mlir.i32_attr(i)) - for i in range(len(out_avals))] - else: - results = call.results - return results + return call.results mlir.register_lowering(tpu_custom_call_p, _tpu_custom_call_lowering, platform="tpu") +def _lower_tpu_kernel( + module: ir.Module, + hardware_generation: int, +) -> ir.Module: + """Runs MLIR passes lowering the given module to an MLIR module. + + Uses Python versions of infer-memref-layout and apply-vector-layout. + + Args: + module: The MLIR module to lower. + hardware_generation: The TPU hardware generation to target. + + Returns: + An MLIR module implementing the kernel. + """ + try: + module.operation.verify() + except ir.MLIRError as e: + raise ValueError("The compiled module fails MLIR verification") from e + + with module.context as ctx, module.operation.location as _: + ctx.append_dialect_registry(mlir.upstream_dialects) + ctx.load_all_available_dialects() + tpu.register_dialect(ctx) + mhlo.register_mhlo_dialect(ctx) + mhlo.register_mhlo_passes() + + dump_mlir(module, "original") + + if _MOSAIC_ALLOW_HLO.value: + # Run hlo dialect conversion: hlo -> linalg -> vector. + pipeline = [ + "hlo-legalize-to-arithmetic", + "func.func(hlo-legalize-to-linalg)", + "func.func(linalg-vectorization)", + ] + pipeline = PassManager.parse(f"builtin.module({','.join(pipeline)})") + pipeline.run(module.operation) + dump_mlir(module, "post-hlo-conversion") + + pipeline = [ + f"func.func(tpu-infer-memref-layout{{hardware-generation={hardware_generation}}})" + ] + pipeline = PassManager.parse(f"builtin.module({','.join(pipeline)})") + pipeline.run(module.operation) + dump_mlir(module, "post-infer-memref-layout") + + pipeline = [ + "canonicalize", + "cse", + ] + pipeline = PassManager.parse(f"builtin.module({','.join(pipeline)})") + pipeline.run(module.operation) + dump_mlir(module, "post-infer-memref-layout-simplify") + + try: + on_device_checks = FLAGS["xla_mosaic_on_device_checks"].value + except KeyError: + on_device_checks = False + + if checks := on_device_checks: + checks = set(checks.split(",")) + if checks == {"bounds"}: # We only support one kind of checks now. + pipeline = PassManager.parse( + "builtin.module(func.func(debug-assert-insertion))" + ) + pipeline.run(module.operation) + dump_mlir(module, "post-assert-insertion") + elif checks: + checks.discard("bounds") + raise ValueError( + f"Unrecognized on-device check categories: {', '.join(checks)}" + ) + + pipeline = [ + "func.func(tpu-infer-vector-layout{sublane-count=8 lane-count=128})", + ] + pipeline = PassManager.parse(f"builtin.module({','.join(pipeline)})") + pipeline.run(module.operation) + dump_mlir(module, "post-infer-vector-layout") + + sl_cnt = 8 + l_cnt = 128 + mxu_size = 128 if hardware_generation < 6 else 256 + pipeline = [ + "func.func(tpu-apply-vector-layout{" + f" sublane-count={sl_cnt} lane-count={l_cnt}" + f" hardware-generation={hardware_generation}" + f" mxu-contracting-size={mxu_size} mxu-noncontracting-size={mxu_size}" + f" max-sublanes-in-scratch={sl_cnt * (sl_cnt + 1)}" + "})" + ] + pipeline = PassManager.parse(f"builtin.module({','.join(pipeline)})") + pipeline.run(module.operation) + dump_mlir(module, "post-apply-vector-layout") + + pipeline = [ + "canonicalize", + "cse", + ] + pipeline = PassManager.parse(f"builtin.module({','.join(pipeline)})") + pipeline.run(module.operation) + dump_mlir(module, "post-apply-vector-layout-simplify") + + return module + + def as_tpu_kernel( module: ir.Module, out_type: Any, @@ -246,29 +360,52 @@ def as_tpu_kernel( backend: str | xla_client.Client = "tpu", device_type: str | None = None, kernel_name: str | None = None, - kernel_regeneration_metadata: bytes | None = None, vmem_limit_bytes: int | None = None, flags: dict[str, bool | int | float] | None = None, + allow_input_fusion: list[bool] | None = None, input_output_aliases: tuple[tuple[int, int], ...] = (), + internal_scratch_in_bytes: int | None = None, ) -> Callable[..., Any]: """Turns an MLIR Mosaic kernel into a JAX-compatible function.""" # We use jax.jit to make sure we hit the fast compilation cache. - some_tpu = jax.devices(backend)[0] - device_kind = some_tpu.device_kind - if not device_kind.startswith("TPU v"): - raise ValueError(f"Unrecognized TPU device kind: {device_kind}.") + if vmem_limit_bytes is not None and not isinstance(vmem_limit_bytes, int): raise ValueError( "vmem_limit_bytes must be an int: provided with a" f" {type(vmem_limit_bytes)}." ) - hardware_generation = int(device_kind[len("TPU v")]) has_communication, has_custom_barrier = tpu.private_has_communication( module.operation ) - bytecode_buffer = io.BytesIO() - module.operation.write_bytecode(bytecode_buffer, desired_version=0) - asm = bytecode_buffer.getvalue() + needs_hlo_passes = _MOSAIC_ALLOW_HLO.value + needs_layout_passes = not device_type + # We'll mutate the module, so clone it + with module.context as ctx, module.operation.location as _: + module = ir.Module.parse( + module.operation.get_asm(binary=True, enable_debug_info=True) + ) + if needs_layout_passes and _MOSAIC_USE_PYTHON_PIPELINE.value: + some_tpu = jax.devices(backend)[0] + device_kind = some_tpu.device_kind + if not device_kind.startswith("TPU v"): + raise ValueError( + f"Unrecognized TPU device kind: {device_kind}. " + "tpu_custom_call cannot be lowered on a machine without TPUs " + "when mosaic_use_python_pipeline=True.") + hardware_generation = int(device_kind[len("TPU v")]) + module = _lower_tpu_kernel(module, hardware_generation) + needs_hlo_passes = False + needs_layout_passes = False + prev_allow_unregistered_dialects = ctx.allow_unregistered_dialects + ctx.allow_unregistered_dialects = True + try: + pipeline = PassManager.parse("builtin.module(mosaic-serde{serialize=true})") + pipeline.run(module.operation) + finally: + ctx.allow_unregistered_dialects = prev_allow_unregistered_dialects + bytecode_buffer = io.BytesIO() + module.operation.write_bytecode(bytecode_buffer, desired_version=0) + asm = bytecode_buffer.getvalue() # TODO(amagni): Kernel name and regeneration metadata could alternatively be # added as a custom attribute to the MLIR call op rather than including them @@ -276,17 +413,18 @@ def as_tpu_kernel( return _lowered_as_tpu_kernel( asm, out_type, - needs_hlo_passes=_MOSAIC_ALLOW_HLO.value, - needs_layout_passes=not device_type, + needs_hlo_passes=needs_hlo_passes, + needs_layout_passes=needs_layout_passes, device_type=device_type, has_communication=has_communication, has_custom_barrier=has_custom_barrier, kernel_name=kernel_name, - kernel_regeneration_metadata=kernel_regeneration_metadata, cost_estimate=cost_estimate, vmem_limit_bytes=vmem_limit_bytes, flags=flags, + allow_input_fusion=allow_input_fusion, input_output_aliases=input_output_aliases, + internal_scratch_in_bytes=internal_scratch_in_bytes, ) @@ -301,10 +439,12 @@ def _lowered_as_tpu_kernel( has_communication: bool = False, has_custom_barrier: bool = False, kernel_name: str | None = None, - kernel_regeneration_metadata: bytes | None = None, vmem_limit_bytes: int | None = None, flags: dict[str, bool | int | float] | None = None, + allow_input_fusion: list[bool] | None = None, input_output_aliases: tuple[tuple[int, int], ...] = (), + serialization_format: int | None = 1, + internal_scratch_in_bytes: int | None = None, ): """Turns a low-level MLIR Mosaic kernel into a JAX-compatible function.""" unpack = False @@ -333,12 +473,14 @@ def apply_kernel(*args, collective_id: int | None = None): needs_layout_passes, vmem_limit_bytes, flags, + allow_input_fusion, + serialization_format, + internal_scratch_in_bytes, ) result = tpu_custom_call_p.bind( *args, config=config, kernel_name=kernel_name, - kernel_regeneration_metadata=kernel_regeneration_metadata, out_avals=out_avals, input_output_aliases=input_output_aliases, ) @@ -355,6 +497,6 @@ def dump_mlir(module: ir.Module, name: str): if should_dump == "sponge": outdir = os.environ.get("TEST_UNDECLARED_OUTPUTS_DIR", None) if outdir: - path = os.path.join(outdir, f"{time.time_ns()}-mosaic-dump-{name}.txt") + path = os.path.join(outdir, f"{time.time_ns()}-mosaic-dump-{name}-py.txt") with open(path, "w") as f: f.write(str(module)) diff --git a/jax/_src/traceback_util.py b/jax/_src/traceback_util.py index 5ab04cd9ab64..d66cbb912a99 100644 --- a/jax/_src/traceback_util.py +++ b/jax/_src/traceback_util.py @@ -14,12 +14,13 @@ from __future__ import annotations +from collections.abc import Callable import functools import os import sys import traceback import types -from typing import Any, Callable, TypeVar, cast +from typing import Any, TypeVar, cast from jax._src import config from jax._src import util @@ -137,7 +138,7 @@ def _running_under_ipython() -> bool: def _ipython_supports_tracebackhide() -> bool: """Returns true if the IPython version supports __tracebackhide__.""" - import IPython # type: ignore + import IPython # pytype: disable=import-error return IPython.version_info[:2] >= (7, 17) def _filtering_mode() -> str: diff --git a/jax/_src/tree.py b/jax/_src/tree.py index f8358670af7b..49faaa774ef2 100644 --- a/jax/_src/tree.py +++ b/jax/_src/tree.py @@ -13,51 +13,145 @@ # limitations under the License. from __future__ import annotations -import functools -from typing import Any, Callable, Iterable, TypeVar, overload +from collections.abc import Callable, Iterable +from typing import Any, TypeVar, overload from jax._src import tree_util T = TypeVar("T") -def _add_doc(docstr): - def wrapper(fun): - doc = fun.__doc__ - firstline, rest = doc.split('\n', 1) - fun.__doc__ = f'{firstline}\n\n {docstr}\n{rest}' - return fun - return wrapper +def all(tree: Any, *, is_leaf: Callable[[Any], bool] | None = None) -> bool: + """Call all() over the leaves of a tree. + Args: + tree: the pytree to evaluate + is_leaf : an optionally specified function that will be called at each + flattening step. It should return a boolean, which indicates whether the + flattening should traverse the current object, or if it should be stopped + immediately, with the whole subtree being treated as a leaf. -@_add_doc("Alias of :func:`jax.tree_util.tree_all`.") -@functools.wraps(tree_util.tree_all) -def all(tree: Any) -> bool: - return tree_util.tree_all(tree) + Returns: + result: boolean True or False + + Examples: + >>> import jax + >>> jax.tree.all([True, {'a': True, 'b': (True, True)}]) + True + >>> jax.tree.all([False, (True, False)]) + False + + See Also: + - :func:`jax.tree.reduce` + - :func:`jax.tree.leaves` + """ + return tree_util.tree_all(tree, is_leaf=is_leaf) -@_add_doc("Alias of :func:`jax.tree_util.tree_flatten`.") -@functools.wraps(tree_util.tree_flatten) def flatten(tree: Any, is_leaf: Callable[[Any], bool] | None = None ) -> tuple[list[tree_util.Leaf], tree_util.PyTreeDef]: + """Flattens a pytree. + + The flattening order (i.e. the order of elements in the output list) + is deterministic, corresponding to a left-to-right depth-first tree + traversal. + + Args: + tree: a pytree to flatten. + is_leaf: an optionally specified function that will be called at each + flattening step. It should return a boolean, with true stopping the + traversal and the whole subtree being treated as a leaf, and false + indicating the flattening should traverse the current object. + + Returns: + A pair where the first element is a list of leaf values and the second + element is a treedef representing the structure of the flattened tree. + + Examples: + >>> import jax + >>> vals, treedef = jax.tree.flatten([1, (2, 3), [4, 5]]) + >>> vals + [1, 2, 3, 4, 5] + >>> treedef + PyTreeDef([*, (*, *), [*, *]]) + + See Also: + - :func:`jax.tree.leaves` + - :func:`jax.tree.structure` + - :func:`jax.tree.unflatten` + """ return tree_util.tree_flatten(tree, is_leaf) -@_add_doc("Alias of :func:`jax.tree_util.tree_leaves`.") -@functools.wraps(tree_util.tree_leaves) def leaves(tree: Any, is_leaf: Callable[[Any], bool] | None = None ) -> list[tree_util.Leaf]: + """Gets the leaves of a pytree. + + Args: + tree: the pytree for which to get the leaves + is_leaf : an optionally specified function that will be called at each + flattening step. It should return a boolean, which indicates whether the + flattening should traverse the current object, or if it should be stopped + immediately, with the whole subtree being treated as a leaf. + + Returns: + leaves: a list of tree leaves. + + Examples: + >>> import jax + >>> jax.tree.leaves([1, (2, 3), [4, 5]]) + [1, 2, 3, 4, 5] + + See Also: + - :func:`jax.tree.flatten` + - :func:`jax.tree.structure` + - :func:`jax.tree.unflatten` + """ return tree_util.tree_leaves(tree, is_leaf) -@_add_doc("Alias of :func:`jax.tree_util.tree_map`.") -@functools.wraps(tree_util.tree_map) def map(f: Callable[..., Any], tree: Any, *rest: Any, is_leaf: Callable[[Any], bool] | None = None) -> Any: + """Maps a multi-input function over pytree args to produce a new pytree. + + Args: + f: function that takes ``1 + len(rest)`` arguments, to be applied at the + corresponding leaves of the pytrees. + tree: a pytree to be mapped over, with each leaf providing the first + positional argument to ``f``. + rest: a tuple of pytrees, each of which has the same structure as ``tree`` + or has ``tree`` as a prefix. + is_leaf: an optionally specified function that will be called at each + flattening step. It should return a boolean, which indicates whether the + flattening should traverse the current object, or if it should be stopped + immediately, with the whole subtree being treated as a leaf. + + Returns: + A new pytree with the same structure as ``tree`` but with the value at each + leaf given by ``f(x, *xs)`` where ``x`` is the value at the corresponding + leaf in ``tree`` and ``xs`` is the tuple of values at corresponding nodes in + ``rest``. + + Examples: + + >>> import jax + >>> jax.tree.map(lambda x: x + 1, {"x": 7, "y": 42}) + {'x': 8, 'y': 43} + + If multiple inputs are passed, the structure of the tree is taken from the + first input; subsequent inputs need only have ``tree`` as a prefix: + + >>> jax.tree.map(lambda x, y: [x] + y, [5, 6], [[7, 9], [1, 2]]) + [[5, 7, 9], [6, 1, 2]] + + See Also: + - :func:`jax.tree.leaves` + - :func:`jax.tree.reduce` + """ return tree_util.tree_map(f, tree, *rest, is_leaf=is_leaf) @@ -73,32 +167,120 @@ def reduce(function: Callable[[T, Any], T], initializer: T, is_leaf: Callable[[Any], bool] | None = None) -> T: ... -@_add_doc("Alias of :func:`jax.tree_util.tree_reduce`.") -@functools.wraps(tree_util.tree_reduce) def reduce(function: Callable[[T, Any], T], tree: Any, initializer: Any = tree_util.no_initializer, is_leaf: Callable[[Any], bool] | None = None) -> T: + """Call reduce() over the leaves of a tree. + + Args: + function: the reduction function + tree: the pytree to reduce over + initializer: the optional initial value + is_leaf : an optionally specified function that will be called at each + flattening step. It should return a boolean, which indicates whether the + flattening should traverse the current object, or if it should be stopped + immediately, with the whole subtree being treated as a leaf. + + Returns: + result: the reduced value. + + Examples: + >>> import jax + >>> import operator + >>> jax.tree.reduce(operator.add, [1, (2, 3), [4, 5, 6]]) + 21 + + See Also: + - :func:`jax.tree.leaves` + - :func:`jax.tree.map` + """ return tree_util.tree_reduce(function, tree, initializer, is_leaf=is_leaf) -@_add_doc("Alias of :func:`jax.tree_util.tree_structure`.") -@functools.wraps(tree_util.tree_structure) def structure(tree: Any, is_leaf: None | (Callable[[Any], bool]) = None) -> tree_util.PyTreeDef: + """Gets the treedef for a pytree. + + Args: + tree: the pytree for which to get the leaves + is_leaf : an optionally specified function that will be called at each + flattening step. It should return a boolean, which indicates whether the + flattening should traverse the current object, or if it should be stopped + immediately, with the whole subtree being treated as a leaf. + + Returns: + pytreedef: a PyTreeDef representing the structure of the tree. + + Examples: + >>> import jax + >>> jax.tree.structure([1, (2, 3), [4, 5]]) + PyTreeDef([*, (*, *), [*, *]]) + + See Also: + - :func:`jax.tree.flatten` + - :func:`jax.tree.leaves` + - :func:`jax.tree.unflatten` + """ return tree_util.tree_structure(tree, is_leaf) -@_add_doc("Alias of :func:`jax.tree_util.tree_transpose`.") -@functools.wraps(tree_util.tree_transpose) def transpose(outer_treedef: tree_util.PyTreeDef, inner_treedef: tree_util.PyTreeDef, pytree_to_transpose: Any) -> Any: + """Transform a tree having tree structure (outer, inner) into one having structure (inner, outer). + + Args: + outer_treedef: PyTreeDef representing the outer tree. + inner_treedef: PyTreeDef representing the inner tree. + If None, then it will be inferred from outer_treedef and the structure of + pytree_to_transpose. + pytree_to_transpose: the pytree to be transposed. + + Returns: + transposed_pytree: the transposed pytree. + + Examples: + >>> import jax + >>> tree = [(1, 2, 3), (4, 5, 6)] + >>> inner_structure = jax.tree.structure(('*', '*', '*')) + >>> outer_structure = jax.tree.structure(['*', '*']) + >>> jax.tree.transpose(outer_structure, inner_structure, tree) + ([1, 4], [2, 5], [3, 6]) + + Inferring the inner structure: + + >>> jax.tree.transpose(outer_structure, None, tree) + ([1, 4], [2, 5], [3, 6]) + """ return tree_util.tree_transpose(outer_treedef, inner_treedef, pytree_to_transpose) -@_add_doc("Alias of :func:`jax.tree_util.tree_unflatten`.") -@functools.wraps(tree_util.tree_unflatten) def unflatten(treedef: tree_util.PyTreeDef, leaves: Iterable[tree_util.Leaf]) -> Any: + """Reconstructs a pytree from the treedef and the leaves. + + The inverse of :func:`tree_flatten`. + + Args: + treedef: the treedef to reconstruct + leaves: the iterable of leaves to use for reconstruction. The iterable must + match the leaves of the treedef. + + Returns: + The reconstructed pytree, containing the ``leaves`` placed in the structure + described by ``treedef``. + + Examples: + >>> import jax + >>> vals, treedef = jax.tree.flatten([1, (2, 3), [4, 5]]) + >>> newvals = [100, 200, 300, 400, 500] + >>> jax.tree.unflatten(treedef, newvals) + [100, (200, 300), [400, 500]] + + See Also: + - :func:`jax.tree.flatten` + - :func:`jax.tree.leaves` + - :func:`jax.tree.structure` + """ return tree_util.tree_unflatten(treedef, leaves) diff --git a/jax/_src/tree_util.py b/jax/_src/tree_util.py index 583f2f55aac5..32f59b1df36e 100644 --- a/jax/_src/tree_util.py +++ b/jax/_src/tree_util.py @@ -14,25 +14,27 @@ from __future__ import annotations import collections -from collections.abc import Hashable, Iterable +from collections.abc import Callable, Hashable, Iterable, Sequence from dataclasses import dataclass import difflib import functools from functools import partial import operator as op import textwrap -from typing import Any, Callable, NamedTuple, TypeVar, Union, overload +from typing import Any, NamedTuple, TypeVar, Union, overload from jax._src import traceback_util from jax._src.lib import pytree -from jax._src.util import safe_zip +from jax._src.util import safe_zip, set_module from jax._src.util import unzip2 +export = set_module('jax.tree_util') + traceback_util.register_exclusion(__file__) T = TypeVar("T") -U = TypeVar("U", bound=type[Any]) +Typ = TypeVar("Typ", bound=type[Any]) H = TypeVar("H", bound=Hashable) Leaf = Any @@ -44,6 +46,13 @@ default_registry.__module__ = __name__ default_registry.__name__ = "default_registry" +# A copy of the default registry, where None is a leaf. +none_leaf_registry = pytree.PyTreeRegistry( + enable_none=False, enable_tuple=True, enable_namedtuple=True, + enable_list=True, enable_dict=True) +none_leaf_registry.__module__ = __name__ +none_leaf_registry.__name__ = "none_leaf_registry" + # A special, internal pytree registry that includes everything in # `default_registry`, plus internal Python-defined types that we want # to teach the fast dispatch path ("C++ dispatch") how to flatten and @@ -60,149 +69,120 @@ dispatch_registry.__module__ = __name__ dispatch_registry.__name__ = "dispatch_registry" + +@export def tree_flatten(tree: Any, is_leaf: Callable[[Any], bool] | None = None ) -> tuple[list[Leaf], PyTreeDef]: - """Flattens a pytree. + """Alias of :func:`jax.tree.flatten`.""" + return default_registry.flatten(tree, is_leaf) - The flattening order (i.e. the order of elements in the output list) - is deterministic, corresponding to a left-to-right depth-first tree - traversal. - Args: - tree: a pytree to flatten. - is_leaf: an optionally specified function that will be called at each - flattening step. It should return a boolean, with true stopping the - traversal and the whole subtree being treated as a leaf, and false - indicating the flattening should traverse the current object. +@export +def tree_unflatten(treedef: PyTreeDef, leaves: Iterable[Leaf]) -> Any: + """Alias of :func:`jax.tree.unflatten`.""" + return treedef.unflatten(leaves) - Returns: - A pair where the first element is a list of leaf values and the second - element is a treedef representing the structure of the flattened tree. - Example: - >>> import jax - >>> vals, treedef = jax.tree.flatten([1, (2, 3), [4, 5]]) - >>> vals - [1, 2, 3, 4, 5] - >>> treedef - PyTreeDef([*, (*, *), [*, *]]) +@export +def tree_leaves(tree: Any, + is_leaf: Callable[[Any], bool] | None = None + ) -> list[Leaf]: + """Alias of :func:`jax.tree.leaves`.""" + return default_registry.flatten(tree, is_leaf)[0] - See Also: - - :func:`jax.tree.leaves` - - :func:`jax.tree.structure` - - :func:`jax.tree.unflatten` - """ - return default_registry.flatten(tree, is_leaf) +@export +def tree_structure(tree: Any, + is_leaf: None | (Callable[[Any], + bool]) = None) -> PyTreeDef: + """Alias of :func:`jax.tree.structure`.""" + return default_registry.flatten(tree, is_leaf)[1] -def tree_unflatten(treedef: PyTreeDef, leaves: Iterable[Leaf]) -> Any: - """Reconstructs a pytree from the treedef and the leaves. - The inverse of :func:`tree_flatten`. +@export +def treedef_tuple(treedefs: Iterable[PyTreeDef]) -> PyTreeDef: + """Makes a tuple treedef from an iterable of child treedefs. Args: - treedef: the treedef to reconstruct - leaves: the iterable of leaves to use for reconstruction. The iterable must - match the leaves of the treedef. + treedefs: iterable of PyTree structures Returns: - The reconstructed pytree, containing the ``leaves`` placed in the structure - described by ``treedef``. + a single treedef representing a tuple of the structures - Example: - >>> import jax - >>> vals, treedef = jax.tree.flatten([1, (2, 3), [4, 5]]) - >>> newvals = [100, 200, 300, 400, 500] - >>> jax.tree.unflatten(treedef, newvals) - [100, (200, 300), [400, 500]] + Examples: + >>> import jax + >>> x = [1, 2, 3] + >>> y = {'a': 4, 'b': 5} + >>> x_tree = jax.tree.structure(x) + >>> y_tree = jax.tree.structure(y) + >>> xy_tree = jax.tree_util.treedef_tuple([x_tree, y_tree]) + >>> xy_tree == jax.tree.structure((x, y)) + True See Also: - - :func:`jax.tree.flatten` - - :func:`jax.tree.leaves` - - :func:`jax.tree.structure` + - :func:`jax.tree_util.treedef_children` """ - return treedef.unflatten(leaves) + return pytree.tuple(default_registry, list(treedefs)) -def tree_leaves(tree: Any, - is_leaf: Callable[[Any], bool] | None = None - ) -> list[Leaf]: - """Gets the leaves of a pytree. +@export +def treedef_children(treedef: PyTreeDef) -> list[PyTreeDef]: + """Return a list of treedefs for immediate children Args: - tree: the pytree for which to get the leaves - is_leaf : an optionally specified function that will be called at each - flattening step. It should return a boolean, which indicates whether the - flattening should traverse the current object, or if it should be stopped - immediately, with the whole subtree being treated as a leaf. + treedef: a single PyTreeDef + Returns: - leaves: a list of tree leaves. + a list of PyTreeDefs representing the children of treedef. - Example: - >>> import jax - >>> jax.tree.leaves([1, (2, 3), [4, 5]]) - [1, 2, 3, 4, 5] + Examples: + >>> import jax + >>> x = [(1, 2), 3, {'a': 4}] + >>> treedef = jax.tree.structure(x) + >>> jax.tree_util.treedef_children(treedef) + [PyTreeDef((*, *)), PyTreeDef(*), PyTreeDef({'a': *})] + >>> _ == [jax.tree.structure(vals) for vals in x] + True See Also: - - :func:`jax.tree.flatten` - - :func:`jax.tree.structure` - - :func:`jax.tree.unflatten` + - :func:`jax.tree_util.treedef_tuple` """ - return default_registry.flatten(tree, is_leaf)[0] + return treedef.children() -def tree_structure(tree: Any, - is_leaf: None | (Callable[[Any], - bool]) = None) -> PyTreeDef: - """Gets the treedef for a pytree. +@export +def treedef_is_leaf(treedef: PyTreeDef) -> bool: + """Return True if the treedef represents a leaf. Args: - tree: the pytree for which to get the leaves - is_leaf : an optionally specified function that will be called at each - flattening step. It should return a boolean, which indicates whether the - flattening should traverse the current object, or if it should be stopped - immediately, with the whole subtree being treated as a leaf. + treedef: tree to check + Returns: - pytreedef: a PyTreeDef representing the structure of the tree. - Example: - >>> import jax - >>> jax.tree.structure([1, (2, 3), [4, 5]]) - PyTreeDef([*, (*, *), [*, *]]) + True if treedef is a leaf (i.e. has a single node); False otherwise. - See Also: - - :func:`jax.tree.flatten` - - :func:`jax.tree.leaves` - - :func:`jax.tree.unflatten` + Examples: + >>> import jax + >>> tree1 = jax.tree.structure(1) + >>> jax.tree_util.treedef_is_leaf(tree1) + True + >>> tree2 = jax.tree.structure([1, 2]) + >>> jax.tree_util.treedef_is_leaf(tree2) + False """ - return default_registry.flatten(tree, is_leaf)[1] - - -def treedef_tuple(treedefs: Iterable[PyTreeDef]) -> PyTreeDef: - """Makes a tuple treedef from an iterable of child treedefs.""" - return pytree.tuple(default_registry, list(treedefs)) # type: ignore - - -def treedef_children(treedef: PyTreeDef) -> list[PyTreeDef]: - return treedef.children() - - -def treedef_is_leaf(treedef: PyTreeDef) -> bool: return treedef.num_nodes == 1 +# treedef_is_strict_leaf is not exported. def treedef_is_strict_leaf(treedef: PyTreeDef) -> bool: return treedef.num_nodes == 1 and treedef.num_leaves == 1 +@export def all_leaves(iterable: Iterable[Any], is_leaf: Callable[[Any], bool] | None = None) -> bool: """Tests whether all elements in the given iterable are all leaves. - >>> tree = {"a": [1, 2, 3]} - >>> assert all_leaves(jax.tree_util.tree_leaves(tree)) - >>> assert not all_leaves([tree]) - This function is useful in advanced cases, for example if a library allows arbitrary map operations on a flat iterable of leaves it may want to check if the result is still a flat iterable of leaves. @@ -212,6 +192,12 @@ def all_leaves(iterable: Iterable[Any], Returns: A boolean indicating if all elements in the input are leaves. + + Examples: + >>> import jax + >>> tree = {"a": [1, 2, 3]} + >>> assert all_leaves(jax.tree_util.tree_leaves(tree)) + >>> assert not all_leaves([tree]) """ if is_leaf is None: return pytree.all_leaves(default_registry, iterable) @@ -223,15 +209,17 @@ def all_leaves(iterable: Iterable[Any], _Children = TypeVar("_Children", bound=Iterable[Any]) _AuxData = TypeVar("_AuxData", bound=Hashable) + +@export def register_pytree_node(nodetype: type[T], flatten_func: Callable[[T], tuple[_Children, _AuxData]], - unflatten_func: Callable[[_AuxData, _Children], T]): + unflatten_func: Callable[[_AuxData, _Children], T]) -> None: """Extends the set of types that are considered internal nodes in pytrees. See :ref:`example usage `. Args: - nodetype: a Python type to treat as an internal pytree node. + nodetype: a Python type to register as a pytree. flatten_func: a function to be used during flattening, taking a value of type ``nodetype`` and returning a pair, with (1) an iterable for the children to be flattened recursively, and (2) some hashable auxiliary data @@ -240,109 +228,156 @@ def register_pytree_node(nodetype: type[T], returned by ``flatten_func`` and stored in the treedef, and the unflattened children. The function should return an instance of ``nodetype``. + + See also: + - :func:`~jax.tree_util.register_static`: simpler API for registering a static pytree. + - :func:`~jax.tree_util.register_dataclass`: simpler API for registering a dataclass. + - :func:`~jax.tree_util.register_pytree_with_keys` + - :func:`~jax.tree_util.register_pytree_node_class` + - :func:`~jax.tree_util.register_pytree_with_keys_class` + + Examples: + First we'll define a custom type: + + >>> class MyContainer: + ... def __init__(self, size): + ... self.x = jnp.zeros(size) + ... self.y = jnp.ones(size) + ... self.size = size + + If we try using this in a JIT-compiled function, we'll get an error because JAX + does not yet know how to handle this type: + + >>> m = MyContainer(size=5) + >>> def f(m): + ... return m.x + m.y + jnp.arange(m.size) + >>> jax.jit(f)(m) # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + TypeError: Cannot interpret value of type as an abstract array; it does not have a dtype attribute + + In order to make our object recognized by JAX, we must register it as + a pytree: + + >>> def flatten_func(obj): + ... children = (obj.x, obj.y) # children must contain arrays & pytrees + ... aux_data = (obj.size,) # aux_data must contain static, hashable data. + ... return (children, aux_data) + ... + >>> def unflatten_func(aux_data, children): + ... # Here we avoid `__init__` because it has extra logic we don't require: + ... obj = object.__new__(MyContainer) + ... obj.x, obj.y = children + ... obj.size, = aux_data + ... return obj + ... + >>> jax.tree_util.register_pytree_node(MyContainer, flatten_func, unflatten_func) + + Now with this defined, we can use instances of this type in JIT-compiled functions. + + >>> jax.jit(f)(m) + Array([1., 2., 3., 4., 5.], dtype=float32) """ default_registry.register_node(nodetype, flatten_func, unflatten_func) + none_leaf_registry.register_node(nodetype, flatten_func, unflatten_func) dispatch_registry.register_node(nodetype, flatten_func, unflatten_func) _registry[nodetype] = _RegistryEntry(flatten_func, unflatten_func) -def register_pytree_node_class(cls: U) -> U: +@export +def register_pytree_node_class(cls: Typ) -> Typ: """Extends the set of types that are considered internal nodes in pytrees. This function is a thin wrapper around ``register_pytree_node``, and provides - a class-oriented interface:: - - @register_pytree_node_class - class Special: - def __init__(self, x, y): - self.x = x - self.y = y - def tree_flatten(self): - return ((self.x, self.y), None) - @classmethod - def tree_unflatten(cls, aux_data, children): - return cls(*children) - """ - register_pytree_node(cls, op.methodcaller("tree_flatten"), cls.tree_unflatten) - return cls - - -def tree_map(f: Callable[..., Any], - tree: Any, - *rest: Any, - is_leaf: Callable[[Any], bool] | None = None) -> Any: - """Maps a multi-input function over pytree args to produce a new pytree. + a class-oriented interface. Args: - f: function that takes ``1 + len(rest)`` arguments, to be applied at the - corresponding leaves of the pytrees. - tree: a pytree to be mapped over, with each leaf providing the first - positional argument to ``f``. - rest: a tuple of pytrees, each of which has the same structure as ``tree`` - or has ``tree`` as a prefix. - is_leaf: an optionally specified function that will be called at each - flattening step. It should return a boolean, which indicates whether the - flattening should traverse the current object, or if it should be stopped - immediately, with the whole subtree being treated as a leaf. + cls: a type to register as a pytree Returns: - A new pytree with the same structure as ``tree`` but with the value at each - leaf given by ``f(x, *xs)`` where ``x`` is the value at the corresponding - leaf in ``tree`` and ``xs`` is the tuple of values at corresponding nodes in - ``rest``. + The input class ``cls`` is returned unchanged after being added to JAX's pytree + registry. This return value allows ``register_pytree_node_class`` to be used as + a decorator. - Examples: + See also: + - :func:`~jax.tree_util.register_static`: simpler API for registering a static pytree. + - :func:`~jax.tree_util.register_dataclass`: simpler API for registering a dataclass. + - :func:`~jax.tree_util.register_pytree_node` + - :func:`~jax.tree_util.register_pytree_with_keys` + - :func:`~jax.tree_util.register_pytree_with_keys_class` - >>> import jax.tree_util - >>> jax.tree_util.tree_map(lambda x: x + 1, {"x": 7, "y": 42}) - {'x': 8, 'y': 43} + Examples: + Here we'll define a custom container that will be compatible with :func:`jax.jit` + and other JAX transformations: - If multiple inputs are passed, the structure of the tree is taken from the - first input; subsequent inputs need only have ``tree`` as a prefix: + >>> import jax + >>> @jax.tree_util.register_pytree_node_class + ... class MyContainer: + ... def __init__(self, x, y): + ... self.x = x + ... self.y = y + ... def tree_flatten(self): + ... return ((self.x, self.y), None) + ... @classmethod + ... def tree_unflatten(cls, aux_data, children): + ... return cls(*children) + ... + >>> m = MyContainer(jnp.zeros(4), jnp.arange(4)) + >>> def f(m): + ... return m.x + 2 * m.y + >>> jax.jit(f)(m) + Array([0., 2., 4., 6.], dtype=float32) + """ + register_pytree_node(cls, op.methodcaller("tree_flatten"), cls.tree_unflatten) + return cls - >>> jax.tree_util.tree_map(lambda x, y: [x] + y, [5, 6], [[7, 9], [1, 2]]) - [[5, 7, 9], [6, 1, 2]] - See Also: - - :func:`jax.tree.leaves` - - :func:`jax.tree.reduce` - """ +@export +def tree_map(f: Callable[..., Any], + tree: Any, + *rest: Any, + is_leaf: Callable[[Any], bool] | None = None) -> Any: + """Alias of :func:`jax.tree.map`.""" leaves, treedef = tree_flatten(tree, is_leaf) all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest] return treedef.unflatten(f(*xs) for xs in zip(*all_leaves)) +@export def build_tree(treedef: PyTreeDef, xs: Any) -> Any: - return treedef.from_iterable_tree(xs) - - -def tree_transpose(outer_treedef: PyTreeDef, inner_treedef: PyTreeDef | None, - pytree_to_transpose: Any) -> Any: - """Transform a tree having tree structure (outer, inner) into one having structure (inner, outer). + """Build a treedef from a nested iterable structure Args: - outer_treedef: PyTreeDef representing the outer tree. - inner_treedef: PyTreeDef representing the inner tree. - If None, then it will be inferred from outer_treedef and the structure of - pytree_to_transpose. - pytree_to_transpose: the pytree to be transposed. + treedef: the PyTreeDef structure to build. + xs: nested iterables matching the arity as the treedef Returns: - transposed_pytree: the transposed pytree. + object with structure defined by treedef + + See Also: + - :func:`jax.tree.unflatten` Examples: >>> import jax - >>> tree = [(1, 2, 3), (4, 5, 6)] - >>> inner_structure = jax.tree.structure(('*', '*', '*')) - >>> outer_structure = jax.tree.structure(['*', '*']) - >>> jax.tree.transpose(outer_structure, inner_structure, tree) - ([1, 4], [2, 5], [3, 6]) + >>> tree = [(1, 2), {'a': 3, 'b': 4}] + >>> treedef = jax.tree.structure(tree) - Inferring the inner structure: + Both ``build_tree`` and :func:`jax.tree_util.tree_unflatten` can reconstruct + the tree from new values, but ``build_tree`` takes these values in terms of + a nested rather than flat structure: - >>> jax.tree.transpose(outer_structure, None, tree) - ([1, 4], [2, 5], [3, 6]) + >>> jax.tree_util.build_tree(treedef, [[10, 11], [12, 13]]) + [(10, 11), {'a': 12, 'b': 13}] + >>> jax.tree_util.tree_unflatten(treedef, [10, 11, 12, 13]) + [(10, 11), {'a': 12, 'b': 13}] """ + return treedef.from_iterable_tree(xs) + + +@export +def tree_transpose(outer_treedef: PyTreeDef, inner_treedef: PyTreeDef | None, + pytree_to_transpose: Any) -> Any: + """Alias of :func:`jax.tree.transpose`.""" flat, treedef = tree_flatten(pytree_to_transpose) if inner_treedef is None: inner_treedef = tree_structure(outer_treedef.flatten_up_to(pytree_to_transpose)[0]) @@ -374,21 +409,9 @@ def tree_transpose(outer_treedef: PyTreeDef, inner_treedef: PyTreeDef | None, def _replace_nones(sentinel, tree): """Replaces ``None`` in ``tree`` with ``sentinel``.""" - if tree is None: - return sentinel - else: - handler = _registry.get(type(tree)) - if handler: - children, metadata = handler.to_iter(tree) - proc_children = [_replace_nones(sentinel, child) for child in children] - return handler.from_iter(metadata, proc_children) - elif isinstance(tree, tuple) and hasattr(tree, "_fields"): - # handle namedtuple as a special case, based on heuristic - children = iter(tree) - proc_children = [_replace_nones(sentinel, child) for child in children] - return type(tree)(*proc_children) - else: - return tree + leaves, treedef = none_leaf_registry.flatten(tree) + leaves = map(lambda x: sentinel if x is None else x, leaves) + return treedef.unflatten(leaves) no_initializer = object() @@ -410,57 +433,22 @@ def tree_reduce(function: Callable[[T, Any], T], ... +@export def tree_reduce(function: Callable[[T, Any], T], tree: Any, initializer: Any = no_initializer, is_leaf: Callable[[Any], bool] | None = None) -> T: - """Call reduce() over the leaves of a tree. - - Args: - function: the reduction function - tree: the pytree to reduce over - initializer: the optional initial value - is_leaf : an optionally specified function that will be called at each - flattening step. It should return a boolean, which indicates whether the - flattening should traverse the current object, or if it should be stopped - immediately, with the whole subtree being treated as a leaf. - - Returns: - result: the reduced value. - - Examples: - >>> import jax - >>> import operator - >>> jax.tree.reduce(operator.add, [1, (2, 3), [4, 5, 6]]) - 21 - - See Also: - - :func:`jax.tree.leaves` - - :func:`jax.tree.map` - """ + """Alias of :func:`jax.tree.reduce`.""" if initializer is no_initializer: return functools.reduce(function, tree_leaves(tree, is_leaf=is_leaf)) else: return functools.reduce(function, tree_leaves(tree, is_leaf=is_leaf), initializer) -def tree_all(tree: Any) -> bool: - """Call all() over the leaves of a tree. - - Args: - tree: the pytree to evaluate - Returns: - result: boolean True or False - - Examples: - - >>> import jax - >>> jax.tree.all([True, {'a': True, 'b': (True, True)}]) - True - >>> jax.tree.all([False, (True, False)]) - False - """ - return all(tree_leaves(tree)) +@export +def tree_all(tree: Any, *, is_leaf: Callable[[Any], bool] | None = None) -> bool: + """Alias of :func:`jax.tree.all`.""" + return all(tree_leaves(tree, is_leaf=is_leaf)) register_pytree_node( @@ -475,7 +463,7 @@ def _flatten_defaultdict(d): register_pytree_node( collections.defaultdict, _flatten_defaultdict, - lambda s, values: collections.defaultdict(s[0], safe_zip(s[1], values))) # type: ignore[index,call-overload] + lambda s, values: collections.defaultdict(s[0], safe_zip(s[1], values))) class _HashableCallableShim: @@ -499,6 +487,7 @@ def __repr__(self): return f'_HashableCallableShim({self.fun!r})' +@export class Partial(functools.partial): """A version of functools.partial that works in pytrees. @@ -535,8 +524,7 @@ class Partial(functools.partial): Array(3, dtype=int32, weak_type=True) Had we passed ``jnp.add`` to ``call_func`` directly, it would have resulted in - a - ``TypeError``. + a ``TypeError``. Note that if the result of ``Partial`` is used in the context where the value is traced, it results in all bound arguments being traced when passed @@ -570,10 +558,11 @@ def __new__(klass, func, *args, **kw): register_pytree_node( Partial, lambda partial_: ((partial_.args, partial_.keywords), partial_.func), - lambda func, xs: Partial(func, *xs[0], **xs[1]), # type: ignore[index] + lambda func, xs: Partial(func, *xs[0], **xs[1]), ) +# broadcast_prefix is not exported. def broadcast_prefix(prefix_tree: Any, full_tree: Any, is_leaf: Callable[[Any], bool] | None = None ) -> list[Any]: @@ -586,35 +575,46 @@ def broadcast_prefix(prefix_tree: Any, full_tree: Any, tree_map(add_leaves, prefix_tree, full_tree, is_leaf=is_leaf) return result -def flatten_one_level(pytree: Any) -> tuple[list[Any], Hashable]: + +# flatten_one_level is not exported. +def flatten_one_level(pytree: Any) -> tuple[Iterable[Any], Hashable]: """Flatten the given pytree node by one level. Args: pytree: A valid pytree node, either built-in or registered via - ``register_pytree_node`` or ``register_pytree_with_keys``. + :func:`register_pytree_node` or related functions. Returns: - A pair of the pytree's flattened children and its hashable metadata. + A pair of the pytrees flattened children and its hashable metadata. Raises: ValueError: If the given pytree is not a built-in or registered container via ``register_pytree_node`` or ``register_pytree_with_keys``. + + Examples: + >>> import jax + >>> from jax._src.tree_util import flatten_one_level + >>> flattened, meta = flatten_one_level({'a': [1, 2], 'b': {'c': 3}}) + >>> flattened + ([1, 2], {'c': 3}) + >>> meta + ('a', 'b') """ - handler = _registry.get(type(pytree)) - if handler: - children, meta = handler.to_iter(pytree) - return list(children), meta - elif isinstance(pytree, tuple) and hasattr(pytree, '_fields'): - # handle namedtuple as a special case, based on heuristic - return [getattr(pytree, s) for s in pytree._fields], None - else: + out = default_registry.flatten_one_level(pytree) + if out is None: raise ValueError(f"can't tree-flatten type: {type(pytree)}") + else: + return out + +# prefix_errors is not exported def prefix_errors(prefix_tree: Any, full_tree: Any, is_leaf: Callable[[Any], bool] | None = None, ) -> list[Callable[[str], ValueError]]: return list(_prefix_error((), prefix_tree, full_tree, is_leaf)) + +# equality_errors is not exported def equality_errors( tree1: Any, tree2: Any, is_leaf: Callable[[Any], bool] | None = None, ) -> Iterable[tuple[KeyPath, str, str, str]]: @@ -659,6 +659,8 @@ def _equality_errors(path, t1, t2, is_leaf): return # no more errors to find t1_children, t1_meta = flatten_one_level(t1) t2_children, t2_meta = flatten_one_level(t2) + t1_children = tuple(t1_children) + t2_children = tuple(t2_children) t1_keys, t2_keys = _child_keys(t1), _child_keys(t2) try: diff = ' '.join(repr(k.key) for k in @@ -692,26 +694,37 @@ def _equality_errors(path, t1, t2, is_leaf): yield from _equality_errors((*path, k), c1, c2, is_leaf) +@export @dataclass(frozen=True) class SequenceKey(): + """Struct for use with :func:`jax.tree_util.register_pytree_with_keys`.""" idx: int def __str__(self): return f'[{self.idx!r}]' + +@export @dataclass(frozen=True) class DictKey(): + """Struct for use with :func:`jax.tree_util.register_pytree_with_keys`.""" key: Hashable def __str__(self): return f'[{self.key!r}]' + +@export @dataclass(frozen=True) class GetAttrKey(): + """Struct for use with :func:`jax.tree_util.register_pytree_with_keys`.""" name: str def __str__(self): return f'.{self.name}' + +@export @dataclass(frozen=True) class FlattenedIndexKey(): + """Struct for use with :func:`jax.tree_util.register_pytree_with_keys`.""" key: int def __str__(self): return f'[]' @@ -721,6 +734,8 @@ def __str__(self): KeyEntry = TypeVar("KeyEntry", bound=Hashable) KeyPath = tuple[KeyEntry, ...] + +@export def keystr(keys: KeyPath): """Helper to pretty-print a tuple of keys. @@ -729,6 +744,12 @@ def keystr(keys: KeyPath): Returns: A string that joins all string representations of the keys. + + Examples: + >>> import jax + >>> keys = (0, 1, 'a', 'b') + >>> jax.tree_util.keystr(keys) + '01ab' """ return ''.join([str(k) for k in keys]) @@ -769,6 +790,7 @@ def flatten_with_keys(xs): ) +@export def register_pytree_with_keys( nodetype: type[T], flatten_with_keys: Callable[ @@ -799,6 +821,38 @@ def register_pytree_with_keys( in the same order as ``flatten_with_keys``, and return the same aux data. This argument is optional and only needed for faster traversal when calling functions without keys like ``tree_map`` and ``tree_flatten``. + + Examples: + First we'll define a custom type: + + >>> class MyContainer: + ... def __init__(self, size): + ... self.x = jnp.zeros(size) + ... self.y = jnp.ones(size) + ... self.size = size + + Now register it using a key-aware flatten function: + + >>> from jax.tree_util import register_pytree_with_keys_class, GetAttrKey + >>> def flatten_with_keys(obj): + ... children = [(GetAttrKey('x'), obj.x), + ... (GetAttrKey('y'), obj.y)] # children must contain arrays & pytrees + ... aux_data = (obj.size,) # aux_data must contain static, hashable data. + ... return children, aux_data + ... + >>> def unflatten(aux_data, children): + ... # Here we avoid `__init__` because it has extra logic we don't require: + ... obj = object.__new__(MyContainer) + ... obj.x, obj.y = children + ... obj.size, = aux_data + ... return obj + ... + >>> jax.tree_util.register_pytree_node(MyContainer, flatten_with_keys, unflatten) + + Now this can be used with functions like :func:`~jax.tree_util.tree_flatten_with_path`: + + >>> m = MyContainer(4) + >>> leaves, treedef = jax.tree_util.tree_flatten_with_path(m) """ if not flatten_func: def flatten_func_impl(tree): @@ -812,25 +866,43 @@ def flatten_func_impl(tree): ) -def register_pytree_with_keys_class(cls: U) -> U: +@export +def register_pytree_with_keys_class(cls: Typ) -> Typ: """Extends the set of types that are considered internal nodes in pytrees. This function is similar to ``register_pytree_node_class``, but requires a class that defines how it could be flattened with keys. It is a thin wrapper around ``register_pytree_with_keys``, and - provides a class-oriented interface:: - - @register_pytree_with_keys_class - class Special: - def __init__(self, x, y): - self.x = x - self.y = y - def tree_flatten_with_keys(self): - return (((GetAttrKey('x'), self.x), (GetAttrKey('y'), self.y)), None) - @classmethod - def tree_unflatten(cls, aux_data, children): - return cls(*children) + provides a class-oriented interface: + + Args: + cls: a type to register as a pytree + + Returns: + The input class ``cls`` is returned unchanged after being added to JAX's pytree + registry. This return value allows ``register_pytree_node_class`` to be used as + a decorator. + + See also: + - :func:`~jax.tree_util.register_static`: simpler API for registering a static pytree. + - :func:`~jax.tree_util.register_dataclass`: simpler API for registering a dataclass. + - :func:`~jax.tree_util.register_pytree_node` + - :func:`~jax.tree_util.register_pytree_with_keys` + - :func:`~jax.tree_util.register_pytree_node_class` + + Examples: + >>> from jax.tree_util import register_pytree_with_keys_class, GetAttrKey + >>> @register_pytree_with_keys_class + ... class Special: + ... def __init__(self, x, y): + ... self.x = x + ... self.y = y + ... def tree_flatten_with_keys(self): + ... return (((GetAttrKey('x'), self.x), (GetAttrKey('y'), self.y)), None) + ... @classmethod + ... def tree_unflatten(cls, aux_data, children): + ... return cls(*children) """ flatten_func = ( op.methodcaller("tree_flatten") if hasattr(cls, "tree_flatten") else None @@ -842,23 +914,141 @@ def tree_unflatten(cls, aux_data, children): return cls +@export +def register_dataclass( + nodetype: Typ, data_fields: Sequence[str], meta_fields: Sequence[str] +) -> Typ: + """Extends the set of types that are considered internal nodes in pytrees. + + This differs from ``register_pytree_with_keys_class`` in that the C++ + registries use the optimized C++ dataclass builtin instead of the argument + functions. + + See :ref:`extending-pytrees` for more information about registering pytrees. + + Args: + nodetype: a Python type to treat as an internal pytree node. This is assumed + to have the semantics of a :obj:`~dataclasses.dataclass`: namely, class + attributes represent the whole of the object state, and can be passed + as keywords to the class constructor to create a copy of the object. + All defined attributes should be listed among ``meta_fields`` or ``data_fields``. + meta_fields: auxiliary data field names. These fields *must* contain static, + hashable, immutable objects, as these objects are used to generate JIT cache + keys. In particular, ``meta_fields`` cannot contain :class:`jax.Array` or + :class:`numpy.ndarray` objects. + data_fields: data field names. These fields *must* be JAX-compatible objects + such as arrays (:class:`jax.Array` or :class:`numpy.ndarray`), scalars, or + pytrees whose leaves are arrays or scalars. Note that ``data_fields`` may be + ``None``, as this is recognized by JAX as an empty pytree. + + Returns: + The input class ``nodetype`` is returned unchanged after being added to JAX's + pytree registry. This return value allows ``register_dataclass`` to be partially + evaluated and used as a decorator as in the example below. + + Examples: + >>> from dataclasses import dataclass + >>> from functools import partial + >>> + >>> @partial(jax.tree_util.register_dataclass, + ... data_fields=['x', 'y'], + ... meta_fields=['op']) + ... @dataclass + ... class MyStruct: + ... x: jax.Array + ... y: jax.Array + ... op: str + ... + >>> m = MyStruct(x=jnp.ones(3), y=jnp.arange(3), op='add') + >>> m + MyStruct(x=Array([1., 1., 1.], dtype=float32), y=Array([0, 1, 2], dtype=int32), op='add') + + Now that this class is registered, it can be used with functions in :mod:`jax.tree_util`: + + >>> leaves, treedef = jax.tree.flatten(m) + >>> leaves + [Array([1., 1., 1.], dtype=float32), Array([0, 1, 2], dtype=int32)] + >>> treedef + PyTreeDef(CustomNode(MyStruct[('add',)], [*, *])) + >>> jax.tree.unflatten(treedef, leaves) + MyStruct(x=Array([1., 1., 1.], dtype=float32), y=Array([0, 1, 2], dtype=int32), op='add') + + In particular, this registration allows ``m`` to be passed seamlessly through code + wrapped in :func:`jax.jit` and other JAX transformations: + + >>> @jax.jit + ... def compiled_func(m): + ... if m.op == 'add': + ... return m.x + m.y + ... else: + ... raise ValueError(f"{m.op=}") + ... + >>> compiled_func(m) + Array([1., 2., 3.], dtype=float32) + """ + # Store inputs as immutable tuples in this scope, because we close over them + # for later evaluation. This prevents potentially confusing behavior if the + # caller were to pass in lists that are later mutated. + meta_fields = tuple(meta_fields) + data_fields = tuple(data_fields) + + def flatten_with_keys(x): + meta = tuple(getattr(x, name) for name in meta_fields) + data = tuple((GetAttrKey(name), getattr(x, name)) for name in data_fields) + return data, meta + + def unflatten_func(meta, data): + meta_args = tuple(zip(meta_fields, meta)) + data_args = tuple(zip(data_fields, data)) + kwargs = dict(meta_args + data_args) + return nodetype(**kwargs) + + def flatten_func(x): + meta = tuple(getattr(x, name) for name in meta_fields) + data = tuple(getattr(x, name) for name in data_fields) + return data, meta + + default_registry.register_dataclass_node(nodetype, list(data_fields), list(meta_fields)) + none_leaf_registry.register_dataclass_node(nodetype, list(data_fields), list(meta_fields)) + dispatch_registry.register_dataclass_node(nodetype, list(data_fields), list(meta_fields)) + _registry[nodetype] = _RegistryEntry(flatten_func, unflatten_func) + _registry_with_keypaths[nodetype] = _RegistryWithKeypathsEntry( + flatten_with_keys, unflatten_func + ) + return nodetype + + +@export def register_static(cls: type[H]) -> type[H]: """Registers `cls` as a pytree with no leaves. - Instances are treated as static by `jax.jit`, `jax.pmap`, etc. This can be an - alternative to labeling inputs as static using `jax.jit`'s `static_argnums` - and `static_argnames` kwargs, `jax.pmap`'s `static_broadcasted_argnums`, etc. + Instances are treated as static by :func:`jax.jit`, :func:`jax.pmap`, etc. This can + be an alternative to labeling inputs as static using ``jit``'s ``static_argnums`` + and ``static_argnames`` kwargs, ``pmap``'s ``static_broadcasted_argnums``, etc. - `cls` must be hashable, as defined in - https://docs.python.org/3/glossary.html#term-hashable. + Args: + cls: type to be registered as static. Must be hashable, as defined in + https://docs.python.org/3/glossary.html#term-hashable. - `register_static` can be applied to subclasses of builtin hashable classes - such as `str`, like this: - ``` - @tree_util.register_static - class StaticStr(str): - pass - ``` + Returns: + The input class ``cls`` is returned unchanged after being added to JAX's + pytree registry. This allows ``register_static`` to be used as a decorator. + + Examples: + >>> import jax + >>> @jax.tree_util.register_static + ... class StaticStr(str): + ... pass + + This static string can now be used directly in :func:`jax.jit`-compiled + functions, without marking the variable static using ``static_argnums``: + + >>> @jax.jit + ... def f(x, y, s): + ... return x + y if s == 'add' else x - y + ... + >>> f(1, 2, StaticStr('add')) + Array(3, dtype=int32, weak_type=True) """ flatten = lambda obj: ((), obj) unflatten = lambda obj, empty_iter_children: obj @@ -866,6 +1056,7 @@ class StaticStr(str): return cls +@export def tree_flatten_with_path( tree: Any, is_leaf: Callable[[Any], bool] | None = None ) -> tuple[list[tuple[KeyPath, Any]], PyTreeDef]: @@ -883,6 +1074,7 @@ def tree_flatten_with_path( return _generate_key_paths(tree, is_leaf), tree_def +@export def tree_leaves_with_path( tree: Any, is_leaf: Callable[[Any], bool] | None = None ) -> list[tuple[KeyPath, Any]]: @@ -893,10 +1085,15 @@ def tree_leaves_with_path( ``register_pytree_with_keys``. Returns: A list of key-leaf pairs, each of which contains a leaf and its key path. + + See Also: + - :func:`jax.tree_util.tree_leaves` + - :func:`jax.tree_util.tree_flatten_with_path` """ return _generate_key_paths(tree, is_leaf) +# generate_key_paths is not exported. def generate_key_paths( tree: Any, is_leaf: Callable[[Any], bool] | None = None ) -> list[tuple[KeyPath, Any]]: @@ -914,25 +1111,31 @@ def _generate_key_paths_( yield key_path, tree return key_handler = _registry_with_keypaths.get(type(tree)) - handler = _registry.get(type(tree)) if key_handler: key_children, _ = key_handler.flatten_with_keys(tree) for k, c in key_children: yield from _generate_key_paths_((*key_path, k), c, is_leaf) - elif handler: - children, _ = handler.to_iter(tree) - for i, c in enumerate(children): - k = FlattenedIndexKey(i) - yield from _generate_key_paths_((*key_path, k), c, is_leaf) - elif isinstance(tree, tuple) and hasattr(tree, '_fields'): + return + + flat = default_registry.flatten_one_level(tree) + if flat is None: + yield key_path, tree # strict leaf type + return + + if (isinstance(tree, tuple) and hasattr(tree, '_fields') and + flat[1] == type(tree)): # handle namedtuple as a special case, based on heuristic key_children = [(GetAttrKey(s), getattr(tree, s)) for s in tree._fields] for k, c in key_children: yield from _generate_key_paths_((*key_path, k), c, is_leaf) - else: - yield key_path, tree # strict leaf type + return + for i, c in enumerate(flat[0]): + k = FlattenedIndexKey(i) + yield from _generate_key_paths_((*key_path, k), c, is_leaf) + +@export def tree_map_with_path(f: Callable[..., Any], tree: Any, *rest: Any, is_leaf: Callable[[Any], bool] | None = None) -> Any: @@ -954,6 +1157,11 @@ def tree_map_with_path(f: Callable[..., Any], leaf given by ``f(kp, x, *xs)`` where ``kp`` is the key path of the leaf at the corresponding leaf in ``tree``, ``x`` is the leaf value and ``xs`` is the tuple of values at corresponding nodes in ``rest``. + + See Also: + - :func:`jax.tree_util.tree_map` + - :func:`jax.tree_util.tree_flatten_with_path` + - :func:`jax.tree_util.tree_leaves_with_path` """ keypath_leaves, treedef = tree_flatten_with_path(tree, is_leaf) @@ -1001,6 +1209,8 @@ def _prefix_error( # point, and because prefix_tree is not a leaf, each can be flattened once: prefix_tree_children, prefix_tree_meta = flatten_one_level(prefix_tree) full_tree_children, full_tree_meta = flatten_one_level(full_tree) + prefix_tree_children = tuple(prefix_tree_children) + full_tree_children = tuple(full_tree_children) prefix_tree_keys = _child_keys(prefix_tree) full_tree_keys = _child_keys(full_tree) # First we check special case types (list and tuple, though if they were diff --git a/jax/_src/typing.py b/jax/_src/typing.py index 8ae276b37457..353f63f2a86d 100644 --- a/jax/_src/typing.py +++ b/jax/_src/typing.py @@ -29,10 +29,12 @@ from collections.abc import Sequence from typing import Any, Protocol, Union import numpy as np +import enum from jax._src.basearray import ( Array as Array, ArrayLike as ArrayLike, + StaticScalar as StaticScalar, ) DType = np.dtype @@ -77,3 +79,15 @@ def shape(self) -> Shape: ... # JAX array (i.e. not including future non-standard array types like KeyArray and BInt). # It's different than np.typing.ArrayLike in that it doesn't accept arbitrary sequences, # nor does it accept string data. + +# We use a class for deprecated args to avoid using Any/object types which can +# introduce complications and mistakes in static analysis +class DeprecatedArg: + def __repr__(self): + return "Deprecated" + +# Mirror of dlpack.h enum +class DLDeviceType(enum.IntEnum): + kDLCPU = 1 + kDLCUDA = 2 + kDLROCM = 10 diff --git a/jax/_src/util.py b/jax/_src/util.py index 3bfea618ceb9..5fb4ea4d7bf9 100644 --- a/jax/_src/util.py +++ b/jax/_src/util.py @@ -15,13 +15,13 @@ from __future__ import annotations import abc -from collections.abc import Iterable, Iterator, Sequence +from collections.abc import Callable, Iterable, Iterator, Sequence import functools from functools import partial import itertools as it import logging import operator -from typing import (Any, Callable, Generic, TypeVar, overload, TYPE_CHECKING, cast) +from typing import (Any, Generic, TypeVar, overload, TYPE_CHECKING, cast) import weakref import numpy as np @@ -131,6 +131,15 @@ def split_list(args: Sequence[T], ns: Sequence[int]) -> list[list[T]]: lists.append(args) return lists +def split_list_checked(args: Sequence[T], ns: Sequence[int]) -> list[list[T]]: + args = list(args) + assert sum(ns) == len(args) + lists = [] + for n in ns: + lists.append(args[:n]) + args = args[n:] + return lists + def partition_list(bs: Sequence[bool], l: Sequence[T]) -> tuple[list[T], list[T]]: assert len(bs) == len(l) lists = [], [] # type: ignore @@ -138,10 +147,13 @@ def partition_list(bs: Sequence[bool], l: Sequence[T]) -> tuple[list[T], list[T] lists[b].append(x) return lists -def merge_lists(bs: Sequence[bool], l0: Sequence[T], l1: Sequence[T]) -> list[T]: +def merge_lists(bs: Sequence[bool], + l0: Sequence[T1], + l1: Sequence[T2] + ) -> list[T1 | T2]: assert sum(bs) == len(l1) and len(bs) - sum(bs) == len(l0) i0, i1 = iter(l0), iter(l1) - out = [next(i1) if b else next(i0) for b in bs] + out: list[T1 | T2] = [next(i1) if b else next(i0) for b in bs] sentinel = object() assert next(i0, sentinel) is next(i1, sentinel) is sentinel return out @@ -273,7 +285,11 @@ def merge(new_lhs, new_rhs): return lhs, rhs, merge -def cache(max_size=4096): + +def _ignore(): return None + + +def cache(max_size=4096, trace_context_in_key=True): def wrap(f): @functools.lru_cache(max_size) def cached(_, *args, **kwargs): @@ -283,17 +299,26 @@ def cached(_, *args, **kwargs): def wrapper(*args, **kwargs): if config.check_tracer_leaks.value: return f(*args, **kwargs) - else: - return cached(config.trace_context(), *args, **kwargs) + return cached(config.trace_context() if trace_context_in_key else _ignore(), + *args, **kwargs) wrapper.cache_clear = cached.cache_clear wrapper.cache_info = cached.cache_info + cache_clearing_funs.add(wrapper.cache_clear) return wrapper return wrap +cache_clearing_funs = weakref.WeakSet() # type: ignore + +def clear_all_caches(): + global cache_clearing_funs + for clear in cache_clearing_funs: + clear() + memoize = cache(max_size=None) -def weakref_lru_cache(call: Callable, maxsize=2048): +def weakref_lru_cache(call: Callable, maxsize=2048, + trace_context_in_key: bool = True): """ Least recently used cache decorator with weakref support. @@ -302,7 +327,9 @@ def weakref_lru_cache(call: Callable, maxsize=2048): behave similar to `functools.lru_cache`. """ global _weakref_lru_caches - cached_call = xc.weakref_lru_cache(config.trace_context, call, maxsize) + cached_call = xc.weakref_lru_cache( + config.trace_context if trace_context_in_key else _ignore, + call, maxsize) _weakref_lru_caches.add(cached_call) return cached_call @@ -348,6 +375,9 @@ def __eq__(self, other): def wrap_name(name, transform_name): return transform_name + '(' + name + ')' +def fun_name(fun: Callable): + return getattr(fun, "__name__", "") + def canonicalize_axis(axis, num_dims) -> int: """Canonicalize an axis in [-num_dims, num_dims) to [0, num_dims).""" axis = operator.index(axis) @@ -387,7 +417,7 @@ def wraps( """ def wrapper(fun: T) -> T: try: - name = getattr(wrapped, "__name__", "") + name = fun_name(wrapped) doc = getattr(wrapped, "__doc__", "") or "" fun.__dict__.update(getattr(wrapped, "__dict__", {})) fun.__annotations__ = getattr(wrapped, "__annotations__", {}) diff --git a/jax/_src/xla_bridge.py b/jax/_src/xla_bridge.py index eda42fbb864d..631ebb77685d 100644 --- a/jax/_src/xla_bridge.py +++ b/jax/_src/xla_bridge.py @@ -20,9 +20,8 @@ """ from __future__ import annotations -from __future__ import annotations - -from collections.abc import Mapping +import atexit +from collections.abc import Callable, Mapping import dataclasses from functools import lru_cache, partial import importlib @@ -31,27 +30,26 @@ import os import pkgutil import platform as py_platform -import sys import threading -from typing import Any, Callable, Union +import traceback +from typing import Any, Union import warnings from jax._src import config from jax._src import distributed +from jax._src import hardware_utils from jax._src import traceback_util from jax._src import util -from jax._src import hardware_utils -from jax._src.cloud_tpu_init import maybe_import_libtpu +from jax._src.cloud_tpu_init import get_tpu_library_path from jax._src.lib import cuda_versions from jax._src.lib import xla_client from jax._src.lib import xla_extension -from jax._src.lib import xla_extension_version logger = logging.getLogger(__name__) jax_plugins: Any | None try: - import jax_plugins # type: ignore + import jax_plugins # pytype: disable=import-error except ModuleNotFoundError: jax_plugins = None except ImportError as e: @@ -65,43 +63,59 @@ MIN_COMPUTE_CAPABILITY = 52 # TODO(phawkins): Remove jax_xla_backend. -_XLA_BACKEND = config.DEFINE_string( +_XLA_BACKEND = config.string_flag( 'jax_xla_backend', '', 'Deprecated, please use --jax_platforms instead.') -BACKEND_TARGET = config.DEFINE_string( +BACKEND_TARGET = config.string_flag( 'jax_backend_target', os.getenv('JAX_BACKEND_TARGET', '').lower(), 'Either "local" or "rpc:address" to connect to a remote service target.') # TODO(skye): warn when this is used once we test out --jax_platforms a bit -_PLATFORM_NAME = config.DEFINE_string( +_PLATFORM_NAME = config.string_flag( 'jax_platform_name', os.getenv('JAX_PLATFORM_NAME', '').lower(), 'Deprecated, please use --jax_platforms instead.') -CUDA_VISIBLE_DEVICES = config.DEFINE_string( +CUDA_VISIBLE_DEVICES = config.string_flag( 'jax_cuda_visible_devices', 'all', 'Restricts the set of CUDA devices that JAX will use. Either "all", or a ' 'comma-separate list of integer device IDs.') -_ROCM_VISIBLE_DEVICES = config.DEFINE_string( +_ROCM_VISIBLE_DEVICES = config.string_flag( 'jax_rocm_visible_devices', 'all', 'Restricts the set of ROCM devices that JAX will use. Either "all", or a ' 'comma-separate list of integer device IDs.') -_USE_MOCK_GPU_CLIENT = config.DEFINE_bool( +_USE_MOCK_GPU_CLIENT = config.bool_flag( name="use_mock_gpu_client", default=False, help="If True, use a mock GPU client instead of a real one.", ) -_MOCK_NUM_GPUS = config.DEFINE_integer( +_MOCK_NUM_GPUS = config.int_flag( name="mock_num_gpus", default=1, help="Mock GPU client number of gpus.", ) -_CPU_ENABLE_GLOO_COLLECTIVES = config.DEFINE_bool( +_CPU_ENABLE_GLOO_COLLECTIVES = config.bool_flag( name="jax_cpu_enable_gloo_collectives", default=False, - help="If True, enable cross-process collectives on CPU using Gloo.", + help="Deprecated, please use jax_cpu_collectives_implementation instead.", +) + +_CPU_COLLECTIVES_IMPLEMENTATION = config.string_flag( + name='jax_cpu_collectives_implementation', + default='none', + help='Cross-process collective implementation used on CPU. Either "none", ' + '"gloo" or "mpi"' +) + +# TODO(yueshengys): turn default back to True after resolving memory increase +# issue. +_CPU_ENABLE_ASYNC_DISPATCH = config.bool_flag( + name="jax_cpu_enable_async_dispatch", + default=False, + help="Only applies to non-parallel computations. If False, run computations" + "inline without async dispatch.", ) @@ -112,35 +126,11 @@ def _at_fork(): "and JAX is multithreaded, so this will likely lead to a deadlock.", RuntimeWarning, stacklevel=2) -# os.register_at_fork only exists on Unix. -if hasattr(os, "register_at_fork"): - os.register_at_fork(before=_at_fork) - +_at_fork_handler_installed = False # Backends -def _get_tpu_library_path() -> str | None: - path_from_env = os.getenv("TPU_LIBRARY_PATH") - if path_from_env is not None: - return path_from_env - - libtpu_module = maybe_import_libtpu() - if libtpu_module is not None: - if hasattr(libtpu_module, "get_library_path"): - if xla_extension_version < 212: - # xla_extension_version < 212 uses tpu_tracer which requires calling - # configure_library_path. - libtpu_module.configure_library_path() - return libtpu_module.get_library_path() - else: - # TODO(b/305803029): Remove this branch around 01/2024 after the oldest - # supported TPU has get_library_path. - libtpu_module.configure_library_path() - return os.getenv("TPU_LIBRARY_PATH", None) - - return None - def tpu_client_timer_callback(timer_secs: float) -> xla_client.Client | None: def _log_warning(): @@ -155,7 +145,9 @@ def _log_warning(): t.start() try: - client = xla_client.make_tpu_client(_get_tpu_library_path()) + client = xla_client.make_tpu_client( # type: ignore + get_tpu_library_path(), + _options_from_jax_configs("tpu")) finally: t.cancel() @@ -190,6 +182,9 @@ class BackendRegistration: # a buggy plugin. experimental: bool = False + # The C API (`PJRT_Api*`) if this backend is a plugin. + c_api: Any | None = None + _backend_factories: dict[str, BackendRegistration] = {} _default_backend: xla_client.Client | None = None _backends : dict[str, xla_client.Client] = {} @@ -198,6 +193,8 @@ class BackendRegistration: _plugins_registered: bool = False _plugin_lock = threading.Lock() _topology_factories: dict[str, TopologyFactory] = {} +_plugin_callbacks: list[Any] = [] +_plugin_callback_lock = threading.Lock() # The set of known non-experimental plugins. # @@ -213,39 +210,48 @@ def register_backend_factory(name: str, factory: BackendFactory, *, priority: int = 0, fail_quietly: bool = True, experimental: bool = False, - make_topology: TopologyFactory | None = None) -> None: + make_topology: TopologyFactory | None = None, + c_api: Any | None = None) -> None: with _backend_lock: if name in _backends: raise RuntimeError(f"Backend {name} already initialized") _backend_factories[name] = BackendRegistration( - factory, priority, fail_quietly, experimental) + factory, priority, fail_quietly, experimental, c_api) if make_topology is not None: _topology_factories[name] = make_topology def make_cpu_client() -> xla_client.Client: - if xla_extension_version >= 223: - collectives: xla_client._xla.CpuCollectives | None = None - if _CPU_ENABLE_GLOO_COLLECTIVES.value: - collectives = xla_client._xla.make_gloo_tcp_collectives( # type: ignore - distributed_client=distributed.global_state.client, - ) - return xla_client.make_cpu_client( # type: ignore - distributed_client=distributed.global_state.client, - node_id=distributed.global_state.process_id, - num_nodes=distributed.global_state.num_processes, - collectives=collectives, - ) - elif xla_extension_version >= 216: - # TODO(phawkins): remove type: ignore after updating jaxlib version used for - # mypy checks. - return xla_client.make_cpu_client( # type: ignore + collectives: xla_client._xla.CpuCollectives | None = None + + collectives_impl = _CPU_COLLECTIVES_IMPLEMENTATION.value + if _CPU_ENABLE_GLOO_COLLECTIVES.value: + collectives_impl = 'gloo' + warnings.warn('Setting `jax_cpu_enable_gloo_collectives` is deprecated. ' + 'Please use `jax.config.update(' + '"jax_cpu_collectives_implementation", "gloo")` instead.', + DeprecationWarning, + ) + if collectives_impl == 'gloo': + collectives = xla_client._xla.make_gloo_tcp_collectives( distributed_client=distributed.global_state.client, - node_id=distributed.global_state.process_id, - num_nodes=distributed.global_state.num_processes, ) - else: - return xla_client.make_cpu_client() + elif collectives_impl == 'mpi': + collectives = xla_client._xla.make_mpi_collectives() + collectives.Init() + atexit.register(collectives.Finalize) + elif collectives_impl != 'none': + collectives_impls = ['none', 'gloo', 'mpi'] + raise RuntimeError(f"Unknown collectives implementation " + f"{collectives_impl}. Available implementations are " + f"{collectives_impls}.") + return xla_client.make_cpu_client( + asynchronous=_CPU_ENABLE_ASYNC_DISPATCH.value, + distributed_client=distributed.global_state.client, + node_id=distributed.global_state.process_id, + num_nodes=distributed.global_state.num_processes, + collectives=collectives, + ) register_backend_factory( @@ -266,33 +272,101 @@ def _check_cuda_compute_capability(devices_to_check): RuntimeWarning ) -def _check_cuda_versions(): + +def _check_cuda_versions(raise_on_first_error: bool = False, + debug: bool = False): assert cuda_versions is not None + results: list[dict[str, Any]] = [] + + def _make_msg(name: str, + runtime_version: int, + build_version: int, + min_supported: int, + debug_msg: bool = False): + if debug_msg: + return (f"Package: {name}\n" + f"Version JAX was built against: {build_version}\n" + f"Minimum supported: {min_supported}\n" + f"Installed version: {runtime_version}") + if min_supported: + req_str = (f"The local installation version must be no lower than " + f"{min_supported}.") + else: + req_str = ("The local installation must be the same version as " + "the version against which JAX was built.") + msg = (f"Outdated {name} installation found.\n" + f"Version JAX was built against: {build_version}\n" + f"Minimum supported: {min_supported}\n" + f"Installed version: {runtime_version}\n" + f"{req_str}") + return msg + + def _version_check(name: str, + get_version, + get_build_version, + scale_for_comparison: int = 1, + min_supported_version: int = 0): + """Checks the runtime CUDA component version against the JAX one. + + Args: + name: Of the CUDA component. + get_version: A function to get the local runtime version of the component. + get_build_version: A function to get the build version of the component. + scale_for_comparison: For rounding down a version to ignore patch/minor. + min_supported_version: An absolute minimum version required. Must be + passed without rounding down. + + Raises: + RuntimeError: If the component is not found, or is of unsupported version, + and if raising the error is not deferred till later. + """ - def _version_check(name, get_version, get_build_version, - scale_for_comparison=1): build_version = get_build_version() try: version = get_version() except Exception as e: - raise RuntimeError(f"Unable to load {name}. Is it installed?") from e - if build_version // scale_for_comparison > version // scale_for_comparison: - raise RuntimeError( - f"Found {name} version {version}, but JAX was built against version " - f"{build_version}, which is newer. The copy of {name} that is " - "installed must be at least as new as the version against which JAX " - "was built." - ) + err_msg = f"Unable to load {name}. Is it installed?" + if raise_on_first_error: + raise RuntimeError(err_msg) from e + err_msg += f"\n{traceback.format_exc()}" + results.append({"name": name, "installed": False, "msg": err_msg}) + return + + if not min_supported_version: + min_supported_version = build_version // scale_for_comparison + passed = min_supported_version <= version + + if not passed or debug: + msg = _make_msg(name=name, + runtime_version=version, + build_version=build_version, + min_supported=min_supported_version, + debug_msg=passed) + if not passed and raise_on_first_error: + raise RuntimeError(msg) + else: + record = {"name": name, + "installed": True, + "msg": msg, + "passed": passed, + "build_version": build_version, + "version": version, + "minimum_supported": min_supported_version} + results.append(record) _version_check("CUDA", cuda_versions.cuda_runtime_get_version, - cuda_versions.cuda_runtime_build_version) + cuda_versions.cuda_runtime_build_version, + scale_for_comparison=10, + min_supported_version=12010) _version_check( "cuDNN", cuda_versions.cudnn_get_version, cuda_versions.cudnn_build_version, # NVIDIA promise both backwards and forwards compatibility for cuDNN patch - # versions: https://docs.nvidia.com/deeplearning/cudnn/developer-guide/index.html#api-compat + # versions: + # https://docs.nvidia.com/deeplearning/cudnn/developer-guide/index.html#api-compat scale_for_comparison=100, + min_supported_version=9000 ) _version_check("cuFFT", cuda_versions.cufft_get_version, cuda_versions.cufft_build_version, @@ -301,24 +375,43 @@ def _version_check(name, get_version, get_build_version, _version_check("cuSOLVER", cuda_versions.cusolver_get_version, cuda_versions.cusolver_build_version, # Ignore patch versions. - scale_for_comparison=100) + scale_for_comparison=100, + min_supported_version=11400) _version_check("cuPTI", cuda_versions.cupti_get_version, - cuda_versions.cupti_build_version) - # TODO(jakevdp) remove these checks when minimum jaxlib is v0.4.21 - if hasattr(cuda_versions, "cublas_get_version"): - _version_check("cuBLAS", cuda_versions.cublas_get_version, - cuda_versions.cublas_build_version, - # Ignore patch versions. - scale_for_comparison=100) - if hasattr(cuda_versions, "cusparse_get_version"): - _version_check("cuSPARSE", cuda_versions.cusparse_get_version, - cuda_versions.cusparse_build_version, - # Ignore patch versions. - scale_for_comparison=100) + cuda_versions.cupti_build_version, + min_supported_version=18) + _version_check("cuBLAS", cuda_versions.cublas_get_version, + cuda_versions.cublas_build_version, + # Ignore patch versions. + scale_for_comparison=100, + min_supported_version=120100) + _version_check("cuSPARSE", cuda_versions.cusparse_get_version, + cuda_versions.cusparse_build_version, + # Ignore patch versions. + scale_for_comparison=100, + min_supported_version=12100) + + errors = [] + debug_results = [] + for result in results: + message: str = result['msg'] + if not result['installed'] or not result['passed']: + errors.append(message) + else: + debug_results.append(message) + + join_str = f'\n{"-" * 50}\n' + if debug_results: + print(f'CUDA components status (debug):\n' + f'{join_str.join(debug_results)}') + if errors: + raise RuntimeError(f'Unable to use CUDA because of the ' + f'following issues with CUDA components:\n' + f'{join_str.join(errors)}') def make_gpu_client( - *, platform_name: str, visible_devices_flag: config.FlagHolder[str] + *, platform_name: str, visible_devices_flag: config.Flag[str] ) -> xla_client.Client: visible_devices = visible_devices_flag.value allowed_devices = None @@ -332,8 +425,17 @@ def make_gpu_client( else distributed.global_state.num_processes ) if platform_name == "cuda": - _check_cuda_versions() - devices_to_check = allowed_devices if allowed_devices else range(cuda_versions.cuda_device_count()) + if not os.getenv("JAX_SKIP_CUDA_CONSTRAINTS_CHECK"): + _check_cuda_versions() + else: + print('Skipped CUDA versions constraints check due to the ' + 'JAX_SKIP_CUDA_CONSTRAINTS_CHECK env var being set.') + + devices_to_check = ( + allowed_devices + if allowed_devices + else range(cuda_versions.cuda_device_count()) + ) _check_cuda_compute_capability(devices_to_check) return xla_client.make_gpu_client( @@ -342,7 +444,7 @@ def make_gpu_client( num_nodes=num_nodes, platform_name=platform_name, allowed_devices=allowed_devices, - mock=use_mock_gpu_client, # type: ignore[call-arg] + mock=use_mock_gpu_client, ) @@ -467,12 +569,7 @@ def discover_pjrt_plugins() -> None: logger.debug("No jax_plugins namespace packages available") # Augment with advertised entrypoints. - if sys.version_info < (3, 10): - # Use the backport library because it provides a forward-compatible - # implementation. - from importlib_metadata import entry_points - else: - from importlib.metadata import entry_points + from importlib.metadata import entry_points for entry_point in entry_points(group="jax_plugins"): logger.debug("Discovered entry-point based JAX plugin: %s", @@ -501,16 +598,30 @@ def discover_pjrt_plugins() -> None: def _options_from_jax_configs(plugin_name): - if plugin_name != "cuda": - return {} - options = {} - visible_devices = CUDA_VISIBLE_DEVICES.value - if visible_devices != 'all': - options['visible_devices'] = [int(x) for x in visible_devices.split(',')] - options['enable_mock_nccl'] = _USE_MOCK_GPU_CLIENT.value - if options['enable_mock_nccl']: - options['num_nodes'] = _MOCK_NUM_GPUS.value + + pjrt_client_options = config.jax_pjrt_client_create_options.value + pjrt_client_option_list = [] + if pjrt_client_options: + pjrt_client_option_list = pjrt_client_options.split(";") + + for option in pjrt_client_option_list: + option_list = option.split(":") + if (len(option_list) != 2): + raise RuntimeError( + "Multiple ':' separators for option in " + f"jax_pjrt_client_create_options: '{option}'. " + "Should be in format 'key:value'") + options[option_list[0]] = option_list[1] + + if plugin_name == "cuda": + visible_devices = CUDA_VISIBLE_DEVICES.value + if visible_devices != 'all': + options['visible_devices'] = [int(x) for x in visible_devices.split(',')] + options['enable_mock_nccl'] = _USE_MOCK_GPU_CLIENT.value + if options['enable_mock_nccl']: + options['num_nodes'] = _MOCK_NUM_GPUS.value + return options @@ -574,20 +685,17 @@ def factory(): 'registering PJRT plugin %s from %s', plugin_name, library_path ) if library_path is not None: - c_api = xla_client.load_pjrt_plugin_dynamically(plugin_name, library_path) # type: ignore + c_api = xla_client.load_pjrt_plugin_dynamically(plugin_name, library_path) xla_client.profiler.register_plugin_profiler(c_api) else: - if xla_extension_version >= 236: - assert c_api is not None - xla_client.load_pjrt_plugin_with_c_api(plugin_name, c_api) - if xla_extension_version >= 239: - make_topology = partial(xla_client.make_c_api_device_topology, c_api) - else: - make_topology = None + assert c_api is not None + xla_client.load_pjrt_plugin_with_c_api(plugin_name, c_api) + + make_topology = partial(xla_client.make_c_api_device_topology, c_api) experimental = plugin_name not in _nonexperimental_plugins register_backend_factory(plugin_name, factory, priority=priority, fail_quietly=False, experimental=experimental, - make_topology=make_topology) + make_topology=make_topology, c_api=c_api) return c_api @@ -634,6 +742,11 @@ def _discover_and_register_pjrt_plugins(): # PJRT_NAMES_AND_LIBRARY_PATHS, in the format of 'name1:path1,name2:path2' # ('name1;path1,name2;path2' for windows). register_pjrt_plugin_factories_from_env() + with _plugin_callback_lock: + for factory in _backend_factories.values(): + if factory.c_api is not None: + for callback in _plugin_callbacks: + callback(c_api=factory.c_api) _plugins_registered = True @@ -694,16 +807,44 @@ def backends_are_initialized() -> bool: return len(_backends) != 0 +def register_plugin_callbacks(callback): + """Registers a callback to be called with c_api after plugins discovery. + + The callback will be called on all discovered PJRT C API plugins. If + `register_plugin_callbacks` is called before the plugins are discovered, the + callback will be called right after the plugins are discovered. Otherwise, the + callback will be called immediately when `register_plugin_callbacks` is + called. + + Args: + callback: the callback to be called with c_api. + """ + with _plugin_callback_lock: + if _plugins_registered: + for factory in _backend_factories.values(): + if factory.c_api is not None: + callback(c_api=factory.c_api) + else: + _plugin_callbacks.append(callback) + + def backends() -> dict[str, xla_client.Client]: global _backends global _backend_errors global _default_backend + global _at_fork_handler_installed _discover_and_register_pjrt_plugins() with _backend_lock: if _backends: return _backends + + # os.register_at_fork only exists on Unix. + if not _at_fork_handler_installed and hasattr(os, "register_at_fork"): + os.register_at_fork(before=_at_fork) + _at_fork_handler_installed = True + if jax_platforms := config.jax_platforms.value: platforms = [] # Allow platform aliases in the list of platforms. @@ -769,8 +910,17 @@ def _suggest_missing_backends(): any(os.path.exists(d) for d in nvidia_gpu_devices)): if hasattr(xla_extension, "GpuAllocatorConfig") and "cuda" in _backend_errors: err = _backend_errors["cuda"] - logger.warning(f"CUDA backend failed to initialize: {err} (Set " - "TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)") + warning_msg = f"CUDA backend failed to initialize: {err}." + if "no supported devices found for platform CUDA." in err: + warning_msg += ( + "This may be due to JAX pre-allocating too much device " + "memory, leaving too little for CUDA library initialization. See " + "https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html " + "for more details and potential workarounds." + ) + warning_msg += "(Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)" + + logger.warning(warning_msg) else: logger.warning("An NVIDIA GPU may be present on this machine, but a " "CUDA-enabled jaxlib is not installed. Falling back to " @@ -840,7 +990,8 @@ def _get_backend_uncached( raise RuntimeError(f"Backend '{platform}' failed to initialize: " f"{_backend_errors[platform]}. " f'Available backends are {list(bs)}') - raise RuntimeError(f"Unknown backend {platform}") + raise RuntimeError( + f"Unknown backend {platform}. Available backends are {list(bs)}") return backend else: assert _default_backend is not None @@ -938,6 +1089,18 @@ def backend_pjrt_c_api_version(platform=None) -> tuple[int, int] | None: return None +def backend_xla_version(platform=None) -> int | None: + """Returns the XLA version of the backend. + + Returns None if the backend does not use PJRT C API or does not have + xla_version in the plugin attributes. This methon can be used to skip features + that are not available before certain xla_version if the backend is a + plugin and uses xla_version. + """ + backend = get_backend(platform) + return getattr(backend, "xla_version", None) + + @lru_cache def local_devices(process_index: int | None = None, backend: str | xla_client.Client | None = None, @@ -1038,12 +1201,13 @@ def make_pjrt_topology(platform: str, topology_name='', **kwargs): # TODO(parkers): Get rid of this in favor of a generic way to get topologies. def make_pjrt_tpu_topology(topology_name='', **kwargs): if not xla_client.pjrt_plugin_loaded("tpu"): - library_path = _get_tpu_library_path() + library_path = get_tpu_library_path() if library_path is None: raise RuntimeError( "JAX TPU support not installed; cannot generate TPU topology. See" " https://github.com/google/jax#installation") - xla_client.load_pjrt_plugin_dynamically("tpu", library_path) + c_api = xla_client.load_pjrt_plugin_dynamically("tpu", library_path) + xla_client.profiler.register_plugin_profiler(c_api) assert xla_client.pjrt_plugin_loaded("tpu") if not xla_client.pjrt_plugin_initialized("tpu"): xla_client.initialize_pjrt_plugin("tpu") diff --git a/jax/config.py b/jax/config.py deleted file mode 100644 index 9ea911f4eb92..000000000000 --- a/jax/config.py +++ /dev/null @@ -1,26 +0,0 @@ -# Copyright 2018 The JAX Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import warnings -from jax._src import deprecations - -# Added February 16, 2024. -_msg = ("Importing the jax.config submodule via `import jax.config` is deprecated." - " To configure JAX use `import jax` and then reference the config object" - " via `jax.config`.") -if deprecations.is_accelerated("jax.config", "config-module"): - raise ImportError(_msg) -else: - warnings.warn(_msg, DeprecationWarning, stacklevel=2) -del deprecations, warnings, _msg diff --git a/jax/core.py b/jax/core.py index c9cbd3310a52..b023d2daf163 100644 --- a/jax/core.py +++ b/jax/core.py @@ -68,7 +68,6 @@ call_bind_with_continuation as call_bind_with_continuation, call_impl as call_impl, call_p as call_p, - canonicalize_shape as _deprecated_canonicalize_shape, check_eqn as check_eqn, check_jaxpr as check_jaxpr, check_type as check_type, @@ -80,8 +79,6 @@ cur_sublevel as cur_sublevel, custom_typechecks as custom_typechecks, dedup_referents as dedup_referents, - definitely_equal as _deprecated_definitely_equal, - dimension_as_value as _deprecated_dimension_as_value, do_subst_axis_names_jaxpr as do_subst_axis_names_jaxpr, ensure_compile_time_eval as ensure_compile_time_eval, escaped_tracer_error as escaped_tracer_error, @@ -118,18 +115,6 @@ no_effects as no_effects, non_negative_dim as _deprecated_non_negative_dim, outfeed_primitives as outfeed_primitives, - pp_aval as pp_aval, - pp_eqn as pp_eqn, - pp_eqn_rules as pp_eqn_rules, - pp_eqns as pp_eqns, - pp_jaxpr as pp_jaxpr, - pp_jaxpr_eqn_range as pp_jaxpr_eqn_range, - pp_jaxpr_skeleton as pp_jaxpr_skeleton, - pp_jaxprs as pp_jaxprs, - pp_kv_pair as pp_kv_pair, - pp_kv_pairs as pp_kv_pairs, - pp_var as pp_var, - pp_vars as pp_vars, primal_dtype_to_tangent_dtype as primal_dtype_to_tangent_dtype, primitive_uses_outfeed as primitive_uses_outfeed, process_env_traces_call as process_env_traces_call, @@ -148,7 +133,6 @@ subst_axis_names_var as subst_axis_names_var, substitute_vars_in_output_ty as substitute_vars_in_output_ty, thread_local_state as thread_local_state, - token as token, trace_state_clean as trace_state_clean, traverse_jaxpr_params as traverse_jaxpr_params, typecheck as typecheck, @@ -163,27 +147,40 @@ from jax._src import core as _src_core _deprecations = { - # Added Oct 11, 2023: + # Added 2024-06-12 + "pp_aval": ("jax.core.pp_aval is deprecated.", _src_core.pp_aval), + "pp_eqn": ("jax.core.pp_eqn is deprecated.", _src_core.pp_eqn), + "pp_eqn_rules": ("jax.core.pp_eqn_rules is deprecated.", _src_core.pp_eqn_rules), + "pp_eqns": ("jax.core.pp_eqns is deprecated.", _src_core.pp_eqns), + "pp_jaxpr": ("jax.core.pp_jaxpr is deprecated.", _src_core.pp_jaxpr), + "pp_jaxpr_eqn_range": ("jax.core.pp_jaxpr_eqn_range is deprecated.", _src_core.pp_jaxpr_eqn_range), + "pp_jaxpr_skeleton": ("jax.core.pp_jaxpr_skeleton is deprecated.", _src_core.pp_jaxpr_skeleton), + "pp_jaxprs": ("jax.core.pp_jaxprs is deprecated.", _src_core.pp_jaxprs), + "pp_kv_pair": ("jax.core.pp_kv_pair is deprecated.", _src_core.pp_kv_pair), + "pp_kv_pairs": ("jax.core.pp_kv_pairs is deprecated.", _src_core.pp_kv_pairs), + "pp_var": ("jax.core.pp_var is deprecated.", _src_core.pp_var), + "pp_vars": ("jax.core.pp_vars is deprecated.", _src_core.pp_vars), + # Finalized 2024-05-13; remove after 2024-08-13 "DimSize": ( "jax.core.DimSize is deprecated. Use DimSize = int | Any.", - _src_core.DimSize, + None, ), "Shape": ( "jax.core.Shape is deprecated. Use Shape = Sequence[int | Any].", - _src_core.Shape, + None, ), - # Added Dec 15, 2023 + # Finalized 2024-06-24; remove after 2024-09-24 "canonicalize_shape": ( - "jax.core.canonicalize_shape is deprecated.", _deprecated_canonicalize_shape, + "jax.core.canonicalize_shape is deprecated.", None, ), "dimension_as_value": ( - "jax.core.dimension_as_value is deprecated. Use jnp.array.", _deprecated_dimension_as_value, + "jax.core.dimension_as_value is deprecated. Use jnp.array.", None, ), "definitely_equal": ( - "jax.core.definitely_equal is deprecated. Use ==.", _deprecated_definitely_equal, + "jax.core.definitely_equal is deprecated. Use ==.", None, ), "symbolic_equal_dim": ( - "jax.core.symbolic_equal_dim is deprecated. Use ==.", _deprecated_definitely_equal, + "jax.core.symbolic_equal_dim is deprecated. Use ==.", None, ), # Added Jan 8, 2024 "non_negative_dim": ( @@ -193,13 +190,19 @@ import typing if typing.TYPE_CHECKING: - DimSize = _src_core.DimSize - Shape = _src_core.Shape - canonicalize_shape = _deprecated_canonicalize_shape - dimension_as_value = _deprecated_dimension_as_value - definitely_equal = _deprecated_definitely_equal non_negative_dim = _deprecated_non_negative_dim - symbolic_equal_dim = _deprecated_definitely_equal + pp_aval = _src_core.pp_aval + pp_eqn = _src_core.pp_eqn + pp_eqn_rules = _src_core.pp_eqn_rules + pp_eqns = _src_core.pp_eqns + pp_jaxpr = _src_core.pp_jaxpr + pp_jaxpr_eqn_range = _src_core.pp_jaxpr_eqn_range + pp_jaxpr_skeleton = _src_core.pp_jaxpr_skeleton + pp_jaxprs = _src_core.pp_jaxprs + pp_kv_pair = _src_core.pp_kv_pair + pp_kv_pairs = _src_core.pp_kv_pairs + pp_var = _src_core.pp_var + pp_vars = _src_core.pp_vars else: from jax._src.deprecations import deprecation_getattr as _deprecation_getattr __getattr__ = _deprecation_getattr(__name__, _deprecations) diff --git a/jax/errors.py b/jax/errors.py index 4b8a0cf7547e..15a6654fa32d 100644 --- a/jax/errors.py +++ b/jax/errors.py @@ -24,5 +24,6 @@ TracerBoolConversionError as TracerBoolConversionError, TracerIntegerConversionError as TracerIntegerConversionError, UnexpectedTracerError as UnexpectedTracerError, + KeyReuseError as KeyReuseError, ) from jax._src.traceback_util import SimplifiedTraceback as SimplifiedTraceback diff --git a/jax/example_libraries/README.md b/jax/example_libraries/README.md index edcfcb506f8e..349b7e12e56b 100644 --- a/jax/example_libraries/README.md +++ b/jax/example_libraries/README.md @@ -44,7 +44,7 @@ net_init, net_apply = stax.serial( ) # Initialize parameters, not committing to a batch shape -rng = random.PRNGKey(0) +rng = random.key(0) in_shape = (-1, 28, 28, 1) out_shape, net_params = net_init(rng, in_shape) diff --git a/jax/example_libraries/optimizers.py b/jax/example_libraries/optimizers.py index 15537e54e3f3..71680ca61b96 100644 --- a/jax/example_libraries/optimizers.py +++ b/jax/example_libraries/optimizers.py @@ -91,7 +91,8 @@ def step(step, opt_state): from __future__ import annotations -from typing import Any, Callable, NamedTuple +from collections.abc import Callable +from typing import Any, NamedTuple from collections import namedtuple import functools @@ -119,7 +120,7 @@ def step(step, opt_state): register_pytree_node( OptimizerState, lambda xs: ((xs.packed_state,), (xs.tree_def, xs.subtree_defs)), - lambda data, xs: OptimizerState(xs[0], data[0], data[1])) # type: ignore[index] + lambda data, xs: OptimizerState(xs[0], data[0], data[1])) Array = Any diff --git a/jax/example_libraries/stax.py b/jax/example_libraries/stax.py index e5bf38c6a69f..476252d92d5d 100644 --- a/jax/example_libraries/stax.py +++ b/jax/example_libraries/stax.py @@ -268,7 +268,7 @@ def apply_fun(params, inputs, **kwargs): msg = ("Dropout layer requires apply_fun to be called with a PRNG key " "argument. That is, instead of `apply_fun(params, inputs)`, call " "it like `apply_fun(params, inputs, rng)` where `rng` is a " - "jax.random.PRNGKey value.") + "PRNG key (e.g. from `jax.random.key`).") raise ValueError(msg) if mode == 'train': keep = random.bernoulli(rng, rate, inputs.shape) diff --git a/jax/experimental/__init__.py b/jax/experimental/__init__.py index 2addcc29b695..caf27ec7a8ca 100644 --- a/jax/experimental/__init__.py +++ b/jax/experimental/__init__.py @@ -22,3 +22,6 @@ from jax._src.callback import ( io_callback as io_callback ) +from jax._src.earray import ( + EArray as EArray +) diff --git a/jax/experimental/array_api/__init__.py b/jax/experimental/array_api/__init__.py index 3169f9667256..e0d8c4ee67f5 100644 --- a/jax/experimental/array_api/__init__.py +++ b/jax/experimental/array_api/__init__.py @@ -21,7 +21,7 @@ >>> from jax.experimental import array_api as xp >>> xp.__array_api_version__ - '2022.12' + '2023.12' >>> arr = xp.arange(1000) @@ -38,68 +38,20 @@ from jax.experimental.array_api._version import __array_api_version__ as __array_api_version__ -from jax.experimental.array_api import ( - fft as fft, - linalg as linalg, -) +from jax.experimental.array_api import fft as fft +from jax.experimental.array_api import linalg as linalg -from jax.experimental.array_api._constants import ( - e as e, - inf as inf, - nan as nan, - newaxis as newaxis, - pi as pi, -) - -from jax.experimental.array_api._creation_functions import ( - arange as arange, - asarray as asarray, - empty as empty, - empty_like as empty_like, - eye as eye, - from_dlpack as from_dlpack, - full as full, - full_like as full_like, - linspace as linspace, - meshgrid as meshgrid, - ones as ones, - ones_like as ones_like, - tril as tril, - triu as triu, - zeros as zeros, - zeros_like as zeros_like, -) - -from jax.experimental.array_api._data_type_functions import ( - astype as astype, - can_cast as can_cast, - finfo as finfo, - iinfo as iinfo, - isdtype as isdtype, - result_type as result_type, -) - -from jax.experimental.array_api._dtypes import ( - bool as bool, - int8 as int8, - int16 as int16, - int32 as int32, - int64 as int64, - uint8 as uint8, - uint16 as uint16, - uint32 as uint32, - uint64 as uint64, - float32 as float32, - float64 as float64, - complex64 as complex64, - complex128 as complex128, -) - -from jax.experimental.array_api._elementwise_functions import ( +from jax.numpy import ( abs as abs, acos as acos, acosh as acosh, add as add, + all as all, + any as any, + arange as arange, + argmax as argmax, + argmin as argmin, + argsort as argsort, asin as asin, asinh as asinh, atan as atan, @@ -111,19 +63,46 @@ bitwise_or as bitwise_or, bitwise_right_shift as bitwise_right_shift, bitwise_xor as bitwise_xor, + bool as bool, + broadcast_arrays as broadcast_arrays, + broadcast_to as broadcast_to, + can_cast as can_cast, ceil as ceil, + complex128 as complex128, + complex64 as complex64, + concat as concat, conj as conj, + copysign as copysign, cos as cos, cosh as cosh, + cumulative_sum as cumulative_sum, divide as divide, + e as e, + empty as empty, + empty_like as empty_like, equal as equal, exp as exp, + expand_dims as expand_dims, expm1 as expm1, + eye as eye, + flip as flip, + float32 as float32, + float64 as float64, floor as floor, floor_divide as floor_divide, + from_dlpack as from_dlpack, + full as full, + full_like as full_like, greater as greater, greater_equal as greater_equal, + iinfo as iinfo, imag as imag, + inf as inf, + int16 as int16, + int32 as int32, + int64 as int64, + int8 as int8, + isdtype as isdtype, isfinite as isfinite, isinf as isinf, isnan as isnan, @@ -138,81 +117,95 @@ logical_not as logical_not, logical_or as logical_or, logical_xor as logical_xor, + matmul as matmul, + matrix_transpose as matrix_transpose, + max as max, + maximum as maximum, + mean as mean, + meshgrid as meshgrid, + min as min, + minimum as minimum, + moveaxis as moveaxis, multiply as multiply, + nan as nan, negative as negative, + newaxis as newaxis, + nonzero as nonzero, not_equal as not_equal, + ones as ones, + ones_like as ones_like, + permute_dims as permute_dims, + pi as pi, positive as positive, pow as pow, + prod as prod, real as real, remainder as remainder, + repeat as repeat, + result_type as result_type, + roll as roll, round as round, + searchsorted as searchsorted, sign as sign, + signbit as signbit, sin as sin, sinh as sinh, + sort as sort, sqrt as sqrt, square as square, + squeeze as squeeze, + stack as stack, subtract as subtract, + sum as sum, + take as take, tan as tan, tanh as tanh, + tensordot as tensordot, + tile as tile, + tril as tril, + triu as triu, trunc as trunc, -) - -from jax.experimental.array_api._indexing_functions import ( - take as take, + uint16 as uint16, + uint32 as uint32, + uint64 as uint64, + uint8 as uint8, + unique_all as unique_all, + unique_counts as unique_counts, + unique_inverse as unique_inverse, + unique_values as unique_values, + unstack as unstack, + vecdot as vecdot, + where as where, + zeros as zeros, + zeros_like as zeros_like, ) from jax.experimental.array_api._manipulation_functions import ( - broadcast_arrays as broadcast_arrays, - broadcast_to as broadcast_to, - concat as concat, - expand_dims as expand_dims, - flip as flip, - permute_dims as permute_dims, reshape as reshape, - roll as roll, - squeeze as squeeze, - stack as stack, ) -from jax.experimental.array_api._searching_functions import ( - argmax as argmax, - argmin as argmin, - nonzero as nonzero, - where as where, +from jax.experimental.array_api._creation_functions import ( + asarray as asarray, + linspace as linspace, ) -from jax.experimental.array_api._set_functions import ( - unique_all as unique_all, - unique_counts as unique_counts, - unique_inverse as unique_inverse, - unique_values as unique_values, +from jax.experimental.array_api._data_type_functions import ( + astype as astype, + finfo as finfo, ) -from jax.experimental.array_api._sorting_functions import ( - argsort as argsort, - sort as sort, +from jax.experimental.array_api._elementwise_functions import ( + clip as clip, + hypot as hypot, ) from jax.experimental.array_api._statistical_functions import ( - max as max, - mean as mean, - min as min, - prod as prod, std as std, - sum as sum, - var as var + var as var, ) from jax.experimental.array_api._utility_functions import ( - all as all, - any as any, -) - -from jax.experimental.array_api._linear_algebra_functions import ( - matmul as matmul, - matrix_transpose as matrix_transpose, - tensordot as tensordot, - vecdot as vecdot, + __array_namespace_info__ as __array_namespace_info__, ) from jax.experimental.array_api import _array_methods diff --git a/jax/experimental/array_api/_array_methods.py b/jax/experimental/array_api/_array_methods.py index 5a73b8a2fe1a..2b071db573a8 100644 --- a/jax/experimental/array_api/_array_methods.py +++ b/jax/experimental/array_api/_array_methods.py @@ -14,11 +14,12 @@ from __future__ import annotations -from typing import Any, Callable +from typing import Any import jax from jax._src.array import ArrayImpl from jax.experimental.array_api._version import __array_api_version__ +from jax.sharding import Sharding from jax._src.lib import xla_extension as xe @@ -30,16 +31,15 @@ def _array_namespace(self, /, *, api_version: None | str = None): return jax.experimental.array_api -def _to_device(self, device: xe.Device | Callable[[], xe.Device], /, *, +def _to_device(self, device: xe.Device | Sharding | None, *, stream: int | Any | None = None): if stream is not None: raise NotImplementedError("stream argument of array.to_device()") - # The type of device is defined by Array.device. In JAX, this is a callable that - # returns a device, so we must handle this case to satisfy the API spec. - return jax.device_put(self, device() if callable(device) else device) + return jax.device_put(self, device) def add_array_object_methods(): # TODO(jakevdp): set on tracers as well? setattr(ArrayImpl, "__array_namespace__", _array_namespace) setattr(ArrayImpl, "to_device", _to_device) + setattr(ArrayImpl, "device", property(lambda self: self.sharding)) diff --git a/jax/experimental/array_api/_creation_functions.py b/jax/experimental/array_api/_creation_functions.py index 0fcde42d58bb..5b9789ed732d 100644 --- a/jax/experimental/array_api/_creation_functions.py +++ b/jax/experimental/array_api/_creation_functions.py @@ -12,54 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import jax import jax.numpy as jnp -def arange(start, /, stop=None, step=1, *, dtype=None, device=None): - return jax.device_put(jnp.arange(start, stop, step, dtype=dtype), device=device) - def asarray(obj, /, *, dtype=None, device=None, copy=None): return jax.device_put(jnp.array(obj, dtype=dtype, copy=copy), device=device) -def empty(shape, *, dtype=None, device=None): - return jax.device_put(jnp.empty(shape, dtype=dtype), device=device) - -def empty_like(x, /, *, dtype=None, device=None): - return jax.device_put(jnp.empty_like(x, dtype=dtype), device=device) - -def eye(n_rows, n_cols=None, /, *, k=0, dtype=None, device=None): - return jax.device_put(jnp.eye(n_rows, n_cols, k=k, dtype=dtype), device=device) - -def from_dlpack(x, /): - return jnp.from_dlpack(x) - -def full(shape, fill_value, *, dtype=None, device=None): - return jax.device_put(jnp.full(shape, fill_value, dtype=dtype), device=device) - -def full_like(x, /, fill_value, *, dtype=None, device=None): - return jax.device_put(jnp.full_like(x, fill_value=fill_value, dtype=dtype), device=device) - def linspace(start, stop, /, num, *, dtype=None, device=None, endpoint=True): return jax.device_put(jnp.linspace(start, stop, num=num, dtype=dtype, endpoint=endpoint), device=device) - -def meshgrid(*arrays, indexing='xy'): - return jnp.meshgrid(*arrays, indexing=indexing) - -def ones(shape, *, dtype=None, device=None): - return jax.device_put(jnp.ones(shape, dtype=dtype), device=device) - -def ones_like(x, /, *, dtype=None, device=None): - return jax.device_put(jnp.ones_like(x, dtype=dtype), device=device) - -def tril(x, /, *, k=0): - return jnp.tril(x, k=k) - -def triu(x, /, *, k=0): - return jnp.triu(x, k=k) - -def zeros(shape, *, dtype=None, device=None): - return jax.device_put(jnp.zeros(shape, dtype=dtype), device=device) - -def zeros_like(x, /, *, dtype=None, device=None): - return jax.device_put(jnp.zeros_like(x, dtype=dtype), device=device) diff --git a/jax/experimental/array_api/_data_type_functions.py b/jax/experimental/array_api/_data_type_functions.py index d2bb032b85ab..248c1c6dd0fe 100644 --- a/jax/experimental/array_api/_data_type_functions.py +++ b/jax/experimental/array_api/_data_type_functions.py @@ -12,135 +12,48 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations -import functools +import builtins from typing import NamedTuple -import jax -import jax.numpy as jnp - - -from jax.experimental.array_api._dtypes import ( - bool, int8, int16, int32, int64, uint8, uint16, uint32, uint64, - float32, float64, complex64, complex128 -) - -_valid_dtypes = { - bool, int8, int16, int32, int64, uint8, uint16, uint32, uint64, - float32, float64, complex64, complex128 -} - -_promotion_table = { - (bool, bool): bool, - (int8, int8): int8, - (int8, int16): int16, - (int8, int32): int32, - (int8, int64): int64, - (int8, uint8): int16, - (int8, uint16): int32, - (int8, uint32): int64, - (int16, int8): int16, - (int16, int16): int16, - (int16, int32): int32, - (int16, int64): int64, - (int16, uint8): int16, - (int16, uint16): int32, - (int16, uint32): int64, - (int32, int8): int32, - (int32, int16): int32, - (int32, int32): int32, - (int32, int64): int64, - (int32, uint8): int32, - (int32, uint16): int32, - (int32, uint32): int64, - (int64, int8): int64, - (int64, int16): int64, - (int64, int32): int64, - (int64, int64): int64, - (int64, uint8): int64, - (int64, uint16): int64, - (int64, uint32): int64, - (uint8, int8): int16, - (uint8, int16): int16, - (uint8, int32): int32, - (uint8, int64): int64, - (uint8, uint8): uint8, - (uint8, uint16): uint16, - (uint8, uint32): uint32, - (uint8, uint64): uint64, - (uint16, int8): int32, - (uint16, int16): int32, - (uint16, int32): int32, - (uint16, int64): int64, - (uint16, uint8): uint16, - (uint16, uint16): uint16, - (uint16, uint32): uint32, - (uint16, uint64): uint64, - (uint32, int8): int64, - (uint32, int16): int64, - (uint32, int32): int64, - (uint32, int64): int64, - (uint32, uint8): uint32, - (uint32, uint16): uint32, - (uint32, uint32): uint32, - (uint32, uint64): uint64, - (uint64, uint8): uint64, - (uint64, uint16): uint64, - (uint64, uint32): uint64, - (uint64, uint64): uint64, - (float32, float32): float32, - (float32, float64): float64, - (float32, complex64): complex64, - (float32, complex128): complex128, - (float64, float32): float64, - (float64, float64): float64, - (float64, complex64): complex128, - (float64, complex128): complex128, - (complex64, float32): complex64, - (complex64, float64): complex128, - (complex64, complex64): complex64, - (complex64, complex128): complex128, - (complex128, float32): complex128, - (complex128, float64): complex128, - (complex128, complex64): complex128, - (complex128, complex128): complex128, -} - - -def _is_valid_dtype(t): - try: - return t in _valid_dtypes - except TypeError: - return False - - -def _promote_types(t1, t2): - if not _is_valid_dtype(t1): - raise ValueError(f"{t1} is not a valid dtype") - if not _is_valid_dtype(t2): - raise ValueError(f"{t2} is not a valid dtype") - if result := _promotion_table.get((t1, t2), None): - return result - else: - raise ValueError("No promotion path for {t1} & {t2}") - - -def astype(x, dtype, /, *, copy=True): - return jnp.array(x, dtype=dtype, copy=copy) +import numpy as np +import jax.numpy as jnp -def can_cast(from_, to, /): - if isinstance(from_, jax.Array): - from_ = from_.dtype - if not _is_valid_dtype(from_): - raise ValueError(f"{from_} is not a valid dtype") - if not _is_valid_dtype(to): - raise ValueError(f"{to} is not a valid dtype") - try: - result = _promote_types(from_, to) - except ValueError: - return False - else: - return result == to +from jax._src.lib import xla_client as xc +from jax._src.sharding import Sharding +from jax._src import dtypes as _dtypes + +# TODO(micky774): Update jax.numpy dtypes to dtype *objects* +bool = np.dtype('bool') +int8 = np.dtype('int8') +int16 = np.dtype('int16') +int32 = np.dtype('int32') +int64 = np.dtype('int64') +uint8 = np.dtype('uint8') +uint16 = np.dtype('uint16') +uint32 = np.dtype('uint32') +uint64 = np.dtype('uint64') +float32 = np.dtype('float32') +float64 = np.dtype('float64') +complex64 = np.dtype('complex64') +complex128 = np.dtype('complex128') + + +# TODO(micky774): Remove when jax.numpy.astype is deprecation is completed +def astype(x, dtype, /, *, copy: builtins.bool = True, device: xc.Device | Sharding | None = None): + src_dtype = x.dtype if hasattr(x, "dtype") else _dtypes.dtype(x) + if ( + src_dtype is not None + and _dtypes.isdtype(src_dtype, "complex floating") + and _dtypes.isdtype(dtype, ("integral", "real floating")) + ): + raise ValueError( + "Casting from complex to non-complex dtypes is not permitted. Please " + "first use jnp.real or jnp.imag to take the real/imaginary component of " + "your input." + ) + return jnp.astype(x, dtype, copy=copy, device=device) class FInfo(NamedTuple): @@ -151,14 +64,8 @@ class FInfo(NamedTuple): smallest_normal: float dtype: jnp.dtype - -class IInfo(NamedTuple): - bits: int - max: int - min: int - dtype: jnp.dtype - - +# TODO(micky774): Update jax.numpy.finfo so that its attributes are python +# floats def finfo(type, /) -> FInfo: info = jnp.finfo(type) return FInfo( @@ -169,43 +76,3 @@ def finfo(type, /) -> FInfo: smallest_normal=float(info.smallest_normal), dtype=jnp.dtype(type) ) - - -def iinfo(type, /) -> IInfo: - info = jnp.iinfo(type) - return IInfo(bits=info.bits, max=info.max, min=info.min, dtype=jnp.dtype(type)) - - -def isdtype(dtype, kind): - return jax.numpy.isdtype(dtype, kind) - - -def result_type(*arrays_and_dtypes): - dtypes = [] - for val in arrays_and_dtypes: - if isinstance(val, jax.Array): - val = val.dtype - if _is_valid_dtype(val): - dtypes.append(val) - else: - raise ValueError(f"{val} is not a valid dtype") - if len(dtypes) == 0: - raise ValueError("result_type requires at least one argument") - if len(dtypes) == 1: - return dtypes[0] - return functools.reduce(_promote_types, dtypes) - - -def _promote_to_default_dtype(x): - if x.dtype.kind == 'b': - return x - elif x.dtype.kind == 'i': - return x.astype(jnp.int_) - elif x.dtype.kind == 'u': - return x.astype(jnp.uint) - elif x.dtype.kind == 'f': - return x.astype(jnp.float_) - elif x.dtype.kind == 'c': - return x.astype(jnp.complex_) - else: - raise ValueError(f"Unrecognized {x.dtype=}") diff --git a/jax/experimental/array_api/_elementwise_functions.py b/jax/experimental/array_api/_elementwise_functions.py index 1352cd5b0b3e..103f8ab7d1ef 100644 --- a/jax/experimental/array_api/_elementwise_functions.py +++ b/jax/experimental/array_api/_elementwise_functions.py @@ -13,377 +13,36 @@ # limitations under the License. import jax -from jax.experimental.array_api._data_type_functions import ( - result_type as _result_type, - isdtype as _isdtype, -) -import numpy as np +from jax.numpy import isdtype +from jax._src.dtypes import issubdtype +from jax._src.numpy.util import promote_args -def _promote_dtypes(name, *args): - assert isinstance(name, str) - if not all(isinstance(arg, jax.Array) for arg in args): - raise ValueError(f"{name}: inputs must be arrays; got types {[type(arg) for arg in args]}") - dtype = _result_type(*args) - return [arg.astype(dtype) for arg in args] - - -def abs(x, /): - """Calculates the absolute value for each element x_i of the input array x.""" - x, = _promote_dtypes("abs", x) - return jax.numpy.abs(x) - - -def acos(x, /): - """Calculates an implementation-dependent approximation of the principal value of the inverse cosine for each element x_i of the input array x.""" - x, = _promote_dtypes("acos", x) - return jax.numpy.acos(x) - -def acosh(x, /): - """Calculates an implementation-dependent approximation to the inverse hyperbolic cosine for each element x_i of the input array x.""" - x, = _promote_dtypes("acos", x) - return jax.numpy.acosh(x) - - -def add(x1, x2, /): - """Calculates the sum for each element x1_i of the input array x1 with the respective element x2_i of the input array x2.""" - x1, x2 = _promote_dtypes("add", x1, x2) - return jax.numpy.add(x1, x2) - - -def asin(x, /): - """Calculates an implementation-dependent approximation of the principal value of the inverse sine for each element x_i of the input array x.""" - x, = _promote_dtypes("asin", x) - return jax.numpy.asin(x) - - -def asinh(x, /): - """Calculates an implementation-dependent approximation to the inverse hyperbolic sine for each element x_i in the input array x.""" - x, = _promote_dtypes("asinh", x) - return jax.numpy.asinh(x) - - -def atan(x, /): - """Calculates an implementation-dependent approximation of the principal value of the inverse tangent for each element x_i of the input array x.""" - x, = _promote_dtypes("atan", x) - return jax.numpy.atan(x) - - -def atan2(x1, x2, /): - """Calculates an implementation-dependent approximation of the inverse tangent of the quotient x1/x2, having domain [-infinity, +infinity] x [-infinity, +infinity] (where the x notation denotes the set of ordered pairs of elements (x1_i, x2_i)) and codomain [-π, +π], for each pair of elements (x1_i, x2_i) of the input arrays x1 and x2, respectively.""" - x1, x2 = _promote_dtypes("atan2", x1, x2) - return jax.numpy.arctan2(x1, x2) - - -def atanh(x, /): - """Calculates an implementation-dependent approximation to the inverse hyperbolic tangent for each element x_i of the input array x.""" - x, = _promote_dtypes("atanh", x) - return jax.numpy.atanh(x) - - -def bitwise_and(x1, x2, /): - """Computes the bitwise AND of the underlying binary representation of each element x1_i of the input array x1 with the respective element x2_i of the input array x2.""" - x1, x2 = _promote_dtypes("bitwise_and", x1, x2) - return jax.numpy.bitwise_and(x1, x2) - - -def bitwise_left_shift(x1, x2, /): - """Shifts the bits of each element x1_i of the input array x1 to the left by appending x2_i (i.e., the respective element in the input array x2) zeros to the right of x1_i.""" - x1, x2 = _promote_dtypes("bitwise_left_shift", x1, x2) - return jax.numpy.bitwise_left_shift(x1, x2) - - -def bitwise_invert(x, /): - """Inverts (flips) each bit for each element x_i of the input array x.""" - x, = _promote_dtypes("bitwise_invert", x) - return jax.numpy.bitwise_invert(x) - - -def bitwise_or(x1, x2, /): - """Computes the bitwise OR of the underlying binary representation of each element x1_i of the input array x1 with the respective element x2_i of the input array x2.""" - x1, x2 = _promote_dtypes("bitwise_or", x1, x2) - return jax.numpy.bitwise_or(x1, x2) - - -def bitwise_right_shift(x1, x2, /): - """Shifts the bits of each element x1_i of the input array x1 to the right according to the respective element x2_i of the input array x2.""" - x1, x2 = _promote_dtypes("bitwise_right_shift", x1, x2) - return jax.numpy.bitwise_right_shift(x1, x2) - - -def bitwise_xor(x1, x2, /): - """Computes the bitwise XOR of the underlying binary representation of each element x1_i of the input array x1 with the respective element x2_i of the input array x2.""" - x1, x2 = _promote_dtypes("bitwise_xor", x1, x2) - return jax.numpy.bitwise_xor(x1, x2) - - -def ceil(x, /): - """Rounds each element x_i of the input array x to the smallest (i.e., closest to -infinity) integer-valued number that is not less than x_i.""" - x, = _promote_dtypes("ceil", x) - if _isdtype(x.dtype, "integral"): - return x - return jax.numpy.ceil(x) - - -def conj(x, /): +# TODO(micky774): Remove when jnp.clip deprecation is completed +# (began 2024-4-2) and default behavior is Array API 2023 compliant +def clip(x, /, min=None, max=None): """Returns the complex conjugate for each element x_i of the input array x.""" - x, = _promote_dtypes("conj", x) - return jax.numpy.conj(x) - - -def cos(x, /): - """Calculates an implementation-dependent approximation to the cosine for each element x_i of the input array x.""" - x, = _promote_dtypes("cos", x) - return jax.numpy.cos(x) - - -def cosh(x, /): - """Calculates an implementation-dependent approximation to the hyperbolic cosine for each element x_i in the input array x.""" - x, = _promote_dtypes("cosh", x) - return jax.numpy.cosh(x) - - -def divide(x1, x2, /): - """Calculates the division of each element x1_i of the input array x1 with the respective element x2_i of the input array x2.""" - x1, x2 = _promote_dtypes("divide", x1, x2) - return jax.numpy.divide(x1, x2) - - -def equal(x1, x2, /): - """Computes the truth value of x1_i == x2_i for each element x1_i of the input array x1 with the respective element x2_i of the input array x2.""" - x1, x2 = _promote_dtypes("equal", x1, x2) - return jax.numpy.equal(x1, x2) - - -def exp(x, /): - """Calculates an implementation-dependent approximation to the exponential function for each element x_i of the input array x (e raised to the power of x_i, where e is the base of the natural logarithm).""" - x, = _promote_dtypes("exp", x) - return jax.numpy.exp(x) - - -def expm1(x, /): - """Calculates an implementation-dependent approximation to exp(x)-1 for each element x_i of the input array x.""" - x, = _promote_dtypes("expm1", x) - return jax.numpy.expm1(x) - - -def floor(x, /): - """Rounds each element x_i of the input array x to the greatest (i.e., closest to +infinity) integer-valued number that is not greater than x_i.""" - x, = _promote_dtypes("floor", x) - if _isdtype(x.dtype, "integral"): - return x - return jax.numpy.floor(x) - - -def floor_divide(x1, x2, /): - """Rounds the result of dividing each element x1_i of the input array x1 by the respective element x2_i of the input array x2 to the greatest (i.e., closest to +infinity) integer-value number that is not greater than the division result.""" - x1, x2 = _promote_dtypes("floor_divide", x1, x2) - return jax.numpy.floor_divide(x1, x2) - - -def greater(x1, x2, /): - """Computes the truth value of x1_i > x2_i for each element x1_i of the input array x1 with the respective element x2_i of the input array x2.""" - x1, x2 = _promote_dtypes("greater", x1, x2) - return jax.numpy.greater(x1, x2) - - -def greater_equal(x1, x2, /): - """Computes the truth value of x1_i >= x2_i for each element x1_i of the input array x1 with the respective element x2_i of the input array x2.""" - x1, x2 = _promote_dtypes("greater_equal", x1, x2) - return jax.numpy.greater_equal(x1, x2) - - -def imag(x, /): - """Returns the imaginary component of a complex number for each element x_i of the input array x.""" - x, = _promote_dtypes("imag", x) - return jax.numpy.imag(x) - - -def isfinite(x, /): - """Tests each element x_i of the input array x to determine if finite.""" - x, = _promote_dtypes("isfinite", x) - return jax.numpy.isfinite(x) - - -def isinf(x, /): - """Tests each element x_i of the input array x to determine if equal to positive or negative infinity.""" - x, = _promote_dtypes("isinf", x) - return jax.numpy.isinf(x) - - -def isnan(x, /): - """Tests each element x_i of the input array x to determine whether the element is NaN.""" - x, = _promote_dtypes("isnan", x) - return jax.numpy.isnan(x) - - -def less(x1, x2, /): - """Computes the truth value of x1_i < x2_i for each element x1_i of the input array x1 with the respective element x2_i of the input array x2.""" - x1, x2 = _promote_dtypes("less", x1, x2) - return jax.numpy.less(x1, x2) - - -def less_equal(x1, x2, /): - """Computes the truth value of x1_i <= x2_i for each element x1_i of the input array x1 with the respective element x2_i of the input array x2.""" - x1, x2 = _promote_dtypes("less_equal", x1, x2) - return jax.numpy.less_equal(x1, x2) - - -def log(x, /): - """Calculates an implementation-dependent approximation to the natural (base e) logarithm for each element x_i of the input array x.""" - x, = _promote_dtypes("log", x) - return jax.numpy.log(x) - - -def log1p(x, /): - """Calculates an implementation-dependent approximation to log(1+x), where log refers to the natural (base e) logarithm, for each element x_i of the input array x.""" - x, = _promote_dtypes("log", x) - return jax.numpy.log1p(x) - - -def log2(x, /): - """Calculates an implementation-dependent approximation to the base 2 logarithm for each element x_i of the input array x.""" - x, = _promote_dtypes("log2", x) - return jax.numpy.log2(x) - - -def log10(x, /): - """Calculates an implementation-dependent approximation to the base 10 logarithm for each element x_i of the input array x.""" - x, = _promote_dtypes("log10", x) - return jax.numpy.log10(x) - - -def logaddexp(x1, x2, /): - """Calculates the logarithm of the sum of exponentiations log(exp(x1) + exp(x2)) for each element x1_i of the input array x1 with the respective element x2_i of the input array x2.""" - x1, x2 = _promote_dtypes("logaddexp", x1, x2) - return jax.numpy.logaddexp(x1, x2) - - -def logical_and(x1, x2, /): - """Computes the logical AND for each element x1_i of the input array x1 with the respective element x2_i of the input array x2.""" - x1, x2 = _promote_dtypes("logical_and", x1, x2) - return jax.numpy.logical_and(x1, x2) - - -def logical_not(x, /): - """Computes the logical NOT for each element x_i of the input array x.""" - x, = _promote_dtypes("logical_not", x) - return jax.numpy.logical_not(x) - - -def logical_or(x1, x2, /): - """Computes the logical OR for each element x1_i of the input array x1 with the respective element x2_i of the input array x2.""" - x1, x2 = _promote_dtypes("logical_or", x1, x2) - return jax.numpy.logical_or(x1, x2) - - -def logical_xor(x1, x2, /): - """Computes the logical XOR for each element x1_i of the input array x1 with the respective element x2_i of the input array x2.""" - x1, x2 = _promote_dtypes("logical_xor", x1, x2) - return jax.numpy.logical_xor(x1, x2) - - -def multiply(x1, x2, /): - """Calculates the product for each element x1_i of the input array x1 with the respective element x2_i of the input array x2.""" - x1, x2 = _promote_dtypes("multiply", x1, x2) - return jax.numpy.multiply(x1, x2) - - -def negative(x, /): - """Computes the numerical negative of each element x_i (i.e., y_i = -x_i) of the input array x.""" - x, = _promote_dtypes("negative", x) - return jax.numpy.negative(x) - - -def not_equal(x1, x2, /): - """Computes the truth value of x1_i != x2_i for each element x1_i of the input array x1 with the respective element x2_i of the input array x2.""" - x1, x2 = _promote_dtypes("not_equal", x1, x2) - return jax.numpy.not_equal(x1, x2) - - -def positive(x, /): - """Computes the numerical positive of each element x_i (i.e., y_i = +x_i) of the input array x.""" - x, = _promote_dtypes("positive", x) - return x - - -def pow(x1, x2, /): - """Calculates an implementation-dependent approximation of exponentiation by raising each element x1_i (the base) of the input array x1 to the power of x2_i (the exponent), where x2_i is the corresponding element of the input array x2.""" - x1, x2 = _promote_dtypes("pow", x1, x2) - return jax.numpy.pow(x1, x2) - - -def real(x, /): - """Returns the real component of a complex number for each element x_i of the input array x.""" - x, = _promote_dtypes("real", x) - return jax.numpy.real(x) - - -def remainder(x1, x2, /): - """Returns the remainder of division for each element x1_i of the input array x1 and the respective element x2_i of the input array x2.""" - x1, x2 = _promote_dtypes("remainder", x1, x2) - return jax.numpy.remainder(x1, x2) - - -def round(x, /): - """Rounds each element x_i of the input array x to the nearest integer-valued number.""" - x, = _promote_dtypes("round", x) - return jax.numpy.round(x) - - -def sign(x, /): - """Returns an indication of the sign of a number for each element x_i of the input array x.""" - x, = _promote_dtypes("sign", x) - if _isdtype(x.dtype, "complex floating"): - return x / abs(x) - return jax.numpy.sign(x) - - -def sin(x, /): - """Calculates an implementation-dependent approximation to the sine for each element x_i of the input array x.""" - x, = _promote_dtypes("sin", x) - return jax.numpy.sin(x) - - -def sinh(x, /): - """Calculates an implementation-dependent approximation to the hyperbolic sine for each element x_i of the input array x.""" - x, = _promote_dtypes("sin", x) - return jax.numpy.sinh(x) - - -def square(x, /): - """Squares each element x_i of the input array x.""" - x, = _promote_dtypes("square", x) - return jax.numpy.square(x) - - -def sqrt(x, /): - """Calculates the principal square root for each element x_i of the input array x.""" - x, = _promote_dtypes("sqrt", x) - return jax.numpy.sqrt(x) - - -def subtract(x1, x2, /): - """Calculates the difference for each element x1_i of the input array x1 with the respective element x2_i of the input array x2.""" - x1, x2 = _promote_dtypes("subtract", x1, x2) - return jax.numpy.subtract(x1, x2) - - -def tan(x, /): - """Calculates an implementation-dependent approximation to the tangent for each element x_i of the input array x.""" - x, = _promote_dtypes("tan", x) - return jax.numpy.tan(x) - - -def tanh(x, /): - """Calculates an implementation-dependent approximation to the hyperbolic tangent for each element x_i of the input array x.""" - x, = _promote_dtypes("tanh", x) - return jax.numpy.tanh(x) - - -def trunc(x, /): - """Rounds each element x_i of the input array x to the nearest integer-valued number that is closer to zero than x_i.""" - x, = _promote_dtypes("trunc", x) - if _isdtype(x.dtype, "integral"): - return x - return jax.numpy.trunc(x) + x, = promote_args("clip", x) + + if any(jax.numpy.iscomplexobj(t) for t in (x, min, max)): + raise ValueError( + "Clip received a complex value either through the input or the min/max " + "keywords. Complex values have no ordering and cannot be clipped. " + "Please convert to a real value or array by taking the real or " + "imaginary components via jax.numpy.real/imag respectively." + ) + return jax.numpy.clip(x, min=min, max=max) + + +# TODO(micky774): Remove when jnp.hypot deprecation is completed +# (began 2024-4-14) and default behavior is Array API 2023 compliant +def hypot(x1, x2, /): + """Computes the square root of the sum of squares for each element x1_i of the input array x1 with the respective element x2_i of the input array x2.""" + x1, x2 = promote_args("hypot", x1, x2) + + if issubdtype(x1.dtype, jax.numpy.complexfloating): + raise ValueError( + "hypot does not support complex-valued inputs. Please convert to real " + "values first, such as by using jnp.real or jnp.imag to take the real " + "or imaginary components respectively.") + return jax.numpy.hypot(x1, x2) diff --git a/jax/experimental/array_api/_fft_functions.py b/jax/experimental/array_api/_fft_functions.py index d1e737a424ac..9b51dd628484 100644 --- a/jax/experimental/array_api/_fft_functions.py +++ b/jax/experimental/array_api/_fft_functions.py @@ -14,47 +14,8 @@ import jax.numpy as jnp - -def fft(x, /, *, n=None, axis=-1, norm='backward'): - """Computes the one-dimensional discrete Fourier transform.""" - return jnp.fft.fft(x, n=n, axis=axis, norm=norm) - -def ifft(x, /, *, n=None, axis=-1, norm='backward'): - """Computes the one-dimensional inverse discrete Fourier transform.""" - return jnp.fft.ifft(x, n=n, axis=axis, norm=norm) - -def fftn(x, /, *, s=None, axes=None, norm='backward'): - """Computes the n-dimensional discrete Fourier transform.""" - return jnp.fft.fftn(x, s=s, axes=axes, norm=norm) - -def ifftn(x, /, *, s=None, axes=None, norm='backward'): - """Computes the n-dimensional inverse discrete Fourier transform.""" - return jnp.fft.ifftn(x, s=s, axes=axes, norm=norm) - -def rfft(x, /, *, n=None, axis=-1, norm='backward'): - """Computes the one-dimensional discrete Fourier transform for real-valued input.""" - return jnp.fft.rfft(x, n=n, axis=axis, norm=norm) - -def irfft(x, /, *, n=None, axis=-1, norm='backward'): - """Computes the one-dimensional inverse of rfft for complex-valued input.""" - return jnp.fft.irfft(x, n=n, axis=axis, norm=norm) - -def rfftn(x, /, *, s=None, axes=None, norm='backward'): - """Computes the n-dimensional discrete Fourier transform for real-valued input.""" - return jnp.fft.rfftn(x, s=s, axes=axes, norm=norm) - -def irfftn(x, /, *, s=None, axes=None, norm='backward'): - """Computes the n-dimensional inverse of rfftn for complex-valued input.""" - return jnp.fft.irfftn(x, s=s, axes=axes, norm=norm) - -def hfft(x, /, *, n=None, axis=-1, norm='backward'): - """Computes the one-dimensional discrete Fourier transform of a signal with Hermitian symmetry.""" - return jnp.fft.hfft(x, n=n, axis=axis, norm=norm) - -def ihfft(x, /, *, n=None, axis=-1, norm='backward'): - """Computes the one-dimensional inverse discrete Fourier transform of a signal with Hermitian symmetry.""" - return jnp.fft.ihfft(x, n=n, axis=axis, norm=norm) - +# TODO(micky774): Remove after adding device parameter to corresponding jnp.fft +# functions. def fftfreq(n, /, *, d=1.0, device=None): """Returns the discrete Fourier transform sample frequencies.""" return jnp.fft.fftfreq(n, d=d).to_device(device) @@ -62,11 +23,3 @@ def fftfreq(n, /, *, d=1.0, device=None): def rfftfreq(n, /, *, d=1.0, device=None): """Returns the discrete Fourier transform sample frequencies (for rfft and irfft).""" return jnp.fft.rfftfreq(n, d=d).to_device(device) - -def fftshift(x, /, *, axes=None): - """Shift the zero-frequency component to the center of the spectrum.""" - return jnp.fft.fftshift(x, axes=axes) - -def ifftshift(x, /, *, axes=None): - """Inverse of fftshift.""" - return jnp.fft.ifftshift(x, axes=axes) diff --git a/jax/experimental/array_api/_linear_algebra_functions.py b/jax/experimental/array_api/_linear_algebra_functions.py index 478af83092d7..cc488d721218 100644 --- a/jax/experimental/array_api/_linear_algebra_functions.py +++ b/jax/experimental/array_api/_linear_algebra_functions.py @@ -13,140 +13,16 @@ # limitations under the License. import jax -from jax.experimental.array_api._data_type_functions import ( - _promote_to_default_dtype, -) - -def cholesky(x, /, *, upper=False): - """ - Returns the lower (upper) Cholesky decomposition of a complex Hermitian or real symmetric positive-definite matrix x. - """ - return jax.numpy.linalg.cholesky(x, upper=upper) - -def cross(x1, x2, /, *, axis=-1): - """ - Returns the cross product of 3-element vectors. - """ - return jax.numpy.linalg.cross(x1, x2, axis=axis) - -def det(x, /): - """ - Returns the determinant of a square matrix (or a stack of square matrices) x. - """ - return jax.numpy.linalg.det(x) - -def diagonal(x, /, *, offset=0): - """ - Returns the specified diagonals of a matrix (or a stack of matrices) x. - """ - return jax.numpy.linalg.diagonal(x, offset=offset) - -def eigh(x, /): - """ - Returns an eigenvalue decomposition of a complex Hermitian or real symmetric matrix (or a stack of matrices) x. - """ - return jax.numpy.linalg.eigh(x) - -def eigvalsh(x, /): - """ - Returns the eigenvalues of a complex Hermitian or real symmetric matrix (or a stack of matrices) x. - """ - return jax.numpy.linalg.eigvalsh(x) - -def inv(x, /): - """ - Returns the multiplicative inverse of a square matrix (or a stack of square matrices) x. - """ - return jax.numpy.linalg.inv(x) - -def matmul(x1, x2, /): - """Computes the matrix product.""" - return jax.numpy.linalg.matmul(x1, x2) - -def matrix_norm(x, /, *, keepdims=False, ord='fro'): - """ - Computes the matrix norm of a matrix (or a stack of matrices) x. - """ - return jax.numpy.linalg.matrix_norm(x, ord=ord, keepdims=keepdims) - -def matrix_power(x, n, /): - """ - Raises a square matrix (or a stack of square matrices) x to an integer power n. - """ - return jax.numpy.linalg.matrix_power(x, n) +# TODO(micky774): Remove after deprecation is completed (began 2024-5-14) def matrix_rank(x, /, *, rtol=None): """ Returns the rank (i.e., number of non-zero singular values) of a matrix (or a stack of matrices). """ - return jax.numpy.linalg.matrix_rank(x, tol=rtol) - -def matrix_transpose(x, /): - """Transposes a matrix (or a stack of matrices) x.""" - return jax.numpy.linalg.matrix_transpose(x) - -def outer(x1, x2, /): - """ - Returns the outer product of two vectors x1 and x2. - """ - return jax.numpy.linalg.outer(x1, x2) + return jax.numpy.linalg.matrix_rank(x, rtol) def pinv(x, /, *, rtol=None): """ Returns the (Moore-Penrose) pseudo-inverse of a matrix (or a stack of matrices) x. """ - return jax.numpy.linalg.pinv(x, rcond=rtol) - -def qr(x, /, *, mode='reduced'): - """ - Returns the QR decomposition of a full column rank matrix (or a stack of matrices). - """ - return jax.numpy.linalg.qr(x, mode=mode) - -def slogdet(x, /): - """ - Returns the sign and the natural logarithm of the absolute value of the determinant of a square matrix (or a stack of square matrices) x. - """ - return jax.numpy.linalg.slogdet(x) - -def solve(x1, x2, /): - """ - Returns the solution of a square system of linear equations with a unique solution. - """ - if x2.ndim == 1: - signature = "(m,m),(m)->(m)" - else: - signature = "(m,m),(m,n)->(m,n)" - return jax.numpy.vectorize(jax.numpy.linalg.solve, signature=signature)(x1, x2) - - -def svd(x, /, *, full_matrices=True): - """ - Returns a singular value decomposition (SVD) of a matrix (or a stack of matrices) x. - """ - return jax.numpy.linalg.svd(x, full_matrices=full_matrices) - -def svdvals(x, /): - """ - Returns the singular values of a matrix (or a stack of matrices) x. - """ - return jax.numpy.linalg.svdvals(x) - -def tensordot(x1, x2, /, *, axes=2): - """Returns a tensor contraction of x1 and x2 over specific axes.""" - return jax.numpy.linalg.tensordot(x1, x2, axes=axes) - -def trace(x, /, *, offset=0, dtype=None): - """ - Returns the sum along the specified diagonals of a matrix (or a stack of matrices) x. - """ - x = _promote_to_default_dtype(x) - return jax.numpy.trace(x, offset=offset, dtype=dtype, axis1=-2, axis2=-1) - -def vecdot(x1, x2, /, *, axis=-1): - """Computes the (vector) dot product of two arrays.""" - return jax.numpy.linalg.vecdot(x1, x2, axis=axis) - -def vector_norm(x, /, *, axis=None, keepdims=False, ord=2): - """Computes the vector norm of a vector (or batch of vectors) x.""" - return jax.numpy.linalg.vector_norm(x, axis=axis, keepdims=keepdims, ord=ord) + return jax.numpy.linalg.pinv(x, rtol) diff --git a/jax/experimental/array_api/_manipulation_functions.py b/jax/experimental/array_api/_manipulation_functions.py index bac3afd67472..c364b9f5b79c 100644 --- a/jax/experimental/array_api/_manipulation_functions.py +++ b/jax/experimental/array_api/_manipulation_functions.py @@ -16,59 +16,10 @@ import jax from jax import Array -from jax.experimental.array_api._data_type_functions import result_type as _result_type - - -def broadcast_arrays(*arrays: Array) -> list[Array]: - """Broadcasts one or more arrays against one another.""" - return jax.numpy.broadcast_arrays(*arrays) - - -def broadcast_to(x: Array, /, shape: tuple[int, ...]) -> Array: - """Broadcasts an array to a specified shape.""" - return jax.numpy.broadcast_to(x, shape=shape) - - -def concat(arrays: tuple[Array, ...] | list[Array], /, *, axis: int | None = 0) -> Array: - """Joins a sequence of arrays along an existing axis.""" - dtype = _result_type(*arrays) - return jax.numpy.concat([arr.astype(dtype) for arr in arrays], axis=axis) - - -def expand_dims(x: Array, /, *, axis: int = 0) -> Array: - """Expands the shape of an array by inserting a new axis (dimension) of size one at the position specified by axis.""" - if axis < -x.ndim - 1 or axis > x.ndim: - raise IndexError(f"{axis=} is out of bounds for array of dimension {x.ndim}") - return jax.numpy.expand_dims(x, axis=axis) - - -def flip(x: Array, /, *, axis: int | tuple[int, ...] | None = None) -> Array: - """Reverses the order of elements in an array along the given axis.""" - return jax.numpy.flip(x, axis=axis) - - -def permute_dims(x: Array, /, axes: tuple[int, ...]) -> Array: - """Permutes the axes (dimensions) of an array x.""" - return jax.numpy.permute_dims(x, axes=axes) +# TODO(micky774): Implement copy def reshape(x: Array, /, shape: tuple[int, ...], *, copy: bool | None = None) -> Array: """Reshapes an array without changing its data.""" del copy # unused return jax.numpy.reshape(x, shape) - - -def roll(x: Array, /, shift: int | tuple[int, ...], *, axis: int | tuple[int, ...] | None = None) -> Array: - """Rolls array elements along a specified axis.""" - return jax.numpy.roll(x, shift=shift, axis=axis) - - -def squeeze(x: Array, /, axis: int | tuple[int, ...]) -> Array: - """Removes singleton dimensions (axes) from x.""" - return jax.numpy.squeeze(x, axis=axis) - - -def stack(arrays: tuple[Array, ...] | list[Array], /, *, axis: int = 0) -> Array: - """Joins a sequence of arrays along a new axis.""" - dtype = _result_type(*arrays) - return jax.numpy.stack(arrays, axis=axis, dtype=dtype) diff --git a/jax/experimental/array_api/_searching_functions.py b/jax/experimental/array_api/_searching_functions.py deleted file mode 100644 index 8357ae3eae86..000000000000 --- a/jax/experimental/array_api/_searching_functions.py +++ /dev/null @@ -1,39 +0,0 @@ -# Copyright 2023 The JAX Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import jax -from jax.experimental.array_api._data_type_functions import result_type as _result_type - - -def argmax(x, /, *, axis=None, keepdims=False): - """Returns the indices of the maximum values along a specified axis.""" - return jax.numpy.argmax(x, axis=axis, keepdims=keepdims) - - -def argmin(x, /, *, axis=None, keepdims=False): - """Returns the indices of the minimum values along a specified axis.""" - return jax.numpy.argmin(x, axis=axis, keepdims=keepdims) - - -def nonzero(x, /): - """Returns the indices of the array elements which are non-zero.""" - if jax.numpy.ndim(x) == 0: - raise ValueError("inputs to nonzero() must have at least one dimension.") - return jax.numpy.nonzero(x) - - -def where(condition, x1, x2, /): - """Returns elements chosen from x1 or x2 depending on condition.""" - dtype = _result_type(x1, x2) - return jax.numpy.where(condition, x1.astype(dtype), x2.astype(dtype)) diff --git a/jax/experimental/array_api/_set_functions.py b/jax/experimental/array_api/_set_functions.py deleted file mode 100644 index c9f539d5ec06..000000000000 --- a/jax/experimental/array_api/_set_functions.py +++ /dev/null @@ -1,35 +0,0 @@ -# Copyright 2023 The JAX Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import jax - - -def unique_all(x, /): - """Returns the unique elements of an input array x, the first occurring indices for each unique element in x, the indices from the set of unique elements that reconstruct x, and the corresponding counts for each unique element in x.""" - return jax.numpy.unique_all(x) - - -def unique_counts(x, /): - """Returns the unique elements of an input array x and the corresponding counts for each unique element in x.""" - return jax.numpy.unique_counts(x) - - -def unique_inverse(x, /): - """Returns the unique elements of an input array x and the indices from the set of unique elements that reconstruct x.""" - return jax.numpy.unique_inverse(x) - - -def unique_values(x, /): - """Returns the unique elements of an input array x.""" - return jax.numpy.unique_values(x) diff --git a/jax/experimental/array_api/_sorting_functions.py b/jax/experimental/array_api/_sorting_functions.py deleted file mode 100644 index 4c64480d39a6..000000000000 --- a/jax/experimental/array_api/_sorting_functions.py +++ /dev/null @@ -1,28 +0,0 @@ -# Copyright 2023 The JAX Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import jax -from jax import Array - - -def argsort(x: Array, /, *, axis: int = -1, descending: bool = False, - stable: bool = True) -> Array: - """Returns the indices that sort an array x along a specified axis.""" - return jax.numpy.argsort(x, axis=axis, descending=descending, stable=stable) - - -def sort(x: Array, /, *, axis: int = -1, descending: bool = False, - stable: bool = True) -> Array: - """Returns a sorted copy of an input array x.""" - return jax.numpy.sort(x, axis=axis, descending=descending, stable=stable) diff --git a/jax/experimental/array_api/_statistical_functions.py b/jax/experimental/array_api/_statistical_functions.py index 2e1333317605..8ee6a39198ee 100644 --- a/jax/experimental/array_api/_statistical_functions.py +++ b/jax/experimental/array_api/_statistical_functions.py @@ -13,43 +13,13 @@ # limitations under the License. import jax -from jax.experimental.array_api._data_type_functions import ( - _promote_to_default_dtype, -) - - -def max(x, /, *, axis=None, keepdims=False): - """Calculates the maximum value of the input array x.""" - return jax.numpy.max(x, axis=axis, keepdims=keepdims) - - -def mean(x, /, *, axis=None, keepdims=False): - """Calculates the arithmetic mean of the input array x.""" - return jax.numpy.mean(x, axis=axis, keepdims=keepdims) - - -def min(x, /, *, axis=None, keepdims=False): - """Calculates the minimum value of the input array x.""" - return jax.numpy.min(x, axis=axis, keepdims=keepdims) - - -def prod(x, /, *, axis=None, dtype=None, keepdims=False): - """Calculates the product of input array x elements.""" - x = _promote_to_default_dtype(x) - return jax.numpy.prod(x, axis=axis, dtype=dtype, keepdims=keepdims) def std(x, /, *, axis=None, correction=0.0, keepdims=False): """Calculates the standard deviation of the input array x.""" - return jax.numpy.std(x, axis=axis, ddof=correction, keepdims=keepdims) - - -def sum(x, /, *, axis=None, dtype=None, keepdims=False): - """Calculates the sum of the input array x.""" - x = _promote_to_default_dtype(x) - return jax.numpy.sum(x, axis=axis, dtype=dtype, keepdims=keepdims) + return jax.numpy.std(x, axis=axis, correction=correction, keepdims=keepdims) def var(x, /, *, axis=None, correction=0.0, keepdims=False): """Calculates the variance of the input array x.""" - return jax.numpy.var(x, axis=axis, ddof=correction, keepdims=keepdims) + return jax.numpy.var(x, axis=axis, correction=correction, keepdims=keepdims) diff --git a/jax/experimental/array_api/_utility_functions.py b/jax/experimental/array_api/_utility_functions.py index 60d739277627..f75b2e2e29af 100644 --- a/jax/experimental/array_api/_utility_functions.py +++ b/jax/experimental/array_api/_utility_functions.py @@ -12,14 +12,74 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import jax +from jax._src.sharding import Sharding +from jax._src.lib import xla_client as xc +from jax._src import dtypes as _dtypes, config + +# TODO(micky774): Add to jax.numpy.util when finalizing jax.experimental.array_api +# deprecation +class __array_namespace_info__: + + def __init__(self): + self._capabilities = { + "boolean indexing": True, + "data-dependent shapes": False, + } + + + def _build_dtype_dict(self): + array_api_types = { + "bool", "int8", "int16", + "int32", "uint8", "uint16", + "uint32", "float32", "complex64" + } + if config.enable_x64.value: + array_api_types |= {"int64", "uint64", "float64", "complex128"} + return {category: {t.name: t for t in types if t.name in array_api_types} + for category, types in _dtypes._dtype_kinds.items()} + + def default_device(self): + # By default JAX arrays are uncommitted (device=None), meaning that + # JAX is free to choose the most efficient device placement. + return None + def devices(self): + return jax.devices() -def all(x, /, *, axis=None, keepdims=False): - """Tests whether all input array elements evaluate to True along a specified axis.""" - return jax.numpy.all(x, axis=axis, keepdims=keepdims) + def capabilities(self): + return self._capabilities + def default_dtypes(self, *, device: xc.Device | Sharding | None = None): + # Array API supported dtypes are device-independent in JAX + del device + default_dtypes = { + "real floating": "f", + "complex floating": "c", + "integral": "i", + "indexing": "i", + } + return { + dtype_name: _dtypes.canonicalize_dtype( + _dtypes._default_types.get(kind) + ) for dtype_name, kind in default_dtypes.items() + } -def any(x, /, *, axis=None, keepdims=False): - """Tests whether any input array element evaluates to True along a specified axis.""" - return jax.numpy.any(x, axis=axis, keepdims=keepdims) + def dtypes( + self, *, + device: xc.Device | Sharding | None = None, + kind: str | tuple[str, ...] | None = None): + # Array API supported dtypes are device-independent in JAX + del device + data_types = self._build_dtype_dict() + if kind is None: + out_dict = data_types["numeric"] | data_types["bool"] + elif isinstance(kind, tuple): + out_dict = {} + for _kind in kind: + out_dict |= data_types[_kind] + else: + out_dict = data_types[kind] + return out_dict diff --git a/jax/experimental/array_api/_version.py b/jax/experimental/array_api/_version.py index 4936af86da4c..104df73c77b9 100644 --- a/jax/experimental/array_api/_version.py +++ b/jax/experimental/array_api/_version.py @@ -12,4 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -__array_api_version__ = '2022.12' +__array_api_version__ = '2023.12' diff --git a/jax/experimental/array_api/fft.py b/jax/experimental/array_api/fft.py index f83d45401d20..bb3721118a47 100644 --- a/jax/experimental/array_api/fft.py +++ b/jax/experimental/array_api/fft.py @@ -12,9 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from jax.experimental.array_api._fft_functions import ( +from jax.numpy.fft import ( fft as fft, - fftfreq as fftfreq, fftn as fftn, fftshift as fftshift, hfft as hfft, @@ -25,6 +24,10 @@ irfft as irfft, irfftn as irfftn, rfft as rfft, - rfftfreq as rfftfreq, rfftn as rfftn, ) + +from jax.experimental.array_api._fft_functions import ( + fftfreq as fftfreq, + rfftfreq as rfftfreq, +) diff --git a/jax/experimental/array_api/linalg.py b/jax/experimental/array_api/linalg.py index 49c93c5b1908..6494884135fe 100644 --- a/jax/experimental/array_api/linalg.py +++ b/jax/experimental/array_api/linalg.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from jax.experimental.array_api._linear_algebra_functions import ( +from jax.numpy.linalg import ( cholesky as cholesky, cross as cross, det as det, @@ -23,17 +23,21 @@ matmul as matmul, matrix_norm as matrix_norm, matrix_power as matrix_power, - matrix_rank as matrix_rank, matrix_transpose as matrix_transpose, outer as outer, - pinv as pinv, qr as qr, slogdet as slogdet, solve as solve, svd as svd, svdvals as svdvals, tensordot as tensordot, - trace as trace, vecdot as vecdot, vector_norm as vector_norm, ) + +from jax.numpy.linalg import trace as trace + +from jax.experimental.array_api._linear_algebra_functions import ( + matrix_rank as matrix_rank, + pinv as pinv, +) diff --git a/jax/experimental/array_api/skips.txt b/jax/experimental/array_api/skips.txt index 76149272ae83..c865fabcfb55 100644 --- a/jax/experimental/array_api/skips.txt +++ b/jax/experimental/array_api/skips.txt @@ -1,11 +1,36 @@ # Known failures for the array api tests. # Test suite attempts in-place mutation: -array_api_tests/test_special_cases.py::test_binary array_api_tests/test_special_cases.py::test_iop array_api_tests/test_special_cases.py::test_nan_propagation -array_api_tests/test_special_cases.py::test_unary array_api_tests/test_array_object.py::test_setitem -# fft test suite is buggy as of 83f0bcdc -array_api_tests/test_fft.py +# Raises NonInteractiveExampleWarning +array_api_tests/test_special_cases.py::test_binary +array_api_tests/test_special_cases.py::test_unary + +# Pending implementation update for proper dtype promotion behavior, +# see https://github.com/data-apis/array-api-tests/issues/234 +array_api_tests/test_statistical_functions.py::test_sum +array_api_tests/test_statistical_functions.py::test_prod + +# Pending bugfix, see https://github.com/data-apis/array-api-tests/issues/256 +array_api_tests/test_signatures.py::test_func_signature[logical_and] +array_api_tests/test_signatures.py::test_func_signature[logical_or] +array_api_tests/test_signatures.py::test_func_signature[logical_xor] + +# Returns int32 when int64 is expected +array_api_tests/test_searching_functions.py::test_searchsorted + +# Various info functions not yet defined +# Pending bugfix, see https://github.com/data-apis/array-api-tests/pull/262 +array_api_tests/test_has_names.py::test_has_names[info-capabilities] +array_api_tests/test_has_names.py::test_has_names[info-default_device] +array_api_tests/test_has_names.py::test_has_names[info-default_dtypes] +array_api_tests/test_has_names.py::test_has_names[info-devices] +array_api_tests/test_has_names.py::test_has_names[info-dtypes] +array_api_tests/test_signatures.py::test_func_signature[capabilities] +array_api_tests/test_signatures.py::test_func_signature[default_device] +array_api_tests/test_signatures.py::test_func_signature[default_dtypes] +array_api_tests/test_signatures.py::test_func_signature[devices] +array_api_tests/test_signatures.py::test_func_signature[dtypes] diff --git a/jax/experimental/array_serialization/serialization.py b/jax/experimental/array_serialization/serialization.py index 347dc65d3055..c7aa8b590412 100644 --- a/jax/experimental/array_serialization/serialization.py +++ b/jax/experimental/array_serialization/serialization.py @@ -17,22 +17,22 @@ import abc import asyncio -from collections.abc import Awaitable, Sequence +from collections.abc import Awaitable, Callable, Sequence from functools import partial import itertools import logging import os import re -import sys import threading import time -from typing import Any, Callable, Optional, Union +from typing import Any import jax from jax._src import array from jax._src import distributed from jax._src import sharding from jax._src import sharding_impls +from jax._src.layout import Layout, DeviceLocalLayout as DLL from jax._src import typing from jax._src import util from jax._src.lib import xla_extension as xe @@ -62,9 +62,6 @@ class BarrierTimeoutException(Exception): "Suggestions for possible fixes:\n" "* Check the logs to see if one or more processes failed.\n" "* Make sure the training and checkpointing endpoints are close geographically.\n" - "* Try setting these environment variables: " - "`TENSORSTORE_CURL_LOW_SPEED_TIME_SECONDS=60` " - "`TENSORSTORE_CURL_LOW_SPEED_LIMIT_BYTES=256` which will force a http retry\n" "* Try increasing the timeout you pass to GlobalAsyncCheckpointManager.") logger = logging.getLogger(__name__) @@ -72,12 +69,12 @@ class BarrierTimeoutException(Exception): async def create_async_array_from_callback( global_shape: array.Shape, - inp_sharding: sharding_impls.XLACompatibleSharding, + inp_sharding: jax.sharding.Sharding, data_callback: Callable[[array.Index, jax.Device], Awaitable[jax.Array]], ): device_to_index_map = inp_sharding.devices_indices_map(global_shape) addressable_da = inp_sharding._addressable_device_assignment - future_arrays = [data_callback(device_to_index_map[d], d) # type: ignore + future_arrays = [data_callback(device_to_index_map[d], d) for d in addressable_da] dbs = await asyncio.gather(*future_arrays) return array.make_array_from_single_device_arrays( @@ -85,19 +82,11 @@ async def create_async_array_from_callback( def _get_metadata(arr): - if arr.dtype == jnp.bfloat16: - # Tensorstore uses 'bfloat16', not ' bool: +def is_remote_storage(tspec: dict[str, Any] | str) -> bool: """Detect if user is using cloud storages. This can detect common defines and unable to detect some corner cases such as @@ -180,7 +169,7 @@ def __init__(self, num_bytes): self._cv = asyncio.Condition(lock=asyncio.Lock()) async def wait_for_bytes(self, requested_bytes): - if requested_bytes >= self._max_bytes: + if requested_bytes > self._max_bytes: raise ValueError('Requested more bytes than we reserved space for: ' f'{requested_bytes} > {self._max_bytes}') async with self._cv: @@ -200,9 +189,24 @@ async def async_serialize( tensorstore_spec, commit_future=None, context=TS_CONTEXT, - primary_host: Optional[int] = 0, + primary_host: int | None = 0, replica_id: int = 0, ): + """Serialize an array using TensorStore. + + Args: + arr_inp: The array to serialize. + tensorstore_spec: The tensorstore spec to use. + commit_future: A list of futures that will be appended to. The futures can + be awaited asynchronously. If None, the futures will be awaited + synchronously by this method. + context: ts.Context instance. + primary_host: Primary host, which indicates the host that will be treated as + the "leader". If None, all hosts are treated as the primary. DO NOT USE + unless you are sure you know what you are doing. + replica_id: Allows overriding the shard replica id that will be saved. + DO NOT USE unless you are sure you know what you are doing. + """ if (isinstance(arr_inp, array.ArrayImpl) and jax.process_count() > 1 and arr_inp.is_fully_addressable): raise ValueError( @@ -211,17 +215,15 @@ async def async_serialize( f'between processes. Serialization have failed for the array with ' f'the path "{tensorstore_spec["kvstore"]["path"]}".') - if primary_host is None and is_remote_storage(tensorstore_spec): - raise ValueError( - 'When primary_host is set to None and remote storage is used,' - ' serialization is not allowed, as this may lead to a race condition' - ' between processes.' - ) # 'metadata' may not be present at the top level (for example, if we are using # a 'cast' driver). if not _spec_has_metadata(tensorstore_spec): tensorstore_spec['metadata'] = _get_metadata(arr_inp) + # Set dtype if it's not in spec + if 'dtype' not in tensorstore_spec: + tensorstore_spec['dtype'] = jnp.dtype(arr_inp.dtype).name + # If primary_host is None, all hosts will checkpoint. This is used # for checkpointing to local filesystem. if primary_host is None or jax.process_index() == primary_host: @@ -238,11 +240,12 @@ async def async_serialize( else: await open_future - # `ts.open` runs twice for process 0 because for the first time, we just get - # the future to be awaited upon in the background thread. The second one runs - # with `assume_metadata=True` which does no I/O operation and returns the - # tensorstore object. - # For every process other than `0`, we open with `assume_metadata=True`. + # `ts.open` runs twice for process `primary_host` because for the first time, + # we just get the future to be awaited upon in the background thread. The + # second one runs with `assume_metadata=True` which does no I/O operation and + # returns the tensorstore object. + # For every process other than `primary_host`, we open with + # `assume_metadata=True`. t = await ts.open( ts.Spec(tensorstore_spec), open=True, @@ -260,10 +263,7 @@ async def _write_array(shard): else: await write_future.commit - if isinstance(arr_inp, array.ArrayImpl): - local_shards = arr_inp.addressable_shards - else: - local_shards = arr_inp.addressable_shards + local_shards = arr_inp.addressable_shards future_write_state = jax.tree_util.tree_map(_write_array, local_shards) return await asyncio.gather(*future_write_state) @@ -309,7 +309,7 @@ def estimate_read_memory_footprint(t: ts.TensorStore, async def async_deserialize( - in_sharding: sharding_impls.XLACompatibleSharding, + user_in_sharding: jax.sharding.Sharding | Layout, tensorstore_spec: ts.Spec | dict[str, Any], global_shape: Sequence[int] | None = None, dtype=None, @@ -317,6 +317,14 @@ async def async_deserialize( context=TS_CONTEXT, assume_metadata: bool = False, ): + in_sharding = (user_in_sharding.sharding + if isinstance(user_in_sharding, Layout) else user_in_sharding) + if not isinstance(in_sharding, jax.sharding.Sharding): + raise ValueError( + 'sharding passed to deserialization should be specified, concrete and' + f' an instance of `jax.sharding.Sharding`. Got {in_sharding}') + dll = (user_in_sharding.device_local_layout + if isinstance(user_in_sharding, Layout) else None) t = await ts.open( tensorstore_spec, open=True, @@ -343,7 +351,14 @@ async def cb(index: array.Index, device: jax.Device): # Cast while reloading on process to avoid 2 copies on device if the # casting is done on device. out = out.astype(dtype) - result = jax.device_put(out, device) + # Convert to jnp array so that layouts are initialized properly for + # sub-byte dtypes. + # TODO(yashkatariya): This is a band-aid fix. Figure out a better way to + # make this work. + if out.dtype == jnp.int4: + out = jnp.asarray(out) # type: ignore + result = jax.device_put( + out, Layout(dll, jax.sharding.SingleDeviceSharding(device))) if byte_limiter is not None: # NB: `out` actually might not be ready for garbage collection by the # time we call release_bytes . Thus peak memory usage still might grow @@ -361,7 +376,7 @@ async def cb(index: array.Index, device: jax.Device): return await create_async_array_from_callback(tuple(shape), in_sharding, cb) -def run_deserialization(shardings: Sequence[sharding.Sharding], +def run_deserialization(shardings: Sequence[sharding.Sharding | Layout], tensorstore_specs: Sequence[dict[str, Any]], global_shapes: Sequence[array.Shape] | None = None, dtypes: Sequence[typing.DTypeLike] | None = None, @@ -396,7 +411,7 @@ class GlobalAsyncCheckpointManagerBase(util.StrictABC): is finished, checkpoint for step 2 will need to be blocked. Maintaining a class allows to maintain that state. - Example: + Examples: Below is a simplified training loop: @@ -599,7 +614,7 @@ def serialize_with_paths(self, arrays: Sequence[jax.Array], tspecs = jax.tree.map(get_tensorstore_spec, paths) self.serialize(arrays, tspecs, on_commit_callback=on_commit_callback) - def deserialize(self, shardings: Sequence[sharding.Sharding], + def deserialize(self, shardings: Sequence[sharding.Sharding | Layout], tensorstore_specs: Sequence[dict[str, Any]], global_shapes: Sequence[array.Shape] | None = None, dtypes: Sequence[typing.DTypeLike] | None = None, diff --git a/jax/experimental/array_serialization/serialization_test.py b/jax/experimental/array_serialization/serialization_test.py index 6a0828968ea6..ccf2d05467a9 100644 --- a/jax/experimental/array_serialization/serialization_test.py +++ b/jax/experimental/array_serialization/serialization_test.py @@ -14,6 +14,7 @@ """Tests for serialization and deserialization of GDA.""" import asyncio +import contextlib import math from functools import partial import os @@ -23,28 +24,25 @@ from absl.testing import absltest from absl.testing import parameterized import jax +import jax.numpy as jnp from jax._src import test_util as jtu -from jax import config from jax._src import array +from jax._src import xla_bridge as xb from jax.sharding import NamedSharding, GSPMDSharding from jax.sharding import PartitionSpec as P from jax.experimental.array_serialization import serialization +from jax.experimental.layout import Layout, DeviceLocalLayout as DLL import numpy as np import tensorstore as ts -import unittest -config.parse_flags_with_absl() - -prev_xla_flags = None +jax.config.parse_flags_with_absl() +_exit_stack = contextlib.ExitStack() def setUpModule(): - global prev_xla_flags - # This will control the CPU devices. On TPU we always have 2 devices - prev_xla_flags = jtu.set_host_platform_device_count(8) + _exit_stack.enter_context(jtu.set_host_platform_device_count(8)) -# Reset to previous configuration in case other test modules will be run. def tearDownModule(): - prev_xla_flags() + _exit_stack.close() class CheckpointTest(jtu.JaxTestCase): @@ -59,7 +57,7 @@ def test_memory_consumption(self): pspec = P('x', 'y') num = math.prod(inp_shape) sharding = NamedSharding(global_mesh, pspec) - src = jax.numpy.arange(num, dtype=np.int32).reshape(inp_shape) # 8e9 + src = jnp.arange(num, dtype=np.int32).reshape(inp_shape) # 8e9 inp = array.make_array_from_callback( inp_shape, sharding, lambda idx: src[idx]) @@ -198,12 +196,15 @@ def cb3(_): self.assertArraysEqual(np.asarray(s.data), np.array([], dtype=np.float32)) self.assertEqual(m3.dtype, np.float32) - def test_checkpointing_with_bigger_shape_jax_array(self): + @parameterized.product(input_dtype=[np.int32, jnp.bfloat16]) + def test_checkpointing_with_bigger_shape_jax_array(self, input_dtype): global_mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) global_input_shape = (8, 2) num = math.prod(global_input_shape) - global_input_data1 = np.arange(num, dtype=np.int32).reshape(global_input_shape) + global_input_data1 = np.arange(num, dtype=input_dtype).reshape( + global_input_shape + ) def cb1(index): return global_input_data1[index] arr = array.make_array_from_callback( @@ -243,6 +244,56 @@ def cb1(index): for l in m2.addressable_shards: self.assertArraysEqual(l.data, global_input_data1.astype('float32')) + @parameterized.product(input_dtype=[jnp.int4, jnp.int8]) + def test_checkpointing_with_int4(self, input_dtype): + global_mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) + global_input_shape = (8, 2) + num = math.prod(global_input_shape) + + global_input_data = np.arange(num, dtype=input_dtype).reshape( + global_input_shape + ) + def cb(index): + return global_input_data[index] + arr = array.make_array_from_callback( + global_input_shape, NamedSharding(global_mesh, P('x', 'y')), cb) + ckpt_dir = pathlib.Path(self.create_tempdir('first').full_path) + + ckpt_paths = [str(ckpt_dir)] + tspecs = jax.tree_util.tree_map(serialization.get_tensorstore_spec, ckpt_paths) + + manager = serialization.GlobalAsyncCheckpointManager() + manager.serialize( + [arr], tspecs, + on_commit_callback=partial(self._on_commit_callback, ckpt_dir, ckpt_dir)) + manager.wait_until_finished() + + ds = NamedSharding(jtu.create_global_mesh((4, 2), ('x', 'y')), P('x', 'y')) + + target_dtype = jnp.dtype('int4') + m1, = serialization.run_deserialization([ds], tspecs, [(12, 2)], + [target_dtype]) + + # values bigger than 7 are converted properly. + expected_data = { + 0: jnp.array([[0], [2], [4]], dtype=target_dtype), + 1: jnp.array([[1], [3], [5]], dtype=target_dtype), + 2: jnp.array([[6], [8], [10]], dtype=target_dtype), + 3: jnp.array([[7], [9], [11]], dtype=target_dtype), + 4: jnp.array([[12], [14], [0]], dtype=target_dtype), + 5: jnp.array([[13], [15], [0]], dtype=target_dtype), + 6: jnp.array([[0], [0], [0]], dtype=target_dtype), + 7: jnp.array([[0], [0], [0]], dtype=target_dtype), + } + + for l in m1.addressable_shards: + self.assertArraysEqual(np.asarray(l.data), expected_data[l.device.id]) + + new_ds = GSPMDSharding.get_replicated(list(global_mesh.devices.flat)) + m2, = serialization.run_deserialization([new_ds], tspecs, [(8, 2)], [target_dtype]) + for l in m2.addressable_shards: + self.assertArraysEqual(l.data, global_input_data.astype(target_dtype)) + def test_checkpointing_scalar_jax_array(self): global_mesh = jtu.create_global_mesh((2,), ('x')) global_input_shape = () @@ -366,5 +417,69 @@ def test_maybe_cloud_storage(self): } self.assertTrue(serialization.is_remote_storage(nested_tspec)) + def test_load_with_layout(self): + if not jtu.test_device_matches(['tpu']): + self.skipTest('Layouts are only supported on TPUs') + + mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + np_inp = np.arange(32).reshape(8, 4) + s = NamedSharding(mesh, P('x', 'y')) + arr = jax.device_put(np_inp, s) + + out_layout = jax.jit(lambda x: x.T, out_shardings=Layout(DLL.AUTO)).lower( + arr).compile().output_layouts() + self.assertEqual(arr.layout.device_local_layout.major_to_minor, + out_layout.device_local_layout.major_to_minor[::-1]) + + ckpt_dir = pathlib.Path(self.create_tempdir('ckpt').full_path) + ckpt_path = pathlib.Path(self.create_tempdir(f'{ckpt_dir}/first').full_path) + tspecs = jax.tree_util.tree_map(serialization.get_tensorstore_spec, [ckpt_path]) + + manager = serialization.GlobalAsyncCheckpointManager() + manager.serialize( + [arr], tspecs, + on_commit_callback=partial(self._on_commit_callback, ckpt_dir, ckpt_dir)) + manager.wait_until_finished() + + out, = serialization.run_deserialization([out_layout], tspecs) + + self.assertEqual(out.layout, out_layout) + self.assertIsInstance(out, array.ArrayImpl) + self.assertArraysEqual(out, np_inp) + for s in out.addressable_shards: + self.assertArraysEqual(s.data, np_inp[s.index]) + + def test_deserialization_with_int4(self): + dtype = jnp.int4 + shape = (8, 2) + arr = jnp.arange(np.prod(shape)).reshape(shape).astype(dtype) + + ckpt_dir = pathlib.Path(self.create_tempdir('test_ckpt').full_path) + + # Run serialization. + sharding = jax.sharding.GSPMDSharding.get_replicated(jax.devices()) + tspecs = jax.tree_util.tree_map( + serialization.get_tensorstore_spec, [ckpt_dir] + ) + manager = serialization.GlobalAsyncCheckpointManager() + manager.serialize( + [arr], + tspecs, + on_commit_callback=lambda: None, + ) + manager.wait_until_finished() + + # Run deserialization. + deserialized_arr, = serialization.run_deserialization( + shardings=[sharding], + tensorstore_specs=tspecs, + global_shapes=[shape], + dtypes=[dtype], + ) + + out = deserialized_arr.astype(jnp.int8) # doesn't crash + self.assertEqual(out.dtype, jnp.int8) + self.assertArraysEqual(out + out, out * 2) + if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/jax/experimental/attrs.py b/jax/experimental/attrs.py index ac076a4d57ea..8176465c1470 100644 --- a/jax/experimental/attrs.py +++ b/jax/experimental/attrs.py @@ -14,6 +14,7 @@ from __future__ import annotations +from contextlib import contextmanager from typing import Any from jax._src import core @@ -30,59 +31,64 @@ zip, unsafe_zip = safe_zip, zip JaxVal = Any +Pytree = Any register = api_util.register_class_with_attrs -class GetAttrPrimitive(core.Primitive): - def bind_with_trace(self, trace, args, params): - () = args - return trace.process_getattr(**params) -getattr_p = GetAttrPrimitive('getattr') - -class SetAttrPrimitive(core.Primitive): - def bind_with_trace(self, trace, args, params): - val, = args - return trace.process_setattr(trace.full_raise(val), **params) -setattr_p = SetAttrPrimitive('setattr') +@contextmanager +def top_trace(): + stack = core.thread_local_state.trace_state.trace_stack.stack + main = stack.pop() + try: + trace = main.with_cur_sublevel() + yield trace + finally: + stack.append(main) def jax_getattr(obj: Any, attr: str): - return getattr_p.bind(obj=obj, attr=attr) - -def jax_setattr(obj: Any, attr: str, val: JaxVal): - setattr_p.bind(val, obj=obj, attr=attr) + with top_trace() as trace: + return trace.process_getattr(obj, attr) +def jax_setattr(obj: Any, attr: str, val: Pytree): + with top_trace() as trace: + return trace.process_setattr(obj, attr, val) -def _getattr_impl(_, *, obj, attr): +def _getattr_impl(_, obj, attr): return getattr(obj, attr) core.EvalTrace.process_getattr = _getattr_impl -def _setattr_impl(_, val, *, obj, attr): +def _setattr_impl(_, obj, attr, val): setattr(obj, attr, val) core.EvalTrace.process_setattr = _setattr_impl - def _ensure_tracked(trace: pe.DynamicJaxprTrace, obj: Any, attr: str): frame = trace.main.jaxpr_stack[-1] # type: ignore - if (obj, attr) not in frame.attrs_tracked: - init_val = getattr(obj, attr) - aval = core.raise_to_shaped(core.get_aval(init_val)) + + def new_tracer(x): + aval = core.raise_to_shaped(core.get_aval(x)) tracer = pe.DynamicJaxprTracer(trace, aval, pe.source_info_util.current()) var = frame.tracer_to_var[id(tracer)] = frame.newvar(aval) - setattr(obj, attr, tracer) - frame.attrs_tracked.append((obj, attr)) - frame.attrs_inits.append(init_val) frame.attrs_vars.append(var) frame.tracers.append(tracer) + return tracer + + if (obj, attr) not in frame.attrs_tracked: + init_val = getattr(obj, attr) + frame.attrs_inits.append(init_val) + init_vals, init_tree = tree_flatten(init_val) + tracers = map(new_tracer, init_vals) + setattr(obj, attr, tree_unflatten(init_tree, tracers)) + frame.attrs_tracked.append((obj, attr)) pe.DynamicJaxprTrace._ensure_tracked = _ensure_tracked -def _getattr_staging(trace, *, obj, attr): +def _getattr_staging(trace, obj, attr): trace._ensure_tracked(obj, attr) return getattr(obj, attr) pe.DynamicJaxprTrace.process_getattr = _getattr_staging -def _setattr_staging(trace, tracer, *, obj, attr): +def _setattr_staging(trace, obj, attr, val): trace._ensure_tracked(obj, attr) - setattr(obj, attr, tracer) + setattr(obj, attr, val) pe.DynamicJaxprTrace.process_setattr = _setattr_staging @@ -134,12 +140,19 @@ def jvp_subtrace2(main, primals, tangents): del main.attrs_tracked yield out_primals, out_tangents, tangent_attrs_out -def _setattr_jvp(trace, tracer, *, obj, attr): +def _setattr_jvp(trace, obj, attr, maybe_tracer): + tracer = trace.full_raise(maybe_tracer) + if isinstance(tracer.tangent, ad.Zero): + return setattr(obj, attr, tracer.primal) if (obj, attr) not in trace.main.attrs_tracked: trace.main.attrs_tracked.append((obj, attr)) - setattr(obj, attr, tracer) + return setattr(obj, attr, tracer) ad.JVPTrace.process_setattr = _setattr_jvp +def _getattr_jvp(trace, obj, attr): + return getattr(obj, attr) +ad.JVPTrace.process_getattr = _getattr_jvp + def linearize(f, *primals, attrs: list[tuple[Any, str]] = []): attr_primals = [jax_getattr(o, a) for o, a in attrs] diff --git a/jax/experimental/compute_on.py b/jax/experimental/compute_on.py new file mode 100644 index 000000000000..dac3540c10c5 --- /dev/null +++ b/jax/experimental/compute_on.py @@ -0,0 +1,17 @@ +# Copyright 2023 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the ific language governing permissions and +# limitations under the License. + +from jax._src.compute_on import ( + compute_on as compute_on, +) diff --git a/jax/experimental/custom_partitioning.py b/jax/experimental/custom_partitioning.py index c71eaf19c095..aa138fe88993 100644 --- a/jax/experimental/custom_partitioning.py +++ b/jax/experimental/custom_partitioning.py @@ -14,6 +14,7 @@ from __future__ import annotations +from functools import partial import inspect from typing import Optional import weakref @@ -32,8 +33,8 @@ from jax._src.interpreters import partial_eval as pe from jax._src.sharding_impls import _op_sharding_to_pos_sharding from jax._src import custom_api_util +from jax._src import xla_bridge as xb from jax._src.lib import xla_client as xc -from jax._src.lib import xla_extension_version from jax._src.api_util import flatten_fun_nokwargs, argnums_partial @@ -132,10 +133,8 @@ def _custom_partitioning_propagate_user_sharding(user_sharding, shape, def _to_hlo_sharding(sharding, num_dimensions): - if not isinstance(sharding, jax.sharding.XLACompatibleSharding): - raise ValueError( - "Custom Partitioning rules must return XLACompatibleShardings." - ) + if not isinstance(sharding, jax.sharding.Sharding): + raise ValueError("Custom Partitioning rules must return Sharding.") return sharding._to_xla_hlo_sharding(num_dimensions) @@ -181,18 +180,15 @@ def _custom_partitioning_partition(arg_shapes, arg_shardings, result_shape, % (repr(closed_jaxpr.out_avals), repr(tiled_results)) ) axis_context = sharding_impls.SPMDAxisContext(mesh) - module = mlir.build_mlir_module_helper( - closed_jaxpr, - name="tmp_xla_computation", - platforms=module_context.platforms, - backend_or_name=module_context.backend_or_name, - axis_context=axis_context.extend_manual(frozenset(mesh.axis_names)), - ) + with core.extend_axis_env_nd(mesh.shape.items()): + module = mlir.build_mlir_module_helper( + closed_jaxpr, + name="tmp_xla_computation", + platforms=module_context.platforms, + backend_or_name=module_context.backend_or_name, + axis_context=axis_context.extend_manual(frozenset(mesh.axis_names)), + ) result_sharding = _pack_result_sharding(result_shape, result_shardings) - if xla_extension_version < 232: - built = xc._xla.mlir.mlir_module_to_xla_computation( - mlir.module_to_string(module), use_tuple_args=False, return_tuple=False) - return built, arg_shardings, result_sharding return mlir.module_to_bytecode(module), arg_shardings, result_sharding @@ -303,7 +299,7 @@ def infer_sharding_from_operands(mesh, arg_shapes, shape): Positional arguments can be specified as static using static_argnums. JAX uses :code:`inspect.signature(fun)` to resolve these positional arguments. - Example: + Examples: As an example, assume we want to enhance the existing ``jax.numpy.fft.fft``. This function computes the discrete Fourier transform of an N-dimensional input along the last dimension, and is batched @@ -518,8 +514,6 @@ def to_mesh_pspec_sharding(hlo_sharding: xc.HloSharding | None, ndim): partition, to_mesh_pspec_sharding, in_tree, out_tree, infer_sharding_from_operands, ctx.module_context, mesh, static_args) key = str(id(sharding_callback_info)) - # TODO(parkers): Remove bytes registration when xla_extension_version > 211 - _sharding_callbacks[key] = sharding_callback_info _sharding_callbacks[bytes(key, 'utf8')] = sharding_callback_info # We need to make sure `sharding_callback_info` is still alive when the SPMD # partitioner runs so we keep it alive by attaching it to the executable. @@ -541,8 +535,18 @@ def to_mesh_pspec_sharding(hlo_sharding: xc.HloSharding | None, ndim): mlir.register_lowering(custom_partitioning_p, _custom_partitioning_lowering_rule) -xc.register_custom_call_partitioner( # pytype: disable=module-attr +xc.register_custom_call_partitioner( _CUSTOM_PARTITIONING_CALL_NAME, _custom_partitioning_propagate_user_sharding, _custom_partitioning_partition, - _custom_partitioning_infer_sharding_from_operands, True) + _custom_partitioning_infer_sharding_from_operands, True) # type: ignore +xb.register_plugin_callbacks( + partial( + xc.register_custom_call_partitioner, + name=_CUSTOM_PARTITIONING_CALL_NAME, + prop_user_sharding=_custom_partitioning_propagate_user_sharding, + partition=_custom_partitioning_partition, + infer_sharding_from_operands=_custom_partitioning_infer_sharding_from_operands, + can_side_effecting_have_replicated_sharding=True, + ) +) diff --git a/jax/experimental/export/BUILD b/jax/experimental/export/BUILD index 436b89c8bee9..1246b0d407af 100644 --- a/jax/experimental/export/BUILD +++ b/jax/experimental/export/BUILD @@ -14,11 +14,11 @@ # JAX-export provides APIs for exporting StableHLO for serialization purposes. +load("@rules_python//python:defs.bzl", "py_library") load( "//jaxlib:jax.bzl", "py_deps", ) -load("@rules_python//python:defs.bzl", "py_library") licenses(["notice"]) @@ -31,11 +31,6 @@ py_library( name = "export", srcs = [ "__init__.py", - "_export.py", - "_serialization.py", - "_shape_poly.py", - "_shape_poly_decision.py", - "serialization_generated.py", ], srcs_version = "PY3", # TODO: b/255503696: enable pytype diff --git a/jax/experimental/export/__init__.py b/jax/experimental/export/__init__.py index ac4fd199e37f..b67354bb4248 100644 --- a/jax/experimental/export/__init__.py +++ b/jax/experimental/export/__init__.py @@ -13,26 +13,61 @@ # limitations under the License. # ============================================================================== -from jax.experimental.export._export import ( - minimum_supported_serialization_version, - maximum_supported_serialization_version, - Exported, - export, - call_exported, # TODO: deprecate - call, - DisabledSafetyCheck, - default_lowering_platform, - - args_specs, # TODO: deprecate +_deprecation_message = ( + "The jax.experimental.export module is deprecated. " + "Use jax.export instead. " + "See the migration guide at https://jax.readthedocs.io/en/latest/export/export.html#migration-guide-from-jax-experimental-export." ) -from jax.experimental.export._shape_poly import ( - is_symbolic_dim, - symbolic_shape, - symbolic_args_specs, - SymbolicScope, -) -from jax.experimental.export._serialization import ( - serialize, - deserialize, -) -from jax.experimental.export import _shape_poly_decision + +from jax._src.export import _export as _src_export +from jax._src.export import shape_poly as _src_shape_poly +from jax._src.export import serialization as _src_serialization +# Import only to set the shape poly decision procedure +from jax._src.export import shape_poly_decision +del shape_poly_decision + +# All deprecations added Jun 14, 2024 +_deprecations = { + # Added Jun 13, 2024 + "Exported": (_deprecation_message, _src_export.Exported), + "DisabledSafetyCheck": (_deprecation_message, _src_export.DisabledSafetyCheck), + "export": (_deprecation_message, _src_export.export_back_compat), + "call": (_deprecation_message, _src_export.call), + "call_exported": (_deprecation_message, _src_export.call_exported), + "default_lowering_platform": (_deprecation_message, _src_export.default_lowering_platform), + "minimum_supported_serialization_version" : (_deprecation_message, _src_export.minimum_supported_calling_convention_version), + "maximum_supported_serialization_version" : (_deprecation_message, _src_export.maximum_supported_calling_convention_version), + + "serialize": (_deprecation_message, _src_serialization.serialize), + "deserialize": (_deprecation_message, _src_serialization.deserialize), + + "SymbolicScope": (_deprecation_message, _src_shape_poly.SymbolicScope), + "is_symbolic_dim": (_deprecation_message, _src_shape_poly.is_symbolic_dim), + "symbolic_shape": (_deprecation_message, _src_shape_poly.symbolic_shape), + "symbolic_args_specs": (_deprecation_message, _src_shape_poly.symbolic_args_specs), +} + +import typing +if typing.TYPE_CHECKING: + Exported = _src_export.Exported + DisabledSafetyCheck = _src_export.DisabledSafetyCheck + export = _src_export.export_back_compat + call = _src_export.call + call_exported = _src_export.call_exported + default_lowering_platform = _src_export.default_lowering_platform + + serialize = _src_serialization.serialize + deserialize = _src_serialization.deserialize + + SymbolicScope = _src_shape_poly.SymbolicScope + is_symbolic_dim = _src_shape_poly.is_symbolic_dim + symbolic_shape = _src_shape_poly.symbolic_shape + symbolic_args_specs = _src_shape_poly.symbolic_args_specs +else: + from jax._src.deprecations import deprecation_getattr as _deprecation_getattr + __getattr__ = _deprecation_getattr(__name__, _deprecations) + del _deprecation_getattr +del typing +del _src_export +del _src_serialization +del _src_shape_poly diff --git a/jax/experimental/host_callback.py b/jax/experimental/host_callback.py index b4ab6628a55c..43e9813d7fac 100644 --- a/jax/experimental/host_callback.py +++ b/jax/experimental/host_callback.py @@ -13,7 +13,11 @@ # limitations under the License. """Primitives for calling Python functions on the host from JAX accelerator code. -**Experimental: please give feedback, and expect changes.** +.. warning:: + The host_callback APIs are deprecated as of March 20, 2024. + The functionality is subsumed by the + `new JAX external callbacks `_ + See https://github.com/google/jax/issues/20385. This module introduces the host callback functions :func:`call`, :func:`id_tap`, and :func:`id_print`, that send their arguments from the device @@ -498,15 +502,17 @@ def power3_with_cotangents(x): from __future__ import annotations import atexit -from collections.abc import Sequence +import enum +from collections.abc import Callable, Sequence import functools import itertools import logging import math import threading import traceback -from typing import Any, Callable, Optional, cast +from typing import Any, cast +import jax from jax._src import api from jax._src import core from jax._src import config @@ -514,6 +520,7 @@ def power3_with_cotangents(x): from jax._src import dtypes from jax import lax from jax.experimental import pjit +from jax.experimental import io_callback from jax._src.interpreters import ad, batching, pxla from jax._src.interpreters import mlir from jax._src.interpreters import partial_eval as pe @@ -534,12 +541,12 @@ def power3_with_cotangents(x): import numpy as np -_HOST_CALLBACK_INLINE = config.DEFINE_bool( +_HOST_CALLBACK_INLINE = config.bool_flag( 'jax_host_callback_inline', config.bool_env('JAX_HOST_CALLBACK_INLINE', False), help='Inline the host_callback, if not in a staged context.' ) -_HOST_CALLBACK_MAX_QUEUE_BYTE_SIZE = config.DEFINE_integer( +_HOST_CALLBACK_MAX_QUEUE_BYTE_SIZE = config.int_flag( 'jax_host_callback_max_queue_byte_size', config.int_env('JAX_HOST_CALLBACK_MAX_QUEUE_BYTE_SIZE', int(256 * 1e6)), help=('The size in bytes of the buffer used to hold outfeeds from each ' @@ -548,7 +555,7 @@ def power3_with_cotangents(x): 'until the Python callback consume more outfeeds.'), lower_bound=int(16 * 1e6) ) -_HOST_CALLBACK_OUTFEED = config.DEFINE_bool( +_HOST_CALLBACK_OUTFEED = config.bool_flag( 'jax_host_callback_outfeed', config.bool_env('JAX_HOST_CALLBACK_OUTFEED', False), help=( @@ -557,6 +564,15 @@ def power3_with_cotangents(x): 'Has no effect on TPU, since only the outfeed mechanism is implemented.' ) ) +_HOST_CALLBACK_LEGACY = config.bool_flag( + 'jax_host_callback_legacy', + config.bool_env('JAX_HOST_CALLBACK_LEGACY', True), + help=( + 'Use old implementation of host_callback, documented in the module docstring.' + 'If False, use the jax.experimental.io_callback implementation. ' + 'See https://github.com/google/jax/issues/20385.' + ) +) logger = logging.getLogger(__name__) @@ -570,9 +586,8 @@ def _raise_if_using_outfeed_with_pjrt_c_api(backend: xb.XlaBackend): """Should be called whenever outfeed (or infeed) will be used.""" if xb.using_pjrt_c_api(backend): raise NotImplementedError( - "host_callback functionality isn't supported with the new Cloud TPU " - "runtime. See https://jax.readthedocs.io/en/latest/debugging/index.html" - " and " + "host_callback functionality isn't supported with PJRT C API. " + "See https://jax.readthedocs.io/en/latest/debugging/index.html and " "https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html" " for alternatives. Please file a feature request at " "https://github.com/google/jax/issues if none of the alternatives are " @@ -588,17 +603,31 @@ def _raise_if_using_outfeed_with_pjrt_c_api(backend: xb.XlaBackend): XlaLocalClient = xla_client.Client DType = Any +class CallbackFlavor(enum.Enum): + """Specifies which flavor of callback to use under JAX_HOST_CALLBACK_LEGACY=False. + + See https://github.com/google/jax/issues/20385. + """ + IO_CALLBACK = 1 # uses jax.experimental.io_callback + PURE = 2 # uses jax.pure_callback + DEBUG = 3 # uses jax.debug.callback, valid only when there are no results + -def id_tap(tap_func, +def _deprecated_id_tap(tap_func, arg, *, result=None, tap_with_device=False, device_index=0, + callback_flavor=CallbackFlavor.IO_CALLBACK, **kwargs): """Host-callback tap primitive, like identity function with a call to ``tap_func``. - **Experimental: please give feedback, and expect changes!** + .. warning:: + The host_callback APIs are deprecated as of March 20, 2024. + The functionality is subsumed by the + `new JAX external callbacks `_ + See https://github.com/google/jax/issues/20385. ``id_tap`` behaves semantically like the identity function but has the side-effect that a user-defined Python function is called with the runtime @@ -622,6 +651,9 @@ def id_tap(tap_func, device_index: specifies from which device the tap function is invoked in a SPMD program. Works only when using the outfeed implementation mechanism, i.e., does not work on CPU unless --jax_host_callback_outfeed=True. + callback_flavor: if running with `JAX_HOST_CALLBACK_LEGACY=False` specifies + the flavor of callback to use. + See https://github.com/google/jax/issues/20385. Returns: ``arg``, or ``result`` if given. @@ -654,7 +686,8 @@ def id_tap(tap_func, call_with_device=tap_with_device, result_shape=None, identity=True, - device_index=device_index) + device_index=device_index, + callback_flavor=callback_flavor) if result is not None: return result @@ -662,17 +695,22 @@ def id_tap(tap_func, return call_res -def id_print(arg, +def _deprecated_id_print(arg, *, result=None, tap_with_device=False, device_index=0, output_stream=None, threshold=None, + callback_flavor=CallbackFlavor.IO_CALLBACK, **kwargs): """Like :func:`id_tap` with a printing tap function. - **Experimental: please give feedback, and expect changes!** + .. warning:: + The host_callback APIs are deprecated as of March 20, 2024. + The functionality is subsumed by the + `new JAX external callbacks `_ + See https://github.com/google/jax/issues/20385. On each invocation of the printing tap, the ``kwargs`` if present will be printed first (sorted by keys). Then arg will be printed, @@ -688,27 +726,36 @@ def id_print(arg, built-in ``print``. The string will be passed as ``output_stream.write(s)``. * ``threshold`` is passed to ``numpy.array2string``. + * ``callback_flavor``: if running with `JAX_HOST_CALLBACK_LEGACY=False` specifies + the flavor of callback to use. + See https://github.com/google/jax/issues/20385. For more details see the :mod:`jax.experimental.host_callback` module documentation. """ printer = functools.partial(_print_tap_func, output_stream=output_stream, threshold=threshold, **kwargs) - return id_tap( + return _deprecated_id_tap( printer, arg, result=result, tap_with_device=tap_with_device, - device_index=device_index) + device_index=device_index, + callback_flavor=callback_flavor) -def call(callback_func: Callable, arg, *, +def _deprecated_call(callback_func: Callable, arg, *, result_shape=None, call_with_device=False, - device_index=0): + device_index=0, + callback_flavor=CallbackFlavor.IO_CALLBACK): """Make a call to the host, and expect a result. - **Experimental: please give feedback, and expect changes!** + .. warning:: + The host_callback APIs are deprecated as of March 20, 2024. + The functionality is subsumed by the + `new JAX external callbacks `_ + See https://github.com/google/jax/issues/20385. Args: callback_func: The Python function to invoke on the host as @@ -736,14 +783,26 @@ def call(callback_func: Callable, arg, *, device_index: specifies from which device the tap function is invoked in a SPMD program. Works only when using the outfeed implementation mechanism, i.e., does not work on CPU unless --jax_host_callback_outfeed=True. + callback_flavor: if running with `JAX_HOST_CALLBACK_LEGACY=False` specifies + the flavor of callback to use. + See https://github.com/google/jax/issues/20385. + Returns: the result of the ``callback_func`` invocation. For more details see the :mod:`jax.experimental.host_callback` module documentation. """ + if (not _HOST_CALLBACK_LEGACY.value and + callback_flavor is CallbackFlavor.DEBUG and + result_shape is not None): + raise NotImplementedError( + "When using JAX_HOST_CALLBACK_LEGACY=False you can use the `DEBUG` " + "flavor of callback only when the `result_shape` is None. " + "See https://github.com/google/jax/issues/20385." + ) return _call(callback_func, arg, result_shape=result_shape, call_with_device=call_with_device, identity=False, - device_index=device_index) + device_index=device_index, callback_flavor=callback_flavor) # We need the wrapper function to have hash and equality defined since it is @@ -754,6 +813,11 @@ def __init__(self, callback_func, identity, call_with_device): self.callback_func = callback_func self.identity = identity self.call_with_device = call_with_device + if not _HOST_CALLBACK_LEGACY.value and call_with_device: + raise NotImplementedError( + "When using JAX_HOST_CALLBACK_LEGACY=False, the host_callback APIs" + " do not support `tap_with_device` and `call_with_device`. " + "See https://github.com/google/jax/issues/20385.") def __hash__(self): return hash((self.callback_func, self.identity, self.call_with_device)) @@ -763,7 +827,16 @@ def __eq__(self, other): self.identity == other.identity and self.call_with_device == other.call_with_device) - def __call__(self, arg, device, transforms): + def __call__(self, *args, **kwargs): + if _HOST_CALLBACK_LEGACY.value: + return self._call_legacy(*args, **kwargs) + else: + if self.identity: + # For id_tap, we pass empty transforms, for backwards compatibility + return self.callback_func(args[0], ()) + return self.callback_func(*args, **kwargs) + + def _call_legacy(self, arg, device, transforms): if self.identity: # For id_tap, we pass the transforms, for backwards compatibility if self.call_with_device: @@ -785,14 +858,16 @@ def _call(callback_func: Callable, result_shape=None, call_with_device=False, device_index=0, - identity=False): - # Lazy initialization - _initialize_outfeed_receiver( - max_callback_queue_size_bytes=_HOST_CALLBACK_MAX_QUEUE_BYTE_SIZE.value) + identity=False, + callback_flavor=CallbackFlavor.IO_CALLBACK): + if _HOST_CALLBACK_LEGACY.value: + # Lazy initialization + _initialize_outfeed_receiver( + max_callback_queue_size_bytes=_HOST_CALLBACK_MAX_QUEUE_BYTE_SIZE.value) api.check_callable(callback_func) flat_args, arg_treedef = tree_util.tree_flatten(arg) - for arg in flat_args: - dispatch.check_arg(arg) + for arg_ in flat_args: + dispatch.check_arg(arg_) # See definition of outside_call_p for what parameters it takes params: dict[str, Any] = {} # TODO: wrap function @@ -817,8 +892,27 @@ def _call(callback_func: Callable, params["result_treedef"] = result_treedef params["flat_results_aval"] = tuple(flat_results_aval) - flat_results = outside_call_p.bind(*flat_args, **params) - return result_treedef.unflatten(flat_results) if not identity else arg_treedef.unflatten(flat_results) + + if _HOST_CALLBACK_LEGACY.value: + flat_results = outside_call_p.bind(*flat_args, **params) + return result_treedef.unflatten(flat_results) if not identity else arg_treedef.unflatten(flat_results) + else: + callback_device = jax.local_devices()[device_index] + sharding = jax.sharding.SingleDeviceSharding(callback_device) + callback_func = _CallbackWrapper(callback_func, identity, + call_with_device) + if callback_flavor is CallbackFlavor.DEBUG: + assert identity + jax.debug.callback(callback_func, arg) + return arg + elif callback_flavor is CallbackFlavor.PURE: + call_res = jax.pure_callback(callback_func, result_shape, arg, + sharding=sharding) + else: + call_res = io_callback(callback_func, result_shape, arg, + sharding=sharding, + ordered=True) + return call_res if not identity else arg # We need the lock for when we use the CustomCall implementation of callbacks. @@ -843,7 +937,6 @@ def _print_tap_func( threshold: the value of numpy.array2string threshold parameter. **kwargs: all other keyword args are printed before printing `arg`. """ - def emit_str(s: str): if output_stream is not None: output_stream.write(s + "\n") @@ -1435,7 +1528,7 @@ def _rewrite_eqn(eqn: core.JaxprEqn, eqns: list[core.JaxprEqn], body_jaxpr=_rewrite_closed_jaxpr(body_jaxpr, True, True), cond_jaxpr=_rewrite_closed_jaxpr(cond_jaxpr, True, False)))) elif eqn.primitive is lax.cond_p: - branches, linear = util.split_dict(eqn.params, ["branches", "linear"]) + branches, = util.split_dict(eqn.params, ["branches"]) index, *operands = eqn.invars new_invars = [index, *operands, input_token_var, input_itoken_var] eqns.append( @@ -1445,13 +1538,12 @@ def _rewrite_eqn(eqn: core.JaxprEqn, eqns: list[core.JaxprEqn], eqn.params, branches=tuple( _rewrite_closed_jaxpr(jaxpr, True, True) - for jaxpr in branches), - linear=(*linear, False, False)))) + for jaxpr in branches)))) elif eqn.primitive is lax.scan_p: - num_consts, num_carry, carry_jaxpr, linear, _, _, _ = util.split_dict( + num_consts, num_carry, carry_jaxpr, linear, _, _, _, _ = util.split_dict( eqn.params, ["num_consts", "num_carry", "jaxpr", "linear", "reverse", "length", - "unroll"]) + "unroll", "_split_transpose"]) # We add the tokens right at the end of carry nr_const_and_carry = num_consts + num_carry new_invars = eqn.invars[0:nr_const_and_carry] + [ @@ -1549,6 +1641,8 @@ def unreachable_thunk(): eqn.params["out_shardings"] + (sharding_impls.UNSPECIFIED, sharding_impls.UNSPECIFIED) ), + in_layouts=(eqn.params["in_layouts"] + (None, None)), + out_layouts=(eqn.params["out_layouts"] + (None, None)), ), ) ) @@ -1799,20 +1893,20 @@ def _initialize_outfeed_receiver( _callback_handler_data.receiver = outfeed_receiver_module.start( _callback_input_received, tuple(clients_with_outfeed), max_callback_queue_size_bytes, - compiler.get_compile_options(1, 1).executable_build_options) # type:ignore + compiler.get_compile_options(1, 1).executable_build_options) def exit_handler(): # Prevent logging usage during compilation, gives errors under pytest - dispatch._on_exit = True # type: ignore[protected-access] + dispatch._on_exit = True if not _callback_handler_data.on_exit: _callback_handler_data.on_exit = True - barrier_wait("at_exit") + _deprecated_barrier_wait("at_exit") atexit.register(exit_handler) # We wait as long as we have callbacks _callback_handler_data.initialized = True -def barrier_wait(logging_name: str | None = None): +def _deprecated_barrier_wait(logging_name: str | None = None): """Blocks the calling thread until all current outfeed is processed. Waits until all callbacks from computations already running on all devices @@ -1832,6 +1926,10 @@ def barrier_wait(logging_name: str | None = None): For more details see the :mod:`jax.experimental.host_callback` module documentation. """ + if not _HOST_CALLBACK_LEGACY.value: + jax.effects_barrier() + return + logging_name = logging_name or "" logger.debug("barrier_wait[%s]: start", logging_name) @@ -1858,7 +1956,7 @@ def barrier_tap_received(dev_idx, _): for d_idx, d in enumerate(_callback_handler_data.devices): logger.debug("barrier_wait[%s]: enqueueing barrier on device %s", logging_name, d) x_on_dev = api.device_put(d_idx, device=d) - api.jit(lambda x: id_tap(barrier_tap_received, x), device=d)(x_on_dev) + api.jit(lambda x: _deprecated_id_tap(barrier_tap_received, x), device=d)(x_on_dev) logger.debug("barrier_wait[%s]: waiting for callbacks", logging_name) @@ -1875,9 +1973,14 @@ def barrier_tap_received(dev_idx, _): f"Last one was: {formatted_last_exception}") from last_exception -def stop_outfeed_receiver(): +def _deprecated_stop_outfeed_receiver(): """Stops the outfeed receiver runtime. + .. warning:: + The host_callback APIs are deprecated as of March 20, 2024. + The functionality is subsumed by the + `new JAX external callbacks `_ + This waits for all outfeeds from computations already running on all devices, and then stops the outfeed receiver runtime. The runtime will be restarted next time you use a tap function. @@ -1886,3 +1989,30 @@ def stop_outfeed_receiver(): using lax.outfeed directly after having used host callbacks. """ _callback_handler_data.stop() + +_deprecation_msg = ( + "The host_callback APIs are deprecated as of March 20, 2024. The functionality " + "is subsumed by the new JAX external callbacks. " + "See https://github.com/google/jax/issues/20385.") + +_deprecations = { + # Added March 20, 2024 + "id_tap": (_deprecation_msg, _deprecated_id_tap), + "id_print": (_deprecation_msg, _deprecated_id_print), + "call": (_deprecation_msg, _deprecated_call), + "barrier_wait": (_deprecation_msg, _deprecated_barrier_wait), + "stop_outfeed_receiver": (_deprecation_msg, _deprecated_stop_outfeed_receiver), +} + +import typing +if typing.TYPE_CHECKING: + id_tap = _deprecated_id_tap + id_print = _deprecated_id_print + call = _deprecated_call + barrier_wait = _deprecated_barrier_wait + stop_outfeed_receiver = _deprecated_stop_outfeed_receiver +else: + from jax._src.deprecations import deprecation_getattr as _deprecation_getattr + __getattr__ = _deprecation_getattr(__name__, _deprecations) + del _deprecation_getattr +del typing diff --git a/jax/experimental/jax2tf/BUILD b/jax/experimental/jax2tf/BUILD index a4c52bae4e83..d60b4c333a4a 100644 --- a/jax/experimental/jax2tf/BUILD +++ b/jax/experimental/jax2tf/BUILD @@ -48,6 +48,5 @@ py_library( visibility = jax_visibility("jax2tf_internal"), deps = [ "//jax", - "//jax/experimental/export", ] + py_deps("numpy") + py_deps("tensorflow_core") + jax2tf_deps, ) diff --git a/jax/experimental/jax2tf/README.md b/jax/experimental/jax2tf/README.md index 5e4f462a9791..265cd120cb13 100644 --- a/jax/experimental/jax2tf/README.md +++ b/jax/experimental/jax2tf/README.md @@ -1004,6 +1004,7 @@ We list here a history of the serialization version numbers: Supported by XlaCallModule since October 27th, 2023, available in JAX since October 20th, 2023 (JAX 0.4.20), and the default since February 1st, 2024 (JAX 0.4.24). + This is the only supported version as of 27th of March, 2024. ## Known issues diff --git a/jax/experimental/jax2tf/call_tf.py b/jax/experimental/jax2tf/call_tf.py index 077ae796e3e2..adf43b6b94c0 100644 --- a/jax/experimental/jax2tf/call_tf.py +++ b/jax/experimental/jax2tf/call_tf.py @@ -25,10 +25,10 @@ from __future__ import annotations -from collections.abc import Sequence +from collections.abc import Callable, Sequence import dataclasses import functools -from typing import Any, Callable, Optional +from typing import Any from absl import logging import jax @@ -40,7 +40,6 @@ from jax._src import core from jax._src import effects from jax._src import util -from jax._src import xla_bridge from jax._src.lib import xla_client from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import func as func_dialect @@ -334,7 +333,7 @@ def _arg_jax_to_tf(arg_jax): if (isinstance(arg_jax, jax.Array) and list(arg_jax.devices())[0].platform in _DLPACK_PLATFORMS and arg_jax.dtype.type in dlpack.SUPPORTED_DTYPES): - arg_dlpack = jax.dlpack.to_dlpack(arg_jax, take_ownership=False) + arg_dlpack = jax.dlpack.to_dlpack(arg_jax) return tf.experimental.dlpack.from_dlpack(arg_dlpack) # The following avoids copies to the host on CPU, always for Array # and even for ndarray if they are sufficiently aligned. @@ -354,10 +353,7 @@ def _res_tf_to_jax(res_tf: TfVal): if isinstance(res_tf, tf.Tensor) and jax_dtype.type in dlpack.SUPPORTED_DTYPES: res_tf_platform = tf.DeviceSpec.from_string(res_tf.backing_device).device_type res_jax_platform = res_tf_platform.lower() - # Skip using dlpack in PJRT C API runtime, because it currently fails - # with "PJRT C API does not support GetDefaultLayout". - # https://github.com/openxla/xla/blob/762bde36adf22792e91c38fe87cabe5af05bfadc/xla/pjrt/pjrt_c_api_client.h#L285-L289 - if res_jax_platform in _DLPACK_PLATFORMS and not xla_bridge.using_pjrt_c_api(): + if res_jax_platform in _DLPACK_PLATFORMS: res_dlpack = tf.experimental.dlpack.to_dlpack(res_tf) return jax.dlpack.from_dlpack(res_dlpack) @@ -564,14 +560,15 @@ def canonical_res_aval(res_shape: xla_client.Shape) -> core.ShapedArray: else: result_shapes = result_shape.tuple_shapes() # type: ignore - result_avals = tuple(map(canonical_res_aval, result_shapes)) # type: ignore + result_avals = tuple(map(canonical_res_aval, result_shapes)) submodule = mlir.xla_computation_to_mlir_module(xla_comp) symtab = ir.SymbolTable(submodule.operation) callee_result_types = symtab["main"].type.results fn = mlir.merge_mlir_modules(ctx.module_context.module, f"call_tf_{function_flat_tf.name}", - submodule) + submodule, + dst_symtab=ctx.module_context.symbol_table) call = func_dialect.CallOp(callee_result_types, ir.FlatSymbolRefAttr.get(fn), tuple(args_op) + captured_ops) diff --git a/jax/experimental/jax2tf/examples/keras_reuse_main.py b/jax/experimental/jax2tf/examples/keras_reuse_main.py index 1f8fbea5b44f..77f882af6850 100644 --- a/jax/experimental/jax2tf/examples/keras_reuse_main.py +++ b/jax/experimental/jax2tf/examples/keras_reuse_main.py @@ -20,9 +20,9 @@ import logging from absl import app from absl import flags -from jax.experimental.jax2tf.examples import mnist_lib # type: ignore -from jax.experimental.jax2tf.examples import saved_model_main # type: ignore -import tensorflow as tf # type: ignore +from jax.experimental.jax2tf.examples import mnist_lib +from jax.experimental.jax2tf.examples import saved_model_main +import tensorflow as tf import tensorflow_datasets as tfds # type: ignore import tensorflow_hub as hub # type: ignore diff --git a/jax/experimental/jax2tf/examples/keras_reuse_main_test.py b/jax/experimental/jax2tf/examples/keras_reuse_main_test.py index 562369cdb9df..2934842912f0 100644 --- a/jax/experimental/jax2tf/examples/keras_reuse_main_test.py +++ b/jax/experimental/jax2tf/examples/keras_reuse_main_test.py @@ -16,13 +16,13 @@ from absl import flags from absl.testing import absltest from absl.testing import parameterized +import jax from jax._src import test_util as jtu -from jax import config from jax.experimental.jax2tf.examples import keras_reuse_main from jax.experimental.jax2tf.tests import tf_test_util -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() FLAGS = flags.FLAGS diff --git a/jax/experimental/jax2tf/examples/mnist_lib.py b/jax/experimental/jax2tf/examples/mnist_lib.py index d69f55278dc0..41173c79a5b9 100644 --- a/jax/experimental/jax2tf/examples/mnist_lib.py +++ b/jax/experimental/jax2tf/examples/mnist_lib.py @@ -21,24 +21,24 @@ from __future__ import annotations -from collections.abc import Sequence +from collections.abc import Callable, Sequence import functools import logging import re import time -from typing import Any, Callable, Optional +from typing import Any from absl import flags -import flax # type: ignore[import] +import flax from flax import linen as nn import jax import jax.numpy as jnp -from matplotlib import pyplot as plt # type: ignore +from matplotlib import pyplot as plt import numpy as np import optax -import tensorflow as tf # type: ignore +import tensorflow as tf import tensorflow_datasets as tfds # type: ignore _MOCK_DATA = flags.DEFINE_boolean("mock_data", False, @@ -128,7 +128,7 @@ def predict(params: Sequence[tuple[Any, Any]], inputs, with_classifier=True): final_w, final_b = params[-1] logits = jnp.dot(x, final_w) + final_b return logits - jax.scipy.special.logsumexp( - logits, axis=1, keepdims=True) # type: ignore[attr-defined] + logits, axis=1, keepdims=True) @staticmethod def loss(params, inputs, labels): diff --git a/jax/experimental/jax2tf/examples/saved_model_lib.py b/jax/experimental/jax2tf/examples/saved_model_lib.py index eb7dbeadd0aa..8f2f0982fd3d 100644 --- a/jax/experimental/jax2tf/examples/saved_model_lib.py +++ b/jax/experimental/jax2tf/examples/saved_model_lib.py @@ -26,11 +26,11 @@ from __future__ import annotations -from collections.abc import Sequence -from typing import Any, Callable, Optional, Union +from collections.abc import Callable, Sequence +from typing import Any -from jax.experimental import jax2tf # type: ignore[import] -import tensorflow as tf # type: ignore[import] +from jax.experimental import jax2tf +import tensorflow as tf def convert_and_save_model( diff --git a/jax/experimental/jax2tf/examples/saved_model_main.py b/jax/experimental/jax2tf/examples/saved_model_main.py index 0dfd9382157d..27fffdf94d41 100644 --- a/jax/experimental/jax2tf/examples/saved_model_main.py +++ b/jax/experimental/jax2tf/examples/saved_model_main.py @@ -30,11 +30,11 @@ from absl import app from absl import flags -from jax.experimental.jax2tf.examples import mnist_lib # type: ignore -from jax.experimental.jax2tf.examples import saved_model_lib # type: ignore +from jax.experimental.jax2tf.examples import mnist_lib +from jax.experimental.jax2tf.examples import saved_model_lib import numpy as np -import tensorflow as tf # type: ignore +import tensorflow as tf import tensorflow_datasets as tfds # type: ignore _MODEL = flags.DEFINE_enum( diff --git a/jax/experimental/jax2tf/examples/serving/model_server_request.py b/jax/experimental/jax2tf/examples/serving/model_server_request.py index 4319c1026886..0c64b7b55f8e 100644 --- a/jax/experimental/jax2tf/examples/serving/model_server_request.py +++ b/jax/experimental/jax2tf/examples/serving/model_server_request.py @@ -15,20 +15,20 @@ See README.md for instructions. """ -import grpc # type: ignore[import] +import grpc # type: ignore import json import logging -import requests # type: ignore[import] +import requests from absl import app from absl import flags -from jax.experimental.jax2tf.examples import mnist_lib # type: ignore +from jax.experimental.jax2tf.examples import mnist_lib import numpy as np -import tensorflow as tf # type: ignore[import] -import tensorflow_datasets as tfds # type: ignore[import] -from tensorflow_serving.apis import predict_pb2 # type: ignore[import] +import tensorflow as tf +import tensorflow_datasets as tfds # type: ignore[import-not-found] +from tensorflow_serving.apis import predict_pb2 # type: ignore[import-not-found] from tensorflow_serving.apis import prediction_service_pb2_grpc diff --git a/jax/experimental/jax2tf/examples/tf_js/quickdraw/input_pipeline.py b/jax/experimental/jax2tf/examples/tf_js/quickdraw/input_pipeline.py index 19c564bdac70..1aeb18a08151 100644 --- a/jax/experimental/jax2tf/examples/tf_js/quickdraw/input_pipeline.py +++ b/jax/experimental/jax2tf/examples/tf_js/quickdraw/input_pipeline.py @@ -11,11 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from jax import numpy as jnp # type: ignore +from jax import numpy as jnp -import numpy as np # type: ignore +import numpy as np import os -import requests # type: ignore +import requests def download_dataset(dir_path, nb_classes): diff --git a/jax/experimental/jax2tf/examples/tf_js/quickdraw/quickdraw.py b/jax/experimental/jax2tf/examples/tf_js/quickdraw/quickdraw.py index 08a207e1293b..0bc7592045d7 100644 --- a/jax/experimental/jax2tf/examples/tf_js/quickdraw/quickdraw.py +++ b/jax/experimental/jax2tf/examples/tf_js/quickdraw/quickdraw.py @@ -21,7 +21,6 @@ from flax.training import train_state import jax -from jax import lax from jax import numpy as jnp import optax @@ -30,7 +29,7 @@ import tensorflow as tf import tensorflowjs as tfjs -import input_pipeline # type: ignore[import] +from . import input_pipeline _NUM_EPOCHS = flags.DEFINE_integer( diff --git a/jax/experimental/jax2tf/examples/tf_js/quickdraw/third_party/zaidalyafeai.github.io/LICENSE b/jax/experimental/jax2tf/examples/tf_js/quickdraw/third_party/zaidalyafeai.github.io/LICENSE index 8c1ffd15dcb3..f1c1c6534356 100644 --- a/jax/experimental/jax2tf/examples/tf_js/quickdraw/third_party/zaidalyafeai.github.io/LICENSE +++ b/jax/experimental/jax2tf/examples/tf_js/quickdraw/third_party/zaidalyafeai.github.io/LICENSE @@ -18,4 +18,4 @@ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. +SOFTWARE. \ No newline at end of file diff --git a/jax/experimental/jax2tf/examples/tflite/mnist/mnist.py b/jax/experimental/jax2tf/examples/tflite/mnist/mnist.py index 6a2494d40fdd..71f2eebee2a4 100644 --- a/jax/experimental/jax2tf/examples/tflite/mnist/mnist.py +++ b/jax/experimental/jax2tf/examples/tflite/mnist/mnist.py @@ -21,8 +21,8 @@ import numpy as np -import tensorflow as tf # type: ignore[import] -import tensorflow_datasets as tfds # type: ignore[import] +import tensorflow as tf +import tensorflow_datasets as tfds # type: ignore[import-not-found] _TFLITE_FILE_PATH = flags.DEFINE_string( 'tflite_file_path', diff --git a/jax/experimental/jax2tf/impl_no_xla.py b/jax/experimental/jax2tf/impl_no_xla.py index f3afe10f7ec4..5ecde602cdaa 100644 --- a/jax/experimental/jax2tf/impl_no_xla.py +++ b/jax/experimental/jax2tf/impl_no_xla.py @@ -16,12 +16,12 @@ from __future__ import annotations import builtins -from collections.abc import Sequence +from collections.abc import Callable, Sequence import dataclasses from functools import partial, wraps import math import string -from typing import Any, Callable, Optional +from typing import Any from jax._src import core from jax import lax @@ -32,7 +32,7 @@ from jax.experimental.jax2tf import jax2tf import numpy as np -import tensorflow as tf # type: ignore[import] +import tensorflow as tf # Implementation rules for primitives when XLA is not linked in. These diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index 343a9ac9e323..b14883b426ff 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -15,7 +15,7 @@ from __future__ import annotations -from collections.abc import Iterable, Sequence +from collections.abc import Callable, Iterable, Sequence from functools import partial import contextlib import math @@ -23,7 +23,7 @@ import os import re import threading -from typing import Any, Callable, Union +from typing import Any, Union import warnings from absl import logging @@ -36,10 +36,7 @@ from jax import numpy as jnp from jax import tree_util from jax import sharding -from jax.experimental import maps -from jax.experimental import export -from jax.experimental.export import _export -from jax.experimental.export import _shape_poly +from jax import export from jax.experimental.jax2tf import impl_no_xla from jax.interpreters import xla @@ -54,12 +51,16 @@ from jax._src import linear_util as lu from jax._src import op_shardings from jax._src import sharding_impls +from jax._src import maps +from jax._src import mesh from jax._src import pjit from jax._src import prng from jax._src import random as random_internal from jax._src import source_info_util from jax._src import util from jax._src import shard_alike +from jax._src.export import _export +from jax._src.export import shape_poly from jax._src.interpreters import ad from jax._src.interpreters import mlir from jax._src.lax import control_flow as lax_control_flow @@ -70,24 +71,24 @@ from jax._src.lib import xla_client from jax._src.numpy.ufuncs import logaddexp -import tensorflow as tf # type: ignore[import] +import tensorflow as tf # These don't have public equivalents. # pylint: disable=g-direct-tensorflow-import -from tensorflow.compiler.tf2xla.python import xla as tfxla # type: ignore[import] -from tensorflow.compiler.xla import xla_data_pb2 # type: ignore[import] -from tensorflow.core.framework import attr_value_pb2 # type: ignore[import] +from tensorflow.compiler.tf2xla.python import xla as tfxla +from tensorflow.compiler.xla import xla_data_pb2 +from tensorflow.core.framework import attr_value_pb2 try: - from tensorflow.python.compiler.xla.experimental import xla_sharding # type: ignore[import] + from tensorflow.python.compiler.xla.experimental import xla_sharding except ModuleNotFoundError: # This can be removed when TF 2.10 support is no longer needed. - from tensorflow.compiler.xla.experimental.xla_sharding import xla_sharding # type: ignore[import] -from tensorflow.python.framework import ops as tf_ops # type: ignore[import] -from tensorflow.python.eager import context as tf_context # type: ignore[import] + from tensorflow.compiler.xla.experimental.xla_sharding import xla_sharding +from tensorflow.python.framework import ops as tf_ops +from tensorflow.python.eager import context as tf_context # pylint: enable=g-direct-tensorflow-import NameStack = source_info_util.NameStack -PolyShape = _shape_poly.PolyShape # TODO: deprecate +PolyShape = shape_poly.PolyShape # TODO: deprecate DType = Any DisabledSafetyCheck = export.DisabledSafetyCheck @@ -387,13 +388,13 @@ def jax_arg_spec_from_tf(a: TfVal) -> jax.ShapeDtypeStruct: args_jax_specs = tree_util.tree_map(jax_arg_spec_from_tf, args_tf) args_specs = export.symbolic_args_specs( - args_jax_specs, polymorphic_shapes=polymorphic_shapes, - symbolic_constraints=polymorphic_constraints) + args_jax_specs, polymorphic_shapes, + constraints=polymorphic_constraints) # The polymorphic_shapes argument refers to positional arguments only. # We assume None for the kwargs. kwargs_jax_specs = tree_util.tree_map(jax_arg_spec_from_tf, kwargs_tf) kwargs_specs = export.symbolic_args_specs( - kwargs_jax_specs, polymorphic_shapes=None) + kwargs_jax_specs, None) combined_args_tf = (args_tf, kwargs_tf) args_flat_tf: Sequence[TfVal] args_flat_tf, args_kwargs_tree = tree_util.tree_flatten(combined_args_tf) @@ -513,11 +514,15 @@ def _restore_context(): _thread_local_state.call_tf_concrete_function_list = _prev_func_list self._restore_context = _restore_context - self.exported = export.export( + _exported_device_assignment = [None] + self.exported = _export.export_back_compat( self.fun_jax, lowering_platforms=self.native_serialization_platforms, - disabled_checks=self.native_serialization_disabled_checks + disabled_checks=self.native_serialization_disabled_checks, + _device_assignment_for_internal_jax2tf_use_only=_exported_device_assignment, )(*self.args_specs, **self.kwargs_specs) + assert(_exported_device_assignment[0] is not None) + self.device_assignment = _exported_device_assignment[0] def after_conversion(self): self._restore_context() @@ -533,10 +538,10 @@ def get_vjp_fun(self) -> tuple[Callable, return _export._get_vjp_fun(self.fun_jax, in_tree=self.exported.in_tree, in_avals=self.exported.in_avals, - in_shardings=self.exported.in_shardings, + in_shardings_hlo=self.exported.in_shardings_hlo, out_avals=self.exported.out_avals, - out_shardings=self.exported.out_shardings, - nr_devices=self.exported.nr_devices, + out_shardings_hlo=self.exported.out_shardings_hlo, + device_assignment=self.device_assignment, apply_jit=True) class GraphSerializationImpl(SerializationImpl): @@ -574,9 +579,9 @@ def _restore_context(): (self.args_specs, self.kwargs_specs)) self.args_avals_flat = tuple( map(lambda a: core.raise_to_shaped(core.get_aval(a)), args_specs_flat)) - dim_vars = _shape_poly.all_dim_vars(self.args_avals_flat) + dim_vars = shape_poly.all_dim_vars(self.args_avals_flat) dim_values, _ = _interpret_fun_jax( - partial(_shape_poly.compute_dim_vars_from_arg_shapes, + partial(shape_poly.compute_dim_vars_from_arg_shapes, self.args_avals_flat, args_kwargs_tree=self.in_tree), self.args_flat_tf, self.args_avals_flat, self.name_stack) @@ -605,10 +610,10 @@ def get_vjp_fun(self) -> tuple[Callable, return _export._get_vjp_fun(self.fun_jax, in_tree=self.in_tree, in_avals=self.args_avals_flat, - in_shardings=(None,) * len(self.args_avals_flat), + in_shardings_hlo=(None,) * len(self.args_avals_flat), out_avals=self.outs_avals, - out_shardings=(None,) * len(self.outs_avals), - nr_devices=1, # Does not matter for unspecified shardings + out_shardings_hlo=(None,) * len(self.outs_avals), + device_assignment=None, # Not used when apply_jit = False apply_jit=False) @@ -672,7 +677,7 @@ def eval_polymorphic_shape(fun_jax: Callable, """ def do_eval_polymorphic_shape(*args_specs) -> Any: args_poly_specs = export.symbolic_args_specs( - args_specs, polymorphic_shapes=polymorphic_shapes) + args_specs, polymorphic_shapes) res_poly_spec = jax.eval_shape(fun_jax, *args_poly_specs) # TODO(necula): For now we export the polymorphic shapes using `str`. res_polymorphic_shape = tree_util.tree_map(lambda r: str(r.shape), res_poly_spec) @@ -777,7 +782,7 @@ def fix_out_ct(out_ct_tf, out_ct_aval: core.ShapedArray, out_tf: TfVal): vjp_polymorphic_shapes = tuple( str(a.shape) # Note: may be _DimExpr, not just DimVar - for a in vjp_in_avals) # type: ignore + for a in vjp_in_avals) in_cts_flat = convert( fun_vjp_jax, with_gradient=with_gradient, @@ -807,7 +812,7 @@ def _interpret_fun_jax( extra_name_stack: str | None, fresh_constant_cache: bool = False, ) -> tuple[tuple[TfVal, ...], tuple[core.ShapedArray, ...]]: - with core.new_base_main(TensorFlowTrace) as main: # type: ignore + with core.new_base_main(TensorFlowTrace) as main: subtrace_fun = _interpret_subtrace(lu.wrap_init(fun_jax), main, args_avals) with _extended_name_stack(extra_name_stack): with core.new_sublevel(): @@ -852,7 +857,7 @@ def _convert_value(val, aval): kept_args_avals = [aval for i, aval in enumerate(exported.in_avals) if i in exported.module_kept_var_idx] kept_args_flat_tf = [atf for i, atf in enumerate(args_flat_tf) if i in exported.module_kept_var_idx] - version = exported.mlir_module_serialization_version + version = exported.calling_convention_version try: get_max_supported_version = tfxla.call_module_maximum_supported_version @@ -885,7 +890,7 @@ def _convert_value(val, aval): has_token_input_output=False ) - call_module_attrs["platforms"] = tuple(p.upper() for p in exported.lowering_platforms) + call_module_attrs["platforms"] = tuple(p.upper() for p in exported.platforms) if version >= 6: call_module_attrs["disabled_checks"] = tuple( str(dc) @@ -911,7 +916,7 @@ def _convert_value(val, aval): # See b/255511660. kept_in_shardings = [] for i in exported.module_kept_var_idx: - kept_in_shardings.append(exported.in_shardings[i]) + kept_in_shardings.append(exported.in_shardings_hlo[i]) args_flat_tf = tuple( map(partial(_shard_value, skip_replicated_sharding=tf.executing_eagerly()), @@ -928,7 +933,7 @@ def _convert_value(val, aval): res = list(map(partial(_shard_value, skip_replicated_sharding=tf.executing_eagerly()), - res, exported.out_shardings)) + res, exported.out_shardings_hlo)) res = tuple(map(_convert_value, res, exported.out_avals)) return res @@ -1008,7 +1013,7 @@ def _interpret_subtrace(main: core.MainTrace, for val, aval in zip(in_vals, in_avals)) outs = yield in_tracers, {} # type: Sequence[TfVal] out_tracers: Iterable[TensorFlowTracer] = ( - map(trace.full_raise, outs)) # type: ignore + map(trace.full_raise, outs)) out_vals_with_avals: Sequence[tuple[TfVal, core.ShapedArray]] = ( tuple((t.val, t.aval) for t in out_tracers)) yield out_vals_with_avals @@ -1055,7 +1060,7 @@ def _aval_to_tf_shape(aval: core.ShapedArray) -> tuple[int | None, ...]: """Generate a TF shape, possibly containing None for polymorphic dimensions.""" aval = _jax_physical_aval(aval) return tuple(map(lambda d: None if export.is_symbolic_dim(d) else d, - aval.shape)) # type: ignore[attr-defined] + aval.shape)) # In the TF world, we represent float0 as zeros of this type. # We pick bool because this is what JAX uses when it lowers float0 to HLO. @@ -1143,8 +1148,8 @@ def _tfval_to_tensor_jax_dtype(val: TfVal, return tf_val, jax_dtype -def _eval_shape(shape: Sequence[_shape_poly.DimSize], dtype=None) -> Sequence[TfVal]: - # Returns a tuple of _shape_poly.dim_as_value_dtype +def _eval_shape(shape: Sequence[shape_poly.DimSize], dtype=None) -> Sequence[TfVal]: + # Returns a tuple of shape_poly.dim_as_value_dtype # Used only for non-native lowering assert all(map(lambda x: x is not None, shape)), ( f"Argument shape should be a valid JAX shape but got {shape}") @@ -1169,7 +1174,7 @@ def _ensure_tf_shape_if_dynamic(x: TfVal, shape): return tf.ensure_shape(x, shape) -def _assert_matching_abstract_shape(x: TfVal, shape: Sequence[_shape_poly.DimSize]): +def _assert_matching_abstract_shape(x: TfVal, shape: Sequence[shape_poly.DimSize]): """Asserts that shape matches x.shape in the known dimensions and has dimension polynomials elsewhere.""" # Ensures that the shape does not contain None; it should contain symbolic expressions. @@ -1225,22 +1230,22 @@ def __init__(self, trace: TensorFlowTrace, val: TfVal, else: assert phys_aval.dtype == _to_jax_dtype(val.dtype), f"expected {phys_aval.dtype} == {val.dtype}" - for aval_dim, val_dim in zip(phys_aval.shape, val_shape): # type: ignore[attr-defined] + for aval_dim, val_dim in zip(phys_aval.shape, val_shape): if val_dim is None: - assert export.is_symbolic_dim(aval_dim), f"expected {phys_aval.shape} == {val_shape}" # type: ignore[attr-defined] + assert export.is_symbolic_dim(aval_dim), f"expected {phys_aval.shape} == {val_shape}" elif not export.is_symbolic_dim(aval_dim): - assert aval_dim == val_dim, f"expected {phys_aval.shape} == {val_shape}" # type: ignore[attr-defined] + assert aval_dim == val_dim, f"expected {phys_aval.shape} == {val_shape}" else: # We have a TF value with known shape, and the abstract shape is a shape variable. try: aval_int = int(_eval_shape([aval_dim])) # type: ignore except (TypeError, KeyError): continue - assert aval_int == val_dim, f"expected {phys_aval.shape} == {val_shape}. Found {aval_int} != {val_dim}." # type: ignore + assert aval_int == val_dim, f"expected {phys_aval.shape} == {val_shape}. Found {aval_int} != {val_dim}." self.val = _tfval_to_tensor_jax_dtype(val, phys_aval.dtype, - memoize_constants=True)[0] # type: ignore[attr-defined] + memoize_constants=True)[0] @property def aval(self): @@ -1325,7 +1330,7 @@ def invoke_impl() -> TfVal: if impl_needs_avals: return impl( *args_tf, - _in_avals=args_avals, # type: ignore + _in_avals=args_avals, _out_aval=out_aval, **params) else: @@ -1367,7 +1372,7 @@ def invoke_impl() -> TfVal: out = [ TensorFlowTracer(self, v, a) for v, a in zip(val_out, out_aval) - ] # type: ignore + ] else: out = TensorFlowTracer(self, val_out, out_aval) # type: ignore @@ -1375,13 +1380,13 @@ def invoke_impl() -> TfVal: # TODO: adapt this to match polymorphic shapes if config.enable_checks.value: if primitive.multiple_results: - for o, expected_aval in zip(out, out_aval): # type: ignore + for o, expected_aval in zip(out, out_aval): assert o.aval.strip_weak_type() == expected_aval.strip_weak_type(), ( f"{primitive}: out.aval = {o.aval}; expected {expected_aval}") else: assert out.aval == out_aval, ( # type: ignore f"{primitive}: out.aval = {out.aval}; expected {out_aval}" - ) # type: ignore + ) return out # type: ignore def process_call(self, call_primitive: core.Primitive, fun: lu.WrappedFun, @@ -1464,11 +1469,13 @@ def get_primitive_impl(self, p: core.Primitive) -> tuple[Callable, bool]: def _unexpected_primitive(p: core.Primitive, *args, **kwargs): assert False, f"Encountered unexpected primitive {p}" - -# Call primitives are inlined for unexpected in [core.call_p, maps.xmap_p]: tf_impl[unexpected] = partial(_unexpected_primitive, unexpected) +tf_impl[lax_control_flow.loops.eval_jaxpr_p] = \ + lambda *args, jaxpr: _interpret_jaxpr( + jaxpr, *args, fresh_constant_cache=False, extra_name_stack=None) + # Primitives that are not yet implemented must be explicitly declared here. tf_not_yet_impl = [ "clz", @@ -1526,9 +1533,11 @@ def _unexpected_primitive(p: core.Primitive, *args, **kwargs): "platform_index", "assert_consumed_value", "consume", + "ragged_dot", + "cholesky_update", ] -tf_impl[prng.reuse_key_p] = lambda x: x +tf_impl[random_internal.random_clone_p] = lambda x: x tf_impl[ad_util.stop_gradient_p] = tf.stop_gradient @@ -1538,7 +1547,7 @@ def _add(x: TfVal, y: TfVal) -> TfVal: tf_impl[ad_util.add_jaxvals_p] = _add -tf_impl[dispatch.device_put_p] = lambda x, device=None, src=None: x +tf_impl[dispatch.device_put_p] = lambda *xs, devices=None, srcs=None: xs tf_impl[lax_internal.copy_p] = lambda x: x def _shard_alike(*args: TfVal, **_): @@ -2620,16 +2629,21 @@ def reducer_computation(*args: TfVal) -> TfVal: def _cumred(lax_reduce_fn: Callable, lax_reduce_window_fn: Callable, extra_name_stack: str): - if config.jax2tf_associative_scan_reductions.value: - return _convert_jax_impl(partial(lax_control_flow.associative_scan, - lax_reduce_fn), - multiple_results=False, - extra_name_stack=extra_name_stack) - else: - return _convert_jax_impl(partial(lax_control_flow.cumred_reduce_window_impl, - lax_reduce_window_fn), - multiple_results=False, - extra_name_stack=extra_name_stack) + associative_scan = partial(lax_control_flow.associative_scan, lax_reduce_fn) + reduce_window = partial( + lax_control_flow.cumred_reduce_window_impl, lax_reduce_window_fn + ) + + def _call_impl(*args, **kwargs): + # Vary which implementation to use when cumulation is called. This cannot be + # done during import time because the caller may later use a python context + # to switch the implementation to use. + associative = config.jax2tf_associative_scan_reductions.value + return (associative_scan if associative else reduce_window)(*args, **kwargs) + + return _convert_jax_impl( + _call_impl, multiple_results=False, extra_name_stack=extra_name_stack + ) tf_impl_with_avals[lax.cummax_p] = _cumred( @@ -3005,9 +3019,9 @@ def update_computation(arg1: TfVal, arg2: TfVal) -> TfVal: tf_impl_with_avals[lax.scatter_add_p] = _scatter -def _cond(index: TfVal, *operands: TfVal, branches: Sequence[core.ClosedJaxpr], - linear: Sequence[bool]) -> Sequence[TfVal]: - del linear +def _cond( + index: TfVal, *operands: TfVal, branches: Sequence[core.ClosedJaxpr] +) -> Sequence[TfVal]: # tf.cond needs lambdas with no arguments. branches_tf = [ partial(_interpret_jaxpr, jaxpr, *operands, @@ -3017,7 +3031,7 @@ def _cond(index: TfVal, *operands: TfVal, branches: Sequence[core.ClosedJaxpr], ] # Same name stack as XLA translation of cond_p # Note: extend_name_stack is a contextmanager, which is callable as a decorator. - branches_tf = list(map(source_info_util.extend_name_stack("cond"), # type: ignore[arg-type] + branches_tf = list(map(source_info_util.extend_name_stack("cond"), branches_tf)) if len(branches) == 2: # `index` comes with tf.int32 type of casted boolean parameter. @@ -3033,7 +3047,7 @@ def _while(*args: TfVal, cond_nconsts: int, cond_jaxpr: core.ClosedJaxpr, body_nconsts: int, body_jaxpr: core.ClosedJaxpr) -> Sequence[TfVal]: cond_consts, body_consts, init_carry = util.split_list( args, [cond_nconsts, body_nconsts]) - if cond_jaxpr.out_avals[0].shape: # type: ignore[attr-defined] + if cond_jaxpr.out_avals[0].shape: # The conditional is not a scalar, this must be a batched while return _batched_cond_while( *args, @@ -3118,7 +3132,7 @@ def select_one_carry(new_c: TfVal, c: TfVal, c_aval: core.ShapedArray) -> TfVal: extra_name_stack="scan") tf_impl_with_avals[ad_checkpoint.remat_p] = \ - _convert_jax_impl(partial(ad_checkpoint.remat_lowering, + _convert_jax_impl(partial(ad_checkpoint.remat_expansion, # TODO: jax2tf cannot discriminate by platform is_gpu_platform=False), multiple_results=True, @@ -3449,11 +3463,11 @@ def split_to_logical_devices(tensor: TfVal, def _xla_compatible_sharding_to_hlo_sharding( - s: sharding.XLACompatibleSharding, + s: sharding.Sharding, aval: core.ShapedArray) -> xla_client.HloSharding | None: if sharding_impls.is_unspecified(s): return None - return s._to_xla_hlo_sharding(aval.ndim) # type: ignore[union-attr] + return s._to_xla_hlo_sharding(aval.ndim) def _shard_value(val: TfVal, sd: xla_client.HloSharding | None, *, @@ -3501,9 +3515,10 @@ def _shard_value(val: TfVal, def _pjit(*args: TfVal, jaxpr: core.ClosedJaxpr, - in_shardings: Sequence[sharding.XLACompatibleSharding], - out_shardings: Sequence[sharding.XLACompatibleSharding], - resource_env: maps.ResourceEnv, + in_shardings: Sequence[sharding.Sharding], + out_shardings: Sequence[sharding.Sharding], + in_layouts, out_layouts, + resource_env: mesh.ResourceEnv, donated_invars, name: str, keep_unused: bool, @@ -3534,8 +3549,8 @@ def _pjit(*args: TfVal, def _pjit_sharding_constraint(arg: TfVal, *, - sharding: sharding.XLACompatibleSharding, - resource_env: maps.ResourceEnv, + sharding: sharding.Sharding, + resource_env: mesh.ResourceEnv, _in_avals: Sequence[core.ShapedArray], _out_aval: core.ShapedArray, **kwargs) -> TfVal: @@ -3554,13 +3569,13 @@ def _dimension_size_jax2tf(op: TfVal, *, dimension, _in_avals, _out_aval): else: return dim_tf -tf_impl_with_avals[_shape_poly.dimension_size_p] = _dimension_size_jax2tf +tf_impl_with_avals[shape_poly.dimension_size_p] = _dimension_size_jax2tf -def _dim_as_value_jax2tf(dim: _shape_poly.DimSize): +def _dim_as_value_jax2tf(dim: shape_poly.DimSize): dim_tf, = _eval_shape((dim,)) return dim_tf -tf_impl[_shape_poly.dim_as_value_p] = _dim_as_value_jax2tf +tf_impl[shape_poly.dim_as_value_p] = _dim_as_value_jax2tf def _shape_assertion_jax2tf(assert_what, *error_message_inputs, error_message: str): @@ -3570,7 +3585,7 @@ def _shape_assertion_jax2tf(assert_what, *error_message_inputs, message=error_message.format(*error_message_inputs)) return [] -tf_impl[_shape_poly.shape_assertion_p] = _shape_assertion_jax2tf +tf_impl[shape_poly.shape_assertion_p] = _shape_assertion_jax2tf def _reduce_precision(x, *, exponent_bits, mantissa_bits): return tfxla.reduce_precision(x, exponent_bits=exponent_bits, diff --git a/jax/experimental/jax2tf/tests/back_compat_tf_test.py b/jax/experimental/jax2tf/tests/back_compat_tf_test.py index 47c5c8360cf5..7f903b70d987 100644 --- a/jax/experimental/jax2tf/tests/back_compat_tf_test.py +++ b/jax/experimental/jax2tf/tests/back_compat_tf_test.py @@ -20,14 +20,13 @@ from __future__ import annotations import base64 -from collections.abc import Sequence +from collections.abc import Callable, Sequence import io import os import tarfile -from typing import Callable, Optional from absl.testing import absltest -from jax import config +import jax from jax._src import test_util as jtu from jax._src.internal_test_util import export_back_compat_test_util as bctu from jax._src.lib import xla_extension @@ -37,7 +36,7 @@ import tensorflow as tf -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() def serialize_directory(directory_path): @@ -78,7 +77,7 @@ def tf_func(the_input): # Use recognizable names for input and result return tf.identity(res, name="the_result") self.tf_func = tf_func - return tf_func(*data.inputs) # type: ignore + return tf_func(*data.inputs) def serialize( self, diff --git a/jax/experimental/jax2tf/tests/call_tf_test.py b/jax/experimental/jax2tf/tests/call_tf_test.py index 57eea5f6a35d..5740b76038d8 100644 --- a/jax/experimental/jax2tf/tests/call_tf_test.py +++ b/jax/experimental/jax2tf/tests/call_tf_test.py @@ -13,36 +13,35 @@ # limitations under the License. """Tests for call_tf.""" +from collections.abc import Callable import contextlib from functools import partial import os -from typing import Callable import unittest from absl import logging from absl.testing import absltest from absl.testing import parameterized import jax -from jax import config from jax import dlpack from jax import dtypes +from jax import export from jax import lax from jax import numpy as jnp +from jax._src import config from jax._src import test_util as jtu -from jax._src import xla_bridge from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import hlo -from jax.experimental import export from jax.experimental import jax2tf from jax.experimental.jax2tf.tests import tf_test_util import numpy as np try: - import tensorflow as tf # type: ignore[import] + import tensorflow as tf except ImportError: tf = None -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() def _maybe_jit(with_jit: bool, func: Callable) -> Callable: @@ -696,7 +695,6 @@ def cos_tf_sin_jax(x): jax.grad(cos_tf_sin_jax)(x) logging.info(jax.make_jaxpr(cos_tf_sin_jax)(x)) - logging.info(jax.xla_computation(cos_tf_sin_jax)(x).as_hlo_text()) def test_tf_gather(self): """tf_gather gradient output is tf.IndexSlices.""" @@ -779,7 +777,7 @@ def f_jax(x): lowering_platforms = ("tpu", "cpu", "cuda") - exp = export.export(f_jax, + exp = export.export(jax.jit(f_jax), lowering_platforms=lowering_platforms)(x) for jax_platform in jax_and_tf_platforms: with self.subTest(jax_platform): @@ -788,7 +786,7 @@ def f_jax(x): logging.info("Running harness natively on %s", jax_device) native_res = f_jax(x_device) logging.info("Running exported harness on %s", jax_device) - exported_res = export.call_exported(exp)(x_device) + exported_res = exp.call(x_device) self.assertAllClose(native_res, exported_res) def test_multi_platform_call_tf_graph(self): @@ -854,7 +852,7 @@ def _transfer_guard(guard_level): with contextlib.ExitStack() as stack: stack.enter_context(jax.transfer_guard_device_to_device(guard_level)) stack.enter_context(jax.transfer_guard_device_to_host(guard_level)) - if not (type_ == jnp.int32 or xla_bridge.using_pjrt_c_api()): + if type_ != jnp.int32: stack.enter_context(jax.transfer_guard_host_to_device(guard_level)) yield @@ -1150,17 +1148,6 @@ def setUp(self): _ = tf.add(1, 1) super().setUp() - def override_serialization_version(self, version_override: int): - version = config.jax_serialization_version - if version != version_override: - self.addCleanup(partial(config.update, - "jax_serialization_version", - version_override)) - config.update("jax_serialization_version", version_override) - logging.info( - "Using JAX serialization version %s", - config.jax_serialization_version) - def test_alternate(self): # Alternate sin/cos with sin in TF and cos in JAX f_tf_inner = tf.math.sin @@ -1275,7 +1262,7 @@ def fun_tf(x): # x:i32[3] @_parameterized_jit def test_shape_poly_static_output_shape(self, with_jit=True): - if config.jax2tf_default_native_serialization: + if jax.config.jax2tf_default_native_serialization: raise unittest.SkipTest("TODO(b/268386622): call_tf with shape polymorphism and native serialization.") x = np.array([0.7, 0.8], dtype=np.float32) @@ -1289,7 +1276,7 @@ def fun_tf(x): @_parameterized_jit def test_shape_poly(self, with_jit=False): - if config.jax2tf_default_native_serialization: + if jax.config.jax2tf_default_native_serialization: raise unittest.SkipTest("TODO(b/268386622): call_tf with shape polymorphism and native serialization.") x = np.array([7, 8, 9, 10], dtype=np.float32) def fun_jax(x): @@ -1308,7 +1295,7 @@ def fun_jax(x): @_parameterized_jit def test_shape_poly_pytree_result(self, with_jit=True): - if config.jax2tf_default_native_serialization: + if jax.config.jax2tf_default_native_serialization: raise unittest.SkipTest("TODO(b/268386622): call_tf with shape polymorphism and native serialization.") x = np.array([7, 8, 9, 10], dtype=np.float32) def fun_jax(x): @@ -1394,7 +1381,7 @@ def fun_jax(x): if kind == "bad_dim" and with_jit: # TODO: in jit more the error pops up later, at AddV2 expect_error = "Dimensions must be equal, but are 4 and 9 for .* AddV2" - if kind == "bad_dim" and config.jax2tf_default_native_serialization: + if kind == "bad_dim" and jax.config.jax2tf_default_native_serialization: # TODO(b/268386622): call_tf with shape polymorphism and native serialization. expect_error = "Error compiling TensorFlow function" fun_tf_rt = _maybe_tf_jit(with_jit, @@ -1432,7 +1419,7 @@ def test_several_round_trips(self, f4_function=False, f4_saved_model=False): if (f2_saved_model and f4_saved_model and - not config.jax2tf_default_native_serialization): + not jax.config.jax2tf_default_native_serialization): # TODO: Getting error Found invalid capture Tensor("jax2tf_vjp/jax2tf_arg_0:0", shape=(), dtype=float32) when saving custom gradients # when saving f4, but only with non-native serialization. raise unittest.SkipTest("TODO: error invalid capture when saving custom gradients") @@ -1661,116 +1648,127 @@ def tf_f_2(): _, _ = tf_test_util.SaveAndLoadFunction(tf_f_rt_2, input_args=[]) @jtu.parameterized_filterable( - kwargs=[dict(version=version) for version in [8, 9]] + kwargs=[dict(version=version) for version in [9]] ) def test_call_tf_graph_ordered(self, *, version: int): - self.override_serialization_version(version) - @tf.function - def tf_print(x): - tf.print(x) - - call_tf_print = jax2tf.call_tf( - tf_print, - call_tf_graph=True, - ordered=True, - ) - - x = jnp.array(1.0, dtype=jnp.float32) + with config.jax_export_calling_convention_version(version): + logging.info( + "Using JAX serialization version %s", + jax.config.jax_export_calling_convention_version) - def body(i, x): - call_tf_print(x) - return x + 1 + @tf.function + def tf_print(x): + tf.print(x) - @jax.jit - def f_jax(x): - return jax.lax.fori_loop(0, 4, body, x) + call_tf_print = jax2tf.call_tf( + tf_print, + call_tf_graph=True, + ordered=True, + ) - num_custom_calls = 0 + x = jnp.array(1.0, dtype=jnp.float32) - def _check_mlir_ops(op): - nonlocal num_custom_calls + def body(i, x): + call_tf_print(x) + return x + 1 - if ( - op.operation.name == "stablehlo.custom_call" - and ir.StringAttr(op.attributes["call_target_name"]).value - == "tf.call_tf_function" + @jax.jit + def f_jax(x): + return jax.lax.fori_loop(0, 4, body, x) + + num_custom_calls = 0 + + def _check_mlir_ops(op): + nonlocal num_custom_calls + + if ( + op.operation.name == "stablehlo.custom_call" + and ir.StringAttr(op.attributes["call_target_name"]).value + == "tf.call_tf_function" + ): + num_custom_calls += 1 + + # The custom call op must have `has_token_input_output` attribute. + tf_backend_config = ir.DictAttr(op.attributes["tf.backend_config"]) + self.assertTrue( + ir.BoolAttr(tf_backend_config["has_token_input_output"]).value + ) + + # Verify that the first argument/result of the custom call op is a token + # type. This is a calling convention defined by `has_token_input_output`. + self.assertTrue(hlo.TokenType.isinstance(op.operands[0].type)) + self.assertTrue(hlo.TokenType.isinstance(op.results[0].type)) + + stablehlo_module = None + with self.assertRaisesRegex( + ValueError, + "call_tf_graph=True only support exporting by jax2tf.convert currently", ): - num_custom_calls += 1 - - # The custom call op must have `has_token_input_output` attribute. - tf_backend_config = ir.DictAttr(op.attributes["tf.backend_config"]) - self.assertTrue( - ir.BoolAttr(tf_backend_config["has_token_input_output"]).value - ) - - # Verify that the first argument/result of the custom call op is a token - # type. This is a calling convention defined by `has_token_input_output`. - self.assertTrue(hlo.TokenType.isinstance(op.operands[0].type)) - self.assertTrue(hlo.TokenType.isinstance(op.results[0].type)) - - stablehlo_module = None - with self.assertRaisesRegex( - ValueError, - "call_tf_graph=True only support exporting by jax2tf.convert currently", - ): - lower = f_jax.lower(x) - self.assertNotEmpty(lower._lowering.compile_args["ordered_effects"]) - stablehlo_module = lower.compiler_ir("stablehlo") - if stablehlo_module: - self._walk_stablehlo_operations(stablehlo_module, _check_mlir_ops) - self.assertEqual(num_custom_calls, 1) - - f_tf = jax2tf.convert( - f_jax, - native_serialization=True, - with_gradient=False, - ) - _, restored_model = tf_test_util.SaveAndLoadFunction(f_tf, input_args=[x]) + lower = f_jax.lower(x) + self.assertNotEmpty(lower._lowering.compile_args["ordered_effects"]) + stablehlo_module = lower.compiler_ir("stablehlo") + if stablehlo_module: + self._walk_stablehlo_operations(stablehlo_module, _check_mlir_ops) + self.assertEqual(num_custom_calls, 1) + + f_tf = jax2tf.convert( + f_jax, + native_serialization=True, + with_gradient=False, + ) + _, restored_model = tf_test_util.SaveAndLoadFunction(f_tf, input_args=[x]) @jtu.parameterized_filterable( kwargs=[dict(poly=poly, version=version) for poly in [True, False] - for version in [8, 9]] + for version in [9]] ) def test_call_tf_ordered_dead_inputs(self, *, poly: bool, version: int): - self.override_serialization_version(version) - def f_jax(x1, x_dead, x3): - return (x1, jax2tf.call_tf(lambda x: tf.math.sin(x), ordered=True, - call_tf_graph=True)(x3)) - if poly: - polymorphic_shapes = ["b", None, None] - else: - polymorphic_shapes = None - f_tf = jax2tf.convert(f_jax, polymorphic_shapes=polymorphic_shapes) - x1 = np.arange(3, dtype=np.float32) - x_dead = np.arange(4, dtype=np.float32) - x3 = np.arange(5, dtype=np.float32) - self.assertAllClose(f_jax(x1, x_dead, x3), - f_tf(x1, x_dead, x3)) + with config.jax_export_calling_convention_version(version): + logging.info( + "Using JAX serialization version %s", + jax.config.jax_export_calling_convention_version) + def f_jax(x1, x_dead, x3): + return (x1, jax2tf.call_tf(lambda x: tf.math.sin(x), ordered=True, + call_tf_graph=True)(x3)) + if poly: + polymorphic_shapes = ["b", None, None] + else: + polymorphic_shapes = None + f_tf = jax2tf.convert(f_jax, polymorphic_shapes=polymorphic_shapes) + x1 = np.arange(3, dtype=np.float32) + x_dead = np.arange(4, dtype=np.float32) + x3 = np.arange(5, dtype=np.float32) + self.assertAllClose(f_jax(x1, x_dead, x3), + f_tf(x1, x_dead, x3)) @jtu.parameterized_filterable( kwargs=[dict(ordered=ordered, version=version) for ordered in [True, False] - for version in [8, 9] + for version in [9] ] ) def test_call_tf_graph_polymorphic(self, ordered: bool, version: int): - self.override_serialization_version(version) - @tf.function(jit_compile=True, autograph=False) - @partial(jax2tf.convert, - with_gradient=False, - native_serialization=True, - polymorphic_shapes=["(b)"]) - @jax.jit - def tf_f_2(x): - tf_f = lambda x: print(tf.strings.length(tf.constant("hello, world"))) - jax2tf.call_tf(tf_f, - call_tf_graph=True, - ordered=ordered)(x) - return x + with config.jax_export_calling_convention_version(version): + logging.info( + "Using JAX serialization version %s", + jax.config.jax_export_calling_convention_version) + + @tf.function(jit_compile=True, autograph=False) + @partial(jax2tf.convert, + with_gradient=False, + native_serialization=True, + polymorphic_shapes=["(b)"]) + @jax.jit + def tf_f_2(x): + tf_f = lambda x: print(tf.strings.length(tf.constant("hello, world"))) + jax2tf.call_tf(tf_f, + call_tf_graph=True, + ordered=ordered)(x) + return x - x = np.arange(3, dtype=np.int32) - _ = tf.function(tf_f_2, autograph=False).get_concrete_function(x) + x = np.arange(3, dtype=np.int32) + _ = tf.function(tf_f_2, autograph=False).get_concrete_function(x) # TODO(b/293927250): call_tf_graph=True only accept concrete_function. The # workaround here is to set `module.call=concrete_fn.`. diff --git a/jax/experimental/jax2tf/tests/control_flow_ops_test.py b/jax/experimental/jax2tf/tests/control_flow_ops_test.py index 253a5ffc68a7..c66a6d696e89 100644 --- a/jax/experimental/jax2tf/tests/control_flow_ops_test.py +++ b/jax/experimental/jax2tf/tests/control_flow_ops_test.py @@ -23,8 +23,7 @@ from jax.experimental.jax2tf.tests import tf_test_util -from jax import config -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() class ControlFlowOpsTest(tf_test_util.JaxToTfTestCase): diff --git a/jax/experimental/jax2tf/tests/converters.py b/jax/experimental/jax2tf/tests/converters.py index f0a293ca52d5..1ed017fc0819 100644 --- a/jax/experimental/jax2tf/tests/converters.py +++ b/jax/experimental/jax2tf/tests/converters.py @@ -12,10 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. """Converters for jax2tf.""" + +from collections.abc import Callable import dataclasses import functools import tempfile -from typing import Any, Callable +from typing import Any + from jax.experimental import jax2tf import tensorflow as tf import tensorflowjs as tfjs diff --git a/jax/experimental/jax2tf/tests/cross_compilation_check.py b/jax/experimental/jax2tf/tests/cross_compilation_check.py index 63e8928ee371..cc34d78e88d4 100644 --- a/jax/experimental/jax2tf/tests/cross_compilation_check.py +++ b/jax/experimental/jax2tf/tests/cross_compilation_check.py @@ -26,12 +26,11 @@ from __future__ import annotations -from collections.abc import Sequence +from collections.abc import Callable, Sequence import contextlib import dataclasses import os import re -from typing import Callable, Optional import zlib from absl import app @@ -39,12 +38,11 @@ import numpy.random as npr -import jax -from jax import config # Must import before TF +import jax # Must import before TF from jax.experimental import jax2tf # Defines needed flags from jax._src import test_util # Defines needed flags -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() # Import after parsing flags from jax.experimental.jax2tf.tests import primitive_harness diff --git a/jax/experimental/jax2tf/tests/flax_models/bilstm_classifier.py b/jax/experimental/jax2tf/tests/flax_models/bilstm_classifier.py index 9bb7466125c5..5b1169224ed9 100644 --- a/jax/experimental/jax2tf/tests/flax_models/bilstm_classifier.py +++ b/jax/experimental/jax2tf/tests/flax_models/bilstm_classifier.py @@ -18,8 +18,9 @@ from __future__ import annotations +from collections.abc import Callable import functools -from typing import Any, Callable, Optional +from typing import Any from flax import linen as nn import jax diff --git a/jax/experimental/jax2tf/tests/flax_models/gnn.py b/jax/experimental/jax2tf/tests/flax_models/gnn.py index a94153d9115e..4a74be446ba1 100644 --- a/jax/experimental/jax2tf/tests/flax_models/gnn.py +++ b/jax/experimental/jax2tf/tests/flax_models/gnn.py @@ -16,8 +16,7 @@ https://github.com/google/flax/tree/main/examples/ogbg_molpcba """ -from collections.abc import Sequence -from typing import Callable +from collections.abc import Callable, Sequence from flax import linen as nn @@ -152,7 +151,7 @@ def pool(self, graphs: jraph.GraphsTuple) -> jraph.GraphsTuple: axis=0, total_repeat_length=sum_n_node) # We use the aggregation function to pool the nodes per graph. - pooled = self.pooling_fn(graphs.nodes, node_graph_indices, n_graph) # type: ignore[call-arg] + pooled = self.pooling_fn(graphs.nodes, node_graph_indices, n_graph) return graphs._replace(globals=pooled) @nn.compact diff --git a/jax/experimental/jax2tf/tests/flax_models/resnet.py b/jax/experimental/jax2tf/tests/flax_models/resnet.py index bb6e519deceb..48829127b304 100644 --- a/jax/experimental/jax2tf/tests/flax_models/resnet.py +++ b/jax/experimental/jax2tf/tests/flax_models/resnet.py @@ -19,9 +19,9 @@ # See issue #620. # pytype: disable=wrong-arg-count -from collections.abc import Sequence +from collections.abc import Callable, Sequence from functools import partial -from typing import Any, Callable +from typing import Any from flax import linen as nn import jax.numpy as jnp diff --git a/jax/experimental/jax2tf/tests/flax_models/transformer_lm1b.py b/jax/experimental/jax2tf/tests/flax_models/transformer_lm1b.py index 334248219962..27535c784e89 100644 --- a/jax/experimental/jax2tf/tests/flax_models/transformer_lm1b.py +++ b/jax/experimental/jax2tf/tests/flax_models/transformer_lm1b.py @@ -24,7 +24,8 @@ from __future__ import annotations -from typing import Callable, Any, Optional +from collections.abc import Callable +from typing import Any from flax import linen as nn from flax import struct diff --git a/jax/experimental/jax2tf/tests/flax_models/transformer_nlp_seq.py b/jax/experimental/jax2tf/tests/flax_models/transformer_nlp_seq.py index 04111e6c4d5b..cc78b5a41496 100644 --- a/jax/experimental/jax2tf/tests/flax_models/transformer_nlp_seq.py +++ b/jax/experimental/jax2tf/tests/flax_models/transformer_nlp_seq.py @@ -18,7 +18,8 @@ from __future__ import annotations -from typing import Callable, Any, Optional +from collections.abc import Callable +from typing import Any from flax import linen as nn from flax import struct diff --git a/jax/experimental/jax2tf/tests/flax_models/transformer_wmt.py b/jax/experimental/jax2tf/tests/flax_models/transformer_wmt.py index 58e50dacd914..1cdeffeb6ea9 100644 --- a/jax/experimental/jax2tf/tests/flax_models/transformer_wmt.py +++ b/jax/experimental/jax2tf/tests/flax_models/transformer_wmt.py @@ -24,7 +24,8 @@ from __future__ import annotations -from typing import Callable, Any, Optional +from collections.abc import Callable +from typing import Any from flax import linen as nn from flax import struct diff --git a/jax/experimental/jax2tf/tests/jax2tf_limitations.py b/jax/experimental/jax2tf/tests/jax2tf_limitations.py index c5a583dca80a..03e6086a4924 100644 --- a/jax/experimental/jax2tf/tests/jax2tf_limitations.py +++ b/jax/experimental/jax2tf/tests/jax2tf_limitations.py @@ -15,9 +15,9 @@ from __future__ import annotations -from collections.abc import Sequence +from collections.abc import Callable, Sequence import itertools -from typing import Any, Callable, Optional, Union +from typing import Any import jax from jax import lax @@ -281,6 +281,10 @@ def asinh(cls, harness: test_harnesses.Harness): custom_numeric(dtypes=[np.complex64], devices=("cpu", "gpu", "tpu"), tol=1e-3), custom_numeric(dtypes=[np.complex128], devices=("cpu", "gpu"), tol=1e-12), + custom_numeric(dtypes=[np.complex128], devices=("cpu",), + modes=("eager", "compiled", "graph"), + tol=1e-13, + native_serialization=Jax2TfLimitation.FOR_NATIVE | Jax2TfLimitation.FOR_NON_NATIVE), cls.helper_get_trig_custom_limitation(np.sinh) ] @@ -389,35 +393,34 @@ def conv_general_dilated(cls, harness: test_harnesses.Harness): @classmethod def cumlogsumexp(cls, harness): return [ - # JAX uses a different lowering for CPU and GPU. custom_numeric( - dtypes=(np.float16, jnp.bfloat16), - devices=("cpu", "gpu"), + dtypes=(np.float16, jnp.bfloat16, np.float32), + devices=("cpu", "gpu", "tpu"), modes=("eager", "graph", "compiled"), - tol=5e-1) + tol=5e-1, + ) ] - @classmethod def cumprod(cls, harness): return [ - # JAX uses a different lowering for CPU and GPU. custom_numeric( dtypes=(np.float16, jnp.bfloat16), - devices=("cpu", "gpu"), + devices=("cpu", "gpu", "tpu"), modes=("eager", "graph", "compiled"), - tol=5e-1) + tol=5e-1, + ) ] @classmethod def cumsum(cls, harness): return [ - # JAX uses a different lowering for CPU and GPU. custom_numeric( dtypes=(np.float16, jnp.bfloat16), - devices=("cpu", "gpu"), + devices=("cpu", "gpu", "tpu"), modes=("eager", "graph", "compiled"), - tol=5e-1) + tol=5e-1, + ) ] @classmethod @@ -547,6 +550,12 @@ def dot_general(cls, harness: test_harnesses.Harness): # may be more precise. custom_numeric(dtypes=[np.float16], devices=["cpu"], tol=1e-2, modes=("eager", "graph", "compiled")), + # Flakiness on different_dtypes_lhs_int16_4_3_rhs_float16_3_6_dimensionnumbers_1_0_enable_xla_True + # Strangely, we only see the flakiness in primitives_graph_serialization_test_gpu_pjrt_c_api + custom_numeric(dtypes=[np.int16], devices=["gpu"], tol=1e-2, + modes=("eager", "graph", "compiled"), + enabled=(harness.params["enable_xla"] and + harness.dtype != harness.params["rhs_dtype"])), ] @classmethod @@ -747,7 +756,11 @@ def fft(cls, harness): enabled=(str(harness.params["fft_type"]) in ["FftType.IFFT", "FftType.IRFFT"])), # TODO: very high tolerance - custom_numeric(tol=1e-3, modes=("eager", "graph", "compiled")), + custom_numeric(tol=1e-3, modes=("eager", "graph", "compiled"), + native_serialization=Jax2TfLimitation.FOR_NON_NATIVE), + custom_numeric(tol=1e-5, modes=("eager", "graph", "compiled"), + native_serialization=Jax2TfLimitation.FOR_NATIVE, + devices=("cpu",)), ] @classmethod @@ -1277,7 +1290,7 @@ def dot_column_wise(a, b): # values like 1.0000001 on float32, which are clipped to 1.0. It is # possible that anything other than `cos_angular_diff` can be outside # the interval [0, 1] due to roundoff. - cos_angular_diff = jnp.clip(cos_angular_diff, a_min=0.0, a_max=1.0) + cos_angular_diff = jnp.clip(cos_angular_diff, min=0.0, max=1.0) angular_diff = jnp.arccos(cos_angular_diff) diff --git a/jax/experimental/jax2tf/tests/jax2tf_test.py b/jax/experimental/jax2tf/tests/jax2tf_test.py index ffd492cf0168..26266f67d4f2 100644 --- a/jax/experimental/jax2tf/tests/jax2tf_test.py +++ b/jax/experimental/jax2tf/tests/jax2tf_test.py @@ -15,9 +15,7 @@ Specific JAX primitive conversion tests are in primitives_test.""" import collections -from collections.abc import Sequence import contextlib -import functools import math import os import re @@ -29,29 +27,38 @@ import jax from jax import ad_checkpoint from jax import dtypes +from jax import export from jax import lax from jax import numpy as jnp from jax import sharding from jax._src import config from jax._src import core +from jax._src.maps import xmap from jax._src import source_info_util from jax._src import test_util as jtu from jax._src import xla_bridge as xb from jax.experimental import jax2tf -from jax.experimental import export from jax.experimental.jax2tf.tests import tf_test_util -from jax.experimental.maps import xmap from jax.experimental.shard_map import shard_map from jax.experimental import pjit from jax.sharding import PartitionSpec as P import numpy as np -import tensorflow as tf # type: ignore[import] +import tensorflow as tf # pylint: disable=g-direct-tensorflow-import -from tensorflow.compiler.tf2xla.python import xla as tfxla # type: ignore[import] +from tensorflow.compiler.tf2xla.python import xla as tfxla # pylint: enable=g-direct-tensorflow-import config.parse_flags_with_absl() +_exit_stack = contextlib.ExitStack() + +# TODO(necula): Remove once tensorflow is 2.10.0 everywhere. +def setUpModule(): + if not hasattr(tfxla, "optimization_barrier"): + _exit_stack.enter_context(jtu.global_config_context(jax_remat_opt_barrier=False)) + +def tearDownModule(): + _exit_stack.close() class Jax2TfTest(tf_test_util.JaxToTfTestCase): @@ -970,8 +977,8 @@ def caller_jax(x): self.assertIn("my_test_function_jax/mul", self.TfToHlo(run_tf)) else: graph_def = str(tf.function(run_tf, autograph=False).get_concrete_function().graph.as_graph_def()) - if "my_test_function_jax/pjit_fn_/Mul" not in graph_def: - self.assertIn("my_test_function_jax/jit_fn_/Mul", graph_def) + if "my_test_function_jax/pjit_multiply_/Mul" not in graph_def: + self.assertIn("my_test_function_jax/jit_multiply_/Mul", graph_def) def test_bfloat16_constant(self): # Re: https://github.com/google/jax/issues/3942 @@ -1307,7 +1314,7 @@ def body_fun(carry): shape = (3, 2) x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape) - jax_comp = jax.xla_computation(f_while)(x) + jax_comp = jax.jit(f_while).lower(x).compiler_ir('hlo') backend = xb.get_backend() modules = backend.compile(jax_comp).hlo_modules() jax_opt_hlo = modules[0].to_string() @@ -1496,10 +1503,10 @@ def apply_transform(func, transform: str): transformed_func = dict( none=func, jit=jax.jit(func), - jit_in_shardings_None=jax.jit(func, in_shardings=None), # type: ignore - jit_in_shardings_P=jax.jit(func, in_shardings=(P("a"),)), # type: ignore + jit_in_shardings_None=jax.jit(func, in_shardings=None), + jit_in_shardings_P=jax.jit(func, in_shardings=(P("a"),)), jit_in_shardings_Sharding=jax.jit( - func, in_shardings=(sharding.NamedSharding(mesh, P("a")),)), # type: ignore + func, in_shardings=(sharding.NamedSharding(mesh, P("a")),)), pjit=pjit.pjit(func), pjit_in_shardings_None=pjit.pjit(func, in_shardings=None, out_shardings=None), @@ -1552,7 +1559,7 @@ def apply_transform(func, transform: str): # Run the JAX native version, to check it works, and to fill caches. _ = func_to_convert(*args) exported = export.export( - func_to_convert, + (jax.jit(func_to_convert) if not hasattr(func_to_convert, "trace") else func_to_convert), lowering_platforms=("tpu",) )(*(core.ShapedArray(a.shape, a.dtype) for a in args)) @@ -1730,6 +1737,15 @@ def f_switch(x): f_switch_tf = jax2tf.convert(f_switch, enable_xla=False) self.assertIn("switch_case", self.TfToHlo(f_switch_tf, np.pi)) + @jtu.skip_on_flag("jax2tf_default_native_serialization", False) + def test_ragged_dot(self): + dtype = np.float32 + m, k, n, num_groups = 5, 4, 3, 2 + lhs = np.arange(m * k, dtype=dtype).reshape((m, k)) + rhs = np.arange(num_groups * k * n, dtype=dtype).reshape((num_groups, k, n)) + group_sizes = np.array([3, 2], dtype=np.int32) + self.ConvertAndCompare(jax.lax.ragged_dot, lhs, rhs, group_sizes) + @jtu.with_config(jax_enable_custom_prng=True) class Jax2tfWithCustomPRNGTest(tf_test_util.JaxToTfTestCase): @@ -1770,7 +1786,4 @@ def test_simple(self): if __name__ == "__main__": - # TODO: Remove once tensorflow is 2.10.0 everywhere. - if not hasattr(tfxla, "optimization_barrier"): - jax.config.update("jax_remat_opt_barrier", False) absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/jax/experimental/jax2tf/tests/jax_primitives_coverage_test.py b/jax/experimental/jax2tf/tests/jax_primitives_coverage_test.py index ec0123324e60..c6e89a6d85db 100644 --- a/jax/experimental/jax2tf/tests/jax_primitives_coverage_test.py +++ b/jax/experimental/jax2tf/tests/jax_primitives_coverage_test.py @@ -42,7 +42,7 @@ @jtu.with_config(jax_legacy_prng_key='allow', - jax_enable_key_reuse_checks=False) + jax_debug_key_reuse=False) class JaxPrimitiveTest(jtu.JaxTestCase): # This test runs for all primitive harnesses. For each primitive "xxx" the diff --git a/jax/experimental/jax2tf/tests/model_harness.py b/jax/experimental/jax2tf/tests/model_harness.py index 9af7229c0530..91aacf2f596f 100644 --- a/jax/experimental/jax2tf/tests/model_harness.py +++ b/jax/experimental/jax2tf/tests/model_harness.py @@ -15,10 +15,10 @@ from __future__ import annotations -from collections.abc import Sequence +from collections.abc import Callable, Sequence import dataclasses import functools -from typing import Any, Callable, Optional, Union +from typing import Any import re import numpy as np diff --git a/jax/experimental/jax2tf/tests/primitives_test.py b/jax/experimental/jax2tf/tests/primitives_test.py index 755a5e9c8cb3..22315f04c881 100644 --- a/jax/experimental/jax2tf/tests/primitives_test.py +++ b/jax/experimental/jax2tf/tests/primitives_test.py @@ -70,7 +70,7 @@ from jax._src.interpreters import xla import numpy as np -import tensorflow as tf # type: ignore[import] +import tensorflow as tf config.parse_flags_with_absl() @@ -179,9 +179,14 @@ def test_primitive_coverage(self): # TODO: Remove once tensorflow is 2.10.0 everywhere. if p.name == "optimization_barrier": continue - if p.name == "debug_callback": + if p.name == "debug_callback" or p.name == "debug_print": # TODO(sharadmv,necula): enable debug callbacks in TF continue + if p.name in ("max_contiguous", "multiple_of"): + # Pallas-specific primitives are not supported. + continue + if p.name == "pallas_call": + continue if p.name in tf_not_yet_impl: self.assertNotIn( p, tf_impl) # Should not be in both tf_impl and tf_not_yet_impl diff --git a/jax/experimental/jax2tf/tests/savedmodel_test.py b/jax/experimental/jax2tf/tests/savedmodel_test.py index 86ae81f8373a..8b71de7db30c 100644 --- a/jax/experimental/jax2tf/tests/savedmodel_test.py +++ b/jax/experimental/jax2tf/tests/savedmodel_test.py @@ -19,14 +19,13 @@ from jax import lax import jax.numpy as jnp import numpy as np -import tensorflow as tf # type: ignore[import] +import tensorflow as tf from jax.experimental import jax2tf from jax.experimental.jax2tf.tests import tf_test_util from jax._src import test_util as jtu -from jax import config -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() class SavedModelTest(tf_test_util.JaxToTfTestCase): diff --git a/jax/experimental/jax2tf/tests/shape_poly_test.py b/jax/experimental/jax2tf/tests/shape_poly_test.py index 1ceb87e695a9..83aac43f2d9d 100644 --- a/jax/experimental/jax2tf/tests/shape_poly_test.py +++ b/jax/experimental/jax2tf/tests/shape_poly_test.py @@ -15,10 +15,10 @@ from __future__ import annotations -from collections.abc import Sequence +from collections.abc import Callable, Sequence import contextlib import math -from typing import Any, Callable +from typing import Any import unittest from absl import logging @@ -32,9 +32,8 @@ import jax from jax.experimental import jax2tf -from jax.experimental import export -from jax.experimental.export import _shape_poly as shape_poly from jax.experimental import pjit +from jax import export from jax import lax import jax.numpy as jnp from jax import random @@ -43,6 +42,7 @@ from jax._src import core from jax._src import test_util as jtu from jax._src import util +from jax._src.export import shape_poly from jax._src.lax import lax as lax_internal from jax._src.lax import control_flow as lax_control_flow from jax._src.lib import xla_client @@ -50,7 +50,7 @@ from jax.experimental.jax2tf.tests import tf_test_util -import tensorflow as tf # type: ignore[import] +import tensorflow as tf config.parse_flags_with_absl() @@ -218,7 +218,7 @@ def log_message(extra: str): if not self.skip_jax_run: res_jax = f_jax(*args) if self.check_result: - res_tf = tf.nest.map_structure(lambda t: t.numpy(), res_tf) # type: ignore + res_tf = tf.nest.map_structure(lambda t: t.numpy(), res_tf) custom_assert_lims = [ l for l in self.limitations if l.custom_assert is not None] assert len(custom_assert_lims) <= 1, custom_assert_lims @@ -615,7 +615,7 @@ def conv_and_run(*, arg_shape: core.Shape, "Using the following polymorphic shapes specifications: args[0].shape = (2*b + a, a, c + b + a). " "Obtained dimension variables: 'a' = 2 from specification 'a' for dimension args[0].shape[1] (= 2), " "'b' = 0 from specification '2*b + a' for dimension args[0].shape[0] (= 2), . " - "Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#shape-assertion-errors for more details." + "Please see https://jax.readthedocs.io/en/latest/export/shape_poly.html#shape-assertion-errors for more details." )), dict(shape=(3, 2, 6), # a = 2, b = 0.5, c = 4 - b is not integer poly_spec="(a + 2*b, a, a + b + c)", @@ -624,7 +624,7 @@ def conv_and_run(*, arg_shape: core.Shape, "Division had remainder 1 when computing the value of 'b'. " "Using the following polymorphic shapes specifications: args[0].shape = (2*b + a, a, c + b + a). " "Obtained dimension variables: 'a' = 2 from specification 'a' for dimension args[0].shape[1] (= 2), . " - "Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#shape-assertion-errors for more details." + "Please see https://jax.readthedocs.io/en/latest/export/shape_poly.html#shape-assertion-errors for more details." )), dict(shape=(8, 2, 6), # a = 2, b = 3 - inconsistency poly_spec="(a + 2*b, a, a + b)", @@ -634,7 +634,7 @@ def conv_and_run(*, arg_shape: core.Shape, "Using the following polymorphic shapes specifications: args[0].shape = (2*b + a, a, b + a). " "Obtained dimension variables: 'a' = 2 from specification 'a' for dimension args[0].shape[1] (= 2), " "'b' = 4 from specification 'b + a' for dimension args[0].shape[2] (= 6), . " - "Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#shape-assertion-errors for more details." + "Please see https://jax.readthedocs.io/en/latest/export/shape_poly.html#shape-assertion-errors for more details." )), dict(shape=(7, 2, 36), # a = 2, b = 3, c = 6 - cannot solve c poly_spec="(2 * a + b, a, c * c)", @@ -643,7 +643,7 @@ def conv_and_run(*, arg_shape: core.Shape, "We can only solve linear uni-variate constraints. " "Using the following polymorphic shapes specifications: args[0].shape = (b + 2*a, a, c^2). " "Unprocessed specifications: 'c^2' for dimension size args[0].shape[2]. " - "Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#dimension-variables-must-be-solvable-from-the-input-shapes for more details." + "Please see https://jax.readthedocs.io/en/latest/export/shape_poly.html#dimension-variables-must-be-solvable-from-the-input-shapes for more details." )), ]) def test_shape_constraints_errors(self, *, @@ -2377,7 +2377,7 @@ def f2(z, w): # z: f32[a, 5] w: f32[a + b, 5] -> f32[2*a + b, 10] PolyHarness("scatter_grad", "", lambda *args: jax.grad( lambda *args: - jnp.sum(lax.scatter( # type: ignore + jnp.sum(lax.scatter( *args, indices_are_sorted=False, unique_indices=False, @@ -2392,7 +2392,7 @@ def f2(z, w): # z: f32[a, 5] w: f32[a + b, 5] -> f32[2*a + b, 10] PolyHarness("scatter_grad", "poly_indices", lambda *args: jax.grad( lambda *args: - jnp.sum(lax.scatter( # type: ignore + jnp.sum(lax.scatter( *args, indices_are_sorted=False, unique_indices=False)) @@ -2523,6 +2523,12 @@ def f2(z, w): # z: f32[a, 5] w: f32[a + b, 5] -> f32[2*a + b, 10] lambda x: jnp.tri(x.shape[0], M=x.shape[0] + 2) + x, arg_descriptors=[RandArg((3, 1), _f32)], polymorphic_shapes=["b, ..."]), + PolyHarness("tril", "", + lambda x: jnp.tril(jnp.ones((x.shape[0], x.shape[0] + x.shape[1]), + dtype=_f32), + k=x.shape[1]), + arg_descriptors=[RandArg((3, 4), _f32)], + polymorphic_shapes=["m, n"]), [ PolyHarness("triangular_solve", f"shape={jtu.format_shape_dtype_string(a_shape, dtype)}_{left_side=}_{a_poly=}_{b_poly=}", @@ -2809,17 +2815,8 @@ def test_harness(self, harness: PolyHarness): if harness.group_name == "eig" and not jtu.test_device_matches(["cpu"]): raise unittest.SkipTest("JAX implements eig only on CPU.") - prev_jax_config_flags = { - fname: getattr(jax.config, fname) - for fname, fvalue in harness.override_jax_config_flags.items() - } - try: - for fname, fvalue in harness.override_jax_config_flags.items(): - jax.config.update(fname, fvalue) + with jtu.global_config_context(**harness.override_jax_config_flags): harness.run_test(self) - finally: - for fname, _ in harness.override_jax_config_flags.items(): - jax.config.update(fname, prev_jax_config_flags[fname]) if __name__ == "__main__": diff --git a/jax/experimental/jax2tf/tests/sharding_test.py b/jax/experimental/jax2tf/tests/sharding_test.py index 9a674e79119d..b6750133090e 100644 --- a/jax/experimental/jax2tf/tests/sharding_test.py +++ b/jax/experimental/jax2tf/tests/sharding_test.py @@ -39,7 +39,6 @@ from jax import lax from jax.experimental import jax2tf from jax.experimental import pjit -from jax.experimental.maps import xmap from jax.experimental.shard_map import shard_map from jax.sharding import NamedSharding from jax.sharding import Mesh @@ -48,20 +47,20 @@ import numpy as np -import tensorflow as tf # type: ignore[import] +import tensorflow as tf config.parse_flags_with_absl() # Must come after initializing the flags from jax.experimental.jax2tf.tests import tf_test_util -prev_xla_flags = None -prev_spmd_lowering_flag = None - +_exit_stack = contextlib.ExitStack() topology = None def setUpModule(): - global prev_xla_flags, topology + _exit_stack.enter_context(jtu.set_host_platform_device_count(8)) + + global topology if jtu.test_device_matches(["tpu"]): resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='') tf.config.experimental_connect_to_cluster(resolver) @@ -70,26 +69,8 @@ def setUpModule(): else: topology = None - prev_xla_flags = os.getenv("XLA_FLAGS") - flags_str = prev_xla_flags or "" - # Don't override user-specified device count, or other XLA flags. - if "xla_force_host_platform_device_count" not in flags_str: - os.environ["XLA_FLAGS"] = (flags_str + - " --xla_force_host_platform_device_count=8") - # Clear any cached backends so new CPU backend will pick up the env var. - xla_bridge.get_backend.cache_clear() - global prev_spmd_lowering_flag - prev_spmd_lowering_flag = maps.SPMD_LOWERING.value - config.update('experimental_xmap_spmd_lowering', True) - - def tearDownModule(): - if prev_xla_flags is None: - del os.environ["XLA_FLAGS"] - else: - os.environ["XLA_FLAGS"] = prev_xla_flags - xla_bridge.get_backend.cache_clear() - config.update('experimental_xmap_spmd_lowering', prev_spmd_lowering_flag) + _exit_stack.close() class ShardingTest(tf_test_util.JaxToTfTestCase): @@ -454,54 +435,39 @@ def f_grad_tf(x_v, res_ct): (r"f32\[20,10\].*custom_call_target.*Sharding.*sharding.*devices=\[2,1\]", count_out_P), ]) - @jtu.parameterized_filterable( - kwargs=[ - dict(testcase_name=f"_kind={kind}_in_shardings={in_shardings}_out_shardings={out_shardings}", - kind=kind, in_shardings=in_shardings, out_shardings=out_shardings) - for kind in ("pjit", "jit", "sharding_constraint") - for in_shardings in ( - ("none", "P") if kind == "sharding_constraint" else - ("unspecified",) if kind == "jit" else - ("unspecified", "none", "P")) - for out_shardings in ( - ("unspecified",) if kind in ["sharding_constraint", "jit"] else - ("unspecified", "none", "P")) - ]) - def test_pjit_error_inner_sharding(self, kind="pjit", in_shardings="P", - out_shardings="none"): - # Check that we raise an error if there is no top-level pjit but we convert - # a function with non-replicated shardings (with native lowering). - shardings_map = dict(none=None, P=P("x")) - + def test_grad_sharding_different_mesh(self): + # Convert with two similar meshes, the only difference being + # the order of the devices. grad should not fail. + # https://github.com/google/jax/issues/21314 + devices = jax.local_devices()[:2] + if len(devices) < 2: + raise unittest.SkipTest("Test requires 2 local devices") def f_jax(x): - if kind == "pjit": - pjit_kwargs = {} - if in_shardings != "unspecified": - pjit_kwargs["in_shardings"] = shardings_map[in_shardings] - if out_shardings != "unspecified": - pjit_kwargs["out_shardings"] = shardings_map[out_shardings] - res = pjit.pjit(lambda x: x * 2., **pjit_kwargs)(x) - elif kind == "jit": - res = jax.jit(lambda x: x * 2.)(x) - elif kind == "sharding_constraint": - res = jax.lax.with_sharding_constraint(x * 2., shardings_map[in_shardings]) - else: - assert False - return res - - expect_error = (in_shardings == "P" or out_shardings == "P") - shape = (8, 10) - x = np.arange(np.prod(shape), dtype=np.float32).reshape(shape) - - f_tf = tf.function(jax2tf.convert(f_jax, native_serialization=True), - autograph=False, jit_compile=True) - with contextlib.ExitStack() as stack: - if expect_error: - stack.enter_context(self.assertRaisesRegex( - ValueError, - "Lowered function does not have a top-level pjit but it has non-replicated sharding annotations")) - with Mesh(self.devices, axis_names=("x",)): - f_tf(x) + return jnp.sum(x * 2.) + + mesh = Mesh(devices, "i") + # The same mesh with reversed order of devices + mesh_rev = Mesh(list(reversed(devices)), "i") + shardings = NamedSharding(mesh, jax.sharding.PartitionSpec(("i",))) + shardings_rev = NamedSharding(mesh_rev, jax.sharding.PartitionSpec(("i",))) + + f_tf = tf.function(jax2tf.convert(pjit.pjit(f_jax, in_shardings=shardings)), + autograph=False) + f_tf_rev = tf.function(jax2tf.convert(pjit.pjit(f_jax, in_shardings=shardings_rev)), + autograph=False) + inp = np.ones((2, 4), dtype=np.float32) + + input_v = tf.Variable(inp) + with tf.GradientTape(persistent=True) as tape: + tape.watch(input_v) + res_tf = f_tf(input_v) + g = tape.gradient(res_tf, input_v) + + with tf.GradientTape(persistent=True) as tape: + tape.watch(input_v) + res_tf_rev = f_tf_rev(input_v) + g_rev = tape.gradient(res_tf_rev, input_v) + self.assertAllClose(g, g_rev) @jtu.parameterized_filterable( kwargs=[ @@ -551,115 +517,6 @@ def f_nested_pjit_replicated(a): "function with sharded arguments or results must be used under a `tf.function` context"): jax2tf.convert(f_jax)(a) - def test_xmap_basic(self): - devices = np.reshape(self.devices, (1, 2)) - ashape = (16, 8, 5) - a = np.arange(np.prod(ashape), dtype=np.float32).reshape(ashape) - bshape = (2, 7) - b = np.arange(np.prod(bshape), dtype=np.float32).reshape(bshape) - - # f_jax: f32[16,8,5], f32[2,7] -> f32[16,8,10], f32[2,28] - # lambda ...: f32[5], f32[7] -> f32[10], f32[28] - f_jax = xmap(lambda a, b: (jnp.concatenate([a, a], axis=0) * 2., - jnp.concatenate([b, b, b, b], axis=0) * 4.), - in_axes=({0: 'a', 1: 'b'}, ['c', ...]), - out_axes=({0: 'a', 1: 'b'}, ['c', ...]), - axis_resources={'a': 'x', 'b': 'y', 'c': 'x'}) - - @tf.function(autograph=False, jit_compile=True) - def f_tf(a, b): - # xmap works only with native serialization - f_converted = jax2tf.convert(f_jax, native_serialization=True) - if jtu.test_device_matches(["tpu"]): - res = tf.compat.v1.tpu.rewrite( - f_converted, [tf.convert_to_tensor(a), tf.convert_to_tensor(b)], - device_assignment=self.device_assignment( - computation_shape=[1, 1, 1, 2]) - ) - return (res[0], res[1]) - else: - return f_converted(a, b) - - with Mesh(devices, ('x', 'y')): - res_jax = f_jax(a, b) - self.assertAllClose(res_jax, (jnp.concatenate([a, a], axis=2) * 2., - jnp.concatenate([b, b, b, b], axis=1) * 4.)) - res_tf = f_tf(a, b) - self.assertAllClose(res_tf, res_jax) - - self.check_sharding( - jax2tf.convert(f_jax, native_serialization=True), [a, b], - checks=[ - (r"f32\[16,8,5\].*custom_call_target.*Sharding.*sharding.*devices=\[1,2,1\]", 1), - # The output sharding - (r"f32\[2,7\].*custom_call_target.*Sharding.*sharding.*replicated", 1), - (r"f32\[2,28\].*custom_call_target.*Sharding.*sharding.*replicated", 1), - ]) - - def test_xmap_collective_reduce(self): - devices = np.reshape(self.devices, (1, 2)) - ashape = (16, 8, 5) - a = np.arange(np.prod(ashape), dtype=np.float32).reshape(ashape) - bshape = (2, 7) - b = np.arange(np.prod(bshape), dtype=np.float32).reshape(bshape) - f_jax = xmap(lambda a, b: (lax.psum(a * 2., 'a'), b * 4.), - in_axes=(['a', 'b', ...], {0: 'c'}), - out_axes=(['b', ...], {0: 'c'}), - axis_resources={'a': 'x', 'b': 'y', 'c': 'x'}) - - @tf.function(autograph=False, jit_compile=True) - def f_tf(a, b): - f_converted = jax2tf.convert(f_jax, native_serialization=True) - if jtu.test_device_matches(["tpu"]): - res = tf.compat.v1.tpu.rewrite( - f_converted, [tf.convert_to_tensor(a), tf.convert_to_tensor(b)], - device_assignment=self.device_assignment( - computation_shape=[1, 1, 1, 2]) - ) - return (res[0], res[1]) - else: - return f_converted(a, b) - - with Mesh(devices, ('x', 'y')): - res_jax = f_jax(a, b) - self.assertAllClose(res_jax, ((a * 2.).sum(0), b * 4.)) - res_tf = f_tf(a, b) - self.assertAllClose(res_tf, res_jax) - self.check_sharding( - jax2tf.convert(f_jax, native_serialization=True), [a, b], - checks=[ - (r"f32\[16,8,5\].*custom_call_target.*Sharding.*sharding.*devices=\[1,2,1\]", 1), - (r"f32\[2,7\].*custom_call_target.*Sharding.*sharding.*replicated", 2), - (r"f32\[8,5\].*custom_call_target.*Sharding.*sharding.*devices=\[2,1\]", 1), - ]) - - def test_grad_xmap(self): - devices = np.reshape(self.devices, (1, 2)) - ashape = (16, 8, 5) - a = np.arange(np.prod(ashape), dtype=np.float32).reshape(ashape) - - # f_jax: f32[16,8,5]-> f32[16,8,10] - # lambda ...: f32[5]-> f32[10] - f_jax = xmap(lambda a: jnp.concatenate([a, a], axis=0) * 2., - in_axes=({0: 'a', 1: 'b'}), - out_axes={0: 'a', 1: 'b'}, - axis_resources={'a': 'x', 'b': 'y'}) - - def f_grad_tf(a, res_ct): - with tf.GradientTape(persistent=True) as tape: - tape.watch(a) - res_tf = jax2tf.convert(f_jax, native_serialization=True)(a) - return tape.gradient(res_tf, a, output_gradients=res_ct) - - with Mesh(devices, ('x', 'y')): - self.check_sharding(f_grad_tf, [a, np.concatenate([a, a], axis=2)], - checks=[ - # Primal input and grad output - (r"f32\[16,8,5\].*custom_call_target.*Sharding.*sharding.*devices=\[1,2,1\]", self.GEQ(2)), - # Input cotangent - (r"f32\[16,8,10\].*custom_call_target.*Sharding.*sharding.*devices=\[1,2,1\]", self.GEQ(1)), - ]) - @jtu.ignore_warning(category=UserWarning, message="all_to_all .* are only implemented properly for TPUs and GPUs .*") def test_shmap_all_to_all(self): diff --git a/jax/experimental/jax2tf/tests/tf_test_util.py b/jax/experimental/jax2tf/tests/tf_test_util.py index 51314638ea25..32f89e533daf 100644 --- a/jax/experimental/jax2tf/tests/tf_test_util.py +++ b/jax/experimental/jax2tf/tests/tf_test_util.py @@ -14,12 +14,12 @@ from __future__ import annotations -from collections.abc import Sequence +from collections.abc import Callable, Sequence import contextlib import dataclasses import re import os -from typing import Any, Callable, Optional +from typing import Any from absl.testing import absltest from absl import logging @@ -31,13 +31,13 @@ from jax import tree_util from jax.experimental import jax2tf -from jax.experimental import export +from jax import export from jax._src import config from jax._src import xla_bridge import numpy as np -import tensorflow as tf # type: ignore[import] -from tensorflow.compiler.xla import xla_data_pb2 # type: ignore[import] -from tensorflow.compiler.tf2xla.python import xla as tfxla # type: ignore[import] +import tensorflow as tf +from tensorflow.compiler.xla import xla_data_pb2 +from tensorflow.compiler.tf2xla.python import xla as tfxla DType = Any @@ -158,7 +158,7 @@ def ComputeTfValueAndGrad(tf_f: Callable, tf_args: Sequence, @jtu.with_config(jax_numpy_rank_promotion="allow", jax_numpy_dtype_promotion='standard', jax_legacy_prng_key="allow", - jax_enable_key_reuse_checks=False) + jax_debug_key_reuse=False) class JaxToTfTestCase(jtu.JaxTestCase): # We want most tests to use the maximum available version, from the locally # installed tfxla module and export. @@ -180,18 +180,18 @@ def setUp(self): # We run the tests using the maximum version supported, even though # the default serialization version may be held back for a while to # ensure compatibility - version = config.jax_serialization_version.value + version = config.jax_export_calling_convention_version.value if self.use_max_serialization_version: # Use the largest supported by both export and tfxla.call_module - version = min(export.maximum_supported_serialization_version, + version = min(export.maximum_supported_calling_convention_version, tfxla.call_module_maximum_supported_version()) self.assertGreaterEqual(version, - export.minimum_supported_serialization_version) - self.enter_context(config.jax_serialization_version(version)) + export.minimum_supported_calling_convention_version) + self.enter_context(config.jax_export_calling_convention_version(version)) logging.info( "Using JAX serialization version %s (export.max_version %s, tf.XlaCallModule max version %s)", version, - export.maximum_supported_serialization_version, + export.maximum_supported_calling_convention_version, tfxla.call_module_maximum_supported_version()) with contextlib.ExitStack() as stack: @@ -293,7 +293,7 @@ def log_message(extra): logging.info(log_message(f"Using tol={max_tol} due to {max_tol_lim}")) # Convert results to np.arrays - result_tf = tf.nest.map_structure(lambda t: t.numpy(), result_tf) # type: ignore + result_tf = tf.nest.map_structure(lambda t: t.numpy(), result_tf) custom_assert_lim = [l for l in jax2tf_limits if l.custom_assert] assert len(custom_assert_lim) <= 1, f"Expecting at most one applicable limitation with custom_assert, found {custom_assert_lim}" diff --git a/jax/experimental/jet.py b/jax/experimental/jet.py index ed4ecd5995cc..1ed6183b1229 100644 --- a/jax/experimental/jet.py +++ b/jax/experimental/jet.py @@ -52,7 +52,8 @@ `outstanding primitive rules `__. """ -from typing import Any, Callable +from collections.abc import Callable +from typing import Any from functools import partial @@ -490,24 +491,22 @@ def _pow_taylor(primals_in, series_in): return primal_out, series_out jet_rules[lax.pow_p] = _pow_taylor +def _pow_by_squaring(x, n): + if n < 0: + return _pow_by_squaring(1 / x, -n) + elif n == 0: + return 1 + elif n % 2 == 0: + return _pow_by_squaring(x * x, n / 2) + elif n % 2 == 1: + return x * _pow_by_squaring(x * x, (n - 1) / 2) + def _integer_pow_taylor(primals_in, series_in, *, y): if y == 0: return jet(jnp.ones_like, primals_in, series_in) - elif y == 1: - return jet(lambda x: x, primals_in, series_in) - elif y == 2: - return jet(lambda x: x * x, primals_in, series_in) - x, = primals_in - series, = series_in - u = [x] + series - v = [lax.integer_pow(x, y)] + [None] * len(series) - for k in range(1, len(v)): - vu = sum(_scale(k, j) * v[k-j] * u[j] for j in range(1, k + 1)) - uv = sum(_scale(k, j) * u[k-j] * v[j] for j in range(1, k)) - v[k] = jnp.where(x == 0, 0, fact(k-1) * (y * vu - uv) / x) - primal_out, *series_out = v + else: + return jet(lambda x: _pow_by_squaring(x, y), primals_in, series_in) - return primal_out, series_out jet_rules[lax.integer_pow_p] = _integer_pow_taylor @@ -741,6 +740,8 @@ def _pjit_jet_rule(primals_in, series_in, **params): params['out_shardings'] + (sharding_impls.UNSPECIFIED,) * num_series_out ), + 'in_layouts': params['in_layouts'] + (None,) * num_series_in, + 'out_layouts': params['out_layouts'] + (None,) * num_series_out, 'donated_invars': params['donated_invars'] + (False,) * num_series_in, } result = pjit.pjit_p.bind(*primals_and_series, **new_params) diff --git a/jax/experimental/key_reuse/__init__.py b/jax/experimental/key_reuse/__init__.py index 72d9a861eacf..121f52006390 100644 --- a/jax/experimental/key_reuse/__init__.py +++ b/jax/experimental/key_reuse/__init__.py @@ -16,33 +16,27 @@ Experimental Key Reuse Checking ------------------------------- -This module contains **experimental** functionality for detecting re-use of random -keys within JAX programs. It is under active development and the APIs here are likely -to change. The usage below requires JAX version 0.4.26 or newer. +This module contains **experimental** functionality for detecting reuse of random +keys within JAX programs. It is under active development and the APIs here are +likely to change. The usage below requires JAX version 0.4.26 or newer. -Key reuse checking can be enabled using the `jax_enable_key_reuse_checks` configuration:: +Key reuse checking can be enabled using the ``jax_debug_key_reuse`` configuration. +This can be set globally using:: + + >>> jax.config.update('jax_debug_key_reuse', True) # doctest: +SKIP + +Or it can be enabled locally with the :func:`jax.debug_key_reuse` context manager. +When enabled, using the same key twice will result in a :class:`~jax.errors.KeyReuseError`:: >>> import jax - >>> jax.config.update('jax_enable_key_reuse_checks', True) - >>> key = jax.random.key(0) - >>> jax.random.normal(key) - Array(-0.20584226, dtype=float32) - >>> jax.random.normal(key) # doctest: +IGNORE_EXCEPTION_DETAIL + >>> with jax.debug_key_reuse(True): + ... key = jax.random.key(0) + ... val1 = jax.random.normal(key) + ... val2 = jax.random.normal(key) # doctest: +IGNORE_EXCEPTION_DETAIL Traceback (most recent call last): ... KeyReuseError: Previously-consumed key passed to jit-compiled function at index 0 -This flag can also be controlled locally using the :func:`jax.enable_key_reuse_checks` -context manager:: - - >>> with jax.enable_key_reuse_checks(False): - ... print(jax.random.normal(key)) - -0.20584226 +The key reuse checker is currently experimental, but in the future we will likely +enable it by default. """ -from jax._src.prng import ( - reuse_key as reuse_key, -) - -from jax.experimental.key_reuse._core import ( - KeyReuseError as KeyReuseError, -) diff --git a/jax/experimental/key_reuse/_core.py b/jax/experimental/key_reuse/_core.py index 489fcc14e8fa..b4989e151a53 100644 --- a/jax/experimental/key_reuse/_core.py +++ b/jax/experimental/key_reuse/_core.py @@ -15,12 +15,14 @@ from __future__ import annotations from collections import defaultdict -from functools import partial, reduce, wraps -from typing import Any, Callable, NamedTuple +from collections.abc import Callable, Iterator +from functools import partial, reduce, total_ordering, wraps +from typing import Any, NamedTuple import jax from jax import lax from jax import tree_util +from jax.errors import KeyReuseError from jax.interpreters import batching, mlir from jax._src import api_util from jax._src import config @@ -30,6 +32,7 @@ from jax._src import prng from jax._src import random from jax._src import source_info_util +from jax._src import traceback_util from jax._src import util from jax._src.ad_checkpoint import remat_p from jax._src.debugging import debug_callback_p @@ -40,52 +43,148 @@ import numpy as np -class Sink(NamedTuple): - idx: int - mask: bool | np.ndarray = True +traceback_util.register_exclusion(__file__) - def __repr__(self): - if isinstance(self.mask, bool) and self.mask: - return f"Sink({self.idx})" - else: - return f"Sink({self.idx}, mask={self.mask})" +_source_context_message = ( + 'PRNG key first used at the above location was subsequently reused' + ' at the following location:') + +def key_reuse_error_with_source_traceback( + message: str, traceback: source_info_util.Traceback | None) -> KeyReuseError: + err = KeyReuseError(message) + if traceback is not None: + filtered_tb = traceback_util.filter_traceback(traceback.as_python_traceback()) + if filtered_tb: + context_err = KeyReuseError(_source_context_message).with_traceback(filtered_tb) + context_err.__context__ = err.__context__ + context_err.__cause__ = err.__cause__ + context_err.__suppress_context__ = err.__suppress_context__ + err.__context__ = None + err.__cause__ = context_err + return err -class Source(NamedTuple): +# Create Source() and Sink() objects which validate inputs, have +# correct equality semantics, and are hashable & immutable. +@total_ordering +class _SourceSinkBase: idx: int - mask: bool | np.ndarray = True + mask: bool | np.ndarray + + def __init__(self, idx: int, mask: bool | np.bool_ | np.ndarray = True): + assert isinstance(idx, int) + if isinstance(mask, np.ndarray): + assert mask.dtype == np.dtype('bool') + if np.all(mask): + mask = True + elif not np.any(mask): + mask = False + elif mask.flags.writeable: + mask = np.array(mask, copy=True) + mask.flags.writeable = False + elif isinstance(mask, np.bool_): + mask = bool(mask) + else: + assert isinstance(mask, bool) + super().__setattr__("idx", idx) + super().__setattr__("mask", mask) - def __repr__(self): - if isinstance(self.mask, bool) and self.mask: - return f"Source({self.idx})" + def __setattr__(self, *args, **kwargs): + raise ValueError(f"{self.__class__.__name__} is immutable") + + def __eq__(self, other): + return (self.__class__ == other.__class__ + and self.idx == other.idx + and np.shape(self.mask) == np.shape(other.mask) + and np.all(self.mask == other.mask)) + + def __lt__(self, other): + if isinstance(other, Forward): + return True + elif isinstance(other, _SourceSinkBase): + return ((self.__class__.__name__, self.idx) + < (other.__class__.__name__, other.idx)) + else: + return NotImplemented + + def __hash__(self): + if isinstance(self.mask, bool): + return hash((self.__class__, self.idx, self.mask)) else: - return f"Source({self.idx}, mask={self.mask})" + mask = np.asarray(self.mask) + return hash((self.__class__, self.idx, mask.shape, + tuple(mask.flatten().tolist()))) + + def __repr__(self): + if self.mask is True: + return f"{self.__class__.__name__}({self.idx})" + return f"{self.__class__.__name__}({self.idx}, {self.mask})" + + +class Sink(_SourceSinkBase): + pass + + +class Source(_SourceSinkBase): + pass + class Forward(NamedTuple): in_idx: int out_idx: int + def __repr__(self): + return f"Forward({self.in_idx}, {self.out_idx})" + + +# KeyReuseSignature is essentially a frozen set of Source/Sink/Forward +# objects, with a few convenience methods related to key reuse checking. +class KeyReuseSignature: + _args: frozenset[Source | Sink | Forward] + + def __init__(self, *args): + self._args = frozenset(args) + + def __repr__(self): + return f"KeyReuseSignature{tuple(sorted(self._args))}" + + def __eq__(self, other): + return isinstance(other, KeyReuseSignature) and self._args == other._args + + def __hash__(self): + return hash(self._args) -class KeyReuseSignature(NamedTuple): - sinks: list[Sink] - sources: list[Source] - forwards: list[Forward] = [] + @property + def sinks(self) -> Iterator[Sink]: + yield from (s for s in self._args if isinstance(s, Sink)) + + @property + def sources(self) -> Iterator[Source]: + yield from (s for s in self._args if isinstance(s, Source)) + + @property + def forwards(self) -> Iterator[Forward]: + yield from (s for s in self._args if isinstance(s, Forward)) def check_signature(self, *args, funcname="function", context=None): for sink in self.sinks: - if not isinstance(args[sink.idx], prng.PRNGKeyArray): + key = args[sink.idx] + if not isinstance(key, prng.PRNGKeyArray): continue - if np.any(args[sink.idx]._consumed & sink.mask): + if np.any(key._consumed & sink.mask): msg = f"Previously-consumed key passed to {funcname} at index {sink.idx}" if context: msg += " {context}" - raise KeyReuseError(msg) + raise key_reuse_error_with_source_traceback( + msg, key._source_info and key._source_info.traceback) def update_consumption(self, args_in, args_out): for sink in self.sinks: arg = args_in[sink.idx] if isinstance(arg, prng.PRNGKeyArray): arg._consumed = arg._consumed | sink.mask + if np.any(sink.mask): + arg._source_info = source_info_util.current() for arg in args_out: if isinstance(arg, prng.PRNGKeyArray): arg._consumed = True @@ -99,8 +198,42 @@ def update_consumption(self, args_in, args_out): arg_out._consumed = arg_in._consumed -class KeyReuseError(RuntimeError): - pass +class DynamicKeyReuseSignature(NamedTuple): + signature: Callable[[core.JaxprEqn], KeyReuseSignature] + +def dynamic_key_reuse_signature(f: Callable[[core.JaxprEqn], KeyReuseSignature]) -> DynamicKeyReuseSignature: + return DynamicKeyReuseSignature(f) + +def key_reuse_signature_from_eqn(eqn: core.JaxprEqn) -> KeyReuseSignature: + if eqn.primitive in key_reuse_signatures: + sig = key_reuse_signatures[eqn.primitive] + if isinstance(sig, KeyReuseSignature): + return sig + elif isinstance(sig, DynamicKeyReuseSignature): + return sig.signature(eqn) + else: + raise TypeError( + f"Unrecognized key reuse sigature of type {type(sig)}: {sig}") + else: + return unknown_signature(eqn) + + +def key_reuse_signature_from_primitive(prim, *args, **params): + if prim == pjit.pjit_p: + return jaxpr_type_signature(params['jaxpr'].jaxpr) + if prim not in key_reuse_signatures: + # TODO(jakevdp) should we generate an unknown signature here? + raise RuntimeError(f"Internal: no key reuse rule for primitive {prim}") + sig = key_reuse_signatures[prim] + if isinstance(sig, KeyReuseSignature): + return sig + elif isinstance(sig, DynamicKeyReuseSignature): + jaxpr = jax.make_jaxpr(partial(prim.bind, **params))(*args).jaxpr + return jaxpr_type_signature(jaxpr) + else: + raise TypeError( + f"Unrecognized key reuse sigature of type {type(sig)}: {sig}") + consume_p = core.Primitive("consume") consume_p.def_impl(lambda x: x) @@ -145,51 +278,46 @@ def _check_consumed_value(eqn, consumed): # The behavior of most primitives can be described via simple signatures. -key_reuse_signatures: dict[core.Primitive, KeyReuseSignature] = {} +key_reuse_signatures: dict[core.Primitive, KeyReuseSignature | DynamicKeyReuseSignature] = {} -key_reuse_signatures[consume_p] = KeyReuseSignature([Sink(0)], [], [Forward(0, 0)]) -key_reuse_signatures[assert_consumed_value_p] = KeyReuseSignature([], [], [Forward(0, 0)]) -key_reuse_signatures[prng.reuse_key_p] = KeyReuseSignature([], [Source(0)]) -key_reuse_signatures[prng.random_bits_p] = KeyReuseSignature([Sink(0)], []) +key_reuse_signatures[consume_p] = KeyReuseSignature(Sink(0), Forward(0, 0)) +key_reuse_signatures[assert_consumed_value_p] = KeyReuseSignature(Forward(0, 0)) +key_reuse_signatures[random.random_clone_p] = KeyReuseSignature(Source(0)) +key_reuse_signatures[prng.random_bits_p] = KeyReuseSignature(Sink(0)) # TODO(jakevdp): should fold_in sink its input key? -# key_reuse_signatures[prng.random_fold_in_p] = KeyReuseSignature([Sink(0)], [Source(0)]) -key_reuse_signatures[prng.random_fold_in_p] = KeyReuseSignature([], [Source(0)]) -key_reuse_signatures[prng.random_seed_p] = KeyReuseSignature([], [Source(0)]) -key_reuse_signatures[prng.random_split_p] = KeyReuseSignature([Sink(0)], [Source(0)]) -key_reuse_signatures[random.random_gamma_p] = KeyReuseSignature([Sink(0)], []) +key_reuse_signatures[prng.random_fold_in_p] = KeyReuseSignature(Source(0)) +key_reuse_signatures[prng.random_seed_p] = KeyReuseSignature(Source(0)) +key_reuse_signatures[prng.random_split_p] = KeyReuseSignature(Sink(0), Source(0)) +key_reuse_signatures[random.random_gamma_p] = KeyReuseSignature(Sink(0)) # TODO(jakevdp): broadcast should probably consume the input to avoid implicit duplication -key_reuse_signatures[lax.broadcast_in_dim_p] = KeyReuseSignature([], [], [Forward(0, 0)]) -key_reuse_signatures[lax.copy_p] = KeyReuseSignature([], [], [Forward(0, 0)]) -key_reuse_signatures[lax.convert_element_type_p] = KeyReuseSignature([], [], [Forward(0, 0)]) -key_reuse_signatures[lax.device_put_p] = KeyReuseSignature([], [], [Forward(0, 0)]) -key_reuse_signatures[lax.reshape_p] = KeyReuseSignature([], [], [Forward(0, 0)]) -key_reuse_signatures[lax.squeeze_p] = KeyReuseSignature([], [], [Forward(0, 0)]) -key_reuse_signatures[prng.random_wrap_p] = KeyReuseSignature([], [Source(0)], []) +key_reuse_signatures[lax.broadcast_in_dim_p] = KeyReuseSignature(Forward(0, 0)) +key_reuse_signatures[lax.copy_p] = KeyReuseSignature(Forward(0, 0)) +key_reuse_signatures[lax.convert_element_type_p] = KeyReuseSignature(Forward(0, 0)) +key_reuse_signatures[lax.device_put_p] = KeyReuseSignature(Forward(0, 0)) +key_reuse_signatures[lax.reshape_p] = KeyReuseSignature(Forward(0, 0)) +key_reuse_signatures[lax.squeeze_p] = KeyReuseSignature(Forward(0, 0)) +key_reuse_signatures[prng.random_wrap_p] = KeyReuseSignature(Source(0)) # TODO(jakevdp): should unwrap sink its input key? -key_reuse_signatures[prng.random_unwrap_p] = KeyReuseSignature([], [], []) -key_reuse_signatures[debug_callback_p] = KeyReuseSignature([], []) -key_reuse_signatures[lax.dynamic_slice_p] = KeyReuseSignature([], [], [Forward(0, 0)]) -key_reuse_signatures[lax.dynamic_update_slice_p] = KeyReuseSignature([Sink(1)], [], [Forward(0, 0)]) -key_reuse_signatures[lax.gather_p] = KeyReuseSignature([], [], [Forward(0, 0)]) -key_reuse_signatures[lax.scatter_p] = KeyReuseSignature([Sink(2)], [], [Forward(0, 0)]) +key_reuse_signatures[prng.random_unwrap_p] = KeyReuseSignature() +key_reuse_signatures[debug_callback_p] = KeyReuseSignature() +key_reuse_signatures[lax.dynamic_slice_p] = KeyReuseSignature(Forward(0, 0)) +key_reuse_signatures[lax.dynamic_update_slice_p] = KeyReuseSignature(Sink(1), Forward(0, 0)) +key_reuse_signatures[lax.gather_p] = KeyReuseSignature(Forward(0, 0)) +key_reuse_signatures[lax.scatter_p] = KeyReuseSignature(Sink(2), Forward(0, 0)) # Equality checks don't consume -key_reuse_signatures[lax.eq_p] = KeyReuseSignature([], [], []) -key_reuse_signatures[lax.ne_p] = KeyReuseSignature([], [], []) - -# Rules which require more dynamic logic. -key_reuse_signatures_dynamic: dict[core.Primitive, Callable[..., KeyReuseSignature]] = {} +key_reuse_signatures[lax.eq_p] = KeyReuseSignature() +key_reuse_signatures[lax.ne_p] = KeyReuseSignature() # The default signature will Sink all key inputs, and not Source any. def unknown_signature(eqn): def is_key(var: core.Atom): return hasattr(var.aval, "dtype") and jax.dtypes.issubdtype(var.aval.dtype, jax.dtypes.prng_key) return KeyReuseSignature( - sinks=[Sink(idx, True) for idx, var in enumerate(eqn.invars) if is_key(var)], - sources=[], + *(Sink(idx) for idx, var in enumerate(eqn.invars) if is_key(var)) ) @weakref_lru_cache -def get_jaxpr_type_signature(jaxpr: core.Jaxpr) -> KeyReuseSignature: +def jaxpr_type_signature(jaxpr: core.Jaxpr) -> KeyReuseSignature: """Parse the jaxpr to determine key reuse signature""" consumed: dict[core.Atom, bool | np.ndarray] = {} forwards: dict[core.Atom, core.Atom] = {} # map forwarded outputs to inputs. @@ -218,7 +346,6 @@ def sink(var: core.Atom, mask=True): return True consumed[var] = np.logical_or(consumed.get(var, False), mask) - def source(var: core.Atom, mask=False): if not is_key(var): return @@ -236,13 +363,7 @@ def is_consumed(var: core.Atom): traceback = eqn.source_info.traceback name_stack = source_info_util.current_name_stack() + eqn.source_info.name_stack with source_info_util.user_context(traceback, name_stack=name_stack): - if eqn.primitive in key_reuse_signatures: - signature = key_reuse_signatures[eqn.primitive] - elif eqn.primitive in key_reuse_signatures_dynamic: - signature = key_reuse_signatures_dynamic[eqn.primitive](eqn) - else: - signature = unknown_signature(eqn) - + signature = key_reuse_signature_from_eqn(eqn) if eqn.primitive == assert_consumed_value_p: # This is a special case that goes beyond normal key reuse logic. _check_consumed_value(eqn, is_consumed(eqn.invars[0])) @@ -263,62 +384,82 @@ def is_consumed(var: core.Atom): raise KeyReuseError(f"In {eqn.primitive}, source {src.idx} out of range [0, {len(eqn.outvars)}]") source(eqn.outvars[src.idx]) + all_inputs = [*jaxpr.invars, *jaxpr.constvars] return KeyReuseSignature( - sinks=[Sink(i, consumed[v]) for i, v in enumerate(jaxpr.invars) - if is_key(v) and np.any(consumed.get(v, False))], - sources=[Source(i) for i, v in enumerate(jaxpr.outvars) - if is_key(v) and resolve_forwards(v) not in jaxpr.invars and not consumed.get(v, False)], - forwards=[Forward(jaxpr.invars.index(resolve_forwards(outvar)), idx_out) # type: ignore[arg-type] - for idx_out, outvar in enumerate(jaxpr.outvars) - if is_key(outvar) and resolve_forwards(outvar) in jaxpr.invars] + *(Sink(i, consumed[v]) for i, v in enumerate(all_inputs) + if is_key(v) and np.any(consumed.get(v, False))), + *(Source(i) for i, v in enumerate(jaxpr.outvars) + if is_key(v) and resolve_forwards(v) not in all_inputs and not consumed.get(v, False)), + *(Forward(all_inputs.index(resolve_forwards(outvar)), idx_out) # type: ignore[arg-type] + for idx_out, outvar in enumerate(jaxpr.outvars) + if is_key(outvar) and resolve_forwards(outvar) in all_inputs) ) +def function_type_signature(fun: Callable[..., Any], *args: Any) -> KeyReuseSignature: + args_flat, in_tree = tree_util.tree_flatten(args) + in_avals_flat = [core.get_aval(arg) for arg in args_flat] + wrapped_fun, _ = api_util.flatten_fun_nokwargs(lu.wrap_init(fun), in_tree) + jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals_flat) + return jaxpr_type_signature(jaxpr) + + def check_key_reuse_jaxpr(jaxpr: core.Jaxpr) -> None: """Check the jaxpr for key reuse.""" - get_jaxpr_type_signature(jaxpr) + jaxpr_type_signature(jaxpr) def check_key_reuse(fun: Callable[..., Any], /, *args: Any) -> None: """Function to statically check key reuse.""" - args_flat, in_tree = tree_util.tree_flatten(args) - in_avals_flat = [core.get_aval(arg) for arg in args_flat] - wrapped_fun, _ = api_util.flatten_fun_nokwargs(lu.wrap_init(fun), in_tree) - jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals_flat) - check_key_reuse_jaxpr(jaxpr) + function_type_signature(fun, *args) #---------------------------------------------------------------------------------- # key reuse rules for particular primitives: +@dynamic_key_reuse_signature def _slice_signature(eqn): in_aval = eqn.invars[0].aval if not jax.dtypes.issubdtype(in_aval.dtype, jax.dtypes.prng_key): - return KeyReuseSignature([], [], [Forward(0, 0)]) + return KeyReuseSignature(Forward(0, 0)) if any(core.is_symbolic_dim(s) for s in in_aval.shape): - return KeyReuseSignature([], [], [Forward(0, 0)]) + return KeyReuseSignature(Forward(0, 0)) start_indices = eqn.params['start_indices'] limit_indices = eqn.params['limit_indices'] strides = eqn.params['strides'] or (1,) * len(start_indices) idx = tuple(slice(*tup) for tup in util.safe_zip(start_indices, limit_indices, strides)) sink = np.zeros(in_aval.shape, dtype=bool) sink[idx] = True - return KeyReuseSignature([Sink(0, sink)], [Source(0)]) + return KeyReuseSignature(Sink(0, sink), Source(0)) -key_reuse_signatures_dynamic[lax.slice_p] = _slice_signature +key_reuse_signatures[lax.slice_p] = _slice_signature +@dynamic_key_reuse_signature +def _concatenate_signature(eqn): + num_vals = len(eqn.invars) + # TODO(jakevdp): should this signature be more granular? + if num_vals == 1: + return KeyReuseSignature(Forward(0, 0)) + else: + return KeyReuseSignature(*(Sink(i) for i in range(num_vals)), Source(0)) + +key_reuse_signatures[lax.concatenate_p] = _concatenate_signature + +@dynamic_key_reuse_signature def _pjit_key_type_signature(eqn): - return get_jaxpr_type_signature(eqn.params['jaxpr'].jaxpr) + return jaxpr_type_signature(eqn.params['jaxpr'].jaxpr) -key_reuse_signatures_dynamic[pjit.pjit_p] = _pjit_key_type_signature +key_reuse_signatures[pjit.pjit_p] = _pjit_key_type_signature +@dynamic_key_reuse_signature def _shard_map_type_signature(eqn): - return get_jaxpr_type_signature(eqn.params['jaxpr']) + return jaxpr_type_signature(eqn.params['jaxpr']) -key_reuse_signatures_dynamic[shard_map_p] = _shard_map_type_signature +key_reuse_signatures[shard_map_p] = _shard_map_type_signature +@dynamic_key_reuse_signature def _cond_key_type_signature(eqn): - signatures = [get_jaxpr_type_signature(branch.jaxpr) for branch in eqn.params['branches']] + signatures = [jaxpr_type_signature(branch.jaxpr) for branch in eqn.params['branches']] sinks = defaultdict(list) sources = defaultdict(list) for sig in signatures: @@ -331,15 +472,16 @@ def _cond_key_type_signature(eqn): combined_sources = [Source(i, reduce(np.logical_and, m)) for i, m in sources.items()] combined_forwards = [Forward(f.in_idx + 1, f.out_idx) for f in set.intersection(*(set(sig.forwards) for sig in signatures))] - return KeyReuseSignature(combined_sinks, combined_sources, combined_forwards) + return KeyReuseSignature(*combined_sinks, *combined_sources, *combined_forwards) -key_reuse_signatures_dynamic[lax.cond_p] = _cond_key_type_signature +key_reuse_signatures[lax.cond_p] = _cond_key_type_signature +@dynamic_key_reuse_signature def _scan_key_type_signature(eqn): jaxpr = eqn.params['jaxpr'].jaxpr num_consts = eqn.params['num_consts'] num_carry = eqn.params['num_carry'] - signature = get_jaxpr_type_signature(jaxpr) + signature = jaxpr_type_signature(jaxpr) # scan body should not consume key in constants if any(np.any(s.mask) for s in signature.sinks if s.idx < num_consts): @@ -362,16 +504,17 @@ def _scan_key_type_signature(eqn): f" {jaxpr=}") return signature -key_reuse_signatures_dynamic[jax.lax.scan_p] = _scan_key_type_signature +key_reuse_signatures[jax.lax.scan_p] = _scan_key_type_signature +@dynamic_key_reuse_signature def _while_key_type_signature(eqn): cond_jaxpr = eqn.params['cond_jaxpr'].jaxpr cond_nconsts = eqn.params['cond_nconsts'] body_jaxpr = eqn.params['body_jaxpr'].jaxpr body_nconsts = eqn.params['body_nconsts'] - cond_signature = get_jaxpr_type_signature(cond_jaxpr) - body_signature = get_jaxpr_type_signature(body_jaxpr) + cond_signature = jaxpr_type_signature(cond_jaxpr) + body_signature = jaxpr_type_signature(body_jaxpr) # Error if there are sinks among consts. if any(np.any(s.mask) for s in cond_signature.sinks if s.idx < cond_nconsts): @@ -403,8 +546,9 @@ def _while_key_type_signature(eqn): f" {eqn=}") return body_signature -key_reuse_signatures_dynamic[jax.lax.while_p] = _while_key_type_signature +key_reuse_signatures[jax.lax.while_p] = _while_key_type_signature +@dynamic_key_reuse_signature def _remat_key_type_signature(eqn): # The assumption here is that the non-differentiated pass contains all relevant # key usage, and the differentiated pass @@ -412,40 +556,20 @@ def _remat_key_type_signature(eqn): # 2) will never create keys # Therefore, the differentiated pass is a no-op. if eqn.params['differentiated']: - return KeyReuseSignature([], []) - return get_jaxpr_type_signature(eqn.params['jaxpr']) - -key_reuse_signatures_dynamic[remat_p] = _remat_key_type_signature - - -# TODO(jakevdp): when we integrate key reuse checks more tightly with JAX, -# we should move this logic directly into each primitive impl. -def key_reuse_impl_rule(prim, original_rule): - @wraps(original_rule) - def key_reuse_impl(*args, **kwargs): - if config.enable_key_reuse_checks.value: - if prim == pjit.pjit_p: - funcname = "jit-compiled function" - jaxpr = kwargs['jaxpr'].jaxpr - signature = get_jaxpr_type_signature(jaxpr) - elif prim in key_reuse_signatures: - funcname = str(prim) - jaxpr = None - signature = key_reuse_signatures[prim] - elif prim in key_reuse_signatures_dynamic: - funcname = str(prim) - jaxpr = jax.make_jaxpr(partial(prim.bind, **kwargs))(*args).jaxpr - signature = get_jaxpr_type_signature(jaxpr) - else: - raise RuntimeError(f"Internal: no key reuse rule for primitive {prim}") - signature.check_signature(*args, funcname=funcname) - result = original_rule(*args, **kwargs) - signature.update_consumption(args, result if prim.multiple_results else [result]) - return result - else: - return original_rule(*args, **kwargs) - return key_reuse_impl - - -for prim in (*key_reuse_signatures, *key_reuse_signatures_dynamic): - prim.impl = key_reuse_impl_rule(prim, prim.impl) # type: ignore[method-assign] + return KeyReuseSignature() + return jaxpr_type_signature(eqn.params['jaxpr']) + +key_reuse_signatures[remat_p] = _remat_key_type_signature + + +def call_impl_with_key_reuse_checks(prim: core.Primitive, raw_impl: Callable[..., Any], *args, **kwargs) -> Any: + if prim not in key_reuse_signatures: + # TODO(jakevdp): should we use an unknown signature here? + return raw_impl(*args, **kwargs) + signature = key_reuse_signature_from_primitive(prim, *args, **kwargs) + funcname = "jit-compiled function" if prim == pjit.pjit_p else str(prim) + consts = kwargs['jaxpr'].consts if prim == pjit.pjit_p else [] + signature.check_signature(*args, *consts, funcname=funcname) + result = raw_impl(*args, **kwargs) + signature.update_consumption([*args, *consts], result if prim.multiple_results else [result]) + return result diff --git a/jax/experimental/layout.py b/jax/experimental/layout.py index 716d73e2f569..ed9f8931938e 100644 --- a/jax/experimental/layout.py +++ b/jax/experimental/layout.py @@ -13,6 +13,6 @@ # limitations under the License. from jax._src.layout import ( - SpecifiedLayout as SpecifiedLayout, - AUTO as AUTO, + DeviceLocalLayout as DeviceLocalLayout, + Layout as Layout ) diff --git a/jax/experimental/maps.py b/jax/experimental/maps.py index afdd11dbb3f7..e378a6fc2499 100644 --- a/jax/experimental/maps.py +++ b/jax/experimental/maps.py @@ -12,18 +12,38 @@ # See the License for the specific language governing permissions and # limitations under the License. +import warnings + +from jax._src import deprecations from jax._src.maps import ( AxisName as AxisName, ResourceSet as ResourceSet, SerialLoop as SerialLoop, + _prepare_axes as _prepare_axes, make_xmap_callable as make_xmap_callable, serial_loop as serial_loop, - xmap as xmap, xmap_p as xmap_p, - _prepare_axes as _prepare_axes, + xmap as xmap, ) from jax._src.mesh import ( EMPTY_ENV as EMPTY_ENV, ResourceEnv as ResourceEnv, thread_resources as thread_resources, ) + +# Added March 7, 2024. +_msg = ( + "jax.experimental.maps and jax.experimental.maps.xmap are deprecated and" + " will be removed in a future release. Use jax.experimental.shard_map or" + " jax.vmap with the spmd_axis_name argument for expressing SPMD" + " device-parallel computations. Please file an issue on" + " https://github.com/google/jax/issues if neither" + " jax.experimental.shard_map nor jax.vmap are suitable for your use case." +) + +if deprecations.is_accelerated("jax-experimental-maps-module"): + raise ImportError(_msg) +else: + warnings.warn(_msg, DeprecationWarning, stacklevel=2) + +del deprecations, warnings, _msg diff --git a/jax/experimental/mesh_utils.py b/jax/experimental/mesh_utils.py index 89ef0159d74c..dd112db3b269 100644 --- a/jax/experimental/mesh_utils.py +++ b/jax/experimental/mesh_utils.py @@ -17,27 +17,30 @@ from __future__ import annotations import collections -from collections.abc import Sequence +from collections.abc import Callable, Generator, MutableMapping, Sequence import itertools import logging -from typing import Any, Callable, Optional +import math +from typing import Any -import numpy as np from jax._src import xla_bridge as xb +import numpy as np logger = logging.getLogger(__name__) _TPU_V2 = 'TPU v2' _TPU_V3 = 'TPU v3' _TPU_V4 = 'TPU v4' +_TPU_V5_LITE = "TPU v5 lite" # Maps physical topology -> mesh shape -> transpose to use for jekbradbury's # famous contiguous mesh trick. # # The trick only works for certain topologies and mesh shapes. Trivial dims of # size 1 can be added to the shapes listed, and they are also supported. -_TRANSPOSE_TRICKS: dict[tuple[int, ...], - dict[tuple[int, ...], tuple[int, ...]]] = { +_TRANSPOSE_TRICKS: dict[ + tuple[int, ...], dict[tuple[int, ...], tuple[int, ...]] +] = { (2, 2, 1): { (2, 2): (0, 1, 2), }, @@ -62,7 +65,8 @@ # Physical ordering of core IDs in a tray that creates a ring _TRAY_RING_ORDER = (0, 1, 2, 3, 6, 7, 4, 5) - +_TRAY_2x2_RING_ORDER = (0, 1, 3, 2) +_TRAY_4x4_RING_ORDER = (0, 1, 2, 3, 7, 6, 5, 9, 10, 11, 15, 14, 13, 12, 8, 4) def _tpu_v2_v3_create_device_mesh( mesh_shape: Sequence[int], @@ -92,6 +96,45 @@ def _tpu_v2_v3_create_device_mesh( return np.asarray(devices).reshape(mesh_shape) +def _vlc_create_device_mesh( + mesh_shape: Sequence[int], devices: Sequence[Any], **unused_kwargs +) -> np.ndarray | None: + """Creates rotated pincer device assignment for selected topologies. + + Args: + mesh_shape: Logical mesh shape used by the model. + devices: TPU devices. + **unused_kwargs: ... + + Returns: + None or reordered devices reshaped as `mesh_shape`. + """ + max_x, max_y, max_z = max(getattr(d, "coords", (0, 0, 0)) for d in devices) + bound_x, bound_y, bound_z = max_x + 1, max_y + 1, max_z + 1 + # Our ring re-ordering makes sense only if the passed-in devices are + # sequential, which may not always be the case. reversed() changes z-minor to + # x-minor. + sequential_devices = sorted( + devices, + key=lambda d: tuple(reversed(getattr(d, "coords", (0, 0, 0))))) + + if bound_x == bound_y == 2 and bound_z == 1 and len(devices) == 4: # VLC2x2 + device_mesh = np.asarray(sequential_devices) + device_mesh = device_mesh[np.array(_TRAY_2x2_RING_ORDER)] + device_mesh = device_mesh.reshape(mesh_shape) + return device_mesh + + if bound_x == bound_y == 4 and bound_z == 1 and len(devices) == 16: # VLP4x4 + # Only uses ring order if the whole mesh is a replica group. + if max(mesh_shape) == len(devices): + device_mesh = np.asarray(sequential_devices) + device_mesh = device_mesh[np.array(_TRAY_4x4_RING_ORDER)] + device_mesh = device_mesh.reshape(mesh_shape) + return device_mesh + + return None + + # Registers functions to create device mesh for specific device kinds. Takes # precedence over the more general logic in create_device_mesh(). Handler may # return None; in that case, it will fall back to using the default logic. @@ -101,12 +144,16 @@ def _tpu_v2_v3_create_device_mesh( ] = { _TPU_V2: _tpu_v2_v3_create_device_mesh, _TPU_V3: _tpu_v2_v3_create_device_mesh, + _TPU_V5_LITE: _vlc_create_device_mesh, } def _create_device_mesh_for_nd_torus( - physical_mesh: np.ndarray, mesh_shape: Sequence[int], -) -> tuple[np.ndarray, list[tuple[int, ...]]]: + physical_mesh: np.ndarray, + mesh_shape: Sequence[int], + *, + allow_split_physical_axes: bool = False, +) -> tuple[np.ndarray, np.ndarray]: """Assigns logical parallelism axes to physical axes of an N-D torus network. Given logical parallelism axes with sizes in `mesh_shape` and devices in an @@ -117,10 +164,9 @@ def _create_device_mesh_for_nd_torus( axes of the same size (e.g., a 2D square) rather than multiple physical axes of different sizes when possible. - Note that this routine will never split a physical axis over more than one - logical axis (which would reduce total usable bandwidth but may sometimes be - desired anyway). As a result, it will error out in cases where this is - necessary to produce a valid mapping. + If allow_split_physical_axes = False (default), this routine will error out + instead of splitting a physical axis over more than one logical axis (which + would reduce total usable bandwidth). Let's use a concrete example to explain the concepts and considerations. @@ -139,12 +185,16 @@ def _create_device_mesh_for_nd_torus( physical topology. mesh_shape: shape of the logical mesh (size of the various logical parallelism axes), with axes ordered by increasing network intensity. + allow_split_physical_axes: If True, we would split physical axes if + necessary to fit the desired mesh shape. Returns: An np.ndarray of devices in the shape of the logical mesh (mesh_shape), with each logical parallelism axis mapped to one or more physical mesh axes. - The axis assignment (a list of length num_logical_axes, whose elements - are tuples representing physical axis indices). + The axis assignment matrix, which is a 2-d array mapping from + (physical_axis, logical_axis) to the size assigned, with the invariant + np.prod(assignment, axis=1) = physical_mesh_shape, and + np.prod(assignment, axis=0) = mesh_shape. """ # Remaining physical axes to be assigned to logical axes. assignable_physical_mesh = list(physical_mesh.shape) @@ -155,14 +205,16 @@ def _create_device_mesh_for_nd_torus( # `mesh_shape` is assumed to ordered by lowest network intensity first, so # reverse it first. for logical_axis_index, logical_axis_size in reversed( - list(enumerate(mesh_shape))): + list(enumerate(mesh_shape)) + ): # Preferentially map to more physical axes first for higher bandwidth. for num_axes in range(3, 0, -1): # Try assign to any subset of size num_axes. Generate all candidates. - axes = itertools.combinations(assignable_physical_mesh, num_axes) - indices = itertools.combinations( - range(len(assignable_physical_mesh)), num_axes) - for c_axes, c_indices in zip(axes, indices): + indices_and_axes = itertools.combinations( + enumerate(assignable_physical_mesh), num_axes + ) + for elem in indices_and_axes: + c_indices, c_axes = zip(*elem) # TODO(zhangqiaorjc): Due to limitations in XLA, 2D collectives only # implemented for square 2D plane. Mapping a physical axis to two # logical axes might be slower for non-square 2D plane, e.g., map 32 to @@ -184,20 +236,328 @@ def _create_device_mesh_for_nd_torus( # If the num_axes for loop did not break, i.e. none of the candidates work # goto here with this while-else construct. if logical_axis_size > 1: - raise NotImplementedError( - 'Failed to find assignment for logical_axis_index' - f' {logical_axis_index} of size {logical_axis_size} with remaining' - f' assignable mesh {assignable_physical_mesh}. The size of each' - ' axis in your logical mesh must be equal to the product of' - ' some subset of the physical mesh axis sizes. E.g logical mesh (4,' - ' 16) is compatible with physical mesh 4x4x4 since 4=4 and 16=4x4.' - ) + if not allow_split_physical_axes: + # Although this is now implemented, there are downstream tasks + # counting on this being a NotImplementedError. + raise NotImplementedError( + 'Failed to find assignment for logical_axis_index' + f' {logical_axis_index} of size {logical_axis_size} with' + f' remaining assignable mesh {assignable_physical_mesh}. The size' + ' of each axis in your logical mesh must be equal to the product' + ' of some subset of the physical mesh axis sizes. E.g. logical' + ' mesh (4, 16) is compatible with physical mesh 4x4x4 since 4=4' + ' and 16=4x4. If you want to split physical axes, set ' + ' allow_split_physical_axes to True.' + ) + else: + # We will try finding an assignment, even if that means splitting the + # physical axes, which requires a more sophisticated implementation. + return _create_device_mesh_for_nd_torus_splitting_axes( + physical_mesh, mesh_shape + ) + # Flatten the assignment, e.g., [(), (2,), (0, 1)] -> (2, 0, 1). transpose: list[int] = [] - for x in assignment: + assignment_array = np.ones( + [len(physical_mesh.shape), len(mesh_shape)], dtype=np.int64 + ) + for i, x in enumerate(assignment): for y in x: - transpose.append(int(y)) - return physical_mesh.transpose(transpose).reshape(mesh_shape), assignment + physical_mesh_axis = int(y) + assignment_array[physical_mesh_axis, i] = physical_mesh.shape[ + physical_mesh_axis + ] + transpose.append(physical_mesh_axis) + return ( + physical_mesh.transpose(transpose).reshape(mesh_shape), + assignment_array, + ) + + +def _create_device_mesh_for_nd_torus_splitting_axes( + physical_mesh: np.ndarray, + mesh_shape: Sequence[int], +) -> tuple[np.ndarray, np.ndarray]: + """Assigns logical parallelism axes to physical axes of an N-D torus network. + + This implementation allows creating meshes that requires splitting physical + axes, and thus one could produce logical mesh of any shape, as long as the + number of devices matches, e.g., + + - Creating 2x2x4 from 4x4; + + - Creating 2x2x16 from 8x8; + + Args: + physical_mesh: a np.ndarray of devices in the shape of the N-D torus + physical topology. + mesh_shape: shape of the logical mesh (size of the various logical + parallelism axes), with axes ordered by increasing network intensity. + + Returns: + An np.ndarray of devices in the shape of the logical mesh (mesh_shape), with + each logical parallelism axis mapped to one or more physical mesh axes. + The axis assignment matrix, which is a 2-d array mapping from + (physical_axis, logical_axis) to the size assigned, with the invariant + np.prod(assignment, axis=1) = physical_mesh_shape, and + np.prod(assignment, axis=0) = mesh_shape. + """ + if np.prod(physical_mesh.shape) != np.prod(mesh_shape): + raise ValueError( + 'The number of devices in physical mesh' + f' {physical_mesh.shape} does not match the number of devices' + f' in logical mesh {mesh_shape}.' + ) + + physical_mesh_shape = physical_mesh.shape + logical_mesh_shape = tuple(mesh_shape) + + # (Partial) assignment map as an 2-d array [p_axis, l_axis] -> size. + assignment = np.ones( + [len(physical_mesh_shape), len(logical_mesh_shape)], dtype=np.int64 + ) + + # Process logical axes from highest network intensity to lowest. + # `mesh_shape` is assumed to ordered by lowest network intensity first, so + # reverse it. + for logical_axis, logical_axis_size in reversed( + list(enumerate(logical_mesh_shape)) + ): + # Go over all the possible assignment for the logical axis, including the + # one that splits multiple physical axes. + best_logical_axis_assignment = None + for logical_axis_assignment in _enumerate_feasible_logical_axis_assignments( + physical_mesh_shape, assignment, logical_axis_size + ): + # TODO(rosun): Instead of using heuristics, replace this with a proper + # scoring function reflecting the underlying hardware properties. + if ( + best_logical_axis_assignment is None + or _prefer_first_logical_axis_assignment( + logical_axis_assignment, + best_logical_axis_assignment, + physical_mesh_shape=physical_mesh_shape, + assignment=assignment, + ) + ): + best_logical_axis_assignment = logical_axis_assignment + assignment[:, logical_axis] = best_logical_axis_assignment + + # Read out the assignment. + logical_mesh = _generate_logical_mesh( + physical_mesh, logical_mesh_shape, assignment + ) + + return logical_mesh, assignment + + +def _get_prime_factors(x: int) -> list[int]: + """Returns a sorted list of prime factors for the given number.""" + assert x > 0 + factors = [] + for p in range(2, math.isqrt(x) + 2): + while x % p == 0: + factors.append(p) + x //= p + if x == 1: + return factors + else: + return [x] # x is a prime number. + + +def _enumerate_feasible_logical_axis_assignments( + physical_mesh_shape: Sequence[int], + assignment: np.ndarray, + logical_axis_size: int, +) -> Generator[np.ndarray, None, None]: + """Yields feasible assignments for a single logical axis. + + For a physical mesh of shape [x_1, ..., x_n], and the product of all previous + assignments on each physical axes [y_1, ..., y_n], this function yields all + possible assignments for the axis as 1-d arrays [z_1, ..., z_n], so that: + + - prod(z_1, ..., z_n) = logical_axis_size + + - x_i % (z_i * y_i) = 0 + + Args: + physical_mesh_shape: Physical mesh shape. + assignment: Existing assignment matrix. + logical_axis_size: Size of the logical axis to assign. + + Yields: + All valid assignments for the logical axis. Each assignment is represented + as an integer array of length len(physical_mesh_shape). + """ + logical_axis_factors: MutableMapping[int, int] = collections.defaultdict(int) + for factor in _get_prime_factors(logical_axis_size): + logical_axis_factors[factor] += 1 + + available_physical_mesh_shape = np.array(physical_mesh_shape) // np.prod( + assignment, axis=-1 + ) + + # To enable efficient enumerations, we first index physical axes by their + # prime factors. Since we know the prime factorization of the logical axis + # size, we could simply enumerate by picking the correct count for each + # prime factor. + physical_axes_by_factor: MutableMapping[int, list[int]] = ( + collections.defaultdict(list) + ) + for physical_axis, physical_axis_size in enumerate( + available_physical_mesh_shape + ): + for factor in _get_prime_factors(physical_axis_size): + if factor not in logical_axis_factors: + continue + physical_axes_by_factor[factor].append(physical_axis) + + factors = [] + assignments_by_factor = [] + for factor, multiplicity in logical_axis_factors.items(): + factors.append(factor) + assignments_by_factor.append( + set( + itertools.combinations( + physical_axes_by_factor[factor], multiplicity + ) + ) + ) + + for axis_assignment in itertools.product(*assignments_by_factor): + result = np.ones([len(physical_mesh_shape)], dtype=np.int64) + for factor_index, per_factor_assignment in enumerate(axis_assignment): + for physical_axis in per_factor_assignment: + result[physical_axis] *= factors[factor_index] + yield result + + +def _prefer_first_logical_axis_assignment( + x: np.ndarray, + y: np.ndarray, + *, + physical_mesh_shape: Sequence[int], + assignment: np.ndarray, +) -> bool: + """Returns True if the first axis assignment is preferred over the second. + + For now, this is implemented with some very simple heuristics. However, + it is possible to introduce e.g., a value function here based on a more + precise model of the underlying hardware. + + TODO(rosun): Use a proxy of network capacity to select the partitions. + + Args: + x: Logical axis assignment as [len(physical_mesh_shape)] array. + y: Logical axis assignment as [len(physical_mesh_shape)] array. + physical_mesh_shape: Physical mesh shape. + assignment: Assignment matrix. + + Returns: + True if x is preferred over y. + """ + # Prefer occupying complete physical axes. I don't have a good reason for + # this, except that it is compatible with the existing behavior. + # + # E.g., on 4 x 4 x 8, [4, 4, -] will be preferred over [4, -, 4], and then + # over [2, 2, 4]. + x_whole_axis_size = np.prod( + [s for i, s in enumerate(x) if s == physical_mesh_shape[i]] + ) + y_whole_axis_size = np.prod( + [s for i, s in enumerate(y) if s == physical_mesh_shape[i]] + ) + + if x_whole_axis_size != y_whole_axis_size: + return x_whole_axis_size > y_whole_axis_size + + # Prefer occupying more whole physical axes for better bandwidth. + # + # This is consistent with existing logic, i.e., 2 x 2 is preferred over 4. + x_num_whole_axes = len( + [1 for i, s in enumerate(x) if s == physical_mesh_shape[i] and s > 1] + ) + y_num_whole_axes = len( + [1 for i, s in enumerate(y) if s == physical_mesh_shape[i] and s > 1] + ) + + if x_num_whole_axes != y_num_whole_axes: + return x_num_whole_axes > y_num_whole_axes + + # Prefer taking physical axes that are not taken by logical axes of higher + # network intensity. E.g., for a 4 x 4 x 4, suppose that the previous + # assignments are 1 x 2 x 4, and we want to place a new logical axis of size + # 2, we will go for [2, 1, 1] instead of [1, 2, 1], as the latter choice will + # tap into bandwidth already taken by the higher intensity axis. + assigned_physical_mesh_shape = np.prod(assignment, axis=-1) + + x_non_overlapping_axis_size = np.prod( + [s for i, s in enumerate(x) if assigned_physical_mesh_shape[i] > 1] + ) + y_non_overlapping_axis_size = np.prod( + [s for i, s in enumerate(y) if assigned_physical_mesh_shape[i] > 1] + ) + + if x_non_overlapping_axis_size != y_non_overlapping_axis_size: + return x_non_overlapping_axis_size > y_non_overlapping_axis_size + + # Otherwise sort by reverse lexical graphical order, to be consistent with + # existing behavior. + return tuple(x) > tuple(y) + + +def _generate_logical_mesh( + physical_mesh: np.ndarray, + logical_mesh_shape: Sequence[int], + assignment: np.ndarray, +) -> np.ndarray: + """Compute the logical mesh from assignment map. + + Args: + physical_mesh: Physical device mesh. + logical_mesh_shape: Logical mesh shape. + assignment: 2-d assignment matrix shape [physical_dims, logical_dims]. + + Returns: + Logical mesh reshaped from physical mesh. + """ + physical_indices = np.broadcast_to( + np.expand_dims( + np.arange(len(physical_mesh.shape), dtype=np.int64), axis=-1 + ), + assignment.shape, + ).reshape([-1]) + + logical_indices = np.broadcast_to( + np.expand_dims( + np.arange(len(logical_mesh_shape), dtype=np.int64), axis=0 + ), + assignment.shape, + ).reshape([-1]) + + # Axes of logical mesh is ordered by (physical_axis, logical_axis). + # + # Note that we sort for each physical_axis the logical_axis, so that higher + # intensity logical axes are replicated at inner (minor) dimensions. + # + # E.g., if a dimension size is 12 = 3x4, where 3 is higher intensity and 4 + # is lower, we want to reshape so that it becomes 12 = 4x3. Imagine in the + # 1-d case, this will allow more connections between the higher intensity + # axes. + logical_mesh = np.reshape(physical_mesh, assignment.reshape([-1])) + + # We will then group by l_axis as this is what is expected from output. + _, _, transpose_axes = zip( + *sorted( + zip(logical_indices, physical_indices, range(len(logical_indices))) + ) + ) + logical_mesh = np.transpose(logical_mesh, transpose_axes) + + # Reshape to add the trivial dimensions back. + logical_mesh = np.reshape(logical_mesh, logical_mesh_shape) + + return logical_mesh def _bounds_from_last_device(last_device) -> Sequence[int]: @@ -245,14 +605,16 @@ def _get_physical_tpu_mesh(jax_devices: Sequence[Any]) -> np.ndarray: # jekbradbury's famous trick for creating contiguous submeshes (where available) -def _transpose_trick(physical_mesh: np.ndarray, - mesh_shape: Sequence[int]) -> np.ndarray: +def _transpose_trick( + physical_mesh: np.ndarray, mesh_shape: Sequence[int] +) -> np.ndarray: mesh_shape = tuple(mesh_shape) topology = physical_mesh.shape if topology not in _TRANSPOSE_TRICKS: raise ValueError( - f"create_device_mesh cannot create contiguous submeshes for " - f"physical mesh topology {topology}") + 'create_device_mesh cannot create contiguous submeshes for ' + f'physical mesh topology {topology}' + ) mesh_shape_no_trivial_dims: tuple[int, ...] = () for dim_size in mesh_shape: @@ -261,18 +623,23 @@ def _transpose_trick(physical_mesh: np.ndarray, if mesh_shape_no_trivial_dims not in _TRANSPOSE_TRICKS[topology]: raise ValueError( - f"create_device_mesh cannot create contiguous submeshes for " - f"mesh_shape {mesh_shape} and physical mesh topology {topology}. " - f"Available mesh_shapes: {list(_TRANSPOSE_TRICKS[topology].keys())}") + 'create_device_mesh cannot create contiguous submeshes for ' + f'mesh_shape {mesh_shape} and physical mesh topology {topology}. ' + f'Available mesh_shapes: {list(_TRANSPOSE_TRICKS[topology].keys())}' + ) return physical_mesh.transpose( - *_TRANSPOSE_TRICKS[topology][mesh_shape_no_trivial_dims]) + *_TRANSPOSE_TRICKS[topology][mesh_shape_no_trivial_dims] + ) def create_device_mesh( mesh_shape: Sequence[int], - devices: Sequence[Any] | None = None, *, - contiguous_submeshes: bool = False) -> np.ndarray: + devices: Sequence[Any] | None = None, + *, + contiguous_submeshes: bool = False, + allow_split_physical_axes: bool = False, +) -> np.ndarray: """Creates a performant device mesh for jax.sharding.Mesh. Args: @@ -287,6 +654,8 @@ def create_device_mesh( setting was sometimes necessary before the introduction of jax.Array to ensure non-ragged local arrays; if using jax.Arrays, it's better to keep this set to False. + allow_split_physical_axes: If True, we will split physical axes if necessary + to produce the desired device mesh. Raises: ValueError: if the number of devices doesn't equal the product of @@ -299,8 +668,10 @@ def create_device_mesh( if devices is None: devices = xb.devices() if np.prod(mesh_shape) != len(devices): - raise ValueError(f'Number of devices {len(devices)} must equal the product ' - f'of mesh_shape {mesh_shape}') + raise ValueError( + f'Number of devices {len(devices)} must equal the product ' + f'of mesh_shape {mesh_shape}' + ) last_device = devices[-1] handler = device_kind_handler_dict.get(last_device.device_kind, None) @@ -315,18 +686,25 @@ def create_device_mesh( physical_mesh = _get_physical_tpu_mesh(devices) if contiguous_submeshes: physical_mesh = _transpose_trick(physical_mesh, mesh_shape) - device_mesh, _ = _create_device_mesh_for_nd_torus(physical_mesh, mesh_shape) + device_mesh, _ = _create_device_mesh_for_nd_torus( + physical_mesh, + mesh_shape, + allow_split_physical_axes=allow_split_physical_axes, + ) return device_mesh else: device_mesh = np.asarray(devices).reshape(mesh_shape) return device_mesh + def create_hybrid_device_mesh( mesh_shape: Sequence[int], dcn_mesh_shape: Sequence[int], - devices: Sequence[Any] | None = None, *, + devices: Sequence[Any] | None = None, + *, process_is_granule: bool = False, should_sort_granules_by_key: bool = True, + allow_split_physical_axes: bool = False, ) -> np.ndarray: """Creates a device mesh for hybrid (e.g., ICI and DCN) parallelism. @@ -341,10 +719,12 @@ def create_hybrid_device_mesh( process_is_granule: if True, this function will treat processes as the units of the slower/outer network. Otherwise it will look for slice_index attributes on devices and use slices as the units. Enabling this is meant - as a fallback for platforms (e.g., GPU) that don't set slice_index. + as a fallback for platforms that don't set slice_index. should_sort_granules_by_key: Whether device granules should be sorted by the granule key, either slice or process index, depending on process_is_granule. + allow_split_physical_axes: If True, we will split physical axes if necessary + to produce the desired device mesh. Raises: ValueError: if the number of slices to which the `devices` belong doesn't @@ -365,16 +745,25 @@ def create_hybrid_device_mesh( granules = ( [granule_dict[key] for key in sorted(granule_dict.keys())] if should_sort_granules_by_key - else granule_dict.values()) + else granule_dict.values() + ) if np.prod(dcn_mesh_shape) != len(granules): raise ValueError( f'Number of slices {len(granules)} must equal the product of ' - f'dcn_mesh_shape {dcn_mesh_shape}') - per_granule_meshes = [create_device_mesh(mesh_shape, granule) - for granule in granules] + f'dcn_mesh_shape {dcn_mesh_shape}' + ) + per_granule_meshes = [ + create_device_mesh( + mesh_shape, + granule, + allow_split_physical_axes=allow_split_physical_axes, + ) + for granule in granules + ] # TODO(jekbradbury): handle non-uniform DCN topologies granule_mesh = np.arange(len(granules)).reshape(dcn_mesh_shape) - blocks = np.vectorize( - lambda i: per_granule_meshes[i], otypes=[object])(granule_mesh) + blocks = np.vectorize(lambda i: per_granule_meshes[i], otypes=[object])( + granule_mesh + ) device_mesh = np.block(blocks.tolist()) return device_mesh diff --git a/jax/experimental/mosaic/gpu/__init__.py b/jax/experimental/mosaic/gpu/__init__.py new file mode 100644 index 000000000000..c58b3b24efa0 --- /dev/null +++ b/jax/experimental/mosaic/gpu/__init__.py @@ -0,0 +1,704 @@ +# Copyright 2024 The JAX Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from collections.abc import Callable, Sequence +import contextlib +import ctypes +import dataclasses +import functools +import itertools +import os +import pathlib +import subprocess +import tempfile +import time +from typing import Any, Generic, TypeVar + +import jax +from jax._src import config +from jax._src import core as jax_core +from jax._src.interpreters import mlir +from jax._src.lib import xla_client +from jaxlib.mlir import ir +from jaxlib.mlir.dialects import arith +from jaxlib.mlir.dialects import builtin +from jaxlib.mlir.dialects import func +from jaxlib.mlir.dialects import gpu +from jaxlib.mlir.dialects import llvm +from jaxlib.mlir.dialects import memref +from jaxlib.mlir.dialects import nvvm +from jaxlib.mlir.passmanager import PassManager +import numpy as np + +from . import dsl as mgpu +from . import profiler +from . import utils + +# mypy: ignore-errors + +# MLIR can't find libdevice unless we point it to the CUDA path +# TODO(apaszke): Unify with jax._src.lib.cuda_path +CUDA_ROOT = "/usr/local/cuda" +if os.environ.get("CUDA_ROOT") is None: + os.environ["CUDA_ROOT"] = CUDA_ROOT +else: + CUDA_ROOT = os.environ["CUDA_ROOT"] + +PTXAS_PATH = os.path.join(CUDA_ROOT, "bin/ptxas") +NVDISASM_PATH = os.path.join(CUDA_ROOT, "bin/nvdisasm") + +TMA_DESCRIPTOR_BYTES = 128 +TMA_DESCRIPTOR_ALIGNMENT = 64 + + +c = mgpu.c # This is too common to fully qualify. + + +RUNTIME_PATH = None +try: + from jax._src.lib import mosaic_gpu as mosaic_gpu_lib + + RUNTIME_PATH = ( + pathlib.Path(mosaic_gpu_lib._mosaic_gpu_ext.__file__).parent + / "libmosaic_gpu_runtime.so" + ) +except ImportError: + pass + +if RUNTIME_PATH and RUNTIME_PATH.exists(): + # Set this so that the custom call can find it + os.environ["MOSAIC_GPU_RUNTIME_LIB_PATH"] = str(RUNTIME_PATH) + + +mosaic_gpu_p = jax.core.Primitive("mosaic_gpu_p") +mosaic_gpu_p.multiple_results = True + + +@mosaic_gpu_p.def_abstract_eval +def _mosaic_gpu_abstract_eval(*_, module, out_types, gmem_scratch_bytes): + del module, gmem_scratch_bytes # Unused. + return [jax._src.core.ShapedArray(t.shape, t.dtype) for t in out_types] + +# TODO(apaszke): Implement a proper system for managing kernel lifetimes +kernel_idx = itertools.count() + +def _mosaic_gpu_lowering_rule(ctx, *args, module, out_types, gmem_scratch_bytes): + del out_types # Unused. + idx_bytes = next(kernel_idx).to_bytes(8, byteorder="little") + op = mlir.custom_call( + "mosaic_gpu", + result_types=[ + *(mlir.aval_to_ir_type(aval) for aval in ctx.avals_out), + mlir.aval_to_ir_type( + jax_core.ShapedArray((gmem_scratch_bytes,), np.uint8) + ), + ], + operands=args, + backend_config=idx_bytes + + module.operation.get_asm(binary=True, enable_debug_info=True), + ) + return op.results[:-1] # Skip the scratch space. + +mlir.register_lowering(mosaic_gpu_p, _mosaic_gpu_lowering_rule, "cuda") + + +@dataclasses.dataclass(frozen=True) +class MemRefTransform: + def apply(self, ref: ir.Value) -> ir.Value: + raise NotImplementedError("Subclasses should override this method") + + def transform_index(self, idx: Sequence[ir.Value]) -> tuple[ir.Value, ...]: + raise NotImplementedError("Subclasses should override this method") + + def transform_shape(self, shape: Sequence[int]) -> tuple[int, ...]: + raise NotImplementedError("Subclasses should override this method") + + +@dataclasses.dataclass(frozen=True) +class TileTransform(MemRefTransform): + """Tiles a suffix of memref dimensions. + + For example, given a memref of shape (5, 128, 128) and a tiling of (64, 32), + the shape of the result will be (5, 2, 4, 64, 32). The shape always ends with + the tile shape, and the size of tiled dimensions is divided by the tile size. + This is especially useful for swizzled WGMMA, which expect tiled layouts in + shared memory. + """ + tiling: tuple[int, ...] + + def apply(self, ref: ir.Value) -> ir.Value: + untiled_rank = ir.MemRefType(ref.type).rank + tiling_rank = len(self.tiling) + tiled_rank = untiled_rank + tiling_rank + for t, d in zip(self.tiling[::-1], range(untiled_rank)[::-1]): + ref = mgpu.memref_unfold(ref, d, (None, t)) + permutation = ( + *range(untiled_rank - tiling_rank), + *range(untiled_rank - tiling_rank, tiled_rank, 2), + *range(untiled_rank - tiling_rank + 1, tiled_rank, 2), + ) + return mgpu.memref_transpose(ref, permutation) + + def transform_index(self, idx: Sequence[ir.Value]) -> tuple[ir.Value, ...]: + index = ir.IndexType.get() + tiling_rank = len(self.tiling) + return ( + *idx[:-tiling_rank], + *( + arith.divui(i, c(t, index)) + for i, t in zip(idx[-tiling_rank:], self.tiling) + ), + *( + arith.remui(i, c(t, index)) + for i, t in zip(idx[-tiling_rank:], self.tiling) + ), + ) + + def transform_shape(self, shape: Sequence[int]) -> tuple[int, ...]: + # Note that this also checks that tiled dims are not squeezed. Their slice + # size would be 1 if so. + tiling_rank = len(self.tiling) + for size, tile_size in zip(shape[-tiling_rank:], self.tiling): + if size % tile_size: + raise ValueError( + f"Expected GMEM slice shape {shape} suffix to be a multiple" + f" of tiling {self.tiling}" + ) + return ( + *shape[:-tiling_rank], + *(s // t for s, t in zip(shape[-tiling_rank:], self.tiling)), + *self.tiling, + ) + + +@dataclasses.dataclass(frozen=True) +class TransposeTransform(MemRefTransform): + """Transposes memref dimensions.""" + permutation: tuple[int, ...] + + def __post_init__(self): + if len(self.permutation) != len(set(self.permutation)): + raise ValueError("Permutation must be a permutation") + + def apply(self, ref: ir.Value) -> ir.Value: + return mgpu.memref_transpose(ref, self.permutation) + + def transform_index(self, idx: Sequence[ir.Value]) -> tuple[ir.Value, ...]: + return tuple(idx[p] for p in self.permutation) + + def transform_shape(self, shape: Sequence[int]) -> tuple[int, ...]: + return tuple(shape[p] for p in self.permutation) + + +OnDeviceProfiler = profiler.OnDeviceProfiler + + +@dataclasses.dataclass() +class LaunchContext: + launch_op: gpu.LaunchOp + profiler: OnDeviceProfiler | None = None + next_scratch_offset: int = 0 + host_scratch_init: list[Callable[[ir.Value], None]] = dataclasses.field( + default_factory=list, init=False + ) + tma_descriptors: dict[ + tuple[ir.Value, tuple[int, ...], int | None, tuple[MemRefTransform, ...]], + ir.Value, + ] = dataclasses.field(default_factory=dict, init=False) + + @contextlib.contextmanager + def named_region(self, *args, **kwargs): + if self.profiler is not None: + with self.profiler.record(*args, **kwargs): + yield + else: + yield + + def _alloc_scratch( + self, + size: int, + alignment: int | None = None, + host_init: Callable[[ir.Value], None] = lambda _: None, + device_init: Callable[[ir.Value], Any] = lambda x: x, + ) -> ir.Value: + """Allocates a GMEM scratch buffer. + + The buffer is initialized on the host and then copied to GMEM before the + kernel launch. + """ + i8 = ir.IntegerType.get_signless(8) + ptr_ty = ir.Type.parse("!llvm.ptr") + if alignment is None: + alignment = size + if self.next_scratch_offset % alignment: + raise NotImplementedError # TODO(apaszke): Pad to match alignment + alloc_base = self.next_scratch_offset + self.next_scratch_offset += size + def host_init_wrapped(host_ptr): + host_init( + llvm.getelementptr(ptr_ty, host_ptr, [], [alloc_base], i8) + ) + self.host_scratch_init.append(host_init_wrapped) + + with ir.InsertionPoint.at_block_begin(self.launch_op.body.blocks[0]): + ptr_ty = ir.Type.parse("!llvm.ptr") + const_ptr_ty = ir.Type.parse("!llvm.ptr<4>") + gmem_scratch_ptr = llvm.call_intrinsic( + ptr_ty, + "llvm.nvvm.ptr.constant.to.gen.p0.p4", + [llvm.mlir_addressof(const_ptr_ty, "global_scratch")], + ) + return device_init(llvm.getelementptr( + ptr_ty, gmem_scratch_ptr, [], [alloc_base], i8 + )) + + def _get_tma_desc( + self, + ref, + gmem_transform: tuple[MemRefTransform, ...], + transformed_slice_shape: tuple[int, ...], + swizzle: int | None, + ): + tma_desc_key = (ref, transformed_slice_shape, swizzle, gmem_transform) + if (tma_desc := self.tma_descriptors.get(tma_desc_key, None)) is None: + with ir.InsertionPoint(self.launch_op): + for t in gmem_transform: + ref = t.apply(ref) + ref_ty = ir.MemRefType(ref.type) + + i64 = ir.IntegerType.get_signless(64) + ptr_ty = ir.Type.parse("!llvm.ptr") + def init_tma_desc(host_ptr): + _, offset, *sizes_and_strides = memref.extract_strided_metadata(ref) + aligned_ptr_idx = memref.extract_aligned_pointer_as_index(ref) + as_i64 = lambda i: arith.index_cast(i64, i) + alloc_ptr = llvm.inttoptr(ptr_ty, as_i64(aligned_ptr_idx)) + llvm_dyn = -2147483648 # TODO(apaszke): Improve the MLIR bindings... + base_ptr = llvm.getelementptr( + ptr_ty, alloc_ptr, [as_i64(offset)], [llvm_dyn], ref_ty.element_type, + ) + rank = ref_ty.rank + assert rank * 2 == len(sizes_and_strides) + args = [ + host_ptr, + base_ptr, + c(utils.bytewidth(ref_ty.element_type), i64), + c(rank, i64), + utils.pack_array([as_i64(i) for i in sizes_and_strides[:rank]]), + utils.pack_array([as_i64(i) for i in sizes_and_strides[rank:]]), + c(0 if swizzle is None else swizzle, i64), + utils.pack_array([c(v, i64) for v in transformed_slice_shape]), + ] + func.call([], "mosaic_gpu_init_tma_desc", args) + def cast_tma_desc(device_ptr): + # TODO(apaszke): Investigate why prefetching can cause launch failures + # nvvm.prefetch_tensormap(device_ptr) + return device_ptr + tma_desc = self._alloc_scratch( + TMA_DESCRIPTOR_BYTES, + alignment=TMA_DESCRIPTOR_ALIGNMENT, + host_init=init_tma_desc, + device_init=cast_tma_desc, + ) + self.tma_descriptors[tma_desc_key] = tma_desc + return tma_desc + + def async_copy( + self, + *, + src_ref, + dst_ref, + gmem_slice: Any = (), + gmem_transform: MemRefTransform | tuple[MemRefTransform, ...] = (), + barrier: mgpu.Barrier | None = None, + swizzle: int | None = None, + arrive: bool | None = None, + uniform: bool = True, + ): + index = ir.IndexType.get() + i32 = ir.IntegerType.get_signless(32) + smem = ir.Attribute.parse("#gpu.address_space") + src_ref_ty = ir.MemRefType(src_ref.type) + dst_ref_ty = ir.MemRefType(dst_ref.type) + element_type = src_ref_ty.element_type + if element_type != dst_ref_ty.element_type: + raise ValueError( + f"Expected same element type, got {element_type} and" + f" {dst_ref_ty.element_type}" + ) + if not isinstance(gmem_transform, tuple): + gmem_transform = (gmem_transform,) + + if src_ref_ty.memory_space is None and dst_ref_ty.memory_space == smem: + gmem_ref, smem_ref = src_ref, dst_ref + if barrier is None: + raise ValueError("Barriers are required for GMEM -> SMEM copies") + if arrive is None: + arrive = True # Arrive by default + elif src_ref_ty.memory_space == smem and dst_ref_ty.memory_space is None: + gmem_ref, smem_ref = dst_ref, src_ref + if barrier is not None: + raise ValueError("Barriers are unsupported for SMEM -> GMEM copies") + if arrive is not None: + raise ValueError("arrive is unsupported for SMEM -> GMEM copies") + else: + raise ValueError("Only SMEM <-> GMEM copies supported") + # TODO(apaszke): This is a very approximate check. Improve it! + expected_name = "builtin.unrealized_conversion_cast" + if ( + gmem_ref.owner is None + or gmem_ref.owner.opview.OPERATION_NAME != expected_name + ): + raise ValueError("GMEM reference in async_copy must be a kernel argument") + + base_indices, slice_shape, is_squeezed = utils.parse_indices( + gmem_slice, ir.MemRefType(gmem_ref.type).shape + ) + dyn_base_indices = tuple( + c(i, index) if not isinstance(i, ir.Value) else i for i in base_indices + ) + slice_shape = tuple(slice_shape) + for t in gmem_transform: + dyn_base_indices = t.transform_index(dyn_base_indices) + slice_shape = t.transform_shape(slice_shape) + for dim, squeezed in enumerate(is_squeezed): + if squeezed: + smem_ref = mgpu.memref_unsqueeze(smem_ref, dim) + smem_ref_ty = ir.MemRefType(smem_ref.type) + + if slice_shape != tuple(smem_ref_ty.shape): + raise ValueError( + "Expected the SMEM reference to have the same shape as the tiled" + f" slice: {tuple(smem_ref_ty.shape)} != {slice_shape}" + ) + tma_desc = self._get_tma_desc( + gmem_ref, gmem_transform, slice_shape, swizzle, + ) + + # We constuct TMA descriptors in column-major order. + rev_dyn_base_indices = [ + arith.index_cast(i32, idx) for idx in reversed(dyn_base_indices) + ] + + uniform_ctx = ( + functools.partial(mgpu.single_thread, per_block=False) + if uniform + else contextlib.nullcontext + ) + + rank = len(slice_shape) + if rank > 5: # TODO: apaszke - Implement stride compression + raise ValueError("Async copies only support striding up to 5 dimensions") + smem_ptr = utils.memref_ptr(smem_ref, memory_space=3) + if gmem_ref is src_ref: + assert barrier is not None # for pytype + slice_bytes = c(np.prod(slice_shape) * mgpu.bytewidth(element_type), i32) + barrier_ptr = barrier.get_ptr() + with uniform_ctx(): + if arrive: + nvvm.mbarrier_arrive_expect_tx_shared(barrier_ptr, slice_bytes) + nvvm.cp_async_bulk_tensor_shared_cluster_global( + smem_ptr, tma_desc, rev_dyn_base_indices, barrier_ptr, [] + ) + else: + with uniform_ctx(): + nvvm.cp_async_bulk_tensor_global_shared_cta( + tma_desc, smem_ptr, rev_dyn_base_indices + ) + nvvm.cp_async_bulk_commit_group() + + def await_async_copy( + self, allow_groups: int, await_read_only: bool = False + ): + nvvm.cp_async_bulk_wait_group(allow_groups, read=await_read_only) + # TODO(apaszke): Use a warpgroup barrier!!! + gpu.barrier() # Groups are supposedly tracked per-thread + + +# ShapeTrees currently can not contain unions. +ShapeTree = Any +RefTree = Any +T = TypeVar('T') + + +@dataclasses.dataclass(frozen=True) +class Union(Generic[T]): + members: Sequence[T] + + +def _count_buffer_bytes(shape_dtype: jax.ShapeDtypeStruct) -> int: + return np.prod(shape_dtype.shape) * np.dtype(shape_dtype.dtype).itemsize + + +def _construct_smem_reftree( + dynamic_smem: ir.Value, smem_buffers: ShapeTree) -> RefTree: + index = ir.IndexType.get() + smem = ir.Attribute.parse("#gpu.address_space") + flat_ref_tys, smem_buffer_tree = jax.tree.flatten(smem_buffers) + smem_refs = [] + dynamic_smem_offset = 0 + for ref_ty in flat_ref_tys: + mlir_dtype = mlir.dtype_to_ir_type(ref_ty.dtype) + tile_smem = memref.view( + ir.MemRefType.get(ref_ty.shape, mlir_dtype, memory_space=smem), + dynamic_smem, c(dynamic_smem_offset, index), [], + ) + dynamic_smem_offset += _count_buffer_bytes(ref_ty) + smem_refs.append(tile_smem) + return jax.tree.unflatten(smem_buffer_tree, smem_refs) + + +# TODO(apaszke): Inline this +@contextlib.contextmanager +def _launch( + token, + grid, + block, + smem_buffers: ShapeTree | Union[ShapeTree], + profiler_spec: profiler.ProfilerSpec | None = None, + maybe_prof_buffer: ir.Value | None = None, +): + if (profiler_spec is None) != (maybe_prof_buffer is None): + raise ValueError + index = ir.IndexType.get() + i32 = ir.IntegerType.get_signless(32) + i8 = ir.IntegerType.get_signless(8) + grid_vals = [c(i, index) for i in grid] + block_vals = [c(i, index) for i in block] + + if isinstance(smem_buffers, Union): + smem_disjoint_live_buffers_collections = smem_buffers.members + compute_smem_bytes = max( + sum(_count_buffer_bytes(l) for l in jax.tree.leaves(s)) + for s in smem_buffers.members) + else: + smem_disjoint_live_buffers_collections = [smem_buffers] + compute_smem_bytes = sum( + _count_buffer_bytes(l) for l in jax.tree.leaves(smem_buffers)) + + smem_bytes = compute_smem_bytes + if profiler_spec is not None: + smem_bytes += profiler_spec.smem_bytes(block=block) + + # TODO(cperivol): Query the shared memory size programmatically. + if smem_bytes > 228 * 1024: + raise ValueError(f"Mosaic GPU kernel exceeds available shared memory {smem_bytes=} > 228000") + launch_op = gpu.LaunchOp( + token.type, [token], *grid_vals, *block_vals, + dynamicSharedMemorySize=c(smem_bytes, i32)) + launch_op.body.blocks.append(*([index] * 12)) # Append an empty block + smem = ir.Attribute.parse("#gpu.address_space") + with ir.InsertionPoint(launch_op.body.blocks[0]): + dynamic_smem = gpu.dynamic_shared_memory( + ir.MemRefType.get( + (ir.ShapedType.get_dynamic_size(),), i8, memory_space=smem + ) + ) + smem_ref_trees = [] + + for smem_live_buffers_collection in smem_disjoint_live_buffers_collections: + smem_ref_tree = _construct_smem_reftree( + dynamic_smem, smem_live_buffers_collection) + smem_ref_trees.append(smem_ref_tree) + + if profiler_spec: + prof_smem = memref.view( + ir.MemRefType.get( + (profiler_spec.smem_i32_elements(block=block),), + i32, memory_space=smem, + ), + dynamic_smem, c(compute_smem_bytes, index), [], + ) + prof = profiler.OnDeviceProfiler( + profiler_spec, prof_smem, maybe_prof_buffer + ) + else: + prof = None + + if isinstance(smem_buffers, Union): + smem_ref_tree: Union[RefTree] = Union(smem_ref_trees) + else: + smem_ref_tree: RefTree = smem_ref_trees[0] if smem_ref_trees else [] + + yield LaunchContext(launch_op, prof), smem_ref_tree + if prof is not None: + prof.finalize(grid=grid, block=block) + gpu.terminator() + + +def _lower_as_gpu_kernel( + body, + grid: tuple[int, ...], + block: tuple[int, ...], + in_shapes: tuple[Any, ...], + out_shape, + smem_scratch_shape: ShapeTree | Union[ShapeTree], + prof_spec: profiler.ProfilerSpec | None = None, +): + ptr_ty = ir.Type.parse("!llvm.ptr") + token_ty = ir.Type.parse("!gpu.async.token") + i8 = ir.IntegerType.get_signless(8) + i32 = ir.IntegerType.get_signless(32) + i64 = ir.IntegerType.get_signless(64) + + def _shape_to_ref_ty(shape: jax.ShapeDtypeStruct) -> ir.MemRefType: + return ir.MemRefType.get(shape.shape, mlir.dtype_to_ir_type(shape.dtype)) + + in_ref_tys = [_shape_to_ref_ty(t) for t in in_shapes] + + unwrap_output_tuple = False + if isinstance(out_shape, list): + out_shape = tuple(out_shape) + elif not isinstance(out_shape, tuple): + out_shape = (out_shape,) + unwrap_output_tuple = True + out_ref_tys = [_shape_to_ref_ty(t) for t in out_shape] + if prof_spec is not None: + out_shape = (*out_shape, prof_spec.jax_buffer_type(grid, block)) + out_ref_tys.append(prof_spec.mlir_buffer_type(grid, block)) + + module = ir.Module.create() + with ir.InsertionPoint(module.body): + _declare_runtime_functions() + gmem_scratch_bytes = 0 + global_scratch = llvm.GlobalOp( + ir.Type.parse("!llvm.array<0 x i8>"), # We don't know the shape yet. + "global_scratch", + ir.Attribute.parse("#llvm.linkage"), + addr_space=ir.IntegerAttr.get(i32, 4), # GPU constant memory. + ) + @func.FuncOp.from_py_func(ptr_ty, ptr_ty, ptr_ty) + def main(token_ptr, buffers, gmem_scratch_ptr): + nonlocal gmem_scratch_bytes + token = builtin.unrealized_conversion_cast([token_ty], [token_ptr]) + arg_refs = [] + for i, ref_ty in enumerate([*in_ref_tys, *out_ref_tys]): + ptr = llvm.LoadOp(ptr_ty, llvm.GEPOp(ptr_ty, buffers, [], [i], ptr_ty)) + arg_refs.append(utils.ptr_as_memref(ptr, ir.MemRefType(ref_ty))) + in_refs = arg_refs[:len(in_ref_tys)] + out_refs = arg_refs[len(in_ref_tys):] + prof_buffer = out_refs.pop() if prof_spec is not None else None + with _launch( + token, grid, block, smem_scratch_shape, + prof_spec, prof_buffer + ) as (launch_ctx, smem_refs): + body(launch_ctx, *in_refs, *out_refs, smem_refs) + gmem_scratch_bytes = launch_ctx.next_scratch_offset + # Allocate and initialize the host buffer right before the launch. + # Note that we couldn't do that before, because we had to run the body + # to learn what the scratch contains. + with ir.InsertionPoint(launch_ctx.launch_op): + host_scratch_ptr = llvm.alloca(ptr_ty, c(gmem_scratch_bytes, i64), i8) + for init_callback in launch_ctx.host_scratch_init: + init_callback(host_scratch_ptr) + global_scratch.global_type = ir.TypeAttr.get( + ir.Type.parse("!llvm.array<" + str(gmem_scratch_bytes) + " x i8>") + ) + func.call( + [], + "mosaic_gpu_memcpy_async_h2d", + [ + gmem_scratch_ptr, + host_scratch_ptr, + c(gmem_scratch_bytes, i64), + token_ptr, + ], + ) + main.func_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get() + sym_tab = ir.SymbolTable(module.operation) + sym_tab.insert(main.func_op) + sym_tab.insert(global_scratch) + module.operation.verify() + + return module, out_shape, gmem_scratch_bytes, unwrap_output_tuple + + +def as_gpu_kernel( + body, + grid: tuple[int, ...], + block: tuple[int, ...], + in_shape, + out_shape, + smem_scratch_shape: ShapeTree | Union[ShapeTree], + prof_spec: profiler.ProfilerSpec | None = None, +): + if isinstance(in_shape, list): + in_shape = tuple(in_shape) + elif not isinstance(in_shape, tuple): + in_shape = (in_shape,) + + module, out_shape, gmem_scratch_bytes, unwrap_output_tuple = ( + _lower_as_gpu_kernel( + body, grid, block, in_shape, out_shape, smem_scratch_shape, prof_spec + ) + ) + + expected_arg_treedef = jax.tree.structure(in_shape) + def _check_args(*args): + arg_treedef = jax.tree.structure(args) + if arg_treedef != expected_arg_treedef: + raise ValueError( + f"Invalid argument structure: expected {expected_arg_treedef}, got" + f" {arg_treedef}, ({args=})" + ) + + def bind(*args): + return mosaic_gpu_p.bind( + *args, + out_types=out_shape, + module=module, + gmem_scratch_bytes=gmem_scratch_bytes, + ) + + if prof_spec is not None: + @jax.jit + def prof_kernel(*args): + _check_args(*args) + *results, prof_buffer = bind(*args) + def dump_profile(prof_buffer): + out_file = os.path.join( + os.getenv("TEST_UNDECLARED_OUTPUTS_DIR"), + f"{time.time_ns()}-trace.json", + ) + try: + with open(out_file, "x") as f: + prof_spec.dump(prof_buffer, f, grid=grid, block=block) + except FileExistsError: + pass # TODO: Retry + jax.debug.callback(dump_profile, prof_buffer) + return results[0] if unwrap_output_tuple else results + return prof_kernel + else: + @jax.jit + def kernel(*args): + _check_args(*args) + results = bind(*args) + return results[0] if unwrap_output_tuple else results + return kernel + + +def _declare_runtime_functions(): + """Declares the runtime functions that can be used by the generated code.""" + ptr_ty = ir.Type.parse("!llvm.ptr") + i64 = ir.IntegerType.get_signless(64) + arg_tys = [ptr_ty, ptr_ty, i64, i64, ptr_ty, ptr_ty, i64, ptr_ty] + init_tma_desc_type = ir.FunctionType.get(arg_tys, []) + func.FuncOp( + "mosaic_gpu_init_tma_desc", init_tma_desc_type, visibility="private" + ) + memcpy_async_type = ir.FunctionType.get([ptr_ty, ptr_ty, i64, ptr_ty], []) + func.FuncOp( + "mosaic_gpu_memcpy_async_h2d", memcpy_async_type, visibility="private" + ) diff --git a/jax/experimental/mosaic/gpu/dsl.py b/jax/experimental/mosaic/gpu/dsl.py new file mode 100644 index 000000000000..bd8960c74c2b --- /dev/null +++ b/jax/experimental/mosaic/gpu/dsl.py @@ -0,0 +1,50 @@ +# Copyright 2024 The JAX Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from .fragmented_array import ( + FragmentedArray, + FragmentedLayout, + WGMMA_LAYOUT, + WGMMA_ROW_LAYOUT, + WGStridedFragLayout, +) +from .utils import ( + Barrier, + BarrierArray, + DynamicSlice, + Partition, + Partition1D, + bytewidth, + c, + commit_shared, + debug_print, + ds, + fori, + memref_fold, + memref_slice, + memref_transpose, + memref_unfold, + memref_unsqueeze, + single_thread, + thread_idx, + tile_shape, + warp_idx, + warpgroup_idx, +) +from .wgmma import ( + WGMMAAccumulator, + WGMMALayout, + wgmma, +) diff --git a/jax/experimental/mosaic/gpu/examples/BUILD b/jax/experimental/mosaic/gpu/examples/BUILD new file mode 100644 index 000000000000..3f9496b38376 --- /dev/null +++ b/jax/experimental/mosaic/gpu/examples/BUILD @@ -0,0 +1,65 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("//jaxlib:jax.bzl", "py_deps") +load("@rules_python//python:defs.bzl", "py_library", "py_test") + +licenses(["notice"]) + +package( + default_applicable_licenses = [], + default_visibility = ["//third_party/py/jax:mosaic_gpu_users"], +) + +exports_files( + srcs = [ + "flash_attention.py", + "matmul.py", + ], + visibility = ["//third_party/py/jax:internal"], +) + +py_library( + name = "matmul", + srcs = ["matmul.py"], + deps = [ + "//third_party/py/jax", + "//third_party/py/jax:mosaic_gpu", + ], +) + +py_library( + name = "flash_attention", + srcs = ["flash_attention.py"], + deps = [ + "//third_party/py/jax", + "//third_party/py/jax:mosaic_gpu", + ], +) + +py_test( + name = "run_matmul", + srcs = ["matmul.py"], + main = "matmul.py", + tags = [ + "manual", + "notap", + "requires-gpu-sm90-only", + ], + deps = [ + "//learning/brain/research/jax:gpu_support", + "//third_party/py/jax", + "//third_party/py/jax:mosaic_gpu", + ] + py_deps("numpy"), +) diff --git a/jax/experimental/mosaic/gpu/examples/flash_attention.py b/jax/experimental/mosaic/gpu/examples/flash_attention.py new file mode 100644 index 000000000000..e680faffee6e --- /dev/null +++ b/jax/experimental/mosaic/gpu/examples/flash_attention.py @@ -0,0 +1,650 @@ +# Copyright 2024 The JAX Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import contextlib +import dataclasses +import enum +import itertools +import os + +from absl import app +import jax +from jax import random +from jax._src.interpreters import mlir +from jax._src import test_util as jtu +from jax.experimental.mosaic import gpu as mosaic_gpu +from jax.experimental.mosaic.gpu import profiler +from jax.experimental.mosaic.gpu.dsl import * # noqa: F403 +import jax.numpy as jnp +from jaxlib.mlir import ir +from jaxlib.mlir.dialects import arith +from jaxlib.mlir.dialects import gpu +from jaxlib.mlir.dialects import nvgpu +from jaxlib.mlir.dialects import nvvm +from jaxlib.mlir.dialects import scf +import numpy as np + +# mypy: ignore-errors +# ruff: noqa: F405 + +@dataclasses.dataclass(frozen=True) +class BlockSizes: + q: int + kv: int + stages: int + +_utils_c = c + + +# TODO(apaszke): Implement a Q-scaled, base2 exp implementation. +class ExpImplementation(enum.Enum): + EXACT = enum.auto() + APPROX = enum.auto() + + +class Implementation(enum.Enum): + TWO_COMPUTE_WG = enum.auto() + TWO_COMPUTE_ONE_TMA_WG = enum.auto() + + +def build_kernel( + batch_size: int, + q_heads: int, + kv_heads: int, + q_seq_len: int, + kv_seq_len: int, + head_dim: int, + blocks: BlockSizes, + prof_spec: profiler.ProfilerSpec | None = None, + exp_impl: ExpImplementation = ExpImplementation.EXACT, + impl: Implementation = Implementation.TWO_COMPUTE_WG, +): + compute_wgs_per_block = 2 + match impl: + case Implementation.TWO_COMPUTE_WG: + wgs_per_block = 2 + case Implementation.TWO_COMPUTE_ONE_TMA_WG: + wgs_per_block = 3 + + if batch_size != 1: + raise NotImplementedError + if blocks.stages < 2: + raise ValueError("Kernel requires at least 2 stages.") + if q_heads % kv_heads: + raise ValueError("kv_heads must divide q_heads.") + if q_seq_len % (blocks.q * compute_wgs_per_block): + raise ValueError + if kv_seq_len % blocks.kv: + raise ValueError + if blocks.q % 64: + raise NotImplementedError + if blocks.kv % 64: + raise NotImplementedError + if head_dim % 64: + raise NotImplementedError + if blocks.stages * blocks.kv > kv_seq_len: + raise NotImplementedError + + q_shape = jax.ShapeDtypeStruct( + (q_heads, q_seq_len, head_dim), jnp.float16 + ) + kv_shape = jax.ShapeDtypeStruct( + (kv_heads, kv_seq_len, head_dim), jnp.float16 + ) + q_heads_per_kv_head = q_heads // kv_heads + + def exp(x: FragmentedArray) -> FragmentedArray: + return x.exp(approx=exp_impl == ExpImplementation.APPROX) + + block_partition = Partition( + elements=(batch_size, q_seq_len, q_heads), + partition=(0, 1, 2), + chunk_size=(1, blocks.q * compute_wgs_per_block, 1), + ) + + index = ir.IndexType.get() + i32 = ir.IntegerType.get_signless(32) + f16 = ir.F16Type.get() + f32 = ir.F32Type.get() + + grid = block_partition.num_chunks + block = (wgs_per_block * 128, 1, 1) + tiling = (64, 64) + qo_scratch = jax.ShapeDtypeStruct( + (compute_wgs_per_block, *tile_shape((blocks.q, head_dim), tiling)), + jnp.float16, + ) + k_scratch = jax.ShapeDtypeStruct( + tile_shape((blocks.stages, head_dim, blocks.kv), tiling), jnp.float16 + ) + v_scratch = jax.ShapeDtypeStruct( + tile_shape((blocks.stages, blocks.kv, head_dim), tiling), jnp.float16 + ) + smem_scratch_shape = [ + qo_scratch, + k_scratch, + v_scratch, + ] + in_shape = (q_shape, kv_shape, kv_shape) + out_shape = q_shape + + def c(value, ty=index): + return _utils_c(value, ty) + + def tma_wg_kernel( + ctx: mosaic_gpu.LaunchContext, + q_gmem, + k_gmem, + v_gmem, + out_gmem, + smem_scratch, + ): + k_barriers = BarrierArray(blocks.stages) + v_barriers = BarrierArray(blocks.stages) + q_barriers = BarrierArray(compute_wgs_per_block) + k_consumed_barrier, v_consumed_barrier = BarrierArray(2, arrival_count=256) + schedule_barrier = BarrierArray(1, arrival_count=256)[0] + @ctx.named_region("Schedule barrier") + def perform_schedule_barrier(): + schedule_barrier.arrive() + schedule_barrier.wait() + wg_idx = warpgroup_idx(sync=True) + qo_smem, k_smem, v_smem = smem_scratch + qo_smem = memref_slice(qo_smem, arith.index_cast(index, wg_idx)) + + @contextlib.contextmanager + def only_wg(idx): + is_wg = arith.cmpi(arith.CmpIPredicate.eq, wg_idx, c(idx, i32)) + with ir.InsertionPoint(scf.IfOp(is_wg).then_block): + yield + scf.yield_([]) + + batch_idx, q_seq_base, q_head_idx = block_partition.get_base( + gpu.block_id(gpu.Dimension.x), + gpu.block_id(gpu.Dimension.y), + gpu.block_id(gpu.Dimension.z), + ) + q_seq_base = arith.addi( + q_seq_base, arith.muli(arith.index_cast(index, wg_idx), c(blocks.q)) + ) + del batch_idx + + loop_partition = Partition1D(kv_seq_len, chunk_size=blocks.kv) + if_compute = scf.IfOp( + arith.cmpi(arith.CmpIPredicate.ne, wg_idx, c(2, i32)), hasElse=True + ) + with ir.InsertionPoint(if_compute.then_block): + nvvm.setmaxregister(232, nvvm.SetMaxRegisterAction.increase) + with ctx.named_region("Q TMA start"): + ctx.async_copy( + src_ref=q_gmem, + gmem_slice=(q_head_idx, ds(q_seq_base, blocks.q)), + gmem_transform=mosaic_gpu.TileTransform(tiling), + dst_ref=qo_smem, + barrier=q_barriers[wg_idx], + swizzle=128, + ) + + with ctx.named_region("Q TMA wait"): + q_barriers[wg_idx].wait() + + m_i = FragmentedArray.splat( + c(-jnp.inf, f32), shape=(blocks.q,), layout=WGMMA_ROW_LAYOUT + ) + l_i = FragmentedArray.splat( + c(0, f32), shape=(blocks.q,), layout=WGMMA_ROW_LAYOUT + ) + acc = FragmentedArray.splat( + c(0, f32), shape=(blocks.q, head_dim), layout=WGMMA_LAYOUT + ) + + k_barriers[c(0)].wait() + + with only_wg(1): + perform_schedule_barrier() + + @fori(c(loop_partition.num_chunks), (acc, m_i, l_i)) + def kv_loop(kv_step, carry): + acc, m_i, l_i = carry + slot = arith.remui(kv_step, c(blocks.stages)) + + with ctx.named_region("QK issue"): + # TODO(apaszke): Support WGMMA without an initial accumulator. + qk_acc = WGMMAAccumulator.zero(blocks.q, blocks.kv) + q, k = qo_smem, memref_slice(k_smem, slot) + qk_acc = wgmma(qk_acc, q, k, b_order=WGMMALayout.COL_MAJOR) + nvvm.wgmma_commit_group_sync_aligned() + + perform_schedule_barrier() + + with ctx.named_region("QK wait"): + nvvm.wgmma_wait_group_sync_aligned(0) + k_consumed_barrier.arrive() + qk = qk_acc.value + + with ctx.named_region("Softmax"): + m_ij = m_i.max(qk.reduce(arith.maximumf, axis=1)) + alpha = exp(m_i - m_ij) + m_i = m_ij + p = exp(qk - m_ij.broadcast_minor(blocks.kv)) + acc *= alpha.broadcast_minor(head_dim) + l_i *= alpha + p16 = p.astype(f16) + + with ctx.named_region("V TMA wait"): + v_barriers[slot].wait() + + perform_schedule_barrier() + + # This is quite suprising, but it seems like warp shuffles cannot + # run simutaneously with the WGMMA. For that reason we include it as + # part of the TensorCore critical section and not the ALU section. + with ctx.named_region("Softmax reduction"): + l_i += p.reduce(arith.addf, axis=1) + + with ctx.named_region("PV issue"): + v = memref_slice(v_smem, slot) + acc_update = WGMMAAccumulator.from_registers(acc) + acc_update = wgmma(acc_update, p16, v) + nvvm.wgmma_commit_group_sync_aligned() + + # We hide the barrier overhead by overlapping it with the PV matmul. + with ctx.named_region("K TMA wait"): + wait_step = arith.addi(kv_step, c(1)) + wait_slot = arith.remui(wait_step, c(blocks.stages)) + wait_step_in_bounds = arith.cmpi( + arith.CmpIPredicate.slt, wait_step, c(loop_partition.num_chunks) + ) + with ir.InsertionPoint(scf.IfOp(wait_step_in_bounds).then_block): + k_barriers[wait_slot].wait() + scf.yield_([]) + + with ctx.named_region("PV wait"): + nvvm.wgmma_wait_group_sync_aligned(0) + v_consumed_barrier.arrive() + acc = acc_update.value + + return acc, m_i, l_i + + with only_wg(0): + perform_schedule_barrier() + + acc, m_i, l_i = kv_loop.results + del m_i + # TODO(apaszke): Invert and multiply to avoid expensive divisions. + acc /= l_i.broadcast_minor(head_dim) + + with ctx.named_region("Acc store"): + acc.astype(f16).store_tiled(qo_smem, swizzle=128) + gpu.barrier() + nvvm.fence_proxy( + nvvm.ProxyKind.async_shared, space=nvvm.SharedSpace.shared_cta + ) # Make sure the store is visible to the TMA. + + with ctx.named_region("GMEM store"): + ctx.async_copy( + src_ref=qo_smem, + dst_ref=out_gmem, + gmem_slice=(q_head_idx, ds(q_seq_base, blocks.q)), + gmem_transform=mosaic_gpu.TileTransform(tiling), + swizzle=128, + ) + ctx.await_async_copy(0) + + scf.yield_([]) + with ir.InsertionPoint(if_compute.else_block): + nvvm.setmaxregister(40, nvvm.SetMaxRegisterAction.decrease) + with single_thread(per_block=False): + k_tr = ( + mosaic_gpu.TileTransform(tiling), + mosaic_gpu.TransposeTransform((0, 2, 1, 3, 4)), + ) + v_tr = mosaic_gpu.TileTransform(tiling) + kv_head_idx = arith.divui(q_head_idx, c(q_heads_per_kv_head)) + def start_kv_copy(slot, kv_seq_base, smem, gmem, barrier, transform): + ctx.async_copy( + dst_ref=memref_slice(smem, slot), + src_ref=gmem, + gmem_slice=(kv_head_idx, ds(kv_seq_base, blocks.kv)), + gmem_transform=transform, + barrier=barrier, + uniform=False, + swizzle=128, + ) + def start_k_copy(slot, kv_seq_base): + return start_kv_copy( + slot, kv_seq_base, k_smem, k_gmem, k_barriers[slot], k_tr + ) + def start_v_copy(slot, kv_seq_base): + return start_kv_copy( + slot, kv_seq_base, v_smem, v_gmem, v_barriers[slot], v_tr + ) + + with ctx.named_region("KV TMA warmup"): + for i in range(blocks.stages): + start_k_copy(c(i), loop_partition.get_base(c(i))) + start_v_copy(c(i), loop_partition.get_base(c(i))) + + @fori(c(loop_partition.num_chunks - blocks.stages), None) + def _kv_loop_memory(kv_step, _): + tma_step = arith.addi(kv_step, c(blocks.stages)) + tma_slot = arith.remui(kv_step, c(blocks.stages)) + with ctx.named_region("K consumed barrier"): + k_consumed_barrier.wait() + start_k_copy(tma_slot, loop_partition.get_base(tma_step)) + with ctx.named_region("V consumed barrier"): + v_consumed_barrier.wait() + start_v_copy(tma_slot, loop_partition.get_base(tma_step)) + @fori(c(blocks.stages), None) + def _kv_loop_memory(i, _): + k_consumed_barrier.wait() + v_consumed_barrier.wait() + scf.yield_([]) + + def compute_only_kernel( + ctx: mosaic_gpu.LaunchContext, + q_gmem, + k_gmem, + v_gmem, + out_gmem, + smem_scratch, + ): + barriers = BarrierArray(blocks.stages + wgs_per_block) + schedule_barrier = BarrierArray(1, arrival_count=256)[0] + def perform_schedule_barrier(): + schedule_barrier.arrive() + schedule_barrier.wait() + wg_idx = warpgroup_idx(sync=True) + qo_smem, k_smem, v_smem = smem_scratch + qo_smem = memref_slice(qo_smem, arith.index_cast(index, wg_idx)) + + @contextlib.contextmanager + def only_wg(idx): + i32 = ir.IntegerType.get_signless(32) + is_wg = arith.cmpi(arith.CmpIPredicate.eq, wg_idx, c(idx, i32)) + with ir.InsertionPoint(scf.IfOp(is_wg).then_block): + yield + scf.yield_([]) + + batch_idx, q_seq_base, q_head_idx = block_partition.get_base( + gpu.block_id(gpu.Dimension.x), + gpu.block_id(gpu.Dimension.y), + gpu.block_id(gpu.Dimension.z), + ) + q_seq_base = arith.addi( + q_seq_base, arith.muli(arith.index_cast(index, wg_idx), c(blocks.q)) + ) + del batch_idx + + q_barrier = arith.addi(c(blocks.stages), arith.index_cast(index, wg_idx)) + with ctx.named_region("Q TMA start"): + ctx.async_copy( + src_ref=q_gmem, + gmem_slice=(q_head_idx, ds(q_seq_base, blocks.q)), + gmem_transform=mosaic_gpu.TileTransform(tiling), + dst_ref=qo_smem, + barrier=barriers[q_barrier], + swizzle=128, + ) + + kv_head_idx = arith.divui(q_head_idx, c(q_heads_per_kv_head)) + + def kv_copy_init(slot, kv_seq_base): + with single_thread(per_block=False): + txcount = c(2 * blocks.kv * head_dim * bytewidth(f16)) + nvgpu.mbarrier_arrive_expect_tx(barriers.value, txcount, slot) + k_tr = ( + mosaic_gpu.TileTransform(tiling), + mosaic_gpu.TransposeTransform((0, 2, 1, 3, 4)), + ) + v_tr = mosaic_gpu.TileTransform(tiling) + for smem, gmem, t in ((k_smem, k_gmem, k_tr), (v_smem, v_gmem, v_tr)): + ctx.async_copy( + dst_ref=memref_slice(smem, slot), + src_ref=gmem, + gmem_slice=(kv_head_idx, ds(kv_seq_base, blocks.kv)), + gmem_transform=t, + barrier=barriers[slot], + arrive=False, + uniform=False, + swizzle=128, + ) + + loop_partition = Partition1D(kv_seq_len, chunk_size=blocks.kv) + with only_wg(1), ctx.named_region("KV TMA warmup"): + for i in range(blocks.stages - 1): + kv_copy_init(c(i), loop_partition.get_base(c(i))) + + with ctx.named_region("Q TMA wait"): + barriers[q_barrier].wait() + + m_i = FragmentedArray.splat( + c(-jnp.inf, f32), shape=(blocks.q,), layout=WGMMA_ROW_LAYOUT + ) + l_i = FragmentedArray.splat( + c(0, f32), shape=(blocks.q,), layout=WGMMA_ROW_LAYOUT + ) + acc = FragmentedArray.splat( + c(0, f32), shape=(blocks.q, head_dim), layout=WGMMA_LAYOUT + ) + + with only_wg(1): + perform_schedule_barrier() + + with only_wg(0): + barriers[c(0)].wait() + + @fori(c(loop_partition.num_chunks), (acc, m_i, l_i)) + def kv_loop(kv_step, carry): + acc, m_i, l_i = carry + slot = arith.remui(kv_step, c(blocks.stages)) + + with ctx.named_region("QK issue"): + # TODO(apaszke): Support WGMMA without an initial accumulator. + qk_acc = WGMMAAccumulator.zero(blocks.q, blocks.kv) + q, k = qo_smem, memref_slice(k_smem, slot) + qk_acc = wgmma(qk_acc, q, k, b_order=WGMMALayout.COL_MAJOR) + nvvm.wgmma_commit_group_sync_aligned() + + # We hide the TMA overhead by overlapping it with the QK matmul. + with only_wg(1), ctx.named_region("KV TMA start"): + tma_step = arith.addi(kv_step, c(blocks.stages - 1)) + tma_slot = arith.remui(tma_step, c(blocks.stages)) + tma_step_in_bounds = arith.cmpi( + arith.CmpIPredicate.slt, tma_step, c(loop_partition.num_chunks) + ) + if_op = scf.IfOp(tma_step_in_bounds) + with ir.InsertionPoint(if_op.then_block): + kv_copy_init(tma_slot, loop_partition.get_base(tma_step)) + scf.yield_([]) + + perform_schedule_barrier() + + with ctx.named_region("QK wait"): + nvvm.wgmma_wait_group_sync_aligned(0) + qk = qk_acc.value + + with ctx.named_region("Softmax"): + m_ij = m_i.max(qk.reduce(arith.maximumf, axis=1)) + alpha = exp(m_i - m_ij) + m_i = m_ij + p = exp(qk - m_ij.broadcast_minor(blocks.kv)) + acc *= alpha.broadcast_minor(head_dim) + l_i *= alpha + l_i += p.reduce(arith.addf, axis=1) + p = p.astype(f16) + + perform_schedule_barrier() + + with ctx.named_region("PV issue"): + v = memref_slice(v_smem, slot) + acc_update = WGMMAAccumulator.from_registers(acc) + acc_update = wgmma(acc_update, p, v) + nvvm.wgmma_commit_group_sync_aligned() + + # We hide the barrier overhead by overlapping it with the PV matmul. + with only_wg(0), ctx.named_region("KV TMA wait"): + wait_step = arith.addi(kv_step, c(1)) + wait_slot = arith.remui(wait_step, c(blocks.stages)) + wait_step_in_bounds = arith.cmpi( + arith.CmpIPredicate.slt, wait_step, c(loop_partition.num_chunks) + ) + with ir.InsertionPoint(scf.IfOp(wait_step_in_bounds).then_block): + barriers[wait_slot].wait() + scf.yield_([]) + + with ctx.named_region("PV wait"): + nvvm.wgmma_wait_group_sync_aligned(0) + acc = acc_update.value + + return acc, m_i, l_i + + with only_wg(0): + perform_schedule_barrier() + + acc, m_i, l_i = kv_loop.results + del m_i + # TODO(apaszke): Invert and multiply to avoid expensive divisions. + acc /= l_i.broadcast_minor(head_dim) + + with ctx.named_region("Acc store"): + acc.astype(f16).store_tiled(qo_smem, swizzle=128) + gpu.barrier() + nvvm.fence_proxy( + nvvm.ProxyKind.async_shared, space=nvvm.SharedSpace.shared_cta + ) # Make sure the store is visible to the TMA. + + with ctx.named_region("GMEM store"): + ctx.async_copy( + src_ref=qo_smem, + dst_ref=out_gmem, + gmem_slice=(q_head_idx, ds(q_seq_base, blocks.q)), + gmem_transform=mosaic_gpu.TileTransform(tiling), + swizzle=128, + ) + ctx.await_async_copy(0) + + match impl: + case Implementation.TWO_COMPUTE_WG: + kernel = compute_only_kernel + case Implementation.TWO_COMPUTE_ONE_TMA_WG: + kernel = tma_wg_kernel + return mosaic_gpu.as_gpu_kernel( + kernel, grid, block, in_shape, out_shape, smem_scratch_shape, prof_spec + ) + + +def benchmark_and_verify( + batch_size, + q_seq_len, + kv_seq_len, + num_q_heads, + num_kv_heads, + head_dim, + **kwargs, +) -> float: + with mlir.make_ir_context(), ir.Location.unknown(): + kq, kk, kv = random.split(random.key(1234), 3) + q = random.normal( + kq, (batch_size, num_q_heads, q_seq_len, head_dim), dtype=jnp.float16 + ) + k = random.normal( + kk, (batch_size, num_kv_heads, kv_seq_len, head_dim), dtype=jnp.float16 + ) + v = random.normal( + kv, (batch_size, num_kv_heads, kv_seq_len, head_dim), dtype=jnp.float16 + ) + f = build_kernel( + batch_size=batch_size, + q_heads=num_q_heads, + kv_heads=num_kv_heads, + q_seq_len=q_seq_len, + kv_seq_len=kv_seq_len, + head_dim=head_dim, + **kwargs, + ) + out, runtime = profiler.measure(f, q[0], k[0], v[0]) + out = out[None] + + @jax.jit + def ref(q, k, v): + q = q.astype(jnp.float32) + k = k.astype(jnp.float32) + v = v.astype(jnp.float32) + q_reshaped = q.reshape( + batch_size, num_kv_heads, num_q_heads // num_kv_heads, q_seq_len, + head_dim) + logits = jnp.einsum("bxhqc,bxkc->bxhqk", q_reshaped, k) + m = logits.max(axis=-1) + unnormalized = jnp.exp(logits - m[..., None]) + l = unnormalized.sum(axis=-1) + weights = unnormalized / l[..., None] + return jnp.einsum("bxhqk,bxkc->bxhqc", weights, v).reshape(*q.shape) + expected = ref(q, k, v) + np.testing.assert_allclose(out, expected, atol=2e-3, rtol=2e-3) + return runtime + + +if __name__ == "__main__": + batch_size = 1 + num_q_heads = 4 + num_kv_heads = 1 + prof_spec = None + seq_lens = (4096, 32768) + problem_it = itertools.product(seq_lens, (64, 128, 256,)) + for seq_len, head_dim in problem_it: + q_seq_len = kv_seq_len = seq_len + print( + "====" + f" {kv_seq_len=:<6} {q_seq_len=:<6} {num_q_heads=:<4} {head_dim=:<6} ====" + ) + param_it = itertools.product( + (ExpImplementation.APPROX,), Implementation, (64,), (64, 128, 256), + ) + best = None + for exp_impl, impl, block_q, block_kv in param_it: + try: + runtime_ms = benchmark_and_verify( + batch_size, + q_seq_len, + kv_seq_len, + num_q_heads, + num_kv_heads, + head_dim, + prof_spec=prof_spec, + exp_impl=exp_impl, + blocks=BlockSizes(q=block_q, kv=block_kv, stages=2), + impl=impl, + ) + except ValueError as e: + if "exceeds available shared memory" in e.args[0]: + continue + raise + runtime_us = runtime_ms * 1e3 + matmul_flops = ( + 4 * q_seq_len * kv_seq_len * head_dim * num_q_heads * batch_size + ) + peak_flops = 1e15 # f16 TensorCore peak = 1000TFLOPS + optimal_time = matmul_flops / peak_flops * 1e6 # us + achieved_tc_util = optimal_time / runtime_us * 100 + has_tma_warp = impl == Implementation.TWO_COMPUTE_ONE_TMA_WG + print( + f"exp_impl={exp_impl.name:<6} block_q={block_q:<4}block_kv={block_kv:<4}tma_warp={has_tma_warp:<1}: {runtime_us:<7.1f}us" + f" = {achieved_tc_util:4.1f}% TC utilization" + ) + if best is None or runtime_us < best[0]: + best = (runtime_us, achieved_tc_util) + if best is not None: + print(f"Best: {best[0]:<7.1f}us = {best[1]:4.1f}% TC utilization") diff --git a/jax/experimental/mosaic/gpu/examples/matmul.py b/jax/experimental/mosaic/gpu/examples/matmul.py new file mode 100644 index 000000000000..53818f2ef39d --- /dev/null +++ b/jax/experimental/mosaic/gpu/examples/matmul.py @@ -0,0 +1,530 @@ +# Copyright 2024 The JAX Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Matmul kernels for H100.""" + +import dataclasses +import enum +import functools + +import jax +from jax import random +from jax._src.interpreters import mlir +from jax.experimental.mosaic import gpu as mosaic_gpu +from jax.experimental.mosaic.gpu import profiler +from jax.experimental.mosaic.gpu.dsl import * # noqa: F403 +import jax.numpy as jnp +from jaxlib.mlir import ir +from jaxlib.mlir.dialects import arith +from jaxlib.mlir.dialects import gpu +from jaxlib.mlir.dialects import memref +from jaxlib.mlir.dialects import nvgpu +from jaxlib.mlir.dialects import nvvm +from jaxlib.mlir.dialects import scf +from jaxlib.mlir.dialects import vector +import numpy as np + +# mypy: ignore-errors +# ruff: noqa: F405 +# pylint: disable=line-too-long, wildcard-import, missing-function-docstring, bad-continuation, g-bad-todo, protected-access + +SmemRef = ir.Value + + +@dataclasses.dataclass(frozen=True) +class Tiling: + m: int + n: int + k: int + + @property + def mk(self): + return (self.m, self.k) + + @property + def kn(self): + return (self.k, self.n) + + @property + def nk(self): + return (self.n, self.k) + + @property + def mn(self): + return (self.m, self.n) + + +class F32Precision(enum.Enum): + DEFAULT = enum.auto() + TF32_X3 = enum.auto() + + +class WGMMADefaultImpl: + """Default WGMMA implementation.""" + + @staticmethod + def zero_accs(tile_m: int, tile_n: int) -> WGMMAAccumulator: + return WGMMAAccumulator.zero(tile_m, tile_n) + + @staticmethod + def smem_shape_extra( + block_tiling: Tiling, + tma_tiling: Tiling, + lhs_dtype: jnp.dtype, rhs_dtype: jnp.dtype, + rhs_transpose: WGMMALayout, + ) -> dict[str, jax.ShapeDtypeStruct]: + del block_tiling, tma_tiling, lhs_dtype, rhs_dtype, rhs_transpose + return {} + + @staticmethod + def get_result_tile(acc: WGMMAAccumulator) -> FragmentedArray: + return acc.value + + @staticmethod + def wgmma( + smem_scratch: dict[str, SmemRef], # pylint: disable=unused-argument + acc: WGMMAAccumulator, + b_order: WGMMALayout, + a_slice: SmemRef, + b_slice: SmemRef, + ) -> dict[str, WGMMAAccumulator]: + acc = wgmma(acc, a_slice, b_slice, b_order=b_order) + nvvm.wgmma_commit_group_sync_aligned() + nvvm.wgmma_wait_group_sync_aligned(1) + return acc + + +class WGMMATF32x3Impl: + """WGMMA implementation for 3xTF32 precision.""" + + @staticmethod + def zero_accs(tile_m, tile_n) -> dict[str, WGMMAAccumulator]: + zero_acc = WGMMADefaultImpl.zero_accs(tile_m, tile_n) + return {"main": zero_acc, "errs": zero_acc} + + @staticmethod + def smem_shape_extra( + block_tiling: Tiling, + tma_tiling: Tiling, + lhs_dtype: jnp.dtype, rhs_dtype: jnp.dtype, + rhs_transpose: bool, + ) -> dict[str, jax.ShapeDtypeStruct]: + del rhs_transpose + lhs_err = jax.ShapeDtypeStruct(shape=tile_shape(block_tiling.mk, tma_tiling.mk), dtype=lhs_dtype) + rhs_err = jax.ShapeDtypeStruct(shape=tile_shape(block_tiling.kn, tma_tiling.kn), dtype=rhs_dtype) + return {"lhs_err": lhs_err, "rhs_err": rhs_err} + + @staticmethod + def get_result_tile(accs) -> FragmentedArray: + return accs["main"].value + accs["errs"].value + + @staticmethod + def rounding_error(x_ref, err_ref): + """Store the TF32 rounding error of x_ref in err_ref.""" + f32 = ir.F32Type.get() + i32 = ir.IntegerType.get_signless(32) + t = FragmentedArray.load_strided(x_ref) + tf32_mask = FragmentedArray.splat(c(0xFFFFE000, i32), t.shape, t.layout) + t_tf32 = (t.bitcast(i32) & tf32_mask).bitcast(f32) + (t - t_tf32).store_untiled(err_ref) + + @staticmethod + def wgmma( + smem_scratch: dict[str, SmemRef], + accs: dict[str, WGMMAAccumulator], + b_order: WGMMALayout, + a_slice: SmemRef, + b_slice: SmemRef, + ) -> dict[str, WGMMAAccumulator]: + acc = wgmma(accs["main"], a_slice, b_slice, b_order=b_order) + nvvm.wgmma_commit_group_sync_aligned() + # Note: we assert that only the slice_ab and err_b mmas are still running + # which are unaffected by writing to the err_a shared memory. + # After nvvm.wgmma_wait_group_sync_aligned(2) there are no wgmmas + # accessing err_a so we can safely write to it. + nvvm.wgmma_wait_group_sync_aligned(2) + WGMMATF32x3Impl.rounding_error(a_slice, smem_scratch["lhs_err"]) + commit_shared() + acc_err = wgmma(accs["errs"], smem_scratch["lhs_err"], b_slice, b_order=b_order) + nvvm.wgmma_commit_group_sync_aligned() + # Note: similar to the above we wait for the last wgmma access to + # err_b which was 2 wgmmas ago. + nvvm.wgmma_wait_group_sync_aligned(2) + WGMMATF32x3Impl.rounding_error(b_slice, smem_scratch["rhs_err"]) + commit_shared() + acc_err = wgmma(acc_err, a_slice, smem_scratch["rhs_err"], b_order=b_order) + nvvm.wgmma_commit_group_sync_aligned() + nvvm.wgmma_wait_group_sync_aligned(2) + return {"main": acc, "errs": acc_err} + +class WGMMACvtRhsImpl: + """Mixed WGMMA implementation where B is converted to A.""" + + @staticmethod + def zero_accs(tile_m: int, tile_n: int) -> WGMMAAccumulator: + return WGMMADefaultImpl.zero_accs(tile_m, tile_n) + + @staticmethod + def smem_shape_extra( + block_tiling: Tiling, + tma_tiling: Tiling, + lhs_dtype: jnp.dtype, rhs_dtype: jnp.dtype, + rhs_transpose: bool, + ) -> dict[str, jax.ShapeDtypeStruct]: + del rhs_dtype + if rhs_transpose: + raise NotImplementedError("Transpose requires more elaborate handling of tiling.") + + if tma_tiling.k != 64: + raise ValueError(f"WGMMA layout needs the left tiling dimension to be 64 {tma_tiling.k=}") + + # The second dim needs to be tma_tiling.k so it is 128b wide and + # the first dim needs to line up with the lhs dimension. That's + # why we have a strange (k, k) here. + cvt_shape = tile_shape(block_tiling.kn, (tma_tiling.k, tma_tiling.k)) + return {"cvt": jax.ShapeDtypeStruct(shape=cvt_shape, dtype=lhs_dtype)} + + @staticmethod + def get_result_tile(acc: WGMMAAccumulator) -> FragmentedArray: + return WGMMADefaultImpl.get_result_tile(acc) + + @staticmethod + def wgmma( + smem_scratch: dict[str, SmemRef], # pylint: disable=unused-argument + acc: WGMMAAccumulator, + b_order: WGMMALayout, + a_slice: SmemRef, + b_slice: SmemRef, + ) -> dict[str, WGMMAAccumulator]: + # Convert the load + arr = FragmentedArray.load_tiled(b_slice, swizzle=128) + cvt_ty = ir.MemRefType(smem_scratch["cvt"].type) + # TODO(cperivol): https://research.google/blog/mixed-input-matrix-multiplication-performance-optimizations/ + arr = arr.astype(cvt_ty.element_type) + # Make sure no wgmma is running. + # TODO(cperivol): double buffer. + nvvm.wgmma_wait_group_sync_aligned(0) + arr.store_tiled(smem_scratch["cvt"], swizzle=128) + commit_shared() + acc = wgmma(acc, a_slice, smem_scratch["cvt"], b_order=b_order) + nvvm.wgmma_commit_group_sync_aligned() + return acc + + +def mlir_context(f): + def wrap(*args, **kw): + with mlir.make_ir_context(), ir.Location.unknown(): + return f(*args, **kw) + + return wrap + +@mlir_context +def build_kernel( + m, n, k, + lhs_dtype, rhs_dtype, + stages: int = 2, + tile_m: int = 128, + tile_n: int = 128, + rhs_transpose: bool = False, + wgmma_impl=WGMMADefaultImpl, + profiler_spec: profiler.ProfilerSpec | None = None, +): + f32 = ir.F32Type.get() + out_128b_elems = 128 // bytewidth(f32) + out_tiling = (64, out_128b_elems) + out_tile = jax.ShapeDtypeStruct(tile_shape((tile_m, tile_n), out_tiling), jnp.float32) + if tile_m % 64 != 0: + raise ValueError(f"{tile_m=} must be divisible by 64") + if m % tile_m != 0: + raise ValueError(f"{m=} must be divisible by {tile_m=}") + if n % 64 != 0: + raise ValueError(f"n must be divisible by 64, but got {n=}") + if stages < 2: + raise ValueError(f"Need at least 2 stages, but got {stages=}") + + lhs_elem_bytes = bytewidth(mlir.dtype_to_ir_type(lhs_dtype)) + rhs_elem_bytes = bytewidth(mlir.dtype_to_ir_type(rhs_dtype)) + lhs_128b_elems = 128 // lhs_elem_bytes + rhs_128b_elems = 128 // rhs_elem_bytes + tile_k = max(lhs_128b_elems, rhs_128b_elems) + + if tile_n % rhs_128b_elems != 0: + raise ValueError( + f"{tile_n=} must be divisible by 128 bytes =" + f" {((lhs_128b_elems, lhs_dtype), (rhs_128b_elems, rhs_dtype))}" + ) + + if k % tile_k != 0: + raise ValueError(f"k must be divisible by {tile_k=}, but got {k=}") + + block_tiling = Tiling(m=tile_m, n=tile_n, k=tile_k) + tma_tiling = Tiling(m=64, n=rhs_128b_elems, k=lhs_128b_elems) + k_steps = k // block_tiling.k + stages = min(stages, k_steps) + + def safe_div(x, y): + assert x % y == 0, (x, y) + return x // y + + grid = (safe_div(m, block_tiling.m), safe_div(n, block_tiling.n), 1) + block = (128, 1, 1) + + c = arith.ConstantOp.create_index + divmod = lambda x, y: (arith.divui(x, c(y)), arith.remui(x, c(y))) + + compute_scratch_shapes = { + "lhs": jax.ShapeDtypeStruct((stages, *tile_shape(block_tiling.mk, tma_tiling.mk)), lhs_dtype), + "rhs": jax.ShapeDtypeStruct((stages, *tile_shape(block_tiling.kn, tma_tiling.kn)), rhs_dtype), + } + compute_scratch_shapes |= wgmma_impl.smem_shape_extra(block_tiling, tma_tiling, lhs_dtype, rhs_dtype, rhs_transpose) + + epilogue_scratch_shapes = { + "acc": jax.ShapeDtypeStruct(out_tile.shape, out_tile.dtype), + } + + smem_shape = mosaic_gpu.Union( + [compute_scratch_shapes, epilogue_scratch_shapes]) + + def _main(ctx, a_device, b_device, c_device, + smem_union: mosaic_gpu.Union[mosaic_gpu.RefTree]): + compute_smem, epilogue_smem = smem_union.members + + memref.assume_alignment(c_device, 16) + + barrier_group = BarrierArray(stages) + m_start = arith.muli(c(block_tiling.m), gpu.block_id(gpu.Dimension.x)) + n_start = arith.muli(c(block_tiling.n), gpu.block_id(gpu.Dimension.y)) + + def fetch(slot, ki): + barrier = barrier_group[slot] + k_start = arith.muli(c(block_tiling.k), ki) + lhs_tma_tile_bytes = int(np.prod(block_tiling.mk) * lhs_elem_bytes) + rhs_tma_tile_bytes = int(np.prod(block_tiling.kn) * rhs_elem_bytes) + txcount = c(lhs_tma_tile_bytes + rhs_tma_tile_bytes) + common_copy_args = dict( + swizzle=128, barrier=barrier, arrive=False, uniform=False, + ) + with single_thread(): + nvgpu.mbarrier_arrive_expect_tx(barrier_group.value, txcount, slot) + ctx.async_copy( + src_ref=a_device, + dst_ref=memref_slice(compute_smem["lhs"], slot), + gmem_slice=(ds(m_start, block_tiling.m), ds(k_start, block_tiling.k)), + gmem_transform=mosaic_gpu.TileTransform(tma_tiling.mk), + **common_copy_args, + ) + rhs_slice = (ds(k_start, block_tiling.k), ds(n_start, block_tiling.n)) + rhs_transform = (mosaic_gpu.TileTransform(tma_tiling.kn),) + if rhs_transpose: + rhs_slice = rhs_slice[::-1] + rhs_transform += (mosaic_gpu.TransposeTransform((1, 0, 2, 3)),) + assert tma_tiling.n == tma_tiling.k, block_tiling # No need to flip the tiling. + ctx.async_copy( + src_ref=b_device, + dst_ref=memref_slice(compute_smem["rhs"], slot), + gmem_slice=rhs_slice, + gmem_transform=rhs_transform, + **common_copy_args, + ) + + accs = wgmma_impl.zero_accs(block_tiling.m, block_tiling.n) + + with ctx.named_region("TMA warmup"): + for i in range(stages): + fetch(c(i), c(i)) + + @fori(c(k_steps), accs) + def stage_loop_body(ki, accs): + si = arith.remui(ki, c(stages)) + + with ctx.named_region("TMA wait"): + barrier_group[si].wait() + + with ctx.named_region("WGMMA"): + a_slice = memref_slice(compute_smem["lhs"], si) + b_slice = memref_slice(compute_smem["rhs"], si) + rhs_smem_order = ( + WGMMALayout.COL_MAJOR if rhs_transpose else WGMMALayout.ROW_MAJOR + ) + accs = wgmma_impl.wgmma( + compute_smem, accs, rhs_smem_order, a_slice, b_slice) + + with ctx.named_region("TMA start"): + tma_ki = arith.addi(ki, c(stages - 1)) + do_tma = arith.cmpi(arith.CmpIPredicate.slt, tma_ki, c(k_steps)) + not_first_step = arith.cmpi(arith.CmpIPredicate.ne, ki, c(0)) + if_op = scf.IfOp(arith.andi(not_first_step, do_tma)) + with ir.InsertionPoint(if_op.then_block): + tma_si = arith.remui(tma_ki, c(stages)) + fetch(tma_si, tma_ki) + scf.yield_([]) + + return accs + + # Wait until everyone is done with their WMMA + with ctx.named_region("WGMMA drain"): + nvvm.wgmma_wait_group_sync_aligned(0) + + with ctx.named_region("SMEM store"): + acc_val = wgmma_impl.get_result_tile(stage_loop_body.result) + acc_smem = epilogue_smem["acc"] + acc_val.store_tiled(acc_smem, swizzle=128) + gpu.barrier() + + with ctx.named_region("GMEM store"): + # Vectorized epilogue to move results from SMEM to GMEM + # TODO(apaszke): Make this into a proper copy function. + warps_per_warpgroup = 4 + lanes_per_warp = 32 + m_out_tiling, n_out_tiling = out_tiling[-2:] + warp_id, lane_id = divmod(gpu.thread_id(gpu.Dimension.x), lanes_per_warp) + # We store 4 f32 numbers for a block of 16B. + vector_len = 4 + num_vectors_per_row = safe_div(tile_n, vector_len) + # Process several rows at once if it is necessary to fully exploit each + # warp. + if tile_n < lanes_per_warp * vector_len: + num_rows_per_warp = min( + safe_div(lanes_per_warp * vector_len, tile_n), + safe_div(tile_m, warps_per_warpgroup)) + else: + num_rows_per_warp = 1 + lanes_per_row = safe_div(lanes_per_warp, num_rows_per_warp) + lane_row_offset, lane_col_offset = divmod(lane_id, lanes_per_row) + warp_for_op = scf.ForOp(arith.muli(warp_id, c(num_rows_per_warp)), + c(tile_m), + c(warps_per_warpgroup * num_rows_per_warp)) + with ir.InsertionPoint(warp_for_op.body): + start_row = warp_for_op.induction_variable + m_row_idx = arith.addi(start_row, lane_row_offset) + vector_for_op = scf.ForOp(lane_col_offset, c(num_vectors_per_row), + c(lanes_per_row)) + with ir.InsertionPoint(vector_for_op.body): + vector_idx = vector_for_op.induction_variable + n_store = arith.muli(vector_idx, c(vector_len)) + col_group, n_load = divmod(n_store, n_out_tiling) + m_tile, m_within_tile = divmod(m_row_idx, m_out_tiling) + swizzle_source = arith.shli(arith.remui(m_row_idx, c(8)), c(2)) + n_acc = arith.xori(n_load, swizzle_source) + acc_part = vector.load( + ir.VectorType.get((vector_len,), f32), + acc_smem, + [m_tile, col_group, m_within_tile, n_acc], + ) + vector.store( + acc_part, + c_device, + [arith.addi(m_start, m_row_idx), arith.addi(n_start, n_store)], + ) + scf.yield_([]) + scf.yield_([]) + + return mosaic_gpu.as_gpu_kernel( + _main, + grid, + block, + ( + jax.ShapeDtypeStruct((m, k), lhs_dtype), + jax.ShapeDtypeStruct((n, k) if rhs_transpose else (k, n), rhs_dtype), + ), + jax.ShapeDtypeStruct((m, n), jnp.float32), + smem_shape, + profiler_spec, + ) + + +def random_array(key, shape: tuple[int, ...], dtype: jnp.dtype): + if jax.dtypes.issubdtype(dtype, np.floating): + return random.uniform(key, shape, dtype=dtype) + elif jax.dtypes.issubdtype(dtype, np.integer): + return random.randint(key, shape, -127, 127, dtype) + else: + raise NotImplementedError(dtype) + +def verify( + m=(33 * 128), + k=2048, + n=(4 * 128), + stages=4, + tile_m=128, + tile_n=128, + profile=False, + lhs_dtype=jnp.float16, + rhs_dtype=jnp.float16, + rhs_transpose=False, + precision: F32Precision = F32Precision.DEFAULT, +): + # TODO(cperivol): Transpose is only supported for 16bit wgmma. ATM + # that means bf16 x bf16, f16 x f16 and bf16 x s8. When we get more + # general mixed precision this check will need to be more nuanced. + if not rhs_transpose and jnp.dtype(lhs_dtype).itemsize != 2: + raise ValueError( + "Implicit transpose can only happen for 16bit types (or mixed precision" + " that is underpinned by 16bit operations)." + ) + + kx, ky = random.split(random.key(1234)) + x = random_array(kx, (m, k), lhs_dtype) + y = random_array(ky, (n, k) if rhs_transpose else (k, n), rhs_dtype) + + if lhs_dtype != rhs_dtype: + impl = WGMMACvtRhsImpl + else: + match precision: + case F32Precision.DEFAULT: + impl = WGMMADefaultImpl + case F32Precision.TF32_X3: + impl = WGMMATF32x3Impl + + prof_spec = profiler.ProfilerSpec(4096) if profile else None + f = build_kernel( + m, n, k, + jnp.dtype(lhs_dtype), jnp.dtype(rhs_dtype), + stages=stages, + tile_m=tile_m, + tile_n=tile_n, + rhs_transpose=rhs_transpose, + wgmma_impl=impl, + profiler_spec=prof_spec, + ) + z, runtime = profiler.measure(f, x, y) + + if rhs_transpose: + dimension_numbers = ((1,), (1,)), ((), ()) + else: + dimension_numbers = ((1,), (0,)), ((), ()) + if lhs_dtype == jnp.dtype(jnp.float32): # Account for the tf32 precision + exponent_bits, mantissa_bits = 8, 10 + x, y = ( + jax.lax.reduce_precision(v, exponent_bits, mantissa_bits) + for v in (x, y) + ) + + ref_f = functools.partial( + jax.lax.dot_general, + dimension_numbers=dimension_numbers, + preferred_element_type=jnp.float32, + ) + + ref, ref_runtime = profiler.measure(ref_f, x, y) + np.testing.assert_allclose(z, ref, atol=1e-3, rtol=1e-3) + return runtime, ref_runtime + + +if __name__ == "__main__": + m, k, n = 33 * 128, 2048, 4 * 128 + runtime, ref_runtime = verify(m=m, k=k, n=n) + tflops = float(2 * k * m * n) / (runtime / 1e3) / 1e12 + print(f"Kernel: {runtime * 1000:.1f} us = {tflops:.1f} TFLOPS") + print(f"Reference: {runtime * 1000:.1f} us = {tflops:.1f} TFLOPS") diff --git a/jax/experimental/mosaic/gpu/fragmented_array.py b/jax/experimental/mosaic/gpu/fragmented_array.py new file mode 100644 index 000000000000..ff7d1e0591b0 --- /dev/null +++ b/jax/experimental/mosaic/gpu/fragmented_array.py @@ -0,0 +1,661 @@ +# Copyright 2024 The JAX Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utilities for code generator.""" + +import dataclasses + +import jax +from jaxlib.mlir import ir +from jaxlib.mlir.dialects import arith +from jaxlib.mlir.dialects import gpu +from jaxlib.mlir.dialects import llvm +from jaxlib.mlir.dialects import math as mlir_math +from jaxlib.mlir.dialects import memref +from jaxlib.mlir.dialects import nvvm +from jaxlib.mlir.dialects import vector +import numpy as np + +from . import dsl as mgpu +from . import utils + +# mypy: ignore-errors + +WARPGROUP_SIZE = utils.WARPGROUP_SIZE +c = utils.c + + +@dataclasses.dataclass(frozen=True) +class WGSplatFragLayout: + """A fragmented array where all the values are equal represented as a register per thread. + + FragmentedArrays in this layout can be are always the result of a + splat, each thread in the warpgroup has a single copy of the value, + while the FragmentedArray pretends it has whatever shape the user + wants. This means we can trivially broadcast, reshape and do + elementwise operations with all other layouts. + + Examples: + + To load a value in + ``` + FragmentedArray.splat(memref.load(ref_1d, [1]), (10,20,2)) + ``` + + A shape is always provided for sanity check reasons. + + """ + + shape: tuple[int, ...] = () + + def can_broadcast_to(self, shape) -> bool: + """Check that the shape can be broadcast. + + Only dimensions of size 1 can be broadcast. All other dimensions + must be the same as the argument shape. + """ + return all(dim1 == dim2 or dim1 == 1 for dim1, dim2 in zip(self.shape[::-1], shape[::-1])) + + +@dataclasses.dataclass(frozen=True) +class WGMMAFragLayout: + """[m, n] matrix, where m % 64 == 0 == n % 8.""" + + +@dataclasses.dataclass(frozen=True) +class WGMMARowFragLayout: + """[m] matrix, where m % 64 == 0.""" + + +@dataclasses.dataclass(frozen=True) +class WGStridedFragLayout: + """Convert the array to 1D and then shard across threads.""" + + shape: tuple[int, ...] + vec_size: int + + def __post_init__(self): + if np.prod(self.shape) % (self.vec_size * WARPGROUP_SIZE) != 0: + raise ValueError((self, WARPGROUP_SIZE)) + + @classmethod + def from_memref_type(cls, memref_ty: ir.Type): + if not ir.MemRefType.isinstance(memref_ty): + raise TypeError(memref_ty) + + memref_type = ir.MemRefType(memref_ty) + bw = mgpu.bytewidth(memref_type.element_type) + assert 8 % bw == 0 and 8 // bw != 0, bw + if np.prod(memref_type.shape) % WARPGROUP_SIZE != 0: + raise ValueError( + "Ref must have a number of elements that is a multiple of" + f" {WARPGROUP_SIZE}" + ) + max_vec_size = np.prod(memref_type.shape) // WARPGROUP_SIZE + return cls( + shape=tuple(memref_type.shape), vec_size=min(8 // bw, max_vec_size) + ) + + def thread_vec_idxs(self): + """The indexes to be used for vector load/store WGStridedFragLayout. + + Yields: + The indices of the vector that correspond to the current thread. + """ + index = ir.IndexType.get() + cardinality = np.prod(self.shape) + assert cardinality % (WARPGROUP_SIZE * self.vec_size) == 0 + reg_num = cardinality // (WARPGROUP_SIZE * self.vec_size) + tidx = arith.remui(gpu.thread_id(gpu.Dimension.x), c(WARPGROUP_SIZE, index)) + off = arith.muli(tidx, c(self.vec_size, tidx.type)) + for i in range(reg_num): + yield [arith.addi(off, c(i * WARPGROUP_SIZE * self.vec_size, tidx.type))] + + +FragmentedLayout = WGSplatFragLayout | WGStridedFragLayout | WGMMAFragLayout | WGMMARowFragLayout + + +WGMMA_LAYOUT = WGMMAFragLayout() +WGMMA_ROW_LAYOUT = WGMMARowFragLayout() + + +@jax.tree_util.register_pytree_node_class +class FragmentedArray: + registers: np.ndarray # of ir.Value, see checks in init for shapes. + layout: FragmentedLayout + + def __init__(self, *, _registers: np.ndarray, _layout: FragmentedLayout): + self.registers = _registers + self.layout = _layout + + match self.layout: + # Registers are [m_tiles, n_tiles, 2 rows, 1 cols] in WGMMA layout + # Each element is a vector<2xdtype> + case WGMMAFragLayout(): + if self.registers.ndim != 4 or self.registers.shape[2:] != (2, 1): + raise ValueError("Invalid register array shape") + + # Registers are [m_tiles, 2 rows] in WGMMA_ROW layout + # Each element is a dtype scalar + case WGMMARowFragLayout(): + if self.registers.ndim != 2 or self.registers.shape[-1] != 2: + raise ValueError("Invalid register array shape") + + # Registers are flat + case WGStridedFragLayout(shape): + (reg_size,) = ir.VectorType(_registers.flat[0].type).shape + if np.prod(shape) != np.prod(_registers.shape) * WARPGROUP_SIZE * reg_size: + raise ValueError((reg_size, shape, _registers.shape, WARPGROUP_SIZE), _registers.flat[0].type) + + # Just a single register + case WGSplatFragLayout(): + if _registers.size != 1: + raise ValueError(f"WGStridedFragLayout requires a single value {_registers.shape} ({_registers.size})") + + case _: + raise NotImplementedError + + @classmethod + def load_strided(cls, ref: ir.Value): + if not ir.MemRefType.isinstance(ref.type): + raise TypeError(ref.type) + + ref_ty = ir.MemRefType(ref.type) + ref_1d = mgpu.memref_fold(ref, 0, len(ref_ty.shape)) + layout = WGStridedFragLayout.from_memref_type(ref_ty) + vec_ty = ir.VectorType.get((layout.vec_size,), ref_ty.element_type) + vecs = [vector.load(vec_ty, ref_1d, vec_idx) for vec_idx in layout.thread_vec_idxs()] + return cls(_registers=np.array(vecs), _layout=layout) + + @classmethod + def splat(cls, value, shape, layout=None): + layout = layout or WGSplatFragLayout(shape) + match layout: + case WGMMARowFragLayout(): + if len(shape) != 1: + raise ValueError + if shape[0] % 64: + raise ValueError + reg_shape = (shape[0] // 64, 2) + case WGMMAFragLayout(): + if len(shape) != 2: + raise ValueError + if shape[0] % 64 or shape[1] % 8: + raise ValueError + reg_shape = (shape[0] // 64, shape[1] // 8, 2, 1) + value = vector.splat(ir.VectorType.get((2,), value.type), value) + case WGStridedFragLayout(vec_size=vec_size): + assert shape == layout.shape + elems = np.prod(shape) + reg_shape = (elems // (WARPGROUP_SIZE * vec_size),) + value = vector.splat(ir.VectorType.get((vec_size,), value.type), value) + case WGSplatFragLayout(): + assert shape == layout.shape + reg_shape = () + case _: + raise NotImplementedError(layout) + + return cls( + _registers=np.full(reg_shape, value, dtype=object), + _layout=layout, + ) + + @property + def shape(self): + match self.layout: + case WGMMAFragLayout(): + row_tiles, col_tiles = self.registers.shape[:2] + return (row_tiles * 64, col_tiles * 8) + case WGMMARowFragLayout(): + row_tiles = self.registers.shape[0] + return (row_tiles * 64,) + case WGStridedFragLayout(shape): + return shape + case WGSplatFragLayout(shape=shape): + return shape + + @property + def mlir_dtype(self): + reg_ty = self.registers.flat[0].type + match self.layout: + case WGMMAFragLayout() | WGStridedFragLayout(): + return ir.VectorType(reg_ty).element_type + case WGMMARowFragLayout() | WGSplatFragLayout(): + return reg_ty + + def _pointwise(self, op, *other): + other_arrs = [] + for o in other: + if not isinstance(o, FragmentedArray): + if not isinstance(o, ir.Value): + raise NotImplementedError(o) + + o = FragmentedArray.splat(o, shape=self.shape, layout=self.layout) + + if isinstance(o.layout, WGSplatFragLayout): + if not o.layout.can_broadcast_to(self.shape): + raise ValueError("Can't broadcast shape.") + o = FragmentedArray.splat(o.registers.flat[0], shape=self.shape, layout=self.layout) + else: + if self.layout != o.layout: + raise ValueError("Incompatible FragmentedArray layouts") + if self.registers.shape != o.registers.shape: + raise ValueError("Incompatible FragmentedArray shapes") + + other_arrs.append(o) + new_regs = np.empty_like(self.registers) + + for idx, reg in np.ndenumerate(self.registers): + new_regs[idx] = op(reg, *(o.registers[idx] for o in other_arrs)) + return FragmentedArray(_registers=new_regs, _layout=self.layout) + + def __add__(self, other): + if ir.FloatType.isinstance(self.mlir_dtype): + return self._pointwise(arith.addf, other) + elif ir.IntegerType.isinstance(self.mlir_dtype): + return self._pointwise(arith.addi, other) + else: + raise NotImplementedError(self.mlir_dtype) + + def __mul__(self, other): + if ir.FloatType.isinstance(self.mlir_dtype): + return self._pointwise(arith.mulf, other) + elif ir.IntegerType.isinstance(self.mlir_dtype): + return self._pointwise(arith.muli, other) + else: + raise NotImplementedError(self.mlir_dtype) + + def __sub__(self, other): + if not ir.FloatType.isinstance(self.mlir_dtype): + raise NotImplementedError + return self._pointwise(arith.subf, other) + + def __truediv__(self, other): + if not ir.FloatType.isinstance(self.mlir_dtype): + raise NotImplementedError + return self._pointwise(arith.divf, other) + + def max(self, other): + if not ir.FloatType.isinstance(self.mlir_dtype): + raise NotImplementedError + return self._pointwise(arith.maximumf, other) + + def exp(self, approx: bool = False): + if not ir.FloatType.isinstance(self.mlir_dtype): + raise NotImplementedError + def fast_exp(x): + f32 = ir.F32Type.get() + if self.mlir_dtype != f32: + raise NotImplementedError + log2e = arith.constant(f32, ir.FloatAttr.get(f32, 1.4426950408889634)) + if x.type == f32: + scaled = arith.mulf(x, log2e) + return llvm.inline_asm( + f32, [scaled], "ex2.approx.f32 $0,$1;", "=f,f", asm_dialect=0 + ) + elif ir.VectorType.isinstance(x.type): + index = ir.IndexType.get() + result = llvm.mlir_undef(x.type) + for i in range(2): + v = vector.extractelement(x, position=c(i, index)) + vr = fast_exp(v) + result = vector.insertelement(vr, result, position=c(i, index)) + return result + else: + raise NotImplementedError(x.type) + return self._pointwise(fast_exp if approx else mlir_math.exp) + + def rsqrt(self): + return self._pointwise(mlir_math.rsqrt) + + def __and__(self, other): + if not ir.IntegerType.isinstance(self.mlir_dtype): + raise ValueError( + "Bitwise operations only defined for integer types, not" + f" {self.mlir_dtype}" + ) + + return self._pointwise(arith.andi, other) + + def bitcast(self, elt: ir.Type): + reg_type = self.registers.flat[0].type + if ir.VectorType.isinstance(reg_type): + reg_shape = ir.VectorType(reg_type).shape + ty = ir.VectorType.get(reg_shape, elt) + else: + ty = elt + + return self._pointwise(lambda x: arith.bitcast(ty, x)) + + def __getitem__(self, idx): + if self.layout != WGMMA_LAYOUT: + raise NotImplementedError("Only WGMMA layouts support slicing") + base_idx, slice_shape, is_squeezed = utils.parse_indices(idx, self.shape) + if any(is_squeezed): + raise NotImplementedError("Only slicing implemented") + if ( + base_idx[0] % 64 + or slice_shape[0] % 64 + or base_idx[1] % 8 + or slice_shape[1] % 8 + ): + raise NotImplementedError("Only tile aligned slicing supported") + base_idx[0] //= 64 + slice_shape[0] //= 64 + base_idx[1] //= 8 + slice_shape[1] //= 8 + new_regs = self.registers[ + base_idx[0] : base_idx[0] + slice_shape[0], + base_idx[1] : base_idx[1] + slice_shape[1], + ] + return FragmentedArray(_registers=new_regs, _layout=self.layout) + + # TODO(apaszke): Support JAX dtypes here as well? + def astype(self, new_dtype: ir.Type): + cur_dtype = self.mlir_dtype + if cur_dtype == new_dtype: + return self + from_float = ir.FloatType.isinstance(cur_dtype) + to_float = ir.FloatType.isinstance(new_dtype) + from_integer = ir.IntegerType.isinstance(cur_dtype) + to_integer = ir.IntegerType.isinstance(new_dtype) + if from_float and to_float: + if ir.FloatType(cur_dtype).width > ir.FloatType(new_dtype).width: + convert = arith.truncf + else: + convert = arith.extf + elif from_integer and to_integer: + if ir.IntegerType(cur_dtype).width > ir.IntegerType(new_dtype).width: + convert = arith.trunci + else: + convert = arith.extsi + elif from_integer and to_float: + convert = arith.sitofp + elif from_float and to_integer: + convert = arith.fptosi + new_registers = np.empty_like(self.registers) + match self.layout: + case WGMMAFragLayout(): + new_reg_ty = ir.VectorType.get((2,), new_dtype) + case WGStridedFragLayout(vec_size=vec_size): + new_reg_ty = ir.VectorType.get((vec_size,), new_dtype) + case WGMMARowFragLayout() | WGSplatFragLayout(): + new_reg_ty = new_dtype + case _: + raise NotImplementedError(f"Unsupported layout {self.layout}") + for idx, reg in np.ndenumerate(self.registers): + new_registers[idx] = convert(new_reg_ty, reg) + return FragmentedArray(_registers=new_registers, _layout=self.layout) + + def reduce_sum(self, scratch) -> ir.Value: + index = ir.IndexType.get() + if not isinstance(self.layout, WGStridedFragLayout): + raise NotImplementedError(f"Unsupported layout {self.layout}") + result = c(0, self.mlir_dtype) + for reg in self.registers: + result = arith.addf( + result, + vector.reduction(self.mlir_dtype, vector.CombiningKind.ADD, reg), + ) + scratch_ty = ir.MemRefType(scratch.type) + if scratch_ty.element_type != self.mlir_dtype or scratch_ty.shape != [4]: + raise ValueError(f"Expected shape={(4,)}, {self.mlir_dtype} (got {scratch_ty})") + + if ir.FloatType.isinstance(self.mlir_dtype): + op = arith.addf + elif ir.IntegerType.isinstance(self.mlir_dtype): + op = arith.addi + else: + raise NotImplementedError(self.mlir_dtype) + + warp_result = utils.warp_tree_reduce(result, op, 32) + warp_id = arith.divui(gpu.thread_id(gpu.Dimension.x), c(32, index)) + memref.store(warp_result, scratch, [warp_id]) + utils.commit_shared() + zero_index = c(0, index) + with mgpu.single_thread(): + scratch_vec = vector.load( + ir.VectorType.get((4,), self.mlir_dtype), + scratch, + [zero_index], + ) + scratch_sum = vector.reduction( + self.mlir_dtype, vector.CombiningKind.ADD, scratch_vec + ) + memref.store(scratch_sum, scratch, [zero_index]) + utils.commit_shared() + return memref.load(scratch, [zero_index]) + + def reduce(self, op, axis): + if self.layout != WGMMA_LAYOUT: + raise NotImplementedError(self.layout) + if axis != 1: + raise NotImplementedError + index = ir.IndexType.get() + i32 = ir.IntegerType.get_signless(32) + new_regs = np.empty(self.registers.shape[::2], dtype=object) + assert self.registers.shape[-1] == 1 + for row_tile, row_subtile in np.ndindex(new_regs.shape): + # Reduce the registers owned by the current thread over n tiles + thread_result_vec = self.registers[row_tile, 0, row_subtile, 0] + for n_tile in range(1, self.registers.shape[1]): + thread_result_vec = op( + thread_result_vec, self.registers[row_tile, n_tile, row_subtile, 0] + ) + thread_result = op( + vector.extractelement(thread_result_vec, position=c(0, index)), + vector.extractelement(thread_result_vec, position=c(1, index)), + ) + # Do a shuffle to reduce in groups of 4 consecutive threads. + result = thread_result + for i in (1, 2): + other_result = nvvm.shfl_sync( + result.type, + c(0xFFFFFFFF, i32), + result, + c(i, i32), + c(0x1F, i32), + nvvm.ShflKind.bfly, + ) + result = op(result, other_result) + new_regs[row_tile, row_subtile] = result + return FragmentedArray(_registers=new_regs, _layout=WGMMA_ROW_LAYOUT) + + def broadcast(self, shape): + if not isinstance(self.layout, WGSplatFragLayout): + raise NotImplementedError(self.layout) + + if self.shape == shape: + return self + + if not self.layout.can_broadcast_to(shape): + raise ValueError(f"Can't broadcast {self.shape} to {shape}") + + return FragmentedArray(_registers=self.registers, _layout=WGSplatFragLayout(shape)) + + def reshape(self, shape): + if self.shape == shape: + return self + + if not isinstance(self.layout, WGSplatFragLayout): + raise NotImplementedError(self.layout) + + if np.prod(shape) != np.prod(self.shape): + raise ValueError(f"Can't reshape {self.shape} to {shape}") + + return FragmentedArray(_registers=self.registers, _layout=WGSplatFragLayout(shape)) + + def broadcast_minor(self, n): + if self.layout != WGMMA_ROW_LAYOUT: + raise NotImplementedError + num_row_tiles = self.registers.shape[0] + num_col_tiles, rem = divmod(n, 8) + if rem: + raise ValueError("Number of columns must be divisible by 8") + new_regs = np.empty((num_row_tiles, num_col_tiles, 2, 1), dtype=object) + dtype = self.mlir_dtype + for (row_tile, row_subtile), reg in np.ndenumerate(self.registers): + new_regs[row_tile, :, row_subtile, :] = vector.splat( + ir.VectorType.get((2,), dtype), reg + ) + return FragmentedArray(_registers=new_regs, _layout=WGMMA_LAYOUT) + + def store_untiled(self, ref: ir.Value): + if not ir.MemRefType.isinstance(ref.type): + raise ValueError(ref) + + match self.layout: + case WGMMAFragLayout(): + self._store_untiled_wgmma(ref) + case WGStridedFragLayout(): + self._store_untiled_wg_strided(ref) + case _: + raise NotImplementedError(self.layout) + + def _store_untiled_wg_strided(self, ref: ir.Value): + ref_ty = ir.MemRefType(ref.type) + ref_shape = tuple(ref_ty.shape) + if ref_shape != self.shape: + raise ValueError((ref_shape, self.shape)) + smem_1d = mgpu.memref_fold(ref, 0, len(ref_ty.shape)) + for idx, reg in zip(self.layout.thread_vec_idxs(), self.registers.flat): + vector.store(reg, smem_1d, idx) + + def _store_untiled_wgmma(self, ref: ir.Value): + """Stores accumulator to a 2D memref. Not optimized at the moment.""" + assert self.layout == WGMMA_LAYOUT + index = ir.IndexType.get() + m, n = self.shape + ref_ty = ir.MemRefType(ref.type) + if ref_ty.shape != [m, n]: + raise ValueError(ref.type, (m, n)) + + def c(x): + return arith.ConstantOp(index, ir.IntegerAttr.get(index, x)) + + tidx = arith.remui(gpu.thread_id(gpu.Dimension.x), c(WARPGROUP_SIZE)) + lane_id = arith.remui(tidx, c(32)) # {0, 1, ..., 31} + warp_id = arith.divui(tidx, c(32)) # {0, 1, 2, 3} + row_base = arith.addi( + arith.divui(lane_id, c(4)), arith.muli(warp_id, c(16)) + ) + col_base = arith.muli(arith.remui(lane_id, c(4)), c(2)) # {0, 2, 4, 6} + it = np.ndenumerate(self.registers) + for (row_tile, col_tile, row_idx, col_zero), elem in it: + del col_zero + row = arith.addi(row_base, c(row_tile * 64 + row_idx * 8)) + for col_idx in range(2): + value = vector.extractelement(elem, position=c(col_idx)) + col = arith.addi(col_base, c(col_tile * 8 + col_idx)) + memref.store(value, ref, [row, col]) + + def store_tiled(self, ref, swizzle: int | None): + if self.layout != WGMMA_LAYOUT: + raise NotImplementedError + dtype = self.mlir_dtype + bw = mgpu.bytewidth(dtype) + m, n = self.shape + assert m % 64 == 0 # This is implied by the layout. + cols_per_tile = 128 // bw + expected_shape = [m // 64, n // cols_per_tile, 64, cols_per_tile] + if ir.MemRefType(ref.type).shape != expected_shape: + raise ValueError(ref.type, (m, n)) + for get, _, idxs in self.transfer_tiled(self.shape, dtype, swizzle): + vector.store(get(self.registers), ref, idxs) + + @classmethod + def load_tiled(cls, ref, swizzle: int | None): + ref_ty = ir.MemRefType(ref.type) + dtype = ref_ty.element_type + bw = mgpu.bytewidth(dtype) + m_tiles, n_tiles, m_tile_size, n_tile_size = ref_ty.shape + if m_tile_size != 64 or n_tile_size != (128 // bw): + raise ValueError + m, n = m_tiles * m_tile_size, n_tiles * n_tile_size + assert m % 64 == 0 # This is implied by the layout. + registers = np.full( + (m_tiles, n // 8, 2, 1), + vector.splat(ir.VectorType.get((2,), dtype), c(0, dtype)), + dtype=object, + ) + for _, update, idxs in cls.transfer_tiled((m, n), dtype, swizzle): + update(registers, vector.load(ir.VectorType.get((2,), dtype), ref, idxs)) + return cls(_registers=registers, _layout=WGMMA_LAYOUT) + + @staticmethod + def transfer_tiled(shape, dtype, swizzle: int | None): + bw = mgpu.bytewidth(dtype) + m, n = shape + if n % 32 != 0: + raise NotImplementedError + cols_per_tile = 128 // bw + if swizzle != 128: + raise NotImplementedError("Only 128B swizzle supported") + + c = arith.ConstantOp.create_index + tidx = arith.remui(gpu.thread_id(gpu.Dimension.x), c(WARPGROUP_SIZE)) + lane_id = arith.remui(tidx, c(32)) # {0, 1, ..., 31} + warp_id = arith.divui(tidx, c(32)) # {0, 1, 2, 3} + sub_row_base = arith.divui(lane_id, c(4)) # {0, 1, ..., 7} + if bw > 2: # Stagger is only necessary for values larger than 16bit. + is_even_row = arith.cmpi( + arith.CmpIPredicate.eq, arith.remui(sub_row_base, c(2)), c(0) + ) + else: + # We rely on canonicalization to clean up the selects. + i1 = ir.IntegerType.get_signless(1) + is_even_row = arith.constant(i1, ir.BoolAttr.get(True)) + row_base = arith.addi(sub_row_base, arith.muli(warp_id, c(16))) + col_base = arith.muli(arith.remui(lane_id, c(4)), c(2)) # {0, 2, 4, 6} + # The swizzle pattern is constant for a given thread. + col_swizzle_bits = arith.muli(sub_row_base, c(16 // bw)) + for row_group in range(m // 64): + for col_group in range(n // cols_per_tile): + for row_subidx in range(2): + row = arith.addi(row_base, c(row_subidx * 8)) + for col_subidx in range(cols_per_tile // 8): + # We stagger the even and odd rows a little to avoid bank conflicts. + # It seems that the STS.64 is 2x faster (and the hardware reports no + # conflicts) when the conflicts are split between half-warps, as + # opposed to having them within the half-warp. This requires a + # little more work for the selects, but is ultimately worth it. + col_subidx_even = col_subidx + col_subidx_odd = col_subidx ^ 2 + col_off = arith.select( + is_even_row, c(col_subidx_even * 8), c(col_subidx_odd * 8) + ) + col = arith.addi(col_base, col_off) + col = arith.xori(col, col_swizzle_bits) + reg_idx_even = col_subidx_even + col_group * (cols_per_tile // 8) + reg_idx_odd = col_subidx_odd + col_group * (cols_per_tile // 8) + even_idx = row_group, reg_idx_even, row_subidx, 0 + odd_idx = row_group, reg_idx_odd, row_subidx, 0 + idx = c(row_group), c(col_group), row, col + def get_register(regs, even_idx=even_idx, odd_idx=odd_idx): + value_even = regs[even_idx] + value_odd = regs[odd_idx] + return arith.select(is_even_row, value_even, value_odd) + def update_registers(regs, new, even_idx=even_idx, odd_idx=odd_idx): + regs[even_idx] = arith.select(is_even_row, new, regs[even_idx]) + regs[odd_idx] = arith.select(is_even_row, regs[odd_idx], new) + yield get_register, update_registers, idx + + def tree_flatten(self): + return list(self.registers.flat), (self.layout, self.registers.shape) + + @classmethod + def tree_unflatten(cls, aux, flat_registers): + layout, reg_shape = aux + registers = np.asarray(flat_registers, dtype=object).reshape(reg_shape) + return cls(_registers=registers, _layout=layout) diff --git a/jax/experimental/mosaic/gpu/profiler.py b/jax/experimental/mosaic/gpu/profiler.py new file mode 100644 index 000000000000..4ad0e6117c79 --- /dev/null +++ b/jax/experimental/mosaic/gpu/profiler.py @@ -0,0 +1,300 @@ +# Copyright 2024 The JAX Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import contextlib +import ctypes +import functools +import json +import math + +import jax +from jax._src.interpreters import mlir +from jax._src.lib import xla_client +import jax.numpy as jnp +from jaxlib.mlir import ir +from jaxlib.mlir.dialects import arith +from jaxlib.mlir.dialects import gpu +from jaxlib.mlir.dialects import memref +from jaxlib.mlir.dialects import scf +import numpy as np + +from .utils import * # noqa: F403 + + +try: + from jax._src.lib import mosaic_gpu as mosaic_gpu_lib + + xla_client.register_custom_call_target( + "mosaic_gpu_record_event", + mosaic_gpu_lib._mosaic_gpu_ext._record_event_capsule(), + platform="CUDA", + ) +except ImportError: + pass + +# ruff: noqa: F405 +# mypy: ignore-errors + + +record_event_p = jax.core.Primitive("record_event") +record_event_p.multiple_results = True + +@record_event_p.def_abstract_eval +def _record_event_abstract_eval(*args, event): + del event # Unused. + return args + +@functools.partial(mlir.register_lowering, record_event_p, platform="cuda") +def _record_event_lowering_rule(ctx, *args, event): + ptr_bytes = ctypes.cast(event, ctypes.c_void_p).value.to_bytes( + 8, byteorder="little" + ) # pytype: disable=attribute-error + op = mlir.custom_call( + "mosaic_gpu_record_event", + result_types=[mlir.aval_to_ir_type(aval) for aval in ctx.avals_out], + operands=args, + backend_config=ptr_bytes, + operand_output_aliases={i: i for i in range(len(args))}, + ) + return op.results + +def _record_event(args, event): + flat_args, treedef = jax.tree.flatten(args) + return jax.tree.unflatten( + treedef, record_event_p.bind(*flat_args, event=event) + ) + +def measure(f, *args, **kwargs): + # TODO(apaszke): Raise if this is called under jit. + start_event = mosaic_gpu_lib._mosaic_gpu_ext._gpu_event_create() + end_event = mosaic_gpu_lib._mosaic_gpu_ext._gpu_event_create() + try: + + @jax.jit + def run(*args, **kwargs): + flat_args, treedef = jax.tree.flatten((args, kwargs)) + flat_args = _record_event(flat_args, start_event) + args, kwargs = jax.tree.unflatten(treedef, flat_args) + return _record_event(f(*args, **kwargs), end_event) + + jax.block_until_ready(run(*args, **kwargs)) # Warmup. + results = jax.block_until_ready(run(*args, **kwargs)) + elapsed = mosaic_gpu_lib._mosaic_gpu_ext._gpu_event_elapsed( + start_event, end_event + ) + finally: + mosaic_gpu_lib._mosaic_gpu_ext._gpu_event_destroy(start_event) + mosaic_gpu_lib._mosaic_gpu_ext._gpu_event_destroy(end_event) + return results, elapsed + + +class ProfilerSpec: + ENTER = 0 + EXIT = 1 << 31 + + def __init__(self, entries_per_warpgroup: int): + self.entries_per_warpgroup = entries_per_warpgroup + self.interned_names = {} + + def _num_warpgroups( + self, grid: tuple[int, ...], block: tuple[int, ...] + ) -> int: + if math.prod(block) % WARPGROUP_SIZE: + raise ValueError("Block size is not a multiple of warpgroup size") + return math.prod(grid) * math.prod(block) // WARPGROUP_SIZE + + def mlir_buffer_type( + self, grid: tuple[int, ...], block: tuple[int, ...] + ) -> ir.Type: + return ir.MemRefType.get( + (self._num_warpgroups(grid, block) * self.entries_per_warpgroup,), + ir.IntegerType.get_signless(32), + ) + + def jax_buffer_type( + self, grid: tuple[int, ...], block: tuple[int, ...] + ) -> ir.Type: + return jax.ShapeDtypeStruct( + (self._num_warpgroups(grid, block) * self.entries_per_warpgroup,), + jnp.uint32, + ) + + def smem_i32_elements(self, block: tuple[int, ...]): + num_warpgroups = self._num_warpgroups((), block) + return int(num_warpgroups * self.entries_per_warpgroup) + + def smem_bytes(self, block: tuple[int, ...]): + bytes_per_entry = 4 + return self.smem_i32_elements(block) * bytes_per_entry + + def intern_name(self, name: str) -> int: + if name_id := self.interned_names.get(name, None): + return name_id + name_id = self.interned_names[name] = len(self.interned_names) + if name_id & self.EXIT: + raise RuntimeError("Allocated too many names") + return name_id + + def dump(self, buffer, f, grid: tuple[int, ...], block: tuple[int, ...]): + buffer = np.asarray(buffer) + num_blocks = math.prod(grid) + warpgroups_per_block = self._num_warpgroups((), block) + entries = buffer.reshape( + num_blocks, warpgroups_per_block, self.entries_per_warpgroup + ) + start_times = entries[..., :2].astype(np.int64) + start_times = (start_times[..., 0] << 32) + start_times[..., 1] + start_times -= start_times.min() # Normalize + entries_used = entries[..., 2] + if np.any(entries_used > self.entries_per_warpgroup - 2): + raise RuntimeError("Insufficient space to capture a full trace") + traces = entries[..., 3:] + unintern = {v: k for k, v in self.interned_names.items()} + events = [] + for block_idx, wg_idx in np.ndindex(num_blocks, warpgroups_per_block): + valid_entries = entries_used[block_idx, wg_idx] - 3 + local_clock_offset = None + assert valid_entries % 2 == 0, valid_entries + start_time = start_times[block_idx, wg_idx] + block_events = [] + for i in range(0, valid_entries, 2): + tag = traces[block_idx, wg_idx, i] + time = traces[block_idx, wg_idx, i + 1] + if local_clock_offset is None: + local_clock_offset = time + time -= local_clock_offset + time -= i * 6 # Account for the overhead of profiling. + if time < 0: + break # Detect a timer wraparound + name_id = tag + begin = True + if name_id & ProfilerSpec.EXIT: + name_id = name_id ^ ProfilerSpec.EXIT + begin = False + name = unintern[name_id] + block_events.append({ + "name": name, + "ph": "B" if begin else "E", + "ts": float(start_time + time) / 1e3, + "pid": 1 + block_idx, + "tid": 1 + wg_idx, + }) + else: # If we didn't break + events.extend(block_events) + return json.dump({"displayTimeUnit": "ns", "traceEvents": events}, f) + + +class OnDeviceProfiler: + + def __init__(self, spec: ProfilerSpec, smem_buffer: ir.Value, gmem_buffer: ir.Value): + self.spec = spec + # self.should_store = gpu.thread_id(gpu.Dimension.x) + i32 = ir.IntegerType.get_signless(32) + index = ir.IndexType.get() + self.entries_per_wg = spec.entries_per_warpgroup + wg_idx = warpgroup_idx(sync=False) + self.smem_buffer = memref_slice( + smem_buffer, + ds( + arith.index_cast( + index, arith.muli(wg_idx, c(self.entries_per_wg, i32)) + ), + self.entries_per_wg, + ), + ) + self.gmem_buffer = gmem_buffer + # Hopefully mem2reg will remove the allocation. + self.offset = memref.alloca(ir.MemRefType.get((), i32), [], []) + memref.store(c(0, i32), self.offset, []) + + @contextlib.contextmanager + def record(self, name: str): + i32 = ir.IntegerType.get_signless(32) + index = ir.IndexType.get() + name_id = self.spec.intern_name(name) + def store(modifier): + cur = arith.index_cast(index, memref.load(self.offset, [])) + # TODO(apaszke): Clamp indices + # bound = arith.subi(self.entries_per_block, c(2, index)) + # cur = arith.select( + # arith.cmpi(arith.CmpIPredicate.ult, cur, bound), cur, bound + # ) + memref.store(c(modifier | name_id, i32), self.smem_buffer, [cur]) + memref.store( + clock(), self.smem_buffer, [arith.addi(cur, c(1, cur.type))] + ) + memref.store( + arith.index_cast(i32, arith.addi(cur, c(2, cur.type))), + self.offset, + [], + ) + store(ProfilerSpec.ENTER) + yield + store(ProfilerSpec.EXIT) + + def finalize(self, grid: tuple[int, ...], block: tuple[int, ...]): + index = ir.IndexType.get() + i32 = ir.IntegerType.get_signless(32) + + gpu.barrier() # Make sure all warpgroups are done. + + block_idx = c(0, index) + for dim in gpu.Dimension: # pytype: disable=wrong-arg-types + block_idx = arith.addi( + arith.muli(block_idx, gpu.grid_dim(dim)), gpu.block_id(dim) + ) + wg_idx = warpgroup_idx(sync=False) + wg_per_block = math.prod(block) // WARPGROUP_SIZE + global_wg_idx = arith.addi( + arith.muli(block_idx, c(wg_per_block, index)), + arith.index_cast(index, wg_idx), + ) + start_offset = arith.muli(global_wg_idx, c(self.entries_per_wg, index)) + wg_gmem_buffer = memref.subview( + self.gmem_buffer, [start_offset], [self.entries_per_wg], [1], + result_type=ir.Type.parse( + f"memref<{self.entries_per_wg}xi32, strided<[1], offset: ?>>" + ), + ) + thread_in_wg = arith.remui(thread_idx(), c(128, i32)) + if_first = scf.IfOp( + arith.cmpi(arith.CmpIPredicate.eq, thread_in_wg, c(0, i32)) + ) + with ir.InsertionPoint(if_first.then_block): + # TODO(apaszke): Either use globaltimer or delete + # memref.store(globaltimer("high"), block_gmem_buffer, [c(0, index)]) + # memref.store(globaltimer("low"), block_gmem_buffer, [c(1, index)]) + memref.store(c(0, i32), wg_gmem_buffer, [c(0, index)]) + memref.store(c(0, i32), wg_gmem_buffer, [c(1, index)]) + memref.store( + arith.addi(memref.load(self.offset, []), c(3, i32)), + wg_gmem_buffer, + [c(2, index)], + ) + + for_op = scf.ForOp( + c(0, index), + c(self.entries_per_wg - 3, index), + c(1, index), + ) + with ir.InsertionPoint(for_op.body): + x = memref.load(self.smem_buffer, [for_op.induction_variable]) + memref.store( + x, + wg_gmem_buffer, + [arith.addi(for_op.induction_variable, c(3, index))], + ) + scf.yield_([]) + scf.yield_([]) diff --git a/jax/experimental/mosaic/gpu/utils.py b/jax/experimental/mosaic/gpu/utils.py new file mode 100644 index 000000000000..95577bb28999 --- /dev/null +++ b/jax/experimental/mosaic/gpu/utils.py @@ -0,0 +1,784 @@ +# Copyright 2024 The JAX Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utilities for code generator.""" + +from collections.abc import Iterator, Sequence +import contextlib +import dataclasses +import enum +import functools +from typing import Any, Literal + +import jax +from jaxlib.mlir import ir +from jaxlib.mlir.dialects import arith +from jaxlib.mlir.dialects import builtin +from jaxlib.mlir.dialects import gpu +from jaxlib.mlir.dialects import llvm +from jaxlib.mlir.dialects import memref +from jaxlib.mlir.dialects import nvgpu +from jaxlib.mlir.dialects import nvvm +from jaxlib.mlir.dialects import scf +from jaxlib.mlir.dialects import vector +import numpy as np + +# mypy: ignore-errors + +WARPGROUP_SIZE: int = 128 +DYNAMIC = -9223372036854775808 + +# pylint: disable=line-too-long, wildcard-import, missing-function-docstring, bad-continuation, g-bad-todo, protected-access, g-explicit-length-test, missing-class-docstring, g-doc-return-or-yield, g-inconsistent-quotes + + +def ptr_as_memref(ptr, memref_ty: ir.MemRefType): + if len(memref_ty.shape) == 0: + raise NotImplementedError + i64 = ir.IntegerType.get_signless(64) + rank = len(memref_ty.shape) + desc_ty = ir.Type.parse( + f"!llvm.struct<(ptr, ptr, i64, array<{rank} x i64>, array<{rank} x i64>)>" + ) + desc = llvm.UndefOp(desc_ty) + desc = llvm.InsertValueOp(desc, ptr, [0]) # Allocation + desc = llvm.InsertValueOp(desc, ptr, [1]) # Aligned Base + desc = llvm.InsertValueOp( + desc, llvm.ConstantOp(i64, ir.IntegerAttr.get(i64, 0)), [2] + ) + for i, s in enumerate(memref_ty.shape): + desc = llvm.InsertValueOp( + desc, llvm.ConstantOp(i64, ir.IntegerAttr.get(i64, s)), [3, i] + ) + for i, s in enumerate(get_contiguous_strides(memref_ty.shape)): + desc = llvm.InsertValueOp( + desc, llvm.ConstantOp(i64, ir.IntegerAttr.get(i64, s)), [4, i] + ) + return builtin.unrealized_conversion_cast([memref_ty], [desc]) + + +def pack_array(values): + if not values: + raise ValueError("Empty array") + elem_ty = values[0].type + i64 = ir.IntegerType.get_signless(64) + ptr_ty = ir.Type.parse("!llvm.ptr") + arr_ptr = llvm.alloca(ptr_ty, c(len(values), i64), elem_ty) + for i, v in enumerate(values): + elem_ptr = llvm.getelementptr(ptr_ty, arr_ptr, [], [i], elem_ty) + llvm.store(v, elem_ptr) + return arr_ptr + + +def get_contiguous_strides(xs): + strides_ret = [] + stride = 1 + for x in xs[::-1]: + strides_ret.append(stride) + stride *= x + return strides_ret[::-1] + + +def c(val: int | float, ty): + if ir.IntegerType.isinstance(ty) or ir.IndexType.isinstance(ty): + if not isinstance(val, (int, np.integer)): + raise TypeError(type(val)) + attr = ir.IntegerAttr.get(ty, val) + elif ir.FloatType.isinstance(ty): + attr = ir.FloatAttr.get(ty, val) + elif ir.VectorType.isinstance(ty): + return vector.splat(ty, c(val, ir.VectorType(ty).element_type)) + else: + raise NotImplementedError(ty) + return arith.constant(ty, attr) + + +def debug_print(fmt, *args, uniform=True): + type_formats = [] + new_args = [] + for arg in args: + ty_format = None + if ir.IndexType.isinstance(arg.type): + ty_format = "%llu" + if ir.IntegerType.isinstance(arg.type): + width = ir.IntegerType(arg.type).width + ty_format = "%llu" + if width < 64: + arg = arith.extui(ir.IntegerType.get_signless(64), arg) + if ir.F32Type.isinstance(arg.type): + ty_format = "%f" + if ir.F16Type.isinstance(arg.type): + ty_format = "%f" + arg = arith.extf(ir.F32Type.get(), arg) + if ty_format is None: + raise NotImplementedError(arg.type) + type_formats.append(ty_format) + new_args.append(arg) + ctx = ( + functools.partial(single_thread, per_block=False) + if uniform + else contextlib.nullcontext + ) + with ctx(): + gpu.printf(fmt.format(*type_formats) + "\n", new_args) + + +@dataclasses.dataclass(frozen=True) +class ForResult: + op: scf.ForOp + results: tuple[Any, ...] + + @property + def result(self): + if len(self.results) != 1: + raise ValueError + return self.results[0] + + +def fori(bound, carrys): + unwrap = False + if not isinstance(carrys, (list, tuple)): + carrys = [carrys] + unwrap = True + flat_carrys, carry_treedef = jax.tree.flatten(carrys) + + def wrapper(f): + index = ir.IndexType.get() + c0 = arith.ConstantOp(index, ir.IntegerAttr.get(index, 0)) + c1 = arith.ConstantOp(index, ir.IntegerAttr.get(index, 1)) + for_op = scf.ForOp(c0, bound, c1, flat_carrys) + with ir.InsertionPoint(for_op.body): + i = for_op.induction_variable + inner_carrys = jax.tree.unflatten(carry_treedef, for_op.inner_iter_args) + if unwrap: + [inner_carrys] = inner_carrys + new_carrys = f(i, inner_carrys) + if unwrap: + new_carrys = [new_carrys] + new_flat_carrys, new_carry_treedef = jax.tree.flatten(new_carrys) + if new_carry_treedef != carry_treedef: + raise ValueError(new_carry_treedef, carry_treedef) + scf.YieldOp(new_flat_carrys) + final_flat_carrys = for_op.results + return ForResult( + for_op, jax.tree.unflatten(carry_treedef, final_flat_carrys) + ) + + return wrapper + + +def thread_idx(): + i32 = ir.IntegerType.get_signless(32) + as_i32 = lambda x: arith.index_cast(i32, x) + tidx = as_i32(gpu.thread_id(gpu.Dimension.x)) + stride = as_i32(gpu.block_dim(gpu.Dimension.x)) + for dim in (gpu.Dimension.y, gpu.Dimension.z): + tidx = arith.addi(tidx, arith.muli(as_i32(gpu.thread_id(dim)), stride)) + stride = arith.muli(stride, as_i32(gpu.block_dim(dim))) + return tidx + + +def _warp_bcast(val, lane_idx=0): + i32 = ir.IntegerType.get_signless(32) + mask = c(0xFFFFFFFF, i32) + return nvvm.shfl_sync( + val.type, mask, val, c(lane_idx, i32), c(0x1F, i32), nvvm.ShflKind.idx + ) + + +def warp_idx(sync=True): + i32 = ir.IntegerType.get_signless(32) + warp_idx = arith.shrui(thread_idx(), c(5, i32)) + # Performing a warp broadcast improves performance as compiler understands + # that the value is uniform across the warp. + return _warp_bcast(warp_idx) if sync else warp_idx + + +def warpgroup_idx(sync=True): + i32 = ir.IntegerType.get_signless(32) + wg_idx = arith.shrui(thread_idx(), c(7, i32)) + # Performing a warp broadcast improves performance as compiler understands + # that the value is uniform across the warp. + return _warp_bcast(wg_idx) if sync else wg_idx + + +class ThreadSubset(enum.IntEnum): + WARPGROUP = enum.auto() + BLOCK = enum.auto() + + +# True withon `once()` contexts. +_ONCE_PER: ThreadSubset | None = None + + +@contextlib.contextmanager +def single_thread(per_block=True): + """Runs the context only from a single thread. + + Args: + per_block: If True, only one thread per block will run the context. + Otherwise, only one thread per warp group will run the context. + """ + global _ONCE_PER + scope = ThreadSubset.BLOCK if per_block else ThreadSubset.WARPGROUP + # If we're already in a single-thread context, we don't have to do anything. + if _ONCE_PER is not None and _ONCE_PER >= scope: + yield + return + + warp = warp_idx() + if not per_block: + warp = arith.remui(warp, c(4, warp.type)) + first_warp = arith.cmpi(arith.CmpIPredicate.eq, warp, c(0, warp.type)) + elected = nvvm.elect_sync(ir.IntegerType.get_signless(1)) + should_run = arith.andi(first_warp, elected) + if_op = scf.IfOp(should_run) + prev_scope = _ONCE_PER + _ONCE_PER = scope + try: + with ir.InsertionPoint(if_op.then_block): + yield + scf.YieldOp([]) + finally: + _ONCE_PER = prev_scope + + +def clock(): + i32 = ir.IntegerType.get_signless(32) + return llvm.inline_asm( + i32, [], "mov.u32 $0,%clock;", "=r", asm_dialect=0, has_side_effects=True + ) + + +def globaltimer(kind: Literal["low", "high"] | None = None): + if kind is None: + i64 = ir.IntegerType.get_signless(64) + return llvm.inline_asm( + i64, [], "mov.u32 $0,%globaltimer;", + "=l", asm_dialect=0, has_side_effects=True, + ) + i32 = ir.IntegerType.get_signless(32) + return llvm.inline_asm( + i32, [], f"mov.u32 $0,%globaltimer_{kind[:2]};", + "=r", asm_dialect=0, has_side_effects=True, + ) + + +def bytewidth(ty: ir.Type): + if ir.IntegerType.isinstance(ty): + return ir.IntegerType(ty).width // 8 + if ir.FloatType.isinstance(ty): + return ir.FloatType(ty).width // 8 + raise NotImplementedError(ty) + + +@dataclasses.dataclass(frozen=True) +class DynamicSlice: + base: ir.Value | int + length: int + + +ds = DynamicSlice + + +def memref_slice(ref: ir.Value, index) -> ir.Value: + ref_ty = ir.MemRefType(ref.type) + base_indices, slice_shape, is_squeezed = parse_indices(index, ref_ty.shape) + + memref_strides, offset = ref_ty.get_strides_and_offset() + new_offset = offset + for idx, stride in zip(base_indices, memref_strides): + if isinstance(idx, int): + new_offset += idx * stride + else: + new_offset = ir.ShapedType.get_dynamic_stride_or_offset() + break + new_strides = [ + s for s, squeeze in zip(memref_strides, is_squeezed) if not squeeze + ] + new_shape = [s for s, squeeze in zip(slice_shape, is_squeezed) if not squeeze] + new_layout = ir.StridedLayoutAttr.get(new_offset, new_strides) + + ref_slice = memref.subview( + ref, base_indices, slice_shape, [1] * len(ref_ty.shape), + result_type=ir.MemRefType.get( + new_shape, ref_ty.element_type, new_layout, ref_ty.memory_space + ), + ) + return ref_slice + + +def _is_contiguous_shape_slice( + ref_ty: ir.MemRefType, dim_slice: slice | None = slice(None) +): + # If it's not a strided layout then we are definitely contiguous. + if not ir.StridedLayoutAttr.isinstance(ref_ty.layout): + return True + + strides = ir.StridedLayoutAttr(ref_ty.layout).strides[dim_slice] + shape = ref_ty.shape[dim_slice] + + # Check that each dimension fits exactly it the immediately larger stride. + ss = sorted(zip(strides, shape), key=lambda x: x[0], reverse=True) + for (prev_stride, _), (stride, shape) in zip(ss, ss[1:]): + if stride * shape != prev_stride: + return False + + return True + + +def memref_fold(ref: ir.Value, dim, fold_rank) -> ir.Value: + ref_ty = ir.MemRefType(ref.type) + new_shape = list(ref_ty.shape) + new_shape[dim : dim + fold_rank] = [np.prod(new_shape[dim : dim + fold_rank])] + identity = ir.AffineMapAttr.get(ir.AffineMap.get_identity(ref_ty.rank)) + contig_strided_1d = ir.Attribute.parse("strided<[1]>") + # Not sure why but MLIR expects the strided 1D layout to disappear in this op. + if ref_ty.layout == identity or ref_ty.layout == contig_strided_1d: + new_layout = ir.AffineMapAttr.get( + ir.AffineMap.get_identity(ref_ty.rank - fold_rank + 1) + ) + elif _is_contiguous_shape_slice(ref_ty, slice(dim, dim + fold_rank)): + new_strides, offset = ref_ty.get_strides_and_offset() + new_strides[dim : dim + fold_rank] = [new_strides[dim + fold_rank - 1]] + new_layout = ir.StridedLayoutAttr.get(offset, new_strides) + else: + raise NotImplementedError( + f"strides={ref_ty.get_strides_and_offset()[0]}, {ref_ty.shape=}," + f" {dim=}, {fold_rank=}" + ) + + new_ty = ir.MemRefType.get( + new_shape, ref_ty.element_type, new_layout, ref_ty.memory_space + ) + assoc = [[d] for d in range(dim)] + assoc.append([dim + i for i in range(fold_rank)]) + assoc.extend([d] for d in range(dim + fold_rank, ref_ty.rank)) + assert len(assoc) == new_ty.rank + return memref.collapse_shape(new_ty, ref, assoc) + + +def memref_unfold(ref: ir.Value, dim, factors) -> ir.Value: + """Unfolds dim into two dimensions, the size of leading one given be major_factor.""" + ref_ty = ir.MemRefType(ref.type) + new_shape = list(ref_ty.shape) + if sum(f is None for f in factors) > 1: + raise ValueError("Can only infer one dimension") + known_factor_prod = np.prod([f for f in factors if f is not None]) + if new_shape[dim] % known_factor_prod: + raise ValueError("Non-divisible unfold:", new_shape[dim], factors) + factors = tuple( + new_shape[dim] // known_factor_prod if f is None else f for f in factors + ) + new_shape[dim : dim + 1] = factors + identity = ir.AffineMapAttr.get(ir.AffineMap.get_identity(ref_ty.rank)) + if ref_ty.layout == identity: + new_layout = ir.AffineMapAttr.get( + ir.AffineMap.get_identity(ref_ty.rank + len(factors) - 1) + ) + else: + new_strides, offset = ref_ty.get_strides_and_offset() + prev_stride = new_strides[dim] + inserted_strides = [] + for f in reversed(factors): + inserted_strides.append(prev_stride) + prev_stride *= f + new_strides[dim : dim + 1] = reversed(inserted_strides) + new_layout = ir.StridedLayoutAttr.get(offset, new_strides) + new_ty = ir.MemRefType.get( + new_shape, ref_ty.element_type, new_layout, ref_ty.memory_space + ) + if dim == ref_ty.rank: + assoc = [[d] for d in range(ref_ty.rank)] + assoc[-1].extend(range(ref_ty.rank, ref_ty.rank + len(factors) - 1)) + else: + assoc = [[d] for d in range(dim)] + assoc.append(list(range(dim, dim + len(factors)))) + assoc.extend([d + len(factors) - 1] for d in range(dim + 1, ref_ty.rank)) + assert len(assoc) == ref_ty.rank + return memref.expand_shape(new_ty, ref, assoc, [], new_ty.shape) + + +def memref_unsqueeze(ref: ir.Value, dim) -> ir.Value: + """Inserts a singleton dimension.""" + ref_ty = ir.MemRefType(ref.type) + if dim == ref_ty.rank: + new_shape = list(ref_ty.shape) + new_shape.append(1) + identity = ir.AffineMapAttr.get(ir.AffineMap.get_identity(ref_ty.rank)) + if ref_ty.layout == identity: + new_layout = ir.AffineMapAttr.get( + ir.AffineMap.get_identity(ref_ty.rank + 1) + ) + else: + new_strides, offset = ref_ty.get_strides_and_offset() + new_strides.append(1) + new_layout = ir.StridedLayoutAttr.get(offset, new_strides) + new_ty = ir.MemRefType.get( + new_shape, ref_ty.element_type, new_layout, ref_ty.memory_space + ) + assoc = [[d] for d in range(ref_ty.rank)] + assoc[-1].append(ref_ty.rank) + return memref.expand_shape(new_ty, ref, assoc, [], new_ty.shape) + else: + return memref_unfold(ref, dim, (1, None)) + + +def memref_transpose(ref: ir.Value, permutation: Sequence[int]) -> ir.Value: + ref_ty = ir.MemRefType(ref.type) + strides, offset = ref_ty.get_strides_and_offset() + new_strides = [strides[p] for p in permutation] + new_shape = [ref_ty.shape[p] for p in permutation] + new_layout = ir.StridedLayoutAttr.get(offset, new_strides) + new_ty = ir.MemRefType.get( + new_shape, ref_ty.element_type, new_layout, ref_ty.memory_space + ) + return memref.transpose( + new_ty, ref, ir.AffineMap.get_permutation(permutation) + ) + + +def parse_indices( + index, shape: tuple[int, ...] +) -> tuple[list[ir.Value | int], list[int], list[bool]]: + if not isinstance(index, tuple): + index = (index,) + if trailing_dims := len(shape) - len(index): + index += (slice(None),) * trailing_dims + base_indices = [] + slice_shape = [] + is_squeezed = [] + for idx, bound in zip(index, shape): + if isinstance(idx, (ir.Operation, ir.OpView)): + idx = idx.result + if isinstance(idx, int): + base_indices.append(idx) + slice_shape.append(1) + is_squeezed.append(True) + elif isinstance(idx, slice): + if idx.step is not None: + raise NotImplementedError("Strided slices not implemented") + base_indices.append(idx.start or 0) + slice_shape.append((idx.stop or bound) - (idx.start or 0)) + is_squeezed.append(False) + elif isinstance(idx, DynamicSlice): + base_indices.append(idx.base) + slice_shape.append(idx.length) + is_squeezed.append(False) + elif isinstance(idx, ir.Value): + if not ir.IndexType.isinstance(idx.type): + raise ValueError("Expected an index-typed index") + base_indices.append(idx) + slice_shape.append(1) + is_squeezed.append(True) + else: + raise NotImplementedError(type(idx)) + assert len(base_indices) == len(slice_shape) == len(is_squeezed) == len(shape) + return base_indices, slice_shape, is_squeezed + + +def commit_shared(): + gpu.barrier() + nvvm.fence_proxy( + nvvm.ProxyKind.async_shared, space=nvvm.SharedSpace.shared_cta + ) + + +class BarrierArray: + + def __init__(self, num_barriers: int, arrival_count: int = 1): + barrier_group_ty = ir.Type.parse( + "!nvgpu.mbarrier.group," + f" num_barriers={num_barriers}>" + ) + + self.num_barriers = num_barriers + self.value = nvgpu.mbarrier_create(barrier_group_ty) + self.num_barriers = num_barriers + index = ir.IndexType.get() + if num_barriers > 32: + raise NotImplementedError("Only up to 32 barriers per group supported") + i32 = ir.IntegerType.get_signless(32) + self.phases = memref.alloca(ir.MemRefType.get((), i32), [], []) + memref.store(c(0, i32), self.phases, []) + with single_thread(per_block=True): + for i in range(num_barriers): + nvgpu.mbarrier_init(self.value, c(arrival_count, index), c(i, index)) + gpu.barrier() + + def __iter__(self) -> Iterator["Barrier"]: + for offset in range(self.num_barriers): + yield self[offset] + + def __getitem__(self, offset: ir.Value | int): + index = ir.IndexType.get() + if isinstance(offset, int): + offset = c(offset, index) + if ir.IntegerType.isinstance(offset.type): + offset = arith.index_castui(index, offset) + return Barrier(self, offset) + + +@dataclasses.dataclass(frozen=True) +class Barrier: + barrier_array: BarrierArray + offset: ir.Value + + def wait_parity(self, parity, expect_wait=False): + i1 = ir.IntegerType.get_signless(1) + index = ir.IndexType.get() + if expect_wait: + nvgpu.mbarrier_try_wait_parity( + self.barrier_array.value, parity, c(10000000, index), self.offset, + ) + return + barrier_ptr = self.get_ptr() + barrier_ready = llvm.inline_asm( + i1, + [barrier_ptr, parity], + "mbarrier.test_wait.parity.shared.b64 $0, [$1], $2;", + "=b,l,r", + asm_dialect=0, + has_side_effects=True, + ) + should_wait = arith.xori(barrier_ready, c(1, i1)) + should_wait = llvm.intr_expect(should_wait, c(0, i1)) + with ir.InsertionPoint(scf.IfOp(should_wait).then_block): + nvgpu.mbarrier_try_wait_parity( + self.barrier_array.value, parity, c(10000000, index), self.offset, + ) + scf.yield_([]) + + def wait(self, expect_wait=False): + i32 = ir.IntegerType.get_signless(32) + parities = memref.load(self.barrier_array.phases, []) + offset_i32 = arith.index_castui(i32, self.offset) + bitmask = arith.shli(c(1, i32), offset_i32) + parity = arith.cmpi( + arith.CmpIPredicate.ne, arith.andi(parities, bitmask), c(0, i32) + ) + new_parities = arith.xori(parities, bitmask) + memref.store(new_parities, self.barrier_array.phases, []) + self.wait_parity(parity, expect_wait=expect_wait) + + def arrive(self): + token_ty = ir.Type.parse("!nvgpu.mbarrier.token") + nvgpu.mbarrier_arrive(token_ty, self.barrier_array.value, self.offset) + + def get_ptr(self): + i32 = ir.IntegerType.get_signless(32) + i64 = ir.IntegerType.get_signless(64) + ptr_ty = ir.Type.parse("!llvm.ptr<3>") + smem = ir.IntegerAttr.get(i64, 3) + num_barriers = self.barrier_array.num_barriers + mbarrier_ref_ty = ir.MemRefType.get((num_barriers,), i64, memory_space=smem) + mbarrier_ref = builtin.unrealized_conversion_cast( + [mbarrier_ref_ty], [self.barrier_array.value], + ) + mbarrier_ref_ptr = memref.extract_aligned_pointer_as_index(mbarrier_ref) + barrier_arr_ptr = llvm.inttoptr( + ptr_ty, arith.index_cast(i64, mbarrier_ref_ptr), + ) + offset_i32 = arith.index_cast(i32, self.offset) + return llvm.getelementptr( + ptr_ty, barrier_arr_ptr, [offset_i32], [-2147483648], i64, + ) + + +class Partition: + source_bounds: tuple[int, ...] + target_bounds: tuple[int, ...] + partition: tuple[int | None, ...] + base_offset: tuple[ir.Value, ...] | None + + def __init__( + self, + elements: tuple[int, ...], + *, + partition: tuple[int | None, ...], + base_offset: tuple[ir.Value, ...] | None = None, + num_chunks: tuple[int, ...] | None = None, + chunk_size: tuple[int, ...] | None = None, + ): + self.target_bounds = elements + self.partition = partition + self.base_offset = base_offset + if len(self.target_bounds) != len(self.partition): + raise ValueError + if num_chunks is None == chunk_size is None: + raise ValueError( + "Exactly one of num_chunks and chunk_size must be specified" + ) + if num_chunks is not None: + self.source_bounds = num_chunks + else: + if len(chunk_size) != len(self.target_bounds): + raise ValueError + source_bounds = [] + for els, chunk in zip(elements, chunk_size): + if els % chunk: + raise ValueError("Non-divisible partition", elements, chunk_size) + source_bounds.append(els // chunk) + self.source_bounds = tuple(source_bounds) + + seen_dims = set() + for p in self.partition: + if p is None: + continue + if not (0 <= p < len(self.source_bounds)): + raise ValueError + if p in seen_dims: + raise ValueError + seen_dims.add(p) + for tb, p in zip(self.target_bounds, self.partition): + if p is not None and tb % self.source_bounds[p]: + raise ValueError("Non-divisible partitioning") + + @property + def num_chunks(self) -> tuple[int, ...]: + return self.source_bounds + + @property + def target_block_shape(self): + return tuple(tb if p is None else tb // self.source_bounds[p] + for tb, p in zip(self.target_bounds, self.partition)) + + def get_base(self, *source_coords: ir.Value | int) -> list[ir.Value]: + coords = [] + index = ir.IndexType.get() + for i, (tbs, p) in enumerate(zip(self.target_block_shape, self.partition)): + if p is None: + dim_base = c(0, index) + else: + dim_base = arith.muli(c(tbs, index), source_coords[p]) + if self.base_offset is not None: + dim_base = arith.addi(self.base_offset[i], dim_base) + coords.append(dim_base) + return coords + + +class Partition1D: + partition: Partition + + def __init__( + self, + elements: int, + *, + base_offset: ir.Value | None = None, + num_chunks: int | None = None, + chunk_size: int | None = None, + ): + self.base_offset = base_offset + if num_chunks is None == chunk_size is None: + raise ValueError( + "Exactly one of num_chunks and chunk_size must be specified" + ) + common_kwargs = dict(elements=(elements,), partition=(0,)) + if base_offset is not None: + common_kwargs["base_offset"] = (base_offset,) + if num_chunks is not None: + self.partition = Partition(num_chunks=(num_chunks,), **common_kwargs) + else: + self.partition = Partition(chunk_size=(chunk_size,), **common_kwargs) + + @property + def num_chunks(self) -> int: + return self.partition.source_bounds[0] + + def get_base(self, source_coords: ir.Value) -> ir.Value: + return self.partition.get_base(source_coords)[0] + + def refine( + self, + *, + chunk: ir.Value | None = None, + num_chunks: int | None = None, + chunk_size: int | None = None, + ): + return Partition1D( + self.partition.target_block_shape[0], + num_chunks=num_chunks, + chunk_size=chunk_size, + base_offset=self.get_base(chunk) if chunk is not None else None, + ) + + +def tile_shape(shape, tiling): + if len(tiling) > len(shape): + raise ValueError + if not tiling: + return shape + tiling_rank = len(tiling) + for s, t in zip(shape[-tiling_rank:], tiling): + if s % t: + raise ValueError("Non-divisible tiling:", shape, tiling) + return ( + *shape[:-tiling_rank], + *(s // t for s, t in zip(shape[-tiling_rank:], tiling)), + *tiling, + ) + + +def warp_tree_reduce(value, op, group_size): + """Reduce a value across the warpgroup.""" + assert 32 % group_size == 0 and group_size <= 32 + i32 = ir.IntegerType.get_signless(32) + result = value + iters = np.log2(group_size) + if not iters.is_integer(): + raise ValueError(f"Warp reduction group size should be a power of 2 (got {group_size})") + iters = int(iters) + for i in range(iters): + other_result = nvvm.shfl_sync( + result.type, + c(0xFFFFFFFF, i32), + result, + c(1 << i, i32), + c(0x1F, i32), + nvvm.ShflKind.bfly, + ) + result = op(result, other_result) + + return result + + +def memref_ptr(memref_arg, memory_space=None): + i64 = ir.IntegerType.get_signless(64) + memref_ty = ir.MemRefType(memref_arg.type) + if len(memref_ty.shape) == 0: + raise NotImplementedError + elem_bytewidth = bytewidth(memref_ty.element_type) + rank = len(memref_ty.shape) + # TODO: Read out memory space from memref + space = "" if memory_space is None else "<" + str(memory_space) + ">" + ptr_ty = ir.Type.parse("!llvm.ptr" + space) + desc_ty = ir.Type.parse( + f"!llvm.struct<({ptr_ty}, {ptr_ty}, i64, array<{rank} x i64>," + f" array<{rank} x i64>)>" + ) + desc = builtin.UnrealizedConversionCastOp([desc_ty], [memref_arg]) + aligned_ptr = llvm.extractvalue(ptr_ty, desc, [1]) + offset_elems = llvm.extractvalue(i64, desc, [2]) + offset_bytes = llvm.mul( + offset_elems, + c(elem_bytewidth, i64), + overflow_flags=llvm.IntegerOverflowFlags.none, + ) + return llvm.inttoptr( + ptr_ty, + llvm.add( + llvm.ptrtoint(i64, aligned_ptr), + offset_bytes, + overflow_flags=llvm.IntegerOverflowFlags.none, + ), + ) diff --git a/jax/experimental/mosaic/gpu/wgmma.py b/jax/experimental/mosaic/gpu/wgmma.py new file mode 100644 index 000000000000..4d3db294143a --- /dev/null +++ b/jax/experimental/mosaic/gpu/wgmma.py @@ -0,0 +1,502 @@ +# Copyright 2024 The JAX Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import dataclasses +import enum +import functools +import itertools + +import jax +from jaxlib.mlir import ir +from jaxlib.mlir.dialects import arith +from jaxlib.mlir.dialects import builtin +from jaxlib.mlir.dialects import llvm +from jaxlib.mlir.dialects import nvvm +from jaxlib.mlir.dialects import vector +import numpy as np + +from . import dsl as mgpu +from . import utils + +# mypy: ignore-errors + +c = mgpu.c +bytewidth = mgpu.bytewidth + + +@jax.tree_util.register_pytree_node_class +@dataclasses.dataclass +class WGMMAAccumulator: + """A FragmentedArray that has is synchronized with the async proxy. + + This implies that it requires no additional synchronization when passed in + as a WGMMA accumulator. In particular, when created from a + FragmentedArray, the necessary synchronization is inserted at construction. + """ + value: mgpu.FragmentedArray + + def __init__(self, *, _value: mgpu.FragmentedArray, _sync: bool = True): + if _value.layout != mgpu.WGMMA_LAYOUT: + raise ValueError("Only WGMMA layouts supported in WGMMAAccumulator") + self.value = _value + if _sync: + self.value = wgmma_fence(_value) + + @classmethod + def zero(cls, m, n, dtype=None): + if m % 64 or n % 8: + raise ValueError + f32 = ir.F32Type.get() + if dtype is None: + dtype = f32 + zero = arith.constant(dtype, ir.FloatAttr.get(dtype, 0.0)) + return cls( + _value=mgpu.FragmentedArray.splat(zero, (m, n), mgpu.WGMMA_LAYOUT) + ) + + @classmethod + def from_registers(cls, registers): + return cls(_value=registers) + + def tree_flatten(self): + return (self.value,), () + + @classmethod + def tree_unflatten(cls, aux, value): + del aux + return cls(_value=value[0], _sync=False) + + +def wgmma_encode(x: int): + result = (x & 0x3FFFF) >> 4 + if result << 4 != x: + raise ValueError("Cannot encode value in a WGMMA descriptor") + return result + + +def llvm_add(x, y): + return llvm.add(x, y, overflow_flags=llvm.IntegerOverflowFlags.none) + + +def create_descriptor( + memref_arg, + leading_byte_offset: int, + stride_byte_offset: int, + swizzle: int | None, + memory_space: int | None = None, +): + i64 = ir.IntegerType.get_signless(64) + ptr_val = llvm.ptrtoint(i64, utils.memref_ptr(memref_arg, memory_space)) + if swizzle is None: + swizzle_encoding = 0 + elif swizzle == 128: + swizzle_encoding = 1 + elif swizzle == 64: + swizzle_encoding = 2 + elif swizzle == 32: + swizzle_encoding = 3 + else: + raise NotImplementedError(swizzle) + encoded_base_addr = llvm.LShrOp( + llvm.AndOp(ptr_val, c(0x3FFFF, i64)), c(4, i64) + ) + # We ignore the offset + desc_const = ( + (wgmma_encode(leading_byte_offset) << 16) + | (wgmma_encode(stride_byte_offset) << 32) + ) + desc = llvm.or_( + arith.shli(c(swizzle_encoding, i64), c(62, i64)), c(desc_const, i64) + ) + desc = llvm.or_(encoded_base_addr, desc) + return desc + + +def _unpack_i32(vec_ty, r): + i32 = ir.IntegerType.get_signless(32) + return vector.bitcast( + vec_ty, vector.splat(ir.VectorType.get((1,), i32), r) + ) + + +def _supported_wgmma_types(dtype, abtype) -> bool: + input_types_are = lambda ty: ty.isinstance(abtype) + if ir.F32Type.isinstance(dtype): + return any(input_types_are(ty) for ty in (ir.FloatTF32Type, ir.BF16Type, ir.F16Type)) + elif ir.F16Type.isinstance(dtype): + return input_types_are(ir.F16Type) + else: + return False + + +def wgmma_m64( + acc: np.ndarray, # of register Values + a, + b_descriptor: ir.Value, + a_transpose: bool | None, + b_transpose: bool, + a_k_stride: int | None, + b_k_stride: int, + n: int, + swizzle: int, + element_type: ir.Type, +): + out_ty = ir.VectorType(acc.flat[0].type).element_type + if not _supported_wgmma_types(out_ty, element_type): + raise ValueError(f"Usupported wgmma types {(out_ty, element_type)=}") + + i32 = ir.IntegerType.get_signless(32) + i64 = ir.IntegerType.get_signless(64) + index = ir.IndexType.get() + if b_k_stride % 16: + raise ValueError + if n % (swizzle // bytewidth(element_type)): + raise ValueError + # Only 16-bit types support transposes + supports_transpose = bytewidth(element_type) == 2 + if not supports_transpose and (a_transpose or b_transpose): + raise ValueError("Only f16 WGMMA supports transposes") + if a_in_regs := isinstance(a, mgpu.FragmentedArray): + if a.mlir_dtype != ir.F16Type.get() and a.mlir_dtype != ir.BF16Type.get(): + raise ValueError(f"Unsupported A register array dtype: {a.mlir_dtype}") + if a.layout != mgpu.WGMMA_LAYOUT or a.shape != (64, 64): + raise ValueError("Unsupported A register array layout") + if a_k_stride is not None or a_transpose is not None: + raise ValueError("Unsupported WGMMA features with A in registers") + else: + if a_k_stride is None or a_k_stride % 16: + raise ValueError + if a_transpose is None: + raise ValueError + + if ir.F32Type.isinstance(out_ty): + num_acc_regs = n // 2 + out_ty_field = out_ty + acc_regs = [ # pylint: disable=g-complex-comprehension + vector.extractelement(reg, position=c(pos, index)) + for reg in acc.flat + for pos in range(2) + ] + to_acc_vec_regs = functools.partial(_as_fragmented_reg_ndarray, dtype=out_ty, shape=acc.shape) + acc_constraint = "f" + elif ir.F16Type.isinstance(out_ty): + num_acc_regs = n // 4 + out_ty_field = i32 + acc_regs = [_as_i32_reg(reg) for reg in acc.flat] + vec_ty = ir.VectorType(acc.flat[0].type) + to_acc_vec_regs = lambda regs : np.array([_unpack_i32(vec_ty, reg) for reg in regs]).reshape(acc.shape) + acc_constraint = "r" + else: + raise ValueError(f"WGMMA instruciton only supports f32 and f16 out (got {out_ty})") + + num_imm_regs = 4 if supports_transpose else 2 + + if a_in_regs: + a_reg_constraints = ["r"] * 4 # 4x f16x2 registers + num_imm_regs -= 1 # transpose not supported for a in registers + else: + a_reg_constraints = ["l"] # descriptor + # Reference for i/o aliasing: https://gcc.gnu.org/onlinedocs/gcc/Extended-Asm.html + # Seems like it's not actually documented in LLVM IR docs. + reg_constraints_list = ( + [f"={acc_constraint}"] * num_acc_regs # accumulator registers + + [str(i) for i in range(num_acc_regs)] # we alias outputs as inputs, too. + + a_reg_constraints # a descriptor / registers + + ["l"] * 1 # b descriptor + + ["n"] * (1 + num_imm_regs) # literal constants + ) + reg_constraints = ",".join(reg_constraints_list) + + reg_count = itertools.count() + + def take_regs(n): + return (f"${i}" for i in itertools.islice(reg_count, n)) + + acc_reg_vector = "{" + ",".join(take_regs(num_acc_regs)) + "}" + for _ in take_regs(num_acc_regs): # Ignore next entries: aliasing. + pass + if a_in_regs: + a_regs = "{" + ",".join(take_regs(len(a_reg_constraints))) + "}" + else: + a_regs, = take_regs(1) + b_desc_reg, use_out_reg = take_regs(2) + imm_regs = ", ".join(take_regs(num_imm_regs)) # Immediate regs (scale, ...). + assert next(reg_count) == len(reg_constraints_list) + el_ty = element_type + k_instr = 32 // bytewidth(element_type) + wgmma_instr = ( + f"wgmma.mma_async.sync.aligned.m64n{n}k{k_instr}.{out_ty}.{el_ty}.{el_ty} " + f"{acc_reg_vector}, {a_regs}, {b_desc_reg}, p, {imm_regs};" + ) + ptx = f"{{ .reg .pred p; setp.ne.b32 p, {use_out_reg}, 0; {wgmma_instr} }}\n" + + def lc(x): + return llvm.ConstantOp(i32, ir.IntegerAttr.get(i32, x)).result + + use_out = scale_a = scale_b = lc(1) + imms = [use_out, scale_a, scale_b] + if supports_transpose and a_transpose is not None: + imms += [lc(int(a_transpose)), lc(int(b_transpose))] + elif supports_transpose: + imms += [lc(int(b_transpose))] + if acc.ndim != 4 or acc.shape[0] != 1 or acc.shape[2:] != (2, 1): + raise ValueError(acc.shape) + acc_struct_type = ir.Type.parse( + f"!llvm.struct<({','.join(str(out_ty_field) for _ in acc_regs)})>" + ) + for i in range((swizzle // bytewidth(element_type)) // k_instr): + # Slice out the relevant part of A or advance the A descriptor. + if a_in_regs: + a_slice = a[:, (i * 16) : ((i + 1) * 16)] + a_args = [_as_i32_reg(v) for v in a_slice.registers.flat] + else: + if i > 0: + a = llvm_add( + a, + llvm.ConstantOp(i64, ir.IntegerAttr.get(i64, a_k_stride >> 4)), + ) + a_args = [a] + # Advance the B descriptor. + if i > 0: + b_descriptor = llvm_add( + b_descriptor, + llvm.ConstantOp(i64, ir.IntegerAttr.get(i64, b_k_stride >> 4)), + ) + assert len(a_args) == len(a_reg_constraints) + acc_struct = llvm.inline_asm( + acc_struct_type, + [*acc_regs, *a_args, b_descriptor, *imms], + ptx, + reg_constraints, + asm_dialect=0, + has_side_effects=True, + ) + acc_regs = [ + llvm.extractvalue(out_ty_field, acc_struct, [i]) for i in range(len(acc_regs)) + ] + return to_acc_vec_regs(acc_regs) + + +class WGMMALayout(enum.Enum): + ROW_MAJOR = enum.auto() + COL_MAJOR = enum.auto() + + +# TODO(apaszke): Remove WGMMALayout. Make input shapes logical and infer +# transpositions from memref strides. +def wgmma( + acc: WGMMAAccumulator, + a, + b, + *, + swizzle: int = 128, + # Order only applies within each tile! + a_order: WGMMALayout | None = None, + b_order: WGMMALayout = WGMMALayout.ROW_MAJOR, +): + if a_in_regs := isinstance(a, mgpu.FragmentedArray): + a_element_type = a.mlir_dtype + a_shape = a.shape + else: + a_ty = ir.MemRefType(a.type) + a_element_type = a_ty.element_type + a_shape = a_ty.shape + b_ty = ir.MemRefType(b.type) + supported_types = {ir.F16Type.get(), ir.BF16Type.get(), ir.F32Type.get()} + if a_element_type not in supported_types: + raise ValueError(a_element_type) + if b_ty.element_type not in supported_types: + raise ValueError(b_ty.element_type) + if (element_type := a_element_type) != b_ty.element_type: + raise ValueError + element_bytewidth = bytewidth(element_type) + kn_tile = swizzle // element_bytewidth + + groups_k, groups_n = b_ty.shape[:2] + if b_ty.shape[2:] != [kn_tile, kn_tile]: + raise ValueError(b_ty.shape) + + if a_in_regs: + if a_element_type != ir.F16Type.get() and a_element_type != ir.BF16Type.get(): + raise ValueError(a_element_type) + if a_shape[0] % 64 or a_shape[1] % kn_tile: + raise ValueError(a_shape) + if a_shape[1] // kn_tile != groups_k: + raise ValueError(a_shape[1] // kn_tile, groups_k) + groups_m = a_shape[0] // 64 + if a_order is not None: + raise ValueError( + "a_order can only be specified when A is in shared memory" + ) + else: + groups_m = a_shape[0] + if a_shape[1] != groups_k: + raise ValueError(a_shape[1], groups_k) + if a_shape[2:] != [64, kn_tile]: + raise ValueError(a_shape) + if a_order is None: + a_order = WGMMALayout.ROW_MAJOR + + if a_order == WGMMALayout.COL_MAJOR and swizzle != 128: + # Not sure what the layout is like, since the tiles aren't square. + raise NotImplementedError + + row_major = WGMMALayout.ROW_MAJOR + col_major = WGMMALayout.COL_MAJOR + tnsp_lbo = swizzle * (swizzle // 32) + sbo = swizzle // 2 + a_desc_fields = dict( + leading_byte_offset=(1 if a_order == row_major else tnsp_lbo) << 4, + stride_byte_offset=sbo << 4, + swizzle=swizzle, + memory_space=3, + ) + b_desc_fields = dict( + leading_byte_offset=(tnsp_lbo if b_order == row_major else 1) << 4, + stride_byte_offset=sbo << 4, + swizzle=swizzle, + memory_space=3, + ) + wgmma_params = dict( + a_transpose=a_order == col_major, + b_transpose=b_order == row_major, + a_k_stride=(2 if a_order == row_major else 128) << 4, + b_k_stride=(swizzle if b_order == row_major else 2) << 4, + n=(groups_n * kn_tile), + swizzle=swizzle, + element_type=ir.FloatTF32Type.get() + if ir.F32Type.isinstance(element_type) + else element_type, + ) + if a_in_regs: + wgmma_params["a_k_stride"] = wgmma_params["a_transpose"] = None + + if a_in_regs: + a = wgmma_fence(a) # Make sure the registers are ready. + a_m_byte_stride = a_k_byte_stride = a_desc_base = None # Silence pytype. + else: + a_desc_base = create_descriptor(a, **a_desc_fields) + a_strides, _ = ir.MemRefType(a.type).get_strides_and_offset() + a_byte_strides = [s * element_bytewidth for s in a_strides] + a_m_byte_stride, a_k_byte_stride = a_byte_strides[:2] + if a_byte_strides[2:] != [swizzle, element_bytewidth]: + raise ValueError(a_byte_strides) + b_desc_base = create_descriptor(b, **b_desc_fields) + b_strides, _ = b_ty.get_strides_and_offset() + b_byte_strides = [s * element_bytewidth for s in b_strides] + b_k_byte_stride = b_byte_strides[0] + if b_byte_strides[1:] != [swizzle * kn_tile, swizzle, element_bytewidth]: + raise ValueError(b_byte_strides) + + i64 = ir.IntegerType.get_signless(64) + new_acc_regs = acc.value.registers.copy() + for mi in range(groups_m): + for ki in range(groups_k): + if a_in_regs: + a_mk = a[mi * 64 : (mi + 1) * 64, ki * kn_tile : (ki + 1) * kn_tile] + else: + a_mk = llvm_add( + a_desc_base, + c(wgmma_encode(mi * a_m_byte_stride + ki * a_k_byte_stride), i64), + ) + b_k = llvm_add(b_desc_base, c(wgmma_encode(ki * b_k_byte_stride), i64)) + new_acc_regs[mi : mi + 1] = wgmma_m64( + new_acc_regs[mi : mi + 1], a_mk, b_k, **wgmma_params + ) + return WGMMAAccumulator( + _value=mgpu.FragmentedArray( + _registers=new_acc_regs, _layout=mgpu.WGMMA_LAYOUT + ), + _sync=False, + ) + + +def wgmma_fence(array: mgpu.FragmentedArray): + """Fences the array construction from WGMMA instructions. + + This is a little workaround to force LLVM to initialize the PTX registers + before the wgmma.fence.sync.aligned instruction. Otherwise, LLVM treats + in-register computation as pure and can move it after the fence, which is + explicitly disallowed by the PTX programming model. + """ + i32 = ir.IntegerType.get_signless(32) + index = ir.IndexType.get() + dtype = array.mlir_dtype + src_vec_ty = ir.VectorType(array.registers.flat[0].type) + assert src_vec_ty.shape == [2] + + if dtype == ir.F32Type.get(): + regs = [ # pylint: disable=g-complex-comprehension + vector.extractelement(reg, position=c(pos, index)) + for reg in array.registers.flat + for pos in range(2) + ] + reg_dtype = dtype + reg_constraints_list = ["=f"] * len(regs) + ["f"] * len(regs) + ptx_lines = [f"mov.f32 ${i}, ${len(regs)+i}" for i in range(len(regs))] + elif dtype == ir.F16Type.get() or dtype == ir.BF16Type.get(): + regs = [_as_i32_reg(reg) for reg in array.registers.flat] + reg_dtype = i32 + reg_constraints_list = ["=r"] * len(regs) + ["r"] * len(regs) + ptx_lines = [f"mov.b32 ${i}, ${len(regs)+i}" for i in range(len(regs))] + else: + raise NotImplementedError(dtype) + reg_constraints = ",".join(reg_constraints_list) + # Copy over the registers. ptxas should be able to remove the moves. + ptx_lines.append("wgmma.fence.sync.aligned") + ptx = ";\n".join(ptx_lines) + ";\n" + dtype_str = str(reg_dtype) + struct_ty = ir.Type.parse( + f"!llvm.struct<({','.join(dtype_str for _ in regs)})>" + ) + acc_struct = llvm.inline_asm( + struct_ty, regs, ptx, reg_constraints, + asm_dialect=0, has_side_effects=True, + ) + regs = [ + llvm.extractvalue(reg_dtype, acc_struct, [i]) for i in range(len(regs)) + ] + if dtype == ir.F32Type.get(): + registers = _as_fragmented_reg_ndarray( + regs, array.mlir_dtype, array.registers.shape + ) + elif dtype == ir.F16Type.get() or dtype == ir.BF16Type.get(): + regs = [_unpack_i32(src_vec_ty, r) for r in regs] + registers = np.asarray(regs, dtype=object).reshape(array.registers.shape) + else: + raise NotImplementedError(dtype) + return mgpu.FragmentedArray(_registers=registers, _layout=array.layout) + + +def _as_fragmented_reg_ndarray(flat_regs, dtype: ir.Type, shape: tuple[int, ...]): + vec_regs = [] + for first, second in zip(flat_regs[::2], flat_regs[1::2]): + vec = llvm.mlir_undef(ir.VectorType.get((2,), dtype)) + vec = llvm.insertelement(vec, first, position=_lc(0)) + vec = llvm.insertelement(vec, second, position=_lc(1)) + vec_regs.append(vec) + return np.asarray(vec_regs, dtype=object).reshape(shape) + + +def _as_i32_reg(v): + i32 = ir.IntegerType.get_signless(32) + return llvm.extractelement( + vector.bitcast(ir.VectorType.get((1,), i32), v), _lc(0) + ) + + +def _lc(x): + i32 = ir.IntegerType.get_signless(32) + return llvm.ConstantOp(i32, ir.IntegerAttr.get(i32, x)).result diff --git a/jax/experimental/multihost_utils.py b/jax/experimental/multihost_utils.py index 8b9e52bdf7f8..8d3331d774f9 100644 --- a/jax/experimental/multihost_utils.py +++ b/jax/experimental/multihost_utils.py @@ -366,7 +366,7 @@ def ltg_batcher(insert_axis, spmd_axis_name, axis_size, new_parts = None if spmd_axis_name is None else spmd_axis_name new_pspec = list(pspec) new_pspec.insert(d, new_parts) - new_pspec = P(*new_pspec) # type: ignore + new_pspec = P(*new_pspec) y = host_local_array_to_global_array_p.bind( x, global_mesh=global_mesh, pspec=new_pspec) return y, d @@ -424,21 +424,27 @@ def global_array_to_host_local_array( You can use this function to convert the globally shaped `jax.Array` output from pjit to host local values again so that the transition to jax.Array can - be a mechanical change. Example usage + be a mechanical change. - >> from jax.experimental import multihost_utils # doctest: +SKIP - >> - >> global_inputs = multihost_utils.host_local_array_to_global_array(host_local_inputs, global_mesh, in_pspecs) # doctest: +SKIP - >> - >> with mesh: # doctest: +SKIP - >> global_out = pjitted_fun(global_inputs) # doctest: +SKIP - >> - >> host_local_output = multihost_utils.global_array_to_host_local_array(global_out, mesh, out_pspecs) # doctest: +SKIP + Example usage: + + >>> from jax.experimental import multihost_utils # doctest: +SKIP + >>> + >>> global_inputs = multihost_utils.host_local_array_to_global_array(host_local_inputs, global_mesh, in_pspecs) # doctest: +SKIP + >>> + >>> with mesh: # doctest: +SKIP + ... global_out = pjitted_fun(global_inputs) # doctest: +SKIP + >>> + >>> host_local_output = multihost_utils.global_array_to_host_local_array(global_out, mesh, out_pspecs) # doctest: +SKIP Args: global_inputs: A Pytree of global jax.Array's. - global_mesh: A jax.sharding.Mesh object. - pspecs: A Pytree of jax.sharding.PartitionSpec's. + global_mesh: A :class:`jax.sharding.Mesh` object. The mesh must be contiguous + meaning all local devices of the host must form a subcube. + pspecs: A Pytree of :class:`jax.sharding.PartitionSpec` objects. + + Returns: + A Pytree of host local arrays. """ flat_inps, out_tree = tree_flatten(global_inputs) out_pspecs = _flatten_pspecs('output pspecs', out_tree, diff --git a/jax/experimental/ode.py b/jax/experimental/ode.py index 81a9630cbdfe..b8e3daee48c8 100644 --- a/jax/experimental/ode.py +++ b/jax/experimental/ode.py @@ -201,7 +201,7 @@ def body_fun(state): next_t = t + dt error_ratio = mean_error_ratio(next_y_error, rtol, atol, y, next_y) new_interp_coeff = interp_fit_dopri(y, next_y, k, dt) - dt = jnp.clip(optimal_step_size(dt, error_ratio), a_min=0., a_max=hmax) + dt = jnp.clip(optimal_step_size(dt, error_ratio), min=0., max=hmax) new = [i + 1, next_y, next_f, next_t, dt, t, new_interp_coeff] old = [i + 1, y, f, t, dt, last_t, interp_coeff] @@ -214,7 +214,7 @@ def body_fun(state): return carry, y_target f0 = func_(y0, ts[0]) - dt = jnp.clip(initial_step_size(func_, ts[0], y0, 4, rtol, atol, f0), a_min=0., a_max=hmax) + dt = jnp.clip(initial_step_size(func_, ts[0], y0, 4, rtol, atol, f0), min=0., max=hmax) interp_coeff = jnp.array([y0] * 5) init_carry = [y0, f0, ts[0], dt, ts[0], interp_coeff] _, ys = lax.scan(scan_fun, init_carry, ts[1:]) diff --git a/jax/experimental/pallas/__init__.py b/jax/experimental/pallas/__init__.py index d0ad9b864b6c..65fd4f466b10 100644 --- a/jax/experimental/pallas/__init__.py +++ b/jax/experimental/pallas/__init__.py @@ -12,7 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Module for pallas, a JAX extension for custom kernels.""" +"""Module for Pallas, a JAX extension for custom kernels. + +See the Pallas documentation at https://jax.readthedocs.io/en/latest/pallas.html. +""" from jax._src import pallas from jax._src.pallas.core import BlockSpec @@ -29,6 +32,7 @@ from jax._src.pallas.primitives import atomic_or from jax._src.pallas.primitives import atomic_xchg from jax._src.pallas.primitives import atomic_xor +from jax._src.pallas.primitives import debug_print from jax._src.pallas.primitives import dot from jax._src.pallas.primitives import load from jax._src.pallas.primitives import max_contiguous @@ -41,17 +45,7 @@ from jax._src.pallas.utils import next_power_of_2 from jax._src.pallas.utils import strides_from_shape from jax._src.pallas.utils import when -from jax._src.state.primitives import broadcast_to from jax._src.state.indexing import ds -from jax._src.state.indexing import dslice +from jax._src.state.indexing import dslice from jax._src.state.indexing import Slice - -try: - from jax.experimental.pallas import gpu # pytype: disable=import-error -except (ImportError, ModuleNotFoundError): - pass - -try: - from jax.experimental.pallas import tpu # pytype: disable=import-error -except (ImportError, ModuleNotFoundError): - pass +from jax._src.state.primitives import broadcast_to diff --git a/jax/experimental/pallas/gpu.py b/jax/experimental/pallas/gpu.py index 1a6c5ece6537..adade4e8a72c 100644 --- a/jax/experimental/pallas/gpu.py +++ b/jax/experimental/pallas/gpu.py @@ -12,11 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Contains Triton specific Pallas functions.""" -try: - from jax._src.pallas import triton - get_compute_capability = triton.get_compute_capability - del triton -except ImportError as e: - raise ImportError("Cannot import Pallas Triton backend. " - "Make sure you've installed jax-triton.") from e +"""Triton-specific Pallas APIs.""" + +from jax._src.pallas.triton.primitives import approx_tanh +from jax._src.pallas.triton.primitives import elementwise_inline_asm diff --git a/jax/experimental/pallas/ops/__init__.py b/jax/experimental/pallas/ops/__init__.py index 017372046b07..132d3839e212 100644 --- a/jax/experimental/pallas/ops/__init__.py +++ b/jax/experimental/pallas/ops/__init__.py @@ -12,13 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. - -from jax.experimental.pallas.ops import attention -from jax.experimental.pallas.ops import layer_norm -from jax.experimental.pallas.ops import rms_norm -from jax.experimental.pallas.ops import softmax - - # All files within ops should be treated as user code. import os import jax._src.source_info_util diff --git a/jax/experimental/pallas/ops/gpu/__init__.py b/jax/experimental/pallas/ops/gpu/__init__.py new file mode 100644 index 000000000000..862a661e24b9 --- /dev/null +++ b/jax/experimental/pallas/ops/gpu/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/jax/experimental/pallas/ops/attention.py b/jax/experimental/pallas/ops/gpu/attention.py similarity index 88% rename from jax/experimental/pallas/ops/attention.py rename to jax/experimental/pallas/ops/gpu/attention.py index a96d3e1cc8dd..a0221ebf6f74 100644 --- a/jax/experimental/pallas/ops/attention.py +++ b/jax/experimental/pallas/ops/gpu/attention.py @@ -16,7 +16,7 @@ from __future__ import annotations import functools -from typing import Any, Optional +from typing import Any import jax from jax import lax @@ -119,7 +119,7 @@ def body(start_k, carry): # Ceildiv (`pl.cdiv` and `//` do not work due to type of start_q) upper_bound = lax.div(block_q * (start_q + 1) + block_k - 1, block_k) else: - upper_bound = pl.cdiv(seq_len, block_k) # type: ignore + upper_bound = pl.cdiv(seq_len, block_k) o, m_i, l_i = lax.fori_loop(0, upper_bound, body, (o, m_i, l_i)) if residual_refs: @@ -198,19 +198,19 @@ def mha( in_specs = [ pl.BlockSpec( - lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim) + (None, seq_len, None, head_dim), lambda _, j, k: (j, 0, k, 0) ), pl.BlockSpec( - lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim) + (None, seq_len, None, head_dim), lambda _, j, k: (j, 0, k, 0) ), pl.BlockSpec( - lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim) + (None, seq_len, None, head_dim), lambda _, j, k: (j, 0, k, 0) ), ] in_specs.append( None # type: ignore[arg-type] if segment_ids is None - else pl.BlockSpec(lambda _, j, k: (j, 0), (None, seq_len)) + else pl.BlockSpec((None, seq_len), lambda _, j, k: (j, 0)) ) out_shape = jax.ShapeDtypeStruct(shape=q.shape, dtype=q.dtype) return pl.pallas_call( @@ -218,7 +218,7 @@ def mha( grid=grid_, in_specs=in_specs, out_specs=pl.BlockSpec( - lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim) + (None, seq_len, None, head_dim), lambda _, j, k: (j, 0, k, 0) ), compiler_params=dict( triton=dict(num_warps=num_warps_, num_stages=num_stages) @@ -270,19 +270,19 @@ def _mha_forward( ] in_specs = [ pl.BlockSpec( - lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim) + (None, seq_len, None, head_dim), lambda _, j, k: (j, 0, k, 0) ), pl.BlockSpec( - lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim) + (None, seq_len, None, head_dim), lambda _, j, k: (j, 0, k, 0) ), pl.BlockSpec( - lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim) + (None, seq_len, None, head_dim), lambda _, j, k: (j, 0, k, 0) ), ] in_specs.append( None # type: ignore[arg-type] if segment_ids is None - else pl.BlockSpec(lambda _, j, k: (j, 0), (None, seq_len)) + else pl.BlockSpec((None, seq_len), lambda _, j, k: (j, 0)) ) out, l, m = pl.pallas_call( kernel, @@ -290,10 +290,10 @@ def _mha_forward( in_specs=in_specs, out_specs=[ pl.BlockSpec( - lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim) + (None, seq_len, None, head_dim), lambda _, j, k: (j, 0, k, 0) ), - pl.BlockSpec(lambda _, j, k: (j, k, 0), (None, None, seq_len)), - pl.BlockSpec(lambda _, j, k: (j, k, 0), (None, None, seq_len)), + pl.BlockSpec((None, None, seq_len), lambda _, j, k: (j, k, 0)), + pl.BlockSpec((None, None, seq_len), lambda _, j, k: (j, k, 0)), ], compiler_params=dict( triton=dict(num_warps=num_warps_, num_stages=num_stages) @@ -336,13 +336,13 @@ def _preprocess_backward(out, do, l, block_q: int, functools.partial(_preprocess_backward_kernel, block_q=block_q), grid=(pl.cdiv(seq_len, block_q), batch_size, num_heads), in_specs=[ - pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)), - pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)), - pl.BlockSpec(lambda _, j, k: (j, k, 0), (None, None, seq_len)), + pl.BlockSpec((None, seq_len, None, head_dim), lambda _, j, k: (j, 0, k, 0)), + pl.BlockSpec((None, seq_len, None, head_dim), lambda _, j, k: (j, 0, k, 0)), + pl.BlockSpec((None, None, seq_len), lambda _, j, k: (j, k, 0)), ], out_specs=[ - pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)), - pl.BlockSpec(lambda _, j, k: (j, k, 0), (None, None, seq_len)), + pl.BlockSpec((None, seq_len, None, head_dim), lambda _, j, k: (j, 0, k, 0)), + pl.BlockSpec((None, None, seq_len), lambda _, j, k: (j, k, 0)), ], compiler_params=dict( triton=dict(num_warps=4, num_stages=3) @@ -483,32 +483,32 @@ def _mha_backward(sm_scale: float, causal: bool, block_q: int, block_k: int, in_specs = [ pl.BlockSpec( - lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim) + (None, seq_len, None, head_dim), lambda j, k: (j, 0, k, 0) ), pl.BlockSpec( - lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim) + (None, seq_len, None, head_dim), lambda j, k: (j, 0, k, 0) ), pl.BlockSpec( - lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim) + (None, seq_len, None, head_dim), lambda j, k: (j, 0, k, 0) ), pl.BlockSpec( - lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim) + (None, seq_len, None, head_dim), lambda j, k: (j, 0, k, 0) ), pl.BlockSpec( - lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim) + (None, seq_len, None, head_dim), lambda j, k: (j, 0, k, 0) ), - pl.BlockSpec(lambda j, k: (j, k, 0), (None, None, seq_len)), - pl.BlockSpec(lambda j, k: (j, k, 0), (None, None, seq_len)), - pl.BlockSpec(lambda j, k: (j, k, 0), (None, None, seq_len)), + pl.BlockSpec((None, None, seq_len), lambda j, k: (j, k, 0)), + pl.BlockSpec((None, None, seq_len), lambda j, k: (j, k, 0)), + pl.BlockSpec((None, None, seq_len), lambda j, k: (j, k, 0)), pl.BlockSpec( - lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim) + (None, seq_len, None, head_dim), lambda j, k: (j, 0, k, 0) ), ] if segment_ids is None: in_specs.insert(3, None) # type: ignore[arg-type] input_output_aliases = {8: 0} else: - in_specs.insert(3, pl.BlockSpec(lambda j, k: (j, 0), (None, seq_len))) + in_specs.insert(3, pl.BlockSpec((None, seq_len), lambda j, k: (j, 0))) input_output_aliases = {9: 0} grid = (batch_size, num_heads) # TODO(sharadmv): figure out why num_warps=8 doesn't work! @@ -527,13 +527,13 @@ def _mha_backward(sm_scale: float, causal: bool, block_q: int, block_k: int, in_specs=in_specs, out_specs=[ pl.BlockSpec( - lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim) + (None, seq_len, None, head_dim), lambda j, k: (j, 0, k, 0) ), pl.BlockSpec( - lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim) + (None, seq_len, None, head_dim), lambda j, k: (j, 0, k, 0) ), pl.BlockSpec( - lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim) + (None, seq_len, None, head_dim), lambda j, k: (j, 0, k, 0) ), ], name="mha_backward", diff --git a/jax/experimental/pallas/ops/gpu/decode_attention.py b/jax/experimental/pallas/ops/gpu/decode_attention.py index 7e08836b0f6c..9be724a1f42c 100644 --- a/jax/experimental/pallas/ops/gpu/decode_attention.py +++ b/jax/experimental/pallas/ops/gpu/decode_attention.py @@ -82,7 +82,7 @@ def body(start_k, carry): o_next = correction[:, None] * o_prev + o_curr return o_next, m_next, l_next - upper_bound = pl.cdiv(k_seq_len, block_k) # type: ignore + upper_bound = pl.cdiv(k_seq_len, block_k) # o is left unscaled; it will be scaled in the final reduction step o, m_i, l_i = lax.fori_loop(0, upper_bound, body, (o, m_i, l_i)) @@ -144,29 +144,18 @@ def attn_unbatched( kernel, grid=grid_, in_specs=[ - pl.BlockSpec(lambda i, j: (i, 0), (block_h, head_dim)), - pl.BlockSpec(lambda i, j: (j, 0, 0), (None, k_seq_len, head_dim)), - pl.BlockSpec(lambda i, j: (j, 0, 0), (None, k_seq_len, head_dim)), + pl.BlockSpec((block_h, head_dim), lambda i, j: (i, 0)), + pl.BlockSpec((None, k_seq_len, head_dim), lambda i, j: (j, 0, 0)), + pl.BlockSpec((None, k_seq_len, head_dim), lambda i, j: (j, 0, 0)), ], out_specs=[ - pl.BlockSpec(lambda i, j: (j, i, 0), (None, block_h, head_dim)), # o - pl.BlockSpec( - lambda i, j: (j, i), - ( - None, - block_h, - ), - ), # l - pl.BlockSpec( - lambda i, j: (j, i), - ( - None, - block_h, - ), - ), # m + pl.BlockSpec((None, block_h, head_dim), lambda i, j: (j, i, 0)), # o + pl.BlockSpec((None, block_h), lambda i, j: (j, i)), # l + pl.BlockSpec((None, block_h), lambda i, j: (j, i)), # m ], - num_warps=num_warps_, - num_stages=num_stages, + compiler_params=dict( + triton=dict(num_warps=num_warps_, num_stages=num_stages) + ), out_shape=[ jax.ShapeDtypeStruct(shape=(k_splits, *q.shape), dtype=q.dtype), # o jax.ShapeDtypeStruct( diff --git a/jax/experimental/pallas/ops/layer_norm.py b/jax/experimental/pallas/ops/gpu/layer_norm.py similarity index 99% rename from jax/experimental/pallas/ops/layer_norm.py rename to jax/experimental/pallas/ops/gpu/layer_norm.py index 269f29dc71b7..0c39a9bf6e0d 100644 --- a/jax/experimental/pallas/ops/layer_norm.py +++ b/jax/experimental/pallas/ops/gpu/layer_norm.py @@ -18,8 +18,6 @@ import functools -from typing import Optional - import jax from jax import lax import jax.numpy as jnp diff --git a/jax/experimental/pallas/ops/rms_norm.py b/jax/experimental/pallas/ops/gpu/rms_norm.py similarity index 100% rename from jax/experimental/pallas/ops/rms_norm.py rename to jax/experimental/pallas/ops/gpu/rms_norm.py diff --git a/jax/experimental/pallas/ops/softmax.py b/jax/experimental/pallas/ops/gpu/softmax.py similarity index 100% rename from jax/experimental/pallas/ops/softmax.py rename to jax/experimental/pallas/ops/gpu/softmax.py diff --git a/jax/experimental/pallas/ops/tpu/all_gather.py b/jax/experimental/pallas/ops/tpu/all_gather.py index 979513d0f414..e121db894122 100644 --- a/jax/experimental/pallas/ops/tpu/all_gather.py +++ b/jax/experimental/pallas/ops/tpu/all_gather.py @@ -63,18 +63,18 @@ def ag_kernel(x_ref, o_ref, send_sem, recv_sem, *, axis_name: str, my_id = lax.axis_index(axis_name) # TODO(sharadmv): could speed this up having the first remote DMA go from # x_ref->o_ref immediately instead of a blocking HBM copy. - with pltpu.trace("initial_copy"): + with jax.named_scope("initial_copy"): pltpu.async_copy(x_ref, o_ref.at[my_id], recv_sem[0]).wait() - with pltpu.trace("neighbour_lookup"): + with jax.named_scope("neighbour_lookup"): axis_size = lax.psum(1, axis_name) left_neighbor = get_neighbor(my_id, mesh, axis_name, direction="left") right_neighbor = get_neighbor(my_id, mesh, axis_name, direction="right") - with pltpu.trace("main_barrier"): + with jax.named_scope("main_barrier"): sem = pltpu.get_barrier_semaphore() - pltpu.semaphore_signal(sem, 2, device_id=left_neighbor) - pltpu.semaphore_signal(sem, 2, device_id=right_neighbor) + pltpu.semaphore_signal(sem, 1, device_id=left_neighbor) + pltpu.semaphore_signal(sem, 1, device_id=right_neighbor) pltpu.semaphore_wait(sem, 2) shard_size = x_ref.shape[0] @@ -86,7 +86,7 @@ def ag_kernel(x_ref, o_ref, send_sem, recv_sem, *, axis_name: str, right_slice = pl.ds(shard_size // 2, shard_size // 2) slot = jnp.where(right_slot < 0, axis_size + right_slot, right_slot) if right_dma: - with pltpu.trace("wait_right_dma"): + with jax.named_scope("wait_right_dma"): right_dma.wait() right_dma = pltpu.async_remote_copy( o_ref.at[slot, right_slice], @@ -100,7 +100,7 @@ def ag_kernel(x_ref, o_ref, send_sem, recv_sem, *, axis_name: str, left_slice = pl.ds(0, shard_size // 2) slot = lax.rem(left_slot, axis_size) if left_dma: - with pltpu.trace("wait_left_dma"): + with jax.named_scope("wait_left_dma"): left_dma.wait() left_dma = pltpu.async_remote_copy( o_ref.at[slot, left_slice], @@ -109,7 +109,7 @@ def ag_kernel(x_ref, o_ref, send_sem, recv_sem, *, axis_name: str, recv_sem[0], device_id=left_neighbor, ) - with pltpu.trace("wait_all_dma"): + with jax.named_scope("wait_all_dma"): assert right_dma is not None assert left_dma is not None right_dma.wait() @@ -136,7 +136,7 @@ def ag_local(x_shard): out = pl.pallas_call( functools.partial(ag_kernel, axis_name=axis_name, mesh=mesh), out_shape=out_shape, - mosaic_params=dict(collective_id=0), + compiler_params=dict(mosaic=dict(collective_id=0)), grid_spec=pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=0, scratch_shapes=( diff --git a/jax/experimental/pallas/ops/tpu/flash_attention.py b/jax/experimental/pallas/ops/tpu/flash_attention.py index c63929bb6355..f3b09c96486b 100644 --- a/jax/experimental/pallas/ops/tpu/flash_attention.py +++ b/jax/experimental/pallas/ops/tpu/flash_attention.py @@ -910,8 +910,8 @@ def run(): @pl.when(q_seq_index == q_seq_len // block_q_major - 1) def end_of_q_sequence(): - dv_tile_ref[0, 0, :, :] = dv_scratch_ref[...].astype(dv_tile_ref) - dk_tile_ref[0, 0, :, :] = dk_scratch_ref[...].astype(dk_tile_ref) + dv_tile_ref[0, 0, :, :] = dv_scratch_ref[...].astype(dv_tile_ref.dtype) + dk_tile_ref[0, 0, :, :] = dk_scratch_ref[...].astype(dk_tile_ref.dtype) def _flash_attention_bwd_dkv( @@ -1099,7 +1099,7 @@ def dkv_index_map(batch_index, head_index, kv_seq_index, _): grid_spec=pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=0, grid=grid, - in_specs=in_specs, # type: ignore + in_specs=in_specs, out_specs=out_specs, scratch_shapes=scratch_shapes, ), @@ -1266,7 +1266,7 @@ def zero_out_ds(): @pl.when(kv_seq_index == kv_seq_len // block_k_major - 1) def end_of_kv_sequence(): - dq_tile_ref[0, 0, :, :] = dq_scratch_ref[...].astype(dq_tile_ref) + dq_tile_ref[0, 0, :, :] = dq_scratch_ref[...].astype(dq_tile_ref.dtype) dq_scratch_ref[...] = jnp.zeros_like(dq_scratch_ref) @@ -1444,8 +1444,8 @@ def kv_segment_ids_index_map( grid_spec=pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=0, grid=grid, - in_specs=in_specs, # type: ignore - out_specs=out_specs, # type: ignore + in_specs=in_specs, + out_specs=out_specs, scratch_shapes=scratch_shapes, ), out_shape=out_shapes, diff --git a/jax/experimental/array_api/_constants.py b/jax/experimental/pallas/ops/tpu/megablox/__init__.py similarity index 81% rename from jax/experimental/array_api/_constants.py rename to jax/experimental/pallas/ops/tpu/megablox/__init__.py index e6f0d542ae79..2c7391a18173 100644 --- a/jax/experimental/array_api/_constants.py +++ b/jax/experimental/pallas/ops/tpu/megablox/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2023 The JAX Authors. +# Copyright 2024 The JAX Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,10 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -import numpy as np - -e = np.e -inf = np.inf -nan = np.nan -newaxis = np.newaxis -pi = np.pi +from jax.experimental.pallas.ops.tpu.megablox.ops import gmm diff --git a/jax/experimental/pallas/ops/tpu/megablox/common.py b/jax/experimental/pallas/ops/tpu/megablox/common.py new file mode 100644 index 000000000000..bd843cf46ca4 --- /dev/null +++ b/jax/experimental/pallas/ops/tpu/megablox/common.py @@ -0,0 +1,63 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Common utilities for GMM kernels.""" + +import re + +import jax +import jax.numpy as jnp + + +def is_tpu() -> bool: + return "TPU" in jax.devices()[0].device_kind + + +def tpu_kind() -> str: + """Query identification string for the currently attached TPU.""" + return jax.devices()[0].device_kind + + +_TPU_KIND_PATTERN = re.compile(r"TPU v(\d+)") + + +def tpu_generation() -> int: + """Generation number of the currently attached TPU.""" + if version := _TPU_KIND_PATTERN.match(tpu_kind()): + return int(version[1]) + raise NotImplementedError("only TPU devices are supported") + + +def supports_bfloat16_matmul() -> bool: + """Does the currently attached CPU support bfloat16 inputs?""" + return not is_tpu() or tpu_generation() >= 4 + + +def assert_is_supported_dtype(dtype: jnp.dtype) -> None: + if dtype != jnp.bfloat16 and dtype != jnp.float32: + raise ValueError(f"Expected bfloat16 or float32 array but got {dtype}.") + + +def select_input_dtype(lhs: jnp.ndarray, rhs: jnp.ndarray) -> jnp.dtype: + """A type to which both input should be adapted to before dot product.""" + # bf16xbf16 matmul is only supported since TPUv4 generation. In case of mixed + # input precision, we need to convert bf16 argument to fp32 beforehand. + if ( + supports_bfloat16_matmul() + and lhs.dtype == jnp.bfloat16 + and rhs.dtype == jnp.bfloat16 + ): + return jnp.bfloat16 + else: + return jnp.float32 diff --git a/jax/experimental/pallas/ops/tpu/megablox/gmm.py b/jax/experimental/pallas/ops/tpu/megablox/gmm.py new file mode 100644 index 000000000000..ba8ca6c1b617 --- /dev/null +++ b/jax/experimental/pallas/ops/tpu/megablox/gmm.py @@ -0,0 +1,799 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Grouped matrix multiplication kernels for TPU written in Pallas.""" + +from collections.abc import Callable +import functools +from typing import Any, Optional + +import jax +from jax import lax +from jax.experimental import pallas as pl +from jax.experimental.pallas import tpu as pltpu +from jax.experimental.pallas.ops.tpu.megablox import common +import jax.numpy as jnp + + +partial = functools.partial + + +def _validate_args( + *, + lhs: jnp.ndarray, + rhs: jnp.ndarray, + group_sizes: jnp.ndarray, + expected_rhs_dims: int = 3, +) -> tuple[jnp.ndarray, jnp.ndarray, jnp.dtype]: + """Validates the arguments for the gmm function.""" + # Validate 'lhs'. + if lhs.ndim != 2: + raise ValueError(f"Expected 2-tensor for 'lhs' but got {lhs.ndim}-tensor.") + common.assert_is_supported_dtype(lhs.dtype) + + # Validate 'rhs'. + if rhs.ndim != expected_rhs_dims: + raise ValueError( + f"Expected {expected_rhs_dims}-tensor for 'rhs' but got" + f" {rhs.ndim}-tensor." + ) + common.assert_is_supported_dtype(rhs.dtype) + + # Validate 'group_sizes'. + if group_sizes.dtype != jnp.int32: + raise ValueError( + f"Expected 32-bit integer 'group_sizes' but got {group_sizes.dtype}." + ) + + return lhs, group_sizes, common.select_input_dtype(lhs, rhs) + + +def _calculate_num_tiles(x: int, tx: int) -> int: + tiles, rem = divmod(x, tx) + if rem: + raise ValueError(f"{x} must be divisible by x-dimension tile size ({tx}).") + return tiles + + +def _calculate_irregular_num_tiles(x: int, tx: int) -> tuple[int, int]: + tiles, rem = divmod(x, tx) + if rem: + tiles += 1 + return tiles, rem + + +GroupMetadata = Any # TODO(enriqueps): Clean this up and use a namedtuple + + +def make_group_metadata( + *, + group_sizes: jnp.ndarray, + m: int, + tm: int, + start_group: jnp.ndarray, + num_nonzero_groups: int, + visit_empty_groups: bool = True, +) -> GroupMetadata: + """Create the metadata needed for grouped matmul computation. + + Args: + group_sizes: A 1d, jnp.ndarray with shape [num_groups] and jnp.int32 dtype. + m: The number of rows in lhs. + tm: The m-dimension tile size being used. + start_group: The group in group sizes to start computing from. This is + particularly useful for when rhs num_groups is sharded. + num_nonzero_groups: Number of groups in group sizes to compute on. Useful in + combination with group_offset. + visit_empty_groups: If True, do not squeeze tiles for empty groups out of + the metadata. This is necessary for tgmm, where we at least need to zero + the output for each group. + + Returns: + tuple of: + group_offsets: A 1d, jnp.ndarray with shape [num_groups+1] and jnp.int32 + dtype. group_offsets[i] indicates the row at which group [i] starts in + the lhs matrix and group_offsets[i-1] = m. + group_ids: A 1d, jnp.ndarray with shape [m_tiles + num_groups] and + jnp.int32 dtype. group_ids[i] indicates which group grid index 'i' will + work on. + m_tile_ids: A 1d, jnp.ndarray with shape [m_tiles + num_groups] and + jnp.int32. m_tile_ids[i] indicates which m-dimension tile grid index 'i' + will work on. + num_tiles: The number of m-dimension tiles to execute. + """ + num_groups = group_sizes.shape[0] + end_group = start_group + num_nonzero_groups - 1 + + # Calculate the offset of each group, starting at zero. This metadata is + # similar to row offsets in a CSR matrix. The following properties hold: + # + # group_offsets.shape = [num_groups + 1] + # group_offsets[0] = 0 + # group_offsets[num_groups] = m + # + # The row at which group 'i' starts is group_offsets[i]. + group_ends = jnp.cumsum(group_sizes) + group_offsets = jnp.concatenate([jnp.zeros(1, dtype=jnp.int32), group_ends]) + + # Assign a group id to each grid index. + # + # If a group starts somewhere other than the start of a tile or ends somewhere + # other than the end of a tile we need to compute that full tile. Calculate + # the number of tiles for each group by rounding their end up to the nearest + # 'tm' and their start down to the nearest 'tm'. + + # (1) Round the group_ends up to the nearest multiple of 'tm'. + # + # NOTE: This does not change group_offsets[num_groups], which is m + # (because we enforce m is divisible by tm). + rounded_group_ends = ((group_ends + tm - 1) // tm * tm).astype(jnp.int32) + + # (2) Round the group_starts down to the nearest multiple of 'tm'. + group_starts = jnp.concatenate( + [jnp.zeros(1, dtype=jnp.int32), group_ends[:-1]] + ) + rounded_group_starts = group_starts // tm * tm + + # (3) Calculate the number of rows in each group. + # + # NOTE: Handle zero-sized groups as a special case. If the start for a + # zero-sized group is not divisible by 'tm' its start will be rounded down and + # its end will be rounded up such that its size will become 1 tile here. + rounded_group_sizes = rounded_group_ends - rounded_group_starts + rounded_group_sizes = jnp.where(group_sizes == 0, 0, rounded_group_sizes) + + # (4) Convert the group sizes from units of rows to unit of 'tm' sized tiles. + # + # An m-dimension tile is 'owned' by group 'i' if the first row of the tile + # belongs to group 'i'. In addition to owned tiles, each group can have 0 or 1 + # initial partial tiles if it's first row does not occur in the first row of a + # tile. The '0-th' group never has a partial tile because it always starts at + # the 0-th row. + # + # If no group has a partial tile, the total number of tiles is equal to + # 'm // tm'. If every group has a partial except the 0-th group, the total + # number of tiles is equal to 'm // tm + num_groups - 1'. Thus we know that + # + # tiles_m <= group_tiles.sum() <= tiles_m + num_groups - 1 + # + # Where tiles_m = m // tm. + # + # NOTE: All group sizes are divisible by 'tm' because of the rounding in steps + # (1) and (2) so this division is exact. + group_tiles = rounded_group_sizes // tm + + if visit_empty_groups: + # Insert one tile for empty groups. + group_tiles = jnp.where(group_sizes == 0, 1, group_tiles) + + # Create the group ids for each grid index based on the tile counts for each + # group. + # + # NOTE: This repeat(...) will pad group_ids with the final group id if + # group_tiles.sum() < tiles_m + num_groups - 1. The kernel grid will be sized + # such that we only execute the necessary number of tiles. + tiles_m = _calculate_num_tiles(m, tm) + group_ids = jnp.repeat( + jnp.arange(num_groups, dtype=jnp.int32), + group_tiles, + total_repeat_length=tiles_m + num_groups - 1, + ) + + # Assign an m-dimension tile id to each grid index. + # + # NOTE: Output tiles can only be re-visited consecutively. The following + # procedure guarantees that m-dimension tile indices respect this. + + # (1) Calculate how many times each m-dimension tile will be visited. + # + # Each tile is guaranteed to be visited once by the group that owns the tile. + # The remaining possible visits occur when a group starts inside of a tile at + # a position other than the first row. We can calculate which m-dimension tile + # each group starts in by floor-dividing its offset with `tm` and then count + # tile visits with a histogram. + # + # To avoid double counting tile visits from the group that owns the tile, + # filter these out by assigning their tile id to `tile_m` (one beyond the max) + # such that they're ignored by the subsequent histogram. Also filter out any + # group which is empty. + # + # TODO(tgale): Invert the 'partial_tile_mask' predicates to be more clear. + partial_tile_mask = jnp.logical_or( + (group_offsets[:-1] % tm) == 0, group_sizes == 0 + ) + + # Explicitly enable tiles for zero sized groups, if specified. This covers + # zero sized groups that start on a tile-aligned row and those that do not. + if visit_empty_groups: + partial_tile_mask = jnp.where(group_sizes == 0, 0, partial_tile_mask) + + partial_tile_ids = jnp.where( + partial_tile_mask, tiles_m, group_offsets[:-1] // tm + ) + + tile_visits = ( + jnp.histogram(partial_tile_ids, bins=tiles_m, range=(0, tiles_m - 1))[0] + + 1 + ) + + # Create the m-dimension tile ids for each grid index based on the visit + # counts for each tile. + m_tile_ids = jnp.repeat( + jnp.arange(tiles_m, dtype=jnp.int32), + tile_visits.astype(jnp.int32), + total_repeat_length=tiles_m + num_groups - 1, + ) + + # Account for sharding. + # + # Find the start of the groups owned by our shard and shift the group_ids and + # m_tile_ids s.t. the metadata for our tiles are at the front of the arrays. + # + # TODO(tgale): Move this offset into the kernel to avoid these rolls. + first_tile_in_shard = (group_ids < start_group).sum() + group_ids = jnp.roll(group_ids, shift=-first_tile_in_shard, axis=0) + m_tile_ids = jnp.roll(m_tile_ids, shift=-first_tile_in_shard, axis=0) + + # Calculate the number of tiles we need to compute for our shard. + # + # Remove tile visits that belong to a group not in our shard. + iota = jnp.arange(num_groups, dtype=jnp.int32) + active_group_mask = jnp.logical_and(iota <= end_group, iota >= start_group) + group_tiles = jnp.where(active_group_mask, group_tiles, 0) + num_tiles = group_tiles.sum() + return (group_offsets, group_ids, m_tile_ids), num_tiles + + +def _get_group_size( + *, grid_id: jnp.ndarray, group_metadata: GroupMetadata +) -> jnp.ndarray: + """Calculate the number of rows in the current group.""" + group_offsets, group_ids = group_metadata[:2] + group_id = group_ids[grid_id] + group_start = group_offsets[group_id] + group_end = group_offsets[group_id + 1] + return group_end - group_start + + +def _get_store_mask( + *, + grid_id: jnp.ndarray, + group_metadata: GroupMetadata, + tm: int, + tn: int, +) -> jnp.ndarray: + """Mask for rows that belong to the current group in the current tile.""" + group_offsets, group_ids, m_tile_ids = group_metadata[:3] + group_id = group_ids[grid_id] + group_start = group_offsets[group_id] + group_end = group_offsets[group_id + 1] + m_id = m_tile_ids[grid_id] * tm + iota = jax.lax.broadcasted_iota(jnp.int32, (tm, tn), 0) + m_id + return jnp.logical_and(iota >= group_start, iota < group_end) + + +def _zero_uninitialized_memory( + out: jnp.ndarray, + *, + start_group: jnp.ndarray, + num_nonzero_groups: int, + group_metadata: GroupMetadata, +) -> jnp.ndarray: + """Zero out uninitialized memory from output.""" + group_offsets = group_metadata[0] + group_start = group_offsets[start_group] + group_end = group_offsets[start_group + num_nonzero_groups] + valid_mask = jax.lax.broadcasted_iota(jnp.int32, (out.shape[0],), 0) + valid_mask = (valid_mask >= group_start) & (valid_mask < group_end) + return jnp.where(valid_mask[:, None], out, 0) + + +LutFn = Callable[[int, int, int], Optional[tuple[int, int, int]]] + + +@functools.partial( + jax.jit, + static_argnames=[ + "preferred_element_type", + "tiling", + "transpose_rhs", + "interpret", + ], +) +def gmm( + lhs: jnp.ndarray, + rhs: jnp.ndarray, + group_sizes: jnp.ndarray, + preferred_element_type: jnp.dtype = jnp.float32, + tiling: tuple[int, int, int] | LutFn | None = (128, 128, 128), + group_offset: jnp.ndarray | None = None, + existing_out: jnp.ndarray | None = None, + transpose_rhs: bool = False, + interpret: bool = False, +) -> jnp.ndarray: + """Compute lhs[sizes[i-1]:sizes[i], :] @ rhs for each group 'i'. + + Args: + lhs: A 2d, jnp.ndarray with shape [m, k]. + rhs: A 3d, jnp.ndarray with shape [num_groups, k, n]. + group_sizes: A 1d, jnp.ndarray with shape [num_groups] and jnp.int32 dtype. + preferred_element_type: jnp.dtype, the element type for the output matrix. + tiling: 3-tuple of ints. The m, k and n-dimension tile sizes. + group_offset: The group in group sizes to start computing from. This is + particularly useful for when rhs num_groups is sharded. + existing_out: Existing output to write to. + transpose_rhs: True if the rhs needs to be transposed. + interpret: Whether or not to run the kernel in interpret mode, helpful for + testing and debugging. + + Returns: + A 2d, jnp.ndarray with shape [m, n]. + """ + + if existing_out is not None: + assert isinstance(existing_out, jax.Array) + expected_dtype = existing_out.dtype + if expected_dtype != preferred_element_type: + raise ValueError( + "Existing output dtype must match preferred_element_type." + ) + if group_offset is None: + group_offset = jnp.array([0], dtype=jnp.int32) + else: + if group_offset.shape: + raise ValueError( + f"group_offset must be a ()-shaped array. Got: {group_offset.shape}." + ) + group_offset = group_offset[None] + num_current_groups = rhs.shape[0] + num_total_groups = group_sizes.shape[0] + lhs, group_sizes, input_dtype = _validate_args( + lhs=lhs, rhs=rhs, group_sizes=group_sizes + ) + + # Gather shape information. + m, k, n = (lhs.shape[0], lhs.shape[1], rhs.shape[2]) + if transpose_rhs: + n = rhs.shape[1] + + # If tiling is callable, look up the problem dimensions in the LUT. If no tuned + # tile dimensions are available throw an error. + if callable(tiling): + tiling = tiling(m, k, n) + + if tiling is None: + raise ValueError(f"No tuned tiling found for (m, k, n) = ({m}, {k}, {n})") + + tm, tk, tn = tiling + tiles_k, k_rem = _calculate_irregular_num_tiles(k, tk) + tiles_n, n_rem = _calculate_irregular_num_tiles(n, tn) + del n_rem + + # Create the metadata we need for computation. + group_metadata, num_active_tiles = make_group_metadata( # pylint: disable=unbalanced-tuple-unpacking + group_sizes=group_sizes, + m=m, + tm=tm, + start_group=group_offset[0], + num_nonzero_groups=rhs.shape[0], + visit_empty_groups=False, + ) + + def kernel( + group_metadata, + group_offset, + lhs, + rhs, + existing_out, + out, + acc_scratch, + ): + group_offsets, group_ids, m_tile_ids = group_metadata + del group_offsets, group_ids, group_offset + + grid_id = pl.program_id(1) + k_i = pl.program_id(2) + + @pl.when(k_i == 0) + def _zero_acc(): + acc_scratch[...] = jnp.zeros_like(acc_scratch) + + if existing_out is not None: + prev_grid_id = jnp.where(grid_id > 0, grid_id - 1, 0) + is_first_processed_group = grid_id == 0 + m_tile_changed = m_tile_ids[grid_id] != m_tile_ids[prev_grid_id] + first_time_seeing_out = jnp.logical_or( + is_first_processed_group, m_tile_changed + ) + + @pl.when(first_time_seeing_out) + def _init_out(): + out[...] = existing_out[...] + + def mask_k_rem(x, *, dim): + if k_rem == 0: + return x + + orig_dtype = x.dtype + iota = lax.broadcasted_iota(jnp.int32, x.shape, dim) + x = x.astype(jnp.float32) + return jnp.where(iota < k_rem, x, 0).astype(orig_dtype) + + def _store_accum(): + mask = _get_store_mask( + grid_id=grid_id, + group_metadata=group_metadata, + tm=tm, + tn=tn, + ) + to_store = acc_scratch[...] + out[...] = jax.lax.select( + mask[...], to_store, out[...].astype(jnp.float32) + ).astype(preferred_element_type) + + def _accum(is_last_k_tile): + if is_last_k_tile: + mask_k_rem_lhs = partial(mask_k_rem, dim=1) + mask_k_rem_rhs = partial(mask_k_rem, dim=int(transpose_rhs)) + else: + mask_k_rem_lhs = lambda x: x + mask_k_rem_rhs = lambda x: x + + if transpose_rhs: + dot_general_dims = (((1,), (1,)), ((), ())) + else: + dot_general_dims = (((1,), (0,)), ((), ())) + + loaded_lhs = lhs[...] + loaded_rhs = rhs[...] + acc_scratch[...] += lax.dot_general( + mask_k_rem_lhs(loaded_lhs).astype(input_dtype), + mask_k_rem_rhs(loaded_rhs).astype(input_dtype), + preferred_element_type=jnp.float32, + dimension_numbers=dot_general_dims, + ) + + if is_last_k_tile: + _store_accum() + + lax.cond( + k_i == tiles_k - 1, + partial(_accum, True), + partial(_accum, False), + ) + + def lhs_transform_indices(n_i, grid_id, k_i, group_metadata, group_offset): + # lhs is (m, k). Load the [tm, tk] matrix for this m-tile. + group_offsets, group_ids, m_tile_ids = group_metadata + del n_i, group_offsets, group_ids, group_offset + return m_tile_ids[grid_id], k_i + + def rhs_transform_indices(n_i, grid_id, k_i, group_metadata, group_offset): + # rhs is (num_groups, k, n). Load the [tk, tn] matrix based on the group id + # for this m-tile. + group_offsets, group_ids, m_tile_ids = group_metadata + del group_offsets, m_tile_ids + if transpose_rhs: + k_i, n_i = n_i, k_i + + # NOTE: If we're working on only a shard of the rhs we need to adjust the + # group index we load from to account for this. The group_ids are in the + # "unsharded" domain. + return group_ids[grid_id] - group_offset[0], k_i, n_i + + def out_transform_indices(n_i, grid_id, k_i, group_metadata, group_offset): + # out is (m, n). Load the [tm, tn] matrix for this m-tile. + group_offsets, group_ids, m_tile_ids = group_metadata + del k_i, group_offsets, group_ids, group_offset + return m_tile_ids[grid_id], n_i + + out_block_spec = pl.BlockSpec(out_transform_indices, (tm, tn)) + if existing_out is None: + in_out_block_spec: Any = None + input_output_aliases = {} + else: + in_out_block_spec = out_block_spec + input_output_aliases = {6: 0} + + lhs_block_spec = pl.BlockSpec(lhs_transform_indices, (tm, tk)) + if transpose_rhs: + rhs_block_spec = pl.BlockSpec(rhs_transform_indices, (None, tn, tk)) + else: + rhs_block_spec = pl.BlockSpec(rhs_transform_indices, (None, tk, tn)) + + lhs_bytes = lhs.size * lhs.itemsize + rhs_bytes = (k * n) * rhs.itemsize # We don't read all of rhs + out_bytes = (m * n) * jnp.dtype(preferred_element_type).itemsize + max_active_tiles = group_metadata[1].size + bytes_accessed = ( + (lhs_bytes * tiles_n) + (rhs_bytes * max_active_tiles) + out_bytes + ) + flops = 2 * m * k * n + cost_estimate = pltpu.CostEstimate( + flops=flops, bytes_accessed=bytes_accessed, transcendentals=0 + ) + call_gmm = pl.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct((m, n), preferred_element_type), + grid_spec=pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=2, + in_specs=[ + lhs_block_spec, + rhs_block_spec, + in_out_block_spec, + ], + out_specs=out_block_spec, + grid=(tiles_n, num_active_tiles, tiles_k), + scratch_shapes=[pltpu.VMEM((tm, tn), jnp.float32)], + ), + input_output_aliases=input_output_aliases, + compiler_params=dict( + mosaic=dict( + dimension_semantics=("parallel", "arbitrary", "arbitrary"), + cost_estimate=cost_estimate, + ) + ), + interpret=interpret, + ) + + out = call_gmm( + group_metadata, + group_offset, + lhs, + rhs, + existing_out, + ) + if existing_out is None and num_current_groups < num_total_groups: + out = _zero_uninitialized_memory( + out, + start_group=group_offset[0], + num_nonzero_groups=rhs.shape[0], + group_metadata=group_metadata, + ) + return out + + +@functools.partial( + jax.jit, + static_argnames=[ + "preferred_element_type", + "tiling", + "num_actual_groups", + "interpret", + ], +) +def tgmm( + lhs: jnp.ndarray, + rhs: jnp.ndarray, + group_sizes: jnp.ndarray, + preferred_element_type: jnp.dtype = jnp.float32, + tiling: tuple[int, int, int] | LutFn | None = (128, 128, 128), + group_offset: jnp.ndarray | None = None, + num_actual_groups: int | None = None, + existing_out: jnp.ndarray | None = None, + interpret: bool = False, +) -> jnp.ndarray: + """Compute lhs[:, sizes[i-1]:sizes[i]] @ rhs[sizes[i-1]:sizes[i], :]. + + Args: + lhs: A 2d, jnp.ndarray with shape [k, m]. + rhs: A 2d, jnp.ndarray with shape [m, n]. + group_sizes: A 1d, jnp.ndarray with shape [num_groups] and jnp.int32 dtype. + preferred_element_type: jnp.dtype, the element type for the output matrix. + tiling: 3-tuple of ints. The m, k and n-dimension tile sizes. + group_offset: The group in group sizes to start computing from. This is + particularly useful for when rhs num_groups is sharded. + num_actual_groups: For when num_groups is sharded and we should only compute + the groups that are local, starting from group_offset. + existing_out: Existing output to write to. + interpret: Whether or not to run the kernel in interpret mode, helpful for + testing and debugging. + + Returns: + A 3d, jnp.ndarray with shape [num_groups, k, n]. + """ + if group_offset is None: + group_offset = jnp.array([0], dtype=jnp.int32) + else: + group_offset = group_offset[None] + lhs, group_sizes, input_dtype = _validate_args( + lhs=lhs, rhs=rhs, group_sizes=group_sizes, expected_rhs_dims=2 + ) + + # Gather shape information. + k, m, n = (lhs.shape[0], lhs.shape[1], rhs.shape[1]) + num_groups = group_sizes.shape[0] + num_actual_groups = ( + num_actual_groups if num_actual_groups is not None else num_groups + ) + + # If tiling is callable, look up the problem dimensions in the LUT. If no tuned + # tile dimensions are available throw an error. + if callable(tiling): + tiling = tiling(m, k, n) + + if tiling is None: + raise ValueError(f"No tuned tiling found for (m, k, n) = ({m}, {k}, {n})") + + tm, tk, tn = tiling + tiles_k, k_rem = _calculate_irregular_num_tiles(k, tk) + del k_rem + tiles_n, n_rem = _calculate_irregular_num_tiles(n, tn) + del n_rem + + # Create the metadata we need for computation. + group_metadata, num_active_tiles = make_group_metadata( + group_sizes=group_sizes, + m=m, + tm=tm, + start_group=group_offset[0], + num_nonzero_groups=num_actual_groups, + visit_empty_groups=True, + ) + + def kernel( + group_metadata, + group_offset, + lhs, + rhs, + existing_out, + out, + acc_scratch, + ): + grid_id = pl.program_id(2) + group_offsets, group_ids, m_tile_ids = group_metadata + del group_offsets, group_offset, m_tile_ids + + group = group_ids[grid_id] + prev_grid_id = jnp.where(grid_id > 0, grid_id - 1, 0) + prev_group = group_ids[prev_grid_id] + + group_has_changed = jnp.logical_or(grid_id == 0, prev_group != group) + + @pl.when(group_has_changed) + def _zero_acc(): + acc_scratch[...] = jnp.zeros_like(acc_scratch) + + # We'll only do computation if our group has a nonzero number of rows in it. + dont_skip = ( + _get_group_size(grid_id=grid_id, group_metadata=group_metadata) > 0 + ) + + @pl.when(dont_skip) + def _do(): + rhs_mask = _get_store_mask( + grid_id=grid_id, + group_metadata=group_metadata, + tm=tm, + tn=tn, + ) + lhs_mask = _get_store_mask( + grid_id=grid_id, + group_metadata=group_metadata, + tm=tm, + tn=tk, + ) + + loaded_lhs = lhs[...] + loaded_rhs = rhs[...] + loaded_lhs = lax.select( + lhs_mask[...], + loaded_lhs.astype(jnp.float32), + jnp.zeros_like(lhs, jnp.float32), + ).swapaxes(0, 1) + loaded_rhs = lax.select( + rhs_mask[...], + loaded_rhs.astype(jnp.float32), + jnp.zeros_like(rhs, jnp.float32), + ) + + acc_scratch[...] += lax.dot( + loaded_lhs.astype(input_dtype), + loaded_rhs.astype(input_dtype), + preferred_element_type=jnp.float32, + ) + + is_end_of_grid = grid_id == (pl.num_programs(2) - 1) + next_grid_id = jnp.where(is_end_of_grid, grid_id, grid_id + 1) + next_group = group_ids[next_grid_id] + + group_is_changing = jnp.logical_or(is_end_of_grid, group != next_group) + + @pl.when(group_is_changing) + def _store_accum(): + to_store = acc_scratch[...] + if existing_out is not None: + to_store += existing_out[...].astype(jnp.float32) + out[...] = to_store.astype(preferred_element_type) + + def lhs_transform_indices(n_i, k_i, grid_id, group_metadata, group_offset): + # lhs is (m, k). Load the [tm, tk] matrix for this m-tile. + group_offsets, group_ids, m_tile_ids = group_metadata + del n_i, group_offsets, group_ids, group_offset + return m_tile_ids[grid_id], k_i + + def rhs_transform_indices(n_i, k_i, grid_id, group_metadata, group_offset): + # rhs is (m, n). Load the [tm, tn] matrix for this m-tile. + group_offsets, group_ids, m_tile_ids = group_metadata + del k_i, group_offsets, group_ids, group_offset + return m_tile_ids[grid_id], n_i + + def out_transform_indices(n_i, k_i, grid_id, group_metadata, group_offset): + # out is (num_groups, k, n). Load the [tk, tn] matrix based on the group id + # for this m-tile. + group_offsets, group_ids, m_tile_ids = group_metadata + del group_offsets, m_tile_ids + + # NOTE: If we're working on only a shard of the output we need to adjust the + # group index we load from to account for this. The group_ids are in the + # "unsharded" domain. + return group_ids[grid_id] - group_offset[0], k_i, n_i + + out_block_spec = pl.BlockSpec(out_transform_indices, (None, tk, tn)) + if existing_out is None: + in_out_block_spec: Any = None + input_output_aliases = {} + else: + in_out_block_spec = out_block_spec + input_output_aliases = {6: 0} + + lhs_block_spec = pl.BlockSpec(lhs_transform_indices, (tm, tk)) + rhs_block_spec = pl.BlockSpec(rhs_transform_indices, (tm, tn)) + + lhs_bytes = lhs.size * lhs.itemsize + rhs_bytes = rhs.size * rhs.itemsize + out_bytewidth = jnp.dtype(preferred_element_type).itemsize + out_bytes = (num_actual_groups * k * n) * out_bytewidth + bytes_accessed = ( + (lhs_bytes * tiles_n) + (rhs_bytes * tiles_k) + out_bytes + ) + flops = 2 * m * k * n + cost_estimate = pltpu.CostEstimate( + flops=flops, bytes_accessed=bytes_accessed, transcendentals=0 + ) + lhs = lhs.swapaxes(0, 1) + call_gmm = pl.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct( + (num_actual_groups, k, n), preferred_element_type + ), + grid_spec=pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=2, + in_specs=[ + lhs_block_spec, + rhs_block_spec, + in_out_block_spec, + ], + out_specs=out_block_spec, + grid=(tiles_n, tiles_k, num_active_tiles), + scratch_shapes=[pltpu.VMEM((tk, tn), jnp.float32)], + ), + input_output_aliases=input_output_aliases, + compiler_params=dict( + mosaic=dict( + dimension_semantics=("parallel", "arbitrary", "arbitrary"), + cost_estimate=cost_estimate, + ) + ), + interpret=interpret, + ) + + out = call_gmm( + group_metadata, + group_offset, + lhs, + rhs, + existing_out, + ) + return out diff --git a/jax/experimental/pallas/ops/tpu/megablox/ops.py b/jax/experimental/pallas/ops/tpu/megablox/ops.py new file mode 100644 index 000000000000..015c6b3ade67 --- /dev/null +++ b/jax/experimental/pallas/ops/tpu/megablox/ops.py @@ -0,0 +1,109 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Grouped matrix multiplication operations with custom VJPs.""" + +import jax +from jax.experimental.pallas.ops.tpu.megablox import gmm as backend +import jax.numpy as jnp + + +gmm = jax.custom_vjp( + backend.gmm, + nondiff_argnums=(3, 4, 7, 8), +) + + +def _gmm_fwd( + lhs: jnp.ndarray, + rhs: jnp.ndarray, + group_sizes: jnp.ndarray, + preferred_element_type: jnp.dtype = jnp.float32, + tiling: tuple[int, int, int] = (128, 128, 128), + group_offset: jnp.ndarray | None = None, + existing_out: jnp.ndarray | None = None, + transpose_rhs: bool = False, + interpret: bool = False, +) -> tuple[ + jnp.ndarray, + tuple[ + jnp.ndarray, + jnp.ndarray, + jnp.ndarray, + jnp.ndarray | None, + int, + ], +]: + """Forward function for GMM VJP.""" + out = backend.gmm( + lhs, + rhs, + group_sizes, + preferred_element_type, + tiling, + group_offset, + existing_out, + transpose_rhs=transpose_rhs, + interpret=interpret, + ) + return out, (lhs, rhs, group_sizes, group_offset, rhs.shape[0]) + + +def _gmm_bwd( + preferred_element_type: jnp.dtype, + tiling: tuple[int, int, int], + transpose_rhs: bool, + interpret: bool, + residual: tuple[ + jnp.ndarray, + jnp.ndarray, + jnp.ndarray, + jnp.ndarray | None, + int, + ], + grad: jnp.ndarray, +) -> tuple[jnp.ndarray, jnp.ndarray, None, None, jnp.ndarray]: + """Backward function for throughput GMM VJP.""" + del preferred_element_type + lhs, rhs, group_sizes, group_offset, num_actual_groups = residual + grad_lhs = backend.gmm( + grad, + rhs, + group_sizes, + lhs[0].dtype, + tiling, + group_offset, + transpose_rhs=not transpose_rhs, + interpret=interpret, + ) + grad_rhs = backend.tgmm( + lhs.swapaxes(0, 1), + grad, + group_sizes, + rhs.dtype, + tiling, + group_offset, + num_actual_groups, + interpret=interpret, + ) + + # NOTE: If the rhs transposition is fused into the forward pass we need to + # return the transpose of the rhs gradient that we calculated above. + # + # TODO(tgale, enriqueps, apaske): Fuse this transposition into the tgmm. + grad_rhs = grad_rhs.swapaxes(1, 2) if transpose_rhs else grad_rhs + return grad_lhs, grad_rhs, None, None, grad + + +gmm.defvjp(_gmm_fwd, _gmm_bwd) diff --git a/jax/experimental/array_api/_indexing_functions.py b/jax/experimental/pallas/ops/tpu/paged_attention/__init__.py similarity index 85% rename from jax/experimental/array_api/_indexing_functions.py rename to jax/experimental/pallas/ops/tpu/paged_attention/__init__.py index 261c81b20351..1cce79926e8a 100644 --- a/jax/experimental/array_api/_indexing_functions.py +++ b/jax/experimental/pallas/ops/tpu/paged_attention/__init__.py @@ -12,7 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -import jax - -def take(x, indices, /, *, axis): - return jax.numpy.take(x, indices, axis=axis) +from jax.experimental.pallas.ops.tpu.paged_attention.paged_attention_kernel import paged_attention diff --git a/jax/experimental/pallas/ops/tpu/paged_attention/paged_attention_kernel.py b/jax/experimental/pallas/ops/tpu/paged_attention/paged_attention_kernel.py new file mode 100644 index 000000000000..979c0e7e1de9 --- /dev/null +++ b/jax/experimental/pallas/ops/tpu/paged_attention/paged_attention_kernel.py @@ -0,0 +1,649 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""PagedAttention TPU kernel.""" + +import functools + +import jax +from jax import lax +from jax.experimental import pallas as pl +from jax.experimental.pallas import tpu as pltpu +from jax.experimental.pallas.ops.tpu.paged_attention import quantization_utils +import jax.numpy as jnp +import numpy as np + + +DEFAULT_MASK_VALUE = -0.7 * float(np.finfo(np.dtype("float32")).max) + + +class MultiPageAsyncCopyDescriptor: + """Descriptor for async copy of multiple K/V pages from HBM.""" + + def __init__( + self, + pages_hbm_ref, + scales_pages_hbm_ref, + vmem_buffer, + scales_vmem_buffer, + sem, + page_indices, + page_indices_start_offset, + num_pages_to_load, + head_index, + ): + self._vmem_buffer = vmem_buffer + self._scales_vmem_buffer = scales_vmem_buffer + self._num_pages_to_load = num_pages_to_load + if head_index is not None: + self._pages_hbm_ref = pages_hbm_ref.at[head_index] + if scales_pages_hbm_ref is not None: + self._scales_pages_hbm_ref = scales_pages_hbm_ref.at[head_index] + else: + self._scales_pages_hbm_ref = None + else: + self._pages_hbm_ref = pages_hbm_ref + self._scales_pages_hbm_ref = scales_pages_hbm_ref + self._sem = sem + self._page_indices = page_indices + self._page_indices_start_offset = page_indices_start_offset + self._async_copies = [ + self._make_async_copy(i) for i in range(self._num_pages_to_load) + ] + if ( + self._scales_pages_hbm_ref is not None + and self._scales_vmem_buffer is not None + ): + self._async_copies += [ + self._make_scales_async_copy(i) + for i in range(self._num_pages_to_load) + ] + + def _make_async_copy(self, i): + page_index = self._page_indices[self._page_indices_start_offset + i] + return pltpu.make_async_copy( + self._pages_hbm_ref.at[page_index], self._vmem_buffer.at[i], self._sem + ) + + def _make_scales_async_copy(self, i): + page_index = self._page_indices[self._page_indices_start_offset + i] + return pltpu.make_async_copy( + self._scales_pages_hbm_ref.at[page_index], # pytype: disable=attribute-error + self._scales_vmem_buffer.at[i], # pytype: disable=attribute-error + self._sem, + ) + + def start(self): + """Starts the async copies.""" + for async_copy in self._async_copies: + async_copy.start() + + def _maybe_dequantize(self, x, x_scale, dtype=jnp.bfloat16): + if x_scale is None: + return x.astype(dtype) + return quantization_utils.from_int8(x, x_scale, dtype=dtype) + + def wait_and_get_loaded(self) -> jax.Array: + """Wait async copies and gets the loaded buffer as a jax.Array.""" + for async_copy in self._async_copies: + async_copy.wait() + head_dim = self._vmem_buffer.shape[-1] + jax_array = self._vmem_buffer[...].astype(jnp.float32) + if self._scales_vmem_buffer is not None: + scales_jax_array = self._scales_vmem_buffer[...].astype(jnp.float32) + else: + scales_jax_array = None + jax_array = self._maybe_dequantize(jax_array, scales_jax_array) + return jax_array.reshape(-1, head_dim) + + +def paged_flash_attention_kernel( + lengths_ref, + page_indices_ref, + buffer_index_ref, + step_ref, + q_ref, + k_pages_hbm_ref, + k_scales_pages_hbm_ref, + v_pages_hbm_ref, + v_scales_pages_hbm_ref, + o_ref, + m_ref, + l_ref, + k_vmem_buffer, + k_scales_vmem_buffer, + v_vmem_buffer, + v_scales_vmem_buffer, + sem, + *, + batch_size: int, + pages_per_compute_block: int, + pages_per_sequence: int, + mask_value: float, + megacore_mode: str, + program_ids=(), +): + """Pallas kernel for paged attention.""" + if program_ids: + core_index, b, h, i = program_ids + else: + core_index, b, h, i = ( + pl.program_id(0), + pl.program_id(1), + pl.program_id(2), + pl.program_id(3), + ) + num_kv_heads, _, page_size, _ = k_pages_hbm_ref.shape + bk = page_size * pages_per_compute_block + num_cores = pl.num_programs(0) + + b_step = num_cores if megacore_mode == "batch" else 1 + b_start = core_index if megacore_mode == "batch" else 0 + h_step = num_cores if megacore_mode == "kv_head" else 1 + h_start = core_index if megacore_mode == "kv_head" else 0 + + h = h * h_step + h_start + b = b * b_step + b_start + length = lengths_ref[b] + + def compute_block_indices(b, h, i): + + def advance_b(): + next_b = b + b_step + + def advance_to_next_non_zero_length(): + next_next_b = next_b + b_step + return lax.fori_loop( + lax.div(next_next_b, b_step), + lax.div(batch_size, b_step), + lambda _, b: jnp.where(lengths_ref[b] == 0, b + b_step, b), + next_next_b, + ) + + return ( + lax.cond( + jnp.logical_and(next_b < batch_size, lengths_ref[next_b] == 0), + advance_to_next_non_zero_length, + lambda: next_b, + ), + h_start, + 0, + ) + + def advance_h(): + next_h = h + h_step + return lax.cond(next_h < num_kv_heads, lambda: (b, next_h, 0), advance_b) + + return lax.cond(i * bk < lengths_ref[b], lambda: (b, h, i), advance_h) + + def create_kv_async_copy_descriptors(b, h, i, buffer_index): + page_offset = b * pages_per_sequence + i * pages_per_compute_block + pages_to_load = pages_per_compute_block + async_copy_k = MultiPageAsyncCopyDescriptor( + k_pages_hbm_ref, + k_scales_pages_hbm_ref, + k_vmem_buffer.at[buffer_index], + k_scales_vmem_buffer.at[buffer_index] + if k_scales_vmem_buffer is not None + else None, + sem, + page_indices_ref, + page_offset, + pages_to_load, + h, + ) + async_copy_v = MultiPageAsyncCopyDescriptor( + v_pages_hbm_ref, + v_scales_pages_hbm_ref, + v_vmem_buffer.at[buffer_index], + v_scales_vmem_buffer.at[buffer_index] + if v_scales_vmem_buffer is not None + else None, + sem, + page_indices_ref, + page_offset, + pages_to_load, + h, + ) + return async_copy_k, async_copy_v + + @pl.when(i * bk < length) + def flash_attention(): + step = step_ref[0] + buffer_index = buffer_index_ref[0] + + @pl.when(i == 0) + def init(): + m_ref[...] = jnp.full_like(m_ref, -jnp.inf) + l_ref[...] = jnp.zeros_like(l_ref) + o_ref[...] = jnp.zeros_like(o_ref) + + @pl.when(step == 0) + def prefetch_first_block(): + async_copy_k, async_copy_v = create_kv_async_copy_descriptors( + b, h, i, buffer_index + ) + async_copy_k.start() + async_copy_v.start() + + next_b, next_h, next_i = compute_block_indices(b, h, i + 1) + + @pl.when(next_b < batch_size) + def prefetch_next_block(): + next_buffer_index = jnp.where(buffer_index == 0, 1, 0) + async_copy_next_k, async_copy_next_v = create_kv_async_copy_descriptors( + next_b, next_h, next_i, next_buffer_index + ) + async_copy_next_k.start() + async_copy_next_v.start() + buffer_index_ref[0] = next_buffer_index + + async_copy_k, async_copy_v = create_kv_async_copy_descriptors( + b, h, i, buffer_index + ) + q = q_ref[...].astype(jnp.float32) + k = async_copy_k.wait_and_get_loaded() + qk = jnp.einsum('hd,td->ht', q, k, preferred_element_type=jnp.float32) + mask = i * bk + jax.lax.broadcasted_iota(jnp.int32, qk.shape, 1) < length + qk = qk + jnp.where(mask, 0.0, mask_value) + m_curr = qk.max(axis=-1) + + s_curr = jnp.exp(qk - m_curr[..., None]) + m_prev, l_prev = m_ref[...], l_ref[...] + l_curr = jax.lax.broadcast_in_dim(s_curr.sum(axis=-1), l_prev.shape, (0,)) + m_curr = jax.lax.broadcast_in_dim(m_curr, m_prev.shape, (0,)) + m_next = jnp.maximum(m_prev, m_curr) + alpha = jnp.exp(m_prev - m_next) + beta = jnp.exp(m_curr - m_next) + l_next = alpha * l_prev + beta * l_curr + l_next_safe = jnp.where(l_next == 0.0, 1.0, l_next) + + v = async_copy_v.wait_and_get_loaded() + o_curr_times_l_curr = jnp.dot(s_curr, v) + + m_ref[...], l_ref[...] = m_next, l_next_safe + o_ref[...] = ( + (l_prev * alpha * o_ref[...] + beta * o_curr_times_l_curr) / l_next_safe + ).astype(o_ref.dtype) + + step_ref[0] = step + 1 + + +def paged_flash_attention_kernel_inline_seq_dim( + lengths_ref, + page_indices_ref, + buffer_index_ref, + step_ref, + q_ref, + k_pages_hbm_ref, + k_scales_pages_hbm_ref, + v_pages_hbm_ref, + v_scales_pages_hbm_ref, + o_ref, + m_ref, + l_ref, + k_vmem_buffer, + k_scales_vmem_buffer, + v_vmem_buffer, + v_scales_vmem_buffer, + sem, + *, + batch_size: int, + pages_per_compute_block: int, + pages_per_sequence: int, + mask_value: float, + megacore_mode: str, +): + core_index, b, h = pl.program_id(0), pl.program_id(1), pl.program_id(2) + + # Initialize the output HBM buffers to avoid accessing garbage memory inside + # the kernel body below. + m_ref[...] = jnp.full_like(m_ref, -jnp.inf) + l_ref[...] = jnp.zeros_like(l_ref) + o_ref[...] = jnp.zeros_like(o_ref) + + def body(i, _): + paged_flash_attention_kernel( + lengths_ref, + page_indices_ref, + buffer_index_ref, + step_ref, + q_ref, + k_pages_hbm_ref, + k_scales_pages_hbm_ref, + v_pages_hbm_ref, + v_scales_pages_hbm_ref, + o_ref, + m_ref, + l_ref, + k_vmem_buffer, + k_scales_vmem_buffer, + v_vmem_buffer, + v_scales_vmem_buffer, + sem, + batch_size=batch_size, + pages_per_compute_block=pages_per_compute_block, + pages_per_sequence=pages_per_sequence, + mask_value=mask_value, + megacore_mode=megacore_mode, + program_ids=(core_index, b, h, i), + ) + return () + + bk = pages_per_compute_block * k_pages_hbm_ref.shape[-2] + + if megacore_mode == "batch": + num_cores = pl.num_programs(0) + length = lengths_ref[b * num_cores + core_index] + else: + length = lengths_ref[b] + + lax.fori_loop(0, lax.div(length + bk - 1, bk), body, ()) + + +@functools.partial( + jax.jit, + static_argnames=[ + "pages_per_compute_block", + "mask_value", + "megacore_mode", + "inline_seq_dim", + ], +) +def paged_attention( + q: jax.Array, + k_pages: jax.Array | quantization_utils.QuantizedTensor, + v_pages: jax.Array | quantization_utils.QuantizedTensor, + lengths: jax.Array, + page_indices: jax.Array, + *, + mask_value: float = DEFAULT_MASK_VALUE, + pages_per_compute_block: int, + megacore_mode: str | None = None, + inline_seq_dim: bool = True, +) -> jax.Array: + """Paged grouped query attention. + + Args: + q: A [batch_size, num_heads, head_dim] jax.Array. + k_pages: A [num_kv_heads, total_num_pages, page_size, head_dim] jax.Array. + v_pages: A [num_kv_heads, total_num_pages, page_size, head_dim] jax.Array. + lengths: A i32[batch_size] jax.Array the length of each example. + page_indices: A i32[batch_size, pages_per_sequence] jax.Array. Each entry + should be in the range of [0, total_num_pages), indicating where to locate + the page in `k_pages` or `v_pages`. + mask_value: The value used for padding in attention. By default it is a very + negative floating point number. + pages_per_compute_block: how many pages to be processed in one flash + attention block in the pallas kernel. + megacore_mode: if set, enable megacore to parallelize the computation. Must + be one of ['kv_head', 'batch', None]. Caveat: set this only if megacore is + enabled, otherwise the kernel may hang. If you are not sure, leave it to + None. + * None: disable megacore parallelism. + * kv_head: megacore parallelism on KV heads; requires number of KV heads + divisible by 2. + * batch: megacore parallelism on batch dimension; requires batch divisible + by 2. + inline_seq_dim: whether to fuse kernel instances along the sequence dim into + one kernel. + + Returns: + The output of attention([batch_size, num_heads, head_dim]). + """ + if isinstance(k_pages, quantization_utils.QuantizedTensor): + k_pages, k_scales_pages = k_pages.weight, k_pages.scales + assert isinstance(k_scales_pages, jax.Array) # For typing. + k_scales_pages = jnp.broadcast_to( + k_scales_pages, (*k_scales_pages.shape[:-1], k_pages.shape[-1]) + ) + else: + k_scales_pages = None + if isinstance(v_pages, quantization_utils.QuantizedTensor): + v_pages, v_scales_pages = v_pages.weight, v_pages.scales + assert isinstance(v_scales_pages, jax.Array) # For typing. + v_scales_pages = jnp.broadcast_to( + v_scales_pages, (*v_scales_pages.shape[:-1], v_pages.shape[-1]) + ) + else: + v_scales_pages = None + + batch_size, num_heads, head_dim = q.shape + num_kv_heads, _, page_size, head_dim_k = k_pages.shape + batch_size_paged_indices, pages_per_sequence = page_indices.shape + + if k_pages.shape != v_pages.shape: + raise ValueError( + f"k_pages and v_pages must have the same shape. Got {k_pages.shape} and" + f" {v_pages.shape}" # pytype: disable=attribute-error + ) + if num_heads % num_kv_heads != 0: + raise ValueError( + "Number of Q heads must be divisible by number of KV heads. Got" + f" {num_heads} and {num_kv_heads}." + ) + if head_dim_k != head_dim: + raise ValueError( + "head_dim of Q must be the same as that of K/V. Got" + f" {head_dim} and {head_dim_k}." + ) + if pages_per_sequence % pages_per_compute_block != 0: + raise ValueError( + "pages_per_compute_block must be divisible by pages per sequence. Got" + f" {pages_per_compute_block} and {pages_per_sequence}." + ) + if lengths.shape != (batch_size,): + raise ValueError("`lengths` and `q` must have the same batch size") + if batch_size_paged_indices != batch_size: + raise ValueError("`page_indices` and `q` must have the same batch size") + if lengths.dtype != jnp.int32: + raise ValueError( + "The dtype of `lengths` must be int32. Got {lengths.dtype}" + ) + + # TODO(dinghua): get the actual cores per chip once there's an official API. + if megacore_mode == "kv_head": + if num_kv_heads % 2 != 0: + raise ValueError( + "number of KV heads must be even when megacore_mode is 'kv_head'" + ) + num_cores = 2 + elif megacore_mode == "batch": + if batch_size % 2 != 0: + raise ValueError("batch size must be even when megacore_mode is 'batch'") + num_cores = 2 + elif megacore_mode is None: + num_cores = 1 + else: + raise ValueError("megacore_mode must be one of ['kv_head', 'batch', None]") + + if (num_heads // num_kv_heads) % 8 != 0: + # Reshape q to hint XLA to pick a <1x128> layout otherwise it will pick a + # <8x128> layout for a <1x128> memref inside the kernel and error out. + q = q.reshape(batch_size, num_heads, 1, head_dim) + if megacore_mode == "kv_head": + q_block_spec = pl.BlockSpec( + (None, num_heads // num_kv_heads, None, head_dim), + lambda core_index, b, h, *_: (b, h * num_cores + core_index, 0, 0), + ) + elif megacore_mode == "batch": + q_block_spec = pl.BlockSpec( + (None, num_heads // num_kv_heads, None, head_dim), + lambda core_index, b, h, *_: (b * num_cores + core_index, h, 0, 0), + ) + else: + q_block_spec = pl.BlockSpec( + (None, num_heads // num_kv_heads, None, head_dim), + lambda core_index, b, h, *_: (b, h, 0, 0), + ) + q_dtype_for_kernel_launch = jnp.float32 + else: + if megacore_mode == "kv_head": + q_block_spec = pl.BlockSpec( + (None, num_heads // num_kv_heads, head_dim), + lambda core_index, b, h, *_: (b, h * num_cores + core_index, 0), + ) + elif megacore_mode == "batch": + q_block_spec = pl.BlockSpec( + (None, num_heads // num_kv_heads, head_dim), + lambda core_index, b, h, *_: (b * num_cores + core_index, h, 0), + ) + else: + q_block_spec = pl.BlockSpec( + (None, num_heads // num_kv_heads, head_dim), + lambda core_index, b, h, *_: (b, h, 0), + ) + q_dtype_for_kernel_launch = q.dtype + + if inline_seq_dim: + kernel = paged_flash_attention_kernel_inline_seq_dim + grid = ( + num_cores, + batch_size // num_cores if megacore_mode == "batch" else batch_size, + num_kv_heads // num_cores + if megacore_mode == "kv_head" + else num_kv_heads, + ) + dimension_sematics = ("parallel", "arbitrary", "arbitrary") + else: + kernel = paged_flash_attention_kernel + grid = ( + num_cores, + batch_size // num_cores if megacore_mode == "batch" else batch_size, + num_kv_heads // num_cores + if megacore_mode == "kv_head" + else num_kv_heads, + pages_per_sequence // pages_per_compute_block, + ) # type: ignore + dimension_sematics = ("parallel", "arbitrary", "arbitrary", "arbitrary") # type: ignore + + if k_scales_pages is not None and v_scales_pages is not None: + in_specs = [ + q_block_spec, + pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + ] + scratch_shapes = ( + pltpu.VMEM( + ( + 2, # For double buffering during DMA copies. + pages_per_compute_block, + page_size, + head_dim, + ), + k_pages.dtype, + ), # k_pages buffer + pltpu.VMEM( + ( + 2, # For double buffering during DMA copies. + pages_per_compute_block, + page_size, + head_dim, + ), + k_scales_pages.dtype, # pytype: disable=attribute-error + ), # k_scales_pages buffer + pltpu.VMEM( + ( + 2, # For double buffering during DMA copies. + pages_per_compute_block, + page_size, + head_dim, + ), + v_pages.dtype, + ), # v_pages buffer + pltpu.VMEM( + ( + 2, # For double buffering during DMA copies. + pages_per_compute_block, + page_size, + head_dim, + ), + v_scales_pages.dtype, # pytype: disable=attribute-error + ), # v_scales_pages buffer + pltpu.SemaphoreType.DMA, + ) + else: + in_specs = [ + q_block_spec, + pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + None, # type: ignore[list-item] + pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + None, # type: ignore[list-item] + ] + scratch_shapes = ( + pltpu.VMEM( + ( + 2, # For double buffering during DMA copies. + pages_per_compute_block, + page_size, + head_dim, + ), + k_pages.dtype, + ), # k_pages buffer + None, + pltpu.VMEM( + ( + 2, # For double buffering during DMA copies. + pages_per_compute_block, + page_size, + head_dim, + ), + v_pages.dtype, + ), # v_pages buffer + None, + pltpu.SemaphoreType.DMA, + ) + + out, _, _ = pl.pallas_call( + functools.partial( + kernel, + pages_per_sequence=pages_per_sequence, + batch_size=batch_size, + pages_per_compute_block=pages_per_compute_block, + mask_value=mask_value, + megacore_mode=megacore_mode, + ), + grid_spec=pltpu.PrefetchScalarGridSpec( + # There are 4 scalars prefetched per kernel call: `lengths_ref`, + # `page_indices_ref`, `buffer_index_ref`, `step_ref` + num_scalar_prefetch=4, + in_specs=in_specs, + out_specs=[ + q_block_spec, + q_block_spec, + q_block_spec, + ], + grid=grid, + scratch_shapes=scratch_shapes, + ), + compiler_params=dict(mosaic=dict(dimension_semantics=dimension_sematics)), + out_shape=[ + jax.ShapeDtypeStruct(q.shape, q_dtype_for_kernel_launch), + jax.ShapeDtypeStruct((*q.shape[:-1], 1), jnp.float32), + jax.ShapeDtypeStruct((*q.shape[:-1], 1), jnp.float32), + ], + )( + lengths, + page_indices.reshape(-1), + jnp.zeros((1,), jnp.int32), # buffer index + jnp.zeros((1,), jnp.int32), # step + q.astype(q_dtype_for_kernel_launch), + k_pages, + k_scales_pages, + v_pages, + v_scales_pages, + ) + return out.reshape(batch_size, num_heads, head_dim).astype(q.dtype) diff --git a/jax/experimental/pallas/ops/tpu/paged_attention/quantization_utils.py b/jax/experimental/pallas/ops/tpu/paged_attention/quantization_utils.py new file mode 100644 index 000000000000..81466059986d --- /dev/null +++ b/jax/experimental/pallas/ops/tpu/paged_attention/quantization_utils.py @@ -0,0 +1,107 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import NamedTuple +import jax +from jax import numpy as jnp + +P = jax.sharding.PartitionSpec +MAX_INT8 = 127.5 + + +class QuantizedTensor(NamedTuple): + """A tensor which has been quantized to int8 and its scales. + + Attributes: + weight: Weight + scales: Scales + """ + + weight: jnp.ndarray + scales: jnp.ndarray + + +def to_int8(x: jnp.ndarray, h: jnp.ndarray) -> jnp.ndarray: + """Converts a float array to an int8 array with a scale. + + Args: + x: Float array. + h: Quantization scale. + + Returns: + Int8 array. + """ + return jnp.int8(jnp.rint(x * (MAX_INT8 / h))) + + +def from_int8( + x: jnp.ndarray, h: jnp.ndarray, dtype: jnp.dtype = jnp.bfloat16 +) -> jnp.ndarray: + """Converts an int8 array to a float array with a scale. + + Args: + x: Int8 array. + h: Quantization scale. + dtype: Float dtype to convert to. + + Returns: + Float array. + """ + return x.astype(dtype) * h / MAX_INT8 + + +def get_quantization_scales(x: jnp.ndarray) -> jnp.ndarray: + """Computes the quantization scales for a float array. + + These are the maximum values of the trailing dimension. + + Args: + x: Float array to quantize. + + Returns: + Array of the same shape as input but with the trailing dimension reduced to + a size 1 absolute max value. + """ + return jnp.max(jnp.abs(x), axis=-1, keepdims=True) + + +def quantize_to_int8( + x: jnp.ndarray, +) -> QuantizedTensor: + """Quantizes a float array to an int8 QuantizedTensor. + + Args: + x: Float array to quantize. + + Returns: + Int8 QuantizedTensor. + """ + x_scales = get_quantization_scales(x) + return QuantizedTensor(weight=to_int8(x, x_scales), scales=x_scales) + + +def unquantize_from_int8( + x: QuantizedTensor, + dtype: jnp.dtype = jnp.bfloat16, +) -> jnp.ndarray: + """Unquantizes an int8 QuantizedTensor to a float array. + + Args: + x: Int8 QuantizedTensor to unquantize. + dtype: Float dtype to unquantize to. + + Returns: + Float array. + """ + return from_int8(x.weight, x.scales, dtype) diff --git a/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py b/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py index 313b2a677c13..bebeb7551b7a 100644 --- a/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py +++ b/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py @@ -16,10 +16,11 @@ from __future__ import annotations +from collections.abc import Callable, Mapping import dataclasses import enum import functools -from typing import Any, Callable, Literal, NamedTuple, Union, Optional, overload +from typing import Any, Literal, NamedTuple, Optional, Union, overload import jax from jax import ad_checkpoint @@ -89,10 +90,13 @@ class SegmentIds(NamedTuple): def get_kernel_name( - is_mqa: bool, save_residuals: bool, is_segmented: bool, phase: str + block_metadata: Mapping[str, Any], + is_mqa: bool, + save_residuals: bool, + is_segmented: bool, + phase: str, ) -> str: """Returns a unique name for all SplashAttention kernel variants.""" - assert phase == "dq" or phase == "dkv" or phase == "fwd" # Saving residuals is supported only for the fwd phase. assert not save_residuals or phase == "fwd" @@ -103,7 +107,9 @@ def get_kernel_name( residuals = "_no_residuals" attention_type = "mqa" if is_mqa else "mha" segments = "_segmented" if is_segmented else "" - return f"splash_{attention_type}_{phase}{segments}{residuals}" + return f"splash_{attention_type}_{phase}{segments}{residuals}_" + "_".join( + f"{k}={v}" for k, v in sorted(block_metadata.items()) + ) # Reference attention implementations @@ -418,9 +424,9 @@ def _wrapped( if is_grouped: def reshape_activations(activations): - if activations.ndim == 4: - kv_heads, q_heads_per_kv_head, q_seq_len, head_dim = activations.shape - return activations.reshape( + if activations.ndim == 4: # pytype: disable=attribute-error + kv_heads, q_heads_per_kv_head, q_seq_len, head_dim = activations.shape # pytype: disable=attribute-error + return activations.reshape( # pytype: disable=attribute-error kv_heads * q_heads_per_kv_head, q_seq_len, head_dim ) return activations @@ -582,7 +588,7 @@ def _apply_mask_and_soft_cap( *, attn_logits_soft_cap: float, k_slice: pl.Slice, - k_offset: int, + k_offset: int | jax.Array, bq: int, k_in_lanes=True, mask_function=None, @@ -598,7 +604,7 @@ def _apply_mask_and_soft_cap( mask = pl.load(mask_ref, (k_slice, slice(None))) snm = jnp.where(should_not_mask, 1, 0) - masks.append(jnp.bitwise_or(mask, jnp.broadcast_to(snm, mask.shape))) + masks.append(jnp.bitwise_or(mask, jnp.broadcast_to(snm, mask.shape)) != 0) if mask_function is not None: # Compute the mask using the given q_sequence indices. @@ -628,7 +634,13 @@ def _apply_mask_and_soft_cap( q_sequence = jnp.broadcast_to(q_sequence, (k_slice.size, bq)) assert q_sequence.shape == k_sequence.shape - masks.append(mask_function(q_sequence, k_sequence)) # pytype: disable=wrong-arg-count + computed_mask = mask_function(q_sequence, k_sequence) # pytype: disable=wrong-arg-count + if computed_mask.dtype != jnp.dtype(jnp.bool_): + raise ValueError( + "Mask function must return a boolean-valued array, but got:" + f" {computed_mask.dtype}" + ) + masks.append(computed_mask) if q_segment_ids_ref is not None: if k_in_lanes: @@ -755,7 +767,7 @@ def body(kv_compute_index, _): qk = apply_mask_and_soft_cap() - m_curr = qk.max(axis=-1)[:, None] + m_curr = qk.max(axis=-1)[:, None] # pytype: disable=attribute-error assert m_curr.shape == (bq, 1) m_next = jnp.maximum(m_prev, m_curr) assert m_next.shape == (bq, NUM_LANES) @@ -910,7 +922,7 @@ def _splash_attention_forward( if bkv % bkv_compute: raise ValueError(f"{bkv=} must be a multiple of {bkv_compute=}.") if bkv_compute % NUM_LANES: - raise ValueError(f"{bkv=} must be a multiple of {NUM_LANES}.") + raise ValueError(f"{bkv_compute=} must be a multiple of {NUM_LANES}.") kv_seq_len = k.shape[kv_seq_len_dimension] @@ -973,25 +985,25 @@ def kv_segment_ids_index_map(h, i, j, data_next_ref, block_mask_ref, # Convert the logical shape from head-minor to sequence-minor. in_specs = [ pl.BlockSpec( - q_index_map, from_head_minor((None, bq, head_dim), q_layout) + from_head_minor((None, bq, head_dim), q_layout), q_index_map ), pl.BlockSpec( - k_index_map, from_head_minor( (bkv, head_dim) if is_mqa else (None, bkv, head_dim), k_layout ), + k_index_map, ), pl.BlockSpec( - v_index_map, from_head_minor( (bkv, head_dim) if is_mqa else (None, bkv, head_dim), v_layout ), + v_index_map, ), ] if segment_ids is not None: in_specs += [ - pl.BlockSpec(q_segment_ids_index_map, (bq, NUM_LANES)), - pl.BlockSpec(kv_segment_ids_index_map, (NUM_SUBLANES, bkv)), + pl.BlockSpec((bq, NUM_LANES), q_segment_ids_index_map), + pl.BlockSpec((NUM_SUBLANES, bkv), kv_segment_ids_index_map), ] q_segment_ids = jax.lax.broadcast_in_dim( segment_ids.q, (q_seq_len, NUM_LANES), (0,) @@ -1004,7 +1016,7 @@ def kv_segment_ids_index_map(h, i, j, data_next_ref, block_mask_ref, q_segment_ids = kv_segment_ids = None if fwd_mask_info.partial_mask_blocks is not None: - in_specs.append(pl.BlockSpec(mask_index_map, (None, bq, bkv))) + in_specs.append(pl.BlockSpec((None, bq, bkv), mask_index_map)) else: in_specs.append(None) @@ -1017,7 +1029,7 @@ def kv_segment_ids_index_map(h, i, j, data_next_ref, block_mask_ref, q_sequence = jax.lax.broadcast_in_dim( fwd_mask_info.q_sequence, (q_seq_len, NUM_LANES), (0,) ) - in_specs.append(pl.BlockSpec(q_segment_ids_index_map, (bq, NUM_LANES))) + in_specs.append(pl.BlockSpec((bq, NUM_LANES), q_segment_ids_index_map)) else: q_sequence = None in_specs.append(None) @@ -1032,10 +1044,10 @@ def kv_segment_ids_index_map(h, i, j, data_next_ref, block_mask_ref, ] out_specs = [ # TODO(sharadmv): convert m/l to be scratch - pl.BlockSpec(lambda h, i, j, *_: (0, 0), (bq, NUM_LANES)), - pl.BlockSpec(lambda h, i, j, *_: (0, 0), (bq, NUM_LANES)), - pl.BlockSpec(lambda h, i, j, *_: (0, 0), (bq, head_dim)), - pl.BlockSpec(out_index_map, (None, bq, head_dim)), + pl.BlockSpec((bq, NUM_LANES), lambda h, i, j, *_: (0, 0)), + pl.BlockSpec((bq, NUM_LANES), lambda h, i, j, *_: (0, 0)), + pl.BlockSpec((bq, head_dim), lambda h, i, j, *_: (0, 0)), + pl.BlockSpec((None, bq, head_dim), out_index_map), ] if save_residuals: out_shapes += [ @@ -1048,34 +1060,23 @@ def logsumexp_index_map(h, i, *_): return h, i, 0 out_specs += [ - pl.BlockSpec(logsumexp_index_map, (None, bq, NUM_LANES)), + pl.BlockSpec((None, bq, NUM_LANES), logsumexp_index_map), ] else: out_shapes += [None] out_specs += [None] - # Attach useful metadata to the custom-call HLO op. - # Having this information available in an HLO-dump or xprof is valuable for - # debugging and performance investigation. - metadata_dict = dict( - block_sizes=dataclasses.asdict(block_sizes), - is_mqa=is_mqa, - save_residuals=save_residuals, - mask_value=mask_value, - is_segmented=segment_ids is not None, - attn_logits_soft_cap=attn_logits_soft_cap, - residual_checkpoint_name=residual_checkpoint_name, - ) - - mosaic_params = pltpu.encode_kernel_regeneration_metadata(metadata_dict) - - mosaic_params.update( + mosaic_params = dict( dimension_semantics=("parallel", "arbitrary", "arbitrary"), flags={"XLA_TPU_FORCE_LP_LLO_SCHEDULER": True}, ) kernel_name = get_kernel_name( - is_mqa, save_residuals, segment_ids is not None, "fwd" + dataclasses.asdict(block_sizes), + is_mqa=is_mqa, + save_residuals=save_residuals, + is_segmented=segment_ids is not None, + phase="fwd", ) if fwd_mask_info.data_next is not None: @@ -1394,13 +1395,13 @@ def _splash_attention_bwd_dq( def o_index_map(h, i, *_): return h, i, 0 - o_spec = pl.BlockSpec(o_index_map, (None, bq, head_dim)) + o_spec = pl.BlockSpec((None, bq, head_dim), o_index_map) def q_index_map(h, i, *_): return from_head_minor((h, i, 0), q_layout) q_spec = pl.BlockSpec( - q_index_map, from_head_minor((None, bq, head_dim), q_layout) + from_head_minor((None, bq, head_dim), q_layout), q_index_map ) def k_index_map(h, i, j, data_next_ref, block_mask_ref, mask_next_ref, *_): @@ -1411,9 +1412,10 @@ def k_index_map(h, i, j, data_next_ref, block_mask_ref, mask_next_ref, *_): return from_head_minor((*prefix, next_j, 0), k_layout) k_spec = pl.BlockSpec( - k_index_map, from_head_minor( - (bkv, head_dim) if is_mqa else (None, bkv, head_dim), k_layout), + (bkv, head_dim) if is_mqa else (None, bkv, head_dim), k_layout + ), + k_index_map, ) def v_index_map(h, i, j, data_next_ref, block_mask_ref, mask_next_ref, *_): @@ -1424,9 +1426,10 @@ def v_index_map(h, i, j, data_next_ref, block_mask_ref, mask_next_ref, *_): return from_head_minor((*prefix, next_j, 0), v_layout) v_spec = pl.BlockSpec( - v_index_map, from_head_minor( - (bkv, head_dim) if is_mqa else (None, bkv, head_dim), v_layout), + (bkv, head_dim) if is_mqa else (None, bkv, head_dim), v_layout + ), + v_index_map, ) def mask_index_map(h, i, j, data_next_ref, block_mask_ref, mask_next_ref, *_): @@ -1435,7 +1438,7 @@ def mask_index_map(h, i, j, data_next_ref, block_mask_ref, mask_next_ref, *_): ) return next_m, 0, 0 - mask_spec = pl.BlockSpec(mask_index_map, (None, bq, bkv)) + mask_spec = pl.BlockSpec((None, bq, bkv), mask_index_map) def q_segment_ids_index_map(h, i, j, *_): del h, j # Unused. @@ -1451,9 +1454,10 @@ def kv_segment_ids_index_map( ) return 0, next_j - q_segment_spec = pl.BlockSpec(q_segment_ids_index_map, (bq, NUM_LANES)) - kv_segment_spec = pl.BlockSpec(kv_segment_ids_index_map, - (NUM_SUBLANES, bkv)) + q_segment_spec = pl.BlockSpec((bq, NUM_LANES), q_segment_ids_index_map) + kv_segment_spec = pl.BlockSpec( + (NUM_SUBLANES, bkv), kv_segment_ids_index_map + ) q_segment_ids = jax.lax.broadcast_in_dim( segment_ids.q, (q_seq_len, NUM_LANES), (0,) ) @@ -1470,11 +1474,11 @@ def logsumexp_index_map(h, i, *_): return h, 0, i logsumexp = jnp.expand_dims(logsumexp, axis=-2) - logsumexp_spec = pl.BlockSpec(logsumexp_index_map, (None, 1, bq)) + logsumexp_spec = pl.BlockSpec((None, 1, bq), logsumexp_index_map) assert logsumexp.ndim == len(logsumexp_spec.block_shape) di = jnp.expand_dims(di, axis=-2) - di_spec = pl.BlockSpec(logsumexp_index_map, (None, 1, bq)) + di_spec = pl.BlockSpec((None, 1, bq), logsumexp_index_map) assert di.ndim == len(di_spec.block_shape) in_specs = [ @@ -1498,7 +1502,7 @@ def logsumexp_index_map(h, i, *_): q_sequence = jax.lax.broadcast_in_dim( mask_info.q_sequence, (q_seq_len, NUM_LANES), (0,) ) - in_specs.append(pl.BlockSpec(q_segment_ids_index_map, (bq, NUM_LANES))) + in_specs.append(pl.BlockSpec((bq, NUM_LANES), q_segment_ids_index_map)) else: q_sequence = None in_specs.append(None) @@ -1508,7 +1512,7 @@ def logsumexp_index_map(h, i, *_): jax.ShapeDtypeStruct(q.shape, q.dtype), ] out_specs = [ - pl.BlockSpec(lambda *_: (0, 0), (bq, head_dim)), + pl.BlockSpec((bq, head_dim), lambda *_: (0, 0)), dq_spec, ] @@ -1526,28 +1530,24 @@ def logsumexp_index_map(h, i, *_): ) num_scalar_prefetch = 3 - # Attach useful metadata to the custom-call HLO op. - # Having this information available in an HLO-dump or xprof is valuable for - # debugging and performance investigation. - metadata_dict = dict( - block_q_dq=bq, - block_kv_dq=bkv, - q_layout=q_layout, - k_layout=k_layout, - v_layout=v_layout, - is_mqa=is_mqa, - mask_value=mask_value, - is_segmented=segment_ids is not None, - attn_logits_soft_cap=attn_logits_soft_cap, - ) - - mosaic_params = pltpu.encode_kernel_regeneration_metadata(metadata_dict) - mosaic_params.update( + mosaic_params = dict( dimension_semantics=("arbitrary", "arbitrary", "arbitrary"), flags={"XLA_TPU_FORCE_LP_LLO_SCHEDULER": True}, ) - kernel_name = get_kernel_name(is_mqa, False, segment_ids is not None, "dq") + kernel_name = get_kernel_name( + dict( + block_q_dq=bq, + block_kv_dq=bkv, + q_layout=q_layout, + k_layout=k_layout, + v_layout=v_layout, + ), + is_mqa=is_mqa, + save_residuals=False, + is_segmented=segment_ids is not None, + phase="dq", + ) with jax.named_scope(kernel_name): _, dq = pl.pallas_call( kernel, @@ -1853,7 +1853,7 @@ def o_index_map( ) return head_index, next_i, 0 - o_spec = pl.BlockSpec(o_index_map, (None, bq, head_dim)) + o_spec = pl.BlockSpec((None, bq, head_dim), o_index_map) def q_index_map( kv_index, @@ -1875,18 +1875,18 @@ def q_index_map( return from_head_minor((head_index, next_i, 0), q_layout) q_spec = pl.BlockSpec( - q_index_map, from_head_minor((None, bq, head_dim), q_layout)) + from_head_minor((None, bq, head_dim), q_layout), q_index_map) def k_index_map(kv_index, head_index, *_): prefix = () if is_mqa else (_div(head_index, q_heads_per_kv_head),) return from_head_minor((*prefix, kv_index, 0), k_layout) k_spec = pl.BlockSpec( - k_index_map, from_head_minor( (bkv, head_dim) if is_mqa else (None, bkv, head_dim), k_layout, ), + k_index_map, ) def v_index_map(kv_index, head_index, *_): @@ -1894,22 +1894,22 @@ def v_index_map(kv_index, head_index, *_): return from_head_minor((*prefix, kv_index, 0), v_layout) v_spec = pl.BlockSpec( - v_index_map, from_head_minor( (bkv, head_dim) if is_mqa else (None, bkv, head_dim), v_layout, ), + v_index_map, ) if use_fused_bwd_kernel: def dq_index_map(kv_index, head_index, q_index, *_): return (kv_index, head_index, q_index, 0) - dq_spec = pl.BlockSpec(dq_index_map, (None, None, bq, head_dim)) + dq_spec = pl.BlockSpec((None, None, bq, head_dim), dq_index_map) dq_shape = jax.ShapeDtypeStruct((kv_seq_len // bkv, *q.shape), q.dtype) if bkv == bkv_compute: dq_scratch_spec = dq_scratch_shape = None else: - dq_scratch_spec = pl.BlockSpec(lambda *_: (0, 0), (bq, head_dim)) + dq_scratch_spec = pl.BlockSpec((bq, head_dim), lambda *_: (0, 0)) dq_scratch_shape = jax.ShapeDtypeStruct((bq, head_dim), jnp.float32) else: dq_spec = dq_shape = dq_scratch_spec = dq_scratch_shape = None @@ -1919,8 +1919,8 @@ def dkv_index_map(kv_index, head_index, *_): return (*prefix, kv_index, 0) dk_spec = dv_spec = pl.BlockSpec( - dkv_index_map, (bkv, head_dim) if is_mqa else (None, bkv, head_dim), + dkv_index_map, ) def mask_index_map( @@ -1942,7 +1942,7 @@ def mask_index_map( ) return next_m, 0, 0 - mask_spec = pl.BlockSpec(mask_index_map, (None, bkv, bq)) + mask_spec = pl.BlockSpec((None, bkv, bq), mask_index_map) def q_segment_ids_index_map( kv_index, @@ -1967,9 +1967,8 @@ def q_segment_ids_index_map( def kv_segment_ids_index_map(kv_index, *_): return kv_index, 0 - q_segment_spec = pl.BlockSpec(q_segment_ids_index_map, (NUM_SUBLANES, bq)) - kv_segment_spec = pl.BlockSpec(kv_segment_ids_index_map, - (bkv, NUM_LANES)) + q_segment_spec = pl.BlockSpec((NUM_SUBLANES, bq), q_segment_ids_index_map) + kv_segment_spec = pl.BlockSpec((bkv, NUM_LANES), kv_segment_ids_index_map) q_segment_ids = jax.lax.broadcast_in_dim( segment_ids.q, (NUM_SUBLANES, q_seq_len), (1,) ) @@ -2005,12 +2004,12 @@ def logsumexp_index_map( # TODO(apaszke): Remove the sublane expansion once Mosaic has all retilings logsumexp_shape = (num_q_heads, NUM_SUBLANES, q_seq_len) logsumexp = jnp.broadcast_to(jnp.expand_dims(logsumexp, -2), logsumexp_shape) - logsumexp_spec = pl.BlockSpec(logsumexp_index_map, (None, NUM_SUBLANES, bq)) + logsumexp_spec = pl.BlockSpec((None, NUM_SUBLANES, bq), logsumexp_index_map) assert logsumexp.ndim == len(logsumexp_spec.block_shape) # TODO(apaszke): Remove the sublane expansion once Mosaic has all retilings di = jnp.broadcast_to(jnp.expand_dims(di, -2), logsumexp_shape) - di_spec = pl.BlockSpec(logsumexp_index_map, (None, NUM_SUBLANES, bq)) + di_spec = pl.BlockSpec((None, NUM_SUBLANES, bq), logsumexp_index_map) assert di.ndim == len(di_spec.block_shape) in_specs = [ @@ -2029,7 +2028,7 @@ def logsumexp_index_map( in_specs.append(None) if mask_info.q_sequence is not None: - in_specs.append(pl.BlockSpec(q_segment_ids_index_map, (NUM_SUBLANES, bq))) + in_specs.append(pl.BlockSpec((NUM_SUBLANES, bq), q_segment_ids_index_map)) q_sequence = jax.lax.broadcast_in_dim( mask_info.q_sequence, (NUM_SUBLANES, q_seq_len), (1,) ) @@ -2047,8 +2046,8 @@ def logsumexp_index_map( ] out_specs = [ dq_scratch_spec, - pl.BlockSpec(lambda *_: (0, 0), (bkv, head_dim)), - pl.BlockSpec(lambda *_: (0, 0), (bkv, head_dim)), + pl.BlockSpec((bkv, head_dim), lambda *_: (0, 0)), + pl.BlockSpec((bkv, head_dim), lambda *_: (0, 0)), dq_spec, dk_spec, dv_spec, @@ -2072,35 +2071,30 @@ def logsumexp_index_map( ) num_scalar_prefetch = 3 - # Attach useful metadata to the custom-call HLO op. - # Having this information available in an HLO-dump or xprof is valuable for - # debugging and performance investigation. - metadata_dict = dict( - block_q_dkv=bq, - block_kv_dkv=bkv, - block_kv_dkv_compute=bkv_compute, - q_layout=q_layout, - k_layout=k_layout, - v_layout=v_layout, - use_fused_bwd_kernel=use_fused_bwd_kernel, - is_mqa=is_mqa, - mask_value=mask_value, - is_segmented=segment_ids is not None, - attn_logits_soft_cap=attn_logits_soft_cap, - ) - - mosaic_params = pltpu.encode_kernel_regeneration_metadata(metadata_dict) # We set all dimensions to arbitrary because: # 1) for kv_seq_len, the splash attention prefetch schedule assumes no # megacore # 2) for heads, we are reducing over heads # 3) for q_seq_len, we are reducing over it to compute dkv - mosaic_params.update( + mosaic_params = dict( dimension_semantics=("arbitrary", "arbitrary", "arbitrary"), flags={"XLA_TPU_FORCE_LP_LLO_SCHEDULER": True}, ) - kernel_name = get_kernel_name(is_mqa, False, segment_ids is not None, "dkv") + kernel_name = get_kernel_name( + dict( + block_q_dkv=bq, + block_kv_dkv=bkv, + block_kv_dkv_compute=bkv_compute, + q_layout=q_layout, + k_layout=k_layout, + v_layout=v_layout, + ), + is_mqa=is_mqa, + save_residuals=False, + is_segmented=segment_ids is not None, + phase="dkv", + ) with jax.named_scope(kernel_name): _, _, _, dq_unreduced, dk, dv = pl.pallas_call( kernel, diff --git a/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_mask.py b/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_mask.py index e65d9b073a18..eab2a695dc02 100644 --- a/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_mask.py +++ b/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_mask.py @@ -16,8 +16,9 @@ from __future__ import annotations +from collections.abc import Callable, Sequence import dataclasses -from typing import Any, Callable, Sequence, Tuple +from typing import Any import numpy as np # mypy: ignore-errors @@ -26,7 +27,7 @@ class Mask: """A base class for splash attention masks.""" @property - def shape(self) -> Tuple[int, ...]: + def shape(self) -> tuple[int, ...]: raise NotImplementedError def __getitem__(self, idx) -> np.ndarray: @@ -38,14 +39,14 @@ def __bool__(self) -> bool: ' instead of bitwise operations on masks.' ) - def __or__(self, other: 'Mask') -> 'Mask': + def __or__(self, other: Mask) -> Mask: if self.shape != other.shape: raise ValueError( f'Invalid shape for other: {other.shape}, expected: {self.shape}' ) return LogicalOr(self, other) - def __and__(self, other: 'Mask') -> 'Mask': + def __and__(self, other: Mask) -> Mask: if self.shape != other.shape: raise ValueError( f'Invalid shape for other: {other.shape}, expected: {self.shape}' @@ -53,7 +54,7 @@ def __and__(self, other: 'Mask') -> 'Mask': return LogicalAnd(self, other) -def make_causal_mask(shape: Tuple[int, int], offset: int = 0) -> np.ndarray: +def make_causal_mask(shape: tuple[int, int], offset: int = 0) -> np.ndarray: """Makes a causal attention mask. Args: @@ -73,8 +74,8 @@ def make_causal_mask(shape: Tuple[int, int], offset: int = 0) -> np.ndarray: def make_local_attention_mask( - shape: Tuple[int, int], - window_size: Tuple[int | None, int | None], + shape: tuple[int, int], + window_size: tuple[int | None, int | None], *, offset: int = 0, ) -> np.ndarray: @@ -92,7 +93,7 @@ def make_local_attention_mask( def make_random_mask( - shape: Tuple[int, int], sparsity: float, seed: int + shape: tuple[int, int], sparsity: float, seed: int ) -> np.ndarray: """Makes a random attention mask.""" np.random.seed(seed) @@ -111,7 +112,7 @@ def __init__(self, left: Mask, right: Mask): self.right = right @property - def shape(self) -> Tuple[int, ...]: + def shape(self) -> tuple[int, ...]: return self.left.shape def __getitem__(self, idx) -> np.ndarray: @@ -133,7 +134,7 @@ def __init__(self, left: Mask, right: Mask): self.right = right @property - def shape(self) -> Tuple[int, ...]: + def shape(self) -> tuple[int, ...]: return self.left.shape def __getitem__(self, idx) -> np.ndarray: @@ -167,7 +168,7 @@ def __post_init__(self): raise ValueError('Nesting MultiHeadMasks is not supported') @property - def shape(self) -> Tuple[int, ...]: + def shape(self) -> tuple[int, ...]: return (len(self.masks),) + self.masks[0].shape def __getitem__(self, idx) -> np.ndarray: @@ -208,13 +209,13 @@ class _ComputableMask(Mask): mask rather than loading it. """ - _shape: Tuple[int, int] + _shape: tuple[int, int] q_sequence: np.ndarray mask_function: Callable[..., Any] def __init__( self, - shape: Tuple[int, int], + shape: tuple[int, int], mask_function: Callable[..., Any], shard_count: int = 1, ): @@ -231,7 +232,7 @@ def __init__( self.q_sequence = np.arange(q_seq_len, dtype=np.int32) @property - def shape(self) -> Tuple[int, ...]: + def shape(self) -> tuple[int, ...]: return self._shape def __getitem__(self, idx) -> np.ndarray: @@ -271,7 +272,7 @@ class CausalMask(_ComputableMask): def __init__( self, - shape: Tuple[int, int], + shape: tuple[int, int], offset: int = 0, shard_count: int = 1, ): @@ -329,15 +330,15 @@ class LocalMask(Mask): # TODO(amagni): Transform LocalMask into a _ComputableMask. - _shape: Tuple[int, int] - window_size: Tuple[int | None, int | None] + _shape: tuple[int, int] + window_size: tuple[int | None, int | None] offset: int _q_sequence: np.ndarray | None = None def __init__( self, - shape: Tuple[int, int], - window_size: Tuple[int | None, int | None], + shape: tuple[int, int], + window_size: tuple[int | None, int | None], offset: int, shard_count: int = 1, ): @@ -352,7 +353,7 @@ def __init__( ) @property - def shape(self) -> Tuple[int, int]: + def shape(self) -> tuple[int, int]: return self._shape def __getitem__(self, idx) -> np.ndarray: @@ -429,7 +430,7 @@ def __post_init__(self): raise ValueError('Mask must be a boolean array') @property - def shape(self) -> Tuple[int, ...]: + def shape(self) -> tuple[int, ...]: return self.array.shape def __getitem__(self, idx) -> np.ndarray: @@ -467,7 +468,7 @@ def __post_init__(self): raise ValueError(f'Unsupported shape type: {type(self.shape)}') @property - def shape(self) -> Tuple[int, ...]: + def shape(self) -> tuple[int, ...]: return self._shape def __getitem__(self, idx) -> np.ndarray: diff --git a/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_mask_info.py b/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_mask_info.py index 051c367b2e2c..3c672b8dbe88 100644 --- a/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_mask_info.py +++ b/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_mask_info.py @@ -16,8 +16,9 @@ from __future__ import annotations import collections +from collections.abc import Callable import functools -from typing import Callable, Dict, List, NamedTuple, Set, Tuple +from typing import Dict, List, NamedTuple, Set, Tuple from jax import util as jax_util from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_mask as mask_lib import numpy as np @@ -161,11 +162,11 @@ def __eq__(self, other: object) -> bool: def _get_mask_info_for_shard( - output_shape: Tuple[int, int, int], + output_shape: tuple[int, int, int], has_mask_next: bool, mask: mask_lib.MultiHeadMask, - block_shape: Tuple[int, int], - coords_to_partial_mask_block_index: Dict[Tuple[int, int, int], int], + block_shape: tuple[int, int], + coords_to_partial_mask_block_index: dict[tuple[int, int, int], int], masks_per_head_shard: int, head_start: int, num_heads: int, @@ -173,7 +174,7 @@ def _get_mask_info_for_shard( q_seq_shard_size: int, blocked_q_seq_start: int, is_dkv: bool, -) -> Tuple[np.ndarray, np.ndarray | None]: +) -> tuple[np.ndarray, np.ndarray | None]: """Process a slice of the mask to compute data_next and mask_next. Args: @@ -310,7 +311,7 @@ def _get_mask_info_for_shard( @functools.lru_cache(maxsize=12) def _process_mask( mask: mask_lib.MultiHeadMask, # [num_heads, q_seq_len, kv_seq_len] - block_shape: Tuple[int, int], + block_shape: tuple[int, int], is_dkv: bool, *, downcast_smem_data: bool = True, @@ -394,18 +395,18 @@ def assign_unique_ids(objects): id_map = collections.defaultdict(lambda: len(id_map)) return {obj: id_map[obj] for obj in objects} - unique_masks_dict: Dict[mask_lib.Mask, int] = assign_unique_ids( + unique_masks_dict: dict[mask_lib.Mask, int] = assign_unique_ids( head_mask for head_mask in mask.masks ) # Build a mapping of heads to unique masks and masks to unique masks. - head_to_mask_id: List[int] = [0] * head_count - head_shard_to_mask_ids: List[Set[int]] = [set() for _ in range(head_shards)] - mask_id_to_heads: List[List[int]] = [ + head_to_mask_id: list[int] = [0] * head_count + head_shard_to_mask_ids: list[set[int]] = [set() for _ in range(head_shards)] + mask_id_to_heads: list[list[int]] = [ [] for _ in range(len(unique_masks_dict)) ] - mask_id_to_head_shards: List[Set[int]] = [ + mask_id_to_head_shards: list[set[int]] = [ set() for _ in range(len(unique_masks_dict)) ] @@ -426,7 +427,7 @@ def assign_unique_ids(objects): # MaskInfo class and runtime overhead to perform an indirect lookup. Since # having multiple masks per head-shard is not a common case we leave this for # future work. - max_masks_per_head_shard = max([len(x) for x in head_shard_to_mask_ids]) + max_masks_per_head_shard = max(len(x) for x in head_shard_to_mask_ids) masks_per_head_shard = 1 if max_masks_per_head_shard == 1 else heads_per_shard unique_masks = [ @@ -436,10 +437,10 @@ def assign_unique_ids(objects): # TODO(amagni): checking the validity of the masks is slow for large masks. # Disable it for now, reevalute in the future. - partial_mask_block_ids: Dict[_HashableNDArray, int] = collections.defaultdict( + partial_mask_block_ids: dict[_HashableNDArray, int] = collections.defaultdict( lambda: len(partial_mask_block_ids) ) - block_id_to_block_coords: Dict[int, List[Tuple[int, ...]]] = ( + block_id_to_block_coords: dict[int, list[tuple[int, ...]]] = ( collections.defaultdict(list) ) @@ -697,7 +698,7 @@ def set_block_mask(mask_id: int, q_index: int, kv_index: int, value: int): # maintain the SPMD paradigm. padding_axis = 1 if is_dkv else 2 - max_size = max([x.shape[padding_axis] for x in block_mask_shards]) + max_size = max(x.shape[padding_axis] for x in block_mask_shards) padded_block_mask_shards = [] padded_data_next_shards = [] padded_mask_next_shards = [] @@ -791,7 +792,7 @@ def _shrink_mask_info( # Pad each row in the non-zero indices to match the width of the longest # row. This avoids having jagged rows. - max_non_zero_cols = max([len(x) for x in grouped_non_zero_cols]) + max_non_zero_cols = max(len(x) for x in grouped_non_zero_cols) padded_non_zero_cols = [] padding = -1 for row in grouped_non_zero_cols: @@ -856,7 +857,7 @@ def _shrink_mask_info_dkv( # Pad each col in the non-zero indices to match the height of the longest # col. This avoids having jagged cols. - max_non_zero_rows = max([len(x) for x in grouped_non_zero_rows]) + max_non_zero_rows = max(len(x) for x in grouped_non_zero_rows) padded_non_zero_rows = [] padding = -1 for col in grouped_non_zero_rows: diff --git a/jax/experimental/pallas/tpu.py b/jax/experimental/pallas/tpu.py index 937de43eeb27..ad5fb92719d0 100644 --- a/jax/experimental/pallas/tpu.py +++ b/jax/experimental/pallas/tpu.py @@ -12,35 +12,42 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Contains Mosaic specific Pallas functions.""" -from jax._src.pallas.mosaic import ANY -from jax._src.pallas.mosaic import CMEM -from jax._src.pallas.mosaic import PrefetchScalarGridSpec -from jax._src.pallas.mosaic import SMEM -from jax._src.pallas.mosaic import SemaphoreType -from jax._src.pallas.mosaic import TPUMemorySpace -from jax._src.pallas.mosaic import VMEM -from jax._src.pallas.mosaic import DeviceIdType -from jax._src.pallas.mosaic import async_copy -from jax._src.pallas.mosaic import async_remote_copy -from jax._src.pallas.mosaic import bitcast -from jax._src.pallas.mosaic import dma_semaphore -from jax._src.pallas.mosaic import device_id -from jax._src.pallas.mosaic import emit_pipeline_with_allocations -from jax._src.pallas.mosaic import emit_pipeline -from jax._src.pallas.mosaic import PipelineCallbackArgs -from jax._src.pallas.mosaic import PipelinePrefetchArgs -from jax._src.pallas.mosaic import ManualPrefetchArgs -from jax._src.pallas.mosaic import encode_kernel_regeneration_metadata -from jax._src.pallas.mosaic import extract_kernel_regeneration_metadata -from jax._src.pallas.mosaic import get_barrier_semaphore -from jax._src.pallas.mosaic import make_async_copy -from jax._src.pallas.mosaic import make_async_remote_copy -from jax._src.pallas.mosaic import repeat -from jax._src.pallas.mosaic import roll -from jax._src.pallas.mosaic import run_scoped -from jax._src.pallas.mosaic import semaphore -from jax._src.pallas.mosaic import semaphore_signal -from jax._src.pallas.mosaic import semaphore_wait -from jax._src.pallas.mosaic import trace +"""Mosaic-specific Pallas APIs.""" + +from jax._src.pallas.mosaic import core +from jax._src.pallas.mosaic.core import dma_semaphore +from jax._src.pallas.mosaic.core import PrefetchScalarGridSpec +from jax._src.pallas.mosaic.core import semaphore +from jax._src.pallas.mosaic.core import SemaphoreType +from jax._src.pallas.mosaic.core import TPUMemorySpace +from jax._src.pallas.mosaic.lowering import LoweringException +from jax._src.pallas.mosaic.pipeline import BufferedRef +from jax._src.pallas.mosaic.pipeline import emit_pipeline +from jax._src.pallas.mosaic.pipeline import emit_pipeline_with_allocations +from jax._src.pallas.mosaic.pipeline import get_pipeline_schedule +from jax._src.pallas.mosaic.pipeline import make_pipeline_allocations +from jax._src.pallas.mosaic.pipeline import ARBITRARY +from jax._src.pallas.mosaic.pipeline import PARALLEL +from jax._src.pallas.mosaic.primitives import async_copy +from jax._src.pallas.mosaic.primitives import async_remote_copy +from jax._src.pallas.mosaic.primitives import bitcast +from jax._src.pallas.mosaic.primitives import delay +from jax._src.pallas.mosaic.primitives import device_id +from jax._src.pallas.mosaic.primitives import DeviceIdType +from jax._src.pallas.mosaic.primitives import get_barrier_semaphore +from jax._src.pallas.mosaic.primitives import make_async_copy +from jax._src.pallas.mosaic.primitives import make_async_remote_copy +from jax._src.pallas.mosaic.primitives import repeat +from jax._src.pallas.mosaic.primitives import roll +from jax._src.pallas.mosaic.primitives import run_scoped +from jax._src.pallas.mosaic.primitives import semaphore_read +from jax._src.pallas.mosaic.primitives import semaphore_signal +from jax._src.pallas.mosaic.primitives import semaphore_wait +from jax._src.pallas.mosaic.primitives import prng_seed +from jax._src.pallas.mosaic.primitives import prng_random_bits from jax._src.tpu_custom_call import CostEstimate + +ANY = TPUMemorySpace.ANY +CMEM = TPUMemorySpace.CMEM +SMEM = TPUMemorySpace.SMEM +VMEM = TPUMemorySpace.VMEM diff --git a/jax/experimental/rnn.py b/jax/experimental/rnn.py index 961bee9b1ef1..f5131365cb50 100644 --- a/jax/experimental/rnn.py +++ b/jax/experimental/rnn.py @@ -162,7 +162,7 @@ def get_num_params_in_lstm(input_size: int, hidden_size: int, num_layers: int, """Get param count in LSTM.""" layer_shapes = _get_params_shapes_in_lstm(input_size, hidden_size, num_layers, bidirectional) - param_count = sum([math.prod(shape) for shape in layer_shapes]) + param_count = sum(math.prod(shape) for shape in layer_shapes) return param_count @@ -466,7 +466,7 @@ def lstm_bwd(input_size: int, hidden_size: int, num_layers: int, dropout: float, return (dx, dh_0, dc_0, dw, jnp.zeros_like(seq_lengths)) -def rnn_bwd_abstract_eval(dy_aval, dhn_aval, dcn_aval, x_aval, h0_aval, c0_aval, # type: ignore +def rnn_bwd_abstract_eval(dy_aval, dhn_aval, dcn_aval, x_aval, h0_aval, c0_aval, w_aval, y_aval, reserve_space_aval, seq_lengths_aval, input_size: int, hidden_size: int, num_layers: int, dropout: float, bidirectional: bool, diff --git a/jax/experimental/shard_map.py b/jax/experimental/shard_map.py index 71250a34cf4c..4957df4866f0 100644 --- a/jax/experimental/shard_map.py +++ b/jax/experimental/shard_map.py @@ -13,14 +13,14 @@ # limitations under the License. from __future__ import annotations -from collections.abc import Hashable, Sequence +from collections.abc import Callable, Hashable, Sequence import enum from functools import partial import inspect import itertools as it from math import prod import operator as op -from typing import Any, Callable, Optional, TypeVar, Union +from typing import Any, TypeVar, Union import numpy as np @@ -29,7 +29,6 @@ from jax.sharding import NamedSharding, PartitionSpec, Mesh from jax._src import ad_checkpoint from jax._src import ad_util -from jax._src import array from jax._src import callback from jax._src import core from jax._src import custom_derivatives @@ -40,6 +39,7 @@ from jax._src import ops from jax._src import pjit from jax._src import prng +from jax._src import random from jax._src import sharding_impls from jax._src import source_info_util from jax._src import traceback_util @@ -51,9 +51,8 @@ special, control_flow, ann) from jax._src.util import (HashableFunction, HashablePartial, unzip2, unzip3, as_hashable_function, memoize, partition_list, - merge_lists, split_list, subs_list2, - weakref_lru_cache) -from jax.api_util import flatten_fun_nokwargs, shaped_abstractify + merge_lists, split_list, subs_list2) +from jax.api_util import flatten_fun_nokwargs, shaped_abstractify, argnums_partial from jax._src.interpreters import batching from jax._src.interpreters import mlir from jax._src.interpreters import partial_eval as pe @@ -81,6 +80,57 @@ @traceback_util.api_boundary def shard_map(f: Callable, mesh: Mesh, in_specs: Specs, out_specs: Specs, check_rep: bool = True, auto: frozenset[AxisName] = frozenset()): + """Map a function over shards of data. + + Note: + ``shard_map`` is an experimental API, and still subject to change. For an + introduction to sharded data, refer to :ref:`sharded-computation`. For a more + in-depth look at using ``shard_map``, refer to `SPMD multi-device parallelism with shard_map`_. + + Args: + f: callable to be mapped. Each application of ``f``, or "instance" of ``f``, + takes as input a shard of the mapped-over arguments and produces a shard + of the output. + mesh: a ``jax.sharding.Mesh`` representing the array of devices over which + to shard the data and on which to execute instances of ``f``. The names of + the ``Mesh`` can be used in collective communication operations in ``f``. + This is typically created by a utility function like + :func:`jax.experimental.mesh_utils.create_device_mesh`. + in_specs: a pytree with :class:`~jax.sharding.PartitionSpec` instances as leaves, + with a tree structure that is a tree prefix of the args tuple to be mapped + over. Similar to :class:`~jax.sharding.NamedSharding`, each ``PartitionSpec`` + represents how the corresponding argument (or subtree of arguments) should + be sharded along the named axes of ``mesh``. In each ``PartitionSpec``, + mentioning a ``mesh`` axis name at a position expresses sharding the + corresponding argument array axis along that positional axis; not + mentioning an axis name expresses replication. If an argument, or argument + subtree, has a corresponding spec of None, that argument is not sharded. + out_specs: a pytree with :class:`~jax.sharding.PartitionSpec` instances as leaves, + with a tree structure that is a tree prefix of the output of ``f``. Each + ``PartitionSpec`` represents how the corresponding output shards should be + concatenated. In each ``PartitionSpec``, metioning a ``mesh`` axis name at + a position expresses concatenation of that mesh axis's shards along the + corresponding positional axis. Not mentioning a ``mesh`` axis name + expresses a promise that the output values are equal along that mesh axis, + and that rather than concatenating only a single value should be produced. + check_rep: If True (default) enable additional validity checks and automatic + differentiation optimizations. The validity checks concern whether any mesh + axis names not mentioned in ``out_specs`` are consistent with how the outputs + of ``f`` are replicated. Must be set False if using a Pallas kernel in ``f``. + auto: (experimental) an optional set of axis names from ``mesh`` over which we + do not shard the data or map the function, but rather we allow the + compiler to control sharding. These names cannot be used in ``in_specs``, + ``out_specs``, or in communication collectives in ``f``. + + Returns: + A callable that applies the input function ``f`` across data sharded according to + the ``mesh`` and ``in_specs``. + + Examples: + For examples, refer to :ref:`sharded-computation` or `SPMD multi-device parallelism with shard_map`_. + + .. _SPMD multi-device parallelism with shard_map: https://jax.readthedocs.io/en/latest/notebooks/shard_map.html + """ return _shard_map(f, mesh, in_specs, out_specs, check_rep, auto) def _shard_map(f: Callable, mesh: Mesh, in_specs: Specs, @@ -92,28 +142,35 @@ def _shard_map(f: Callable, mesh: Mesh, in_specs: Specs, if not isinstance(mesh, Mesh): raise TypeError("shard_map requires a `jax.sharding.Mesh` instance for its " f"second argument, but got {mesh} of type {type(mesh)}.") - _check_specs(SpecErrorType.input, in_specs) + if not auto.issubset(mesh.axis_names): + raise ValueError(f"shard_map requires auto={auto} to be a subset of " + f"mesh.axis_names={mesh.axis_names}") + _check_specs(SpecErrorType.input, in_specs, auto) if not callable(out_specs): - _check_specs(SpecErrorType.out, out_specs) + _check_specs(SpecErrorType.out, out_specs, auto) @util.wraps(f) @traceback_util.api_boundary def wrapped(*args): fun = lu.wrap_init(f) args_flat, in_tree = tree_flatten(args) - try: in_specs_flat = broadcast_prefix(in_specs, args) + fun, out_tree = flatten_fun_nokwargs(fun, in_tree) + try: in_specs_flat = broadcast_prefix(in_specs, args, + is_leaf=lambda x: x is None) except ValueError: e, *_ = prefix_errors(in_specs, args) raise e('shard_map in_specs') from None - _check_specs_vs_args(f, mesh, in_tree, in_specs, in_specs_flat, args_flat) + dyn_argnums, in_specs_flat = unzip2((i, s) for i, s in enumerate(in_specs_flat) + if s is not None) + fun, args_flat = argnums_partial(fun, dyn_argnums, args_flat) + _check_specs_vs_args(f, mesh, in_tree, in_specs, dyn_argnums, in_specs_flat, args_flat) in_names_flat = tuple(map(_canonicalize_spec, in_specs_flat)) - fun, out_tree = flatten_fun_nokwargs(fun, in_tree) @memoize def out_names_thunk(): if callable(out_specs): out_specs_ = out_specs() - _check_specs(SpecErrorType.out, out_specs_) + _check_specs(SpecErrorType.out, out_specs_, auto) else: out_specs_ = out_specs dummy = tree_unflatten(out_tree(), [object()] * out_tree().num_leaves) @@ -161,17 +218,40 @@ def _canonicalize_spec(spec: PartitionSpec) -> AxisNames: SpecErrorType = enum.Enum('SpecErrorType', ['input', 'out']) -def _check_specs(error_type: SpecErrorType, specs: Any) -> None: +def _check_specs(error_type: SpecErrorType, specs: Any, auto) -> None: if error_type == SpecErrorType.input and specs is None: raise TypeError( "shard_map in_specs argument must be a pytree of " "`jax.sharding.PartitionSpec` instances, but it was None.\n" "Instead of `in_specs=None`, did you mean `in_specs=P()`, " "where `P = jax.sharding.PartitionSpec`?") - if all(isinstance(p, PartitionSpec) for p in tree_leaves(specs)): return + def check_spec(p): + if not isinstance(p, PartitionSpec): + return False + for names in p: + if not isinstance(names, tuple): + names = (names,) + for name in names: + if name in auto: + return False + return True + if all(check_spec(p) for p in tree_leaves(specs)): return prefix = 'in' if error_type == SpecErrorType.input else 'out' msgs = [f" {prefix}_specs{keystr(key)} is {x} of type {type(x).__name__}, " for key, x in generate_key_paths(specs) if not isinstance(x, P)] + if not msgs: + for key, p in generate_key_paths(specs): + for names in p: + if not isinstance(names, tuple): + names = (names,) + for name in names: + if name in auto: + msgs.append(f" {prefix}_specs{keystr(key)} refers to {repr(name)}") + raise ValueError( + f"shard_map {prefix}_specs argument cannot refer to an axis " + f"marked auto ({auto}), but:\n\n" + + '\n\n'.join(msgs) + '\n\n' + f"Check the {prefix}_specs values passed to shard_map.") raise TypeError( f"shard_map {prefix}_specs argument must be a pytree of " f"`jax.sharding.PartitionSpec` instances, but:\n\n" @@ -183,11 +263,13 @@ class NoFail: pass def _check_specs_vs_args( f: Callable, mesh: Mesh, in_tree: PyTreeDef, in_specs: Specs, - in_specs_flat: list[P], xs: list) -> None: + dyn_argnums: Sequence[int], in_specs_flat: Sequence[P], + xs: Sequence) -> None: in_avals = map(shaped_abstractify, xs) fail = [a if not len(p) <= a.ndim else no_fail for p, a in zip(in_specs_flat, in_avals)] if any(f is not no_fail for f in fail): + fail = _expand_fail(in_tree, dyn_argnums, fail) msg = _spec_rank_error(SpecErrorType.input, f, in_tree, in_specs, fail) raise ValueError(msg) in_names_flat = tuple(map(_canonicalize_spec, in_specs_flat)) @@ -195,9 +277,18 @@ def _check_specs_vs_args( for d, ns in names.items()) else no_fail for a, names in zip(in_avals, in_names_flat)] if any(f is not no_fail for f in fail): + fail = _expand_fail(in_tree, dyn_argnums, fail) msg = _spec_divisibility_error(f, mesh, in_tree, in_specs, fail) raise ValueError(msg) +def _expand_fail(in_tree: PyTreeDef, dyn_argnums: Sequence[int], + fail: Sequence[core.ShapedArray | NoFail] + ) -> list[core.ShapedArray | NoFail]: + fail_: list[core.ShapedArray | NoFail] = [no_fail] * in_tree.num_leaves + for i, f in zip(dyn_argnums, fail): + fail_[i] = f + return fail_ + def _spec_rank_error( error_type: SpecErrorType, f: Callable, tree: PyTreeDef, specs: Specs, fails: list[core.ShapedArray | NoFail]) -> str: @@ -329,6 +420,7 @@ def _unmentioned(mesh: Mesh, names: AxisNames) -> list[AxisName]: name_set = {n for ns in names.values() for n in ns} return [n for n in mesh.axis_names if n not in name_set] + def _try_infer_args(f, tree): dummy_args = tree_unflatten(tree, [False] * tree.num_leaves) try: @@ -342,11 +434,11 @@ def _iter_paths(tree: PyTreeDef, specs: Specs, fails: list[T | NoFail] failures = tree_unflatten(tree, fails) failures_aug = generate_key_paths(failures) specs_ = tree_unflatten(tree_structure(specs), generate_key_paths(specs)) - leaf = lambda x: type(x) is tuple and len(x) == 2 and type(x[1]) is P + leaf = lambda x: x is None or type(x) is tuple and len(x) == 2 and type(x[1]) is P specs_aug = broadcast_prefix(specs_, failures, is_leaf=leaf) - return [((spec_key, spec), (fail_key, fail_data)) - for (spec_key, spec), (fail_key, fail_data) - in zip(specs_aug, failures_aug) if fail_data is not no_fail] + return [(s, (fail_key, fail_data)) for s, (fail_key, fail_data) + in zip(specs_aug, failures_aug) + if s is not None and fail_data is not no_fail] # Primitive @@ -426,9 +518,7 @@ def _shard_map_staging( in_avals_ = map(partial(_shard_aval, mesh), in_names, in_avals) main = trace.main with core.new_sublevel(), core.extend_axis_env_nd(mesh.shape.items()): - jaxpr, genavals, consts, () = pe.trace_to_subjaxpr_dynamic( - f, main, in_avals_ - ) + jaxpr, genavals, consts, () = pe.trace_to_subjaxpr_dynamic(f, main, in_avals_) out_avals_ = map(_check_shapedarray, genavals) _check_names(out_names_thunk(), out_avals_) in_rep = map(partial(_in_names_to_rep, mesh), in_names) @@ -548,7 +638,7 @@ def _shard_map_lowering(ctx, *in_nodes, jaxpr, mesh, in_names, out_names, in_nodes_ = map(partial(_xla_shard, ctx, mesh, auto), in_names, ctx.avals_in, in_avals_, in_nodes) new_axis_context = sharding_impls.SPMDAxisContext( - mesh, frozenset(mesh.axis_names) + mesh, frozenset(mesh.axis_names) - auto ) sub_ctx = ctx.module_context.replace(axis_context=new_axis_context) with core.extend_axis_env_nd(tuple(mesh.shape.items())): @@ -562,34 +652,44 @@ def _shard_map_lowering(ctx, *in_nodes, jaxpr, mesh, in_names, out_names, ctx.avals_out, out_nodes_) mlir.register_lowering(shard_map_p, _shard_map_lowering) +def _make_scoped_manual_sharding(ctx, mesh, axes): + axis_ctx = ctx.module_context.axis_context + if isinstance(axis_ctx, sharding_impls.SPMDAxisContext): + manual_axes = axis_ctx.manual_axes + else: + manual_axes = frozenset({}) + return NamedSharding( + mesh, sharding_impls.array_mapping_to_axis_resources(axes), # pytype: disable=wrong-arg-types + _manual_axes=manual_axes) + def _xla_shard(ctx: mlir.LoweringRuleContext, mesh, auto, names, aval_in, aval_out, x): manual_proto = pxla.manual_proto(aval_in, frozenset(mesh.axis_names) - auto, mesh) axes = {name: i for i, ns in names.items() for name in ns} - shard_proto = NamedSharding( - mesh, sharding_impls.array_mapping_to_axis_resources(axes) # type: ignore - )._to_xla_hlo_sharding(aval_in.ndim) + ns = _make_scoped_manual_sharding(ctx, mesh, axes) if dtypes.issubdtype(aval_in.dtype, dtypes.extended): - shard_proto = aval_in.dtype._rules.physical_hlo_sharding(aval_in, shard_proto) + ns = sharding_impls.physical_sharding(aval_in, ns) + aval_in = core.physical_aval(aval_in) + shard_proto = ns._to_xla_hlo_sharding(aval_in.ndim).to_proto() unspecified = set(range(aval_in.ndim)) if auto else set() - sx = mlir.wrap_with_sharding_op(ctx, x, aval_in, shard_proto.to_proto(), # type: ignore + sx = mlir.wrap_with_sharding_op(ctx, x, aval_in, shard_proto, unspecified_dims=unspecified) - return [mlir.wrap_with_full_to_shard_op(ctx, sx, aval_out, manual_proto, set())] + return [mlir.wrap_with_full_to_shard_op(ctx, sx, aval_out, manual_proto, unspecified)] def _xla_unshard(ctx: mlir.LoweringRuleContext, mesh, auto, names, aval_in, aval_out, xs): x, = xs - manual_proto = pxla.manual_proto(aval_in, frozenset(mesh.axis_names) - auto, mesh) - sx = mlir.wrap_with_sharding_op(ctx, x, aval_in, manual_proto, unspecified_dims=set()) axes = {name: i for i, ns in names.items() for name in ns} - shard_proto = NamedSharding( - mesh, sharding_impls.array_mapping_to_axis_resources(axes) # type: ignore - )._to_xla_hlo_sharding(aval_out.ndim) + ns = _make_scoped_manual_sharding(ctx, mesh, axes) if dtypes.issubdtype(aval_out.dtype, dtypes.extended): - shard_proto = aval_out.dtype._rules.physical_hlo_sharding(aval_out, shard_proto) + ns = sharding_impls.physical_sharding(aval_out, ns) + aval_out = core.physical_aval(aval_out) unspecified = set(range(aval_out.ndim)) if auto else set() - return mlir.wrap_with_shard_to_full_op(ctx, sx, aval_out, shard_proto.to_proto(), - unspecified) # type: ignore + manual_proto = pxla.manual_proto(aval_in, frozenset(mesh.axis_names) - auto, mesh) + sx = mlir.wrap_with_sharding_op(ctx, x, aval_in, manual_proto, unspecified_dims=unspecified) + shard_proto = ns._to_xla_hlo_sharding(aval_out.ndim).to_proto() + return mlir.wrap_with_shard_to_full_op(ctx, sx, aval_out, shard_proto, + unspecified) def _pspec_mhlo_attrs(names: AxisNames, aval: core.AbstractValue) -> str: if isinstance(aval, core.ShapedArray): @@ -811,13 +911,13 @@ def _debug_callback_eager_rule(mesh, *args, callback: Callable[..., Any], return [] eager_rules[debugging.debug_callback_p] = _debug_callback_eager_rule -def _device_put_eager_rule(mesh, x, *, src, device): - del mesh, src - if device is None: - return x - else: - raise ValueError("device_put with explicit device not allowed within " - f"shard_map-decorated functions, but got device {device}") +def _device_put_eager_rule(mesh, *xs, srcs, devices): + del mesh, srcs + for device in devices: + if device is not None: + raise ValueError("device_put with explicit device not allowed within " + f"shard_map-decorated functions, but got device {device}") + return xs eager_rules[dispatch.device_put_p] = _device_put_eager_rule # New primitives for efficient transposition @@ -941,7 +1041,8 @@ def _standard_collective_rewrite(prim, mesh, in_rep, x, axis_name, **params): special.__dict__.values(), convolution.__dict__.values(), fft.__dict__.values(), linalg.__dict__.values(), ops.__dict__.values(), ad_util.__dict__.values(), - prng.__dict__.values(), ann.__dict__.values()): + prng.__dict__.values(), ann.__dict__.values(), + random.__dict__.values()): if isinstance(o, core.Primitive): register_standard_check(o) register_standard_rewrite(o) @@ -1059,8 +1160,8 @@ def _io_callback_rule(mesh, *_, result_avals, **__): @register_check(dispatch.device_put_p) -def _device_put_rule(mesh, x, **_): - return x +def _device_put_rule(mesh, *xs, **_): + return list(xs) register_norewrite(dispatch.device_put_p) @@ -1188,6 +1289,9 @@ def _shard_map_batch( for ax in names} for names, d in zip(in_names, in_dims)] spmd_axis_name = trace.spmd_axis_name if spmd_axis_name is not None: + used = {n for names in in_names for ns in names.values() for n in ns} + if set(spmd_axis_name) & used: + raise ValueError("vmap spmd_axis_name cannot appear in shard_map in_specs") new_in_names = [{**ns, d:spmd_axis_name} if d is not batching.not_mapped # type: ignore else ns for ns, d in zip(new_in_names, in_dims)] @as_hashable_function(closure=out_names_thunk) @@ -1220,6 +1324,9 @@ def _batch_out_names(spmd_axis_name, dims, out_names): out_names_ = [{ax + (d is not batching.not_mapped and d <= ax): names[ax] for ax in names} for names, d in zip(out_names, dims)] if spmd_axis_name is not None: + used = {n for names in out_names for ns in names.values() for n in ns} + if set(spmd_axis_name) & used: + raise ValueError("vmap spmd_axis_name cannot appear in shard_map out_specs") out_names_ = [{**ns, d:spmd_axis_name} if d is not batching.not_mapped else ns for ns, d in zip(out_names_, dims)] return out_names_ @@ -1272,6 +1379,7 @@ def _shard_map_partial_eval(trace, shard_map_p, f, tracers, mesh, in_names, in_pvals = [t.pval for t in tracers] in_knowns, in_avals, in_consts = pe.partition_pvals(in_pvals) unk_in_names, known_in_names = pe.partition_list(in_knowns, in_names) + all_names = _all_mesh_names(mesh) in_avals_sharded = map(partial(_shard_aval, mesh), unk_in_names, in_avals) f = pe.trace_to_subjaxpr_nounits_fwd2(f, trace.main, False) f = _promote_scalar_residuals(f) @@ -1283,7 +1391,7 @@ def known_out_names(): in_fwd, out_fwd, out_knowns, _, jaxpr, _ = aux() _, out_known_names = pe.partition_list(out_knowns, out_names_thunk()) num_res = sum(f1 is None and f2 is None for f1, f2 in zip(in_fwd, out_fwd)) - return (*out_known_names, *({0: (*mesh.axis_names,)},) * num_res) + return (*out_known_names, *({0: all_names},) * num_res) known_params = dict(mesh=mesh, in_names=(*known_in_names,), out_names_thunk=known_out_names, check_rep=check_rep, @@ -1298,7 +1406,7 @@ def known_out_names(): res = subs_list2(in_fwd, out_fwd, in_consts, out_consts, non_fwd_res) res_names = [known_in_names[f1] if f1 is not None else known_out_names_[f2] if f2 is not None else - {0: (*mesh.axis_names,)} for f1, f2 in zip(in_fwd, out_fwd)] + {0: all_names} for f1, f2 in zip(in_fwd, out_fwd)] unk_in_names = (*res_names,) + ({},) * len(env) + (*unk_in_names,) const_tracers = map(trace.new_instantiated_const, res) env_tracers = map(trace.full_raise, env) @@ -1310,7 +1418,7 @@ def known_out_names(): out_tracers = [pe.JaxprTracer(trace, pe.PartialVal.unknown(a), None) for a in out_avals] effs = core.filter_named_axis_effects(jaxpr.effects, mesh.axis_names) - eqn = pe.new_eqn_recipe((*const_tracers, *env_tracers, *unk_arg_tracers), # type: ignore[arg-type] + eqn = pe.new_eqn_recipe((*const_tracers, *env_tracers, *unk_arg_tracers), out_tracers, shard_map_p, unk_params, effs, source_info_util.current()) for t in out_tracers: t.recipe = eqn @@ -1320,6 +1428,7 @@ def known_out_names(): def _shard_map_partial_eval_post_process( trace, tracers, mesh, in_names, out_names_thunk, check_rep, rewrite, auto): del check_rep + all_names = _all_mesh_names(mesh) unk_tracers = [t for t in tracers if not t.is_known()] jaxpr, res, env = pe.tracers_to_jaxpr([], unk_tracers) # TODO(mattjj): output forwarding optimization @@ -1340,7 +1449,7 @@ def todo(out): const_tracers = map(trace.new_instantiated_const, res_) env_tracers = map(trace.full_raise, env) - staged_in_names = ({0: (*mesh.axis_names,)},) * len(res_) + ({},) * len(env) + staged_in_names = ({0: all_names},) * len(res_) + ({},) * len(env) staged_params = dict(jaxpr=jaxpr_, mesh=mesh, in_names=staged_in_names, out_names=(*out_names_unknown,), check_rep=False, rewrite=rewrite, auto=auto) @@ -1359,7 +1468,7 @@ def todo(out): def out_names_transform(out_names): nonlocal out_names_unknown out_names_unknown, out_names_known = partition_list(out_knowns, out_names) - return (*out_names_known,) + ({0: (*mesh.axis_names,)},) * len(res) + return (*out_names_known,) + ({0: all_names},) * len(res) out_names_unknown: list | None = None return out, (todo, out_names_transform) @@ -1387,13 +1496,21 @@ def fun(*res_and_args): jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(fun, in_avals) return jaxpr + +def _unmentioned2(mesh: Mesh, names: AxisNames) -> list[AxisName]: + # We use a filtered-down version of unmentioned to avoid defensive-psum over + # more chips than required in the transpose-no-check-rep case. + name_set = {n for ns in names.values() for n in ns} + return [n for n in _all_mesh_names(mesh) if n not in name_set] + + def _shard_map_transpose(out_cts, *args, jaxpr, mesh, in_names, out_names, check_rep, rewrite, auto): mb_div = lambda x, y: x / y if y != 1 else x out_cts = [ad.Zero(_shard_aval(mesh, ns, x.aval)) if type(x) is ad.Zero - else x if rewrite - else mb_div(x, prod(map(mesh.shape.get, _unmentioned(mesh, ns)))) - for ns, x in zip(out_names, out_cts)] + else x if rewrite + else mb_div(x, prod(map(mesh.shape.get, _unmentioned2(mesh, ns)))) + for ns, x in zip(out_names, out_cts)] args = [x if type(x) is not ad.UndefinedPrimal else ad.UndefinedPrimal(_shard_aval(mesh, ns, x.aval)) for ns, x in zip(in_names, args)] @@ -1409,8 +1526,9 @@ def fun_trans(out_cts, args): jaxpr_unknown.jaxpr, False, (), (*res_reshaped, *undefs), out_cts ) out = [ad.Zero(_unshard_aval(mesh, ns, x.aval)) if type(x) is ad.Zero - else x if rewrite else jax.lax.psum(x, tuple(_unmentioned(mesh, ns))) - for ns, x in zip(in_names, out)] + else x if rewrite + else jax.lax.psum(x, tuple(_unmentioned2(mesh, ns))) + for ns, x in zip(in_names, out)] return out fun_trans, nz_arg_cts = ad.nonzero_outputs(fun_trans) @@ -1469,10 +1587,10 @@ def _partial_eval_jaxpr_custom_rule( _, ins_staged = partition_list(inst_in, eqn.invars) _, out_binders_staged = partition_list(inst_out, eqn.outvars) newvar = core.gensym() - params_known, params_staged = _pe_custom_params( + params_known, params_staged, all_names = _pe_custom_params( unks_in, inst_in, map(op.not_, unks_out), inst_out, in_fwd, out_fwd, which, dict(eqn.params, jaxpr=jaxpr_known), dict(eqn.params, jaxpr=jaxpr_staged)) - residuals = [newvar(_unshard_aval(mesh, {0: (*mesh.axis_names,)}, var.aval)) + residuals = [newvar(_unshard_aval(mesh, {0: all_names}, var.aval)) for var, w in zip(jaxpr_staged.invars[:num_res], which) if w] eqn_known = pe.new_jaxpr_eqn(ins_known, [*out_binders_known, *residuals], eqn.primitive, params_known, jaxpr_known.effects, @@ -1521,9 +1639,10 @@ def _pe_custom_params(unks_in, inst_in, kept_outs_known, kept_outs_staged, in_fwd, out_fwd, which, params_known, params_staged): # prune inputs to jaxpr_known according to unks_in mesh = params_known['mesh'] + all_names = _all_mesh_names(mesh) in_names_known, _ = partition_list(unks_in, params_known['in_names']) _, out_names_known = partition_list(kept_outs_known, params_known['out_names']) - out_names_known = out_names_known + [{0: (*mesh.axis_names,)}] * sum(which) + out_names_known = out_names_known + [{0: all_names}] * sum(which) new_params_known = dict(params_known, in_names=tuple(in_names_known), out_names=tuple(out_names_known)) @@ -1531,12 +1650,22 @@ def _pe_custom_params(unks_in, inst_in, kept_outs_known, kept_outs_staged, _, in_names_staged = partition_list(inst_in, params_staged['in_names']) res_names = [in_names_known[f1] if f1 is not None else out_names_known[f2] if f2 is not None else - {0: (*mesh.axis_names,)} for f1, f2 in zip(in_fwd, out_fwd)] + {0: all_names} for f1, f2 in zip(in_fwd, out_fwd)] in_names_staged = res_names + in_names_staged _, out_names_staged = partition_list(kept_outs_staged, params_staged['out_names']) new_params_staged = dict(params_staged, in_names=tuple(in_names_staged), out_names=tuple(out_names_staged), check_rep=False) - return new_params_known, new_params_staged + return new_params_known, new_params_staged, all_names + + +# TODO(mattjj): remove this mechanism when we revise mesh scopes +def _all_mesh_names(mesh: Mesh) -> tuple[AxisName, ...]: + stack = core.thread_local_state.trace_state.trace_stack.stack + names = {n for frame in stack + if (ns := frame.payload.get('spmd_axis_name', ())) is not None + for n in ns} + return tuple(name for name in mesh.axis_names if name not in names) + # DCE diff --git a/jax/experimental/slab/djax.py b/jax/experimental/slab/djax.py new file mode 100644 index 000000000000..c989ede8663e --- /dev/null +++ b/jax/experimental/slab/djax.py @@ -0,0 +1,187 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import collections +from collections.abc import Callable +from functools import partial +import sys + +import numpy as np +import jax +import jax.numpy as jnp +from jax import lax + +from jax._src import core +from jax._src import util + +import jax.experimental.slab.slab as sl + +map, zip = util.safe_map, util.safe_zip + +def make_djaxpr(f, abstracted_axes, **make_jaxpr_kwargs): + def djaxpr_maker(*args, **kwargs): + with jax._src.config.dynamic_shapes(True): + jaxpr_maker = jax.make_jaxpr( + f, abstracted_axes=abstracted_axes, **make_jaxpr_kwargs) + return jaxpr_maker(*args, **kwargs) + return djaxpr_maker + +@partial(jax.jit, static_argnums=(0,)) +def interp(djaxpr, slab, sizes, args): + views = [] + in_types = [x.aval for x in djaxpr.invars] + _, arg_types = util.split_list(in_types, [len(djaxpr.invars) - len(args)]) + for ty, x in zip(arg_types, args): + if isinstance(ty, core.DShapedArray): + resolved_shape = tuple(sizes.get(d, d) for d in ty.shape) + # TODO(frostig,mattjj): reconstructing slab views seems off? + views.append(sl.SlabView(x, resolved_shape, ty.dtype)) + else: + views.append(x) + slab, outs = eval_djaxpr(djaxpr, slab, *sizes.values(), *views) + return slab, outs + +def _check_axis_size_conflicts(all_axes, sizes): + if len(all_axes) != len(set(all_axes)): + d = collections.defaultdict(list) + for name, sz in zip(all_axes, sizes): + d[name].append(sz) + msg = '; '.join([f'{name}: {" != ".join(map(str, sizes))}' + for name, sizes in d.items() if len(sizes) > 1]) + raise ValueError(f'abstracted axes resolve to conflicting sizes. {msg}') + +def djit(f, abstracted_axes, **djit_kwargs): + # TODO(frostig,mattjj): un/flatten f + def f_wrapped(slab, *args): # TODO(frostig,mattjj): kw support + djaxpr = make_djaxpr(f, abstracted_axes, **djit_kwargs)(*args).jaxpr + in_types = [x.aval for x in djaxpr.invars] + _, arg_types = util.split_list(in_types, [len(djaxpr.invars) - len(args)]) + + def upload(slab, ty, x): + if isinstance(ty, core.DShapedArray): + return sl.slab_upload(slab, x) + elif isinstance(ty, core.ShapedArray): + return slab, x + else: + assert False + + slab, views = sl.chain(slab, upload, *zip(arg_types, args)) + + sizes: dict[core.Var, int] = {} + for ty, x in zip(arg_types, args): + for v, d in zip(ty.shape, x.shape): + if isinstance(v, core.Var): + d_ = sizes.setdefault(v, d) + if d_ != d: + raise ValueError( + f'abstract dimension bound to unequal sizes: {d_} != {d}') + + slab, out_views = interp( + djaxpr, slab, sizes, + [v.addr if isinstance(v, sl.SlabView) else v for v in views]) + return slab, tuple(sl.slab_download(slab, v) for v in out_views) + + return f_wrapped + +def eval_djaxpr(jaxpr: core.Jaxpr, slab: sl.Slab, *args: jax.Array | sl.SlabView): + if jaxpr.constvars: raise NotImplementedError + + env: dict[core.Var, jax.Array | sl.SlabView] = {} + + def read(a): + return env[a] if type(a) is core.Var else a.val + + def write(v, val): + env[v] = val + + map(write, jaxpr.invars, args) + for eqn in jaxpr.eqns: + invals = map(read, eqn.invars) + slab, outvals = rules[eqn.primitive](slab, *invals, **eqn.params) + map(write, eqn.outvars, outvals) + return slab, map(read, jaxpr.outvars) + +rules: dict[core.Primitive, Callable] = {} + +def matmul_rule(slab, lhs, rhs, *, dimension_numbers, **_): + slab, out = sl.matmul(slab, lhs, rhs) + return slab, [out] +rules[lax.dot_general_p] = matmul_rule + +def tanh_rule(slab, x, **_): + slab, out = sl.tanh(slab, x) + return slab, [out] +rules[lax.tanh_p] = tanh_rule + +# ------- + +def print_seg(msg): + print() + print(f'-- {msg}') + print() + +def check_djit(slab, f, abstracted_axes, *args): + refs, _ = jax.tree.flatten(f(*args)) + f_djit = djit(f, abstracted_axes=abstracted_axes) + slab, outs = f_djit(slab, *args) + for out, ref in zip(outs, refs): + abs_err = jnp.max(jnp.abs(out - ref)) + rel_err = jnp.max(jnp.abs(out - ref) / jnp.abs(ref)) + msg = f'abs={abs_err}, rel={rel_err}' + assert jnp.allclose(out, ref, atol=1e-4), msg + +def test(slab, xs): + a, b = xs + + def f(a, b): + c = jnp.dot(a, b) + return jnp.tanh(c) + + abstracted_axes = (('m', 'k'), ('k', 'n')) + + print_seg('djaxpr') + djaxpr = make_djaxpr(f, abstracted_axes)(a, b).jaxpr + print(djaxpr) + + print_seg('djax output') + f_djit = djit(f, abstracted_axes=abstracted_axes) + slab, [c] = f_djit(slab, a, b) + print(c) + + print_seg('djax -> jax lowering') + big_jaxpr = jax.make_jaxpr(f_djit)(slab, a, b) + print('\n'.join(str(big_jaxpr).split('\n')[:20])) + print('...') + print('\n'.join(str(big_jaxpr).split('\n')[-20:])) + print(len(str(big_jaxpr).split('\n'))) + + check_djit(slab, f, abstracted_axes, a, b) + +def parse_arr(i, s): + shape = eval(s) + return np.random.RandomState(i).normal(size=shape).astype(np.float32) + +def main(args): + slab_sz = eval(args[0]) + print('slab size', slab_sz) + xs = map(parse_arr, range(len(args[1:])), args[1:]) + assert all(len(x.shape) == 2 for x in xs) + slab = sl.slab_make(slab_sz) + test(slab, xs) + + +if __name__ == '__main__': + main(sys.argv[1:]) diff --git a/jax/experimental/slab/slab.py b/jax/experimental/slab/slab.py new file mode 100644 index 000000000000..af7b079eeb7f --- /dev/null +++ b/jax/experimental/slab/slab.py @@ -0,0 +1,365 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from collections.abc import Iterable, Sequence +from functools import partial, reduce +import sys +import typing +from typing import NamedTuple, Union + +import numpy as np + +import jax +import jax.numpy as jnp + +from jax._src import core +from jax._src import util + +map, zip = util.safe_map, util.safe_zip + +DInt = jax.Array +Address = DInt +XInt = Union[int, DInt] +DShape = tuple[XInt, ...] +SShape = tuple[int, ...] +DType = jnp.dtype + +class Slab(NamedTuple): + data: jax.Array + cursor: Address + +@jax.tree_util.register_pytree_node_class +class SlabView(NamedTuple): + addr: Address + shape: DShape + dtype: DType + + def size(self): + return jnp.prod(jnp.array(self.shape)) + + def ndim(self): + return len(self.shape) + + def tree_flatten(self): + return (self.addr, self.shape), self.dtype + + @classmethod + def tree_unflatten(cls, dtype, xs): + addr, shape = xs + return cls(addr, shape, dtype) + +word_b = 4 +phrase_b = 512 +phrase_w = 128 +tile_aspect = 8 + +def xceil_div(x: XInt, y: XInt) -> XInt: + """ceil(x / y)""" + return (x + y - 1) // y + +def _xadd(x: XInt, y: XInt) -> XInt: + return x + y + +def _xmul(x: XInt, y: XInt) -> XInt: + return x * y + +def xadd(*xs: XInt) -> XInt: + return reduce(_xadd, xs, typing.cast(XInt, 0)) + +def xmul(*xs: XInt) -> XInt: + return reduce(_xmul, xs, typing.cast(XInt, 1)) + +def xsum(xs: Iterable[XInt]) -> XInt: + return xadd(*list(xs)) + +def xprod(xs: Iterable[XInt]) -> XInt: + return xmul(*list(xs)) + +def static_int(x: XInt) -> bool: + return isinstance(core.get_aval(x), core.ConcreteArray) + +def static_shape(s: DShape) -> bool: + return all(map(static_int, s)) + +def assert_static_int(x: XInt) -> int: + if not static_int(x): + raise TypeError(f'{x} is not a static int') + return int(x) + +def assert_static_shape(s: DShape) -> SShape: + if not static_shape(s): + raise TypeError(f'{s} is not a static shape') + return tuple(map(int, s)) + +def tile_shape(shape: DShape, dtype) -> SShape: + # Units: (1, 1, ..., elements, 1) + if len(shape) < 2: + raise NotImplementedError('matrices or bust') + num_leading = len(shape) - 2 + return (1,) * num_leading + (tile_aspect * word_b // dtype.itemsize, + phrase_b // word_b) + +def tile_phrases(shape: DShape, dtype: DType): + # Units: phrases + return xprod(tile_shape(shape, dtype)) * dtype.itemsize // phrase_b + +def slab_make(num_phrases): + return Slab(jnp.zeros((num_phrases, phrase_w), dtype=jnp.uint32), + jnp.array(0, dtype=jnp.int32)) + +def slab_alloc(slab: Slab, shape: DShape, dtype): + if len(shape) < 2: + raise NotImplementedError('matrices or bust') + tiled_shape = map(xceil_div, shape, tile_shape(shape, dtype)) + num_p = xmul(*tiled_shape, tile_phrases(shape, dtype)) + new_slab = Slab(slab.data, slab.cursor + num_p) + slab_val = SlabView(slab.cursor, shape, dtype) + return new_slab, slab_val + +def strides(xs): + s = 1 + ss = [] + for x in reversed(xs): + ss.append(s) + s *= x + return tuple(reversed(ss)) + +def slab_slices(view, slice_base_e: DShape, slice_shape_e: SShape): + view_shape_e = tile_shape(view.shape, view.dtype) + # dassert all(s % t == 0 for s, t in zip(slice_base, view_shape_e)) + # dassert all(s % t == 0 for s, t in zip(slice_shape, view_shape_e)) + slice_base_t = [s // t for s, t in zip(slice_base_e, view_shape_e)] + slice_shape_t = [s // t for s, t in zip(slice_shape_e, view_shape_e)] + tiled_shape = map(xceil_div, view.shape, view_shape_e) + tiled_strides = strides(tiled_shape) + tp = tile_phrases(view.shape, view.dtype) + for idx in np.ndindex(*slice_shape_t[:-1]): + linear_idx_t = xsum( + map(xmul, map(xadd, slice_base_t, (*idx, 0)), tiled_strides)) + yield (view.addr + linear_idx_t * tp, slice_shape_t[-1] * tp) + +def reinterpret_cast(x: jax.Array, shape: SShape, dtype: DType): + x_bytes = x.size * x.dtype.itemsize + if -1 in shape: + assert x_bytes % xprod(s for s in shape if s != -1) * dtype.itemsize == 0 + else: + assert x_bytes == xprod(shape) * dtype.itemsize, (x.shape, x.dtype, shape, dtype) + if x.dtype.itemsize != dtype.itemsize: + # reshape(x, -1) in conversion below becomes reshape(-1, a, b) for some a,b + raise NotImplementedError('todo') + return jax.lax.bitcast_convert_type(x.reshape(-1), dtype).reshape(shape) + +def slab_read(slab, view, slice_base: DShape, slice_shape: SShape): + view_tile_shape = tile_shape(view.shape, view.dtype) + tiled_shape = assert_static_shape( + tuple(map(xceil_div, slice_shape, view_tile_shape))) + slices = [ + jax.lax.dynamic_slice_in_dim(slab.data, addr, phrases) + for addr, phrases in slab_slices(view, slice_base, slice_shape)] + slice_mem = jnp.stack(slices, axis=0) + return reinterpret_cast( + slice_mem, (*tiled_shape, *view_tile_shape), view.dtype + ).swapaxes(-2, -3).reshape(slice_shape) + +# TODO: just take vjp of slab_read +def slab_write(slab, view, slice_base: DShape, inval: jax.Array): + slice_shape = inval.shape + view_tile_shape = tile_shape(view.shape, view.dtype) + tiled_shape = map(xceil_div, inval.shape, view_tile_shape) + inval_linearized = inval.reshape( + *tiled_shape[:-1], view_tile_shape[-2], tiled_shape[-1], view_tile_shape[-1] + ).swapaxes(-2, -3) + slice_mem = reinterpret_cast(inval_linearized, (-1, phrase_w), + jnp.dtype('uint32')) + slice_addr = 0 + new_slab = slab.data + for slab_addr, slice_sz_p in slab_slices(view, slice_base, slice_shape): + s = jax.lax.dynamic_slice_in_dim(slice_mem, slice_addr, slice_sz_p) + slice_addr += slice_sz_p + new_slab = jax.lax.dynamic_update_slice_in_dim( + new_slab, s, slab_addr, axis=0) + return Slab(new_slab, slab.cursor) + +def elementwise(f, slab: Slab, xs: Sequence[SlabView], out: SlabView): + if len(xs) == 0: + raise TypeError('missing input arguments') + x = xs[0] + for y in xs[1:]: + if x.shape != y.shape: + raise ValueError(f'elementwise shapes mismatch: {x.shape} != {y.shape}') + if x.dtype != y.dtype: + raise ValueError(f'elementwise dtypes mismatch: {x.dtype} != {y.dtype}') + if x.shape != out.shape: + raise ValueError( + f'elementwise input/output shape mismatch: {x.shape} != {out.shape}') + + tiled_shape = map(xceil_div, x.shape, tile_shape(x.shape, x.dtype)) + x_sz_p = xprod(tiled_shape) * tile_phrases(x.shape, x.dtype) + compute_tile_p = 16 + num_whole_blocks = x_sz_p // compute_tile_p + + def f_u32(*zs): + a = zs[0] + return reinterpret_cast( + f(*[reinterpret_cast(z, a.shape, x.dtype) for z in zs]), + a.shape, jnp.dtype('uint32')) + + def body(i_b, mem): + i_p = i_b * compute_tile_p + slices = [ + jax.lax.dynamic_slice_in_dim(mem, z.addr + i_p, compute_tile_p) + for z in xs] + out_slice = f_u32(*slices) + return jax.lax.dynamic_update_slice_in_dim( + mem, out_slice, out.addr + i_p, axis=0) + mem = jax.lax.fori_loop(0, num_whole_blocks, body, slab.data) + + epi_start_p = num_whole_blocks * compute_tile_p + epi_size_p = x_sz_p - epi_start_p + slices = [ + jax.lax.dynamic_slice_in_dim(mem, z.addr + epi_start_p, compute_tile_p) + for z in xs] + out_slice = f_u32(*slices) + return Slab(masked_store(mem, out.addr + epi_start_p, out_slice, epi_size_p), + slab.cursor) + +def masked_store(mem, addr, update, num_p): + update_p = update.shape[0] + prev_val = jax.lax.dynamic_slice_in_dim(mem, addr, update_p) + new_val = jnp.where(jnp.arange(update_p)[:, None] < num_p, update, prev_val) + return jax.lax.dynamic_update_slice_in_dim(mem, new_val, addr, axis=0) + +def _matmul(slab: Slab, ins: Sequence[SlabView], out: SlabView): + lhs, rhs = ins + dtype = lhs.dtype + n, k, m = (*lhs.shape, rhs.shape[1]) + # todo: shape + dtype check + # dassert shapes are tile aligned + tile_n, tile_k, tile_m = 128, 128, 128 + n_tiles = n // tile_n + k_tiles = k // tile_k + m_tiles = m // tile_m + + mem = slab + def loop_n(ni, mem): + def loop_m(mi, mem): + acc = jnp.zeros((tile_n, tile_m), dtype=dtype) + def loop_k(ki, acc): + lhs_tile = slab_read(mem, lhs, (ni * tile_n, ki * tile_k), (tile_n, tile_k)) + rhs_tile = slab_read(mem, rhs, (ki * tile_k, mi * tile_m), (tile_k, tile_m)) + acc += lhs_tile @ rhs_tile + return acc + acc = jax.lax.fori_loop(0, k_tiles, loop_k, acc) + return slab_write(mem, out, (ni * tile_n, mi * tile_m), acc) + return jax.lax.fori_loop(0, m_tiles, loop_m, mem) + mem = jax.lax.fori_loop(0, n_tiles, loop_n, mem) + return mem + +def make_allocating_op(op, type_rule): + def made_op(slab, *xs: SlabView): + out_shape, out_dtype = type_rule(*xs) + slab, out = slab_alloc(slab, out_shape, out_dtype) + slab = op(slab, xs, out) + return slab, out + return made_op + +add = make_allocating_op(partial(elementwise, jax.lax.add), + lambda x, *_: (x.shape, x.dtype)) +mul = make_allocating_op(partial(elementwise, jax.lax.mul), + lambda x, *_: (x.shape, x.dtype)) +tanh = make_allocating_op(partial(elementwise, jax.lax.tanh), + lambda x, *_: (x.shape, x.dtype)) +matmul = make_allocating_op(_matmul, + lambda a, b: ((a.shape[0], b.shape[1]), a.dtype)) + +def parse_arr(i, s): + shape = eval(s) + return np.random.RandomState(i).normal(size=shape).astype(np.float32) + +def print_seg(msg): + print() + print(f'-- {msg}') + print() + +def make_jaxpr_slab_write(slab, view, inval): + return jax.make_jaxpr( + lambda slab, x: slab_write(slab, view, (0, 0), x))(slab, inval) + +def make_jaxpr_slab_read(slab, view, outval_shape): + return jax.make_jaxpr( + lambda slab: slab_read(slab, view, (0, 0), outval_shape))(slab) + +def slab_download(slab, v): + if not static_shape(v.shape): raise Exception + return slab_read(slab, v, (0,) * v.ndim(), v.shape) + +def slab_upload(slab, x): + slab, xv = slab_alloc(slab, x.shape, x.dtype) + slab = slab_write(slab, xv, (0,) * x.ndim, x) + return slab, xv + +def chain(slab, fs, *argss, unary=False): + if callable(fs): + fs = [fs] * len(argss) + outss = [] + for f, args in zip(fs, argss): + if unary: + slab, outs = f(slab, args) + else: + slab, outs = f(slab, *args) + outss.append(outs) + return slab, outss + +def test_binop(op, ref_op, slab, x, y): + z = ref_op(x, y) + slab, xv = slab_upload(slab, x) + slab, yv = slab_upload(slab, y) + slab, zv = op(slab, xv, yv) + assert jnp.allclose(slab_download(slab, xv), x, atol=1e-4) + assert jnp.allclose(slab_download(slab, yv), y, atol=1e-4) + assert jnp.allclose(slab_download(slab, zv), z, atol=1e-4) + +def main(args): + xs = map(parse_arr, range(len(args)), args) + assert all(len(x.shape) == 2 for x in xs) + + slab = slab_make(1024) + + x, y, *_ = xs + test_binop(add, jax.lax.add, slab, x, x) + test_binop(mul, jax.lax.mul, slab, x, x) + test_binop(matmul, lambda a, b: a @ b, slab, x, y) + + def put(slab, x): + slab, v = slab_upload(slab, x) + print_seg('slab_read result') + print(slab_download(slab, v)) + return slab, v + + slab, vals = chain(slab, put, *xs, unary=True) + + if len(vals) >= 2: + x, y, *_ = vals + slab, z = mul(slab, x, x) + print_seg('mul') + print(slab_download(slab, z)) + slab, w = add(slab, x, z) + print_seg('add') + print(slab_download(slab, w)) + + +if __name__ == '__main__': + main(sys.argv[1:]) diff --git a/jax/experimental/sparse/ad.py b/jax/experimental/sparse/ad.py index c789baf9d15e..2c235c9320d5 100644 --- a/jax/experimental/sparse/ad.py +++ b/jax/experimental/sparse/ad.py @@ -14,9 +14,9 @@ from __future__ import annotations -from collections.abc import Sequence +from collections.abc import Callable, Sequence import itertools -from typing import Any, Callable, Union +from typing import Any import jax from jax._src import core @@ -81,7 +81,7 @@ def value_and_grad(fun: Callable, argnums: int | Sequence[int] = 0, taking the gradient with respect to a :class:`jax.experimental.sparse` array, the gradient is computed in the subspace defined by the array's sparsity pattern. - Example: + Examples: >>> from jax.experimental import sparse >>> X = sparse.BCOO.fromdense(jnp.arange(6.)) @@ -109,7 +109,7 @@ def grad(fun: Callable, argnums: int | Sequence[int] = 0, the gradient with respect to a :class:`jax.experimental.sparse` array, the gradient is computed in the subspace defined by the array's sparsity pattern. - Example: + Examples: >>> from jax.experimental import sparse >>> X = sparse.BCOO.fromdense(jnp.arange(6.)) diff --git a/jax/experimental/sparse/bcoo.py b/jax/experimental/sparse/bcoo.py index b6572b01cadf..4cbe52383751 100644 --- a/jax/experimental/sparse/bcoo.py +++ b/jax/experimental/sparse/bcoo.py @@ -385,7 +385,7 @@ def bcoo_extract(sparr: BCOO, arr: ArrayLike, *, assume_unique: bool | None = No extracted : a BCOO array with the same sparsity pattern as self. """ if not isinstance(sparr, BCOO): - raise ValueError(f"First argument to bcoo_extract should be a BCOO array. Got {type(sparr)=}") + raise TypeError(f"First argument to bcoo_extract should be a BCOO array. Got {type(sparr)=}") a = jnp.asarray(arr) if a.shape != sparr.shape: raise ValueError(f"shape mismatch: {sparr.shape=} {a.shape=}") @@ -514,7 +514,7 @@ def bcoo_transpose(mat: BCOO, *, permutation: Sequence[int]) -> BCOO: batch, sparse, and dense dimensions. The i’th axis of the returned array corresponds to the axis numbered permutation[i] of ``mat``. Transpose permutation currently does not support permuting batch axes with non-batch - axes nor permutating dense axes with non-dense axes. + axes nor permuting dense axes with non-dense axes. Returns: A BCOO-format array. @@ -641,7 +641,7 @@ def bcoo_dot_general(lhs: BCOO | Array, rhs: BCOO | Array, *, dimension_numbers: preferred_element_type=preferred_element_type, lhs_spinfo=lhs._info) elif isinstance(rhs, BCOO): - return _bcoo_rdot_general(lhs, rhs.data, rhs.indices, dimension_numbers=dimension_numbers, # type: ignore[arg-type] + return _bcoo_rdot_general(lhs, rhs.data, rhs.indices, dimension_numbers=dimension_numbers, preferred_element_type=preferred_element_type, rhs_spinfo=rhs._info) else: @@ -749,7 +749,7 @@ def _bcoo_dot_general_abstract_eval(lhs_data, lhs_indices, rhs, *, dimension_num n_batch, n_sparse, _, _ = _validate_bcoo(lhs_data, lhs_indices, lhs_spinfo.shape) if lhs_batch and max(lhs_batch) >= n_batch: raise NotImplementedError( - "bcoo_dot_general batch dimensions must be among the batch dimensions in the sparse representtaion.\n" + "bcoo_dot_general batch dimensions must be among the batch dimensions in the sparse representation.\n" f"got {lhs_batch=}, {n_batch=}") # TODO: support contraction of dense dimensions? @@ -862,7 +862,7 @@ def _bcoo_dot_general_transpose(ct, lhs_data, lhs_indices, rhs, *, dimension_num rhs_kept = remaining(range(rhs_ndim), rhs_contract, rhs_batch) ans_batch, ans_lhs, ans_rhs = map(list, ranges_like(lhs_batch, lhs_kept, rhs_kept)) if ad.is_undefined_primal(lhs_data): - dims: DotDimensionNumbers = ((ans_rhs, rhs_kept), (ans_batch, rhs_batch)) # type: ignore[assignment] + dims: DotDimensionNumbers = ((ans_rhs, rhs_kept), (ans_batch, rhs_batch)) lhs_contract_sorted_by_rhs = list(np.take(lhs_contract, np.argsort(rhs_contract))) permutation = list(lhs_batch) + lhs_kept + lhs_contract_sorted_by_rhs out_axes = list(np.argsort(permutation)) @@ -895,7 +895,7 @@ def _bcoo_dot_general_transpose(ct, lhs_data, lhs_indices, rhs, *, dimension_num result = _bcoo_extract(lhs_indices, out_dense) return result, lhs_indices, rhs else: - dims = ((lhs_kept, ans_lhs), (lhs_batch, ans_batch)) # type: ignore[assignment] + dims = ((lhs_kept, ans_lhs), (lhs_batch, ans_batch)) rhs_contract_sorted_by_lhs = list(np.take(rhs_contract, np.argsort(lhs_contract))) out_axes = list(np.argsort(list(rhs_batch) + rhs_contract_sorted_by_lhs + rhs_kept)) result = _bcoo_dot_general(lhs_data, lhs_indices, ct, lhs_spinfo=lhs_spinfo, @@ -1224,15 +1224,24 @@ def _bcoo_spdot_general_abstract_eval(lhs_data, lhs_indices, rhs_data, rhs_indic out_n_batch = lhs.n_batch + rhs.n_batch - len(lhs_batch) out_nse = min(out_nse, math.prod(out_aval.shape[out_n_batch:])) + lhs_batch_shape = np.broadcast_shapes( + tuple(lhs_data.shape[dim] for dim in range(lhs.n_batch) if dim not in lhs_batch), + tuple(lhs_indices.shape[dim] for dim in range(lhs.n_batch) if dim not in lhs_batch), + ) + rhs_batch_shape = np.broadcast_shapes( + tuple(rhs_data.shape[dim] for dim in range(rhs.n_batch) if dim not in rhs_batch), + tuple(rhs_indices.shape[dim] for dim in range(rhs.n_batch) if dim not in rhs_batch), + ) + data_shape = ( *(lhs_shape[dim] for dim in lhs_batch), - *(lhs_data.shape[dim] for dim in range(lhs.n_batch) if dim not in lhs_batch), - *(rhs_data.shape[dim] for dim in range(rhs.n_batch) if dim not in rhs_batch), + *lhs_batch_shape, + *rhs_batch_shape, out_nse) indices_shape = ( *(lhs_shape[dim] for dim in lhs_batch), - *(lhs_indices.shape[dim] for dim in range(lhs.n_batch) if dim not in lhs_batch), - *(rhs_indices.shape[dim] for dim in range(rhs.n_batch) if dim not in rhs_batch), + *lhs_batch_shape, + *rhs_batch_shape, out_nse, lhs.n_sparse + rhs.n_sparse - 2 * len(lhs_contracting)) data_aval = core.ShapedArray(data_shape, out_aval.dtype) @@ -1951,7 +1960,7 @@ def bcoo_slice(mat: BCOO, *, start_indices: Sequence[int], limit_indices: Sequen out: BCOO array containing the slice. """ if not isinstance(mat, BCOO): - raise ValueError(f"bcoo_slice: input should be BCOO array, got type(mat)={type(mat)}") + raise TypeError(f"bcoo_slice: input should be BCOO array, got type(mat)={type(mat)}") start_indices = [operator.index(i) for i in start_indices] limit_indices = [operator.index(i) for i in limit_indices] if strides is not None: @@ -2030,7 +2039,7 @@ def bcoo_dynamic_slice(mat: BCOO, start_indices: Sequence[Any], slice_sizes: Seq jax.eval_shape(partial(lax.dynamic_slice, slice_sizes=slice_sizes), jax.ShapeDtypeStruct(mat.shape, mat.dtype), start_indices) if not isinstance(mat, BCOO): - raise ValueError(f"bcoo_slice: input should be BCOO array, got type(mat)={type(mat)}") + raise TypeError(f"bcoo_slice: input should be BCOO array, got type(mat)={type(mat)}") start_indices = tuple(jnp.asarray(i) for i in start_indices) assert all(jnp.issubdtype(i.dtype, np.integer) for i in start_indices) assert all(i.shape == () for i in start_indices) @@ -2379,7 +2388,7 @@ def _convert_to_1d_for_conv(mat, index_dtype): # zero-out data at OOB indices, otherwise strange things happen. data = jnp.where(lax.squeeze(indices, (1,)) < mat.shape[-1], data, 0) else: - raise ValueError(f"bcoo_conv_general_dilated: input of type {type(mat)} not recognized.") + raise TypeError(f"bcoo_conv_general_dilated: input of type {type(mat)} not recognized.") return BCOO((data, indices), shape=mat.shape[2:]) def _bcoo_conv_1d(lhs: BCOO, rhs: BCOO, padding: Sequence[int]) -> BCOO: diff --git a/jax/experimental/sparse/linalg.py b/jax/experimental/sparse/linalg.py index c2990d3fed57..b0ac1fa5d380 100644 --- a/jax/experimental/sparse/linalg.py +++ b/jax/experimental/sparse/linalg.py @@ -16,7 +16,7 @@ from __future__ import annotations -from typing import Union, Callable +from collections.abc import Callable import functools import jax diff --git a/jax/experimental/sparse/nm.py b/jax/experimental/sparse/nm.py new file mode 100644 index 000000000000..6c827325befc --- /dev/null +++ b/jax/experimental/sparse/nm.py @@ -0,0 +1,244 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""N:M-sparsity associated primitives.""" + +from jax import core +from jax._src import dispatch +from jax._src.lax.lax import DotDimensionNumbers +from jax._src.lib import gpu_sparse +from jax._src.lib.mlir.dialects import mhlo +from jax._src.typing import Array, DTypeLike +from jax.interpreters import mlir +import jax.numpy as jnp +import numpy as np + +# -------------------------------------------------------------------- +# nm_spmm + +nm_spmm_p = core.Primitive("sparse_dense_matmul") + +_supported_input_types = (jnp.int8, jnp.int16, jnp.float16, jnp.bfloat16) +_supported_output_types = (jnp.bfloat16, jnp.float32) + + +def nm_spmm( + lhs: Array, + rhs: Array, + metadata: Array, + dimension_numbers: DotDimensionNumbers = (((1,), (0,)), (tuple(), tuple())), + sparse_operand_idx: int = 0, + output_dtype: DTypeLike = jnp.bfloat16, +) -> Array: + """Dot operation where one of the operands has N:M sparsity. + + Args: + lhs: An ndarray (first dot operand). + rhs: An ndarray (second dot operand). + metadata: An ndarray with structured sparsity metadata for the contracting + dimension. For 2:4 sparsity it should contain (N=2) two-bit index values + for each (M=4) element group. + dimension_numbers: a tuple of tuples of the form `((lhs_contracting_dims, + rhs_contracting_dims), (lhs_batch_dims, rhs_batch_dims))`. + sparse_operand_idx: index of the sparse operand (0 or 1). + output_dtype: result type. + + Returns: + An ndarray dense array containing the result. + """ + return nm_spmm_p.bind( + lhs, + rhs, + metadata, + dimension_numbers=dimension_numbers, + sparse_operand_idx=sparse_operand_idx, + output_dtype=output_dtype, + ) + + +def _calc_groups_per_element(n, m): + group_bits = n * (m.bit_length() - 1) # 4 bits per group for 2:4 + return 16 // group_bits + + +def _validate_dnums(rank, contract, batch, name): + non_contract = tuple(sorted(set(range(rank)) - set(contract + batch))) + if sorted(non_contract + contract + batch) != list(range(rank)): + raise TypeError(f"Incorrect dimension numbers for {name}") + return non_contract + + +def _validate_metadata(lhs, rhs, metadata, dimension_numbers, index, n=2, m=4): + assert index in (0, 1) + size_factor = n * _calc_groups_per_element(n, m) + + sparse = [lhs, rhs][index] + sparse_contract = dimension_numbers[0][index] + if metadata.dtype != np.uint16: + raise TypeError(f"Metadata must be uint16, got {metadata.dtype}") + if sparse_contract[0] != sparse.ndim - 1: + raise TypeError("Contracting dimension must be the minor one") + if metadata.shape[:-1] != sparse.shape[:-1]: + raise TypeError( + "Metadata shape must match the operand shape (except for the" + " contracting dimension)" + ) + if metadata.shape[-1] * size_factor != sparse.shape[-1]: + raise TypeError( + f"Metadata must be exactly {size_factor} times less than the" + f" contracting dimension for {n}:{m} structured sparsity (expected" + f" {sparse.shape[-1] // size_factor}, got {metadata.shape[-1]})" + ) + if sparse.shape[-1] % size_factor != 0: + raise NotImplementedError("Metadata with padding is not supported") + + dense = [lhs, rhs][1 - index] + dense_contract = dimension_numbers[0][1 - index] + a, b = sparse.shape[sparse_contract[0]], dense.shape[dense_contract[0]] + if n * b != m * a: + raise TypeError( + f"Contracting dimension sizes should have {n}:{m} ratio, got {a}:{b}" + ) + + +def _infer_result_shape(lhs, rhs, dimension_numbers): + ((lhs_contract, rhs_contract), (lhs_batch, rhs_batch)) = dimension_numbers + if len(lhs_contract) != 1 or len(rhs_contract) != 1: + raise TypeError("Only single contracting dimension is supported") + lhs_dims = _validate_dnums(lhs.ndim, lhs_contract, lhs_batch, "lhs") + rhs_dims = _validate_dnums(rhs.ndim, rhs_contract, rhs_batch, "rhs") + if len(lhs_dims) != 1 or len(rhs_dims) != 1: + raise TypeError("Only single non-contracting dimension is supported") + batch = [lhs.shape[i] for i in lhs_batch] + if batch != [rhs.shape[i] for i in rhs_batch]: + raise TypeError("Batch dimension sizes do not match") + return tuple(batch + [lhs.shape[lhs_dims[0]], rhs.shape[rhs_dims[0]]]) + + +def _nm_spmm_default_lowering(*_args, **_kwargs): + raise NotImplementedError("Sparse N:M matmul is only implemented on GPU") + + +def _nm_spmm_gpu_lowering( + ctx, + lhs, + rhs, + metadata, + *, + dimension_numbers, + sparse_operand_idx, + output_dtype, +): + assert sparse_operand_idx in (0, 1) + sparsity_descriptor = mhlo.SparsityDescriptor.get( + dimension=dimension_numbers[0][sparse_operand_idx][0], n=2, m=4 + ) + dot_dnums = mhlo.DotDimensionNumbers.get( + lhs_batching_dimensions=dimension_numbers[1][sparse_operand_idx], + rhs_batching_dimensions=dimension_numbers[1][1 - sparse_operand_idx], + lhs_contracting_dimensions=dimension_numbers[0][sparse_operand_idx], + rhs_contracting_dimensions=dimension_numbers[0][1 - sparse_operand_idx], + ) + dot_type = ctx.avals_out[0] + key = ["lhs_sparsity", "rhs_sparsity"][sparse_operand_idx] + kwargs = {key: sparsity_descriptor} + op = mhlo.SparseDotOp( + mlir.aval_to_ir_type(dot_type), lhs, rhs, [metadata], dot_dnums, **kwargs + ) + return op.results + + +@nm_spmm_p.def_abstract_eval +def _nm_spmm_abstract_eval( + lhs, rhs, metadata, *, dimension_numbers, sparse_operand_idx, output_dtype +): + if lhs.dtype not in _supported_input_types: + raise TypeError(f"Unsupported lhs input type: {lhs.dtype}") + if rhs.dtype not in _supported_input_types: + raise TypeError(f"Unsupported rhs input type: {rhs.dtype}") + if output_dtype not in _supported_output_types: + raise TypeError(f"Unsupported output type: {output_dtype}") + + res_shape = _infer_result_shape(lhs, rhs, dimension_numbers) + _validate_metadata(lhs, rhs, metadata, dimension_numbers, sparse_operand_idx) + return core.ShapedArray(res_shape, output_dtype) + + +mlir.register_lowering(nm_spmm_p, _nm_spmm_default_lowering) +dispatch.simple_impl(nm_spmm_p) + +if gpu_sparse.cuda_is_supported: + mlir.register_lowering(nm_spmm_p, _nm_spmm_gpu_lowering, platform="cuda") + +if gpu_sparse.rocm_is_supported: + mlir.register_lowering(nm_spmm_p, _nm_spmm_gpu_lowering, platform="rocm") + +# -------------------------------------------------------------------- +# nm_pack + +nm_pack_p = core.Primitive("sparse_pack_nm") + + +def nm_pack(mask: Array, n=2, m=4) -> Array: + """Generate metadata tensor for an N:M mask. + + Args: + mask: Predicates for the input tensor, where the elements are grouped in the + minor dimension. In each group of size M there should be exactly N true + values, which mark the data elements to keep. + n: Number of non-zero elements in a group. + m: Group size. + + Returns: + An ndarray containing only the masked input elements. + """ + return nm_pack_p.bind(mask, n=n, m=m) + + +def _compress(data, n, m, k): + result = [] + expected = n * (k // m) + for i in range(0, len(data), k): + index = tuple(jnp.nonzero(data[i : i + k], size=expected)[0] % m) + value = sum(j * pow(m, i) for i, j in enumerate(index)) + result.append(value) + return jnp.array(result, dtype=np.uint16) + + +@nm_pack_p.def_impl +def _nm_pack_impl(mask, *, n, m): + batch_size = m * _calc_groups_per_element(n, m) + return jnp.apply_along_axis( + lambda x: _compress(x, n, m, batch_size), -1, mask + ) + + +@nm_pack_p.def_abstract_eval +def _nm_pack_abstract_eval(mask, *, n, m): + size_factor = m * _calc_groups_per_element(n, m) + if mask.dtype != bool: + raise TypeError(f"Mask should be bool, got {mask.dtype}") + if mask.shape[-1] % size_factor != 0: + raise TypeError( + f"Inner dimension size should be divisible by {size_factor}, got" + f" {mask.shape}" + ) + res_shape = list(mask.shape) + res_shape[-1] //= size_factor + return core.ShapedArray(res_shape, np.uint16) + + +_nm_pack_lowering = mlir.lower_fun(_nm_pack_impl, multiple_results=False) +mlir.register_lowering(nm_pack_p, _nm_pack_lowering) +dispatch.simple_impl(nm_pack_p) diff --git a/jax/experimental/sparse/random.py b/jax/experimental/sparse/random.py index b22d466dcf5a..f90c2572d282 100644 --- a/jax/experimental/sparse/random.py +++ b/jax/experimental/sparse/random.py @@ -29,7 +29,7 @@ def random_bcoo(key, shape, *, dtype=jnp.float_, indices_dtype=None, """Generate a random BCOO matrix. Args: - key : random.PRNGKey to be passed to ``generator`` function. + key : PRNG key to be passed to ``generator`` function. shape : tuple specifying the shape of the array to be generated. dtype : dtype of the array to be generated. indices_dtype: dtype of the BCOO indices. diff --git a/jax/experimental/sparse/test_util.py b/jax/experimental/sparse/test_util.py index 52b704f5f52c..365c436521b8 100644 --- a/jax/experimental/sparse/test_util.py +++ b/jax/experimental/sparse/test_util.py @@ -15,12 +15,11 @@ from __future__ import annotations -from collections.abc import Iterable, Iterator, Sequence +from collections.abc import Callable, Iterable, Iterator, Sequence import functools import itertools import math -from typing import Any, Callable, Union -from typing import NamedTuple +from typing import Any, NamedTuple import jax from jax import lax @@ -41,10 +40,6 @@ np.complex128: 1e-10, } -GPU_LOWERING_ENABLED = gpu_sparse and ( - gpu_sparse.cuda_is_supported or gpu_sparse.rocm_is_supported -) - def is_sparse(x): return isinstance(x, sparse.JAXSparse) diff --git a/jax/experimental/sparse/transform.py b/jax/experimental/sparse/transform.py index 55460118de81..efdf1888f436 100644 --- a/jax/experimental/sparse/transform.py +++ b/jax/experimental/sparse/transform.py @@ -27,9 +27,9 @@ >>> from jax import random >>> from jax.experimental.sparse import BCOO, sparsify ->>> mat = random.uniform(random.PRNGKey(1701), (5, 5)) +>>> mat = random.uniform(random.key(1701), (5, 5)) >>> mat = mat.at[mat < 0.5].set(0) ->>> vec = random.uniform(random.PRNGKey(42), (5,)) +>>> vec = random.uniform(random.key(42), (5,)) >>> def f(mat, vec): ... return -(jnp.sin(mat) @ vec) @@ -47,9 +47,9 @@ -0.15574613], dtype=float32) """ -from collections.abc import Sequence +from collections.abc import Callable, Sequence import functools -from typing import Any, Callable, NamedTuple, Optional +from typing import Any, NamedTuple import numpy as np @@ -144,7 +144,7 @@ def __init__(self, bufs=()): self._buffers = list(bufs) def _push(self, arr: Array) -> int: - self._buffers.append(jnp.asarray(arr)) # type: ignore + self._buffers.append(jnp.asarray(arr)) return len(self._buffers) - 1 def data(self, spvalue: SparsifyValue) -> Array: @@ -772,7 +772,8 @@ def _while_sparse(spenv, *spvalues, cond_jaxpr, cond_nconsts, body_jaxpr, body_n def _pjit_sparse(spenv, *spvalues, jaxpr, in_shardings, out_shardings, - resource_env, donated_invars, name, keep_unused, inline): + in_layouts, out_layouts, resource_env, donated_invars, name, + keep_unused, inline): if any(donated_invars): raise NotImplementedError("sparse xla_call with donated_invars") @@ -790,12 +791,20 @@ def _pjit_sparse(spenv, *spvalues, jaxpr, in_shardings, out_shardings, sharding_impls.UNSPECIFIED for _ in range(len(sp_call_jaxpr.out_avals) - len(out_shardings)) ) + in_layouts = in_layouts + tuple( + None for _ in range(len(args_flat) - len(in_layouts)) + ) + out_layouts = out_layouts + tuple( + None for _ in range(len(sp_call_jaxpr.out_avals) - len(out_layouts)) + ) out_flat = pjit.pjit_p.bind( *args_flat, jaxpr=sp_call_jaxpr, in_shardings=in_shardings, out_shardings=out_shardings, + in_layouts=in_layouts, + out_layouts=out_layouts, resource_env=resource_env, donated_invars=donated_invars, name=name, @@ -839,15 +848,14 @@ def _scan_sparse(spenv, *spvalues, jaxpr, num_consts, num_carry, **params): sparse_rules_bcoo[lax.scan_p] = _scan_sparse -def _cond_sparse(spenv, pred, *operands, branches, linear, **params): +def _cond_sparse(spenv, pred, *operands, branches, **params): sp_branches, treedefs = zip(*(_sparsify_jaxpr(spenv, jaxpr, *operands) for jaxpr in branches)) _check_tree_and_avals("sparsified true_fun and false_fun output", treedefs[0], sp_branches[0].out_avals, treedefs[1], sp_branches[1].out_avals) - sp_linear = tuple(_duplicate_for_sparse_spvalues(operands, linear)) args, _ = tree_flatten(spvalues_to_arrays(spenv, (pred, *operands))) - out_flat = lax.cond_p.bind(*args, branches=sp_branches, linear=sp_linear, **params) + out_flat = lax.cond_p.bind(*args, branches=sp_branches, **params) out = tree_unflatten(treedefs[0], out_flat) return arrays_to_spvalues(spenv, out) diff --git a/jax/export.py b/jax/export.py new file mode 100644 index 000000000000..13186f886f43 --- /dev/null +++ b/jax/export.py @@ -0,0 +1,36 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +__all__ = ["DisabledSafetyCheck", "Exported", "export", "deserialize", + "maximum_supported_calling_convention_version", + "minimum_supported_calling_convention_version", + "default_export_platform", + "SymbolicScope", "is_symbolic_dim", + "symbolic_shape", "symbolic_args_specs"] + +from jax._src.export._export import ( + DisabledSafetyCheck, + Exported, + export, + deserialize, + maximum_supported_calling_convention_version, + minimum_supported_calling_convention_version, + default_export_platform) + +from jax._src.export import shape_poly_decision # Import only to set the decision procedure +del shape_poly_decision +from jax._src.export.shape_poly import ( + SymbolicScope, + is_symbolic_dim, + symbolic_shape, + symbolic_args_specs) diff --git a/jax/extend/BUILD b/jax/extend/BUILD index 7b367d3a2599..229b6cd6ec9d 100644 --- a/jax/extend/BUILD +++ b/jax/extend/BUILD @@ -14,6 +14,7 @@ load( "//jaxlib:jax.bzl", + "py_library_providing_imports_info", "pytype_strict_library", ) @@ -26,17 +27,25 @@ pytype_strict_library( name = "extend", srcs = ["__init__.py"], deps = [ + ":backend", ":core", + ":ffi", ":linear_util", ":random", ":source_info_util", ], ) -pytype_strict_library( +py_library_providing_imports_info( name = "core", - srcs = ["core.py"], - deps = ["//jax:abstract_arrays"], + srcs = glob(["core/**/*.py"]), + lib_rule = pytype_strict_library, + deps = [ + "//jax", + "//jax:abstract_arrays", + "//jax:ad_util", + "//jax:core", + ], ) pytype_strict_library( @@ -45,6 +54,12 @@ pytype_strict_library( deps = ["//jax:core"], ) +pytype_strict_library( + name = "backend", + srcs = ["backend.py"], + deps = ["//jax"], +) + pytype_strict_library( name = "random", srcs = ["random.py"], @@ -56,3 +71,9 @@ pytype_strict_library( srcs = ["source_info_util.py"], deps = ["//jax:source_info_util"], ) + +pytype_strict_library( + name = "ffi", + srcs = ["ffi.py"], + deps = ["//jax"], +) diff --git a/jax/extend/__init__.py b/jax/extend/__init__.py index 77c81488c5cb..e8ef32935cbf 100644 --- a/jax/extend/__init__.py +++ b/jax/extend/__init__.py @@ -29,7 +29,9 @@ """ from jax.extend import ( + backend as backend, core as core, + ffi as ffi, linear_util as linear_util, random as random, source_info_util as source_info_util, diff --git a/jax/extend/core.py b/jax/extend/backend.py similarity index 86% rename from jax/extend/core.py rename to jax/extend/backend.py index 99beaaf05191..7aa2c8a06ba8 100644 --- a/jax/extend/core.py +++ b/jax/extend/backend.py @@ -1,4 +1,4 @@ -# Copyright 2023 The JAX Authors. +# Copyright 2024 The JAX Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,6 +15,6 @@ # Note: import as is required for names to be exported. # See PEP 484 & https://github.com/google/jax/issues/7570 -from jax._src.abstract_arrays import ( - array_types as array_types +from jax._src.api import ( + clear_backends as clear_backends, ) diff --git a/jax/extend/core/__init__.py b/jax/extend/core/__init__.py new file mode 100644 index 000000000000..2732b1984c1d --- /dev/null +++ b/jax/extend/core/__init__.py @@ -0,0 +1,33 @@ +# Copyright 2023 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Note: import as is required for names to be exported. +# See PEP 484 & https://github.com/google/jax/issues/7570 + +from jax._src.abstract_arrays import ( + array_types as array_types +) + +from jax._src.core import ( + ClosedJaxpr as ClosedJaxpr, + Jaxpr as Jaxpr, + JaxprEqn as JaxprEqn, + jaxpr_as_fun as jaxpr_as_fun, + Literal as Literal, + Primitive as Primitive, + Token as Token, + Var as Var, +) + +from . import primitives as primitives diff --git a/jax/extend/core/primitives.py b/jax/extend/core/primitives.py new file mode 100644 index 000000000000..9f269cbc6287 --- /dev/null +++ b/jax/extend/core/primitives.py @@ -0,0 +1,229 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Note: import as is required for names to be exported. +# See PEP 484 & https://github.com/google/jax/issues/7570 + +from jax._src.ad_util import stop_gradient_p as stop_gradient_p + +from jax._src.core import ( + call_p as call_p, + closed_call_p as closed_call_p +) + +from jax._src.custom_derivatives import ( + custom_jvp_call_p as custom_jvp_call_p, + custom_jvp_call_jaxpr_p as custom_jvp_call_jaxpr_p, + custom_vjp_call_p as custom_vjp_call_p, + custom_vjp_call_jaxpr_p as custom_vjp_call_jaxpr_p, +) + +from jax._src.dispatch import device_put_p as device_put_p + +from jax._src.interpreters.ad import ( + add_jaxvals_p as add_jaxvals_p, + custom_lin_p as custom_lin_p, + zeros_like_p as zeros_like_p, +) + +from jax._src.interpreters.pxla import xla_pmap_p as xla_pmap_p + +from jax._src.lax.lax import ( + abs_p as abs_p, + acos_p as acos_p, + acosh_p as acosh_p, + add_p as add_p, + after_all_p as after_all_p, + and_p as and_p, + argmax_p as argmax_p, + argmin_p as argmin_p, + asin_p as asin_p, + asinh_p as asinh_p, + atan_p as atan_p, + atan2_p as atan2_p, + atanh_p as atanh_p, + bitcast_convert_type_p as bitcast_convert_type_p, + broadcast_in_dim_p as broadcast_in_dim_p, + cbrt_p as cbrt_p, + ceil_p as ceil_p, + clamp_p as clamp_p, + clz_p as clz_p, + complex_p as complex_p, + concatenate_p as concatenate_p, + conj_p as conj_p, + convert_element_type_p as convert_element_type_p, + copy_p as copy_p, + cos_p as cos_p, + cosh_p as cosh_p, + create_token_p as create_token_p, + div_p as div_p, + dot_general_p as dot_general_p, + eq_p as eq_p, + eq_to_p as eq_to_p, + exp_p as exp_p, + exp2_p as exp2_p, + expm1_p as expm1_p, + floor_p as floor_p, + ge_p as ge_p, + gt_p as gt_p, + imag_p as imag_p, + infeed_p as infeed_p, + integer_pow_p as integer_pow_p, + iota_p as iota_p, + is_finite_p as is_finite_p, + le_p as le_p, + le_to_p as le_to_p, + log1p_p as log1p_p, + log_p as log_p, + logistic_p as logistic_p, + lt_p as lt_p, + lt_to_p as lt_to_p, + max_p as max_p, + min_p as min_p, + mul_p as mul_p, + ne_p as ne_p, + neg_p as neg_p, + nextafter_p as nextafter_p, + not_p as not_p, + or_p as or_p, + outfeed_p as outfeed_p, + pad_p as pad_p, + population_count_p as population_count_p, + pow_p as pow_p, + real_p as real_p, + reduce_and_p as reduce_and_p, + reduce_max_p as reduce_max_p, + reduce_min_p as reduce_min_p, + reduce_or_p as reduce_or_p, + reduce_p as reduce_p, + reduce_precision_p as reduce_precision_p, + reduce_prod_p as reduce_prod_p, + reduce_sum_p as reduce_sum_p, + reduce_xor_p as reduce_xor_p, + rem_p as rem_p, + reshape_p as reshape_p, + rev_p as rev_p, + rng_bit_generator_p as rng_bit_generator_p, + rng_uniform_p as rng_uniform_p, + round_p as round_p, + rsqrt_p as rsqrt_p, + select_n_p as select_n_p, + shift_left_p as shift_left_p, + shift_right_arithmetic_p as shift_right_arithmetic_p, + shift_right_logical_p as shift_right_logical_p, + sign_p as sign_p, + sin_p as sin_p, + sinh_p as sinh_p, + sort_p as sort_p, + sqrt_p as sqrt_p, + squeeze_p as squeeze_p, + sub_p as sub_p, + tan_p as tan_p, + tanh_p as tanh_p, + top_k_p as top_k_p, + transpose_p as transpose_p, + xor_p as xor_p, +) + +from jax._src.lax.special import ( + bessel_i0e_p as bessel_i0e_p, + bessel_i1e_p as bessel_i1e_p, + digamma_p as digamma_p, + erfc_p as erfc_p, + erf_inv_p as erf_inv_p, + erf_p as erf_p, + igammac_p as igammac_p, + igamma_grad_a_p as igamma_grad_a_p, + igamma_p as igamma_p, + lgamma_p as lgamma_p, + polygamma_p as polygamma_p, + random_gamma_grad_p as random_gamma_grad_p, + regularized_incomplete_beta_p as regularized_incomplete_beta_p, + zeta_p as zeta_p, +) + +from jax._src.lax.slicing import ( + dynamic_slice_p as dynamic_slice_p, + dynamic_update_slice_p as dynamic_update_slice_p, + gather_p as gather_p, + scatter_add_p as scatter_add_p, + scatter_max_p as scatter_max_p, + scatter_min_p as scatter_min_p, + scatter_mul_p as scatter_mul_p, + scatter_p as scatter_p, + slice_p as slice_p, +) + +from jax._src.lax.convolution import ( + conv_general_dilated_p as conv_general_dilated_p, +) + +from jax._src.lax.windowed_reductions import ( + reduce_window_max_p as reduce_window_max_p, + reduce_window_min_p as reduce_window_min_p, + reduce_window_p as reduce_window_p, + reduce_window_sum_p as reduce_window_sum_p, + select_and_gather_add_p as select_and_gather_add_p, + select_and_scatter_p as select_and_scatter_p, + select_and_scatter_add_p as select_and_scatter_add_p, +) + +from jax._src.lax.control_flow import ( + cond_p as cond_p, + cumlogsumexp_p as cumlogsumexp_p, + cummax_p as cummax_p, + cummin_p as cummin_p, + cumprod_p as cumprod_p, + cumsum_p as cumsum_p, + linear_solve_p as linear_solve_p, + scan_p as scan_p, + while_p as while_p, +) + +from jax._src.lax.fft import ( + fft_p as fft_p, +) + +from jax._src.lax.parallel import ( + all_gather_p as all_gather_p, + all_to_all_p as all_to_all_p, + axis_index_p as axis_index_p, + pmax_p as pmax_p, + pmin_p as pmin_p, + ppermute_p as ppermute_p, + psum_p as psum_p, +) + +from jax._src.lax.ann import ( + approx_top_k_p as approx_top_k_p +) + +from jax._src.lax.linalg import ( + cholesky_p as cholesky_p, + eig_p as eig_p, + eigh_p as eigh_p, + hessenberg_p as hessenberg_p, + lu_p as lu_p, + householder_product_p as householder_product_p, + qr_p as qr_p, + svd_p as svd_p, + triangular_solve_p as triangular_solve_p, + tridiagonal_p as tridiagonal_p, + tridiagonal_solve_p as tridiagonal_solve_p, + schur_p as schur_p, +) + +from jax._src.pjit import sharding_constraint_p as sharding_constraint_p +from jax._src.prng import threefry2x32_p as threefry2x32_p +from jax._src.random import random_gamma_p as random_gamma_p diff --git a/jax/extend/ffi.py b/jax/extend/ffi.py new file mode 100644 index 000000000000..5862d0ecebde --- /dev/null +++ b/jax/extend/ffi.py @@ -0,0 +1,23 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Note: import as is required for names to be exported. +# See PEP 484 & https://github.com/google/jax/issues/7570 + +from jax._src.extend.ffi import ( + ffi_call as ffi_call, + ffi_lowering as ffi_lowering, + include_dir as include_dir, + pycapsule as pycapsule, +) diff --git a/jax/interpreters/ad.py b/jax/interpreters/ad.py index 01416c63cbe4..6663df3ac473 100644 --- a/jax/interpreters/ad.py +++ b/jax/interpreters/ad.py @@ -73,32 +73,18 @@ zeros_like_p as zeros_like_p, ) -from jax import config as _deprecated_config -from jax._src import source_info_util as _deprecated_source_info_util _deprecations = { - # Added Oct 13, 2023: + # Finalized Mar 18, 2024; remove after June 18, 2024 "config": ( "jax.interpreters.ad.config is deprecated. Use jax.config directly.", - _deprecated_config, + None, ), "source_info_util": ( "jax.interpreters.ad.source_info_util is deprecated. Use jax.extend.source_info_util.", - _deprecated_source_info_util, + None, ), } -import typing -if typing.TYPE_CHECKING: - config = _deprecated_config - source_info_util = _deprecated_source_info_util -else: - from jax._src.deprecations import deprecation_getattr as _deprecation_getattr - __getattr__ = _deprecation_getattr(__name__, _deprecations) - del _deprecation_getattr -del typing -del _deprecated_config -del _deprecated_source_info_util - def backward_pass(jaxpr, reduce_axes, transform_stack, consts, primals_in, cotangents_in): if reduce_axes: diff --git a/jax/interpreters/mlir.py b/jax/interpreters/mlir.py index 7a72a478b807..ba476c75e519 100644 --- a/jax/interpreters/mlir.py +++ b/jax/interpreters/mlir.py @@ -37,7 +37,6 @@ dense_bool_elements as dense_bool_elements, dense_bool_array as dense_bool_array, dense_int_array as dense_int_array, - dense_int_array_v6 as dense_int_array_v6, dense_int_elements as dense_int_elements, dtype_to_ir_type as dtype_to_ir_type, emit_python_callback as emit_python_callback, diff --git a/jax/interpreters/pxla.py b/jax/interpreters/pxla.py index a515f2293214..c5aa31a536f6 100644 --- a/jax/interpreters/pxla.py +++ b/jax/interpreters/pxla.py @@ -23,7 +23,6 @@ global_avals_to_results_handler as global_avals_to_results_handler, global_result_handlers as global_result_handlers, parallel_callable as parallel_callable, - shard_arg as shard_arg, shard_args as shard_args, xla_pmap_p as xla_pmap_p, ) diff --git a/jax/interpreters/xla.py b/jax/interpreters/xla.py index 3b042dcea818..bbd5b65d5d3e 100644 --- a/jax/interpreters/xla.py +++ b/jax/interpreters/xla.py @@ -17,77 +17,78 @@ canonicalize_dtype as canonicalize_dtype, canonicalize_dtype_handlers as canonicalize_dtype_handlers, pytype_aval_mappings as pytype_aval_mappings, - - # Deprecations - backend_specific_translations as _deprecated_backend_specific_translations, - register_translation as _deprecated_register_translation, - translations as _deprecated_translations, - xla_destructure as _deprecated_xla_destructure, - TranslationContext as _deprecated_TranslationContext, - TranslationRule as _deprecated_TranslationRule, ) from jax._src.dispatch import ( apply_primitive as apply_primitive, ) -from jax._src import xla_bridge as xb -from jax._src.lib import xla_client as xc # type: ignore +from jax._src import xla_bridge as _xb +from jax._src.lib import xla_client as _xc -xe = xc._xla -Backend = xe.Client +_xe = _xc._xla +Backend = _xe.Client # Deprecations _deprecations = { - # Added Aug 29, 2023: + # Added 2024-06-28 + "xb": ( + "jax.interpreters.xla.xb is deprecated. Use jax.lib.xla_bridge instead.", + _xb + ), + "xc": ( + "jax.interpreters.xla.xc is deprecated. Use jax.lib.xla_client instead.", + _xc, + ), + "xe": ( + "jax.interpreters.xla.xe is deprecated. Use jax.lib.xla_extension instead.", + _xe, + ), + # Finalized 2024-05-13; remove after 2024-08-13 "backend_specific_translations": ( "jax.interpreters.xla.backend_specific_translations is deprecated. " "Register custom primitives via jax.interpreters.mlir instead.", - _deprecated_backend_specific_translations, + None, ), "translations": ( "jax.interpreters.xla.translations is deprecated. " "Register custom primitives via jax.interpreters.mlir instead.", - _deprecated_translations, + None, ), "register_translation": ( "jax.interpreters.xla.register_translation is deprecated. " "Register custom primitives via jax.interpreters.mlir instead.", - _deprecated_register_translation, + None, ), "xla_destructure": ( "jax.interpreters.xla.xla_destructure is deprecated. " "Register custom primitives via jax.interpreters.mlir instead.", - _deprecated_xla_destructure, + None, ), "TranslationRule": ( "jax.interpreters.xla.TranslationRule is deprecated. " "Register custom primitives via jax.interpreters.mlir instead.", - _deprecated_TranslationRule, + None, ), "TranslationContext": ( "jax.interpreters.xla.TranslationContext is deprecated. " "Register custom primitives via jax.interpreters.mlir instead.", - _deprecated_TranslationContext, + None, ), "XlaOp": ( "jax.interpreters.xla.XlaOp is deprecated. " "Register custom primitives via jax.interpreters.mlir instead.", - xc.XlaOp, + None, ), } import typing +from jax._src.deprecations import deprecation_getattr as _deprecation_getattr if typing.TYPE_CHECKING: - backend_specific_translations = _deprecated_backend_specific_translations - translations = _deprecated_translations - register_translation = _deprecated_register_translation - xla_destructure = _deprecated_xla_destructure - TranslationRule = _deprecated_TranslationRule - TranslationContext = _deprecated_TranslationContext - XlaOp = xc.XlaOp + xb = _xb + xc = _xc + xe = _xe else: - from jax._src.deprecations import deprecation_getattr as _deprecation_getattr __getattr__ = _deprecation_getattr(__name__, _deprecations) - del _deprecation_getattr +del _deprecation_getattr del typing diff --git a/jax/lax/__init__.py b/jax/lax/__init__.py index 1629ff42eb23..040786c22735 100644 --- a/jax/lax/__init__.py +++ b/jax/lax/__init__.py @@ -152,6 +152,7 @@ population_count_p as population_count_p, pow as pow, pow_p as pow_p, + ragged_dot as ragged_dot, real as real, real_p as real_p, reciprocal as reciprocal, @@ -210,7 +211,6 @@ tan_p as tan_p, tanh as tanh, tanh_p as tanh_p, - tie_in as _deprecated_tie_in, top_k as top_k, top_k_p as top_k_p, transpose as transpose, @@ -342,6 +342,7 @@ all_to_all_p as all_to_all_p, axis_index as axis_index, axis_index_p as axis_index_p, + pbroadcast as pbroadcast, pmax as pmax, pmax_p as pmax_p, pmean as pmean, @@ -375,18 +376,13 @@ _deprecations = { - # Added January 18 2023 + # Finalized 2024-05-13; remove after 2024-08-13 "tie_in": ( "jax.lax.tie_in is deprecated: it has been a no-op since JAX v0.2.0. " - "Replace z = tie_in(x, y) with z = y.", _deprecated_tie_in, + "Replace z = tie_in(x, y) with z = y.", None, ), } -import typing as _typing -if _typing.TYPE_CHECKING: - tie_in = _deprecated_tie_in -else: - from jax._src.deprecations import deprecation_getattr as _deprecation_getattr - __getattr__ = _deprecation_getattr(__name__, _deprecations) - del _deprecation_getattr -del _typing +from jax._src.deprecations import deprecation_getattr as _deprecation_getattr +__getattr__ = _deprecation_getattr(__name__, _deprecations) +del _deprecation_getattr diff --git a/jax/nn/__init__.py b/jax/nn/__init__.py index 0f008b30ecd1..9eebebd7b8d6 100644 --- a/jax/nn/__init__.py +++ b/jax/nn/__init__.py @@ -41,26 +41,24 @@ soft_sign as soft_sign, softmax as softmax, softplus as softplus, + sparse_plus as sparse_plus, + sparse_sigmoid as sparse_sigmoid, silu as silu, swish as swish, squareplus as squareplus, + mish as mish, ) # Deprecations _deprecations = { - # Added Nov 8, 2023: + # Finalized 2024-05-13; remove after 2024-08-13 "normalize": ( "jax.nn.normalize is deprecated. Use jax.nn.standardize instead.", - standardize, + None, ), } -import typing -if typing.TYPE_CHECKING: - normalize = standardize -else: - from jax._src.deprecations import deprecation_getattr as _deprecation_getattr - __getattr__ = _deprecation_getattr(__name__, _deprecations) - del _deprecation_getattr -del typing +from jax._src.deprecations import deprecation_getattr as _deprecation_getattr +__getattr__ = _deprecation_getattr(__name__, _deprecations) +del _deprecation_getattr diff --git a/jax/numpy/__init__.py b/jax/numpy/__init__.py index 72b2ed7f9cc0..1b9a990f3a0d 100644 --- a/jax/numpy/__init__.py +++ b/jax/numpy/__init__.py @@ -233,6 +233,7 @@ tensordot as tensordot, tile as tile, trace as trace, + trapezoid as trapezoid, transpose as transpose, tri as tri, tril as tril, @@ -252,6 +253,7 @@ unpackbits as unpackbits, unravel_index as unravel_index, unsignedinteger as unsignedinteger, + unstack as unstack, unwrap as unwrap, vander as vander, vdot as vdot, @@ -294,6 +296,7 @@ count_nonzero as count_nonzero, cumsum as cumsum, cumprod as cumprod, + cumulative_sum as cumulative_sum, max as max, mean as mean, median as median, @@ -447,7 +450,15 @@ register_jax_array_methods() del register_jax_array_methods -try: - from numpy import issubsctype as _deprecated_issubsctype -except ImportError: - _deprecated_issubsctype = None + +_deprecations = { + # Deprecated 18 Sept 2023 and removed 06 Feb 2024 + "trapz": ( + "jnp.trapz is deprecated; use jnp.trapezoid instead.", + None + ), +} + +from jax._src.deprecations import deprecation_getattr as _deprecation_getattr +__getattr__ = _deprecation_getattr(__name__, _deprecations) +del _deprecation_getattr diff --git a/jax/numpy/__init__.pyi b/jax/numpy/__init__.pyi index a618f457016d..6e9c7af5eabb 100644 --- a/jax/numpy/__init__.pyi +++ b/jax/numpy/__init__.pyi @@ -2,14 +2,19 @@ from __future__ import annotations import builtins -from typing import Any, Callable, Literal, NamedTuple, Optional, Sequence, TypeVar, Union, overload +from collections.abc import Callable, Sequence +from typing import Any, Literal, NamedTuple, TypeVar, Union, overload from jax._src import core as _core from jax._src import dtypes as _dtypes from jax._src.lax.lax import PrecisionLike from jax._src.lax.slicing import GatherScatterMode +from jax._src.lib import Device from jax._src.numpy.index_tricks import _Mgrid, _Ogrid, CClass as _CClass, RClass as _RClass -from jax._src.typing import Array, ArrayLike, DType, DTypeLike, DimSize, DuckTypedArray, Shape +from jax._src.typing import ( + Array, ArrayLike, DType, DTypeLike, DeprecatedArg, + DimSize, DuckTypedArray, Shape, StaticScalar, +) from jax.numpy import fft as fft, linalg as linalg from jax.sharding import Sharding as _Sharding import numpy as _np @@ -18,8 +23,7 @@ _T = TypeVar('_T') _Axis = Union[None, int, Sequence[int]] -# TODO(jakevdp): use xla_client.Device here -_Device = Any +_Device = Device ComplexWarning: type @@ -30,21 +34,21 @@ def acos(x: ArrayLike, /) -> Array: ... def acosh(x: ArrayLike, /) -> Array: ... def add(x: ArrayLike, y: ArrayLike, /) -> Array: ... def amax(a: ArrayLike, axis: _Axis = ..., out: None = ..., - keepdims: builtins.bool = ..., initial: Optional[ArrayLike] = ..., - where: Optional[ArrayLike] = ...) -> Array: ... + keepdims: builtins.bool = ..., initial: ArrayLike | None = ..., + where: ArrayLike | None = ...) -> Array: ... def amin(a: ArrayLike, axis: _Axis = ..., out: None = ..., - keepdims: builtins.bool = ..., initial: Optional[ArrayLike] = ..., - where: Optional[ArrayLike] = ...) -> Array: ... + keepdims: builtins.bool = ..., initial: ArrayLike | None = ..., + where: ArrayLike | None = ...) -> Array: ... def all(a: ArrayLike, axis: _Axis = ..., out: None = ..., - keepdims: builtins.bool = ..., *, where: Optional[ArrayLike] = ...) -> Array: ... + keepdims: builtins.bool = ..., *, where: ArrayLike | None = ...) -> Array: ... def allclose(a: ArrayLike, b: ArrayLike, rtol: ArrayLike = ..., atol: ArrayLike = ..., equal_nan: builtins.bool = ...) -> Array: ... alltrue = all def angle(z: ArrayLike, deg: builtins.bool = ...) -> Array: ... def any(a: ArrayLike, axis: _Axis = ..., out: None = ..., - keepdims: builtins.bool = ..., *, where: Optional[ArrayLike] = ...) -> Array: ... + keepdims: builtins.bool = ..., *, where: ArrayLike | None = ...) -> Array: ... def append( - arr: ArrayLike, values: ArrayLike, axis: Optional[int] = ... + arr: ArrayLike, values: ArrayLike, axis: int | None = ... ) -> Array: ... def apply_along_axis(func1d: Callable, axis: int, arr: ArrayLike, *args, **kwargs) -> Array: ... @@ -53,9 +57,10 @@ def apply_over_axes( ) -> Array: ... def arange( start: DimSize, - stop: Optional[DimSize] = ..., - step: Optional[DimSize] = ..., - dtype: Optional[DTypeLike] = ..., + stop: DimSize | None = ..., + step: DimSize | None = ..., + dtype: DTypeLike | None = ..., *, + device: _Device | _Sharding | None = ..., ) -> Array: ... def arccos(x: ArrayLike, /) -> Array: ... def arccosh(x: ArrayLike, /) -> Array: ... @@ -66,31 +71,31 @@ def arctan2(x: ArrayLike, y: ArrayLike, /) -> Array: ... def arctanh(x: ArrayLike, /) -> Array: ... def argmax( a: ArrayLike, - axis: Optional[int] = ..., + axis: int | None = ..., out: None = ..., - keepdims: Optional[builtins.bool] = ..., + keepdims: builtins.bool | None = ..., ) -> Array: ... def argmin( a: ArrayLike, - axis: Optional[int] = ..., + axis: int | None = ..., out: None = ..., - keepdims: Optional[builtins.bool] = ..., + keepdims: builtins.bool | None = ..., ) -> Array: ... def argpartition(a: ArrayLike, kth: int, axis: int = ...) -> Array: ... def argsort( a: ArrayLike, - axis: Optional[int] = ..., - kind: str | None = ..., - order: None = ..., + axis: int | None = ..., *, stable: builtins.bool = ..., descending: builtins.bool = ..., + kind: str | None = ..., + order: None = ..., ) -> Array: ... def argwhere( a: ArrayLike, *, - size: Optional[int] = ..., - fill_value: Optional[ArrayLike] = ..., + size: int | None = ..., + fill_value: ArrayLike | None = ..., ) -> Array: ... around = round def array(object: Any, dtype: DTypeLike | None = ..., copy: builtins.bool = True, @@ -102,17 +107,17 @@ def array_equiv(a1: ArrayLike, a2: ArrayLike) -> Array: ... array_repr = _np.array_repr def array_split( ary: ArrayLike, - indices_or_sections: Union[int, Sequence[int], ArrayLike], + indices_or_sections: int | Sequence[int] | ArrayLike, axis: int = ..., ) -> list[Array]: ... array_str = _np.array_str def asarray( - a: Any, dtype: Optional[DTypeLike] = ..., order: Optional[str] = ..., - *, copy: Optional[builtins.bool] = ... + a: Any, dtype: DTypeLike | None = ..., order: str | None = ..., + *, copy: builtins.bool | None = ... ) -> Array: ... def asin(x: ArrayLike, /) -> Array: ... def asinh(x: ArrayLike, /) -> Array: ... -def astype(a: ArrayLike, dtype: Optional[DTypeLike], /, *, copy: builtins.bool = ...) -> Array: ... +def astype(a: ArrayLike, dtype: DTypeLike | None, /, *, copy: builtins.bool = ..., device: _Device | _Sharding | None = ...) -> Array: ... def atan(x: ArrayLike, /) -> Array: ... def atan2(x: ArrayLike, y: ArrayLike, /) -> Array: ... def atanh(x: ArrayLike, /) -> Array: ... @@ -138,19 +143,19 @@ def atleast_3d(x: ArrayLike, /) -> Array: ... def atleast_3d(x: ArrayLike, y: ArrayLike, /, *arys: ArrayLike) -> list[Array]: ... @overload -def average(a: ArrayLike, axis: _Axis = ..., weights: Optional[ArrayLike] = ..., +def average(a: ArrayLike, axis: _Axis = ..., weights: ArrayLike | None = ..., returned: Literal[False] = False, keepdims: builtins.bool = False) -> Array: ... @overload -def average(a: ArrayLike, axis: _Axis = ..., weights: Optional[ArrayLike] = ..., *, +def average(a: ArrayLike, axis: _Axis = ..., weights: ArrayLike | None = ..., *, returned: Literal[True], keepdims: builtins.bool = False) -> tuple[Array, Array]: ... @overload -def average(a: ArrayLike, axis: _Axis = ..., weights: Optional[ArrayLike] = ..., - returned: builtins.bool = False, keepdims: builtins.bool = False) -> Union[Array, tuple[Array, Array]]: ... +def average(a: ArrayLike, axis: _Axis = ..., weights: ArrayLike | None = ..., + returned: builtins.bool = False, keepdims: builtins.bool = False) -> Array | tuple[Array, Array]: ... def bartlett(M: int) -> Array: ... bfloat16: Any -def bincount(x: ArrayLike, weights: Optional[ArrayLike] = ..., - minlength: int = ..., *, length: Optional[int] = ...) -> Array: ... +def bincount(x: ArrayLike, weights: ArrayLike | None = ..., + minlength: int = ..., *, length: int | None = ...) -> Array: ... def bitwise_and(x: ArrayLike, y: ArrayLike, /) -> Array: ... def bitwise_count(x: ArrayLike, /) -> Array: ... def bitwise_invert(x: ArrayLike, /) -> Array: ... @@ -160,7 +165,7 @@ def bitwise_or(x: ArrayLike, y: ArrayLike, /) -> Array: ... def bitwise_right_shift(x: ArrayLike, y: ArrayLike, /) -> Array: ... def bitwise_xor(x: ArrayLike, y: ArrayLike, /) -> Array: ... def blackman(M: int) -> Array: ... -def block(arrays: Union[ArrayLike, Sequence[ArrayLike], Sequence[Sequence[ArrayLike]]]) -> Array: ... +def block(arrays: ArrayLike | Sequence[ArrayLike] | Sequence[Sequence[ArrayLike]]) -> Array: ... bool: Any bool_: Any def broadcast_arrays(*args: ArrayLike) -> list[Array]: ... @@ -169,8 +174,8 @@ def broadcast_arrays(*args: ArrayLike) -> list[Array]: ... def broadcast_shapes(*shapes: Sequence[int]) -> tuple[int, ...]: ... @overload -def broadcast_shapes(*shapes: Sequence[Union[int, _core.Tracer]] - ) -> tuple[Union[int, _core.Tracer], ...]: ... +def broadcast_shapes(*shapes: Sequence[int | _core.Tracer] + ) -> tuple[int | _core.Tracer, ...]: ... def broadcast_to(array: ArrayLike, shape: DimSize | Shape) -> Array: ... c_: _CClass @@ -181,49 +186,56 @@ def ceil(x: ArrayLike, /) -> Array: ... character = _np.character def choose(a: ArrayLike, choices: Sequence[ArrayLike], out: None = ..., mode: str = ...) -> Array: ... -def clip(a: ArrayLike, a_min: Optional[ArrayLike] = ..., - a_max: Optional[ArrayLike] = ..., out: None = ...) -> Array: ... +def clip( + x: ArrayLike | None = ..., + /, + min: ArrayLike | None = ..., + max: ArrayLike | None = ..., + a: ArrayLike | DeprecatedArg | None = ..., + a_min: ArrayLike | DeprecatedArg | None = ..., + a_max: ArrayLike | DeprecatedArg | None = ... +) -> Array: ... def column_stack( - tup: Union[_np.ndarray, Array, Sequence[ArrayLike]] + tup: _np.ndarray | Array | Sequence[ArrayLike] ) -> Array: ... complex128: Any complex64: Any complex_: Any complexfloating = _np.complexfloating -def compress(condition: ArrayLike, a: ArrayLike, axis: Optional[int] = ..., +def compress(condition: ArrayLike, a: ArrayLike, axis: int | None = ..., out: None = ...) -> Array: ... def concat(arrays: Sequence[ArrayLike], /, *, axis: int | None = 0) -> Array: ... def concatenate( - arrays: Union[_np.ndarray, Array, Sequence[ArrayLike]], - axis: Optional[int] = ..., - dtype: Optional[DTypeLike] = ..., + arrays: _np.ndarray | Array | Sequence[ArrayLike], + axis: int | None = ..., + dtype: DTypeLike | None = ..., ) -> Array: ... def conjugate(x: ArrayLike, /) -> Array: ... conj = conjugate def convolve(a: ArrayLike, v: ArrayLike, mode: str = ..., *, precision: PrecisionLike = ..., - preferred_element_type: Optional[DTypeLike] = ...) -> Array: ... -def copy(a: ArrayLike, order: Optional[str] = ...) -> Array: ... + preferred_element_type: DTypeLike | None = ...) -> Array: ... +def copy(a: ArrayLike, order: str | None = ...) -> Array: ... def copysign(x: ArrayLike, y: ArrayLike, /) -> Array: ... -def corrcoef(x: ArrayLike, y: Optional[ArrayLike] = ..., rowvar: builtins.bool = ...) -> Array: ... +def corrcoef(x: ArrayLike, y: ArrayLike | None = ..., rowvar: builtins.bool = ...) -> Array: ... def correlate(a: ArrayLike, v: ArrayLike, mode: str = ..., *, precision: PrecisionLike = ..., - preferred_element_type: Optional[DTypeLike] = ...) -> Array: ... + preferred_element_type: DTypeLike | None = ...) -> Array: ... def cos(x: ArrayLike, /) -> Array: ... def cosh(x: ArrayLike, /) -> Array: ... def count_nonzero(a: ArrayLike, axis: _Axis = ..., keepdims: builtins.bool = ...) -> Array: ... -def cov(m: ArrayLike, y: Optional[ArrayLike] = ..., rowvar: builtins.bool = ..., - bias: builtins.bool = ..., ddof: Optional[int] = ..., - fweights: Optional[ArrayLike] = ..., - aweights: Optional[ArrayLike] = ...) -> Array: ... +def cov(m: ArrayLike, y: ArrayLike | None = ..., rowvar: builtins.bool = ..., + bias: builtins.bool = ..., ddof: int | None = ..., + fweights: ArrayLike | None = ..., + aweights: ArrayLike | None = ...) -> Array: ... def cross( a: ArrayLike, b: ArrayLike, axisa: int = -1, axisb: int = -1, axisc: int = -1, - axis: Optional[int] = ..., + axis: int | None = ..., ) -> Array: ... csingle: Any def cumprod(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike = ..., @@ -231,13 +243,16 @@ def cumprod(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike = ..., cumproduct = cumprod def cumsum(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike = ..., out: None = ...) -> Array: ... +def cumulative_sum(x: ArrayLike, /, *, axis: int | None = ..., + dtype: DTypeLike | None = ..., + include_initial: bool = ...) -> Array: ... def deg2rad(x: ArrayLike, /) -> Array: ... degrees = rad2deg def delete( arr: ArrayLike, - obj: Union[ArrayLike, slice], - axis: Optional[int] = ..., + obj: ArrayLike | slice, + axis: int | None = ..., *, assume_unique_indices: builtins.bool = ..., ) -> Array: ... @@ -249,33 +264,33 @@ def diagonal( a: ArrayLike, offset: ArrayLike = ..., axis1: int = ..., axis2: int = ... ): ... def diff(a: ArrayLike, n: int = ..., axis: int = ..., - prepend: Optional[ArrayLike] = ..., - append: Optional[ArrayLike] = ...) -> Array: ... + prepend: ArrayLike | None = ..., + append: ArrayLike | None = ...) -> Array: ... def digitize(x: ArrayLike, bins: ArrayLike, right: builtins.bool = ...) -> Array: ... divide = true_divide def divmod(x: ArrayLike, y: ArrayLike, /) -> tuple[Array, Array]: ... def dot( a: ArrayLike, b: ArrayLike, *, precision: PrecisionLike = ..., - preferred_element_type: Optional[DTypeLike] = ...) -> Array: ... + preferred_element_type: DTypeLike | None = ...) -> Array: ... double: Any def dsplit( - ary: ArrayLike, indices_or_sections: Union[int, ArrayLike] + ary: ArrayLike, indices_or_sections: int | ArrayLike ) -> list[Array]: ... -def dstack(tup: Union[_np.ndarray, Array, Sequence[ArrayLike]], - dtype: Optional[DTypeLike] = ...) -> Array: ... +def dstack(tup: _np.ndarray | Array | Sequence[ArrayLike], + dtype: DTypeLike | None = ...) -> Array: ... dtype = _np.dtype e: float -def ediff1d(ary: ArrayLike, to_end: Optional[ArrayLike] = ..., - to_begin: Optional[ArrayLike] = ...) -> Array: ... +def ediff1d(ary: ArrayLike, to_end: ArrayLike | None = ..., + to_begin: ArrayLike | None = ...) -> Array: ... @overload def einsum( subscript: str, /, *operands: ArrayLike, out: None = ..., - optimize: str = "optimal", + optimize: str | builtins.bool | list[tuple[int, ...]] = ..., precision: PrecisionLike = ..., - preferred_element_type: Optional[DTypeLike] = ..., + preferred_element_type: DTypeLike | None = ..., _use_xeinsum: builtins.bool = False, _dot_general: Callable[..., Array] = ..., ) -> Array: ... @@ -284,11 +299,11 @@ def einsum( def einsum( arr: ArrayLike, axes: Sequence[Any], /, - *operands: Union[ArrayLike, Sequence[Any]], + *operands: ArrayLike | Sequence[Any], out: None = ..., - optimize: str = "optimal", + optimize: str | builtins.bool | list[tuple[int, ...]] = ..., precision: PrecisionLike = ..., - preferred_element_type: Optional[DTypeLike] = ..., + preferred_element_type: DTypeLike | None = ..., _use_xeinsum: builtins.bool = False, _dot_general: Callable[..., Array] = ..., ) -> Array: ... @@ -297,41 +312,62 @@ def einsum( subscripts, /, *operands, out: None = ..., - optimize: str = ..., + optimize: str | builtins.bool | list[tuple[int, ...]] = ..., precision: PrecisionLike = ..., - preferred_element_type: Optional[DTypeLike] = ..., + preferred_element_type: DTypeLike | None = ..., _use_xeinsum: builtins.bool = ..., _dot_general: Callable[..., Array] = ..., ) -> Array: ... -def einsum_path(subscripts, *operands, optimize = ...): ... -def empty(shape: Any, dtype: Optional[DTypeLike] = ..., - device: Optional[Union[_Device, _Sharding]] = ...) -> Array: ... -def empty_like(prototype: Union[ArrayLike, DuckTypedArray], - dtype: Optional[DTypeLike] = ..., +@overload +def einsum_path( + subscripts: str, /, + *operands: ArrayLike, + optimize: str | builtins.bool | list[tuple[int, ...]] = ..., +) -> tuple[list[tuple[int, ...]], Any]: ... +@overload +def einsum_path( + arr: ArrayLike, + axes: Sequence[Any], /, + *operands: ArrayLike | Sequence[Any], + optimize: str | builtins.bool | list[tuple[int, ...]] = ..., +) -> tuple[list[tuple[int, ...]], Any]: ... +@overload +def einsum_path( + subscripts, /, + *operands: ArrayLike, + optimize: str | builtins.bool | list[tuple[int, ...]] = ..., +) -> tuple[list[tuple[int, ...]], Any]: ... + +def empty(shape: Any, dtype: DTypeLike | None = ..., + device: _Device | _Sharding | None = ...) -> Array: ... +def empty_like(prototype: ArrayLike | DuckTypedArray, + dtype: DTypeLike | None = ..., shape: Any = ..., *, - device: Optional[Union[_Device, _Sharding]] = ...) -> Array: ... + device: _Device | _Sharding | None = ...) -> Array: ... def equal(x: ArrayLike, y: ArrayLike, /) -> Array: ... euler_gamma: float def exp(x: ArrayLike, /) -> Array: ... def exp2(x: ArrayLike, /) -> Array: ... -def expand_dims(a: ArrayLike, axis: Union[int, Sequence[int]]) -> Array: ... +def expand_dims(a: ArrayLike, axis: int | Sequence[int]) -> Array: ... def expm1(x: ArrayLike, /) -> Array: ... -def extract(condition: ArrayLike, arr: ArrayLike) -> Array: ... -def eye(N: DimSize, M: Optional[DimSize] = ..., k: int = ..., - dtype: Optional[DTypeLike] = ...) -> Array: ... +def extract(condition: ArrayLike, arr: ArrayLike, *, + size: int | None = None, fill_value: ArrayLike = 0) -> Array: ... +def eye(N: DimSize, M: DimSize | None = ..., k: int | ArrayLike = ..., + dtype: DTypeLike | None = ..., *, + device: _Device | _Sharding | None = ...) -> Array: ... def fabs(x: ArrayLike, /) -> Array: ... finfo = _dtypes.finfo def fix(x: ArrayLike, out: None = ...) -> Array: ... def flatnonzero( a: ArrayLike, *, - size: Optional[int] = ..., - fill_value: Union[None, ArrayLike, tuple[ArrayLike]] = ..., + size: int | None = ..., + fill_value: None | ArrayLike | tuple[ArrayLike] = ..., ) -> Array: ... flexible = _np.flexible def flip( - m: ArrayLike, axis: Optional[Union[int, Sequence[int]]] = ... + m: ArrayLike, axis: int | Sequence[int] | None = ... ) -> Array: ... def fliplr(m: ArrayLike) -> Array: ... @@ -353,8 +389,9 @@ def fmax(x: ArrayLike, y: ArrayLike, /) -> Array: ... def fmin(x: ArrayLike, y: ArrayLike, /) -> Array: ... def fmod(x: ArrayLike, y: ArrayLike, /) -> Array: ... def frexp(x: ArrayLike, /) -> tuple[Array, Array]: ... -def from_dlpack(x: Any) -> Array: ... -def frombuffer(buffer: Union[bytes, Any], dtype: DTypeLike = ..., +def from_dlpack(x: Any, /, *, device: _Device | _Sharding | None = None, + copy: builtins.bool | None = None) -> Array: ... +def frombuffer(buffer: bytes | Any, dtype: DTypeLike = ..., count: int = ..., offset: int = ...) -> Array: ... def fromfile(*args, **kwargs): ... def fromfunction(function: Callable[..., Array], shape: Any, @@ -364,12 +401,12 @@ def fromstring( string: str, dtype: DTypeLike = ..., count: int = ..., *, sep: str ) -> Array: ... def full(shape: Any, fill_value: ArrayLike, - dtype: Optional[DTypeLike] = ..., *, - device: Optional[Union[_Device, _Sharding]] = ...) -> Array: ... -def full_like(a: Union[ArrayLike, DuckTypedArray], - fill_value: ArrayLike, dtype: Optional[DTypeLike] = ..., + dtype: DTypeLike | None = ..., *, + device: _Device | _Sharding | None = ...) -> Array: ... +def full_like(a: ArrayLike | DuckTypedArray, + fill_value: ArrayLike, dtype: DTypeLike | None = ..., shape: Any = ..., *, - device: Optional[Union[_Device, _Sharding]] = ...) -> Array: ... + device: _Device | _Sharding | None = ...) -> Array: ... def gcd(x1: ArrayLike, x2: ArrayLike) -> Array: ... generic = _np.generic def geomspace( @@ -377,48 +414,48 @@ def geomspace( stop: ArrayLike, num: int = ..., endpoint: builtins.bool = ..., - dtype: Optional[DTypeLike] = ..., + dtype: DTypeLike | None = ..., axis: int = ..., ) -> Array: ... get_printoptions = _np.get_printoptions def gradient(f: ArrayLike, *varargs: ArrayLike, - axis: Optional[Union[int, Sequence[int]]] = ..., - edge_order: Optional[int] = ...) -> Union[Array, list[Array]]: ... + axis: int | Sequence[int] | None = ..., + edge_order: int | None = ...) -> Array | list[Array]: ... def greater(x: ArrayLike, y: ArrayLike, /) -> Array: ... def greater_equal(x: ArrayLike, y: ArrayLike, /) -> Array: ... def hamming(M: int) -> Array: ... def hanning(M: int) -> Array: ... def heaviside(x: ArrayLike, y: ArrayLike, /) -> Array: ... def histogram(a: ArrayLike, bins: ArrayLike = ..., - range: Optional[Sequence[ArrayLike]] = ..., - weights: Optional[ArrayLike] = ..., - density: Optional[builtins.bool] = ...) -> tuple[Array, Array]: ... + range: Sequence[ArrayLike] | None = ..., + weights: ArrayLike | None = ..., + density: builtins.bool | None = ...) -> tuple[Array, Array]: ... def histogram2d( x: ArrayLike, y: ArrayLike, - bins: Union[ArrayLike, Sequence[ArrayLike]] = ..., - range: Optional[Sequence[Union[None, Array, Sequence[ArrayLike]]]] = ..., - weights: Optional[ArrayLike] = ..., - density: Optional[builtins.bool] = ..., + bins: ArrayLike | Sequence[ArrayLike] = ..., + range: Sequence[None | Array | Sequence[ArrayLike]] | None = ..., + weights: ArrayLike | None = ..., + density: builtins.bool | None = ..., ) -> tuple[Array, Array, Array]: ... def histogram_bin_edges(a: ArrayLike, bins: ArrayLike = ..., - range: Union[None, Array, Sequence[ArrayLike]] = ..., - weights: Optional[ArrayLike] = ...) -> Array: ... + range: None | Array | Sequence[ArrayLike] = ..., + weights: ArrayLike | None = ...) -> Array: ... def histogramdd( sample: ArrayLike, - bins: Union[ArrayLike, Sequence[ArrayLike]] = ..., - range: Optional[Sequence[Union[None, Array, Sequence[ArrayLike]]]] = ..., - weights: Optional[ArrayLike] = ..., - density: Optional[builtins.bool] = ..., + bins: ArrayLike | Sequence[ArrayLike] = ..., + range: Sequence[None | Array | Sequence[ArrayLike]] | None = ..., + weights: ArrayLike | None = ..., + density: builtins.bool | None = ..., ) -> tuple[Array, list[Array]]: ... def hsplit( - ary: ArrayLike, indices_or_sections: Union[int, ArrayLike] + ary: ArrayLike, indices_or_sections: int | ArrayLike ) -> list[Array]: ... -def hstack(tup: Union[_np.ndarray, Array, Sequence[ArrayLike]], - dtype: Optional[DTypeLike] = ...) -> Array: ... +def hstack(tup: _np.ndarray | Array | Sequence[ArrayLike], + dtype: DTypeLike | None = ...) -> Array: ... def hypot(x: ArrayLike, y: ArrayLike, /) -> Array: ... def i0(x: ArrayLike) -> Array: ... -def identity(n: DimSize, dtype: Optional[DTypeLike] = ...) -> Array: ... +def identity(n: DimSize, dtype: DTypeLike | None = ...) -> Array: ... iinfo = _dtypes.iinfo def imag(x: ArrayLike, /) -> Array: ... index_exp = _np.index_exp @@ -431,15 +468,15 @@ def indices(dimensions: Sequence[int], dtype: DTypeLike = int32, *, sparse: Literal[True]) -> tuple[Array, ...]: ... @overload def indices(dimensions: Sequence[int], dtype: DTypeLike = int32, - sparse: builtins.bool = False) -> Union[Array, tuple[Array, ...]]: ... + sparse: builtins.bool = False) -> Array | tuple[Array, ...]: ... inexact = _np.inexact inf: float def inner( a: ArrayLike, b: ArrayLike, *, precision: PrecisionLike = ..., - preferred_element_type: Optional[DTypeLike] = ...) -> Array: ... -def insert(arr: ArrayLike, obj: Union[ArrayLike, slice], values: ArrayLike, - axis: Optional[int] = ...) -> Array: ... + preferred_element_type: DTypeLike | None = ...) -> Array: ... +def insert(arr: ArrayLike, obj: ArrayLike | slice, values: ArrayLike, + axis: int | None = ...) -> Array: ... int16: Any int32: Any int4: Any @@ -448,17 +485,17 @@ int8: Any int_: Any integer = _np.integer def interp(x: ArrayLike, xp: ArrayLike, fp: ArrayLike, - left: Union[ArrayLike, str, None] = ..., - right: Union[ArrayLike, str, None] = ..., - period: Optional[ArrayLike] = ...) -> Array: ... + left: ArrayLike | str | None = ..., + right: ArrayLike | str | None = ..., + period: ArrayLike | None = ...) -> Array: ... def intersect1d(ar1: ArrayLike, ar2: ArrayLike, assume_unique: builtins.bool = ..., - return_indices: builtins.bool = ...) -> Union[Array, tuple[Array, Array, Array]]: ... + return_indices: builtins.bool = ...) -> Array | tuple[Array, Array, Array]: ... def invert(x: ArrayLike, /) -> Array: ... def isclose(a: ArrayLike, b: ArrayLike, rtol: ArrayLike = ..., atol: ArrayLike = ..., equal_nan: builtins.bool = ...) -> Array: ... def iscomplex(m: ArrayLike) -> Array: ... def iscomplexobj(x: Any) -> builtins.bool: ... -def isdtype(dtype: DTypeLike, kind: Union[DType, str, tuple[Union[DType, str], ...]]) -> builtins.bool: ... +def isdtype(dtype: DTypeLike, kind: DType | str | tuple[DType | str, ...]) -> builtins.bool: ... def isfinite(x: ArrayLike, /) -> Array: ... def isin(element: ArrayLike, test_elements: ArrayLike, assume_unique: builtins.bool = ..., invert: builtins.bool = ...) -> Array: ... @@ -484,23 +521,23 @@ def lexsort(keys: Sequence[ArrayLike], axis: int = ...) -> Array: ... @overload def linspace(start: ArrayLike, stop: ArrayLike, num: int = 50, endpoint: builtins.bool = True, retstep: Literal[False] = False, - dtype: Optional[DTypeLike] = ..., + dtype: DTypeLike | None = ..., axis: int = 0) -> Array: ... @overload def linspace(start: ArrayLike, stop: ArrayLike, num: int, endpoint: builtins.bool, retstep: Literal[True], - dtype: Optional[DTypeLike] = ..., + dtype: DTypeLike | None = ..., axis: int = 0) -> tuple[Array, Array]: ... @overload def linspace(start: ArrayLike, stop: ArrayLike, num: int = 50, endpoint: builtins.bool = True, *, retstep: Literal[True], - dtype: Optional[DTypeLike] = ..., + dtype: DTypeLike | None = ..., axis: int = 0) -> tuple[Array, Array]: ... @overload def linspace(start: ArrayLike, stop: ArrayLike, num: int = 50, endpoint: builtins.bool = True, retstep: builtins.bool = False, - dtype: Optional[DTypeLike] = ..., - axis: int = 0) -> Union[Array, tuple[Array, Array]]: ... + dtype: DTypeLike | None = ..., + axis: int = 0) -> Array | tuple[Array, Array]: ... def load(*args: Any, **kwargs: Any) -> Array: ... def log(x: ArrayLike, /) -> Array: ... @@ -515,20 +552,20 @@ def logical_or(x: ArrayLike, y: ArrayLike, /) -> Array: ... def logical_xor(x: ArrayLike, y: ArrayLike, /) -> Array: ... def logspace(start: ArrayLike, stop: ArrayLike, num: int = ..., endpoint: builtins.bool = ..., base: ArrayLike = ..., - dtype: Optional[DTypeLike] = ..., axis: int = ...) -> Array: ... + dtype: DTypeLike | None = ..., axis: int = ...) -> Array: ... def mask_indices( n: int, mask_func: Callable, k: int = ... ) -> tuple[Array, ...]: ... def matmul( a: ArrayLike, b: ArrayLike, *, precision: PrecisionLike = ..., - preferred_element_type: Optional[DTypeLike] = ...) -> Array: ... + preferred_element_type: DTypeLike | None = ...) -> Array: ... def matrix_transpose(x: ArrayLike, /) -> Array: ... max = amax def maximum(x: ArrayLike, y: ArrayLike, /) -> Array: ... def mean(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike = ..., out: None = ..., keepdims: builtins.bool = ..., *, - where: Optional[ArrayLike] = ...) -> Array: ... -def median(a: ArrayLike, axis: Optional[Union[int, tuple[int, ...]]] = ..., + where: ArrayLike | None = ...) -> Array: ... +def median(a: ArrayLike, axis: int | tuple[int, ...] | None = ..., out: None = ..., overwrite_input: builtins.bool = ..., keepdims: builtins.bool = ...) -> Array: ... def meshgrid(*xi: ArrayLike, copy: builtins.bool = ..., sparse: builtins.bool = ..., @@ -538,121 +575,121 @@ min = amin def minimum(x: ArrayLike, y: ArrayLike, /) -> Array: ... def mod(x: ArrayLike, y: ArrayLike, /) -> Array: ... def modf(x: ArrayLike, /, out=None) -> tuple[Array, Array]: ... -def moveaxis(a: ArrayLike, source: Union[int, Sequence[int]], - destination: Union[int, Sequence[int]]) -> Array: ... +def moveaxis(a: ArrayLike, source: int | Sequence[int], + destination: int | Sequence[int]) -> Array: ... def multiply(x: ArrayLike, y: ArrayLike, /) -> Array: ... nan: float def nan_to_num(x: ArrayLike, copy: builtins.bool = ..., nan: ArrayLike = ..., - posinf: Optional[ArrayLike] = ..., - neginf: Optional[ArrayLike] = ...) -> Array: ... + posinf: ArrayLike | None = ..., + neginf: ArrayLike | None = ...) -> Array: ... def nanargmax( a: ArrayLike, - axis: Optional[int] = ..., + axis: int | None = ..., out: None = ..., - keepdims: Optional[builtins.bool] = ..., + keepdims: builtins.bool | None = ..., ) -> Array: ... def nanargmin( a: ArrayLike, - axis: Optional[int] = ..., + axis: int | None = ..., out: None = ..., - keepdims: Optional[builtins.bool] = ..., + keepdims: builtins.bool | None = ..., ) -> Array: ... def nancumprod(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike = ..., out: None = ...) -> Array: ... def nancumsum(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike = ..., out: None = ...) -> Array: ... def nanmax(a: ArrayLike, axis: _Axis = ..., out: None = ..., - keepdims: builtins.bool = ..., initial: Optional[ArrayLike] = ..., - where: Optional[ArrayLike] = ...) -> Array: ... + keepdims: builtins.bool = ..., initial: ArrayLike | None = ..., + where: ArrayLike | None = ...) -> Array: ... def nanmean(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike = ..., out: None = ..., keepdims: builtins.bool = ..., - where: Optional[ArrayLike] = ...) -> Array: ... -def nanmedian(a: ArrayLike, axis: Optional[Union[int, tuple[int, ...]]] = ..., + where: ArrayLike | None = ...) -> Array: ... +def nanmedian(a: ArrayLike, axis: int | tuple[int, ...] | None = ..., out: None = ..., overwrite_input: builtins.bool = ..., keepdims: builtins.bool = ...) -> Array: ... def nanmin(a: ArrayLike, axis: _Axis = ..., out: None = ..., - keepdims: builtins.bool = ..., initial: Optional[ArrayLike] = ..., - where: Optional[ArrayLike] = ...) -> Array: ... + keepdims: builtins.bool = ..., initial: ArrayLike | None = ..., + where: ArrayLike | None = ...) -> Array: ... def nanpercentile(a: ArrayLike, q: ArrayLike, - axis: Optional[Union[int, tuple[int, ...]]] = ..., + axis: int | tuple[int, ...] | None = ..., out: None = ..., overwrite_input: builtins.bool = ..., method: str = ..., - keepdims: builtins.bool = ..., interpolation: None = ...) -> Array: ... + keepdims: builtins.bool = ..., *, interpolation: DeprecatedArg | str = ...) -> Array: ... def nanprod(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike = ..., out: None = ..., - keepdims: builtins.bool = ..., initial: Optional[ArrayLike] = ..., - where: Optional[ArrayLike] = ...) -> Array: ... -def nanquantile(a: ArrayLike, q: ArrayLike, axis: Optional[Union[int, tuple[int, ...]]] = ..., + keepdims: builtins.bool = ..., initial: ArrayLike | None = ..., + where: ArrayLike | None = ...) -> Array: ... +def nanquantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = ..., out: None = ..., overwrite_input: builtins.bool = ..., method: str = ..., - keepdims: builtins.bool = ..., interpolation: None = ...) -> Array: ... + keepdims: builtins.bool = ..., *, interpolation: DeprecatedArg | str = ...) -> Array: ... def nanstd(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike = ..., out: None = ..., ddof: int = ..., keepdims: builtins.bool = ..., - where: Optional[ArrayLike] = ...) -> Array: ... + where: ArrayLike | None = ...) -> Array: ... def nansum(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike = ..., out: None = ..., keepdims: builtins.bool = ..., - initial: Optional[ArrayLike] = ..., - where: Optional[ArrayLike] = ...) -> Array: ... + initial: ArrayLike | None = ..., + where: ArrayLike | None = ...) -> Array: ... def nanvar(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike = ..., out: None = ..., ddof: int = 0, keepdims: builtins.bool = False, - where: Optional[ArrayLike] = ...) -> Array: ... + where: ArrayLike | None = ...) -> Array: ... ndarray = Array ndim = _np.ndim def negative(x: ArrayLike, /) -> Array: ... newaxis = ... def nextafter(x: ArrayLike, y: ArrayLike, /) -> Array: ... -def nonzero(a: ArrayLike, *, size: Optional[int] = ..., - fill_value: Union[None, ArrayLike, tuple[ArrayLike, ...]] = ... +def nonzero(a: ArrayLike, *, size: int | None = ..., + fill_value: None | ArrayLike | tuple[ArrayLike, ...] = ... ) -> tuple[Array, ...]: ... def not_equal(x: ArrayLike, y: ArrayLike, /) -> Array: ... number = _np.number object_ = _np.object_ ogrid: _Ogrid -def ones(shape: Any, dtype: Optional[DTypeLike] = ..., - device: Optional[Union[_Device, _Sharding]] = ...) -> Array: ... -def ones_like(a: Union[ArrayLike, DuckTypedArray], - dtype: Optional[DTypeLike] = ..., +def ones(shape: Any, dtype: DTypeLike | None = ..., + device: _Device | _Sharding | None = ...) -> Array: ... +def ones_like(a: ArrayLike | DuckTypedArray, + dtype: DTypeLike | None = ..., shape: Any = ..., *, - device: Optional[Union[_Device, _Sharding]] = ...) -> Array: ... + device: _Device | _Sharding | None = ...) -> Array: ... def outer(a: ArrayLike, b: Array, out: None = ...) -> Array: ... def packbits( - a: ArrayLike, axis: Optional[int] = ..., bitorder: str = ... + a: ArrayLike, axis: int | None = ..., bitorder: str = ... ) -> Array: ... PadValueLike = Union[_T, Sequence[_T], Sequence[Sequence[_T]]] def pad(array: ArrayLike, pad_width: PadValueLike[int | Array | _np.ndarray], - mode: Union[str, Callable[..., Any]] = ..., **kwargs) -> Array: ... + mode: str | Callable[..., Any] = ..., **kwargs) -> Array: ... def partition(a: ArrayLike, kth: int, axis: int = ...) -> Array: ... def percentile(a: ArrayLike, q: ArrayLike, - axis: Optional[Union[int, tuple[int, ...]]] = ..., + axis: int | tuple[int, ...] | None = ..., out: None = ..., overwrite_input: builtins.bool = ..., method: str = ..., - keepdims: builtins.bool = ..., interpolation: None = ...) -> Array: ... + keepdims: builtins.bool = ..., *, interpolation: DeprecatedArg | str = ...) -> Array: ... def permute_dims(x: ArrayLike, /, axes: tuple[int, ...]) -> Array: ... pi: float -def piecewise(x: ArrayLike, condlist: Union[Array, Sequence[ArrayLike]], - funclist: Sequence[Union[ArrayLike, Callable[..., Array]]], +def piecewise(x: ArrayLike, condlist: Array | Sequence[ArrayLike], + funclist: Sequence[ArrayLike | Callable[..., Array]], *args, **kw) -> Array: ... def place(arr: ArrayLike, mask: ArrayLike, vals: ArrayLike, *, inplace: builtins.bool = ...) -> Array: ... -def poly(seq_of_zeros: Array) -> Array: ... -def polyadd(a1: Array, a2: Array) -> Array: ... -def polyder(p: Array, m: int = ...) -> Array: ... +def poly(seq_of_zeros: ArrayLike) -> Array: ... +def polyadd(a1: ArrayLike, a2: ArrayLike) -> Array: ... +def polyder(p: ArrayLike, m: int = ...) -> Array: ... def polydiv(u: ArrayLike, v: ArrayLike, *, trim_leading_zeros: builtins.bool = ...) -> tuple[Array, Array]: ... -def polyfit(x: Array, y: Array, deg: int, rcond: Optional[float] = ..., - full: builtins.bool = ..., w: Optional[Array] = ..., cov: builtins.bool = ... - ) -> Union[Array, tuple[Array, ...]]: ... -def polyint(p: Array, m: int = ..., k: Optional[int] = ...) -> Array: ... +def polyfit(x: ArrayLike, y: ArrayLike, deg: int, rcond: float | None = ..., + full: builtins.bool = ..., w: ArrayLike | None = ..., cov: builtins.bool = ... + ) -> Array | tuple[Array, ...]: ... +def polyint(p: ArrayLike, m: int = ..., k: int | ArrayLike | None = ...) -> Array: ... def polymul(a1: ArrayLike, a2: ArrayLike, *, trim_leading_zeros: builtins.bool = ...) -> Array: ... -def polysub(a1: Array, a2: Array) -> Array: ... -def polyval(p: Array, x: Array, *, unroll: int = ...) -> Array: ... +def polysub(a1: ArrayLike, a2: ArrayLike) -> Array: ... +def polyval(p: ArrayLike, x: ArrayLike, *, unroll: int = ...) -> Array: ... def positive(x: ArrayLike, /) -> Array: ... def pow(x: ArrayLike, y: ArrayLike, /) -> Array: ... def power(x: ArrayLike, y: ArrayLike, /) -> Array: ... printoptions = _np.printoptions def prod(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike = ..., out: None = ..., keepdims: builtins.bool = ..., - initial: Optional[ArrayLike] = ..., where: Optional[ArrayLike] = ..., + initial: ArrayLike | None = ..., where: ArrayLike | None = ..., promote_integers: builtins.bool = ...) -> Array: ... product = prod promote_types = _np.promote_types @@ -660,9 +697,9 @@ def ptp(a: ArrayLike, axis: _Axis = ..., out: None = ..., keepdims: builtins.bool = ...) -> Array: ... def put(a: ArrayLike, ind: ArrayLike, v: ArrayLike, mode: str | None = ..., *, inplace: builtins.bool = ...) -> Array: ... -def quantile(a: ArrayLike, q: ArrayLike, axis: Optional[Union[int, tuple[int, ...]]] = ..., +def quantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = ..., out: None = ..., overwrite_input: builtins.bool = ..., method: str = ..., - keepdims: builtins.bool = ..., interpolation: None = ...) -> Array: ... + keepdims: builtins.bool = ..., *, interpolation: DeprecatedArg | str = ...) -> Array: ... r_: _RClass def rad2deg(x: ArrayLike, /) -> Array: ... radians = deg2rad @@ -673,18 +710,19 @@ def real(x: ArrayLike, /) -> Array: ... def reciprocal(x: ArrayLike, /) -> Array: ... register_jax_array_methods: Any def remainder(x: ArrayLike, y: ArrayLike, /) -> Array: ... -def repeat(a: ArrayLike, repeats: ArrayLike, axis: Optional[int] = ..., *, - total_repeat_length: Optional[int] = ...) -> Array: ... +def repeat(a: ArrayLike, repeats: ArrayLike, axis: int | None = ..., *, + total_repeat_length: int | None = ...) -> Array: ... def reshape( - a: ArrayLike, newshape: Union[DimSize, Shape], order: str = ... + a: ArrayLike, shape: DimSize | Shape = ..., + newshape: DimSize | Shape | None = ..., order: str = ... ) -> Array: ... def resize(a: ArrayLike, new_shape: Shape) -> Array: ... def result_type(*args: Any) -> DType: ... def right_shift(x: ArrayLike, y: ArrayLike, /) -> Array: ... def rint(x: ArrayLike, /) -> Array: ... -def roll(a: ArrayLike, shift: Union[ArrayLike, Sequence[int]], - axis: Optional[Union[int, Sequence[int]]] = ...) -> Array: ... +def roll(a: ArrayLike, shift: ArrayLike | Sequence[int], + axis: int | Sequence[int] | None = ...) -> Array: ... def rollaxis(a: ArrayLike, axis: int, start: int = 0) -> Array: ... def roots(p: ArrayLike, *, strip_zeros: builtins.bool = ...) -> Array: ... def rot90(m: ArrayLike, k: int = ..., axes: tuple[int, int] = ...) -> Array: ... @@ -694,7 +732,7 @@ s_ = _np.s_ save = _np.save savez = _np.savez def searchsorted(a: ArrayLike, v: ArrayLike, side: str = ..., - sorter: None = ..., *, method: str = ...) -> Array: ... + sorter: ArrayLike | None = ..., *, method: str = ...) -> Array: ... def select( condlist: Sequence[ArrayLike], choicelist: Sequence[ArrayLike], @@ -706,8 +744,8 @@ def setdiff1d( ar2: ArrayLike, assume_unique: builtins.bool = ..., *, - size: Optional[int] = ..., - fill_value: Optional[ArrayLike] = ..., + size: int | None = ..., + fill_value: ArrayLike | None = ..., ) -> Array: ... def setxor1d(ar1: ArrayLike, ar2: ArrayLike, assume_unique: builtins.bool = ...) -> Array: ... shape = _np.shape @@ -722,34 +760,34 @@ size = _np.size sometrue = any def sort( a: ArrayLike, - axis: Optional[int] = ..., - kind: str | None = ..., - order: None = ..., + axis: int | None = ..., *, stable: builtins.bool = ..., descending: builtins.bool = ..., + kind: str | None = ..., + order: None = ..., ) -> Array: ... def sort_complex(a: ArrayLike) -> Array: ... def split( ary: ArrayLike, - indices_or_sections: Union[int, Sequence[int], ArrayLike], + indices_or_sections: int | Sequence[int] | ArrayLike, axis: int = ..., ) -> list[Array]: ... def sqrt(x: ArrayLike, /) -> Array: ... def square(x: ArrayLike, /) -> Array: ... def squeeze( - a: ArrayLike, axis: Optional[Union[int, Sequence[int]]] = ... + a: ArrayLike, axis: int | Sequence[int] | None = ... ) -> Array: ... def stack( - arrays: Union[_np.ndarray, Array, Sequence[ArrayLike]], + arrays: _np.ndarray | Array | Sequence[ArrayLike], axis: int = ..., out: None = ..., - dtype: Optional[DTypeLike] = ..., + dtype: DTypeLike | None = ..., ) -> Array: ... def std(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike = ..., out: None = ..., ddof: int = ..., keepdims: builtins.bool = ..., *, - where: Optional[ArrayLike] = ...) -> Array: ... + where: ArrayLike | None = ..., correction: int | float | None = ...) -> Array: ... def subtract(x: ArrayLike, y: ArrayLike, /) -> Array: ... def sum( a: ArrayLike, @@ -757,50 +795,53 @@ def sum( dtype: DTypeLike = ..., out: None = ..., keepdims: builtins.bool = ..., - initial: Optional[ArrayLike] = ..., - where: Optional[ArrayLike] = ..., + initial: ArrayLike | None = ..., + where: ArrayLike | None = ..., promote_integers: builtins.bool = ..., ) -> Array: ... def swapaxes(a: ArrayLike, axis1: int, axis2: int) -> Array: ... def take( a: ArrayLike, indices: ArrayLike, - axis: Optional[int] = ..., + axis: int | None = ..., out: None = ..., - mode: Optional[str] = ..., + mode: str | None = ..., unique_indices: builtins.bool = ..., indices_are_sorted: builtins.bool = ..., - fill_value: Optional[ArrayLike] = ..., + fill_value: StaticScalar | None = ..., ) -> Array: ... def take_along_axis( arr: ArrayLike, indices: ArrayLike, - axis: Optional[int], - mode: Optional[Union[str, GatherScatterMode]] = ..., + axis: int | None, + mode: str | GatherScatterMode | None = ..., + fill_value: StaticScalar | None = None, ) -> Array: ... def tan(x: ArrayLike, /) -> Array: ... def tanh(x: ArrayLike, /) -> Array: ... def tensordot(a: ArrayLike, b: ArrayLike, - axes: Union[int, Sequence[int], Sequence[Sequence[int]]] = ..., + axes: int | Sequence[int] | Sequence[Sequence[int]] = ..., *, precision: PrecisionLike = ..., - preferred_element_type: Optional[DTypeLike] = ...) -> Array: ... -def tile(A: ArrayLike, reps: Union[DimSize, Sequence[DimSize]]) -> Array: ... -def trace(a: ArrayLike, offset: int = ..., axis1: int = ..., axis2: int = ..., - dtype: Optional[DTypeLike] = ..., out: None = ...) -> Array: ... -def transpose(a: ArrayLike, axes: Optional[Sequence[int]] = ...) -> Array: ... + preferred_element_type: DTypeLike | None = ...) -> Array: ... +def tile(A: ArrayLike, reps: DimSize | Sequence[DimSize]) -> Array: ... +def trace(a: ArrayLike, offset: int | ArrayLike = ..., axis1: int = ..., axis2: int = ..., + dtype: DTypeLike | None = ..., out: None = ...) -> Array: ... +def transpose(a: ArrayLike, axes: Sequence[int] | None = ...) -> Array: ... +def trapezoid(y: ArrayLike, x: ArrayLike | None = None, dx: ArrayLike = ..., + axis: int = ...) -> Array: ... def tri( - N: int, M: Optional[int] = ..., k: int = ..., dtype: DTypeLike = ... + N: int, M: int | None = ..., k: int = ..., dtype: DTypeLike = ... ) -> Array: ... def tril(m: ArrayLike, k: int = ...) -> Array: ... def tril_indices( - n: int, k: int = ..., m: Optional[int] = ... + n: int, k: int = ..., m: int | None = ... ) -> tuple[Array, Array]: ... def tril_indices_from(arr: ArrayLike, k: int = ...) -> tuple[Array, Array]: ... def fill_diagonal(a: ArrayLike, val: ArrayLike, wrap: builtins.bool = ..., *, inplace: builtins.bool = ...) -> Array: ... def trim_zeros(filt: ArrayLike, trim: str = ...) -> Array: ... def triu(m: ArrayLike, k: int = ...) -> Array: ... def triu_indices( - n: int, k: int = ..., m: Optional[int] = ... + n: int, k: int = ..., m: int | None = ... ) -> tuple[Array, Array]: ... def triu_indices_from(arr: ArrayLike, k: int = ...) -> tuple[Array, Array]: ... def true_divide(x: ArrayLike, y: ArrayLike, /) -> Array: ... @@ -815,8 +856,8 @@ def union1d( ar1: ArrayLike, ar2: ArrayLike, *, - size: Optional[int] = ..., - fill_value: Optional[ArrayLike] = ..., + size: int | None = ..., + fill_value: ArrayLike | None = ..., ) -> Array: ... class _UniqueAllResult(NamedTuple): values: Array @@ -830,66 +871,71 @@ class _UniqueInverseResult(NamedTuple): values: Array inverse_indices: Array def unique(ar: ArrayLike, return_index: builtins.bool = ..., return_inverse: builtins.bool = ..., - return_counts: builtins.bool = ..., axis: Optional[int] = ..., - *, equal_nan: builtins.bool = ..., size: Optional[int] = ..., - fill_value: Optional[ArrayLike] = ... + return_counts: builtins.bool = ..., axis: int | None = ..., + *, equal_nan: builtins.bool = ..., size: int | None = ..., + fill_value: ArrayLike | None = ... ): ... -def unique_all(x: ArrayLike, /) -> _UniqueAllResult: ... -def unique_counts(x: ArrayLike, /) -> _UniqueCountsResult: ... -def unique_inverse(x: ArrayLike, /) -> _UniqueInverseResult: ... -def unique_values(x: ArrayLike, /) -> Array: ... +def unique_all(x: ArrayLike, /, *, size: int | None = ..., + fill_value: ArrayLike | None = ...) -> _UniqueAllResult: ... +def unique_counts(x: ArrayLike, /, *, size: int | None = ..., + fill_value: ArrayLike | None = ...) -> _UniqueCountsResult: ... +def unique_inverse(x: ArrayLike, /, *, size: int | None = ..., + fill_value: ArrayLike | None = ...) -> _UniqueInverseResult: ... +def unique_values(x: ArrayLike, /, *, size: int | None = ..., + fill_value: ArrayLike | None = ...) -> Array: ... def unpackbits( a: ArrayLike, - axis: Optional[int] = ..., - count: Optional[ArrayLike] = ..., + axis: int | None = ..., + count: ArrayLike | None = ..., bitorder: str = ..., ) -> Array: ... def unravel_index(indices: ArrayLike, shape: Shape) -> tuple[Array, ...]: ... unsignedinteger = _np.unsignedinteger -def unwrap(p: ArrayLike, discont: Optional[ArrayLike] = ..., +def unstack(x: ArrayLike , /, *, axis: int = ...) -> tuple[Array, ...]: ... +def unwrap(p: ArrayLike, discont: ArrayLike | None = ..., axis: int = ..., period: ArrayLike = ...) -> Array: ... def vander( - x: ArrayLike, N: Optional[int] = ..., increasing: builtins.bool = ... + x: ArrayLike, N: int | None = ..., increasing: builtins.bool = ... ) -> Array: ... def var(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike = ..., out: None = ..., ddof: int = ..., keepdims: builtins.bool = ..., *, - where: Optional[ArrayLike] = ...) -> Array: ... + where: ArrayLike | None = ..., correction: int | float | None = ...) -> Array: ... def vdot( a: ArrayLike, b: ArrayLike, *, precision: PrecisionLike = ..., - preferred_element_type: Optional[DTypeLike] = ...) -> Array: ... + preferred_element_type: DTypeLike | None = ...) -> Array: ... def vecdot(x1: ArrayLike, x2: ArrayLike, /, *, axis: int = ..., precision: PrecisionLike = ..., - preferred_element_type: Optional[DTypeLike] = ...) -> Array: ... + preferred_element_type: DTypeLike | None = ...) -> Array: ... def vsplit( - ary: ArrayLike, indices_or_sections: Union[int, ArrayLike] + ary: ArrayLike, indices_or_sections: int | ArrayLike ) -> list[Array]: ... -def vstack(tup: Union[_np.ndarray, Array, Sequence[ArrayLike]], - dtype: Optional[DTypeLike] = ...) -> Array: ... +def vstack(tup: _np.ndarray | Array | Sequence[ArrayLike], + dtype: DTypeLike | None = ...) -> Array: ... @overload def where(condition: ArrayLike, x: Literal[None] = ..., y: Literal[None] = ..., - /, *, size: Optional[int] = ..., - fill_value: Union[None, ArrayLike, tuple[ArrayLike, ...]] = ... + /, *, size: int | None = ..., + fill_value: None | ArrayLike | tuple[ArrayLike, ...] = ... ) -> tuple[Array, ...]: ... @overload def where(condition: ArrayLike, x: ArrayLike, y: ArrayLike, /, *, - size: Optional[int] = ..., - fill_value: Union[None, ArrayLike, tuple[ArrayLike, ...]] = ... + size: int | None = ..., + fill_value: None | ArrayLike | tuple[ArrayLike, ...] = ... ) -> Array: ... @overload -def where(condition: ArrayLike, x: Optional[ArrayLike] = ..., - y: Optional[ArrayLike] = ..., /, *, size: Optional[int] = ..., - fill_value: Union[None, ArrayLike, tuple[ArrayLike, ...]] = ... - ) -> Union[Array, tuple[Array, ...]]: ... - -def zeros(shape: Any, dtype: Optional[DTypeLike] = ..., - device: Optional[Union[_Device, _Sharding]] = ...) -> Array: ... -def zeros_like(a: Union[ArrayLike, DuckTypedArray], - dtype: Optional[DTypeLike] = ..., +def where(condition: ArrayLike, x: ArrayLike | None = ..., + y: ArrayLike | None = ..., /, *, size: int | None = ..., + fill_value: None | ArrayLike | tuple[ArrayLike, ...] = ... + ) -> Array | tuple[Array, ...]: ... + +def zeros(shape: Any, dtype: DTypeLike | None = ..., + device: _Device | _Sharding | None = ...) -> Array: ... +def zeros_like(a: ArrayLike | DuckTypedArray, + dtype: DTypeLike | None = ..., shape: Any = ..., *, - device: Optional[Union[_Device, _Sharding]] = ...) -> Array: ... + device: _Device | _Sharding | None = ...) -> Array: ... def vectorize(pyfunc, *, excluded = ..., signature = ...) -> Callable: ... diff --git a/jax/numpy/linalg.py b/jax/numpy/linalg.py index a4e65fc32fb9..c342fde0ae6e 100644 --- a/jax/numpy/linalg.py +++ b/jax/numpy/linalg.py @@ -17,6 +17,7 @@ from jax._src.numpy.linalg import ( cholesky as cholesky, + cond as cond, cross as cross, det as det, diagonal as diagonal, @@ -31,6 +32,7 @@ matrix_power as matrix_power, matrix_rank as matrix_rank, matrix_transpose as matrix_transpose, + multi_dot as multi_dot, norm as norm, outer as outer, pinv as pinv, @@ -40,12 +42,9 @@ svd as svd, svdvals as svdvals, tensordot as tensordot, - vector_norm as vector_norm, - vecdot as vecdot, -) -from jax._src.third_party.numpy.linalg import ( - cond as cond, - multi_dot as multi_dot, tensorinv as tensorinv, tensorsolve as tensorsolve, + trace as trace, + vector_norm as vector_norm, + vecdot as vecdot, ) diff --git a/jax/random.py b/jax/random.py index c06f48b3583b..f951fce406a8 100644 --- a/jax/random.py +++ b/jax/random.py @@ -22,24 +22,25 @@ >>> seed = 1701 >>> num_steps = 100 ->>> key = jax.random.PRNGKey(seed) +>>> key = jax.random.key(seed) >>> for i in range(num_steps): ... key, subkey = jax.random.split(key) ... params = compiled_update(subkey, params, next(batches)) # doctest: +SKIP -PRNG Keys +PRNG keys --------- Unlike the *stateful* pseudorandom number generators (PRNGs) that users of NumPy and SciPy may be accustomed to, JAX random functions all require an explicit PRNG state to be passed as a first argument. -The random state is described by two unsigned 32-bit integers that we call a **key**, -usually generated by the :py:func:`jax.random.PRNGKey` function:: +The random state is described by a special array element type that we call a **key**, +usually generated by the :py:func:`jax.random.key` function:: >>> from jax import random - >>> key = random.PRNGKey(0) + >>> key = random.key(0) >>> key - Array([0, 0], dtype=uint32) + Array((), dtype=key) overlaying: + [0 0] This key can then be used in any of JAX's random number generation routines:: @@ -57,11 +58,47 @@ >>> random.uniform(subkey) Array(0.10536897, dtype=float32) +.. note:: + + Typed key arrays, with element types such as ``key`` above, + were introduced in JAX v0.4.16. Before then, keys were + conventionally represented in ``uint32`` arrays, whose final + dimension represented the key's bit-level representation. + + Both forms of key array can still be created and used with the + :mod:`jax.random` module. New-style typed key arrays are made with + :py:func:`jax.random.key`. Legacy ``uint32`` key arrays are made + with :py:func:`jax.random.PRNGKey`. + + To convert between the two, use :py:func:`jax.random.key_data` and + :py:func:`jax.random.wrap_key_data`. The legacy key format may be + needed when interfacing with systems outside of JAX (e.g. exporting + arrays to a serializable format), or when passing keys to JAX-based + libraries that assume the legacy format. + + Otherwise, typed keys are recommended. Caveats of legacy keys + relative to typed ones include: + + * They have an extra trailing dimension. + + * They have a numeric dtype (``uint32``), allowing for operations + that are typically not meant to be carried out over keys, such as + integer arithmetic. + + * They do not carry information about the RNG implementation. When + legacy keys are passed to :mod:`jax.random` functions, a global + configuration setting determines the RNG implementation (see + "Advanced RNG configuration" below). + + To learn more about this upgrade, and the design of key types, see + `JEP 9263 + `_. + Advanced -------- -Design and Context -================== +Design and background +===================== **TLDR**: JAX PRNG = `Threefry counter PRNG `_ + a functional array-oriented `splitting model `_ @@ -79,16 +116,19 @@ Advanced RNG configuration ========================== -JAX provides several PRNG implementations (controlled by the -`jax_default_prng_impl` flag). +JAX provides several PRNG implementations. A specific one can be +selected with the optional `impl` keyword argument to +`jax.random.key`. When no `impl` option is passed to the `key` +constructor, the implementation is determined by the global +`jax_default_prng_impl` configuration flag. -- **default** +- **default**, `"threefry2x32"`: `A counter-based PRNG built around the Threefry hash function `_. - *experimental* A PRNG that thinly wraps the XLA Random Bit Generator (RBG) algorithm. See `TF doc `_. - - "rbg" uses ThreeFry for splitting, and XLA RBG for data generation. - - "unsafe_rbg" exists only for demonstration purposes, using RBG both for + - `"rbg"` uses ThreeFry for splitting, and XLA RBG for data generation. + - `"unsafe_rbg"` exists only for demonstration purposes, using RBG both for splitting (using an untested made up algorithm) and generating. The random streams generated by these experimental implementations haven't @@ -116,8 +156,9 @@ identical across JAX/XLA versions ✅ ✅ ================================= ======== ========= === ========== ===== ============ -(*): with jax_threefry_partitionable=1 set -(**): with XLA_FLAGS=--xla_tpu_spmd_rng_bit_generator_unsafe=1 set +(*): with ``jax_threefry_partitionable=1`` set + +(**): with ``XLA_FLAGS=--xla_tpu_spmd_rng_bit_generator_unsafe=1`` set The difference between "rbg" and "unsafe_rbg" is that while "rbg" uses a less robust/studied hash function for random value generation (but not for @@ -126,7 +167,7 @@ less safe in the sense that the quality of random streams it generates from different keys is less well understood. -For more about jax_threefry_partitionable, see +For more about `jax_threefry_partitionable`, see https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html#generating-random-numbers """ @@ -144,6 +185,7 @@ cauchy as cauchy, chisquare as chisquare, choice as choice, + clone as clone, dirichlet as dirichlet, double_sided_maxwell as double_sided_maxwell, exponential as exponential, diff --git a/jax/scipy/linalg.py b/jax/scipy/linalg.py index 67d44bcfa96f..059f927ec46c 100644 --- a/jax/scipy/linalg.py +++ b/jax/scipy/linalg.py @@ -26,6 +26,7 @@ expm as expm, expm_frechet as expm_frechet, hessenberg as hessenberg, + hilbert as hilbert, inv as inv, lu as lu, lu_factor as lu_factor, diff --git a/jax/scipy/special.py b/jax/scipy/special.py index 9b74fb3746dc..e244c3705af3 100644 --- a/jax/scipy/special.py +++ b/jax/scipy/special.py @@ -17,10 +17,10 @@ from jax._src.scipy.special import ( bernoulli as bernoulli, + bessel_jn as bessel_jn, + beta as beta, betainc as betainc, betaln as betaln, - beta as beta, - bessel_jn as bessel_jn, digamma as digamma, entr as entr, erf as erf, @@ -31,30 +31,33 @@ expit as expit, expn as expn, factorial as factorial, + gamma as gamma, gammainc as gammainc, gammaincc as gammaincc, gammaln as gammaln, - gamma as gamma, + gammasgn as gammasgn, + hyp1f1 as hyp1f1, i0 as i0, i0e as i0e, i1 as i1, i1e as i1e, + kl_div as kl_div, + log_ndtr as log_ndtr, + log_softmax as log_softmax, logit as logit, logsumexp as logsumexp, lpmn as lpmn, lpmn_values as lpmn_values, multigammaln as multigammaln, - log_ndtr as log_ndtr, ndtr as ndtr, ndtri as ndtri, + poch as poch, polygamma as polygamma, + rel_entr as rel_entr, + softmax as softmax, spence as spence, sph_harm as sph_harm, - xlogy as xlogy, xlog1py as xlog1py, + xlogy as xlogy, zeta as zeta, - kl_div as kl_div, - rel_entr as rel_entr, - poch as poch, - hyp1f1 as hyp1f1, ) diff --git a/jax/sharding.py b/jax/sharding.py index 18caa9eb0f57..fe221f90af67 100644 --- a/jax/sharding.py +++ b/jax/sharding.py @@ -17,7 +17,7 @@ from jax._src.sharding import Sharding as Sharding from jax._src.sharding_impls import ( - XLACompatibleSharding as XLACompatibleSharding, + XLACompatibleSharding as _deprecated_XLACompatibleSharding, NamedSharding as NamedSharding, SingleDeviceSharding as SingleDeviceSharding, PmapSharding as PmapSharding, @@ -28,3 +28,23 @@ PartitionSpec as PartitionSpec, ) from jax._src.interpreters.pxla import Mesh as Mesh + +_deprecations = { + # Added Jun 4, 2024. + "XLACompatibleSharding": ( + ( + "jax.sharding.XLACompatibleSharding is deprecated. Use" + " jax.sharding.Sharding instead." + ), + _deprecated_XLACompatibleSharding, + ) +} + +import typing +if typing.TYPE_CHECKING: + XLACompatibleSharding = _deprecated_XLACompatibleSharding +else: + from jax._src.deprecations import deprecation_getattr as _deprecation_getattr + __getattr__ = _deprecation_getattr(__name__, _deprecations) + del _deprecation_getattr +del typing diff --git a/jax/stages.py b/jax/stages.py index 0a6e6082f2ea..6ffc3144c3bc 100644 --- a/jax/stages.py +++ b/jax/stages.py @@ -30,4 +30,6 @@ Lowered as Lowered, Wrapped as Wrapped, ArgInfo as ArgInfo, + OutInfo as OutInfo, + Traced as Traced, ) diff --git a/jax/tools/BUILD b/jax/tools/BUILD index f4058c0c6760..80f757ca421c 100644 --- a/jax/tools/BUILD +++ b/jax/tools/BUILD @@ -12,12 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +load("@rules_python//python:defs.bzl", "py_library") load( "//jaxlib:jax.bzl", "py_deps", "pytype_strict_library", ) -load("@rules_python//python:defs.bzl", "py_library") licenses(["notice"]) diff --git a/jax/tools/build_utils.py b/jax/tools/build_utils.py index 7d961cb11df8..84cc697d1894 100644 --- a/jax/tools/build_utils.py +++ b/jax/tools/build_utils.py @@ -75,10 +75,17 @@ def build_wheel( for wheel in glob.glob(os.path.join(sources_path, "dist", "*.whl")): output_file = os.path.join(output_path, os.path.basename(wheel)) sys.stderr.write(f"Output wheel: {output_file}\n\n") - sys.stderr.write(f"To install the newly-built {package_name} wheel, run:\n") + sys.stderr.write(f"To install the newly-built {package_name} wheel " + + "on system Python, run:\n") sys.stderr.write(f" pip install {output_file} --force-reinstall\n\n") - shutil.copy(wheel, output_path) + py_version = ".".join(platform.python_version_tuple()[:-1]) + sys.stderr.write(f"To install the newly-built {package_name} wheel " + + "on hermetic Python, run:\n") + sys.stderr.write(f' echo -e "\\n{output_file}" >> build/requirements.in\n') + sys.stderr.write(" bazel run //build:requirements.update" + + f" --repo_env=HERMETIC_PYTHON_VERSION={py_version}\n\n") + shutil.copy(wheel, output_path) def build_editable( sources_path: str, output_path: str, package_name: str @@ -100,3 +107,13 @@ def update_setup_with_cuda_version(file_dir: pathlib.Path, cuda_version: str): ) with open(src_file, "w") as f: f.write(content) + +def update_setup_with_rocm_version(file_dir: pathlib.Path, rocm_version: str): + src_file = file_dir / "setup.py" + with open(src_file) as f: + content = f.read() + content = content.replace( + "rocm_version = 0 # placeholder", f"rocm_version = {rocm_version}" + ) + with open(src_file, "w") as f: + f.write(content) diff --git a/jax/tools/jax_to_ir.py b/jax/tools/jax_to_ir.py index a0f86bd62f1d..f3e6fc571594 100644 --- a/jax/tools/jax_to_ir.py +++ b/jax/tools/jax_to_ir.py @@ -85,7 +85,7 @@ def fn(x, y, z): try: import tensorflow as tf except ImportError: - tf = None # type: ignore + tf = None _FN = flags.DEFINE_string( @@ -151,7 +151,7 @@ def ordered_wrapper(*args): return fn_curried(**dict(zip(arg_names, args))) if format == 'HLO': - comp = jax.xla_computation(ordered_wrapper)(*args) + comp = jax.jit(ordered_wrapper).lower(*args).compiler_ir('hlo') serialized_proto = comp.as_serialized_hlo_module_proto() debug_txt = comp.as_hlo_text() else: @@ -238,12 +238,14 @@ def parse_shape_str(s): shape = () return jax.core.ShapedArray(shape, dtype) -_DT = {'pred': jnp.bool_, - 'u8': jnp.uint8, 'u16': jnp.uint16, 'u32': jnp.uint32, 'u64': jnp.uint64, - 's8': jnp.int8, 's16': jnp.int16, 's32': jnp.int32, 's64': jnp.int64, - 'bf16': jnp.bfloat16, - 'f16': jnp.float16, 'f32': jnp.float32, 'f64': jnp.float64, - 'c64': jnp.complex64, 'c128': jnp.complex128} +_DT = { + 'pred': jnp.bool_, + 'u4': jnp.uint4, 'u8': jnp.uint8, 'u16': jnp.uint16, 'u32': jnp.uint32, 'u64': jnp.uint64, + 's4': jnp.int4, 's8': jnp.int8, 's16': jnp.int16, 's32': jnp.int32, 's64': jnp.int64, + 'bf16': jnp.bfloat16, + 'f16': jnp.float16, 'f32': jnp.float32, 'f64': jnp.float64, + 'c64': jnp.complex64, 'c128': jnp.complex128 +} _SHAPE_RE = re.compile(f"^({'|'.join(_DT)})\\[\\s*(\\d*[\\s*,\\d+]*)\\s*\\]$") diff --git a/jax/tools/pgo_nsys_converter.py b/jax/tools/pgo_nsys_converter.py new file mode 100644 index 000000000000..5e87220be606 --- /dev/null +++ b/jax/tools/pgo_nsys_converter.py @@ -0,0 +1,62 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import csv +import re +import sys +import argparse +import os +import shutil +import subprocess + +if __name__ == '__main__': + + print("Script to convert NVIDIA Nsys Profiles to the .pbtxt format. This format is readable by XLA's Profile Guided Latency Estimator. Usage: pgo_nsys_converter.py --profile_path --pgle_output_path ") + + nsys_path = shutil.which("nsys") + + parser = argparse.ArgumentParser(description='Tool to convert NVIDIA Nsys Profiles to the .pbtxt format') + parser.add_argument("--profile_path", type=str, help="path to nsys profile") + parser.add_argument("--post_process", help="post process pbtxt to get minimum cost value for each instruction", action="store_true") + parser.add_argument("--pgle_output_path", type=str, help="output directory", default="/opt/paxml/workspace/lhs_pbtxt/temp.pbtxt") + + args = parser.parse_args() + + pgle_filename = os.path.basename(args.pgle_output_path).partition('.')[0] + pgle_folder = os.path.join(os.path.split(args.pgle_output_path)[0], '') + profile_folder = os.path.join(os.path.split(args.profile_path)[0], '') + + assert isinstance(nsys_path, str) + stats_command = [nsys_path, "stats", "--force-overwrite", "true", "--force-export", "true", "--report", "nvtxkernsum", f"{args.profile_path}", "-o", f"{args.pgle_output_path}"] + + print(f""" + ******Starting stats command****** + {stats_command}.""") + + proc = subprocess.Popen(stats_command, stdout=sys.stdout, stderr=sys.stderr) + proc.wait() + + thunk_re = re.compile("hlo_op=(.*)#") + with open(f"{args.pgle_output_path}", 'w', newline='') as protofile: + with open(f"{pgle_folder}{pgle_filename}.pbtxt_nvtxkernsum.csv", newline='') as csvfile: + reader = csv.DictReader(csvfile) + for row in reader: + name = row['NVTX Range'] + time_ns = float(row['Avg (ns)']) + m = thunk_re.search(name) + if m is not None: + protofile.write(f'costs {{ name: "{m.group(1)}" cost_us: {time_ns / 1000.0} }}\n') + + clean_command = f"rm {profile_folder}/*.sqlite; rm {pgle_folder}/*.csv" + subprocess.call(clean_command, shell=True) diff --git a/jax/experimental/array_api/_dtypes.py b/jax/tools/toolchains/BUILD similarity index 56% rename from jax/experimental/array_api/_dtypes.py rename to jax/tools/toolchains/BUILD index 72229bfc28af..1401ee12a06c 100644 --- a/jax/experimental/array_api/_dtypes.py +++ b/jax/tools/toolchains/BUILD @@ -1,4 +1,4 @@ -# Copyright 2023 The JAX Authors. +# Copyright 2024 The JAX Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,18 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. -import numpy as np +licenses(["notice"]) -bool = np.dtype('bool') -int8 = np.dtype('int8') -int16 = np.dtype('int16') -int32 = np.dtype('int32') -int64 = np.dtype('int64') -uint8 = np.dtype('uint8') -uint16 = np.dtype('uint16') -uint32 = np.dtype('uint32') -uint64 = np.dtype('uint64') -float32 = np.dtype('float32') -float64 = np.dtype('float64') -complex64 = np.dtype('complex64') -complex128 = np.dtype('complex128') +package( + default_applicable_licenses = [], + default_visibility = ["//visibility:public"], +) + +# Refer to https://bazel.build/configure/windows#clang for potential changes. +platform( + name = "x64_windows-clang-cl", + constraint_values = [ + "@platforms//cpu:x86_64", + "@platforms//os:windows", + "@bazel_tools//tools/cpp:clang-cl", + ], +) diff --git a/jax/tree_util.py b/jax/tree_util.py index 7d8b54276944..800086c220ac 100644 --- a/jax/tree_util.py +++ b/jax/tree_util.py @@ -39,33 +39,34 @@ # See PEP 484 & https://github.com/google/jax/issues/7570 from jax._src.tree_util import ( - Partial as Partial, - PyTreeDef as PyTreeDef, - all_leaves as all_leaves, - build_tree as build_tree, - default_registry as default_registry, - register_pytree_node as register_pytree_node, - register_pytree_node_class as register_pytree_node_class, - tree_all as tree_all, - tree_flatten as tree_flatten, - tree_leaves as tree_leaves, - tree_map as tree_map, - tree_reduce as tree_reduce, - tree_structure as tree_structure, - tree_transpose as tree_transpose, - tree_unflatten as tree_unflatten, - treedef_children as treedef_children, - treedef_is_leaf as treedef_is_leaf, - treedef_tuple as treedef_tuple, - register_pytree_with_keys as register_pytree_with_keys, - register_pytree_with_keys_class as register_pytree_with_keys_class, - tree_map_with_path as tree_map_with_path, - tree_flatten_with_path as tree_flatten_with_path, - tree_leaves_with_path as tree_leaves_with_path, - keystr as keystr, - SequenceKey as SequenceKey, - DictKey as DictKey, - GetAttrKey as GetAttrKey, - FlattenedIndexKey as FlattenedIndexKey, - register_static as register_static, + DictKey as DictKey, + FlattenedIndexKey as FlattenedIndexKey, + GetAttrKey as GetAttrKey, + Partial as Partial, + PyTreeDef as PyTreeDef, + SequenceKey as SequenceKey, + all_leaves as all_leaves, + build_tree as build_tree, + default_registry as default_registry, + keystr as keystr, + register_pytree_node_class as register_pytree_node_class, + register_pytree_node as register_pytree_node, + register_pytree_with_keys_class as register_pytree_with_keys_class, + register_dataclass as register_dataclass, + register_pytree_with_keys as register_pytree_with_keys, + register_static as register_static, + tree_all as tree_all, + tree_flatten_with_path as tree_flatten_with_path, + tree_flatten as tree_flatten, + tree_leaves_with_path as tree_leaves_with_path, + tree_leaves as tree_leaves, + tree_map_with_path as tree_map_with_path, + tree_map as tree_map, + tree_reduce as tree_reduce, + tree_structure as tree_structure, + tree_transpose as tree_transpose, + tree_unflatten as tree_unflatten, + treedef_children as treedef_children, + treedef_is_leaf as treedef_is_leaf, + treedef_tuple as treedef_tuple, ) diff --git a/jax/util.py b/jax/util.py index 4519576a7d49..c1259e9c5f56 100644 --- a/jax/util.py +++ b/jax/util.py @@ -23,6 +23,7 @@ safe_zip as safe_zip, split_dict as split_dict, split_list as split_list, + split_list_checked as split_list_checked, split_merge as split_merge, subvals as subvals, toposort as toposort, diff --git a/jax/version.py b/jax/version.py index 14ba0d9d461a..f3d007eec9b1 100644 --- a/jax/version.py +++ b/jax/version.py @@ -21,7 +21,7 @@ import pathlib import subprocess -_version = "0.4.26" +_version = "0.4.31" # The following line is overwritten by build scripts in distributions & # releases. Do not modify this manually, or jax/jaxlib build will fail. _release_version: str | None = None @@ -47,21 +47,15 @@ def _version_from_git_tree(base_version: str) -> str | None: try: root_directory = os.path.dirname(os.path.realpath(__file__)) - # Get date string from date of most recent git commit. - p = subprocess.Popen(["git", "show", "-s", "--format=%at", "HEAD"], + # Get date string from date of most recent git commit, and the abbreviated + # hash of that commit. + p = subprocess.Popen(["git", "show", "-s", "--format=%at-%h", "HEAD"], cwd=root_directory, stdout=subprocess.PIPE, stderr=subprocess.PIPE) stdout, _ = p.communicate() - timestamp = int(stdout.decode().strip()) - datestring = datetime.date.fromtimestamp(timestamp).strftime("%Y%m%d") + timestamp, commit_hash = stdout.decode().strip().split('-', 1) + datestring = datetime.date.fromtimestamp(int(timestamp)).strftime("%Y%m%d") assert datestring.isnumeric() - - # Get commit hash from most recent git commit. - p = subprocess.Popen(["git", "describe", "--long", "--always"], - cwd=root_directory, - stdout=subprocess.PIPE, stderr=subprocess.PIPE) - stdout, _ = p.communicate() - commit_hash = stdout.decode().strip().rsplit('-g', 1)[-1] assert commit_hash.isalnum() except: return None @@ -139,7 +133,7 @@ def make_release_tree(self, base_dir, files): __version__ = _get_version_string() -_minimum_jaxlib_version = "0.4.20" +_minimum_jaxlib_version = "0.4.30" def _version_as_tuple(version_str): return tuple(int(i) for i in version_str.split(".") if i.isdigit()) diff --git a/jax_plugins/BUILD.bazel b/jax_plugins/BUILD.bazel index 2102c6404c5a..6e2cf6aadbaf 100644 --- a/jax_plugins/BUILD.bazel +++ b/jax_plugins/BUILD.bazel @@ -17,6 +17,7 @@ licenses(["notice"]) load( "//jaxlib:jax.bzl", "if_cuda_is_configured", + "if_rocm_is_configured", "py_library_providing_imports_info", ) @@ -30,5 +31,7 @@ py_library( ":jax_plugins", ] + if_cuda_is_configured([ "//jax_plugins/cuda:cuda_plugin", + ]) + if_rocm_is_configured([ + "//jax_plugins/rocm:rocm_plugin", ]), -) \ No newline at end of file +) diff --git a/jax_plugins/cuda/__init__.py b/jax_plugins/cuda/__init__.py index d8f35fc33969..ff5a1561dbbc 100644 --- a/jax_plugins/cuda/__init__.py +++ b/jax_plugins/cuda/__init__.py @@ -25,7 +25,7 @@ # cuda_plugin_extension locates inside jaxlib. `jaxlib` is for testing without # preinstalled jax cuda plugin packages. -for pkg_name in ['jax_cuda12_plugin', 'jax_cuda11_plugin', 'jaxlib']: +for pkg_name in ['jax_cuda12_plugin', 'jaxlib']: try: cuda_plugin_extension = importlib.import_module( f'{pkg_name}.cuda_plugin_extension' diff --git a/jax_plugins/cuda/plugin_setup.py b/jax_plugins/cuda/plugin_setup.py index 77454fb488a3..cd26731aa629 100644 --- a/jax_plugins/cuda/plugin_setup.py +++ b/jax_plugins/cuda/plugin_setup.py @@ -49,13 +49,33 @@ def has_ext_modules(self): author="JAX team", author_email="jax-dev@google.com", packages=[package_name], - python_requires=">=3.9", + python_requires=">=3.10", install_requires=[f"jax-cuda{cuda_version}-pjrt=={__version__}"], + extras_require={ + 'with_cuda': [ + "nvidia-cublas-cu12>=12.1.3.1", + "nvidia-cuda-cupti-cu12>=12.1.105", + "nvidia-cuda-nvcc-cu12>=12.1.105", + "nvidia-cuda-runtime-cu12>=12.1.105", + "nvidia-cudnn-cu12>=9.0,<10.0", + "nvidia-cufft-cu12>=11.0.2.54", + "nvidia-cusolver-cu12>=11.4.5.107", + "nvidia-cusparse-cu12>=12.1.0.106", + "nvidia-nccl-cu12>=2.18.1", + # nvjitlink is not a direct dependency of JAX, but it is a transitive + # dependency via, for example, cuSOLVER. NVIDIA's cuSOLVER packages + # do not have a version constraint on their dependencies, so the + # package doesn't get upgraded even though not doing that can cause + # problems (https://github.com/google/jax/issues/18027#issuecomment-1756305196) + # Until NVIDIA add version constraints, add a version constraint + # here. + "nvidia-nvjitlink-cu12>=12.1.105", + ], + }, url="https://github.com/google/jax", license="Apache-2.0", classifiers=[ "Development Status :: 3 - Alpha", - "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", diff --git a/jax_plugins/rocm/BUILD.bazel b/jax_plugins/rocm/BUILD.bazel new file mode 100644 index 000000000000..08a61c786262 --- /dev/null +++ b/jax_plugins/rocm/BUILD.bazel @@ -0,0 +1,55 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +licenses(["notice"]) + +load("//jaxlib:symlink_files.bzl", "symlink_files") +load( + "//jaxlib:jax.bzl", + "if_windows", + "py_library_providing_imports_info", + "pytype_library", +) + +package( + default_applicable_licenses = [], + default_visibility = ["//:__subpackages__"], +) + +exports_files([ + "__init__.py", + "plugin_pyproject.toml", + "plugin_setup.py", + "pyproject.toml", + "setup.py", +]) + +symlink_files( + name = "pjrt_c_api_gpu_plugin", + srcs = if_windows( + ["@xla//xla/pjrt/c/pjrt_c_api_gpu_plugin.pyd"], + ["@xla//xla/pjrt/c:pjrt_c_api_gpu_plugin.so"], + ), + dst = ".", + flatten = True, +) + +py_library_providing_imports_info( + name = "rocm_plugin", + srcs = [ + "__init__.py", + ], + data = [":pjrt_c_api_gpu_plugin"], + lib_rule = pytype_library, +) diff --git a/jax_plugins/rocm/__init__.py b/jax_plugins/rocm/__init__.py new file mode 100644 index 000000000000..4535f1b3bbc8 --- /dev/null +++ b/jax_plugins/rocm/__init__.py @@ -0,0 +1,91 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import functools +import importlib +import logging +import pathlib +import platform + +from jax._src.lib import xla_client +import jax._src.xla_bridge as xb + +# rocm_plugin_extension locates inside jaxlib. `jaxlib` is for testing without +# preinstalled jax rocm plugin packages. +for pkg_name in ['jax_rocm60_plugin', 'jaxlib']: + try: + rocm_plugin_extension = importlib.import_module( + f'{pkg_name}.rocm_plugin_extension' + ) + except ImportError: + rocm_plugin_extension = None + else: + break + +logger = logging.getLogger(__name__) + + +def _get_library_path(): + base_path = pathlib.Path(__file__).resolve().parent + installed_path = ( + base_path / 'xla_rocm_plugin.so' + ) + if installed_path.exists(): + return installed_path + + local_path = ( + base_path / 'pjrt_c_api_gpu_plugin.so' + ) + if local_path.exists(): + logger.debug( + 'Native library %s does not exist. This most likely indicates an issue' + ' with how %s was built or installed. Fallback to local test' + ' library %s', + installed_path, + __package__, + local_path, + ) + return local_path + + logger.debug( + 'WARNING: Native library %s and local test library path %s do not' + ' exist. This most likely indicates an issue with how %s was built or' + ' installed or missing src files.', + installed_path, + local_path, + __package__, + ) + return None + + +def initialize(): + path = _get_library_path() + if path is None: + return + options = xla_client.generate_pjrt_gpu_plugin_options() + options["platform_name"] = "ROCM" + c_api = xb.register_plugin( + 'rocm', priority=500, library_path=str(path), options=options + ) + if rocm_plugin_extension: + xla_client.register_custom_call_handler( + "ROCM", + functools.partial( + rocm_plugin_extension.register_custom_call_target, c_api + ), + ) + for _name, _value in rocm_plugin_extension.registrations().items(): + xla_client.register_custom_call_target(_name, _value, platform="ROCM") + else: + logger.warning('rocm_plugin_extension is not found.') diff --git a/jax_plugins/rocm/plugin_pyproject.toml b/jax_plugins/rocm/plugin_pyproject.toml new file mode 100644 index 000000000000..8fe2f47af9a1 --- /dev/null +++ b/jax_plugins/rocm/plugin_pyproject.toml @@ -0,0 +1,3 @@ +[build-system] +requires = ["setuptools>=42", "wheel"] +build-backend = "setuptools.build_meta" diff --git a/jax_plugins/rocm/plugin_setup.py b/jax_plugins/rocm/plugin_setup.py new file mode 100644 index 000000000000..9ccf3bf44339 --- /dev/null +++ b/jax_plugins/rocm/plugin_setup.py @@ -0,0 +1,70 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib +import os +from setuptools import setup +from setuptools.dist import Distribution + +__version__ = None +rocm_version = 0 # placeholder +project_name = f"jax-rocm{rocm_version}-plugin" +package_name = f"jax_rocm{rocm_version}_plugin" + +def load_version_module(pkg_path): + spec = importlib.util.spec_from_file_location( + 'version', os.path.join(pkg_path, 'version.py')) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module + +_version_module = load_version_module(package_name) +__version__ = _version_module._get_version_for_build() +_cmdclass = _version_module._get_cmdclass(package_name) + +class BinaryDistribution(Distribution): + """This class makes 'bdist_wheel' include an ABI tag on the wheel.""" + + def has_ext_modules(self): + return True + +setup( + name=project_name, + version=__version__, + cmdclass=_cmdclass, + description="JAX Plugin for AMD GPUs", + long_description="", + long_description_content_type="text/markdown", + author="Ruturaj4", + author_email="Ruturaj.Vaidya@amd.com", + packages=[package_name], + python_requires=">=3.9", + install_requires=[f"jax-rocm{rocm_version}-pjrt=={__version__}"], + url="https://github.com/google/jax", + license="Apache-2.0", + classifiers=[ + "Development Status :: 3 - Alpha", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + ], + package_data={ + package_name: [ + "*", + ], + }, + zip_safe=False, + distclass=BinaryDistribution, +) diff --git a/jax_plugins/rocm/pyproject.toml b/jax_plugins/rocm/pyproject.toml new file mode 100644 index 000000000000..8fe2f47af9a1 --- /dev/null +++ b/jax_plugins/rocm/pyproject.toml @@ -0,0 +1,3 @@ +[build-system] +requires = ["setuptools>=42", "wheel"] +build-backend = "setuptools.build_meta" diff --git a/jax_plugins/rocm/setup.py b/jax_plugins/rocm/setup.py new file mode 100644 index 000000000000..8782676ce9a2 --- /dev/null +++ b/jax_plugins/rocm/setup.py @@ -0,0 +1,66 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib +import os +from setuptools import setup, find_namespace_packages + +__version__ = None +rocm_version = 0 # placeholder +project_name = f"jax-rocm{rocm_version}-pjrt" +package_name = f"jax_plugins.xla_rocm{rocm_version}" + +def load_version_module(pkg_path): + spec = importlib.util.spec_from_file_location( + 'version', os.path.join(pkg_path, 'version.py')) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module + +_version_module = load_version_module(f"jax_plugins/xla_rocm{rocm_version}") +__version__ = _version_module._get_version_for_build() + +packages = find_namespace_packages( + include=[ + package_name, + f"{package_name}.*", + ] +) + +setup( + name=project_name, + version=__version__, + description="JAX XLA PJRT Plugin for AMD GPUs", + long_description="", + long_description_content_type="text/markdown", + author="Ruturaj4", + author_email="Ruturaj.Vaidya@amd.com", + packages=packages, + install_requires=[], + url="https://github.com/google/jax", + license="Apache-2.0", + classifiers=[ + "Development Status :: 3 - Alpha", + "Programming Language :: Python :: 3", + ], + package_data={ + package_name: ["xla_rocm_plugin.so"], + }, + zip_safe=False, + entry_points={ + "jax_plugins": [ + f"xla_rocm{rocm_version} = {package_name}", + ], + }, +) diff --git a/jaxlib/BUILD b/jaxlib/BUILD index 988fe9afca03..dc8b5148ca93 100644 --- a/jaxlib/BUILD +++ b/jaxlib/BUILD @@ -17,7 +17,6 @@ load("//jaxlib:symlink_files.bzl", "symlink_files") load( "//jaxlib:jax.bzl", - "if_windows", "py_library_providing_imports_info", "pybind_extension", "pytype_library", @@ -30,10 +29,16 @@ package( default_visibility = ["//:__subpackages__"], ) +# This makes xla_extension module accessible from jax._src.lib. +genrule( + name = "xla_extension_py", + outs = ["xla_extension.py"], + cmd = "echo 'from xla.xla.python.xla_extension import *\n' > $@", +) + py_library_providing_imports_info( name = "jaxlib", srcs = [ - "ducc_fft.py", "gpu_common_utils.py", "gpu_linalg.py", "gpu_prng.py", @@ -46,23 +51,27 @@ py_library_providing_imports_info( "lapack.py", ":version", ":xla_client", + ":xla_extension_py", ], - data = [":xla_extension"], + data = [":ffi_headers"], lib_rule = pytype_library, deps = [ ":cpu_feature_guard", ":utils", - "//jaxlib/cpu:_ducc_fft", "//jaxlib/cpu:_lapack", "//jaxlib/mlir", "//jaxlib/mlir:arithmetic_dialect", "//jaxlib/mlir:builtin_dialect", "//jaxlib/mlir:chlo_dialect", "//jaxlib/mlir:func_dialect", + "//jaxlib/mlir:gpu_dialect", "//jaxlib/mlir:ir", + "//jaxlib/mlir:llvm_dialect", "//jaxlib/mlir:math_dialect", "//jaxlib/mlir:memref_dialect", "//jaxlib/mlir:mhlo_dialect", + "//jaxlib/mlir:nvgpu_dialect", + "//jaxlib/mlir:nvvm_dialect", "//jaxlib/mlir:pass_manager", "//jaxlib/mlir:scf_dialect", "//jaxlib/mlir:sparse_tensor_dialect", @@ -70,6 +79,7 @@ py_library_providing_imports_info( "//jaxlib/mlir:vector_dialect", "//jaxlib/mosaic", "//jaxlib/triton", + "@xla//xla/python:xla_extension", ], ) @@ -88,12 +98,9 @@ symlink_files( ) symlink_files( - name = "xla_extension", - srcs = if_windows( - ["@xla//xla/python:xla_extension.pyd"], - ["@xla//xla/python:xla_extension.so"], - ), - dst = ".", + name = "ffi_headers", + srcs = ["@xla//xla/ffi/api:all_headers"], + dst = "include/xla/ffi/api", flatten = True, ) @@ -126,7 +133,7 @@ cc_library( features = ["-use_header_modules"], deps = [ ":kernel_helpers", - "@tsl//tsl/python/lib/core:numpy", + "@xla//xla/tsl/python/lib/core:numpy", "@com_google_absl//absl/base", "@nanobind", ], @@ -189,11 +196,13 @@ pybind_extension( srcs = ["cuda_plugin_extension.cc"], module_name = "cuda_plugin_extension", deps = [ + "@com_google_absl//absl/status", "@nanobind", "//jaxlib:kernel_nanobind_helpers", "@xla//third_party/python_runtime:headers", - "@xla//xla:status", + "@local_config_cuda//cuda:cuda_headers", "@xla//xla:util", + "@xla//xla/ffi/api:c_api", "@xla//xla/pjrt:status_casters", "@xla//xla/pjrt/c:pjrt_c_api_gpu_extension_hdrs", "@xla//xla/pjrt/c:pjrt_c_api_hdrs", @@ -202,7 +211,29 @@ pybind_extension( "@xla//xla/python:py_client_gpu", "@xla//xla/tsl/cuda:cublas", "@xla//xla/tsl/cuda:cudart", - "@tsl//tsl/python/lib/core:numpy", + "@xla//xla/tsl/python/lib/core:numpy", + ], +) + +pybind_extension( + name = "rocm_plugin_extension", + srcs = ["rocm_plugin_extension.cc"], + module_name = "rocm_plugin_extension", + deps = [ + "//jaxlib:kernel_nanobind_helpers", + "@xla//third_party/python_runtime:headers", + "@xla//xla:status", + "@xla//xla:util", + "@xla//xla/ffi/api:c_api", + "@xla//xla/pjrt:status_casters", + "@xla//xla/pjrt/c:pjrt_c_api_gpu_extension_hdrs", + "@xla//xla/pjrt/c:pjrt_c_api_hdrs", + "@xla//xla/pjrt/c:pjrt_c_api_helpers", + "@xla//xla/python:py_client_gpu", + "@xla//xla/tsl/python/lib/core:numpy", + "@com_google_absl//absl/status", + "@local_config_rocm//rocm:rocm_headers", + "@nanobind", ], ) diff --git a/jaxlib/cpu/BUILD b/jaxlib/cpu/BUILD index 279cf5c2aab5..bb406ffd3adc 100644 --- a/jaxlib/cpu/BUILD +++ b/jaxlib/cpu/BUILD @@ -16,7 +16,6 @@ load( "//jaxlib:jax.bzl", - "flatbuffer_cc_library", "pybind_extension", ) @@ -36,8 +35,14 @@ cc_library( copts = ["-fexceptions"], features = ["-use_header_modules"], deps = [ + "@xla//xla/ffi/api:c_api", + "@xla//xla/ffi/api:ffi", "@xla//xla/service:custom_call_status", + "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:dynamic_annotations", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:str_format", ], ) @@ -51,6 +56,7 @@ cc_library( pybind_extension( name = "_lapack", srcs = ["lapack.cc"], + hdrs = ["lapack.h"], copts = [ "-fexceptions", "-fno-strict-aliasing", @@ -64,49 +70,7 @@ pybind_extension( deps = [ ":lapack_kernels", "//jaxlib:kernel_nanobind_helpers", - "@nanobind", - ], -) - -# DUCC (CPU FFTs) - -flatbuffer_cc_library( - name = "ducc_fft_flatbuffers_cc", - srcs = ["ducc_fft.fbs"], -) - -cc_library( - name = "ducc_fft_kernels", - srcs = ["ducc_fft_kernels.cc"], - hdrs = ["ducc_fft_kernels.h"], - copts = ["-fexceptions"], # DUCC may throw. - features = ["-use_header_modules"], - deps = [ - ":ducc_fft_flatbuffers_cc", - "@xla//xla/service:custom_call_status", - "@com_github_google_flatbuffers//:flatbuffers", - "@ducc//:fft", - ], -) - -pybind_extension( - name = "_ducc_fft", - srcs = ["ducc_fft.cc"], - copts = [ - "-fexceptions", - "-fno-strict-aliasing", - ], - enable_stub_generation = True, - features = ["-use_header_modules"], - module_name = "_ducc_fft", - pytype_srcs = [ - "_ducc_fft.pyi", - ], - deps = [ - ":ducc_fft_flatbuffers_cc", - ":ducc_fft_kernels", - "//jaxlib:kernel_nanobind_helpers", - "@com_github_google_flatbuffers//:flatbuffers", + "@xla//xla/ffi/api:ffi", "@nanobind", ], ) @@ -114,11 +78,13 @@ pybind_extension( cc_library( name = "cpu_kernels", srcs = ["cpu_kernels.cc"], + hdrs = ["lapack.h"], visibility = ["//visibility:public"], deps = [ - ":ducc_fft_kernels", ":lapack_kernels", ":lapack_kernels_using_lapack", + "@xla//xla/ffi/api:c_api", + "@xla//xla/ffi/api:ffi", "@xla//xla/service:custom_call_target_registry", ], alwayslink = 1, diff --git a/jaxlib/cpu/cpu_kernels.cc b/jaxlib/cpu/cpu_kernels.cc index dc5d657a9430..0cb9e7cb3328 100644 --- a/jaxlib/cpu/cpu_kernels.cc +++ b/jaxlib/cpu/cpu_kernels.cc @@ -16,13 +16,23 @@ limitations under the License. // This file is not used by JAX itself, but exists to assist with running // JAX-generated HLO code from outside of JAX. -#include "jaxlib/cpu/ducc_fft_kernels.h" +#include + +#include "jaxlib/cpu/lapack.h" #include "jaxlib/cpu/lapack_kernels.h" +#include "xla/ffi/api/c_api.h" +#include "xla/ffi/api/ffi.h" #include "xla/service/custom_call_target_registry.h" +#define JAX_CPU_REGISTER_HANDLER(name) \ + XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), #name, "Host", name); + namespace jax { namespace { +// Old-style kernels +// TODO(b/344892332): To be removed after the 6M compatibility period is over. + XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("blas_strsm", Trsm::Kernel, "Host"); XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("blas_dtrsm", Trsm::Kernel, @@ -105,10 +115,19 @@ XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM( "lapack_cgees", ComplexGees>::Kernel, "Host"); XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM( "lapack_zgees", ComplexGees>::Kernel, "Host"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM( - "ducc_fft", DuccFft, "Host"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM( - "dynamic_ducc_fft", DynamicDuccFft, "Host"); + +// FFI Kernels + +JAX_CPU_REGISTER_HANDLER(lapack_sgetrf_ffi); +JAX_CPU_REGISTER_HANDLER(lapack_dgetrf_ffi); +JAX_CPU_REGISTER_HANDLER(lapack_cgetrf_ffi); +JAX_CPU_REGISTER_HANDLER(lapack_zgetrf_ffi); +JAX_CPU_REGISTER_HANDLER(lapack_spotrf_ffi); +JAX_CPU_REGISTER_HANDLER(lapack_dpotrf_ffi); +JAX_CPU_REGISTER_HANDLER(lapack_cpotrf_ffi); +JAX_CPU_REGISTER_HANDLER(lapack_zpotrf_ffi); + +#undef JAX_CPU_REGISTER_HANDLER } // namespace } // namespace jax diff --git a/jaxlib/cpu/ducc_fft.cc b/jaxlib/cpu/ducc_fft.cc deleted file mode 100644 index 33e73c5f4214..000000000000 --- a/jaxlib/cpu/ducc_fft.cc +++ /dev/null @@ -1,66 +0,0 @@ -/* Copyright 2020 The JAX Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include - -#include "nanobind/nanobind.h" -#include "nanobind/stl/vector.h" -#include "jaxlib/cpu/ducc_fft_generated.h" -#include "jaxlib/cpu/ducc_fft_kernels.h" -#include "jaxlib/kernel_nanobind_helpers.h" - -namespace nb = nanobind; - -namespace jax { -namespace { - - -nb::bytes BuildDynamicDuccFftDescriptor( - const uint32_t ndims, - bool is_double, int fft_type, - const std::vector &axes, - bool forward) { - DynamicDuccFftDescriptorT descriptor; - descriptor.ndims = ndims; - descriptor.fft_type = static_cast(fft_type); - descriptor.dtype = - is_double ? DuccFftDtype_COMPLEX128 : DuccFftDtype_COMPLEX64; - descriptor.axes = axes; - descriptor.forward = forward; - flatbuffers::FlatBufferBuilder fbb; - fbb.Finish(DynamicDuccFftDescriptor::Pack(fbb, &descriptor)); - return nb::bytes(reinterpret_cast(fbb.GetBufferPointer()), - fbb.GetSize()); -} - -nb::dict Registrations() { - nb::dict dict; - // TODO(b/287702203): this must be kept until EOY 2023 for backwards - // of serialized functions using fft. - dict["ducc_fft"] = EncapsulateFunction(DuccFft); - dict["dynamic_ducc_fft"] = EncapsulateFunction(DynamicDuccFft); - return dict; -} - -NB_MODULE(_ducc_fft, m) { - m.def("registrations", &Registrations); - m.def("dynamic_ducc_fft_descriptor", &BuildDynamicDuccFftDescriptor, - nb::arg("ndims"), nb::arg("is_double"), nb::arg("fft_type"), - nb::arg("axes"), nb::arg("forward")); -} - -} // namespace -} // namespace jax diff --git a/jaxlib/cpu/ducc_fft_kernels.cc b/jaxlib/cpu/ducc_fft_kernels.cc deleted file mode 100644 index eab35905f31c..000000000000 --- a/jaxlib/cpu/ducc_fft_kernels.cc +++ /dev/null @@ -1,142 +0,0 @@ -/* Copyright 2020 The JAX Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include - -#include "ducc/src/ducc0/fft/fft.h" -#include "ducc/src/ducc0/fft/fft1d_impl.h" // NOLINT: required for fft definitions. -#include "ducc/src/ducc0/fft/fftnd_impl.h" // NOLINT: required for fft definitions. -#include "flatbuffers/flatbuffers.h" -#include "jaxlib/cpu/ducc_fft_generated.h" -#include "xla/service/custom_call_status.h" - -namespace jax { - -using shape_t = ducc0::fmav_info::shape_t; -using stride_t = ducc0::fmav_info::stride_t; - -namespace { - -void DuccFftImpl(void *out, void *operand, jax::DuccFftDtype dtype, - jax::DuccFftType fft_type, - shape_t shape, stride_t strides_in, stride_t strides_out, shape_t axes, - bool forward, double scale) { - - switch (fft_type) { - case DuccFftType_C2C: - if (dtype == DuccFftDtype_COMPLEX64) { - ducc0::cfmav> m_in( - reinterpret_cast *>(operand), shape, strides_in); - ducc0::vfmav> m_out( - reinterpret_cast *>(out), shape, strides_out); - ducc0::c2c(m_in, m_out, axes, forward, static_cast(scale)); - } else { - ducc0::cfmav> m_in( - reinterpret_cast *>(operand), shape, strides_in); - ducc0::vfmav> m_out( - reinterpret_cast *>(out), shape, strides_out); - ducc0::c2c(m_in, m_out, axes, forward, scale); - } - break; - case DuccFftType_C2R: - if (dtype == DuccFftDtype_COMPLEX64) { - auto shape_in = shape; - shape_in[axes.back()] = shape_in[axes.back()] / 2 + 1; - ducc0::cfmav> m_in( - reinterpret_cast *>(operand), - shape_in, strides_in); - ducc0::vfmav m_out(reinterpret_cast(out), shape, - strides_out); - ducc0::c2r(m_in, m_out, axes, forward, static_cast(scale)); - } else { - auto shape_in = shape; - shape_in[axes.back()] = shape_in[axes.back()] / 2 + 1; - ducc0::cfmav> m_in( - reinterpret_cast *>(operand), - shape_in, strides_in); - ducc0::vfmav m_out(reinterpret_cast(out), shape, - strides_out); - ducc0::c2r(m_in, m_out, axes, forward, scale); - } - break; - case DuccFftType_R2C: - if (dtype == DuccFftDtype_COMPLEX64) { - auto shape_out = shape; - shape_out[axes.back()] = shape_out[axes.back()] / 2 + 1; - ducc0::cfmav m_in(reinterpret_cast(operand), shape, - strides_in); - ducc0::vfmav> m_out( - reinterpret_cast *>(out), - shape_out, strides_out); - ducc0::r2c(m_in, m_out, axes, forward, static_cast(scale)); - } else { - auto shape_out = shape; - shape_out[axes.back()] = shape_out[axes.back()] / 2 + 1; - ducc0::cfmav m_in(reinterpret_cast(operand), shape, - strides_in); - ducc0::vfmav> m_out( - reinterpret_cast *>(out), - shape_out, strides_out); - ducc0::r2c(m_in, m_out, axes, forward, scale); - } - break; - } -} - -} // namespace - - -// TODO(b/287702203): this must be kept until EOY 2023 for backwards -// of serialized functions using fft. -void DuccFft(void *out, void **in, XlaCustomCallStatus *) { - const DuccFftDescriptor *descriptor = GetDuccFftDescriptor(in[0]); - shape_t shape(descriptor->shape()->begin(), descriptor->shape()->end()); - stride_t strides_in(descriptor->strides_in()->begin(), - descriptor->strides_in()->end()); - stride_t strides_out(descriptor->strides_out()->begin(), - descriptor->strides_out()->end()); - shape_t axes(descriptor->axes()->begin(), descriptor->axes()->end()); - - DuccFftImpl(out, in[1], descriptor->dtype(), descriptor->fft_type(), - shape, strides_in, strides_out, axes, - descriptor->forward(), descriptor->scale()); -} - - -void DynamicDuccFft(void *out, void **in, XlaCustomCallStatus *) { - // in[0]=descriptor, in[1]=operand, - // in[2]=shape, in[3]=strides_in, in[4]=strides_out, in[5]=scale. - const DynamicDuccFftDescriptor *descriptor = - flatbuffers::GetRoot(in[0]); - const std::uint32_t *dynamic_shape = - reinterpret_cast(in[2]); - shape_t shape(dynamic_shape, dynamic_shape + descriptor->ndims()); - const std::uint32_t *dynamic_strides_in = - reinterpret_cast(in[3]); - stride_t strides_in(dynamic_strides_in, - dynamic_strides_in + descriptor->ndims()); - const std::uint32_t *dynamic_strides_out = - reinterpret_cast(in[4]); - stride_t strides_out(dynamic_strides_out, - dynamic_strides_out + descriptor->ndims()); - shape_t axes(descriptor->axes()->begin(), descriptor->axes()->end()); - const double *dynamic_scale = reinterpret_cast(in[5]); - - DuccFftImpl(out, in[1], descriptor->dtype(), descriptor->fft_type(), - shape, strides_in, strides_out, axes, - descriptor->forward(), *dynamic_scale); -} - -} // namespace jax diff --git a/jaxlib/cpu/lapack.cc b/jaxlib/cpu/lapack.cc index ddf605fdd0bd..d01efa7f7864 100644 --- a/jaxlib/cpu/lapack.cc +++ b/jaxlib/cpu/lapack.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "jaxlib/cpu/lapack.h" + #include #include "nanobind/nanobind.h" @@ -24,6 +26,8 @@ namespace { namespace nb = nanobind; +using ::xla::ffi::DataType; + void GetLapackKernelsFromScipy() { static bool initialized = false; // Protected by GIL if (initialized) return; @@ -35,12 +39,11 @@ void GetLapackKernelsFromScipy() { auto blas_ptr = [&](const char* name) { return nb::cast(blas_capi[name]).data(); }; - Trsm::fn = reinterpret_cast::FnType*>(blas_ptr("strsm")); - Trsm::fn = reinterpret_cast::FnType*>(blas_ptr("dtrsm")); - Trsm>::fn = - reinterpret_cast>::FnType*>(blas_ptr("ctrsm")); - Trsm>::fn = - reinterpret_cast>::FnType*>(blas_ptr("ztrsm")); + + AssignKernelFn>(blas_ptr("strsm")); + AssignKernelFn>(blas_ptr("dtrsm")); + AssignKernelFn>>(blas_ptr("ctrsm")); + AssignKernelFn>>(blas_ptr("ztrsm")); nb::module_ cython_lapack = nb::module_::import_("scipy.linalg.cython_lapack"); @@ -48,106 +51,63 @@ void GetLapackKernelsFromScipy() { auto lapack_ptr = [&](const char* name) { return nb::cast(lapack_capi[name]).data(); }; - Getrf::fn = - reinterpret_cast::FnType*>(lapack_ptr("sgetrf")); - Getrf::fn = - reinterpret_cast::FnType*>(lapack_ptr("dgetrf")); - Getrf>::fn = - reinterpret_cast>::FnType*>( - lapack_ptr("cgetrf")); - Getrf>::fn = - reinterpret_cast>::FnType*>( - lapack_ptr("zgetrf")); - Geqrf::fn = - reinterpret_cast::FnType*>(lapack_ptr("sgeqrf")); - Geqrf::fn = - reinterpret_cast::FnType*>(lapack_ptr("dgeqrf")); - Geqrf>::fn = - reinterpret_cast>::FnType*>( - lapack_ptr("cgeqrf")); - Geqrf>::fn = - reinterpret_cast>::FnType*>( - lapack_ptr("zgeqrf")); - Orgqr::fn = - reinterpret_cast::FnType*>(lapack_ptr("sorgqr")); - Orgqr::fn = - reinterpret_cast::FnType*>(lapack_ptr("dorgqr")); - Orgqr>::fn = - reinterpret_cast>::FnType*>( - lapack_ptr("cungqr")); - Orgqr>::fn = - reinterpret_cast>::FnType*>( - lapack_ptr("zungqr")); - Potrf::fn = - reinterpret_cast::FnType*>(lapack_ptr("spotrf")); - Potrf::fn = - reinterpret_cast::FnType*>(lapack_ptr("dpotrf")); - Potrf>::fn = - reinterpret_cast>::FnType*>( - lapack_ptr("cpotrf")); - Potrf>::fn = - reinterpret_cast>::FnType*>( - lapack_ptr("zpotrf")); - RealGesdd::fn = - reinterpret_cast::FnType*>(lapack_ptr("sgesdd")); - RealGesdd::fn = - reinterpret_cast::FnType*>(lapack_ptr("dgesdd")); - ComplexGesdd>::fn = - reinterpret_cast>::FnType*>( - lapack_ptr("cgesdd")); - ComplexGesdd>::fn = - reinterpret_cast>::FnType*>( - lapack_ptr("zgesdd")); - RealSyevd::fn = - reinterpret_cast::FnType*>(lapack_ptr("ssyevd")); - RealSyevd::fn = - reinterpret_cast::FnType*>(lapack_ptr("dsyevd")); - ComplexHeevd>::fn = - reinterpret_cast>::FnType*>( - lapack_ptr("cheevd")); - ComplexHeevd>::fn = - reinterpret_cast>::FnType*>( - lapack_ptr("zheevd")); - RealGeev::fn = - reinterpret_cast::FnType*>(lapack_ptr("sgeev")); - RealGeev::fn = - reinterpret_cast::FnType*>(lapack_ptr("dgeev")); - ComplexGeev>::fn = - reinterpret_cast>::FnType*>( - lapack_ptr("cgeev")); - ComplexGeev>::fn = - reinterpret_cast>::FnType*>( - lapack_ptr("zgeev")); - RealGees::fn = - reinterpret_cast::FnType*>(lapack_ptr("sgees")); - RealGees::fn = - reinterpret_cast::FnType*>(lapack_ptr("dgees")); - ComplexGees>::fn = - reinterpret_cast>::FnType*>( - lapack_ptr("cgees")); - ComplexGees>::fn = - reinterpret_cast>::FnType*>( - lapack_ptr("zgees")); - Gehrd::fn = - reinterpret_cast::FnType*>(lapack_ptr("sgehrd")); - Gehrd::fn = - reinterpret_cast::FnType*>(lapack_ptr("dgehrd")); - Gehrd>::fn = - reinterpret_cast>::FnType*>( - lapack_ptr("cgehrd")); - Gehrd>::fn = - reinterpret_cast>::FnType*>( - lapack_ptr("zgehrd")); - Sytrd::fn = - reinterpret_cast::FnType*>(lapack_ptr("ssytrd")); - Sytrd::fn = - reinterpret_cast::FnType*>(lapack_ptr("dsytrd")); - Sytrd>::fn = - reinterpret_cast>::FnType*>( - lapack_ptr("chetrd")); - Sytrd>::fn = - reinterpret_cast>::FnType*>( - lapack_ptr("zhetrd")); + AssignKernelFn>(lapack_ptr("sgetrf")); + AssignKernelFn>(lapack_ptr("dgetrf")); + AssignKernelFn>>(lapack_ptr("cgetrf")); + AssignKernelFn>>(lapack_ptr("zgetrf")); + AssignKernelFn>(lapack_ptr("sgetrf")); + AssignKernelFn>(lapack_ptr("dgetrf")); + AssignKernelFn>(lapack_ptr("cgetrf")); + AssignKernelFn>(lapack_ptr("zgetrf")); + + AssignKernelFn>(lapack_ptr("sgeqrf")); + AssignKernelFn>(lapack_ptr("dgeqrf")); + AssignKernelFn>>(lapack_ptr("cgeqrf")); + AssignKernelFn>>(lapack_ptr("zgeqrf")); + + AssignKernelFn>(lapack_ptr("sorgqr")); + AssignKernelFn>(lapack_ptr("dorgqr")); + AssignKernelFn>>(lapack_ptr("cungqr")); + AssignKernelFn>>(lapack_ptr("zungqr")); + + AssignKernelFn>(lapack_ptr("spotrf")); + AssignKernelFn>(lapack_ptr("dpotrf")); + AssignKernelFn>>(lapack_ptr("cpotrf")); + AssignKernelFn>>(lapack_ptr("zpotrf")); + AssignKernelFn>(lapack_ptr("spotrf")); + AssignKernelFn>(lapack_ptr("dpotrf")); + AssignKernelFn>(lapack_ptr("cpotrf")); + AssignKernelFn>(lapack_ptr("zpotrf")); + + AssignKernelFn>(lapack_ptr("sgesdd")); + AssignKernelFn>(lapack_ptr("dgesdd")); + AssignKernelFn>>(lapack_ptr("cgesdd")); + AssignKernelFn>>(lapack_ptr("zgesdd")); + + AssignKernelFn>(lapack_ptr("ssyevd")); + AssignKernelFn>(lapack_ptr("dsyevd")); + AssignKernelFn>>(lapack_ptr("cheevd")); + AssignKernelFn>>(lapack_ptr("zheevd")); + + AssignKernelFn>(lapack_ptr("sgeev")); + AssignKernelFn>(lapack_ptr("dgeev")); + AssignKernelFn>>(lapack_ptr("cgeev")); + AssignKernelFn>>(lapack_ptr("zgeev")); + + AssignKernelFn>(lapack_ptr("sgees")); + AssignKernelFn>(lapack_ptr("dgees")); + AssignKernelFn>>(lapack_ptr("cgees")); + AssignKernelFn>>(lapack_ptr("zgees")); + + AssignKernelFn>(lapack_ptr("sgehrd")); + AssignKernelFn>(lapack_ptr("dgehrd")); + AssignKernelFn>>(lapack_ptr("cgehrd")); + AssignKernelFn>>(lapack_ptr("zgehrd")); + + AssignKernelFn>(lapack_ptr("ssytrd")); + AssignKernelFn>(lapack_ptr("dsytrd")); + AssignKernelFn>>(lapack_ptr("chetrd")); + AssignKernelFn>>(lapack_ptr("zhetrd")); initialized = true; } @@ -222,14 +182,24 @@ nb::dict Registrations() { dict["lapack_zhetrd"] = EncapsulateFunction(Sytrd>::Kernel); + dict["lapack_sgetrf_ffi"] = EncapsulateFunction(lapack_sgetrf_ffi); + dict["lapack_dgetrf_ffi"] = EncapsulateFunction(lapack_dgetrf_ffi); + dict["lapack_cgetrf_ffi"] = EncapsulateFunction(lapack_cgetrf_ffi); + dict["lapack_zgetrf_ffi"] = EncapsulateFunction(lapack_zgetrf_ffi); + dict["lapack_spotrf_ffi"] = EncapsulateFunction(lapack_spotrf_ffi); + dict["lapack_dpotrf_ffi"] = EncapsulateFunction(lapack_dpotrf_ffi); + dict["lapack_cpotrf_ffi"] = EncapsulateFunction(lapack_cpotrf_ffi); + dict["lapack_zpotrf_ffi"] = EncapsulateFunction(lapack_zpotrf_ffi); + return dict; } NB_MODULE(_lapack, m) { // Populates the LAPACK kernels from scipy on first call. m.def("initialize", GetLapackKernelsFromScipy); - m.def("registrations", &Registrations); + + // Old-style LAPACK Workspace Size Queries m.def("lapack_sgeqrf_workspace", &Geqrf::Workspace, nb::arg("m"), nb::arg("n")); m.def("lapack_dgeqrf_workspace", &Geqrf::Workspace, nb::arg("m"), diff --git a/jaxlib/cpu/lapack.h b/jaxlib/cpu/lapack.h new file mode 100644 index 000000000000..b00440616f19 --- /dev/null +++ b/jaxlib/cpu/lapack.h @@ -0,0 +1,61 @@ +/* Copyright 2024 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_CPU_LAPACK_H_ +#define JAXLIB_CPU_LAPACK_H_ + +#include "jaxlib/cpu/lapack_kernels.h" +#include "xla/ffi/api/ffi.h" + +namespace jax { + +// FFI Definition Macros (by DataType) + +#define JAX_CPU_DEFINE_GETRF(name, data_type) \ + XLA_FFI_DEFINE_HANDLER( \ + name, LuDecomposition::Kernel, \ + ::xla::ffi::Ffi::Bind() \ + .Arg<::xla::ffi::Buffer>(/*x*/) \ + .Ret<::xla::ffi::Buffer>(/*x_out*/) \ + .Ret<::xla::ffi::Buffer>(/*ipiv*/) \ + .Ret<::xla::ffi::Buffer>(/*info*/)) + +#define JAX_CPU_DEFINE_POTRF(name, data_type) \ + XLA_FFI_DEFINE_HANDLER( \ + name, CholeskyFactorization::Kernel, \ + ::xla::ffi::Ffi::Bind() \ + .Arg<::xla::ffi::Buffer>(/*x*/) \ + .Attr("uplo") \ + .Ret<::xla::ffi::Buffer>(/*x_out*/) \ + .Ret<::xla::ffi::Buffer>(/*info*/)) + +// FFI Handlers + +JAX_CPU_DEFINE_GETRF(lapack_sgetrf_ffi, ::xla::ffi::DataType::F32); +JAX_CPU_DEFINE_GETRF(lapack_dgetrf_ffi, ::xla::ffi::DataType::F64); +JAX_CPU_DEFINE_GETRF(lapack_cgetrf_ffi, ::xla::ffi::DataType::C64); +JAX_CPU_DEFINE_GETRF(lapack_zgetrf_ffi, ::xla::ffi::DataType::C128); + +JAX_CPU_DEFINE_POTRF(lapack_spotrf_ffi, ::xla::ffi::DataType::F32); +JAX_CPU_DEFINE_POTRF(lapack_dpotrf_ffi, ::xla::ffi::DataType::F64); +JAX_CPU_DEFINE_POTRF(lapack_cpotrf_ffi, ::xla::ffi::DataType::C64); +JAX_CPU_DEFINE_POTRF(lapack_zpotrf_ffi, ::xla::ffi::DataType::C128); + +#undef JAX_CPU_DEFINE_GETRF +#undef JAX_CPU_DEFINE_POTRF + +} // namespace jax + +#endif // JAXLIB_CPU_LAPACK_H_ diff --git a/jaxlib/cpu/lapack_kernels.cc b/jaxlib/cpu/lapack_kernels.cc index 00b54bab0822..85c4cc44b065 100644 --- a/jaxlib/cpu/lapack_kernels.cc +++ b/jaxlib/cpu/lapack_kernels.cc @@ -15,32 +15,133 @@ limitations under the License. #include "jaxlib/cpu/lapack_kernels.h" +#include #include +#include +#include #include -#include +#include #include +#include +#include +#include +#include +#include +#include "absl/algorithm/container.h" #include "absl/base/dynamic_annotations.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_format.h" +#include "xla/ffi/api/c_api.h" +#include "xla/ffi/api/ffi.h" + +static_assert(sizeof(jax::lapack_int) == sizeof(int32_t), + "Expected LAPACK integers to be 32-bit"); + +namespace ffi = xla::ffi; + +// TODO(danfm): These macros and the casting functions should be moved to a +// separate header for use in other FFI kernels. +#define ASSIGN_OR_RETURN_FFI_ERROR(lhs, rhs) \ + if (!rhs.ok()) { \ + return ffi::Error(static_cast(rhs.status().code()), \ + std::string(rhs.status().message())); \ + } \ + lhs = rhs.value() + +#define RETURN_IF_FFI_ERROR(...) \ + do { \ + ffi::Error err = (__VA_ARGS__); \ + if (err.failure()) { \ + return err; \ + } \ + } while (0) namespace { -inline int64_t catch_lapack_int_overflow(const std::string& source, int64_t value) { - if constexpr (sizeof(jax::lapack_int) == sizeof(int64_t)) { +template +inline absl::StatusOr MaybeCastNoOverflow( + int64_t value, const std::string& source = __FILE__) { + if constexpr (sizeof(T) == sizeof(int64_t)) { return value; } else { - if (value > std::numeric_limits::max()) { - throw std::overflow_error(source + "(=" + std::to_string(value) + ") exceeds maximum value of jax::lapack_int"); + if (value > std::numeric_limits::max()) [[unlikely]] { + return absl::InvalidArgumentError( + absl::StrFormat("%s: Value (=%d) exceeds the maximum representable " + "value of the desired type", + source, value)); } - return value; + return static_cast(value); + } +} + +template +inline T CastNoOverflow(int64_t value, const std::string& source = __FILE__) { + auto result = MaybeCastNoOverflow(value, source); + if (!result.ok()) { + throw std::overflow_error{std::string(result.status().message())}; } + return result.value(); } +template +ffi::Error CheckMatrixDimensions(ffi::Span dims) { + if (dims.size() < 2) { + return ffi::Error(ffi::ErrorCode::kInvalidArgument, + "Matrix must have at least 2 dimensions"); + } + return ffi::Error::Success(); } +template +std::tuple SplitBatch2D(ffi::Span dims) { + auto matrix_dims = dims.last(2); + return std::make_tuple(absl::c_accumulate(dims.first(dims.size() - 2), 1, + std::multiplies()), + matrix_dims.front(), matrix_dims.back()); +} + +template +void CopyIfDiffBuffer(ffi::Buffer x, ffi::ResultBuffer x_out) { + auto [batch_count, x_rows, x_cols] = SplitBatch2D(x.dimensions); + if (x.data != x_out->data) { + const auto x_size = batch_count * x_rows * x_cols; + std::copy_n(x.data, x_size, x_out->data); + } +} + +} // namespace + +#define REGISTER_CHAR_ENUM_ATTR_DECODING(type) \ + std::optional xla::ffi::AttrDecoding::Decode( \ + XLA_FFI_AttrType attr_type, void* attr, DiagnosticEngine& diagnostic) { \ + if (attr_type != XLA_FFI_AttrType_SCALAR) [[unlikely]] { \ + return diagnostic.Emit("Wrong attribute type: expected ") \ + << XLA_FFI_AttrType_SCALAR << " but got" << attr_type; \ + } \ + auto* scalar = reinterpret_cast(attr); \ + if (scalar->dtype != XLA_FFI_DataType_U8) [[unlikely]] { \ + return diagnostic.Emit("Wrong scalar data type: expected ") \ + << XLA_FFI_DataType_U8 << " but got " << scalar->dtype; \ + } \ + auto underlying = \ + *reinterpret_cast*>(scalar->value); \ + return static_cast(underlying); \ + } + +REGISTER_CHAR_ENUM_ATTR_DECODING(jax::MatrixParams::Side); +REGISTER_CHAR_ENUM_ATTR_DECODING(jax::MatrixParams::Transpose); +REGISTER_CHAR_ENUM_ATTR_DECODING(jax::MatrixParams::Diag); +REGISTER_CHAR_ENUM_ATTR_DECODING(jax::MatrixParams::UpLo); + +#undef REGISTER_CHAR_ENUM_ATTR_DECODING + namespace jax { -static_assert(sizeof(lapack_int) == sizeof(int32_t), - "Expected LAPACK integers to be 32-bit"); +//== Triangular System Solver ==// + +// lapack trsm template typename Trsm::FnType* Trsm::fn = nullptr; @@ -92,7 +193,9 @@ template struct Trsm; template struct Trsm>; template struct Trsm>; -// Getrf +//== LU Decomposition ==// + +// lapack getrf template typename Getrf::FnType* Getrf::fn = nullptr; @@ -126,7 +229,47 @@ template struct Getrf; template struct Getrf>; template struct Getrf>; -// Geqrf +// FFI Kernel + +template +ffi::Error LuDecomposition::Kernel( + ffi::Buffer x, ffi::ResultBuffer x_out, + ffi::ResultBuffer ipiv, + ffi::ResultBuffer info) { + RETURN_IF_FFI_ERROR(CheckMatrixDimensions(x.dimensions)); + auto [batch_count, x_rows, x_cols] = SplitBatch2D(x.dimensions); + auto* x_out_data = x_out->data; + auto* ipiv_data = ipiv->data; + auto* info_data = info->data; + + CopyIfDiffBuffer(x, x_out); + + ASSIGN_OR_RETURN_FFI_ERROR(auto x_rows_v, + MaybeCastNoOverflow(x_rows)); + ASSIGN_OR_RETURN_FFI_ERROR(auto x_cols_v, + MaybeCastNoOverflow(x_cols)); + auto x_leading_dim_v = x_rows_v; + + const int64_t x_out_step{x_rows * x_cols}; + const int64_t ipiv_step{std::min(x_rows, x_cols)}; + for (int64_t i = 0; i < batch_count; ++i) { + fn(&x_rows_v, &x_cols_v, x_out_data, &x_leading_dim_v, ipiv_data, + info_data); + x_out_data += x_out_step; + ipiv_data += ipiv_step; + ++info_data; + } + return ffi::Error::Success(); +} + +template struct LuDecomposition; +template struct LuDecomposition; +template struct LuDecomposition; +template struct LuDecomposition; + +//== QR Factorization ==// + +// lapack geqrf template typename Geqrf::FnType* Geqrf::fn = nullptr; @@ -173,7 +316,10 @@ template struct Geqrf; template struct Geqrf>; template struct Geqrf>; -// Orgqr +//== Orthogonal QR ==// +//== Computes orthogonal matrix Q from QR Decomposition ==// + +// lapack orgqr template typename Orgqr::FnType* Orgqr::fn = nullptr; @@ -221,7 +367,9 @@ template struct Orgqr; template struct Orgqr>; template struct Orgqr>; -// Potrf +//== Cholesky Factorization ==// + +// lapack potrf template typename Potrf::FnType* Potrf::fn = nullptr; @@ -255,7 +403,42 @@ template struct Potrf; template struct Potrf>; template struct Potrf>; -// Gesdd +// FFI Kernel + +template +ffi::Error CholeskyFactorization::Kernel( + ffi::Buffer x, MatrixParams::UpLo uplo, + ffi::ResultBuffer x_out, ffi::ResultBuffer info) { + RETURN_IF_FFI_ERROR(CheckMatrixDimensions(x.dimensions)); + auto [batch_count, x_rows, x_cols] = SplitBatch2D(x.dimensions); + auto* x_out_data = x_out->data; + auto* info_data = info->data; + + CopyIfDiffBuffer(x, x_out); + + auto uplo_v = static_cast(uplo); + ASSIGN_OR_RETURN_FFI_ERROR( + auto x_order_v, MaybeCastNoOverflow(x.dimensions.back())); + auto x_leading_dim_v = x_order_v; + + const int64_t x_out_step{x_rows * x_cols}; + for (int64_t i = 0; i < batch_count; ++i) { + fn(&uplo_v, &x_order_v, x_out_data, &x_leading_dim_v, info_data); + x_out_data += x_out_step; + ++info_data; + } + return ffi::Error::Success(); +} + +template struct CholeskyFactorization; +template struct CholeskyFactorization; +template struct CholeskyFactorization; +template struct CholeskyFactorization; + +//== Singular Value Decomposition (SVD) ==// +//== using a divide and conquer method ==// + +// lapack gesdd static char GesddJobz(bool job_opt_compute_uv, bool job_opt_full_matrices) { if (!job_opt_compute_uv) { @@ -267,7 +450,7 @@ static char GesddJobz(bool job_opt_compute_uv, bool job_opt_full_matrices) { } lapack_int GesddIworkSize(int64_t m, int64_t n) { - return catch_lapack_int_overflow("gesdd iwork", 8 * std::min(m, n)); + return CastNoOverflow(8 * std::min(m, n), "gesdd iwork"); } template @@ -333,11 +516,12 @@ int64_t RealGesdd::Workspace(lapack_int m, lapack_int n, lapack_int ComplexGesddRworkSize(int64_t m, int64_t n, int compute_uv) { int64_t mn = std::min(m, n); if (compute_uv == 0) { - return catch_lapack_int_overflow("complex gesdd rwork", 7 * mn); + return CastNoOverflow(7 * mn, "complex gesdd rwork"); } int64_t mx = std::max(m, n); - return catch_lapack_int_overflow("complex gesdd rwork", - std::max(5 * mn * mn + 5 * mn, 2 * mx * mn + 2 * mn * mn + mn)); + return CastNoOverflow( + std::max(5 * mn * mn + 5 * mn, 2 * mx * mn + 2 * mn * mn + mn), + "complex gesdd rwork"); } template @@ -408,13 +592,17 @@ template struct RealGesdd; template struct ComplexGesdd>; template struct ComplexGesdd>; +//== Eigenvalues and eigenvectors ==// + +// lapack syevd/heevd + // # Workspace sizes, taken from the LAPACK documentation. lapack_int SyevdWorkSize(int64_t n) { - return catch_lapack_int_overflow("syevd lwork", 1 + 6 * n + 2 * n * n); + return CastNoOverflow(1 + 6 * n + 2 * n * n, "syevd lwork"); } lapack_int SyevdIworkSize(int64_t n) { - return catch_lapack_int_overflow("syevd iwork", 3 + 5 * n); + return CastNoOverflow(3 + 5 * n, "syevd iwork"); } template @@ -454,11 +642,11 @@ void RealSyevd::Kernel(void* out_tuple, void** data, XlaCustomCallStatus*) { // Workspace sizes, taken from the LAPACK documentation. lapack_int HeevdWorkSize(int64_t n) { - return catch_lapack_int_overflow("heevd work", 1 + 2 * n + n * n); + return CastNoOverflow(1 + 2 * n + n * n, "heevd work"); } lapack_int HeevdRworkSize(int64_t n) { - return catch_lapack_int_overflow("heevd rwork", 1 + 5 * n + 2 * n * n); + return CastNoOverflow(1 + 5 * n + 2 * n * n, "heevd rwork"); } template @@ -534,6 +722,8 @@ static void UnpackEigenvectors(int n, const T* im_eigenvalues, const T* packed, } } +// lapack geev + template typename RealGeev::FnType* RealGeev::fn = nullptr; @@ -679,7 +869,9 @@ template struct RealGeev; template struct ComplexGeev>; template struct ComplexGeev>; -// Gees +//== Schur Decomposition ==// + +// lapack gees template typename RealGees::FnType* RealGees::fn = nullptr; @@ -809,6 +1001,10 @@ template struct RealGees; template struct ComplexGees>; template struct ComplexGees>; +//== Hessenberg Decomposition ==// + +// lapack gehrd + template typename Gehrd::FnType* Gehrd::fn = nullptr; @@ -859,6 +1055,10 @@ template struct Gehrd; template struct Gehrd>; template struct Gehrd>; +//== Tridiagonal Reduction ==// + +// lapack sytrd/hetrd + template typename Sytrd::FnType* Sytrd::fn = nullptr; @@ -917,3 +1117,6 @@ template struct Sytrd>; template struct Sytrd>; } // namespace jax + +#undef ASSIGN_OR_RETURN_FFI_ERROR +#undef RETURN_IF_FFI_ERROR diff --git a/jaxlib/cpu/lapack_kernels.h b/jaxlib/cpu/lapack_kernels.h index 4641b772c2ab..4119f6ba08a2 100644 --- a/jaxlib/cpu/lapack_kernels.h +++ b/jaxlib/cpu/lapack_kernels.h @@ -16,19 +16,70 @@ limitations under the License. #ifndef JAXLIB_CPU_LAPACK_KERNELS_H_ #define JAXLIB_CPU_LAPACK_KERNELS_H_ -#include #include +#include +#include +#include "xla/ffi/api/ffi.h" +#include "xla/ffi/api/c_api.h" #include "xla/service/custom_call_status.h" -// Underlying function pointers (e.g., Trsm::Fn) are initialized either +// Underlying function pointers (i.e., KERNEL_CLASS::Fn) are initialized either // by the pybind wrapper that links them to an existing SciPy lapack instance, // or using the lapack_kernels_strong.cc static initialization to link them // directly to lapack for use in a pure C++ context. namespace jax { -typedef int lapack_int; +struct MatrixParams { + enum class Side : char { kLeft = 'L', kRight = 'R' }; + enum class UpLo : char { kLower = 'L', kUpper = 'U' }; + enum class Diag : char { kNonUnit = 'N', kUnit = 'U' }; + enum class Transpose : char { + kNoTrans = 'N', + kTrans = 'T', + kConjTrans = 'C' + }; +}; + +template +void AssignKernelFn(void* func) { + KernelType::fn = reinterpret_cast(func); +} + +template +void AssignKernelFn(typename KernelType::FnType* func) { + KernelType::fn = func; +} + +} // namespace jax + +#define DEFINE_CHAR_ENUM_ATTR_DECODING(ATTR) \ + template <> \ + struct xla::ffi::AttrDecoding { \ + using Type = ATTR; \ + static std::optional Decode(XLA_FFI_AttrType type, void* attr, \ + DiagnosticEngine& diagnostic); \ + } + +// XLA needs attributes to have deserialization method specified +DEFINE_CHAR_ENUM_ATTR_DECODING(jax::MatrixParams::Side); +DEFINE_CHAR_ENUM_ATTR_DECODING(jax::MatrixParams::UpLo); +DEFINE_CHAR_ENUM_ATTR_DECODING(jax::MatrixParams::Transpose); +DEFINE_CHAR_ENUM_ATTR_DECODING(jax::MatrixParams::Diag); + +#undef DEFINE_CHAR_ENUM_ATTR_DECODING + +namespace jax { + +using lapack_int = int; +inline constexpr auto LapackIntDtype = ::xla::ffi::DataType::S32; +static_assert( + std::is_same_v<::xla::ffi::NativeType, lapack_int>); + +//== Triangular System Solver ==// + +// lapack trsm template struct Trsm { @@ -40,6 +91,10 @@ struct Trsm { static void Kernel(void* out, void** data, XlaCustomCallStatus*); }; +//== LU Decomposition ==// + +// lapack getrf + template struct Getrf { using FnType = void(lapack_int* m, lapack_int* n, T* a, lapack_int* lda, @@ -49,6 +104,25 @@ struct Getrf { static void Kernel(void* out, void** data, XlaCustomCallStatus*); }; +// FFI Kernel + +template <::xla::ffi::DataType dtype> +struct LuDecomposition { + using ValueType = ::xla::ffi::NativeType; + using FnType = void(lapack_int* m, lapack_int* n, ValueType* a, + lapack_int* lda, lapack_int* ipiv, lapack_int* info); + + inline static FnType* fn = nullptr; + static ::xla::ffi::Error Kernel( + ::xla::ffi::Buffer x, ::xla::ffi::ResultBuffer x_out, + ::xla::ffi::ResultBuffer ipiv, + ::xla::ffi::ResultBuffer info); +}; + +//== QR Factorization ==// + +// lapack geqrf + template struct Geqrf { using FnType = void(lapack_int* m, lapack_int* n, T* a, lapack_int* lda, @@ -60,6 +134,10 @@ struct Geqrf { static int64_t Workspace(lapack_int m, lapack_int n); }; +//== Orthogonal QR ==// + +// lapack orgqr + template struct Orgqr { using FnType = void(lapack_int* m, lapack_int* n, lapack_int* k, T* a, @@ -70,6 +148,10 @@ struct Orgqr { static int64_t Workspace(lapack_int m, lapack_int n, lapack_int k); }; +//== Cholesky Factorization ==// + +// lapack potrf + template struct Potrf { using FnType = void(char* uplo, lapack_int* n, T* a, lapack_int* lda, @@ -78,6 +160,24 @@ struct Potrf { static void Kernel(void* out, void** data, XlaCustomCallStatus*); }; +template <::xla::ffi::DataType dtype> +struct CholeskyFactorization { + using ValueType = ::xla::ffi::NativeType; + using FnType = void(char* uplo, lapack_int* n, ValueType* a, lapack_int* lda, + lapack_int* info); + + inline static FnType* fn = nullptr; + + static ::xla::ffi::Error Kernel( + ::xla::ffi::Buffer x, MatrixParams::UpLo uplo, + ::xla::ffi::ResultBuffer x_out, + ::xla::ffi::ResultBuffer info); +}; + +//== Singular Value Decomposition (SVD) ==// + +// lapack gesdd + lapack_int GesddIworkSize(int64_t m, int64_t n); template @@ -109,6 +209,10 @@ struct ComplexGesdd { bool job_opt_full_matrices); }; +//== Eigenvalues and eigenvectors ==// + +// lapack syevd/heevd + lapack_int SyevdWorkSize(int64_t n); lapack_int SyevdIworkSize(int64_t n); @@ -135,6 +239,8 @@ struct ComplexHeevd { static void Kernel(void* out, void** data, XlaCustomCallStatus*); }; +// lapack geev + template struct RealGeev { using FnType = void(char* jobvl, char* jobvr, lapack_int* n, T* a, @@ -155,6 +261,10 @@ struct ComplexGeev { static void Kernel(void* out, void** data, XlaCustomCallStatus*); }; +//== Schur Decomposition ==// + +// lapack gees + template struct RealGees { using FnType = void(char* jobvs, char* sort, bool (*select)(T, T), @@ -176,7 +286,11 @@ struct ComplexGees { static void Kernel(void* out, void** data, XlaCustomCallStatus*); }; -// Gehrd: Reduces a non-symmetric square matrix to upper Hessenberg form. +//== Hessenberg Decomposition ==// +//== Reduces a non-symmetric square matrix to upper Hessenberg form ==// + +// lapack gehrd + template struct Gehrd { using FnType = void(lapack_int* n, lapack_int* ilo, lapack_int* ihi, T* a, @@ -199,14 +313,16 @@ struct real_type> { typedef T type; }; -// Sytrd/Hetrd: Reduces a symmetric (Hermitian) square matrix to tridiagonal -// form. +//== Tridiagonal Reduction ==// +//== Reduces a Symmetric/Hermitian square matrix to tridiagonal form ==// + +// lapack sytrd/hetrd + template struct Sytrd { using FnType = void(char* uplo, lapack_int* n, T* a, lapack_int* lda, typename real_type::type* d, - typename real_type::type* e, - T* tau, T* work, + typename real_type::type* e, T* tau, T* work, lapack_int* lwork, lapack_int* info); static FnType* fn; diff --git a/jaxlib/cpu/lapack_kernels_using_lapack.cc b/jaxlib/cpu/lapack_kernels_using_lapack.cc index bc67fc556a49..48b1d5bffc1b 100644 --- a/jaxlib/cpu/lapack_kernels_using_lapack.cc +++ b/jaxlib/cpu/lapack_kernels_using_lapack.cc @@ -13,12 +13,17 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include +#include + #include "jaxlib/cpu/lapack_kernels.h" // From a Python binary, JAX obtains its LAPACK/BLAS kernels from Scipy, but // a C++ user should link against LAPACK directly. This is needed when using // JAX-generated HLO from C++. +namespace ffi = xla::ffi; + extern "C" { jax::Trsm::FnType strsm_; @@ -26,10 +31,10 @@ jax::Trsm::FnType dtrsm_; jax::Trsm>::FnType ctrsm_; jax::Trsm>::FnType ztrsm_; -jax::Getrf::FnType sgetrf_; -jax::Getrf::FnType dgetrf_; -jax::Getrf>::FnType cgetrf_; -jax::Getrf>::FnType zgetrf_; +jax::LuDecomposition::FnType sgetrf_; +jax::LuDecomposition::FnType dgetrf_; +jax::LuDecomposition::FnType cgetrf_; +jax::LuDecomposition::FnType zgetrf_; jax::Geqrf::FnType sgeqrf_; jax::Geqrf::FnType dgeqrf_; @@ -41,10 +46,10 @@ jax::Orgqr::FnType dorgqr_; jax::Orgqr>::FnType cungqr_; jax::Orgqr>::FnType zungqr_; -jax::Potrf::FnType spotrf_; -jax::Potrf::FnType dpotrf_; -jax::Potrf>::FnType cpotrf_; -jax::Potrf>::FnType zpotrf_; +jax::CholeskyFactorization::FnType spotrf_; +jax::CholeskyFactorization::FnType dpotrf_; +jax::CholeskyFactorization::FnType cpotrf_; +jax::CholeskyFactorization::FnType zpotrf_; jax::RealGesdd::FnType sgesdd_; jax::RealGesdd::FnType dgesdd_; @@ -80,51 +85,106 @@ jax::Sytrd>::FnType zhetrd_; namespace jax { +#define JAX_KERNEL_FNTYPE_MISMATCH_MSG "FFI Kernel FnType mismatch" + +static_assert(std::is_same_v::FnType, + jax::Getrf::FnType>, + JAX_KERNEL_FNTYPE_MISMATCH_MSG); +static_assert(std::is_same_v::FnType, + jax::Getrf::FnType>, + JAX_KERNEL_FNTYPE_MISMATCH_MSG); +static_assert(std::is_same_v::FnType, + jax::Getrf>::FnType>, + JAX_KERNEL_FNTYPE_MISMATCH_MSG); +static_assert(std::is_same_v::FnType, + jax::Getrf>::FnType>, + JAX_KERNEL_FNTYPE_MISMATCH_MSG); +static_assert( + std::is_same_v::FnType, + jax::Potrf::FnType>, + JAX_KERNEL_FNTYPE_MISMATCH_MSG); +static_assert( + std::is_same_v::FnType, + jax::Potrf::FnType>, + JAX_KERNEL_FNTYPE_MISMATCH_MSG); +static_assert( + std::is_same_v::FnType, + jax::Potrf>::FnType>, + JAX_KERNEL_FNTYPE_MISMATCH_MSG); +static_assert( + std::is_same_v::FnType, + jax::Potrf>::FnType>, + JAX_KERNEL_FNTYPE_MISMATCH_MSG); + +#undef JAX_KERNEL_FNTYPE_MISMATCH_MSG + static auto init = []() -> int { - Trsm::fn = strsm_; - Trsm::fn = dtrsm_; - Trsm>::fn = ctrsm_; - Trsm>::fn = ztrsm_; - Getrf::fn = sgetrf_; - Getrf::fn = dgetrf_; - Getrf>::fn = cgetrf_; - Getrf>::fn = zgetrf_; - Geqrf::fn = sgeqrf_; - Geqrf::fn = dgeqrf_; - Geqrf>::fn = cgeqrf_; - Geqrf>::fn = zgeqrf_; - Orgqr::fn = sorgqr_; - Orgqr::fn = dorgqr_; - Orgqr>::fn = cungqr_; - Orgqr>::fn = zungqr_; - Potrf::fn = spotrf_; - Potrf::fn = dpotrf_; - Potrf>::fn = cpotrf_; - Potrf>::fn = zpotrf_; - RealGesdd::fn = sgesdd_; - RealGesdd::fn = dgesdd_; - ComplexGesdd>::fn = cgesdd_; - ComplexGesdd>::fn = zgesdd_; - RealSyevd::fn = ssyevd_; - RealSyevd::fn = dsyevd_; - ComplexHeevd>::fn = cheevd_; - ComplexHeevd>::fn = zheevd_; - RealGeev::fn = sgeev_; - RealGeev::fn = dgeev_; - ComplexGeev>::fn = cgeev_; - ComplexGeev>::fn = zgeev_; - RealGees::fn = sgees_; - RealGees::fn = dgees_; - ComplexGees>::fn = cgees_; - ComplexGees>::fn = zgees_; - Gehrd::fn = sgehrd_; - Gehrd::fn = dgehrd_; - Gehrd>::fn = cgehrd_; - Gehrd>::fn = zgehrd_; - Sytrd::fn = ssytrd_; - Sytrd::fn = dsytrd_; - Sytrd>::fn = chetrd_; - Sytrd>::fn = zhetrd_; + AssignKernelFn>(strsm_); + AssignKernelFn>(dtrsm_); + AssignKernelFn>>(ctrsm_); + AssignKernelFn>>(ztrsm_); + + AssignKernelFn>(sgetrf_); + AssignKernelFn>(dgetrf_); + AssignKernelFn>>(cgetrf_); + AssignKernelFn>>(zgetrf_); + + AssignKernelFn>(sgeqrf_); + AssignKernelFn>(dgeqrf_); + AssignKernelFn>>(cgeqrf_); + AssignKernelFn>>(zgeqrf_); + + AssignKernelFn>(sorgqr_); + AssignKernelFn>(dorgqr_); + AssignKernelFn>>(cungqr_); + AssignKernelFn>>(zungqr_); + + AssignKernelFn>(spotrf_); + AssignKernelFn>(dpotrf_); + AssignKernelFn>>(cpotrf_); + AssignKernelFn>>(zpotrf_); + + AssignKernelFn>(sgesdd_); + AssignKernelFn>(dgesdd_); + AssignKernelFn>>(cgesdd_); + AssignKernelFn>>(zgesdd_); + + AssignKernelFn>(ssyevd_); + AssignKernelFn>(dsyevd_); + AssignKernelFn>>(cheevd_); + AssignKernelFn>>(zheevd_); + + AssignKernelFn>(sgeev_); + AssignKernelFn>(dgeev_); + AssignKernelFn>>(cgeev_); + AssignKernelFn>>(zgeev_); + + AssignKernelFn>(sgees_); + AssignKernelFn>(dgees_); + AssignKernelFn>>(cgees_); + AssignKernelFn>>(zgees_); + + AssignKernelFn>(sgehrd_); + AssignKernelFn>(dgehrd_); + AssignKernelFn>>(cgehrd_); + AssignKernelFn>>(zgehrd_); + + AssignKernelFn>(ssytrd_); + AssignKernelFn>(dsytrd_); + AssignKernelFn>>(chetrd_); + AssignKernelFn>>(zhetrd_); + + // FFI Kernels + + AssignKernelFn>(sgetrf_); + AssignKernelFn>(dgetrf_); + AssignKernelFn>(cgetrf_); + AssignKernelFn>(zgetrf_); + + AssignKernelFn>(spotrf_); + AssignKernelFn>(dpotrf_); + AssignKernelFn>(cpotrf_); + AssignKernelFn>(zpotrf_); return 0; }(); diff --git a/jaxlib/cpu_feature_guard.c b/jaxlib/cpu_feature_guard.c index b7fe688eaf52..7c8ff2951a79 100644 --- a/jaxlib/cpu_feature_guard.c +++ b/jaxlib/cpu_feature_guard.c @@ -77,9 +77,19 @@ static int GetXCR0EAX() { static void ReportMissingCpuFeature(const char* name) { PyErr_Format( PyExc_RuntimeError, +#if defined(__APPLE__) + "This version of jaxlib was built using %s instructions, which your " + "CPU and/or operating system do not support. This error is frequently " + "encountered on macOS when running an x86 Python installation on ARM " + "hardware. In this case, try installing an ARM build of Python. " + "Otherwise, you may be able work around this issue by building jaxlib " + "from source.", +#else "This version of jaxlib was built using %s instructions, which your " "CPU and/or operating system do not support. You may be able work around " - "this issue by building jaxlib from source.", name); + "this issue by building jaxlib from source.", +#endif + name); } static PyObject *CheckCpuFeatures(PyObject *self, PyObject *args) { diff --git a/jaxlib/cuda/BUILD b/jaxlib/cuda/BUILD index 2e97ca7b1474..ed2792f1e7ce 100644 --- a/jaxlib/cuda/BUILD +++ b/jaxlib/cuda/BUILD @@ -14,13 +14,13 @@ # NVIDIA CUDA kernels +load("@rules_python//python:defs.bzl", "py_library") load( "//jaxlib:jax.bzl", "cuda_library", "if_cuda_is_configured", "pybind_extension", ) -load("@rules_python//python:defs.bzl", "py_library") licenses(["notice"]) @@ -37,6 +37,7 @@ cc_library( defines = ["JAX_GPU_CUDA=1"], visibility = ["//visibility:public"], deps = [ + "@xla//xla/tsl/cuda:cupti", "@local_config_cuda//cuda:cuda_headers", "@local_config_cuda//cuda:cudnn_header", ], @@ -118,7 +119,7 @@ pybind_extension( ":cuda_vendor", "//jaxlib:kernel_nanobind_helpers", "@xla//xla/tsl/cuda:cublas", - "@tsl//tsl/python/lib/core:numpy", + "@xla//xla/tsl/python/lib/core:numpy", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings:str_format", "@nanobind", @@ -205,7 +206,7 @@ pybind_extension( "//jaxlib:kernel_nanobind_helpers", "@xla//xla/tsl/cuda:cudart", "@xla//xla/tsl/cuda:cusolver", - "@tsl//tsl/python/lib/core:numpy", + "@xla//xla/tsl/python/lib/core:numpy", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings:str_format", "@local_config_cuda//cuda:cuda_headers", @@ -255,7 +256,7 @@ pybind_extension( "//jaxlib:kernel_nanobind_helpers", "@xla//xla/tsl/cuda:cudart", "@xla//xla/tsl/cuda:cusparse", - "@tsl//tsl/python/lib/core:numpy", + "@xla//xla/tsl/python/lib/core:numpy", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base", "@com_google_absl//absl/base:core_headers", @@ -280,8 +281,12 @@ cc_library( ":cuda_gpu_kernel_helpers", ":cuda_lu_pivot_kernels_impl", ":cuda_vendor", - "//jaxlib:kernel_helpers", - "@xla//xla/service:custom_call_status", + "@xla//xla/ffi/api:c_api", + "@xla//xla/ffi/api:ffi", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:str_format", "@local_config_cuda//cuda:cuda_headers", ], ) @@ -295,6 +300,43 @@ cuda_library( deps = [ ":cuda_gpu_kernel_helpers", ":cuda_vendor", + "@xla//xla/ffi/api:ffi", + "@xla//xla/service:custom_call_status", + "@local_config_cuda//cuda:cuda_headers", + ], +) + +cc_library( + name = "cholesky_update_kernel", + srcs = [ + "//jaxlib/gpu:cholesky_update_kernel.cc", + ], + hdrs = ["//jaxlib/gpu:cholesky_update_kernel.h"], + features = ["-use_header_modules"], + deps = [ + ":cholesky_update_kernel_impl", + ":cuda_gpu_kernel_helpers", + ":cuda_vendor", + ":cusolver_kernels", + "//jaxlib:kernel_helpers", + "@xla//xla/service:custom_call_status", + "@com_google_absl//absl/status", + "@local_config_cuda//cuda:cuda_headers", + ], +) + +cuda_library( + name = "cholesky_update_kernel_impl", + srcs = [ + "//jaxlib/gpu:cholesky_update_kernel.cu.cc", + ], + hdrs = [ + "//jaxlib/gpu:cholesky_update_kernel.h", + ], + deps = [ + ":cuda_gpu_kernel_helpers", + ":cuda_vendor", + ":cusolver_kernels", "//jaxlib:kernel_helpers", "@xla//xla/service:custom_call_status", "@local_config_cuda//cuda:cuda_headers", @@ -311,12 +353,14 @@ pybind_extension( features = ["-use_header_modules"], module_name = "_linalg", deps = [ + ":cholesky_update_kernel", ":cuda_gpu_kernel_helpers", ":cuda_lu_pivot_kernels", ":cuda_lu_pivot_kernels_impl", ":cuda_vendor", "//jaxlib:kernel_nanobind_helpers", "@xla//xla/tsl/cuda:cudart", + "@xla//xla/tsl/python/lib/core:numpy", "@local_config_cuda//cuda:cuda_headers", "@nanobind", ], @@ -333,7 +377,10 @@ cc_library( ":cuda_prng_kernels_impl", ":cuda_vendor", "//jaxlib:kernel_helpers", + "@xla//xla/ffi/api:c_api", + "@xla//xla/ffi/api:ffi", "@xla//xla/service:custom_call_status", + "@com_google_absl//absl/algorithm:container", "@local_config_cuda//cuda:cuda_headers", ], ) @@ -348,6 +395,8 @@ cuda_library( ":cuda_gpu_kernel_helpers", ":cuda_vendor", "//jaxlib:kernel_helpers", + "@xla//xla/ffi/api:c_api", + "@xla//xla/ffi/api:ffi", "@xla//xla/service:custom_call_status", "@local_config_cuda//cuda:cuda_headers", ], @@ -377,13 +426,17 @@ cc_library( srcs = ["//jaxlib/gpu:gpu_kernels.cc"], visibility = ["//visibility:public"], deps = [ + ":cholesky_update_kernel", ":cublas_kernels", ":cuda_lu_pivot_kernels", ":cuda_prng_kernels", ":cuda_vendor", + ":cudnn_rnn_kernels", ":cusolver_kernels", ":cusparse_kernels", ":triton_kernels", + "@xla//xla/ffi/api:c_api", + "@xla//xla/ffi/api:ffi", "@xla//xla/service:custom_call_target_registry", ], alwayslink = 1, @@ -399,12 +452,13 @@ cc_library( ":triton_utils", "//jaxlib/gpu:triton_cc_proto", "@xla//xla/service:custom_call_status", - "@xla//xla/stream_executor/gpu:asm_compiler", + "@xla//xla/stream_executor/cuda:cuda_asm_compiler", "@xla//xla/tsl/cuda:cudart", "@tsl//tsl/platform:env", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/cleanup", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", @@ -470,8 +524,6 @@ cc_library( deps = [ ":cuda_gpu_kernel_helpers", ":cuda_vendor", - "//jaxlib:absl_status_casters", - "//jaxlib:kernel_nanobind_helpers", "@xla//xla/tsl/cuda:cublas", "@xla//xla/tsl/cuda:cudart", "@xla//xla/tsl/cuda:cudnn", @@ -479,6 +531,7 @@ cc_library( "@xla//xla/tsl/cuda:cupti", "@xla//xla/tsl/cuda:cusolver", "@xla//xla/tsl/cuda:cusparse", + "@com_google_absl//absl/base:dynamic_annotations", ], ) @@ -532,6 +585,7 @@ py_library( ":_sparse", ":_triton", ":_versions", + "//jaxlib/mosaic/gpu:mosaic_gpu", ], ) diff --git a/jaxlib/cuda/versions_helpers.cc b/jaxlib/cuda/versions_helpers.cc index 9ecd9a83ccc8..d42199d37467 100644 --- a/jaxlib/cuda/versions_helpers.cc +++ b/jaxlib/cuda/versions_helpers.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "absl/base/dynamic_annotations.h" #include "jaxlib/gpu/gpu_kernel_helpers.h" #include "jaxlib/gpu/vendor.h" @@ -30,39 +31,45 @@ namespace jax::cuda { int CudaRuntimeGetVersion() { int version; JAX_THROW_IF_ERROR(JAX_AS_STATUS(cudaRuntimeGetVersion(&version))); + ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(&version, sizeof version); return version; } int CudaDriverGetVersion() { int version; JAX_THROW_IF_ERROR(JAX_AS_STATUS(cudaDriverGetVersion(&version))); + ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(&version, sizeof version); return version; } uint32_t CuptiGetVersion() { uint32_t version; JAX_THROW_IF_ERROR(JAX_AS_STATUS(cuptiGetVersion(&version))); + ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(&version, sizeof version); return version; } int CufftGetVersion() { int version; JAX_THROW_IF_ERROR(JAX_AS_STATUS(cufftGetVersion(&version))); + ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(&version, sizeof version); return version; } int CusolverGetVersion() { int version; JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusolverGetVersion(&version))); + ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(&version, sizeof version); return version; } int CublasGetVersion() { int version; - // NVIDIA promise that it's safe to parse nullptr as the handle to this + // NVIDIA promise that it's safe to pass a null pointer as the handle to this // function. JAX_THROW_IF_ERROR( JAX_AS_STATUS(cublasGetVersion(/*handle=*/nullptr, &version))); + ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(&version, sizeof version); return version; } @@ -73,6 +80,9 @@ int CusparseGetVersion() { JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseGetProperty(MAJOR_VERSION, &major))); JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseGetProperty(MINOR_VERSION, &minor))); JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseGetProperty(PATCH_LEVEL, &patch))); + ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(&major, sizeof major); + ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(&minor, sizeof minor); + ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(&patch, sizeof patch); return major * 1000 + minor * 100 + patch; } size_t CudnnGetVersion() { @@ -82,14 +92,18 @@ size_t CudnnGetVersion() { if (version == 0) { throw std::runtime_error("cuDNN not found."); } + ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(&version, sizeof version); return version; } int CudaComputeCapability(int device) { int major, minor; + JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpuInit(0))); JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpuDeviceGetAttribute( &major, GPU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR, device))); JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpuDeviceGetAttribute( &minor, GPU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR, device))); + ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(&major, sizeof major); + ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(&minor, sizeof minor); return major * 10 + minor; } @@ -98,8 +112,9 @@ int CudaDeviceCount() { JAX_THROW_IF_ERROR(JAX_AS_STATUS(cuInit(0))); JAX_THROW_IF_ERROR(JAX_AS_STATUS(cuDeviceGetCount(&device_count))); + ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(&device_count, sizeof device_count); return device_count; } -} // namespace jax::cuda \ No newline at end of file +} // namespace jax::cuda diff --git a/jaxlib/cuda_plugin_extension.cc b/jaxlib/cuda_plugin_extension.cc index 565206051b8a..0bb8cbbace65 100644 --- a/jaxlib/cuda_plugin_extension.cc +++ b/jaxlib/cuda_plugin_extension.cc @@ -14,34 +14,33 @@ limitations under the License. ==============================================================================*/ #include +#include #include #include #include #include "nanobind/nanobind.h" +#include "absl/status/status.h" +#include "third_party/gpus/cuda/include/cuda.h" #include "jaxlib/kernel_nanobind_helpers.h" +#include "xla/ffi/api/c_api.h" #include "xla/pjrt/c/pjrt_c_api.h" #include "xla/pjrt/c/pjrt_c_api_gpu_extension.h" #include "xla/pjrt/c/pjrt_c_api_helpers.h" #include "xla/pjrt/status_casters.h" #include "xla/python/py_client_gpu.h" -#include "xla/status.h" +#include "xla/tsl/python/lib/core/numpy.h" #include "xla/util.h" -#include "tsl/python/lib/core/numpy.h" namespace nb = nanobind; namespace xla { namespace { -Status RegisterCustomCallTarget(const PJRT_Api* c_api, nb::str fn_name, - nb::capsule fn, int api_version) { - static const char* const kName = "xla._CUSTOM_CALL_TARGET"; - if (std::string_view(fn.name()) != kName) { - return InvalidArgument( - "Argument to RegisterCustomCallTargetRegistry was not a " - "xla._CUSTOM_CALL_TARGET capsule."); - } - +absl::Status RegisterCustomCallTarget(const PJRT_Api* c_api, + const char* fn_name_c_str, + size_t fn_name_size, nb::capsule fn, + int api_version, + XLA_FFI_Handler_Traits traits) { if (c_api->extension_start == nullptr) { return Unimplemented("The plugin does not have extension."); } @@ -56,10 +55,14 @@ Status RegisterCustomCallTarget(const PJRT_Api* c_api, nb::str fn_name, return Unimplemented("The plugin does not have a custom call extension."); } + if (traits != 0) { + return Unimplemented("The plugin does not support custom call traits."); + } + PJRT_Gpu_Register_Custom_Call_Args args; args.struct_size = PJRT_Gpu_Register_Custom_Call_Args_STRUCT_SIZE; - args.function_name = fn_name.c_str(); - args.function_name_size = nb::len(fn_name); + args.function_name = fn_name_c_str; + args.function_name_size = fn_name_size; #if PJRT_API_GPU_EXTENSION_VERSION >= 1 args.api_version = api_version; #endif @@ -67,7 +70,7 @@ Status RegisterCustomCallTarget(const PJRT_Api* c_api, nb::str fn_name, RETURN_STATUS_IF_PJRT_ERROR( reinterpret_cast(next)->custom_call(&args), c_api); - return OkStatus(); + return absl::OkStatus(); } nb::dict Registrations() { @@ -76,20 +79,64 @@ nb::dict Registrations() { jax::EncapsulateFunction(xla::XlaPythonGpuCallback); return dict; } + +static std::string ToString(CUresult result) { + const char* error_name; + if (cuGetErrorName(result, &error_name)) { + return absl::StrCat("UNKNOWN ERROR (", static_cast(result), ")"); + } + const char* error_string; + if (cuGetErrorString(result, &error_string)) { + return error_name; + } + return absl::StrCat(error_name, ": ", error_string); +} } // namespace NB_MODULE(cuda_plugin_extension, m) { tsl::ImportNumpy(); m.def( "register_custom_call_target", - [](nb::capsule c_api, nb::str fn_name, nb::capsule fn, - nb::str xla_platform_name, int api_version) { - xla::ThrowIfError( - RegisterCustomCallTarget(static_cast(c_api.data()), - fn_name, std::move(fn), api_version)); + [](nb::capsule c_api, nb::object fn_name_py, nb::capsule fn, + nb::str xla_platform_name, int api_version, + XLA_FFI_Handler_Traits traits) { + const char* fn_name_c_str; + size_t fn_name_size; + nb::str fn_name_bn_str; + if (nb::try_cast(fn_name_py, fn_name_bn_str)) { + fn_name_c_str = fn_name_bn_str.c_str(); + fn_name_size = nb::len(fn_name_bn_str); + } else{ + nb::bytes bytes = nb::cast(fn_name_py); + fn_name_c_str = bytes.c_str(); + fn_name_size = bytes.size(); + } + xla::ThrowIfError(RegisterCustomCallTarget( + static_cast(c_api.data()), fn_name_c_str, + fn_name_size, std::move(fn), api_version, traits)); }, nb::arg("c_api"), nb::arg("fn_name"), nb::arg("fn"), - nb::arg("xla_platform_name"), nb::arg("api_version") = 0); + nb::arg("xla_platform_name"), nb::arg("api_version") = 0, + nb::arg("traits") = 0); m.def("registrations", &Registrations); + m.def( + "get_device_ordinal", + [](std::intptr_t data_value) { + if (data_value == 0) { + return 0; + } + int device_ordinal; + void* data_ptr = reinterpret_cast(data_value); + CUresult result = + cuPointerGetAttribute(static_cast(&device_ordinal), + CU_POINTER_ATTRIBUTE_DEVICE_ORDINAL, + reinterpret_cast(data_ptr)); + if (result != CUDA_SUCCESS) { + xla::ThrowIfError(absl::InvalidArgumentError(absl::StrCat( + "Not able to get the device_ordinal: ", ToString(result)))); + } + return device_ordinal; + }, + nb::arg("data_value")); } } // namespace xla diff --git a/jaxlib/gpu/BUILD b/jaxlib/gpu/BUILD index 2d8c11757eb0..37a57463374c 100644 --- a/jaxlib/gpu/BUILD +++ b/jaxlib/gpu/BUILD @@ -27,6 +27,9 @@ exports_files(srcs = [ "blas.cc", "blas_kernels.cc", "blas_kernels.h", + "cholesky_update_kernel.cc", + "cholesky_update_kernel.cu.cc", + "cholesky_update_kernel.h", "gpu_kernel_helpers.cc", "gpu_kernel_helpers.h", "gpu_kernels.cc", diff --git a/jaxlib/gpu/blas.cc b/jaxlib/gpu/blas.cc index 9f83b86f61d4..c6ec7d038296 100644 --- a/jaxlib/gpu/blas.cc +++ b/jaxlib/gpu/blas.cc @@ -25,7 +25,7 @@ limitations under the License. #include "jaxlib/gpu/blas_kernels.h" #include "jaxlib/gpu/vendor.h" #include "jaxlib/kernel_nanobind_helpers.h" -#include "tsl/python/lib/core/numpy.h" +#include "xla/tsl/python/lib/core/numpy.h" namespace jax { namespace JAX_GPU_NAMESPACE { diff --git a/jaxlib/gpu/cholesky_update_kernel.cc b/jaxlib/gpu/cholesky_update_kernel.cc new file mode 100644 index 000000000000..eb2cf0dfa600 --- /dev/null +++ b/jaxlib/gpu/cholesky_update_kernel.cc @@ -0,0 +1,53 @@ +/* Copyright 2024 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "jaxlib/gpu/cholesky_update_kernel.h" +#include +#include + +#include "absl/status/status.h" +#include "jaxlib/gpu/gpu_kernel_helpers.h" +#include "jaxlib/gpu/vendor.h" +#include "jaxlib/kernel_helpers.h" +#include "xla/service/custom_call_status.h" + +namespace jax { +namespace JAX_GPU_NAMESPACE { +namespace { + + +absl::Status CholeskyUpdateImpl(gpuStream_t stream, void** buffers, + const char* opaque, std::size_t opaque_len) { + auto s = UnpackDescriptor(opaque, opaque_len); + JAX_RETURN_IF_ERROR(s.status()); + const CholeskyUpdateDescriptor& d = **s; + LaunchCholeskyUpdateKernel(stream, buffers, d); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuGetLastError())); + return absl::OkStatus(); +} + +} // namespace + +void CholeskyUpdate(gpuStream_t stream, void** buffers, + const char* opaque, size_t opaque_len, + XlaCustomCallStatus* status) { + auto s = CholeskyUpdateImpl(stream, buffers, opaque, opaque_len); + if (!s.ok()) { + std::string_view message = s.message(); + XlaCustomCallStatusSetFailure(status, message.data(), message.length()); + } +} + +} // namespace JAX_GPU_NAMESPACE +} // namespace jax diff --git a/jaxlib/gpu/cholesky_update_kernel.cu.cc b/jaxlib/gpu/cholesky_update_kernel.cu.cc new file mode 100644 index 000000000000..780a051d3ccd --- /dev/null +++ b/jaxlib/gpu/cholesky_update_kernel.cu.cc @@ -0,0 +1,136 @@ +/* Copyright 2024 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/gpu/cholesky_update_kernel.h" +#include + +#ifdef JAX_GPU_HIP +#include "rocm/include/hip/amd_detail/amd_hip_cooperative_groups.h" +#else // JAX_GPU_CUDA +#include "third_party/gpus/cuda/include/cooperative_groups.h" +#endif + +namespace cg = cooperative_groups; + +namespace jax { +namespace JAX_GPU_NAMESPACE { +namespace { + + +template +__device__ void drotg(T* da, T* db, T* c, T* s) { + if (*db == 0) { + *c = 1.; + *s = 0.; + return; + } + T denominator = max(abs(*da), abs(*db)); + T a = *da / denominator; + T b = *db / denominator; + T rh = rhypot(a, b); + *c = a * rh; + *s = -(b * rh); + return; +} + +template +__global__ void CholeskyUpdateKernel( + T* rMatrix, T* uVector, + int nSize) { + + cg::grid_group grid = cg::this_grid(); + int k = grid.thread_rank(); + + T c, s; + + for (int step = 0; step < 2 * nSize; ++step) { + grid.sync(); + + int i = step - k; + if (i < k || i >= nSize || k >= nSize) { + continue; + } + if (i == k) { + drotg( + rMatrix + k * nSize + k, + uVector + k, + &c, + &s); + } + T r_i = c * rMatrix[k * nSize + i] - s * uVector[i]; + uVector[i] = s * rMatrix[k * nSize + i] + c * uVector[i]; + rMatrix[k * nSize + i] = r_i; + } +} +} // namespace + + +template +void LaunchCholeskyUpdateKernelBody( + gpuStream_t stream, void** buffers, + int grid_dim, int block_dim, int nSize) { + T* rMatrix = reinterpret_cast(buffers[2]); + T* uVector = reinterpret_cast(buffers[3]); + + void* arg_ptrs[3] = { + reinterpret_cast(&rMatrix), + reinterpret_cast(&uVector), + reinterpret_cast(&nSize), + }; +#ifdef JAX_GPU_HIP + hipLaunchCooperativeKernel( + (void*) CholeskyUpdateKernel, grid_dim, block_dim, arg_ptrs, + /*dynamic_shared_mem_bytes=*/ 0, stream); +#else // JAX_GPU_CUDA + cudaLaunchCooperativeKernel( + (void*) CholeskyUpdateKernel, grid_dim, block_dim, arg_ptrs, + /*dynamic_shared_mem_bytes=*/ 0, stream); +#endif +} + + +void LaunchCholeskyUpdateKernel( + gpuStream_t stream, void** buffers, + CholeskyUpdateDescriptor descriptor) { + + int nSize = descriptor.matrix_size; + LinalgType type = descriptor.linalg_type; + + int dev = 0; +#ifdef JAX_GPU_HIP + hipDeviceProp_t deviceProp; + hipGetDeviceProperties(&deviceProp, dev); +#else // JAX_GPU_CUDA + cudaDeviceProp deviceProp; + cudaGetDeviceProperties(&deviceProp, dev); +#endif + + int block_dim = deviceProp.maxThreadsPerBlock; + int grid_dim = deviceProp.multiProcessorCount; + + switch (type) { + case LinalgType::F64: + LaunchCholeskyUpdateKernelBody( + stream, buffers, grid_dim, block_dim, nSize); + break; + case LinalgType::F32: + LaunchCholeskyUpdateKernelBody( + stream, buffers, grid_dim, block_dim, nSize); + break; + } +} + +} // namespace JAX_GPU_NAMESPACE +} // namespace jax diff --git a/jaxlib/gpu/cholesky_update_kernel.h b/jaxlib/gpu/cholesky_update_kernel.h new file mode 100644 index 000000000000..2f10cbe1ad5e --- /dev/null +++ b/jaxlib/gpu/cholesky_update_kernel.h @@ -0,0 +1,50 @@ +/* Copyright 2024 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_GPU_CHOLESKY_UPDATE_KERNEL_H_ +#define JAXLIB_GPU_CHOLESKY_UPDATE_KERNEL_H_ + +#include +#include +#include + +#include "jaxlib/gpu/vendor.h" +#include "xla/service/custom_call_status.h" + + +namespace jax { +namespace JAX_GPU_NAMESPACE { + +enum LinalgType { + F32 = 0, + F64 = 1, +}; + +struct CholeskyUpdateDescriptor { + LinalgType linalg_type; + std::int64_t matrix_size; // leading dim (N) for a square (NxN)matrix +}; + +void LaunchCholeskyUpdateKernel( + gpuStream_t stream, void** buffers, CholeskyUpdateDescriptor descriptor); + +void CholeskyUpdate(gpuStream_t stream, void** buffers, + const char* opaque, size_t opaque_len, + XlaCustomCallStatus* status); + +} // namespace JAX_GPU_NAMESPACE +} // namespace jax + +#endif // JAXLIB_GPU_CHOLESKY_UPDATE_KERNEL_H_ diff --git a/jaxlib/gpu/gpu_kernels.cc b/jaxlib/gpu/gpu_kernels.cc index bd3dde4fdee6..c7a075224eb7 100644 --- a/jaxlib/gpu/gpu_kernels.cc +++ b/jaxlib/gpu/gpu_kernels.cc @@ -17,12 +17,16 @@ limitations under the License. // JAX-generated HLO code from outside of JAX. #include "jaxlib/gpu/blas_kernels.h" +#include "jaxlib/gpu/cholesky_update_kernel.h" #include "jaxlib/gpu/lu_pivot_kernels.h" #include "jaxlib/gpu/prng_kernels.h" +#include "jaxlib/gpu/rnn_kernels.h" #include "jaxlib/gpu/solver_kernels.h" #include "jaxlib/gpu/sparse_kernels.h" #include "jaxlib/gpu/triton_kernels.h" #include "jaxlib/gpu/vendor.h" +#include "xla/ffi/api/c_api.h" +#include "xla/ffi/api/ffi.h" #include "xla/service/custom_call_target_registry.h" namespace jax { @@ -31,16 +35,27 @@ namespace { XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cublas_getrf_batched", GetrfBatched, "CUDA"); +XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cublas_geqrf_batched", GeqrfBatched, + "CUDA"); +XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cudnn_rnn", RNNForward, "CUDA"); +XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cudnn_rnn_bwd", RNNBackward, "CUDA"); +XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cu_cholesky_update", + CholeskyUpdate, "CUDA"); +// TODO(b/350111820): use the new FFI registration mechanism XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cu_lu_pivots_to_permutation", LuPivotsToPermutation, "CUDA"); XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cu_threefry2x32", ThreeFry2x32, "CUDA"); +// TODO(b/350111820): use the new FFI registration mechanism +XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cu_threefry2x32_ffi", + ThreeFry2x32Ffi, "CUDA"); XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_getrf", Getrf, "CUDA"); XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_geqrf", Geqrf, "CUDA"); XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_csrlsvqr", Csrlsvqr, "CUDA"); XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_orgqr", Orgqr, "CUDA"); XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_syevd", Syevd, "CUDA"); XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_syevj", Syevj, "CUDA"); +XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_sytrd", Sytrd, "CUDA"); XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_gesvd", Gesvd, "CUDA"); XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_gesvdj", Gesvdj, "CUDA"); diff --git a/jaxlib/gpu/linalg.cc b/jaxlib/gpu/linalg.cc index 6397647105ad..dfdbbc111a83 100644 --- a/jaxlib/gpu/linalg.cc +++ b/jaxlib/gpu/linalg.cc @@ -14,9 +14,12 @@ limitations under the License. ==============================================================================*/ #include "nanobind/nanobind.h" +#include "jaxlib/gpu/cholesky_update_kernel.h" #include "jaxlib/gpu/gpu_kernel_helpers.h" #include "jaxlib/gpu/lu_pivot_kernels.h" +#include "jaxlib/gpu/vendor.h" #include "jaxlib/kernel_nanobind_helpers.h" +#include "xla/tsl/python/lib/core/numpy.h" namespace jax { namespace JAX_GPU_NAMESPACE { @@ -24,29 +27,27 @@ namespace { namespace nb = nanobind; -std::string BuildLuPivotsToPermutationDescriptor( - std::int64_t batch_size, std::int32_t pivot_size, - std::int32_t permutation_size) { - return PackDescriptorAsString(LuPivotsToPermutationDescriptor{ - batch_size, pivot_size, permutation_size}); -} -nb::dict Registrations() { - nb::dict dict; - dict[JAX_GPU_PREFIX "_lu_pivots_to_permutation"] = - EncapsulateFunction(LuPivotsToPermutation); - return dict; +nb::bytes BuildCholeskyUpdateDescriptor( + dtype np_type, + std::int64_t matrix_size) { + + LinalgType linalg_type = ( + np_type.itemsize() == 4 ? LinalgType::F32 : LinalgType::F64); + + return PackDescriptor(CholeskyUpdateDescriptor{linalg_type, matrix_size}); } NB_MODULE(_linalg, m) { - m.def("registrations", &Registrations); - m.def("lu_pivots_to_permutation_descriptor", - [](std::int64_t batch_size, std::int32_t pivot_size, - std::int32_t permutation_size) { - std::string result = BuildLuPivotsToPermutationDescriptor( - batch_size, pivot_size, permutation_size); - return nb::bytes(result.data(), result.size()); - }); + tsl::ImportNumpy(); + m.def("registrations", []() { + nb::dict dict; + dict[JAX_GPU_PREFIX "_lu_pivots_to_permutation"] = + nb::capsule(reinterpret_cast(+LuPivotsToPermutation)); + dict["cu_cholesky_update"] = EncapsulateFunction(CholeskyUpdate); + return dict; + }); + m.def("build_cholesky_update_descriptor", &BuildCholeskyUpdateDescriptor); } } // namespace diff --git a/jaxlib/gpu/lu_pivot_kernels.cc b/jaxlib/gpu/lu_pivot_kernels.cc index c705f6f08bfc..b2c6362273ab 100644 --- a/jaxlib/gpu/lu_pivot_kernels.cc +++ b/jaxlib/gpu/lu_pivot_kernels.cc @@ -15,38 +15,70 @@ limitations under the License. #include "jaxlib/gpu/lu_pivot_kernels.h" -#include +#include +#include +#include +#include +#include "absl/algorithm/container.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_format.h" #include "jaxlib/gpu/gpu_kernel_helpers.h" #include "jaxlib/gpu/vendor.h" -#include "jaxlib/kernel_helpers.h" -#include "xla/service/custom_call_status.h" +#include "xla/ffi/api/c_api.h" +#include "xla/ffi/api/ffi.h" namespace jax { namespace JAX_GPU_NAMESPACE { -namespace { - -absl::Status LuPivotsToPermutation_(gpuStream_t stream, void** buffers, - const char* opaque, - std::size_t opaque_len) { - auto s = - UnpackDescriptor(opaque, opaque_len); - JAX_RETURN_IF_ERROR(s.status()); - LaunchLuPivotsToPermutationKernel(stream, buffers, **s); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuGetLastError())); - return absl::OkStatus(); -} -} // namespace +namespace ffi = xla::ffi; + +template +inline absl::StatusOr MaybeCastNoOverflow( + std::int64_t value, const std::string& source = __FILE__) { + if constexpr (sizeof(T) == sizeof(std::int64_t)) { + return value; + } else { + if (value > std::numeric_limits::max()) [[unlikely]] { + return absl::InvalidArgumentError(absl::StrFormat( + "%s: Value (=%d) exceeds the maximum representable value of the " + "desired type", + source, value)); + } + return static_cast(value); + } +} -void LuPivotsToPermutation(gpuStream_t stream, void** buffers, - const char* opaque, size_t opaque_len, - XlaCustomCallStatus* status) { - auto s = LuPivotsToPermutation_(stream, buffers, opaque, opaque_len); - if (!s.ok()) { - std::string_view message = s.message(); - XlaCustomCallStatusSetFailure(status, message.data(), message.length()); +ffi::Error LuPivotsToPermutationImpl( + gpuStream_t stream, std::int32_t permutation_size, + ffi::Buffer pivots, + ffi::Result> permutation) { + auto dims = pivots.dimensions; + if (dims.size() < 1) { + return ffi::Error(ffi::ErrorCode::kInvalidArgument, + "pivots must have at least one dimension"); + } + auto maybe_pivot_size = MaybeCastNoOverflow(dims.back()); + if (!maybe_pivot_size.ok()) { + return ffi::Error( + static_cast(maybe_pivot_size.status().code()), + std::string(maybe_pivot_size.status().message())); + } + std::int32_t pivot_size = maybe_pivot_size.value(); + std::int64_t batch_size = 1; + if (dims.size() >= 2) { + batch_size = + absl::c_accumulate(dims.first(dims.size() - 1), 1, std::multiplies<>()); + } + LaunchLuPivotsToPermutationKernel(stream, batch_size, pivot_size, + permutation_size, pivots.data, + permutation->data); + if (auto status = JAX_AS_STATUS(gpuGetLastError()); !status.ok()) { + return ffi::Error(static_cast(status.code()), + std::string(status.message())); } + return ffi::Error::Success(); } } // namespace JAX_GPU_NAMESPACE diff --git a/jaxlib/gpu/lu_pivot_kernels.cu.cc b/jaxlib/gpu/lu_pivot_kernels.cu.cc index 6a5c6e7553f3..1d24a38d2c7f 100644 --- a/jaxlib/gpu/lu_pivot_kernels.cu.cc +++ b/jaxlib/gpu/lu_pivot_kernels.cu.cc @@ -61,21 +61,19 @@ __global__ void LuPivotsToPermutationKernel( } // namespace -void LaunchLuPivotsToPermutationKernel( - gpuStream_t stream, void** buffers, - LuPivotsToPermutationDescriptor descriptor) { - const std::int32_t* pivots = - reinterpret_cast(buffers[0]); - std::int32_t* permutation_out = reinterpret_cast(buffers[1]); - +void LaunchLuPivotsToPermutationKernel(gpuStream_t stream, + std::int64_t batch_size, + std::int32_t pivot_size, + std::int32_t permutation_size, + const std::int32_t* pivots, + std::int32_t* permutation) { const int block_dim = 128; - const std::int64_t grid_dim = std::min( - 1024, (descriptor.batch_size + block_dim - 1) / block_dim); + const std::int64_t grid_dim = + std::min(1024, (batch_size + block_dim - 1) / block_dim); LuPivotsToPermutationKernel<<>>( - pivots, permutation_out, descriptor.batch_size, descriptor.pivot_size, - descriptor.permutation_size); + pivots, permutation, batch_size, pivot_size, permutation_size); } } // namespace JAX_GPU_NAMESPACE diff --git a/jaxlib/gpu/lu_pivot_kernels.h b/jaxlib/gpu/lu_pivot_kernels.h index 1c1a137d0838..b2cceb883dc9 100644 --- a/jaxlib/gpu/lu_pivot_kernels.h +++ b/jaxlib/gpu/lu_pivot_kernels.h @@ -16,29 +16,34 @@ limitations under the License. #ifndef JAXLIB_GPU_LU_PIVOT_KERNELS_H_ #define JAXLIB_GPU_LU_PIVOT_KERNELS_H_ -#include #include -#include #include "jaxlib/gpu/vendor.h" -#include "xla/service/custom_call_status.h" +#include "xla/ffi/api/ffi.h" namespace jax { namespace JAX_GPU_NAMESPACE { -struct LuPivotsToPermutationDescriptor { - std::int64_t batch_size; - std::int32_t pivot_size; - std::int32_t permutation_size; -}; - -void LaunchLuPivotsToPermutationKernel( - gpuStream_t stream, void** buffers, - LuPivotsToPermutationDescriptor descriptor); - -void LuPivotsToPermutation(gpuStream_t stream, void** buffers, - const char* opaque, size_t opaque_len, - XlaCustomCallStatus* status); +namespace ffi = xla::ffi; + +void LaunchLuPivotsToPermutationKernel(gpuStream_t stream, + std::int64_t batch_size, + std::int32_t pivot_size, + std::int32_t permutation_size, + const std::int32_t* pivots, + std::int32_t* permutation); + +ffi::Error LuPivotsToPermutationImpl( + gpuStream_t stream, std::int32_t permutation_size, + ffi::Buffer pivots, + ffi::Result> permutation); + +XLA_FFI_DEFINE_HANDLER(LuPivotsToPermutation, LuPivotsToPermutationImpl, + ffi::Ffi::Bind() + .Ctx>() + .Attr("permutation_size") + .Arg>() + .Ret>()); } // namespace JAX_GPU_NAMESPACE } // namespace jax diff --git a/jaxlib/gpu/prng.cc b/jaxlib/gpu/prng.cc index 8aec8d81f8a3..79df5c005926 100644 --- a/jaxlib/gpu/prng.cc +++ b/jaxlib/gpu/prng.cc @@ -28,12 +28,16 @@ std::string BuildThreeFry2x32Descriptor(std::int64_t n) { } nb::dict Registrations() { nb::dict dict; + dict[JAX_GPU_PREFIX "_threefry2x32_ffi"] = + EncapsulateFunction(ThreeFry2x32Ffi); + // TODO(b/338022728): remove after 3 weeks dict[JAX_GPU_PREFIX "_threefry2x32"] = EncapsulateFunction(ThreeFry2x32); return dict; } NB_MODULE(_prng, m) { m.def("registrations", &Registrations); + // TODO(b/338022728): remove after 3 weeks m.def("threefry2x32_descriptor", [](std::int64_t n) { std::string result = BuildThreeFry2x32Descriptor(n); return nb::bytes(result.data(), result.size()); diff --git a/jaxlib/gpu/prng_kernels.cc b/jaxlib/gpu/prng_kernels.cc index 00ee4e648b11..56ce0323ec96 100644 --- a/jaxlib/gpu/prng_kernels.cc +++ b/jaxlib/gpu/prng_kernels.cc @@ -15,16 +15,26 @@ limitations under the License. #include "jaxlib/gpu/prng_kernels.h" +#include +#include +#include #include +#include "absl/algorithm/container.h" #include "jaxlib/gpu/gpu_kernel_helpers.h" #include "jaxlib/kernel_helpers.h" +#include "xla/ffi/api/c_api.h" +#include "xla/ffi/api/ffi.h" #include "xla/service/custom_call_status.h" namespace jax { namespace JAX_GPU_NAMESPACE { + +namespace ffi = xla::ffi; + namespace { +// TODO(b/338022728): old custom call target, remove after 6 months absl::Status ThreeFry2x32_(gpuStream_t stream, void** buffers, const char* opaque, std::size_t opaque_len) { auto s = UnpackDescriptor(opaque, opaque_len); @@ -36,6 +46,7 @@ absl::Status ThreeFry2x32_(gpuStream_t stream, void** buffers, } // namespace +// TODO(b/338022728): remove after 6 months void ThreeFry2x32(gpuStream_t stream, void** buffers, const char* opaque, size_t opaque_len, XlaCustomCallStatus* status) { auto s = ThreeFry2x32_(stream, buffers, opaque, opaque_len); @@ -45,5 +56,32 @@ void ThreeFry2x32(gpuStream_t stream, void** buffers, const char* opaque, } } +XLA_FFI_Error* ThreeFry2x32Ffi(XLA_FFI_CallFrame* call_frame) { + static const auto* kImpl = + ffi::Ffi::Bind() + .Ctx>() + .Arg>() + .Arg>() + .Arg>() + .Arg>() + .Ret>() + .Ret>() + .To([](gpuStream_t stream, auto keys0, auto keys1, auto data0, + auto data1, auto out0, auto out1) -> ffi::Error { + std::int64_t n = absl::c_accumulate(out0->dimensions, 1, + std::multiplies()); + LaunchThreeFry2x32KernelFfi(stream, n, keys0.data, keys1.data, + data0.data, data1.data, out0->data, + out1->data); + if (auto status = JAX_AS_STATUS(gpuGetLastError()); !status.ok()) { + return ffi::Error(static_cast(status.code()), + std::string(status.message())); + } + return ffi::Error::Success(); + }) + .release(); + return kImpl->Call(call_frame); +} + } // namespace JAX_GPU_NAMESPACE } // namespace jax diff --git a/jaxlib/gpu/prng_kernels.cu.cc b/jaxlib/gpu/prng_kernels.cu.cc index 18a435ae2d96..4799614b475a 100644 --- a/jaxlib/gpu/prng_kernels.cu.cc +++ b/jaxlib/gpu/prng_kernels.cu.cc @@ -105,6 +105,23 @@ __global__ void ThreeFry2x32Kernel(const std::uint32_t* key0, } // namespace +void LaunchThreeFry2x32KernelFfi(gpuStream_t stream, + std::int64_t n, + std::uint32_t *keys0, + std::uint32_t *keys1, + std::uint32_t *data0, + std::uint32_t *data1, + std::uint32_t *out0, + std::uint32_t *out1) { + const int block_dim = 128; + const std::int64_t grid_dim = + std::min(1024, (n + block_dim - 1) / block_dim); + ThreeFry2x32Kernel<<>>(keys0, keys1, data0, data1, out0, + out1, n, nullptr); +} + +// TODO(b/338022728): remove after 6 months void LaunchThreeFry2x32Kernel(gpuStream_t stream, void** buffers, ThreeFry2x32Descriptor descriptor) { std::array keys; diff --git a/jaxlib/gpu/prng_kernels.h b/jaxlib/gpu/prng_kernels.h index 5f1c4bd749e2..363023e32b3b 100644 --- a/jaxlib/gpu/prng_kernels.h +++ b/jaxlib/gpu/prng_kernels.h @@ -21,21 +21,33 @@ limitations under the License. #include #include "jaxlib/gpu/vendor.h" +#include "xla/ffi/api/c_api.h" #include "xla/service/custom_call_status.h" namespace jax { namespace JAX_GPU_NAMESPACE { +// TODO(b/338022728): remove after 6 months struct ThreeFry2x32Descriptor { std::int64_t n; // If -1 then the length is passed as a 5th operand }; +// TODO(b/338022728): remove after 6 months void LaunchThreeFry2x32Kernel(gpuStream_t stream, void** buffers, ThreeFry2x32Descriptor descriptor); +// TODO(b/338022728): remove after 6 months void ThreeFry2x32(gpuStream_t stream, void** buffers, const char* opaque, size_t opaque_len, XlaCustomCallStatus* status); +XLA_FFI_Error* ThreeFry2x32Ffi(XLA_FFI_CallFrame* call_frame); + +void LaunchThreeFry2x32KernelFfi(gpuStream_t stream, + std::int64_t n, + std::uint32_t *keys0, std::uint32_t *keys1, + std::uint32_t *data0, std::uint32_t *data1, + std::uint32_t *out0, std::uint32_t *out1); + } // namespace JAX_GPU_NAMESPACE } // namespace jax diff --git a/jaxlib/gpu/solver.cc b/jaxlib/gpu/solver.cc index d97718b765ad..c040f3875927 100644 --- a/jaxlib/gpu/solver.cc +++ b/jaxlib/gpu/solver.cc @@ -27,7 +27,8 @@ limitations under the License. #include "jaxlib/gpu/solver_kernels.h" #include "jaxlib/gpu/vendor.h" #include "jaxlib/kernel_nanobind_helpers.h" -#include "tsl/python/lib/core/numpy.h" +#include "xla/tsl/python/lib/core/numpy.h" + namespace jax { namespace JAX_GPU_NAMESPACE { diff --git a/jaxlib/gpu/sparse.cc b/jaxlib/gpu/sparse.cc index cda214794dd7..b9eb51388fa9 100644 --- a/jaxlib/gpu/sparse.cc +++ b/jaxlib/gpu/sparse.cc @@ -30,7 +30,7 @@ limitations under the License. #include "jaxlib/gpu/sparse_kernels.h" #include "jaxlib/gpu/vendor.h" #include "jaxlib/kernel_nanobind_helpers.h" -#include "tsl/python/lib/core/numpy.h" +#include "xla/tsl/python/lib/core/numpy.h" namespace nb = nanobind; diff --git a/jaxlib/gpu/triton_kernels.cc b/jaxlib/gpu/triton_kernels.cc index e48dce4b1d7f..89d804511313 100644 --- a/jaxlib/gpu/triton_kernels.cc +++ b/jaxlib/gpu/triton_kernels.cc @@ -19,6 +19,7 @@ #include "absl/base/thread_annotations.h" #include "absl/cleanup/cleanup.h" #include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/log/check.h" #include "absl/log/log.h" #include "absl/status/status.h" @@ -30,7 +31,11 @@ #include "jaxlib/gpu/triton_utils.h" #include "jaxlib/gpu/vendor.h" #include "xla/service/custom_call_status.h" -#include "xla/stream_executor/gpu/asm_compiler.h" +#include "tsl/platform/env.h" + +#ifdef JAX_GPU_CUDA +#include "xla/stream_executor/cuda/cuda_asm_compiler.h" +#endif #define GPU_RETURN_IF_ERROR(expr) JAX_RETURN_IF_ERROR(JAX_AS_STATUS(expr)) @@ -553,7 +558,7 @@ jax_triton::TritonAutotunedKernelCall AutotunedKernelCall::ToProto() const { // GPU_RETURN_IF_ERROR(gpuCtxPushCurrent(context)); // absl::Cleanup ctx_restorer = [] { gpuCtxPopCurrent(nullptr); }; - // Autotuning is not supported if the the stream is in graph capture mode. + // Autotuning is not supported if the stream is in graph capture mode. gpustreamCaptureStatus_t capture_status; GPU_RETURN_IF_ERROR(gpuStreamIsCapturing(stream, &capture_status)); if (capture_status == GPU_STREAM_CAPTURE_STATUS_ACTIVE) { @@ -583,7 +588,13 @@ jax_triton::TritonAutotunedKernelCall AutotunedKernelCall::ToProto() const { // First run a single iteration of each to config to determine how many // iterations to run for benchmarking. float best = std::numeric_limits::infinity(); + JAX_ASSIGN_OR_RETURN(gpuDevice_t device, GetStreamDevice(stream)); + absl::flat_hash_set configs_to_skip; for (Config& config : kernel_call.configs_) { + if (!config.kernel_call.CanLaunchOnDevice(device)) { + configs_to_skip.insert(&config); + continue; + } JAX_ASSIGN_OR_RETURN(float t, Benchmark(stream, config.kernel_call, buffers, 1)); LOG(INFO) << config.description << ", ran 1 iter in " << t << " ms"; @@ -601,9 +612,8 @@ jax_triton::TritonAutotunedKernelCall AutotunedKernelCall::ToProto() const { } best = std::numeric_limits::infinity(); - JAX_ASSIGN_OR_RETURN(gpuDevice_t device, GetStreamDevice(stream)); for (Config& config : kernel_call.configs_) { - if (!config.kernel_call.CanLaunchOnDevice(device)) { + if (configs_to_skip.contains(&config)) { LOG(WARNING) << "Unable to launch autotune config on device: " << config.description; continue; diff --git a/jaxlib/gpu/vendor.h b/jaxlib/gpu/vendor.h index dcbb95e8b360..f1d4bfa86b15 100644 --- a/jaxlib/gpu/vendor.h +++ b/jaxlib/gpu/vendor.h @@ -26,32 +26,21 @@ limitations under the License. #include "third_party/gpus/cuda/include/cuComplex.h" // IWYU pragma: export #include "third_party/gpus/cuda/include/cublas_v2.h" // IWYU pragma: export #include "third_party/gpus/cuda/include/cuda.h" // IWYU pragma: export +#include "third_party/gpus/cuda/include/cuda_fp8.h" // IWYU pragma: export #include "third_party/gpus/cuda/include/cuda_runtime_api.h" // IWYU pragma: export #include "third_party/gpus/cuda/include/cufft.h" // IWYU pragma: export #include "third_party/gpus/cuda/include/cusolverDn.h" // IWYU pragma: export #include "third_party/gpus/cuda/include/cusparse.h" // IWYU pragma: export #include "third_party/gpus/cudnn/cudnn.h" // IWYU pragma: export -// Some sparse functionality is only available in CUSPARSE 11.3 or newer. -#define JAX_GPU_HAVE_SPARSE (CUSPARSE_VERSION >= 11300) +#if CUDA_VERSION < 11080 +#error "JAX requires CUDA 11.8 or newer." +#endif // CUDA_VERSION < 11080 + +#define JAX_GPU_HAVE_SPARSE 1 // CUDA-11.8 introduces FP8 E4M3/E5M2 types. -#define JAX_GPU_HAVE_FP8 (CUDA_VERSION >= 11080) - -#if JAX_GPU_HAVE_FP8 -#include "third_party/gpus/cuda/include/cuda_fp8.h" -#endif - -// cuSPARSE generic APIs are not supported on Windows until 11.0 -// cusparseIndexType_t is used in very limited scope so manually define will -// workaround compiling issue without harm. -#if defined(_WIN32) && (CUSPARSE_VERSION < 11000) -typedef enum { - CUSPARSE_INDEX_16U = 1, - CUSPARSE_INDEX_32I = 2, - CUSPARSE_INDEX_64I = 3 -} cusparseIndexType_t; -#endif +#define JAX_GPU_HAVE_FP8 1 #define JAX_GPU_NAMESPACE cuda #define JAX_GPU_PREFIX "cu" @@ -232,7 +221,6 @@ typedef cusparseDnVecDescr_t gpusparseDnVecDescr_t; // provide deterministic (bit-wise) results for each run. These indexing modes // are fully supported (both row- and column-major inputs) in CUSPARSE 11.7.1 // and newer (which was released as part of CUDA 11.8) -#if CUSPARSE_VERSION > 11700 #define GPUSPARSE_SPMV_COO_ALG CUSPARSE_SPMV_COO_ALG2 #define GPUSPARSE_SPMV_CSR_ALG CUSPARSE_SPMV_CSR_ALG2 #define GPUSPARSE_SPMM_COO_ALG CUSPARSE_SPMM_COO_ALG2 @@ -242,12 +230,7 @@ typedef cusparseDnVecDescr_t gpusparseDnVecDescr_t; // does not cover all cases and silently fell back to other algorithms for cases // it did not cover. CUDA 12.2.1 removed the fallback behavior. #define GPUSPARSE_SPMM_CSR_ALG CUSPARSE_SPMM_ALG_DEFAULT -#else -#define GPUSPARSE_SPMV_COO_ALG CUSPARSE_MV_ALG_DEFAULT -#define GPUSPARSE_SPMV_CSR_ALG CUSPARSE_MV_ALG_DEFAULT -#define GPUSPARSE_SPMM_COO_ALG CUSPARSE_SPMM_ALG_DEFAULT -#define GPUSPARSE_SPMM_CSR_ALG CUSPARSE_SPMM_ALG_DEFAULT -#endif + #define GPUSPARSE_OPERATION_NON_TRANSPOSE CUSPARSE_OPERATION_NON_TRANSPOSE #define GPUSPARSE_OPERATION_TRANSPOSE CUSPARSE_OPERATION_TRANSPOSE #define GPUSPARSE_ORDER_ROW CUSPARSE_ORDER_ROW @@ -513,6 +496,7 @@ typedef hipsparseDnVecDescr_t gpusparseDnVecDescr_t; #define gpuMemcpyHostToDevice hipMemcpyHostToDevice #define gpuMemcpyDeviceToHost hipMemcpyDeviceToHost #define gpuStreamSynchronize hipStreamSynchronize +#define gpuStreamWaitEvent hipStreamWaitEvent #define gpuSuccess hipSuccess #define gpuCtxGetDevice hipCtxGetDevice @@ -542,7 +526,7 @@ typedef hipsparseDnVecDescr_t gpusparseDnVecDescr_t; #define gpuThreadExchangeStreamCaptureMode hipThreadExchangeStreamCaptureMode #define gpuStreamCreate hipStreamCreateWithFlags #define gpuStreamDestroy hipStreamDestroy -#define gpuStreamIsCapturing hipStreamIsCapturing +#define gpuStreamIsCapturing hipStreamIsCapturing #define GPU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR \ hipDeviceAttributeComputeCapabilityMajor diff --git a/jaxlib/gpu_linalg.py b/jaxlib/gpu_linalg.py index 1002b66171b7..af79b3ae756f 100644 --- a/jaxlib/gpu_linalg.py +++ b/jaxlib/gpu_linalg.py @@ -15,6 +15,7 @@ import functools from functools import partial import importlib +import numpy as np import operator import jaxlib.mlir.ir as ir @@ -24,7 +25,7 @@ from jaxlib import xla_client -for cuda_module_name in [".cuda", "jax_cuda12_plugin", "jax_cuda11_plugin"]: +for cuda_module_name in [".cuda", "jax_cuda12_plugin"]: try: _cuda_linalg = importlib.import_module( f"{cuda_module_name}._linalg", package="jaxlib" @@ -36,14 +37,26 @@ if _cuda_linalg: for _name, _value in _cuda_linalg.registrations().items(): - xla_client.register_custom_call_target(_name, _value, platform="CUDA") + api_version = 0 if _name == "cu_cholesky_update" else 1 + xla_client.register_custom_call_target( + _name, _value, platform="CUDA", api_version=api_version + ) -try: - from .rocm import _linalg as _hip_linalg # pytype: disable=import-error +for rocm_module_name in [".rocm", "jax_rocm60_plugin"]: + try: + _hip_linalg = importlib.import_module( + f"{rocm_module_name}._linalg", package="jaxlib" + ) + except ImportError: + _hip_linalg = None + else: + break + +if _hip_linalg: for _name, _value in _hip_linalg.registrations().items(): - xla_client.register_custom_call_target(_name, _value, platform="ROCM") -except ImportError: - _hip_linalg = None + xla_client.register_custom_call_target( + _name, _value, platform="ROCM", api_version=1 + ) _prod = lambda xs: functools.reduce(operator.mul, xs, 1) @@ -56,14 +69,9 @@ def _lu_pivots_to_permutation_hlo(platform, gpu_linalg, pivots, *, permutation_s assert typ.element_type == i32_type, typ - batch_size = _prod(dims[:-1]) - pivot_size = dims[-1] - if not gpu_linalg: raise GpuLibNotLinkedError() - opaque = gpu_linalg.lu_pivots_to_permutation_descriptor( - batch_size, pivot_size, permutation_size) pivots_layout = tuple(range(len(dims) - 1, -1, -1)) permutations_layout = pivots_layout permutations_dims = list(dims) @@ -71,13 +79,47 @@ def _lu_pivots_to_permutation_hlo(platform, gpu_linalg, pivots, *, permutation_s permutations_type = ir.RankedTensorType.get(permutations_dims, i32_type) return custom_call( f"{platform}_lu_pivots_to_permutation", + api_version=4, result_types=[permutations_type], operands=[pivots], - backend_config=opaque, + backend_config=dict( + permutation_size=ir.IntegerAttr.get(i32_type, permutation_size), + ), operand_layouts=[pivots_layout], - result_layouts=[permutations_layout]).results + result_layouts=[permutations_layout], + ).results cuda_lu_pivots_to_permutation = partial(_lu_pivots_to_permutation_hlo, "cu", _cuda_linalg) hip_lu_pivots_to_permutation = partial( _lu_pivots_to_permutation_hlo, "hip", _hip_linalg) + + + +def _cholesky_update_hlo(platform, gpu_linalg, r_matrix, w_vector, dtype): + """Cholesky update.""" + del platform + r_type = ir.RankedTensorType(r_matrix.type) + dims = r_type.shape + assert dims[0] == dims[1] + n = dims[0] + + if not gpu_linalg: + raise GpuLibNotLinkedError() + + np_type = np.dtype(dtype) + opaque = gpu_linalg.build_cholesky_update_descriptor(np_type, n) + + return custom_call( + "cu_cholesky_update", + operands = [r_matrix, w_vector], + result_types=[ + ir.RankedTensorType.get((n, n), r_type.element_type), + ir.RankedTensorType.get((n,), r_type.element_type), + ], + operand_output_aliases={0: 0, 1: 1}, + backend_config=opaque, + ).results[:1] + + +cuda_cholesky_update = partial(_cholesky_update_hlo, "cu", _cuda_linalg) diff --git a/jaxlib/gpu_prng.py b/jaxlib/gpu_prng.py index f1e0ed6a400a..573dae268cee 100644 --- a/jaxlib/gpu_prng.py +++ b/jaxlib/gpu_prng.py @@ -27,7 +27,7 @@ from .hlo_helpers import custom_call from .gpu_common_utils import GpuLibNotLinkedError -for cuda_module_name in [".cuda", "jax_cuda12_plugin", "jax_cuda11_plugin"]: +for cuda_module_name in [".cuda", "jax_cuda12_plugin"]: try: _cuda_prng = importlib.import_module( f"{cuda_module_name}._prng", package="jaxlib" @@ -39,26 +39,41 @@ if _cuda_prng: for _name, _value in _cuda_prng.registrations().items(): - xla_client.register_custom_call_target(_name, _value, platform="CUDA") + # TODO(b/338022728): remove after 6 months, always api_version=1 + api_version = 1 if "_ffi" in _name else 0 + xla_client.register_custom_call_target(_name, _value, platform="CUDA", + api_version=api_version) -try: - from .rocm import _prng as _hip_prng # pytype: disable=import-error +for rocm_module_name in [".rocm", "jax_rocm60_plugin"]: + try: + _hip_prng = importlib.import_module( + f"{rocm_module_name}._prng", package="jaxlib" + ) + except ImportError: + _hip_prng = None + else: + break + +if _hip_prng: for _name, _value in _hip_prng.registrations().items(): - xla_client.register_custom_call_target(_name, _value, platform="ROCM") -except ImportError: - _hip_prng = None + # TODO(b/338022728): remove after 6 months, always api_version=1 + api_version = 1 if "_ffi" in _name else 0 + xla_client.register_custom_call_target(_name, _value, platform="ROCM", + api_version=api_version) _prod = lambda xs: functools.reduce(operator.mul, xs, 1) -def _threefry2x32_lowering(prng, platform, keys, data, +# TODO(b/338022728): forward_compatibility_mode=False after 3 weeks. +def _threefry2x32_lowering(prng, platform: str, keys, data, length: int | ir.Value | None = None, - output_shape: ir.Value | None = None): + output_shape: ir.Value | None = None, + forward_compatibility_mode: bool = True): """ThreeFry2x32 kernel for GPU. In presence of dynamic shapes, `length` is an `ir.Value` and `output_shape` is a 1D tensor describing the shape of the two outputs. """ - if not prng: + if forward_compatibility_mode and not prng: raise GpuLibNotLinkedError() assert len(keys) == 2, keys assert len(data) == 2, data @@ -75,28 +90,37 @@ def _threefry2x32_lowering(prng, platform, keys, data, operand_layouts = [layout] * 4 operands = [keys[0], keys[1], data[0], data[1]] - if length is None: + if forward_compatibility_mode and length is None: length = _prod(dims) + opaque = {} # Use if not forward_compatibility_mode to trigger the FFI (v4). if isinstance(length, int): - opaque = prng.threefry2x32_descriptor(length) + if forward_compatibility_mode: + opaque = prng.threefry2x32_descriptor(length) result_shapes = None else: assert output_shape is not None - opaque = prng.threefry2x32_descriptor(-1) - assert (ir.RankedTensorType(length.type).element_type == - ir.IntegerType.get_signless(64)), length - assert (ir.RankedTensorType(length.type).shape == - [1]), (length, ir.RankedTensorType(length.type).shape) - # Pass the length, which will be used by the custom call target since the - # static length in the descriptor is -1. - operands.append(length) - operand_layouts.append((0,)) + if forward_compatibility_mode: + opaque = prng.threefry2x32_descriptor(-1) + assert (ir.RankedTensorType(length.type).element_type == # type: ignore[attribute-error] + ir.IntegerType.get_signless(64)), length + assert (ir.RankedTensorType(length.type).shape == # type: ignore[attribute-error] + [1]), (length, ir.RankedTensorType(length.type).shape) # type: ignore[attribute-error] + # Pass the length, which will be used by the custom call target since the + # static length in the descriptor is -1. + operands.append(length) + operand_layouts.append((0,)) # We also need to pass separately the shapes of the outputs. result_shapes = [output_shape, output_shape] + custom_call_target = ( + f"{platform}_threefry2x32" + if forward_compatibility_mode + else f"{platform}_threefry2x32_ffi" + ) return custom_call( - f"{platform}_threefry2x32", + custom_call_target, + api_version=(2 if forward_compatibility_mode else 4), result_types=[typ, typ], operands=operands, backend_config=opaque, diff --git a/jaxlib/gpu_rnn.py b/jaxlib/gpu_rnn.py index 59192c99911c..0fc3dc350967 100644 --- a/jaxlib/gpu_rnn.py +++ b/jaxlib/gpu_rnn.py @@ -22,7 +22,7 @@ from jaxlib import xla_client from .gpu_common_utils import GpuLibNotLinkedError -for cuda_module_name in [".cuda", "jax_cuda12_plugin", "jax_cuda11_plugin"]: +for cuda_module_name in [".cuda", "jax_cuda12_plugin"]: try: _rnn = importlib.import_module(f"{cuda_module_name}._rnn", package="jaxlib") except ImportError: diff --git a/jaxlib/gpu_solver.py b/jaxlib/gpu_solver.py index ea833dccd304..f9a704a79f3d 100644 --- a/jaxlib/gpu_solver.py +++ b/jaxlib/gpu_solver.py @@ -28,12 +28,12 @@ from .hlo_helpers import ( DimensionSize, ShapeTypePair, mk_result_types_and_shapes, - custom_call, ensure_hlo_s32, hlo_s32, dense_int_array, dense_int_array_v6) + custom_call, ensure_hlo_s32, hlo_s32, dense_int_array) try: from .cuda import _blas as _cublas # pytype: disable=import-error except ImportError: - for cuda_module_name in ["jax_cuda12_plugin", "jax_cuda11_plugin"]: + for cuda_module_name in ["jax_cuda12_plugin"]: try: _cublas = importlib.import_module(f"{cuda_module_name}._blas") except ImportError: @@ -45,7 +45,7 @@ for _name, _value in _cublas.registrations().items(): xla_client.register_custom_call_target(_name, _value, platform="CUDA") -for cuda_module_name in [".cuda", "jax_cuda12_plugin", "jax_cuda11_plugin"]: +for cuda_module_name in [".cuda", "jax_cuda12_plugin"]: try: _cusolver = importlib.import_module( f"{cuda_module_name}._solver", package="jaxlib" @@ -59,21 +59,34 @@ for _name, _value in _cusolver.registrations().items(): xla_client.register_custom_call_target(_name, _value, platform="CUDA") - try: from .rocm import _blas as _hipblas # pytype: disable=import-error - for _name, _value in _hipblas.registrations().items(): - xla_client.register_custom_call_target(_name, _value, platform="ROCM") except ImportError: - _hipblas = None + for rocm_module_name in ["jax_rocm60_plugin"]: + try: + _hipblas = importlib.import_module(f"{rocm_module_name}._blas") + except: + _hipblas = None + else: + break -try: - from .rocm import _solver as _hipsolver # pytype: disable=import-error - for _name, _value in _hipsolver.registrations().items(): - xla_client.register_custom_call_target(_name, _value, platform="ROCM") -except ImportError: - _hipsolver = None +if _hipblas: + for _name, _value in _hipblas.registrations().items(): + xla_client.register_custom_call_target(_name, _value, platform="ROCM") + +for rocm_module_name in [".rocm", "jax_rocm60_plugin"]: + try: + _hipsolver = importlib.import_module( + f"{rocm_module_name}._solver", package="jaxlib" + ) + except ImportError: + _hipsolver = None + else: + break +if _hipsolver: + for _name, _value in _hipsolver.registrations().items(): + xla_client.register_custom_call_target(_name, _value, platform="ROCM") def _real_type(dtype): """Returns the real equivalent of 'dtype'.""" @@ -536,14 +549,13 @@ def _sytrd_hlo(platform, gpu_solver, dtype, a, *, lower): # simply copy it back to where it needs to be: intattr = lambda xs: ir.DenseIntElementsAttr.get(np.asarray(xs, np.int64)) intarrattr = lambda xs: dense_int_array(np.asarray(xs, np.int64)) - intarrattr_v6 = lambda xs: dense_int_array_v6(np.asarray(xs, np.int64)) if not lower and platform == "cu" and m > 1: start = (0,) * len(batch_dims) + (0,) end = batch_dims + (1,) s = hlo.slice( e, intarrattr(start), intarrattr(end), intarrattr([1] * len(start))) s_type = ir.RankedTensorType.get(batch_dims + (1, 1), diag_type) - s = hlo.broadcast_in_dim(s_type, s, intarrattr_v6(range(len(dims) - 1))) + s = hlo.broadcast_in_dim(s_type, s, intarrattr(range(len(dims) - 1))) # The diagonals are always real; convert to complex if needed. s = hlo.convert( ir.RankedTensorType.get(s_type.shape, a_type.element_type), s) diff --git a/jaxlib/gpu_sparse.py b/jaxlib/gpu_sparse.py index 078bb71b75ce..84192d4d0286 100644 --- a/jaxlib/gpu_sparse.py +++ b/jaxlib/gpu_sparse.py @@ -27,7 +27,7 @@ from .hlo_helpers import custom_call, mk_result_types_and_shapes -for cuda_module_name in [".cuda", "jax_cuda12_plugin", "jax_cuda11_plugin"]: +for cuda_module_name in [".cuda", "jax_cuda12_plugin"]: try: _cusparse = importlib.import_module( f"{cuda_module_name}._sparse", package="jaxlib" @@ -41,11 +41,17 @@ for _name, _value in _cusparse.registrations().items(): xla_client.register_custom_call_target(_name, _value, platform="CUDA") -try: - from .rocm import _sparse as _hipsparse # pytype: disable=import-error -except ImportError: - _hipsparse = None -else: +for rocm_module_name in [".rocm", "jax_rocm60_plugin"]: + try: + _hipsparse = importlib.import_module( + f"{rocm_module_name}._sparse", package="jaxlib" + ) + except ImportError: + _hipsparse = None + else: + break + +if _hipsparse: for _name, _value in _hipsparse.registrations().items(): xla_client.register_custom_call_target(_name, _value, platform="ROCM") diff --git a/jaxlib/gpu_triton.py b/jaxlib/gpu_triton.py index a004954a6164..f2d37bfec03d 100644 --- a/jaxlib/gpu_triton.py +++ b/jaxlib/gpu_triton.py @@ -15,7 +15,7 @@ from jaxlib import xla_client -for cuda_module_name in [".cuda", "jax_cuda12_plugin", "jax_cuda11_plugin"]: +for cuda_module_name in [".cuda", "jax_cuda12_plugin"]: try: _cuda_triton = importlib.import_module( f"{cuda_module_name}._triton", package="jaxlib" @@ -38,8 +38,17 @@ get_custom_call = _cuda_triton.get_custom_call get_serialized_metadata = _cuda_triton.get_serialized_metadata -try: - from .rocm import _triton as _hip_triton # pytype: disable=import-error +for rocm_module_name in [".rocm", "jax_rocm60_plugin"]: + try: + _hip_triton = importlib.import_module( + f"{rocm_module_name}._triton", package="jaxlib" + ) + except ImportError: + _hip_triton = None + else: + break + +if _hip_triton: xla_client.register_custom_call_target( "triton_kernel_call", _hip_triton.get_custom_call(), platform='ROCM') @@ -51,5 +60,3 @@ get_compute_capability = _hip_triton.get_compute_capability get_custom_call = _hip_triton.get_custom_call get_serialized_metadata = _hip_triton.get_serialized_metadata -except ImportError: - _hip_triton = None diff --git a/jaxlib/hlo_helpers.py b/jaxlib/hlo_helpers.py index ee59a1b96ee1..0d57a04f1aa7 100644 --- a/jaxlib/hlo_helpers.py +++ b/jaxlib/hlo_helpers.py @@ -16,9 +16,9 @@ from __future__ import annotations -from collections.abc import Sequence +from collections.abc import Callable, Sequence from functools import partial -from typing import Callable, Union +from typing import Union import jaxlib.mlir.ir as ir import jaxlib.mlir.dialects.stablehlo as hlo @@ -110,16 +110,7 @@ def hlo_s32(x: int): def ensure_hlo_s32(x: DimensionSize): return hlo_s32(x) if isinstance(x, int) else x -def dense_int_array(xs) -> ir.DenseIntElementsAttr | ir.DenseI64ArrayAttr: - # TODO: b/321794305 - remove this check when jaxlib is on StableHLO API v5 or higher - if hlo.get_api_version() < 5: - return ir.DenseIntElementsAttr.get(np.asarray(xs, np.int64)) - return ir.DenseI64ArrayAttr.get(np.asarray(xs, np.int64)) - -# TODO: b/321794305 - delete this when jaxlib is on StableHLO API v6 or higher -def dense_int_array_v6(xs) -> ir.DenseIntElementsAttr | ir.DenseI64ArrayAttr: - if hlo.get_api_version() < 6: - return ir.DenseIntElementsAttr.get(np.asarray(xs, np.int64)) +def dense_int_array(xs) -> ir.DenseI64ArrayAttr: return ir.DenseI64ArrayAttr.get(np.asarray(xs, np.int64)) def hlo_min(x: DimensionSize, y: DimensionSize) -> DimensionSize: diff --git a/jaxlib/jax.bzl b/jaxlib/jax.bzl index 0593870d1073..a4da463a3447 100644 --- a/jaxlib/jax.bzl +++ b/jaxlib/jax.bzl @@ -17,9 +17,10 @@ load("@com_github_google_flatbuffers//:build_defs.bzl", _flatbuffer_cc_library = "flatbuffer_cc_library") load("@local_config_cuda//cuda:build_defs.bzl", _cuda_library = "cuda_library", _if_cuda_is_configured = "if_cuda_is_configured") load("@local_config_rocm//rocm:build_defs.bzl", _if_rocm_is_configured = "if_rocm_is_configured", _rocm_library = "rocm_library") +load("@python_version_repo//:py_version.bzl", "HERMETIC_PYTHON_VERSION") load("@rules_cc//cc:defs.bzl", _cc_proto_library = "cc_proto_library") -load("@tsl//tsl:tsl.bzl", _if_windows = "if_windows", _pybind_extension = "tsl_pybind_extension_opensource") load("@tsl//tsl/platform:build_config_root.bzl", _tf_cuda_tests_tags = "tf_cuda_tests_tags", _tf_exec_properties = "tf_exec_properties") +load("@xla//xla/tsl:tsl.bzl", _if_windows = "if_windows", _pybind_extension = "tsl_pybind_extension_opensource") # Explicitly re-exports names to avoid "unused variable" warnings from .bzl # lint tools. @@ -37,6 +38,7 @@ tf_cuda_tests_tags = _tf_cuda_tests_tags jax_internal_packages = [] jax_extend_internal_users = [] +mosaic_gpu_internal_users = [] mosaic_internal_users = [] pallas_gpu_internal_users = [] pallas_tpu_internal_users = [] @@ -46,13 +48,52 @@ jax_internal_test_harnesses_visibility = [] jax_test_util_visibility = [] loops_visibility = [] +# TODO(vam): remove this once zstandard builds against Python 3.13 +def get_zstandard(): + if HERMETIC_PYTHON_VERSION == "3.13": + return [] + return ["@pypi_zstandard//:pkg"] + +_py_deps = { + "absl/logging": ["@pypi_absl_py//:pkg"], + "absl/testing": ["@pypi_absl_py//:pkg"], + "absl/flags": ["@pypi_absl_py//:pkg"], + "cloudpickle": ["@pypi_cloudpickle//:pkg"], + "colorama": ["@pypi_colorama//:pkg"], + "epath": ["@pypi_etils//:pkg"], # etils.epath + "filelock": ["@pypi_filelock//:pkg"], + "flatbuffers": ["@pypi_flatbuffers//:pkg"], + "hypothesis": ["@pypi_hypothesis//:pkg"], + "matplotlib": ["@pypi_matplotlib//:pkg"], + "opt_einsum": ["@pypi_opt_einsum//:pkg"], + "pil": ["@pypi_pillow//:pkg"], + "portpicker": ["@pypi_portpicker//:pkg"], + "ml_dtypes": ["@pypi_ml_dtypes//:pkg"], + "numpy": ["@pypi_numpy//:pkg"], + "scipy": ["@pypi_scipy//:pkg"], + "tensorflow_core": [], + "torch": [], + "zstandard": get_zstandard(), +} + +def all_py_deps(excluded = []): + py_deps_copy = dict(_py_deps) + for excl in excluded: + py_deps_copy.pop(excl) + return py_deps(py_deps_copy.keys()) + def py_deps(_package): """Returns the Bazel deps for Python package `package`.""" - # We assume the user has installed all dependencies in their Python environment. - # This indirection exists because in Google's internal build we build - # dependencies from source with Bazel, but that's not something most people would want. - return [] + if type(_package) == type([]) or type(_package) == type(()): + deduped_py_deps = {} + for _pkg in _package: + for py_dep in _py_deps[_pkg]: + deduped_py_deps[py_dep] = _pkg + + return deduped_py_deps.keys() + + return _py_deps[_package] def jax_visibility(_target): """Returns the additional Bazel visibilities for `target`.""" @@ -155,7 +196,7 @@ def windows_cc_shared_mlir_library(name, out, deps = [], srcs = [], exported_sym ALL_BACKENDS = ["cpu", "gpu", "tpu"] -def if_building_jaxlib(if_building, if_not_building = []): +def if_building_jaxlib(if_building, if_not_building = ["@pypi_jaxlib//:pkg"]): return select({ "//jax:enable_jaxlib_build": if_building, "//conditions:default": if_not_building, @@ -169,6 +210,7 @@ def jax_test( env = {}, shard_count = None, deps = [], + data = [], disable_backends = None, # buildifier: disable=unused-variable backend_variant_args = {}, # buildifier: disable=unused-variable backend_tags = {}, # buildifier: disable=unused-variable @@ -210,6 +252,7 @@ def jax_test( "//jax:enable_build_cuda_plugin_from_source": ["//jax_plugins:gpu_plugin_only_test_deps"], "//conditions:default": [], }), + data = data, shard_count = test_shards, tags = test_tags, main = main, diff --git a/jaxlib/kernel_nanobind_helpers.h b/jaxlib/kernel_nanobind_helpers.h index a5863441569b..01d027cf6719 100644 --- a/jaxlib/kernel_nanobind_helpers.h +++ b/jaxlib/kernel_nanobind_helpers.h @@ -21,7 +21,7 @@ limitations under the License. #include "nanobind/nanobind.h" #include "absl/base/casts.h" #include "jaxlib/kernel_helpers.h" -#include "tsl/python/lib/core/numpy.h" // NOLINT +#include "xla/tsl/python/lib/core/numpy.h" // NOLINT namespace jax { diff --git a/jaxlib/mlir/BUILD.bazel b/jaxlib/mlir/BUILD.bazel index c9268aaf3dd7..6626d0d162ab 100644 --- a/jaxlib/mlir/BUILD.bazel +++ b/jaxlib/mlir/BUILD.bazel @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -load("//jaxlib:symlink_files.bzl", "symlink_inputs") +load("//jaxlib:symlink_files.bzl", "symlink_files", "symlink_inputs") package( default_visibility = [ @@ -44,7 +44,10 @@ symlink_inputs( name = "ir", rule = py_library, symlinked_inputs = {"srcs": { - ".": ["@llvm-project//mlir/python:IRPyFiles"], + ".": [ + "@llvm-project//mlir/python:IRPyFiles", + "@llvm-project//mlir/python:IRPyIFiles", + ], }}, deps = [ ":mlir", @@ -197,7 +200,10 @@ symlink_inputs( name = "pass_manager", rule = py_library, symlinked_inputs = {"srcs": { - ".": ["@llvm-project//mlir/python:PassManagerPyFiles"], + ".": [ + "@llvm-project//mlir/python:PassManagerPyFiles", + "@llvm-project//mlir/python:PassManagerPyIFiles", + ], }}, deps = [ ":mlir", @@ -217,3 +223,82 @@ symlink_inputs( "//jaxlib/mlir/_mlir_libs:_stablehlo", ], ) + +symlink_inputs( + name = "nvgpu_dialect", + rule = py_library, + symlinked_inputs = {"srcs": {"dialects": [ + "@llvm-project//mlir/python:NVGPUOpsPyFiles", + ]}}, + deps = [ + ":core", + ":ir", + ":mlir", + "//jaxlib/mlir/_mlir_libs:_mlirDialectsNVGPU", + ], +) + +symlink_inputs( + name = "nvvm_dialect", + rule = py_library, + symlinked_inputs = {"srcs": {"dialects": [ + "@llvm-project//mlir/python:NVVMOpsPyFiles", + ]}}, + deps = [ + ":core", + ":ir", + ":mlir", + ], +) + +symlink_files( + name = "gpu_files", + srcs = ["@llvm-project//mlir/python:GPUOpsPyFiles"], + dst = "dialects", + flatten = True, +) + +symlink_files( + name = "gpu_package_files", + srcs = ["@llvm-project//mlir/python:GPUOpsPackagePyFiles"], + dst = "dialects/gpu", + flatten = True, +) + +symlink_files( + name = "gpu_package_passes_files", + srcs = ["@llvm-project//mlir/python:GPUOpsPackagePassesPyFiles"], + dst = "dialects/gpu/passes", + flatten = True, +) + +py_library( + name = "gpu_dialect", + srcs = [ + ":gpu_files", + ":gpu_package_files", + ":gpu_package_passes_files", + ], + deps = [ + ":core", + ":ir", + ":mlir", + "//jaxlib/mlir/_mlir_libs:_mlirDialectsGPU", + "//jaxlib/mlir/_mlir_libs:_mlirGPUPasses", + ], +) + +symlink_inputs( + name = "llvm_dialect", + rule = py_library, + symlinked_inputs = {"srcs": {"dialects": [ + "@llvm-project//mlir/python:LLVMOpsPyFiles", + ]}}, + deps = [ + ":core", + ":ir", + ":mlir", + "//jaxlib/mlir/_mlir_libs:_mlirDialectsLLVM", + ], +) + diff --git a/jaxlib/mlir/_mlir_libs/BUILD.bazel b/jaxlib/mlir/_mlir_libs/BUILD.bazel index 082e0f765e84..22735eeaf1ad 100644 --- a/jaxlib/mlir/_mlir_libs/BUILD.bazel +++ b/jaxlib/mlir/_mlir_libs/BUILD.bazel @@ -34,11 +34,11 @@ COPTS = [ ] LINKOPTS = select({ - "@tsl//tsl:macos": [ + "@xla//xla/tsl:macos": [ "-Wl,-rpath,@loader_path/", "-Wl,-rename_section,__TEXT,text_env,__TEXT,__text", ], - "@tsl//tsl:windows": [], + "@xla//xla/tsl:windows": [], "//conditions:default": [ "-Wl,-rpath,$$ORIGIN/", ], @@ -58,6 +58,68 @@ py_extension( ], ) +py_extension( + name = "_mlirDialectsGPU", + srcs = [ + "@llvm-project//mlir:lib/Bindings/Python/DialectGPU.cpp", + ], + copts = COPTS, + linkopts = LINKOPTS, + deps = [ + ":jaxlib_mlir_capi_shared_library", + "@llvm-project//mlir:CAPIGPUHeaders", + "@llvm-project//mlir:CAPIIRHeaders", + "@llvm-project//mlir:MLIRBindingsPythonHeaders", + "@pybind11", + ], +) + +py_extension( + name = "_mlirGPUPasses", + srcs = [ + "@llvm-project//mlir:lib/Bindings/Python/GPUPasses.cpp", + ], + copts = COPTS, + linkopts = LINKOPTS, + deps = [ + ":jaxlib_mlir_capi_shared_library", + "@llvm-project//mlir:CAPIGPUHeaders", + "@pybind11", + ], +) + +py_extension( + name = "_mlirDialectsNVGPU", + srcs = [ + "@llvm-project//mlir:lib/Bindings/Python/DialectNVGPU.cpp", + ], + copts = COPTS, + linkopts = LINKOPTS, + deps = [ + ":jaxlib_mlir_capi_shared_library", + "@llvm-project//mlir:CAPIIRHeaders", + "@llvm-project//mlir:CAPINVGPUHeaders", + "@llvm-project//mlir:MLIRBindingsPythonHeaders", + "@pybind11", + ], +) + +py_extension( + name = "_mlirDialectsLLVM", + srcs = [ + "@llvm-project//mlir:lib/Bindings/Python/DialectLLVM.cpp", + ], + copts = COPTS, + linkopts = LINKOPTS, + deps = [ + ":jaxlib_mlir_capi_shared_library", + "@llvm-project//mlir:CAPIIRHeaders", + "@llvm-project//mlir:CAPILLVMHeaders", + "@llvm-project//mlir:MLIRBindingsPythonHeaders", + "@pybind11", + ], +) + py_extension( name = "_mlirDialectsSparseTensor", srcs = [ @@ -148,6 +210,31 @@ symlink_inputs( ], ) +cc_library( + name = "jaxlib_mlir_capi_shims", + srcs = ["jaxlib_mlir_capi_shims.cc"], + hdrs = ["jaxlib_mlir_capi_shims.h"], + deps = [ + "@llvm-project//mlir:BuiltinToLLVMIRTranslation", + "@llvm-project//mlir:CAPIIRHeaders", + "@llvm-project//mlir:GPUPipelines", + "@llvm-project//mlir:GPUToLLVMIRTranslation", + "@llvm-project//mlir:LLVMToLLVMIRTranslation", + "@llvm-project//mlir:MemRefTransforms", + "@llvm-project//mlir:NVVMTarget", + "@llvm-project//mlir:NVVMToLLVMIRTranslation", + ], + alwayslink = 1, +) + +cc_library( + name = "jaxlib_mlir_capi_shims_hdrs", + hdrs = ["jaxlib_mlir_capi_shims.h"], + deps = [ + "@llvm-project//mlir:CAPIIRHeaders", + ], +) + # JAX-specific registrations. py_extension( name = "register_jax_dialects", @@ -157,9 +244,13 @@ py_extension( deps = [ ":jaxlib_mlir_capi_shared_library", "@llvm-project//mlir:CAPIArithHeaders", + "@llvm-project//mlir:CAPIGPUHeaders", "@llvm-project//mlir:CAPIIRHeaders", + "@llvm-project//mlir:CAPILLVMHeaders", "@llvm-project//mlir:CAPIMathHeaders", "@llvm-project//mlir:CAPIMemRefHeaders", + "@llvm-project//mlir:CAPINVGPUHeaders", + "@llvm-project//mlir:CAPINVVMHeaders", "@llvm-project//mlir:CAPISCFHeaders", "@llvm-project//mlir:CAPITransformsHeaders", "@llvm-project//mlir:CAPIVectorHeaders", @@ -223,9 +314,11 @@ py_extension( deps = [ ":jaxlib_mlir_capi_shared_library", "@llvm-project//mlir:CAPIIRHeaders", + "@llvm-project//mlir:IR", "@llvm-project//mlir:MLIRBindingsPythonHeaders", "@local_config_python//:headers", "@pybind11", + "@stablehlo//:reference_api", "@stablehlo//:stablehlo_capi_headers", "@stablehlo//:stablehlo_portable_api", "@stablehlo//:stablehlo_serialization", @@ -237,12 +330,12 @@ py_extension( cc_library( name = "jaxlib_mlir_capi_shared_library", srcs = select({ - "@tsl//tsl:windows": [":jaxlib_mlir_capi.dll"], - "@tsl//tsl:macos": [":libjaxlib_mlir_capi.dylib"], + "@xla//xla/tsl:windows": [":jaxlib_mlir_capi.dll"], + "@xla//xla/tsl:macos": [":libjaxlib_mlir_capi.dylib"], "//conditions:default": [":libjaxlib_mlir_capi.so"], }), deps = select({ - "@tsl//tsl:windows": [":jaxlib_mlir_capi_dll"], + "@xla//xla/tsl:windows": [":jaxlib_mlir_capi_dll"], "//conditions:default": [], }), ) @@ -252,9 +345,13 @@ cc_library( deps = [ "//jaxlib/mosaic:tpu_dialect_capi_objects", "@llvm-project//mlir:CAPIArithObjects", + "@llvm-project//mlir:CAPIGPUObjects", "@llvm-project//mlir:CAPIIRObjects", + "@llvm-project//mlir:CAPILLVMObjects", "@llvm-project//mlir:CAPIMathObjects", "@llvm-project//mlir:CAPIMemRefObjects", + "@llvm-project//mlir:CAPINVGPUObjects", + "@llvm-project//mlir:CAPINVVMObjects", "@llvm-project//mlir:CAPISCFObjects", "@llvm-project//mlir:CAPISparseTensorObjects", "@llvm-project//mlir:CAPITransformsObjects", diff --git a/jaxlib/mlir/_mlir_libs/register_jax_dialects.cc b/jaxlib/mlir/_mlir_libs/register_jax_dialects.cc index 87b1cfbb3c9a..e1958c211b33 100644 --- a/jaxlib/mlir/_mlir_libs/register_jax_dialects.cc +++ b/jaxlib/mlir/_mlir_libs/register_jax_dialects.cc @@ -3,8 +3,12 @@ #include "mlir-c/Dialect/Arith.h" #include "mlir-c/Dialect/Func.h" +#include "mlir-c/Dialect/GPU.h" +#include "mlir-c/Dialect/LLVM.h" #include "mlir-c/Dialect/Math.h" #include "mlir-c/Dialect/MemRef.h" +#include "mlir-c/Dialect/NVGPU.h" +#include "mlir-c/Dialect/NVVM.h" #include "mlir-c/Dialect/SCF.h" #include "mlir-c/Dialect/Vector.h" #include "mlir-c/Transforms.h" @@ -24,6 +28,11 @@ PYBIND11_MODULE(register_jax_dialects, m) { REGISTER_DIALECT(memref); REGISTER_DIALECT(scf); REGISTER_DIALECT(vector); + // For Mosaic GPU + REGISTER_DIALECT(gpu); + REGISTER_DIALECT(nvgpu); + REGISTER_DIALECT(nvvm); + REGISTER_DIALECT(llvm); mlirRegisterTransformsPasses(); // Transforms used by JAX. mlirRegisterTransformsStripDebugInfo(); diff --git a/jaxlib/mlir/_mlir_libs/tpu_ext.cc b/jaxlib/mlir/_mlir_libs/tpu_ext.cc index 004bd531ced6..a1531ff9f6de 100644 --- a/jaxlib/mlir/_mlir_libs/tpu_ext.cc +++ b/jaxlib/mlir/_mlir_libs/tpu_ext.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include #include +#include #include #include #include @@ -35,6 +36,7 @@ limitations under the License. #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/SmallVectorExtras.h" +#include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringRef.h" #include "mlir-c/AffineMap.h" #include "mlir-c/Bindings/Python/Interop.h" @@ -70,7 +72,9 @@ constexpr MlirTpuI64TargetTuple TARGET_SHAPE{8, 128}; // TODO(tlongeri): For our use-case, we don't really need C++ exceptions - just // setting the exception object and returning NULL to Python should suffice, but // not sure if this is possible with pybind. -class NotImplementedException : public std::exception {}; +class NotImplementedException : public std::runtime_error { + using runtime_error::runtime_error; +}; } // namespace template <> @@ -92,7 +96,7 @@ struct py::detail::type_caster { } else if (src.is(implicit_dim_cls.attr("SECOND_MINOR"))) { value = MlirTpuImplicitDimSecondMinor; } else { - throw NotImplementedException(); + throw py::value_error(); } return true; } @@ -156,37 +160,59 @@ struct py::detail::type_caster { }; namespace { -class NotImplementedDetector { +// Handler for use with MLIR C API print functions. The 2nd parameter is an +// opaque pointer to "user data" that should always be a string. +void printToString(MlirStringRef c_mlir_str, void* opaque_string) { + std::string* str = static_cast(opaque_string); + CHECK(str != nullptr); + str->append(c_mlir_str.data, c_mlir_str.length); +} + +class DiagnosticCapture { public: - NotImplementedDetector(MlirContext ctx) + DiagnosticCapture(MlirContext ctx) : ctx_(ctx), id_(mlirContextAttachDiagnosticHandler(ctx, handleDiagnostic, this, nullptr)) {} - ~NotImplementedDetector() { mlirContextDetachDiagnosticHandler(ctx_, id_); } - bool detected() const { return detected_; } - - private: - static void handleDiagnosticMessage(MlirStringRef str, - void* opaque_detector) { - // Note that we receive each argument to the stream separately. - // "Not implemented" must be entirely in a single argument. - NotImplementedDetector* detector = - static_cast(opaque_detector); - if (llvm::StringRef(str.data, str.length).contains("Not implemented")) { - detector->detected_ = true; + ~DiagnosticCapture() { mlirContextDetachDiagnosticHandler(ctx_, id_); } + + void throwIfError() { + if (error_messages_.size() == 1) { + // Throw NotImplementedException if we got a single diagnostic that + // contains "Not implemented". + llvm::StringRef ref = error_messages_.front(); + constexpr llvm::StringRef not_implemented = "Not implemented"; + if (const size_t pos = ref.find(not_implemented); + pos != llvm::StringRef::npos) { + // We strip "Not implemented" only if it is a prefix. Sometimes it may + // come after another prefix (e.g. op prefix), in which case we leave it + if (pos == 0) { + ref = ref.drop_front(not_implemented.size()); + ref.consume_front(": "); + } + throw NotImplementedException(ref.str()); + } + } + if (!error_messages_.empty()) { + // Note that it is unusual/unexpected to get multiple diagnostics, so we + // just forward all the error messages. + throw std::runtime_error(llvm::join(error_messages_, "\n")); } } + + private: static MlirLogicalResult handleDiagnostic(MlirDiagnostic diag, void* opaque_detector) { - NotImplementedDetector* detector = - static_cast(opaque_detector); + DiagnosticCapture* detector = + static_cast(opaque_detector); if (mlirDiagnosticGetSeverity(diag) == MlirDiagnosticError) { - mlirDiagnosticPrint(diag, handleDiagnosticMessage, detector); + std::string& message = detector->error_messages_.emplace_back(); + mlirDiagnosticPrint(diag, printToString, &message); } return mlirLogicalResultFailure(); // Propagate to other handlers } - bool detected_ = false; + llvm::SmallVector error_messages_; const MlirContext ctx_; const MlirDiagnosticHandlerID id_; }; @@ -562,7 +588,13 @@ PYBIND11_MODULE(_tpu_ext, m) { " shape: An optional shape of the vector to which both layouts " "apply. More layouts are considered equivalent when the shape is " "specified. Also see the docstring of the generalizes method.") - .def("__eq__", mlirTpuVectorLayoutEquals); + .def("__eq__", mlirTpuVectorLayoutEquals) + .def("__repr__", + [](MlirTpuVectorLayout self) { + std::string str; + mlirTpuVectorLayoutPrint(self, printToString, &str); + return str; + }); // TODO(tlongeri): Can we make the first parameter a VectorType? m.def("assemble", @@ -589,13 +621,11 @@ PYBIND11_MODULE(_tpu_ext, m) { TARGET_SHAPE); }); m.def("disassemble", [](MlirTpuVectorLayout layout, MlirValue val) { - NotImplementedDetector detector(getDefaultContext()); + DiagnosticCapture diag_capture(getDefaultContext()); MlirTpuValueArray val_arr = mlirTpuDisassemble(getDefaultInsertionPoint(), layout, val, TARGET_SHAPE); if (val_arr.vals == nullptr) { - if (detector.detected()) { - throw NotImplementedException(); - } + diag_capture.throwIfError(); throw py::value_error("Failed to disassemble"); } py::array_t np_vals( @@ -609,25 +639,21 @@ PYBIND11_MODULE(_tpu_ext, m) { }); m.def("apply_layout_op", [](int hardware_generation, const MlirOperation c_op) { - NotImplementedDetector detector(getDefaultContext()); + DiagnosticCapture diag_capture(getDefaultContext()); MlirLogicalResult res = mlirTpuApplyLayoutOp(hardware_generation, c_op, TARGET_SHAPE); if (mlirLogicalResultIsFailure(res)) { - if (detector.detected()) { - throw NotImplementedException(); - } + diag_capture.throwIfError(); throw std::runtime_error("applyLayoutOp failed"); } }); m.def("relayout", [](MlirValue v, MlirTpuVectorLayout src, MlirTpuVectorLayout dst) { - NotImplementedDetector detector(getDefaultContext()); + DiagnosticCapture diag_capture(getDefaultContext()); MlirValue new_v = mlirTpuRelayout(getDefaultInsertionPoint(), v, src, dst, TARGET_SHAPE); if (new_v.ptr == nullptr) { - if (detector.detected()) { - throw NotImplementedException(); - } + diag_capture.throwIfError(); throw py::value_error("Failed to relayout"); } return new_v; @@ -636,7 +662,7 @@ PYBIND11_MODULE(_tpu_ext, m) { try { if (p) std::rethrow_exception(p); } catch (const NotImplementedException& e) { - PyErr_SetNone(PyExc_NotImplementedError); + PyErr_SetString(PyExc_NotImplementedError, e.what()); } }); diff --git a/jaxlib/mosaic/BUILD b/jaxlib/mosaic/BUILD index 51986d1fcccb..a14d69881d05 100644 --- a/jaxlib/mosaic/BUILD +++ b/jaxlib/mosaic/BUILD @@ -57,6 +57,7 @@ cc_library( "@llvm-project//mlir:ArithDialect", "@llvm-project//mlir:ControlFlowDialect", "@llvm-project//mlir:Dialect", + "@llvm-project//mlir:DialectUtils", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:LinalgTransforms", @@ -65,16 +66,21 @@ cc_library( "@llvm-project//mlir:Pass", "@llvm-project//mlir:SCFDialect", "@llvm-project//mlir:Support", + "@llvm-project//mlir:TensorDialect", "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:VectorDialect", "@llvm-project//mlir:VectorTransforms", "@xla//xla:array", "@xla//xla:shape_util", "@xla//xla:util", + "@tsl//tsl/platform:statusor", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/hash", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", + "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:span", ], ) diff --git a/jaxlib/mosaic/dialect/tpu/integrations/c/tpu_dialect.cc b/jaxlib/mosaic/dialect/tpu/integrations/c/tpu_dialect.cc index 4ae2d738d93b..73b1b1e56ef2 100644 --- a/jaxlib/mosaic/dialect/tpu/integrations/c/tpu_dialect.cc +++ b/jaxlib/mosaic/dialect/tpu/integrations/c/tpu_dialect.cc @@ -24,10 +24,12 @@ limitations under the License. #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/MemAlloc.h" +#include "llvm/Support/raw_ostream.h" #include "mlir-c/IR.h" #include "mlir-c/Support.h" #include "mlir/CAPI/IR.h" #include "mlir/CAPI/Registration.h" +#include "mlir/CAPI/Utils.h" #include "mlir/CAPI/Wrap.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/Attributes.h" @@ -295,6 +297,12 @@ bool mlirTpuVectorLayoutEquivalentTo(MlirTpuVectorLayout layout, unwrap(target_shape)); } +void mlirTpuVectorLayoutPrint( + MlirTpuVectorLayout layout, MlirStringCallback callback, void *userData) { + mlir::detail::CallbackOstream stream(callback, userData); + unwrap(layout)->print(stream); +} + void mlirTpuVregDataBoundsDestroy(MlirTpuVregDataBounds data_bounds) { delete unwrap(data_bounds); } diff --git a/jaxlib/mosaic/dialect/tpu/integrations/c/tpu_dialect.h b/jaxlib/mosaic/dialect/tpu/integrations/c/tpu_dialect.h index 60774147abe9..d1f126db4566 100644 --- a/jaxlib/mosaic/dialect/tpu/integrations/c/tpu_dialect.h +++ b/jaxlib/mosaic/dialect/tpu/integrations/c/tpu_dialect.h @@ -176,6 +176,9 @@ MLIR_CAPI_EXPORTED bool mlirTpuVectorLayoutEquivalentTo( MlirTpuVectorLayout layout, MlirTpuVectorLayout other, MlirTpuI64ArrayRef shape, MlirTpuI64TargetTuple target_shape); +MLIR_CAPI_EXPORTED void mlirTpuVectorLayoutPrint( + MlirTpuVectorLayout layout, MlirStringCallback callback, void* user_data); + MLIR_CAPI_EXPORTED void mlirTpuVregDataBoundsDestroy( MlirTpuVregDataBounds data_bounds); diff --git a/jaxlib/mosaic/dialect/tpu/layout.cc b/jaxlib/mosaic/dialect/tpu/layout.cc index 123ee2d56760..7be9c85b2e7c 100644 --- a/jaxlib/mosaic/dialect/tpu/layout.cc +++ b/jaxlib/mosaic/dialect/tpu/layout.cc @@ -40,7 +40,6 @@ limitations under the License. #include "mlir/IR/ValueRange.h" #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" -#include "mlir/Support/MathExtras.h" #include "absl/log/check.h" #include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h" #include "jaxlib/mosaic/dialect/tpu/util.h" @@ -440,26 +439,10 @@ bool VectorLayout::hasNativeTiling( SmallVector VectorLayout::implicitShape( ArrayRef shape) const { - CHECK(!shape.empty()); - switch (implicit_dim_) { - case ImplicitDim::kNone: - return SmallVector(shape); - case ImplicitDim::kMinor: { - SmallVector implicit_shape; - implicit_shape.reserve(shape.size() + 1); - implicit_shape.append(shape.begin(), shape.end()); - implicit_shape.push_back(1); - return implicit_shape; - } - case ImplicitDim::kSecondMinor: { - SmallVector implicit_shape; - implicit_shape.reserve(shape.size() + 1); - implicit_shape.append(shape.begin(), std::prev(shape.end())); - implicit_shape.push_back(1); - implicit_shape.push_back(shape.back()); - return implicit_shape; - } - } + SmallVector implicit_shape(shape); + implicit_shape.reserve(shape.size() + num_implicit_dims()); + insertImplicit(implicit_shape, 1); + return implicit_shape; } SmallVector VectorLayout::tileArrayImplicitShape( @@ -482,16 +465,7 @@ SmallVector VectorLayout::tileArrayShape( SmallVector tiles_shape = tileArrayImplicitShape(shape, target_shape); // Remove the implicit dimension --- it's always of size 1. - switch (implicit_dim_) { - case ImplicitDim::kNone: - break; - case ImplicitDim::kMinor: - tiles_shape.pop_back(); - break; - case ImplicitDim::kSecondMinor: - tiles_shape.erase(tiles_shape.end() - 2); - break; - } + eraseImplicit(tiles_shape); return tiles_shape; } @@ -502,29 +476,17 @@ std::unique_ptr VectorLayout::tileDataBounds( // TODO(apaszke): allow_replicated could have been generalized to specify // what action should be taken when a REPLICATED offset is encountered. // Right now it either disallows replication, or selects the whole dimension. - int64_t s, l; - switch (implicit_dim_) { - case ImplicitDim::kNone: - s = idxs[idxs.size() - 2]; - l = idxs[idxs.size() - 1]; - break; - case ImplicitDim::kMinor: - s = idxs[idxs.size() - 1]; - l = 0; - break; - case ImplicitDim::kSecondMinor: - s = 0; - l = idxs[idxs.size() - 1]; - break; - } - + const std::array tiled_idxs = getImplicitTiledDims(idxs, 0); + const int64_t s = tiled_idxs[0]; + const int64_t l = tiled_idxs[1]; const SmallVector tiles_implicit_shape = tileArrayImplicitShape(full_shape, target_shape); - const int64_t ns = tiles_implicit_shape[tiles_implicit_shape.size() - 2]; - const int64_t nl = tiles_implicit_shape[tiles_implicit_shape.size() - 1]; - const SmallVector implicit_shape = implicitShape(full_shape); - const int64_t is = implicit_shape[implicit_shape.size() - 2]; - const int64_t il = implicit_shape[implicit_shape.size() - 1]; + const int64_t ns = *(tiles_implicit_shape.end() - 2); + const int64_t nl = *(tiles_implicit_shape.end() - 1); + const std::array shape_tiled_dims = + getImplicitTiledDims(full_shape, 1); + const int64_t is = shape_tiled_dims[0]; + const int64_t il = shape_tiled_dims[1]; if (!hasNaturalTopology(target_shape)) { if (!offsets_[0].has_value() || !offsets_[1].has_value()) { @@ -608,30 +570,26 @@ bool VectorLayout::generalizes( } } if (implicit_dim_ != other.implicit_dim_) { - // Don't fail yet! implicit_dim might not matter for some shapes. - if (shape.data() == nullptr) { - return false; - } - // If the second-minor dimension is of size 1, then it does not matter - // whether we have a second minor implicit dim or not. - bool ok = false; - if (((implicit_dim_ == ImplicitDim::kSecondMinor && + // Don't fail yet! + if (tiling_[0] == 1 && other.tiling_[0] == 1 && + ((implicit_dim_ == ImplicitDim::kSecondMinor && other.implicit_dim_ == ImplicitDim::kNone) || - (other.implicit_dim_ == ImplicitDim::kSecondMinor && - implicit_dim_ == ImplicitDim::kNone)) && - shape[shape.size() - 2] == 1) { - ok = true; - } - // If sufficiently many trailing dimensions are of size 1, then it does not - // matter if we use implicit dims to insert more. - int max_rank = std::max(layout_rank(), other.layout_rank()); - CHECK_GE(max_rank, 1); - CHECK_LE(max_rank, 2); - if (*(shape.end() - 1) == 1 && (max_rank == 1 || *(shape.end() - 2) == 1)) { - ok = true; - } - if (!ok) { - return false; + (implicit_dim_ == ImplicitDim::kNone && + other.implicit_dim_ == ImplicitDim::kSecondMinor))) { + // If the tiling is (1, n), we can always squeeze an implicit 2nd minor + // dimension without having to combine vregs. + } else { + if (shape.data() == nullptr) { + return false; + } + // Since we do not reorder axes, if the shapes resulting from inserting + // implicit dimensions are the same in the 2 minormost dimensions for both + // layouts, then the elements must be laid out the same way (before + // tiling). + if (getImplicitTiledDims(shape, 1) != + other.getImplicitTiledDims(shape, 1)) { + return false; + } } } if (tiling_ != other.tiling_) { @@ -640,11 +598,15 @@ bool VectorLayout::generalizes( if (shape.data() == nullptr) { return false; } - const SmallVector ishape = implicitShape(shape); + + // We can assume the implicit shape is the same for both layouts. They are + // only allowed to be different when both tilings are equal to (1, n) (and + // each other), and we've checked that tilings are different above. + const std::array ishape_tiled_dims = + getImplicitTiledDims(shape, 1); if (!(tiling_[1] == other.tiling_[1] && tiling_[1] == target_shape[1] && - offsets_[1].value_or(0) + ishape[ishape.size() - 1] <= - target_shape[1] && - offsets_[0].value_or(0) + ishape[ishape.size() - 2] <= + offsets_[1].value_or(0) + ishape_tiled_dims[1] <= target_shape[1] && + offsets_[0].value_or(0) + ishape_tiled_dims[0] <= std::min(tiling_[0], other.tiling_[0]))) { return false; } @@ -682,26 +644,8 @@ std::optional VectorLayout::join(const VectorLayout& l, if (l.bitwidth_ != r.bitwidth_ || l.tiling_ != r.tiling_) { return std::nullopt; } - if (l.implicit_dim_ != r.implicit_dim_) { - if (shape.size() < 2) { - return std::nullopt; - } - ImplicitDim dim; - if (l.implicit_dim_ == ImplicitDim::kNone) { - dim = r.implicit_dim_; - } else if (r.implicit_dim_ == ImplicitDim::kNone) { - dim = l.implicit_dim_; - } else { - return std::nullopt; - } - if (dim == ImplicitDim::kMinor && shape[shape.size() - 1] == 1) { - // OK, they are equivalent. - } else if (dim == ImplicitDim::kSecondMinor && - shape[shape.size() - 2] == 1) { - // OK, they are equivalent. - } else { - return std::nullopt; - } + if (l.getImplicitTiledDims(shape, 1) != r.getImplicitTiledDims(shape, 1)) { + return std::nullopt; } LayoutOffsets offsets; for (int i = 0; i < 2; ++i) { diff --git a/jaxlib/mosaic/dialect/tpu/layout.h b/jaxlib/mosaic/dialect/tpu/layout.h index e23310f968e9..f6903469f0b1 100644 --- a/jaxlib/mosaic/dialect/tpu/layout.h +++ b/jaxlib/mosaic/dialect/tpu/layout.h @@ -139,13 +139,13 @@ class RectangularVregBounds : public VRegDataBounds { // . . . . . . . . . . // . . a b c d e . . . // . . f g h i j . . . -// . . . . . . . . . . +// . . k l m n o . . . // // vreg 3 vreg 4 -// . . k l m n o . . . // . . p q r s t . . . // . . . . . . . . . . // . . . . . . . . . . +// . . . . . . . . . . // // The dot character indicates padding. Nothing should be assumed about the // value of those entries. @@ -207,6 +207,15 @@ class RectangularVregBounds : public VRegDataBounds { // one specified as an attribute. // implicit_dim: If specified, the value has an implicit dim inserted in // either minormost or second minormost position. +// +// Note: There is a special case when VectorLayout is used for an mlir::Value +// of i1 type. In this case, we use it to represent a vmask, which has a smaller +// bitwidth than a vreg. For these types, the packing() is accurate but the +// bitwidth() is a lie, and the i1 value is replicated for every bit. +// For example, if the vmask is 8 x 128 x 4 bits and packing() == 2, each 4-bit +// register contains two logical bool values which are represented as either b11 +// or b00. Its usage is currently limited to MLIR arith.cmp and arith.select ops +// but we might want to split out a separate class if it gets used more widely. class VectorLayout { public: enum class ImplicitDim { @@ -236,8 +245,17 @@ class VectorLayout { const std::array &tiling() const { return tiling_; } ImplicitDim implicit_dim() const { return implicit_dim_; } int packing() const { return 32 / bitwidth_; } + int num_implicit_dims() const { + switch (implicit_dim_) { + case ImplicitDim::kNone: + return 0; + case ImplicitDim::kMinor: + case ImplicitDim::kSecondMinor: + return 1; + } + } // The number of minormost dimensions tiled by this layout. - int layout_rank() const { return 1 + (implicit_dim_ == ImplicitDim::kNone); } + int layout_rank() const { return 2 - num_implicit_dims(); } bool operator==(const VectorLayout &other) const; bool operator!=(const VectorLayout &other) const { @@ -268,9 +286,56 @@ class VectorLayout { return {tiling_[0], tilesPerVreg(target_shape) * tiling_[1]}; } + template + void insertImplicit(SmallVector &vec, T value) const { + CHECK_GE(vec.size(), layout_rank()); + switch (implicit_dim_) { + case ImplicitDim::kNone: + break; + case ImplicitDim::kMinor: + case ImplicitDim::kSecondMinor: + vec.insert(vec.end() - (static_cast(implicit_dim_) - 1), + value); + break; + } + } + + template + void eraseImplicit(SmallVector &vec) const { + CHECK_GE(vec.size(), 2); + switch (implicit_dim_) { + case ImplicitDim::kNone: + break; + case ImplicitDim::kMinor: + case ImplicitDim::kSecondMinor: + vec.erase(vec.end() - static_cast(implicit_dim_)); + break; + } + } + + // Returns the value of the tiled (2 minormost) dimensions of the given array + // with implicit dims inserted. + // + // Roughly equivalent to the following (but avoids vector allocation): + // + // SmallVector vec = arr; + // insertImplicit(arr, implicit_value); + // return {*(vec.end() - 2), *(vec.end() - 1)}; + std::array getImplicitTiledDims( + const ArrayRef arr, const int64_t implicit_value) const { + CHECK_GE(arr.size(), layout_rank()); + switch (implicit_dim_) { + case ImplicitDim::kNone: + return {*(arr.end() - 2), *(arr.end() - 1)}; + case ImplicitDim::kMinor: + return {*(arr.end() - 1), implicit_value}; + case ImplicitDim::kSecondMinor: + return {implicit_value, *(arr.end() - 1)}; + } + } + SmallVector implicitShape(ArrayRef shape) const; - private: SmallVector tileArrayImplicitShape( ArrayRef shape, std::array target_shape) const; diff --git a/jaxlib/mosaic/dialect/tpu/tpu.td b/jaxlib/mosaic/dialect/tpu/tpu.td index 8bd083e8321b..93e4abe3e422 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu.td +++ b/jaxlib/mosaic/dialect/tpu/tpu.td @@ -31,7 +31,6 @@ def TPU_Dialect : Dialect { let cppNamespace = "::mlir::tpu"; let useDefaultAttributePrinterParser = 1; let useDefaultTypePrinterParser = 1; - let usePropertiesForAttributes = 0; } class TPU_Attr traits = []> @@ -98,6 +97,18 @@ def TPU_ContractPrecisionEnum let assemblyFormat = "`<` $value `>`"; } +def TPU_PackFormat : I32EnumAttr<"PackFormat", "Pack format", [ + I32EnumAttrCase<"kCompressed", 0, "compressed">, + I32EnumAttrCase<"kInterleaved", 1, "interleaved"> +]> { + let genSpecializedAttr = 0; + let cppNamespace = "::mlir::tpu"; +} + +def TPU_PackFormatEnum : EnumAttr { + let assemblyFormat = "`<` $value `>`"; +} + def TPU_TiledCase : I32EnumAttrCase<"tiled", 0>; def TPU_LaneCase : I32EnumAttrCase<"lanes", 1>; def TPU_SublaneCase : I32EnumAttrCase<"sublanes", 2>; @@ -194,6 +205,34 @@ def TPU_LoadOp : TPU_Op<"load"> { }]; } +def TPU_StridedLoadOp : TPU_Op<"strided_load"> { + let arguments = (ins + AnyMemRef:$base, + Variadic:$indices, + DenseI32ArrayAttr:$strides + ); + let results = (outs AnyVector:$result); + let assemblyFormat = [{ + $base `[` $indices `]` attr-dict `:` type($base) `,` type($result) + }]; + let hasVerifier = 1; +} + +def TPU_StridedStoreOp : TPU_Op<"strided_store"> { + let arguments = (ins + AnyVector:$valueToStore, + AnyMemRef:$base, + Variadic:$indices, + DenseI32ArrayAttr:$strides + ); + let results = (outs); + let assemblyFormat = [{ + $base `[` $indices `]` `,` $valueToStore attr-dict `:` type($base) `,` type($valueToStore) + }]; + let hasVerifier = 1; +} + +// TODO(jevinjiang): deprecate to use dynamic_rotate. def TPU_RotateOp : TPU_Op<"rotate", [Pure, SameOperandsAndResultType]> { let arguments = (ins AnyVector:$value, @@ -211,6 +250,23 @@ def TPU_RotateOp : TPU_Op<"rotate", [Pure, SameOperandsAndResultType]> { let hasVerifier = 1; } +def TPU_DynamicRotateOp : TPU_Op<"dynamic_rotate", [Pure]> { + let arguments = (ins + AnyVector:$value, + I32:$amount, + SI32Attr:$dimension, + // When the stride is specified, the rotation amount for each index on the + // stride dimension will be (amount + stride * index). + OptionalAttr:$stride, + OptionalAttr:$stride_dimension + ); + let results = (outs AnyVector:$result); + let assemblyFormat = [{ + $value `by` $amount `dim` $dimension attr-dict `:` type($value) `,` type($amount) `->` type($result) + }]; + let hasVerifier = 1; +} + def TPU_IotaOp : TPU_Op<"iota", [Pure]> { let arguments = (ins OptionalAttr:$dimension); let results = (outs AnyVector:$output); @@ -252,7 +308,10 @@ def TPU_UnpackSubelementsOp : TPU_Op<"unpack_subelements", [Pure]> { // Integer packs are always signed at the moment. def TPU_PackSubelementsOp : TPU_Op<"pack_subelements", [Pure]> { - let arguments = (ins Variadic:$sources); + let arguments = (ins + Variadic:$sources, + TPU_PackFormatEnum:$pack_format + ); let results = (outs AnyVector:$output); let assemblyFormat = [{ $sources attr-dict `:` type($sources) `->` type($output) }]; } @@ -351,6 +410,35 @@ def TPU_CreateMaskOp : TPU_Op<"create_mask", [Pure, SameVariadicOperandSize]> { } def TPU_CreateSubelementMaskOp : TPU_Op<"create_subelement_mask", [Pure]> { + let summary = "Create a mask masking contiguous rows of subelements."; + // TODO(tlongeri): Why don't we just get `num_subelems` from the result type? + // Taking a parameter and allowing a mismatch is confusing. + let description = [{ + The "half-sublanes", "quarter-sublanes", etc. (unit is determined by + `num_subelems`) of the mask are masked in the range specified by `from` and + `to`. + + - If `from <= to`, the range `[from, to)` is set and the rest is unset. + - If `to <= from`, the range `[to, from)` is unset and the rest is set. + + All lanes are set identically. + + Example: + + ```mlir + %msk = tpu.create_subelement_mask 3, 9, 2 : vector<8x128x2xi1> + ``` + + This creates a mask `%msk` where, for all `lane`s, `%msk[*][lane][*]` is: + + ``` + [[0, 0], [0, 1], [1, 1], [1, 1], [1, 0], [0, 0], [0, 0], [0, 0]] + ``` + + It is currently only supported: + - In TPU v4, for `num_subelems` of 1 and 2. + - In TPU v5, for `num_subelems` of 1, 2, and 4. + }]; let arguments = (ins I32Attr:$from, // inclusive I32Attr:$to, // exclusive @@ -370,14 +458,16 @@ def TPU_AssumeMultipleOp : TPU_Op<"assume_multiple", [Pure, SameOperandsAndResul let results = (outs AnyTypeOf<[Index, AnyInteger]>:$result); } -def TPU_MemRefSliceOp : TPU_Op<"memref_slice", [Pure]> { +def TPU_MemRefSliceOp : TPU_Op<"memref_slice", [Pure, AttrSizedOperandSegments]> { let arguments = (ins AnyMemRef:$mem_ref, - Variadic:$base_idx + Variadic:$base_idx, + Variadic:$dynamic_sizes ); let results = (outs AnyMemRef:$result); let assemblyFormat = [{ - $mem_ref `[` $base_idx `]` attr-dict `:` type($mem_ref) `->` type($result) + $mem_ref `[` $base_idx `]` (`<` $dynamic_sizes^ `>`)? + attr-dict `:` type($mem_ref) `->` type($result) }]; let hasVerifier = 1; let hasCanonicalizeMethod = 1; @@ -424,6 +514,12 @@ def TPU_DeviceIdOp : TPU_Op<"device_id", [Pure]> { let assemblyFormat = [{ attr-dict `:` type($result) }]; } +def TPU_SemaphoreReadOp : TPU_Op<"sem_read"> { + let arguments = (ins MemRefOf<[TPU_SemaphoreType]>:$semaphore); + let results = (outs I32:$result); + let assemblyFormat = [{ $semaphore attr-dict `:` type($semaphore) `->` type($result)}]; +} + def TPU_SemaphoreWaitOp : TPU_Op<"sem_wait"> { let arguments = (ins MemRefOf<[TPU_SemaphoreType]>:$semaphore, @@ -446,25 +542,27 @@ def TPU_GetBarrierSemaphoreOp : TPU_Op<"sem_barrier"> { let hasVerifier = 1; } -def TPU_SemaphoreSignalOp : TPU_Op<"sem_signal"> { +def TPU_SemaphoreSignalOp : TPU_Op<"sem_signal", [AttrSizedOperandSegments]> { let arguments = (ins MemRefOf<[TPU_SemaphoreType]>:$semaphore, I32:$amount, - Optional:$device_id // For remote DMAs + Optional:$device_id, // For remote DMAs + Optional:$core_id // For megacore ); let assemblyFormat = [{ - $semaphore `,` $amount (`,` $device_id^)? attr-dict `:` type($semaphore) + $semaphore `,` $amount (`,` $device_id^)? (`,` $core_id^)? attr-dict `:` type($semaphore) }]; let hasVerifier = 1; } -def TPU_EnqueueDMAOp : TPU_Op<"enqueue_dma", [SameVariadicOperandSize]> { +def TPU_EnqueueDMAOp : TPU_Op<"enqueue_dma", [AttrSizedOperandSegments]> { let arguments = (ins AnyMemRef:$source, Optional>:$source_semaphore, // For remote DMAs AnyMemRef:$target, MemRefOf<[TPU_DMASemaphoreType]>:$target_semaphore, - Optional:$device_id // For remote DMAs + Optional:$device_id, // For remote DMAs + Optional:$core_id // For megacore ); let hasVerifier = 1; } @@ -504,6 +602,11 @@ def TPU_YieldOp : TPU_Op<"yield", [Pure, ReturnLike, Terminator]> { let assemblyFormat = [{ attr-dict ($results^ `:` type($results))? }]; } +def TPU_DelayOp : TPU_Op<"delay"> { + let arguments = (ins I32:$nanos); + let results = (outs); +} + // Expands the granularity of mask to subelements. def TPU_MaskCastOp : TPU_Op<"mask_cast", [Pure]> { let arguments = (ins AnyVector:$input); @@ -520,6 +623,32 @@ def TPU_GetIterationBoundOp : TPU_Op<"iteration_bound"> { let assemblyFormat = [{ $dim attr-dict `:` type($result) }]; } +def TPU_GetInternalScratchOp : TPU_Op<"internal_scratch"> { + let arguments = (ins); + let results = (outs AnyMemRef:$result); + let assemblyFormat = [{ attr-dict `:` type($result) }]; +} + +def TPU_PRNGSeed32Op : TPU_Op<"prng_set_seed_32"> { + let arguments = (ins Variadic:$seeds); + let results = (outs); +} + +def TPU_PRNGRandomBitsOp : TPU_Op<"prng_random_bits"> { + let arguments = (ins); + let results = (outs AnyVector:$output); +} + +def TPU_LogOp : TPU_Op<"log"> { + let arguments = (ins + Variadic:$inputs, + StrAttr:$tag, + DefaultValuedAttr:$formatted + ); + let results = (outs); + let assemblyFormat = [{ $tag attr-dict (`:` `[` $inputs^ `]` `:` type($inputs))? }]; +} + def DebugAssertInsertionPass : Pass<"debug-assert-insertion", "::mlir::func::FuncOp"> { let dependentDialects = [ "::mlir::func::FuncDialect", @@ -590,6 +719,7 @@ def ApplyVectorLayoutPass : Pass<"tpu-apply-vector-layout", "::mlir::func::FuncO Option<"sublane_count", "sublane-count", "int", /*default=*/"8", "">, Option<"mxu_contracting_size", "mxu-contracting-size", "int", /*default=*/"128", "">, Option<"mxu_noncontracting_size", "mxu-noncontracting-size", "int", /*default=*/"128", "">, + Option<"max_sublanes_in_scratch", "max-sublanes-in-scratch", "int", /*default=*/"0", "">, ]; } @@ -602,7 +732,11 @@ def LinalgVectorizationPass : Pass<"linalg-vectorization", "::mlir::func::FuncOp "::mlir::vector::VectorDialect", "::mlir::tpu::TPUDialect", ]; - let constructor = "::mlir::tpu::createLinalgVectorizationPass()"; + let constructor = "::mlir::tpu::createLinalgVectorizationPass(false)"; + let options = [ + Option<"supports_bf16_alu_instructions", "supports-bf16-alu-instructions", "bool", "", "">, + Option<"supports_bf16_matmul", "supports-bf16-matmul", "bool", "", "">, + ]; } #endif // TPU_ATTRS diff --git a/jaxlib/mosaic/dialect/tpu/tpu_dialect.cc b/jaxlib/mosaic/dialect/tpu/tpu_dialect.cc index abc7cc595cd6..df00093fabe6 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu_dialect.cc +++ b/jaxlib/mosaic/dialect/tpu/tpu_dialect.cc @@ -189,6 +189,9 @@ bool isGuaranteedDivisible(Value value, int64_t divisor, int64_t fuel) { if (fuel <= 0) { return false; } + if (divisor == 1) { + return true; + } if (auto assume_op = value.getDefiningOp()) { return assume_op.getMultiple() % divisor == 0; } diff --git a/jaxlib/mosaic/dialect/tpu/tpu_dialect.h b/jaxlib/mosaic/dialect/tpu/tpu_dialect.h index c55268f673d6..dc5b68246e3f 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu_dialect.h +++ b/jaxlib/mosaic/dialect/tpu/tpu_dialect.h @@ -58,12 +58,15 @@ std::unique_ptr> createInferVectorLayoutPass( std::unique_ptr> createApplyVectorLayoutPass( int hardware_generation = -1, int lane_count = 128, int sublane_count = 8, - int mxu_contracting_size = 128, int mxu_noncontracting_size = 128); + int mxu_contracting_size = 128, int mxu_noncontracting_size = 128, + int max_sublanes_in_scratch = 0); std::unique_ptr> createLogicalToPhysicalDeviceIdPass(int64_t total_devices); -std::unique_ptr> createLinalgVectorizationPass(); +std::unique_ptr> createLinalgVectorizationPass( + bool supports_bf16_alu_instructions = false, + bool supports_bf16_matmul = false); std::unique_ptr> createDebugAssertInsertionPass(); diff --git a/jaxlib/mosaic/dialect/tpu/tpu_ops.cc b/jaxlib/mosaic/dialect/tpu/tpu_ops.cc index be22d809c3ee..8cea43739719 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu_ops.cc +++ b/jaxlib/mosaic/dialect/tpu/tpu_ops.cc @@ -24,6 +24,7 @@ limitations under the License. #include "mlir/IR/Value.h" #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" +#include "mlir/include/mlir/IR/BuiltinTypes.h" #include "mlir/include/mlir/IR/IRMapping.h" #include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h" @@ -65,7 +66,8 @@ LogicalResult MemRefSliceOp::verify() { (target_memory_space == nullptr || target_memory_space == source_type.getMemorySpace()) && ((isa(target_layout) && target_layout.isIdentity()) || - target_type.getLayout() == source_type.getLayout())); + target_type.getLayout() == source_type.getLayout()) && + getDynamicSizes().size() == target_type.getNumDynamicDims()); } LogicalResult MemRefSliceOp::canonicalize(MemRefSliceOp op, @@ -81,8 +83,9 @@ LogicalResult MemRefSliceOp::canonicalize(MemRefSliceOp op, auto new_result_type = MemRefType::get( op.getResult().getType().getShape(), layout_ty.getElementType(), layout_ty.getLayout(), layout_ty.getMemorySpace()); - auto slice = rewriter.create(op.getLoc(), new_result_type, - layout_ref, op.getBaseIdx()); + auto slice = + rewriter.create(op.getLoc(), new_result_type, layout_ref, + op.getBaseIdx(), op.getDynamicSizes()); rewriter.replaceOpWithNewOp(op, op.getType(), slice); return success(); } @@ -180,6 +183,45 @@ LogicalResult MemRefSqueezeOp::canonicalize(MemRefSqueezeOp op, return success(); } +template +LogicalResult verifyStridedOp(Op op, MemRefType memref_ty, + VectorType vector_ty) { + auto indices = op.getIndices(); + auto strides = op.getStrides(); + if (memref_ty.getRank() != indices.size()) { + op.emitError("Base memref's rank and indices size do not match: ") + << memref_ty.getRank() << " vs " << indices.size(); + return failure(); + } + if (memref_ty.getRank() != strides.size()) { + op.emitError("Base memref's rank and strides size do not match: ") + << memref_ty.getRank() << " vs " << strides.size(); + return failure(); + } + if (memref_ty.getRank() != vector_ty.getRank()) { + op.emitError("Base memref's rank and result's rank do not match: ") + << memref_ty.getRank() << " vs " << vector_ty.getRank(); + return failure(); + } + for (int64_t i = 0; i < memref_ty.getRank(); ++i) { + if (strides[i] < 1) { + op.emitError("Strides[") << i << "]=" << strides[i] << " must be >= 1"; + return failure(); + } + } + return success(); +} + +LogicalResult StridedLoadOp::verify() { + return verifyStridedOp(*this, getMemRefType(getBase()), + getType()); +} + +LogicalResult StridedStoreOp::verify() { + return verifyStridedOp(*this, getMemRefType(getBase()), + getValueToStore().getType()); +} + LogicalResult ReinterpretCastOp::verify() { auto source_type = getMemRefType(getInput()); auto target_type = getType(); @@ -188,28 +230,26 @@ LogicalResult ReinterpretCastOp::verify() { source_type.getMemorySpace() == target_type.getMemorySpace()); } -LogicalResult RotateOp::verify() { - auto vty = getResult().getType(); - if (vty.getRank() <= getDimension() || getDimension() < 0) { - emitOpError("Invalid dimension: ") << getDimension(); +template +LogicalResult verifyRotateOp(Op op) { + auto vty = op.getResult().getType(); + if (vty.getRank() <= op.getDimension() || op.getDimension() < 0) { + op.emitOpError("Invalid dimension: ") << op.getDimension(); return failure(); } - if (getAmount() < 0) { - emitOpError("Rotate amount must be >= 0"); + if (op.getStride().has_value() && op.getStride().value() < 0) { + op.emitOpError("Rotate stride must be >= 0 if it is specified"); return failure(); } - if (getStride().has_value() && getStride().value() < 0) { - emitOpError("Rotate stride must be >= 0 if it is specified"); + if (op.getStrideDimension().has_value() && + (vty.getRank() <= op.getStrideDimension().value() || + op.getStrideDimension().value() < 0)) { + op.emitOpError("Invalid stride dimension: ") + << op.getStrideDimension().value(); return failure(); } - if (getStrideDimension().has_value() && - (vty.getRank() <= getStrideDimension().value() || - getStrideDimension().value() < 0)) { - emitOpError("Invalid stride dimension: ") << getStrideDimension().value(); - return failure(); - } - if (getStride().has_value() != getStrideDimension().has_value()) { - emitOpError( + if (op.getStride().has_value() != op.getStrideDimension().has_value()) { + op.emitOpError( "Expected either none or both stride and stride dimension are " "present"); return failure(); @@ -217,6 +257,13 @@ LogicalResult RotateOp::verify() { return success(); } +// TODO(b/347016737): deprecate static rotate +LogicalResult RotateOp::verify() { return verifyRotateOp(*this); } + +LogicalResult DynamicRotateOp::verify() { + return verifyRotateOp(*this); +} + // a + matmul(l, r, 0) == matmul(l, r, a) template class CanonicalizeAddOfMatmul : public OpRewritePattern { @@ -275,8 +322,7 @@ LogicalResult GetBarrierSemaphoreOp::verify() { LogicalResult SemaphoreSignalOp::verify() { auto sem_type = getMemRefType(getSemaphore()); if (sem_type.getRank() != 0) { - emitOpError("Semaphore reference must be rank 0"); - return failure(); + return emitOpError("Semaphore reference must be rank 0"); } return success(); } @@ -286,14 +332,19 @@ LogicalResult EnqueueDMAOp::verify() { if (source_sem) { auto source_sem_type = getMemRefType(getSourceSemaphore()); if (source_sem_type.getRank() != 0) { - emitOpError("DMA source semaphore reference must be rank 0"); - return failure(); + return emitOpError("DMA source semaphore reference must be rank 0"); } } auto target_sem_type = getMemRefType(getTargetSemaphore()); if (target_sem_type.getRank() != 0) { - emitOpError("DMA target semaphore must be rank 0"); - return failure(); + return emitOpError("DMA target semaphore must be rank 0"); + } + if (getDeviceId() || getCoreId()) { + if (!getSourceSemaphore()) { + return emitOpError( + "DMA source semaphore must be specified when " + "device_id or core_id is specified"); + } } return success(); } diff --git a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc index d05ca3024e13..f0c568aeea48 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc @@ -7,10 +7,10 @@ #include #include #include +#include #include #include #include -#include #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVectorExtras.h" @@ -110,6 +110,14 @@ namespace mlir::tpu { #define TPU_ASSERT_LE_LOC(loc, lhs, rhs) \ TPU_ASSERT_CMP_LOC_IMPL(mlir::emitError(loc), lhs, rhs, <=) +// The minimum bound required to rotate with scratch space. The bound refers to +// the number of VREGs on rotation dim. This number was concluded from some cost +// analysis for comparing different dynamic rotation implementations. If +// actual bound is greater than this, dynamic rotation with internal scratch +// space is more efficient. +// TODO(jevinjiang): need to update it based on the generation. +static constexpr int kMinBoundToRotateWithScratch = 27; + LogicalResult applyLayoutBlock(RewriteContext &ctx, Block &block); namespace { @@ -157,6 +165,36 @@ FailureOr maskOOB(RewriteContext &ctx, OpBuilder &builder, .getResult(); } +// Get the address of pre-allocated internal scratch space with requested shape. +// +// Arguments: +// shape: The shape of the requested scratch space. +// elem_ty: The type of the elements in the requested scratch space. +// +// Returns: +// A memref of the requested shape and type. +FailureOr getInternalScratch(RewriteContext &ctx, OpBuilder &builder, + Location loc, ArrayRef shape, + Type elem_ty) { + if (shape.empty()) { + return failure(); + } + if (shape.back() % ctx.target_shape[1] != 0) { + return failure(); + } + int sublane_count = + std::accumulate(shape.begin(), shape.end(), 1, std::multiplies()) / + ctx.target_shape[1]; + if (sublane_count > ctx.max_sublanes_in_scratch) { + return failure(); + } + FAILUREOR_ASSIGN_OR_RETURN( + MemRefType scratch_ref_ty, + inferMemref(MemRefType::get(shape, elem_ty), ctx.hardware_generation)); + return builder.create(loc, scratch_ref_ty) + .getResult(); +} + // Models Numpy's np.repeat, repeating each element `repeats` times along the // specified axis. For example, if `src` is [1, 2], `axis` is 0 and `repeats` is // 3, this will return [1, 1, 1, 2, 2, 2]. @@ -204,6 +242,21 @@ xla::Array concatenate(const ArrayRef> arrays, return res; } +SmallVector> split(const xla::Array &vregs, int axis) { + CHECK(axis >= 0 && axis < vregs.num_dimensions()); + SmallVector> chunks; + chunks.reserve(vregs.dim(axis)); + SmallVector starts(vregs.num_dimensions(), 0); + SmallVector limits(vregs.dimensions().begin(), + vregs.dimensions().end()); + for (int64_t i = 0; i < vregs.dim(axis); ++i) { + starts[axis] = i; + limits[axis] = i + 1; + chunks.push_back(vregs.Slice(starts, limits)); + } + return chunks; +}; + template ArrayRef XlaArrayToFlatArrayRef(xla::Array xla_array) { return ArrayRef(xla_array.data(), xla_array.num_elements()); @@ -249,7 +302,7 @@ bool incrementIndex(const MutableArrayRef idx, } bool sliceIsEmpty(const absl::Span starts, - const absl::Span limits) { + const absl::Span limits) { for (auto [s, l] : llvm::zip_equal(starts, limits)) { CHECK_LE(s, l); if (s == l) { @@ -282,9 +335,19 @@ void updateSliceFromRange(xla::Array &arr, Range data, return; } SmallVector idx(toArrayRef(starts)); + auto in_bounds = [&] { + for (int64_t i = 0; i < idx.size(); ++i) { + if (idx[i] >= arr.dim(i)) { + return false; + } + } + return true; + }; auto data_it = data.begin(); do { - arr(idx) = *data_it; + if (in_bounds()) { + arr(idx) = *data_it; + } ++data_it; } while (incrementSliceIndex(idx, starts, limits)); CHECK(data_it == data.end()); @@ -367,7 +430,7 @@ FailureOr>> sliceRef( Value sliced_ref = builder.create( MemRefType::get(pad_slice_shape, ref_ty.getElementType(), ref_ty.getLayout(), ref_ty.getMemorySpace()), - base_ref, slice_base_indices); + base_ref, slice_base_indices, /*dynamic_sizes=*/ValueRange()); return std::make_pair(sliced_ref, indices_within_slice); } @@ -489,18 +552,34 @@ FailureOr appendConstant(RewriteContext &ctx, return argument; } -FailureOr getNativeVregType( - Type elem_ty, const std::array target_shape) { - FAILUREOR_ASSIGN_OR_RETURN(const int8_t bitwidth, - getTypeBitwidth(elem_ty)); +FailureOr getNativeVregOrVmaskTypeImpl( + Type elem_ty, const int8_t bitwidth, + const std::array target_shape) { if (bitwidth == 32) { return VectorType::get(target_shape, elem_ty); } - // bitwidth != 32 return VectorType::get({target_shape[0], target_shape[1], 32 / bitwidth}, elem_ty); } +FailureOr getNativeVregOrVmaskType( + Type elem_ty, const int8_t layout_bitwidth, + const std::array target_shape) { + int8_t bitwidth = elem_ty.getIntOrFloatBitWidth(); + if (bitwidth == 1) { + bitwidth = layout_bitwidth; + } else { + CHECK_EQ(bitwidth, layout_bitwidth); + } + return getNativeVregOrVmaskTypeImpl(elem_ty, bitwidth, target_shape); +} + +FailureOr getNativeVregType( + Type elem_ty, const std::array target_shape) { + return getNativeVregOrVmaskTypeImpl(elem_ty, elem_ty.getIntOrFloatBitWidth(), + target_shape); +} + // Returns empty vector on null attribute FailureOr> getLayoutArrayFromAttr(const Attribute attr) { if (const auto array_attr = dyn_cast_if_present(attr)) { @@ -522,8 +601,21 @@ bool layoutIsValidForValue(const Layout &l, const Value v, const std::array target_shape) { // l must be non-null iff v is of vector type if (const auto vty = dyn_cast(v.getType())) { - return l.has_value() && l->isValid(target_shape) && - l->layout_rank() <= vty.getRank(); + if (!l.has_value()) { + return false; + } + + // Vector type should have the same bitwidth as the layout, except for the + // i1 special case, used for vmasks (see comment for VectorLayout class). + if (!vty.getElementType().isIntOrFloat()) { + return false; + } + const int8_t bitwidth = vty.getElementTypeBitWidth(); + if (bitwidth != l->bitwidth() && bitwidth != 1) { + return false; + } + + return l->isValid(target_shape) && l->layout_rank() <= vty.getRank(); } return !l.has_value(); } @@ -600,7 +692,8 @@ LogicalResult elementwise_op_rule(RewriteContext &ctx, Operation &op, FAILUREOR_ASSIGN_OR_RETURN( const VectorType out_vreg_ty, - getNativeVregType(out_ty.getElementType(), ctx.target_shape)); + getNativeVregOrVmaskType(out_ty.getElementType(), layout_out.bitwidth(), + ctx.target_shape)); NamedAttrList attributes(op.getAttrDictionary()); attributes.erase("in_layout"); @@ -658,69 +751,62 @@ LogicalResult ext_op_rule_impl(RewriteContext &ctx, OpTy op, const auto result_ty = cast(op.getResult().getType()); auto source = cast>(op.getIn()); const auto source_ty = source.getType(); + auto output_vregs_shape = + layout_out.tileArrayShape(result_ty.getShape(), ctx.target_shape); if (layout_out.bitwidth() != 32) { return op.emitOpError( "Not implemented: Only extensions to 32-bit supported"); } FAILUREOR_ASSIGN_OR_RETURN( - const xla::Array input_vregs, + xla::Array input_vregs, disassemble(builder, layout_in, source, ctx.target_shape)); - xla::Array output_vregs( - layout_out.tileArrayShape(result_ty.getShape(), ctx.target_shape)); + xla::Array output_vregs(output_vregs_shape); + // TODO(jevinjiang): maybe just use tileArrayImplicitShape in disassemble? + if (layout_in.implicit_dim() != VectorLayout::ImplicitDim::kNone) { + input_vregs.Reshape(layout_in.tileArrayImplicitShape(source_ty.getShape(), + ctx.target_shape)); + output_vregs.Reshape(layout_out.tileArrayImplicitShape(result_ty.getShape(), + ctx.target_shape)); + } FAILUREOR_ASSIGN_OR_RETURN( const VectorType res_vreg_ty, getNativeVregType(result_ty.getElementType(), ctx.target_shape)); if (layout_in.implicit_dim() != layout_out.implicit_dim()) { - return op.emitOpError("Not implemented: Change of layout during the cast"); + return op.emitOpError( + "Not implemented: Change of implicit dim during the cast"); } if (layout_in.offsets() != layout_out.offsets()) { return op.emitOpError("Not implemented: Change of offsets during the cast"); } - switch (layout_in.implicit_dim()) { - case VectorLayout::ImplicitDim::kNone: { - if (layout_in.tiling() != layout_out.tiling()) { - return op.emitOpError( - "Not implemented: Changing tiling during the cast"); - } - auto tiling = layout_in.tiling(); - if (ctx.target_shape[0] % tiling[0] != 0 || - ctx.target_shape[1] != tiling[1]) { - return op.emitOpError("Not implemented: tiling not supported"); - } - const int packing = layout_in.packing(); - output_vregs.Each([&](absl::Span idxs, Value *v) { - SmallVector input_vreg_idxs(toArrayRef(idxs)); - input_vreg_idxs.back() /= packing; - const int64_t vreg_part = idxs.back() % packing; - *v = builder.create( - res_vreg_ty, input_vregs(input_vreg_idxs), vreg_part); - }); - } break; - case VectorLayout::ImplicitDim::kMinor: - return op.emitOpError( - "Not implemented: Only casts of lane-oriented values supported"); - case VectorLayout::ImplicitDim::kSecondMinor: { - auto is_one_tile = [](VectorType vty, VectorLayout layout) { - auto implicit_shape = layout.implicitShape(vty.getShape()); - auto tiled_shape = ArrayRef(implicit_shape).take_back(2); - return (layout.offsets()[0].value_or(0) + tiled_shape[0] <= - layout.tiling()[0]) && - (layout.offsets()[1].value_or(0) + tiled_shape[1] <= - layout.tiling()[1]); - }; - if (input_vregs.dimensions() != absl::Span{1} || - output_vregs.dimensions() != absl::Span{1} || - !is_one_tile(source_ty, layout_in) || - !is_one_tile(result_ty, layout_out)) { - return op.emitOpError("Not implemented"); - } - if (layout_in.offsets()[0] >= ctx.target_shape[0]) { - return op.emitOpError("Not implemented"); - } - auto unpack_subelements_op = builder.create( - res_vreg_ty, *input_vregs.begin(), 0); - output_vregs.Fill(unpack_subelements_op.getResult()); + const int packing = layout_in.packing(); + if (layout_in.hasNativeTiling(ctx.target_shape) && + layout_out.hasNativeTiling(ctx.target_shape)) { + output_vregs.Each([&](absl::Span idxs, Value *v) { + SmallVector input_vreg_idxs(toArrayRef(idxs)); + int64_t vreg_part = *(input_vreg_idxs.end() - 2) % packing; + *(input_vreg_idxs.end() - 2) /= packing; + *v = builder.create( + res_vreg_ty, input_vregs(input_vreg_idxs), vreg_part); + }); + } else { + if (layout_in.tiling() != layout_out.tiling()) { + return op.emitOpError("Not implemented: Changing tiling during the cast"); } + auto tiling = layout_in.tiling(); + if (ctx.target_shape[0] % tiling[0] != 0 || + ctx.target_shape[1] != tiling[1]) { + return op.emitOpError("Not implemented: tiling not supported"); + } + output_vregs.Each([&](absl::Span idxs, Value *v) { + SmallVector input_vreg_idxs(toArrayRef(idxs)); + input_vreg_idxs.back() /= packing; + const int64_t vreg_part = idxs.back() % packing; + *v = builder.create( + res_vreg_ty, input_vregs(input_vreg_idxs), vreg_part); + }); + } + if (layout_out.implicit_dim() != VectorLayout::ImplicitDim::kNone) { + output_vregs.Reshape(output_vregs_shape); } op.replaceAllUsesWith(assemble(builder, result_ty, layout_out, std::move(output_vregs), ctx.target_shape) @@ -762,73 +848,87 @@ LogicalResult trunc_op_rule_impl(RewriteContext &ctx, OpTy op, const VectorLayout &layout_in, const VectorLayout &layout_out) { ImplicitLocOpBuilder builder(op.getLoc(), op.getOperation()); + auto source = cast>(op.getIn()); + const auto source_ty = source.getType(); auto result_ty = cast(op.getResult().getType()); + auto output_vregs_shape = + layout_out.tileArrayShape(result_ty.getShape(), ctx.target_shape); FAILUREOR_ASSIGN_OR_RETURN( - const xla::Array input_vregs, - disassemble(builder, layout_in, cast>(op.getIn()), - ctx.target_shape)); - xla::Array output_vregs( - layout_out.tileArrayShape(result_ty.getShape(), ctx.target_shape)); + xla::Array input_vregs, + disassemble(builder, layout_in, source, ctx.target_shape)); + xla::Array output_vregs(output_vregs_shape); if (layout_in.bitwidth() != 32) { return op.emitOpError("Not implemented: Only 32-bit truncation supported"); } + if (layout_in.offsets() != layout_out.offsets()) { + return op.emitOpError( + "Not implemented: Change of offsets during the truncation"); + } + if (layout_in.implicit_dim() != layout_out.implicit_dim()) { + return op.emitOpError("Not implemented: Change of layout during the cast"); + } + if (layout_in.tiling() != ctx.target_shape) { + return op.emitOpError("Not implemented: Only (8,128) tiling supported"); + } + if (layout_in.implicit_dim() != VectorLayout::ImplicitDim::kNone) { + input_vregs.Reshape(layout_in.tileArrayImplicitShape(source_ty.getShape(), + ctx.target_shape)); + output_vregs.Reshape(layout_out.tileArrayImplicitShape(result_ty.getShape(), + ctx.target_shape)); + } FAILUREOR_ASSIGN_OR_RETURN( VectorType res_vreg_ty, getNativeVregType(result_ty.getElementType(), ctx.target_shape)); - if (layout_in.implicit_dim() == VectorLayout::ImplicitDim::kNone && - layout_out.implicit_dim() == VectorLayout::ImplicitDim::kNone) { - if (layout_in.tiling() != ctx.target_shape) { - return op.emitOpError("Not implemented: Only (8,128) tiling supported"); - } - if (layout_out.tiling() == ctx.target_shape) { - const int packing = layout_out.packing(); - output_vregs.Each([&](absl::Span idxs, Value *v) { - SmallVector parts; - SmallVector idxs_local(toArrayRef(idxs)); - idxs_local.back() *= packing; - for (int64_t i = 0; i < packing; ++i) { - parts.push_back(input_vregs(idxs_local)); - // Pack any data lying around if OOB - if (idxs_local.back() < input_vregs.dimensions().back() - 1) { - ++idxs_local.back(); - } - } - *v = builder.create(res_vreg_ty, parts); - }); - - } else if (layout_out.hasNativeTiling(ctx.target_shape)) { - int packing = layout_out.packing(); + if (layout_out.tiling() == ctx.target_shape) { + const int packing = layout_out.packing(); + output_vregs.Each([&](absl::Span idxs, Value *v) { SmallVector parts; - parts.reserve(packing); - output_vregs.Each([&](absl::Span idxs, Value *v) { - CHECK_GE(idxs.size(), 2); - SmallVector idxs_local(toArrayRef(idxs)); - idxs_local[idxs.size() - 2] *= packing; + SmallVector idxs_local(toArrayRef(idxs)); + idxs_local.back() *= packing; + for (int64_t i = 0; i < packing; ++i) { parts.push_back(input_vregs(idxs_local)); - idxs_local[idxs.size() - 2]++; - while (parts.size() < packing) { - if (*(idxs_local.end() - 2) < *(input_vregs.dimensions().end() - 2)) { - parts.push_back(input_vregs(idxs_local)); - idxs_local[idxs.size() - 2]++; - } else { - // Once we run out of tiles, we can pick any one we like. - parts.push_back(parts.back()); - } + // Pack any data lying around if OOB + if (idxs_local.back() < input_vregs.dimensions().back() - 1) { + ++idxs_local.back(); } - *v = builder.create(res_vreg_ty, parts); - parts.clear(); - }); - } else { - return op.emitOpError("Not implemented: unsupported output tiling"); - } - op.replaceAllUsesWith(assemble(builder, result_ty, layout_out, - std::move(output_vregs), ctx.target_shape) - .getResult()); - op.erase(); - return success(); + } + *v = builder.create(res_vreg_ty, parts, + tpu::PackFormat::kCompressed); + }); + } else if (layout_out.hasNativeTiling(ctx.target_shape)) { + int packing = layout_out.packing(); + SmallVector parts; + parts.reserve(packing); + output_vregs.Each([&](absl::Span idxs, Value *v) { + CHECK_GE(idxs.size(), 2); + SmallVector idxs_local(toArrayRef(idxs)); + idxs_local[idxs.size() - 2] *= packing; + parts.push_back(input_vregs(idxs_local)); + idxs_local[idxs.size() - 2]++; + while (parts.size() < packing) { + if (*(idxs_local.end() - 2) < *(input_vregs.dimensions().end() - 2)) { + parts.push_back(input_vregs(idxs_local)); + idxs_local[idxs.size() - 2]++; + } else { + // Once we run out of tiles, we can pick any one we like. + parts.push_back(parts.back()); + } + } + *v = builder.create(res_vreg_ty, parts, + tpu::PackFormat::kCompressed); + parts.clear(); + }); + } else { + return op.emitOpError("Not implemented: unsupported output tiling"); + } + if (layout_out.implicit_dim() != VectorLayout::ImplicitDim::kNone) { + output_vregs.Reshape(output_vregs_shape); } - // TODO(tlongeri): why wasn't this part of the original code? - return op.emitOpError("Not implemented"); + op.replaceAllUsesWith(assemble(builder, result_ty, layout_out, + std::move(output_vregs), ctx.target_shape) + .getResult()); + op.erase(); + return success(); } LogicalResult arith_truncf_rule(RewriteContext &ctx, Operation &op, @@ -840,9 +940,10 @@ LogicalResult arith_truncf_rule(RewriteContext &ctx, Operation &op, TPU_ASSERT_OP(layouts_out.front().has_value()); auto truncf_op = cast(op); if (layouts_in.front()->bitwidth() != 32 || - layouts_out.front()->bitwidth() != 16) { + (layouts_out.front()->bitwidth() != 16 && + layouts_out.front()->bitwidth() != 8)) { return op.emitOpError( - "Not implemented: Only 32-bit to 16-bit conversion supported"); + "Not implemented: Only 32-bit to 16-or-8-bit conversion supported"); } return trunc_op_rule_impl(ctx, truncf_op, *layouts_in.front(), *layouts_out.front()); @@ -878,16 +979,35 @@ LogicalResult scf_for_rule(RewriteContext &ctx, Operation &op, scf::ForOp for_op = cast(op); TPU_ASSERT_EQ_OP(layouts_in.size(), for_op->getNumOperands()); TPU_ASSERT_EQ_OP(layouts_out.size(), for_op->getNumResults()); - if (!llvm::equal(layouts_in.drop_front(3), layouts_out)) { - return op.emitOpError( - "Expected matched layouts in scf.for's inputs and outputs"); - } FAILUREOR_ASSIGN_OR_RETURN( const SmallVector yield_in_layouts, getInLayouts(*for_op.getBody()->getTerminator(), ctx.target_shape)); - if (!llvm::equal(ArrayRef(yield_in_layouts), layouts_out)) { - return op.emitOpError( - "Expected matched layouts in scf.yield operands and scf.for's results"); + int out_idx = 0; + for (auto [in_layout, yield_layout, out_layout, result] : + llvm::zip_equal(layouts_in.drop_front(3), yield_in_layouts, layouts_out, + op.getResults())) { + if (auto vty = dyn_cast(result.getType())) { + TPU_ASSERT_OP(in_layout.has_value()); + TPU_ASSERT_OP(yield_layout.has_value()); + TPU_ASSERT_OP(out_layout.has_value()); + if (in_layout.value() != yield_layout.value()) { + return op.emitOpError( + "Not implemented: for loop input layout does not match with " + "yield layout ") + << out_idx; + } + if (in_layout.value() != out_layout.value()) { + return op.emitOpError( + "Not implemented: for loop input layout does not match with " + "out layout ") + << out_idx; + } + } else { + TPU_ASSERT_EQ_OP(in_layout, kNoLayout); + TPU_ASSERT_EQ_OP(yield_layout, kNoLayout); + TPU_ASSERT_EQ_OP(out_layout, kNoLayout); + } + ++out_idx; } if (failed(applyLayoutBlock(ctx, *for_op.getBody()))) { @@ -1000,6 +1120,208 @@ LogicalResult scf_for_rule(RewriteContext &ctx, Operation &op, return success(); } +LogicalResult scf_while_rule(RewriteContext &ctx, Operation &op, + const ArrayRef layouts_in, + const ArrayRef layouts_out) { + scf::WhileOp while_op = cast(op); + TPU_ASSERT_EQ_OP(layouts_in.size(), while_op->getNumOperands()); + TPU_ASSERT_EQ_OP(layouts_out.size(), while_op->getNumResults()); + TPU_ASSERT_EQ_OP(layouts_in.size(), layouts_out.size()); + + // The terminator for the before region is the condition op. + // It takes multiple arguments -- the first being the decision to execute the + // after region or branch to the exit. + FAILUREOR_ASSIGN_OR_RETURN( + const SmallVector cond_in_layouts, + getInLayouts(*while_op.getBeforeBody()->getTerminator(), + ctx.target_shape)); + + FAILUREOR_ASSIGN_OR_RETURN( + const SmallVector yield_in_layouts, + getInLayouts(*while_op.getYieldOp(), ctx.target_shape)); + int out_idx = 0; + for (auto [in_layout, cond_layout, yield_layout, out_layout, result] : + llvm::zip_equal(layouts_in, + ArrayRef(cond_in_layouts).drop_front(1), + yield_in_layouts, layouts_out, op.getResults())) { + if (auto vty = dyn_cast(result.getType())) { + TPU_ASSERT_OP(in_layout.has_value()); + TPU_ASSERT_OP(yield_layout.has_value()); + TPU_ASSERT_OP(out_layout.has_value()); + if (in_layout.value() != cond_layout.value()) { + return op.emitOpError( + "Not implemented: while loop input layout does not match " + "with condition layout ") + << out_idx; + } + if (in_layout.value() != yield_layout.value()) { + return op.emitOpError( + "Not implemented: while loop input layout does not match " + "with yield layout ") + << out_idx; + } + if (in_layout.value() != out_layout.value()) { + return op.emitOpError( + "Not implemented: while loop input layout does not match " + "with output layout ") + << out_idx; + } + } else { + TPU_ASSERT_EQ_OP(in_layout, kNoLayout); + TPU_ASSERT_EQ_OP(cond_layout, kNoLayout); + TPU_ASSERT_EQ_OP(yield_layout, kNoLayout); + TPU_ASSERT_EQ_OP(out_layout, kNoLayout); + } + ++out_idx; + } + + if (failed(applyLayoutBlock(ctx, *while_op.getBeforeBody()))) { + return failure(); + } + + if (failed(applyLayoutBlock(ctx, *while_op.getAfterBody()))) { + return failure(); + } + + if (op.getNumResults() == 0) { + return success(); + } + + OpBuilder builder(&op); + SmallVector unrolled_args; + for (int i = 0; i < layouts_in.size(); ++i) { + auto layout = layouts_in[i]; + auto operand = while_op.getOperand(i); + if (auto vector_operand = dyn_cast>(operand)) { + if (!layout.has_value()) { + return op.emitOpError("Expected layout for vector operand"); + } + FAILUREOR_ASSIGN_OR_RETURN( + const xla::Array tiles, + disassemble(builder, *layout, vector_operand, ctx.target_shape)); + unrolled_args.append(tiles.begin(), tiles.end()); + } else { + if (layout.has_value()) { + return op.emitOpError("Expected no layout for scalar operand"); + } + unrolled_args.push_back(operand); + } + } + + // Create a new scf::WhileOp with unrolled args. + auto new_op = builder.create( + while_op->getLoc(), + TypeRange(while_op.getConditionOp().getOperands().drop_front(1)), + unrolled_args, nullptr, nullptr); + + const auto tile_body_args = [&](::mlir::Block *old_body, + ::mlir::Block *new_body, + const ArrayRef layouts) { + TPU_ASSERT_OP(old_body != nullptr); + TPU_ASSERT_OP(new_body != nullptr); + int num_old_args = old_body->getNumArguments(); + SmallVector locs(new_body->getNumArguments(), while_op.getLoc()); + old_body->addArguments(TypeRange(new_body->getArguments()), locs); + builder.setInsertionPointToStart(old_body); + auto arg_idx = num_old_args; + for (auto [old_arg, layout] : llvm::zip_equal( + old_body->getArguments().take_front(num_old_args), layouts)) { + if (const auto vty = dyn_cast(old_arg.getType())) { + TPU_ASSERT_OP(layout.has_value()); + const SmallVector tiles_shape = + layout->tileArrayShape(vty.getShape(), ctx.target_shape); + const int64_t num_vectors = ShapedType::getNumElements(tiles_shape); + xla::Array tiles(tiles_shape); + TPU_ASSERT_LE_OP(arg_idx + num_vectors, old_body->getNumArguments()); + tiles.SetValues(llvm::make_range( + old_body->getArguments().begin() + arg_idx, + old_body->getArguments().begin() + arg_idx + num_vectors)); + arg_idx += num_vectors; + RollVectorsOp rolled_op = + assemble(builder, vty, *layout, tiles, ctx.target_shape); + old_arg.replaceUsesWithIf(rolled_op, [&](OpOperand &operand) { + return operand.getOwner() != rolled_op; + }); + } else { + TPU_ASSERT_OP(!layout.has_value()); + old_arg.replaceAllUsesWith(old_body->getArgument(arg_idx)); + ++arg_idx; + } + } + old_body->eraseArguments(0, num_old_args); + return success(); + }; + + const auto before_status = tile_body_args(while_op.getBeforeBody(), + new_op.getBeforeBody(), layouts_in); + if (before_status.failed()) return before_status; + new_op.getBefore().takeBody(while_op.getBefore()); + + const auto after_status = tile_body_args(while_op.getAfterBody(), + new_op.getAfterBody(), layouts_out); + if (after_status.failed()) return after_status; + new_op.getAfter().takeBody(while_op.getAfter()); + + builder.setInsertionPointAfter(new_op); + int64_t res_idx = 0; + SmallVector rolled_results; + for (auto [result, layout] : + llvm::zip_equal(while_op.getResults(), layouts_out)) { + if (const auto vty = dyn_cast(result.getType())) { + TPU_ASSERT_OP(layout.has_value()); + const SmallVector tiles_shape = + layout->tileArrayShape(vty.getShape(), ctx.target_shape); + const int64_t num_vectors = ShapedType::getNumElements(tiles_shape); + xla::Array tiles(tiles_shape); + TPU_ASSERT_LE_OP(res_idx + num_vectors, new_op.getResults().size()); + tiles.SetValues(llvm::make_range( + new_op.getResults().begin() + res_idx, + new_op.getResults().begin() + res_idx + num_vectors)); + res_idx += num_vectors; + RollVectorsOp rolled_op = + assemble(builder, vty, *layout, tiles, ctx.target_shape); + rolled_results.push_back(rolled_op); + } else { + TPU_ASSERT_OP(!layout.has_value()); + rolled_results.push_back(new_op.getResult(res_idx)); + ++res_idx; + } + } + + while_op.replaceAllUsesWith(rolled_results); + while_op.erase(); + return success(); +} + +LogicalResult scf_condition_rule(RewriteContext &ctx, Operation &op, + const ArrayRef layouts_in, + const ArrayRef layouts_out) { + OpBuilder builder(&op); + auto condition_op = cast(op); + TPU_ASSERT_EQ_OP(layouts_in.size(), condition_op.getNumOperands()); + TPU_ASSERT_EQ_OP(layouts_out.size(), 0); + SmallVector unrolled; + + for (auto [operand, layout] : + llvm::zip_equal(condition_op.getOperands(), layouts_in)) { + if (auto vector_operand = dyn_cast>(operand)) { + // When the operand has vector type, disassemble the operand. + TPU_ASSERT_OP(layout.has_value()); + FAILUREOR_ASSIGN_OR_RETURN( + const xla::Array tiles, + disassemble(builder, *layout, vector_operand, ctx.target_shape)); + unrolled.append(tiles.begin(), tiles.end()); + } else { + TPU_ASSERT_OP(!layout.has_value()); + unrolled.push_back(operand); + } + } + + // Replace the old operands with unrolled operands. + condition_op->setOperands(unrolled); + return success(); +} + LogicalResult scf_if_rule(RewriteContext &ctx, Operation &op, const ArrayRef layouts_in, const ArrayRef layouts_out) { @@ -1007,17 +1329,42 @@ LogicalResult scf_if_rule(RewriteContext &ctx, Operation &op, TPU_ASSERT_OP(!layouts_in.front().has_value()); ImplicitLocOpBuilder builder(op.getLoc(), &op); scf::IfOp if_op = cast(op); + SmallVector then_yield_in_layouts; + SmallVector else_yield_in_layouts; FAILUREOR_ASSIGN_OR_RETURN( - const SmallVector then_yield_in_layouts, + then_yield_in_layouts, getInLayouts(*if_op.thenYield(), ctx.target_shape)); - // TODO(tlongeri): ArrayRef conversion should not be necessary, fix - // after LLVM adds const qualifiers to ==/!= operators. Also - // applies to else_yield_in_layouts comparison below. - if (!layouts_out.empty() && - ArrayRef(then_yield_in_layouts) != layouts_out) { - return op.emitOpError( - "Not implemented: different layouts in then yield's operands and if's " - "results"); + if (!if_op.getElseRegion().empty()) { + FAILUREOR_ASSIGN_OR_RETURN( + else_yield_in_layouts, + getInLayouts(*if_op.elseYield(), ctx.target_shape)); + } + int out_idx = 0; + for (auto [then_layout, else_layout, result_layout, result] : + llvm::zip_equal(then_yield_in_layouts, else_yield_in_layouts, + layouts_out, op.getResults())) { + if (auto vty = dyn_cast(result.getType())) { + TPU_ASSERT_OP(then_layout.has_value()); + TPU_ASSERT_OP(else_layout.has_value()); + TPU_ASSERT_OP(result_layout.has_value()); + if (result_layout.value() != then_layout.value()) { + return op.emitOpError( + "Not implemented: yield layout from then branch does not " + "match with output layout ") + << out_idx; + } + if (result_layout.value() != else_layout.value()) { + return op.emitOpError( + "Not implemented: yield layout from else branch does not " + "match with output layout ") + << out_idx; + } + } else { + TPU_ASSERT_EQ_OP(then_layout, kNoLayout); + TPU_ASSERT_EQ_OP(else_layout, kNoLayout); + TPU_ASSERT_EQ_OP(result_layout, kNoLayout); + } + ++out_idx; } if (failed(applyLayoutBlock(ctx, *if_op.thenBlock()))) { return failure(); @@ -1027,15 +1374,6 @@ LogicalResult scf_if_rule(RewriteContext &ctx, Operation &op, TPU_ASSERT_EQ_OP(layouts_out.size(), 0); return success(); } - FAILUREOR_ASSIGN_OR_RETURN( - const SmallVector else_yield_in_layouts, - getInLayouts(*if_op.elseYield(), ctx.target_shape)); - if (!layouts_out.empty() && - ArrayRef(else_yield_in_layouts) != layouts_out) { - return op.emitOpError( - "Not implemented: different layouts in else yield's operands and if's " - "results"); - } if (failed(applyLayoutBlock(ctx, *if_op.elseBlock()))) { return failure(); } @@ -1135,7 +1473,7 @@ LogicalResult tpu_load_rule(RewriteContext &ctx, Operation &op, } tpu::LoadOp load_op = cast(op); if (layout_out != VectorLayout(32, {0, 0}, ctx.target_shape, - VectorLayout::ImplicitDim::kNone)) { + VectorLayout::ImplicitDim::kNone)) { return op.emitOpError("Invalid output layout for ") << load_op->getName(); } FAILUREOR_ASSIGN_OR_RETURN( @@ -1158,6 +1496,137 @@ LogicalResult tpu_load_rule(RewriteContext &ctx, Operation &op, return success(); } +LogicalResult strided_op_rule_impl(RewriteContext &ctx, Operation &op, + Value base_ref, ValueRange indices, + const VectorType &vty, + const VectorLayout &layout, + const ArrayRef &strides) { + if (!isa(op)) { + return op.emitOpError("Not implemented: Unsupported strided op") + << op.getName(); + } + if (layout != VectorLayout(32, {0, 0}, ctx.target_shape, + VectorLayout::ImplicitDim::kNone)) { + return op.emitOpError("Not implemented: Unsupported vector layout in ") + << op.getName(); + } + const auto base_ty = getMemRefType(base_ref); + auto rank = base_ty.getRank(); + CHECK_EQ(rank, indices.size()); + CHECK_EQ(rank, strides.size()); + CHECK_EQ(rank, vty.getShape().size()); + if (rank < 2) { + return op.emitOpError("Not implemented: Stride on 1D vector"); + } + auto mem_layout = dyn_cast(base_ty.getLayout()); + if (!mem_layout) { + return op.emitOpError("Expected a tiled memref"); + } + auto tile_strides = mem_layout.getTileStrides(); + + // Currently we hold constraints that the last dim size of memref needs to be + // exactly same as the lane size of native vreg and the memref has never + // been sliced before on the last dim. In other words, the original base + // memref's shape needs to be (..., target_shape[1]). + if (base_ty.getShape()[rank - 1] != ctx.target_shape[1] || + tile_strides.take_back(2) != ArrayRef{1, 1}) { + return op.emitOpError("Not Implemented: The last dim size is not ") + << ctx.target_shape[1] << " in original base memref"; + } + if (strides[rank - 1] != 1) { + return op.emitOpError("Not Implemented: Stride on last dim is not 1"); + } + auto last_idx = getIntConst(indices[rank - 1], /*silent=*/true); + if (failed(last_idx)) { + return op.emitOpError("Not Implemented: Dynamic index on last dim"); + } else if (last_idx.value() != 0) { + return op.emitOpError("Not Implemented: Index on last dim is not 0"); + } + ImplicitLocOpBuilder builder(op.getLoc(), &op); + + FAILUREOR_ASSIGN_OR_RETURN( + VectorType vreg_ty, + getNativeVregType(vty.getElementType(), ctx.target_shape)); + + bool is_load_op = true; + xla::Array tiles( + layout.tileArrayShape(vty.getShape(), ctx.target_shape)); + if (auto store_op = dyn_cast(op)) { + is_load_op = false; + FAILUREOR_ASSIGN_OR_RETURN( + tiles, disassemble(builder, layout, store_op.getValueToStore(), + ctx.target_shape)); + } + + tiles.Each([&](absl::Span tile_idxs, Value *v) { + CHECK_EQ(tile_idxs.size(), rank); + SmallVector idxs(rank); + for (int64_t i = 0; i < rank; ++i) { + int64_t stride = (i < rank - 2) + ? strides[i] + : (strides[i] * ctx.target_shape[i - rank + 2]); + idxs[i] = builder.create( + indices[i], IdxConst(tile_idxs[i] * stride, builder, op.getLoc())); + } + SmallVector sublane_mask(ctx.target_shape[0], true); + int64_t sublane_rem = vty.getDimSize(rank - 2) % ctx.target_shape[0]; + if (sublane_rem > 0 && tile_idxs[rank - 2] == tiles.dim(rank - 2) - 1) { + for (int64_t i = sublane_rem; i < ctx.target_shape[0]; ++i) { + sublane_mask[i] = false; + } + } + const auto sublane_mask_attr = + DenseBoolArrayAttr::get(op.getContext(), sublane_mask); + if (is_load_op) { + *v = builder.create( + vreg_ty, base_ref, idxs, sublane_mask_attr, + builder.getI32IntegerAttr(strides[rank - 2])); + } else { + builder.create( + *v, base_ref, idxs, sublane_mask_attr, + /*mask=*/nullptr, builder.getI32IntegerAttr(strides[rank - 2])); + } + }); + if (is_load_op) { + op.replaceAllUsesWith( + assemble(builder, vty, layout, std::move(tiles), ctx.target_shape)); + } + op.erase(); + return success(); +} + +// TODO(jevinjiang): maybe unify with vector load? +LogicalResult tpu_strided_load_rule(RewriteContext &ctx, Operation &op, + const ArrayRef layouts_in, + const ArrayRef layouts_out) { + TPU_ASSERT_OP(llvm::none_of(layouts_in, + [&](const Layout &l) { return l.has_value(); })); + TPU_ASSERT_EQ_OP(layouts_out.size(), 1); + TPU_ASSERT_OP(layouts_out.front().has_value()); + const VectorLayout &layout_out = *layouts_out.front(); + auto load_op = cast(op); + const auto vty = cast(load_op.getResult().getType()); + return strided_op_rule_impl(ctx, op, load_op.getBase(), load_op.getIndices(), + vty, layout_out, load_op.getStrides()); +} + +// TODO(jevinjiang): maybe unify with vector store? +LogicalResult tpu_strided_store_rule(RewriteContext &ctx, Operation &op, + const ArrayRef layouts_in, + const ArrayRef layouts_out) { + TPU_ASSERT_OP(layouts_in.front().has_value()); + TPU_ASSERT_OP(llvm::none_of(layouts_in.drop_front(), + [&](const Layout &l) { return l.has_value(); })); + TPU_ASSERT_EQ_OP(layouts_out.size(), 0); + + const VectorLayout &to_store_layout = *layouts_in.front(); + auto store_op = cast(op); + const auto vty = store_op.getValueToStore().getType(); + return strided_op_rule_impl(ctx, op, store_op.getBase(), + store_op.getIndices(), vty, to_store_layout, + store_op.getStrides()); +} + LogicalResult matmul_rule_impl(RewriteContext &ctx, Operation &op, const bool transpose_lhs, const bool transpose_rhs, @@ -1560,7 +2029,7 @@ LogicalResult tpu_assume_layout_rule(RewriteContext &ctx, Operation &op, TPU_ASSERT_EQ_OP(op.getNumResults(), 1); TPU_ASSERT_EQ_OP(layouts_in.size(), 1); TPU_ASSERT_EQ_OP(layouts_out.size(), 1); - if (layouts_in[0] !=layouts_out[0]) { + if (layouts_in[0] != layouts_out[0]) { return op.emitOpError("Expected same input and output layout"); } OpBuilder builder(&op); @@ -1585,19 +2054,13 @@ LogicalResult tpu_assume_layout_rule(RewriteContext &ctx, Operation &op, return success(); } -LogicalResult tpu_rotate_rule(RewriteContext &ctx, Operation &op, - const ArrayRef layouts_in, - const ArrayRef layouts_out) { - CHECK_EQ(layouts_in.size(), 1); - CHECK_EQ(layouts_out.size(), 1); - if (!layouts_in.front().has_value()) { - return op.emitOpError("Expected non-null input layout"); - } - if (!layouts_out.front().has_value()) { - return op.emitOpError("Expected non-null output layout"); - } - const VectorLayout &layout_in = *layouts_in.front(); - const VectorLayout &layout_out = *layouts_out.front(); +// TODO(b/347016737): Deprecate tpu.rotate and only use tpu.dynamic_rotate. So +// we do not need template for the op type and to explicitly force amount +// argument to dynamic. +template +LogicalResult rotate_rule_impl(RewriteContext &ctx, OpTy op, Value amount, + const VectorLayout &layout_in, + const VectorLayout &layout_out) { auto layout = VectorLayout(32, {0, 0}, ctx.target_shape, VectorLayout::ImplicitDim::kNone); if (layout_in != layout) { @@ -1606,8 +2069,7 @@ LogicalResult tpu_rotate_rule(RewriteContext &ctx, Operation &op, if (layout_out != layout) { return op.emitOpError("Not implemented: unsupported layout for output"); } - tpu::RotateOp rotate_op = cast(op); - auto vty = rotate_op.getResult().getType(); + auto vty = op.getResult().getType(); if (vty.getRank() < 2) { return op.emitOpError("Not implemented: unsupported 1D shape"); } @@ -1616,23 +2078,77 @@ LogicalResult tpu_rotate_rule(RewriteContext &ctx, Operation &op, return op.emitOpError("Not implemented: unsupported unaliged shape"); } - ImplicitLocOpBuilder builder(op.getLoc(), &op); + ImplicitLocOpBuilder builder(op.getLoc(), op.getOperation()); FAILUREOR_ASSIGN_OR_RETURN( VectorType res_vreg_ty, getNativeVregType(vty.getElementType(), ctx.target_shape)); FAILUREOR_ASSIGN_OR_RETURN( const xla::Array in_tiles, - disassemble(builder, layout_in, rotate_op.getValue(), ctx.target_shape)); + disassemble(builder, layout_in, op.getValue(), ctx.target_shape)); FAILUREOR_ASSIGN_OR_RETURN( const VectorType i32_vreg, getNativeVregType(builder.getI32Type(), ctx.target_shape)); - auto getVmaskByPaddingEnd = [&](int dim, int padding, int stride = 0) { + + // Some helper functions for math ops. + auto mlirI32Const = [&](int d) { + return builder.create( + builder.getIntegerAttr(builder.getI32Type(), d)); + }; + auto mlirIndexConst = [&](int d) { + return builder.create( + builder.getIntegerAttr(builder.getIndexType(), d)); + }; + auto modI = [&](const Value &v, unsigned d) -> Value { + if (auto cst = getIntConst(v, /*silent=*/true); succeeded(cst)) { + return mlirI32Const(cst.value() % d); + } + return builder.create(v, mlirI32Const(d)); + }; + auto divI = [&](const Value &v, unsigned d) -> Value { + if (auto cst = getIntConst(v, /*silent=*/true); succeeded(cst)) { + return mlirI32Const(cst.value() / d); + } + return builder.create(v, mlirI32Const(d)); + }; + auto addI = [&](const Value &v, unsigned d) -> Value { + if (auto cst = getIntConst(v, /*silent=*/true); succeeded(cst)) { + return mlirI32Const(cst.value() + d); + } + return builder.create(v, mlirI32Const(d)); + }; + + // A helper function that creates a VMASK with false flags to bottom (dim = 0) + // or right (dim = 1) where the flag count corresponds to the (dim_size - + // padding). If stride is provided, the padding value is sequentially + // increased by the stride value along the dim. + // + // For example, assume VMASK shape is (4, 8) + // + // getVmaskByPaddingEnd(padding=3, dim=1) creates: + // [T, T, T, T, T, F, F, F] + // [T, T, T, T, T, F, F, F] + // [T, T, T, T, T, F, F, F] + // [T, T, T, T, T, F, F, F] + // + // getVmaskByPaddingEnd(padding=3, dim=1, stride=1) creates: + // [T, T, T, T, T, F, F, F] + // [T, T, T, T, T, T, F, F] + // [T, T, T, T, T, T, T, F] + // [T, T, T, T, T, T, T, T] + auto getVmaskByPaddingEnd = [&](Value padding, int dim, int stride = 0) { CHECK(dim == 0 || dim == 1); - CHECK(padding >= 0 && padding <= ctx.target_shape[dim]); - Value padding_vreg = builder.create( - DenseElementsAttr::get(i32_vreg, builder.getI32IntegerAttr( - ctx.target_shape[dim] - padding))); + Value padding_vreg; + if (auto padding_cst = getIntConst(padding, /*silent=*/true); + succeeded(padding_cst)) { + CHECK_GE(padding_cst.value(), 0); + CHECK_LE(padding_cst.value(), ctx.target_shape[dim]); + padding_vreg = builder.create(DenseElementsAttr::get( + i32_vreg, builder.getI32IntegerAttr(padding_cst.value()))); + } else { + padding_vreg = builder.create(i32_vreg, padding); + } + if (stride > 0) { auto offset = builder.create( i32_vreg, @@ -1649,77 +2165,155 @@ LogicalResult tpu_rotate_rule(RewriteContext &ctx, Operation &op, padding_vreg); }; - auto splitVregs = [](const xla::Array &vregs, int axis) { - CHECK(axis >= 0 && axis < vregs.num_dimensions()); - SmallVector> chunks; - chunks.reserve(vregs.dim(axis)); - for (int64_t i = 0; i < vregs.dim(axis); ++i) { - SmallVector starts(vregs.num_dimensions(), 0); - starts[axis] = i; - SmallVector limits(vregs.dimensions().begin(), - vregs.dimensions().end()); - limits[axis] = i + 1; - chunks.push_back(vregs.Slice(starts, limits)); + // Apply rotation on each vreg with the assumption that shift <= VREG dim size + // and blend the data from contiguous vregs to emulate circular rotation. + auto rotateOnTilingDim = [&](const xla::Array &vregs, + const Value &shift, int axis, int stride = 0) { + if (auto shift_cst = getIntConst(shift, /*silent=*/true); + succeeded(shift_cst)) { + if (shift_cst.value() == 0 && stride == 0) { + return vregs; + } + } + int tiling_dim = axis - (vregs.num_dimensions() - 2); + CHECK((tiling_dim == 0 && stride == 0) || (tiling_dim == 1 && stride >= 0)); + xla::Array result(vregs.dimensions()); + auto chunks = split(vregs, axis); + for (int64_t i = 0; i < chunks.size(); ++i) { + chunks[i].Each([&](absl::Span idxs, Value *v) { + auto stride_attr = + stride > 0 ? builder.getSI32IntegerAttr(stride) : nullptr; + auto stride_dimension_attr = + stride > 0 ? builder.getSI32IntegerAttr(0) : nullptr; + *v = builder.create(res_vreg_ty, *v, shift, + tiling_dim, stride_attr, + stride_dimension_attr); + }); + } + auto mask = getVmaskByPaddingEnd(shift, tiling_dim, stride); + xla::Array last_chunk_copy(chunks[chunks.size() - 1]); + for (int64_t i = chunks.size() - 1; i > 0; --i) { + chunks[i].Each([&](absl::Span idxs, Value *v) { + *v = builder.create(mask, chunks[i - 1](idxs), *v); + }); } - return chunks; + chunks[0].Each([&](absl::Span idxs, Value *v) { + *v = builder.create(mask, last_chunk_copy(idxs), *v); + }); + return concatenate(chunks, axis); }; - auto roll = [&](const xla::Array &vregs, int64_t shift, int axis, - int stride = 0) { + + std::function(const xla::Array &, Value, int, int)> + rotate; + rotate = [&](const xla::Array &vregs, Value shift, int axis, + int stride) { xla::Array result(vregs.dimensions()); CHECK(axis >= 0 && axis < vregs.num_dimensions()); - auto chunks = splitVregs(vregs, axis); - if (axis >= vregs.num_dimensions() - 2) { - int tiling_dim = axis - (vregs.num_dimensions() - 2); - int64_t shift_in_vreg = shift % ctx.target_shape[tiling_dim]; - shift /= ctx.target_shape[tiling_dim]; - CHECK((tiling_dim == 0 && stride == 0) || - (tiling_dim == 1 && stride >= 0)); - if (shift_in_vreg != 0 || stride != 0) { - for (int64_t i = 0; i < chunks.size(); ++i) { - chunks[i].Each([&](absl::Span idxs, Value *v) { - auto stride_attr = - stride > 0 ? builder.getSI32IntegerAttr(stride) : nullptr; - auto stride_dimension_attr = - stride > 0 ? builder.getSI32IntegerAttr(0) : nullptr; - *v = builder.create(res_vreg_ty, *v, shift_in_vreg, - tiling_dim, stride_attr, - stride_dimension_attr); - }); - } - // After rotation on each vreg, we need to select the wrapped data - // from the previous vreg and overwrite them to the current vreg. - auto mask = getVmaskByPaddingEnd( - tiling_dim, ctx.target_shape[tiling_dim] - shift_in_vreg, stride); - xla::Array last_chunk_copy(chunks[chunks.size() - 1]); - for (int64_t i = chunks.size() - 1; i > 0; --i) { - chunks[i].Each([&](absl::Span idxs, Value *v) { - *v = builder.create(mask, chunks[i - 1](idxs), *v); - }); - } + int tiling_dim = axis - (vregs.num_dimensions() - 2); + CHECK((tiling_dim != 1 && stride == 0) || (tiling_dim == 1 && stride >= 0)); + SmallVector, 4> chunks; + // Handle rotation with static shift. + if (auto shift_cst = getIntConst(shift, /*silent=*/true); + succeeded(shift_cst)) { + int64_t static_shift = shift_cst.value(); + if (tiling_dim >= 0) { + shift = mlirI32Const(static_shift % ctx.target_shape[tiling_dim]); + static_shift /= ctx.target_shape[tiling_dim]; + chunks = split(rotateOnTilingDim(vregs, shift, axis, stride), axis); + } else { + chunks = split(vregs, axis); + } + // Now we only need to shuffle vregs. + for (int64_t i = 0; i < chunks.size(); ++i) { + SmallVector starts(result.num_dimensions(), 0); + starts[axis] = (i + static_shift) % result.dim(axis); + result.UpdateSlice(chunks[i], starts); + } + return result; + } + // Handle rotation with dynamic shift. + // TODO(jevinjiang): consider optimize with assume_multiple op. + Value in_vreg_shift = tiling_dim >= 0 + ? modI(shift, ctx.target_shape[tiling_dim]) + : mlirI32Const(0); + Value vreg_shift = + tiling_dim >= 0 ? divI(shift, ctx.target_shape[tiling_dim]) : shift; + result = tiling_dim >= 0 + ? rotateOnTilingDim(vregs, in_vreg_shift, axis, stride) + : vregs; + int bound = vregs.dim(axis); + if (bound <= ctx.max_sublanes_in_scratch / ctx.target_shape[0] && + bound >= kMinBoundToRotateWithScratch) { + // Use static store + dynamic load to implement dynamic shift. + if (auto scratch_ref = getInternalScratch( + ctx, builder, op.getLoc(), + {ctx.max_sublanes_in_scratch / ctx.target_shape[0], + ctx.target_shape[0], ctx.target_shape[1]}, + vty.getElementType()); + succeeded(scratch_ref)) { + auto cst_0 = mlirIndexConst(0); + SmallVector scratch_indices(3, cst_0); + SmallVector sublane_mask(ctx.target_shape[0], true); + const auto sublane_mask_attr = + DenseBoolArrayAttr::get(op.getContext(), sublane_mask); + chunks = split(result, axis); chunks[0].Each([&](absl::Span idxs, Value *v) { - *v = builder.create(mask, last_chunk_copy(idxs), *v); + // Static store vregs. + for (int i = 0; i < bound; ++i) { + scratch_indices[0] = mlirIndexConst(i); + builder.create(chunks[i](idxs), scratch_ref.value(), + scratch_indices, sublane_mask_attr, + /*mask=*/nullptr, + /*sublane_stride=*/nullptr); + } + // Dynamic load vregs back from a circular buffer. + for (int i = 0; i < bound; ++i) { + scratch_indices[0] = builder.create( + builder.getIndexType(), + modI(builder.create(mlirI32Const(bound + i), + vreg_shift), + bound)); + chunks[i](idxs) = + builder.create(v->getType(), scratch_ref.value(), + scratch_indices, sublane_mask_attr, + /*sublane_stride=*/nullptr); + } }); + return concatenate(chunks, axis); } - } else { - CHECK_EQ(stride, 0); } - // Now we only need to shuffle vregs. - for (int64_t i = 0; i < chunks.size(); ++i) { - SmallVector starts(result.num_dimensions(), 0); - starts[axis] = (i + shift) % result.dim(axis); - result.UpdateSlice(chunks[i], starts); + // Convert dynamic shift to log(bound) static ops. + int roll_by = 1; + Value cst_1 = mlirI32Const(1); + while (bound > 0) { + auto new_result = rotate( + result, + mlirI32Const(tiling_dim >= 0 ? roll_by * ctx.target_shape[tiling_dim] + : roll_by), + axis, /*stride=*/0); + auto mask = builder.create( + arith::CmpIPredicate::ne, + builder.create( + i32_vreg, builder.create(vreg_shift, cst_1)), + builder.create( + DenseElementsAttr::get(i32_vreg, builder.getI32IntegerAttr(0)))); + result.Each([&](absl::Span idxs, Value *v) { + *v = builder.create(mask, new_result(idxs), *v); + }); + roll_by *= 2; + bound /= 2; + vreg_shift = divI(vreg_shift, 2); } return result; }; xla::Array out_tiles(in_tiles.dimensions()); - const auto dim = rotate_op.getDimension(); - const auto amount = rotate_op.getAmount() % vty.getDimSize(dim); + const auto dim = op.getDimension(); + amount = modI(amount, vty.getDimSize(dim)); - if (rotate_op.getStride().has_value() && - rotate_op.getStrideDimension().has_value()) { - auto stride_dim = rotate_op.getStrideDimension().value(); - auto stride = rotate_op.getStride().value() % vty.getDimSize(stride_dim); + if (op.getStride().has_value() && op.getStrideDimension().has_value()) { + auto stride_dim = op.getStrideDimension().value(); + auto stride = op.getStride().value() % vty.getDimSize(stride_dim); if (stride_dim == dim) { return op.emitOpError( "Expected rotation dimension and stride dimension are not equal"); @@ -1734,46 +2328,96 @@ LogicalResult tpu_rotate_rule(RewriteContext &ctx, Operation &op, "is the minor most when stride dimension is the second minor most"); } CHECK_GE(stride, 0); - auto chunks = splitVregs(in_tiles, stride_dim); + auto chunks = split(in_tiles, stride_dim); for (int64_t i = 0; i < chunks.size(); ++i) { - int64_t base_amount = - (ctx.target_shape[0] * i * stride + amount) % vty.getDimSize(dim); + Value base_amount = modI(addI(amount, ctx.target_shape[0] * i * stride), + vty.getDimSize(dim)); // After applying stride, we expect all shifts in a vreg are less or // equal to the vreg's lane count for now. - auto max_shift_in_vreg = base_amount % ctx.target_shape[1] + - (ctx.target_shape[0] - 1) * stride; - if (max_shift_in_vreg > ctx.target_shape[1]) { - return op.emitOpError("Not implemented: the max shift in a vreg ") - << max_shift_in_vreg << " is larger than the vreg's width " - << ctx.target_shape[1]; + if (auto base_amount_cst = getIntConst(base_amount, /*silent=*/true); + succeeded(base_amount_cst)) { + int64_t static_base_amount = base_amount_cst.value(); + auto max_shift_in_vreg = static_base_amount % ctx.target_shape[1] + + (ctx.target_shape[0] - 1) * stride; + if (max_shift_in_vreg > ctx.target_shape[1]) { + return op.emitOpError("Not implemented: the max shift in a vreg ") + << max_shift_in_vreg << " is larger than the vreg's width " + << ctx.target_shape[1]; + } } SmallVector starts(out_tiles.num_dimensions(), 0); starts[stride_dim] = i; - out_tiles.UpdateSlice(roll(chunks[i], base_amount, dim, stride), + out_tiles.UpdateSlice(rotate(chunks[i], base_amount, dim, stride), starts); } } else { // Split vregs along the stride dimension. - auto chunks = splitVregs(in_tiles, stride_dim); + auto chunks = split(in_tiles, stride_dim); for (int64_t i = 0; i < chunks.size(); ++i) { SmallVector starts(out_tiles.num_dimensions(), 0); starts[stride_dim] = i; - out_tiles.UpdateSlice(roll(chunks[i], amount + i * stride, dim), - starts); + out_tiles.UpdateSlice( + rotate(chunks[i], addI(amount, i * stride), dim, /*stride=*/0), + starts); } } } else { // No stride. - out_tiles = roll(in_tiles, amount, dim); + out_tiles = rotate(in_tiles, amount, dim, /*stride=*/0); } const RollVectorsOp rolled_op = - assemble(builder, rotate_op.getResult().getType(), layout_out, out_tiles, + assemble(builder, op.getResult().getType(), layout_out, out_tiles, ctx.target_shape); op.replaceAllUsesWith(rolled_op); op.erase(); return success(); } +// TODO(b/347016737): deprecate the static rotate. +LogicalResult tpu_rotate_rule(RewriteContext &ctx, Operation &op, + const ArrayRef layouts_in, + const ArrayRef layouts_out) { + CHECK_EQ(layouts_in.size(), 1); + CHECK_EQ(layouts_out.size(), 1); + if (!layouts_in.front().has_value()) { + return op.emitOpError("Expected non-null input layout"); + } + if (!layouts_out.front().has_value()) { + return op.emitOpError("Expected non-null output layout"); + } + auto rotate_op = cast(op); + if (rotate_op.getAmount() < 0) { + return op.emitOpError("Not implemented: shifting by negative amount"); + } + ImplicitLocOpBuilder builder(op.getLoc(), &op); + Value shift = builder.create( + builder.getIntegerAttr(builder.getI32Type(), rotate_op.getAmount())); + const VectorLayout &layout_in = *layouts_in.front(); + const VectorLayout &layout_out = *layouts_out.front(); + return rotate_rule_impl(ctx, rotate_op, shift, layout_in, layout_out); +} + +LogicalResult tpu_dynamic_rotate_rule(RewriteContext &ctx, Operation &op, + const ArrayRef layouts_in, + const ArrayRef layouts_out) { + CHECK_EQ(layouts_in.size(), 2); + CHECK_EQ(layouts_out.size(), 1); + if (!layouts_in.front().has_value()) { + return op.emitOpError("Expected non-null layout for the value to rotate"); + } + if (layouts_in[1].has_value()) { + return op.emitOpError("Expected null layout for the shift"); + } + if (!layouts_out.front().has_value()) { + return op.emitOpError("Expected non-null output layout"); + } + auto rotate_op = cast(op); + const VectorLayout &layout_in = *layouts_in.front(); + const VectorLayout &layout_out = *layouts_out.front(); + return rotate_rule_impl(ctx, rotate_op, rotate_op.getAmount(), layout_in, + layout_out); +} + LogicalResult tpu_concatenate_rule(RewriteContext &ctx, Operation &op, const ArrayRef layouts_in, const ArrayRef layouts_out) { @@ -1788,15 +2432,12 @@ LogicalResult tpu_concatenate_rule(RewriteContext &ctx, Operation &op, return op.emitOpError("Not implemented: Inconsistent layouts"); } } - if (!layout.hasNaturalTopology(ctx.target_shape)) { - return op.emitOpError("Not implemented"); - } OpBuilder builder(&op); auto concatenate_op = cast(op); const VectorType res_ty = concatenate_op.getResult().getType(); const uint32_t dimension = concatenate_op.getDimension(); if (dimension - res_ty.getRank() >= -2) { - if (!layout.hasNaturalTopology(ctx.target_shape) || + if (!layout.hasNativeTiling(ctx.target_shape) || layout.offsets() != LayoutOffsets{0, 0}) { return op.emitOpError( "Not implemented: Only native tiling with offset (0, 0) is supported " @@ -1920,7 +2561,24 @@ LogicalResult tpu_iota_rule(RewriteContext &ctx, Operation &op, op.erase(); return success(); } - return op.emitOpError("Not implemented: Unsupported dimension"); + // We take the iota over an untiled dimension. + CHECK_LT(*dimension, vty.getRank()); + SmallVector tiles; + tiles.reserve(vty.getDimSize(*dimension)); + for (int64_t i = 0; i < vty.getDimSize(*dimension); ++i) { + tiles.push_back(builder.create( + native_vreg_ty, + DenseElementsAttr::get(native_vreg_ty, + IntegerAttr::get(vty.getElementType(), i)))); + } + xla::Array out_tiles(tile_array_shape); + out_tiles.Each([&](absl::Span idxs, Value *v) { + *v = tiles[idxs[*dimension]]; + }); + op.replaceAllUsesWith( + assemble(builder, vty, layout_out, out_tiles, ctx.target_shape)); + op.erase(); + return success(); } LogicalResult tpu_gather_rule(RewriteContext &ctx, Operation &op, @@ -2102,7 +2760,9 @@ LogicalResult vector_load_rule(RewriteContext &ctx, Operation &op, FAILUREOR_ASSIGN_OR_RETURN( Tiling memref_tiling, getMemRefTiling(load_op.getBase(), ctx.target_shape)); - if (layout_out.tiling() != memref_tiling) { + if (memref_tiling != layout_out.tiling() && + !(memref_tiling[0] == 1 && layout_out.tiling()[0] == 1 && + memref_tiling[1] % layout_out.tiling()[1] == 0)) { // Now we can handle the case when tiling is (1, TARGET_SHAPE.lanes). // TODO(b/295393167): need to support strided load for bitwidth < 32. if (layout_out.bitwidth() != 32 || @@ -2307,8 +2967,11 @@ LogicalResult vector_broadcast_rule(RewriteContext &ctx, Operation &op, ImplicitLocOpBuilder builder(op.getLoc(), &op); vector::BroadcastOp broadcast_op = cast(op); const VectorType dst_ty = broadcast_op.getResult().getType(); + const ArrayRef dst_shape = dst_ty.getShape(); const SmallVector dst_tiles_shape = - layout_out.tileArrayShape(dst_ty.getShape(), ctx.target_shape); + layout_out.tileArrayShape(dst_shape, ctx.target_shape); + const SmallVector dst_tiles_implicit_shape = + layout_out.tileArrayImplicitShape(dst_shape, ctx.target_shape); if (auto src = dyn_cast>(broadcast_op.getSource())) { VectorType src_ty = src.getType(); TPU_ASSERT_OP(maybe_layout_in.has_value()); @@ -2317,89 +2980,84 @@ LogicalResult vector_broadcast_rule(RewriteContext &ctx, Operation &op, return op.emitOpError( "Not implemented: Changing implicit dims mid-broadcast"); } - const VectorLayout::ImplicitDim implicit_dim = layout_in.implicit_dim(); - const int layout_rank = layout_in.layout_rank(); const LayoutOffsets offsets_in = layout_in.offsets(); const LayoutOffsets offsets_out = layout_out.offsets(); if (layout_in.tiling() != layout_out.tiling()) { - return op.emitOpError( - "Not implemented: Changing tiling mid-broadcast"); + return op.emitOpError("Not implemented: Changing tiling mid-broadcast"); } auto tiling = layout_in.tiling(); const int64_t expand_rank = dst_ty.getRank() - src_ty.getRank(); - SmallVector src_shape_padded(expand_rank, -1); const ArrayRef src_shape = src_ty.getShape(); - src_shape_padded.append(src_shape.begin(), src_shape.end()); - const SmallVector dim_eq = llvm::map_to_vector( - llvm::zip(src_shape_padded, dst_ty.getShape()), [](auto tup) { - auto [i, o] = tup; - return i == o; - }); - bool no_op = false; - switch (implicit_dim) { - case VectorLayout::ImplicitDim::kNone: { - const ArrayRef tiled_dim_eq = ArrayRef(dim_eq).take_back(2); - for (auto [in_off, out_off, eq] : - llvm::zip(offsets_in, offsets_out, tiled_dim_eq)) { - if (eq && in_off != out_off) { - return op.emitOpError( - "Not implemented: Changing offsets mid-broadcast"); - } + SmallVector src_implicit_shape_padded; + // `is_logical_broadcast` stores whether each dimension of the implicit + // shape of the result is a broadcast. E.g. if the implicit shape goes from + // (2, 1, 3) to (4, 2, 5, 3) it's (true, false, true, false). + SmallVector is_logical_broadcast; + src_implicit_shape_padded.reserve(dst_shape.size() + + layout_in.num_implicit_dims()); + is_logical_broadcast.reserve(dst_shape.size() + + layout_in.num_implicit_dims()); + src_implicit_shape_padded.append(expand_rank, 1); + src_implicit_shape_padded.append(src_shape.begin(), src_shape.end()); + for (auto [i, o] : llvm::zip(src_implicit_shape_padded, dst_shape)) { + TPU_ASSERT_OP(i == o || i == 1); // Verifier should guarantee this. + is_logical_broadcast.push_back(i != o); + } + layout_in.insertImplicit(src_implicit_shape_padded, 1); + layout_in.insertImplicit(is_logical_broadcast, false); + + // Verify that the offsets are valid. + for (auto [is_logical_broadcast_on_dim, in_off, out_off] : + llvm::zip_equal(ArrayRef(is_logical_broadcast).take_back(2), + offsets_in, offsets_out)) { + if (is_logical_broadcast_on_dim) { + if (out_off.has_value()) { + // There's no reason to ever assign a non-replicated offset to a + // broadcasted dimension in the output. + return op.emitOpError( + // TODO(tlongeri): This should never be implemented but the fuzzed + // tests expect a NotImplementedError, which + // is raised with a "Not implemented" (see + // NotImplementedDetector in tpu_ext.cc). Fix. + "Not implemented: Broadcast output expected to have replicated " + "offsets."); } - no_op = layout_in.hasNaturalTopology(ctx.target_shape) && - layout_out.hasNaturalTopology(ctx.target_shape) && - llvm::all_of(llvm::zip_equal(offsets_in, tiled_dim_eq), - [](auto tup) { - auto [o, eq] = tup; - return eq || !o.has_value(); - }); - } break; - case VectorLayout::ImplicitDim::kMinor: - case VectorLayout::ImplicitDim::kSecondMinor: - if (dim_eq.back()) { - if (offsets_in != offsets_out) { - return op.emitOpError( - "Not implemented: Changing offsets mid-broadcast"); - } - no_op = true; - } else if (implicit_dim == VectorLayout::ImplicitDim::kSecondMinor && - !offsets_in[1].has_value()) { - no_op = true; - } else if (implicit_dim == VectorLayout::ImplicitDim::kMinor && - !offsets_in[0].has_value()) { - no_op = true; + } else { // !is_logical_broadcast_on_dim + if (in_off != out_off) { + return op.emitOpError( + "Not implemented: Changing offsets mid-broadcast"); } - break; - } - TPU_ASSERT_OP(layout_rank); - if (src_ty.getShape().take_back(layout_rank) == - dst_ty.getShape().take_back(layout_rank)) { - if (offsets_in != offsets_out) { - op.emitOpError("Not implemented: Changing offsets mid-broadcast"); } - no_op = true; } + // `needs_physical_broadcast` specifies whether we need to broadcast vregs + // vregs in the sublane and lane dimensions. We only need to do this if the + // corresponding dimension of the implicit shape is logically broadcast and + // if the input vregs are not already replicated along this dimension. + const std::array needs_physical_broadcast{ + *(is_logical_broadcast.end() - 2) && offsets_in[0].has_value(), + *(is_logical_broadcast.end() - 1) && offsets_in[1].has_value()}; FAILUREOR_ASSIGN_OR_RETURN( xla::Array src_tiles, - disassemble(builder, layout_in, src, ctx.target_shape)); - xla::Array dst_tiles(dst_tiles_shape); - if (no_op) { + disassemble(builder, layout_in, src, ctx.target_shape, + /*use_implicit_shape=*/true)); + xla::Array dst_tiles(dst_tiles_implicit_shape); + if (needs_physical_broadcast == std::array{false, false}) { // No-op SmallVector reshape_dims(expand_rank, 1); const absl::Span src_tiles_dims = src_tiles.dimensions(); reshape_dims.append(src_tiles_dims.begin(), src_tiles_dims.end()); src_tiles.Reshape(reshape_dims); dst_tiles.Each([&](const absl::Span dst_idx, Value *tile) { - const SmallVector src_idx = - llvm::map_to_vector(llvm::zip_equal(dst_idx, dim_eq), [](auto tup) { - auto [i, eq] = tup; - return eq ? i : 0; + const SmallVector src_idx = llvm::map_to_vector( + llvm::zip_equal(dst_idx, is_logical_broadcast), [](auto tup) { + auto [i, is_logical_broadcast_on_dim] = tup; + return is_logical_broadcast_on_dim ? 0 : i; }); *tile = src_tiles(src_idx); }); - } else if (implicit_dim == VectorLayout::ImplicitDim::kNone) { + } else { if (layout_in.bitwidth() != 32) { return op.emitOpError( "Not implemented: Only 32-bit broadcast supported"); @@ -2408,8 +3066,8 @@ LogicalResult vector_broadcast_rule(RewriteContext &ctx, Operation &op, return op.emitOpError("Not implemented: unsupported tiling"); } int64_t num_tiles = layout_in.tilesPerVreg(ctx.target_shape); - TPU_ASSERT_OP(!*(dim_eq.end() - 1) || !*(dim_eq.end() - 2)); - if (*(dim_eq.end() - 1)) { // Sublane broadcast + if (needs_physical_broadcast == + std::array{true, false}) { // Sublane broadcast if (num_tiles != 1) { return op.emitOpError( "Not implemented: Only native tiling supported"); @@ -2421,12 +3079,12 @@ LogicalResult vector_broadcast_rule(RewriteContext &ctx, Operation &op, SmallVector(ctx.target_shape[0], offset)); src_tiles.Each([&](const absl::Span src_idx, Value *const src_tile) { - SmallVector dst_starts(dst_tiles_shape.size()); - SmallVector dst_limits(dst_tiles_shape.size()); + SmallVector dst_starts(dst_tiles_implicit_shape.size()); + SmallVector dst_limits(dst_tiles_implicit_shape.size()); for (int64_t i = 0; i < dst_tiles.num_dimensions(); ++i) { - if (i < expand_rank || !dim_eq[i]) { + if (i < expand_rank || is_logical_broadcast[i]) { dst_starts[i] = 0; - dst_limits[i] = dst_tiles_shape[i]; + dst_limits[i] = dst_tiles_implicit_shape[i]; } else { dst_starts[i] = src_idx[i - expand_rank]; dst_limits[i] = dst_starts[i] + 1; @@ -2437,7 +3095,8 @@ LogicalResult vector_broadcast_rule(RewriteContext &ctx, Operation &op, src_tile->getType(), *src_tile, indices, 0), dst_starts, dst_limits); }); - } else if (*(dim_eq.end() - 2)) { // Lane broadcast + } else if (needs_physical_broadcast == + std::array{false, true}) { // Lane broadcast TPU_ASSERT_EQ_OP(*(src_tiles.dimensions().end() - 1), 1); TPU_ASSERT_OP(offsets_in[1].has_value()); const int64_t offset = *offsets_in[1]; @@ -2445,8 +3104,7 @@ LogicalResult vector_broadcast_rule(RewriteContext &ctx, Operation &op, VectorType::get(ctx.target_shape, builder.getI32Type()); auto idx_const = builder.create( broadcast_op.getLoc(), idx_ty, - DenseElementsAttr::get(idx_ty, - builder.getI32IntegerAttr(offset))); + DenseElementsAttr::get(idx_ty, builder.getI32IntegerAttr(offset))); int64_t sublanes_per_tile = layout_in.sublanesPerTile(ctx.target_shape); DenseI32ArrayAttr sublane_pattern; if (num_tiles != 1) { @@ -2461,12 +3119,12 @@ LogicalResult vector_broadcast_rule(RewriteContext &ctx, Operation &op, } src_tiles.Each([&](const absl::Span src_idx, Value *const src_tile) { - SmallVector dst_starts(dst_tiles_shape.size()); - SmallVector dst_limits(dst_tiles_shape.size()); + SmallVector dst_starts(dst_tiles_implicit_shape.size()); + SmallVector dst_limits(dst_tiles_implicit_shape.size()); for (int64_t i = 0; i < dst_tiles.num_dimensions(); ++i) { - if (i < expand_rank || !dim_eq[i]) { + if (i < expand_rank || is_logical_broadcast[i]) { dst_starts[i] = 0; - dst_limits[i] = dst_tiles_shape[i]; + dst_limits[i] = dst_tiles_implicit_shape[i]; } else { dst_starts[i] = src_idx[i - expand_rank]; dst_limits[i] = dst_starts[i] + 1; @@ -2483,14 +3141,15 @@ LogicalResult vector_broadcast_rule(RewriteContext &ctx, Operation &op, updateSlice(dst_tiles, res_vreg, dst_starts, dst_limits); }); } else { - return op.emitOpError("Not implemented"); + TPU_ASSERT_OP((needs_physical_broadcast == std::array{true, true})); + return op.emitOpError( + "Not implemented: Broadcast in both sublanes and lanes"); } - } else { - return op.emitOpError("Not implemented"); } - broadcast_op.replaceAllUsesWith( - assemble(builder, dst_ty, layout_out, dst_tiles, ctx.target_shape) - .getOperation()); + broadcast_op.replaceAllUsesWith(assemble(builder, dst_ty, layout_out, + dst_tiles, ctx.target_shape, + /*use_implicit_shape=*/true) + .getOperation()); broadcast_op.erase(); return success(); } else if (layout_out.bitwidth() == 32 && @@ -2573,9 +3232,126 @@ LogicalResult vector_broadcast_rule(RewriteContext &ctx, Operation &op, } } +// Returns slice of vregs containing a given slice of elements, obtained from +// the result of a vector.extract or vector.extract_strided_slice op. +// +// Takes offsets and sizes describing the slice of elements. If their size is +// less than the rank of the input vector, they describe a prefix i.e. they +// apply to the first (majormost) dimensions and the remaining dimensions are +// not sliced. +// +// Args: +// - ctx: Rewrite context (for disassembling, which may create an op). +// - op: Source vector.extract or vector.extract_strided_slice op. +// - offsets: Prefix of offsets of slice of elements. Must have the same size +// as sizes. +// - sizes: Prefix of sizes of slice of elements. Must have the same size +// as offsets. +// - layout_in: Layout of src_vector. +// - layout_out: Layout that will be used to reassemble the slice (by caller). +// Used only to check that the reassembling is valid. +FailureOr> vector_extract_slice_impl( + RewriteContext &ctx, Operation &op, const ArrayRef sizes, + const ArrayRef offsets, const VectorLayout &layout_in, + const VectorLayout &layout_out) { + if (layout_in.tiling() != layout_out.tiling() || + layout_in.bitwidth() != layout_out.bitwidth()) { + return op.emitOpError( + "Expected layout_in and layout_out tiling and packing to match"); + } + + // Both extract_strided_slice and extract have their input vector at index 0 + // and a single result. + CHECK((isa(op))); + auto src_vector = cast>(op.getOperand(0)); + auto result = cast>(op.getResult(0)); + + const VectorType dst_ty = result.getType(); + if (layout_in.implicit_dim() != layout_out.implicit_dim() && + !(layout_in.implicit_dim() == VectorLayout::ImplicitDim::kNone && + layout_out.implicit_dim() == VectorLayout::ImplicitDim::kSecondMinor && + dst_ty.getRank() == 1)) { + return op.emitOpError( + "Unexpected change in implicit dimension that may not be a no-op"); + } + + const ArrayRef src_vector_shape = src_vector.getType().getShape(); + const int64_t src_vector_rank = src_vector_shape.size(); + const int64_t num_indices = offsets.size(); + TPU_ASSERT_EQ_OP(num_indices, sizes.size()); + + SmallVector full_sizes; + full_sizes.reserve(src_vector_rank + layout_in.num_implicit_dims()); + full_sizes.append(sizes.begin(), sizes.end()); + full_sizes.append(src_vector_shape.begin() + num_indices, + src_vector_shape.end()); + layout_in.insertImplicit(full_sizes, 1); + + SmallVector full_offsets; + full_offsets.reserve(src_vector_rank + layout_in.num_implicit_dims()); + full_offsets.append(offsets.begin(), offsets.end()); + full_offsets.append(src_vector_rank - num_indices, 0); + layout_in.insertImplicit(full_offsets, 0); + + // We currently only support no-op cases - that is, those where we effectively + // just extract a slice of vregs without doing any operations (e.g. shifts) on + // them. + // TODO(tlongeri): VectorLayout enforces that the offsets must fall in the + // first tile of each vreg. That means a no-op would not result in a valid + // layout if the index offset falls within a different tile in the vreg. Do we + // want to loosen this restriction or add shifts? This is the only non-no-op + // that might make sense to support - otherwise we should expect + // infer-vector-layout to assign no-op layouts and have the burden of any + // shifts that might be needed later fall on relayout. + for (auto [index_offset, in_offset, vreg_slice, out_offset] : llvm::zip_equal( + ArrayRef(full_offsets).take_back(2), layout_in.offsets(), + layout_in.vregSlice(ctx.target_shape), layout_out.offsets())) { + if (in_offset.has_value() != out_offset.has_value()) { + return op.emitOpError( + "Unexpected mismatch in replication between input and output " + "layouts"); + } + if (in_offset.has_value() && + (index_offset + *in_offset) % vreg_slice != *out_offset) { + return op.emitOpError("Not implemented: Only no-op tiles"); + } + } + + const std::array vreg_slice = + layout_in.vregSlice(ctx.target_shape); + SmallVector slice_tiled_starts(full_offsets); + *(slice_tiled_starts.end() - 2) = + (layout_in.offsets()[0].value_or(0) + *(full_offsets.end() - 2)) / + vreg_slice[0]; + *(slice_tiled_starts.end() - 1) = + (layout_in.offsets()[1].value_or(0) + *(full_offsets.end() - 1)) / + vreg_slice[1]; + layout_in.eraseImplicit(slice_tiled_starts); + SmallVector slice_tiled_limits(full_offsets); + for (int64_t i = 0; i < full_offsets.size() - layout_in.layout_rank(); ++i) { + slice_tiled_limits[i] += full_sizes[i]; + } + *(slice_tiled_limits.end() - 2) = + llvm::divideCeil(layout_in.offsets()[0].value_or(0) + + *(full_offsets.end() - 2) + *(full_sizes.end() - 2), + vreg_slice[0]); + *(slice_tiled_limits.end() - 1) = + llvm::divideCeil(layout_in.offsets()[1].value_or(0) + + *(full_offsets.end() - 1) + *(full_sizes.end() - 1), + vreg_slice[1]); + layout_in.eraseImplicit(slice_tiled_limits); + + OpBuilder builder(&op); + FAILUREOR_ASSIGN_OR_RETURN( + const xla::Array input_tiles, + disassemble(builder, layout_in, src_vector, ctx.target_shape)); + return input_tiles.Slice(slice_tiled_starts, slice_tiled_limits); +} + LogicalResult vector_extract_rule(RewriteContext &ctx, Operation &op, const ArrayRef layouts_in, const ArrayRef layouts_out) { + ImplicitLocOpBuilder builder(op.getLoc(), &op); vector::ExtractOp extract_op = cast(op); if (extract_op.hasDynamicPosition()) { return op.emitOpError("Not implemented: dynamic indices"); @@ -2584,32 +3360,58 @@ LogicalResult vector_extract_rule(RewriteContext &ctx, Operation &op, TPU_ASSERT_EQ_OP(layouts_out.size(), 1); TPU_ASSERT_OP(layouts_in.front().has_value()); const VectorLayout &layout_in = *layouts_in.front(); - if (layouts_out.front().has_value()) { - return op.emitOpError("Not implemented: Only scalar results supported"); - } if (layout_in.bitwidth() != 32) { return op.emitOpError( "Not implemented: Only 32-bit vector.extract supported"); } - if (layout_in.offsets() != LayoutOffsets{0, 0}) { - return op.emitOpError("Not implemented: Unsupported layout"); - } - ImplicitLocOpBuilder builder(op.getLoc(), &op); - for (int64_t i : extract_op.getStaticPosition()) { - if (i != 0) { - return op.emitOpError("Not implemented: Only 0 indices supported"); + const VectorType res_vty = + dyn_cast(extract_op.getResult().getType()); + if (res_vty != nullptr) { + TPU_ASSERT_OP(layouts_out.front().has_value()); + const VectorLayout &layout_out = *layouts_out.front(); + const int64_t num_indices = extract_op.getStaticPosition().size(); + const SmallVector sizes(num_indices, 1); + FAILUREOR_ASSIGN_OR_RETURN( + xla::Array dst_vregs, + vector_extract_slice_impl(ctx, *extract_op, sizes, + extract_op.getStaticPosition(), layout_in, + *layouts_out.front())); + // Squeeze leading singleton dimensions. + TPU_ASSERT_EQ_OP(res_vty.getRank(), + extract_op.getSourceVectorType().getRank() - num_indices); + TPU_ASSERT_OP( + llvm::all_of(toArrayRef(dst_vregs.dimensions()).take_front(num_indices), + [](const int64_t d) { return d == 1; })); + // Copy dims to temporary before passing to xla::Array::Reshape - it cannot + // take a pointer to its own data. + dst_vregs.Reshape(SmallVector( + toArrayRef(dst_vregs.dimensions()).drop_front(num_indices))); + op.replaceAllUsesWith( + assemble(builder, res_vty, layout_out, dst_vregs, ctx.target_shape) + .getOperation()); + op.erase(); + return success(); + } else { + for (int64_t i : extract_op.getStaticPosition()) { + if (i != 0) { + return op.emitOpError( + "Not implemented: Only 0 indices supported for scalar results"); + } } + if (layout_in.offsets() != LayoutOffsets{0, 0}) { + return op.emitOpError("Not implemented: Unsupported layout"); + } + FAILUREOR_ASSIGN_OR_RETURN( + const xla::Array vregs, + disassemble(builder, layout_in, extract_op.getVector(), + ctx.target_shape)); + TPU_ASSERT_GT_OP(vregs.num_elements(), 0); + extract_op.replaceAllUsesWith( + builder + .create(op.getLoc(), *vregs.data(), + ArrayRef{0, 0}) + .getResult()); } - FAILUREOR_ASSIGN_OR_RETURN( - const xla::Array vregs, - disassemble(builder, layout_in, extract_op.getVector(), - ctx.target_shape)); - TPU_ASSERT_GT_OP(vregs.num_elements(), 0); - extract_op.replaceAllUsesWith( - builder - .create(op.getLoc(), *vregs.data(), - ArrayRef{0, 0}) - .getResult()); extract_op.erase(); return success(); } @@ -2671,23 +3473,8 @@ LogicalResult vector_extract_strided_slice_rule( TPU_ASSERT_OP(layouts_out.front().has_value()); const VectorLayout &layout_in = *layouts_in.front(); const VectorLayout &layout_out = *layouts_out.front(); - if (!layout_in.hasNaturalTopology(ctx.target_shape)) { - return op.emitOpError("Not implemented: Unsupported input layout"); - } - if (layout_out != layout_in) { - return op.emitOpError("Not implemented: Unsupported output layout"); - } - OpBuilder builder(&op); - vector::ExtractStridedSliceOp extract_strided_slice_op = - cast(op); - const ArrayRef tiled_dims = - extract_strided_slice_op.getVector().getType().getShape().take_back(2); - if (tiled_dims[0] % layout_in.tiling()[0] != 0 || - tiled_dims[1] % layout_in.tiling()[1] != 0) { - return op.emitOpError( - "Not implemented: Extract strides slices only works with operands with " - "sizes that are multiples of the native tiling"); - } + ImplicitLocOpBuilder builder(op.getLoc(), &op); + auto extract_strided_slice_op = cast(op); auto I64ArrayToSmallVector = [&](const ArrayAttr array_attr) { return llvm::map_to_vector(array_attr, [](Attribute attr) { @@ -2695,37 +3482,18 @@ LogicalResult vector_extract_strided_slice_rule( }); }; - // We currently only support zero-offset, tile-aligned slices. This implies - // the output layout is merely a slice of the input layout, without needing to - // modify physical any of the vregs' layouts. - const SmallVector offsets = - I64ArrayToSmallVector(extract_strided_slice_op.getOffsets()); - for (const int64_t offset : ArrayRef(offsets).take_back(2)) { - if (offset != 0) { - return extract_strided_slice_op.emitOpError( - "Not implemented: Only tile-aligned slices supported"); - } - } - - const SmallVector slice_sizes = - I64ArrayToSmallVector(extract_strided_slice_op.getSizes()); - SmallVector slice_tiled_limits = - layout_in.tileArrayShape(slice_sizes, ctx.target_shape); - TPU_ASSERT_EQ_OP(slice_tiled_limits.size(), offsets.size()); - for (size_t i = 0; i < slice_tiled_limits.size(); ++i) { - slice_tiled_limits[i] += offsets[i]; - } FAILUREOR_ASSIGN_OR_RETURN( - const xla::Array input_tiles, - disassemble(builder, layout_in, extract_strided_slice_op.getVector(), - ctx.target_shape)); - const xla::Array dst_tiles = - input_tiles.Slice(offsets, slice_tiled_limits); - const VectorType dst_ty = extract_strided_slice_op.getResult().getType(); - extract_strided_slice_op.replaceAllUsesWith( - assemble(builder, dst_ty, layout_out, dst_tiles, ctx.target_shape) - .getOperation()); - extract_strided_slice_op.erase(); + const xla::Array dst_vregs, + vector_extract_slice_impl( + ctx, *extract_strided_slice_op, + I64ArrayToSmallVector(extract_strided_slice_op.getSizes()), + I64ArrayToSmallVector(extract_strided_slice_op.getOffsets()), + layout_in, layout_out)); + op.replaceAllUsesWith(assemble(builder, + extract_strided_slice_op.getResult().getType(), + layout_out, dst_vregs, ctx.target_shape) + .getOperation()); + op.erase(); return success(); } @@ -2792,12 +3560,10 @@ LogicalResult vector_multi_reduction_rule(RewriteContext &ctx, Operation &op, case vector::CombiningKind::ADD: neutral = builder.getF32FloatAttr(0); break; - case vector::CombiningKind::MAXNUMF: case vector::CombiningKind::MAXIMUMF: { // TODO(b/322836633): The semantics of maximumf don't match the lowering // for older TPU versions because older TPU versions don't respect the - // -0.0 vs +0.0 ordering. Keeping MAXNUMF for backward compatibility of - // serialized artifacts. + // -0.0 vs +0.0 ordering. neutral = builder.getFloatAttr( builder.getF32Type(), APFloat::getInf(APFloat::IEEEsingle(), /*Negative=*/true)); @@ -2891,7 +3657,6 @@ LogicalResult vector_multi_reduction_rule(RewriteContext &ctx, Operation &op, case vector::CombiningKind::ADD: tpu_kind = tpu::ReductionKind::SUM; break; - case vector::CombiningKind::MAXNUMF: case vector::CombiningKind::MAXIMUMF: tpu_kind = tpu::ReductionKind::MAX; break; @@ -2989,87 +3754,61 @@ LogicalResult vector_shape_cast_rule(RewriteContext &ctx, Operation &op, using Tiling = std::array; const VectorLayout &layout_in = *layouts_in.front(); const VectorLayout &layout_out = *layouts_out.front(); + TPU_ASSERT_EQ_OP( + layout_in.bitwidth(), + layout_out.bitwidth()); // This should be guaranteed through MLIR + // verifier plus our layoutIsValidForValue check ImplicitLocOpBuilder builder(op.getLoc(), &op); auto shape_cast_op = cast(op); const VectorType src_ty = shape_cast_op.getSourceVectorType(); const ArrayRef src_shape = src_ty.getShape(); const VectorType dst_ty = shape_cast_op.getResultVectorType(); const ArrayRef dst_shape = dst_ty.getShape(); - const int layout_rank = layout_in.layout_rank(); bool no_op = false; - // TODO(tlongeri): It looks like this could probably be simplified by using - // VectorLayout::implicitShape() - if (layout_in == layout_out && src_ty.getShape().take_back(layout_rank) == - dst_ty.getShape().take_back(layout_rank)) { - no_op = true; - } else if (layout_in.implicit_dim() == VectorLayout::ImplicitDim::kNone && - layout_out.implicit_dim() == - VectorLayout::ImplicitDim::kSecondMinor && - layout_in.hasNativeTiling(ctx.target_shape) && - layout_in.tiling() == layout_out.tiling() && - layout_in.offsets() == layout_out.offsets() && - *(src_shape.end() - 1) == *(dst_shape.end() - 1) && - *(src_shape.end() - 2) == 1) { - no_op = true; - } else if (layout_in.implicit_dim() == VectorLayout::ImplicitDim::kNone && - layout_out.implicit_dim() == VectorLayout::ImplicitDim::kMinor && - layout_in.hasNaturalTopology(ctx.target_shape) && - layout_in.tiling() == layout_out.tiling() && - layout_in.offsets() == layout_out.offsets() && - src_shape == - ArrayRef(layout_out.implicitShape(dst_shape))) { - no_op = true; - } else if (layout_in.implicit_dim() == VectorLayout::ImplicitDim::kMinor && - layout_out.implicit_dim() == VectorLayout::ImplicitDim::kNone && - layout_out.hasNaturalTopology(ctx.target_shape) && - layout_in.tiling() == layout_out.tiling() && - layout_in.offsets() == layout_out.offsets() && - dst_shape == - ArrayRef(layout_in.implicitShape(src_shape))) { + const std::array src_tiled_dims = + layout_in.getImplicitTiledDims(src_shape, 1); + const std::array dst_tiled_dims = + layout_out.getImplicitTiledDims(dst_shape, 1); + const std::array src_vreg_slice = + layout_in.vregSlice(ctx.target_shape); + const std::array dst_vreg_slice = + layout_out.vregSlice(ctx.target_shape); + if (layout_in.tiling() == layout_out.tiling() && + layout_in.offsets() == layout_out.offsets() && + src_tiled_dims == dst_tiled_dims) { no_op = true; } else if ( // Fold or unfold sublane dim, but keeping a whole number of // vregs. layout_in.implicit_dim() == VectorLayout::ImplicitDim::kNone && layout_out.implicit_dim() == VectorLayout::ImplicitDim::kNone && - layout_in.offsets() == LayoutOffsets{0, 0} && - layout_out.offsets() == LayoutOffsets{0, 0} && + layout_in.offsets()[0] == 0 && + layout_in.offsets() == layout_out.offsets() && layout_in.tiling() == layout_out.tiling() && - layout_in.tiling()[1] == ctx.target_shape[1] && *(dst_shape.end() - 1) == *(src_shape.end() - 1) && - *(dst_shape.end() - 2) % layout_in.tiling()[0] == 0 && - *(src_shape.end() - 2) % layout_in.tiling()[0] == 0) { + *(dst_shape.end() - 2) % dst_vreg_slice[0] == 0 && + *(src_shape.end() - 2) % src_vreg_slice[0] == 0) { no_op = true; } else if (layout_in.implicit_dim() == VectorLayout::ImplicitDim::kNone && layout_out.implicit_dim() == VectorLayout::ImplicitDim::kNone && layout_in.offsets() == layout_out.offsets() && layout_in.offsets() == LayoutOffsets{0, 0} && - layout_in.tiling() == Tiling{1, ctx.target_shape[1]} && - layout_out.hasNaturalTopology(ctx.target_shape) && - *(dst_shape.end() - 1) != *(src_shape.end() - 1) && - *(dst_shape.end() - 1) == ctx.target_shape[1] && - *(dst_shape.end() - 2) % ctx.target_shape[0] == 0 && - *(src_shape.end() - 1) % - (ctx.target_shape[0] * ctx.target_shape[1]) == - 0 && - (*(src_shape.end() - 2) == 1 || - *(src_shape.end() - 2) % ctx.target_shape[0] == 0)) { - // Shapecast (..., m * 128) -> (..., 128). + layout_in.tiling()[0] == 1 && + layout_out.hasNativeTiling(ctx.target_shape) && + *(dst_shape.end() - 1) == dst_vreg_slice[1] && + *(dst_shape.end() - 2) % dst_vreg_slice[0] == 0 && + *(src_shape.end() - 1) % src_vreg_slice[1] == 0) { + // Shapecast (..., m * 128 * packing) -> (..., 128). no_op = true; } else if (layout_in.implicit_dim() == VectorLayout::ImplicitDim::kNone && layout_out.implicit_dim() == VectorLayout::ImplicitDim::kNone && layout_in.offsets() == LayoutOffsets{0, 0} && layout_out.offsets() == LayoutOffsets{0, 0} && - layout_in.hasNaturalTopology(ctx.target_shape) && - layout_out.tiling() == Tiling{1, ctx.target_shape[1]} && - *(src_shape.end() - 1) != *(dst_shape.end() - 1) && - *(src_shape.end() - 1) == ctx.target_shape[1] && - *(src_shape.end() - 2) % ctx.target_shape[0] == 0 && - *(dst_shape.end() - 1) % - (ctx.target_shape[0] * ctx.target_shape[1]) == - 0 && - (*(dst_shape.end() - 2) == 1 || - *(dst_shape.end() - 2) % ctx.target_shape[0] == 0)) { - // Shapecast (..., 128) -> (..., m * 128). + layout_in.hasNativeTiling(ctx.target_shape) && + layout_out.tiling()[0] == 1 && + *(src_shape.end() - 1) == src_vreg_slice[1] && + *(src_shape.end() - 2) % src_vreg_slice[0] == 0 && + *(dst_shape.end() - 1) % dst_vreg_slice[1] == 0) { + // Shapecast (..., 128) -> (..., m * 128 * packing). no_op = true; } FAILUREOR_ASSIGN_OR_RETURN( @@ -3194,7 +3933,9 @@ LogicalResult vector_store_rule(RewriteContext &ctx, Operation &op, FAILUREOR_ASSIGN_OR_RETURN( const Tiling memref_tiling, getMemRefTiling(store_op.getBase(), ctx.target_shape)); - if (to_store_layout.tiling() != memref_tiling) { + if (memref_tiling != to_store_layout.tiling() && + !(memref_tiling[0] == 1 && to_store_layout.tiling()[0] == 1 && + memref_tiling[1] % to_store_layout.tiling()[1] == 0)) { // Now we can handle the case when tiling is (1, TARGET_SHAPE.lanes). // TODO(b/295393167): need to support strided store for bitwidth < 32. if (to_store_layout.bitwidth() != 32 || @@ -3387,11 +4128,6 @@ LogicalResult vector_transpose_rule(RewriteContext &ctx, Operation &op, "Not implemented: Non-native or offset layout unsupported"); } const int64_t transpose_unit_size = ctx.target_shape[1]; - for (const int64_t s : src_ty.getShape().take_back(2)) { - if (s % transpose_unit_size != 0) { - return transpose_op->emitOpError("Not implemented: Padded transpose"); - } - } if (ctx.hardware_generation < 4 && layout_in.bitwidth() != 32) { return transpose_op->emitOpError( "Not implemented: TPUs before v4 only support 32-bit transposes"); @@ -3430,8 +4166,8 @@ LogicalResult vector_transpose_rule(RewriteContext &ctx, Operation &op, src_slice_ends.append(incremented_batch_idx.begin(), incremented_batch_idx.end()); src_slice_ends.append({(src_row + 1) * vregs_per_tile, src_col_end}); - xla::Array src_tile_vregs = - src_vregs.Slice(src_slice_starts, src_slice_ends); + xla::Array src_tile_vregs = src_vregs.Slice( + src_slice_starts, src_slice_ends, /*out_of_bounds_ok=*/true); // Drop leading singleton (batch) dimensions to have a shape that conforms // with the vreg array shape specified by layout_in, as expected by assemble src_tile_vregs.Reshape( @@ -3462,12 +4198,12 @@ LogicalResult vector_transpose_rule(RewriteContext &ctx, Operation &op, const ArrayRef batch_sizes = dst_ty.getShape().take_front(num_batch_dims); SmallVector batch_idx(num_batch_dims); + const int64_t tile_rows = + xla::CeilOfRatio(*(src_ty.getShape().end() - 2), transpose_unit_size); + const int64_t num_col_tiles = + xla::CeilOfRatio(*(src_ty.getShape().end() - 1), transpose_unit_size); do { - const int64_t tile_rows = - *(src_ty.getShape().end() - 2) / transpose_unit_size; for (int64_t src_row = 0; src_row < tile_rows; ++src_row) { - const int64_t num_col_tiles = - *(src_ty.getShape().end() - 1) / transpose_unit_size; if (can_batch) { const int64_t num_batch_tiles = num_col_tiles / 2; for (int64_t src_col = 0; src_col < num_batch_tiles; ++src_col) { @@ -3494,6 +4230,43 @@ LogicalResult vector_transpose_rule(RewriteContext &ctx, Operation &op, transpose_op->erase(); return success(); } + +LogicalResult prng_random_bits_rule(RewriteContext &ctx, Operation &op, + const ArrayRef layouts_in, + const ArrayRef layouts_out) { + TPU_ASSERT_EQ_OP(layouts_in.size(), 0); + TPU_ASSERT_EQ_OP(layouts_out.size(), 1); + TPU_ASSERT_OP(layouts_out.front().has_value()); + + const VectorLayout &layout_out = *layouts_out.front(); + tpu::PRNGRandomBitsOp rng_op = cast(op); + if (layout_out != VectorLayout(32, {0, 0}, ctx.target_shape, + VectorLayout::ImplicitDim::kNone)) { + return op.emitOpError( + "Unsupported output layout for ") << rng_op->getName(); + } + OpBuilder builder(op.getContext()); + builder.setInsertionPointAfter(&op); + + VectorType vty = rng_op.getResult().getType(); + TPU_ASSERT_OP(vty.getElementType().isInteger()); + // Only 32-bit output supported currently. + TPU_ASSERT_OP(vty.getElementType().getIntOrFloatBitWidth() == 32); + xla::Array tiles( + layout_out.tileArrayShape(vty.getShape(), ctx.target_shape)); + VectorType tile_ty = VectorType::get(ctx.target_shape, vty.getElementType()); + tiles.Each([&](absl::Span tile_idxs, Value * v) { + *v = builder.create(op.getLoc(), tile_ty); + }); + const RollVectorsOp roll_vectors_op = + assemble(builder, vty, layout_out, tiles, ctx.target_shape); + rng_op->replaceUsesWithIf(roll_vectors_op, [&](OpOperand &operand) { + return operand.getOwner() != roll_vectors_op; + }); + rng_op->erase(); + return success(); +} + const llvm::StringMap &rules() { static auto rules = new llvm::StringMap{ {arith::ConstantOp::getOperationName(), arith_constant_rule}, @@ -3503,20 +4276,26 @@ const llvm::StringMap &rules() { {arith::TruncIOp::getOperationName(), arith_trunci_rule}, {func::ReturnOp::getOperationName(), func_return_rule}, {scf::ForOp::getOperationName(), scf_for_rule}, + {scf::WhileOp::getOperationName(), scf_while_rule}, + {scf::ConditionOp::getOperationName(), scf_condition_rule}, {scf::IfOp::getOperationName(), scf_if_rule}, {scf::YieldOp::getOperationName(), scf_yield_rule}, {tpu::RotateOp::getOperationName(), tpu_rotate_rule}, + {tpu::DynamicRotateOp::getOperationName(), tpu_dynamic_rotate_rule}, {tpu::ConcatenateOp::getOperationName(), tpu_concatenate_rule}, {tpu::IotaOp::getOperationName(), tpu_iota_rule}, {tpu::GatherOp::getOperationName(), tpu_gather_rule}, {tpu::LoadOp::getOperationName(), tpu_load_rule}, + {tpu::StoreOp::getOperationName(), tpu_store_rule}, + {tpu::StridedLoadOp::getOperationName(), tpu_strided_load_rule}, + {tpu::StridedStoreOp::getOperationName(), tpu_strided_store_rule}, {tpu::MatmulOp::getOperationName(), tpu_matmul_rule}, {tpu::RegionOp::getOperationName(), tpu_region_rule}, {tpu::RepeatOp::getOperationName(), tpu_repeat_rule}, - {tpu::StoreOp::getOperationName(), tpu_store_rule}, {tpu::BitcastOp::getOperationName(), tpu_bitcast_rule}, {tpu::TraceOp::getOperationName(), tpu_trace_rule}, {tpu::AssumeLayoutOp::getOperationName(), tpu_assume_layout_rule}, + {tpu::PRNGRandomBitsOp::getOperationName(), prng_random_bits_rule}, {vector::BroadcastOp::getOperationName(), vector_broadcast_rule}, {vector::ContractionOp::getOperationName(), vector_contract_rule}, {vector::ExtractOp::getOperationName(), vector_extract_rule}, @@ -3535,9 +4314,16 @@ const llvm::StringMap &rules() { RollVectorsOp assemble(OpBuilder &builder, VectorType vty, const VectorLayout &layout, const xla::Array &vals, - const std::array target_shape) { - CHECK(vals.dimensions() == - layout.tileArrayShape(vty.getShape(), target_shape)); + const std::array target_shape, + const bool use_implicit_shape) { + // TODO(tlongeri): Maybe just add a parameter to tileArrayShape instead of + // having `tileArrayShape` and `tileArrayImplicitShape`. + SmallVector vreg_array_shape = + layout.tileArrayImplicitShape(vty.getShape(), target_shape); + if (!use_implicit_shape) { + layout.eraseImplicit(vreg_array_shape); + } + CHECK(vals.dimensions() == vreg_array_shape); CHECK_GT(vals.num_elements(), 0); Location loc = vals.begin()->getLoc(); auto op = @@ -3558,8 +4344,8 @@ RollVectorsOp assemble(OpBuilder &builder, VectorType vty, // An ndarray of MLIR values representing the tiling of val given by layout. FailureOr> disassemble( OpBuilder &builder, const VectorLayout &layout, - const TypedValue val, - const std::array target_shape) { + const TypedValue val, const std::array target_shape, + const bool use_implicit_shape) { // TODO(tlongeri): Remove default const auto vty = val.getType(); const auto op_result = dyn_cast(val); if (op_result == nullptr) { @@ -3573,8 +4359,13 @@ FailureOr> disassemble( TPU_ASSERT_LOC(val.getLoc(), def_layout.has_value()); TPU_ASSERT_LOC(val.getLoc(), def_layout->generalizes(layout, vty.getShape(), target_shape)); + // TODO(tlongeri): Maybe just add a parameter to tileArrayShape instead of + // having `tileArrayShape` and `tileArrayImplicitShape`. SmallVector layout_shape = - layout.tileArrayShape(vty.getShape(), target_shape); + layout.tileArrayImplicitShape(vty.getShape(), target_shape); + if (!use_implicit_shape) { + layout.eraseImplicit(layout_shape); + } if (auto roll_vectors_op = dyn_cast(op)) { return XlaArrayFromShapeAndValues(layout_shape, roll_vectors_op->getOperands()); @@ -3625,6 +4416,7 @@ Value selectTilesFromRotatedRowVregs( Value left_partial_vreg = selectTilesFromRotatedRowVregs( builder, rotated_row_vregs, start_src_col, mid_src_col, first_dst_tile_sublane_offset, dst_layout, target_shape); + Location loc = left_partial_vreg.getLoc(); const int64_t left_tiles_count = mid_src_col - start_src_col + 1; const int64_t right_first_dst_tile_sublane_offset = @@ -3637,12 +4429,23 @@ Value selectTilesFromRotatedRowVregs( right_first_dst_tile_sublane_offset, dst_layout, target_shape); const IntegerType i1 = builder.getI1Type(); - const auto mask_vreg_ty = - dst_layout.packing() > 1 - ? VectorType::get(ArrayRef{target_shape[0], target_shape[1], - dst_layout.packing()}, - i1) - : VectorType::get(target_shape, i1); + // We never need to select partial sublanes, even for packed data. + const auto mask_vreg_ty = VectorType::get(target_shape, i1); + auto i32_vreg = VectorType::get(target_shape, builder.getI32Type()); + auto select_32bit = [&](Value sublane_mask, Value left, Value right) { + // Always do the selects on 32-bit granularity for maximum HW compatibility. + Type vreg_ty = left.getType(); + if (dst_layout.packing() != 1) { + left = builder.create(loc, i32_vreg, left); + right = builder.create(loc, i32_vreg, right); + } + Value result = + builder.create(loc, sublane_mask, left, right); + if (dst_layout.packing() != 1) { + result = builder.create(loc, vreg_ty, result); + } + return result; + }; auto boundIdxConst = std::bind(IdxConst, std::placeholders::_1, builder, left_partial_vreg.getLoc()); @@ -3670,9 +4473,7 @@ Value selectTilesFromRotatedRowVregs( boundIdxConst(0)}, ArrayRef{boundIdxConst(right_first_dst_tile_sublane_offset), boundIdxConst(target_shape[1])}); - return builder.create(left_partial_vreg.getLoc(), - sublanes_mask, left_partial_vreg, - right_partial_vreg); + return select_32bit(sublanes_mask, left_partial_vreg, right_partial_vreg); } auto sublanes_mask = builder.create( @@ -3681,9 +4482,7 @@ Value selectTilesFromRotatedRowVregs( boundIdxConst(0)}, ArrayRef{boundIdxConst(first_dst_tile_sublane_offset), boundIdxConst(target_shape[1])}); - return builder.create(left_partial_vreg.getLoc(), - sublanes_mask, right_partial_vreg, - left_partial_vreg); + return select_32bit(sublanes_mask, right_partial_vreg, left_partial_vreg); } // Retiles across vregs to match the destination layout when the sublane tiling @@ -3995,7 +4794,7 @@ FailureOr> relayout( *(src_tiles.dimensions().end() - 2) == 1)) && dst.offsets()[1] == 0 && src.tiling() == std::array{1, 128} && dst.tiling() == std::array{8, 128}) { - xla::Array src_tiles_retiled( + xla::Array src_tiles_retiled( dst.tileArrayShape(vty.getShape(), target_shape)); src_tiles_retiled.Each([&](absl::Span idx, Value *tile) { for (int dst_sl_idx = 0; dst_sl_idx < 8; ++dst_sl_idx) { @@ -4053,73 +4852,151 @@ FailureOr> relayout( }); src = dst; src_tiles = std::move(src_tiles_retiled); - } else if ( // TODO(b/265133506): Generalize retiling to general 16-bit types - // (might need to use a different unpacking op). - // (8,128) -> (16,128) tiling change for packed 16-bit types. + } else if ( // TODO(b/265133506): Generalize retiling. + // (8,128) -> (8 * packing,128) tiling change for packed type. src.implicit_dim() == VectorLayout::ImplicitDim::kNone && - dst.implicit_dim() == VectorLayout::ImplicitDim::kNone && - vty.getElementTypeBitWidth() == 16 && src.offsets() == dst.offsets() && + dst.implicit_dim() == VectorLayout::ImplicitDim::kNone && bitwidth < 32 && + 32 % bitwidth == 0 && src.offsets() == dst.offsets() && src.tiling() == std::array{8, 128} && - dst.tiling() == std::array{16, 128}) { + dst.tiling() == std::array{8 * dst.packing(), 128}) { const VectorLayout new_src(src.bitwidth(), src.offsets(), dst.tiling()); xla::Array src_tiles_retiled( new_src.tileArrayShape(vty.getShape(), target_shape)); + int vty_packing = dst.packing(); + VectorType vreg_x32 = + vty.getElementType().isSignlessInteger() + ? VectorType::get(target_shape, builder.getI32Type()) + : VectorType::get(target_shape, builder.getF32Type()); src_tiles_retiled.Each([&](absl::Span idx, Value *tile) { + const int vreg_part = idx.back() % vty_packing; + SmallVector parts; + parts.reserve(vty_packing); SmallVector src_idx(idx.begin(), idx.end()); - src_idx[src_idx.size() - 2] *= 2; - src_idx[src_idx.size() - 1] /= 2; - Value src_row1 = src_tiles(src_idx); - if (src_idx[src_idx.size() - 2] + 1 < - src_tiles.dim(src_tiles.num_dimensions() - 2)) { - ++src_idx[src_idx.size() - 2]; + src_idx[src_idx.size() - 2] *= vty_packing; + src_idx[src_idx.size() - 1] /= vty_packing; + for (int i = 0; i < vty_packing; ++i) { + parts.push_back(builder.create( + v.getLoc(), vreg_x32, src_tiles(src_idx), vreg_part)); + if (src_idx[src_idx.size() - 2] < + src_tiles.dim(src_tiles.num_dimensions() - 2) - 1) { + ++src_idx[src_idx.size() - 2]; + } } - Value src_row2 = src_tiles(src_idx); - const int vreg_part = idx[idx.size() - 1] % 2; - - VectorType vreg_x32 = - vty.getElementType().isSignlessInteger() - ? VectorType::get(target_shape, builder.getI32Type()) - : VectorType::get(target_shape, builder.getF32Type()); - auto half_row1 = builder.create( - v.getLoc(), vreg_x32, src_row1, vreg_part); - auto half_row2 = builder.create( - v.getLoc(), vreg_x32, src_row2, vreg_part); *tile = builder.create( - v.getLoc(), src_row1.getType(), ValueRange{half_row1, half_row2}); + v.getLoc(), src_tiles.begin()->getType(), parts, + tpu::PackFormat::kCompressed); }); src = new_src; src_tiles = std::move(src_tiles_retiled); - } else if ( // (8,128) -> (32,128) tiling change for packed 8-bit integers. + } else if ( // Handle retiling from (1, 128 * packing) to (packing, 128) for + // packed data. + // We do compressed unpacking followed by interleaved packing. + // TODO(tlongeri): This can be used as a first step before using + // a generalized retiling where we only move sublanes around + // (without packing/unpacking). + // TODO(tlongeri): Interleaved unpacking followed by interleaved + // packing (but with different pairings) might also be + // interesting if the next step is a retile, since we can also + // match corresponding elements without shifting. It's just that + // the tiles are not adjacent (no contiguous vreg slice). src.implicit_dim() == VectorLayout::ImplicitDim::kNone && - dst.implicit_dim() == VectorLayout::ImplicitDim::kNone && - vty.getElementType() == builder.getI8Type() && - src.offsets() == dst.offsets() && - src.tiling() == std::array{8, 128} && - dst.tiling() == std::array{32, 128}) { - const VectorLayout new_src(src.bitwidth(), src.offsets(), dst.tiling()); + dst.implicit_dim() == VectorLayout::ImplicitDim::kNone && bitwidth < 32 && + 32 % bitwidth == 0 && src.offsets() == dst.offsets() && + src.tiling() == std::array{1, 128 * packing} && + dst.tiling() == std::array{packing, 128}) { + // To illustrate, consider a 2 x 16 16-bit shape laid out in vregs of + // 4 sublanes and 2 lanes (this is convenient for to keep the example small + // yet non-trivial) with (1, 4) tiling. We will relayout to (2, 2) tiling. + // + // The vreg slice is 1 x 16, that is, the vreg contains the data for a + // 1 x 16 window of the logical shape. + // + // [a b c d e f g h i j k l m n o p] -> vreg 1 + // [A B C D E F G H I J K L M N O P] -> vreg 2 + // + // Note: we support multiple vregs per row of the logical shape, but we use + // one here just to keep the example small. + // + // When we do a compressed unpack, the resulting vregs effectively have a + // tiling of (1, 2) and cover a vreg slice of 1 x 8 logical elements. + // + // [a b c d e f g h] -> vreg 1, part 1 [i j k l m n o p] -> vreg 1, part 2 + // [A B C D E F G H] -> vreg 2, part 1 [I J K L M N O P] -> vreg 2, part 2 + // + // It is clear that if combine vreg 1, part 1 and vreg 2, part 1 we get data + // that covers a 2 x 8 vreg slice. Note, however, that we will have to mind + // the internal ordering of the vreg. + // + // [a b c d e f g h [i j k l m n o p + // A B C D E F G H] -> new vreg 1 I J K L M N O P] -> new vreg 2 + // + // To see if we can get the right internal ordering that we need for (2, 2) + // tiling, let's break new vreg 1 into (1, 2) rows, which correspond to + // sublanes when unpacked and half-sublanes when packed. + // + // [(a b) (c d) (e f) (g h) + // (A B) (C D) (E F) (G H)] + // + // The sublane order for the vreg parts is [(a b) (c d) ...] for vreg 1, + // part 1 and [(A B) (C D) ...] for vreg 2, part 1. + // + // The desired half-sublane order, for packed (2, 2) tiling, is + // [(a b) (A B) (c d) (C D) ...]. That is, traverse down each column before + // moving to the next one. This is exactly an interleaving of the sublanes + // of the vreg parts. + const VectorLayout new_src(src.bitwidth(), src.offsets(), + std::array{packing, 128}); xla::Array src_tiles_retiled( new_src.tileArrayShape(vty.getShape(), target_shape)); - VectorType vreg_i32 = - getNativeVregType(builder.getI32Type(), target_shape).value(); + const VectorType vreg_x32 = + vty.getElementType().isSignlessInteger() + ? VectorType::get(target_shape, builder.getI32Type()) + : VectorType::get(target_shape, builder.getF32Type()); src_tiles_retiled.Each([&](absl::Span idx, Value *tile) { - const int vreg_part = idx.back() % 4; - std::array parts; - SmallVector src_idx(idx.begin(), idx.end()); - src_idx[src_idx.size() - 2] *= 4; - src_idx[src_idx.size() - 1] /= 4; - for (int i = 0; i < 4; ++i) { - parts[i] = builder.create( - v.getLoc(), vreg_i32, src_tiles(src_idx), vreg_part); - if (src_idx[src_idx.size() - 2] < - src_tiles.dim(src_tiles.num_dimensions() - 2) - 1) { - ++src_idx[src_idx.size() - 2]; - } + SmallVector parts; + parts.reserve(packing); + SmallVector src_idx(toArrayRef(idx)); + *(src_idx.end() - 2) *= packing; + const int64_t vreg_part = *(src_idx.end() - 1) % packing; + *(src_idx.end() - 1) /= packing; + for (int i = 0; i < packing; ++i) { + parts.push_back(builder.create( + v.getLoc(), vreg_x32, src_tiles(src_idx), vreg_part)); + if (*(src_idx.end() - 2) < *(src_tiles.dimensions().end() - 2)) { + ++*(src_idx.end() - 2); + } // The rest is padding, so just pick any of the input parts (but not + // an arbitrary vreg so we don't add an extra dependency). } *tile = builder.create( - v.getLoc(), src_tiles.begin()->getType(), parts); + v.getLoc(), src_tiles.begin()->getType(), parts, + tpu::PackFormat::kInterleaved); }); src = new_src; src_tiles = std::move(src_tiles_retiled); + } else if ( // Handle retiling from (8, 128, -2) to (8, 128) for 32-bit data. + // This drops the implicit second minor dimension. + src.implicit_dim() == VectorLayout::ImplicitDim::kSecondMinor && + dst.implicit_dim() == VectorLayout::ImplicitDim::kNone && + src.bitwidth() == 32 && src.offsets() == dst.offsets() && + src.offsets() == LayoutOffsets{0, 0} && src.tiling() == dst.tiling() && + src.tiling() == std::array{8, 128}) { + xla::Array src_tiles_retiled( + dst.tileArrayShape(vty.getShape(), target_shape)); + src_tiles_retiled.Each( + [&](const absl::Span idx, Value *tile) { + for (int dst_sl_idx = 0; dst_sl_idx < 8; ++dst_sl_idx) { + SmallVector src_idx(idx.begin(), idx.end()); + auto second_minor_idx = idx.size() - 2; + src_idx[second_minor_idx] = 8 * idx[second_minor_idx] + dst_sl_idx; + if (src_idx[second_minor_idx] >= src_tiles.dim(second_minor_idx)) { + break; + } + *tile = copy_one_sublane(builder, src_tiles(src_idx), 0, *tile, + dst_sl_idx, target_shape); + } + }); + src = dst; + src_tiles = std::move(src_tiles_retiled); } if (isSupportedReducedSublanesRetile(src, dst, target_shape)) { @@ -4184,8 +5061,8 @@ FailureOr> relayout( v.getLoc(), bits_vreg_ty, DenseElementsAttr::get(bits_vreg_ty, shift_bits)); dst_tiles.Each([&](absl::Span /*idx*/, Value *tile) { - auto bit_tile = - builder.create(v.getLoc(), bits_vreg_ty, *tile); + auto bit_tile = builder.create( + v.getLoc(), bits_vreg_ty, *tile); Operation *shift_tile; if (subelem_diff > 0) { shift_tile = @@ -4197,7 +5074,7 @@ FailureOr> relayout( } *tile = builder .create(v.getLoc(), tile->getType(), - shift_tile->getResult(0)) + shift_tile->getResult(0)) .getResult(); return absl::OkStatus(); }); @@ -4216,38 +5093,56 @@ FailureOr> relayout( return emitError(v.getLoc(), "Not implemented: Both columns and rows are shifted"); } - if (col_diff < 0) { - return emitError(v.getLoc(), "Not implemented: Shifts to the left"); - } if (bitwidth != 32 || tiling != target_shape) { return emitError(v.getLoc(), "Not implemented: Only 32-bit column shifts for " "native layouts supported"); } - const int64_t sublane_diff = col_diff; TPU_ASSERT_GE_LOC(v.getLoc(), src_tiles.num_dimensions(), 1); std::optional maybe_create_mask; - if (src_tiles.dimensions()[src_tiles.num_dimensions() - 1] > 1) { + if (*(src_tiles.dimensions().end() - 1) > 1) { + int64_t lane_start, lane_end; + if (col_diff > 0) { + lane_start = 0; + lane_end = col_diff; + } else { // col_diff < 0 + lane_start = target_shape[1] + col_diff; + lane_end = target_shape[1]; + } auto boundIdxConst = std::bind(IdxConst, std::placeholders::_1, builder, v.getLoc()); maybe_create_mask = builder.create( v.getLoc(), VectorType::get(target_shape, builder.getI1Type()), - ValueRange{boundIdxConst(0), boundIdxConst(0)}, + ValueRange{boundIdxConst(0), boundIdxConst(lane_start)}, ValueRange{boundIdxConst(target_shape[0]), - boundIdxConst(col_diff)}); + boundIdxConst(lane_end)}); } - src_tiles.Each([&](absl::Span idx, Value tile) { - Value rot_tile = - builder - .create(v.getLoc(), tile, - /*amount=*/sublane_diff, - /*dimension=*/1, /*stride=*/nullptr, - /*stride_dimension=*/nullptr) - .getResult(); - if (idx[idx.size() - 1] != 0) { - SmallVector prev_idx(idx.begin(), idx.end()); - --prev_idx[idx.size() - 1]; - Value prev_rot_tile = dst_tiles(prev_idx); + src_tiles.Each([&](absl::Span idx, Value *tile) { + *tile = builder + .create(v.getLoc(), *tile, + /*amount=*/col_diff < 0 + ? target_shape[1] + col_diff + : col_diff, + /*dimension=*/1, /*stride=*/nullptr, + /*stride_dimension=*/nullptr) + .getResult(); + }); + src_tiles.Each([&](absl::Span idx, Value rot_tile) { + Value prev_rot_tile; + if (col_diff > 0) { + if (*(idx.end() - 1) != 0) { + SmallVector prev_idx(idx.begin(), idx.end()); + --*(prev_idx.end() - 1); + prev_rot_tile = src_tiles(prev_idx); + } + } else { // col_diff < 0 + if (*(idx.end() - 1) != *(src_tiles.dimensions().end() - 1) - 1) { + SmallVector prev_idx(idx.begin(), idx.end()); + ++*(prev_idx.end() - 1); + prev_rot_tile = src_tiles(prev_idx); + } + } + if (prev_rot_tile != nullptr) { rot_tile = builder.create( v.getLoc(), maybe_create_mask->getResult(), prev_rot_tile, rot_tile); @@ -4370,12 +5265,14 @@ struct ApplyVectorLayoutPass : public impl::ApplyVectorLayoutPassBase { ApplyVectorLayoutPass(int hardware_generation_, int lane_count_, int sublane_count_, int mxu_contracting_size_, - int mxu_noncontracting_size_) { + int mxu_noncontracting_size_, + int max_sublanes_in_scratch_) { hardware_generation = hardware_generation_; sublane_count = sublane_count_; lane_count = lane_count_; mxu_contracting_size = mxu_contracting_size_; mxu_noncontracting_size = mxu_noncontracting_size_; + max_sublanes_in_scratch = max_sublanes_in_scratch_; } void runOnOperation() override { // Fail if hardware_generation has not been set from the default value. @@ -4387,7 +5284,8 @@ struct ApplyVectorLayoutPass RewriteContext ctx{func, hardware_generation, {sublane_count, lane_count}, - {mxu_contracting_size, mxu_noncontracting_size}}; + {mxu_contracting_size, mxu_noncontracting_size}, + max_sublanes_in_scratch}; if (failed(applyLayoutFunc(ctx, func))) { signalPassFailure(); return; @@ -4397,10 +5295,11 @@ struct ApplyVectorLayoutPass std::unique_ptr> createApplyVectorLayoutPass( int hardware_generation, int lane_count, int sublane_count, - int mxu_contracting_size, int mxu_noncontracting_size) { + int mxu_contracting_size, int mxu_noncontracting_size, + int max_sublanes_in_scratch) { return std::make_unique( hardware_generation, lane_count, sublane_count, mxu_contracting_size, - mxu_noncontracting_size); + mxu_noncontracting_size, max_sublanes_in_scratch); } } // namespace mlir::tpu diff --git a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.h b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.h index aaaa96fe5a0e..547a8a00c10c 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.h +++ b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.h @@ -21,18 +21,22 @@ struct RewriteContext { const int hardware_generation; const std::array target_shape = {8, 128}; const std::array mxu_shape = {128, 128}; + const int max_sublanes_in_scratch = 0; MLIRContext *getMLIRContext() { return func.getContext(); } }; +// TODO(tlongeri): Remove default values for use_implicit_shape. RollVectorsOp assemble(OpBuilder &builder, VectorType vty, const VectorLayout &layout, const xla::Array &vals, - std::array target_shape); + std::array target_shape, + bool use_implicit_shape = false); FailureOr> disassemble(OpBuilder &builder, const VectorLayout &layout, TypedValue val, - std::array target_shape); + std::array target_shape, + bool use_implicit_shape = false); // Rewrites the operation according to its layout annotations. // diff --git a/jaxlib/mosaic/dialect/tpu/transforms/debug_assert_insertion.cc b/jaxlib/mosaic/dialect/tpu/transforms/debug_assert_insertion.cc index 1eb7c12a828a..5478c64f9944 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/debug_assert_insertion.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/debug_assert_insertion.cc @@ -22,12 +22,9 @@ limitations under the License. #include "llvm/ADT/StringMap.h" #include "llvm/Support/raw_ostream.h" #include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" -#include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/Types.h" @@ -55,9 +52,11 @@ rule_type as_generic_rule(void (*rule)(Op)) { void assertIsValidSubwindow(Operation *op, mlir::ValueRange base_indices, ArrayRef window_shape, - ArrayRef full_shape) { + ArrayRef full_shape, + ArrayRef strides = {}) { if (base_indices.size() != window_shape.size() || - base_indices.size() != full_shape.size()) { + base_indices.size() != full_shape.size() || + (!strides.empty() && base_indices.size() != strides.size())) { return; // Malformed op. } if (base_indices.empty()) { @@ -68,14 +67,15 @@ void assertIsValidSubwindow(Operation *op, mlir::ValueRange base_indices, for (auto [dim, access] : llvm::enumerate(llvm::zip(base_indices, window_shape, full_shape))) { auto [idx, size, bound] = access; + int64_t stride = strides.empty() ? 1 : strides[dim]; Value positive = builder.create( arith::CmpIPredicate::sge, idx, builder.create(builder.getIntegerAttr(idx_type, 0))); Value in_bounds = builder.create( - arith::CmpIPredicate::sle, + arith::CmpIPredicate::slt, builder.create( idx, builder.create( - builder.getIntegerAttr(idx_type, size))), + builder.getIntegerAttr(idx_type, (size - 1) * stride))), builder.create( builder.getIntegerAttr(idx_type, bound))); std::string msg; @@ -107,6 +107,21 @@ void tpu_memref_slice_rule(tpu::MemRefSliceOp op) { /*full_shape=*/op.getMemRef().getType().getShape()); } +void tpu_strided_load_rule(tpu::StridedLoadOp op) { + assertIsValidSubwindow(op, op.getIndices(), + /*window_shape=*/op.getResult().getType().getShape(), + /*full_shape=*/op.getBase().getType().getShape(), + /*strides=*/op.getStrides()); +} + +void tpu_strided_store_rule(tpu::StridedStoreOp op) { + assertIsValidSubwindow( + op, op.getIndices(), + /*window_shape=*/op.getValueToStore().getType().getShape(), + /*full_shape=*/op.getBase().getType().getShape(), + /*strides=*/op.getStrides()); +} + const llvm::StringMap &rules() { static auto rules = new llvm::StringMap{ // TODO: tpu::LoadOp, tpu::StoreOp @@ -114,6 +129,10 @@ const llvm::StringMap &rules() { {vector::StoreOp::getOperationName(), as_generic_rule(vector_store_rule)}, {tpu::MemRefSliceOp::getOperationName(), as_generic_rule(tpu_memref_slice_rule)}, + {tpu::StridedLoadOp::getOperationName(), + as_generic_rule(tpu_strided_load_rule)}, + {tpu::StridedStoreOp::getOperationName(), + as_generic_rule(tpu_strided_store_rule)}, }; return *rules; } diff --git a/jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.cc index 5efd4496d1a9..54add5fe469e 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.cc @@ -23,6 +23,7 @@ #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" #include "absl/log/check.h" +#include "absl/log/log.h" #include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h" #include "jaxlib/mosaic/dialect/tpu/util.h" #include "xla/layout.h" @@ -33,6 +34,23 @@ namespace mlir::tpu { #define GEN_PASS_DEF_INFERMEMREFLAYOUTPASS #include "jaxlib/mosaic/dialect/tpu/tpu_passes.h.inc" +SmallVector ComputeTileStrides(MemRefType memref_ty, + int64_t leading_tile_rows) { + SmallVector tile_strides(memref_ty.getRank()); + int64_t stride = 1; + for (int i = memref_ty.getRank() - 1; i >= 0; --i) { + tile_strides[i] = stride; + if (i == memref_ty.getRank() - 1) { + stride *= llvm::divideCeil(memref_ty.getShape()[i], 128); + } else if (i == memref_ty.getRank() - 2) { + stride *= llvm::divideCeil(memref_ty.getShape()[i], leading_tile_rows); + } else { + stride *= memref_ty.getShape()[i]; + } + } + return tile_strides; +} + // Returns the number of 128-element groups in a tile. // // Arguments: @@ -55,7 +73,8 @@ int getTilingFactor(const int num_128s, const int hardware_generation, } FailureOr inferLayout(MemRefType memref_ty, - const int hardware_generation) { + const int hardware_generation, + int64_t leading_tile_rows = 0) { if (auto tiled_layout_attr = dyn_cast(memref_ty.getLayout())) { return tiled_layout_attr; @@ -91,11 +110,14 @@ FailureOr inferLayout(MemRefType memref_ty, } return TiledLayoutAttr::get(memref_ty.getContext(), tiles, {1}); } + // memref.getRank() > 1 const ArrayRef shape = memref_ty.getShape(); const int64_t second_minor = shape[shape.size() - 2]; - const int64_t leading_tile_rows = - getTilingFactor(second_minor, hardware_generation, bitwidth); + if (leading_tile_rows == 0) { + leading_tile_rows = + getTilingFactor(second_minor, hardware_generation, bitwidth); + } SmallVector tiles{xla::Tile({leading_tile_rows, 128})}; if (bitwidth != 32) { if (!llvm::has_single_bit(bitwidth) || bitwidth > 32) { @@ -105,19 +127,7 @@ FailureOr inferLayout(MemRefType memref_ty, } tiles.push_back(xla::Tile({32 / bitwidth, 1})); } - SmallVector tile_strides(memref_ty.getRank()); - int64_t stride = 1; - for (int i = memref_ty.getRank() - 1; i >= 0; --i) { - tile_strides[i] = stride; - if (i == memref_ty.getRank() - 1) { - stride *= (memref_ty.getShape()[i] + 127) / 128; - } else if (i == memref_ty.getRank() - 2) { - stride *= (memref_ty.getShape()[i] + leading_tile_rows - 1) / - leading_tile_rows; - } else { - stride *= memref_ty.getShape()[i]; - } - } + auto tile_strides = ComputeTileStrides(memref_ty, leading_tile_rows); return TiledLayoutAttr::get(memref_ty.getContext(), tiles, tile_strides); } return emitError(UnknownLoc::get(memref_ty.getContext()), @@ -149,7 +159,8 @@ LogicalResult checkTiles(MLIRContext *mlir_ctx, } FailureOr inferMemref(MemRefType memref, - const int hardware_generation) { + const int hardware_generation, + int64_t leading_tile_rows) { if (isa(memref.getElementType())) { const Attribute semaphore_mem = tpu::MemorySpaceAttr::get( memref.getContext(), MemorySpace::kSemaphoreMem); @@ -169,8 +180,9 @@ FailureOr inferMemref(MemRefType memref, tpu::MemorySpaceAttr::get(memref.getContext(), MemorySpace::vmem); const Attribute memory_space = memref.getMemorySpace() == nullptr ? vmem : memref.getMemorySpace(); - FAILUREOR_ASSIGN_OR_RETURN(const TiledLayoutAttr layout, - inferLayout(memref, hardware_generation)); + FAILUREOR_ASSIGN_OR_RETURN( + const TiledLayoutAttr layout, + inferLayout(memref, hardware_generation, leading_tile_rows)); const ArrayRef tiles = layout.getTiles(); if (failed(checkTiles(memref.getContext(), tiles))) { @@ -244,14 +256,24 @@ LogicalResult inferFunc(func::FuncOp f, const int hardware_generation) { Block &entry = f.getBody().front(); SmallVector new_arg_types; auto builder = OpBuilder::atBlockBegin(&entry); - for (BlockArgument arg : entry.getArguments()) { + for (int i = 0; i < entry.getNumArguments(); ++i) { + BlockArgument arg = entry.getArgument(i); const auto memref_ty = dyn_cast(arg.getType()); if (memref_ty == nullptr) { new_arg_types.push_back(arg.getType()); continue; } - FAILUREOR_ASSIGN_OR_RETURN(const MemRefType new_memref_ty, - inferMemref(memref_ty, hardware_generation)); + int64_t leading_tile_rows = 0; + auto leading_tile_rows_attr = + f.getArgAttrOfType(i, kLeadingTileRows); + if (leading_tile_rows_attr != nullptr) { + leading_tile_rows = leading_tile_rows_attr.getInt(); + f.removeArgAttr(i, kLeadingTileRows); + } + + FAILUREOR_ASSIGN_OR_RETURN( + const MemRefType new_memref_ty, + inferMemref(memref_ty, hardware_generation, leading_tile_rows)); arg.setType(new_memref_ty); new_arg_types.push_back(arg.getType()); if (memref_ty != new_memref_ty) { diff --git a/jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.h b/jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.h index 724a09fffa19..2ad0afbb690d 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.h +++ b/jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.h @@ -1,12 +1,17 @@ #ifndef THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_TPU_TRANSFORMS_INFER_MEMREF_LAYOUT_H_ #define THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_TPU_TRANSFORMS_INFER_MEMREF_LAYOUT_H_ +#include + #include "mlir/IR/BuiltinTypes.h" #include "mlir/Support/LogicalResult.h" namespace mlir::tpu { -FailureOr inferMemref(MemRefType memref, int hardware_generation); +FailureOr inferMemref(MemRefType memref, int hardware_generation, + int64_t leading_tile_rows = 0); + +const std::string_view kLeadingTileRows = "leading_tile_rows"; } // namespace mlir::tpu diff --git a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc index b91e4832be72..fe02b4270a40 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc @@ -26,6 +26,7 @@ limitations under the License. #include #include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVectorExtras.h" #include "llvm/Support/raw_ostream.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" @@ -42,10 +43,12 @@ limitations under the License. #include "absl/log/check.h" #include "absl/log/log.h" #include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/include/mlir/Dialect/SCF/IR/SCF.h" #include "mlir/include/mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/include/mlir/IR/Attributes.h" #include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/include/mlir/IR/OpDefinition.h" +#include "mlir/include/mlir/IR/Visitors.h" #include "jaxlib/mosaic/dialect/tpu/layout.h" #include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h" #include "xla/layout.h" @@ -178,9 +181,8 @@ class VectorLayoutInferer { TPU_CHECK_OP(static_cast(in_ty) == static_cast(out_ty), "Input and output are not both vectors?"); if (in_ty) { - TPU_CHECK_OP(in_ty.getElementTypeBitWidth() == 1 && - out_ty.getElementTypeBitWidth() == 32, - "Only 1 bit -> 32 bit extensison supported"); + TPU_CHECK_OP(in_ty.getElementTypeBitWidth() == 1, + "Only extending i1 is supported"); } if (inferElementwise(&any_op, /*check_bitwidth=*/false).failed()) { return failure(); @@ -192,11 +194,7 @@ class VectorLayoutInferer { auto rhs_ty = dyn_cast(any_op.getOperand(1).getType()); TPU_CHECK_OP(static_cast(lhs_ty) == static_cast(rhs_ty), "Only one side of cmp is a vector?"); - if (lhs_ty) { - TPU_CHECK_OP(lhs_ty.getElementTypeBitWidth() == kNativeBitwidth && - rhs_ty.getElementTypeBitWidth() == kNativeBitwidth, - "Only 32-bit cmp supported"); - } + // TODO(tlongeri): Check that TPU generation supports comparison. if (inferElementwise(&any_op, /*check_bitwidth=*/false).failed()) { return failure(); } @@ -220,10 +218,18 @@ class VectorLayoutInferer { if (infer(op).failed()) { return failure(); } + } else if (auto op = dyn_cast(any_op)) { + if (infer(op).failed()) { + return failure(); + } } else if (auto op = dyn_cast(any_op)) { if (infer(op).failed()) { return failure(); } + } else if (auto op = dyn_cast(any_op)) { + if (infer(op).failed()) { + return failure(); + } } else if (auto op = dyn_cast(any_op)) { if (infer(op).failed()) { return failure(); @@ -232,11 +238,19 @@ class VectorLayoutInferer { if (infer(op).failed()) { return failure(); } - } else if (auto op = dyn_cast(any_op)) { + } else if (auto op = dyn_cast(any_op)) { if (infer(op).failed()) { return failure(); } - } else if (auto op = dyn_cast(any_op)) { + } else if (auto op = dyn_cast(any_op)) { + if (infer(op).failed()) { + return failure(); + } + } else if (auto op = dyn_cast(any_op)) { + if (infer(op).failed()) { + return failure(); + } + } else if (auto op = dyn_cast(any_op)) { if (infer(op).failed()) { return failure(); } @@ -264,6 +278,10 @@ class VectorLayoutInferer { if (infer(op).failed()) { return failure(); } + } else if (auto op = dyn_cast(any_op)) { + if (infer(op).failed()) { + return failure(); + } } else if (auto op = dyn_cast(any_op)) { if (infer(op).failed()) { return failure(); @@ -411,19 +429,7 @@ class VectorLayoutInferer { auto then_yield = op.thenBlock()->getTerminator(); TPU_CHECK_OP(then_yield->getOperandTypes() == op->getResultTypes(), "scf if results and then branch yield operands do not match"); - SmallVector result_layout; - result_layout.reserve(then_yield->getNumOperands()); - for (const auto &operand : then_yield->getOperands()) { - if (operand.getType().isSignlessIntOrIndexOrFloat()) { - result_layout.push_back(kNoLayout); - } else if (isa(operand.getType())) { - result_layout.push_back(getLayout(operand)); - } else { - op.emitOpError("unsupported scf.yield type"); - return failure(); - } - } - + auto then_yield_in_layouts = getLayoutFromOperands(then_yield); if (auto else_block = op.elseBlock()) { if (inferBlock(*else_block, match_yield).failed()) { op.emitOpError("failed to infer layout for else branch"); @@ -438,32 +444,53 @@ class VectorLayoutInferer { auto else_yield = op.elseBlock()->getTerminator(); TPU_CHECK_OP(else_yield->getOperandTypes() == op->getResultTypes(), "scf if results and else branch yield operands do not match"); - - // Check each layout of the yield in else branch and override the - // result_layout if else branch's yield layout is less general. For example, - // if we yield offset (*, *) in then branch and offset (*, 0) in else - // branch, the result offset should be (*, 0). - for (int i = 0; i < else_yield->getNumOperands(); ++i) { - const auto &operand = else_yield->getOperand(i); - if (!isa(operand.getType())) { - continue; - } - auto shape = dyn_cast(operand.getType()).getShape(); - auto layout = getLayout(operand); - CHECK(result_layout[i].has_value() && layout.has_value()); - result_layout[i] = - VectorLayout::join(result_layout[i].value(), layout.value(), shape); - if (!result_layout[i].has_value()) { - op.emitOpError( - "failed to find a compatible layout in then and else branch for " - "output ") - << i; - return failure(); + auto else_yield_in_layouts = getLayoutFromOperands(else_yield); + // Find a compatible layout from then and else branches for each reuslt. For + // example, if we yield offset (*, *) in then branch and offset (*, 0) in + // else branch, the result offset should be (*, 0). + SmallVector out_layouts; + out_layouts.reserve(op->getNumResults()); + int out_idx = 0; + for (auto [then_layout, else_layout, result] : llvm::zip_equal( + then_yield_in_layouts, else_yield_in_layouts, op.getResults())) { + if (auto vty = dyn_cast(result.getType())) { + if (!then_layout.has_value()) { + return op.emitOpError( + "expected a vector layout for then yield input ") + << out_idx; + } + if (!else_layout.has_value()) { + return op.emitOpError( + "expected a vector layout for else yield input ") + << out_idx; + } + auto compatible_layout = VectorLayout::join( + then_layout.value(), else_layout.value(), vty.getShape()); + // If no compatible layout is found in layouts for then and else + // branches, the output layout falls back to a normalized layout which + // has offsets 0 and the native tiling. + if (!compatible_layout.has_value()) { + compatible_layout = VectorLayout( + then_layout->bitwidth(), {0, 0}, + nativeTiling(then_layout->bitwidth()), ImplicitDim::kNone); + } + out_layouts.push_back(compatible_layout); + } else { + if (then_layout.has_value()) { + return op.emitOpError("expected no layout for then yield input ") + << out_idx; + } + if (else_layout.has_value()) { + return op.emitOpError("expected no layout for else yield input ") + << out_idx; + } + out_layouts.push_back(kNoLayout); } + ++out_idx; } - setInLayout(then_yield, result_layout); - setInLayout(else_yield, result_layout); - setOutLayout(op, result_layout); + setInLayout(then_yield, out_layouts); + setInLayout(else_yield, out_layouts); + setOutLayout(op, out_layouts); return success(); } @@ -481,53 +508,215 @@ class VectorLayoutInferer { op->getNumOperands() == 3 + op.getNumResults(), "expected num_operands is equal to 3 + num_results in scf.for"); - SmallVector in_layouts; - in_layouts.reserve(op->getNumOperands()); - in_layouts.push_back(kNoLayout); // Lower bound. - in_layouts.push_back(kNoLayout); // Upper bound. - in_layouts.push_back(kNoLayout); // Step. - for (const auto &arg : op.getInitArgs()) { - if (arg.getType().isSignlessIntOrIndexOrFloat()) { - in_layouts.push_back(kNoLayout); - } else if (isa(arg.getType())) { - auto layout = getLayout(arg); - in_layouts.push_back(layout); + auto in_layouts = getLayoutFromOperands(op); + // Drop the input layouts for lower bound, upper bound. But keep the layout + // for step because it matches with induction variable in arguments. + auto arg_layouts = ArrayRef(in_layouts).drop_front(2); + if (assumeLayoutsForBlockArgs(*op.getBody(), arg_layouts).failed() || + inferBlock(*op.getBody(), match_yield).failed()) { + return op.emitOpError( + "failed to infer layout with initial layouts for body in " + "scf.for op"); + } + auto yield_op = op.getBody()->getTerminator(); + auto yield_in_layouts = getLayoutFromOperands(yield_op); + + SmallVector out_layouts; + out_layouts.reserve(op->getNumResults()); + int out_idx = 0; + bool require_reinfer = false; + for (auto [in_layout, yield_layout, result] : + llvm::zip_equal(arg_layouts.drop_front( + 1), // Drop the layout for induction variable. + yield_in_layouts, op.getResults())) { + if (auto vty = dyn_cast(result.getType())) { + if (!in_layout.has_value()) { + return op.emitOpError("expected a vector layout for input ") + << out_idx; + } + if (!yield_layout.has_value()) { + return op.emitOpError("expected a vector layout for yield input ") + << out_idx; + } + auto compatible_layout = VectorLayout::join( + in_layout.value(), yield_layout.value(), vty.getShape()); + // If no compatible layout is found in layouts for input and + // yield, the output layout falls back to a normalized layout which + // has offsets 0 and the native tiling. + if (!compatible_layout.has_value()) { + compatible_layout = VectorLayout(in_layout->bitwidth(), {0, 0}, + nativeTiling(in_layout->bitwidth()), + ImplicitDim::kNone); + } + if (!require_reinfer && + (compatible_layout.value() != in_layout.value() || + compatible_layout.value() != yield_layout.value())) { + require_reinfer = true; + } + out_layouts.push_back(compatible_layout); } else { - op.emitOpError() << "unsupported arg type " << arg.getType() - << " in scf::for"; - return failure(); + if (in_layout.has_value()) { + return op.emitOpError("expected no layout for input ") << out_idx; + } + if (yield_layout.has_value()) { + return op.emitOpError("expected no layout for yield input ") + << out_idx; + } + out_layouts.push_back(kNoLayout); } + ++out_idx; } - ArrayRef out_layouts = ArrayRef(in_layouts).drop_front(3); - // Use tpu.assume_layout to annotate every block argument with the layout of - // the corresponding operand in forOp and replace all uses of the block - // argument with the result of tpu.assume_layout. - ImplicitLocOpBuilder builder = - ImplicitLocOpBuilder::atBlockBegin(op.getLoc(), op.getBody()); - - // Drop the induction_variable and layouts of bounds+step (respectively). - for (auto [iter_arg, layout] : llvm::zip_equal( - op.getBody()->getArguments().drop_front(1), out_layouts)) { - if (!dyn_cast(iter_arg.getType())) { - continue; + if (require_reinfer) { + // Force same layouts in input layout but skip the first 3 layouts for + // lower bound, upper bound and step. + std::copy(out_layouts.begin(), out_layouts.end(), in_layouts.begin() + 3); + + // Terminator in the loop will carry layouts to the next loop but + // the loop's block args' layouts are determined by the initial inputs. We + // need to force the same layouts for all in order to make layouts be + // consistent across all branches. To ensure that, we need to reprocess + // layout inference for the entire body with the final consolidated + // layout. + clearBlockLayouts(*op.getBody()); + if (assumeLayoutsForBlockArgs(*op.getBody(), + ArrayRef(in_layouts).drop_front(2)) + .failed() || + inferBlock(*op.getBody(), match_yield).failed()) { + return op.emitOpError( + "failed to infer layout with compatible layouts for body in " + "scf.for op"); } - auto assume_layout_op = - builder.create(iter_arg.getType(), iter_arg); - setLayout(assume_layout_op, layout, layout); - iter_arg.replaceUsesWithIf(assume_layout_op, [&](OpOperand &operand) { - return operand.getOwner() != assume_layout_op; - }); } + setInLayout(yield_op, out_layouts); + setLayout(op, in_layouts, out_layouts); + return success(); + } - if (inferBlock(*op.getBody(), match_yield).failed()) { - return failure(); + LogicalResult infer(scf::WhileOp op) { + static LogicalResult (*match_condition)(Operation *) = [](Operation *op) { + TPU_CHECK_OP(isa(op), "expected condition terminator"); + return success(); + }; + static LogicalResult (*match_yield)(Operation *) = [](Operation *op) { + TPU_CHECK_OP(isa(op), "expected yield terminator"); + return success(); + }; + TPU_CHECK_OP(op.getNumRegions() == 2, "expected two blocks for scf.while"); + + SmallVector in_layouts = getLayoutFromOperands(op); + + if (assumeLayoutsForBlockArgs(*op.getBeforeBody(), in_layouts).failed() || + inferBlock(*op.getBeforeBody(), match_condition).failed()) { + return op.emitOpError( + "failed to infer layout with initial layouts for before body in " + "scf.while op"); } - auto yield_op = op.getBody()->getTerminator(); + + if (assumeLayoutsForBlockArgs(*op.getAfterBody(), in_layouts).failed() || + inferBlock(*op.getAfterBody(), match_yield).failed()) { + return op.emitOpError( + "failed to infer layout with initial layouts for after body in " + "scf.while op"); + } + + auto *cond_op = op.getBeforeBody()->getTerminator(); + auto cond_in_layouts = getLayoutFromOperands(cond_op); + auto *yield_op = op.getAfterBody()->getTerminator(); + auto yield_in_layouts = getLayoutFromOperands(yield_op); + + // Find a compatible layout from condition body and loop body for each + // reuslt. For example, if we yield offset (*, *) in condition body and + // offset (*, 0) in loop body, the result offset should be (*, 0). + SmallVector out_layouts; + out_layouts.reserve(op->getNumResults()); + int out_idx = 0; + bool require_reinfer = false; + for (auto [in_layout, cond_layout, yield_layout, result] : llvm::zip_equal( + in_layouts, ArrayRef(cond_in_layouts).drop_front(1), + yield_in_layouts, op.getResults())) { + if (auto vty = dyn_cast(result.getType())) { + if (!in_layout.has_value()) { + return op.emitOpError("expected a vector layout for whileOp input ") + << out_idx; + } + if (!cond_layout.has_value()) { + return op.emitOpError("expected a vector layout for condition input ") + << out_idx + 1; // ConditionOp's first input is 1 bit bool. + } + if (!yield_layout.has_value()) { + return op.emitOpError("expected a vector layout for yield input ") + << out_idx; + } + auto compatible_layout = VectorLayout::join( + cond_layout.value(), yield_layout.value(), vty.getShape()); + if (compatible_layout.has_value()) { + compatible_layout = VectorLayout::join( + in_layout.value(), compatible_layout.value(), vty.getShape()); + } + // If no compatible layout is found in layouts for input, condition and + // yield, the output layout falls back to a normalized layout which + // has offsets 0 and the native tiling. + if (!compatible_layout.has_value()) { + compatible_layout = VectorLayout(in_layout->bitwidth(), {0, 0}, + nativeTiling(in_layout->bitwidth()), + ImplicitDim::kNone); + } + if (!require_reinfer && + (compatible_layout.value() != in_layout.value() || + compatible_layout.value() != cond_layout.value() || + compatible_layout.value() != yield_layout.value())) { + require_reinfer = true; + } + out_layouts.push_back(compatible_layout); + } else { + if (in_layout.has_value()) { + return op.emitOpError("expected no layout for whileOp input ") + << out_idx; + } + if (cond_layout.has_value()) { + return op.emitOpError("expected no layout for condition input ") + << out_idx + 1; // ConditionOp's first input is 1 bit bool. + } + if (yield_layout.has_value()) { + return op.emitOpError("expected no layout for yield input ") + << out_idx; + } + out_layouts.push_back(kNoLayout); + } + ++out_idx; + } + if (require_reinfer) { + clearBlockLayouts(*op.getBeforeBody()); + clearBlockLayouts(*op.getAfterBody()); + // Terminator in the loop will carry layouts to the next loop but + // the loop's block args' layouts are determined by the initial inputs. We + // need to force the same layouts for all in order to make layouts be + // consistent across all branches. To ensure that, we need to reprocess + // layout inference for the entire body with the final consolidated + // layout. + if (assumeLayoutsForBlockArgs(*op.getBeforeBody(), out_layouts) + .failed() || + inferBlock(*op.getBeforeBody(), match_condition).failed()) { + return op.emitOpError( + "failed to infer layout with compatible layouts for before body in " + "scf.while op"); + } + if (assumeLayoutsForBlockArgs(*op.getAfterBody(), out_layouts).failed() || + inferBlock(*op.getAfterBody(), match_yield).failed()) { + return op.emitOpError( + "failed to infer layout with compatible layouts for after body in " + "scf.while op"); + } + } + std::copy(out_layouts.begin(), out_layouts.end(), + cond_in_layouts.begin() + 1); // Skip the first 1 bit bool. + setInLayout(cond_op, cond_in_layouts); setInLayout(yield_op, out_layouts); - setLayout(op, in_layouts, out_layouts); + setLayout(op, out_layouts, out_layouts); return success(); } + // TODO(b/347016737): deprecate the static rotate. LogicalResult infer(tpu::RotateOp op) { auto bitwidth = op.getType().getElementTypeBitWidth(); if (bitwidth != 32) { @@ -542,6 +731,21 @@ class VectorLayoutInferer { return success(); } + LogicalResult infer(tpu::DynamicRotateOp op) { + auto bitwidth = op.getType().getElementTypeBitWidth(); + // TODO(b/347067057): Support dynamic rotate with packed dtype. + if (bitwidth != 32) { + NYI("Rotate with non-32-bit data"); + } + if (op.getType().getRank() < 2) { + NYI("Unsupported 1D shape"); + } + auto layout = VectorLayout(bitwidth, {0, 0}, nativeTiling(bitwidth), + ImplicitDim::kNone); + setLayout(op, {layout, kNoLayout}, layout); + return success(); + } + LogicalResult infer(tpu::ConcatenateOp op) { TPU_CHECK_OP(!op.getSources().empty(), "Need at least one vector to concatenate"); @@ -554,15 +758,20 @@ class VectorLayoutInferer { } auto res_ty = op.getResult().getType(); int8_t bitwidth = res_ty.getElementTypeBitWidth(); - if (bitwidth != 32) { - NYI("Support concatenation with non 32-bit data"); + auto layout = getLayout(op.getSources().front()); + // When concatenating vectors with replicated offsets, we want to reset the + // replicated offset to zero. Because we are not sure if the replicated + // value from each vector are same. + layout = VectorLayout( + layout->bitwidth(), + {layout->offsets()[0].value_or(0), layout->offsets()[1].value_or(0)}, + layout->tiling(), layout->implicit_dim()); + if (dimension >= res_rank - 2) { + layout = VectorLayout(bitwidth, {0, 0}, nativeTiling(bitwidth), + ImplicitDim::kNone); } - auto layout = (dimension >= res_rank - 2) - ? VectorLayout(bitwidth, {0, 0}, nativeTiling(bitwidth), - ImplicitDim::kNone) - : getLayout(op.getSources().front()); SmallVector in_layouts(op->getNumOperands(), layout); - setLayout(op, in_layouts, in_layouts.back()); + setLayout(op, in_layouts, layout); return success(); } @@ -581,6 +790,39 @@ class VectorLayoutInferer { return success(); } + LogicalResult infer(tpu::StridedLoadOp op) { + auto vty = op.getResult().getType(); + int8_t bitwidth = vty.getElementTypeBitWidth(); + if (bitwidth != 32) { + NYI("Strided load with non 32-bit data"); + } + if (vty.getRank() < 2) { + NYI("Strided load with 1D vector"); + } + SmallVector in_layout(op->getNumOperands(), kNoLayout); + setLayout(op, in_layout, + VectorLayout(bitwidth, {0, 0}, nativeTiling(bitwidth), + ImplicitDim::kNone)); + return success(); + } + + LogicalResult infer(tpu::StridedStoreOp op) { + auto vty = op.getValueToStore().getType(); + int8_t bitwidth = vty.getElementTypeBitWidth(); + if (bitwidth != 32) { + NYI("Strided store with non 32-bit data"); + } + if (vty.getRank() < 2) { + NYI("Strided store with 1D vector"); + } + auto store_layout = VectorLayout(bitwidth, {0, 0}, nativeTiling(bitwidth), + ImplicitDim::kNone); + SmallVector in_layout{op->getNumOperands(), kNoLayout}; + in_layout[0] = store_layout; + setInLayout(op, in_layout); + return success(); + } + LogicalResult infer(tpu::MatmulOp op) { return inferMatmul(op); } LogicalResult infer(tpu::StoreOp op) { @@ -727,47 +969,42 @@ class VectorLayoutInferer { TPU_CHECK_OP(res_ty.getRank() >= 2, "result rank below 2D unsupported"); auto some_layout = getLayout(op.getSource()); TPU_CHECK_OP(some_layout.has_value(), "missing vector layout"); + auto &layout = *some_layout; // Since we can only do sublane broadcasts in the (8, 128) tiling, we // should always use that when sublane broadcasting is required. - if (src_ty.getDimSize(src_ty.getRank() - 2) != - res_ty.getDimSize(res_ty.getRank() - 2)) { - if (some_layout->bitwidth() != kNativeBitwidth) { + if (*(src_ty.getShape().end() - 2) != *(res_ty.getShape().end() - 2)) { + if (layout.bitwidth() != kNativeBitwidth) { NYI("Only 32-bit broadcasts supported"); } - LayoutOffsets offsets = some_layout->offsets(); + LayoutOffsets offsets = layout.offsets(); // At the moment relayout can only produce replicated sublanes when // converting to (8, 128) if the input was in (1, 128) tiling - if (some_layout->tiling()[0] == 1) { + if (layout.tiling()[0] == 1) { offsets[0] = std::nullopt; } - *some_layout = VectorLayout(some_layout->bitwidth(), offsets, - default_tiling_, some_layout->implicit_dim()); + layout = VectorLayout(layout.bitwidth(), offsets, default_tiling_, + layout.implicit_dim()); } - auto &layout = *some_layout; if (layout.implicit_dim() != ImplicitDim::kNone) { VectorLayout layout_2d(layout.bitwidth(), layout.offsets(), layout.tiling(), ImplicitDim::kNone); if (layout_2d.equivalentTo(layout, src_ty.getShape(), target_shape_)) { + // TODO(b/342237796): Stop preferring 2D layouts (if given the choice) + // and defer the work, if any, to relayout. layout = layout_2d; - } else { - op.emitOpError() << "Only 2D layouts supported"; - return failure(); } } auto src_tiled_shape = src_ty.getShape().take_back(2); auto dst_tiled_shape = res_ty.getShape().take_back(2); LayoutOffsets offsets = layout.offsets(); - if (layout.bitwidth() == kNativeBitwidth && - layout.tiling() == default_tiling_) { - for (int i = 0; i < 2; ++i) { - if (src_tiled_shape[i] != dst_tiled_shape[i]) { - offsets[i] = std::nullopt; - } + for (int i = 0; i < 2; ++i) { + if (src_tiled_shape[i] != dst_tiled_shape[i]) { + offsets[i] = std::nullopt; } } - setLayout(op, some_layout, + setLayout(op, layout, VectorLayout(layout.bitwidth(), offsets, layout.tiling(), - ImplicitDim::kNone)); + layout.implicit_dim())); return success(); } op.emitOpError("unsupported broadcast source type"); @@ -817,10 +1054,37 @@ class VectorLayoutInferer { "Only 32-bit types supported"); auto layout = getLayout(op.getVector()); TPU_CHECK_OP(layout.has_value(), "missing vector layout"); - setLayout(op, - VectorLayout(kNativeBitwidth, {0, 0}, layout->tiling(), - layout->implicit_dim()), - kNoLayout); + if (VectorType res_vty = dyn_cast(op.getResult().getType()); + res_vty != nullptr) { + if (res_vty.getRank() == 1 && + layout->implicit_dim() == ImplicitDim::kNone) { + const int64_t second_minor_idx = op.getStaticPosition().back(); + const LayoutOffset second_minor_offset = layout->offsets()[0]; + const LayoutOffset res_second_minor_offset = + second_minor_offset.has_value() + ? (*second_minor_offset + second_minor_idx) % + layout->vregSlice(target_shape_)[0] + : LayoutOffset(); + TPU_CHECK_OP(!res_second_minor_offset.has_value() || + *res_second_minor_offset < layout->tiling()[0], + "Not implemented: Slice does not start on the first tile " + "of a VReg"); + setLayout(op, layout, + VectorLayout(layout->bitwidth(), + {res_second_minor_offset, layout->offsets()[1]}, + layout->tiling(), ImplicitDim::kSecondMinor)); + } else { + TPU_CHECK_OP(layout->layout_rank() <= res_vty.getRank(), + "Internal error: Layout has too many dimensions for " + "vector type (invalid vector.extract?)") + setLayout(op, layout, layout); + } + } else { + setLayout(op, + VectorLayout(kNativeBitwidth, {0, 0}, layout->tiling(), + layout->implicit_dim()), + kNoLayout); + } return success(); } @@ -831,6 +1095,10 @@ class VectorLayoutInferer { "memref and vector rank mismatch"); int64_t rank = res_ty.getRank(); int8_t bitwidth = res_ty.getElementTypeBitWidth(); + if (kNativeBitwidth % bitwidth != 0) { + return op.emitOpError("Unsupported bitwidth"); + } + const int packing = kNativeBitwidth / bitwidth; auto maybe_tiling = verifyMemoryTiling(op, getMemRefLayout(op.getBase()).getTiles(), src_ty.getRank(), src_ty.getElementTypeBitWidth()); @@ -862,14 +1130,14 @@ class VectorLayoutInferer { } if (rank == 1) { TPU_CHECK_OP(tiling.size() == 1, "Expected 1D tiling in 1D loads"); + const int64_t lane_tiling = packing * target_shape_[1]; auto tile = tiling.front(); - TPU_CHECK_OP(tile % target_shape_[1] == 0, - "Unsupported tiling for 1D load"); + TPU_CHECK_OP(tile % lane_tiling == 0, "Unsupported tiling for 1D load"); CHECK_EQ(tile_offsets.size(), 1); // TODO(apaszke): We could generate replicated loads for short values. setLayout(op, in_layout, - VectorLayout(bitwidth, {0, tile_offsets[0]}, {1, tile}, - ImplicitDim::kSecondMinor)); + VectorLayout(bitwidth, {0, tile_offsets[0] % lane_tiling}, + {1, lane_tiling}, ImplicitDim::kSecondMinor)); } else { // rank >= 2 TPU_CHECK_OP(tiling.size() == 2, "Expected 2D tiling in 2D+ loads"); CHECK_EQ(tile_offsets.size(), 2); @@ -919,22 +1187,35 @@ class VectorLayoutInferer { LogicalResult infer(vector::ExtractStridedSliceOp op) { auto input_layout = getLayout(op.getVector()); TPU_CHECK_OP(input_layout, "missing vector layout"); - TPU_CHECK_OP(input_layout->implicit_dim() == ImplicitDim::kNone, - "only 2D layouts supported"); - TPU_CHECK_OP(op.getType().getElementTypeBitWidth() == 32, - "Only 32-bit types supported"); - auto offsets = op.getOffsets().getValue(); - auto strides = op.getStrides().getValue(); - for (auto offset_attr : offsets.take_back(2)) { - int off = offset_attr.cast().getInt(); - TPU_CHECK_OP(off == 0, "Only zero-offset slices supported."); - } - for (auto stride : strides) { + auto offsets_attr = op.getOffsets().getValue(); + auto strides_attr = op.getStrides().getValue(); + auto offsets = llvm::map_to_vector(offsets_attr, [](auto attr) { + return cast(attr).getInt(); + }); + input_layout->insertImplicit(offsets, 0); + auto vreg_slice = input_layout->vregSlice(target_shape_); + LayoutOffsets new_layout_offsets; + if (input_layout->offsets()[0].has_value()) { + new_layout_offsets[0] = + (*(offsets.end() - 2) + *input_layout->offsets()[0]) % vreg_slice[0]; + } + if (input_layout->offsets()[1].has_value()) { + new_layout_offsets[1] = + (*(offsets.end() - 1) + *input_layout->offsets()[1]) % vreg_slice[1]; + } + TPU_CHECK_OP( + new_layout_offsets[0].value_or(0) < input_layout->tiling()[0] && + new_layout_offsets[1].value_or(0) < input_layout->tiling()[1], + "Not implemented: Resulting offsets are not in first tile within vreg"); + for (auto stride : strides_attr) { TPU_CHECK_OP(stride.cast().getInt() == 1, "Only trivial strides supported."); } - setLayout(op, input_layout, input_layout); + setLayout( + op, input_layout, + VectorLayout(input_layout->bitwidth(), new_layout_offsets, + input_layout->tiling(), input_layout->implicit_dim())); return success(); } @@ -1014,6 +1295,8 @@ class VectorLayoutInferer { auto some_src_layout = getLayout(op.getSource()); TPU_CHECK_OP(some_src_layout, "missing vector layout"); auto layout = *some_src_layout; + const unsigned bitwidth = src_ty.getElementTypeBitWidth(); + const std::array vreg_slice = layout.vregSlice(target_shape_); if (layout.implicit_dim() == ImplicitDim::kNone) { // Nothing changes in the last two dims. if (res_rank >= 2 && src_shape.take_back(2) == res_shape.take_back(2)) { @@ -1021,62 +1304,57 @@ class VectorLayoutInferer { return success(); } // Sublane (un)tiling. - if (res_rank >= 2 && layout.tiling()[1] == target_shape_[1] && - src_ty.getDimSize(src_ty.getRank() - 1) == - res_shape[res_shape.size() - 1] && - src_ty.getDimSize(src_ty.getRank() - 2) % layout.tiling()[0] == 0 && - res_shape[res_shape.size() - 2] % layout.tiling()[0] == 0) { - layout = VectorLayout(layout.bitwidth(), {0, 0}, layout.tiling(), - layout.implicit_dim()); + if (res_rank >= 2 && *(src_shape.end() - 1) == *(res_shape.end() - 1) && + *(src_shape.end() - 2) % vreg_slice[0] == 0 && + *(res_shape.end() - 2) % vreg_slice[0] == 0) { + // TODO(b/343808585): We shouldn't force second minor offset to 0 when + // unfolding, it's still a no-op, but we need to add + // support in apply-vector-layout. + layout = VectorLayout(layout.bitwidth(), {0, layout.offsets()[1]}, + layout.tiling(), layout.implicit_dim()); setLayout(op, layout, layout); return success(); } + const auto native_tiling = nativeTiling(bitwidth); // Lane (un)tiling. - if (layout.tiling()[1] == target_shape_[1] && - src_ty.getDimSize(src_ty.getRank() - 1) != + if (src_ty.getDimSize(src_ty.getRank() - 1) != res_shape[res_shape.size() - 1] && src_ty.getDimSize(src_ty.getRank() - 1) % layout.tiling()[1] == 0 && res_shape[res_shape.size() - 1] % layout.tiling()[1] == 0) { - // TODO(jevinjiang): support shapecast along lane with any bitwidth. - if (src_ty.getElementTypeBitWidth() != kNativeBitwidth) { - NYI("Shapecast along lane dimension when bitwidth is not 32"); - } - - // When we shapecast from input shape (..., m * target_shape_[1]) to - // output shape (..., target_shape_[1]), the reshape becomes no-op when - // input is densely packed with tiling (1, target_shape_[1]) and - // output has the native tiling. + const int packing = kNativeBitwidth / bitwidth; + const auto elements_per_vreg = native_tiling[0] * native_tiling[1]; + // When we shapecast from input shape + // (..., m * target_shape_[1] * packing) to output shape + // (..., target_shape_[1]), the reshape becomes no-op when input is + // densely packed with tiling (1, target_shape_[1] * packing) and output + // has the native tiling. if (*(res_shape.end() - 1) == target_shape_[1] && - *(res_shape.end() - 2) % target_shape_[0] == 0 && - *(src_shape.end() - 1) % (target_shape_[0] * target_shape_[1]) == - 0 && - (*(src_shape.end() - 2) == 1 || - *(src_shape.end() - 2) % target_shape_[0] == 0)) { - // Inferring in_layout to have tiling (1, 128) triggers any + *(res_shape.end() - 2) % native_tiling[0] == 0 && + *(src_shape.end() - 1) % elements_per_vreg == 0) { + // Inferring in_layout to have tiling (1, 128 * packing) triggers any // necessary relayout before shapecast. - setLayout(op, - VectorLayout(layout.bitwidth(), {0, 0}, - {1, target_shape_[1]}, ImplicitDim::kNone), - VectorLayout(layout.bitwidth(), {0, 0}, default_tiling_, - ImplicitDim::kNone)); + setLayout( + op, + VectorLayout(layout.bitwidth(), {0, 0}, + {1, target_shape_[1] * packing}, ImplicitDim::kNone), + VectorLayout(layout.bitwidth(), {0, 0}, native_tiling, + ImplicitDim::kNone)); return success(); } - // When we shapecast from input shape (..., target_shape_[1]) to - // output shape (..., m * target_shape_[1]), the reshape becomes no-op - // when input has the native tiling and output is densely packed with - // tiling (1, target_shape_[1]). + // When we shapecast from input shape (..., target_shape_[1]) to output + // shape (..., m * target_shape_[1] * packing), the reshape becomes + // no-op when input has the native tiling and output is densely packed + // with tiling (1, target_shape_[1] * packing). if (*(src_shape.end() - 1) == target_shape_[1] && - *(src_shape.end() - 2) % target_shape_[0] == 0 && - *(res_shape.end() - 1) % (target_shape_[0] * target_shape_[1]) == - 0 && - (*(res_shape.end() - 2) == 1 || - *(res_shape.end() - 2) % target_shape_[0] == 0)) { + *(src_shape.end() - 2) % native_tiling[0] == 0 && + *(res_shape.end() - 1) % elements_per_vreg == 0) { setLayout(op, - VectorLayout(layout.bitwidth(), {0, 0}, default_tiling_, + VectorLayout(layout.bitwidth(), {0, 0}, native_tiling, ImplicitDim::kNone), VectorLayout(layout.bitwidth(), {0, 0}, - {1, target_shape_[1]}, ImplicitDim::kNone)); + {1, target_shape_[1] * packing}, + ImplicitDim::kNone)); return success(); } @@ -1084,8 +1362,6 @@ class VectorLayoutInferer { op.emitOpError("unsupported shape cast"); return failure(); } - unsigned bitwidth = src_ty.getElementTypeBitWidth(); - auto native_tiling = nativeTiling(bitwidth); if (layout.tiling() != native_tiling) { layout = VectorLayout(bitwidth, layout.offsets(), native_tiling, layout.implicit_dim()); @@ -1096,7 +1372,6 @@ class VectorLayoutInferer { if (res_ty.getRank() >= 2) { // Squeeze out the sublane dim. if (layout_shape[0] == 1 && - res_shape.drop_back(1) == src_shape.drop_back(2) && res_shape.back() == src_shape.back()) { setLayout(op, layout, VectorLayout(bitwidth, layout.offsets(), layout.tiling(), @@ -1114,28 +1389,28 @@ class VectorLayoutInferer { return success(); } } else if (res_ty.getRank() == 1) { - bool all_one = true; - for (int64_t s : src_ty.getShape().drop_back(2)) { - all_one &= s == 1; - } - // Squeeze out everything, but lanes - if (layout_shape[0] == 1 && all_one && - res_ty.getShape().back() == layout_shape[1]) { + // All dimensions have been folded into a single one + + // Squeeze all but minor dimension + if (res_ty.getShape().back() == layout_shape[1]) { + // The condition implies that everything apart from the minor + // dimension is 1 in the source. setLayout(op, layout, VectorLayout(bitwidth, layout.offsets(), layout.tiling(), ImplicitDim::kSecondMinor)); return success(); } - // Squeeze out everything, but sublanes - if (layout_shape[1] == 1 && all_one && - res_ty.getShape().back() == layout_shape[0]) { - TPU_CHECK_OP(src_ty.getElementTypeBitWidth() == kNativeBitwidth, - "only 32-bit shape casts supported"); + // Squeeze all but second minor dimension + if (res_ty.getShape().back() == layout_shape[0]) { + // The condition implies that everything apart from the second minor + // dimension is 1 in the source setLayout(op, layout, VectorLayout(kNativeBitwidth, layout.offsets(), layout.tiling(), ImplicitDim::kMinor)); return success(); } + // TODO(b/340625465): Add case where layout_shape is (1, 1) and we fold + // batch dimensions once we support 0-D layouts. } } else { // Nothing changes in the last dim. @@ -1143,22 +1418,23 @@ class VectorLayoutInferer { setLayout(op, layout, layout); return success(); } - TPU_CHECK_OP(src_ty.getElementTypeBitWidth() == kNativeBitwidth, - "only 32-bit shape casts supported"); // Insert a singleton innermost dim. if (res_ty.getRank() == src_ty.getRank() + 1 && src_ty.getDimSize(src_rank - 1) == res_ty.getDimSize(res_rank - 2) && res_ty.getDimSize(res_rank - 1) == 1) { if (layout.implicit_dim() == ImplicitDim::kMinor) { setLayout(op, layout, - VectorLayout(kNativeBitwidth, layout.offsets(), - default_tiling_, ImplicitDim::kNone)); + VectorLayout(bitwidth, layout.offsets(), layout.tiling(), + ImplicitDim::kNone)); } else { + TPU_CHECK_OP(bitwidth == kNativeBitwidth, + "Insertion of minor dim that is not a no-op only " + "supported for 32-bit types"); TPU_CHECK_OP(layout.implicit_dim() == ImplicitDim::kSecondMinor, "unexpected implicit dim value"); setLayout(op, layout, - VectorLayout(kNativeBitwidth, {0, std::nullopt}, - default_tiling_, ImplicitDim::kNone)); + VectorLayout(bitwidth, {0, std::nullopt}, default_tiling_, + ImplicitDim::kNone)); } return success(); } @@ -1174,6 +1450,10 @@ class VectorLayoutInferer { "memref and vector rank mismatch"); int64_t rank = ref_ty.getRank(); int8_t bitwidth = store_ty.getElementTypeBitWidth(); + if (kNativeBitwidth % bitwidth != 0) { + return op.emitOpError("Unsupported bitwidth"); + } + const int packing = kNativeBitwidth / bitwidth; auto maybe_tiling = verifyMemoryTiling(op, getMemRefLayout(op.getBase()).getTiles(), ref_ty.getRank(), ref_ty.getElementTypeBitWidth()); @@ -1204,12 +1484,14 @@ class VectorLayoutInferer { } if (rank == 1) { TPU_CHECK_OP(tiling.size() == 1, "Expected 1D tiling in 1D store"); + const int64_t lane_tiling = packing * target_shape_[1]; auto tile = tiling.front(); - TPU_CHECK_OP(tile % target_shape_[1] == 0, + TPU_CHECK_OP(tile % lane_tiling == 0, "Unsupported 1D tiling for 1D store"); CHECK_EQ(tile_offsets.size(), 1); - store_layout = VectorLayout(bitwidth, {0, tile_offsets[0]}, {1, tile}, - ImplicitDim::kSecondMinor); + store_layout = + VectorLayout(bitwidth, {0, tile_offsets[0] % lane_tiling}, + {1, lane_tiling}, ImplicitDim::kSecondMinor); } else { // rank >= 2 // NOLINT(readability-else-after-return) TPU_CHECK_OP(tiling.size() == 2, "Expected 2D tiling in 2D+ store"); CHECK_EQ(tile_offsets.size(), 2); @@ -1252,44 +1534,32 @@ class VectorLayoutInferer { LogicalResult infer(vector::TransposeOp op) { auto permutation = op.getPermutation(); + TPU_CHECK_OP(permutation.size() > 1, + "Vector and scalar transpose should be a no-op and removed"); + auto some_layout = getLayout(op.getVector()); TPU_CHECK_OP(some_layout.has_value(), "missing vector layout"); auto &layout = *some_layout; auto src_ty = op.getSourceVectorType(); TPU_CHECK_OP(permutation.size() == src_ty.getRank(), "Transpose permutation has incorrect rank"); - if (layout.implicit_dim() == ImplicitDim::kNone) { - TPU_CHECK_OP((layout.offsets() == LayoutOffsets{0, 0}), - "Padded transposes unsupported"); - auto xlu_width = target_shape_[1]; - for (int64_t s : src_ty.getShape().take_back(2)) { - TPU_CHECK_OP(s % xlu_width == 0, "Padded transposes unsupported"); - } - for (auto dim : permutation.drop_back(2)) { - TPU_CHECK_OP( - dim < src_ty.getRank() - 2, - "Unsupported transpose permutation - minor dims into major"); - } - for (auto dim : permutation.take_back(2)) { - TPU_CHECK_OP( - dim >= src_ty.getRank() - 2, - "Unsupported transpose permutation - major dims into minor"); - } - Layout required_layout = some_layout; - if (permutation.size() < 2) { - return failure(); - } - // Require native tiling if we're going to use the XLU. - if (permutation[permutation.size() - 1] == permutation.size() - 2) { - auto native_tiling = nativeTiling(layout.bitwidth()); - required_layout = VectorLayout(layout.bitwidth(), layout.offsets(), - native_tiling, ImplicitDim::kNone); - } - setLayout(op, required_layout, required_layout); - return success(); + for (auto dim : permutation.drop_back(2)) { + TPU_CHECK_OP(dim < src_ty.getRank() - 2, + "Unsupported transpose permutation - minor dims into major"); } - op.emitOpError("Unsupported transpose"); - return failure(); + for (auto dim : permutation.take_back(2)) { + TPU_CHECK_OP(dim >= src_ty.getRank() - 2, + "Unsupported transpose permutation - major dims into minor"); + } + Layout required_layout = some_layout; + // Require native tiling if we're going to use the XLU. + if (permutation[permutation.size() - 1] == permutation.size() - 2) { + auto native_tiling = nativeTiling(layout.bitwidth()); + required_layout = VectorLayout(layout.bitwidth(), LayoutOffsets{0, 0}, + native_tiling, ImplicitDim::kNone); + } + setLayout(op, required_layout, required_layout); + return success(); } LogicalResult inferExt(Operation *op) { @@ -1312,34 +1582,32 @@ class VectorLayoutInferer { "Only extensions to 32-bit supported"); } auto &layout = *some_layout; - if (layout.implicit_dim() == ImplicitDim::kNone) { - // TODO(apaszke): Support native packed layouts here. - Layout src_layout; - Layout dst_layout; - // All layouts that subdivide the rows of the default tiling evenly - // can be handled uniformly with the default case, by preserving the - // tiling through the op. - if (default_tiling_[0] % layout.tiling()[0] == 0 && - default_tiling_[1] == layout.tiling()[1]) { - src_layout = layout; - } else { - src_layout = VectorLayout(layout.bitwidth(), layout.offsets(), - default_tiling_, ImplicitDim::kNone); - } + // TODO(apaszke): Support native packed layouts here. + Layout src_layout; + Layout dst_layout; + // All layouts that subdivide the rows of the default tiling evenly + // can be handled uniformly with the default case, by preserving the + // tiling through the op. + if (default_tiling_[0] % layout.tiling()[0] == 0 && + default_tiling_[1] == layout.tiling()[1]) { + src_layout = layout; dst_layout = VectorLayout(32, layout.offsets(), src_layout->tiling(), - ImplicitDim::kNone); - setLayout(op, src_layout, dst_layout); - return success(); - } - if (layout.implicit_dim() == ImplicitDim::kSecondMinor) { - TPU_CHECK_OP(layout.tiling() == nativeTiling(16), "unsupported tiling"); - auto dst_layout = VectorLayout(32, layout.offsets(), default_tiling_, - layout.implicit_dim()); - setLayout(op, some_layout, dst_layout); - return success(); + layout.implicit_dim()); + } else if (layout.tiling() == + nativeTiling(src_ty.getElementTypeBitWidth())) { + // If the source is already in native tiling, we can unpack it directly. + src_layout = layout; + dst_layout = VectorLayout(32, layout.offsets(), default_tiling_, + layout.implicit_dim()); + } else { + // TODO(b/335863273): we should also reduce offsets. + src_layout = VectorLayout(layout.bitwidth(), layout.offsets(), + default_tiling_, layout.implicit_dim()); + dst_layout = VectorLayout(32, layout.offsets(), default_tiling_, + layout.implicit_dim()); } - op->emitOpError("unsupported extension layout"); - return failure(); + setLayout(op, src_layout, dst_layout); + return success(); } LogicalResult inferTrunc(Operation *op) { @@ -1355,27 +1623,24 @@ class VectorLayoutInferer { TPU_CHECK_OP(some_layout.has_value(), "missing vector layout"); if (dyn_cast(op)) { TPU_CHECK_OP(src_ty.getElementTypeBitWidth() == 32 && - dst_ty.getElementTypeBitWidth() == 16, - "Only 32-bit to 16-bit truncation supported"); + (dst_ty.getElementTypeBitWidth() == 16 || + dst_ty.getElementTypeBitWidth() == 8), + "Only 32-bit to 8-bit or 16-bit truncation supported"); } else { TPU_CHECK_OP(src_ty.getElementTypeBitWidth() == 32, "Only 32-bit truncation supported"); } auto &layout = *some_layout; - if (layout.implicit_dim() == ImplicitDim::kNone) { - bool select_native = allUsersRequireNativeTiling(op->getResult(0)); - auto src_layout = VectorLayout(32, layout.offsets(), default_tiling_, - ImplicitDim::kNone); - auto dst_layout = VectorLayout( - dst_ty.getElementTypeBitWidth(), layout.offsets(), - select_native ? nativeTiling(dst_ty.getElementTypeBitWidth()) - : default_tiling_, - ImplicitDim::kNone); - setLayout(op, src_layout, dst_layout); - return success(); - } - op->emitOpError("unsupported truncation layout"); - return failure(); + bool select_native = allUsersRequireNativeTiling(op->getResult(0)); + auto src_layout = VectorLayout(32, layout.offsets(), default_tiling_, + layout.implicit_dim()); + auto dst_layout = VectorLayout( + dst_ty.getElementTypeBitWidth(), layout.offsets(), + select_native ? nativeTiling(dst_ty.getElementTypeBitWidth()) + : default_tiling_, + layout.implicit_dim()); + setLayout(op, src_layout, dst_layout); + return success(); } LogicalResult inferElementwise(Operation *op, bool check_bitwidth = true) { @@ -1463,25 +1728,31 @@ class VectorLayoutInferer { } LogicalResult inferMatmul(Operation *op) { - auto get_unpadded_layout = - [&](Value v, std::optional major_multiple = std::nullopt, + auto get_operand_layout = + [&](Value v, llvm::StringRef operand_name, + std::optional major_multiple = std::nullopt, std::optional minor_multiple = std::nullopt) -> std::optional { - auto pad = getLayout(v); - if (!pad.has_value() || pad->implicit_dim() != ImplicitDim::kNone) { + auto layout = getLayout(v); + if (!layout.has_value()) { + op->emitOpError("Internal error: assert failed: Operand ") + << operand_name << " has no vector layout"; return std::nullopt; } auto vty = cast(v.getType()); auto tiling = nativeTiling(vty.getElementTypeBitWidth()); auto shape = vty.getShape().take_back(2); - if (pad->offsets()[0].value_or(0) != 0 || - pad->offsets()[1].value_or(0) != 0 || - shape[0] % major_multiple.value_or(tiling[0]) != 0 || + if (shape[0] % major_multiple.value_or(tiling[0]) != 0 || shape[1] % minor_multiple.value_or(tiling[1]) != 0) { + op->emitOpError("Matmul operand") + << operand_name << " must have a shape divisible by (" + << major_multiple.value_or(tiling[0]) << ", " + << minor_multiple.value_or(tiling[1]) << "), but got: (" << shape[0] + << ", " << shape[1] << ")"; return std::nullopt; } // Override tiling to match the native one. - return VectorLayout(pad->bitwidth(), pad->offsets(), tiling, + return VectorLayout(layout->bitwidth(), {0, 0}, tiling, ImplicitDim::kNone); }; auto res_ty = dyn_cast(op->getResult(0).getType()); @@ -1503,21 +1774,35 @@ class VectorLayoutInferer { rhs_major_multiple = 1; } in_layout[0] = - get_unpadded_layout(op->getOperand(0), lhs_major_multiple, 1); + get_operand_layout(op->getOperand(0), "lhs", lhs_major_multiple, 1); + if (!in_layout[0].has_value()) { + return failure(); + } in_layout[1] = - get_unpadded_layout(op->getOperand(1), rhs_major_multiple, 1); - in_layout[2] = get_unpadded_layout(op->getOperand(2), 1, 1); - for (Layout &l : in_layout) { - if (!l.has_value()) { - op->emitOpError("unsupported operand shapes or layouts"); - return failure(); - } + get_operand_layout(op->getOperand(1), "rhs", rhs_major_multiple, 1); + if (!in_layout[1].has_value()) { + return failure(); + } + in_layout[2] = get_operand_layout(op->getOperand(2), "result", 1, 1); + if (!in_layout[2].has_value()) { + return failure(); } setLayout(op, in_layout, VectorLayout(kNativeBitwidth, {0, 0}, default_tiling_, ImplicitDim::kNone)); return success(); } + LogicalResult infer(tpu::PRNGRandomBitsOp op) { + auto res_ty = dyn_cast(op->getResult(0).getType()); + TPU_CHECK_OP(res_ty.getElementTypeBitWidth() == kNativeBitwidth, + "only 32-bit random bit generation supported"); + // TODO: b/342054464 - Support implicit dims for PRNGRandomBitsOp. + LayoutOffsets offsets = {0, 0}; + setOutLayout(op, VectorLayout( + kNativeBitwidth, offsets, nativeTiling(kNativeBitwidth), + ImplicitDim::kNone)); + return success(); + } bool allUsersRequireNativeTiling(Value x) { for (OpOperand &operand : x.getUses()) { @@ -1539,6 +1824,53 @@ class VectorLayoutInferer { return true; } + LogicalResult assumeLayoutsForBlockArgs(Block &block, + ArrayRef layouts) { + auto op = block.getParentOp(); + if (layouts.size() != block.getNumArguments()) { + return op->emitOpError( + "Block arguments must have the same number of layouts"); + } + // Use tpu.assume_layout to annotate every block argument with the layout of + // the corresponding operand and replace all uses of the block argument with + // the result of tpu.assume_layout. + ImplicitLocOpBuilder builder = + ImplicitLocOpBuilder::atBlockBegin(op->getLoc(), &block); + for (auto [iter_arg, layout] : + llvm::zip_equal(block.getArguments(), layouts)) { + if (!dyn_cast(iter_arg.getType())) { + continue; + } + if (llvm::any_of(iter_arg.getUsers(), [](Operation *user) { + return isa(user); + })) { + return op->emitOpError("Expected no assume layout for block arguments"); + } + auto assume_layout_op = + builder.create(iter_arg.getType(), iter_arg); + setLayout(assume_layout_op, layout, layout); + iter_arg.replaceUsesWithIf(assume_layout_op, [&](OpOperand &operand) { + return operand.getOwner() != assume_layout_op; + }); + } + return success(); + } + + void clearBlockLayouts(Block &block) { + block.walk([&](Operation *op) { + // We need to remove assume_layout ops in each block. Otherwise, we will + // create extra assume_layout ops for nested blocks. + if (auto assume_op = dyn_cast(op)) { + assume_op.getResult().replaceAllUsesWith(assume_op.getInput()); + assume_op->erase(); + return WalkResult::advance(); + } + op->removeAttr("in_layout"); + op->removeAttr("out_layout"); + return WalkResult::advance(); + }); + } + void setInLayout(Operation *op, ArrayRef in) { CHECK_EQ(in.size(), op->getNumOperands()) << Print(op); SmallVector in_attrs; @@ -1616,10 +1948,24 @@ class VectorLayoutInferer { return cast(out_attrs[result_index]).getLayout(); } + SmallVector getLayoutFromOperands(Operation *op) { + SmallVector layouts; + layouts.reserve(op->getNumOperands()); + for (const auto &operand : op->getOperands()) { + if (isa(operand.getType())) { + layouts.push_back(getLayout(operand)); + } else { + layouts.push_back(kNoLayout); + } + } + return layouts; + } + private: std::optional> verifyMemoryTiling( Operation *op, ArrayRef mem_tiling, int64_t rank, int8_t bitwidth) { + const int packing = kNativeBitwidth / bitwidth; if (bitwidth == 32) { if (mem_tiling.size() != 1) { op->emitOpError("Only one-level tiling supported for 32-bit loads"); @@ -1636,7 +1982,7 @@ class VectorLayoutInferer { } auto first = mem_tiling[0].dimensions(); auto second = mem_tiling[1].dimensions(); - if (first.size() != 1 || first[0] % target_shape_[1] != 0) { + if (first.size() != 1 || first[0] % (packing * target_shape_[1]) != 0) { op->emitOpError("Invalid first-level tile in 1D memory op"); return std::nullopt; } diff --git a/jaxlib/mosaic/dialect/tpu/transforms/linalg_vectorization.cc b/jaxlib/mosaic/dialect/tpu/transforms/linalg_vectorization.cc index 00996d8862aa..4d5e62049098 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/linalg_vectorization.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/linalg_vectorization.cc @@ -16,14 +16,31 @@ limitations under the License. #include #include +#include "absl/algorithm/container.h" +#include "absl/container/flat_hash_set.h" +#include "absl/strings/string_view.h" +#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/include/mlir/Dialect/Linalg/Transforms/Hoisting.h" #include "mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/include/mlir/Dialect/Math/IR/Math.h" +#include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/include/mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/include/mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h" #include "mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" +#include "mlir/include/mlir/IR/AffineMap.h" +#include "mlir/include/mlir/IR/Attributes.h" +#include "mlir/include/mlir/IR/BuiltinAttributes.h" +#include "mlir/include/mlir/IR/BuiltinTypeInterfaces.h" +#include "mlir/include/mlir/IR/BuiltinTypes.h" #include "mlir/include/mlir/IR/DialectRegistry.h" +#include "mlir/include/mlir/IR/Matchers.h" #include "mlir/include/mlir/IR/Operation.h" +#include "mlir/include/mlir/IR/OperationSupport.h" #include "mlir/include/mlir/IR/PatternMatch.h" +#include "mlir/include/mlir/IR/Types.h" +#include "mlir/include/mlir/IR/Value.h" #include "mlir/include/mlir/Pass/Pass.h" #include "mlir/include/mlir/Support/LLVM.h" #include "mlir/include/mlir/Support/LogicalResult.h" @@ -49,9 +66,403 @@ struct VectorizationPattern } }; +// Check preconditions for `vector.transfer_read` rewrite patterns. +LogicalResult checkPreconditions(vector::TransferReadOp op, + PatternRewriter &rewriter) { + if (op.hasOutOfBoundsDim()) { + return rewriter.notifyMatchFailure(op, "out of bounds transfer dim"); + } + if (op.getMask()) { + return rewriter.notifyMatchFailure(op, "masked transfer"); + } + if (!op.getPermutationMap().isIdentity()) { + return rewriter.notifyMatchFailure(op, "non identity permutation map"); + } + SmallVector indices = {op.getIndices().begin(), op.getIndices().end()}; + if (absl::c_any_of( + indices, [](Value index) { return !isConstantIntValue(index, 0); })) { + return rewriter.notifyMatchFailure(op, "non zero indices"); + } + return success(); +} + +// Create a `vector.transfer_read` based on the original |op|, which succeeds +// the checkPreconditions() call. +vector::TransferReadOp createTransferReadOp(vector::TransferReadOp op, + Value source, + RankedTensorType source_ty, + PatternRewriter &rewriter) { + // We know from preconditions that there are no out of bound dims. + SmallVector in_bounds(source_ty.getRank(), true); + return rewriter.create( + op.getLoc(), + VectorType::get(source_ty.getShape(), source_ty.getElementType()), source, + SmallVector( + source_ty.getRank(), + rewriter.create(op.getLoc(), 0)), + AffineMapAttr::get(AffineMap::getMultiDimIdentityMap(source_ty.getRank(), + op->getContext())), + rewriter.getBoolArrayAttr(in_bounds)); +} + +template +LogicalResult matchAndRewriteTransferOfExpandOrCollapseShape( + vector::TransferReadOp op, PatternRewriter &rewriter) { + if (failed(checkPreconditions(op, rewriter))) { + return failure(); + } + auto expand = op.getSource().template getDefiningOp(); + if (!expand) { + return rewriter.notifyMatchFailure( + op, "not a tensor.expand_shape/collapse_shape"); + } + if (auto result_type = dyn_cast(op.getType()); + !result_type || + result_type.getShape() != expand.getResultType().getShape()) { + return rewriter.notifyMatchFailure(op, "output type mismatch"); + } + auto expand_src_type = expand.getSrcType(); + // We know from preconditions that there are no out of bound dims. + SmallVector in_bounds(expand_src_type.getRank(), true); + rewriter.replaceOpWithNewOp( + op, op.getType(), + createTransferReadOp(op, expand.getSrc(), expand_src_type, rewriter)); + return success(); +} + +// Rewrite `vector.transfer_read(tensor.expand_shape)` as +// `vector.shape_cast(vector.transfer_read)`. +struct TransferReadOfExpandShape + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::TransferReadOp op, + PatternRewriter &rewriter) const override { + return matchAndRewriteTransferOfExpandOrCollapseShape< + tensor::ExpandShapeOp>(op, rewriter); + } +}; + +// Rewrite `vector.transfer_read(tensor.collapse_shape)` as +// `vector.shape_cast(vector.transfer_read)`. +struct TransferReadOfCollapseShape + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::TransferReadOp op, + PatternRewriter &rewriter) const override { + return matchAndRewriteTransferOfExpandOrCollapseShape< + tensor::CollapseShapeOp>(op, rewriter); + } +}; + +// Rewrite a `vector.transfer_read` of a dense tensor constant as a dense +// vector constant. +struct TransferReadOfConstant + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::TransferReadOp op, + PatternRewriter &rewriter) const override { + DenseElementsAttr constant_elements; + Attribute constant_value; + if (matchPattern(op.getSource(), m_Constant(&constant_elements)) && + constant_elements.isSplat()) { + constant_value = constant_elements.getSplatValue(); + } else { + return rewriter.notifyMatchFailure(op, "not an arith.constant"); + } + rewriter.replaceOpWithNewOp( + op, op.getVectorType(), + DenseElementsAttr::get(op.getVectorType(), constant_value)); + return success(); + } +}; + +// Rewrite `vector.transfer_read(arith.select)` as `arith.select` with +// `transfer_read` applied to its operands. +struct TransferReadOfSelect : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::TransferReadOp op, + PatternRewriter &rewriter) const override { + if (failed(checkPreconditions(op, rewriter))) { + return failure(); + } + auto select = op.getSource().getDefiningOp(); + if (!select) { + return rewriter.notifyMatchFailure(op, "source not an arith.select"); + } + auto true_value_ty = + dyn_cast(select.getTrueValue().getType()); + if (!true_value_ty) { + return rewriter.notifyMatchFailure( + op, "true value is not a ranked tensor type"); + } + // We do not check the type of the false_value since the verifier enforces + // that types of true_value, false_value, and result match. + auto false_value_ty = + dyn_cast(select.getFalseValue().getType()); + auto condition_type = + dyn_cast(select.getCondition().getType()); + if (!condition_type) { + return rewriter.notifyMatchFailure( + op, "condition is not a ranked tensor type"); + } + auto transfer_read = [&](Value value, RankedTensorType type) { + return createTransferReadOp(op, value, type, rewriter); + }; + rewriter.replaceOpWithNewOp( + op, transfer_read(select.getCondition(), condition_type), + transfer_read(select.getTrueValue(), true_value_ty), + transfer_read(select.getFalseValue(), false_value_ty)); + return success(); + } +}; + +// Rewrite `vector.transfer_read(arith.cmpi)` as `arith.cmpi` with +// `transfer_read` applied to its operands. +struct TransferReadOfCmpI : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::TransferReadOp op, + PatternRewriter &rewriter) const override { + if (failed(checkPreconditions(op, rewriter))) { + return failure(); + } + auto cmp = op.getSource().getDefiningOp(); + if (!cmp) { + return rewriter.notifyMatchFailure(op, "source not an arith.cmpi"); + } + auto lhs_type = dyn_cast(cmp.getLhs().getType()); + if (!lhs_type) { + return rewriter.notifyMatchFailure(op, "lhs is not a ranked tensor type"); + } + auto rhs_type = dyn_cast(cmp.getRhs().getType()); + if (!rhs_type) { + return rewriter.notifyMatchFailure(op, "rhs is not a ranked tensor type"); + } + auto transfer_read = [&](Value value, RankedTensorType type) { + return createTransferReadOp(op, value, type, rewriter); + }; + rewriter.replaceOpWithNewOp( + op, cmp.getPredicate(), transfer_read(cmp.getLhs(), lhs_type), + transfer_read(cmp.getRhs(), rhs_type)); + return success(); + } +}; + +// Rewrite `vector.transfer_read(tensor.splat)` as `vector.broadcast`. +struct TransferReadOfSplat : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::TransferReadOp op, + PatternRewriter &rewriter) const override { + if (failed(checkPreconditions(op, rewriter))) { + return failure(); + } + auto splat = op.getSource().getDefiningOp(); + if (!splat) { + return rewriter.notifyMatchFailure(op, "source not a tensor.splat"); + } + if (!splat.getType().hasStaticShape()) { + return rewriter.notifyMatchFailure(op, "not statically shaped"); + } + rewriter.replaceOpWithNewOp(op, op.getVectorType(), + splat.getInput()); + return success(); + } +}; + +// List of operations that are covered by the supports_bf16_alu_instructions. +const auto kSupportedBf16Ops = absl::flat_hash_set( + {arith::AddFOp::getOperationName(), arith::SubFOp::getOperationName(), + arith::MulFOp::getOperationName(), arith::MaximumFOp::getOperationName(), + arith::MinimumFOp::getOperationName()}); + +// Rewrite operation with bf16 inputs/outputs into an operation with f32 +// inputs/outputs, where the inputs are extended and the outputs truncated. +// Non-bf16 operands remain unchanged. +// TODO(b/324596736): Extend the functionality to int8 and int16. +class GenericBitwidthConvert : public RewritePattern { + public: + explicit GenericBitwidthConvert(llvm::StringRef operation_name, + MLIRContext *ctx, + bool supports_bf16_alu_instructions) + : RewritePattern(operation_name, 0, ctx), + supports_bf16_alu_instructions_(supports_bf16_alu_instructions) {} + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override { + if (supports_bf16_alu_instructions_ && + kSupportedBf16Ops.contains(op->getName().getStringRef())) { + return rewriter.notifyMatchFailure(op, "target supports bf16 operands"); + } + llvm::SmallVector extended_operands; + extended_operands.reserve(op->getOperands().size()); + Location loc = op->getLoc(); + bool has_bf16_operand = false; + for (Value operand : op->getOperands()) { + auto operand_type = dyn_cast(operand.getType()); + if (!operand_type) { + return rewriter.notifyMatchFailure(op, "operand not a vector"); + } + if (!operand_type.getElementType().isBF16()) { + // Add the operand as is and continue, since not all operands must be + // bf16, for example in the case of a select op. + extended_operands.push_back(operand); + continue; + } + has_bf16_operand = true; + extended_operands.push_back(rewriter.create( + loc, VectorType::get(operand_type.getShape(), rewriter.getF32Type()), + operand)); + } + // If there are no bf16 operands, then we do not need to rewrite the op. + if (!has_bf16_operand) { + return rewriter.notifyMatchFailure(op, "no bf16 operands"); + } + llvm::SmallVector new_results; + new_results.reserve(op->getResultTypes().size()); + for (Type result_ty : op->getResultTypes()) { + auto result_type = dyn_cast(result_ty); + if (!result_type) { + return rewriter.notifyMatchFailure(op, "result is not a vector"); + } + if (!result_type.getElementType().isBF16()) { + return rewriter.notifyMatchFailure(op, + "result element type is not bf16"); + } + new_results.push_back( + VectorType::get(result_type.getShape(), rewriter.getF32Type())); + } + OperationState state(loc, op->getName().getStringRef(), extended_operands, + new_results, op->getAttrs(), op->getSuccessors()); + Operation *new_op = rewriter.create(state); + rewriter.replaceOpWithNewOp(op, op->getResultTypes(), + new_op->getResults()); + return success(); + } + + private: + // Whether the target supports bf16 ALU instructions. + const bool supports_bf16_alu_instructions_; +}; + +// Rewrite `vector.contraction` with bf16 accumulator and output into a +// contraction with f32 accumulator and output, where the accumulator is +// extended and the output truncated. For targets that do not support bf16 +// matmul, the lhs and rhs are extended to f32. +struct ContractionBitwidthConvert + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + ContractionBitwidthConvert(bool supports_bf16_matmul, MLIRContext *ctx) + : OpRewritePattern(ctx), supports_bf16_matmul_(supports_bf16_matmul) {} + + LogicalResult matchAndRewrite(vector::ContractionOp op, + PatternRewriter &rewriter) const override { + // The ContractionOp contract is that (1) lhs and rhs have same element + // type, and (2) the accumulator and result have the same element type. + + // If the target does not support bf16 matmul and we have bf16 operands, we + // need to extend the lhs and rhs to f32. + const bool extend_operands = + op.getLhsType().getElementType().isBF16() && !supports_bf16_matmul_; + // Determine if the accumulator is bf16 and hence needs to be extended to + // f32. + ShapedType acc_ty = dyn_cast(op.getAccType()); + if (acc_ty == nullptr) { + return rewriter.notifyMatchFailure(op, + "accumulator is not a shaped type"); + } + const bool extend_acc = acc_ty.getElementType().isBF16(); + + if (!extend_operands && !extend_acc) { + return rewriter.notifyMatchFailure(op, "no bf16 operands or accumulator"); + } + + Value lhs = op.getLhs(); + Value rhs = op.getRhs(); + if (extend_operands) { + lhs = rewriter.create( + op.getLoc(), + VectorType::get(op.getLhsType().getShape(), rewriter.getF32Type()), + lhs); + rhs = rewriter.create( + op.getLoc(), + VectorType::get(op.getRhsType().getShape(), rewriter.getF32Type()), + rhs); + } + + Value acc = op.getAcc(); + if (extend_acc) { + acc = rewriter.create( + op.getLoc(), + VectorType::get(acc_ty.getShape(), rewriter.getF32Type()), + op.getAcc()); + } + + vector::ContractionOp contraction = rewriter.create( + op.getLoc(), lhs, rhs, acc, op.getIndexingMaps(), op.getIteratorTypes(), + op.getKind()); + + if (extend_acc) { + rewriter.replaceOpWithNewOp( + op, dyn_cast(op.getResultType()), contraction); + } else { + rewriter.replaceOp(op, contraction); + } + return success(); + } + + private: + const bool supports_bf16_matmul_; +}; + +// Rewrite `vector.multi_dim_reduction` with bf16 source/accumulator/output into +// a multi_dim_reduction with f32 source/accumulator/output, where the source +// and accumulator are extended and the result is truncated. +// TODO(b/324596736): Make the rewrite conditional on the target supporting +// bf16 reductions. +struct MultiDimReductionBitwidthConvert + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::MultiDimReductionOp op, + PatternRewriter &rewriter) const override { + // Below we rely on the contract that the source operand, accumulator, and + // result have the same element type. + auto src_ty = op.getSourceVectorType(); + if (!src_ty.getElementType().isBF16()) { + return rewriter.notifyMatchFailure(op, "not bf16 reduction"); + } + + auto res_ty = dyn_cast(op.getResult().getType()); + if (!res_ty) { + return rewriter.notifyMatchFailure(op, "not vector reduction"); + } + + auto reduction = rewriter.create( + op.getLoc(), + rewriter.create( + op.getLoc(), + VectorType::get(src_ty.getShape(), rewriter.getF32Type()), + op.getSource()), + rewriter.create( + op.getLoc(), + VectorType::get(res_ty.getShape(), rewriter.getF32Type()), + op.getAcc()), + op.getReductionMask(), op.getKind()); + rewriter.replaceOpWithNewOp(op, res_ty, reduction); + return success(); + } +}; + struct LinalgVectorizationPass : public impl::LinalgVectorizationPassBase { - LinalgVectorizationPass() = default; + explicit LinalgVectorizationPass( + const LinalgVectorizationPassOptions &options) + : impl::LinalgVectorizationPassBase(options) {} void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); } @@ -70,11 +481,50 @@ struct LinalgVectorizationPass vector::populateVectorTransferPermutationMapLoweringPatterns(patterns); vector::TransferReadOp::getCanonicalizationPatterns(patterns, ctx); vector::TransferWriteOp::getCanonicalizationPatterns(patterns, ctx); + patterns.add(ctx); + // Pull in patterns to convert bf16 ops to f32 ops. + for (::llvm::StringLiteral unary_op_name : + {arith::NegFOp::getOperationName(), math::TanhOp::getOperationName(), + math::ExpOp::getOperationName(), math::AbsFOp::getOperationName(), + math::SinOp::getOperationName(), math::CosOp::getOperationName(), + math::SqrtOp::getOperationName(), math::RsqrtOp::getOperationName(), + math::LogOp::getOperationName(), math::Log1pOp::getOperationName(), + math::RoundOp::getOperationName(), + math::RoundEvenOp::getOperationName()}) { + patterns.add(unary_op_name, ctx, + supports_bf16_alu_instructions); + } + for (::llvm::StringLiteral binary_op_name : + {arith::MulFOp::getOperationName(), arith::DivFOp::getOperationName(), + arith::AddFOp::getOperationName(), arith::SubFOp::getOperationName(), + arith::MaximumFOp::getOperationName(), + arith::MinimumFOp::getOperationName(), + math::PowFOp::getOperationName()}) { + patterns.add(binary_op_name, ctx, + supports_bf16_alu_instructions); + } + for (::llvm::StringLiteral ternary_op_name : + {arith::SelectOp::getOperationName()}) { + patterns.add(ternary_op_name, ctx, + supports_bf16_alu_instructions); + } + patterns.add(supports_bf16_matmul, ctx); + patterns.add(ctx); // We do not want to apply the vector patterns above to the ops that are // unrelated to the original linalg op. SmallVector linalgOps; - func.walk([&](linalg::LinalgOp op) { linalgOps.push_back(op); }); + func.walk([&](Operation *op) { + if (dyn_cast(op) || + dyn_cast(op) || + dyn_cast(op) || + dyn_cast(op) || + dyn_cast(op)) { + linalgOps.push_back(op); + } + }); if (failed(applyOpPatternsAndFold(linalgOps, std::move(patterns)))) { return signalPassFailure(); } @@ -83,8 +533,12 @@ struct LinalgVectorizationPass } // namespace -std::unique_ptr> createLinalgVectorizationPass() { - return std::make_unique(); +std::unique_ptr> createLinalgVectorizationPass( + bool supports_bf16_alu_instructions, bool supports_bf16_matmul) { + LinalgVectorizationPassOptions options; + options.supports_bf16_alu_instructions = supports_bf16_alu_instructions; + options.supports_bf16_matmul = supports_bf16_matmul; + return std::make_unique(options); } } // namespace mlir::tpu diff --git a/jaxlib/mosaic/dialect/tpu/transforms/serde.cc b/jaxlib/mosaic/dialect/tpu/transforms/serde.cc index 1c4d5f6c323b..ac2389d6c238 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/serde.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/serde.cc @@ -27,11 +27,13 @@ limitations under the License. #include "mlir/IR/Visitors.h" #include "mlir/Pass/Pass.h" // IWYU pragma: keep #include "mlir/Support/LLVM.h" +#include "mlir/include/mlir/IR/OpDefinition.h" #include "mlir/include/mlir/IR/OperationSupport.h" +#include "mlir/include/mlir/Support/LogicalResult.h" +#include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h" namespace mlir::tpu { -#define GEN_PASS_DECL_MOSAICSERDEPASS #define GEN_PASS_DEF_MOSAICSERDEPASS #include "jaxlib/mosaic/dialect/tpu/tpu_passes.h.inc" @@ -39,7 +41,7 @@ namespace { constexpr std::string_view kMangledDialect = "stable_mosaic."; constexpr StringRef kVersionAttrName = "stable_mosaic.version"; -constexpr int kVersion = 1; +constexpr int kVersion = 2; StringRef mangle(StringRef name, std::string* storage) { storage->clear(); @@ -57,6 +59,55 @@ std::optional demangle(StringRef name) { return name.drop_front(kMangledDialect.size()); } +using rule_type = std::function; + +LogicalResult enqueue_dma_rule(Operation* op, int version) { + // Added AttrSizedOperandSegments and core_id in version 2. + if (version < 2) { + if (op->getNumOperands() == 3) { // Local DMA. + op->setAttr( + OpTrait::AttrSizedOperandSegments< + EnqueueDMAOp>::getOperandSegmentSizeAttr(), + mlir::DenseI32ArrayAttr::get(op->getContext(), {1, 0, 1, 1, 0, 0})); + } else if (op->getNumOperands() == 5) { // Remote DMA. + op->setAttr( + OpTrait::AttrSizedOperandSegments< + EnqueueDMAOp>::getOperandSegmentSizeAttr(), + mlir::DenseI32ArrayAttr::get(op->getContext(), {1, 1, 1, 1, 1, 0})); + } else { + return op->emitError("Unexpected operand count in tpu.enqueue_dma: ") + << op->getNumOperands(); + } + } + return success(); +} + +LogicalResult semaphore_signal_rule(Operation* op, int version) { + // Added AttrSizedOperandSegments and core_id in version 2. + if (version < 2) { + if (op->getNumOperands() == 2) { // Local signal. + op->setAttr(OpTrait::AttrSizedOperandSegments< + EnqueueDMAOp>::getOperandSegmentSizeAttr(), + mlir::DenseI32ArrayAttr::get(op->getContext(), {1, 1, 0, 0})); + } else if (op->getNumOperands() == 3) { // Remote signal. + op->setAttr(OpTrait::AttrSizedOperandSegments< + EnqueueDMAOp>::getOperandSegmentSizeAttr(), + mlir::DenseI32ArrayAttr::get(op->getContext(), {1, 1, 1, 0})); + } else { + return op->emitError("Unexpected operand count in tpu.semaphore_signal"); + } + } + return success(); +} + +const llvm::StringMap& upgrade_rules() { + static auto rules = new llvm::StringMap{ + {EnqueueDMAOp::getOperationName(), enqueue_dma_rule}, + {SemaphoreSignalOp::getOperationName(), semaphore_signal_rule}, + }; + return *rules; +} + struct MosaicSerdePass : public impl::MosaicSerdePassBase { using Base::Base; @@ -68,6 +119,7 @@ struct MosaicSerdePass : public impl::MosaicSerdePassBase { signalPassFailure(); return; } + int version = kVersion; if (serialize) { module->setAttr( kVersionAttrName, @@ -81,16 +133,17 @@ struct MosaicSerdePass : public impl::MosaicSerdePassBase { signalPassFailure(); return; } - if (version_attr.getValue() != kVersion) { + if (version_attr.getInt() > kVersion) { module->emitError("Unsupported Mosaic version: ") - << version_attr.getValue().getSExtValue(); + << version_attr.getInt(); signalPassFailure(); return; } + version = version_attr.getInt(); module->removeAttr(kVersionAttrName); } std::string name_storage; - auto result = module.walk([this, &name_storage](Operation* op) { + auto result = module.walk([this, &name_storage, version](Operation* op) { if (isa(op)) { // Don't mangle the ModuleOp itself. return WalkResult::advance(); } @@ -111,6 +164,13 @@ struct MosaicSerdePass : public impl::MosaicSerdePassBase { op->emitError("Operation not in a serialized form"); return WalkResult::interrupt(); } + // Upgrade the op to the current version, if needed. + if (const auto rule = upgrade_rules().find(new_name->getStringRef()); + rule != upgrade_rules().end()) { + if (rule->second(op, version).failed()) { + return WalkResult::interrupt(); + } + } } auto new_op = Operation::create( op->getLoc(), *new_name, op->getResultTypes(), op->getOperands(), diff --git a/jaxlib/mosaic/dialect/tpu/util.h b/jaxlib/mosaic/dialect/tpu/util.h index 9ae5f9a59619..4ffbf160b1c9 100644 --- a/jaxlib/mosaic/dialect/tpu/util.h +++ b/jaxlib/mosaic/dialect/tpu/util.h @@ -14,6 +14,7 @@ #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" #include "absl/types/span.h" +#include "tsl/platform/statusor.h" // TODO: Instead of CHECK_EQs, can we do something like TF_RET_CHECK but with // MLIR diagnostics? @@ -57,6 +58,12 @@ FailureOr getTypeBitwidth(Type ty) { if (auto bf16_ty = dyn_cast(ty)) { return 16; } + if (auto f8e5m2_ty = dyn_cast(ty)) { + return 8; + } + if (auto f8e4m3fn_ty = dyn_cast(ty)) { + return 8; + } return emitError(UnknownLoc::get(ty.getContext()), "Unsupported type: ") << ty; } diff --git a/jaxlib/mosaic/gpu/BUILD b/jaxlib/mosaic/gpu/BUILD new file mode 100644 index 000000000000..20fcf2b4ce74 --- /dev/null +++ b/jaxlib/mosaic/gpu/BUILD @@ -0,0 +1,198 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("@rules_python//python:defs.bzl", "py_library") +load("//jaxlib:jax.bzl", "pybind_extension") + +package( + default_applicable_licenses = [], + default_visibility = ["//:__subpackages__"], +) + +py_library( + name = "mosaic_gpu", + data = [":libmosaic_gpu_runtime.so"], + deps = [":_mosaic_gpu_ext"], +) + +cc_library( + name = "passes", + srcs = [ + "launch_lowering.cc", + "passes.cc", + ], + hdrs = [ + "launch_lowering.h", + "pass_boilerplate.h", + "passes.h", + ], + deps = [ + "@llvm-project//llvm:Support", + "@llvm-project//mlir:DataLayoutInterfaces", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:GPUDialect", + "@llvm-project//mlir:GPUToGPURuntimeTransforms", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:LLVMCommonConversion", + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", + ], +) + +CAPI_SOURCES = [ + "integrations/c/passes.cc", +] + +CAPI_HEADERS = [ + "integrations/c/passes.h", +] + +cc_library( + name = "mlir_capi", + srcs = CAPI_SOURCES, + hdrs = CAPI_HEADERS, + deps = [ + ":passes", + "@llvm-project//mlir:CAPIIRHeaders", + ], +) + +# Header-only target, used when using the C API from a separate shared library. +cc_library( + name = "mlir_capi_headers", + hdrs = CAPI_HEADERS, + deps = [ + "@llvm-project//mlir:CAPIIRHeaders", + ], +) + +# Alwayslink target, used when exporting the C API from a shared library. +cc_library( + name = "mlir_capi_objects", + srcs = CAPI_SOURCES, + hdrs = CAPI_HEADERS, + deps = [ + ":passes", + "@llvm-project//mlir:CAPIIRObjects", + ], + alwayslink = True, +) + +cc_library( + name = "runtime", + srcs = ["runtime.cc"], + deps = [ + "@local_config_cuda//cuda:cuda_headers", + ], +) + +cc_library( + name = "custom_call", + srcs = ["custom_call.cc"], + deps = [ + ":passes", + "//jaxlib/cuda:cuda_vendor", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:ArithToLLVM", + "@llvm-project//mlir:BuiltinToLLVMIRTranslation", + "@llvm-project//mlir:ComplexToLLVM", + "@llvm-project//mlir:ControlFlowToLLVM", + "@llvm-project//mlir:ConversionPasses", + "@llvm-project//mlir:ExecutionEngine", + "@llvm-project//mlir:ExecutionEngineUtils", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:FuncToLLVM", + "@llvm-project//mlir:GPUDialect", + "@llvm-project//mlir:GPUToLLVMIRTranslation", + "@llvm-project//mlir:GPUTransforms", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:IndexToLLVM", + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:LLVMToLLVMIRTranslation", + "@llvm-project//mlir:MathDialect", + "@llvm-project//mlir:MathToLLVM", + "@llvm-project//mlir:MemRefDialect", + "@llvm-project//mlir:MemRefToLLVM", + "@llvm-project//mlir:MemRefTransforms", + "@llvm-project//mlir:NVGPUDialect", + "@llvm-project//mlir:NVVMDialect", + "@llvm-project//mlir:NVVMTarget", + "@llvm-project//mlir:NVVMToLLVM", + "@llvm-project//mlir:NVVMToLLVMIRTranslation", + "@llvm-project//mlir:Parser", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:SCFDialect", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:Transforms", + "@llvm-project//mlir:UBToLLVM", + "@llvm-project//mlir:VectorDialect", + "@xla//xla/service:custom_call_status", + "@xla//xla/service:custom_call_target_registry", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/synchronization", + ], + alwayslink = True, +) + +pybind_extension( + name = "_mosaic_gpu_ext", + srcs = ["mosaic_gpu_ext.cc"], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + linkopts = select({ + "@xla//xla/python:use_jax_cuda_pip_rpaths": [ + "-Wl,-rpath,$$ORIGIN/../../../nvidia/cuda_runtime/lib", + ], + "//conditions:default": [], + }), + deps = [ + "//jaxlib:kernel_nanobind_helpers", + "//jaxlib/cuda:cuda_vendor", + "@xla//xla/service:custom_call_status", + "@xla//xla/tsl/cuda:cudart", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/synchronization", + "@nanobind", + ], +) + +cc_binary( + name = "libmosaic_gpu_runtime.so", + srcs = ["runtime.cc"], + copts = ["-fvisibility=default"], + linkopts = select({ + "@xla//xla/python:use_jax_cuda_pip_rpaths": [ + "-Wl,-rpath,$$ORIGIN/../../../nvidia/cuda_runtime/lib", + ], + "//conditions:default": [], + }), + linkshared = 1, + tags = [ + "manual", + "notap", + ], + deps = [ + "@xla//xla/tsl/cuda:cudart", + "@local_config_cuda//cuda:cuda_headers", + ], +) diff --git a/jaxlib/mosaic/gpu/custom_call.cc b/jaxlib/mosaic/gpu/custom_call.cc new file mode 100644 index 000000000000..56b3d2312c19 --- /dev/null +++ b/jaxlib/mosaic/gpu/custom_call.cc @@ -0,0 +1,446 @@ +/* Copyright 2021 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "jaxlib/gpu/vendor.h" +#include "absl/base/optimization.h" +#include "absl/container/flat_hash_map.h" +#include "absl/log/check.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/synchronization/mutex.h" +#include "llvm/include/llvm/ADT/SmallVector.h" +#include "llvm/include/llvm/Support/CodeGen.h" +#include "llvm/include/llvm/Support/TargetSelect.h" +#include "mlir/include/mlir/Conversion/ArithToLLVM/ArithToLLVM.h" +#include "mlir/include/mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h" +#include "mlir/include/mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" +#include "mlir/include/mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h" +#include "mlir/include/mlir/Conversion/IndexToLLVM/IndexToLLVM.h" +#include "mlir/include/mlir/Conversion/MathToLLVM/MathToLLVM.h" +#include "mlir/include/mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" +#include "mlir/include/mlir/Conversion/NVVMToLLVM/NVVMToLLVM.h" +#include "mlir/include/mlir/Conversion/Passes.h" +#include "mlir/include/mlir/Conversion/UBToLLVM/UBToLLVM.h" +#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/include/mlir/Dialect/GPU/Transforms/Passes.h" +#include "mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h" +#include "mlir/include/mlir/Dialect/Math/IR/Math.h" +#include "mlir/include/mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h" +#include "mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h" +#include "mlir/include/mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/include/mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/include/mlir/ExecutionEngine/ExecutionEngine.h" +#include "mlir/include/mlir/ExecutionEngine/OptUtils.h" +#include "mlir/include/mlir/IR/AsmState.h" +#include "mlir/include/mlir/IR/DialectRegistry.h" +#include "mlir/include/mlir/IR/MLIRContext.h" +#include "mlir/include/mlir/Parser/Parser.h" +#include "mlir/include/mlir/Pass/PassManager.h" +#include "mlir/include/mlir/Pass/PassRegistry.h" +#include "mlir/include/mlir/Support/LLVM.h" +#include "mlir/include/mlir/Support/LogicalResult.h" +#include "mlir/include/mlir/Target/LLVM/NVVM/Target.h" +#include "mlir/include/mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h" +#include "mlir/include/mlir/Target/LLVMIR/Dialect/GPU/GPUToLLVMIRTranslation.h" +#include "mlir/include/mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" +#include "mlir/include/mlir/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.h" +#include "mlir/include/mlir/Transforms/Passes.h" +#include "jaxlib/mosaic/gpu/launch_lowering.h" +#include "jaxlib/mosaic/gpu/passes.h" +#include "xla/service/custom_call_status.h" +#include "xla/service/custom_call_target_registry.h" + +namespace { + +using MosaicInitFunc = void(void****); +using MosaicHostFunc = void(void**); + +mlir::FailureOr GetPassPipeline( + mlir::MLIRContext* ctx, mlir::gpu::CompilationTarget target) { + static bool register_once = []() { + llvm::InitializeNativeTarget(); + llvm::InitializeNativeTarget(); + llvm::InitializeNativeTargetAsmPrinter(); + mlir::registerCanonicalizer(); + mlir::registerCSE(); + mlir::registerStripDebugInfo(); + mlir::registerConvertNVGPUToNVVMPass(); + mlir::registerConvertVectorToSCF(); + mlir::registerSCFToControlFlow(); + mlir::registerConvertNVVMToLLVMPass(); + mlir::registerArithToLLVMConversionPass(); + mlir::registerConvertIndexToLLVMPass(); + mlir::registerConvertGpuOpsToNVVMOps(); + mlir::registerConvertMathToLLVMPass(); + mlir::registerConvertFuncToLLVMPass(); + mlir::registerConvertAffineToStandard(); + mlir::registerReconcileUnrealizedCasts(); + // TODO(apaszke): Only register the passes we actually use. + mlir::memref::registerMemRefPasses(); + mlir::registerConvertToLLVMPass(); + mlir::registerGPUPasses(); + mosaic::gpu::registerGpuLaunchLoweringPass(); + mosaic::gpu::registerConvertGpuToLLVMPass(); + return true; + }(); + (void)register_once; + return mlir::parsePassPipeline( + R"( + builtin.module( + convert-nvgpu-to-nvvm, + gpu-kernel-outlining{data-layout-str=}, + convert-vector-to-scf{full-unroll=false lower-tensors=false target-rank=1}, + convert-scf-to-cf, + convert-nvvm-to-llvm, + expand-strided-metadata, + nvvm-attach-target{O=3 chip=sm_90a fast=false features=+ptx80 ftz=false module= triple=nvptx64-nvidia-cuda}, + lower-affine, + convert-arith-to-llvm{index-bitwidth=0}, + convert-index-to-llvm{index-bitwidth=64}, + canonicalize{max-iterations=10 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=true}, + cse, + gpu.module(strip-debuginfo), + gpu.module(convert-gpu-to-nvvm{has-redux=false index-bitwidth=64 use-bare-ptr-memref-call-conv=false}), + gpu.module(canonicalize{max-iterations=10 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=true}), + gpu.module(cse), + gpu.module(reconcile-unrealized-casts), + mosaic-convert-gpu-to-llvm, + gpu-module-to-binary{format=)" + + mlir::gpu::stringifyCompilationTarget(target).str() + R"(}, + convert-math-to-llvm{approximate-log1p=true}, + canonicalize{max-iterations=10 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=true}, + cse, + )" + + (target != mlir::gpu::CompilationTarget::Assembly ? "gpu-launch-lowering," + : "") + + R"( + convert-to-llvm, + reconcile-unrealized-casts + ) + )"); +} + +mlir::LogicalResult RunPasses(mlir::OpPassManager&& passes, + mlir::ModuleOp module) { + mlir::PassManager pm(module.getContext()); + *static_cast(&pm) = std::move(passes); + if (getenv("MOSAIC_GPU_DUMP_MLIR_PASSES") != nullptr) { + pm.enableIRPrinting(); + } + return pm.run(module); +} + +void InitContext(mlir::MLIRContext* context) { + mlir::DialectRegistry registry; + registry.insert(); + mlir::registerConvertNVVMToLLVMInterface(registry); + mlir::registerConvertComplexToLLVMInterface(registry); + mlir::registerConvertMemRefToLLVMInterface(registry); + mlir::registerConvertMathToLLVMInterface(registry); + mlir::registerConvertFuncToLLVMInterface(registry); + mlir::index::registerConvertIndexToLLVMInterface(registry); + mlir::cf::registerConvertControlFlowToLLVMInterface(registry); + mlir::ub::registerConvertUBToLLVMInterface(registry); + mlir::arith::registerConvertArithToLLVMInterface(registry); + mlir::registerConvertMemRefToLLVMInterface(registry); + mlir::gpu::registerOffloadingLLVMTranslationInterfaceExternalModels(registry); + mlir::NVVM::registerNVVMTargetInterfaceExternalModels(registry); + mlir::registerBuiltinDialectTranslation(registry); + mlir::registerGPUDialectTranslation(registry); + mlir::registerLLVMDialectTranslation(registry); + mlir::registerNVVMDialectTranslation(registry); + context->appendDialectRegistry(registry); + context->loadAllAvailableDialects(); +} + +absl::Status RunCUDATool(const char* tool, + const std::vector& args, + bool stderr_to_stdout = false) { + CHECK(!args.empty() && args.back() == nullptr); + const char * cuda_path_ptr = getenv("CUDA_ROOT"); + if (!cuda_path_ptr) return absl::InternalError("Failed to get CUDA_ROOT"); + std::string tool_path(cuda_path_ptr); + tool_path += "/bin/"; + tool_path += tool; + pid_t child_pid; + posix_spawn_file_actions_t file_actions; + if (posix_spawn_file_actions_init(&file_actions)) { + return absl::InternalError("Failed to initialize spawn file actions"); + } + if (posix_spawn_file_actions_adddup2(&file_actions, STDOUT_FILENO, + STDERR_FILENO)) { + return absl::InternalError("Failed to set up spawn file actions"); + } + // execv is guaranteed by POSIX to not modify the args (other than + // replacing the whole process image), so the const_cast is valid. + if (posix_spawn(&child_pid, tool_path.c_str(), &file_actions, nullptr, + const_cast(args.data()), environ)) { + return absl::InternalError("Process spawn failed"); + } + int status; + if (waitpid(child_pid, &status, 0) == -1) { + return absl::InternalError("Failed to wait for CUDA tool invocation"); + } + if (status != 0) return absl::InternalError("CUDA tool failed"); + if (posix_spawn_file_actions_destroy(&file_actions) != 0) { + return absl::InternalError("Failed to clean up after posix_spawn"); + } + return absl::OkStatus(); +} + +class TemporaryDirectory { + private: + TemporaryDirectory(std::string path) : path(std::move(path)) {} + // TODO(apaszke): Unlink in destructor. + + public: + static absl::StatusOr Create() { + std::string pattern = "/tmp/mosaic-gpu-XXXXXX"; + if (mkdtemp(pattern.data()) == NULL) { + return absl::InternalError("Failed to create temporary directory"); + } + return TemporaryDirectory(std::move(pattern)); + } + + std::string_view GetPath() { return path; } + + private: + std::string path; +}; + +void DumpCompilationOutput(mlir::ModuleOp module) { + bool dump_ptx = getenv("MOSAIC_GPU_DUMP_PTX") != nullptr; + bool dump_ptxas = getenv("MOSAIC_GPU_DUMP_PTXAS") != nullptr; + bool dump_sass = getenv("MOSAIC_GPU_DUMP_SASS") != nullptr; + if (!dump_ptx && !dump_ptxas && !dump_sass) { + return; + } + + module = module.clone(); // Prevent accidental modification. + auto passes = GetPassPipeline(module.getContext(), + mlir::gpu::CompilationTarget::Assembly); + if (mlir::failed(passes) || + mlir::failed(RunPasses(std::move(*passes), module))) { + return; + } + for (mlir::Operation& op : module.getBody()->getOperations()) { + auto binary = mlir::dyn_cast(&op); + if (!binary) { continue; } + auto objects = binary.getObjects(); + if (objects.size() != 1) { + std::cerr << "Multiple objects per gpu.binary unsupported" << std::endl; + continue; + } + auto object = mlir::cast(*objects.begin()); + std::string ptx = object.getObject().getValue().str(); + if (dump_ptx) { + std::cout << ptx << std::endl; + } + if (!dump_ptxas && !dump_sass) { continue; } // We're done. + auto tmpdir = TemporaryDirectory::Create(); + if (!tmpdir.ok()) { + std::cerr << "Failed to create a temporary directory" << std::endl; + continue; + } + std::string ptx_path = std::string(tmpdir->GetPath()) + "/kernel.ptx"; + std::string elf_path = std::string(tmpdir->GetPath()) + "/kernel.o"; + // Dump PTX into a file. + std::ofstream ptx_out(ptx_path.c_str()); + if (!ptx_out) { + std::cerr << "Failed to write PTX to a file" << std::endl; + continue; + } + ptx_out << ptx << std::endl; + // Run ptxas to generate SASS. + std::vector ptxas_args = { + "ptxas", "--opt-level", "3", + "--gpu-name", "sm_90a", "--output-file", + elf_path.c_str(), ptx_path.c_str()}; + if (dump_ptxas) { + ptxas_args.push_back("-v"); + } + ptxas_args.push_back(nullptr); + if (auto status = RunCUDATool("ptxas", ptxas_args); !status.ok()) { + std::cerr << "ptxas invocation failed: " << status.message() << std::endl; + continue; + } + if (!dump_sass) { continue; } // We're done. + // Call nvdisasm to pretty-print SASS. + if (auto status = RunCUDATool( + "nvdisasm", {"nvdisasm", "-ndf", "-c", elf_path.c_str(), nullptr}); + !status.ok()) { + std::cerr << "nvdisasm invocation failed: " << status.message() + << std::endl; + continue; + } + } +} + +absl::StatusOr> Compile( + mlir::ModuleOp module) { + DumpCompilationOutput(module); + auto passes = GetPassPipeline(module.getContext(), + mlir::gpu::CompilationTarget::Binary); + if (mlir::failed(passes)) { + return absl::InternalError("Failed to construct pass pipeline"); + } + if (mlir::failed(RunPasses(std::move(*passes), module))) { + return absl::InternalError("Pass pipeline failed"); + } + + llvm::SmallVector runtime_lib; + if (const char* lib_path = getenv("MOSAIC_GPU_RUNTIME_LIB_PATH")) { + runtime_lib.emplace_back(lib_path); + } + // Create a transformer to run all LLVM optimization passes at the + // specified optimization level. + mlir::ExecutionEngineOptions options; + options.transformer = mlir::makeOptimizingTransformer(3, 0, nullptr); + options.jitCodeGenOptLevel = llvm::CodeGenOptLevel::Aggressive; + options.sharedLibPaths = runtime_lib; + auto maybe_execution_engine = mlir::ExecutionEngine::create(module, options); + if (!maybe_execution_engine) { + return absl::InternalError("Failed to compile kernel"); + } + return std::move(*maybe_execution_engine); +} + +class CompiledKernel { + public: + CompiledKernel(std::unique_ptr engine, void* ctx, + void* scratch_addr, MosaicHostFunc* host_launch) + : engine_(std::move(engine)), + ctx_(ctx), + scratch_addr_(scratch_addr), + host_launch_(host_launch) {} + + std::tuple GetHostLaunch() { + return std::make_tuple(ctx_, scratch_addr_, host_launch_); + } + + private: + std::unique_ptr engine_; + void* ctx_; // TODO(apaszke): Destroy this properly + void* scratch_addr_; + MosaicHostFunc* host_launch_; +}; + +std::pair*, absl::Mutex*> +GetKernelCache() { + static absl::Mutex mutex; + static auto& context_cache = + *new absl::flat_hash_map; + return std::make_pair(&context_cache, &mutex); +} + +// Each compiled kernel has a unique init func, and each kernel is used from +// a single HLO module. So it should be safe to not include the CUDA context +// in the key. +absl::StatusOr> CompileAndInit( + uint64_t kernel_id, const char* module) { + auto cache_and_mutex = GetKernelCache(); + auto* cache = cache_and_mutex.first; + auto* mutex = cache_and_mutex.second; + + { + // Fast path uses reader lock (as hash map look-up is relatively slow). + absl::ReaderMutexLock lock(mutex); + auto it = cache->find(kernel_id); + if (ABSL_PREDICT_TRUE(it != cache->end())) + return it->second.GetHostLaunch(); + } + + absl::MutexLock lock(mutex); + // We released the reader lock, another thread might have initialized it. + if (cache->find(kernel_id) == cache->end()) { + mlir::MLIRContext context(mlir::MLIRContext::Threading::DISABLED); + InitContext(&context); + mlir::ParserConfig parse_config(&context); + auto module_op = + mlir::parseSourceString(module, parse_config); + if (!module_op) { + return absl::InternalError("Failed to parse module"); + } + auto maybe_engine = Compile(*module_op); + if (!maybe_engine.ok()) { + return maybe_engine.status(); + } + mlir::ExecutionEngine* execution_engine = maybe_engine->get(); + auto main = execution_engine->lookupPacked("_mlir_ciface_main"); + auto init = execution_engine->lookupPacked("_mlir_ciface_main_init"); + if (!init || !main) { + return absl::InternalError("Failed to retrieve kernel function"); + } + void* module_ptr = nullptr; + void* kernel_ptr = nullptr; + void** module_ptr_ptr = &module_ptr; + void** kernel_ptr_ptr = &kernel_ptr; + void*** init_args[2] = {&module_ptr_ptr, &kernel_ptr_ptr}; + reinterpret_cast(*init)(init_args); + CUmodule module = static_cast(module_ptr); + CUdeviceptr scratch_addr; + cuModuleGetGlobal(&scratch_addr, nullptr, module, "global_scratch"); + cache->insert_or_assign( + kernel_id, + CompiledKernel(std::move(*maybe_engine), kernel_ptr, + reinterpret_cast(scratch_addr), + reinterpret_cast(*main))); + } + return cache->at(kernel_id).GetHostLaunch(); +} + +void MosaicGPUCustomCall(void* stream, void** buffers, char* opaque, + size_t opaque_len, XlaCustomCallStatus* status) { + uint64_t kernel_id = *reinterpret_cast(opaque); + auto ctx_and_kernel = CompileAndInit(kernel_id, opaque + sizeof(uint64_t)); + if (!ctx_and_kernel.ok()) { + XlaCustomCallStatusSetFailure(status, + ctx_and_kernel.status().message().data(), + ctx_and_kernel.status().message().size()); + return; + } + void* args[4] = {&std::get<0>(*ctx_and_kernel), &stream, &buffers, + &std::get<1>(*ctx_and_kernel)}; + std::get<2>(*ctx_and_kernel)(args); +} + +XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("mosaic_gpu", &MosaicGPUCustomCall, + "CUDA"); + +} // namespace diff --git a/jaxlib/cpu/ducc_fft.fbs b/jaxlib/mosaic/gpu/integrations/c/passes.cc similarity index 54% rename from jaxlib/cpu/ducc_fft.fbs rename to jaxlib/mosaic/gpu/integrations/c/passes.cc index a58e1dc7ca18..065d11fd33e1 100644 --- a/jaxlib/cpu/ducc_fft.fbs +++ b/jaxlib/mosaic/gpu/integrations/c/passes.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The JAX Authors. +/* Copyright 2024 The JAX Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,36 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -namespace jax; +#include "jaxlib/mosaic/gpu/integrations/c/passes.h" -enum DuccFftDtype : byte { - COMPLEX64 = 0, - COMPLEX128 = 1, -} +#include "jaxlib/mosaic/gpu/launch_lowering.h" -enum DuccFftType : byte { - C2C = 0, - C2R = 1, - R2C = 2, -} +extern "C" { -table DuccFftDescriptor { - dtype:DuccFftDtype; - fft_type:DuccFftType; - shape:[uint64]; - strides_in:[uint64]; - strides_out:[uint64]; - axes:[uint32]; - forward:bool; - scale:double; +void mlirMosaicGpuRegisterPasses() { + mosaic::gpu::registerGpuLaunchLoweringPass(); } -table DynamicDuccFftDescriptor { - ndims:uint32; - dtype:DuccFftDtype; - fft_type:DuccFftType; - axes:[uint32]; - forward:bool; } - -root_type DuccFftDescriptor; diff --git a/jaxlib/cpu/ducc_fft_kernels.h b/jaxlib/mosaic/gpu/integrations/c/passes.h similarity index 56% rename from jaxlib/cpu/ducc_fft_kernels.h rename to jaxlib/mosaic/gpu/integrations/c/passes.h index a3bf6cf46db0..901c39d68b77 100644 --- a/jaxlib/cpu/ducc_fft_kernels.h +++ b/jaxlib/mosaic/gpu/integrations/c/passes.h @@ -1,4 +1,4 @@ -/* Copyright 2020 The JAX Authors. +/* Copyright 2024 The JAX Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,19 +13,19 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef JAXLIB_CPU_DUCC_FFT_KERNELS_H_ -#define JAXLIB_CPU_DUCC_FFT_KERNELS_H_ +#ifndef JAXLIB_MOSAIC_GPU_INTEGRATIONS_C_PASSES_H_ +#define JAXLIB_MOSAIC_GPU_INTEGRATIONS_C_PASSES_H_ -#include "xla/service/custom_call_status.h" +#include "mlir-c/Support.h" -namespace jax { +#ifdef __cplusplus +extern "C" { +#endif -// TODO(b/287702203): this must be kept until EOY 2023 for backwards -// of serialized functions using fft. -void DuccFft(void* out, void** in, XlaCustomCallStatus*); +MLIR_CAPI_EXPORTED void mlirMosaicGpuRegisterPasses(); -void DynamicDuccFft(void* out, void** in, XlaCustomCallStatus*); +#ifdef __cplusplus +} +#endif -} // namespace jax - -#endif // JAXLIB_CPU_DUCC_FFT_KERNELS_H_ +#endif // JAXLIB_MOSAIC_GPU_INTEGRATIONS_C_PASSES_H_ diff --git a/jaxlib/mosaic/gpu/launch_lowering.cc b/jaxlib/mosaic/gpu/launch_lowering.cc new file mode 100644 index 000000000000..f3a1ce4f439f --- /dev/null +++ b/jaxlib/mosaic/gpu/launch_lowering.cc @@ -0,0 +1,329 @@ +/* Copyright 2024 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// The stock MLIR pipeline lowers gpu.launch_func into a sequence of +// instructions that load the kernel onto the GPU, run it and immediately unload +// it again. This has the correct semantics, but loading the kernel is both +// expensive and forces synchronization, which causes performance issues. + +// This pass implements an alternative strategy, where each function containing +// a gpu.launch_func is split into two functions: one that preloads the kernel +// onto the GPU, and second one that consumes the handle produced by the +// first one. We call the first function at compile-time, while only the +// second one is used at run-time. + +// TODO(apaszke): Implement a third function that properly cleans up the +// resources allocated by the first function. + +#include +#include + +#include "llvm/include/llvm/ADT/STLExtras.h" +#include "llvm/include/llvm/ADT/StringRef.h" +#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/include/mlir/Dialect/LLVMIR/LLVMAttrs.h" +#include "mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h" +#include "mlir/include/mlir/IR/Builders.h" +#include "mlir/include/mlir/IR/BuiltinAttributes.h" +#include "mlir/include/mlir/IR/BuiltinOps.h" +#include "mlir/include/mlir/IR/DialectRegistry.h" +#include "mlir/include/mlir/IR/Location.h" +#include "mlir/include/mlir/IR/SymbolTable.h" +#include "mlir/include/mlir/IR/TypeRange.h" +#include "mlir/include/mlir/IR/Value.h" +#include "mlir/include/mlir/IR/ValueRange.h" +#include "mlir/include/mlir/IR/Visitors.h" +#include "mlir/include/mlir/Interfaces/DataLayoutInterfaces.h" +#include "mlir/include/mlir/Pass/Pass.h" +#include "mlir/include/mlir/Pass/PassRegistry.h" +#include "mlir/include/mlir/Support/LLVM.h" +#include "mlir/include/mlir/Support/LogicalResult.h" +#include "mlir/include/mlir/Support/TypeID.h" + +namespace mosaic { +namespace gpu { + +namespace { + +mlir::Value packKernelArgs(mlir::OpBuilder &builder, + mlir::gpu::LaunchFuncOp launch) { + std::vector kernel_operand_types; + kernel_operand_types.reserve(launch.getNumKernelOperands()); + for (mlir::Value operand : launch.getKernelOperands()) { + kernel_operand_types.push_back(operand.getType()); + } + auto kernel_args_struct_ty = mlir::LLVM::LLVMStructType::getLiteral( + builder.getContext(), kernel_operand_types); + auto ptr_ty = mlir::LLVM::LLVMPointerType::get(builder.getContext()); + mlir::Value c1 = builder.create( + launch.getLoc(), builder.getI32Type(), builder.getI32IntegerAttr(1)); + mlir::Value kernel_args_struct = builder.create( + launch.getLoc(), ptr_ty, kernel_args_struct_ty, c1); + mlir::Value kernel_args_array = builder.create( + launch.getLoc(), ptr_ty, + mlir::LLVM::LLVMArrayType::get(builder.getI64Type(), + launch.getNumKernelOperands()), + c1); + + for (auto [i, operand] : llvm::enumerate(launch.getKernelOperands())) { + mlir::LLVM::GEPArg gep_arg(i); + mlir::Value storage_ptr = builder.create( + launch.getLoc(), ptr_ty, operand.getType(), kernel_args_struct, + gep_arg); + builder.create(launch.getLoc(), operand, storage_ptr); + mlir::Value array_slot_ptr = builder.create( + launch.getLoc(), ptr_ty, builder.getI64Type(), kernel_args_array, + gep_arg); + builder.create(launch.getLoc(), storage_ptr, + array_slot_ptr); + } + return kernel_args_array; +} + +void emitRuntimeDecls(mlir::ModuleOp module) { + auto ptr_ty = mlir::LLVM::LLVMPointerType::get(module.getContext()); + auto i32 = mlir::IntegerType::get(module.getContext(), 32); + auto i64 = mlir::IntegerType::get(module.getContext(), 64); + auto decl_builder = mlir::OpBuilder::atBlockBegin(module.getBody()); + decl_builder.create( + module.getLoc(), decl_builder.getStringAttr("mosaic_gpu_launch_kernel"), + mlir::FunctionType::get(module.getContext(), + {ptr_ty, i64, i64, i64, i64, i64, i64, i32, + ptr_ty, ptr_ty}, + {}), + decl_builder.getStringAttr("private"), /*arg_attr=*/nullptr, + /*res_attrs=*/nullptr); + decl_builder.create( + module.getLoc(), decl_builder.getStringAttr("mosaic_gpu_module_load"), + mlir::FunctionType::get(module.getContext(), {ptr_ty}, {ptr_ty}), + decl_builder.getStringAttr("private"), /*arg_attr=*/nullptr, + /*res_attrs=*/nullptr); + decl_builder.create( + module.getLoc(), decl_builder.getStringAttr("mosaic_gpu_get_function"), + mlir::FunctionType::get(module.getContext(), {ptr_ty, ptr_ty, i32}, + {ptr_ty}), + decl_builder.getStringAttr("private"), /*arg_attr=*/nullptr, + /*res_attrs=*/nullptr); +} + +void buildInitFunction(mlir::OpBuilder &module_builder, + mlir::func::FuncOp init_func, + llvm::StringRef kernel_name, + mlir::gpu::ObjectAttr object, + mlir::Value dynamic_smem_size) { + auto i32 = mlir::IntegerType::get(init_func.getContext(), 32); + auto ptr_ty = mlir::LLVM::LLVMPointerType::get(init_func.getContext()); + mlir::Location loc = init_func.getLoc(); + auto builder = mlir::OpBuilder::atBlockBegin(init_func.addEntryBlock()); + auto binary_global_decl = module_builder.create( + loc, + mlir::LLVM::LLVMArrayType::get(builder.getI8Type(), + object.getObject().size()), + /*is_constant=*/true, + /*linkage=*/mlir::LLVM::Linkage::Internal, + /*name=*/ + builder.getStringAttr(kernel_name.str() + "_kernel_binary"), + /*value=*/object.getObject()); + mlir::Value binary_addr = builder.create( + init_func.getLoc(), binary_global_decl); + mlir::Value module_handle = + builder + .create(loc, "mosaic_gpu_module_load", ptr_ty, + binary_addr) + .getResult(0); + + // TODO(apaszke): This will create duplicate globals if the kernel + // is called from multiple functions! + mlir::StringAttr kernel_name_global_name = + builder.getStringAttr(kernel_name.str() + "_name"); + auto kernel_name_global = module_builder.create( + loc, + mlir::LLVM::LLVMArrayType::get(builder.getI8Type(), + kernel_name.size() + 1), + /*is_constant=*/true, + /*linkage=*/mlir::LLVM::Linkage::Internal, + /*name=*/kernel_name_global_name, + /*value=*/ + builder.getStringAttr( + llvm::Twine(kernel_name).concat(llvm::Twine('\0')))); + mlir::Value kernel_name_ptr = + builder.create(loc, kernel_name_global); + mlir::Value used_smem = builder.create( + loc, i32, builder.getI32IntegerAttr(0)); + if (dynamic_smem_size) { + if (auto const_smem = + dynamic_smem_size.getDefiningOp()) { + used_smem = builder.create( + loc, i32, + builder.getI32IntegerAttr( + mlir::cast(const_smem.getValue()).getInt())); + } + } + mlir::Value kernel_handle = + builder + .create( + loc, "mosaic_gpu_get_function", ptr_ty, + mlir::ValueRange{module_handle, kernel_name_ptr, used_smem}) + .getResult(0); + builder.create(loc, module_handle, + init_func.getArgument(0)); + builder.create(loc, kernel_handle, + init_func.getArgument(1)); + builder.create(loc); +} + +mlir::LogicalResult launchPreloadedKernel(mlir::func::FuncOp func, + mlir::gpu::LaunchFuncOp launch, + mlir::Value kernel_handle) { + // Lower gpu.launch_func to a call to mgpuLaunchKernel. + mlir::OpBuilder builder(launch); + mlir::Value dynamic_smem = launch.getDynamicSharedMemorySize(); + if (!dynamic_smem) { + dynamic_smem = builder.create( + launch.getLoc(), builder.getI64Type(), builder.getI64IntegerAttr(0)); + } + mlir::Value arg_ptr_array = packKernelArgs(builder, launch); + if (launch.hasClusterSize()) { + return launch.emitOpError("Clusters not supported yet."); + } + mlir::gpu::KernelDim3 grid = launch.getGridSizeOperandValues(); + mlir::gpu::KernelDim3 block = launch.getBlockSizeOperandValues(); + mlir::Value stream = launch.getAsyncObject(); + builder.create( + launch.getLoc(), "mosaic_gpu_launch_kernel", mlir::TypeRange{}, + mlir::ValueRange{kernel_handle, grid.x, grid.y, grid.z, block.x, block.y, + block.z, dynamic_smem, stream, arg_ptr_array}); + return mlir::success(); +} + +class GpuLaunchLoweringPass : public ::mlir::OperationPass { + public: + GpuLaunchLoweringPass() + : ::mlir::OperationPass( + ::mlir::TypeID::get()) {} + GpuLaunchLoweringPass(const GpuLaunchLoweringPass &other) + : ::mlir::OperationPass(other) {} + GpuLaunchLoweringPass &operator=(const GpuLaunchLoweringPass &) = delete; + GpuLaunchLoweringPass(GpuLaunchLoweringPass &&) = delete; + GpuLaunchLoweringPass &operator=(GpuLaunchLoweringPass &&) = delete; + ~GpuLaunchLoweringPass() = default; + + // Pass boilerplate... + static constexpr ::llvm::StringLiteral getArgumentName() { + return ::llvm::StringLiteral("gpu-launch-lowering"); + } + ::llvm::StringRef getArgument() const override { return getArgumentName(); } + ::llvm::StringRef getDescription() const override { return ""; } + static constexpr ::llvm::StringLiteral getPassName() { + return ::llvm::StringLiteral("GpuLaunchLoweringPass"); + } + ::llvm::StringRef getName() const override { return getPassName(); } + static bool classof(const ::mlir::Pass *pass) { + return pass->getTypeID() == ::mlir::TypeID::get(); + } + std::unique_ptr<::mlir::Pass> clonePass() const override { + return std::make_unique( + *static_cast(this)); + } + void getDependentDialects(::mlir::DialectRegistry ®istry) const override {} + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(GpuLaunchLoweringPass) + + void runOnOperation() override { + mlir::ModuleOp module = getOperation(); + auto ptr_ty = mlir::LLVM::LLVMPointerType::get(module.getContext()); + emitRuntimeDecls(module); + for (mlir::Operation &op : *module.getBody()) { + if (auto func = mlir::dyn_cast(&op)) { + if (func.isDeclaration() || + !func->getAttr( + mlir::LLVM::LLVMDialect::getEmitCWrapperAttrName())) { + continue; + } + auto module_builder = mlir::OpBuilder::atBlockBegin(module.getBody()); + auto init_func = module_builder.create( + op.getLoc(), func.getName().str() + "_init", + mlir::FunctionType::get(func->getContext(), {ptr_ty, ptr_ty}, {})); + init_func->setAttr(mlir::LLVM::LLVMDialect::getEmitCWrapperAttrName(), + mlir::UnitAttr::get(func->getContext())); + bool had_launch = false; + auto result = getOperation()->walk([&](mlir::gpu::LaunchFuncOp launch) + -> mlir::WalkResult { + if (had_launch) { + launch->emitOpError("Only one launch per function supported."); + return mlir::WalkResult::interrupt(); + } + had_launch = true; + auto binary = + mlir::SymbolTable::lookupNearestSymbolFrom( + launch, launch.getKernelModuleName()); + if (!binary) { + launch.emitError("Failed to find the gpu.binary op for ") + << launch.getKernelModuleName(); + return mlir::WalkResult::interrupt(); + } + if (binary.getObjects().size() != 1) { + binary.emitOpError("Expected exactly one object in the binary."); + return mlir::WalkResult::interrupt(); + } + mlir::gpu::ObjectAttr object = + mlir::cast(*binary.getObjects().begin()); + if (object.getFormat() != mlir::gpu::CompilationTarget::Fatbin && + object.getFormat() != mlir::gpu::CompilationTarget::Binary) { + binary.emitOpError("Expected a binary or a fatbin object."); + return mlir::WalkResult::interrupt(); + } + + buildInitFunction(module_builder, init_func, + launch.getKernelName().getValue(), object, + launch.getDynamicSharedMemorySize()); + + // Add a new function argument for the kernel handle. + func.insertArgument(0, ptr_ty, + mlir::DictionaryAttr::get(func.getContext()), + mlir::UnknownLoc::get(func.getContext())); + mlir::Value kernel_handle = func.getArgument(0); + if (launchPreloadedKernel(func, launch, kernel_handle).failed()) { + return mlir::WalkResult::interrupt(); + } + launch.erase(); + + // TODO(apaszke): Generate a destructor function. + // builder.CreateCall(getModuleUnloadFn(), {moduleObject}); + + return mlir::WalkResult::advance(); + }); + if (!had_launch) { + init_func.erase(); + } + if (result == mlir::WalkResult::interrupt()) { + signalPassFailure(); + } + } + } + } +}; + +} // namespace + +void registerGpuLaunchLoweringPass() { + ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> { + return std::make_unique(); + }); +} + +} // namespace gpu +} // namespace mosaic diff --git a/jaxlib/mosaic/gpu/launch_lowering.h b/jaxlib/mosaic/gpu/launch_lowering.h new file mode 100644 index 000000000000..36b90f45b650 --- /dev/null +++ b/jaxlib/mosaic/gpu/launch_lowering.h @@ -0,0 +1,27 @@ +/* Copyright 2024 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_MOSAIC_GPU_LAUNCH_LOWERING_H_ +#define JAXLIB_MOSAIC_GPU_LAUNCH_LOWERING_H_ + +namespace mosaic { +namespace gpu { + +void registerGpuLaunchLoweringPass(); + +} // namespace gpu +} // namespace mosaic + +#endif // JAXLIB_MOSAIC_GPU_LAUNCH_LOWERING_H_ diff --git a/jaxlib/mosaic/gpu/mosaic_gpu_ext.cc b/jaxlib/mosaic/gpu/mosaic_gpu_ext.cc new file mode 100644 index 000000000000..ec574de4368f --- /dev/null +++ b/jaxlib/mosaic/gpu/mosaic_gpu_ext.cc @@ -0,0 +1,63 @@ +/* Copyright 2021 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "nanobind/nanobind.h" +#include "absl/container/flat_hash_map.h" +#include "absl/synchronization/mutex.h" +#include "jaxlib/gpu/vendor.h" +#include "jaxlib/kernel_nanobind_helpers.h" +#include "xla/service/custom_call_status.h" + +namespace jax::cuda { +namespace { + +namespace nb = nanobind; + +void EventRecordCall(void* stream, void** buffers, char* opaque, + size_t opaque_len, XlaCustomCallStatus* status) { + auto* event = reinterpret_cast(opaque); + if (gpuEventRecord(**event, reinterpret_cast(stream)) != + gpuSuccess) { + const char message[] = "Failed to record event"; + XlaCustomCallStatusSetFailure(status, message, sizeof(message)); + } +} + +NB_MODULE(_mosaic_gpu_ext, m) { + m.def("_gpu_event_create", []() { + gpuEvent_t* event = new gpuEvent_t(); + gpuEventCreate(event, GPU_EVENT_DEFAULT); + return reinterpret_cast(event); + }); + m.def("_gpu_event_destroy", [](uintptr_t event) { + gpuEventDestroy(*reinterpret_cast(event)); + }); + m.def("_gpu_event_elapsed", [](uintptr_t start_event, uintptr_t end_event) { + float elapsed_ms = -1; + if (gpuEventElapsedTime( + &elapsed_ms, *reinterpret_cast(start_event), + *reinterpret_cast(end_event)) != gpuSuccess) { + throw std::runtime_error("Failed to get elapsed time between events"); + } + return elapsed_ms; + }); + m.def("_record_event_capsule", + []() { return EncapsulateFunction(EventRecordCall); }); +} + +} // namespace +} // namespace jax::cuda diff --git a/jaxlib/mosaic/gpu/pass_boilerplate.h b/jaxlib/mosaic/gpu/pass_boilerplate.h new file mode 100644 index 000000000000..b0241fca97ab --- /dev/null +++ b/jaxlib/mosaic/gpu/pass_boilerplate.h @@ -0,0 +1,64 @@ +/* Copyright 2024 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_MOSAIC_GPU_PASS_BOILERPLATE_H_ +#define JAXLIB_MOSAIC_GPU_PASS_BOILERPLATE_H_ + +#include "mlir/include/mlir/IR/DialectRegistry.h" +#include "mlir/include/mlir/Pass/Pass.h" +#include "mlir/include/mlir/Support/LLVM.h" +#include "mlir/include/mlir/Support/TypeID.h" +namespace mosaic { +namespace gpu { + +template +class Pass : public ::mlir::OperationPass { + public: + Pass() : ::mlir::OperationPass(::mlir::TypeID::get()) {} + Pass(const Pass &other) : ::mlir::OperationPass(other) {} + Pass &operator=(const Pass &) = delete; + Pass(Pass &&) = delete; + Pass &operator=(Pass &&) = delete; + ~Pass() = default; + + static constexpr ::llvm::StringLiteral getArgumentName() { + return ::llvm::StringLiteral(Derived::kArgumentName); + } + ::llvm::StringRef getArgument() const override { return getArgumentName(); } + ::llvm::StringRef getDescription() const override { return ""; } + static constexpr ::llvm::StringLiteral getPassName() { + return ::llvm::StringLiteral(Derived::kPassName); + } + ::llvm::StringRef getName() const override { return getPassName(); } + static bool classof(const ::mlir::Pass *pass) { + return pass->getTypeID() == ::mlir::TypeID::get(); + } + std::unique_ptr<::mlir::Pass> clonePass() const override { + return std::make_unique(*static_cast(this)); + } + void getDependentDialects(::mlir::DialectRegistry ®istry) const override {} + + private: + using This = + Pass; // Can't have a comma in the macro instantiation + + public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(This) +}; + +} // namespace gpu +} // namespace mosaic + +#endif // JAXLIB_MOSAIC_GPU_PASS_BOILERPLATE_H_ diff --git a/jaxlib/mosaic/gpu/passes.cc b/jaxlib/mosaic/gpu/passes.cc new file mode 100644 index 000000000000..9c9d82ac4f3f --- /dev/null +++ b/jaxlib/mosaic/gpu/passes.cc @@ -0,0 +1,77 @@ +/* Copyright 2024 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/mosaic/gpu/passes.h" +#include +#include +#include + +#include "llvm/include/llvm/ADT/StringRef.h" +#include "mlir/include/mlir/Conversion/GPUCommon/GPUCommonPass.h" +#include "mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/include/mlir/IR/BuiltinOps.h" +#include "mlir/include/mlir/IR/SymbolTable.h" +#include "mlir/include/mlir/Pass/PassRegistry.h" +#include "mlir/include/mlir/Support/LLVM.h" +#include "mlir/include/mlir/Transforms/DialectConversion.h" +#include "jaxlib/mosaic/gpu/pass_boilerplate.h" + +namespace mosaic { +namespace gpu { + +namespace { + +class ConvertGpuToLLVMPass + : public mosaic::gpu::Pass { + public: + using mosaic::gpu::Pass::Pass; + static constexpr llvm::StringLiteral kArgumentName = + "mosaic-convert-gpu-to-llvm"; + static constexpr llvm::StringLiteral kPassName = "ConvertGpuToLLVMPass"; + + void runOnOperation() override { + mlir::MLIRContext *ctx = &getContext(); + mlir::RewritePatternSet patterns(ctx); + mlir::LLVMTypeConverter converter(ctx); + mlir::ConversionTarget target(*ctx); + target.addLegalDialect(); + target.addLegalOp(); + target.addDynamicallyLegalOp( + [&](mlir::gpu::LaunchFuncOp op) -> bool { + return converter.isLegal(op->getOperandTypes()) && + converter.isLegal(op->getResultTypes()); + }); + auto symtab = mlir::SymbolTable(getOperation()); + mlir::populateGpuToLLVMConversionPatterns(converter, patterns, false); + if (mlir::applyPartialConversion(getOperation(), target, + std::move(patterns)) + .failed()) { + signalPassFailure(); + } + } +}; + +} // namespace + +void registerConvertGpuToLLVMPass() { + ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> { + return std::make_unique(); + }); +} + +} // namespace gpu +} // namespace mosaic diff --git a/jaxlib/mosaic/gpu/passes.h b/jaxlib/mosaic/gpu/passes.h new file mode 100644 index 000000000000..bf7a804ee217 --- /dev/null +++ b/jaxlib/mosaic/gpu/passes.h @@ -0,0 +1,27 @@ +/* Copyright 2024 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_MOSAIC_GPU_CONVERT_GPU_TO_LLVM_H_ +#define JAXLIB_MOSAIC_GPU_CONVERT_GPU_TO_LLVM_H_ + +namespace mosaic { +namespace gpu { + +void registerConvertGpuToLLVMPass(); + +} // namespace gpu +} // namespace mosaic + +#endif // JAXLIB_MOSAIC_GPU_CONVERT_GPU_TO_LLVM_H_ diff --git a/jaxlib/mosaic/gpu/runtime.cc b/jaxlib/mosaic/gpu/runtime.cc new file mode 100644 index 000000000000..5f32f6e2bb81 --- /dev/null +++ b/jaxlib/mosaic/gpu/runtime.cc @@ -0,0 +1,145 @@ +/* Copyright 2024 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include "third_party/gpus/cuda/include/cuda.h" + +extern "C" { + +void mosaic_gpu_init_tma_desc(CUtensorMap *tma_desc, void *base_addr, + int64_t elem_bytewidth, int64_t rank, + int64_t *sizes, int64_t *strides, + int64_t swizzle_bytes, int64_t *window_shape) { + CUtensorMapDataType data_type; + if (elem_bytewidth == 1) { + data_type = CU_TENSOR_MAP_DATA_TYPE_UINT8; + } else if (elem_bytewidth == 2) { + data_type = CU_TENSOR_MAP_DATA_TYPE_UINT16; + } else if (elem_bytewidth == 4) { + data_type = CU_TENSOR_MAP_DATA_TYPE_UINT32; + } else if (elem_bytewidth == 8) { + data_type = CU_TENSOR_MAP_DATA_TYPE_UINT64; + } else { + fprintf(stderr, "Unsupported element size: %ld\n", elem_bytewidth); + abort(); + } + cuuint64_t tma_sizes[5] = {1, 1, 1, 1, 1}; + for (int i = 0; i < rank; ++i) { + tma_sizes[i] = static_cast(sizes[rank - i - 1]); + } + cuuint64_t tma_strides[5] = {1, 1, 1, 1, 1}; + if (strides[rank - 1] != 1) { + fprintf(stderr, "Minormost stride must be 1, but got %ld\n", + strides[rank - 1]); + abort(); + } + for (int i = 0; i < rank - 1; ++i) { // We skip the implicit minor stride. + tma_strides[i] = + static_cast(strides[rank - i - 2] * elem_bytewidth); + } + cuuint32_t tma_window_shape[5] = {1, 1, 1, 1, 1}; + for (int64_t i = 0; i < rank; ++i) { + tma_window_shape[i] = static_cast(window_shape[rank - i - 1]); + } + cuuint32_t element_strides[5] = {1, 1, 1, 1, 1}; + CUtensorMapSwizzle swizzle; + if (swizzle_bytes == 0) { + swizzle = CU_TENSOR_MAP_SWIZZLE_NONE; + } else if (swizzle_bytes == 32) { + swizzle = CU_TENSOR_MAP_SWIZZLE_32B; + } else if (swizzle_bytes == 64) { + swizzle = CU_TENSOR_MAP_SWIZZLE_64B; + } else if (swizzle_bytes == 128) { + swizzle = CU_TENSOR_MAP_SWIZZLE_128B; + } else { + fprintf(stderr, "Unsupported swizzle: %ld\n", swizzle_bytes); + abort(); + } + CUresult result = cuTensorMapEncodeTiled( + tma_desc, data_type, rank, base_addr, tma_sizes, tma_strides, + tma_window_shape, element_strides, CU_TENSOR_MAP_INTERLEAVE_NONE, swizzle, + CU_TENSOR_MAP_L2_PROMOTION_NONE, CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE); + if (result != CUDA_SUCCESS) { + const char *ptr = nullptr; + cuGetErrorString(result, &ptr); + fprintf(stderr, "cuTensorMapEncodeTiled failed: %s\n", ptr); + abort(); + } +} + +void mosaic_gpu_memcpy_async_h2d(CUdeviceptr dst, void *src, uint64_t bytes, + CUstream stream) { + CUresult result = cuMemcpyHtoDAsync(dst, src, bytes, stream); + if (result != CUDA_SUCCESS) { + const char *ptr = nullptr; + cuGetErrorString(result, &ptr); + fprintf(stderr, "cuMemcpyAsync failed: %s\n", ptr); + abort(); + } +} + +void* mosaic_gpu_module_load(void *data) { + CUmodule module = nullptr; + if (auto result = cuModuleLoadData(&module, data); result != CUDA_SUCCESS) { + const char *ptr = nullptr; + cuGetErrorString(result, &ptr); + fprintf(stderr, "cuModuleLoadData failed: %s\n", ptr); + abort(); + } + return module; +} + +void *mosaic_gpu_get_function(CUmodule module, const char *name, + int32_t smem_bytes) { + CUfunction function = nullptr; + CUresult result = cuModuleGetFunction(&function, module, name); + if (result != CUDA_SUCCESS) { + const char *ptr = nullptr; + cuGetErrorString(result, &ptr); + fprintf(stderr, "cuModuleGetFunction failed: %s\n", ptr); + abort(); + } + if (smem_bytes) { + result = cuFuncSetAttribute( + function, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, smem_bytes); + if (result != CUDA_SUCCESS) { + const char *ptr = nullptr; + cuGetErrorString(result, &ptr); + fprintf(stderr, "cuFuncSetAttribute failed: %s\n", ptr); + abort(); + } + } + return function; +} + +void mosaic_gpu_launch_kernel(CUfunction function, int64_t grid_x, + int64_t grid_y, int64_t grid_z, int64_t block_x, + int64_t block_y, int64_t block_z, + int32_t smem_bytes, CUstream stream, + void **params) { + CUresult result = + cuLaunchKernel(function, grid_x, grid_y, grid_z, block_x, block_y, + block_z, smem_bytes, stream, params, nullptr); + if (result != CUDA_SUCCESS) { + const char *ptr = nullptr; + cuGetErrorString(result, &ptr); + fprintf(stderr, "cuLaunchKernel failed: %s\n", ptr); + abort(); + } +} + +} diff --git a/jaxlib/mosaic/python/BUILD b/jaxlib/mosaic/python/BUILD index 48268bfcf30a..639e61a89062 100644 --- a/jaxlib/mosaic/python/BUILD +++ b/jaxlib/mosaic/python/BUILD @@ -14,8 +14,8 @@ # Mosaic Python bindings -load("@llvm-project//mlir:tblgen.bzl", "gentbl_filegroup") load("@rules_python//python:defs.bzl", "py_library") +load("@llvm-project//mlir:tblgen.bzl", "gentbl_filegroup") gentbl_filegroup( name = "tpu_python_gen_raw", diff --git a/jaxlib/rocm/BUILD.bazel b/jaxlib/rocm/BUILD.bazel index ccb66bbc8f8e..13e63208632c 100644 --- a/jaxlib/rocm/BUILD.bazel +++ b/jaxlib/rocm/BUILD.bazel @@ -100,7 +100,7 @@ pybind_extension( "@local_config_rocm//rocm:hipblas", "@local_config_rocm//rocm:rocm_headers", "@nanobind", - "@tsl//tsl/python/lib/core:numpy", + "@xla//xla/tsl/python/lib/core:numpy", ], ) @@ -141,7 +141,7 @@ pybind_extension( "@local_config_rocm//rocm:hipsolver", "@local_config_rocm//rocm:rocm_headers", "@nanobind", - "@tsl//tsl/python/lib/core:numpy", + "@xla//xla/tsl/python/lib/core:numpy", ], ) @@ -189,7 +189,7 @@ pybind_extension( "@local_config_rocm//rocm:hipsparse", "@local_config_rocm//rocm:rocm_headers", "@nanobind", - "@tsl//tsl/python/lib/core:numpy", + "@xla//xla/tsl/python/lib/core:numpy", ], ) @@ -201,9 +201,13 @@ cc_library( ":hip_gpu_kernel_helpers", ":hip_lu_pivot_kernels_impl", ":hip_vendor", - "//jaxlib:kernel_helpers", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:str_format", "@local_config_rocm//rocm:rocm_headers", - "@xla//xla/service:custom_call_status", + "@xla//xla/ffi/api:c_api", + "@xla//xla/ffi/api:ffi", ], ) @@ -214,12 +218,49 @@ rocm_library( deps = [ ":hip_gpu_kernel_helpers", ":hip_vendor", + "@local_config_rocm//rocm:rocm_headers", + "@xla//xla/ffi/api:ffi", + ], +) + +cc_library( + name = "cholesky_update_kernel", + srcs = [ + "//jaxlib/gpu:cholesky_update_kernel.cc", + ], + hdrs = ["//jaxlib/gpu:cholesky_update_kernel.h"], + features = ["-use_header_modules"], + deps = [ + ":cholesky_update_kernel_impl", + ":hip_gpu_kernel_helpers", + ":hip_vendor", + ":hipsolver_kernels", "//jaxlib:kernel_helpers", + "@xla//xla/service:custom_call_status", + "@com_google_absl//absl/status", "@local_config_rocm//rocm:rocm_headers", + ], +) + +rocm_library( + name = "cholesky_update_kernel_impl", + srcs = [ + "//jaxlib/gpu:cholesky_update_kernel.cu.cc", + ], + hdrs = [ + "//jaxlib/gpu:cholesky_update_kernel.h", + ], + deps = [ + ":hip_gpu_kernel_helpers", + ":hip_vendor", + ":hipsolver_kernels", + "//jaxlib:kernel_helpers", "@xla//xla/service:custom_call_status", + "@local_config_rocm//rocm:rocm_headers", ], ) + pybind_extension( name = "_linalg", srcs = ["//jaxlib/gpu:linalg.cc"], @@ -230,11 +271,13 @@ pybind_extension( features = ["-use_header_modules"], module_name = "_linalg", deps = [ + ":cholesky_update_kernel", ":hip_gpu_kernel_helpers", ":hip_lu_pivot_kernels", ":hip_lu_pivot_kernels_impl", ":hip_vendor", "//jaxlib:kernel_nanobind_helpers", + "@xla//xla/tsl/python/lib/core:numpy", "@local_config_rocm//rocm:rocm_headers", "@nanobind", ], @@ -295,8 +338,6 @@ cc_library( ":hip_vendor", ":triton_utils", "//jaxlib/gpu:triton_cc_proto", - "@xla//xla/service:custom_call_status", - "@xla//xla/stream_executor/gpu:asm_compiler", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/cleanup", "@com_google_absl//absl/container:flat_hash_map", @@ -306,7 +347,9 @@ cc_library( "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/synchronization", - "@tsl//tsl/platform:env", + "@xla//xla/service:custom_call_status", + "@xla//xla/stream_executor/gpu:asm_compiler", + "@xla//xla/tsl/util:env_var", ], ) @@ -347,6 +390,7 @@ pybind_extension( "@nanobind", ], ) + py_library( name = "rocm_gpu_support", deps = [ @@ -358,3 +402,12 @@ py_library( ":_triton", ], ) + +py_library( + name = "gpu_only_test_deps", + # `if_rocm_is_configured` will default to `[]`. + deps = if_rocm_is_configured([ + ":rocm_gpu_support", + "//jaxlib:rocm_plugin_extension", + ]), +) diff --git a/jaxlib/rocm_plugin_extension.cc b/jaxlib/rocm_plugin_extension.cc new file mode 100644 index 000000000000..9b0743d27cd9 --- /dev/null +++ b/jaxlib/rocm_plugin_extension.cc @@ -0,0 +1,152 @@ +/* Copyright 2024 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include + +#include +#include + +#include "nanobind/nanobind.h" +#include "absl/status/status.h" +#include "rocm/include/hip/hip_runtime.h" +#include "jaxlib/kernel_nanobind_helpers.h" +#include "xla/ffi/api/c_api.h" +#include "xla/pjrt/c/pjrt_c_api.h" +#include "xla/pjrt/c/pjrt_c_api_gpu_extension.h" +#include "xla/pjrt/c/pjrt_c_api_helpers.h" +#include "xla/pjrt/status_casters.h" +#include "xla/python/py_client_gpu.h" +#include "xla/tsl/python/lib/core/numpy.h" +#include "xla/util.h" + +namespace nb = nanobind; + +namespace xla { +namespace { +Status RegisterCustomCallTarget(const PJRT_Api* c_api, nb::str fn_name, + nb::capsule fn, int api_version, + XLA_FFI_Handler_Traits traits) { + if (c_api->extension_start == nullptr) { + return Unimplemented("The plugin does not have extension."); + } + const PJRT_Extension_Base* next = + reinterpret_cast(c_api->extension_start); + while (next != nullptr && + next->type != + PJRT_Extension_Type::PJRT_Extension_Type_Gpu_Custom_Call) { + next = next->next; + } + if (next == nullptr) { + return Unimplemented("The plugin does not have a custom call extension."); + } + + if (traits != 0) { + return Unimplemented("The plugin does not support custom call traits."); + } + + PJRT_Gpu_Register_Custom_Call_Args args; + args.struct_size = PJRT_Gpu_Register_Custom_Call_Args_STRUCT_SIZE; + args.function_name = fn_name.c_str(); + args.function_name_size = nb::len(fn_name); +#if PJRT_API_GPU_EXTENSION_VERSION >= 1 + args.api_version = api_version; +#endif + args.custom_call_function = static_cast(fn.data()); + RETURN_STATUS_IF_PJRT_ERROR( + reinterpret_cast(next)->custom_call(&args), + c_api); + return OkStatus(); +} + +nb::dict Registrations() { + nb::dict dict; + dict["xla_python_gpu_callback"] = + jax::EncapsulateFunction(xla::XlaPythonGpuCallback); + return dict; +} + +std::string ToString(hipError_t result) { +#define OSTREAM_ROCM_ERROR(__name) \ + case hipError##__name: \ + return "HIP_ERROR_" #__name; + + switch (result) { + OSTREAM_ROCM_ERROR(InvalidValue) + OSTREAM_ROCM_ERROR(OutOfMemory) + OSTREAM_ROCM_ERROR(NotInitialized) + OSTREAM_ROCM_ERROR(Deinitialized) + OSTREAM_ROCM_ERROR(NoDevice) + OSTREAM_ROCM_ERROR(InvalidDevice) + OSTREAM_ROCM_ERROR(InvalidImage) + OSTREAM_ROCM_ERROR(InvalidContext) + OSTREAM_ROCM_ERROR(InvalidHandle) + OSTREAM_ROCM_ERROR(NotFound) + OSTREAM_ROCM_ERROR(NotReady) + OSTREAM_ROCM_ERROR(NoBinaryForGpu) + + // Encountered an uncorrectable ECC error during execution. + OSTREAM_ROCM_ERROR(ECCNotCorrectable) + + // Load/store on an invalid address. Must reboot all context. + case 700: + return "ROCM_ERROR_ILLEGAL_ADDRESS"; + // Passed too many / wrong arguments, too many threads for register count. + case 701: + return "ROCM_ERROR_LAUNCH_OUT_OF_RESOURCES"; + + OSTREAM_ROCM_ERROR(ContextAlreadyInUse) + OSTREAM_ROCM_ERROR(PeerAccessUnsupported) + OSTREAM_ROCM_ERROR(Unknown) // Unknown internal error to ROCM. + default: + return absl::StrCat("hipError_t(", static_cast(result), ")"); + } +} +} // namespace + +NB_MODULE(rocm_plugin_extension, m) { + tsl::ImportNumpy(); + m.def( + "register_custom_call_target", + [](nb::capsule c_api, nb::str fn_name, nb::capsule fn, + nb::str xla_platform_name, int api_version, + XLA_FFI_Handler_Traits traits) { + xla::ThrowIfError(RegisterCustomCallTarget( + static_cast(c_api.data()), fn_name, std::move(fn), + api_version, traits)); + }, + nb::arg("c_api"), nb::arg("fn_name"), nb::arg("fn"), + nb::arg("xla_platform_name"), nb::arg("api_version") = 0, + nb::arg("traits") = 0); + m.def("registrations", &Registrations); + m.def( + "get_device_ordinal", + [](std::intptr_t data_value) { + if (data_value == 0) { + return 0; + } + int device_ordinal; + void* data_ptr = reinterpret_cast(data_value); + hipError_t result = + hipPointerGetAttribute(static_cast(&device_ordinal), + HIP_POINTER_ATTRIBUTE_DEVICE_ORDINAL, + reinterpret_cast(data_ptr)); + if (result != hipSuccess) { + LOG(FATAL) << "Not able to get the device_ordinal for ptr: " << data_ptr + << ". Error: " << ToString(result); + } + return device_ordinal; + }, + nb::arg("data_value")); +} +} // namespace xla diff --git a/jaxlib/setup.py b/jaxlib/setup.py index a06a85eb28b0..adc3ba452111 100644 --- a/jaxlib/setup.py +++ b/jaxlib/setup.py @@ -59,41 +59,19 @@ def has_ext_modules(self): author='JAX team', author_email='jax-dev@google.com', packages=['jaxlib', 'jaxlib.xla_extension'], - python_requires='>=3.9', + python_requires='>=3.10', install_requires=[ 'scipy>=1.9', "scipy>=1.11.1; python_version>='3.12'", - 'numpy>=1.22', + 'numpy>=1.24', 'ml_dtypes>=0.2.0', ], - extras_require={ - 'cuda11_pip': [ - "nvidia-cublas-cu11>=11.11", - "nvidia-cuda-cupti-cu11>=11.8", - "nvidia-cuda-nvcc-cu11>=11.8", - "nvidia-cuda-runtime-cu11>=11.8", - "nvidia-cudnn-cu11>=8.8", - "nvidia-cufft-cu11>=10.9", - "nvidia-cusolver-cu11>=11.4", - "nvidia-cusparse-cu11>=11.7", - ], - 'cuda12_pip': [ - "nvidia-cublas-cu12", - "nvidia-cuda-cupti-cu12", - "nvidia-cuda-nvcc-cu12", - "nvidia-cuda-runtime-cu12", - "nvidia-cudnn-cu12>=8.9", - "nvidia-cufft-cu12", - "nvidia-cusolver-cu12", - "nvidia-cusparse-cu12", - ], - }, url='https://github.com/google/jax', license='Apache-2.0', classifiers=[ - "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", ], package_data={ 'jaxlib': [ @@ -104,21 +82,27 @@ def has_ext_modules(self): 'cuda/*', 'cuda/nvvm/libdevice/libdevice*', 'mosaic/*.py', + 'mosaic/gpu/*.so', 'mosaic/python/*.py', 'mosaic/python/*.so', 'mlir/*.py', + 'mlir/*.pyi', 'mlir/dialects/*.py', + 'mlir/dialects/gpu/*.py', + 'mlir/dialects/gpu/passes/*.py', 'mlir/extras/*.py', 'mlir/_mlir_libs/*.dll', 'mlir/_mlir_libs/*.dylib', 'mlir/_mlir_libs/*.so', 'mlir/_mlir_libs/*.pyd', 'mlir/_mlir_libs/*.py', + 'mlir/_mlir_libs/*.pyi', 'rocm/*', 'triton/*.py', 'triton/*.pyi', 'triton/*.pyd', 'triton/*.so', + 'include/xla/ffi/api/*.h', ], 'jaxlib.xla_extension': ['*.pyi'], }, diff --git a/jaxlib/tools/BUILD.bazel b/jaxlib/tools/BUILD.bazel index ed7d129df86b..089cba21dc7b 100644 --- a/jaxlib/tools/BUILD.bazel +++ b/jaxlib/tools/BUILD.bazel @@ -32,6 +32,9 @@ py_binary( "//jaxlib:setup.py", "@xla//xla/python:xla_client.py", "@xla//xla/python:xla_extension", + "@xla//xla/ffi/api:c_api.h", + "@xla//xla/ffi/api:api.h", + "@xla//xla/ffi/api:ffi.h", ] + if_windows([ "//jaxlib/mlir/_mlir_libs:jaxlib_mlir_capi.dll", ]) + if_cuda([ @@ -42,7 +45,10 @@ py_binary( ]), deps = [ "//jax/tools:build_utils", - "@bazel_tools//tools/python/runfiles" + "@bazel_tools//tools/python/runfiles", + "@pypi_build//:pkg", + "@pypi_wheel//:pkg", + "@pypi_setuptools//:pkg", ], ) @@ -55,12 +61,31 @@ py_test( ], ) +cc_binary( + name = "pjrt_c_api_gpu_plugin.so", + linkopts = [ + "-Wl,--version-script,$(location @xla//xla/pjrt/c:pjrt_c_api_gpu_version_script.lds)", + "-Wl,--no-undefined", + ], + linkshared = True, + deps = [ + "//jaxlib/mosaic/gpu:custom_call", + "@xla//xla/pjrt/c:pjrt_c_api_gpu", + "@xla//xla/pjrt/c:pjrt_c_api_gpu_version_script.lds", + "@xla//xla/service:gpu_plugin", + ] + if_cuda([ + "@xla//xla/stream_executor:cuda_platform", + ]) + if_rocm([ + "@xla//xla/stream_executor:rocm_platform", + ]), +) + py_binary( name = "build_gpu_plugin_wheel", srcs = ["build_gpu_plugin_wheel.py"], data = [ "LICENSE.txt", - "@xla//xla/pjrt/c:pjrt_c_api_gpu_plugin.so", + ":pjrt_c_api_gpu_plugin.so", ] + if_cuda([ "//jaxlib:version", "//jaxlib/cuda:cuda_gpu_support", @@ -68,28 +93,47 @@ py_binary( "//jax_plugins/cuda:setup.py", "//jax_plugins/cuda:__init__.py", "@local_config_cuda//cuda:cuda-nvvm", + ]) + if_rocm([ + "//jaxlib:version", + "//jaxlib/rocm:rocm_gpu_support", + "//jax_plugins/rocm:pyproject.toml", + "//jax_plugins/rocm:setup.py", + "//jax_plugins/rocm:__init__.py", ]), deps = [ "//jax/tools:build_utils", - "@bazel_tools//tools/python/runfiles" + "@bazel_tools//tools/python/runfiles", + "@pypi_build//:pkg", + "@pypi_wheel//:pkg", + "@pypi_setuptools//:pkg", ], ) py_binary( - name = "build_cuda_kernels_wheel", - srcs = ["build_cuda_kernels_wheel.py"], + name = "build_gpu_kernels_wheel", + srcs = ["build_gpu_kernels_wheel.py"], data = [ "LICENSE.txt", ] + if_cuda([ + "//jaxlib/mosaic/gpu:mosaic_gpu", "//jaxlib:cuda_plugin_extension", "//jaxlib:version", "//jaxlib/cuda:cuda_gpu_support", "//jax_plugins/cuda:plugin_pyproject.toml", "//jax_plugins/cuda:plugin_setup.py", "@local_config_cuda//cuda:cuda-nvvm", + ]) + if_rocm([ + "//jaxlib:rocm_plugin_extension", + "//jaxlib:version", + "//jaxlib/rocm:rocm_gpu_support", + "//jax_plugins/rocm:plugin_pyproject.toml", + "//jax_plugins/rocm:plugin_setup.py", ]), deps = [ "//jax/tools:build_utils", - "@bazel_tools//tools/python/runfiles" + "@bazel_tools//tools/python/runfiles", + "@pypi_build//:pkg", + "@pypi_wheel//:pkg", + "@pypi_setuptools//:pkg", ], ) diff --git a/jaxlib/tools/LICENSE.txt b/jaxlib/tools/LICENSE.txt index 71123c553fab..6c7416993e2d 100644 --- a/jaxlib/tools/LICENSE.txt +++ b/jaxlib/tools/LICENSE.txt @@ -4331,34 +4331,6 @@ Copyright 2019 The TensorFlow Authors. All rights reserved. See the License for the specific language governing permissions and limitations under the License. --------------------------------------------------------------------------------- -License for the FFT components of ducc0: -Copyright (C) 2010-2022 Max-Planck-Society -All rights reserved. - -Redistribution and use in source and binary forms, with or without modification, -are permitted provided that the following conditions are met: - -* Redistributions of source code must retain the above copyright notice, this - list of conditions and the following disclaimer. -* Redistributions in binary form must reproduce the above copyright notice, this - list of conditions and the following disclaimer in the documentation and/or - other materials provided with the distribution. -* Neither the name of the copyright holder nor the names of its contributors may - be used to endorse or promote products derived from this software without - specific prior written permission. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND -ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED -WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR -ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES -(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; -LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON -ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS -SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -------------------------------------------------------------------------------- License for pybind11: Copyright (c) 2016 Wenzel Jakob , All rights reserved. diff --git a/jaxlib/tools/build_cuda_kernels_wheel.py b/jaxlib/tools/build_gpu_kernels_wheel.py similarity index 60% rename from jaxlib/tools/build_cuda_kernels_wheel.py rename to jaxlib/tools/build_gpu_kernels_wheel.py index 34280ff1ffbf..28d2806a7da9 100644 --- a/jaxlib/tools/build_cuda_kernels_wheel.py +++ b/jaxlib/tools/build_gpu_kernels_wheel.py @@ -43,16 +43,24 @@ "--cpu", default=None, required=True, help="Target CPU architecture. Required." ) parser.add_argument( - "--cuda_version", + "--platform_version", default=None, required=True, - help="Target CUDA version. Required.", + help="Target CUDA/ROCM version. Required.", ) parser.add_argument( "--editable", action="store_true", - help="Create an 'editable' jax cuda plugin build instead of a wheel.", + help="Create an 'editable' jax cuda/rocm plugin build instead of a wheel.", ) +parser.add_argument( + "--enable-cuda", + default=False, + help="Should we build with CUDA enabled? Requires CUDA and CuDNN.") +parser.add_argument( + "--enable-rocm", + default=False, + help="Should we build with ROCM enabled?") args = parser.parse_args() r = runfiles.Create() @@ -70,7 +78,7 @@ def write_setup_cfg(sources_path, cpu): """) -def prepare_wheel( +def prepare_wheel_cuda( sources_path: pathlib.Path, *, cpu, cuda_version ): """Assembles a source tree for the cuda kernel wheel in `sources_path`.""" @@ -90,10 +98,6 @@ def prepare_wheel( write_setup_cfg(sources_path, cpu) plugin_dir = sources_path / f"jax_cuda{cuda_version}_plugin" - copy_runfiles( - dst_dir=plugin_dir / "nvvm" / "libdevice", - src_files=["local_config_cuda/cuda/cuda/nvvm/libdevice/libdevice.10.bc"], - ) copy_runfiles( dst_dir=plugin_dir, src_files=[ @@ -106,19 +110,64 @@ def prepare_wheel( f"__main__/jaxlib/cuda/_triton.{pyext}", f"__main__/jaxlib/cuda/_versions.{pyext}", f"__main__/jaxlib/cuda_plugin_extension.{pyext}", + f"__main__/jaxlib/mosaic/gpu/_mosaic_gpu_ext.{pyext}", + "__main__/jaxlib/mosaic/gpu/libmosaic_gpu_runtime.so", + "__main__/jaxlib/version.py", + ], + ) + +def prepare_wheel_rocm( + sources_path: pathlib.Path, *, cpu, rocm_version +): + """Assembles a source tree for the rocm kernel wheel in `sources_path`.""" + copy_runfiles = functools.partial(build_utils.copy_file, runfiles=r) + + copy_runfiles( + "__main__/jax_plugins/rocm/plugin_pyproject.toml", + dst_dir=sources_path, + dst_filename="pyproject.toml", + ) + copy_runfiles( + "__main__/jax_plugins/rocm/plugin_setup.py", + dst_dir=sources_path, + dst_filename="setup.py", + ) + build_utils.update_setup_with_rocm_version(sources_path, rocm_version) + write_setup_cfg(sources_path, cpu) + + plugin_dir = sources_path / f"jax_rocm{rocm_version}_plugin" + copy_runfiles( + dst_dir=plugin_dir, + src_files=[ + f"__main__/jaxlib/rocm/_solver.{pyext}", + f"__main__/jaxlib/rocm/_blas.{pyext}", + f"__main__/jaxlib/rocm/_linalg.{pyext}", + f"__main__/jaxlib/rocm/_prng.{pyext}", + f"__main__/jaxlib/rocm/_sparse.{pyext}", + f"__main__/jaxlib/rocm/_triton.{pyext}", + f"__main__/jaxlib/rocm_plugin_extension.{pyext}", "__main__/jaxlib/version.py", ], ) # Build wheel for cuda kernels -tmpdir = tempfile.TemporaryDirectory(prefix="jax_cuda_plugin") +if args.enable_rocm: + tmpdir = tempfile.TemporaryDirectory(prefix="jax_rocm_plugin") +else: + tmpdir = tempfile.TemporaryDirectory(prefix="jax_cuda_plugin") sources_path = tmpdir.name try: os.makedirs(args.output_path, exist_ok=True) - prepare_wheel( - pathlib.Path(sources_path), cpu=args.cpu, cuda_version=args.cuda_version - ) - package_name = f"jax cuda{args.cuda_version} plugin" + if args.enable_cuda: + prepare_wheel_cuda( + pathlib.Path(sources_path), cpu=args.cpu, cuda_version=args.platform_version + ) + package_name = f"jax cuda{args.platform_version} plugin" + elif args.enable_rocm: + prepare_wheel_rocm( + pathlib.Path(sources_path), cpu=args.cpu, rocm_version=args.platform_version + ) + package_name = f"jax rocm{args.platform_version} plugin" if args.editable: build_utils.build_editable(sources_path, args.output_path, package_name) else: diff --git a/jaxlib/tools/build_gpu_plugin_wheel.py b/jaxlib/tools/build_gpu_plugin_wheel.py index 7e178e3ad2a4..73cb8a9e020d 100644 --- a/jaxlib/tools/build_gpu_plugin_wheel.py +++ b/jaxlib/tools/build_gpu_plugin_wheel.py @@ -12,8 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -# Script that builds a jax cuda plugin wheel, intended to be run via bazel run -# as part of the jax cuda plugin build process. +# Script that builds a jax cuda/rocm plugin wheel, intended to be run via bazel run +# as part of the jax cuda/rocm plugin build process. # Most users should not run this script directly; use build.py instead. @@ -49,16 +49,24 @@ "--cpu", default=None, required=True, help="Target CPU architecture. Required." ) parser.add_argument( - "--cuda_version", + "--platform_version", default=None, required=True, - help="Target CUDA version. Required.", + help="Target CUDA/ROCM version. Required.", ) parser.add_argument( "--editable", action="store_true", - help="Create an 'editable' jax cuda plugin build instead of a wheel.", + help="Create an 'editable' jax cuda/rocm plugin build instead of a wheel.", ) +parser.add_argument( + "--enable-cuda", + default=False, + help="Should we build with CUDA enabled? Requires CUDA and CuDNN.") +parser.add_argument( + "--enable-rocm", + default=False, + help="Should we build with ROCM enabled?") args = parser.parse_args() r = runfiles.Create() @@ -100,24 +108,62 @@ def prepare_cuda_plugin_wheel(sources_path: pathlib.Path, *, cpu, cuda_version): ], ) copy_runfiles( - "xla/xla/pjrt/c/pjrt_c_api_gpu_plugin.so", + "__main__/jaxlib/tools/pjrt_c_api_gpu_plugin.so", dst_dir=plugin_dir, dst_filename="xla_cuda_plugin.so", ) +def prepare_rocm_plugin_wheel(sources_path: pathlib.Path, *, cpu, rocm_version): + """Assembles a source tree for the ROCm wheel in `sources_path`.""" + copy_runfiles = functools.partial(build_utils.copy_file, runfiles=r) + + plugin_dir = sources_path / "jax_plugins" / f"xla_rocm{rocm_version}" + copy_runfiles( + dst_dir=sources_path, + src_files=[ + "__main__/jax_plugins/rocm/pyproject.toml", + "__main__/jax_plugins/rocm/setup.py", + ], + ) + build_utils.update_setup_with_rocm_version(sources_path, rocm_version) + write_setup_cfg(sources_path, cpu) + copy_runfiles( + dst_dir=plugin_dir, + src_files=[ + "__main__/jax_plugins/rocm/__init__.py", + "__main__/jaxlib/version.py", + ], + ) + copy_runfiles( + "__main__/jaxlib/tools/pjrt_c_api_gpu_plugin.so", + dst_dir=plugin_dir, + dst_filename="xla_rocm_plugin.so", + ) + + tmpdir = None sources_path = args.sources_path if sources_path is None: - tmpdir = tempfile.TemporaryDirectory(prefix="jaxcudapjrt") + tmpdir = tempfile.TemporaryDirectory(prefix="jaxgpupjrt") sources_path = tmpdir.name try: os.makedirs(args.output_path, exist_ok=True) - prepare_cuda_plugin_wheel( - pathlib.Path(sources_path), cpu=args.cpu, cuda_version=args.cuda_version - ) - package_name = "jax cuda plugin" + + if args.enable_cuda: + prepare_cuda_plugin_wheel( + pathlib.Path(sources_path), cpu=args.cpu, cuda_version=args.platform_version + ) + package_name = "jax cuda plugin" + elif args.enable_rocm: + prepare_rocm_plugin_wheel( + pathlib.Path(sources_path), cpu=args.cpu, rocm_version=args.platform_version + ) + package_name = "jax rocm plugin" + else: + raise ValueError("Unsupported backend. Choose either 'cuda' or 'rocm'.") + if args.editable: build_utils.build_editable(sources_path, args.output_path, package_name) else: diff --git a/jaxlib/tools/build_wheel.py b/jaxlib/tools/build_wheel.py index 4ef295c39411..62864f7ad30d 100644 --- a/jaxlib/tools/build_wheel.py +++ b/jaxlib/tools/build_wheel.py @@ -46,7 +46,7 @@ "--jaxlib_git_hash", default="", required=True, - help="Git hash. Empty if unknown. Optional.", + help="Git hash. Empty if unknown. Required.", ) parser.add_argument( "--cpu", default=None, required=True, help="Target CPU architecture. Required." @@ -57,11 +57,11 @@ help="Create an 'editable' jaxlib build instead of a wheel.", ) parser.add_argument( - "--include_gpu_plugin_extension", - # args.include_gpu_plugin_extension is True when - # --include_gpu_plugin_extension is in the command + "--skip_gpu_kernels", + # args.skip_gpu_kernels is True when + # --skip_gpu_kernels is in the command action="store_true", - help="Whether to include gpu plugin extension.", + help="Whether to skip gpu kernels in jaxlib.", ) args = parser.parse_args() @@ -100,6 +100,8 @@ def patch_copy_mlir_import(src_file, dst_dir): _XLA_EXTENSION_STUBS = [ "__init__.pyi", + "ifrt_programs.pyi", + "ifrt_proxy.pyi", "jax_jit.pyi", "ops.pyi", "outfeed_receiver.pyi", @@ -167,7 +169,7 @@ def write_setup_cfg(sources_path, cpu): ) -def prepare_wheel(sources_path: pathlib.Path, *, cpu, include_gpu_plugin_extension): +def prepare_wheel(sources_path: pathlib.Path, *, cpu, skip_gpu_kernels): """Assembles a source tree for the wheel in `sources_path`.""" copy_runfiles = functools.partial(build_utils.copy_file, runfiles=r) @@ -193,7 +195,6 @@ def prepare_wheel(sources_path: pathlib.Path, *, cpu, include_gpu_plugin_extensi f"__main__/jaxlib/utils.{pyext}", "__main__/jaxlib/lapack.py", "__main__/jaxlib/hlo_helpers.py", - "__main__/jaxlib/ducc_fft.py", "__main__/jaxlib/gpu_prng.py", "__main__/jaxlib/gpu_linalg.py", "__main__/jaxlib/gpu_rnn.py", @@ -203,7 +204,7 @@ def prepare_wheel(sources_path: pathlib.Path, *, cpu, include_gpu_plugin_extensi "__main__/jaxlib/gpu_sparse.py", "__main__/jaxlib/version.py", "__main__/jaxlib/xla_client.py", - f"__main__/jaxlib/xla_extension.{pyext}", + f"xla/xla/python/xla_extension.{pyext}", ], ) # This file is required by PEP-561. It marks jaxlib as package containing @@ -216,15 +217,10 @@ def prepare_wheel(sources_path: pathlib.Path, *, cpu, include_gpu_plugin_extensi dst_dir=jaxlib_dir / "cpu", src_files=[ f"__main__/jaxlib/cpu/_lapack.{pyext}", - f"__main__/jaxlib/cpu/_ducc_fft.{pyext}", ], ) - if exists(f"__main__/jaxlib/cuda/_solver.{pyext}") and not include_gpu_plugin_extension: - copy_runfiles( - dst_dir=jaxlib_dir / "cuda" / "nvvm" / "libdevice", - src_files=["local_config_cuda/cuda/cuda/nvvm/libdevice/libdevice.10.bc"], - ) + if exists(f"__main__/jaxlib/cuda/_solver.{pyext}") and not skip_gpu_kernels: copy_runfiles( dst_dir=jaxlib_dir / "cuda", src_files=[ @@ -238,7 +234,7 @@ def prepare_wheel(sources_path: pathlib.Path, *, cpu, include_gpu_plugin_extensi f"__main__/jaxlib/cuda/_versions.{pyext}", ], ) - if exists(f"__main__/jaxlib/rocm/_solver.{pyext}"): + if exists(f"__main__/jaxlib/rocm/_solver.{pyext}") and not skip_gpu_kernels: copy_runfiles( dst_dir=jaxlib_dir / "rocm", src_files=[ @@ -268,7 +264,9 @@ def prepare_wheel(sources_path: pathlib.Path, *, cpu, include_gpu_plugin_extensi dst_dir=jaxlib_dir / "mlir", src_files=[ "__main__/jaxlib/mlir/ir.py", + "__main__/jaxlib/mlir/ir.pyi", "__main__/jaxlib/mlir/passmanager.py", + "__main__/jaxlib/mlir/passmanager.pyi", ], ) copy_runfiles( @@ -289,6 +287,14 @@ def prepare_wheel(sources_path: pathlib.Path, *, cpu, include_gpu_plugin_extensi "__main__/jaxlib/mlir/dialects/_stablehlo_ops_gen.py", "__main__/jaxlib/mlir/dialects/_vector_enum_gen.py", "__main__/jaxlib/mlir/dialects/_vector_ops_gen.py", + "__main__/jaxlib/mlir/dialects/_gpu_enum_gen.py", + "__main__/jaxlib/mlir/dialects/_gpu_ops_gen.py", + "__main__/jaxlib/mlir/dialects/_nvgpu_enum_gen.py", + "__main__/jaxlib/mlir/dialects/_nvgpu_ops_gen.py", + "__main__/jaxlib/mlir/dialects/_nvvm_enum_gen.py", + "__main__/jaxlib/mlir/dialects/_nvvm_ops_gen.py", + "__main__/jaxlib/mlir/dialects/_llvm_enum_gen.py", + "__main__/jaxlib/mlir/dialects/_llvm_ops_gen.py", "__main__/jaxlib/mlir/dialects/arith.py", "__main__/jaxlib/mlir/dialects/builtin.py", "__main__/jaxlib/mlir/dialects/chlo.py", @@ -300,6 +306,9 @@ def prepare_wheel(sources_path: pathlib.Path, *, cpu, include_gpu_plugin_extensi "__main__/jaxlib/mlir/dialects/sparse_tensor.py", "__main__/jaxlib/mlir/dialects/stablehlo.py", "__main__/jaxlib/mlir/dialects/vector.py", + "__main__/jaxlib/mlir/dialects/nvgpu.py", + "__main__/jaxlib/mlir/dialects/nvvm.py", + "__main__/jaxlib/mlir/dialects/llvm.py", ], ) copy_runfiles( @@ -308,6 +317,19 @@ def prepare_wheel(sources_path: pathlib.Path, *, cpu, include_gpu_plugin_extensi "__main__/jaxlib/mlir/extras/meta.py", ], ) + copy_runfiles( + dst_dir=jaxlib_dir / "mlir" / "dialects" / "gpu", + src_files=[ + "__main__/jaxlib/mlir/dialects/gpu/__init__.py", + ], + ) + copy_runfiles( + dst_dir=jaxlib_dir / "mlir" / "dialects" / "gpu" / "passes", + src_files=[ + "__main__/jaxlib/mlir/dialects/gpu/passes/__init__.py", + ], + ) + if build_utils.is_windows(): capi_so = "__main__/jaxlib/mlir/_mlir_libs/jaxlib_mlir_capi.dll" @@ -329,6 +351,10 @@ def prepare_wheel(sources_path: pathlib.Path, *, cpu, include_gpu_plugin_extensi f"__main__/jaxlib/mlir/_mlir_libs/_tpu_ext.{pyext}", f"__main__/jaxlib/mlir/_mlir_libs/_stablehlo.{pyext}", f"__main__/jaxlib/mlir/_mlir_libs/register_jax_dialects.{pyext}", + f"__main__/jaxlib/mlir/_mlir_libs/_mlirDialectsGPU.{pyext}", + f"__main__/jaxlib/mlir/_mlir_libs/_mlirDialectsLLVM.{pyext}", + f"__main__/jaxlib/mlir/_mlir_libs/_mlirDialectsNVGPU.{pyext}", + f"__main__/jaxlib/mlir/_mlir_libs/_mlirGPUPasses.{pyext}", ] + ( [] @@ -355,6 +381,14 @@ def prepare_wheel(sources_path: pathlib.Path, *, cpu, include_gpu_plugin_extensi "__main__/jaxlib/triton/_triton_ops_gen.py", dst_dir=triton_dir ) + copy_runfiles( + dst_dir=jaxlib_dir / "include" / "xla" / "ffi" / "api", + src_files=[ + "xla/xla/ffi/api/c_api.h", + "xla/xla/ffi/api/api.h", + "xla/xla/ffi/api/ffi.h", + ], + ) tmpdir = None sources_path = args.sources_path @@ -367,7 +401,7 @@ def prepare_wheel(sources_path: pathlib.Path, *, cpu, include_gpu_plugin_extensi prepare_wheel( pathlib.Path(sources_path), cpu=args.cpu, - include_gpu_plugin_extension=args.include_gpu_plugin_extension, + skip_gpu_kernels=args.skip_gpu_kernels, ) package_name = "jaxlib" if args.editable: diff --git a/jaxlib/triton/dialect.py b/jaxlib/triton/dialect.py index 315f352651c7..1bbb565b69b2 100644 --- a/jaxlib/triton/dialect.py +++ b/jaxlib/triton/dialect.py @@ -55,12 +55,13 @@ def __init__( self, operands: Sequence[ir.Value], axis: int, + reverse: bool = False, *, loc: ir.Location | None = None, ip: ir.InsertionPoint | None = None, ): return_types = [op.type for op in operands] - super().__init__(return_types, operands, axis, loc=loc, ip=ip) + super().__init__(return_types, operands, axis, reverse, loc=loc, ip=ip) # TODO(slebedev): Consider overriding instead. diff --git a/platform_mappings b/platform_mappings new file mode 100644 index 000000000000..56dc745d759f --- /dev/null +++ b/platform_mappings @@ -0,0 +1,11 @@ +platforms: +# Maps "--platforms=//tools/toolchains/cross_compile/config:darwin_x86_64" +# to "--cpu=darwin". + @xla//tools/toolchains/cross_compile/config:darwin_x86_64 + --cpu=darwin + +flags: + # Maps "--cpu=darwin" to + # "--platforms=//tools/toolchains/cross_compile/config:darwin_x86_64". + --cpu=darwin + @xla//tools/toolchains/cross_compile/config:darwin_x86_64 diff --git a/pyproject.toml b/pyproject.toml index 0a5873d89e16..21b32fc92306 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,19 +6,19 @@ build-backend = "setuptools.build_meta" show_error_codes = true disable_error_code = "attr-defined, name-defined, annotation-unchecked" no_implicit_optional = true +warn_unused_ignores = true [[tool.mypy.overrides]] module = [ "absl.*", "colorama.*", - "importlib_metadata.*", + "filelock.*", "IPython.*", "numpy.*", "opt_einsum.*", "scipy.*", "libtpu.*", "jaxlib.mlir.*", - "iree.*", "rich.*", "optax.*", "flatbuffers.*", @@ -40,17 +40,10 @@ module = [ "jax.experimental.jax2tf.tests.flax_models", "jax.experimental.jax2tf.tests.back_compat_testdata", "setuptools.*", + "jax_cuda12_plugin.*", ] ignore_missing_imports = true -[[tool.mypy.overrides]] -module = [ - "jax.interpreters.autospmd", - "jax.lax.lax_parallel", - "jax._src.internal_test_util.test_harnesses", -] -ignore_errors = true - [tool.pytest.ini_options] markers = [ "multiaccelerator: indicates that a test can make use of and possibly requires multiple accelerators", @@ -58,32 +51,18 @@ markers = [ ] filterwarnings = [ "error", - "ignore:The hookimpl.*:DeprecationWarning", - "ignore:No GPU/TPU found, falling back to CPU.:UserWarning", - "ignore:xmap is an experimental feature and probably has bugs!", - "ignore:the imp module is deprecated in favour of importlib.*:DeprecationWarning", - "ignore:can't resolve package from __spec__ or __package__:ImportWarning", - "ignore:Using or importing the ABCs.*:DeprecationWarning", - "ignore:numpy.ufunc size changed", - "ignore:.*experimental feature", - "ignore:The distutils.* is deprecated.*:DeprecationWarning", - "default:Error reading persistent compilation cache entry for 'jit__lambda_'", - "default:Error writing persistent compilation cache entry for 'jit__lambda_'", - "ignore:backend and device argument on jit is deprecated.*:DeprecationWarning", - # TODO(skyewm): remove when jaxlib >= 0.4.12 is released (needs - # https://github.com/openxla/xla/commit/fb9dc3db0999bf14c78d95cb7c3aa6815221ddc7) - "ignore:ml_dtypes.float8_e4m3b11 is deprecated.", - "ignore:JAX_USE_PJRT_C_API_ON_TPU=false will no longer be supported.*:UserWarning", - "ignore:np.find_common_type is deprecated.*:DeprecationWarning", - "ignore:jax.numpy.in1d is deprecated.*:DeprecationWarning", + "default:Error (reading|writing) persistent compilation cache entry for 'jit_equal'", + "default:Error (reading|writing) persistent compilation cache entry for 'jit__lambda_'", + "default:jax.extend.mlir.dialects.mhlo is deprecated.*:DeprecationWarning", + "default:jax.xla_computation is deprecated. Please use the AOT APIs.*:DeprecationWarning", # TODO(jakevdp): remove when array_api_tests stabilize # start array_api_tests-related warnings - "ignore:The numpy.array_api submodule is still experimental.*:UserWarning", - "ignore:case not machine-readable.*:UserWarning", - "ignore:not machine-readable.*:UserWarning", - "ignore:Special cases found for .* but none were parsed.*:UserWarning", + "default:.*not machine-readable.*:UserWarning", + "default:Special cases found for .* but none were parsed.*:UserWarning", + "default:.*is not JSON-serializable. Using the repr instead.", # end array_api_tests-related warnings - "ignore:jax.extend.mlir.dialects.mhlo is deprecated.*:DeprecationWarning", + # TODO(slebedev): Remove once we migrate all pl.BlockSpec usages in JAX. + "default:BlockSpec now expects .*:DeprecationWarning", ] doctest_optionflags = [ "NUMBER", @@ -122,6 +101,11 @@ exclude = [ "build", "__pycache__", ] +line-length = 88 +indent-width = 2 +target-version = "py310" + +[tool.ruff.lint] ignore = [ # Unnecessary collection call "C408", @@ -133,9 +117,9 @@ ignore = [ "F841", # Raise with from clause inside except block "B904", + # Zip without explicit strict parameter + "B905", ] -line-length = 88 -indent-width = 2 select = [ "B9", "C", @@ -147,12 +131,11 @@ select = [ "E227", "E228", ] -target-version = "py39" -[tool.ruff.mccabe] +[tool.ruff.lint.mccabe] max-complexity = 18 -[tool.ruff.per-file-ignores] +[tool.ruff.lint.per-file-ignores] # F811: Redefinition of unused name. "docs/autodidax.py" = ["F811"] # Note: we don't use jax/*.py because this matches contents of jax/_src diff --git a/setup.py b/setup.py index 1c35aa0d859c..cc2c75ab7ff4 100644 --- a/setup.py +++ b/setup.py @@ -12,23 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -from distutils import spawn import importlib import os -import subprocess -import sys from setuptools import setup, find_packages project_name = 'jax' -_current_jaxlib_version = '0.4.25' +_current_jaxlib_version = '0.4.30' # The following should be updated with each new jaxlib release. -_latest_jaxlib_version_on_pypi = '0.4.25' -_available_cuda11_cudnn_versions = ['86'] -_default_cuda11_cudnn_version = '86' -_default_cuda12_cudnn_version = '89' -_libtpu_version = '0.1.dev20240224' +_latest_jaxlib_version_on_pypi = '0.4.30' +_libtpu_version = '0.1.dev20240617' def load_version_module(pkg_path): spec = importlib.util.spec_from_file_location( @@ -39,27 +33,13 @@ def load_version_module(pkg_path): _version_module = load_version_module(project_name) __version__ = _version_module._get_version_for_build() +_jax_version = _version_module._version # JAX version, with no .dev suffix. _cmdclass = _version_module._get_cmdclass(project_name) _minimum_jaxlib_version = _version_module._minimum_jaxlib_version with open('README.md', encoding='utf-8') as f: _long_description = f.read() -if 'PROTOC' in os.environ and os.path.exists(os.environ['PROTOC']): - protoc = os.environ['PROTOC'] -else: - protoc = spawn.find_executable('protoc') - -def generate_proto(source): - if not protoc or not os.path.exists(source): - return - protoc_command = [protoc, '-I.', '--python_out=.', source] - if subprocess.call(protoc_command) != 0: - sys.exit(-1) - -generate_proto("jax/experimental/australis/executable.proto") -generate_proto("jax/experimental/australis/petri.proto") - setup( name=project_name, version=__version__, @@ -71,27 +51,23 @@ def generate_proto(source): author_email='jax-dev@google.com', packages=find_packages(exclude=["examples", "jax/src/internal_test_util"]), package_data={'jax': ['py.typed', "*.pyi", "**/*.pyi"]}, - python_requires='>=3.9', + python_requires='>=3.10', install_requires=[ + f'jaxlib >={_minimum_jaxlib_version}, <={_jax_version}', 'ml_dtypes>=0.2.0', - 'numpy>=1.22', - "numpy>=1.23.2; python_version>='3.11'", + 'numpy>=1.24', "numpy>=1.26.0; python_version>='3.12'", 'opt_einsum', 'scipy>=1.9', "scipy>=1.11.1; python_version>='3.12'", - # Required by xla_bridge.discover_pjrt_plugins for forwards compat with - # Python versions < 3.10. Can be dropped when 3.10 is the minimum - # required Python version. - 'importlib_metadata>=4.6;python_version<"3.10"', ], extras_require={ # Minimum jaxlib version; used in testing. 'minimum-jaxlib': [f'jaxlib=={_minimum_jaxlib_version}'], - # CPU-only jaxlib can be installed via: - # $ pip install jax[cpu] - 'cpu': [f'jaxlib=={_current_jaxlib_version}'], + # A CPU-only jax doesn't require any extras, but we keep this extra + # around for compatibility. + 'cpu': [], # Used only for CI builds that install JAX from github HEAD. 'ci': [f'jaxlib=={_latest_jaxlib_version_on_pypi}'], @@ -99,96 +75,41 @@ def generate_proto(source): # Cloud TPU VM jaxlib can be installed via: # $ pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html 'tpu': [ - f'jaxlib=={_current_jaxlib_version}', + f'jaxlib>={_current_jaxlib_version},<={_jax_version}', f'libtpu-nightly=={_libtpu_version}', 'requests', # necessary for jax.distributed.initialize ], - # $ pip install jax[australis] - 'australis': ['protobuf>=3.13,<4'], - - # CUDA installations require adding the JAX CUDA releases URL, e.g., - # Cuda installation defaulting to a CUDA and Cudnn version defined above. - # $ pip install jax[cuda] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html - 'cuda': [f"jaxlib=={_current_jaxlib_version}+cuda11.cudnn{_default_cuda11_cudnn_version}"], - - 'cuda11_pip': [ - f"jaxlib=={_current_jaxlib_version}+cuda11.cudnn{_default_cuda11_cudnn_version}", - "nvidia-cublas-cu11>=11.11", - "nvidia-cuda-cupti-cu11>=11.8", - "nvidia-cuda-nvcc-cu11>=11.8", - "nvidia-cuda-runtime-cu11>=11.8", - "nvidia-cudnn-cu11>=8.8", - "nvidia-cufft-cu11>=10.9", - "nvidia-cusolver-cu11>=11.4", - "nvidia-cusparse-cu11>=11.7", - "nvidia-nccl-cu11>=2.18.3", + 'cuda': [ + f"jaxlib=={_current_jaxlib_version}", + f"jax-cuda12-plugin[with_cuda]>={_current_jaxlib_version},<={_jax_version}", ], - 'cuda12_pip': [ - f"jaxlib=={_current_jaxlib_version}+cuda12.cudnn{_default_cuda12_cudnn_version}", - "nvidia-cublas-cu12>=12.3.4.1", - "nvidia-cuda-cupti-cu12>=12.3.101", - "nvidia-cuda-nvcc-cu12>=12.3.107", - "nvidia-cuda-runtime-cu12>=12.3.101", - "nvidia-cudnn-cu12>=8.9.7.29", - "nvidia-cufft-cu12>=11.0.12.1", - "nvidia-cusolver-cu12>=11.5.4.101", - "nvidia-cusparse-cu12>=12.2.0.103", - "nvidia-nccl-cu12>=2.19.3", - # nvjitlink is not a direct dependency of JAX, but it is a transitive - # dependency via, for example, cuSOLVER. NVIDIA's cuSOLVER packages - # do not have a version constraint on their dependencies, so the - # package doesn't get upgraded even though not doing that can cause - # problems (https://github.com/google/jax/issues/18027#issuecomment-1756305196) - # Until NVIDIA add version constraints, add an version constraint - # here. - "nvidia-nvjitlink-cu12>=12.3.101", + 'cuda12': [ + f"jaxlib=={_current_jaxlib_version}", + f"jax-cuda12-plugin[with_cuda]>={_current_jaxlib_version},<={_jax_version}", ], - 'cuda12': [ + # Deprecated alias for cuda12, kept to avoid breaking users who wrote + # cuda12_pip in their CI. + 'cuda12_pip': [ f"jaxlib=={_current_jaxlib_version}", - f"jax-cuda12-plugin=={_current_jaxlib_version}", - "nvidia-cublas-cu12>=12.3.4.1", - "nvidia-cuda-cupti-cu12>=12.3.101", - "nvidia-cuda-nvcc-cu12>=12.3.107", - "nvidia-cuda-runtime-cu12>=12.3.101", - "nvidia-cudnn-cu12>=8.9.7.29", - "nvidia-cufft-cu12>=11.0.12.1", - "nvidia-cusolver-cu12>=11.5.4.101", - "nvidia-cusparse-cu12>=12.2.0.103", - "nvidia-nccl-cu12>=2.19.3", - # nvjitlink is not a direct dependency of JAX, but it is a transitive - # dependency via, for example, cuSOLVER. NVIDIA's cuSOLVER packages - # do not have a version constraint on their dependencies, so the - # package doesn't get upgraded even though not doing that can cause - # problems (https://github.com/google/jax/issues/18027#issuecomment-1756305196) - # Until NVIDIA add version constraints, add an version constraint - # here. - "nvidia-nvjitlink-cu12>=12.3.101", + f"jax-cuda12-plugin[with_cuda]>={_current_jaxlib_version},<={_jax_version}", ], # Target that does not depend on the CUDA pip wheels, for those who want # to use a preinstalled CUDA. - 'cuda11_local': [ - f"jaxlib=={_current_jaxlib_version}+cuda11.cudnn{_default_cuda11_cudnn_version}", - ], 'cuda12_local': [ - f"jaxlib=={_current_jaxlib_version}+cuda12.cudnn{_default_cuda12_cudnn_version}", + f"jaxlib=={_current_jaxlib_version}", + f"jax-cuda12-plugin=={_current_jaxlib_version}", ], - - # CUDA installations require adding jax releases URL; e.g. - # $ pip install jax[cuda11_cudnn82] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html - # $ pip install jax[cuda11_cudnn86] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html - **{f'cuda11_cudnn{cudnn_version}': f"jaxlib=={_current_jaxlib_version}+cuda11.cudnn{cudnn_version}" - for cudnn_version in _available_cuda11_cudnn_versions} }, url='https://github.com/google/jax', license='Apache-2.0', classifiers=[ - "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", ], zip_safe=False, ) diff --git a/tests/BUILD b/tests/BUILD index 9c8ca93103b7..b5a9b916d134 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +load("@rules_python//python:defs.bzl", "py_test") load( "//jaxlib:jax.bzl", "jax_generate_backend_suites", @@ -20,7 +21,6 @@ load( "py_deps", "pytype_test", ) -load("@rules_python//python:defs.bzl", "py_test") licenses(["notice"]) @@ -35,6 +35,7 @@ jax_test( name = "api_test", srcs = ["api_test.py"], shard_count = 10, + tags = ["test_cpu_thunks"], ) jax_test( @@ -54,6 +55,7 @@ py_test( deps = [ "//jax", "//jax:experimental_array_api", + "//jax:test_util", ] + py_deps("absl/testing"), ) @@ -73,6 +75,11 @@ jax_test( }, ) +jax_test( + name = "config_test", + srcs = ["config_test.py"], +) + jax_test( name = "core_test", srcs = ["core_test.py"], @@ -137,6 +144,7 @@ jax_test( shard_count = { "tpu": 20, "cpu": 20, + "gpu": 10, }, ) @@ -233,6 +241,9 @@ jax_test( shard_count = { "tpu": 5, }, + deps = [ + "//jax:experimental", + ], ) jax_test( @@ -256,6 +267,9 @@ jax_test( jax_test( name = "layout_test", srcs = ["layout_test.py"], + backend_tags = { + "tpu": ["requires-mem:16g"], # Under tsan on 2x2 this test exceeds the default 12G memory limit. + }, tags = ["multiaccelerator"], ) @@ -305,6 +319,9 @@ jax_test( jax_test( name = "array_test", srcs = ["array_test.py"], + backend_tags = { + "tpu": ["requires-mem:16g"], # Under tsan on 2x2 this test exceeds the default 12G memory limit. + }, tags = ["multiaccelerator"], deps = [ "//jax:experimental", @@ -336,6 +353,7 @@ jax_test( jax_test( name = "infeed_test", srcs = ["infeed_test.py"], + tags = ["test_cpu_thunks"], deps = [ "//jax:experimental_host_callback", ], @@ -345,6 +363,7 @@ jax_test( name = "jax_jit_test", srcs = ["jax_jit_test.py"], main = "jax_jit_test.py", + tags = ["test_cpu_thunks"], ) py_test( @@ -542,6 +561,21 @@ jax_test( ] + py_deps("numpy"), ) +jax_test( + name = "lax_metal_test", + srcs = ["lax_metal_test.py"], + disable_backends = [ + "cpu", + "gpu", + "tpu", + ], + tags = ["notap"], + deps = [ + "//jax:internal_test_util", + "//jax:lax_reference", + ] + py_deps("numpy"), +) + jax_test( name = "lax_autodiff_test", srcs = ["lax_autodiff_test.py"], @@ -615,6 +649,11 @@ jax_test( }, ) +jax_test( + name = "cholesky_update_test", + srcs = ["cholesky_update_test.py"], +) + jax_test( name = "metadata_test", srcs = ["metadata_test.py"], @@ -651,6 +690,7 @@ jax_test( name = "nn_test", srcs = ["nn_test.py"], shard_count = { + "cpu": 10, "tpu": 10, "gpu": 10, }, @@ -676,6 +716,7 @@ jax_test( backend_tags = { "tpu": [ "noasan", # Times out under asan. + "requires-mem:16g", # Under tsan on 2x2 this test exceeds the default 12G memory limit. ], }, shard_count = { @@ -731,7 +772,11 @@ jax_test( name = "pytorch_interoperability_test", srcs = ["pytorch_interoperability_test.py"], disable_backends = ["tpu"], + # The following cases are disabled because they time out in Google's CI, mostly because the + # CUDA kernels in Torch take a very long time to compile. disable_configs = [ + "gpu_p100", # Pytorch P100 build times out in Google's CI. + "gpu_a100", # Pytorch A100 build times out in Google's CI. "gpu_h100", # Pytorch H100 build times out in Google's CI. ], tags = [ @@ -746,6 +791,14 @@ jax_test( jax_test( name = "qdwh_test", srcs = ["qdwh_test.py"], + backend_tags = { + "tpu": [ + "noasan", # Times out + "nomsan", # Times out + "notsan", # Times out + ], + }, + shard_count = 10, ) jax_test( @@ -784,9 +837,12 @@ jax_test( "notsan", # Times out ], }, + backend_variant_args = { + "gpu": ["--jax_num_generated_cases=40"], + }, shard_count = { "cpu": 40, - "gpu": 30, + "gpu": 40, "tpu": 40, }, tags = ["noasan"], # Times out @@ -960,6 +1016,29 @@ jax_test( ] + py_deps("scipy"), ) +jax_test( + name = "sparse_nm_test", + srcs = ["sparse_nm_test.py"], + config_tags_overrides = { + "gpu_a100": { + "ondemand": False, # Include in presubmit. + }, + }, + disable_backends = [ + "cpu", + "gpu", + "tpu", + ], + enable_configs = [ + "gpu_a100", + "gpu_h100", + ], + deps = [ + "//jax:experimental_sparse", + "//jax:pallas_gpu", + ], +) + jax_test( name = "sparsify_test", srcs = ["sparsify_test.py"], @@ -1014,6 +1093,11 @@ jax_test( main = "third_party/scipy/line_search_test.py", ) +jax_test( + name = "blocked_sampler_test", + srcs = ["blocked_sampler_test.py"], +) + py_test( name = "tree_util_test", srcs = ["tree_util_test.py"], @@ -1071,6 +1155,16 @@ py_test( ], ) +py_test( + name = "lru_cache_test", + srcs = ["lru_cache_test.py"], + deps = [ + "//jax", + "//jax:lru_cache", + "//jax:test_util", + ] + py_deps("filelock"), +) + jax_test( name = "compilation_cache_test", srcs = ["compilation_cache_test.py"], @@ -1241,6 +1335,11 @@ jax_test( deps = py_deps("hypothesis"), ) +jax_test( + name = "mutable_array_test", + srcs = ["mutable_array_test.py"], +) + jax_test( name = "for_loop_test", srcs = ["for_loop_test.py"], @@ -1294,7 +1393,7 @@ jax_test( disable_configs = [ "gpu_a100", # Numerical precision problems. ], - shard_count = 8, + shard_count = 15, deps = [ "//jax:rnn", ], @@ -1328,13 +1427,9 @@ py_test( ], ) -py_test( +jax_test( name = "logging_test", srcs = ["logging_test.py"], - deps = [ - "//jax", - "//jax:test_util", - ], ) jax_test( @@ -1371,7 +1466,6 @@ jax_test( ], deps = [ "//jax:internal_test_harnesses", - "//jax/experimental/export", ], ) @@ -1380,6 +1474,7 @@ jax_test( srcs = ["export_harnesses_multi_platform_test.py"], disable_configs = [ "gpu_a100", # TODO(b/269593297): matmul precision issues + "gpu_h100", # Scarce resources. ], shard_count = { "cpu": 40, @@ -1394,7 +1489,6 @@ jax_test( ], deps = [ "//jax:internal_test_harnesses", - "//jax/experimental/export", ], ) @@ -1415,12 +1509,33 @@ jax_test( "tpu", "cpu", ], - shard_count = 4, + shard_count = { + "gpu": 4, + }, + tags = ["multiaccelerator"], deps = [ "//jax:fused_attention_stablehlo", ], ) +py_test( + name = "pretty_printer_test", + srcs = ["pretty_printer_test.py"], + deps = [ + "//jax", + "//jax:test_util", + ], +) + +py_test( + name = "sourcemap_test", + srcs = ["sourcemap_test.py"], + deps = [ + "//jax", + "//jax:test_util", + ], +) + exports_files( [ "api_test.py", diff --git a/tests/ann_test.py b/tests/ann_test.py index ab35ce0c5392..1d704c725c61 100644 --- a/tests/ann_test.py +++ b/tests/ann_test.py @@ -23,9 +23,7 @@ from jax import lax from jax._src import test_util as jtu -from jax import config - -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() ignore_jit_of_pmap_warning = partial( jtu.ignore_warning,message=".*jit-of-pmap.*") diff --git a/tests/aot_test.py b/tests/aot_test.py index dacfa620c628..bca0d66ed384 100644 --- a/tests/aot_test.py +++ b/tests/aot_test.py @@ -17,7 +17,6 @@ import unittest from absl.testing import absltest import jax -from jax import config from jax._src import core from jax._src import test_util as jtu from jax._src.lib import xla_client as xc @@ -31,7 +30,7 @@ from jax.sharding import PartitionSpec as P import numpy as np -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() prev_xla_flags = None diff --git a/tests/api_test.py b/tests/api_test.py index b1e13f815c64..146340f89a95 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -16,6 +16,7 @@ import collections import collections.abc +from collections.abc import Callable import concurrent.futures from contextlib import contextmanager import copy @@ -33,7 +34,7 @@ import subprocess import sys import types -from typing import Callable, NamedTuple +from typing import NamedTuple import unittest import weakref @@ -53,6 +54,7 @@ from jax._src import test_util as jtu from jax._src import xla_bridge from jax._src import debugging +from jax._src import pjit as pjit_lib from jax._src.ad_checkpoint import saved_residuals from jax._src.interpreters import mlir from jax._src.interpreters import partial_eval as pe @@ -229,7 +231,9 @@ def f(x, y, z): def test_jit_device(self): device = jax.devices()[-1] - x = jit(lambda x: x, device=device)(3.) + with jtu.ignore_warning(category=DeprecationWarning, + message="backend and device argument"): + x = jit(lambda x: x, device=device)(3.) _check_instance(self, x) self.assertEqual(x.devices(), {device}) @@ -260,10 +264,12 @@ def test_jit_default_device(self, module): with jax.default_device(test_device): # Explicit `device` or `backend` argument to jit overrides default_device - self.assertEqual( - module(f, device=system_default_device)(1).devices(), - system_default_devices) - out = module(f, backend="cpu")(1) + with jtu.ignore_warning(category=DeprecationWarning, + message="backend and device argument"): + self.assertEqual( + module(f, device=system_default_device)(1).devices(), + system_default_devices) + out = module(f, backend="cpu")(1) self.assertEqual(next(iter(out.devices())).platform, "cpu") # Sticky input device overrides default_device @@ -312,20 +318,16 @@ def i(): jit(f, **{argnum_type: (0, 1, -3)}) # Out of bounds without *args - # with self.assertRaises(ValueError): - with self.assertWarns(SyntaxWarning): + with self.assertRaises(ValueError): jit(f, **{argnum_type: (0, 1, 3)}) - # with self.assertRaises(ValueError): - with self.assertWarns(SyntaxWarning): + with self.assertRaises(ValueError): jit(f, **{argnum_type: (0, 1, -4)}) - # with self.assertRaises(ValueError): - with self.assertWarns(SyntaxWarning): + with self.assertRaises(ValueError): jit(g, **{argnum_type: (0, 1, 3)}) - # with self.assertRaises(ValueError): - with self.assertWarns(SyntaxWarning): + with self.assertRaises(ValueError): jit(g, **{argnum_type: (0, 1, -3)}) # Out of bounds with *args @@ -351,8 +353,7 @@ def h(a, /, b, c, *args, **kwargs): jit(f, **{argnum_type: ("b", "c")}) # Undefined arg without **kwargs - # with self.assertRaises(ValueError): - with self.assertWarns(SyntaxWarning): + with self.assertRaises(ValueError): jit(f, **{argnum_type: ("b", "c", "not_defined")}) # Undefined arg with **kwargs @@ -362,13 +363,11 @@ def h(a, /, b, c, *args, **kwargs): jit(h, **{argnum_type: ("b", "c", "not_defined")}) # Positional only - # with self.assertRaises(ValueError): - with self.assertWarns(SyntaxWarning): + with self.assertRaises(ValueError): jit(h, **{argnum_type: ("a", "c")}) # Var positional - # with self.assertRaises(ValueError): - with self.assertWarns(SyntaxWarning): + with self.assertRaises(ValueError): jit(h, **{argnum_type: ("args", "c")}) def test_jit_with_many_args_works(self): @@ -499,7 +498,7 @@ def f(inp1, inp2, inp3): self.assertDeleted(z) def test_resolve_argnums_signature_fail(self): - api_util.resolve_argnums(int, None, None, None, None) # doesn't crash + api_util.resolve_argnums(int, None, None, None, None, None) # doesn't crash @jtu.device_supports_buffer_donation() def test_donate_argnames_with_args(self): @@ -558,6 +557,42 @@ def add(x, y): result = f(x, x) result.block_until_ready() + @parameterized.named_parameters( + ('argnames', {'donate_argnames': ('z', 'y')}), + ('argnums', {'donate_argnums': (0, 1)}) + ) + def test_dict_donation(self, jit_kwargs): + @partial(jax.jit, **jit_kwargs) + def f(z, y, x): + return z, y, x + + z = {'c': 3.} + y = {'b': 2.} + x = {'a': 1.} + + _, kwargs_info = f.lower(z=z, y=y, x=x).args_info + self.assertTrue(kwargs_info['z']['c'].donated) + self.assertTrue(kwargs_info['y']['b'].donated) + self.assertFalse(kwargs_info['x']['a'].donated) + + @parameterized.named_parameters( + ('argnames', {'donate_argnames': ('z', 'y')}), + ('argnums', {'donate_argnums': (0, 1)}) + ) + def test_dict_donation_args_kwargs(self, jit_kwargs): + @partial(jax.jit, **jit_kwargs) + def f(z, y, x): + return z, y, x + + z = {'c': 3.} + y = {'b': 2.} + x = {'a': 1.} + + args_info, kwargs_info = f.lower(z, y=y, x=x).args_info + self.assertTrue(args_info[0]['c'].donated) + self.assertTrue(kwargs_info['y']['b'].donated) + self.assertFalse(kwargs_info['x']['a'].donated) + def test_intersecting_static_and_donate_argnames(self): with self.assertRaisesRegex( ValueError, "static_argnames and donate_argnames cannot intersect"): @@ -674,6 +709,13 @@ def test_trivial_computations(self): self.assertNotEqual(z3.unsafe_buffer_pointer(), x1.unsafe_buffer_pointer()) self.assertEqual(z2, 1) + def test_print_token_buffer_error(self): + token = jax.lax.create_token() + with self.assertRaisesRegex( + RuntimeError, "Cannot convert a token-shape buffer to a numpy array." + ): + token._buf._value + def test_trivial_computations_with_tokens(self): @jit def noop(arr, token): @@ -681,8 +723,12 @@ def noop(arr, token): arr = jnp.ones(10) token = jax.lax.create_token() + _, out_token = noop(arr, token) - self.assertEqual(token, noop(arr, token)[1]) + self.assertIsInstance(token, core.Token) + self.assertIsInstance(out_token, core.Token) + # Different token objects. + self.assertIsNot(token, out_token) def test_jit_bad_input(self): def f(x): @@ -815,8 +861,10 @@ def test_cpp_jitted_function_returns_PyBuffer(self): @jtu.skip_on_devices("cpu") def test_explicit_backend(self, module): f = lambda x: x + 1 - jitted_f = module(f, backend=jtu.device_under_test()) - jitted_f_cpu = module(f, backend="cpu") + with jtu.ignore_warning(category=DeprecationWarning, + message="backend and device argument"): + jitted_f = module(f, backend=jtu.device_under_test()) + jitted_f_cpu = module(f, backend="cpu") result = jitted_f(1.) result_cpu = jitted_f_cpu(1.) @@ -831,8 +879,10 @@ def test_explicit_backend(self, module): def test_device_to_device_copy_between_backends(self, module): # b/186624243 f = lambda x: x + 1 - jitted_f = module(f, backend=jtu.device_under_test()) - jitted_f_cpu = module(f, backend="cpu") + with jtu.ignore_warning(category=DeprecationWarning, + message="backend and device argument"): + jitted_f = module(f, backend=jtu.device_under_test()) + jitted_f_cpu = module(f, backend="cpu") x = np.arange(30).reshape(1, 10, 3) result = jitted_f(x) @@ -843,6 +893,8 @@ def test_device_to_device_copy_between_backends(self, module): self.assertAllClose(result_cpu_2, x + 4) @jtu.skip_on_devices("cpu") + @jtu.ignore_warning(category=DeprecationWarning, + message="backend and device argument") def test_mismatched_nested_backends(self): @partial(jax.jit, backend=jtu.device_under_test()) def f(x): @@ -1227,6 +1279,13 @@ def f(x, y, *args, **kwargs): self.assertIn("kwargs['z']", hlo_str) self.assertIn("kwargs['w']", hlo_str) + hlo_str = mlir.module_to_string( + lowered.compiler_ir('stablehlo'), + enable_debug_info=False, + ) + for s in ("\"x\"", "y['hi']", "args[0]", "args[1]", "kwargs['z']", "kwargs['w']"): + self.assertNotIn(s, hlo_str) + @parameterized.parameters([0, 2, [(0, 2)]]) def test_jit_lower_arg_info_static_argnums(self, static_argnums): def f(x, y, *args, **kwargs): @@ -1242,6 +1301,10 @@ def f(x, y, *args, **kwargs): self.assertIn("kwargs['z']", hlo_str) self.assertIn("kwargs['w']", hlo_str) + hlo_str = mlir.module_to_string(ir, enable_debug_info=False) + for s in ("\"x\"", "y['hi']", "args[0]", "args[1]", "kwargs['z']", "kwargs['w']"): + self.assertNotIn(s, hlo_str) + @parameterized.parameters(['a', 'b', [('a', 'b')]]) def test_jit_lower_arg_info_static_argnames(self, static_argnames): def f(x, y, *args, **kwargs): @@ -1259,6 +1322,13 @@ def f(x, y, *args, **kwargs): self.assertNotIn("kwargs['a']", hlo_str) self.assertNotIn("kwargs['b']", hlo_str) + hlo_str = mlir.module_to_string(ir, enable_debug_info=False) + for s in ( + "\"x\"", "y['hi']", "args[0]", "args[1]", "kwargs['z']", + "kwargs['w']", "kwargs['a']", "kwargs['b']" + ): + self.assertNotIn(s, hlo_str) + def test_jit_lower_result_info(self): def f(x, y, z): return {'a': x, 'b': [y]} @@ -1412,7 +1482,6 @@ def f(x): with self.assertRaisesRegex(core.ConcretizationTypeError, "Abstract tracer value"): jax.jit(f)(x) - @parameterized.named_parameters( ('grad', jax.grad), ('jacfwd', jax.jacfwd), @@ -1688,7 +1757,7 @@ def test_device_put_and_get(self): def test_device_put_sharding(self): mesh = jax.sharding.Mesh(jax.devices(), ('x',)) - s = jax.sharding.NamedSharding(mesh, P('x')) + s = jax.NamedSharding(mesh, P('x')) x = jnp.arange(len(jax.devices())) y = jax.device_put(x, s) @@ -1714,9 +1783,9 @@ def test_device_put_sharding_tree(self): mesh = jax.sharding.Mesh(np.array(jax.devices()[:2]).reshape((2, 1)), ("x", "y")) - s1 = jax.sharding.NamedSharding(mesh, P("x")) - s2 = jax.sharding.NamedSharding(mesh, P("y")) - s3 = jax.sharding.NamedSharding(mesh, P("x", "y")) + s1 = jax.NamedSharding(mesh, P("x")) + s2 = jax.NamedSharding(mesh, P("y")) + s3 = jax.NamedSharding(mesh, P("x", "y")) x = jnp.arange(2) y = jnp.arange(2) + 10 @@ -1992,7 +2061,6 @@ def f(x, u): self.assertEqual(fwd(rev(f, 0), 1)(x, u).shape, (5, 2)) self.assertEqual(fwd(fwd(f, 0), 1)(x, u).shape, (5, 2)) - def test_large_device_constant(self): ans = jit(lambda x: 2 * x)(jnp.ones(int(2e6))) # doesn't crash self.assertAllClose(ans, np.ones(int(2e6)) * 2., check_dtypes=False) @@ -2158,12 +2226,10 @@ def f(x, y): return x + y def test_vjp_mismatched_arguments(self): _, pullback = api.vjp(lambda x, y: x * y, np.float32(3), np.float32(4)) self.assertRaisesRegex( - TypeError, - "Tree structure of cotangent input.*does not match", + ValueError, "unexpected tree structure", lambda: pullback((np.float32(7), np.float32(100)))) self.assertRaisesRegex( - TypeError, - "Type of cotangent input to vjp pullback.*is not the expected tangent type", + ValueError, "unexpected JAX type", lambda: pullback(np.float16(42))) def test_vjp_bad_cotangent_shape(self): @@ -2172,9 +2238,7 @@ def test_vjp_bad_cotangent_shape(self): def f_jax(x, y): return jnp.matmul(x, y) res, pullback = jax.vjp(f_jax, x, y) - with self.assertRaisesRegex( - ValueError, - "Shape of cotangent input to vjp pullback function .* must be the same as the shape of corresponding primal input .*"): + with self.assertRaisesRegex(ValueError, "unexpected JAX type"): pullback(np.ones((2, 4), dtype=np.float32)) def test_jvp_jit_cached(self): @@ -2374,6 +2438,20 @@ def test_block_until_ready_function(self): self.assertAllClose(pytree[0], jnp.array(1.), check_dtypes=False) self.assertAllClose(pytree[1], np.ones(3), check_dtypes=False) + def test_block_until_ready_numpy_arrays(self): + pytree = (np.ones(1), np.ones(2)) + pytree = jax.block_until_ready(pytree) + self.assertAllClose(pytree[0], np.ones(1), check_dtypes=False) + self.assertAllClose(pytree[1], np.ones(2), check_dtypes=False) + + def test_block_until_ready_mixed(self): + pytree = (device_put(1.), device_put(2.), np.ones(3), 4) + pytree = jax.block_until_ready(pytree) + self.assertAllClose(pytree[0], jnp.array(1.), check_dtypes=False) + self.assertAllClose(pytree[1], jnp.array(2.), check_dtypes=False) + self.assertAllClose(pytree[2], np.ones(3), check_dtypes=False) + self.assertEqual(pytree[3], 4) + def test_devicearray_weakref_friendly(self): x = device_put(1.) y = weakref.ref(x) @@ -2512,7 +2590,7 @@ def fun(x, y): def test_eval_shape_trace_cache_share(self): def f(x): - return x * 2 + return x inp = np.arange(8) @@ -2520,8 +2598,32 @@ def f(x): jax.eval_shape(f, inp) jax.jit(f)(inp) - # one for `f` and another for mul (`x * 2`) which is jitted. - self.assertEqual(count[0], 2) + self.assertEqual(count[0], 1) + + @unittest.skipIf(xla_extension_version <= 273, "requires jaxlib 0.4.31") + def test_jit_infer_params_cache(self): + def f(x): + return x + + f_jit = jax.jit(f) + + def g(x): + x = f_jit(x) # noqa: F821 + x = f_jit(x) # noqa: F821 + return x + + g_jit = jax.jit(g) + + inp = np.arange(8) + with jtu.count_jit_infer_params_cache_miss() as count: + g_jit(inp) + + self.assertDictEqual(count, {f: 1, g: 1}) + cache_size = pjit_lib._infer_params_cached.cache_info().currsize + del count, f, f_jit, g, g_jit + # Cache should only keep a weak reference to f and g. + self.assertLess(pjit_lib._infer_params_cached.cache_info().currsize, + cache_size, msg=pjit_lib._infer_params_cached.cache_keys()) def test_eval_shape_out_shardings(self): s = jax.sharding.SingleDeviceSharding(jax.devices()[0]) @@ -2638,14 +2740,14 @@ def test_vjp_of_int_index(self): self.assertEqual(tangent_i, np.zeros(shape=(), dtype=float0)) def test_vjp_of_int_shapes(self): - out, fn_vjp = api.vjp(lambda x: lax.reshape(x, (2, 2)), np.ones((4, 1), - dtype=int)) - tangent, = fn_vjp(out) + out, fn_vjp = api.vjp( + lambda x: lax.reshape(x, (2, 2)), np.ones((4, 1), dtype=int)) + tangent, = fn_vjp(np.zeros((2, 2), dtypes.float0)) self.assertArraysEqual(tangent, np.zeros(shape=(4, 1), dtype=float0)) def test_jit_vjp_of_int(self): primal, fn_vjp = api.vjp(lambda x, y: x+y, 2, 1) - tangent_x, tangent_i = jax.jit(fn_vjp)(1) + tangent_x, tangent_i = jax.jit(fn_vjp)(np.zeros((), dtypes.float0)) self.assertEqual(primal, 3) self.assertEqual(tangent_x, np.zeros(shape=(), dtype=float0)) self.assertEqual(tangent_i, np.zeros(shape=(), dtype=float0)) @@ -2878,6 +2980,7 @@ def fn(x): axis_env = [(axis_name, jax.local_device_count())] _ = api.xla_computation(fn, axis_env=axis_env, backend='cpu')(input_x) + @jtu.unaccelerate_getattr_deprecation(jax, 'xla_computation') def test_xla_computation_axis_env(self): def fn(x): z = x * jax.lax.axis_index('i').astype(jnp.float32) @@ -2980,7 +3083,6 @@ def test_vmap_in_axes_list(self): x = jnp.zeros(3) y = jnp.arange(3.) - def f(dct, x, y): return dct['a'] + dct['b'] + x + y @@ -3710,7 +3812,7 @@ def test_jit_returning_token(self): self.assertIsInstance(x, core.Token) def test_jit_capturing_token(self): - tok = core.token + tok = jax.lax.create_token() _, y = jax.jit(lambda x: (x + 2, tok))(7) self.assertIsInstance(y, core.Token) @@ -3992,6 +4094,16 @@ def f(): return jnp.exp(dtype(0)) f() # doesn't error + def test_vmap_make_jaxpr_close_over_tracer(self): + def run(inp): + def f(x, y): + return x + y + g = lambda x: f(x, inp) + jaxpr = jax.make_jaxpr(g)(1) + return jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, 1) + + jax.vmap(run)(jnp.arange(2)) # doesn't crash + def test_large_python_ints(self): with self.assertRaises(OverflowError): jnp.multiply(2 ** 100, 3.) @@ -4269,9 +4381,14 @@ def f(x, y): g = jax.grad(f, argnums=-1) g(x, y) # doesn't crash + @unittest.skipIf(xla_extension_version < 272, "requires jaxlib 0.4.31") def test_jit_negative_static_argnums(self): - g = jax.jit(lambda x, y: x * y, static_argnums=-1) - g(1, 2) # doesn't crash + @partial(jax.jit, static_argnums=-1) + def g(x, y): + assert isinstance(y, int) + return x * y + for i in range(3): # Loop verifies we exercise both Python and C++ dispatch + self.assertEqual(2 * i, g(2, i), msg=i) def test_fastpath_cache_confusion(self): # https://github.com/google/jax/issues/12542 @@ -4283,7 +4400,6 @@ def a(x): def b(x): return a(x) - @jax.jit def g(x): return x, x @@ -4306,7 +4422,6 @@ def a(): # note nullary function, still staged out though def b(x): return a() - @jax.jit def g(x): return x, x @@ -4395,6 +4510,10 @@ def f(i): jax.clear_caches() self.assertEqual(f._cache_size, 0) + def test_invalid_value_device_put(self): + with self.assertRaisesRegex(ValueError, r".*Received invalid value.*"): + jax.device_put(jnp.arange(8), 'cpu') + def test_clear_cache(self): @jax.jit def add(x): @@ -4484,6 +4603,21 @@ def f(x, y): _, msg = cm.output self.assertIn('another function defined on the same line', msg) + def test_cache_miss_explanations_unpacks_transforms(self): + # Tests that the explain_tracing_cache_miss() function does not throw an + # error when unpacking `transforms` with a length greater than 3. + @jax.jit + def f(key): + return jax.random.truncated_normal(key, 1, 1, dtype=jax.numpy.float32) + + with config.explain_cache_misses(True): + with self.assertLogs(level="WARNING") as cm: + f(jax.random.key(seed=123)) + + self.assertLen(cm.output, 5) + for msg in cm.output: + self.assertIn("TRACING CACHE MISS", msg) + @parameterized.named_parameters([ {"testcase_name": f"{dtype}", "dtype": dtype} for dtype in jtu.dtypes.custom_floats]) @@ -4554,7 +4688,6 @@ def test_mesh_creation_error_message(self): with self.assertRaisesRegex(ValueError, "ndim of its first argument"): jax.sharding.Mesh(jax.devices(), ("x", "y")) - @unittest.skipIf(xla_extension_version < 222, 'jaxlib version too old') def test_jit_boundmethod_reference_cycle(self): class A: def __init__(self): @@ -4565,6 +4698,95 @@ def foo(self): gc.collect() assert a() is None + def test_forwarding_bug(self): + # Test for issue #20267. + def f(x): + + @jax.jit + def inner(a, x): + return a, jnp.exp(x) + + return inner(0.0, x)[0] + jax.grad(f)(1.) # don't crash + + @parameterized.parameters(it.product(range(4), repeat=3)) + @jtu.run_on_devices("cpu") + def test_jit_forwarding_correctness(self, seed, num_input_fwd, num_output_fwd): + num_args = 3 + rng = np.random.RandomState(seed) + in_perm = rng.permutation(num_args) + out_perm = rng.permutation(num_args) + + @jax.jit + def f(inputs): + inputs = [inputs[i] for i in in_perm] + outputs = inputs[:num_input_fwd] + [ + jnp.exp(inputs[i]) if i < num_output_fwd else jnp.sin(inputs[i]) + for i in range(num_args - num_input_fwd)] + return [outputs[i] for i in out_perm] + + jtu.check_grads(f, (list(jnp.arange(float(num_args))),), order=1, + modes=['rev'], atol=1e-3, rtol=1e-3) + + @jtu.run_on_devices("cpu") + def test_inner_jit_forwarding_happens(self): + jaxpr = jax.make_jaxpr(lambda: jax.jit(lambda x: x)(3))() + self.assertLen(jaxpr.jaxpr.outvars, 1) + self.assertIsInstance(jaxpr.jaxpr.outvars[0], core.Literal) + self.assertEqual(jaxpr.jaxpr.outvars[0].val, 3) + + @parameterized.parameters(range(8)) + @jtu.run_on_devices("cpu") + def test_inner_jit_forwarding_correctness(self, num_input_fwd): + num_args = 8 + rng = np.random.RandomState(0) + + @jax.jit + def f(inputs): + inputs = [inputs[i] for i in rng.permutation(num_args)] + outputs = (inputs[:num_input_fwd] + + [jnp.sin(inputs[i]) for i in range(num_args - num_input_fwd)]) + return [outputs[i] for i in rng.permutation(num_args)] + + f2 = jax.jit(f) + inputs = list(jnp.arange(float(num_args))) + expected = f(inputs) + ans = f2(inputs) + for a, b in zip(ans, expected): + self.assertAllClose(a, b) + + def test_inner_jit_forwarded_consts_stay_const(self): + out = jax.jit(lambda: int(jax.jit(lambda x: x)(3)))() # don't crash + self.assertEqual(out, 3) + + def test_lowering_platform_aot(self): + @jax.jit + def f(x): + return x * 2 + + f.trace(jnp.arange(8)).lower(lowering_platforms=('tpu',)) # doesn't crash + + def test_no_double_dots_in_error_message(self): + @jax.jit + def f(x): + return 1 if x > 0 else 0 + + with self.assertRaisesRegex(TracerBoolConversionError, r"with shape bool\[\]\.[^\.]"): + f(0) + + def test_inlined_literals_with_error(self): + @jax.jit + def f(): + @partial(jax.jit, inline=True) + def g(): + return jnp.sin(1.) + if g() > 0: + return 1. + return 0. + + with self.assertRaisesRegex(TracerBoolConversionError, "Attempted boolean"): + f() + class RematTest(jtu.JaxTestCase): @@ -5526,6 +5748,16 @@ def f(x): res_avals = saved_residuals(f, jnp.ones((2, 2))) self.assertLen(res_avals, 1) + def test_name_saveable_input(self): + @partial(jax.remat, policy=lambda p, *_, **__: 'mul' in str(p)) + def f(x): + x = checkpoint_name(x * x, 'foo') + x = x * x + return x + + res = saved_residuals(f, 3.) + self.assertStartsWith(res[1][1], "named 'foo'") + def test_name_denylist(self): def f(x): y = checkpoint_name(jnp.multiply(2., 2.), 'y') @@ -6216,7 +6448,6 @@ def f(x): p:f32[] = add n l in (p,) } ) - linear=(False, False, False, False) ] e a a c d in (f,) }""" jaxpr = api.make_jaxpr(f)(jnp.float32(3.)) @@ -8244,10 +8475,10 @@ def foo_bwd(_, g): self.assertRaisesRegex( TypeError, re.escape( - "Custom VJP rule must produce an output with the same container " + "Custom VJP bwd rule must produce an output with the same container " "(pytree) structure as the args tuple of the primal function, " "and in particular must produce a tuple of length equal to the " - "number of arguments to the primal function, but got VJP output " + "number of arguments to the primal function, but got bwd output " "structure {} for primal input structure {}.".format( jax.tree.structure((1, 1)), jax.tree.structure((1,))) @@ -8266,7 +8497,7 @@ def foo_bwd(_, g): return 2. * g # Should be a tuple f.defvjp(foo_fwd, foo_bwd) - with self.assertRaisesRegex(TypeError, "Custom VJP rule .* must produce a tuple"): + with self.assertRaisesRegex(TypeError, "Custom VJP bwd rule .* must produce a tuple"): api.grad(f)(3.) def test_fwd_rule_primal_out_type_doesnt_match_primal_error_message(self): @@ -8996,6 +9227,26 @@ def bwd_snd(_, g): gx, = vjp(x) self.assertArraysAllClose(gx, zero) + def test_symbolic_zero_custom_vjp_bwd_shape_error(self): + @jax.custom_vjp + def f(x, y, z): + return x, y, z + + def fwd(x, y, z): + return f(x.value, y.value, z.value), None + + def bwd(_, gs): + x_bar, y_bar, z_bar = gs + return y_bar, x_bar, z_bar # swapped! + + f.defvjp(fwd, bwd, symbolic_zeros=True) + + with self.assertRaisesRegex( + ValueError, + r'Consider just returning a None here'): + jax.grad(lambda x, y, z: f(x, y, z)[2].sum())( + jnp.ones(1), jnp.ones(2), jnp.ones(3)) + @parameterized.named_parameters( ('jit_vmap', True, True), ('jit', True, False), @@ -9251,6 +9502,60 @@ def f_bwd(_, z_bar): jax.grad(f)((1.0, (2.0, None))) # don't crash + def test_bwd_rule_shape_mismatch(self): + @jax.custom_vjp + def foo(x, y): + return x + + def foo_fwd(x, y): + return x, None + + def foo_bwd(_, g): + return jnp.zeros(3), jnp.zeros(3) + + foo.defvjp(foo_fwd, foo_bwd) + + with self.assertRaisesRegex( + ValueError, + r'output\[1\] the bwd rule produced an output of shape/dtype float..\[3\]'): + jax.grad(lambda x, y: foo(x, y * y).sum(), 1)(jnp.ones(3), jnp.ones(4)) + + def test_bwd_rule_shape_mismatch_disable(self): + # TODO(mattjj): remove this test when the config option is removed + @jax.custom_vjp + def foo(x, y): + return x + + def foo_fwd(x, y): + return x, None + + def foo_bwd(_, g): + return jnp.zeros(3), jnp.zeros(3) + + foo.defvjp(foo_fwd, foo_bwd) + + try: + jax.config.update('jax_custom_vjp_disable_shape_check', True) + jax.grad(lambda x, y: foo(x, y).sum(), 1)(jnp.ones(3), jnp.ones(4)) + finally: + jax.config.update('jax_custom_vjp_disable_shape_check', False) + + def test_bwd_rule_can_produce_list_or_tuple(self): + @jax.custom_vjp + def f(x, y): + return x * y + + def f_fwd(x, y): + return f(x, y), (x, y) + + def f_bwd(xy, g): + x, y = xy + return [g * y, x * g] # list, not tuple + + f.defvjp(f_fwd, f_bwd) + + jax.grad(f)(1., 2.) # don't crash + def transpose_unary(f, x_example): def transposed(y): @@ -9722,8 +10027,7 @@ def tp(r, t): return 2 * t / r return x + fn(y, x) def cond_wrap(f): - return lambda i, x: lax.cond(i > 0, f, lambda x: x, x, - linear=(True,)) + return lambda i, x: lax.cond(i > 0, f, lambda x: x, x) i = 7. x = jnp.ones(2) * 6. @@ -9747,8 +10051,7 @@ def tp(r, t): return 2 * fn(r, t) return x + fn(y, x) def cond_wrap(f): - return lambda i, x: lax.cond(i > 0, f, lambda x: x, x, - linear=(True,)) + return lambda i, x: lax.cond(i > 0, f, lambda x: x, x) i = 7. x = jnp.ones(2) * 6. @@ -10312,8 +10615,8 @@ def test_pmap_nested_donate_ignored(self): class NamedCallTest(jtu.JaxTestCase): + @jtu.unaccelerate_getattr_deprecation(jax, 'xla_computation') def test_default_name(self): - @api.named_call def my_test_function(x): return x**2 @@ -10484,9 +10787,11 @@ def wsc_as_noop(ctx, operand, *args, **kwargs): rules = ((jax.lax.sharding_constraint_p, wsc_as_noop),) lowered_ir = ( jax.jit(f) - .lower(jax.ShapeDtypeStruct((2, 4), dtype=jnp.bfloat16), - _experimental_lowering_parameters=mlir.LoweringParameters( - override_lowering_rules=rules)).as_text()) + .trace(jax.ShapeDtypeStruct((2, 4), dtype=jnp.bfloat16)) + .lower(_private_parameters=mlir.LoweringParameters( + override_lowering_rules=rules)) + .as_text() + ) self.assertNotIn("stablehlo.custom_call @Sharding", lowered_ir) diff --git a/tests/api_util_test.py b/tests/api_util_test.py index 7b7a479dbf14..46bed8c86b8a 100644 --- a/tests/api_util_test.py +++ b/tests/api_util_test.py @@ -16,12 +16,12 @@ import itertools as it from absl.testing import absltest from absl.testing import parameterized +import jax from jax._src import api_util from jax import numpy as jnp from jax._src import test_util as jtu -from jax import config -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() class ApiUtilTest(jtu.JaxTestCase): @@ -43,7 +43,8 @@ def test_donation_vector(self): expected += (False,) self.assertEqual( expected, - api_util.donation_vector(donate_argnums, (), args, kwargs)) + api_util.donation_vector(donate_argnums, (), + jax.tree.structure((args, kwargs)))) @parameterized.parameters( ((0,), (0,)), diff --git a/tests/array_api_test.py b/tests/array_api_test.py index 0d4893e4939e..dcb33b9bc57f 100644 --- a/tests/array_api_test.py +++ b/tests/array_api_test.py @@ -21,9 +21,11 @@ from types import ModuleType -from absl.testing import absltest +from absl.testing import absltest, parameterized import jax -from jax import config +import jax.numpy as jnp +from jax._src import config, test_util as jtu +from jax._src.dtypes import _default_types, canonicalize_dtype from jax.experimental import array_api config.parse_flags_with_absl() @@ -58,12 +60,15 @@ 'broadcast_to', 'can_cast', 'ceil', + 'clip', 'complex128', 'complex64', 'concat', 'conj', + 'copysign', 'cos', 'cosh', + 'cumulative_sum', 'divide', 'e', 'empty', @@ -85,6 +90,7 @@ 'full_like', 'greater', 'greater_equal', + 'hypot', 'iinfo', 'imag', 'inf', @@ -112,9 +118,12 @@ 'matmul', 'matrix_transpose', 'max', + 'maximum', 'mean', 'meshgrid', 'min', + 'minimum', + 'moveaxis', 'multiply', 'nan', 'negative', @@ -130,11 +139,14 @@ 'prod', 'real', 'remainder', + 'repeat', 'reshape', 'result_type', 'roll', 'round', + 'searchsorted', 'sign', + 'signbit', 'sin', 'sinh', 'sort', @@ -149,6 +161,7 @@ 'tan', 'tanh', 'tensordot', + 'tile', 'tril', 'triu', 'trunc', @@ -160,6 +173,7 @@ 'unique_counts', 'unique_inverse', 'unique_values', + 'unstack', 'var', 'vecdot', 'where', @@ -219,19 +233,121 @@ class ArrayAPISmokeTest(absltest.TestCase): """Smoke test for the array API.""" def test_main_namespace(self): - self.assertSetEqual(names(array_api), MAIN_NAMESPACE) + self.assertContainsSubset(MAIN_NAMESPACE, names(array_api)) def test_linalg_namespace(self): - self.assertSetEqual(names(array_api.linalg), LINALG_NAMESPACE) + self.assertContainsSubset(LINALG_NAMESPACE, names(array_api.linalg)) def test_fft_namespace(self): - self.assertSetEqual(names(array_api.fft), FFT_NAMESPACE) + self.assertContainsSubset(FFT_NAMESPACE, names(array_api.fft)) def test_array_namespace_method(self): x = array_api.arange(20) self.assertIsInstance(x, jax.Array) self.assertIs(x.__array_namespace__(), array_api) +class ArrayAPIInspectionUtilsTest(jtu.JaxTestCase): + + info = array_api.__array_namespace_info__() + + def setUp(self): + super().setUp() + self._boolean = self.build_dtype_dict(["bool"]) + self._signed = self.build_dtype_dict(["int8", "int16", "int32"]) + self._unsigned = self.build_dtype_dict(["uint8", "uint16", "uint32"]) + self._floating = self.build_dtype_dict(["float32"]) + self._complex = self.build_dtype_dict(["complex64"]) + if config.enable_x64.value: + self._signed["int64"] = jnp.dtype("int64") + self._unsigned["uint64"] = jnp.dtype("uint64") + self._floating["float64"] = jnp.dtype("float64") + self._complex["complex128"] = jnp.dtype("complex128") + self._integral = self._signed | self._unsigned + self._numeric = ( + self._signed | self._unsigned | self._floating | self._complex + ) + def build_dtype_dict(self, dtypes): + out = {} + for name in dtypes: + out[name] = jnp.dtype(name) + return out + + def test_capabilities_info(self): + capabilities = self.info.capabilities() + assert capabilities["boolean indexing"] + assert not capabilities["data-dependent shapes"] + + def test_default_device_info(self): + assert self.info.default_device() is None + + def test_devices_info(self): + assert self.info.devices() == jax.devices() + + def test_default_dtypes_info(self): + _default_dtypes = { + "real floating": "f", + "complex floating": "c", + "integral": "i", + "indexing": "i", + } + target_dict = { + dtype_name: canonicalize_dtype( + _default_types.get(kind) + ) for dtype_name, kind in _default_dtypes.items() + } + assert self.info.default_dtypes() == target_dict + + @parameterized.parameters( + "bool", "signed integer", "real floating", + "complex floating", "integral", "numeric", None, + (("real floating", "complex floating"),), + (("integral", "signed integer"),), + (("integral", "bool"),), + ) + def test_dtypes_info(self, kind): + + info_dict = self.info.dtypes(kind=kind) + control = { + "bool":self._boolean, + "signed integer":self._signed, + "unsigned integer":self._unsigned, + "real floating":self._floating, + "complex floating":self._complex, + "integral": self._integral, + "numeric": self._numeric + } + target_dict = {} + if kind is None: + target_dict = control["numeric"] | self._boolean + elif isinstance(kind, tuple): + target_dict = {} + for _kind in kind: + target_dict |= control[_kind] + else: + target_dict = control[kind] + assert info_dict == target_dict + +class ArrayAPIErrors(absltest.TestCase): + """Test that our array API implementations raise errors where required""" + + # TODO(micky774): Remove when jnp.clip deprecation is completed + # (began 2024-4-2) and default behavior is Array API 2023 compliant + def test_clip_complex(self): + x = array_api.arange(5, dtype=array_api.complex64) + complex_msg = "Complex values have no ordering and cannot be clipped" + with self.assertRaisesRegex(ValueError, complex_msg): + array_api.clip(x) + + with self.assertRaisesRegex(ValueError, complex_msg): + array_api.clip(x, max=x) + + x = array_api.arange(5, dtype=array_api.int32) + with self.assertRaisesRegex(ValueError, complex_msg): + array_api.clip(x, min=-1+5j) + + with self.assertRaisesRegex(ValueError, complex_msg): + array_api.clip(x, max=-1+5j) + if __name__ == '__main__': absltest.main() diff --git a/tests/array_interoperability_test.py b/tests/array_interoperability_test.py index 006dea8acd4d..c2cd4c0f968d 100644 --- a/tests/array_interoperability_test.py +++ b/tests/array_interoperability_test.py @@ -22,8 +22,6 @@ from jax.sharding import PartitionSpec as P from jax._src import config from jax._src import test_util as jtu -from jax._src import xla_bridge as xb -from jax._src.lib import xla_extension_version import numpy as np @@ -73,25 +71,51 @@ def setUp(self): @jtu.sample_product( shape=all_shapes, dtype=dlpack_dtypes, - gpu=[False, True], + copy=[False, True, None], + use_stream=[False, True], ) - def testJaxRoundTrip(self, shape, dtype, gpu): - if xb.using_pjrt_c_api(): - self.skipTest("DLPack support is incomplete in the PJRT C API") # TODO(skyewm) + @jtu.run_on_devices("gpu") + def testJaxRoundTrip(self, shape, dtype, copy, use_stream): rng = jtu.rand_default(self.rng()) np = rng(shape, dtype) - if gpu and jtu.test_device_matches(["cpu"]): - raise unittest.SkipTest("Skipping GPU test case on CPU") - device = jax.devices("gpu" if gpu else "cpu")[0] - x = jax.device_put(np, device) - dlpack = jax.dlpack.to_dlpack(x) - y = jax.dlpack.from_dlpack(dlpack) - self.assertEqual(y.devices(), {device}) - self.assertAllClose(np.astype(x.dtype), y) + def _check_copy(x: jax.Array, y: jax.Array, expect_copy): + copied = x.unsafe_buffer_pointer() != y.unsafe_buffer_pointer() + assert copied == expect_copy, f"Expected {'a' if expect_copy else 'no'} copy" + + # Check if the source device is preserved + x = jax.device_put(np, jax.devices("cpu")[0]) + device = jax.devices("gpu")[0] + y = jax.device_put(x, device) + dl_device = y.__dlpack_device__() + if use_stream: + stream = tuple(y.devices())[0].get_stream_for_external_ready_events() + dlpack = jax.dlpack.to_dlpack(y, copy=copy, stream=stream) + else: + dlpack = jax.dlpack.to_dlpack(y, copy=copy) + z = jax.dlpack.from_dlpack(dlpack) + + self.assertEqual(z.devices(), {device}) + self.assertAllClose(np.astype(x.dtype), z) self.assertRaisesRegex(RuntimeError, - "DLPack tensor may be consumed at most once", - lambda: jax.dlpack.from_dlpack(dlpack)) + "DLPack tensor may be consumed at most once", + lambda: jax.dlpack.from_dlpack(dlpack)) + + if shape in nonempty_array_shapes: + _check_copy(y, z, bool(copy)) + + # Check if the destination device can be specified + make_dlpack = lambda: x.__dlpack__(dl_device=dl_device, copy=copy) + if copy == False: + self.assertRaisesRegex(ValueError, "copy=False", make_dlpack) + return + + z = jax.dlpack.from_dlpack(make_dlpack()) + self.assertEqual(z.devices(), {device}) + self.assertAllClose(x, z) + + if shape in nonempty_array_shapes: + _check_copy(x, z, True) @jtu.sample_product( shape=all_shapes, @@ -119,8 +143,6 @@ def testJaxArrayRoundTrip(self, shape, dtype, gpu): ) @unittest.skipIf(not tf, "Test requires TensorFlow") def testTensorFlowToJax(self, shape, dtype): - if xb.using_pjrt_c_api(): - self.skipTest("DLPack support is incomplete in the PJRT C API") if (not config.enable_x64.value and dtype in [jnp.int64, jnp.uint64, jnp.float64]): raise self.skipTest("x64 types are disabled by jax_enable_x64") @@ -163,8 +185,6 @@ def testJaxToTensorFlow(self, shape, dtype): @unittest.skipIf(not tf, "Test requires TensorFlow") def testTensorFlowToJaxInt64(self): - if xb.using_pjrt_c_api(): - self.skipTest("DLPack support is incomplete in the PJRT C API") # See https://github.com/google/jax/issues/11895 x = jax.dlpack.from_dlpack( tf.experimental.dlpack.to_dlpack(tf.ones((2, 3), tf.int64))) @@ -174,18 +194,26 @@ def testTensorFlowToJaxInt64(self): @jtu.sample_product( shape=all_shapes, dtype=numpy_dtypes, + copy=[False, True], ) - def testNumpyToJax(self, shape, dtype): + def testNumpyToJax(self, shape, dtype, copy): rng = jtu.rand_default(self.rng()) x_np = rng(shape, dtype) - x_jax = jnp.from_dlpack(x_np) - self.assertAllClose(x_np, x_jax) + device = jax.devices()[0] + _from_dlpack = lambda: jnp.from_dlpack(x_np, device=device, copy=copy) + if jax.default_backend() == 'gpu' and not copy: + self.assertRaisesRegex( + ValueError, + r"Specified .* which requires a copy", + _from_dlpack + ) + else: + self.assertAllClose(x_np, _from_dlpack()) @jtu.sample_product( shape=all_shapes, dtype=numpy_dtypes, ) - @unittest.skipIf(numpy_version < (1, 23, 0), "Requires numpy 1.23 or newer") @jtu.run_on_devices("cpu") # NumPy only accepts cpu DLPacks def testJaxToNumpy(self, shape, dtype): rng = jtu.rand_default(self.rng()) @@ -193,7 +221,6 @@ def testJaxToNumpy(self, shape, dtype): x_np = np.from_dlpack(x_jax) self.assertAllClose(x_np, x_jax) - @unittest.skipIf(xla_extension_version < 221, "Requires newer jaxlib") def testNondefaultLayout(self): # Generate numpy array with nonstandard layout a = np.arange(4).reshape(2, 2) @@ -208,7 +235,6 @@ def testNondefaultLayout(self): class CudaArrayInterfaceTest(jtu.JaxTestCase): @jtu.skip_on_devices("cuda") - @unittest.skipIf(xla_extension_version < 228, "Requires newer jaxlib") def testCudaArrayInterfaceOnNonCudaFails(self): x = jnp.arange(5) self.assertFalse(hasattr(x, "__cuda_array_interface__")) @@ -219,7 +245,6 @@ def testCudaArrayInterfaceOnNonCudaFails(self): _ = x.__cuda_array_interface__ @jtu.run_on_devices("cuda") - @unittest.skipIf(xla_extension_version < 233, "Requires newer jaxlib") def testCudaArrayInterfaceOnShardedArrayFails(self): devices = jax.local_devices() if len(devices) <= 1: @@ -251,7 +276,6 @@ def testCudaArrayInterfaceWorks(self, shape, dtype): self.assertEqual(z.__array_interface__["typestr"], a["typestr"]) @jtu.run_on_devices("cuda") - @unittest.skipIf(xla_extension_version < 228, "Requires newer jaxlib") def testCudaArrayInterfaceBfloat16Fails(self): rng = jtu.rand_default(self.rng()) x = rng((2, 2), jnp.bfloat16) @@ -274,7 +298,6 @@ def testJaxToCuPy(self, shape, dtype): z.__cuda_array_interface__["data"][0]) self.assertAllClose(x, cupy.asnumpy(z)) - @unittest.skipIf(xla_extension_version < 237, "Requires newer jaxlib") @jtu.sample_product( shape=all_shapes, dtype=jtu.dtypes.supported(cuda_array_interface_dtypes), @@ -291,16 +314,12 @@ def testCuPyToJax(self, shape, dtype): z.__cuda_array_interface__["data"][0]) self.assertAllClose(np.asarray(z), cupy.asnumpy(y)) - @unittest.skipIf(xla_extension_version < 237, "Requires newer jaxlib") @jtu.sample_product( shape=all_shapes, dtype=jtu.dtypes.supported(cuda_array_interface_dtypes), ) @jtu.run_on_devices("cuda") def testCaiToJax(self, shape, dtype): - # TODO(b/324133505) enable this test for PJRT C API - if xb.using_pjrt_c_api(): - self.skipTest("CUDA Array Interface support is incomplete in the PJRT C API") rng = jtu.rand_default(self.rng()) x = rng(shape, dtype) diff --git a/tests/array_test.py b/tests/array_test.py index 7c8d4c355333..a8c119cfa82e 100644 --- a/tests/array_test.py +++ b/tests/array_test.py @@ -15,7 +15,6 @@ import contextlib import math -import os import unittest from absl.testing import absltest @@ -25,51 +24,38 @@ import jax import jax.numpy as jnp from jax._src import core +from jax._src import deprecations from jax._src import dispatch from jax._src import op_shardings from jax._src import test_util as jtu from jax._src import xla_bridge as xb from jax._src.lib import xla_client as xc from jax._src.util import safe_zip +from jax._src.sharding import common_devices_indices_map from jax._src.sharding_impls import (_op_sharding_to_pos_sharding, pmap_sharding_devices_indices_map, - NamedSharding, GSPMDSharding) + NamedSharding, GSPMDSharding, + PositionalSharding) from jax.experimental.pjit import pjit from jax.experimental import multihost_utils from jax.sharding import PartitionSpec as P from jax._src import array from jax._src import prng -from jax import config -config.parse_flags_with_absl() - - -prev_xla_flags = None +jax.config.parse_flags_with_absl() with contextlib.suppress(ImportError): import pytest pytestmark = pytest.mark.multiaccelerator - # Run all tests with 8 CPU devices. +_exit_stack = contextlib.ExitStack() + def setUpModule(): - global prev_xla_flags - prev_xla_flags = os.getenv("XLA_FLAGS") - flags_str = prev_xla_flags or "" - # Don't override user-specified device count, or other XLA flags. - if "xla_force_host_platform_device_count" not in flags_str: - os.environ["XLA_FLAGS"] = (flags_str + - " --xla_force_host_platform_device_count=8") - # Clear any cached backends so new CPU backend will pick up the env var. - xb.get_backend.cache_clear() - -# Reset to previous configuration in case other test modules will be run. + _exit_stack.enter_context(jtu.set_host_platform_device_count(8)) + def tearDownModule(): - if prev_xla_flags is None: - del os.environ["XLA_FLAGS"] - else: - os.environ["XLA_FLAGS"] = prev_xla_flags - xb.get_backend.cache_clear() + _exit_stack.close() def create_array(shape, sharding, global_data=None): @@ -328,24 +314,28 @@ def test_zeros_like(self): self.assertTrue(dispatch.is_single_device_sharding(out.sharding)) def test_wrong_num_arrays(self): + if jax.device_count() < 4: + self.skipTest('Requires more than 4 devices') shape = (8, 2) - mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) - devices = jax.local_devices()[:8] # Taking up to 8 devices + mesh = jtu.create_global_mesh((1, 2), ('x', 'y')) + devices = jax.local_devices()[:2] # Taking up to 2 devices s = jax.sharding.NamedSharding(mesh, P('x', 'y')) inp_data = np.arange(math.prod(shape), dtype=np.float32).reshape(shape) di_map = s.devices_indices_map(shape) bufs = [jax.device_put(inp_data[di_map[d]], d) for d in devices] with self.assertRaisesRegex( ValueError, - r'Expected 8 per-device arrays \(this is how many devices are addressable ' - r'by the sharding\), but got 4'): - array.ArrayImpl(core.ShapedArray(shape, np.float32), s, bufs[:4], committed=True) + r'Expected 2 per-device arrays \(this is how many devices are addressable ' + r'by the sharding\), but got 1'): + array.ArrayImpl(core.ShapedArray(shape, np.float32), s, bufs[:1], committed=True) + for buf, d in zip(list(bufs), jax.local_devices()[2:4]): + bufs.append(jax.device_put(buf, d)) with self.assertRaisesRegex( ValueError, - r'Expected 8 per-device arrays \(this is how many devices are addressable ' - r'by the sharding\), but got 16'): - array.ArrayImpl(core.ShapedArray(shape, np.float32), s, bufs + bufs, committed=True) + r'Expected 2 per-device arrays \(this is how many devices are addressable ' + r'by the sharding\), but got 4'): + array.ArrayImpl(core.ShapedArray(shape, np.float32), s, bufs, committed=True) def test_arrays_not_in_device_assignment(self): if jax.device_count() < 4: @@ -365,21 +355,6 @@ def test_arrays_not_in_device_assignment(self): "in the sharding."): array.ArrayImpl(core.ShapedArray(shape, np.float32), s, bufs, committed=True) - def test_more_devices_in_sharding_than_arrays(self): - shape = (8, 2) - mesh = jtu.create_global_mesh((1, 2), ('x', 'y')) - # Sharding device ids = {0, 1} - s = jax.sharding.NamedSharding(mesh, P('x')) - inp_data = np.arange(math.prod(shape), dtype=np.float32).reshape(shape) - # _arrays device ids = {0, 0} - bufs = [jax.device_put(inp_data, jax.devices()[0]) for _ in range(2)] - with self.assertRaisesRegex( - ValueError, - "Addressable devices and per-device arrays devices do not match. " - r"Sharding contains devices \{1\} that are not present in per-device " - "arrays."): - array.ArrayImpl(core.ShapedArray(shape, np.float32), s, bufs, committed=True) - def test_different_devices_in_arrays_than_sharding(self): if jax.device_count() < 3: self.skipTest('Requires more than 3 devices') @@ -398,6 +373,22 @@ def test_different_devices_in_arrays_than_sharding(self): "in the sharding."): array.ArrayImpl(core.ShapedArray(shape, np.float32), s, bufs, committed=True) + def test_duplicated_devices_in_arrays(self): + if xc._version <= 274: + self.skipTest('Test requires jaxlib version 275') + shape = (8, 2) + mesh = jtu.create_global_mesh((1, 2), ('x', 'y')) + # Sharding device ids = {0, 1} + s = jax.sharding.NamedSharding(mesh, P('x')) + inp_data = np.arange(math.prod(shape), dtype=np.float32).reshape(shape) + # _arrays device ids = {0, 2} + bufs = [jax.device_put(inp_data, jax.devices()[0]) for _ in range(2)] + with self.assertRaisesRegex( + ValueError, + 'When making an array from single-device arrays, the input arrays must' + ' be from distinct devices'): + array.ArrayImpl(core.ShapedArray(shape, np.float32), s, bufs, committed=True) + @parameterized.named_parameters( ("mesh_x_y", P("x", "y"), (2, 2)), ("mesh_x", P("x"), (2, 4)), @@ -611,6 +602,22 @@ def test_array_addressable_shards(self): x = jnp.array([1, 2, 3]) self.assertIsInstance(x.addressable_data(0), array.ArrayImpl) + def test_array_not_hashable(self): + x = jnp.arange(4) + with self.assertRaisesRegex(TypeError, "unhashable type"): + hash(x) + + @jax.jit + def check_tracer_hash(x): + self.assertIsInstance(hash(x), int) + + if deprecations.is_accelerated('tracer-hash'): + with self.assertRaisesRegex(TypeError, "unhashable type"): + check_tracer_hash(x) + else: + with self.assertWarnsRegex(FutureWarning, "unhashable type"): + check_tracer_hash(x) + def test_shape_dtype_struct_sharding_jit(self): mesh = jtu.create_global_mesh((8,), ('x')) s = jax.sharding.NamedSharding(mesh, P('x')) @@ -795,6 +802,37 @@ def test_shards_have_correct_dtype(self, dtype): for shard in x.addressable_shards: self.assertEqual(shard.data.dtype, dtype) + def test_make_array_from_callback_global_array(self): + mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + sharding = jax.sharding.NamedSharding(mesh, P()) + np_inp = np.arange(16).reshape(8, 2) + arr = jax.device_put(np_inp, sharding) + + out = jax.make_array_from_callback(np_inp.shape, sharding, + lambda idx: arr[idx]) + self.assertArraysEqual(out, arr) + self.assertEqual(out.sharding, sharding) + + sharding2 = NamedSharding(mesh, P('x', 'y')) + arr2 = jax.device_put(np_inp, sharding2) + out2 = jax.make_array_from_callback(np_inp.shape, sharding2, + lambda idx: arr2[idx]) + self.assertArraysEqual(out2, arr2) + self.assertEqual(out2.sharding, sharding2) + + def test_make_array_from_process_data_single_host_data_sharding(self): + data = np.ones((1, 512)) + mesh = jtu.create_global_mesh((1, 1), ('x', 'unused')) + sharding_spec = jax.sharding.NamedSharding( + mesh, jax.sharding.PartitionSpec('x') + ) + global_shape = data.shape + result = jax.make_array_from_process_local_data( + sharding_spec, data, global_shape + ) + self.assertIsInstance(result, jax.Array) + self.assertEqual(result.shape, data.shape) + self.assertEqual(result.sharding, sharding_spec) class ShardingTest(jtu.JaxTestCase): @@ -815,6 +853,15 @@ def test_mesh_pspec_sharding_interface(self): self.assertListEqual(hlo_sharding.tile_assignment_devices(), [0, 2, 4, 6, 1, 3, 5, 7]) + def test_util_clear_cache(self): + mesh = jtu.create_global_mesh((1,), ('x',)) + s = NamedSharding(mesh, P()) + s.devices_indices_map((8,)) + jax.clear_caches() + s.devices_indices_map((8,)) + c = common_devices_indices_map.cache_info() + self.assertEqual(c.currsize, 1) + @parameterized.named_parameters( ("mesh_x_y", P("x", "y")), ("mesh_x", P("x")), @@ -892,7 +939,7 @@ def test_is_compatible_error(self): r"Sharding NamedSharding\(mesh=Mesh\('replica': 1, 'data': 1, 'mdl': 2\), " r"spec=PartitionSpec\(None, \('mdl',\), None, None\).*\) is only " "valid for values of rank at least 4, but was applied to a value of rank 2"): - new_mps.is_compatible_aval(shape) + new_mps.check_compatible_aval(shape) def test_is_subclass(self): # array version of api_test.py::APITest::test_is_subclass @@ -917,6 +964,10 @@ def test_gspmd_sharding_repr(self): # memory kind also appears in the repr but only for TPU. self.assertIn('GSPMDSharding({replicated}', repr(s2)) + def test_positional_sharding_fully_replicated(self): + sharding = PositionalSharding(jax.devices()) + jax.device_put(jnp.array(1), sharding.replicate()) # doesn't crash + @parameterized.named_parameters( ("mesh_x_y", P("x", "y"), (4, 2), (), False), ("mesh_x", P("x"), (4, 2), (1,), False), @@ -946,6 +997,17 @@ def test_positional_sharding_op_sharding_lowering( devices_sharding.shard_shape(value_shape)) self.assertTrue(op_shardings.are_op_shardings_equal(op1, op2)) + def test_positional_sharding_aval_compatible(self): + if jax.device_count() < 2: + self.skipTest('Requires >=2 devices') + sharding = PositionalSharding(jax.devices()).reshape(1, jax.device_count()) + x = jax.random.uniform(jax.random.key(42), (256, 20, 1000)) + with self.assertRaisesRegex( + ValueError, + 'Sharding PositionalSharding.*is only valid for values of rank 2, but' + ' was applied to a value of rank 3'): + jax.lax.with_sharding_constraint(x, sharding) + @parameterized.named_parameters( ("2d_mesh_x_y", (4, 2), P("x", "y")), ("2d_mesh_x", (4, 2), P("x")), @@ -1148,7 +1210,7 @@ def test_scalar_input_wrong_pspec(self): with self.assertRaisesRegex( ValueError, r"For scalars the PartitionSpec should be P()"): - s.is_compatible_aval(shape) + s.check_compatible_aval(shape) def test_mesh_caching_during_construction(self): if jax.device_count() < 2: @@ -1192,6 +1254,18 @@ def f(x): with self.assertRaisesRegex(ValueError, msg): jax.jit(f)(x) + def test_make_array_from_single_device_arrays_bad_inputs(self): + x = jnp.arange(10) + mesh = jtu.create_global_mesh((2,), ('x',)) + s = jax.sharding.NamedSharding(mesh, P('x')) + x = jax.device_put(x, s) + + msg = ("When making an array from single-device arrays the input arrays " + "must have one shard each. An argument array had 2 shard\\(s\\).") + with self.assertRaisesRegex(ValueError, msg): + jax.make_array_from_single_device_arrays(x.shape, s, [x, x]) + + def test_gspmd_sharding_hash_eq(self): mesh = jtu.create_global_mesh((1, 1, 1), ('x', 'y', 'z')) ns = NamedSharding(mesh, P('x', 'y', 'z')) @@ -1223,18 +1297,18 @@ def f(x): x = array.make_array_from_callback(global_x.shape, s, lambda i: global_x[i]) # check computation is fully partitioned and without any communication - jax.config.update('jax_threefry_partitionable', True) - unopt_txt = f.lower(x).as_text(dialect='hlo') - opt_txt = f.lower(x).compile().as_text() - self.assertIn( f'[{n}]', unopt_txt) - self.assertNotIn(f'[{n}]', opt_txt) - self.assertNotIn('all-reduce', opt_txt) - self.assertNotIn('collective-permute', opt_txt) - - # check against single-device reference - y = f(x) - y_ref1 = f(jax.device_put(x, jax.devices()[0])) - self.assertArraysEqual(y, y_ref1) + with jax.threefry_partitionable(True): + unopt_txt = f.lower(x).as_text(dialect='hlo') + opt_txt = f.lower(x).compile().as_text() + self.assertIn( f'[{n}]', unopt_txt) + self.assertNotIn(f'[{n}]', opt_txt) + self.assertNotIn('all-reduce', opt_txt) + self.assertNotIn('collective-permute', opt_txt) + + # check against single-device reference + y = f(x) + y_ref1 = f(jax.device_put(x, jax.devices()[0])) + self.assertArraysEqual(y, y_ref1) @parameterized.named_parameters( {"testcase_name": f"_{mesh_shape}_{pspec}", @@ -1255,23 +1329,23 @@ def f(x): s = jax.sharding.NamedSharding(mesh, pspec) n = math.prod(global_shape) - global_x = jnp.arange(n).astype('uint32').reshape(global_shape) + global_x = np.arange(n).astype('uint32').reshape(global_shape) x = array.make_array_from_callback(global_x.shape, s, lambda i: global_x[i]) # check computation is fully partitioned and without any communication - jax.config.update('jax_threefry_partitionable', True) - unopt_txt = f.lower(x).as_text(dialect='hlo') - opt_txt = f.lower(x).compile().as_text() - global_shape_fmt = ','.join(str(x) for x in global_shape) - self.assertIn( f'[{global_shape_fmt}]', unopt_txt) - self.assertNotIn(f'[{global_shape_fmt}]', opt_txt) - self.assertNotIn('all-reduce', opt_txt) - self.assertNotIn('collective-permute', opt_txt) - - # check against single-device reference - y = f(x) - y_ref1 = f(jax.device_put(x, jax.devices()[0])) - self.assertArraysEqual(y, y_ref1) + with jax.threefry_partitionable(True): + unopt_txt = f.lower(x).as_text(dialect='hlo') + opt_txt = f.lower(x).compile().as_text() + global_shape_fmt = ','.join(str(x) for x in global_shape) + self.assertIn( f'[{global_shape_fmt}]', unopt_txt) + self.assertNotIn(f'[{global_shape_fmt}]', opt_txt) + self.assertNotIn('all-reduce', opt_txt) + self.assertNotIn('collective-permute', opt_txt) + + # check against single-device reference + y = f(x) + y_ref1 = f(jax.device_put(x, jax.devices()[0])) + self.assertArraysEqual(y, y_ref1) if __name__ == '__main__': diff --git a/tests/attrs_test.py b/tests/attrs_test.py index 1f372907596e..5c834f314270 100644 --- a/tests/attrs_test.py +++ b/tests/attrs_test.py @@ -45,7 +45,6 @@ class Thing: class AttrsTest(jtu.JaxTestCase): - @parameterized.parameters([True, False]) def test_jit_basic(self, jit: bool): thing = Thing(1.0) @@ -67,6 +66,100 @@ def double_it() -> None: double_it() self.assertEqual(thing.x, 16.0) + @parameterized.parameters([True, False]) + def test_jit_basic_tree(self, jit: bool): + thing = Thing((1.0, 2.0)) + + def double_it() -> None: + (cur_x, cur_y) = jax_getattr(thing, "x") + jax_setattr(thing, "x", (cur_x * 2, cur_y * 2)) + + if jit: + double_it = jax.jit(double_it) + + self.assertEqual(thing.x, (1.0, 2.0)) + double_it() + self.assertEqual(thing.x, (2.0, 4.0)) + double_it() + self.assertEqual(thing.x, (4.0, 8.0)) + double_it() + self.assertEqual(thing.x, (8.0, 16.0)) + double_it() + self.assertEqual(thing.x, (16.0, 32.0)) + + @parameterized.parameters([True, False]) + def test_jit_basic_tree_changes(self, jit: bool): + thing = Thing(None) + count = 0 + + def double_it() -> None: + nonlocal count + count += 1 + maybe_x = jax_getattr(thing, "x") + x = 1.0 if maybe_x is None else maybe_x + jax_setattr(thing, "x", 2 * x) + + if jit: + double_it = jax.jit(double_it) + + self.assertEqual(thing.x, None) + double_it() + self.assertEqual(thing.x, 2.0) + self.assertEqual(count, 1) + double_it() + self.assertEqual(thing.x, 4.0) + self.assertEqual(count, 2) + double_it() + self.assertEqual(thing.x, 8.0) + self.assertEqual(count, 2 + (not jit)) + + def test_jit_basic_tree_changes_multiple(self): + thing1 = Thing(None) + thing2 = Thing(0) + count = 0 + + @jax.jit + def double_it() -> None: + nonlocal count + count += 1 + + x1 = jax_getattr(thing1, "x") + if x1 is None: + jax_setattr(thing1, 'x', (None,)) + elif isinstance(x1, tuple): + # depend on a new value + jax_setattr(thing1, 'x', jax_getattr(thing2, 'x') + 1) + else: + jax_setattr(thing2, 'x', jax_getattr(thing1, 'x')) + jax_setattr(thing1, 'x', None) + + self.assertEqual(thing1.x, None) + self.assertEqual(thing2.x, 0) + double_it() + self.assertEqual(thing1.x, (None,)) + self.assertEqual(thing2.x, 0) + self.assertEqual(count, 1) + double_it() + self.assertEqual(thing1.x, 1) + self.assertEqual(thing2.x, 0) + self.assertEqual(count, 2) + double_it() + self.assertEqual(thing1.x, None) + self.assertEqual(thing2.x, 1) + self.assertEqual(count, 3) + double_it() + self.assertEqual(thing1.x, (None,)) + self.assertEqual(thing2.x, 1) + self.assertEqual(count, 3) + double_it() + self.assertEqual(thing1.x, 2) + self.assertEqual(thing2.x, 1) + self.assertEqual(count, 3) + double_it() + self.assertEqual(thing1.x, None) + self.assertEqual(thing2.x, 2) + self.assertEqual(count, 3) + def test_jit_nesting_basic(self): thing = Thing(1.0) diff --git a/tests/batching_test.py b/tests/batching_test.py index 3bcd4c5216cc..4d912bfca206 100644 --- a/tests/batching_test.py +++ b/tests/batching_test.py @@ -14,10 +14,11 @@ from __future__ import annotations +from collections.abc import Callable from contextlib import contextmanager from functools import partial import itertools as it -from typing import Any, Callable, TypeVar, Union +from typing import Any, TypeVar, Union import numpy as np from absl.testing import absltest @@ -37,8 +38,7 @@ from jax.interpreters import batching from jax.tree_util import register_pytree_node -from jax import config -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() # These are 'manual' tests for batching (vmap). The more exhaustive, more @@ -962,7 +962,7 @@ def body_fn(uk): u, _ = lax.while_loop(lambda uk: uk[0] > 0.5, body_fn, (1., key)) return u - with jax.enable_key_reuse_checks(False): + with jax.debug_key_reuse(False): print(vmap(f)(random.split(random.PRNGKey(0), 2))) # no crash def testEmptyTuples(self): diff --git a/tests/blocked_sampler_test.py b/tests/blocked_sampler_test.py new file mode 100644 index 000000000000..1f8f2b645f06 --- /dev/null +++ b/tests/blocked_sampler_test.py @@ -0,0 +1,90 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from absl.testing import absltest +from absl.testing import parameterized +import jax +from jax import numpy as jnp +from jax._src import blocked_sampler +from jax._src import config +from jax._src import test_util as jtu +import numpy as np + + +config.parse_flags_with_absl() + + +def call_kernel( + kernel, + grid: tuple[int, int], + transpose_grid: bool, + *args + ): + """Calls a kernel over a grid and concatenates results to a single array.""" + if transpose_grid: + grid = (grid[1], grid[0]) + m, n = grid + return jnp.concatenate([ + jnp.concatenate([ + kernel(i, j, *args) for j in range(n)], axis=1) + for i in range(m)], axis=0) + + +def uniform_kernel(i: int, j: int, total_size, block_size, tile_size): + """Uniform random sampling kernel function.""" + global_key = jax.random.key(0) + keys = blocked_sampler.blocked_fold_in(global_key, + total_size=total_size, + block_size=block_size, + tile_size=tile_size, + block_index=(i, j)) + return blocked_sampler.sample_block(jax.random.uniform, + keys, + block_size=block_size, + tile_size=tile_size, + minval=0.0, maxval=1.0) + + +class BlockedSamplerTest(jtu.JaxTestCase): + + @parameterized.named_parameters( + dict(testcase_name='8x128_vs_16x256', total_size=(32, 256), + block_size_a=(8, 128), block_size_b=(16, 256), + tile_size=(8, 128), transpose_grid=False), + dict(testcase_name='transpose_8x128_vs_16x256', total_size=(32, 256), + block_size_a=(8, 128), block_size_b=(16, 256), + tile_size=(8, 128), transpose_grid=True), + dict(testcase_name='8x128_vs_32x128', total_size=(32, 128), + block_size_a=(8, 128), block_size_b=(32, 128), + tile_size=(8, 128), transpose_grid=False), + dict(testcase_name='16x256_vs_32x128', total_size=(32, 256), + block_size_a=(16, 256), block_size_b=(32, 128), + tile_size=(8, 128), transpose_grid=False), + ) + def test_block_shape_invariance(self, total_size, block_size_a, + block_size_b, tile_size, transpose_grid): + grid_a = tuple(_tot // _blk for _tot, _blk in zip(total_size, block_size_a)) + result_a = call_kernel( + uniform_kernel, grid_a, transpose_grid, + total_size, block_size_a, tile_size) + + grid_b = tuple(_tot // _blk for _tot, _blk in zip(total_size, block_size_b)) + result_b = call_kernel( + uniform_kernel, grid_b, transpose_grid, + total_size, block_size_b, tile_size) + np.testing.assert_array_equal(result_a, result_b) + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/cache_key_test.py b/tests/cache_key_test.py index 33fc02009353..508dbacc2a98 100644 --- a/tests/cache_key_test.py +++ b/tests/cache_key_test.py @@ -155,6 +155,23 @@ def test_different_computations(self): cache_key.get(computation2, devices, compile_options, backend), ) + def test_different_device_assignment(self): + computation = jax.jit(lambda x, y: x + y).lower(1, 1).compiler_ir() + devices = np.array([[jax.local_devices()[0]]]) + compile_options_1 = compiler.get_compile_options( + num_replicas=1, num_partitions=1, device_assignment=np.array([[0]]) + ) + compile_options_2 = compiler.get_compile_options( + num_replicas=1, num_partitions=1, device_assignment=np.array([[1]]) + ) + backend = xla_bridge.get_backend() + hash_1 = cache_key.get(computation, devices, compile_options_1, backend) + hash_2 = cache_key.get(computation, devices, compile_options_2, backend) + if backend.platform == "gpu": + self.assertEqual(hash_1, hash_2) + else: + self.assertNotEqual(hash_1, hash_2) + @parameterized.parameters([False, True]) def test_identical_computations_different_metadata(self, include_metadata): f = lambda x, y: lax.mul(lax.add(x, y), 2) diff --git a/tests/cholesky_update_test.py b/tests/cholesky_update_test.py new file mode 100644 index 000000000000..63f732dcd55d --- /dev/null +++ b/tests/cholesky_update_test.py @@ -0,0 +1,72 @@ +# Copyright 2024 The JAX Authors. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from absl.testing import absltest + +from jax import numpy as jnp +from jax._src import config +from jax._src import test_util as jtu +from jax._src.lax import linalg as lax_linalg +import numpy as np + +config.parse_flags_with_absl() + +class CholeskyUpdateTest(jtu.JaxTestCase): + + @jtu.sample_product( + shape=[ + (128, 128), + ], + dtype=[jnp.float32, jnp.float64], + ) + def testUpperOnes(self, shape, dtype): + """A test with a (mildly) ill-conditioned matrix.""" + if dtype is jnp.float64 and not config.enable_x64.value: + self.skipTest("Test disabled for x32 mode") + r_upper = jnp.triu(jnp.ones(shape)).astype(dtype) + w = jnp.arange(1, shape[0] + 1).astype(dtype) + new_matrix = r_upper.T @ r_upper + jnp.outer(w, w) + new_cholesky = jnp.linalg.cholesky(new_matrix, upper=True) + + updated = lax_linalg.cholesky_update(r_upper, w) + + atol = 1e-6 if (dtype is jnp.float64) else 2e-2 + jtu._assert_numpy_allclose(updated, new_cholesky, atol=atol) + + @jtu.sample_product( + shape=[ + (128, 128), + ], + dtype=[jnp.float32, jnp.float64], + ) + def testRandomMatrix(self, shape, dtype): + if dtype is jnp.float64 and not config.enable_x64.value: + self.skipTest("Test disabled for x32 mode") + rng = jtu.rand_default(self.rng()) + a = rng(shape, np.float64) + pd_matrix = jnp.array(a.T @ a).astype(dtype) + old_cholesky = jnp.linalg.cholesky(pd_matrix, upper=True) + + w = rng((shape[0],), np.float64) + w = jnp.array(w).astype(dtype) + + new_matrix = pd_matrix + jnp.outer(w, w) + new_cholesky = jnp.linalg.cholesky(new_matrix, upper=True) + updated = lax_linalg.cholesky_update(old_cholesky, w) + atol = 1e-6 if dtype == jnp.float64 else 1e-3 + jtu._assert_numpy_allclose(updated, new_cholesky, atol=atol) + +if __name__ == "__main__": + absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/clear_backends_test.py b/tests/clear_backends_test.py index f149e3c90f6d..9ea9cac3a72c 100644 --- a/tests/clear_backends_test.py +++ b/tests/clear_backends_test.py @@ -14,13 +14,12 @@ """Tests for release_backend_clients.""" from absl.testing import absltest - import jax -from jax import config +from jax._src import api from jax._src import test_util as jtu from jax._src import xla_bridge as xb -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() class ClearBackendsTest(jtu.JaxTestCase): @@ -29,7 +28,7 @@ def test_clear_backends(self): g = jax.jit(lambda x, y: x * y) self.assertEqual(g(1, 2), 2) self.assertNotEmpty(xb.get_backend().live_executables()) - jax.clear_backends() + api.clear_backends() self.assertEmpty(xb.get_backend().live_executables()) self.assertEqual(g(1, 2), 2) diff --git a/tests/compilation_cache_test.py b/tests/compilation_cache_test.py index b9085fd9e28b..a3e9c623db4c 100644 --- a/tests/compilation_cache_test.py +++ b/tests/compilation_cache_test.py @@ -12,12 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + +from collections import Counter from functools import partial import math -import os import platform -import tempfile -from collections import Counter import unittest from unittest import mock from unittest import SkipTest @@ -34,10 +34,12 @@ from jax._src import config from jax._src import distributed from jax._src import monitoring +from jax._src import path as pathlib from jax._src import test_util as jtu from jax._src import xla_bridge +from jax._src.compilation_cache_interface import CacheInterface from jax._src.lib import xla_client -from jax.experimental.maps import xmap +from jax._src.maps import xmap from jax.experimental.pjit import pjit from jax.sharding import PartitionSpec as P import numpy as np @@ -61,270 +63,266 @@ def increment_event_count(event): _counts[event] += 1 +class InMemoryCache(CacheInterface): + """An in-memory cache for testing purposes.""" + + # not used, but required by `CacheInterface` + _path = pathlib.Path() + + def __init__(self): + self._cache: dict[str, bytes] = {} + + def get(self, key: str) -> bytes | None: + return self._cache.get(key) + + def put(self, key: str, value: bytes) -> None: + self._cache[key] = value + + def clear(self) -> None: + self._cache.clear() + + def __len__(self) -> int: + return len(self._cache) + + +def count_cache_items() -> int: + return 0 if cc._cache is None else len(cc._cache) + + +def clear_cache() -> None: + if cc._cache is not None: + cc._cache.clear() + + +class CompilationCacheTestCase(jtu.JaxTestCase): + + def setUp(self): + super().setUp() + cc.reset_cache() + cc._cache = InMemoryCache() + + def tearDown(self): + cc.reset_cache() + super().tearDown() + + @jtu.with_config( jax_enable_compilation_cache=True, jax_raise_persistent_cache_errors=True, jax_persistent_cache_min_compile_time_secs=0, jax_persistent_cache_min_entry_size_bytes=0, ) -class CompilationCacheTest(jtu.JaxTestCase): - +class CompilationCacheTest(CompilationCacheTestCase): def setUp(self): super().setUp() - # TODO(b/323256224): Add back support for CPU together with extra fields in - # a cache key with underlying hardware features. - supported_platforms = ["tpu", "gpu"] + supported_platforms = ["tpu", "gpu", "cpu"] if not jtu.test_device_matches(supported_platforms): raise SkipTest( "serialize executable only works on " + ",".join(supported_platforms) ) - cc.reset_cache() - - def tearDown(self): - cc.reset_cache() - super().tearDown() - def test_get_no_executable(self): - with tempfile.TemporaryDirectory() as tmpdir: - cc.set_cache_dir(tmpdir) - computation = jax.jit(lambda x, y: x + y).lower(1, 1).compiler_ir() - devices = np.array([[jax.local_devices()[0]]]) - compile_options = compiler.get_compile_options( - num_replicas=1, num_partitions=1 - ) - backend = xla_bridge.get_backend() - key = cc.get_cache_key(computation, devices, compile_options, backend) - executable, compile_time = cc.get_executable_and_time( - key, compile_options, backend) - self.assertIsNone(executable) - self.assertIsNone(compile_time) + computation = jax.jit(lambda x, y: x + y).lower(1, 1).compiler_ir() + devices = np.array([[jax.local_devices()[0]]]) + compile_options = compiler.get_compile_options( + num_replicas=1, num_partitions=1 + ) + backend = xla_bridge.get_backend() + key = cc.get_cache_key(computation, devices, compile_options, backend) + executable, compile_time = cc.get_executable_and_time( + key, compile_options, backend) + self.assertIsNone(executable) + self.assertIsNone(compile_time) def test_diff_executables(self): - with tempfile.TemporaryDirectory() as tmpdir: - cc.set_cache_dir(tmpdir) - computation1 = str(jax.jit(lambda x, y: x + y).lower(1, 1).compiler_ir()) - computation2 = str(jax.jit(lambda x, y: x * y).lower(2, 2).compiler_ir()) - compile_options = compiler.get_compile_options( - num_replicas=1, num_partitions=1 - ) - backend = xla_bridge.get_backend() - executable1 = backend.compile(computation1, compile_options) - executable2 = backend.compile(computation2, compile_options) - cc.put_executable_and_time( - "key1", "computation1", executable1, backend, FAKE_COMPILE_TIME) - cc.put_executable_and_time( - "key2", "computation2", executable2, backend, FAKE_COMPILE_TIME) - self.assertNotEqual( - cc.get_executable_and_time("key1", compile_options, backend)[0], - cc.get_executable_and_time("key2", compile_options, backend)[0] - ) + computation1 = str(jax.jit(lambda x, y: x + y).lower(1, 1).compiler_ir()) + computation2 = str(jax.jit(lambda x, y: x * y).lower(2, 2).compiler_ir()) + compile_options = compiler.get_compile_options( + num_replicas=1, num_partitions=1 + ) + backend = xla_bridge.get_backend() + executable1 = backend.compile(computation1, compile_options) + executable2 = backend.compile(computation2, compile_options) + cc.put_executable_and_time( + "key1", "computation1", executable1, backend, FAKE_COMPILE_TIME) + cc.put_executable_and_time( + "key2", "computation2", executable2, backend, FAKE_COMPILE_TIME) + self.assertNotEqual( + cc.get_executable_and_time("key1", compile_options, backend)[0], + cc.get_executable_and_time("key2", compile_options, backend)[0] + ) def test_put_executable(self): - with tempfile.TemporaryDirectory() as tmpdir: - cc.set_cache_dir(tmpdir) - computation = ( - jax.jit(lambda x, y: x + y) - .lower(np.int32(1), np.int32(1)) - .compiler_ir() - ) - devices = np.array([[jax.local_devices()[0]]]) - compile_options = compiler.get_compile_options( - num_replicas=1, num_partitions=1 - ) - backend = xla_bridge.get_backend() - executable = backend.compile(str(computation), compile_options) - key = cc.get_cache_key(computation, devices, compile_options, backend) - cc.put_executable_and_time( - key, "alambda", executable, backend, FAKE_COMPILE_TIME) - executable_retrieved, compile_time_retrieved = cc.get_executable_and_time( - key, compile_options, backend) - inputs_to_executable = ( - np.array(1, dtype=np.int32), - np.array(2, dtype=np.int32), - ) - expected = xla_client.execute_with_python_values( - executable, inputs_to_executable, backend - ) - actual = xla_client.execute_with_python_values( - executable_retrieved, inputs_to_executable, backend - ) - self.assertEqual(expected, actual) - self.assertEqual(FAKE_COMPILE_TIME, compile_time_retrieved) + computation = ( + jax.jit(lambda x, y: x + y) + .lower(np.int32(1), np.int32(1)) + .compiler_ir() + ) + devices = np.array([[jax.local_devices()[0]]]) + compile_options = compiler.get_compile_options( + num_replicas=1, num_partitions=1 + ) + backend = xla_bridge.get_backend() + executable = backend.compile(str(computation), compile_options) + key = cc.get_cache_key(computation, devices, compile_options, backend) + cc.put_executable_and_time( + key, "alambda", executable, backend, FAKE_COMPILE_TIME) + executable_retrieved, compile_time_retrieved = cc.get_executable_and_time( + key, compile_options, backend) + inputs_to_executable = ( + np.array(1, dtype=np.int32), + np.array(2, dtype=np.int32), + ) + expected = xla_client.execute_with_python_values( + executable, inputs_to_executable, backend + ) + actual = xla_client.execute_with_python_values( + executable_retrieved, inputs_to_executable, backend + ) + self.assertEqual(expected, actual) + self.assertEqual(FAKE_COMPILE_TIME, compile_time_retrieved) def test_pmap(self): - with tempfile.TemporaryDirectory() as tmpdir: - cc.set_cache_dir(tmpdir) - f = pmap(lambda x: x - lax.psum(x, "i"), axis_name="i") - x = np.arange(jax.device_count(), dtype=np.int64) - f(x) - files_in_directory = len(os.listdir(tmpdir)) - self.assertEqual(files_in_directory, 1) - x = np.arange(jax.device_count(), dtype=np.float32) - f(x) - files_in_directory = len(os.listdir(tmpdir)) - self.assertEqual(files_in_directory, 2) - # TODO: create a test for calling pmap with the same input more than once + f = pmap(lambda x: x - lax.psum(x, "i"), axis_name="i") + x = np.arange(jax.device_count(), dtype=np.int64) + f(x) + self.assertEqual(count_cache_items(), 1) + x = np.arange(jax.device_count(), dtype=np.float32) + f(x) + self.assertEqual(count_cache_items(), 2) + # TODO: create a test for calling pmap with the same input more than once def test_jit(self): - with tempfile.TemporaryDirectory() as tmpdir: - cc.set_cache_dir(tmpdir) - f = jit(lambda x: x * x) - f(1) - files_in_directory = len(os.listdir(tmpdir)) - self.assertEqual(files_in_directory, 1) - f(1.0) - files_in_directory = len(os.listdir(tmpdir)) - self.assertEqual(files_in_directory, 2) + f = jit(lambda x: x * x) + f(1) + self.assertEqual(count_cache_items(), 1) + f(1.0) + self.assertEqual(count_cache_items(), 2) def test_xla_autofdo_profile_version(self): original_profile_version = config.jax_xla_profile_version.value - with (tempfile.TemporaryDirectory() as tmpdir, - config.jax_xla_profile_version(original_profile_version + 1)): - cc.set_cache_dir(tmpdir) + with config.jax_xla_profile_version(original_profile_version + 1): f = jit(lambda x: x * x) f(1) - files_in_cache_directory = os.listdir(tmpdir) - self.assertLen(files_in_cache_directory, 1) + self.assertEqual(count_cache_items(), 1) # Clear the cache directory, then update the profile version and execute # again. The in-memory caches should be invalidated and a new persistent # cache entry created. - os.unlink(os.path.join(tmpdir, files_in_cache_directory[0])) + clear_cache() with config.jax_xla_profile_version(original_profile_version + 2): f(1) - files_in_directory = len(os.listdir(tmpdir)) - self.assertEqual(files_in_directory, 1) + self.assertEqual(count_cache_items(), 1) @jtu.with_mesh([("x", 2)]) def test_pjit(self): - with tempfile.TemporaryDirectory() as tmpdir: - cc.set_cache_dir(tmpdir) - - @partial(pjit, in_shardings=(P("x"), P("x")), out_shardings=None) - def f(x, y): - return x + y - - shape = (8, 8) - x = np.arange(math.prod(shape), dtype=np.int64).reshape(shape) - f(x, x + 1) - files_in_directory = len(os.listdir(tmpdir)) - self.assertEqual(files_in_directory, 1) - x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape) - f(x, x + 1) - files_in_directory = len(os.listdir(tmpdir)) - self.assertEqual(files_in_directory, 2) + @partial(pjit, in_shardings=(P("x"), P("x")), out_shardings=None) + def f(x, y): + return x + y + + shape = (8, 8) + x = np.arange(math.prod(shape), dtype=np.int64).reshape(shape) + f(x, x + 1) + self.assertEqual(count_cache_items(), 1) + x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape) + f(x, x + 1) + self.assertEqual(count_cache_items(), 2) @jtu.with_mesh([("x", 2)]) def test_xmap(self): - with tempfile.TemporaryDirectory() as tmpdir: - cc.set_cache_dir(tmpdir) - - def f(x): - return x * 2 - - devices = np.array(jax.local_devices()[:2]) - if devices.size < 2: - raise SkipTest("Test requires 2 devices") - x = np.arange(8, dtype=np.int64).reshape((2, 2, 2)) - xmap( - f, in_axes=["a", ...], out_axes=["a", ...], axis_resources={"a": "x"} - )(x) - files_in_directory = len(os.listdir(tmpdir)) - self.assertEqual(files_in_directory, 1) - x = np.arange(8, dtype=np.float32).reshape((2, 2, 2)) - xmap( - f, in_axes=["a", ...], out_axes=["a", ...], axis_resources={"a": "x"} - )(x) - files_in_directory = len(os.listdir(tmpdir)) - self.assertEqual(files_in_directory, 2) + def f(x): + return x * 2 + + devices = np.array(jax.local_devices()[:2]) + if devices.size < 2: + raise SkipTest("Test requires 2 devices") + x = np.arange(8, dtype=np.int64).reshape((2, 2, 2)) + xmap( + f, in_axes=["a", ...], out_axes=["a", ...], axis_resources={"a": "x"} + )(x) + self.assertEqual(count_cache_items(), 1) + x = np.arange(8, dtype=np.float32).reshape((2, 2, 2)) + xmap( + f, in_axes=["a", ...], out_axes=["a", ...], axis_resources={"a": "x"} + )(x) + self.assertEqual(count_cache_items(), 2) def test_cache_write_warning(self): - with tempfile.TemporaryDirectory() as tmpdir: - cc.set_cache_dir(tmpdir) - f = jit(lambda x: x * x) + f = jit(lambda x: x * x) - with ( - config.raise_persistent_cache_errors(False), - mock.patch.object(cc._get_cache().__class__, "put") as mock_put, - warnings.catch_warnings(record=True) as w, - ): - mock_put.side_effect = RuntimeError("test error") - self.assertEqual(f(2), 4) - self.assertLen(w, 1) - self.assertIn( - ( - "Error writing persistent compilation cache entry " - "for 'jit__lambda_': RuntimeError: test error" - ), - str(w[0].message), - ) + backend = xla_bridge.get_backend() + with ( + config.raise_persistent_cache_errors(False), + mock.patch.object(cc._get_cache(backend).__class__, "put") as mock_put, + warnings.catch_warnings(record=True) as w, + ): + mock_put.side_effect = RuntimeError("test error") + self.assertEqual(f(2).item(), 4) + if len(w) != 1: + print("Warnings:", [str(w_) for w_ in w], flush=True) + self.assertLen(w, 1) + self.assertIn( + ( + "Error writing persistent compilation cache entry " + "for 'jit__lambda_': RuntimeError: test error" + ), + str(w[0].message), + ) def test_cache_read_warning(self): - with tempfile.TemporaryDirectory() as tmpdir: - cc.set_cache_dir(tmpdir) - f = jit(lambda x: x * x) + f = jit(lambda x: x * x) - with ( - config.raise_persistent_cache_errors(False), - mock.patch.object(cc._get_cache().__class__, "get") as mock_get, - warnings.catch_warnings(record=True) as w, - ): - mock_get.side_effect = RuntimeError("test error") - self.assertEqual(f(2), 4) - if len(w) > 1: - print("Warnings:", [str(w_) for w_ in w], flush=True) - self.assertLen(w, 1) - self.assertIn( - ( - "Error reading persistent compilation cache entry " - "for 'jit__lambda_': RuntimeError: test error" - ), - str(w[0].message), - ) + backend = xla_bridge.get_backend() + with ( + config.raise_persistent_cache_errors(False), + mock.patch.object(cc._get_cache(backend).__class__, "get") as mock_get, + warnings.catch_warnings(record=True) as w, + ): + mock_get.side_effect = RuntimeError("test error") + # Calling assertEqual with the jitted f will generate two PJIT + # executables: Equal and the lambda function itself. + self.assertEqual(f(2).item(), 4) + if len(w) != 1: + print("Warnings:", [str(w_) for w_ in w], flush=True) + self.assertLen(w, 1) + self.assertIn( + ( + "Error reading persistent compilation cache entry " + "for 'jit__lambda_': RuntimeError: test error" + ), + str(w[0].message), + ) def test_min_entry_size(self): with ( - tempfile.TemporaryDirectory() as tmpdir, config.persistent_cache_min_compile_time_secs(0), config.persistent_cache_min_entry_size_bytes(1048576), # 1MiB ): - cc.set_cache_dir(tmpdir) - jit(lambda x: x + 1)(1) - files_in_cache = len(os.listdir(tmpdir)) - self.assertEqual(files_in_cache, 0) + self.assertEqual(count_cache_items(), 0) def test_min_compile_time(self): with ( - tempfile.TemporaryDirectory() as tmpdir, config.persistent_cache_min_compile_time_secs(2), config.persistent_cache_min_entry_size_bytes(0), ): - cc.set_cache_dir(tmpdir) - # Mock time to progress in small intervals so compilation time is small. with mock.patch("time.monotonic", side_effect=np.arange(0, 10, 0.1)): jit(lambda x: x + 1)(1) - files_in_cache = len(os.listdir(tmpdir)) - self.assertEqual(files_in_cache, 0) + self.assertEqual(count_cache_items(), 0) # Mock time to progress in large intervals so compilation time is large. with mock.patch("time.monotonic", side_effect=np.arange(0, 100, 10)): jit(lambda x: x + 2)(1) - files_in_cache = len(os.listdir(tmpdir)) - self.assertEqual(files_in_cache, 1) + self.assertEqual(count_cache_items(), 1) # This is perhaps related to mocking time.monotonic? @unittest.skipIf(platform.system() == "Windows", "Test fails on Windows") def test_cache_saving_metric(self): with ( - tempfile.TemporaryDirectory() as tmpdir, config.persistent_cache_min_compile_time_secs(2), config.persistent_cache_min_entry_size_bytes(0), ): - cc.set_cache_dir(tmpdir) - durations = Counter() # Map metric name to time duration. def append_metric_duration(metric, duration): durations[metric] += duration @@ -354,29 +352,24 @@ def append_metric_duration(metric, duration): durations["/jax/compilation_cache/compile_time_saved_sec"], 0) def test_task_using_cache_metric(self): - with tempfile.TemporaryDirectory() as tmpdir: - cc.set_cache_dir(tmpdir) - count_before_first_use = _counts[ - "/jax/compilation_cache/tasks_using_cache"] - jit(lambda x: x + 1)(1) - count_after_first_use = _counts[ - "/jax/compilation_cache/tasks_using_cache"] - self.assertEqual(count_after_first_use, count_before_first_use + 1) - - # Verify that the count is incremented only once per task. - jit(lambda x: x + 3)(3) - count_after_second_use = _counts[ - "/jax/compilation_cache/tasks_using_cache"] - self.assertEqual(count_after_second_use, count_after_first_use) + count_before_first_use = _counts[ + "/jax/compilation_cache/tasks_using_cache"] + jit(lambda x: x + 1)(1) + count_after_first_use = _counts[ + "/jax/compilation_cache/tasks_using_cache"] + self.assertEqual(count_after_first_use, count_before_first_use + 1) + + # Verify that the count is incremented only once per task. + jit(lambda x: x + 3)(3) + count_after_second_use = _counts[ + "/jax/compilation_cache/tasks_using_cache"] + self.assertEqual(count_after_second_use, count_after_first_use) def test_compile_requests_use_cache_metric(self): previous_counts = Counter(_counts) - with tempfile.TemporaryDirectory() as tmpdir: - cc.set_cache_dir(tmpdir) - - jit(lambda x: x + 1)(1) - jit(lambda x: x + 2)(1) - jit(lambda x: x + 1)(1) + jit(lambda x: x + 1)(1) + jit(lambda x: x + 2)(1) + jit(lambda x: x + 1)(1) self.assertEqual( _counts["/jax/compilation_cache/compile_requests_use_cache"] @@ -387,12 +380,9 @@ def test_compile_requests_use_cache_metric(self): def test_cache_misses_metric(self, min_entry_size): previous_counts = Counter(_counts) with ( - tempfile.TemporaryDirectory() as tmpdir, config.persistent_cache_min_compile_time_secs(2), config.persistent_cache_min_entry_size_bytes(min_entry_size), ): - cc.set_cache_dir(tmpdir) - # Mock time to create a long compilation time and make cache misses. with mock.patch("time.monotonic", side_effect=np.arange(0, 100, 10)): jit(lambda x: x + 1)(1) @@ -412,12 +402,9 @@ def test_cache_misses_metric(self, min_entry_size): def test_cache_hits_metric(self): previous_counts = Counter(_counts) with ( - tempfile.TemporaryDirectory() as tmpdir, config.persistent_cache_min_compile_time_secs(2), config.persistent_cache_min_entry_size_bytes(0), ): - cc.set_cache_dir(tmpdir) - # Mock time to create a long compilation time, cache saved. with mock.patch("time.monotonic", side_effect=np.arange(0, 100, 10)): jit(lambda x: x + 1)(1) @@ -431,38 +418,39 @@ def test_cache_hits_metric(self): @parameterized.parameters(0, 1) def test_cache_write_with_process_restriction(self, process_id): with ( - tempfile.TemporaryDirectory() as tmpdir, config.persistent_cache_min_compile_time_secs(0), config.persistent_cache_min_entry_size_bytes(0), mock.patch.object(distributed.global_state, "process_id", process_id), ): - cc.set_cache_dir(tmpdir) - jit(lambda x: x + 1)(1) - files_in_directory = len(os.listdir(tmpdir)) + files_in_directory = count_cache_items() if process_id == 0: self.assertEqual(files_in_directory, 1) elif process_id == 1: self.assertEqual(files_in_directory, 0) + def test_backend_serialization_deserialization(self): + backend = xla_bridge.get_backend() + executable = ( + jax.jit(lambda x, y: x + y) + .lower(np.array(1.), np.array(1.)) + .compile() + .runtime_executable() + ) + serialized_executable = backend.serialize_executable(executable) + deserialized_executable = backend.deserialize_executable( + serialized_executable, None) + self.assertEqual( + executable.fingerprint, deserialized_executable.fingerprint) + @jtu.with_config( jax_enable_compilation_cache=False, jax_persistent_cache_min_compile_time_secs=0, jax_persistent_cache_min_entry_size_bytes=0, ) -class CompilationCacheDisabledTest(jtu.JaxTestCase): - - def setUp(self): - super().setUp() - - cc.reset_cache() - - def tearDown(self): - cc.reset_cache() - super().tearDown() - +class CompilationCacheDisabledTest(CompilationCacheTestCase): # If the cache is disabled, there should be no files in the cache directory. # A call to set_cache_dir() does not affect this. def test_jit(self): @@ -471,15 +459,10 @@ def test_jit(self): # 2. Flag is enabled by JaxTestCase for some test configs # (see test_util.py). # We need the flag disabled for this test, so disable it below. - with ( - tempfile.TemporaryDirectory() as tmpdir, - config.enable_compilation_cache(False), - ): - cc.set_cache_dir(tmpdir) + with config.enable_compilation_cache(False): f = jit(lambda x: x * x) f(1) - files_in_directory = len(os.listdir(tmpdir)) - self.assertEqual(files_in_directory, 0) + self.assertEqual(count_cache_items(), 0) if __name__ == "__main__": diff --git a/tests/config_test.py b/tests/config_test.py new file mode 100644 index 000000000000..0f49d988a46c --- /dev/null +++ b/tests/config_test.py @@ -0,0 +1,73 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from absl.testing import absltest + +import jax +from jax._src import test_util as jtu +from jax._src import config + +jax.config.parse_flags_with_absl() + + +jax_test_enum_config = config.enum_state( + name='jax_test_enum_config', + enum_values=['default', 'xxx', 'yyy'], + default='default', + help='Configuration only used for tests.') + + +class ConfigTest(jtu.JaxTestCase): + def test_config_setting_via_update(self): + self.assertEqual(jax_test_enum_config.value, 'default') + + jax.config.update('jax_test_enum_config', 'xxx') + self.assertEqual(jax_test_enum_config.value, 'xxx') + + jax.config.update('jax_test_enum_config', 'yyy') + self.assertEqual(jax_test_enum_config.value, 'yyy') + + jax.config.update('jax_test_enum_config', 'default') + self.assertEqual(jax_test_enum_config.value, 'default') + + def test_config_setting_via_context(self): + self.assertEqual(jax_test_enum_config.value, 'default') + + with jax_test_enum_config('xxx'): + self.assertEqual(jax_test_enum_config.value, 'xxx') + + with jax_test_enum_config('yyy'): + self.assertEqual(jax_test_enum_config.value, 'yyy') + + self.assertEqual(jax_test_enum_config.value, 'xxx') + + self.assertEqual(jax_test_enum_config.value, 'default') + + def test_config_update_validation(self): + self.assertEqual(jax_test_enum_config.value, 'default') + with self.assertRaisesRegex(ValueError, 'new enum value must be in.*'): + jax.config.update('jax_test_enum_config', 'invalid') + # Error should raise before changing the value + self.assertEqual(jax_test_enum_config.value, 'default') + + def test_config_context_validation(self): + self.assertEqual(jax_test_enum_config.value, 'default') + with self.assertRaisesRegex(ValueError, 'new enum value must be in.*'): + with jax_test_enum_config('invalid'): + pass + self.assertEqual(jax_test_enum_config.value, 'default') + + +if __name__ == '__main__': + absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/core_test.py b/tests/core_test.py index 788f61db943d..c75fa3614559 100644 --- a/tests/core_test.py +++ b/tests/core_test.py @@ -406,7 +406,7 @@ def test_check_jaxpr_cond_correct(self): def test_check_jaxpr_jit_invalid(self): jaxpr = make_jaxpr(jax.jit(lambda x, y: x + 1))(1., 2.).jaxpr pjit_eqn, = jaxpr.eqns - jaxpr._eqns[0] = pjit_eqn._replace(invars=()) + jaxpr._eqns[0] = pjit_eqn.replace(invars=()) self.assertRaisesRegex( core.JaxprTypeError, '0 operands cannot call jaxpr with 2 inputs', @@ -750,16 +750,12 @@ def g(x): return x core.check_jaxpr(jaxpr) def test_check_jaxpr_key_reuse(self): - with config.enable_key_reuse_checks(True): - try: - from jax.experimental.key_reuse import KeyReuseError - except ImportError: - self.skipTest("Test requires jax.experimental.key_reuse") + with config.debug_key_reuse(True): def f(seed): key = jax.random.key(seed) return jax.random.uniform(key) + jax.random.normal(key) with jax.enable_checks(True): - with self.assertRaises(KeyReuseError): + with self.assertRaises(jax.errors.KeyReuseError): jax.jit(f)(0) diff --git a/tests/custom_linear_solve_test.py b/tests/custom_linear_solve_test.py index 2c3d2a258a56..830526826059 100644 --- a/tests/custom_linear_solve_test.py +++ b/tests/custom_linear_solve_test.py @@ -28,8 +28,7 @@ import jax.numpy as jnp # scan tests use numpy import jax.scipy as jsp -from jax import config -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() def high_precision_dot(a, b): diff --git a/tests/custom_object_test.py b/tests/custom_object_test.py index 036f912a3c7a..75ff39630705 100644 --- a/tests/custom_object_test.py +++ b/tests/custom_object_test.py @@ -18,9 +18,9 @@ import numpy as np +import jax import jax.numpy as jnp from jax import jit, lax, make_jaxpr -from jax import config from jax.interpreters import mlir from jax.interpreters import xla @@ -34,7 +34,7 @@ xc = xla_client xb = xla_bridge -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() # TODO(jakevdp): use a setup/teardown method to populate and unpopulate all the # dictionaries associated with the following objects. diff --git a/tests/custom_root_test.py b/tests/custom_root_test.py index 88dee90aad9c..6a7eaab17657 100644 --- a/tests/custom_root_test.py +++ b/tests/custom_root_test.py @@ -25,8 +25,7 @@ import jax.numpy as jnp # scan tests use numpy import jax.scipy as jsp -from jax import config -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() def high_precision_dot(a, b): diff --git a/tests/debug_nans_test.py b/tests/debug_nans_test.py index ff49abba8798..1cc342e26784 100644 --- a/tests/debug_nans_test.py +++ b/tests/debug_nans_test.py @@ -23,23 +23,15 @@ from jax._src import api from jax._src import test_util as jtu from jax import numpy as jnp -from jax.experimental import pjit, maps +from jax.experimental import pjit +from jax._src.maps import xmap -from jax import config -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() +@jtu.with_config(jax_debug_nans=True) class DebugNaNsTest(jtu.JaxTestCase): - def setUp(self): - super().setUp() - self.cfg = config._read("jax_debug_nans") - config.update("jax_debug_nans", True) - - def tearDown(self): - config.update("jax_debug_nans", self.cfg) - super().tearDown() - def testSinc(self): # Regression test for #6936 self.assertEqual(jnp.sinc(0.0), 1.0) @@ -65,8 +57,8 @@ def testJitComputationNaN(self): ans = jax.jit(lambda x: 0. / x)(A) ans.block_until_ready() + @jax.debug_nans(False) def testJitComputationNaNContextManager(self): - config.update("jax_debug_nans", False) A = jnp.array(0.) f = jax.jit(lambda x: 0. / x) ans = f(A) @@ -125,7 +117,7 @@ def testPmapNoNaN(self): @jtu.ignore_warning(message=".*is an experimental.*") def testXmap(self): - f = maps.xmap( + f = xmap( lambda x: 0. / x, in_axes=["i"], out_axes=["i"], @@ -205,17 +197,9 @@ def f(x, y): jax.jit(f)(inp, inp) +@jtu.with_config(jax_debug_infs=True) class DebugInfsTest(jtu.JaxTestCase): - def setUp(self): - super().setUp() - self.cfg = config._read("jax_debug_infs") - config.update("jax_debug_infs", True) - - def tearDown(self): - config.update("jax_debug_infs", self.cfg) - super().tearDown() - def testSingleResultPrimitiveNoInf(self): A = jnp.array([[1., 2.], [2., 3.]]) ans = jnp.tanh(A) diff --git a/tests/debugger_test.py b/tests/debugger_test.py index 0faebc668c0b..18693a7bb2c3 100644 --- a/tests/debugger_test.py +++ b/tests/debugger_test.py @@ -13,6 +13,7 @@ # limitations under the License. from collections.abc import Sequence +import contextlib import io import re import textwrap @@ -21,14 +22,13 @@ from absl.testing import absltest import jax -from jax import config from jax.experimental import pjit from jax._src import debugger from jax._src import test_util as jtu import jax.numpy as jnp import numpy as np -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() def make_fake_stdin_stdout(commands: Sequence[str]) -> tuple[IO[str], io.StringIO]: fake_stdin = io.StringIO() @@ -41,16 +41,13 @@ def make_fake_stdin_stdout(commands: Sequence[str]) -> tuple[IO[str], io.StringI def _format_multiline(text): return textwrap.dedent(text).lstrip() -prev_xla_flags = None +_exit_stack = contextlib.ExitStack() def setUpModule(): - global prev_xla_flags - # This will control the CPU devices. On TPU we always have 2 devices - prev_xla_flags = jtu.set_host_platform_device_count(2) + _exit_stack.enter_context(jtu.set_host_platform_device_count(2)) -# Reset to previous configuration in case other test modules will be run. def tearDownModule(): - prev_xla_flags() + _exit_stack.close() foo = 2 @@ -111,7 +108,7 @@ def f(x): return y expected = _format_multiline(r""" Entering jdb: - (jdb) array(2., dtype=float32) + (jdb) Array(2., dtype=float32) (jdb) """) f(jnp.array(2., jnp.float32)) jax.effects_barrier() @@ -127,7 +124,7 @@ def f(x): return y expected = _format_multiline(r""" Entering jdb: - (jdb) (array(2., dtype=float32), array(3., dtype=float32)) + (jdb) (Array(2., dtype=float32), Array(3., dtype=float32)) (jdb) """) f(jnp.array(2., jnp.float32)) jax.effects_barrier() @@ -197,7 +194,7 @@ def g\(x\): -> y = f\(x\) return jnp\.exp\(y\) .* - \(jdb\) array\(2\., dtype=float32\) + \(jdb\) Array\(2\., dtype=float32\) \(jdb\) > .*debugger_test\.py\([0-9]+\) def f\(x\): y = jnp\.sin\(x\) @@ -226,9 +223,9 @@ def g(x): return jnp.exp(y) expected = _format_multiline(r""" Entering jdb: - (jdb) array(3., dtype=float32) + (jdb) Array(3., dtype=float32) (jdb) Entering jdb: - (jdb) array(6., dtype=float32) + (jdb) Array(6., dtype=float32) (jdb) """) g(jnp.array(2., jnp.float32)) jax.effects_barrier() @@ -237,14 +234,9 @@ def g(x): def test_debugger_works_with_vmap(self): stdin, stdout = make_fake_stdin_stdout(["p y", "c", "p y", "c"]) - # On TPU, the breakpoints can be reordered inside of vmap but can be fixed - # by ordering sends. - # TODO(sharadmv): change back to ordered = False when sends are ordered - ordered = jax.default_backend() == "tpu" - def f(x): y = x + 1. - debugger.breakpoint(stdin=stdin, stdout=stdout, ordered=ordered, + debugger.breakpoint(stdin=stdin, stdout=stdout, ordered=True, backend="cli") return 2. * y @@ -255,9 +247,9 @@ def g(x): return jnp.exp(y) expected = _format_multiline(r""" Entering jdb: - (jdb) array(1., dtype=float32) + (jdb) Array(1., dtype=float32) (jdb) Entering jdb: - (jdb) array(2., dtype=float32) + (jdb) Array(2., dtype=float32) (jdb) """) g(jnp.arange(2., dtype=jnp.float32)) jax.effects_barrier() @@ -280,9 +272,9 @@ def g(x): return jnp.exp(y) expected = _format_multiline(r""" Entering jdb: - \(jdb\) array\(.*, dtype=float32\) + \(jdb\) Array\(.*, dtype=float32\) \(jdb\) Entering jdb: - \(jdb\) array\(.*, dtype=float32\) + \(jdb\) Array\(.*, dtype=float32\) \(jdb\) """) g(jnp.arange(2., dtype=jnp.float32)) jax.effects_barrier() @@ -308,7 +300,7 @@ def g(x): out_shardings=jax.sharding.PartitionSpec("dev"), ) with jax.sharding.Mesh(np.array(jax.devices()), ["dev"]): - arr = (1 + np.arange(8)).astype(np.int32) + arr = (1 + jnp.arange(8)).astype(np.int32) expected = _format_multiline(r""" Entering jdb: \(jdb\) {} diff --git a/tests/debugging_primitives_test.py b/tests/debugging_primitives_test.py index f61417c16af0..c00253385792 100644 --- a/tests/debugging_primitives_test.py +++ b/tests/debugging_primitives_test.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import collections +import contextlib import functools import textwrap import unittest @@ -19,14 +20,13 @@ from absl.testing import absltest import jax from jax import lax -from jax import config -from jax.experimental import maps from jax.experimental import pjit from jax.interpreters import pxla from jax._src import ad_checkpoint from jax._src import debugging from jax._src import dispatch from jax._src import test_util as jtu +from jax._src.maps import xmap import jax.numpy as jnp import numpy as np @@ -35,23 +35,20 @@ except ModuleNotFoundError: rich = None -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() debug_print = debugging.debug_print def _format_multiline(text): return textwrap.dedent(text).lstrip() -prev_xla_flags = None +_exit_stack = contextlib.ExitStack() def setUpModule(): - global prev_xla_flags - # This will control the CPU devices. On TPU we always have 2 devices - prev_xla_flags = jtu.set_host_platform_device_count(2) + _exit_stack.enter_context(jtu.set_host_platform_device_count(2)) -# Reset to previous configuration in case other test modules will be run. def tearDownModule(): - prev_xla_flags() + _exit_stack.close() class DummyDevice: def __init__(self, platform, id): @@ -171,7 +168,7 @@ def f(x): with jtu.capture_stdout() as output: f(np.array(2, np.int32)) jax.effects_barrier() - self.assertEqual(output(), f"x: {str(dict(foo=np.array(2, np.int32)))}\n") + self.assertEqual(output(), f"x: {str(dict(foo=jnp.array(2, np.int32)))}\n") def test_debug_print_should_use_default_layout(self): data = np.array( @@ -795,7 +792,7 @@ def foo(x): idx = lax.axis_index('foo') debug_print("{idx}: {x}", idx=idx, x=x) return jnp.mean(x, axis=['foo']) - out = maps.xmap(foo, in_axes=['foo'], out_axes=[...])(x) + out = xmap(foo, in_axes=['foo'], out_axes=[...])(x) debug_print("Out: {}", out) return out mesh = jax.sharding.Mesh(np.array(jax.devices()), ['dev']) @@ -808,13 +805,14 @@ def foo(x): lines = ["0: 0", "1: 2", "2: 4", "3: 6", "4: 8", "5: 10", "6: 12", "7: 14", "Out: 7.0", ""] jax.effects_barrier() - self._assertLinesEqual(output(), "\n".join(lines)) + + self._assertLinesEqual(output(), "\n".join(lines)) def test_unordered_print_with_xmap(self): def f(x): debug_print("{}", x, ordered=False) - f = maps.xmap(f, in_axes=['a'], out_axes=None, backend='cpu', - axis_resources={'a': 'dev'}) + f = xmap(f, in_axes=['a'], out_axes=None, backend='cpu', + axis_resources={'a': 'dev'}) with jax.sharding.Mesh(np.array(jax.devices()), ['dev']): with jtu.capture_stdout() as output: f(np.arange(40)) diff --git a/tests/deprecation_test.py b/tests/deprecation_test.py index 02d9c427c9a1..2476a6208e91 100644 --- a/tests/deprecation_test.py +++ b/tests/deprecation_test.py @@ -15,12 +15,13 @@ import warnings from absl.testing import absltest +from jax._src import deprecations from jax._src import test_util as jtu from jax._src.internal_test_util import deprecation_module as m class DeprecationTest(absltest.TestCase): - def testDeprecation(self): + def testModuleDeprecation(self): with warnings.catch_warnings(): warnings.simplefilter("error") self.assertEqual(m.x, 42) @@ -35,6 +36,24 @@ def testDeprecation(self): "module .* has no attribute 'w'"): _ = m.w + def testNamedDeprecation(self): + some_unique_id = "some-unique-id" + try: + deprecations.register(some_unique_id) + self.assertFalse(deprecations.is_accelerated(some_unique_id)) + deprecations.accelerate(some_unique_id) + self.assertTrue(deprecations.is_accelerated(some_unique_id)) + finally: + deprecations.unregister(some_unique_id) + + msg = f"deprecation_id={some_unique_id!r} not registered" + with self.assertRaisesRegex(ValueError, msg): + deprecations.accelerate(some_unique_id) + with self.assertRaisesRegex(ValueError, msg): + deprecations.is_accelerated(some_unique_id) + with self.assertRaisesRegex(ValueError, msg): + deprecations.unregister(some_unique_id) + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/dtypes_test.py b/tests/dtypes_test.py index 4f563876cb91..4712e6aec652 100644 --- a/tests/dtypes_test.py +++ b/tests/dtypes_test.py @@ -27,6 +27,7 @@ import jax from jax import numpy as jnp +from jax._src import earray from jax._src import config from jax._src import dtypes from jax._src import test_util as jtu @@ -66,8 +67,8 @@ all_dtypes = (bool_dtypes + signed_dtypes + unsigned_dtypes + float_dtypes + complex_dtypes) -scalar_types = [jnp.bool_, jnp.int8, jnp.int16, jnp.int32, jnp.int64, - jnp.uint8, jnp.uint16, jnp.uint32, jnp.uint64, +scalar_types = [jnp.bool_, jnp.int4, jnp.int8, jnp.int16, jnp.int32, jnp.int64, + jnp.uint4, jnp.uint8, jnp.uint16, jnp.uint32, jnp.uint64, jnp.bfloat16, jnp.float16, jnp.float32, jnp.float64, jnp.complex64, jnp.complex128] @@ -93,6 +94,7 @@ _EXPECTED_CANONICALIZE_X32[np.longlong] = np.int32 UINT_DTYPES = { + 4: jnp.uint4, 8: np.uint8, 16: np.uint16, 32: np.uint32, @@ -283,13 +285,15 @@ def testIsSubdtype(self): self.assertTrue(dtypes.issubdtype(np.dtype(t).type, t)) self.assertTrue(dtypes.issubdtype(t, np.dtype(t).type)) self.assertTrue(dtypes.issubdtype(t, np.dtype(t))) - if t != jnp.bfloat16: - for category in [np.generic, jnp.inexact, jnp.integer, jnp.signedinteger, - jnp.unsignedinteger, jnp.floating, jnp.complexfloating]: - self.assertEqual(dtypes.issubdtype(t, category), - np.issubdtype(np.dtype(t).type, category)) - self.assertEqual(dtypes.issubdtype(t, category), - np.issubdtype(np.dtype(t).type, category)) + if t in [jnp.int4, jnp.uint4, jnp.bfloat16]: + # These dtype have no equivalent in NumPy. + continue + for category in [np.generic, jnp.inexact, jnp.integer, jnp.signedinteger, + jnp.unsignedinteger, jnp.floating, jnp.complexfloating]: + self.assertEqual(dtypes.issubdtype(t, category), + np.issubdtype(np.dtype(t).type, category)) + self.assertEqual(dtypes.issubdtype(t, category), + np.issubdtype(np.dtype(t).type, category)) def testIsSubdtypeExtended(self): self.assertTrue(dtypes.issubdtype(dtypes.extended, dtypes.extended)) @@ -554,6 +558,73 @@ def inner_bwd(prev_scale, grads): _, new_scale = jax.jit(jax.grad(outer, (0, 1)))(jnp.float32(3.14), scale) self.assertAllClose(new_scale, jnp.float32(1.0)) + def test_check_dtype_non_hashable(self): + # regression test for issue with checking non-hashable custom dtype + class MyDtype: + __hash__ = None + dtype = np.dtype('float32') + dtypes.check_user_dtype_supported(MyDtype()) + + def test_check_dtype_array(self): + x = jnp.arange(4) + msg = "Passing an array as a dtype argument is deprecated" + with self.assertWarnsRegex(DeprecationWarning, msg): + dtypes.check_user_dtype_supported(x) + with self.assertWarnsRegex(DeprecationWarning, msg): + jax.jit(dtypes.check_user_dtype_supported)(x) + + +class EArrayTest(jtu.JaxTestCase): + + @parameterized.parameters([True, False]) + def test_extended_dtypes_at_rest(self, jit): + # Test a trivial isomorphic-to-float32 extended dtype working with EArray + from jax._src import core + from jax._src.interpreters import pxla + + class foo(dtypes.extended): pass + + class FooTyRules: + + @staticmethod + def convert_to(foo_dtype, target_dtype): + return True + + @staticmethod + def physical_element_aval(foo_dtype): + return core.ShapedArray((), dtypes.dtype('float32')) + + @staticmethod + def global_sharded_result_handler(aval, out_sharding, committed): + phys_sharding = out_sharding # unlike KeyTyRules, assume same shape + phys_aval = core.physical_aval(aval) + phys_handler_maker = pxla.global_result_handlers[core.ShapedArray] + phys_handler = phys_handler_maker(phys_aval, phys_sharding, committed) + return lambda bufs: earray.EArray(aval, phys_handler(bufs)) + + @dataclasses.dataclass(frozen=True) + class FooTy(dtypes.ExtendedDType): + name: str = 'foo' + _rules: type = FooTyRules + type: type = foo + + # Can we make one? + def f(x): + return jax.lax.convert_element_type(x, FooTy()) + if jit: + f = jax.jit(f) + x = f(jnp.arange(3, dtype='float32')) # don't crash + self.assertIsInstance(x.dtype, FooTy) + + # Can we consume one? + def g(x): + self.assertIsInstance(x.dtype, FooTy) + return x + if jit: + g = jax.jit(g) + y = g(x) + self.assertIsInstance(y.dtype, FooTy) + class TestPromotionTables(jtu.JaxTestCase): diff --git a/tests/dynamic_api_test.py b/tests/dynamic_api_test.py index c704d7e10b16..7aa1e18b2fce 100644 --- a/tests/dynamic_api_test.py +++ b/tests/dynamic_api_test.py @@ -23,7 +23,6 @@ import jax import jax.numpy as jnp from jax import lax -from jax import config from jax.interpreters import batching import jax._src.lib @@ -31,7 +30,7 @@ from jax._src import core from jax._src import test_util as jtu -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() @jtu.with_config(jax_dynamic_shapes=True, jax_numpy_rank_promotion="allow") @@ -233,7 +232,7 @@ def test_closing_over_dynamic_shape(self): def f(n): m = 2 * n x = jnp.zeros(m) - return jax.jit(lambda: x)() + return jax.jit(jnp.sin)(x) # { lambda ; a:i32[]. let # b:i32[] = mul a 2 @@ -518,7 +517,7 @@ def test_jit_abstracted_axes_staging3(self): self.assertIs(e.aval.shape[0], d) def test_jit_abstracted_axes_return_polymorphic_shape(self): - f = jax.jit(lambda x: x, abstracted_axes=('n',)) + f = jax.jit(lambda x: jnp.sin(x), abstracted_axes=('n',)) jaxpr = jax.make_jaxpr(f)(jnp.arange(3)) # doesn't crash # { lambda ; a:i32[3]. let # b:i32[3] = pjit[ @@ -629,6 +628,20 @@ def test_shape_validation(self): with self.assertRaisesRegex(TypeError, msg): jax.make_jaxpr(jnp.ones)(jnp.ones((2, 2))) + def test_matmul_two_arg(self): + def f(x, y): + return jnp.matmul(x, y) + + jaxpr = jax.make_jaxpr(f, abstracted_axes=({0: 'a_0', 1: 'a_1'}, {0: 'a_1', 1: 'a_2'}),)(jnp.ones((8, 4)), jnp.ones((4, 8))) + + def test_matmul_two_arg_size_mismatch_name_validation(self): + def f(x, y): + return jnp.matmul(x, y) + + with self.assertRaisesRegex(TypeError, + re.escape("Provided size 4 for a_1 does not match prior associated name for a_1 : 8")): + jaxpr = jax.make_jaxpr(f, abstracted_axes=({0: 'a_0', 1: 'a_1'}, {0: 'a_1', 1: 'a_2'}),)(jnp.ones((8, 4)), jnp.ones((8, 4))) + @unittest.skip("Test does not work with jax.Array") @jtu.with_config(jax_dynamic_shapes=True, jax_numpy_rank_promotion="allow") class DynamicShapeAutodiffTest(jtu.JaxTestCase): @@ -1078,15 +1091,13 @@ def f(x, y): @unittest.skip("Test does not work with jax.Array") @jtu.with_config(jax_dynamic_shapes=True, jax_numpy_rank_promotion="allow") class DynamicShapeExecutionTest(jtu.JaxTestCase): - @jtu.run_on_devices("iree") - def test_jit_basic_iree(self): + def test_jit_basic(self): @jax.jit def f(i): return jnp.sum(jnp.ones(i, dtype='float32')) self.assertAllClose(f(3), jnp.array(3., dtype='float32'), check_dtypes=True) - @jtu.run_on_devices("iree") - def test_jit_basic_iree_2(self): + def test_jit_basic_2(self): count = 0 @partial(jax.jit, abstracted_axes=('n',)) @@ -1101,9 +1112,8 @@ def f(x): self.assertAllClose(y, 6., check_dtypes=False) self.assertEqual(count, 1) - @jtu.run_on_devices("iree") - def test_jit_polymorphic_output_iree(self): - # like test_jit_basic_iree, but without the jnp.sum! + def test_jit_polymorphic_output(self): + # like test_jit_basic, but without the jnp.sum! count = 0 @jax.jit @@ -1125,7 +1135,6 @@ def f(x): # x: f32[n, 4] f(np.ones((5, 4), dtype=np.float32)) # TODO: add assertions - @jtu.run_on_devices("iree") def test_reshape(self): @partial(jax.jit, abstracted_axes=({0: 'n'},)) def f(x): # x: f32[n, 4] @@ -1134,7 +1143,6 @@ def f(x): # x: f32[n, 4] f(np.ones((5, 4), dtype=np.float32)) # TODO: add assertions - @jtu.run_on_devices("iree") def test_nested(self): @jax.jit def nested_f(x): # f32[h, v] -> f32[h, v] @@ -1147,7 +1155,6 @@ def f(x): # f32[h, w] -> f32[h, w] f(np.ones((3, 5), dtype=np.float32)) # TODO: add assertions - @jtu.run_on_devices("iree") def test_nested_arange(self): def nested_f(x): # f32[h, v] -> f32[h, v] # A nested call that needs to compute with shapes @@ -1159,7 +1166,6 @@ def f(x): # f32[h, w] -> f32[h, w] f(np.ones((3, 5), dtype=np.float32)) # TODO: add assertions - @jtu.run_on_devices("iree") def test_transpose(self): # see also https://github.com/iree-org/iree-jax/issues/57 @partial(jax.jit, abstracted_axes=({0: 'h', 1: 'w'},)) @@ -1169,7 +1175,6 @@ def f(x): # f32[h, w] -> f32[w, h] f(np.ones((3, 5), dtype=np.float32)) # doesn't crash # TODO: add assertions - @jtu.run_on_devices("iree") def test_matmul(self): @partial(jax.jit, abstracted_axes=({0: 'w', 1: 'w'},)) def f(x): # f32[w, w] -> f32[w, w] @@ -1178,7 +1183,6 @@ def f(x): # f32[w, w] -> f32[w, w] f(np.ones((5, 5), dtype=np.float32)) # TODO: add assertions - @jtu.run_on_devices("iree") def test_matmul_shape_error(self): @partial(jax.jit, abstracted_axes=({0: 'h', 1: 'w'},)) def f(x): # f32[h, w] -> error @@ -1189,7 +1193,6 @@ def f(x): # f32[h, w] -> error re.escape("dot_general requires contracting dimensions to have the same shape, got")): f(np.ones((5, 5), dtype=np.float32)) - @jtu.run_on_devices("iree") @unittest.skip("TODO: investigate failure") def test_cond(self): @partial(jax.jit, abstracted_axes=({0: 'w', 1: 'w'},)) @@ -1200,7 +1203,6 @@ def f(x): # f32[w, w] -> f32[w, w] f(np.ones((5, 5), dtype=np.float32)) # TODO: add assertions - @jtu.run_on_devices("iree") def test_arange(self): @partial(jax.jit, abstracted_axes=({0: 'w'},)) def f(x): # f32[w] -> f32[w] @@ -1208,8 +1210,6 @@ def f(x): # f32[w] -> f32[w] f(np.ones((5,), dtype=np.float32)) # TODO: add assertions - @unittest.skip('failing w/ iree error') - @jtu.run_on_devices("iree") def test_broadcast(self): @partial(jax.jit, abstracted_axes=({0: 'w'},)) def f(x): # f32[w] -> f32[w, w] @@ -1217,7 +1217,6 @@ def f(x): # f32[w] -> f32[w, w] f(np.ones((5,), dtype=np.float32)) # TODO: add assertions - @jtu.run_on_devices("iree") def test_zeros(self): @partial(jax.jit, abstracted_axes=({0: 'w'},)) def f(x): # f32[w] -> f32[w] @@ -1225,8 +1224,6 @@ def f(x): # f32[w] -> f32[w] f(np.ones((5,), dtype=np.float32)) # TODO: add assertions - @unittest.skip('failing w/ iree error') - @jtu.run_on_devices("iree") def test_stack(self): @partial(jax.jit, abstracted_axes=({0: 'w'},)) def f(x): @@ -1235,8 +1232,7 @@ def f(x): f(np.ones((5,), dtype=np.float32)) # TODO: add assertions - @jtu.run_on_devices("iree") - def test_jit_dependent_pair_output_iree(self): + def test_jit_dependent_pair_output(self): # Like the above 'polymorhpic output' test, but now with a `2 * n`! count = 0 @@ -1276,18 +1272,15 @@ def body(i, _): expected = jnp.cumsum(x) self.assertAllClose(ans, expected, check_dtypes=False) - @jtu.run_on_devices("iree") def test_jit_of_broadcast(self): x = jax.jit(jnp.ones)(3) self.assertAllClose(x, jnp.ones(3)) - @jtu.run_on_devices("iree") def test_jit_of_broadcast2(self): x = jax.jit(lambda n: jnp.ones(2 * n))(3) self.assertAllClose(x, jnp.ones(2 * 3)) - @jtu.run_on_devices("iree") - def test_mlp_autodiff_dynamic_batch_iree(self): + def test_mlp_autodiff_dynamic_batch(self): count = 0 def predict(params, inputs): @@ -1467,7 +1460,6 @@ def f(x): return x[0] f.lower(jnp.zeros((3, 4))).compiler_ir() # doesn't crash - @jtu.run_on_devices("iree") def test_slicing_basic_execute(self): @partial(jax.jit, abstracted_axes=(None, 'n')) def f(x): diff --git a/tests/export_back_compat_test.py b/tests/export_back_compat_test.py index 8a44ac1699e9..d510aab9d789 100644 --- a/tests/export_back_compat_test.py +++ b/tests/export_back_compat_test.py @@ -19,6 +19,7 @@ import dataclasses from functools import partial import itertools +import logging import math from absl.testing import absltest, parameterized @@ -27,10 +28,10 @@ import jax from jax import lax -from jax.experimental.export import _export +from jax._src.export import _export + from jax._src.internal_test_util import export_back_compat_test_util as bctu -from jax._src.internal_test_util.export_back_compat_test_data import cpu_ducc_fft from jax._src.internal_test_util.export_back_compat_test_data import cpu_cholesky_lapack_potrf from jax._src.internal_test_util.export_back_compat_test_data import cpu_eig_lapack_geev from jax._src.internal_test_util.export_back_compat_test_data import cuda_eigh_cusolver_syev @@ -50,6 +51,7 @@ from jax._src.internal_test_util.export_back_compat_test_data import tpu_stablehlo_dynamic_reduce_window from jax._src.internal_test_util.export_back_compat_test_data import stablehlo_dynamic_rng_bit_generator from jax._src.internal_test_util.export_back_compat_test_data import stablehlo_dynamic_top_k +from jax._src.internal_test_util.export_back_compat_test_data import stablehlo_dynamic_approx_top_k from jax.experimental import pjit from jax.experimental.shard_map import shard_map @@ -60,13 +62,20 @@ from jax._src import config from jax._src import test_util as jtu -from jax._src.lib import version as jaxlib_version +from jax._src.lib import cuda_versions +from jax._src.lib import xla_extension_version config.parse_flags_with_absl() +def _is_required_cusolver_version_satisfied(required_version): + if cuda_versions is None: + return False + return cuda_versions.cusolver_get_version() >= required_version -@jtu.with_config(jax_legacy_prng_key='allow', - jax_enable_key_reuse_checks=False) +@jtu.with_config(jax_legacy_prng_key="allow", + jax_debug_key_reuse=False, + jax_include_full_tracebacks_in_locations=False, + jax_threefry_gpu_kernel_lowering=True) class CompatTest(bctu.CompatTestBase): def test_dummy(self): # Tests the testing mechanism. Let this test run on all platforms @@ -101,7 +110,6 @@ def test_custom_call_coverage(self): # Add here all the testdatas that should cover the targets guaranteed # stable covering_testdatas = [ - cpu_ducc_fft.data_2023_03_17, cpu_ducc_fft.data_2023_06_14, cpu_cholesky_lapack_potrf.data_2023_06_19, cpu_eig_lapack_geev.data_2023_06_19, cpu_eigh_lapack_syev.data_2023_03_17, @@ -119,6 +127,7 @@ def test_custom_call_coverage(self): stablehlo_dynamic_rng_bit_generator.data_2023_06_17, stablehlo_dynamic_top_k.data_2023_07_16, stablehlo_dynamic_top_k.data_2023_08_11, # with shape_assertion + stablehlo_dynamic_approx_top_k.data_2024_05_30, ] # Some of the above are nested structures. covering_testdatas = itertools.chain( @@ -131,6 +140,8 @@ def test_custom_call_coverage(self): covered_targets = covered_targets.union({ "tf.call_tf_function", # tested in jax2tf/tests/back_compat_tf_test.py "tpu_custom_call", # tested separately + "__gpu$xla.gpu.triton", # tested in pallas/export_back_compat_pallas_test.py + "cu_threefry2x32_ffi", # TODO(b/338022728) add the actual backwards compatibility test }) not_covered = targets_to_cover.difference(covered_targets) self.assertEmpty(not_covered, @@ -138,28 +149,6 @@ def test_custom_call_coverage(self): "stable but are not covered by any tests: " f"{not_covered}")) - def test_ducc_fft(self): - def func(x): - return lax.fft(x, fft_type="fft", fft_lengths=(4,)) - - # An old lowering, with ducc_fft. We keep it for 6 months. - data = self.load_testdata(cpu_ducc_fft.data_2023_03_17) - if jaxlib_version <= (0, 4, 20): - expect_current_custom_calls = ["dynamic_ducc_fft"] - else: - # We have changed the lowering for fft since we saved this data. - # FFT no longer lowers to a custom call. - expect_current_custom_calls = [] - - self.run_one_test(func, data, - expect_current_custom_calls=expect_current_custom_calls) - - # A newer lowering, with dynamic_ducc_fft. - data = self.load_testdata(cpu_ducc_fft.data_2023_06_14) - # FFT no longer lowers to a custom call. - self.run_one_test(func, data, - expect_current_custom_calls=expect_current_custom_calls) - def cholesky_input(self, shape, dtype): a = jtu.rand_default(self.rng())(shape, dtype) return np.matmul(a, np.conj(np.swapaxes(a, -1, -2))) @@ -303,6 +292,11 @@ def test_cpu_eigh_lapack_syevd(self, dtype_name="f32"): def test_cuda_eigh_cusolver_syev(self, dtype_name="f32", variant="syevj"): if not config.enable_x64.value and dtype_name == "f64": self.skipTest("Test disabled for x32 mode") + if (jtu.test_device_matches(["cuda"]) and + _is_required_cusolver_version_satisfied(11600)): + # The underlying problem is that this test assumes the workspace size can be + # queried from an older version of cuSOLVER and then be used in a newer one. + self.skipTest("Newer cuSOLVER expects a larger workspace than was serialized") # For lax.linalg.eigh dtype = dict(f32=np.float32, f64=np.float64)[dtype_name] size = dict(syevj=8, syevd=36)[variant] @@ -586,6 +580,8 @@ def func(): self.run_one_test(func, data) def test_cuda_threefry2x32(self): + logging.info("test_cuda_threefry2x32: xla_extension_version: %s", + xla_extension_version) def func(x): return jax.random.uniform(x, (2, 4), dtype=np.float32) @@ -595,7 +591,7 @@ def func(x): def test_sharding(self): # Tests "Sharding", "SPMDShardToFullShape", "SPMDFullToShardShape" on TPU if not jtu.test_device_matches(["tpu"]) or len(jax.devices()) < 2: - self.skipTest("Test runs only on TPU with at least 2 devices") + self.skipTest("Test runs only on TPU with at least 2 devices") # Must use exactly 2 devices for expected outputs from ppermute devices = jax.devices()[:2] @@ -720,6 +716,44 @@ def check_top_k_results(res_run, res_expected, *, rtol, atol): polymorphic_shapes=("_, b",), check_results=check_top_k_results) + def test_dynamic_approx_top_k(self): + # stablehlo.dynamic_approx_top_k is used temporarily for a approx_top_k + # with dynamism + # This is the input that was used to generate the test_data + _ = np.arange(24, dtype=np.float32) + + def func(a): # a: f32[b + 4] + return lax.approx_max_k(a, k=a.shape[0] - 4) + + data = self.load_testdata(stablehlo_dynamic_approx_top_k.data_2024_05_30) + + def check_top_k_results(res_run, res_expected, *, rtol, atol): + a = data.inputs[0] + # The order of the results may be different, but should be the same ones + values_expected, _ = res_expected + values_run, indices_run = res_run + # Check that indices are correct + self.assertAllClose( + values_run, + a[indices_run], + atol=atol, + rtol=rtol, + ) + self.assertAllClose( + np.sort(values_run), np.sort(values_expected), atol=atol, rtol=rtol + ) + + self.run_one_test( + func, + data, + polymorphic_shapes=("b + 4,",), + check_results=check_top_k_results, + expect_current_custom_calls=[ + "stablehlo.dynamic_approx_top_k", + "shape_assertion", + ], + ) + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/export_harnesses_multi_platform_test.py b/tests/export_harnesses_multi_platform_test.py index 92e3f0aa8fc6..e75961df363b 100644 --- a/tests/export_harnesses_multi_platform_test.py +++ b/tests/export_harnesses_multi_platform_test.py @@ -21,9 +21,9 @@ from __future__ import annotations +from collections.abc import Callable import math import re -from typing import Callable from absl import logging from absl.testing import absltest @@ -31,19 +31,25 @@ import numpy as np import jax +from jax import export from jax import lax +from jax._src import config from jax._src import test_util as jtu -from jax.experimental import export from jax._src.internal_test_util import test_harnesses +from jax._src.lib import version as jaxlib_version +from jax import random def make_disjunction_regexp(*parts: str) -> re.Pattern[str]: - return re.compile("(" + "|".join(parts) + ")") + if not parts: + return re.compile("matches_no_test") + else: + return re.compile("(" + "|".join(parts) + ")") # TODO(necula): Failures to be investigated (on GPU). _known_failures_gpu = make_disjunction_regexp( - # Failures due to failure to export custom call targets for GPU, these - # targets do not have backwards compatibility tests. + # Failures on GPU due to failure to export custom call targets, these + # involve GPU custom call targets withoutbackwards compatibility tests. "custom_linear_solve_", "lu_", "svd_", @@ -54,9 +60,11 @@ def make_disjunction_regexp(*parts: str) -> re.Pattern[str]: # CUDA lowering. _skip_cuda_lowering_unless_have_gpus = make_disjunction_regexp( "svd_", "lu_", "eigh_", "qr_", "custom_linear_", "tridiagonal_solve_", + # TODO(b/350111820): random should work once we enable FFI threefry2x32 "random_", ) + class PrimitiveTest(jtu.JaxTestCase): @classmethod @@ -84,7 +92,7 @@ def setUpClass(cls): @test_harnesses.parameterized( test_harnesses.all_harnesses, include_jax_unimpl=False, - #one_containing="", + # one_containing="", ) @jtu.ignore_warning( category=UserWarning, @@ -136,9 +144,12 @@ def export_and_compare_to_native( if d.platform not in unimplemented_platforms ] logging.info("Using devices %s", [str(d) for d in devices]) - # lowering_platforms uses "cuda" instead of "gpu" + # lowering_platforms uses "cuda" or "rocm" instead of "gpu" + gpu_platform = "cuda" + if jtu.is_device_rocm(): + gpu_platform = "rocm" lowering_platforms: list[str] = [ - p if p != "gpu" else "cuda" + p if p != "gpu" else gpu_platform for p in ("cpu", "gpu", "tpu") if p not in unimplemented_platforms ] @@ -149,7 +160,8 @@ def export_and_compare_to_native( ) logging.info("Exporting harness for %s", lowering_platforms) - exp = export.export(func_jax, lowering_platforms=lowering_platforms)(*args) + exp = export.export(jax.jit(func_jax), + lowering_platforms=lowering_platforms)(*args) for device in devices: if device.platform in skip_run_on_platforms: @@ -161,7 +173,7 @@ def export_and_compare_to_native( logging.info("Running harness natively on %s", device) native_res = func_jax(*device_args) logging.info("Running exported harness on %s", device) - exported_res = export.call_exported(exp)(*device_args) + exported_res = exp.call(*device_args) if tol is not None: logging.info(f"Using non-standard tolerance {tol}") self.assertAllClose(native_res, exported_res, atol=tol, rtol=tol) @@ -193,6 +205,22 @@ def test_all_gather(self, *, dtype): x = (x % 2).astype(np.bool_) self.export_and_compare_to_native(f, x) + def test_random_with_threefry_gpu_kernel_lowering(self): + # TODO(b/350111820): fix the FFI registration mechanism + self.skipTest("b/350111820: fix the FFI registration mechanism") + if jaxlib_version < (0, 4, 31): + self.skipTest("jaxlib.version < 0.4.31") + # On GPU we use a custom call for threefry2x32 + with config.threefry_gpu_kernel_lowering(True): + # TODO(b/338022728): clean up forward compatibility mode. + with config.export_ignore_forward_compatibility(True): + def f(x): + return random.gamma(random.key(42), x) + + shape = (4, 5) + x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape) + self.export_and_compare_to_native(f, x) + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/export_test.py b/tests/export_test.py index e8106deae29b..64ff806af617 100644 --- a/tests/export_test.py +++ b/tests/export_test.py @@ -13,6 +13,7 @@ # limitations under the License. from __future__ import annotations +from collections.abc import Callable, Sequence import contextlib import dataclasses import functools @@ -25,8 +26,7 @@ import jax from jax import lax from jax import numpy as jnp -from jax.experimental import export -from jax.experimental.export import _export +from jax import export from jax.experimental import pjit from jax.experimental.shard_map import shard_map from jax.sharding import NamedSharding @@ -45,17 +45,22 @@ import numpy as np +# ruff: noqa: F401 +try: + import flatbuffers + CAN_SERIALIZE = True +except (ModuleNotFoundError, ImportError): + CAN_SERIALIZE = False + config.parse_flags_with_absl() -prev_xla_flags = None +_exit_stack = contextlib.ExitStack() + def setUpModule(): - global prev_xla_flags - # This will control the CPU devices. On TPU we always have 2 devices - prev_xla_flags = jtu.set_host_platform_device_count(2) + _exit_stack.enter_context(jtu.set_host_platform_device_count(2)) -# Reset to previous configuration in case other test modules will be run. def tearDownModule(): - prev_xla_flags() + _exit_stack.close() ### Setup for testing lowering with effects @dataclasses.dataclass(frozen=True) @@ -140,24 +145,22 @@ def _testing_multi_platform_fun_expected(x, ] -def get_exported(fun, vjp_order=0, - **export_kwargs): +def get_exported(fun: Callable, vjp_order=0, + **export_kwargs) -> Callable[[...], export.Exported]: """Like export.export but with serialization + deserialization.""" def serde_exported(*fun_args, **fun_kwargs): exp = export.export(fun, **export_kwargs)(*fun_args, **fun_kwargs) - serialized = export.serialize(exp, vjp_order=vjp_order) - return export.deserialize(serialized) + if CAN_SERIALIZE: + serialized = exp.serialize(vjp_order=vjp_order) + return export.deserialize(serialized) + else: + return exp return serde_exported -class JaxExportTest(jtu.JaxTestCase): - def override_serialization_version(self, version_override: int): - version = config.jax_serialization_version.value - if version != version_override: - self.enter_context(config.jax_serialization_version(version_override)) - logging.info( - "Using JAX serialization version %s", - config.jax_serialization_version.value) +# Run tests with the maximum supported version by default +@jtu.with_config(jax_export_calling_convention_version=export.maximum_supported_calling_convention_version) +class JaxExportTest(jtu.JaxTestCase): @classmethod def setUpClass(cls): @@ -171,19 +174,15 @@ def setUpClass(cls): cls.platforms.append(backend) super().setUpClass() - def setUp(self): - super().setUp() - # Run tests with the maximum supported version by default - self.override_serialization_version( - export.maximum_supported_serialization_version) - def test_basic_export_only(self): + @jax.jit def my_fun(x): return jnp.sin(x) exp = get_exported(my_fun)(jax.ShapeDtypeStruct((4,), dtype=np.float32)) self.assertEqual("my_fun", exp.fun_name) - self.assertEqual((export.default_lowering_platform(),), - exp.lowering_platforms) + expected_lowering_platform = xb.canonicalize_platform(jax.default_backend()) + self.assertEqual((expected_lowering_platform,), + exp.platforms) self.assertEqual(jax.tree.flatten(((1,), {}))[1], exp.in_tree) self.assertEqual((core.ShapedArray((4,), dtype=np.float32),), exp.in_avals) self.assertEqual((core.ShapedArray((4,), dtype=np.float32),), exp.out_avals) @@ -194,10 +193,10 @@ def test_pytree_export_only(self): def f(a_b_pair, *, a, b): return (dict(res=a_b_pair, a=a, b=b), jnp.sin(a), jnp.cos(b)) - exp = get_exported(f, lowering_platforms=("cpu",))((a, b), a=a, b=b) + exp = get_exported(jax.jit(f), lowering_platforms=("cpu",))((a, b), a=a, b=b) a_aval = core.ShapedArray(a.shape, a.dtype) b_aval = core.ShapedArray(b.shape, b.dtype) - self.assertEqual(exp.lowering_platforms, ("cpu",)) + self.assertEqual(exp.platforms, ("cpu",)) args = ((a, b),) kwargs = dict(a=a, b=b) self.assertEqual(exp.in_tree, jax.tree.flatten((args, kwargs))[1]) @@ -210,16 +209,78 @@ def test_basic(self): x = np.arange(4, dtype=np.float32) exp_f = get_exported(f)(x) - f1 = export.call_exported(exp_f) - self.assertAllClose(f(x), f1(x)) + self.assertAllClose(f(x), exp_f.call(x)) + + def test_jit_static_arg(self): + + with self.subTest("static_argnames"): + + @functools.partial(jax.jit, static_argnames=["c"]) + def f(x, *, c): + return c * jnp.sin(x) + + x = np.arange(4, dtype=np.float32) + exp_f = get_exported(f)(x, c=0.1) + + self.assertAllClose(f(x, c=0.1), exp_f.call(x)) + + with self.subTest("static_argnums"): + + @functools.partial(jax.jit, static_argnums=[1]) + def g(x, c): + return c * jnp.sin(x) + + x = np.arange(4, dtype=np.float32) + exp_g = get_exported(g)(x, 0.1) + + self.assertAllClose(g(x, 0.1), exp_g.call(x)) + + def test_export_error_no_jit(self): + # Can export a lambda, without jit + with self.assertRaisesRegex(ValueError, + "Function to be exported must be the result of `jit`"): + _ = export.export(lambda x: jnp.sin(x)) + + @jtu.ignore_warning(category=DeprecationWarning, + message="The jax.experimental.export module is deprecated") + def test_export_experimental_back_compat(self): + if not CAN_SERIALIZE: + self.skipTest("serialization disabled") + from jax.experimental import export + # Can export a lambda, without jit + exp = export.export(lambda x: jnp.sin(x))(.1) + self.assertAllClose(exp.call(1.), np.sin(1.)) + + blob = export.serialize(exp, vjp_order=1) + rehydrated = export.deserialize(blob) + + self.assertAllClose(export.call(exp)(1.), np.sin(1.)) + self.assertAllClose(export.call_exported(exp)(1.), np.sin(1.)) def test_call_exported_lambda(self): # When we export a lambda, the exported.fun_name is not a valid MLIR function name - f = lambda x: jnp.sin(x) + f = jax.jit(lambda x: jnp.sin(x)) x = np.arange(4, dtype=np.float32) exp_f = get_exported(f)(x) - f1 = export.call_exported(exp_f) - self.assertAllClose(f(x), f1(x)) + self.assertAllClose(f(x), exp_f.call(x)) + + def test_call_name_conflict(self): + @jax.jit + def inner(x): + # The lowering will contain a _where private function + return jnp.where(x > 0, jnp.ones_like(x), jnp.zeros_like(x)) + + x = jnp.arange(-20, 20, dtype=np.int32) + exp_inner = export.export(inner)(x) + self.assertIn("@_where(", str(exp_inner.mlir_module())) + + @jax.jit + def outer(x): + # There should be no conflict on _where + x = exp_inner.call(x) + return inner(x) + + export.export(outer)(x) def test_call_twice_exported(self): def f(x): return jnp.sin(x) @@ -227,19 +288,18 @@ def f(x): return jnp.sin(x) @jax.jit def f1(x): - exp_f = get_exported(f)(x) - return export.call_exported(exp_f)(x) + export.call_exported(exp_f)(x) + exp_f = get_exported(jax.jit(f))(x) + return exp_f.call(x) + exp_f.call(x) self.assertAllClose(2. * f(x), f1(x)) def test_unused_args(self): - f = lambda x, y: jnp.sin(x) + f = jax.jit(lambda x, y: jnp.sin(x)) x = np.arange(4, dtype=np.float32) y = np.arange(6, dtype=np.float32) exp_f = get_exported(f)(x, y) - f1 = export.call_exported(exp_f) - self.assertAllClose(f(x, y), f1(x, y)) + self.assertAllClose(f(x, y), exp_f.call(x, y)) def test_pytree(self): a = np.arange(4, dtype=np.float32) @@ -247,43 +307,50 @@ def test_pytree(self): def f(a_b_pair, a, b): return (dict(res=a_b_pair, a=a, b=b), jnp.sin(a), jnp.cos(b)) - exp_f = get_exported(f)((a, b), a=a, b=b) - f1 = export.call_exported(exp_f) + exp_f = get_exported(jax.jit(f))((a, b), a=a, b=b) self.assertAllClose(f((a, b), a=a, b=b), - f1((a, b), a=a, b=b)) + exp_f.call((a, b), a=a, b=b)) def test_error_wrong_intree(self): def f(a_b_pair, *, c): return jnp.sin(a_b_pair[0]) + jnp.cos(a_b_pair[1]) + c a = b = c = np.arange(4, dtype=np.float32) - exp_f = get_exported(f)((a, b), c=c) + exp_f = get_exported(jax.jit(f))((a, b), c=c) with self.assertRaisesRegex( ValueError, "The invocation args and kwargs must have the same pytree structure"): - export.call_exported(exp_f)(a, b, c=(a, b)) + exp_f.call(a, b, c=(a, b)) def test_error_wrong_avals(self): def f(a, *, b): # a: f32[4] and b: f32[4] return jnp.sin(a) + jnp.cos(b) f32_4 = np.arange(4, dtype=np.float32) - exp_f = get_exported(f)(f32_4, b=f32_4) + exp_f = get_exported(jax.jit(f))(f32_4, b=f32_4) with self.assertRaisesRegex(ValueError, r"Shape mismatch for args\[0\].shape\[0\]"): - export.call_exported(exp_f)(np.arange(6, dtype=np.float32), b=f32_4) + exp_f.call(np.arange(6, dtype=np.float32), b=f32_4) with self.assertRaisesRegex(ValueError, r"Shape mismatch for kwargs\['b'\].shape\[0\]"): - export.call_exported(exp_f)(f32_4, b=np.arange(6, dtype=np.float32)) + exp_f.call(f32_4, b=np.arange(6, dtype=np.float32)) with self.assertRaisesRegex(ValueError, r"Rank mismatch for args\[0\]"): - export.call_exported(exp_f)(f32_4.reshape((1, 4)), b=f32_4) + exp_f.call(f32_4.reshape((1, 4)), b=f32_4) with self.assertRaisesRegex(ValueError, r"Dtype mismatch for args\[0\]"): - export.call_exported(exp_f)(f32_4.astype(np.float16), b=f32_4) + exp_f.call(f32_4.astype(np.float16), b=f32_4) + + def test_default_export_platform(self): + test_platform = jtu.device_under_test() + if test_platform == "gpu": + test_platform = "rocm" if jtu.is_device_rocm() else "cuda" + self.assertEqual(export.default_export_platform(), test_platform) + exp = export.export(jnp.sin)(1.) + self.assertEqual(exp.platforms, (export.default_export_platform(),)) @jtu.parameterized_filterable( testcase_name=lambda kw: kw["platform"], @@ -297,14 +364,14 @@ def test_error_wrong_platform(self, platform): raise unittest.SkipTest("Uninteresting scenario") with self.assertRaisesRegex( - ValueError, "The exported function .* was lowered for platform"): - export.call_exported(exp_f)(a) + ValueError, "Function .* was exported for platform"): + exp_f.call(a) # Now try with the platform check disabled exp_f_no_platform_check = get_exported( jnp.sin, lowering_platforms=(platform,), disabled_checks=[export.DisabledSafetyCheck.platform()])(a) - res = export.call_exported(exp_f_no_platform_check)(a) + res = exp_f_no_platform_check.call(a) self.assertAllClose(res, jnp.sin(a)) @jtu.parameterized_filterable( @@ -327,30 +394,67 @@ def test_primitive_lowering(ctx, arg): with self.assertRaisesRegex(ValueError, "Cannot serialize code with custom calls whose targets .*"): get_exported( - lambda a: a + test_primitive.bind(a) + jax.jit(lambda a: a + test_primitive.bind(a)) )(a) # Now try again with the safety check disabled exp = get_exported( - lambda a: a + test_primitive.bind(a), + jax.jit(lambda a: a + test_primitive.bind(a)), disabled_checks=[export.DisabledSafetyCheck.custom_call("disallowed_call_target")] )(a) self.assertIn("disallowed_call_target", exp.mlir_module()) + def test_lowering_parameters_for_export(self): + # Test that we propagate properly the LoweringParameters.for_export + test_primitive = core.Primitive("_test_primitive_for_export") + test_primitive.def_abstract_eval(lambda in_aval: in_aval) + # Store here the context for lowering + context = {} + def test_primitive_lowering(ctx, arg): + context["for_export"] = ctx.module_context.lowering_parameters.for_export + context["export_ignore_forward_compatibility"] = ctx.module_context.lowering_parameters.export_ignore_forward_compatibility + return mlir.hlo.AddOp(arg, arg).results + + mlir.register_lowering(test_primitive, test_primitive_lowering) + self.addCleanup(lambda: mlir.register_lowering(test_primitive, None)) + + f = jax.jit(test_primitive.bind) + a = np.arange(3, dtype=np.float32) + context.clear() + res = f(a) # Works with JIT + self.assertAllClose(res, a + a) + self.assertEqual(context, + dict(for_export=False, + export_ignore_forward_compatibility=False)) + context.clear() + f.lower(a) # Works with most AOT + # The above was cached + self.assertEqual(context, {}) + _ = export.export(f)(a) + self.assertEqual(context, + dict(for_export=True, + export_ignore_forward_compatibility=False)) + context.clear() + with config.export_ignore_forward_compatibility(True): + _ = export.export(f)(a) + self.assertEqual(context, + dict(for_export=True, + export_ignore_forward_compatibility=True)) + def test_grad(self): f = lambda x: jnp.sum(jnp.sin(x)) x = np.arange(4, dtype=np.float32) - exp_f = get_exported(f, vjp_order=1)(x) + exp_f = get_exported(jax.jit(f), vjp_order=1)(x) - f1 = export.call_exported(exp_f) + f1 = exp_f.call self.assertAllClose(jax.grad(f)(x), jax.grad(f1)(x)) def test_higher_order_grad(self): f = lambda x: x ** 3 x = np.float32(4.) - exp_f = get_exported(f, vjp_order=3)(x) + exp_f = get_exported(jax.jit(f), vjp_order=3)(x) - f1 = export.call_exported(exp_f) + f1 = exp_f.call self.assertAllClose(jax.grad(f)(x), jax.grad(f1)(x)) self.assertAllClose(jax.grad(jax.grad(f))(x), @@ -376,8 +480,8 @@ def f(xi, xf): self.assertAllClose(res, (xi_ct, xf_ct)) (f_outi_ct2, f_outf_ct2), = f_vjp2((xi_ct, xf_ct)) - exp = get_exported(f, vjp_order=2)(xi, xf) - fr = export.call_exported(exp) + exp = get_exported(jax.jit(f), vjp_order=2)(xi, xf) + fr = exp.call res = fr(xi, xf) self.assertAllClose(res, (f_outi, f_outf)) @@ -404,14 +508,14 @@ def f(a_b_pair, *, a, b): a = np.arange(4, dtype=np.float32) b = np.arange(6, dtype=np.float32) - exp_f = get_exported(f, vjp_order=1)((a, b), a=a, b=b) + exp_f = get_exported(jax.jit(f), vjp_order=1)((a, b), a=a, b=b) out_ct = f((a, b), a=a, b=b) # The output has the right structure as the cotangent def f1_jax(a, b): # For VJP, make a function without kwargs res = f((a, b), a=a, b=b) return res def f1_exp(a, b): # For VJP, make a function without kwargs - res = export.call_exported(exp_f)((a, b), a=a, b=b) + res = exp_f.call((a, b), a=a, b=b) return res jax_vjp = jax.vjp(f1_jax, a, b)[1](out_ct) exp_vjp = jax.vjp(f1_exp, a, b)[1](out_ct) @@ -421,15 +525,15 @@ def test_roundtrip(self): def f1(x): return jnp.sin(x) a = np.arange(4, dtype=np.float32) - exp_f1 = get_exported(f1)(a) + exp_f1 = get_exported(jax.jit(f1))(a) def f2(x): - res1 = export.call_exported(exp_f1)(x) - res2 = export.call_exported(exp_f1)(res1) + res1 = exp_f1.call(x) + res2 = exp_f1.call(res1) return jnp.cos(res2) - exp_f2 = get_exported(f2)(a) + exp_f2 = get_exported(jax.jit(f2))(a) self.assertAllClose(jnp.cos(jnp.sin(jnp.sin(a))), - export.call_exported(exp_f2)(a)) + exp_f2.call(a)) def test_poly_export_only(self): a = np.arange(12, dtype=np.float32).reshape((3, 4)) @@ -437,7 +541,7 @@ def f(a, b): # a: f32[2w,h] b: f32[w,h] return jnp.concatenate([a, b], axis=0) scope = export.SymbolicScope() - exp = get_exported(f)( + exp = get_exported(jax.jit(f))( jax.ShapeDtypeStruct(export.symbolic_shape("(2*w, h)", scope=scope), a.dtype), jax.ShapeDtypeStruct(export.symbolic_shape("(w, h)", scope=scope), a.dtype)) self.assertEqual("(2*w, h)", str(exp.in_avals[0].shape)) @@ -446,13 +550,13 @@ def f(a, b): # a: f32[2w,h] b: f32[w,h] # Peek at the module module_str = exp.mlir_module() - self.assertEqual(config.jax_serialization_version.value >= 7, + self.assertEqual(config.jax_export_calling_convention_version.value >= 7, "shape_assertion" in module_str) self.assertIn("jax.uses_shape_polymorphism = true", module_str) wrapped_main_expected_re = ( r"@_wrapped_jax_export_main\(" - r"%arg0: tensor {jax.global_constant = \"h\"}.*" - r"%arg1: tensor {jax.global_constant = \"w\"}.*" + r"%arg0: tensor {jax.global_constant = \"h\".*" + r"%arg1: tensor {jax.global_constant = \"w\".*" r"%arg2: tensor<\?x\?xf32>" ) self.assertRegex(module_str, wrapped_main_expected_re) @@ -476,7 +580,7 @@ def f(a0, a1, *, ak): return jnp.concatenate([a0, a1, ak], axis=0) a_poly_spec = jax.ShapeDtypeStruct(export.symbolic_shape("(w, h)"), a.dtype) - exp = get_exported(f)(a_poly_spec, a_poly_spec, ak=a_poly_spec) + exp = get_exported(jax.jit(f))(a_poly_spec, a_poly_spec, ak=a_poly_spec) self.assertEqual("(w, h)", str(exp.in_avals[0].shape)) self.assertEqual("(3*w, h)", str(exp.out_avals[0].shape)) @@ -493,31 +597,56 @@ def f(x, y): "Invalid mixing of symbolic scopes when exporting f.*" r"Expected current \(from args\[0\]\) scope .*" r"and found for 'w' \(args\[1\]\) scope .*", re.DOTALL)): - get_exported(f)(x_poly_spec, y_poly_spec) - + get_exported(jax.jit(f))(x_poly_spec, y_poly_spec) + + def test_poly_export_callable_with_no_name(self): + # This was reported by a user + class MyCallable: + def __call__(self, x): + return jnp.sin(x) + + # This makes it look like a jitted-function + def lower(self, x, _experimental_lowering_parameters=None): + return jax.jit(self.__call__).lower( + x, + _experimental_lowering_parameters=_experimental_lowering_parameters) + + def trace(self, x, _experimental_lowering_parameters=None): + return jax.jit(self.__call__).trace( + x, + _experimental_lowering_parameters=_experimental_lowering_parameters) + + a, = export.symbolic_shape("a,") + # No error + _ = get_exported(jax.jit(MyCallable()))( + jax.ShapeDtypeStruct((a, a), dtype=np.float32) + ) @jtu.parameterized_filterable( kwargs=[ dict(v=v) - for v in range(export.minimum_supported_serialization_version - 1, - export.maximum_supported_serialization_version + 2)]) + for v in range(export.minimum_supported_calling_convention_version - 1, + export.maximum_supported_calling_convention_version + 2)]) def test_poly_basic_versions(self, v: int): - self.override_serialization_version(v) - with contextlib.ExitStack() as e: - if not (export.minimum_supported_serialization_version <= v - <= export.maximum_supported_serialization_version): - e.enter_context(self.assertRaisesRegex( - ValueError, - f"The requested jax_serialization version {v} is outside the range of supported versions")) - - exp = get_exported(jnp.sin)( - jax.ShapeDtypeStruct(export.symbolic_shape("w, h"), np.float32)) - x = np.arange(30, dtype=np.float32).reshape((5, 6)) - res = export.call_exported(exp)(x) - self.assertAllClose(res, np.sin(x)) + with config.jax_export_calling_convention_version(v): + logging.info( + "Using JAX calling convention version %s", + config.jax_export_calling_convention_version.value) + with contextlib.ExitStack() as e: + if not (export.minimum_supported_calling_convention_version <= v + <= export.maximum_supported_calling_convention_version): + e.enter_context(self.assertRaisesRegex( + ValueError, + f"The requested export calling convention version {v} is outside the range of supported versions")) + + exp = get_exported(jnp.sin)( + jax.ShapeDtypeStruct(export.symbolic_shape("w, h"), np.float32)) + x = np.arange(30, dtype=np.float32).reshape((5, 6)) + res = exp.call(x) + self.assertAllClose(res, np.sin(x)) # A function is exported with f32[poly_spec] and is called with different arg - # shapes. We use export.call_exported and we also run the shape check + # shapes. We use export.call and we also run the shape check # module. @jtu.parameterized_filterable( testcase_name=lambda kw:f"poly_spec={kw['poly_spec']}_arg_shape={kw['arg_shape']}", # type: ignore @@ -554,9 +683,9 @@ def f(x): # x: f32[poly_spec] return jnp.reshape(x, (-1, x.shape[1])) disabled_checks = () - exp_f = get_exported(f, disabled_checks=disabled_checks)( + exp_f = get_exported(jax.jit(f), disabled_checks=disabled_checks)( jax.ShapeDtypeStruct(export.symbolic_shape(poly_spec), np.float32)) - self.assertEqual(exp_f.uses_shape_polymorphism, poly_spec != "3,4,12") + self.assertEqual(exp_f.uses_global_constants, poly_spec != "3,4,12") arg = np.arange(np.prod(arg_shape), dtype=arg_dtype).reshape(arg_shape) # arg : f32[3,4,12] @@ -565,7 +694,7 @@ def f(x): # x: f32[poly_spec] stack.push(self.assertRaisesRegex(Exception, expect_error)) assert core.is_constant_shape(arg.shape) - res = export.call_exported(exp_f)(arg) + res = exp_f.call(arg) if not expect_error: self.assertAllClose(res, f(arg)) @@ -656,35 +785,35 @@ def inner(x): # x: inner_poly_spec arg = np.arange(np.prod(arg_shape), dtype=arg_dtype).reshape(arg_shape) # x : f32[3,4,12] - inner_exp = get_exported(inner)( + inner_exp = get_exported(jax.jit(inner))( jax.ShapeDtypeStruct(export.symbolic_shape(inner_poly_spec), np.float32)) - self.assertEqual(inner_exp.uses_shape_polymorphism, + self.assertEqual(inner_exp.uses_global_constants, (inner_poly_spec != "3,4,12")) def outer(x): # x: outer_poly_spec # Use an addition to test that the shapes are refined properly for the # result of the call_exported. - return export.call_exported(inner_exp)(x) + inner(x) + return inner_exp.call(x) + inner(x) with contextlib.ExitStack() as stack: if expect_error_outer_exp is not None: stack.push(self.assertRaisesRegex(ValueError, expect_error_outer_exp)) # Call it after exporting again, with polymorphic shapes - outer_exp = get_exported(outer)( + outer_exp = get_exported(jax.jit(outer))( jax.ShapeDtypeStruct(export.symbolic_shape(outer_poly_spec), arg.dtype)) if expect_error_outer_exp is not None: return - self.assertEqual(outer_exp.uses_shape_polymorphism, + self.assertEqual(outer_exp.uses_global_constants, (inner_poly_spec != "3,4,12" or outer_poly_spec != "3,4,12")) with contextlib.ExitStack() as stack: if expect_error_run is not None: stack.push(self.assertRaisesRegex(Exception, expect_error_run)) - res = export.call_exported(outer_exp)(arg) + res = outer_exp.call(arg) if expect_error_run is not None: return @@ -706,7 +835,7 @@ def outer(x): # x: outer_poly_spec "Using the following polymorphic shapes specifications: args[0].shape = (2*b + a, a, c + b + a). " "Obtained dimension variables: 'a' = 2 from specification 'a' for dimension args[0].shape[1] (= 2), " "'b' = 0 from specification '2*b + a' for dimension args[0].shape[0] (= 2), . " - "Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#shape-assertion-errors for more details." + "Please see https://jax.readthedocs.io/en/latest/export/shape_poly.html#shape-assertion-errors for more details." )), dict(shape=(3, 2, 6), # a = 2, b = 0.5, c = 4 - b is not integer poly_spec="(a + 2*b, a, a + b + c)", @@ -715,7 +844,7 @@ def outer(x): # x: outer_poly_spec "Division had remainder 1 when computing the value of 'b'. " "Using the following polymorphic shapes specifications: args[0].shape = (2*b + a, a, c + b + a). " "Obtained dimension variables: 'a' = 2 from specification 'a' for dimension args[0].shape[1] (= 2), . " - "Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#shape-assertion-errors for more details." + "Please see https://jax.readthedocs.io/en/latest/export/shape_poly.html#shape-assertion-errors for more details." )), dict(shape=(8, 2, 6), # a = 2, b = 3 - inconsistency poly_spec="(a + 2*b, a, a + b)", @@ -725,7 +854,7 @@ def outer(x): # x: outer_poly_spec "Using the following polymorphic shapes specifications: args[0].shape = (2*b + a, a, b + a). " "Obtained dimension variables: 'a' = 2 from specification 'a' for dimension args[0].shape[1] (= 2), " "'b' = 4 from specification 'b + a' for dimension args[0].shape[2] (= 6), . " - "Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#shape-assertion-errors for more details." + "Please see https://jax.readthedocs.io/en/latest/export/shape_poly.html#shape-assertion-errors for more details." )), dict(shape=(7, 2, 36), # a = 2, b = 3, c = 6 - cannot solve c poly_spec="(2 * a + b, a, c * c)", @@ -734,7 +863,7 @@ def outer(x): # x: outer_poly_spec "We can only solve linear uni-variate constraints. " "Using the following polymorphic shapes specifications: args[0].shape = (b + 2*a, a, c^2). " "Unprocessed specifications: 'c^2' for dimension size args[0].shape[2]. " - "Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#dimension-variables-must-be-solvable-from-the-input-shapes for more details." + "Please see https://jax.readthedocs.io/en/latest/export/shape_poly.html#dimension-variables-must-be-solvable-from-the-input-shapes for more details." )), ]) def test_shape_constraints_errors(self, *, @@ -746,20 +875,21 @@ def f_jax(x): # x: f32[a + 2*b, a, a + b + c] with contextlib.ExitStack() as stack: if expect_error is not None: stack.push(self.assertRaisesRegex(Exception, re.escape(expect_error))) - exp = get_exported(f_jax)( + exp = get_exported(jax.jit(f_jax))( jax.ShapeDtypeStruct(export.symbolic_shape(poly_spec), x.dtype)) - export.call_exported(exp)(x) + exp.call(x) def test_poly_booleans(self): # For booleans we use a special case ConvertOp to cast to and from # dynamic shapes arguments. + @jax.jit def f_jax(x): # x: bool[b] return jnp.logical_not(x) x = np.array([True, False, True, False], dtype=np.bool_) exp = get_exported(f_jax)(jax.ShapeDtypeStruct(export.symbolic_shape("b"), x.dtype)) - res = export.call_exported(exp)(x) + res = exp.call(x) self.assertAllClose(f_jax(x), res) @jtu.parameterized_filterable( @@ -774,13 +904,14 @@ def test_poly_numeric_dtypes(self, dtype=np.int32): "int4", "uint4"}: self.skipTest(f"TODO: serialization not supported for {str(dtype)}") + @jax.jit def f_jax(x): return x + x x = np.arange(6, dtype=dtype) exp = get_exported(f_jax)(jax.ShapeDtypeStruct(export.symbolic_shape("b"), x.dtype)) - res = export.call_exported(exp)(x) + res = exp.call(x) self.assertAllClose(f_jax(x), res) def test_poly_expressions(self): @@ -790,6 +921,7 @@ def output_shape(b): return (b + b, b - b, b * b, (b + 13) // b, (b + 13) % b, core.max_dim(b - 5, 0)) + @jax.jit def f(x): # x: f32[b] b = x.shape[0] return jnp.ones(output_shape(b), dtype=x.dtype) @@ -797,15 +929,43 @@ def f(x): # x: f32[b] exp = get_exported(f)(jax.ShapeDtypeStruct(export.symbolic_shape("b"), x.dtype)) # Call with static shapes - res = export.call_exported(exp)(x) + res = exp.call(x) self.assertAllClose(res, f(x)) # Now re-export with shape polymorphism x_spec = jax.ShapeDtypeStruct(export.symbolic_shape("a"), x.dtype) - exp2 = get_exported(export.call_exported(exp))(x_spec) + exp2 = get_exported(jax.jit(exp.call))(x_spec) a = exp2.in_avals[0].shape[0] self.assertEqual(exp2.out_avals[0].shape, output_shape(a)) + def test_with_donation(self): + f = jax.jit(jnp.sin, donate_argnums=(0,)) + x = np.arange(3, dtype=np.float32) + exp = export.export(f)(x) + + def caller(x): + y = exp.call(x) + return x + y + res = jax.jit(caller)(x) + self.assertAllClose(res, x + np.sin(x)) + + def test_poly_call_pmap(self): + if len(jax.devices()) < 2: + self.skipTest("Need at least 2 devices") + def f(x): # x: f32[a, 4] + return x + jnp.arange(x.shape[0], dtype=x.dtype).reshape((x.shape[0], 1)) + + a, = export.symbolic_shape("a") + exp = export.export(jax.jit(f))( + jax.ShapeDtypeStruct((a, 4), np.float32)) + f_exp = exp.call + x_jit = np.arange(12, dtype=np.float32).reshape((3, 4)) + res_jit = jax.jit(f_exp)(x_jit) + self.assertAllClose(res_jit, f(x_jit)) + x_pmap = np.arange(24, dtype=np.float32).reshape((2, 3, 4)) + res_pmap = jax.pmap(f_exp)(x_pmap) + self.assertAllClose(res_pmap, jnp.stack([f(x) for x in x_pmap])) + def test_with_sharding(self): nr_devices = 2 if len(jax.devices()) < nr_devices: @@ -837,27 +997,147 @@ def f_jax(b): # b: f32[16 // DEVICES, 4] # We apply the out_shardings for f_jax r".*custom_call @Sharding\(%1\).*mhlo.sharding = \"{devices=\[1,2\]<=\[2\]}\"}.*", re.DOTALL) - hlo = jax.jit(export.call_exported(exp)).lower(a_device).as_text() + hlo = jax.jit(exp.call).lower(a_device).as_text() self.assertRegex(hlo, expected_re) - res_exported = export.call_exported(exp)(a_device) + res_exported = exp.call(a_device) self.assertAllClose(res_native, res_exported) # Test error reporting with self.assertRaisesRegex( - NotImplementedError, - "Exported module .* was lowered for 2 devices and is called in a context with 1 device"): - _ = export.call_exported(exp)(a) + ValueError, + "Function .* was exported for 2 devices and is called in a context with 1 device"): + _ = exp.call(a) with self.assertRaisesRegex( - NotImplementedError, - "Exported module .* was lowered for 2 devices and is called in a context with 1 device"): + ValueError, + "Function .* was exported for 2 devices and is called in a context with 1 device"): mesh1 = Mesh(jax.devices()[0:1], axis_names=("x",)) _ = jax.jit( - export.call_exported(exp), + exp.call, in_shardings=(jax.sharding.NamedSharding(mesh1, P("x", None)),) )(a) + def test_input_shardings_unused_args(self): + nr_devices = 2 + if len(jax.devices()) < nr_devices: + self.skipTest("Need at least 2 devices") + devices = jax.devices()[0:nr_devices] + export_mesh = Mesh(np.array(devices), + axis_names=("x",)) + a = np.arange(16 * 4, dtype=np.float32).reshape((16, 4)) + + f = jax.jit(lambda x, y: jnp.sin(x), + in_shardings=(jax.sharding.NamedSharding(export_mesh, P("x", None),), + None), + out_shardings=(jax.sharding.NamedSharding(export_mesh, P("x", None),))) + exp = get_exported(f)(a, a) + + # We can use other devices and other meshes for running + run_devices = devices[::-1] + run_mesh = Mesh(run_devices, "a") + run_input_shardings = exp.in_shardings_jax(run_mesh) + a_run = jax.device_put(a, run_input_shardings[0]) + b_run = jax.device_put(a, run_input_shardings[1]) + res = exp.call(a_run, b_run) + self.assertEqual(res.addressable_shards[0].device, run_devices[0]) + self.assertEqual(res.addressable_shards[1].device, run_devices[1]) + + def test_call_with_different_no_of_devices(self): + if jax.local_device_count() < 2: + self.skipTest("Need at least 2 devices") + + @jax.jit + def f_without_shardings(x): + return jnp.sum(x ** 2, axis=0) + + a = jnp.arange(jax.local_device_count() * 10, dtype=np.float32).reshape( + (jax.local_device_count(), 10) + ) + res_native = f_without_shardings(a) + exp = get_exported(f_without_shardings)(a) + self.assertEqual(exp.nr_devices, 1) + + run_devices = jax.local_devices() + run_mesh = Mesh(run_devices, "i") + b = jax.device_put(a, jax.sharding.NamedSharding(run_mesh, P("i"))) + + res_exported = exp.call(b) + self.assertAllClose(res_native, res_exported) + + def test_call_with_different_no_of_devices_error_has_in_shardings(self): + if jax.local_device_count() < 2: + self.skipTest("Need at least 2 devices") + + mesh_1 = Mesh(jax.local_devices()[:1], "i") + @functools.partial(pjit.pjit, + in_shardings=NamedSharding(mesh_1, P("i"))) + def f_with_sharding(x): + return jnp.sum(x ** 2, axis=0) + + a = jnp.arange(jax.device_count() * 10, dtype=np.float32).reshape( + (jax.device_count(), 10) + ) + exp = get_exported(f_with_sharding)(a) + self.assertEqual(exp.nr_devices, 1) + + run_devices = jax.local_devices() + run_mesh = Mesh(run_devices, "i") + b = jax.device_put(a, jax.sharding.NamedSharding(run_mesh, P("i"))) + + with self.assertRaisesRegex( + ValueError, + "Function .* was exported for 1 devices and is called in a " + f"context with {jax.local_device_count()} devices.* function contains " + "non-replicated sharding annotations"): + exp.call(b) + + def test_call_with_different_no_of_devices_pmap(self): + if len(jax.devices()) < 2: + self.skipTest("Need at least 2 devices") + + @jax.jit + def f_jax(x): + return jnp.sum(x ** 2, axis=0) + + a = jnp.arange(100, dtype=jnp.float32).reshape((1, 100)) + res_native = f_jax(a) + exp = get_exported(f_jax)(a) + self.assertEqual(exp.nr_devices, 1) + + b = jnp.arange(jax.device_count() * 100, dtype=jnp.float32).reshape( + (-1, 1, 100) + ) + res_exported = jax.pmap(exp.call)(b) + self.assertAllClose(res_native, res_exported[0]) + + def test_call_with_different_no_of_devices_error_has_sharding_constraint(self): + if jax.device_count() < 2: + self.skipTest("Need at least 2 devices") + + mesh_1 = Mesh(jax.local_devices()[:1], "i") + @jax.jit + def f_with_sharding(x): + x = jax.lax.with_sharding_constraint(x, NamedSharding(mesh_1, P("i"))) + return jnp.sum(x ** 2, axis=0) + + a = jnp.arange(jax.device_count() * 10, dtype=np.float32).reshape( + (jax.device_count(), 10) + ) + exp = get_exported(f_with_sharding)(a) + self.assertEqual(exp.nr_devices, 1) + + run_devices = jax.local_devices() + run_mesh = Mesh(run_devices, "i") + b = jax.device_put(a, jax.sharding.NamedSharding(run_mesh, P("i"))) + + with self.assertRaisesRegex( + ValueError, + "Function .* was exported for 1 devices and is called in a " + f"context with {jax.local_device_count()} devices.* function contains " + "non-replicated sharding annotations"): + exp.call(b) + @jtu.parameterized_filterable( kwargs=[ dict(testcase_name=f"_poly={poly}", poly=poly) @@ -882,7 +1162,7 @@ def f_jax(b): # b: f32[2, 4] perm = [(j, (j + 1) % axis_size) for j in range(axis_size)] return lax.ppermute(b, "x", perm=perm) - args_specs = export.symbolic_args_specs((a,), polymorphic_shapes=poly) + args_specs = export.symbolic_args_specs((a,), poly) exp = get_exported(f_jax)(*args_specs) # Test JAX native execution @@ -894,10 +1174,10 @@ def f_jax(b): # b: f32[2, 4] self.assertLen(res_jax.addressable_shards, len(devices)) # Test reloaded execution. - f_r = export.call_exported(exp) + f_r = exp.call with self.assertRaisesRegex( Exception, - "Exported module .* was lowered for 2 devices and is " + "Function .* was exported for 2 devices and is " "called in a context with 1 devices"): _ = f_r(a) # A is all on the default device @@ -918,24 +1198,26 @@ def f_jax(b): # b: f32[2, 4] @jtu.parameterized_filterable( kwargs=[ dict(in_shardings=in_shardings, out_shardings=out_shardings, - with_mesh=with_mesh) + with_mesh_context=with_mesh_context) for in_shardings in ("missing", None, "P") for out_shardings in ("missing", None, "P") - for with_mesh in (True, False) + for with_mesh_context in (True, False) ]) def test_grad_with_sharding(self, in_shardings="P", out_shardings=None, - with_mesh=False): + with_mesh_context=False): if len(jax.devices()) < 2: self.skipTest("Test requires at least 2 devices") x_shape = (10, 20) x = np.arange(np.prod(x_shape), dtype=np.float32).reshape(x_shape) + # The input has shape f32[10,20] and output f32[20,10] in order to + # distinguish them in the HLO. def f_jax(x): # x: f32[10,20] -> f32[20,10] return jnp.sin(x.T) mesh = Mesh(jax.devices()[:2], "d") pjit_kwargs = {} # Use NamedShardings if we don't have a mesh_context - if with_mesh: + if with_mesh_context: sharding_None_d = P(None, "d") sharding_d_None = P("d", None) else: @@ -951,7 +1233,7 @@ def f_jax(x): # x: f32[10,20] -> f32[20,10] f_jax_pjit = pjit.pjit(f_jax, **pjit_kwargs) with contextlib.ExitStack() as stack: - if with_mesh: + if with_mesh_context: stack.enter_context(mesh) # Serialize higher-order gradiends exp = get_exported(f_jax_pjit, vjp_order=2)(x) @@ -961,50 +1243,62 @@ def f_jax(x): # x: f32[10,20] -> f32[20,10] vjp_module_str = str(exp_vjp.mlir_module()) + # The MHLO attributes of the args and the result of the main function + # Arg0 are the primal inputs, arg1 are the output cotangent, res is the input cotangent + arg0_attrs, arg1_attrs, res_attrs = re.search( + r"func.func public @main\(%arg0: tensor<10x20xf32> (.*)" + r", %arg1: tensor<20x10xf32> (.*)" + r"\) -> \(tensor<10x20xf32> (.*)", # the result + vjp_module_str).groups() + if in_shardings == "P": + self.assertRegex(arg0_attrs, re.escape("{devices=[1,2]<=[2]}")) + self.assertRegex(res_attrs, re.escape("{devices=[1,2]<=[2]}")) primal_in_sharding = "{devices=[1,2]<=[2]}" else: primal_in_sharding = "{replicated}" + if with_mesh_context: + self.assertRegex(arg0_attrs, re.escape("replicated")) + self.assertRegex(res_attrs, re.escape("replicated")) + else: + # If there is no mesh context, we have used NamedSharding(None) + # and then the sharding is unspecified! + self.assertNotIn("mhlo.sharding", arg0_attrs) + self.assertNotIn("mhlo.sharding", res_attrs) + if out_shardings == "P": + self.assertRegex(arg1_attrs, re.escape("{devices=[2,1]<=[2]}")) primal_out_sharding = "{devices=[2,1]<=[2]}" else: primal_out_sharding = "{replicated}" + if with_mesh_context: + self.assertRegex(arg1_attrs, re.escape("replicated")) + else: + self.assertNotIn("mhlo.sharding", arg1_attrs) - # TODO(b/326476605): Change the condition below if required. - if in_shardings == "P": - main = re.compile( - r"func.func public @main\(%arg0: tensor<10x20xf32>.*" - "mhlo.sharding = \"" + re.escape(primal_in_sharding) + "\"" - r".*%arg1: tensor<20x10xf32>.*" - "mhlo.sharding = \"" + re.escape(primal_out_sharding) + "\"" - # result - r".*->.*\(tensor<10x20xf32>.*" - "mhlo.sharding = \"" + re.escape(primal_in_sharding) + "\"") - self.assertRegex(vjp_module_str, main) - - # Custom calls for the primal input shape all match primal_in_sharding - primal_in_calls = re.findall( + # Sharding custom calls for the primal input shape all match primal_in_sharding + primal_in_sharding_calls = re.findall( r"custom_call @Sharding.*mhlo.sharding = \"(.+)\".*:.*tensor<10x20xf32>", vjp_module_str) self.assertTrue( - all(s == primal_in_sharding for s in primal_in_calls), - primal_in_calls + all(s == primal_in_sharding for s in primal_in_sharding_calls), + primal_in_sharding_calls ) # Custom calls for the primal output shape all match primal_out_sharding - primal_out_calls = re.findall( + primal_out_sharding_calls = re.findall( r"custom_call @Sharding.*mhlo.sharding = \"(.+)\".*:.*tensor<20x10xf32>", vjp_module_str) self.assertTrue( - all(s == primal_out_sharding for s in primal_out_calls), - primal_in_calls + all(s == primal_out_sharding for s in primal_out_sharding_calls), + primal_out_sharding_calls ) # Call the exported gradient functions. In order to set the device context # we replicate the inputs. If we don't use a mesh context and there are # no shardings on inputs or outputs, then we have serialized for one # device. - if in_shardings != "P" and out_shardings != "P" and not with_mesh: + if in_shardings != "P" and out_shardings != "P" and not with_mesh_context: self.assertEqual(exp_vjp.nr_devices, 1) self.assertEqual(exp_vjp2.nr_devices, 1) call_mesh = Mesh(jax.devices()[:1], "e") @@ -1013,14 +1307,14 @@ def f_jax(x): # x: f32[10,20] -> f32[20,10] self.assertEqual(exp_vjp2.nr_devices, 2) call_mesh = Mesh(jax.devices()[:2], "e") - g1 = pjit.pjit(export.call_exported(exp_vjp), + g1 = pjit.pjit(exp_vjp.call, in_shardings=(NamedSharding(call_mesh, None), NamedSharding(call_mesh, None)))(x, x.T) _, f_jax_vjp = jax.vjp(f_jax, x) xbar = f_jax_vjp(x.T) self.assertAllClose(xbar, g1) - g2 = pjit.pjit(export.call_exported(exp_vjp2), + g2 = pjit.pjit(exp_vjp2.call, in_shardings=(NamedSharding(call_mesh, None), NamedSharding(call_mesh, None), NamedSharding(call_mesh, None)))(x, x.T, x) @@ -1028,11 +1322,37 @@ def f_jax(x): # x: f32[10,20] -> f32[20,10] xbar2, = f_jax_vjp2((x,)) self.assertAllClose(xbar2, g2[1]) + def test_grad_sharding_different_mesh(self): + # Export and serialize with two similar meshes, the only difference being + # the order of the devices. grad and serialization should not fail. + # https://github.com/google/jax/issues/21314 + def f(x): + return jnp.sum(x * 2.) + + mesh = Mesh(jax.local_devices(), "i") + mesh_rev = Mesh(list(reversed(jax.local_devices())), "i") + shardings = NamedSharding(mesh, jax.sharding.PartitionSpec(("i",))) + shardings_rev = NamedSharding(mesh_rev, jax.sharding.PartitionSpec(("i",))) + input_no_shards = jnp.ones(shape=(jax.local_device_count(),)) + input = jnp.ones(shape=(jax.local_device_count(),), device=shardings) + input_rev = jax.device_put(input_no_shards, device=shardings_rev) + + exp = export.export(pjit.pjit(f, in_shardings=shardings))(input) + exp_rev = export.export(pjit.pjit(f, in_shardings=shardings_rev))(input_no_shards) + + if CAN_SERIALIZE: + _ = exp.serialize(vjp_order=1) + _ = exp_rev.serialize(vjp_order=1) + + g = jax.grad(exp_rev.call)(input_rev) + g_rev = jax.grad(exp.call)(input) + self.assertAllClose(g, g_rev) + def test_multi_platform(self): x = np.arange(8, dtype=np.float32) - exp = get_exported(_testing_multi_platform_func, - lowering_platforms=("tpu", "cpu", "cuda"))(x) - self.assertEqual(exp.lowering_platforms, ("tpu", "cpu", "cuda")) + exp = get_exported(jax.jit(_testing_multi_platform_func), + lowering_platforms=("tpu", "cpu", "cuda", "rocm"))(x) + self.assertEqual(exp.platforms, ("tpu", "cpu", "cuda", "rocm")) module_str = str(exp.mlir_module()) expected_main_re = ( r"@main\(" @@ -1046,22 +1366,22 @@ def test_multi_platform(self): # Call with argument placed on different plaforms for platform in self.__class__.platforms: x_device = jax.device_put(x, jax.devices(platform)[0]) - res_exp = export.call_exported(exp)(x_device) + res_exp = exp.call(x_device) self.assertAllClose( res_exp, _testing_multi_platform_fun_expected(x, platform=platform)) def test_multi_platform_nested(self): x = np.arange(5, dtype=np.float32) - exp = get_exported(lambda x: _testing_multi_platform_func(jnp.sin(x)), - lowering_platforms=("cpu", "tpu", "cuda"))(x) - self.assertEqual(exp.lowering_platforms, ("cpu", "tpu", "cuda")) + exp = get_exported(jax.jit(lambda x: _testing_multi_platform_func(jnp.sin(x))), + lowering_platforms=("cpu", "tpu", "cuda", "rocm"))(x) + self.assertEqual(exp.platforms, ("cpu", "tpu", "cuda", "rocm")) # Now serialize the call to the exported using a different sequence of # lowering platforms, but included in the lowering platforms for the # nested exported. - exp2 = get_exported(export.call_exported(exp), - lowering_platforms=("cpu", "cuda"))(x) + exp2 = get_exported(jax.jit(exp.call), + lowering_platforms=("cpu", "cuda", "rocm"))(x) # Ensure that we do not have multiple lowerings of the exported function exp2_module_str = str(exp2.mlir_module()) @@ -1072,39 +1392,130 @@ def test_multi_platform_nested(self): for platform in self.__class__.platforms: if platform == "tpu": continue x_device = jax.device_put(x, jax.devices(platform)[0]) - res_exp = export.call_exported(exp2)(x_device) + res_exp = exp2.call(x_device) self.assertAllClose( res_exp, _testing_multi_platform_fun_expected(np.sin(x), platform=platform)) def test_multi_platform_nested_inside_single_platform_export(self): x = np.arange(5, dtype=np.float32) - exp = get_exported(_testing_multi_platform_func, - lowering_platforms=("cpu", "tpu", "cuda"))(x) - self.assertEqual(exp.lowering_platforms, ("cpu", "tpu", "cuda")) + exp = get_exported(jax.jit(_testing_multi_platform_func), + lowering_platforms=("cpu", "tpu", "cuda", "rocm"))(x) + self.assertEqual(exp.platforms, ("cpu", "tpu", "cuda", "rocm")) # Now serialize the call for the current platform. - exp2 = get_exported(export.call_exported(exp))(x) + exp2 = get_exported(jax.jit(exp.call))(x) module_str = str(exp2.mlir_module()) self.assertIn("jax.uses_shape_polymorphism = true", module_str) - res2 = export.call_exported(exp2)(x) + res2 = exp2.call(x) self.assertAllClose(res2, _testing_multi_platform_fun_expected(x)) + def test_multi_platform_mlir_lower_fun_with_platform_specific_primitives(self): + # A primitive with multiple lowering rules, which themselves involve + # tracing primitives with per-platform rules, using mlir.lower_fun. + # This situation arises for Pallas lowering. + def times_n_lowering(n: int, ctx: mlir.LoweringRuleContext, + x: mlir.ir.Value) -> Sequence[mlir.ir.Value]: + # Lowering n * x + res = x + for i in range(n - 1): + res = mlir.hlo.AddOp(res, x) + return res.results + + times_2 = core.Primitive("__testing_times_2") # x2 for cpu + times_2.def_abstract_eval(lambda x: x) + # Define lowering rules only for the relevant platforms, ensure there + # is no error about missing lowering rules + mlir.register_lowering(times_2, functools.partial(times_n_lowering, 2), + "cpu") + + times_3 = core.Primitive("__testing_times_3") # x3 for cuda and rocm + times_3.def_abstract_eval(lambda x: x) + + mlir.register_lowering(times_3, functools.partial(times_n_lowering, 3), + "rocm") + mlir.register_lowering(times_3, functools.partial(times_n_lowering, 3), + "cuda") + + times_4 = core.Primitive("__testing_times_4") # x4 for tpu + times_4.def_abstract_eval(lambda x: x) + mlir.register_lowering(times_4, functools.partial(times_n_lowering, 4), + "tpu") + + times_2_or_3 = core.Primitive("__testing_times_2_or_3") # x2 for cpu, x3 for cuda and rocm + times_2_or_3.def_abstract_eval(lambda x: x) + mlir.register_lowering(times_2_or_3, + mlir.lower_fun(times_2.bind, + multiple_results=False), "cpu") + + mlir.register_lowering(times_2_or_3, + mlir.lower_fun(times_3.bind, + multiple_results=False), "rocm") + mlir.register_lowering(times_2_or_3, + mlir.lower_fun(times_3.bind, + multiple_results=False), "cuda") + + times_2_or_3_or_4 = core.Primitive("__testing_times_2_or_3_or_4") # x2 for cpu, x3 for cuda and rocm, x4 for tpu + times_2_or_3_or_4.def_abstract_eval(lambda x: x) + times_2_or_3_or_4_lowering_cpu_gpu = mlir.lower_fun(times_2_or_3.bind, + multiple_results=False) + + for platform in ["cpu", "cuda", "rocm"]: + mlir.register_lowering(times_2_or_3_or_4, + times_2_or_3_or_4_lowering_cpu_gpu, + platform) + mlir.register_lowering(times_2_or_3_or_4, mlir.lower_fun(times_4.bind, + multiple_results=False), + "tpu") + + @jax.jit + def f(x): + return times_2_or_3_or_4.bind(x) + x = np.float32(42.) + exp = export.export(f, lowering_platforms=["cpu", "cuda", "rocm", "tpu"])(x) + expected = x * np.float32(dict(cpu=2, gpu=3, tpu=4)[jtu.device_under_test()]) + self.assertAllClose(exp.call(x), expected) + + def test_multi_platform_unknown_platform(self): + x = np.arange(8, dtype=np.float32) + exp = get_exported(jax.jit(jnp.sin), + lowering_platforms=("tpu", "cpu", "cuda", "other"))(x) + self.assertEqual(exp.platforms, ("tpu", "cpu", "cuda", "other")) + + + def test_multi_platform_with_donation(self): + f = jax.jit(jnp.sin, donate_argnums=(0,)) + x = np.arange(3, dtype=np.float32) + exp = export.export(f, platforms=["cpu", "tpu"])(x) + if jtu.device_under_test() not in ["cpu", "tpu"]: + self.skipTest("other platform") + + def caller(x): + y = exp.call(x) + return x + y + res = jax.jit(caller)(x) + self.assertAllClose(res, x + np.sin(x)) + + with self.assertRaisesRegex( + NotImplementedError, + "In multi-platform lowering either all or no lowering platforms should support donation"): + export.export(f, platforms=["cpu", "tpu", "other"])(x) + def test_multi_platform_and_poly(self): if jtu.test_device_matches(["gpu"]): # The export is not applicable to GPU raise unittest.SkipTest("Not intended for running on GPU") - exp = get_exported(lambda x: jnp.reshape(_testing_multi_platform_func(x), (-1,)), - lowering_platforms=("cpu", "tpu"))( + exp = get_exported(jax.jit(lambda x: jnp.reshape(_testing_multi_platform_func(x), (-1,))), + lowering_platforms=("cpu", "tpu"))( jax.ShapeDtypeStruct(export.symbolic_shape("b1, b2"), np.float32) ) x = np.arange(12, dtype=np.float32).reshape((3, 4)) - res = export.call_exported(exp)(x) + res = exp.call(x) self.assertAllClose(res, _testing_multi_platform_fun_expected(x).reshape((-1,))) # Now serialize the call to the exported - exp2 = get_exported(export.call_exported(exp))(x) - res2 = export.call_exported(exp2)(x) + exp2 = get_exported(jax.jit(exp.call))(x) + res2 = exp2.call(x) self.assertAllClose(res2, _testing_multi_platform_fun_expected(x).reshape((-1,))) def test_multi_platform_and_sharding(self): @@ -1120,7 +1531,7 @@ def f_jax(b): # b: f32[16 // DEVICES, 4] res_native = f_jax(a) exp = get_exported(f_jax, - lowering_platforms=("cpu", "tpu", "cuda"))(a) + lowering_platforms=("cpu", "tpu", "cuda", "rocm"))(a) # Call with argument placed on different plaforms for platform in self.__class__.platforms: @@ -1129,209 +1540,190 @@ def f_jax(b): # b: f32[16 // DEVICES, 4] continue run_mesh = Mesh(run_devices, ("x",)) a_device = jax.device_put(a, jax.sharding.NamedSharding(run_mesh, None)) - res_exp = export.call_exported(exp)(a_device) + res_exp = exp.call(a_device) self.assertArraysAllClose(res_native, res_exp) @jtu.parameterized_filterable( kwargs=[ dict(v=v) - for v in range(export.minimum_supported_serialization_version, - export.maximum_supported_serialization_version + 1)]) + for v in range(export.minimum_supported_calling_convention_version, + export.maximum_supported_calling_convention_version + 1)]) def test_ordered_effects_basic(self, *, v: int): - self.override_serialization_version(v) - x = np.arange(3, dtype=np.float32) - def f_jax(x): # x: f32[3] - # Test also the calling convention for inner functions - def f_jax_inner(x): + with config.jax_export_calling_convention_version(v): + logging.info( + "Using JAX serialization version %s", + config.jax_export_calling_convention_version.value) + x = np.arange(3, dtype=np.float32) + def f_jax(x): # x: f32[3] + # Test also the calling convention for inner functions + def f_jax_inner(x): + return ( + testing_primitive_with_effect_p.bind(x, effect_class_name="ForTestingOrderedEffect2") + + testing_primitive_with_effect_p.bind(x, effect_class_name="ForTestingUnorderedEffect1")) return ( - testing_primitive_with_effect_p.bind(x, effect_class_name="ForTestingOrderedEffect2") + - testing_primitive_with_effect_p.bind(x, effect_class_name="ForTestingUnorderedEffect1")) - return ( - 10. + - jax.jit(f_jax_inner)(x) + - testing_primitive_with_effect_p.bind(x, effect_class_name="ForTestingOrderedEffect1") + - testing_primitive_with_effect_p.bind(x, effect_class_name="ForTestingOrderedEffect2") - ) + 10. + + jax.jit(f_jax_inner)(x) + + testing_primitive_with_effect_p.bind(x, effect_class_name="ForTestingOrderedEffect1") + + testing_primitive_with_effect_p.bind(x, effect_class_name="ForTestingOrderedEffect2") + ) - exp = get_exported(f_jax)(x) - if exp.mlir_module_serialization_version >= _export._VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS: + exp = get_exported(jax.jit(f_jax))(x) self.assertEqual(["ForTestingOrderedEffect1()", "ForTestingOrderedEffect2()"], - sorted(str(e) for e in exp.ordered_effects)) + sorted(str(e) for e in exp.ordered_effects)) self.assertEqual(["ForTestingUnorderedEffect1()"], - [str(e) for e in exp.unordered_effects]) - else: - self.assertEqual([], [str(e) for e in exp.ordered_effects]) - self.assertEqual([], [str(e) for e in exp.unordered_effects]) - mlir_module_str = str(exp.mlir_module()) - - # Inner functions use stablehlo.token for all versions - inner_fun_expected_re = ( - r"func.func private @f_jax_inner\(" - r"%arg0: !stablehlo.token .*jax.token = true.*" - r"%arg1: tensor<3xf32>.*->.*" - # Results - r"!stablehlo.token .*jax.token = true.*" - r"tensor<3xf32>" - ) - self.assertRegex(mlir_module_str, inner_fun_expected_re) + [str(e) for e in exp.unordered_effects]) + mlir_module_str = str(exp.mlir_module()) + + # Inner functions use stablehlo.token for all versions + inner_fun_expected_re = ( + r"func.func private @f_jax_inner\(" + r"%arg0: !stablehlo.token .*jax.token = true.*" + r"%arg1: tensor<3xf32>.*->.*" + # Results + r"!stablehlo.token .*jax.token = true.*" + r"tensor<3xf32>" + ) + self.assertRegex(mlir_module_str, inner_fun_expected_re) + + # The wrapped_main function takens tokens after version 9, and takes + # i1[0] before version 9. + wrapped_main_expected_re = ( + r"@_wrapped_jax_export_main\(" + r"%arg0: !stablehlo.token .*jax.token = true.*" + r"%arg1: !stablehlo.token .*jax.token = true.*->.*" + # Results + r"!stablehlo.token .*jax.token = true.*" + r"!stablehlo.token .*jax.token = true.*") + self.assertRegex(mlir_module_str, wrapped_main_expected_re) - # The wrapped_main function takens tokens after version 9, and takes - # i1[0] before version 9. - wrapped_main_expected_re = ( - r"@_wrapped_jax_export_main\(" - r"%arg0: !stablehlo.token .*jax.token = true.*" - r"%arg1: !stablehlo.token .*jax.token = true.*->.*" - # Results - r"!stablehlo.token .*jax.token = true.*" - r"!stablehlo.token .*jax.token = true.*") - if exp.mlir_module_serialization_version < _export._VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS: - wrapped_main_expected_re = wrapped_main_expected_re.replace("!stablehlo.token", "tensor<0xi1>") - self.assertRegex(mlir_module_str, wrapped_main_expected_re) - - if exp.mlir_module_serialization_version < _export._VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS: - # The main function does not have tokens - self.assertNotRegex(mlir_module_str, r"@main.*token") - else: # The main function takes tokens and has the same type as the wrapped main main_expected_re = wrapped_main_expected_re.replace("@_wrapped_jax_export_main", "@main") self.assertRegex(mlir_module_str, main_expected_re) - # Now call the exported from a function that uses its own effects - def f_outer(x): - return ( - testing_primitive_with_effect_p.bind( - x, effect_class_name="ForTestingOrderedEffect2") + - testing_primitive_with_effect_p.bind( - x, effect_class_name="ForTestingUnorderedEffect1") + - export.call_exported(exp)(x)) - - lowered_outer = jax.jit(f_outer).lower(x) - if exp.mlir_module_serialization_version < _export._VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS: - self.assertEqual(["ForTestingOrderedEffect2()"], - [str(e) for e in lowered_outer._lowering.compile_args["ordered_effects"]]) - else: + # Now call the exported from a function that uses its own effects + def f_outer(x): + return ( + testing_primitive_with_effect_p.bind( + x, effect_class_name="ForTestingOrderedEffect2") + + testing_primitive_with_effect_p.bind( + x, effect_class_name="ForTestingUnorderedEffect1") + + exp.call(x)) + + lowered_outer = jax.jit(f_outer).lower(x) self.assertEqual(["ForTestingOrderedEffect1()", "ForTestingOrderedEffect2()"], - sorted(str(e) for e in lowered_outer._lowering.compile_args["ordered_effects"])) - self.assertEqual(["ForTestingUnorderedEffect1()"], - sorted([str(e) for e in lowered_outer._lowering.compile_args["unordered_effects"]])) + sorted(str(e) for e in lowered_outer._lowering.compile_args["ordered_effects"])) + self.assertEqual(["ForTestingUnorderedEffect1()"], + sorted([str(e) for e in lowered_outer._lowering.compile_args["unordered_effects"]])) - mlir_outer_module_str = str(lowered_outer.compiler_ir()) - if exp.mlir_module_serialization_version >= _export._VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS: - main_expected_re = main_expected_re.replace("!stablehlo.token", "tensor<0xi1>") + mlir_outer_module_str = str(lowered_outer.compiler_ir()) self.assertRegex(mlir_outer_module_str, main_expected_re) - res = jax.jit(f_outer)(x) - self.assertAllClose(2. * 2. * x + 10. + 4. * 2. * x, res) + res = jax.jit(f_outer)(x) + self.assertAllClose(2. * 2. * x + 10. + 4. * 2. * x, res) @jtu.parameterized_filterable( - kwargs=[ - dict(v=v) - for v in range(export.minimum_supported_serialization_version, - export.maximum_supported_serialization_version + 1)]) + kwargs=[ + dict(v=v) + for v in range(export.minimum_supported_calling_convention_version, + export.maximum_supported_calling_convention_version + 1)]) def test_ordered_effects_poly(self, *, v: int): - self.override_serialization_version(v) - x = np.arange(12, dtype=np.float32).reshape((3, 4)) - def f_jax(x): # x: f32[b1, b2] - return 10. + testing_primitive_with_effect_p.bind(x, effect_class_name="ForTestingOrderedEffect1") - exp = get_exported(f_jax)(jax.ShapeDtypeStruct( - export.symbolic_shape("b2, b1"), x.dtype)) - mlir_module_str = str(exp.mlir_module()) - wrapped_main_expected_re = ( - r"@_wrapped_jax_export_main\(" - r"%arg0: tensor {jax.global_constant = \"b1\"}.*, " - r"%arg1: tensor {jax.global_constant = \"b2\"}.*, " - r"%arg2: !stablehlo.token {jax.token = true}.*, " - r"%arg3: tensor<\?x\?xf32>.*\) -> \(" - # Results - r"!stablehlo.token {jax.token = true}, tensor<\?x\?xf32>.*\)") - if exp.mlir_module_serialization_version < _export._VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS: - wrapped_main_expected_re = wrapped_main_expected_re.replace("!stablehlo.token", "tensor<0xi1>") - self.assertRegex(mlir_module_str, wrapped_main_expected_re) - - if exp.mlir_module_serialization_version < _export._VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS: - # The main function does not have tokens - self.assertNotRegex(mlir_module_str, r"@main.*token") - else: + with config.jax_export_calling_convention_version(v): + logging.info( + "Using JAX serialization version %s", + config.jax_export_calling_convention_version.value) + x = np.arange(12, dtype=np.float32).reshape((3, 4)) + def f_jax(x): # x: f32[b1, b2] + return 10. + testing_primitive_with_effect_p.bind(x, effect_class_name="ForTestingOrderedEffect1") + exp = get_exported(jax.jit(f_jax))(jax.ShapeDtypeStruct( + export.symbolic_shape("b2, b1"), x.dtype)) + mlir_module_str = str(exp.mlir_module()) + wrapped_main_expected_re = ( + r"@_wrapped_jax_export_main\(" + r"%arg0: tensor {jax.global_constant = \"b1\".* " + r"%arg1: tensor {jax.global_constant = \"b2\".* " + r"%arg2: !stablehlo.token {jax.token = true.* " + r"%arg3: tensor<\?x\?xf32>.*\) -> \(" + # Results + r"!stablehlo.token {jax.token = true.*, tensor<\?x\?xf32>.*\)") + self.assertRegex(mlir_module_str, wrapped_main_expected_re) + main_expected_re = ( r"@main\(" - r"%arg0: !stablehlo.token {jax.token = true}.*, " + r"%arg0: !stablehlo.token {jax.token = true.*, " r"%arg1: tensor<\?x\?xf32>.*\) -> \(" # Results - r"!stablehlo.token {jax.token = true}, tensor<\?x\?xf32>.*\)") + r"!stablehlo.token {jax.token = true.*, tensor<\?x\?xf32>.*\)") self.assertRegex(mlir_module_str, main_expected_re) - res = export.call_exported(exp)(x) - self.assertAllClose(10. + 2. * x, res) + res = exp.call(x) + self.assertAllClose(10. + 2. * x, res) @jtu.parameterized_filterable( kwargs=[ dict(v=v) - for v in range(export.minimum_supported_serialization_version, - export.maximum_supported_serialization_version + 1)]) + for v in range(export.minimum_supported_calling_convention_version, + export.maximum_supported_calling_convention_version + 1)]) def test_ordered_effects_multi_platform_and_poly(self, *, v: int): - self.override_serialization_version(v) - if jtu.device_under_test() == "gpu": - # The export is not applicable to GPU - raise unittest.SkipTest("Not intended for running on GPU") - x = np.ones((3, 4), dtype=np.float32) - def f_jax(x): # x: f32[b1, b2] - return 10. + _testing_multi_platform_func(x, - effect_class_name="ForTestingOrderedEffect1") - exp = get_exported( - f_jax, - lowering_platforms=("cpu", "tpu") - )(jax.ShapeDtypeStruct(export.symbolic_shape("b1, b2"), x.dtype)) - mlir_module_str = str(exp.mlir_module()) - wrapped_main_expected_re = ( - r"@_wrapped_jax_export_main\(" - r"%arg0: tensor {jax.global_constant = \"_platform_index\"}.*, " - r"%arg1: tensor {jax.global_constant = \"b1\"}.*, " - r"%arg2: tensor {jax.global_constant = \"b2\"}.*, " - r"%arg3: !stablehlo.token {jax.token = true}.*, " - r"%arg4: tensor<\?x\?xf32>.*\) -> \(" - # Results - r"!stablehlo.token {jax.token = true}, tensor<\?x\?xf32>.*\)") - if exp.mlir_module_serialization_version < _export._VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS: - wrapped_main_expected_re = wrapped_main_expected_re.replace("!stablehlo.token", "tensor<0xi1>") - self.assertRegex(mlir_module_str, wrapped_main_expected_re) - - if exp.mlir_module_serialization_version < _export._VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS: - # The main function does not have tokens - self.assertNotRegex(mlir_module_str, r"@main.*token") - else: + with config.jax_export_calling_convention_version(v): + logging.info( + "Using JAX serialization version %s", + config.jax_export_calling_convention_version.value) + if jtu.device_under_test() == "gpu": + # The export is not applicable to GPU + raise unittest.SkipTest("Not intended for running on GPU") + x = np.ones((3, 4), dtype=np.float32) + def f_jax(x): # x: f32[b1, b2] + return 10. + _testing_multi_platform_func(x, + effect_class_name="ForTestingOrderedEffect1") + exp = get_exported( + jax.jit(f_jax), + lowering_platforms=("cpu", "tpu") + )(jax.ShapeDtypeStruct(export.symbolic_shape("b1, b2"), x.dtype)) + mlir_module_str = str(exp.mlir_module()) + wrapped_main_expected_re = ( + r"@_wrapped_jax_export_main\(" + r"%arg0: tensor {jax.global_constant = \"_platform_index\".*, " + r"%arg1: tensor {jax.global_constant = \"b1\".*, " + r"%arg2: tensor {jax.global_constant = \"b2\".*, " + r"%arg3: !stablehlo.token {jax.token = true.*, " + r"%arg4: tensor<\?x\?xf32>.*\) -> \(" + # Results + r"!stablehlo.token {jax.token = true.*, tensor<\?x\?xf32>.*\)") + self.assertRegex(mlir_module_str, wrapped_main_expected_re) + main_expected_re = ( r"@main\(" - r"%arg0: tensor {jax.global_constant = \"_platform_index\"}.*, " - r"%arg1: !stablehlo.token {jax.token = true}.*, " + r"%arg0: tensor {jax.global_constant = \"_platform_index\".*, " + r"%arg1: !stablehlo.token {jax.token = true.*, " r"%arg2: tensor<\?x\?xf32>.*\) -> \(" # Results - r"!stablehlo.token {jax.token = true}, tensor<\?x\?xf32>.*\)") + r"!stablehlo.token {jax.token = true.*, tensor<\?x\?xf32>.*\)") self.assertRegex(mlir_module_str, main_expected_re) - res = export.call_exported(exp)(x) - self.assertAllClose(10. + _testing_multi_platform_fun_expected(x), - res) + res = exp.call(x) + self.assertAllClose(10. + _testing_multi_platform_fun_expected(x), + res) @jtu.parameterized_filterable( kwargs=[ dict(v=v) - for v in range(export.minimum_supported_serialization_version, - export.maximum_supported_serialization_version + 1)]) + for v in range(export.minimum_supported_calling_convention_version, + export.maximum_supported_calling_convention_version + 1)]) def test_ordered_effects_with_donation(self, *, v: int): - self.override_serialization_version(v) - x = np.arange(3, dtype=np.float32) + with config.jax_export_calling_convention_version(v): + logging.info( + "Using JAX serialization version %s", + config.jax_export_calling_convention_version.value) - def f_jax(x): - return testing_primitive_with_effect_p.bind( - x, effect_class_name="ForTestingOrderedEffect1" - ) + x = np.arange(3, dtype=np.float32) - f_jax = jax.jit(f_jax, donate_argnums=(0,)) - exp = export.export(f_jax)(x) - mlir_module_str = str(exp.mlir_module()) - if exp.mlir_module_serialization_version < _export._VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS: - self.assertRegex(mlir_module_str, r"@main.*tf.aliasing_output = 0") - self.assertRegex(mlir_module_str, r"@_wrapped_jax_export_main.*tf.aliasing_output = 1") - else: + def f_jax(x): + return testing_primitive_with_effect_p.bind( + x, effect_class_name="ForTestingOrderedEffect1" + ) + + f_jax = jax.jit(f_jax, donate_argnums=(0,)) + exp = export.export(f_jax)(x) + mlir_module_str = str(exp.mlir_module()) self.assertRegex(mlir_module_str, r"@main.*tf.aliasing_output = 1") self.assertRegex(mlir_module_str, r"@_wrapped_jax_export_main.*tf.aliasing_output = 1") @@ -1346,13 +1738,40 @@ def f_jax(x): ) ]) def test_ordered_effects_error(self, *, name: str, expect_error: str): + if not CAN_SERIALIZE: + # These errors arise during serialization + self.skipTest("serialization is disabled") x = np.ones((3, 4), dtype=np.float32) def f_jax(x): return 10. + _testing_multi_platform_func( x, effect_class_name="ForTestingOrderedEffect" + name) with self.assertRaisesRegex(Exception, expect_error): - _ = get_exported(f_jax)(jax.ShapeDtypeStruct((3, 4), x.dtype)) + _ = get_exported(jax.jit(f_jax))(jax.ShapeDtypeStruct((3, 4), x.dtype)) + + @jtu.parameterized_filterable( + kwargs=[ + {"m": 5, "k": 4, "n": 3, "group_sizes": [5]}, + {"m": 10, "k": 9, "n": 8, "group_sizes": [3, 7]}, + ]) + def test_ragged_dot(self, m, k, n, group_sizes): + def f_jax(x, y, gs): + return jax.lax.ragged_dot(x, y, gs) + dtype = np.float32 + group_sizes = np.array(group_sizes, dtype=np.int32) + lhs = np.arange(m * k, dtype=dtype).reshape((m, k)) + num_groups = group_sizes.shape[0] + rhs = np.arange(num_groups * k * n, dtype=dtype).reshape((num_groups, k, n)) + res_native = f_jax(lhs, rhs, group_sizes) + + exp_f = get_exported(jax.jit(f_jax))( + jax.ShapeDtypeStruct(lhs.shape, dtype=lhs.dtype), + jax.ShapeDtypeStruct(rhs.shape, dtype=rhs.dtype), + jax.ShapeDtypeStruct(group_sizes.shape, dtype=group_sizes.dtype), + ) + res_exported = exp_f.call(lhs, rhs, group_sizes) + self.assertAllClose(res_native, res_exported) + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/extend_test.py b/tests/extend_test.py index 37f5c911d821..c419878c8c3a 100644 --- a/tests/extend_test.py +++ b/tests/extend_test.py @@ -12,19 +12,29 @@ # See the License for the specific language governing permissions and # limitations under the License. -from absl.testing import absltest +import os + +import numpy as np +import unittest +from absl.testing import absltest, parameterized import jax +from jax import lax import jax.extend as jex import jax.numpy as jnp from jax._src import abstract_arrays +from jax._src import api +from jax._src import core from jax._src import linear_util from jax._src import prng from jax._src import test_util as jtu +from jax._src.interpreters import mlir +from jax._src.lib import xla_extension_version +from jax._src.lib.mlir import ir +from jax._src.extend import ffi -from jax import config -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() class ExtendTest(jtu.JaxTestCase): @@ -39,6 +49,7 @@ def test_symbols(self): self.assertIs(jex.random.unsafe_rbg_prng_impl, prng.unsafe_rbg_prng_impl) # Assume these are tested elsewhere, only check equivalence + self.assertIs(jex.backend.clear_backends, api.clear_backends) self.assertIs(jex.core.array_types, abstract_arrays.array_types) self.assertIs(jex.linear_util.StoreException, linear_util.StoreException) self.assertIs(jex.linear_util.WrappedFun, linear_util.WrappedFun) @@ -81,5 +92,83 @@ def no_rule(*args, **kwargs): self.assertEqual(impl, jax.random.key_impl(k)) +class FfiTest(jtu.JaxTestCase): + + def testHeadersExist(self): + base_dir = os.path.join(jex.ffi.include_dir(), "xla", "ffi", "api") + for header in ["c_api.h", "api.h", "ffi.h"]: + self.assertTrue(os.path.exists(os.path.join(base_dir, header))) + + @parameterized.parameters( + [True, int(1), float(5.0), + np.int32(-5), np.float32(0.5)]) + def testIrAttribute(self, value): + with mlir.make_ir_context(), ir.Location.unknown(): + const = mlir.ir_constant(value) + attr = ffi._ir_attribute(value) + assert const.type.element_type == attr.type + + @parameterized.parameters([True, 1, 5.0, "param", np.float32(0.5)]) + def testParams(self, param): + prim = core.Primitive("test_ffi") + prim.def_abstract_eval(lambda *args, **kwargs: args[0]) + mlir.register_lowering(prim, jex.ffi.ffi_lowering("test_ffi")) + + # TODO(dfm): Currently testing that lowering works with different types of + # parameters, but we should probably actually check the emitted HLO. + func = jax.jit(lambda *args: prim.bind(*args, param=param)) + func.lower(jnp.linspace(0, 5, 10)) + + @jtu.sample_product( + shape=[(1,), (4,), (5,)], + dtype=(np.int32,), + ) + @jtu.run_on_devices("gpu") + @unittest.skipIf(xla_extension_version < 272, "requires jaxlib 0.4.31") + def testFfiCall(self, shape, dtype): + pivots_size = shape[-1] + permutation_size = 2 * pivots_size + pivots = jnp.arange(permutation_size - 1, pivots_size - 1, -1, dtype=dtype) + pivots = jnp.broadcast_to(pivots, shape) + expected = lax.linalg.lu_pivots_to_permutation(pivots, permutation_size) + actual = ffi_call_lu_pivots_to_permutation(pivots, permutation_size) + self.assertArraysEqual(actual, expected) + + @jtu.sample_product( + shape=[(1,), (4,), (5,)], + dtype=(np.int32,), + vectorized=(False, True), + ) + @jtu.run_on_devices("gpu") + @unittest.skipIf(xla_extension_version < 272, "requires jaxlib 0.4.31") + def testFfiCallBatching(self, shape, dtype, vectorized): + shape = (10,) + shape + pivots_size = shape[-1] + permutation_size = 2 * pivots_size + pivots = jnp.arange(permutation_size - 1, pivots_size - 1, -1, dtype=dtype) + pivots = jnp.broadcast_to(pivots, shape) + expected = lax.linalg.lu_pivots_to_permutation(pivots, permutation_size) + actual = jax.vmap(lambda x: ffi_call_lu_pivots_to_permutation( + x, permutation_size, vectorized=vectorized))(pivots) + self.assertArraysEqual(actual, expected) + + +# TODO(dfm): For now this test uses the `cu_lu_pivots_to_permutation` +# custom call target because that's the only one in jaxlib that uses the +# new FFI interface. Once more are available, consider using something that +# can be run on multiple platforms. +def ffi_call_lu_pivots_to_permutation(pivots, permutation_size, vectorized=True): + return jex.ffi.ffi_call( + "cu_lu_pivots_to_permutation", + jax.ShapeDtypeStruct( + shape=pivots.shape[:-1] + (permutation_size,), + dtype=pivots.dtype, + ), + pivots, + permutation_size=np.int32(permutation_size), + vectorized=vectorized, + ) + + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/for_loop_test.py b/tests/for_loop_test.py index 641f12ff0e1c..b79c233e6f2e 100644 --- a/tests/for_loop_test.py +++ b/tests/for_loop_test.py @@ -24,8 +24,7 @@ from jax._src.lax.control_flow import for_loop import jax.numpy as jnp -from jax import config -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() def remat_of_for_loop(nsteps, body, state, **kwargs): return jax.remat(lambda state: for_loop.for_loop(nsteps, body, state, @@ -143,7 +142,7 @@ def body(i, refs): key = jax.random.PRNGKey(0) x = jax.random.normal(key, (8,)) - np.testing.assert_allclose(cumsum(x), jnp.cumsum(x)) + np.testing.assert_allclose(cumsum(x), jnp.cumsum(x), rtol=1e-6) def for_body_swap(i, refs): a_ref, b_ref = refs @@ -297,7 +296,7 @@ def step(x, _): A = jnp.zeros((3, 3)) # The second DUS was unnecessarily replicating A across time. # We check XLA because _scan_impl is "underneath" the jaxpr language. - s = str(jax.xla_computation(jax.grad(loss))(A).as_hlo_text()) + s = jax.jit(jax.grad(loss)).lower(A).as_text('hlo') assert s.count("dynamic-update-slice(") < 2 @_for_loop_impls diff --git a/tests/fused_attention_stablehlo_test.py b/tests/fused_attention_stablehlo_test.py index 802f23c56d41..9cf7ba80dff5 100644 --- a/tests/fused_attention_stablehlo_test.py +++ b/tests/fused_attention_stablehlo_test.py @@ -14,9 +14,10 @@ from functools import partial from absl.testing import absltest -from typing import Optional import os -os.environ['XLA_FLAGS'] = '--xla_gpu_enable_cudnn_fmha=true --xla_gpu_fused_attention_use_cudnn_rng=true' + +os.environ["XLA_FLAGS"] = \ + "--xla_gpu_enable_cudnn_fmha=true --xla_gpu_fused_attention_use_cudnn_rng=true" import numpy as np import jax @@ -25,36 +26,57 @@ from jax.sharding import PartitionSpec, NamedSharding from jax._src import config from jax._src import test_util as jtu -from jax._src.cudnn.fused_attention_stablehlo import dot_product_attention +from jax._src.cudnn.fused_attention_stablehlo import ( + dot_product_attention, + check_is_flash_attention, + check_cudnn_version, + check_compute_capability, + MaskType, + AttentionLayout, +) config.parse_flags_with_absl() Array = jnp.ndarray def sdpa_train(query: Array, - key: Array, - value: Array, - grad: Array, - bias: Optional[Array] = None, - mask: Optional[Array] = None, - scale: float = 0.5, - is_causal_mask: bool = False, - dropout_rate: float = 0.1) -> Array: + key: Array, + value: Array, + grad: Array, + bias: Array | None = None, + mask: Array | None = None, + scale: float = 0.5, + mask_type: MaskType = MaskType.NO_MASK, + is_bnth: bool = False, + dropout_rate: float = 0.1) -> Array: if mask is not None: # convert bool mask to dtype mask mask = mask.astype(query.dtype) + if mask_type == MaskType.PADDING: + if is_bnth: + B, _, S, _ = query.shape + else: + B, S, _, _ = query.shape + q_seqlen = kv_seqlen = jnp.full((B,), S // 2, jnp.int32) + else: + q_seqlen = kv_seqlen = None out, sdpa_vjp = jax.vjp( - partial(dot_product_attention, scale=scale, is_causal_mask=is_causal_mask, dropout_rate=dropout_rate), - query, key, value, bias, mask) - query_grad, key_grad, value_grad, _, _ = sdpa_vjp(grad) + partial(dot_product_attention, scale=scale, mask_type=mask_type, + dropout_rate=dropout_rate, + qkv_layout="BNTH" if is_bnth else "BTNH"), + query, key, value, bias, mask, q_seqlen, kv_seqlen) + query_grad, key_grad, value_grad, bias_grad, _, _, _ = sdpa_vjp(grad) + if bias is not None and len(bias.shape) == 3: + # has dbias + return out, (query_grad, key_grad, value_grad, bias_grad) return out, (query_grad, key_grad, value_grad) def sdpa_ref(query: Array, key: Array, value: Array, - bias: Optional[Array] = None, - mask: Optional[Array] = None, + bias: Array | None = None, + mask: Array | None = None, scale: float = 0.5, - is_causal_mask: bool = False, + mask_type: MaskType = MaskType.NO_MASK, dropout_rate: float = 0.1) -> Array: def get_large_negative_number(input_t): @@ -64,77 +86,116 @@ def get_large_negative_number(input_t): elif jnp.issubdtype(dtype, jnp.integer): dtype_max = jnp.iinfo(dtype).max else: - raise ValueError('Unsupported dtype for inputs.') + raise ValueError("Unsupported dtype for inputs.") large_negative_number = jnp.asarray(-0.7 * dtype_max, dtype=dtype) return large_negative_number - def get_causal_mask(input_t): - large_negative_number = get_large_negative_number(input_t) - t = input_t.shape[2] + def get_causal_mask(logits): + large_negative_number = get_large_negative_number(logits) + t = logits.shape[-2] col_idx = jax.lax.broadcasted_iota(np.int32, (t, t), 1) row_idx = jax.lax.broadcasted_iota(np.int32, (t, t), 0) - mask = (row_idx < col_idx).astype(input_t.dtype) * large_negative_number - return mask[jnp.newaxis, jnp.newaxis, :, :] + mask = (row_idx < col_idx).astype(logits.dtype) * large_negative_number + return mask[(*([jnp.newaxis]*(len(logits.shape) - 2)), ...)] - attn_weights = jnp.einsum('bqhd,bkhd->bhqk', query, key) + def get_padding_mask(logits): + S, T = logits.shape[-2:] + # temp WAR as cuDNN has a bug for subtraction between two large negative value + large_negative_number = jnp.array(-2 << 40, dtype=logits.dtype) + q_padding = (jax.lax.iota(np.int32, S) >= S // 2).reshape((S, 1)) + kv_padding = (jax.lax.iota(np.int32, T) >= T // 2).reshape((1, T)) + combined_padding = \ + (q_padding + kv_padding).astype(logits.dtype) * large_negative_number + return jax.lax.broadcast(combined_padding, logits.shape[:-2]) + + def get_encoded_padding_mask(encoded): + S = encoded.shape[1] + encoded_padding = (jax.lax.iota(np.int32, S) < S // 2).astype(encoded.dtype) + return jax.lax.broadcast_in_dim( + encoded_padding, encoded.shape, broadcast_dimensions=[1]) + + B, T, qN, H = query.shape + _, _, kN, _ = key.shape + logits = jnp.einsum("bqhd,bkhd->bhqk", query, key) if scale != 1.0: - attn_weights = attn_weights * scale - if is_causal_mask: - bias = get_causal_mask(attn_weights) + logits = logits * scale + if mask_type == MaskType.CAUSAL: + bias = get_causal_mask(logits) + elif mask_type == MaskType.PADDING: + bias = get_padding_mask(logits) if bias is not None: - attn_weights = attn_weights + bias.astype(attn_weights.dtype) + if bias.shape != logits.shape: + bias = jnp.broadcast_to(bias, logits.shape) + logits = logits + bias.astype(logits.dtype) if mask is not None: - large_negative_number = get_large_negative_number(attn_weights) - attn_weights = jax.lax.select(mask, attn_weights, jax.lax.broadcast(large_negative_number, attn_weights.shape)) - attn_weights = jax.nn.softmax(attn_weights) + large_negative_number = get_large_negative_number(logits) + logits = jax.lax.select( + mask, logits, jax.lax.broadcast(large_negative_number, logits.shape)) + probs = jax.nn.softmax(logits, axis=-1) if dropout_rate > 0.: keep_prob = 1.0 - dropout_rate dropout_rng = jax.random.key(0) - keep = jax.random.bernoulli(dropout_rng, keep_prob, attn_weights.shape) - attn_weights = jax.lax.select(keep, attn_weights / keep_prob, jnp.zeros_like(attn_weights)) - - return jnp.einsum('bhqk,bkhd->bqhd', attn_weights, value) + keep = jax.random.bernoulli(dropout_rng, keep_prob, probs.shape) + probs = jax.lax.select(keep, probs / keep_prob, jnp.zeros_like(probs)) + encoded = jnp.einsum("bhqk,bkhd->bqhd", probs, value) + if mask_type == MaskType.PADDING: + # cuDNN padding mask generation will mask out output accordingly + # make sure the behavior is the same + encoded_mask = get_encoded_padding_mask(encoded) + encoded = encoded * encoded_mask + return encoded def sdpa_train_ref(query: Array, key: Array, value: Array, grad: Array, - bias: Optional[Array] = None, - mask: Optional[Array] = None, + bias: Array | None = None, + mask: Array | None = None, scale: float = 0.5, - is_causal_mask: bool = False, + mask_type: MaskType = MaskType.NO_MASK, dropout_rate: float = 0.1) -> Array: out_ref, sdpa_vjp_ref = jax.vjp( - partial(sdpa_ref, scale=scale, is_causal_mask=is_causal_mask, dropout_rate=dropout_rate), + partial( + sdpa_ref, scale=scale, mask_type=mask_type, dropout_rate=dropout_rate), query, key, value, bias, mask) - query_grad_ref, key_grad_ref, value_grad_ref, _, _ = sdpa_vjp_ref(grad) + query_grad_ref, key_grad_ref, value_grad_ref, bias_grad_ref, _ = sdpa_vjp_ref(grad) + if bias is not None and len(bias.shape) == 3: + return out_ref, (query_grad_ref, key_grad_ref, value_grad_ref, bias_grad_ref) return out_ref, (query_grad_ref, key_grad_ref, value_grad_ref) class DotProductAttentionTest(jtu.JaxTestCase): + def setUp(self): + super().setUp() + if jax.device_count() < 4: + self.skipTest("Requires more than 4 devices.") + try: + cudnn_version = check_cudnn_version() + check_compute_capability((80, 90)) + except RuntimeError as e: + self.skipTest(str(e)) + return + if cudnn_version < 8904: + self.skipTest("Requires >= cuDNN 8.9.4") + @jtu.sample_product( batch_size=[4], - seq_len=[256, 1024], + seq_len=[1024], num_heads=[8], head_dim=[64, 128], use_bias=[False, True], - use_mask=[False, True], - is_causal_mask=[False], + mask_type=[MaskType.NO_MASK], dropout_rate=[0, 0.5], scale=[0.5], dtype=[jnp.float16, jnp.bfloat16] ) @jtu.run_on_devices("cuda") def test_sdpa(self, batch_size: int, seq_len: int, num_heads: int, - head_dim: int, use_bias: bool, use_mask: bool, is_causal_mask: bool, + head_dim: int, use_bias: bool, mask_type: MaskType, dropout_rate: float, scale: float, dtype: jnp.dtype): - if seq_len == 256 and is_causal_mask: - self.skipTest("Fused attention does not support mask generation.") - if seq_len == 256 and head_dim == 128: - self.skipTest("Fused attention does not support head dim = 128.") if len(jax.local_devices()) <= 4: self.skipTest("Require at least 4 devices to run sharding tests.") - k1, k2, k3, k4, k5, k6 = jax.random.split(jax.random.key(0), 6) + k1, k2, k3, k4, k5 = jax.random.split(jax.random.key(0), 5) query = jax.random.normal( k1, (batch_size, seq_len, num_heads, head_dim), dtype=dtype) key = jax.random.normal( @@ -148,27 +209,22 @@ def test_sdpa(self, batch_size: int, seq_len: int, num_heads: int, k5, (batch_size, num_heads, seq_len, seq_len), dtype=dtype) else: bias = None - if use_mask: - mask = jax.random.bernoulli( - k5, 0.5, (batch_size, num_heads, seq_len, seq_len)) - else: - mask = None + mask = None devices = np.array(jax.local_devices()[:4]) devices = devices.reshape((2, 2)) - with Mesh(devices, ('dp', 'tp')) as mesh: - qkv_spec = PartitionSpec('dp', None, 'tp', None) + with Mesh(devices, ("dp", "tp")) as mesh: + qkv_spec = PartitionSpec("dp", None, "tp", None) qkv_sharding = NamedSharding(mesh, qkv_spec) if bias is not None: - bias_spec = PartitionSpec('dp', 'tp', None, None) + bias_spec = PartitionSpec("dp", "tp", None, None) else: bias_spec = PartitionSpec() if mask is not None: - mask_spec = PartitionSpec('dp', 'tp', None, None) + mask_spec = PartitionSpec("dp", "tp", None, None) else: mask_spec = PartitionSpec() bias_sharding = NamedSharding(mesh, bias_spec) mask_sharding = NamedSharding(mesh, mask_spec) - replicated = NamedSharding(mesh, PartitionSpec()) query = jax.device_put(query, qkv_sharding) key = jax.device_put(key, qkv_sharding) value = jax.device_put(value, qkv_sharding) @@ -177,30 +233,196 @@ def test_sdpa(self, batch_size: int, seq_len: int, num_heads: int, if mask is not None: mask = jax.device_put(mask, mask_sharding) grad = jax.device_put(grad, qkv_sharding) - in_shardings = (qkv_sharding, qkv_sharding, qkv_sharding, qkv_sharding, bias_sharding, mask_sharding) - out_shardings = (replicated, (qkv_sharding, qkv_sharding, qkv_sharding)) + in_shardings = (qkv_sharding, qkv_sharding, qkv_sharding, + qkv_sharding, bias_sharding, mask_sharding) + out_shardings = (qkv_sharding, (qkv_sharding, qkv_sharding, qkv_sharding)) jitted_sdpa_train = jax.jit( - partial(sdpa_train, scale=scale, is_causal_mask=is_causal_mask, dropout_rate=dropout_rate), + partial( + sdpa_train, scale=scale, mask_type=mask_type, + dropout_rate=dropout_rate), in_shardings=in_shardings, out_shardings=out_shardings ) jitted_sdpa_train_ref = jax.jit( - partial(sdpa_train_ref, scale=scale, is_causal_mask=is_causal_mask, dropout_rate=dropout_rate), + partial( + sdpa_train_ref, scale=scale, mask_type=mask_type, + dropout_rate=dropout_rate), in_shardings=in_shardings, out_shardings=out_shardings ) - out, (query_grad, key_grad, value_grad) = jitted_sdpa_train(query, key, value, grad, bias, mask) - out_ref, (query_grad_ref, key_grad_ref, value_grad_ref) = jitted_sdpa_train_ref(query, key, value, grad, bias, mask) + out, (query_grad, key_grad, value_grad) = \ + jitted_sdpa_train(query, key, value, grad, bias, mask) + out_ref, (query_grad_ref, key_grad_ref, value_grad_ref) = \ + jitted_sdpa_train_ref(query, key, value, grad, bias, mask) self.assertArraysAllClose(out_ref, out, rtol=1e-5, atol=1e-5) if seq_len > 512: # query_grad in flash attention is not deterministic - self.assertArraysAllClose(query_grad_ref, query_grad, rtol=1e-2, atol=1e-2) + self.assertArraysAllClose( + query_grad_ref, query_grad, rtol=1e-2, atol=1e-2) else: - self.assertArraysAllClose(query_grad_ref, query_grad, rtol=1e-5, atol=1e-5) - self.assertArraysAllClose(key_grad_ref, key_grad, rtol=1e-5, atol=1e-5) - self.assertArraysAllClose(value_grad_ref, value_grad, rtol=1e-5, atol=1e-5) + self.assertArraysAllClose( + query_grad_ref, query_grad, rtol=1e-5, atol=1e-5) + self.assertArraysAllClose( + key_grad_ref, key_grad, rtol=1e-5, atol=1e-5) + self.assertArraysAllClose( + value_grad_ref, value_grad, rtol=1e-5, atol=1e-5) + + @jtu.run_on_devices("cuda") + def test_sdpa_inference(self): + k1, k2, k3 = jax.random.split(jax.random.key(0), 3) + query = jax.random.normal( + k1, (4, 1024, 4, 64), dtype=jnp.bfloat16) + key = jax.random.normal( + k2, (4, 1024, 4, 64), dtype=jnp.bfloat16) + value = jax.random.normal( + k3, (4, 1024, 4, 64), dtype=jnp.bfloat16) + + devices = np.array(jax.local_devices()[:4]) + devices = devices.reshape((2, 2)) + with Mesh(devices, ("dp", "tp")) as mesh: + qkv_spec = PartitionSpec("dp", None, "tp", None) + qkv_sharding = NamedSharding(mesh, qkv_spec) + replicated = NamedSharding(mesh, PartitionSpec()) + in_shardings = ( + qkv_sharding, qkv_sharding, qkv_sharding, replicated, replicated) + out_shardings = qkv_sharding + query = jax.device_put(query, qkv_sharding) + key = jax.device_put(key, qkv_sharding) + value = jax.device_put(value, qkv_sharding) + jitted_sdpa_inference = jax.jit( + partial( + dot_product_attention, scale=1.0, mask_type=MaskType.NO_MASK, + dropout_rate=0), + in_shardings=in_shardings, + out_shardings=out_shardings + ) + + jitted_sdpa_inference_ref = jax.jit( + partial( + sdpa_ref, scale=1.0, mask_type=MaskType.NO_MASK, dropout_rate=0), + in_shardings=in_shardings, + out_shardings=out_shardings + ) + + out = jitted_sdpa_inference(query, key, value, None, None) + out_ref = jitted_sdpa_inference_ref(query, key, value, None, None) + self.assertArraysAllClose(out_ref, out, rtol=1e-5, atol=1e-5) + + @jtu.run_on_devices("cuda") + def test_sdpa_var_seq(self): + k1, k2, k3, k4 = jax.random.split(jax.random.key(0), 4) + query = jax.random.normal( + k1, (4, 1024, 4, 64), dtype=jnp.bfloat16) + key = jax.random.normal( + k2, (4, 1024, 4, 64), dtype=jnp.bfloat16) + value = jax.random.normal( + k3, (4, 1024, 4, 64), dtype=jnp.bfloat16) + grad = jax.random.normal( + k4, (4, 1024, 4, 64), dtype=jnp.bfloat16) + jitted_sdpa_train = jax.jit( + partial( + sdpa_train, scale=1.0, mask_type=MaskType.PADDING, dropout_rate=0), + ) + + jitted_sdpa_train_ref = jax.jit( + partial( + sdpa_train_ref, scale=1.0, mask_type=MaskType.PADDING, dropout_rate=0), + ) + + out, (query_grad, key_grad, value_grad) = \ + jitted_sdpa_train(query, key, value, grad, None, None) + out_ref, (query_grad_ref, key_grad_ref, value_grad_ref) = \ + jitted_sdpa_train_ref(query, key, value, grad, None, None) + self.assertArraysAllClose(out_ref, out, rtol=1e-5, atol=1e-5) + self.assertArraysAllClose(query_grad_ref, query_grad, rtol=1e-2, atol=1e-2) + self.assertArraysAllClose(key_grad_ref, key_grad, rtol=1e-5, atol=1e-5) + self.assertArraysAllClose(value_grad_ref, value_grad, rtol=1e-5, atol=1e-5) + + @jtu.run_on_devices("cuda") + def test_sdpa_broadcast_bias_and_dbias(self): + try: + cudnn_version = check_cudnn_version() + check_compute_capability((90,)) + except RuntimeError as e: + self.skipTest(str(e)) + return + if cudnn_version < 8906: + self.skipTest("Requires >= cuDNN 8.9.6") + k1, k2, k3, k4, k5 = jax.random.split(jax.random.key(0), 5) + query = jax.random.normal( + k1, (2, 1024, 4, 64), dtype=jnp.bfloat16) + key = jax.random.normal( + k2, (2, 1024, 4, 64), dtype=jnp.bfloat16) + value = jax.random.normal( + k3, (2, 1024, 4, 64), dtype=jnp.bfloat16) + grad = jax.random.normal( + k4, (2, 1024, 4, 64), dtype=jnp.bfloat16) + bias = jax.random.normal( + k5, (4, 1024, 1024), dtype=jnp.bfloat16) + jitted_sdpa_train = jax.jit( + partial( + sdpa_train, scale=1.0, mask_type=MaskType.NO_MASK, dropout_rate=0), + ) + + jitted_sdpa_train_ref = jax.jit( + partial( + sdpa_train_ref, scale=1.0, mask_type=MaskType.NO_MASK, dropout_rate=0), + ) + + out, (query_grad, key_grad, value_grad, bias_grad) = \ + jitted_sdpa_train(query, key, value, grad, bias, None) + out_ref, (query_grad_ref, key_grad_ref, value_grad_ref, bias_grad_ref) = \ + jitted_sdpa_train_ref(query, key, value, grad, bias, None) + self.assertArraysAllClose(out_ref, out, rtol=1e-5, atol=1e-5) + self.assertArraysAllClose(query_grad_ref, query_grad, rtol=1e-2, atol=1e-2) + self.assertArraysAllClose(key_grad_ref, key_grad, rtol=1e-5, atol=1e-5) + self.assertArraysAllClose(value_grad_ref, value_grad, rtol=1e-5, atol=1e-5) + self.assertArraysAllClose(bias_grad_ref, bias_grad, rtol=1e-5, atol=1e-5) + + @jtu.run_on_devices("cuda") + def test_layouts(self): + dtype = "bfloat16" + B, T, N, H = 4, 1024, 8, 128 + S = T + k0, k1, k2, k3 = jax.random.split(jax.random.key(123), 4) + query = jax.random.normal(k0, (B, T, N, H), dtype=dtype) + key = jax.random.normal(k1, (B, S, N, H), dtype=dtype) + value = jax.random.normal(k2, (B, S, N, H), dtype=dtype) + grad = jax.random.normal(k3, (B, T, N, H), dtype=dtype) + + btnh_fn = jax.jit(partial(sdpa_train_ref, scale=.5, + mask_type=MaskType.CAUSAL, dropout_rate=0.0)) + out_ref, (dq_ref, dk_ref, dv_ref) = btnh_fn(query, key, value, grad) + + def _cvt(x): + return jnp.einsum("BTNH->BNTH", x) + def _cvt_back(x): + return jnp.einsum("BNTH->BTNH", x) + bnth_fn = jax.jit(partial(sdpa_train, scale=.5, mask_type=MaskType.CAUSAL, + is_bnth=True, dropout_rate=0.0)) + out, (dq, dk, dv) = bnth_fn(_cvt(query), _cvt(key), _cvt(value), _cvt(grad)) + + self.assertArraysAllClose(out_ref, _cvt_back(out)) + self.assertArraysAllClose(dq_ref, _cvt_back(dq)) + self.assertArraysAllClose(dk_ref, _cvt_back(dk)) + self.assertArraysAllClose(dv_ref, _cvt_back(dv)) + + def test_sdpa_utils(self): + test_cases = [ + (1, 257, 64, 8905, False, True), + (1, 1024, 64, 8905, False, False), + (1024, 1024, 64, 8905, False, False), + (1024, 1024, 128, 8905, False, False), + ] + + for k in test_cases: + sql_q, sql_v, head_dim, cudnn_version, has_bias, is_training = k + query = jnp.empty((4, sql_q, 4, head_dim)) + key = jnp.empty((4, sql_v, 4, head_dim)) + check_is_flash_attention( + query, key, AttentionLayout.BNTH, cudnn_version, has_bias, is_training) -if __name__ == '__main__': +if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/generated_fun_test.py b/tests/generated_fun_test.py index e96f100b46d4..a288e1a5f19a 100644 --- a/tests/generated_fun_test.py +++ b/tests/generated_fun_test.py @@ -22,11 +22,11 @@ import itertools as it import jax.numpy as jnp +import jax from jax import jit, jvp, vjp import jax._src.test_util as jtu -from jax import config -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() npr.seed(0) diff --git a/tests/gpu_memory_flags_test.py b/tests/gpu_memory_flags_test.py index 21cdae2da1af..308fff257348 100644 --- a/tests/gpu_memory_flags_test.py +++ b/tests/gpu_memory_flags_test.py @@ -13,14 +13,12 @@ # limitations under the License. import os -import sys import unittest from absl.testing import absltest import jax from jax._src import config from jax._src import test_util as jtu -from jax._src.lib import xla_extension_version config.parse_flags_with_absl() @@ -28,15 +26,11 @@ class GpuMemoryAllocationTest(absltest.TestCase): # This test must be run in its own subprocess. - @unittest.skipIf( - "pytest" in sys.modules, - "Test must run in an isolated process", - ) + @jtu.skip_under_pytest("Test must run in an isolated process") @unittest.skipIf( "XLA_PYTHON_CLIENT_ALLOCATOR" in os.environ, "Test does not work if the python client allocator has been overriden", ) - @unittest.skipIf(xla_extension_version < 225, "jaxlib version too old") def test_gpu_memory_allocation(self): falsey_values = ("0", "False", "false") preallocate = ( diff --git a/tests/heap_profiler_test.py b/tests/heap_profiler_test.py index 6d3468e95ac7..240eec1c8fba 100644 --- a/tests/heap_profiler_test.py +++ b/tests/heap_profiler_test.py @@ -17,11 +17,10 @@ import jax import jax._src.xla_bridge as xla_bridge -from jax import config import jax._src.test_util as jtu -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() class HeapProfilerTest(unittest.TestCase): diff --git a/tests/host_callback_test.py b/tests/host_callback_test.py index 0ac0933719e5..5988b2774408 100644 --- a/tests/host_callback_test.py +++ b/tests/host_callback_test.py @@ -14,7 +14,8 @@ from __future__ import annotations -from collections.abc import Sequence +import contextlib +from collections.abc import Callable, Sequence from functools import partial import itertools import logging @@ -22,7 +23,6 @@ import re import threading import time -from typing import Callable import unittest from unittest import skip, SkipTest @@ -30,7 +30,6 @@ import jax from jax import ad_checkpoint -from jax import config from jax import dtypes from jax import lax from jax import numpy as jnp @@ -42,11 +41,13 @@ from jax._src import test_util as jtu from jax._src.lib import xla_client +from jax.experimental.host_callback import _deprecated_id_print as hcb_id_print + xops = xla_client.ops import numpy as np -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() class _TestingOutputStream: @@ -91,8 +92,10 @@ def reset(self): def fun1(a): """Function used for several `id_tap` tests.""" - y = hcb.id_print(a * 2., what="a * 2", output_stream=testing_stream) - y = hcb.id_print(y * 3., what="y * 3", output_stream=testing_stream, result=y) + y = hcb_id_print(a * 2., what="a * 2", output_stream=testing_stream, + callback_flavor=hcb.CallbackFlavor.DEBUG) + y = hcb_id_print(y * 3., what="y * 3", output_stream=testing_stream, result=y, + callback_flavor=hcb.CallbackFlavor.DEBUG) return y ** 2 # Some computation to make the gradient interesting @@ -107,7 +110,7 @@ def maybe_print(do_print: bool, device_index: int = 0): """Conditionally print on testing_string""" if do_print: - return hcb.id_print( + return hcb_id_print( arg, what=what, output_stream=testing_stream, @@ -193,18 +196,13 @@ def helper_log_ir(name, logging.info(f"Optimized HLO[{name}]: {jax_optimized_hlo}") -prev_xla_flags = None - +_exit_stack = contextlib.ExitStack() def setUpModule(): - global prev_xla_flags - # This will control the CPU devices. On TPU we always have 2 devices - prev_xla_flags = jtu.set_host_platform_device_count(2) - + _exit_stack.enter_context(jtu.set_host_platform_device_count(2)) -# Reset to previous configuration in case other test modules will be run. def tearDownModule(): - prev_xla_flags() + _exit_stack.close() def assertMultiDeviceOutputEqual(tst: jtu.JaxTestCase, @@ -233,15 +231,28 @@ def replace_device_name(m) -> str: return assertMultiLineStrippedEqual(tst, expected, what) +class HostCallbackImportsTest(jtu.JaxTestCase): + @jtu.ignore_warning( + category=DeprecationWarning, + message="The host_callback APIs are deprecated") + def test_deprecated_imports(self): + if hasattr(hcb, "id_print"): + id_print = hcb.id_print + self.assertIs(id_print, hcb_id_print) + class HostCallbackTapTest(jtu.JaxTestCase): def setUp(self): - super().setUp() + # skipping here skips teardown, so do this before super().setUp(). if jtu.test_device_matches(["gpu"]) and jax.device_count() > 1: raise SkipTest("host_callback broken on multi-GPU platforms (#6447)") if xla_bridge.using_pjrt_c_api(): raise SkipTest("host_callback not implemented in PJRT C API") - + super().setUp() + self.enter_context(jtu.ignore_warning( + category=DeprecationWarning, message="The host_callback APIs are deprecated")) + self.enter_context(jtu.ignore_warning( + category=DeprecationWarning, message="backend and device argument")) testing_stream.reset() testing_stream._test_method_name = self._testMethodName self.old_flags = os.getenv("XLA_FLAGS", "") @@ -253,6 +264,10 @@ def tearDown(self) -> None: hcb.barrier_wait("HostCallbackTapTest.tearDown") super().tearDown() + def supported_only_in_legacy_mode(self): + if not hcb._HOST_CALLBACK_LEGACY.value: + self.skipTest("Not supported when JAX_HOST_CALLBACK_LEGACY=False") + def test_tap_eval(self): self.assertAllClose((5. * 2.) ** 2, fun1(5.)) hcb.barrier_wait() @@ -264,7 +279,7 @@ def test_tap_eval(self): def test_tap_with_tuple_results(self): def func2(x): - x1, y1 = hcb.id_print((x * 2., x * 3.), output_stream=testing_stream) + x1, y1 = hcb_id_print((x * 2., x * 3.), output_stream=testing_stream) return x1 + y1 self.assertEqual(3. * (2. + 3.), func2(3.)) @@ -275,7 +290,7 @@ def func2(x): def test_tap_with_dict_results(self): def func2(x): - res = hcb.id_print(dict(a=x * 2., b=x * 3.), output_stream=testing_stream) + res = hcb_id_print(dict(a=x * 2., b=x * 3.), output_stream=testing_stream) return res["a"] + res["b"] self.assertEqual(3. * (2. + 3.), func2(3.)) @@ -285,7 +300,7 @@ def func2(x): def test_tap_with_result(self): def func2(x): - x1 = hcb.id_print((x * 2., x * 3.), result=x * 4., + x1 = hcb_id_print((x * 2., x * 3.), result=x * 4., output_stream=testing_stream) return x1 @@ -320,8 +335,9 @@ def func2(x): testing_stream.output) def test_tap_with_device(self): + self.supported_only_in_legacy_mode() def func2(x): - x1 = hcb.id_print((x * 2., x * 3.), result=x * 4., + x1 = hcb_id_print((x * 2., x * 3.), result=x * 4., output_stream=testing_stream, tap_with_device=True) return x1 @@ -335,34 +351,46 @@ def func2(x): def test_tap_eval_exception(self): if not hcb._HOST_CALLBACK_OUTFEED.value: raise SkipTest("TODO: implement error handling for customcall") + # Simulate a tap error def tap_err(*args, **kwargs): raise ValueError("Some user message") def func(x): - x1 = hcb.id_print(x + 1, what="x1", output_stream=testing_stream) + x1 = hcb_id_print(x + 1, what="x1", output_stream=testing_stream) x2 = hcb.id_tap(tap_err, x1 + 1) - x3 = hcb.id_print(x2 + 1, what="x3", output_stream=testing_stream) + x3 = hcb_id_print(x2 + 1, what="x3", output_stream=testing_stream) return x3 - with self.assertRaisesRegex( - hcb.CallbackException, - re.compile("There were exceptions during callback processing. Last one was:.*" - "ValueError: Some user message", re.DOTALL)): + if hcb._HOST_CALLBACK_LEGACY.value: + ctx = self.assertRaisesRegex( + hcb.CallbackException, + re.compile("There were exceptions during callback processing. Last one was:.*" + "ValueError: Some user message", re.DOTALL)) + else: + ctx = self.assertRaisesRegex(Exception, "Some user message") + + with ctx: func(0) hcb.barrier_wait() - # We should have received everything before the error - assertMultiLineStrippedEqual(self, """ - what: x1 - 1 - what: x3 - 3""", testing_stream.output) + if hcb._HOST_CALLBACK_LEGACY.value: + # We should have received everything before the error + assertMultiLineStrippedEqual(self, """ + what: x1 + 1 + what: x3 + 3""", testing_stream.output) + else: + # We should have received everything before the error + assertMultiLineStrippedEqual(self, """ + what: x1 + 1""", testing_stream.output) def test_tap_empty(self): """Tap empty arrays.""" - hcb.id_print((), output_stream=testing_stream) - hcb.id_print((1., np.ones((2, 0))), what="second", output_stream=testing_stream) + hcb_id_print((), output_stream=testing_stream) + hcb_id_print((1., np.ones((2, 0))), what="second", output_stream=testing_stream) hcb.barrier_wait() assertMultiLineStrippedEqual(self, """ ( ) @@ -370,7 +398,7 @@ def test_tap_empty(self): ( 1.00 [] )""", testing_stream.output) def test_tap_jit_simple(self): - jit_fun1 = jax.jit(lambda x: 3. * hcb.id_print( + jit_fun1 = jax.jit(lambda x: 3. * hcb_id_print( 2. * x, what="here", output_stream=testing_stream)) self.assertAllClose(6. * 5., jit_fun1(5.)) hcb.barrier_wait() @@ -380,7 +408,7 @@ def test_tap_jit_simple(self): def test_tap_jit_no_invars(self): def func(): # jitted function does not take arguments - return hcb.id_print(42, output_stream=testing_stream) + return hcb_id_print(42, output_stream=testing_stream) self.assertAllClose(42, jax.jit(func)()) hcb.barrier_wait() @@ -389,7 +417,7 @@ def func(): # jitted function does not take arguments def test_tap_jit_multiple_invars(self): def func(x1, x2): - return hcb.id_print(x1 + x2, output_stream=testing_stream) + return hcb_id_print(x1 + x2, output_stream=testing_stream) self.assertAllClose(42, jax.jit(func)(40, 2)) hcb.barrier_wait() @@ -398,7 +426,7 @@ def func(x1, x2): def test_tap_jit_constant(self): def func(x): - return hcb.id_print(42, result=x, output_stream=testing_stream) + return hcb_id_print(42, result=x, output_stream=testing_stream) self.assertAllClose(5, jax.jit(func)(5)) hcb.barrier_wait() @@ -407,13 +435,17 @@ def func(x): def test_tap_jit_sequence1(self): def func(x): - x1 = hcb.id_print(x, where="1", output_stream=testing_stream) - return hcb.id_print(x1 + 1, where="2", output_stream=testing_stream) + x1 = hcb_id_print(x, where="1", output_stream=testing_stream) + return hcb_id_print(x1 + 1, where="2", output_stream=testing_stream) logging.info("%s: %s", self._testMethodName, jax.make_jaxpr(func)(1)) - logging.info("%s: %s", self._testMethodName, - jax.xla_computation(func, backend=jtu.device_under_test())(1).as_hlo_text()) + logging.info( + "%s: %s", + self._testMethodName, + jax.jit(func) + .trace(1) + .lower(lowering_platforms=(jtu.device_under_test(),)).as_text("hlo")) self.assertEqual(2, jax.jit(func)(1)) hcb.barrier_wait() @@ -427,8 +459,8 @@ def test_tap_jit2(self): """A sequence of JIT.""" def func(x): - x1 = hcb.id_print(x, where="1", output_stream=testing_stream) - x2 = hcb.id_print(x1 + 1, where="2", output_stream=testing_stream) + x1 = hcb_id_print(x, where="1", output_stream=testing_stream) + x2 = hcb_id_print(x1 + 1, where="2", output_stream=testing_stream) return x2 self.assertEqual(2, jax.jit(func)(1)) @@ -448,8 +480,8 @@ def test_tap_jit_result_unused(self): """We can id_print even if we don't use the result.""" def func(x): - hcb.id_print(x, where="1", output_stream=testing_stream) - hcb.id_print(x + 1, where="2", output_stream=testing_stream) + hcb_id_print(x, where="1", output_stream=testing_stream) + hcb_id_print(x + 1, where="2", output_stream=testing_stream) return x + 1 self.assertEqual(2, jax.jit(func)(1)) @@ -467,14 +499,14 @@ def func(x): def test_tap_jit_nested(self): def func(x): - x1 = hcb.id_print(x, where="1", output_stream=testing_stream) + x1 = hcb_id_print(x, where="1", output_stream=testing_stream) def func_nested(x): - x2 = hcb.id_print(x + 1, where="nested", output_stream=testing_stream) + x2 = hcb_id_print(x + 1, where="nested", output_stream=testing_stream) return x2 x3 = jax.jit(func_nested)(x1) - return hcb.id_print(x3 + 1, where="3", output_stream=testing_stream) + return hcb_id_print(x3 + 1, where="3", output_stream=testing_stream) self.assertEqual(3, jax.jit(func)(1)) hcb.barrier_wait() @@ -488,11 +520,12 @@ def func_nested(x): def test_tap_jit_devices(self): """Running on multiple devices.""" + self.supported_only_in_legacy_mode() logging.info("%s: has devices %s", self._testMethodName, local_devices()) def func(x, device_id): - x1 = hcb.id_print(x, dev=str(device_id), output_stream=testing_stream) - x2 = hcb.id_print(x1 + 1, dev=str(device_id), output_stream=testing_stream) + x1 = hcb_id_print(x, dev=str(device_id), output_stream=testing_stream) + x2 = hcb_id_print(x1 + 1, dev=str(device_id), output_stream=testing_stream) return x2 for d in local_devices(): @@ -614,16 +647,16 @@ def test_tap_cond(self, with_jit=False): """A conditional""" def func(x): - x1 = hcb.id_print(x, where="1", output_stream=testing_stream) - x2 = hcb.id_print(x1 + 1, where="2", output_stream=testing_stream) + x1 = hcb_id_print(x, where="1", output_stream=testing_stream) + x2 = hcb_id_print(x1 + 1, where="2", output_stream=testing_stream) x4 = lax.cond(x % 2 == 0, - lambda x: hcb.id_print(x, where="cond_t", + lambda x: hcb_id_print(x, where="cond_t", output_stream=testing_stream), - lambda x: hcb.id_print(-1, where="cond_f", result=x, + lambda x: hcb_id_print(-1, where="cond_f", result=x, output_stream=testing_stream), x2 + 1) - x5 = hcb.id_print(x4 + 1, where="end", output_stream=testing_stream) + x5 = hcb_id_print(x4 + 1, where="end", output_stream=testing_stream) return x5 transform = jax.jit if with_jit else lambda f: f @@ -642,21 +675,21 @@ def func(x): @jtu.sample_product(with_jit=[True, False]) def test_tap_while_cond(self, with_jit=False): def func(x): - x1 = hcb.id_print(x, where="1", output_stream=testing_stream) - x2 = hcb.id_print(x1 + 1, where="2", output_stream=testing_stream) + x1 = hcb_id_print(x, where="1", output_stream=testing_stream) + x2 = hcb_id_print(x1 + 1, where="2", output_stream=testing_stream) def body(x): - x3 = hcb.id_print(x, where="w_b_1", output_stream=testing_stream) + x3 = hcb_id_print(x, where="w_b_1", output_stream=testing_stream) x4 = lax.cond(x % 2 == 0, - lambda x: hcb.id_print(x, where="w_b_t", + lambda x: hcb_id_print(x, where="w_b_t", output_stream=testing_stream), - lambda x: hcb.id_print(-1, where="w_b_f", + lambda x: hcb_id_print(-1, where="w_b_f", result=x, output_stream=testing_stream), x3 + 1) - return hcb.id_print(x4, where="w_b_2", output_stream=testing_stream) + return hcb_id_print(x4, where="w_b_2", output_stream=testing_stream) x10 = lax.while_loop(lambda x: x <= 3, body, x2) - res = hcb.id_print(x10, where="end", output_stream=testing_stream) + res = hcb_id_print(x10, where="end", output_stream=testing_stream) return res transform = jax.jit if with_jit else lambda f: f @@ -686,14 +719,14 @@ def test_tap_jit_while_pred_tap(self): """While with printing in the conditional.""" def func(x): - x1 = hcb.id_print(x, where="1") - x10 = lax.while_loop(lambda x: hcb.id_print(x < 3, + x1 = hcb_id_print(x, where="1") + x10 = lax.while_loop(lambda x: hcb_id_print(x < 3, where="w_p", output_stream=testing_stream), - lambda x: hcb.id_print(x + 1, where="w_b", + lambda x: hcb_id_print(x + 1, where="w_b", output_stream=testing_stream), x1) - res = hcb.id_print(x10, where="3", output_stream=testing_stream) + res = hcb_id_print(x10, where="3", output_stream=testing_stream) return res self.assertEqual(3, jax.jit(func)(1)) @@ -716,19 +749,19 @@ def func(x): @jtu.sample_product(with_jit=[True, False]) def test_tap_scan_cond(self, with_jit=True): def func(x): - x1 = hcb.id_print(x, where="1", output_stream=testing_stream) - x2 = hcb.id_print(x1 + 1, where="2", output_stream=testing_stream) + x1 = hcb_id_print(x, where="1", output_stream=testing_stream) + x2 = hcb_id_print(x1 + 1, where="2", output_stream=testing_stream) def body(c, x): - x3 = hcb.id_print(x, where="s_1", output_stream=testing_stream) + x3 = hcb_id_print(x, where="s_1", output_stream=testing_stream) x4 = lax.cond(x % 2 == 0, - lambda x: hcb.id_print(x, where="s_t", output_stream=testing_stream), - lambda x: hcb.id_print(-1, where="s_f", result=x, output_stream=testing_stream), + lambda x: hcb_id_print(x, where="s_t", output_stream=testing_stream), + lambda x: hcb_id_print(-1, where="s_f", result=x, output_stream=testing_stream), x3 + 1) - return (c, hcb.id_print(x4, where="s_2", output_stream=testing_stream)) + return (c, hcb_id_print(x4, where="s_2", output_stream=testing_stream)) _, x10 = lax.scan(body, x2, jnp.arange(3)) - res = hcb.id_print(x10, where="10", output_stream=testing_stream) + res = hcb_id_print(x10, where="10", output_stream=testing_stream) return res if with_jit: @@ -777,7 +810,7 @@ def test_tap_jit_dtypes(self, nr_args=2, dtype=jnp.int16, shape=(2,)): args = [jnp.arange(np.prod(shape), dtype=dtype).reshape(shape)] if nr_args > 1: args = args * nr_args - jit_fun1 = jax.jit(lambda xs: hcb.id_print( + jit_fun1 = jax.jit(lambda xs: hcb_id_print( xs, a_new_test="************", testcase_name=f"{shape=}_{dtype=}_{nr_args=}")) @@ -787,11 +820,11 @@ def test_tap_jit_dtypes(self, nr_args=2, dtype=jnp.int16, shape=(2,)): def test_tap_jit_large(self): arg = jnp.arange(10000, dtype=jnp.int32).reshape((10, 10, 5, -1)) - jax.jit(hcb.id_print)(arg) + jax.jit(hcb_id_print)(arg) def test_tap_jit_several_together(self): arg = jnp.arange(50, dtype=jnp.int32).reshape((10, 5)) - jax.jit(lambda x, y: hcb.id_print((x, y, x * 2)))(arg, jnp.ones(100, dtype=jnp.int32)) + jax.jit(lambda x, y: hcb_id_print((x, y, x * 2)))(arg, jnp.ones(100, dtype=jnp.int32)) def test_tap_jit_interleaving(self): # Several jit's without data dependencies; they may interfere @@ -825,24 +858,29 @@ def tap_err(*args, **kwargs): raise NotImplementedError def func(x): - x1 = hcb.id_print(x + 1, what="x1", output_stream=testing_stream) + x1 = hcb_id_print(x + 1, what="x1", output_stream=testing_stream) x2 = hcb.id_tap(tap_err, x1 + 1) - x3 = hcb.id_print(x2 + 1, what="x3", output_stream=testing_stream) + x3 = hcb_id_print(x2 + 1, what="x3", output_stream=testing_stream) return x3 - res = jax.jit(func)(0) # No error yet - with self.assertRaises(hcb.CallbackException): - hcb.barrier_wait() - - # Even though the receiver thread raised, the main thread should still - # return 3. - self.assertEqual(3, res) - # We should have received all others - assertMultiLineStrippedEqual(self, """ - what: x1 - 1 - what: x3 - 3""", testing_stream.output) + if hcb._HOST_CALLBACK_LEGACY.value: + res = jax.jit(func)(0) # No error yet + with self.assertRaises(hcb.CallbackException): + hcb.barrier_wait() + + # Even though the receiver thread raised, the main thread should still + # return 3. + self.assertEqual(3, res) + # We should have received all others + assertMultiLineStrippedEqual(self, """ + what: x1 + 1 + what: x3 + 3""", testing_stream.output) + else: + with self.assertRaisesRegex(Exception, "NotImplementedError"): + res = jax.jit(func)(0) + hcb.barrier_wait() def test_tap_while(self): """Executing while, even without JIT uses compiled code""" @@ -851,7 +889,7 @@ def test_tap_while(self): def func(x): return lax.while_loop( lambda c: c[1] < 5, - lambda c: (y, hcb.id_print(c[1], output_stream=testing_stream) + 1), + lambda c: (y, hcb_id_print(c[1], output_stream=testing_stream) + 1), (x, 1)) func(y) @@ -877,8 +915,9 @@ def test_tap_jvp(self): def test_tap_grad_primal_unused(self): # The output of id_print is not needed for backwards pass def func(x): - return 2. * hcb.id_print(x * 3., what="x * 3", - output_stream=testing_stream) + return 2. * hcb_id_print(x * 3., what="x * 3", + output_stream=testing_stream, + callback_flavor=hcb.CallbackFlavor.DEBUG) grad_func = jax.grad(func) arg = jnp.float32(5.) @@ -886,21 +925,22 @@ def func(x): # making the Jaxpr does not print anything hcb.barrier_wait() - treedef = jax.tree.structure(arg) - assertMultiLineStrippedEqual( - self, f""" - {{ lambda ; a:f32[]. let - b:f32[] = mul a 3.00 - c:f32[] = outside_call[ - arg_treedef={treedef} - callback=... - device_index=0 - identity=True - ] b - _:f32[] = mul 2.00 c - d:f32[] = mul 2.00 1.00 - e:f32[] = mul d 3.00 - in (e,) }}""", jaxpr) + if hcb._HOST_CALLBACK_LEGACY.value: + treedef = jax.tree.structure(arg) + assertMultiLineStrippedEqual( + self, f""" + {{ lambda ; a:f32[]. let + b:f32[] = mul a 3.00 + c:f32[] = outside_call[ + arg_treedef={treedef} + callback=... + device_index=0 + identity=True + ] b + _:f32[] = mul 2.00 c + d:f32[] = mul 2.00 1.00 + e:f32[] = mul d 3.00 + in (e,) }}""", jaxpr) assertMultiLineStrippedEqual(self, "", testing_stream.output) testing_stream.reset() @@ -914,9 +954,11 @@ def func(x): def test_tap_grad_simple(self): def func(x): - y = hcb.id_print(x * 2., what="x * 2", output_stream=testing_stream) - return x * hcb.id_print(y * 3., what="y * 3", - output_stream=testing_stream) + y = hcb_id_print(x * 2., what="x * 2", output_stream=testing_stream, + callback_flavor=hcb.CallbackFlavor.DEBUG) + return x * hcb_id_print(y * 3., what="y * 3", + output_stream=testing_stream, + callback_flavor=hcb.CallbackFlavor.DEBUG) grad_func = jax.grad(func) @@ -931,7 +973,8 @@ def func(x): def test_tap_grad_grad(self): def func(x): - y = hcb.id_print(x * 2., what="x * 2", output_stream=testing_stream) + y = hcb_id_print(x * 2., what="x * 2", output_stream=testing_stream, + callback_flavor=hcb.CallbackFlavor.DEBUG) return x * (y * 3.) grad_func = jax.grad(jax.grad(func)) @@ -950,9 +993,10 @@ def func(x): def test_tap_grad_pytree(self): def func(x): - x4, x5 = hcb.id_print((x * 2., x * 3.), what="pair", + x4, x5 = hcb_id_print((x * 2., x * 3.), what="pair", result=(x * 4., x * 5.), - output_stream=testing_stream) + output_stream=testing_stream, + callback_flavor=hcb.CallbackFlavor.DEBUG) return x4 + 2. * x5 x = jnp.float32(5.) @@ -967,15 +1011,18 @@ def func(x): def test_tap_jvp_float0(self): def f(x, yint): - x, yint = hcb.id_tap(lambda arg, _: arg, (x, yint)) + x, yint = hcb.id_tap(lambda arg, _: arg, (x, yint), + callback_flavor=hcb.CallbackFlavor.DEBUG) return x * yint res = jax.jvp(f, (2., 3), (0.2, np.zeros((), dtypes.float0))) self.assertAllClose((6., 0.6), res) def test_tap_grad_float0(self): + def func(x, yint): - x, yint = hcb.id_print((x, yint), what="pair", output_stream=testing_stream) + x, yint = hcb_id_print((x, yint), what="pair", output_stream=testing_stream, + callback_flavor=hcb.CallbackFlavor.DEBUG) return x * yint.astype(x.dtype) grad_func = jax.grad(func) @@ -993,7 +1040,8 @@ def test_tap_grad_float0_result(self): x = (np.array([.7, .8], dtype=np.float32), np.array([11, 12, 13], dtype=np.int32)) def f_jax(x): - x = hcb.id_print(x, result=x, output_stream=testing_stream) # result= is important + x = hcb_id_print(x, result=x, output_stream=testing_stream, + callback_flavor=hcb.CallbackFlavor.DEBUG) # result= is important return (3. * x[0], x[1]) def f_jax_vjp(x): @@ -1015,7 +1063,8 @@ def test_tap_higher_order_grad_float0_result(self): x = (np.array([.7, .8], dtype=np.float32), np.array([11, 12, 13], dtype=np.int32)) def f_jax(x): - x = hcb.id_print(x, result=x, output_stream=testing_stream) # result= is important + x = hcb_id_print(x, result=x, output_stream=testing_stream, + callback_flavor=hcb.CallbackFlavor.DEBUG) # result= is important return (jnp.sin(x[0]), x[1]) def wrap_vjp(f, args, res_f_of_args): @@ -1059,32 +1108,52 @@ def test_tap_vmap(self): vargs = jnp.array([jnp.float32(4.), jnp.float32(5.)]) vmap_fun1(vargs) hcb.barrier_wait() - assertMultiLineStrippedEqual(self, """ - transforms: [('batch', {'batch_dims': (0,)})] what: a * 2 - [ 8.00 10.00] - transforms: [('batch', {'batch_dims': (0,)})] what: y * 3 - [24.00 30.00]""", testing_stream.output) + if hcb._HOST_CALLBACK_LEGACY.value: + assertMultiLineStrippedEqual(self, """ + transforms: [('batch', {'batch_dims': (0,)})] what: a * 2 + [ 8.00 10.00] + transforms: [('batch', {'batch_dims': (0,)})] what: y * 3 + [24.00 30.00]""", testing_stream.output) + else: + assertMultiLineStrippedEqual(self, """ + what: a * 2 + 8.00 + what: a * 2 + 10.00 + what: y * 3 + 24.00 + what: y * 3 + 30.00 + """, testing_stream.output) def test_tap_vmap_not_batched(self): x = 3. def func(y): # x is not mapped, y is mapped - _, y = hcb.id_print((x, y), output_stream=testing_stream) + _, y = hcb_id_print((x, y), output_stream=testing_stream, + callback_flavor=hcb.CallbackFlavor.DEBUG) return x + y vmap_func = jax.vmap(func) vargs = jnp.array([jnp.float32(4.), jnp.float32(5.)]) _ = vmap_func(vargs) hcb.barrier_wait() - assertMultiLineStrippedEqual(self, """ - transforms: [('batch', {'batch_dims': (None, 0)})] - ( 3.00 [4.00 5.00] )""", testing_stream.output) + if hcb._HOST_CALLBACK_LEGACY.value: + assertMultiLineStrippedEqual(self, """ + transforms: [('batch', {'batch_dims': (None, 0)})] + ( 3.00 [4.00 5.00] )""", testing_stream.output) + else: + assertMultiLineStrippedEqual(self, """ + ( 3.00 4.00 ) + ( 3.00 5.00 ) + """, testing_stream.output) def test_tap_vmap_vmap(self): # A 2D tensor with x[i, j] = i + j using 2 vmap def sum(x, y): - return hcb.id_print(x + y, output_stream=testing_stream) + return hcb_id_print(x + y, output_stream=testing_stream, + callback_flavor=hcb.CallbackFlavor.DEBUG) def sum_rows(xv, y): return jax.vmap(sum, in_axes=(0, None))(xv, y) @@ -1097,22 +1166,44 @@ def sum_all(xv, yv): # assertMultiLineStrippedEqual(self, "", str(jax.make_jaxpr(sum_all)(xv, yv))) _ = sum_all(xv, yv) hcb.barrier_wait() - assertMultiLineStrippedEqual(self, """ - transforms: [('batch', {'batch_dims': (0,)}), ('batch', {'batch_dims': (0,)})] - [[0 1 2 3 4] - [1 2 3 4 5] - [2 3 4 5 6]]""", testing_stream.output) + if hcb._HOST_CALLBACK_LEGACY.value: + assertMultiLineStrippedEqual(self, """ + transforms: [('batch', {'batch_dims': (0,)}), ('batch', {'batch_dims': (0,)})] + [[0 1 2 3 4] + [1 2 3 4 5] + [2 3 4 5 6]]""", testing_stream.output) + else: + assertMultiLineStrippedEqual(self, """ + 0 + 1 + 2 + 1 + 2 + 3 + 2 + 3 + 4 + 3 + 4 + 5 + 4 + 5 + 6 + """, testing_stream.output) def test_tap_vmap_while(self): """Vmap of while.""" def func(x): # like max(x, 2) - x1 = hcb.id_print(x, where="before:x", output_stream=testing_stream) + x1 = hcb_id_print(x, where="before:x", output_stream=testing_stream, + callback_flavor=hcb.CallbackFlavor.DEBUG) x2 = lax.while_loop( - lambda x: x < 2, lambda x: hcb.id_print( - x + 1, where="body:x+1", output_stream=testing_stream), x1) - res = hcb.id_print(x2, where="after:x", output_stream=testing_stream) + lambda x: x < 2, lambda x: hcb_id_print( + x + 1, where="body:x+1", output_stream=testing_stream, + callback_flavor=hcb.CallbackFlavor.DEBUG), x1) + res = hcb_id_print(x2, where="after:x", output_stream=testing_stream, + callback_flavor=hcb.CallbackFlavor.DEBUG) return res inputs = np.arange(5, dtype=np.int32) @@ -1121,72 +1212,93 @@ def func(x): jax.jit(jax.vmap(func))(inputs), check_dtypes=False) hcb.barrier_wait() - assertMultiLineStrippedEqual( - self, """ - transforms: [('batch', {'batch_dims': (0,)})] where: before:x - [0 1 2 3 4] - transforms: [('batch', {'batch_dims': (0,)})] where: body:x+1 - [1 2 3 4 5] - transforms: [('batch', {'batch_dims': (0,)})] where: body:x+1 - [2 3 3 4 5] - transforms: [('batch', {'batch_dims': (0,)})] where: after:x - [2 2 2 3 4]""", testing_stream.output) + if hcb._HOST_CALLBACK_LEGACY.value: + assertMultiLineStrippedEqual( + self, """ + transforms: [('batch', {'batch_dims': (0,)})] where: before:x + [0 1 2 3 4] + transforms: [('batch', {'batch_dims': (0,)})] where: body:x+1 + [1 2 3 4 5] + transforms: [('batch', {'batch_dims': (0,)})] where: body:x+1 + [2 3 3 4 5] + transforms: [('batch', {'batch_dims': (0,)})] where: after:x + [2 2 2 3 4]""", testing_stream.output) + else: + pass # order of vmaps is not guaranteed def test_tap_vmap_while_tap_cond(self): """Vmap of while, with a tap in the conditional.""" def func(x): # like max(x, 2) - x1 = hcb.id_print(x, where="1", output_stream=testing_stream) - x2 = lax.while_loop(lambda x: hcb.id_print(x < 2, where="w_c", - output_stream=testing_stream), - lambda x: hcb.id_print(x + 1, where="w_b", - output_stream=testing_stream), + x1 = hcb_id_print(x, where="1", output_stream=testing_stream, + callback_flavor=hcb.CallbackFlavor.DEBUG) + x2 = lax.while_loop(lambda x: hcb_id_print(x < 2, where="w_c", + output_stream=testing_stream, + callback_flavor=hcb.CallbackFlavor.DEBUG), + lambda x: hcb_id_print(x + 1, where="w_b", + output_stream=testing_stream, + callback_flavor=hcb.CallbackFlavor.DEBUG), x1) - res = hcb.id_print(x2, where="3", output_stream=testing_stream) + res = hcb_id_print(x2, where="3", output_stream=testing_stream, + callback_flavor=hcb.CallbackFlavor.DEBUG) return res inputs = np.arange(5, dtype=np.int32) res = jax.jit(jax.vmap(func))(inputs) hcb.barrier_wait() self.assertAllClose(np.array([2, 2, 2, 3, 4]), res, check_dtypes=False) - assertMultiLineStrippedEqual(self, """ - transforms: [('batch', {'batch_dims': (0,)})] where: 1 - [0 1 2 3 4] - transforms: [('batch', {'batch_dims': (0,)})] where: w_c - [ True True False False False] - transforms: [('batch', {'batch_dims': (0,)})] where: w_b - [1 2 3 4 5] - transforms: [('batch', {'batch_dims': (0,)})] where: w_c - [ True False False False False] - transforms: [('batch', {'batch_dims': (0,)})] where: w_b - [2 3 3 4 5] - transforms: [('batch', {'batch_dims': (0,)})] where: w_c - [False False False False False] - transforms: [('batch', {'batch_dims': (0,)})] where: 3 - [2 2 2 3 4]""", testing_stream.output) + if hcb._HOST_CALLBACK_LEGACY.value: + assertMultiLineStrippedEqual(self, """ + transforms: [('batch', {'batch_dims': (0,)})] where: 1 + [0 1 2 3 4] + transforms: [('batch', {'batch_dims': (0,)})] where: w_c + [ True True False False False] + transforms: [('batch', {'batch_dims': (0,)})] where: w_b + [1 2 3 4 5] + transforms: [('batch', {'batch_dims': (0,)})] where: w_c + [ True False False False False] + transforms: [('batch', {'batch_dims': (0,)})] where: w_b + [2 3 3 4 5] + transforms: [('batch', {'batch_dims': (0,)})] where: w_c + [False False False False False] + transforms: [('batch', {'batch_dims': (0,)})] where: 3 + [2 2 2 3 4]""", testing_stream.output) + else: + pass # order of vmap is not guaranteed def test_tap_transforms_doc(self): # Examples from the documentation def power3(x): y = x * x # Print both 'x' and 'x^2'. Must pack as a tuple. - hcb.id_print((x, y), what="x,x^2", output_stream=testing_stream) + hcb_id_print((x, y), what="x,x^2", output_stream=testing_stream, + callback_flavor=hcb.CallbackFlavor.DEBUG) return y * x print(f"impl = {power3(3.)}") hcb.barrier_wait() - expected = """ - what: x,x^2 - ( 3. 9. )""" + if hcb._HOST_CALLBACK_LEGACY.value: + expected = """ + what: x,x^2 + ( 3. 9. )""" + else: + expected = """ + what: x,x^2 + ( 3.0 9.0 )""" self.assertMultiLineStrippedEqual(expected, testing_stream.output) testing_stream.reset() print(f"jvp = {jax.jvp(power3, (3.,), (0.1,))}") hcb.barrier_wait() - expected = """ - what: x,x^2 - ( 3. 9. )""" + if hcb._HOST_CALLBACK_LEGACY.value: + expected = """ + what: x,x^2 + ( 3. 9. )""" + else: + expected = """ + what: x,x^2 + ( 3.0 9.0 )""" self.assertMultiLineStrippedEqual(expected, testing_stream.output) testing_stream.reset() @@ -1197,32 +1309,41 @@ def print_tangents(arg): @print_tangents.defjvp def print_tangents_jvp(primals, tangents): arg_dot, = tangents - hcb.id_print(arg_dot, what="tangents", output_stream=testing_stream) + hcb_id_print(arg_dot, what="tangents", output_stream=testing_stream, + callback_flavor=hcb.CallbackFlavor.DEBUG) return primals, tangents def power3_with_tangents(x): y = x * x # Print both 'x' and 'x^2'. Must pack as a tuple. - hcb.id_print((x, y), what="x,x^2", output_stream=testing_stream) + hcb_id_print((x, y), what="x,x^2", output_stream=testing_stream, + callback_flavor=hcb.CallbackFlavor.DEBUG) print_tangents((x, y)) return y * x print(f"jvp = {jax.jvp(power3_with_tangents, (3.,), (0.1,))}") hcb.barrier_wait() - expected = """ - what: x,x^2 - ( 3. 9. ) - what: tangents - ( 0.1 0.6 )""" - self.assertMultiLineStrippedEqual(expected, testing_stream.output) + if hcb._HOST_CALLBACK_LEGACY.value: + expected = """ + what: x,x^2 + ( 3. 9. ) + what: tangents + ( 0.1 0.6 )""" + self.assertMultiLineStrippedEqual(expected, testing_stream.output) + testing_stream.reset() print(f"grad = {jax.grad(power3)(3.)}") hcb.barrier_wait() # Only the primals by default - expected = """ - what: x,x^2 - ( 3. 9. )""" + if hcb._HOST_CALLBACK_LEGACY.value: + expected = """ + what: x,x^2 + ( 3. 9. )""" + else: + expected = """ + what: x,x^2 + ( 3.0 9.0 )""" self.assertMultiLineStrippedEqual(expected, testing_stream.output) testing_stream.reset() @@ -1236,7 +1357,8 @@ def print_cotangents_fwd(arg): return print_cotangents(arg), None # f_bwd: (residual, CT b) -> [CT a] def print_cotangents_bwd(residual, ct_b): - hcb.id_print(ct_b, what="cotangents", output_stream=testing_stream) + hcb_id_print(ct_b, what="cotangents", output_stream=testing_stream, + callback_flavor=hcb.CallbackFlavor.DEBUG) return ct_b, print_cotangents.defvjp(print_cotangents_fwd, print_cotangents_bwd) @@ -1244,18 +1366,26 @@ def print_cotangents_bwd(residual, ct_b): def power3_with_cotangents(x): y = x * x # Print both 'x' and 'x^2'. Must pack as a tuple. - hcb.id_print((x, y), what="x,x^2", output_stream=testing_stream) + hcb_id_print((x, y), what="x,x^2", output_stream=testing_stream, + callback_flavor=hcb.CallbackFlavor.DEBUG) # Must use the output of print_cotangents (x1, y1) = print_cotangents((x, y)) return y1 * x1 print(f"grad = {jax.grad(power3_with_cotangents)(3.)}") hcb.barrier_wait() - expected = """ - what: x,x^2 - ( 3. 9. ) - what: cotangents - ( 9. 3. )""" + if hcb._HOST_CALLBACK_LEGACY.value: + expected = """ + what: x,x^2 + ( 3. 9. ) + what: cotangents + ( 9. 3. )""" + else: + expected = """ + what: x,x^2 + ( 3.0 9.0 ) + what: cotangents + ( 9.0 3.0 )""" self.assertMultiLineStrippedEqual(expected, testing_stream.output) testing_stream.reset() @@ -1263,50 +1393,89 @@ def power3_with_cotangents(x): print(f"vmap = {jax.vmap(power3)(np.array([2., 3.]))}") hcb.barrier_wait() - expected = """ - transforms: [('batch', {'batch_dims': (0, 0)})] what: x,x^2 - ( [2. 3.] [4. 9.] )""" + if hcb._HOST_CALLBACK_LEGACY.value: + expected = """ + transforms: [('batch', {'batch_dims': (0, 0)})] what: x,x^2 + ( [2. 3.] [4. 9.] )""" + else: + expected = """ + what: x,x^2 + ( 2.0 4.0 ) + what: x,x^2 + ( 3.0 9.0 ) + """ self.assertMultiLineStrippedEqual(expected, testing_stream.output) testing_stream.reset() print(f"vmap o grad {jax.vmap(jax.grad(power3))(np.array([2., 3.]))}") hcb.barrier_wait() - expected = """ - transforms: [('batch', {'batch_dims': (0, 0)})] what: x,x^2 - ( [2. 3.] [4. 9.] )""" + if hcb._HOST_CALLBACK_LEGACY.value: + expected = """ + transforms: [('batch', {'batch_dims': (0, 0)})] what: x,x^2 + ( [2. 3.] [4. 9.] )""" + else: + expected = """ + what: x,x^2 + ( 2.0 4.0 ) + what: x,x^2 + ( 3.0 9.0 ) + """ self.assertMultiLineStrippedEqual(expected, testing_stream.output) testing_stream.reset() print(f"vmap o grad {jax.vmap(jax.grad(power3_with_cotangents))(np.array([2., 3.]))}") hcb.barrier_wait() - expected = """ - transforms: [('batch', {'batch_dims': (0, 0)})] what: x,x^2 - ( [2. 3.] [4. 9.] ) - transforms: [('batch', {'batch_dims': (0, 0)})] what: cotangents - ( [4. 9.] [2. 3.] )""" + if hcb._HOST_CALLBACK_LEGACY.value: + expected = """ + transforms: [('batch', {'batch_dims': (0, 0)})] what: x,x^2 + ( [2. 3.] [4. 9.] ) + transforms: [('batch', {'batch_dims': (0, 0)})] what: cotangents + ( [4. 9.] [2. 3.] )""" + else: + expected = """ + what: x,x^2 + ( 2.0 4.0 ) + what: x,x^2 + ( 3.0 9.0 ) + what: cotangents + ( 4.0 2.0 ) + what: cotangents + ( 9.0 3.0 ) + """ self.assertMultiLineStrippedEqual(expected, testing_stream.output) testing_stream.reset() print(f"grad o remat = {jax.grad(lambda x: power3(ad_checkpoint.checkpoint(power3)(x)))(3.)}") hcb.barrier_wait() - expected = """ - what: x,x^2 - ( 3. 9. ) - what: x,x^2 - ( 27. 729. ) - what: x,x^2 - ( 3. 9. )""" + if hcb._HOST_CALLBACK_LEGACY.value: + expected = """ + what: x,x^2 + ( 3. 9. ) + what: x,x^2 + ( 27. 729. ) + what: x,x^2 + ( 3. 9. )""" + else: + expected = """ + what: x,x^2 + ( 3.0 9.0 ) + what: x,x^2 + ( 27.0 729.0 ) + what: x,x^2 + ( 3.0 9.0 ) + """ self.assertMultiLineStrippedEqual(expected, testing_stream.output) testing_stream.reset() def test_tap_pmap(self): + self.supported_only_in_legacy_mode() if len(local_devices()) < 2: raise SkipTest("test requires at least 2 devices") def power3(x): y = x * x # Print both 'x' and 'x^2'. Must pack as a tuple. - _, y = hcb.id_print((x, y), + _, y = hcb_id_print((x, y), what="x,x^2", output_stream=testing_stream, tap_with_device=True) @@ -1326,6 +1495,7 @@ def power3(x): ( 4 16 )""") def test_tap_pmap_vmap(self): + self.supported_only_in_legacy_mode() # A matrix M[ij] = i * 10 + j nr_devices = len(local_devices()) shape = (nr_devices, 3) @@ -1353,6 +1523,7 @@ def fun1(x, do_print=False): # x: i32 def test_tap_pmap_pmap_vmap(self): # A matrix M[ijk] = i * 100 + j * 10 + k + self.supported_only_in_legacy_mode() nr_devices = len(local_devices()) if nr_devices % 2 != 0: raise SkipTest("test works only on even number of devices") @@ -1386,6 +1557,7 @@ def fun1(x, do_print=False): # x: f32 def test_tap_pmap_pmap_extra(self): """pmap of a pmap surrounded by extra code.""" # A matrix M[ij] = i * 10 + j + self.supported_only_in_legacy_mode() nr_devices = len(local_devices()) if nr_devices != 2: raise SkipTest("test works only on 2 devices") @@ -1419,6 +1591,7 @@ def fun(xv, do_print=False): [[203.00 205.00 207.00]]""") def test_tap_jvp_pmap_vmap(self): + self.supported_only_in_legacy_mode() # A matrix M[ijk] = i * 100 + j * 10 * k nr_devices = len(local_devices()) shape = (nr_devices, 2, 3) @@ -1445,6 +1618,7 @@ def fun(xv, do_print=False): [220.00 222.00 224.00]]""") def test_tap_vmap_pmap(self): + self.supported_only_in_legacy_mode() # A matrix M[ijk] = i * 100 + j * 10 * k nr_devices = len(local_devices()) shape = (2, nr_devices, 3) @@ -1472,6 +1646,7 @@ def fun(xv, do_print=False): @ignore_jit_of_pmap_warning() def test_tap_jit_pmap_extra(self): """jit of a pmap surrounded by extra code.""" + self.supported_only_in_legacy_mode() # A matrix M[ij] = i * 10 + j nr_devices = len(local_devices()) assert nr_devices in (1, 2) @@ -1540,6 +1715,7 @@ def fun2(cond, xv, do_print=False): @jtu.sample_product(device_index=[0, 1]) def test_tap_pjit(self, device_index=0): + self.supported_only_in_legacy_mode() if (device_index != 0 and not hcb._HOST_CALLBACK_OUTFEED.value and jtu.test_device_matches(["cpu"])): @@ -1560,7 +1736,7 @@ def test_tap_pjit(self, device_index=0): @partial(jax.named_call, name="fun1") # for xprof debugging def fun1(x): z = jnp.dot(x, y) - return hcb.id_print(z, what="z", + return hcb_id_print(z, what="z", output_stream=testing_stream, tap_with_device=True, device_index=device_index) @@ -1589,17 +1765,17 @@ def fun1(x): def test_tap_scan_custom_jvp(self): """custom JVP, inside scan. This exercises the custom_jvp_call_jaxpr primitives.""" - + self.supported_only_in_legacy_mode() @jax.custom_jvp def f(x): - return x * hcb.id_print(x, output_stream=testing_stream, what="x") + return x * hcb_id_print(x, output_stream=testing_stream, what="x") @f.defjvp def f_jvp(primals, tangents): x, = primals x_dot, = tangents primal_out = f(x) - tangent_out = 3. * x * hcb.id_print(x_dot, output_stream=testing_stream, what="x_dot") + tangent_out = 3. * x * hcb_id_print(x_dot, output_stream=testing_stream, what="x_dot") return primal_out, tangent_out def g(x): @@ -1633,10 +1809,10 @@ def g(x): def test_tap_scan_custom_vjp(self): """custom VJP, inside scan. This exercises the custom_vjp_call_jaxpr primitives.""" - + self.supported_only_in_legacy_mode() @jax.custom_vjp def f(x): - return x * hcb.id_print(x, output_stream=testing_stream, what="x") + return x * hcb_id_print(x, output_stream=testing_stream, what="x") # f_fwd: a -> (b, residual) def f_fwd(x): @@ -1644,7 +1820,7 @@ def f_fwd(x): # f_bwd: (residual, CT b) -> [CT a] def f_bwd(residual, ct_b): - return residual * hcb.id_print(ct_b, output_stream=testing_stream, what="ct_b"), + return residual * hcb_id_print(ct_b, output_stream=testing_stream, what="ct_b"), f.defvjp(f_fwd, f_bwd) @@ -1682,7 +1858,7 @@ def test_tap_callback_delay(self): def func(x): for i in range(5): - x = hcb.id_print(x * i, what="x times i") + x = hcb_id_print(x * i, what="x times i") return x jax.jit(func)(np.arange(6, dtype=np.float32).reshape((2, 3))) @@ -1692,7 +1868,7 @@ def test_tap_callback_delay_barrier(self): def func(x): for i in range(1, 4): - x = hcb.id_print(x * i, what=f"x times {i}", output_stream=testing_stream) + x = hcb_id_print(x * i, what=f"x times {i}", output_stream=testing_stream) return x jax.jit(func)(np.arange(6, dtype=np.float32).reshape((2, 3))) @@ -1741,12 +1917,12 @@ def test_tap_error_different_shapes(self): comp, token, 123, [xops.Constant(comp, np.zeros((2, 3), dtype=np.float32))], 0) with self.assertRaisesRegex( - RuntimeError, ".*does not match previous shape element_type.*"): + RuntimeError, ".*does not match previous shape .*\n?element_type.*"): hcb._callback_handler_data.receiver.add_outfeed( comp, token, 123, [xops.Constant(comp, np.zeros((2, 3), dtype=np.int32))], 0) with self.assertRaisesRegex( - RuntimeError, ".*does not match previous shape element_type.*"): + RuntimeError, ".*does not match previous shape .*\n?element_type.*"): hcb._callback_handler_data.receiver.add_outfeed( comp, token, 123, [xops.Constant(comp, np.zeros((2,), dtype=np.float32))], 0) @@ -1773,7 +1949,7 @@ def test_tap_odeint(self): from jax.experimental.ode import odeint def f(x, t, k): - x = hcb.id_print(x) + x = hcb_id_print(x, callback_flavor=hcb.CallbackFlavor.DEBUG) return -k * x def loss(k=1.0): @@ -1785,7 +1961,8 @@ def loss(k=1.0): def test_tap_remat_0(self): def f(i, k): - x = hcb.id_print(k + i, output_stream=testing_stream) + x = hcb_id_print(k + i, output_stream=testing_stream, + callback_flavor=hcb.CallbackFlavor.DEBUG) return k * x def loss(k): @@ -1804,10 +1981,11 @@ def loss(k): use_remat=["old", "new", "none"], ) def test_tap_remat(self, use_result=False, grad_func="grad", use_remat="new"): + self.supported_only_in_legacy_mode() if use_remat == "old": raise SkipTest() def f(x): - id_print_result = hcb.id_print(x, output_stream=testing_stream) + id_print_result = hcb_id_print(x, output_stream=testing_stream) if use_result: x = id_print_result return 3. * x @@ -1867,11 +2045,16 @@ class HostCallbackCallTest(jtu.JaxTestCase): """Tests for hcb.call""" def setUp(self): - super().setUp() + # skipping here skips teardown, so do this before super().setUp(). if jtu.test_device_matches(["gpu"]) and jax.device_count() > 1: raise SkipTest("host_callback broken on multi-GPU platforms (#6447)") if xla_bridge.using_pjrt_c_api(): raise SkipTest("host_callback not implemented in PJRT C API") + super().setUp() + self.enter_context(jtu.ignore_warning( + category=DeprecationWarning, message="The host_callback APIs are deprecated")) + self.enter_context(jtu.ignore_warning( + category=DeprecationWarning, message="backend and device argument")) testing_stream.reset() testing_stream._test_method_name = self._testMethodName @@ -1880,6 +2063,10 @@ def tearDown(self) -> None: hcb.barrier_wait("HostCallbackCallTest.tearDown") super().tearDown() + def supported_only_in_legacy_mode(self): + if not hcb._HOST_CALLBACK_LEGACY.value: + self.skipTest("Not supported when JAX_HOST_CALLBACK_LEGACY=False") + def call_log_testing_stream(self, func, arg, *, result_shape, name=""): """Call `func` and log inputs and outputs to the testing stream""" @@ -1916,6 +2103,7 @@ def fun(x): with jtu.count_primitive_compiles() as count: for _ in range(3): self.assertAllClose(2 * arg, fun(arg)) + r = jax.make_jaxpr(fun)(arg) self.assertEqual(count[0], 1) @jtu.sample_product( @@ -2124,6 +2312,7 @@ def fun2(m): helper_print_optimized_hlo(fun2, m) def test_call_with_device(self): + self.supported_only_in_legacy_mode() def callback_func(x, device=None): testing_stream.write(f"device: {device}\n Called with {x}") return x @@ -2139,6 +2328,7 @@ def func(x): Called with 3.00""") def test_call_pmap(self): + self.supported_only_in_legacy_mode() # Works for 1 or 2 devices def callback_func(x, device=None): testing_stream.write(f"device: {device}\n Called with {x}") @@ -2163,10 +2353,14 @@ def test_call_vmap(self): def f_outside(x): return x def fun(x): - return hcb.call(f_outside, x, result_shape=x) + return hcb.call(f_outside, x, result_shape=x, + callback_flavor=hcb.CallbackFlavor.PURE) - with self.assertRaisesRegex(NotImplementedError, - "batching rules are implemented only for id_tap, not for call"): + if hcb._HOST_CALLBACK_LEGACY.value: + with self.assertRaisesRegex(NotImplementedError, + "batching rules are implemented only for id_tap, not for call"): + jax.vmap(fun)(np.ones((2, 3))) + else: jax.vmap(fun)(np.ones((2, 3))) @jtu.sample_product(device_index=[0, 1]) @@ -2256,6 +2450,7 @@ def helper_check_callback_errors(self, thunk: Callable, hcb.barrier_wait("Waiting for error") def test_call_error_callback_throws_exception(self): + self.supported_only_in_legacy_mode() def f_outside(x): raise ValueError("user exception") def fun(x): @@ -2265,6 +2460,7 @@ def fun(x): "ValueError: user exception") def test_call_error_callback_returns_unexpected_shape(self): + self.supported_only_in_legacy_mode() def fun(x): return hcb.call(lambda x: (x, x), x, result_shape=x) @@ -2272,6 +2468,7 @@ def fun(x): "Callback func .* should have returned a result with pytree") def test_call_error_then_compute(self): + self.supported_only_in_legacy_mode() # Continue computation on device after error def f_outside(x): raise ValueError("user exception") @@ -2283,7 +2480,9 @@ def fun(x): "ValueError: user exception") -def call_jax_other_device(jax_outside_fun, arg, *, device): +def call_jax_other_device( + jax_outside_fun, arg, *, device, + callback_flavor: hcb.CallbackFlavor = hcb.CallbackFlavor.IO_CALLBACK): """Calls a JAX function on a specific device with simple support for reverse AD. Functions whose name starts with "jax_outside" are called on another device, @@ -2296,7 +2495,8 @@ def run_jax_outside_fun(arg): @jax.custom_vjp def make_call(arg): return hcb.call(run_jax_outside_fun, arg, - result_shape=jax.eval_shape(jax_outside_fun, arg)) + result_shape=jax.eval_shape(jax_outside_fun, arg), + callback_flavor=callback_flavor) # Define the fwd and bwd custom_vjp functions def make_call_vjp_fwd(arg): @@ -2323,6 +2523,8 @@ class CallJaxTest(jtu.JaxTestCase): """Tests using `call_jax_other_device`.""" def setUp(self): + if not hcb._HOST_CALLBACK_LEGACY.value: + self.skipTest("Not supported when JAX_HOST_CALLBACK_LEGACY=False") if jtu.test_device_matches(["gpu"]) and jax.device_count() > 1: raise SkipTest("host_callback broken on multi-GPU platforms (#6447)") if xla_bridge.using_pjrt_c_api(): @@ -2336,6 +2538,9 @@ def setUp(self): raise SkipTest("Test needs at least two devices. On CPU use XLA_FLAGS=--xla_force_host_platform_device_count=2") self.outside_device = jax.devices("cpu")[1] super().setUp() + self.enter_context(jtu.ignore_warning( + category=DeprecationWarning, message="The host_callback APIs are deprecated")) + def test_jax_impl(self): def f_jax(x): @@ -2403,6 +2608,12 @@ def setUp(self): if xla_bridge.using_pjrt_c_api(): raise SkipTest("host_callback not implemented in PJRT C API") super().setUp() + self.enter_context(jtu.ignore_warning( + category=DeprecationWarning, message="The host_callback APIs are deprecated")) + + def supported_only_in_legacy_mode(self): + if not hcb._HOST_CALLBACK_LEGACY.value: + self.skipTest("Not supported when JAX_HOST_CALLBACK_LEGACY=False") def assertRewrite(self, expected: str, func: Callable, args: Sequence, has_input_token=True, has_output_token=True): @@ -2443,7 +2654,7 @@ def test_simple_outfeed(self): callback=... has_token=True identity=True ] b d e - in (c, f, g) }""", lambda x: hcb.id_print(x + x), [0]) + in (c, f, g) }""", lambda x: hcb_id_print(x + x), [0]) def test_simple_outfeed_without_input_token(self): self.assertRewrite(""" @@ -2455,7 +2666,7 @@ def test_simple_outfeed_without_input_token(self): callback=... has_token=True identity=True ] c e f - in (d,) }""", lambda x1, x2: hcb.id_print(x1 + x2), [1, 2], + in (d,) }""", lambda x1, x2: hcb_id_print(x1 + x2), [1, 2], has_input_token=False, has_output_token=False) def test_simple_outfeed_without_input_token_nor_invars(self): @@ -2467,13 +2678,13 @@ def test_simple_outfeed_without_input_token_nor_invars(self): callback=... has_token=True identity=True ] 42 b c - in (a,) }""", lambda: hcb.id_print(42), [], + in (a,) }""", lambda: hcb_id_print(42), [], has_input_token=False, has_output_token=False) def test_multiple_tap_without_dependencies(self): def f(x): - hcb.id_print(x, what="x") - hcb.id_print(x + 1, what="x + 1") + hcb_id_print(x, what="x") + hcb_id_print(x + 1, what="x + 1") return 2 self.assertRewrite(""" @@ -2494,7 +2705,7 @@ def test_cond(self): def func(x, z): return lax.cond(z > 0, (1, 2), lambda a: (a[0], jnp.zeros(5)), - z, lambda a: (hcb.id_print(a), y)) + z, lambda a: (hcb_id_print(a), y)) self.assertRewrite(""" { lambda a ; b c h i. @@ -2510,8 +2721,7 @@ def func(x, z): { lambda ; f_ a b c g h. let d = broadcast_in_dim[ broadcast_dimensions=( ) shape=(5,) ] 0.00 - in (a, d, g, h) } ) - linear=(False, False, False, False, False, False) ] e a 1 2 c h i + in (a, d, g, h) } ) ] e a 1 2 c h i in (f, g, j, k) }""", func, [y, 5]) def test_while(self): @@ -2522,7 +2732,7 @@ def func(x): # x: f32[5] # c: (f32[5], f32) return lax.while_loop(lambda c: c[1] < jnp.sum(c[0] + ct_cond), - lambda c: (ct_body, hcb.id_print(c[1]) + 1.), + lambda c: (ct_body, hcb_id_print(c[1]) + 1.), (x, np.float32(1.))) self.assertRewrite(""" @@ -2550,8 +2760,8 @@ def test_while_pred_outfeed(self): ct_cond = jnp.ones(2) # captured const for the conditional def func(x): - return lax.while_loop(lambda c: hcb.id_print(ct_cond, result=c[1]) < 5, - lambda c: (ct_body, hcb.id_print(c[1]) + 1), + return lax.while_loop(lambda c: hcb_id_print(ct_cond, result=c[1]) < 5, + lambda c: (ct_body, hcb_id_print(c[1]) + 1), (x, 1)) self.assertRewrite(""" @@ -2601,7 +2811,7 @@ def test_scan(self): y = jnp.ones(5) # captured const def func(x): - return lax.scan(lambda c, a: (hcb.id_print(c), y), (1, 2), x) + return lax.scan(lambda c, a: (hcb_id_print(c), y), (1, 2), x) self.assertRewrite(""" { lambda a ; b f g. @@ -2624,17 +2834,17 @@ def func(x): def test_scan_custom_jvp(self): """custom JVP, inside scan. This exercises the custom_jvp_call_jaxpr primitives.""" - + self.supported_only_in_legacy_mode() @jax.custom_jvp def f(x): - return x * hcb.id_print(x) + return x * hcb_id_print(x) @f.defjvp def f_jvp(primals, tangents): x, = primals x_dot, = tangents primal_out = f(x) - tangent_out = 3. * x * hcb.id_print(x_dot) + tangent_out = 3. * x * hcb_id_print(x_dot) return primal_out, tangent_out def g(x): @@ -2706,10 +2916,10 @@ def g(x): def test_scan_custom_vjp(self): """custom VJP, inside scan. This exercises the custom_vjp_call_jaxpr primitives.""" - + self.supported_only_in_legacy_mode() @jax.custom_vjp def f(x): - return x * hcb.id_print(x) + return x * hcb_id_print(x) # f_fwd: a -> (b, residual) def f_fwd(x): @@ -2717,7 +2927,7 @@ def f_fwd(x): # f_bwd: (residual, CT b) -> [CT a] def f_bwd(residual, ct_b): - return residual * hcb.id_print(ct_b), + return residual * hcb_id_print(ct_b), f.defvjp(f_fwd, f_bwd) @@ -2792,7 +3002,7 @@ def g(x): def test_remat_loop(self): def f(k, x): - x = hcb.id_print(k + x) + x = hcb_id_print(k + x) return -k * x def loss(k): @@ -2849,8 +3059,9 @@ def step(acc, step_nr): in (c, d, e) }""", tap_scalar, [np.int32(3)]) def test_pmap(self): + self.supported_only_in_legacy_mode() def f(xv): - jax.pmap(lambda x: jnp.sin(hcb.id_print(x, tap_with_device=True)), + jax.pmap(lambda x: jnp.sin(hcb_id_print(x, tap_with_device=True)), axis_name="i")(xv) self.assertRewrite(""" diff --git a/tests/host_callback_to_tf_test.py b/tests/host_callback_to_tf_test.py index c8858e14084c..fe80c90ace68 100644 --- a/tests/host_callback_to_tf_test.py +++ b/tests/host_callback_to_tf_test.py @@ -18,15 +18,15 @@ This is separate from host_callback_test because it needs a TF dependency. """ -from typing import Callable +from collections.abc import Callable import unittest from absl.testing import absltest from absl.testing import parameterized import jax -from jax import config from jax import numpy as jnp +from jax._src import config from jax._src import test_util as jtu from jax._src import xla_bridge from jax.experimental import host_callback as hcb @@ -53,7 +53,8 @@ def tf_to_numpy(t): return hcb.call(lambda arg: tf.nest.map_structure(tf_to_numpy, tf_fun(arg)), - arg, result_shape=result_shape) + arg, result_shape=result_shape, + callback_flavor=hcb.CallbackFlavor.DEBUG) def call_tf_simple_ad(tf_fun: Callable, arg, *, result_shape): @@ -166,12 +167,17 @@ def setUp(self): raise unittest.SkipTest("host_callback not implemented in PJRT C API") super().setUp() + def supported_only_in_legacy_mode(self): + if not hcb._HOST_CALLBACK_LEGACY.value: + self.skipTest("Not supported when JAX_HOST_CALLBACK_LEGACY=False") + @parameterized.named_parameters( dict( testcase_name=f"_{ad=}", ad=ad) for ad in CALL_TF_IMPLEMENTATIONS.keys()) def test_impl(self, ad="simple"): + self.supported_only_in_legacy_mode() call_tf = CALL_TF_IMPLEMENTATIONS[ad] def f_jax(x): @@ -192,21 +198,27 @@ def f_outside(x): for ad in CALL_TF_IMPLEMENTATIONS.keys() if ad != "none") def test_grad(self, ad="simple"): + self.supported_only_in_legacy_mode() call_tf = CALL_TF_IMPLEMENTATIONS[ad] def f_jax(x): return 3. * jnp.sin(2. * x) def f_outside(x): - return 3. * call_tf(tf.math.sin, 2. * x, result_shape=x) + return 3. * call_tf( + lambda x: tf.cast(tf.math.sin(x), tf.float32), 2. * x, + result_shape=jax.ShapeDtypeStruct((), np.float32)) - x = 4. - self.assertAllClose(f_jax(x), f_outside(x)) + x = np.float32(4.) + self.assertAllClose(f_jax(x), f_outside(x), + check_dtypes=False) grad_f = jax.grad(f_outside)(x) - self.assertAllClose(jax.grad(f_jax)(x), grad_f) + self.assertAllClose(jax.grad(f_jax)(x), grad_f, + check_dtypes=False) def test_grad_pytree(self): + self.supported_only_in_legacy_mode() call_tf = call_tf_full_ad def f_jax(xy): @@ -215,15 +227,19 @@ def f_jax(xy): def f_outside(xy): dict_ab = call_tf( - lambda xy: dict(a=2. * xy[0], b=xy[0] * xy[1]), + lambda xy: dict(a=tf.cast(2. * xy[0], np.float32), + b=tf.cast(xy[0] * xy[1], np.float32)), xy, - result_shape=dict(a=xy[0], b=xy[1])) + result_shape=dict(a=jax.ShapeDtypeStruct((), np.float32), + b=jax.ShapeDtypeStruct((), np.float32))) return 3. * dict_ab["a"] + 4. * dict_ab["b"] xy = (5., 6.) - self.assertAllClose(f_jax(xy), f_outside(xy)) + self.assertAllClose(f_jax(xy), f_outside(xy), + check_dtypes=False) res_jax = jax.grad(f_jax)(xy) - self.assertAllClose(res_jax, jax.grad(f_outside)(xy)) + self.assertAllClose(res_jax, jax.grad(f_outside)(xy), + check_dtypes=False) @parameterized.named_parameters( dict( @@ -231,6 +247,7 @@ def f_outside(xy): degree=degree) for degree in [1, 2, 3, 4]) def test_higher_order_grad(self, degree=4): + self.supported_only_in_legacy_mode() call_tf = call_tf_full_ad def f_jax(x): diff --git a/tests/image_test.py b/tests/image_test.py index 6204ec91cd5c..f3cd56ed7622 100644 --- a/tests/image_test.py +++ b/tests/image_test.py @@ -24,8 +24,6 @@ from jax import numpy as jnp from jax._src import test_util as jtu -from jax import config - # We use TensorFlow and PIL as reference implementations. try: import tensorflow as tf @@ -37,7 +35,7 @@ except ImportError: PIL_Image = None -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() float_dtypes = jtu.dtypes.all_floating inexact_dtypes = jtu.dtypes.inexact diff --git a/tests/infeed_test.py b/tests/infeed_test.py index 37592d52fa49..ba47d2417f94 100644 --- a/tests/infeed_test.py +++ b/tests/infeed_test.py @@ -19,7 +19,6 @@ from absl.testing import absltest import jax from jax import lax, numpy as jnp -from jax import config from jax.experimental import host_callback as hcb from jax._src import core from jax._src import xla_bridge @@ -27,7 +26,7 @@ import jax._src.test_util as jtu import numpy as np -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() class InfeedTest(jtu.JaxTestCase): @@ -78,7 +77,7 @@ def f(x): @jax.numpy_rank_promotion("allow") # Test explicitly exercises implicit rank promotion. def testInfeedThenOutfeed(self): - hcb.stop_outfeed_receiver() + hcb._deprecated_stop_outfeed_receiver() @jax.jit def f(x): @@ -100,7 +99,7 @@ def f(x): self.assertAllClose(out, y + np.float32(1)) def testInfeedThenOutfeedInALoop(self): - hcb.stop_outfeed_receiver() + hcb._deprecated_stop_outfeed_receiver() def doubler(_, token): y, token = lax.infeed( diff --git a/tests/jax_to_ir_test.py b/tests/jax_to_ir_test.py index 05e2ce40b763..ec67ac3dc631 100644 --- a/tests/jax_to_ir_test.py +++ b/tests/jax_to_ir_test.py @@ -119,10 +119,12 @@ def test_parse_shape_str(self): self.assertParsedShape('f32[]', [], jnp.float32) self.assertParsedShape('f32[1,2,3]', [1, 2, 3], jnp.float32) self.assertParsedShape('pred[1]', [1], jnp.bool_) + self.assertParsedShape('s4[1]', [1], jnp.int4) self.assertParsedShape('s8[1]', [1], jnp.int8) self.assertParsedShape('s16[1]', [1], jnp.int16) self.assertParsedShape('s32[1]', [1], jnp.int32) self.assertParsedShape('s64[1]', [1], jnp.int64) + self.assertParsedShape('u4[1]', [1], jnp.uint4) self.assertParsedShape('u8[1]', [1], jnp.uint8) self.assertParsedShape('u16[1]', [1], jnp.uint16) self.assertParsedShape('u32[1]', [1], jnp.uint32) diff --git a/tests/jaxpr_effects_test.py b/tests/jaxpr_effects_test.py index 255822a26109..88fd7a334048 100644 --- a/tests/jaxpr_effects_test.py +++ b/tests/jaxpr_effects_test.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import functools +import contextlib import threading import unittest @@ -19,7 +19,6 @@ import jax import jax.numpy as jnp from jax import lax -from jax.experimental import maps from jax.experimental import pjit from jax._src import ad_checkpoint from jax._src import dispatch @@ -32,6 +31,7 @@ from jax._src.interpreters import ad from jax._src.interpreters import mlir from jax._src.interpreters import partial_eval as pe +from jax._src.maps import xmap import numpy as np config.parse_flags_with_absl() @@ -124,7 +124,7 @@ def callback_effect_lowering(ctx: mlir.LoweringRuleContext, *args, callback, out out_op, token_out, _ = mlir.emit_python_callback( ctx, callback, token_in, list(args), list(ctx.avals_in), - list(ctx.avals_out), True) + list(ctx.avals_out), has_side_effect=True) if token_out: ctx.set_tokens_out(ctx.tokens_in.update_tokens(mlir.TokenSet({effect: token_out}))) @@ -133,18 +133,13 @@ def callback_effect_lowering(ctx: mlir.LoweringRuleContext, *args, callback, out mlir.register_lowering(callback_p, callback_effect_lowering) -prev_xla_flags = None - +_exit_stack = contextlib.ExitStack() def setUpModule(): - global prev_xla_flags - # This will control the CPU devices. On TPU we always have 2 devices - prev_xla_flags = jtu.set_host_platform_device_count(2) - + _exit_stack.enter_context(jtu.set_host_platform_device_count(2)) -# Reset to previous configuration in case other test modules will be run. def tearDownModule(): - prev_xla_flags() + _exit_stack.close() class JaxprEffectsTest(jtu.JaxTestCase): @@ -275,7 +270,7 @@ def f(x): effect_p.bind(effect=foo_effect) effect_p.bind(effect=bar_effect) return x - f = maps.xmap(f, in_axes=['a'], out_axes=['a']) + f = xmap(f, in_axes=['a'], out_axes=['a']) jaxpr = jax.make_jaxpr(f)(jnp.arange(jax.local_device_count())) self.assertSetEqual(jaxpr.effects, {foo_effect, bar_effect}) @@ -373,55 +368,6 @@ def f(x): 'incorrect set of output token.'): f.lower(2.) - def test_lowering_ordered_effect_should_create_tokens(self): - - def effect_lowering(ctx, *, effect): - ctx.set_tokens_out(ctx.tokens_in) - return [] - mlir.register_lowering(effect_p, effect_lowering) - - @jax.jit - def f(x): - effect_p.bind(effect=foo_effect) - return x + 1. - module = f.lower(2.).compiler_ir() - main = module.body.operations[0] - first_op = main.body.blocks[0].operations[0] - self.assertIn('hlo.create_token', first_op.operation.name) - - @jax.jit - def f(x): - effect_p.bind(effect=foo_effect) - effect_p.bind(effect=foo2_effect) - return x + 1. - module = f.lower(2.).compiler_ir() - main = module.body.operations[0] - first_op = main.body.blocks[0].operations[0] - self.assertIn('hlo.create_token', first_op.operation.name) - second_op = main.body.blocks[0].operations[1] - self.assertIn('hlo.create_token', second_op.operation.name) - - @jax.jit - def f(x): - effect_p.bind(effect=foo_effect) - return x + 1. - module = f.lower(2.).compiler_ir() - main = module.body.operations[0] - first_op = main.body.blocks[0].operations[0] - self.assertIn('hlo.create_token', first_op.operation.name) - - @jax.jit - def f(x): - effect_p.bind(effect=foo_effect) - effect_p.bind(effect=foo2_effect) - return x + 1. - module = f.lower(2.).compiler_ir() - main = module.body.operations[0] - first_op = main.body.blocks[0].operations[0] - self.assertIn('hlo.create_token', first_op.operation.name) - second_op = main.body.blocks[0].operations[1] - self.assertIn('hlo.create_token', second_op.operation.name) - def test_nontrivial_lowering_with_ordered_effect_should_consume_token(self): mlir.register_lowering(effect_p, function_effect_lowering) @@ -432,12 +378,11 @@ def f(x): return x + 1. module = f.lower(2.).compiler_ir() main = module.body.operations[0] - first_op = main.body.blocks[0].operations[0] - self.assertIn('hlo.create_token', first_op.operation.name) - second_op = main.body.blocks[0].operations[1] - self.assertEqual(second_op.operation.name, "func.call") - self.assertEqual(str(second_op.attributes["callee"]), "@effect") - self.assertEqual(second_op.operands[0].owner, first_op) + call_op = main.body.blocks[0].operations[0] + + self.assertEqual(call_op.operation.name, 'func.call') + self.assertEqual(str(call_op.attributes['callee']), '@effect') + func = module.body.operations[1] self.assertEqual(func.name.value, "effect") self.assertIn('hlo.token', str(func.type.inputs[0])) @@ -477,23 +422,24 @@ def f(x): self.assertLen(list(result_types), 1) self.assertEqual(str(result_types[0]), 'tensor') - def test_lowered_jaxpr_with_ordered_effects_takes_in_dummy_inputs(self): + def test_lowered_jaxpr_with_ordered_effects_takes_token_inputs(self): @jax.jit def f(x): effect_p.bind(effect=foo_effect) return x + 1. module = f.lower(1.).compiler_ir() input_types = module.body.operations[0].type.inputs - # First argument should be dummy token + token_type = '!stablehlo.token' + # First argument should be a token self.assertLen(list(input_types), 2) - self.assertEqual(str(input_types[0]), 'tensor<0xi1>') + self.assertEqual(str(input_types[0]), token_type) - # First output should be dummy token + # First output should be a token result_types = module.body.operations[0].type.results self.assertLen(list(result_types), 2) - self.assertEqual(str(result_types[0]), 'tensor<0xi1>') + self.assertEqual(str(result_types[0]), token_type) - def test_lowered_jaxpr_with_multiple_ordered_effects_takes_in_dummy_inputs(self): + def test_lowered_jaxpr_with_multiple_ordered_effects_takes_in_tokens(self): @jax.jit def f(x): effect_p.bind(effect=foo_effect) @@ -501,16 +447,17 @@ def f(x): return x + 1. module = f.lower(1.).compiler_ir() input_types = module.body.operations[0].type.inputs - # First two arguments should be dummy values + token_type = '!stablehlo.token' + # First two arguments should be token values self.assertLen(list(input_types), 3) - self.assertEqual(str(input_types[0]), 'tensor<0xi1>') - self.assertEqual(str(input_types[1]), 'tensor<0xi1>') + self.assertEqual(str(input_types[0]), token_type) + self.assertEqual(str(input_types[1]), token_type) - # First two outputs should be dummy values + # First two outputs should be token values result_types = module.body.operations[0].type.results self.assertLen(list(result_types), 3) - self.assertEqual(str(result_types[0]), 'tensor<0xi1>') - self.assertEqual(str(result_types[1]), 'tensor<0xi1>') + self.assertEqual(str(result_types[0]), token_type) + self.assertEqual(str(result_types[1]), token_type) def test_can_lower_and_run_jaxpr_with_ordered_effects(self): @jax.jit @@ -614,23 +561,22 @@ def log_value(x): log.append(x) return () - @functools.partial(jax.jit, device=jax.devices()[0]) + @jax.jit def f(x): # Expensive computation x = x.dot(x) x = jnp.log(x.sum()) return callback_p.bind(x, callback=log_value, effect=log_effect, out_avals=[]) - @functools.partial(jax.jit, device=jax.devices()[1]) + @jax.jit def g(x): return callback_p.bind(x, callback=log_value, effect=log_effect, out_avals=[]) - f(jnp.ones((500, 500))) - g(3.) - f(jnp.ones((500, 500))) - g(3.) - f(jnp.ones((500, 500))) - g(3.) + x = jax.device_put(jnp.ones((500, 500)), jax.devices()[0]) + y = jax.device_put(3., jax.devices()[1]) + for _ in range(3): + f(x) + g(y) jax.effects_barrier() f_, g_ = float(jnp.log(1.25e8)), 3. expected_log = [f_, g_, f_, g_, f_, g_] diff --git a/tests/jaxpr_util_test.py b/tests/jaxpr_util_test.py index a8df5a325b3d..4597ce6bd7d5 100644 --- a/tests/jaxpr_util_test.py +++ b/tests/jaxpr_util_test.py @@ -61,8 +61,8 @@ def f(x, y): def test_primitives_by_shape(self): def f(x, y): def sub(x, y): - return jnp.sum(jnp.array([x, y])), y - s, _ = jit(sub)(x, y) + return jnp.sum(jnp.array([x, y])) + s = jit(sub)(x, y) return jnp.sin(s) + jnp.cos(y) hist = jaxpr_util.primitives_by_shape(make_jaxpr(f)(1., 1.).jaxpr) @@ -74,7 +74,7 @@ def sub(x, y): f'cos :: float{t}[]', f'reduce_sum :: float{t}[]', f'concatenate :: float{t}[2]', - f'pjit :: float{t}[] *', + f'pjit :: float{t}[]', ] for k in shapes: self.assertEqual(hist[k], 1) diff --git a/tests/jet_test.py b/tests/jet_test.py index c72057246ebc..b1e2ef3f8380 100644 --- a/tests/jet_test.py +++ b/tests/jet_test.py @@ -29,8 +29,7 @@ from jax.experimental.jet import jet, fact, zero_series from jax import lax -from jax import config -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() def jvp_taylor(fun, primals, series): # Computes the Taylor series the slow way, with nested jvp. @@ -95,6 +94,8 @@ def _convert(x): check_dtypes=check_dtypes) @jtu.skip_on_devices("tpu") + # Default tolerance too tight on A100 after openxla/xla@a58070090 + @jax.default_matmul_precision("float32") def test_dot(self): M, K, N = 2, 3, 4 order = 3 @@ -242,6 +243,8 @@ def test_floor(self): self.unary_check(jnp.floor) @jtu.skip_on_devices("tpu") def test_ceil(self): self.unary_check(jnp.ceil) @jtu.skip_on_devices("tpu") + def test_trunc(self): self.unary_check(jnp.trunc) + @jtu.skip_on_devices("tpu") def test_round(self): self.unary_check(lax.round) @jtu.skip_on_devices("tpu") def test_sign(self): self.unary_check(lax.sign) diff --git a/tests/key_reuse_test.py b/tests/key_reuse_test.py index d164290ec48c..286088eebe48 100644 --- a/tests/key_reuse_test.py +++ b/tests/key_reuse_test.py @@ -21,13 +21,15 @@ from jax import core import jax.numpy as jnp from jax._src import prng +from jax._src import random from jax._src import test_util as jtu +from jax.errors import KeyReuseError from jax.experimental.key_reuse._core import ( - assert_consumed, assert_unconsumed, consume, consume_p) -from jax.experimental.key_reuse import _core, KeyReuseError + assert_consumed, assert_unconsumed, consume, consume_p, + Source, Sink, Forward, KeyReuseSignature) +from jax.experimental.key_reuse import _core -from jax import config -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() key = jax.eval_shape(jax.random.key, 0) @@ -36,7 +38,7 @@ primitives_with_static_signatures = { consume_p: (consume, key), - prng.reuse_key_p: (prng.reuse_key, key), + random.random_clone_p: (random.clone, key), prng.random_bits_p: (jax.random.bits, key), # prng.random_fold_in_p: (jax.random.fold_in, key, 2), prng.random_seed_p: (jax.random.key, 0), @@ -64,7 +66,7 @@ def apply_unknown_primitive(key): @jtu.with_config( jax_enable_custom_prng=False, - jax_enable_key_reuse_checks=False) + jax_debug_key_reuse=False) class KeyReuseUnitTestWithForwarding(jtu.JaxTestCase): def check_key_reuse(self, *args): return _core.check_key_reuse(*args) @@ -91,12 +93,12 @@ def f(key): assert_consumed(key2) self.check_key_reuse(f, jax.random.key(0)) - def test_reuse_key(self): + def test_random_clone(self): def f(key): assert_unconsumed(key) consume(key) assert_consumed(key) - key2 = prng.reuse_key(key) + key2 = jax.random.clone(key) assert_unconsumed(key2) self.check_key_reuse(f, jax.random.key(0)) @@ -206,6 +208,18 @@ def f(key): assert_consumed(key2) self.check_key_reuse(f, jax.random.key(0)) + def test_concatenate(self): + def f(key1, key2): + assert_unconsumed(key1) + assert_unconsumed(key2) + keys = jax.lax.concatenate([key1, key2], dimension=0) + assert_consumed(key1) + assert_consumed(key2) + assert_unconsumed(keys) + key1 = jax.random.split(jax.random.key(0)) + key2 = jax.random.split(jax.random.key(1)) + self.check_key_reuse(f, key1, key2) + def test_slice(self): def f(keys): assert_unconsumed(keys) @@ -329,10 +343,16 @@ def test_jaxpr_type_signature(self, primitive): func, *args = primitives_with_static_signatures[primitive] signature = _core.key_reuse_signatures[primitive] jaxpr = jax.make_jaxpr(func)(*args) - self.assertEqual(signature, _core.get_jaxpr_type_signature(jaxpr.jaxpr)) + self.assertEqual(signature, _core.jaxpr_type_signature(jaxpr.jaxpr)) + + @parameterized.parameters(*primitives_with_static_signatures) + def test_function_type_signature(self, primitive): + func, *args = primitives_with_static_signatures[primitive] + signature = _core.key_reuse_signatures[primitive] + self.assertEqual(signature, _core.function_type_signature(func, *args)) -@jtu.with_config(jax_enable_key_reuse_checks=False) +@jtu.with_config(jax_debug_key_reuse=False) class KeyReuseIntegrationTest(jtu.JaxTestCase): random_bits_error = "In random_bits, argument [0-9]+ is already consumed.*" random_split_error = "In random_split, argument [0-9]+ is already consumed.*" @@ -586,25 +606,43 @@ def f_good(x, key): self.check_key_reuse(jax.grad(f_good), x, key) -@jtu.with_config(jax_enable_key_reuse_checks=True) -class KeyReuseEager(jtu.JaxTestCase): +@jtu.with_config(jax_debug_key_reuse=True) +class KeyReuseEagerTest(jtu.JaxTestCase): jit_msg = "Previously-consumed key passed to jit-compiled function at index 0" eager_bits_msg = "Previously-consumed key passed to random_bits at index 0" traced_bits_msg = "In random_bits, argument 0 is already consumed." + def test_clone_eager(self): + key = jax.random.key(0) + key2 = jax.random.clone(key) + self.assertIsNot(key, key2) + + _ = jax.random.uniform(key) + self.assertTrue(key._consumed) + self.assertFalse(key2._consumed) + def test_simple_reuse_nojit(self): key = jax.random.key(0) - _ = jax.random.bits(key) with jax.disable_jit(): + _ = jax.random.bits(key) with self.assertRaisesRegex(KeyReuseError, self.eager_bits_msg): _ = jax.random.bits(key) def test_simple_key_reuse_jit(self): key = jax.random.key(0) - _ = jax.random.bits(key) + _ = jax.jit(jax.random.bits)(key) with self.assertRaisesRegex(KeyReuseError, self.jit_msg): _ = jax.jit(jax.random.bits)(key) + def test_closed_over_key_reuse_jit(self): + key = jax.random.key(0) + @jax.jit + def f(): + return jax.random.uniform(key) + _ = f() + with self.assertRaisesRegex(KeyReuseError, self.jit_msg): + _ = f() + def test_key_reuse_within_jit(self): @jax.jit def f(): @@ -614,9 +652,85 @@ def f(): f() +class KeyReuseImplementationTest(jtu.JaxTestCase): + + def assertEquivalent(self, a, b): + self.assertEqual(a, b) + self.assertEqual(hash(a), hash(b)) + + def assertNotEquivalent(self, a, b): + self.assertNotEqual(a, b) + self.assertNotEqual(hash(a), hash(b)) + + def test_source_sink_immutability(self): + mask = np.array([True, False]) + orig_mask_writeable = mask.flags.writeable + + sink = Sink(0, mask) + source = Source(0, mask) + + self.assertFalse(sink.mask.flags.writeable) + self.assertFalse(source.mask.flags.writeable) + self.assertEqual(mask.flags.writeable, orig_mask_writeable) + + with self.assertRaises(ValueError): + sink.idx = 1 + with self.assertRaises(ValueError): + sink.mask = True + with self.assertRaises(ValueError): + source.idx = 1 + with self.assertRaises(ValueError): + source.mask = True + + def test_source_sink_forward_equivalence_semantics(self): + + true_mask = np.array([True, True]) + false_mask = np.array([False, False]) + mixed_mask = np.array([True, False]) + + self.assertEquivalent(Source(0), Source(0, True)) + self.assertEquivalent(Source(0, True), Source(0, true_mask)) + self.assertEquivalent(Source(0, False), Source(0, false_mask)) + self.assertEquivalent(Source(0, mixed_mask), Source(0, mixed_mask)) + self.assertNotEquivalent(Source(0), Source(1)) + self.assertNotEquivalent(Source(0), Source(0, False)) + self.assertNotEquivalent(Source(0), Source(0, mixed_mask)) + + self.assertEquivalent(Sink(0), Sink(0, True)) + self.assertEquivalent(Sink(0, True), Sink(0, true_mask)) + self.assertEquivalent(Sink(0, False), Sink(0, false_mask)) + self.assertEquivalent(Sink(0, mixed_mask), Sink(0, mixed_mask)) + self.assertNotEquivalent(Sink(0), Sink(1)) + self.assertNotEquivalent(Sink(0), Sink(0, False)) + self.assertNotEquivalent(Sink(0), Sink(0, mixed_mask)) + + self.assertNotEquivalent(Source(0), Sink(0)) + + self.assertEquivalent(Forward(0, 1), Forward(0, 1)) + self.assertNotEquivalent(Forward(0, 1), Forward(1, 0)) + + def test_signature_equality_semantics(self): + self.assertEquivalent( + KeyReuseSignature(Sink(0), Source(1), Forward(1, 0)), + KeyReuseSignature(Forward(1, 0), Source(1), Sink(0))) + self.assertEquivalent( + KeyReuseSignature(), KeyReuseSignature()) + self.assertNotEquivalent( + KeyReuseSignature(Source(0)), KeyReuseSignature(Sink(0))) + + def test_reprs(self): + self.assertEqual(repr(Sink(0)), "Sink(0)") + self.assertEqual(repr(Source(0)), "Source(0)") + self.assertEqual(repr(Forward(0, 1)), "Forward(0, 1)") + self.assertEqual(repr(KeyReuseSignature(Sink(1), Source(0))), + "KeyReuseSignature(Sink(1), Source(0))") + self.assertEqual(repr(KeyReuseSignature(Sink(1), Sink(0))), + "KeyReuseSignature(Sink(0), Sink(1))") + + @jtu.with_config(jax_enable_checks=False) -class KeyReuseGlobalFlags(jtu.JaxTestCase): +class KeyReuseGlobalFlagsTest(jtu.JaxTestCase): def test_key_reuse_flag(self): @jax.jit @@ -629,14 +743,14 @@ def f_good(key): key = jax.random.key(0) - with jax.enable_key_reuse_checks(False): + with jax.debug_key_reuse(False): f_good(key) f_bad(key) # No failure f_bad.clear_cache() f_good.clear_cache() - with jax.enable_key_reuse_checks(True): + with jax.debug_key_reuse(True): f_good(key) with self.assertRaisesRegex(KeyReuseError, "In random_bits.*"): f_bad(key) diff --git a/tests/lax_autodiff_test.py b/tests/lax_autodiff_test.py index 630b08cc3694..ab3a183177f6 100644 --- a/tests/lax_autodiff_test.py +++ b/tests/lax_autodiff_test.py @@ -31,8 +31,7 @@ from jax._src.util import NumpyComplexWarning from jax.test_util import check_grads -from jax import config -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() compatible_shapes = [[(3,)], diff --git a/tests/lax_control_flow_test.py b/tests/lax_control_flow_test.py index f5218aff1a45..3e4d1fbf3e18 100644 --- a/tests/lax_control_flow_test.py +++ b/tests/lax_control_flow_test.py @@ -35,15 +35,15 @@ from jax._src import test_util as jtu from jax import tree_util from jax._src.util import unzip2 -from jax.experimental import maps from jax.ad_checkpoint import checkpoint as new_checkpoint, checkpoint_policies import jax.numpy as jnp # scan tests use numpy import jax.scipy as jsp from jax._src.lax import control_flow as lax_control_flow from jax._src.lax.control_flow import for_loop +from jax._src.interpreters import mlir +from jax._src.maps import xmap -from jax import config -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() # Some tests are useful for testing both lax.cond and lax.switch. This function @@ -96,6 +96,7 @@ def scan_with_remat_for(f, *args, **kwargs): SCAN_IMPLS_WITH_FOR = [ (lax.scan, 'unroll1'), (partial(lax.scan, unroll=2), 'unroll2'), + (partial(lax.scan, _split_transpose=True), 'split_transpose'), (scan_with_new_checkpoint , 'new_checkpoint'), (scan_with_new_checkpoint2, 'new_checkpoint2'), (scan_with_for, 'for_loop'), @@ -130,6 +131,15 @@ def scan_reference(f, init, xs): ignore_jit_of_pmap_warning = partial( jtu.ignore_warning, message=".*jit-of-pmap.*") +# A JAX primitive whose lowering is a custom call to a non-existent function. +prim_non_existent_custom_call = core.Primitive("__testing_non_existent_custom_call") +prim_non_existent_custom_call.def_abstract_eval(lambda x_aval: x_aval) +mlir.register_lowering( + prim_non_existent_custom_call, + lambda ctx, x: mlir.hlo.CustomCallOp( + [x.type], [x], + call_target_name=mlir.ir.StringAttr.get("__testing_non_existent_custom_call")).results) + class LaxControlFlowTest(jtu.JaxTestCase): @@ -1732,8 +1742,8 @@ def f(c, a): rtol = {np.float32: 2e-5, np.float64: 1e-13} atol = {np.float32: 6e-2, np.float64: 1e-13} else: - rtol = {np.float32: 2e-5, np.float64: 1e-13} - atol = {np.float32: 5e-5, np.float64: 1e-13} + rtol = {np.float32: 2e-4, np.float64: 1e-13} + atol = {np.float32: 8e-5, np.float64: 1e-13} if jit_f: f = jax.jit(f) @@ -1951,6 +1961,13 @@ def testScanBodyCarryTypeMismatchErrors(self): x[2]), None), (jnp.array(0, 'int32'),) * 3, None, length=1) + @jax.enable_checks(False) + def testScanInvalidUnrollRaises(self): + with self.assertRaisesRegex(ValueError, "`unroll` must be"): + jax.lax.scan(lambda x, _: (x, x), 0, jnp.arange(5), unroll=-1) + with self.assertRaisesRegex(ValueError, "`unroll` must be"): + jax.lax.scan(lambda x, _: (x, x), 0, jnp.arange(5), unroll=0) + @parameterized.named_parameters( {"testcase_name": f"_{scan_name}", "scan": scan_impl} @@ -2315,7 +2332,7 @@ def step(x, i): A = jnp.zeros((3, 3)) # The second DUS was unnecessarily replicating A across time. # We check XLA because _scan_impl is "underneath" the jaxpr language. - s = str(jax.xla_computation(jax.grad(loss))(A).as_hlo_text()) + s = jax.jit(jax.grad(loss)).lower(A).as_text('hlo') assert s.count("dynamic-update-slice(") < 2 def testScanLengthArg(self): @@ -2410,8 +2427,16 @@ def f(c, a): # but HLO should grow due to unrolling self.assertLess( - len(str(jax.xla_computation(scan)(c, xs).as_hlo_text())), - len(str(jax.xla_computation(scan_unrolled)(c, xs).as_hlo_text()))) + len(str(jax.jit(scan).lower(c, xs).as_text('hlo'))), + len(str(jax.jit(scan_unrolled).lower(c, xs).as_text('hlo')))) + + def test_scan_xs_none(self): + def f(h, _): + return h + 1, None + + length = 20 + h, _ = lax.scan(f, 0, length=length) + self.assertEqual(h, length) def test_disable_jit_cond_with_vmap(self): # https://github.com/google/jax/issues/3093 @@ -2503,7 +2528,7 @@ def f(c, a): scan_fun = lambda c, xs: lax.scan(f, c, xs) def new_jaxpr(): - jaxpr = jax.make_jaxpr(scan_fun)(c, xs).jaxpr + jaxpr = jax.make_jaxpr(partial(scan_fun))(c, xs).jaxpr scan = next(eqn for eqn in jaxpr.eqns if eqn.primitive.name == 'scan') return jaxpr, scan @@ -2537,22 +2562,6 @@ def new_jaxpr(): 'tuple of ClosedJaxpr required: (4, 2)'), lambda: core.check_jaxpr(jaxpr)) - jaxpr, eqn = new_jaxpr() - eqn.params['linear'] = (4, 2) - self.assertRaisesRegex( - core.JaxprTypeError, - re.escape('invalid cond param linear of type tuple, ' - 'tuple of bool required: (4, 2)'), - lambda: core.check_jaxpr(jaxpr)) - - jaxpr, eqn = new_jaxpr() - eqn.params['linear'] = 'multi\nline' - self.assertRaisesRegex( - core.JaxprTypeError, - r'invalid cond param linear of type str, ' - r'tuple of bool required:\r?\nmulti\r?\nline', - lambda: core.check_jaxpr(jaxpr)) - def test_cond_transformation_rule_with_consts(self): # https://github.com/google/jax/pull/9731 @@ -2712,7 +2721,7 @@ def body(carry): i, x = carry return i + 1, x + lax.psum(y, 'b') return lax.while_loop(cond, body, (0, z))[1] - maps.xmap(f, axis_sizes=dict(a=2, b=10), out_axes=(['a']), in_axes={})(1.) + xmap(f, axis_sizes=dict(a=2, b=10), out_axes=(['a']), in_axes={})(1.) def test_while_loop_fixed_point_with_batched_pred_and_consts(self): def f(i, x): @@ -2829,6 +2838,44 @@ def f(x): self.assertNotIn(" sine", hlo) self.assertIn(" cosine", hlo) + def test_platform_dependent_with_non_existent_custom_call(self): + if not jtu.test_device_matches(["cpu"]): + self.skipTest("Only for CPU") + + def f(x): + # One use with the bad custom call on a different platform branch + x1 = lax.platform_dependent(x, + cpu=jnp.sin, + other=prim_non_existent_custom_call.bind) + # and with the bad custom call in the default branch + x2 = lax.platform_dependent(x, + cpu=jnp.sin, + default=prim_non_existent_custom_call.bind) + # and one use where the current platform is the default + x3 = lax.platform_dependent(x, + other=prim_non_existent_custom_call.bind, + default=jnp.sin) + return x1 + x2 + x3 + + x = np.arange(3, dtype=np.float32) + hlo = str(jax.jit(f).lower(x).compiler_ir()) + occurrences = re.findall(prim_non_existent_custom_call.name, hlo) + self.assertLen(occurrences, 3) + + res_eager = f(x) + self.assertAllClose(res_eager, 3. * np.sin(x)) + res_jit = jax.jit(f)(x) + self.assertAllClose(res_jit, 3 * np.sin(x)) + + res_vmap = jax.vmap(f)(x) + self.assertAllClose(res_vmap, 3. * np.sin(x)) + + _, res_jvp = jax.jvp(f, (x,), (np.full(x.shape, .1, dtype=x.dtype),)) + self.assertAllClose(res_jvp, .3 * np.cos(x)) + + res_grad = jax.grad(f)(1.) + self.assertAllClose(res_grad, 3. * np.cos(1.)) + def test_platform_dependent_multiple_identical_branches(self): x = np.arange(3, dtype=np.float32) def f(x): @@ -2898,6 +2945,46 @@ def f(x): self.assertEqual(expect_a_dot, " dot(" in hlo) self.assertEqual(not expect_a_dot, " while(" in hlo) + def test_scan_lowering_doesnt_introduce_singleton(self): + b = 4 + i = 2 + + def scan(y): + def body(carry, x): + return carry, jnp.dot(x, x) + return jax.lax.scan(body, 1.0, y, unroll=False) + + fn = jax.jit(scan) + + init = np.array(np.arange(b * i * i), dtype=np.float32).reshape((b, i, i)) + hlo_text = fn.lower(init).as_text('hlo') + self.assertNotIn('4,1,2,2', hlo_text) + + def test_cond_vmap_forwarding_doesnt_promote(self): + def f(x, y): + x, y = jax.lax.cond( + x < 3, + lambda x, y: (x * 2, y), + lambda x, y: (x * 3, y), + x, y + ) + return x, y + + x = jnp.arange(3) + y = jnp.array(3.) + + x2, y2 = jax.vmap(f, in_axes=(0, None), out_axes=(0, None))(x, y) # don't crash + + assert x is not x2 + assert y is y2 + + def test_cond_casting(self): + x = 1.0 + identity = lambda x: x + + y = lax.cond(True, identity, identity, x) + self.assertEqual(y, x) + self.assertIsInstance(y, jax.Array) if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/lax_metal_test.py b/tests/lax_metal_test.py new file mode 100644 index 000000000000..dab26d86c0a2 --- /dev/null +++ b/tests/lax_metal_test.py @@ -0,0 +1,5773 @@ +# Copyright 2018 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from array import array as make_python_array +import collections +import copy +from functools import partial +import io +import itertools +import math +import platform +from typing import Union, cast +import unittest +from unittest import SkipTest + +from absl.testing import absltest +from absl.testing import parameterized + +import numpy as np +try: + import numpy_dispatch +except ImportError: + numpy_dispatch = None + +import jax +import jax.ops +from jax import lax +from jax import numpy as jnp +from jax.sharding import SingleDeviceSharding + +from jax._src import array +from jax._src import config +from jax._src import core +from jax._src import dtypes +from jax._src import test_util as jtu +from jax._src.lax import lax as lax_internal + +from jax._src.util import safe_zip, NumpyComplexWarning + +try: + from jax_plugins import metal_plugin +except ImportError: + metal_plugin = None + +config.parse_flags_with_absl() + +numpy_version = jtu.numpy_version() + +nonempty_nonscalar_array_shapes = [(4,), (3, 4), (3, 1), (1, 4), (2, 1, 4), (2, 3, 4)] +nonempty_array_shapes = [()] + nonempty_nonscalar_array_shapes +one_dim_array_shapes = [(1,), (6,), (12,)] +empty_array_shapes = [(0,), (0, 4), (3, 0),] +broadcast_compatible_shapes = [(), (1,), (3,), (1, 3), (4, 1), (4, 3)] + +scalar_shapes = [jtu.NUMPY_SCALAR_SHAPE, jtu.PYTHON_SCALAR_SHAPE] +array_shapes = nonempty_array_shapes + empty_array_shapes +nonzerodim_shapes = nonempty_nonscalar_array_shapes + empty_array_shapes +nonempty_shapes = scalar_shapes + nonempty_array_shapes +all_shapes = scalar_shapes + array_shapes + +float_dtypes = jtu.dtypes.all_floating +complex_dtypes = jtu.dtypes.complex +int_dtypes = jtu.dtypes.all_integer +unsigned_dtypes = jtu.dtypes.all_unsigned +bool_dtypes = jtu.dtypes.boolean +default_dtypes = float_dtypes + int_dtypes +inexact_dtypes = float_dtypes + complex_dtypes +number_dtypes = float_dtypes + complex_dtypes + int_dtypes + unsigned_dtypes +all_dtypes = number_dtypes + bool_dtypes + +NO_VALUE = object() + +python_scalar_dtypes = [jnp.bool_, jnp.int_, jnp.float_, jnp.complex_] + +# uint64 is problematic because with any uint type it promotes to float: +int_dtypes_no_uint64 = [d for d in int_dtypes + unsigned_dtypes if d != np.uint64] + +def np_unique_backport(ar, return_index=False, return_inverse=False, return_counts=False, + axis=None, **kwds): + # Wrapper for np.unique, handling the change to inverse_indices in numpy 2.0 + result = np.unique(ar, return_index=return_index, return_inverse=return_inverse, + return_counts=return_counts, axis=axis, **kwds) + if jtu.numpy_version() >= (2, 0, 0) or np.ndim(ar) == 1 or not return_inverse: + return result + + idx = 2 if return_index else 1 + inverse_indices = result[idx] + if axis is None: + inverse_indices = inverse_indices.reshape(np.shape(ar)) + else: + inverse_indices = np.expand_dims(inverse_indices, [i for i in range(np.ndim(ar)) if i != axis]) + return (*result[:idx], inverse_indices, *result[idx + 1:]) + + +def _indexer_with_default_outputs(indexer, use_defaults=True): + """Like jtu.with_jax_dtype_defaults, but for __getitem__ APIs""" + class Indexer: + @partial(jtu.with_jax_dtype_defaults, use_defaults=use_defaults) + def __getitem__(self, *args): + return indexer.__getitem__(*args) + return Indexer() + +def _valid_dtypes_for_shape(shape, dtypes): + # Not all (shape, dtype) pairs are valid. In particular, Python scalars only + # have one type in each category (float, bool, etc.) + if shape is jtu.PYTHON_SCALAR_SHAPE: + return [t for t in dtypes if t in python_scalar_dtypes] + return dtypes + +def _shape_and_dtypes(shapes, dtypes): + for shape in shapes: + for dtype in _valid_dtypes_for_shape(shape, dtypes): + yield (shape, dtype) + +def _compatible_shapes(shape): + if np.ndim(shape) == 0 or shape in scalar_shapes: + return [shape] + return (shape[n:] for n in range(len(shape) + 1)) + +OpRecord = collections.namedtuple( + "OpRecord", + ["name", "nargs", "dtypes", "shapes", "rng_factory", "diff_modes", + "test_name", "check_dtypes", "tolerance", "inexact", "kwargs"]) + +def op_record(name, nargs, dtypes, shapes, rng_factory, diff_modes, + test_name=None, check_dtypes=True, + tolerance=None, inexact=False, kwargs=None): + test_name = test_name or name + return OpRecord(name, nargs, dtypes, shapes, rng_factory, diff_modes, + test_name, check_dtypes, tolerance, inexact, kwargs) + + +JAX_ARGMINMAX_RECORDS = [ + op_record("argmin", 1, default_dtypes, nonempty_shapes, jtu.rand_some_equal, []), + op_record("argmax", 1, default_dtypes, nonempty_shapes, jtu.rand_some_equal, []), + op_record("nanargmin", 1, default_dtypes, nonempty_shapes, jtu.rand_some_nan, []), + op_record("nanargmax", 1, default_dtypes, nonempty_shapes, jtu.rand_some_nan, []), +] + +def _shapes_are_broadcast_compatible(shapes): + try: + lax.broadcast_shapes(*(() if s in scalar_shapes else s for s in shapes)) + except ValueError: + return False + else: + return True + +def _shapes_are_equal_length(shapes): + return all(len(shape) == len(shapes[0]) for shape in shapes[1:]) + +@unittest.skipIf(metal_plugin == None, "Tests require jax-metal plugin.") +class LaxBackedNumpyTests(jtu.JaxTestCase): + """Tests for LAX-backed Numpy implementation.""" + + def _GetArgsMaker(self, rng, shapes, dtypes, np_arrays=True): + def f(): + out = [rng(shape, dtype or jnp.float_) + for shape, dtype in zip(shapes, dtypes)] + if np_arrays: + return out + return [jnp.asarray(a) if isinstance(a, (np.ndarray, np.generic)) else a + for a in out] + return f + + @parameterized.parameters( + [dtype for dtype in [jnp.bool, jnp.uint8, jnp.uint16, jnp.uint32, + jnp.int8, jnp.int16, jnp.int32, jnp.int64, + jnp.float16, jnp.float32] + if dtype == dtypes.canonicalize_dtype(dtype)]) + def testDtypeWrappers(self, dtype): + arr = dtype(0) + self.assertIsInstance(arr, jax.Array) + self.assertEqual(arr.dtype, np.dtype(dtype)) + self.assertArraysEqual(arr, 0, check_dtypes=False) + + # No copy primitive is generated + jaxpr = jax.make_jaxpr(dtype)(0) + prims = [eqn.primitive for eqn in jaxpr.eqns] + self.assertEqual(prims, [lax.convert_element_type_p]) # No copy generated. + + def testBoolDtypeAlias(self): + self.assertIs(jnp.bool, jnp.bool_) + + @jtu.sample_product( + dtype=float_dtypes + [object], + allow_pickle=[True, False], + ) + def testLoad(self, dtype, allow_pickle): + if dtype == object and not allow_pickle: + self.skipTest("dtype=object requires allow_pickle=True") + rng = jtu.rand_default(self.rng()) + arr = rng((10), dtype) + with io.BytesIO() as f: + jnp.save(f, arr) + f.seek(0) + arr_out = jnp.load(f, allow_pickle=allow_pickle) + self.assertArraysEqual(arr, arr_out, allow_object_dtype=True) + + @unittest.skip("Jax-metal fail.") + def testArrayEqualExamples(self): + # examples from the array_equal() docstring. + self.assertTrue(jnp.array_equal([1, 2], [1, 2])) + self.assertTrue(jnp.array_equal(np.array([1, 2]), np.array([1, 2]))) + self.assertFalse(jnp.array_equal([1, 2], [1, 2, 3])) + self.assertFalse(jnp.array_equal([1, 2], [1, 4])) + + a = np.array([1, np.nan]) + self.assertFalse(jnp.array_equal(a, a)) + self.assertTrue(jnp.array_equal(a, a, equal_nan=True)) + + a = np.array([1 + 1j]) + b = a.copy() + a.real = np.nan + b.imag = np.nan + self.assertTrue(jnp.array_equal(a, b, equal_nan=True)) + + def testArrayEquivExamples(self): + # examples from the array_equiv() docstring. + self.assertTrue(jnp.array_equiv([1, 2], [1, 2])) + self.assertFalse(jnp.array_equiv([1, 2], [1, 3])) + with jax.numpy_rank_promotion('allow'): + self.assertTrue(jnp.array_equiv([1, 2], [[1, 2], [1, 2]])) + self.assertFalse(jnp.array_equiv([1, 2], [[1, 2, 1, 2], [1, 2, 1, 2]])) + self.assertFalse(jnp.array_equiv([1, 2], [[1, 2], [1, 3]])) + + def testArrayModule(self): + if numpy_dispatch is None: + raise SkipTest('requires https://github.com/seberg/numpy-dispatch') + + jnp_array = jnp.array(1.0) + np_array = np.array(1.0) + + module = numpy_dispatch.get_array_module(jnp_array) + self.assertIs(module, jnp) + + module = numpy_dispatch.get_array_module(jnp_array, np_array) + self.assertIs(module, jnp) + + def f(x): + module = numpy_dispatch.get_array_module(x) + self.assertIs(module, jnp) + return x + jax.jit(f)(jnp_array) + jax.grad(f)(jnp_array) + + @jtu.sample_product( + [dict(shape=shape, axis=axis) + for shape in all_shapes + for axis in list(range(-len(shape), len(shape)))], + discont=[None, "pi", 2], + period=["2pi", "pi"], + dtype=default_dtypes, + ) + def testUnwrap(self, shape, dtype, axis, discont, period): + special_vals = {"pi": np.pi, "2pi": 2 * np.pi} + period = special_vals.get(period, period) + discont = special_vals.get(discont, discont) + + rng = jtu.rand_default(self.rng()) + + def np_fun(x): + dtype = None + if x.dtype == dtypes.bfloat16: + dtype = x.dtype + x = x.astype(np.float32) + out = np.unwrap(x, axis=axis, discont=discont, period=period) + return out if dtype is None else out.astype(dtype) + + jnp_fun = partial(jnp.unwrap, axis=axis, discont=discont, period=period) + if not dtypes.issubdtype(dtype, np.inexact): + # This case requires implicit dtype promotion + jnp_fun = jax.numpy_dtype_promotion('standard')(jnp_fun) + args_maker = lambda: [rng(shape, dtype)] + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False, + atol={dtypes.bfloat16: 1e-1, np.float16: 1e-2}) + self._CompileAndCheck(jnp_fun, args_maker, atol={dtypes.bfloat16: 1e-1}) + + @jtu.sample_product( + [dict(shape=shape, axis=axis) + for shape in all_shapes + for axis in list(range(-len(shape), len(shape))) + [None]], + dtype=all_dtypes, + ) + def testCountNonzero(self, shape, dtype, axis): + rng = jtu.rand_some_zero(self.rng()) + np_fun = lambda x: np.count_nonzero(x, axis) + jnp_fun = lambda x: jnp.count_nonzero(x, axis) + args_maker = lambda: [rng(shape, dtype)] + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False) + self._CompileAndCheck(jnp_fun, args_maker) + + @jtu.sample_product(shape=nonzerodim_shapes, dtype=all_dtypes) + def testNonzero(self, shape, dtype): + rng = jtu.rand_some_zero(self.rng()) + args_maker = lambda: [rng(shape, dtype)] + self._CheckAgainstNumpy(np.nonzero, jnp.nonzero, args_maker, check_dtypes=False) + + @jtu.sample_product( + [dict(shape=shape, fill_value=fill_value) + for shape in nonempty_array_shapes + for fill_value in [None, -1, shape or (1,)] + ], + dtype=all_dtypes, + size=[1, 5, 10], + ) + def testNonzeroSize(self, shape, dtype, size, fill_value): + rng = jtu.rand_some_zero(self.rng()) + args_maker = lambda: [rng(shape, dtype)] + def np_fun(x): + result = np.nonzero(x) + if size <= len(result[0]): + return tuple(arg[:size] for arg in result) + else: + fillvals = fill_value if np.ndim(fill_value) else len(result) * [fill_value or 0] + return tuple(np.concatenate([arg, np.full(size - len(arg), fval, arg.dtype)]) + for fval, arg in safe_zip(fillvals, result)) + jnp_fun = lambda x: jnp.nonzero(x, size=size, fill_value=fill_value) + with jtu.ignore_warning(category=DeprecationWarning, + message="Calling nonzero on 0d arrays.*"): + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False) + self._CompileAndCheck(jnp_fun, args_maker) + + @jtu.sample_product(shape=all_shapes, dtype=all_dtypes) + def testFlatNonzero(self, shape, dtype): + rng = jtu.rand_some_zero(self.rng()) + np_fun = jtu.ignore_warning( + category=DeprecationWarning, + message="Calling nonzero on 0d arrays.*")(np.flatnonzero) + jnp_fun = jnp.flatnonzero + args_maker = lambda: [rng(shape, dtype)] + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False) + + # JIT compilation requires specifying the size statically: + jnp_fun = lambda x: jnp.flatnonzero(x, size=np.size(x) // 2) + self._CompileAndCheck(jnp_fun, args_maker) + + @jtu.sample_product( + shape=nonempty_array_shapes, + dtype=all_dtypes, + fill_value=[None, -1, 10, (-1,), (10,)], + size=[1, 5, 10], + ) + def testFlatNonzeroSize(self, shape, dtype, size, fill_value): + rng = jtu.rand_some_zero(self.rng()) + args_maker = lambda: [rng(shape, dtype)] + @jtu.ignore_warning(category=DeprecationWarning, message="Calling nonzero on 0d arrays.*") + def np_fun(x): + result = np.flatnonzero(x) + if size <= len(result): + return result[:size] + else: + fill_val = fill_value or 0 + return np.concatenate([result, np.full(size - len(result), fill_val, result.dtype)]) + jnp_fun = lambda x: jnp.flatnonzero(x, size=size, fill_value=fill_value) + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False) + self._CompileAndCheck(jnp_fun, args_maker) + + @jtu.sample_product(shape=nonzerodim_shapes, dtype=all_dtypes) + def testArgWhere(self, shape, dtype): + rng = jtu.rand_some_zero(self.rng()) + args_maker = lambda: [rng(shape, dtype)] + self._CheckAgainstNumpy(np.argwhere, jnp.argwhere, args_maker, check_dtypes=False) + + # JIT compilation requires specifying a size statically. Full test of this + # behavior is in testNonzeroSize(). + jnp_fun = lambda x: jnp.argwhere(x, size=np.size(x) // 2) + self._CompileAndCheck(jnp_fun, args_maker) + + @jtu.sample_product( + [dict(shape=shape, fill_value=fill_value) + for shape in nonempty_array_shapes + for fill_value in [None, -1, shape or (1,)] + ], + dtype=all_dtypes, + size=[1, 5, 10], + ) + def testArgWhereSize(self, shape, dtype, size, fill_value): + rng = jtu.rand_some_zero(self.rng()) + args_maker = lambda: [rng(shape, dtype)] + def np_fun(x): + result = np.argwhere(x) + if size <= len(result): + return result[:size] + else: + fillvals = fill_value if np.ndim(fill_value) else result.shape[-1] * [fill_value or 0] + return np.empty((size, 0), dtype=int) if np.ndim(x) == 0 else np.stack([np.concatenate([arg, np.full(size - len(arg), fval, arg.dtype)]) + for fval, arg in safe_zip(fillvals, result.T)]).T + jnp_fun = lambda x: jnp.argwhere(x, size=size, fill_value=fill_value) + + with jtu.ignore_warning(category=DeprecationWarning, + message="Calling nonzero on 0d arrays.*"): + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False) + self._CompileAndCheck(jnp_fun, args_maker) + + @jtu.sample_product( + [dict(np_op=getattr(np, rec.name), jnp_op=getattr(jnp, rec.name), + shape=shape, dtype=dtype, axis=axis, rng_factory=rec.rng_factory) + for rec in JAX_ARGMINMAX_RECORDS + for shape, dtype in _shape_and_dtypes(rec.shapes, rec.dtypes) + for axis in range(-len(shape), len(shape))], + keepdims=[False, True], + ) + def testArgMinMax(self, np_op, jnp_op, rng_factory, shape, dtype, axis, keepdims): + rng = rng_factory(self.rng()) + if dtype == np.complex128 and jtu.test_device_matches(["gpu"]): + raise unittest.SkipTest("complex128 reductions not supported on GPU") + if "nan" in np_op.__name__ and dtype == jnp.bfloat16: + raise unittest.SkipTest("NumPy doesn't correctly handle bfloat16 arrays") + kwds = {"keepdims": True} if keepdims else {} + + np_fun = jtu.with_jax_dtype_defaults(partial(np_op, axis=axis, **kwds)) + jnp_fun = partial(jnp_op, axis=axis, **kwds) + + args_maker = lambda: [rng(shape, dtype)] + try: + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + except ValueError as e: + if str(e) == "All-NaN slice encountered": + self.skipTest("JAX doesn't support checking for all-NaN slices") + else: + raise + self._CompileAndCheck(jnp_fun, args_maker) + + @jtu.sample_product( + [dict(name=rec.name, np_op=getattr(np, rec.name), + jnp_op=getattr(jnp, rec.name)) + for rec in JAX_ARGMINMAX_RECORDS], + ) + def testArgMinMaxEmpty(self, name, np_op, jnp_op): + name = name[3:] if name.startswith("nan") else name + msg = f"attempt to get {name} of an empty sequence" + with self.assertRaisesRegex(ValueError, msg): + jnp_op(np.array([])) + with self.assertRaisesRegex(ValueError, msg): + jnp_op(np.zeros((2, 0)), axis=1) + np_fun = jtu.with_jax_dtype_defaults(partial(np_op, axis=0)) + jnp_fun = partial(jnp_op, axis=0) + args_maker = lambda: [np.zeros((2, 0))] + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + + @jtu.sample_product( + [dict(lhs_shape=lhs_shape, rhs_shape=rhs_shape, axes=axes) + for lhs_shape, rhs_shape, axes in [ + [(2,), (2,), (-1, -1, -1, None)], # scalar output + [(2, 4), (2, 4), (-1, -1, -1, 0)], # 2D vectors + [(3, 4), (3, 4), (-1, -1, -1, 0)], # 3D vectors + [(3, 4), (3, 6, 5, 4), (-1, -1, -1, 0)], # broadcasting + [(4, 3), (3, 6, 5, 4), (1, 0, -1, None)], # different axes + [(6, 1, 3), (5, 3), (-1, -1, -1, None)], # more broadcasting + [(6, 1, 2), (5, 3), (-1, -1, -1, None)], # mixed 2D and 3D vectors + [(10, 5, 2, 8), (1, 5, 1, 3), (-2, -1, -3, None)], # axes/broadcasting + [(4, 5, 2), (4, 5, 2), (-1, -1, 0, None)], # axisc should do nothing + [(4, 5, 2), (4, 5, 2), (-1, -1, -1, None)] # same as before + ]], + lhs_dtype=number_dtypes, + rhs_dtype=number_dtypes, + ) + @jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion. + def testCross(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype, axes): + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(lhs_shape, lhs_dtype), rng(rhs_shape, rhs_dtype)] + axisa, axisb, axisc, axis = axes + jnp_fun = lambda a, b: jnp.cross(a, b, axisa, axisb, axisc, axis) + # Note: 2D inputs to jnp.cross are deprecated in numpy 2.0. + @jtu.ignore_warning(category=DeprecationWarning, + message="Arrays of 2-dimensional vectors are deprecated.") + def np_fun(a, b): + a = a.astype(np.float32) if lhs_dtype == jnp.bfloat16 else a + b = b.astype(np.float32) if rhs_dtype == jnp.bfloat16 else b + out = np.cross(a, b, axisa, axisb, axisc, axis) + return out.astype(jnp.promote_types(lhs_dtype, rhs_dtype)) + tol_spec = {dtypes.bfloat16: 3e-1, np.float16: 0.15} + tol = max(jtu.tolerance(lhs_dtype, tol_spec), + jtu.tolerance(rhs_dtype, tol_spec)) + with jtu.strict_promotion_if_dtypes_match([lhs_dtype, rhs_dtype]): + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, tol=tol) + self._CompileAndCheck(jnp_fun, args_maker, atol=tol, rtol=tol) + + @jtu.sample_product( + [dict(lhs_shape=lhs_shape, rhs_shape=rhs_shape) + for lhs_shape, rhs_shape in [ + ((3, 3), ()), + ((), (3, 3)), + ((4, 5), (5,)), + ((6,), (6, 4)), + ((3, 4), (4, 5)), + ((4, 3, 2), (2,)), + ((2,), (3, 2, 4)), + ((4, 3, 2), (2, 5)), + ((5, 2), (3, 2, 4)), + ((2, 3, 4), (5, 4, 1))]], + lhs_dtype=float_dtypes,#number_dtypes, + rhs_dtype=float_dtypes,#number_dtypes, + ) + @jax.default_matmul_precision("float32") + def testDot(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype): + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(lhs_shape, lhs_dtype), rng(rhs_shape, rhs_dtype)] + tol = {np.float16: 1e-2, np.float32: 2e-5, np.float64: 1e-14, + np.complex128: 1e-14} + if (lhs_dtype in [np.float16, jnp.bfloat16] and + rhs_dtype in [np.float16, jnp.bfloat16]): + tol = 1e-2 + def np_dot(x, y): + x = x.astype(np.float32) if lhs_dtype == jnp.bfloat16 else x + y = y.astype(np.float32) if rhs_dtype == jnp.bfloat16 else y + return np.dot(x, y).astype(jnp.promote_types(lhs_dtype, rhs_dtype)) + with jtu.strict_promotion_if_dtypes_match([lhs_dtype, rhs_dtype]): + self._CheckAgainstNumpy(np_dot, jnp.dot, args_maker, tol=tol) + self._CompileAndCheck(jnp.dot, args_maker, atol=tol, rtol=tol) + + @jtu.sample_product( + lhs_dtype=number_dtypes, + rhs_dtype=number_dtypes, + ) + @jax.numpy_dtype_promotion('standard') + def testMixedPrecisionDot(self, lhs_dtype, rhs_dtype): + # This test confirms that jnp.dot lowers to a single dot_general call, + # avoiding explicit type casting of inputs and outputs. + lhs = jax.ShapeDtypeStruct((5,), lhs_dtype) + rhs = jax.ShapeDtypeStruct((5,), rhs_dtype) + jaxpr = jax.make_jaxpr(jnp.dot)(lhs, rhs) + prims = [eqn.primitive for eqn in jaxpr.eqns] + self.assertIn(prims, [ + [lax.dot_general_p], + [lax.dot_general_p, lax.convert_element_type_p] + ]) + + @jtu.sample_product( + [dict(name=name, lhs_shape=lhs_shape, rhs_shape=rhs_shape) + for name, lhs_shape, rhs_shape in [ + ("vector-vector", (3,), (3,)), + ("matrix-vector", (3, 3), (3,)), + ("vector-matrix", (3,), (3, 3)), + ("matrix-matrix", (3, 3), (3, 3)), + ("vector-tensor", (3,), (5, 3, 2)), + ("tensor-vector", (5, 3, 2), (2,)), + ("matrix-tensor", (5, 2), (3, 2, 4)), + ("tensor-matrix", (5, 2, 3), (3, 2)), + ("tensor-tensor", (5, 3, 4), (5, 4, 1)), + ("tensor-tensor-broadcast", (3, 1, 3, 4), (5, 4, 1))]], + lhs_dtype=float_dtypes, #number_dtypes, + rhs_dtype=float_dtypes, #number_dtypes, + ) + @jax.default_matmul_precision("float32") + def testMatmul(self, name, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype): + rng = jtu.rand_default(self.rng()) + def np_fun(x, y): + dtype = jnp.promote_types(lhs_dtype, rhs_dtype) + return np.matmul(x, y).astype(dtype) + args_maker = lambda: [rng(lhs_shape, lhs_dtype), rng(rhs_shape, rhs_dtype)] + tol = {np.float16: 1e-2, np.float32: 2e-2, np.float64: 1e-12, + np.complex128: 1e-12} + + with jtu.strict_promotion_if_dtypes_match([lhs_dtype, rhs_dtype]): + self._CheckAgainstNumpy(np_fun, jnp.matmul, args_maker, tol=tol) + self._CompileAndCheck(jnp.matmul, args_maker, atol=tol, rtol=tol) + + @jtu.sample_product( + lhs_batch=broadcast_compatible_shapes, + rhs_batch=broadcast_compatible_shapes, + axis_size=[2, 4], + axis=range(-2, 2), + dtype=float_dtypes,#number_dtypes, + ) + @jax.default_matmul_precision("float32") + def testVecdot(self, lhs_batch, rhs_batch, axis_size, axis, dtype): + # Construct vecdot-compatible shapes. + size = min(len(lhs_batch), len(rhs_batch)) + axis = int(np.clip(axis, -size - 1, size)) + if axis >= 0: + lhs_shape = (*lhs_batch[:axis], axis_size, *lhs_batch[axis:]) + rhs_shape = (*rhs_batch[:axis], axis_size, *rhs_batch[axis:]) + else: + laxis = axis + len(lhs_batch) + 1 + lhs_shape = (*lhs_batch[:laxis], axis_size, *lhs_batch[laxis:]) + raxis = axis + len(rhs_batch) + 1 + rhs_shape = (*rhs_batch[:raxis], axis_size, *rhs_batch[raxis:]) + + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)] + @jtu.promote_like_jnp + def np_fn(x, y, axis=axis): + f = jtu.numpy_vecdot if jtu.numpy_version() < (2, 0, 0) else np.vecdot + return f(x, y, axis=axis).astype(x.dtype) + jnp_fn = partial(jnp.vecdot, axis=axis) + tol = {np.float16: 1e-2, np.float32: 1E-3, np.float64: 1e-12, + np.complex64: 1E-3, np.complex128: 1e-12} + self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, tol=tol) + self._CompileAndCheck(jnp_fn, args_maker, tol=tol) + + @jtu.sample_product( + [dict(lhs_shape=lhs_shape, rhs_shape=rhs_shape, axes=axes) + for lhs_shape, rhs_shape, axes in [ + [(3,), (), 0], + [(2, 3, 4), (5, 6, 7), 0], # from issue #740 + [(2, 3, 4), (3, 4, 5, 6), 2], + [(2, 3, 4), (5, 4, 3, 6), [1, 2]], + [(2, 3, 4), (5, 4, 3, 6), [[1, 2], [2, 1]]], + [(1, 2, 3, 4), (4, 5, 3, 6), [[2, 3], [2, 0]]], + ]], + lhs_dtype=float_dtypes,#number_dtypes, + rhs_dtype=float_dtypes,#number_dtypes, + ) + @jax.default_matmul_precision("float32") + def testTensordot(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype, axes): + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(lhs_shape, lhs_dtype), rng(rhs_shape, rhs_dtype)] + jnp_fun = lambda a, b: jnp.tensordot(a, b, axes) + def np_fun(a, b): + a = a if lhs_dtype != jnp.bfloat16 else a.astype(np.float32) + b = b if rhs_dtype != jnp.bfloat16 else b.astype(np.float32) + dtype = jnp.promote_types(lhs_dtype, rhs_dtype) + return np.tensordot(a, b, axes).astype(dtype) + tol = {np.float16: 1e-1, np.float32: 1e-3, np.float64: 1e-12, + np.complex64: 1e-3, np.complex128: 1e-12} + + with jtu.strict_promotion_if_dtypes_match([lhs_dtype, rhs_dtype]): + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, tol=tol) + self._CompileAndCheck(jnp_fun, args_maker, tol=tol) + + def testTensordotErrors(self): + a = self.rng().random((3, 2, 2)) + b = self.rng().random((2,)) + self.assertRaisesRegex( + TypeError, "Number of tensordot axes.*exceeds input ranks.*", + lambda: jnp.tensordot(a, b, axes=2)) + + self.assertRaisesRegex( + TypeError, "tensordot requires axes lists to have equal length.*", + lambda: jnp.tensordot(a, b, axes=([0], [0, 1]))) + + self.assertRaisesRegex( + TypeError, "tensordot requires both axes lists to be either ints, tuples or lists.*", + lambda: jnp.tensordot(a, b, axes=('bad', 'axes'))) + + self.assertRaisesRegex( + TypeError, "tensordot axes argument must be an int, a pair of ints, or a pair of lists.*", + lambda: jnp.tensordot(a, b, axes='badaxes')) + + @jtu.sample_product( + element_shape=all_shapes, + test_shape=all_shapes, + dtype=default_dtypes, + invert=[False, True], + ) + def testIsin(self, element_shape, test_shape, dtype, invert): + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(element_shape, dtype), rng(test_shape, dtype)] + jnp_fun = lambda e, t: jnp.isin(e, t, invert=invert) + np_fun = lambda e, t: np.isin(e, t, invert=invert) + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + + @jtu.sample_product( + dtype1=[s for s in default_dtypes if s != jnp.bfloat16], + dtype2=[s for s in default_dtypes if s != jnp.bfloat16], + shape1=all_shapes, + shape2=all_shapes, + ) + def testSetdiff1d(self, shape1, shape2, dtype1, dtype2): + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(shape1, dtype1), rng(shape2, dtype2)] + + with jtu.strict_promotion_if_dtypes_match([dtype1, dtype2]): + self._CheckAgainstNumpy(np.setdiff1d, jnp.setdiff1d, args_maker) + + @unittest.skip("JAx-metal fail.") + @jtu.sample_product( + dtype1=[s for s in default_dtypes if s != jnp.bfloat16], + dtype2=[s for s in default_dtypes if s != jnp.bfloat16], + shape1=all_shapes, + shape2=all_shapes, + size=[1, 5, 10], + fill_value=[None, -1], + ) + def testSetdiff1dSize(self, shape1, shape2, dtype1, dtype2, size, fill_value): + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(shape1, dtype1), rng(shape2, dtype2)] + def np_fun(arg1, arg2): + result = np.setdiff1d(arg1, arg2) + if size <= len(result): + return result[:size] + else: + return np.pad(result, (0, size-len(result)), constant_values=fill_value or 0) + def jnp_fun(arg1, arg2): + return jnp.setdiff1d(arg1, arg2, size=size, fill_value=fill_value) + with jtu.strict_promotion_if_dtypes_match([dtype1, dtype2]): + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + + @jtu.sample_product( + dtype1=[s for s in default_dtypes if s != jnp.bfloat16], + dtype2=[s for s in default_dtypes if s != jnp.bfloat16], + shape1=nonempty_nonscalar_array_shapes, + shape2=nonempty_nonscalar_array_shapes, + ) + def testUnion1d(self, shape1, shape2, dtype1, dtype2): + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(shape1, dtype1), rng(shape2, dtype2)] + def np_fun(arg1, arg2): + dtype = jnp.promote_types(arg1.dtype, arg2.dtype) + return np.union1d(arg1, arg2).astype(dtype) + with jtu.strict_promotion_if_dtypes_match([dtype1, dtype2]): + self._CheckAgainstNumpy(np_fun, jnp.union1d, args_maker) + + @unittest.skip("Jax-metal fail.") + @jtu.sample_product( + dtype1=[s for s in default_dtypes if s != jnp.bfloat16], + dtype2=[s for s in default_dtypes if s != jnp.bfloat16], + shape1=nonempty_nonscalar_array_shapes, + shape2=nonempty_nonscalar_array_shapes, + size=[1, 5, 10], + fill_value=[None, -1], + ) + def testUnion1dSize(self, shape1, shape2, dtype1, dtype2, size, fill_value): + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(shape1, dtype1), rng(shape2, dtype2)] + def np_fun(arg1, arg2): + dtype = jnp.promote_types(arg1.dtype, arg2.dtype) + result = np.union1d(arg1, arg2).astype(dtype) + fv = result.min() if fill_value is None else fill_value + if size <= len(result): + return result[:size] + else: + return np.concatenate([result, np.full(size - len(result), fv, result.dtype)]) + def jnp_fun(arg1, arg2): + return jnp.union1d(arg1, arg2, size=size, fill_value=fill_value) + with jtu.strict_promotion_if_dtypes_match([dtype1, dtype2]): + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + + @jtu.sample_product( + dtype1=[s for s in default_dtypes if s != jnp.bfloat16], + dtype2=[s for s in default_dtypes if s != jnp.bfloat16], + shape1=all_shapes, + shape2=all_shapes, + assume_unique=[False, True], + ) + def testSetxor1d(self, shape1, dtype1, shape2, dtype2, assume_unique): + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(shape1, dtype1), rng(shape2, dtype2)] + jnp_fun = lambda ar1, ar2: jnp.setxor1d(ar1, ar2, assume_unique=assume_unique) + def np_fun(ar1, ar2): + if assume_unique: + # pre-flatten the arrays to match with jax implementation + ar1 = np.ravel(ar1) + ar2 = np.ravel(ar2) + return np.setxor1d(ar1, ar2, assume_unique) + with jtu.strict_promotion_if_dtypes_match([dtype1, dtype2]): + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False) + + @jtu.sample_product( + dtype1=[s for s in default_dtypes if s != jnp.bfloat16], + dtype2=[s for s in default_dtypes if s != jnp.bfloat16], + shape1=all_shapes, + shape2=all_shapes, + assume_unique=[False, True], + return_indices=[False, True], + ) + def testIntersect1d(self, shape1, dtype1, shape2, dtype2, assume_unique, + return_indices): + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(shape1, dtype1), rng(shape2, dtype2)] + jnp_fun = lambda ar1, ar2: jnp.intersect1d(ar1, ar2, assume_unique=assume_unique, return_indices=return_indices) + np_fun = lambda ar1, ar2: np.intersect1d(ar1, ar2, assume_unique=assume_unique, return_indices=return_indices) + with jtu.strict_promotion_if_dtypes_match([dtype1, dtype2]): + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False) + + @jtu.sample_product( + [dict(lhs_shape=lhs_shape, lhs_dtype=lhs_dtype, + rhs_shape=rhs_shape, rhs_dtype=rhs_dtype) + # TODO(phawkins): support integer dtypes too. + for lhs_shape, lhs_dtype in _shape_and_dtypes(all_shapes, inexact_dtypes) + for rhs_shape, rhs_dtype in _shape_and_dtypes(all_shapes, inexact_dtypes) + if len(jtu._dims_of_shape(lhs_shape)) == 0 + or len(jtu._dims_of_shape(rhs_shape)) == 0 + or lhs_shape[-1] == rhs_shape[-1]], + ) + @jax.default_matmul_precision("float32") + def testInner(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype): + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(lhs_shape, lhs_dtype), rng(rhs_shape, rhs_dtype)] + def np_fun(lhs, rhs): + lhs = lhs if lhs_dtype != jnp.bfloat16 else lhs.astype(np.float32) + rhs = rhs if rhs_dtype != jnp.bfloat16 else rhs.astype(np.float32) + dtype = jnp.promote_types(lhs_dtype, rhs_dtype) + return np.inner(lhs, rhs).astype(dtype) + jnp_fun = lambda lhs, rhs: jnp.inner(lhs, rhs) + tol_spec = {np.float16: 1e-2, np.float32: 1e-5, np.float64: 1e-13, + np.complex64: 1e-5} + tol = max(jtu.tolerance(lhs_dtype, tol_spec), + jtu.tolerance(rhs_dtype, tol_spec)) + # TODO(phawkins): there are float32/float64 disagreements for some inputs. + with jtu.strict_promotion_if_dtypes_match([lhs_dtype, rhs_dtype]): + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False, tol=tol) + self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=False, atol=tol, rtol=tol) + + @unittest.skip("MLIR translation rule for primitive 'eigh' not found for platform METAL.") + @jtu.sample_product( + dtype=[dt for dt in float_dtypes if dt not in [jnp.float16, jnp.bfloat16]], + shape=[shape for shape in one_dim_array_shapes if shape != (1,)], + deg=[1, 2, 3], + rcond=[None, -1, 10e-3, 10e-5, 10e-10], + full=[False, True], + w=[False, True], + cov=[False, True, "unscaled"], + ) + @jax.default_matmul_precision("float32") + def testPolyfit(self, shape, dtype, deg, rcond, full, w, cov): + rng = jtu.rand_default(self.rng()) + tol_spec = {np.float32: 1e-3, np.float64: 1e-13, np.complex64: 1e-5} + tol = jtu.tolerance(dtype, tol_spec) + _w = lambda a: abs(a) if w else None + args_maker = lambda: [rng(shape, dtype), rng(shape, dtype), rng(shape, dtype)] + jnp_fun = lambda x, y, a: jnp.polyfit(x, y, deg=deg, rcond=rcond, full=full, w=_w(a), cov=cov) + np_fun = jtu.ignore_warning( + message="Polyfit may be poorly conditioned*")(lambda x, y, a: np.polyfit(x, y, deg=deg, rcond=rcond, full=full, w=_w(a), cov=cov)) + + self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=False, atol=tol, rtol=tol) + + args = args_maker() + if not full: + args = args_maker() + try: + np_out = np_fun(*args) + except ValueError: + return # https://github.com/numpy/numpy/issues/22380 + jnp_out = jnp_fun(*args) + self.assertAllClose(np_out, jnp_out, atol=tol, rtol=tol, + check_dtypes=False) + else: + # Don't compare the residuals because jnp.linalg.lstsq acts slightly + # differently to remain `jit`-compatible. + np_p, _, nrank, nsingular_values, nrcond = np_fun(*args) + jp_p, _, jrank, jsingular_values, jrcond = jnp_fun(*args) + self.assertAllClose( + (np_p, nrank, nsingular_values, nrcond), + (jp_p, jrank, jsingular_values, jrcond), + atol=tol, rtol=tol, check_dtypes=False) + + @jtu.sample_product( + [dict(a_min=a_min, a_max=a_max) + for a_min, a_max in [(-1, None), (None, 1), (-0.9, 1), + (-np.ones(1), None), + (None, np.ones(1)), + (np.full(1, -0.9), np.ones(1))] + ], + shape=all_shapes, + dtype=number_dtypes, + ) + @jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion. + @jax.numpy_dtype_promotion('standard') # This test explicitly exercises mixed type promotion + def testClipStaticBounds(self, shape, dtype, a_min, a_max): + if np.issubdtype(dtype, np.unsignedinteger): + a_min = None if a_min is None else abs(a_min) + a_max = None if a_max is None else abs(a_max) + rng = jtu.rand_default(self.rng()) + np_fun = lambda x: np.clip(x, a_min=a_min, a_max=a_max) + jnp_fun = lambda x: jnp.clip(x, min=a_min, max=a_max) + args_maker = lambda: [rng(shape, dtype)] + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False) + self._CompileAndCheck(jnp_fun, args_maker) + + @jtu.sample_product( + [dict(shape=shape, dtype=dtype) + for shape, dtype in _shape_and_dtypes(all_shapes, number_dtypes)], + decimals=[0, 1, -2], + ) + def testRoundStaticDecimals(self, shape, dtype, decimals): + rng = jtu.rand_default(self.rng()) + if jnp.issubdtype(dtype, np.integer) and decimals < 0: + self.skipTest("Integer rounding with decimals < 0 not implemented") + np_fun = lambda x: np.round(x, decimals=decimals) + jnp_fun = lambda x: jnp.round(x, decimals=decimals) + args_maker = lambda: [rng(shape, dtype)] + tol = {jnp.bfloat16: 5e-2, np.float16: 1e-2} + check_dtypes = shape is not jtu.PYTHON_SCALAR_SHAPE + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, + check_dtypes=check_dtypes, tol=tol) + self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=check_dtypes, + atol=tol, rtol=tol) + + @jtu.sample_product(jit=[False, True]) + def testOperatorRound(self, jit): + jround = jax.jit(round, static_argnums=1) if jit else round + self.assertAllClose(round(np.float32(7.532), 1), + jround(jnp.float32(7.5), 1)) + self.assertAllClose(round(np.float32(1.234), 2), + jround(jnp.float32(1.234), 2)) + self.assertAllClose(round(np.float32(1.234)), + jround(jnp.float32(1.234)), check_dtypes=False) + self.assertAllClose(round(np.float32(7.532), 1), + jround(jnp.array(7.5, jnp.float32), 1)) + self.assertAllClose(round(np.float32(1.234), 2), + jround(jnp.array(1.234, jnp.float32), 2)) + self.assertAllClose(round(np.float32(1.234)), + jround(jnp.array(1.234, jnp.float32)), + check_dtypes=False) + + def testRoundMethod(self): + # https://github.com/google/jax/issues/15190 + (jnp.arange(3.) / 5.).round() # doesn't crash + + @jtu.sample_product(shape=[(5,), (5, 2)]) + def testOperatorReversed(self, shape): + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(shape, 'float32')] + np_fun = lambda x: np.array(list(reversed(x))) + jnp_fun = lambda x: jnp.array(list(reversed(x))) + + self._CompileAndCheck(jnp_fun, args_maker) + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + + @jtu.sample_product( + [dict(mode=mode, shape=shape, dtype=dtype, + pad_width=pad_width, constant_values=constant_values) + for mode, shapes in [ + ('constant', all_shapes), + ('wrap', nonempty_shapes), + ('edge', nonempty_shapes), + ] + for shape, dtype in _shape_and_dtypes(shapes, all_dtypes) + for constant_values in [ + # None is used for modes other than 'constant' + None, + # constant + 0, 1, + # (constant,) + (0,), (2.718,), + # ((before_const, after_const),) + ((0, 2),), ((-1, 3.14),), + # ((before_1, after_1), ..., (before_N, after_N)) + tuple((i / 2, -3.14 * i) for i in range(len(shape))), + ] + for pad_width in [ + # ((before_1, after_1), ..., (before_N, after_N)) + tuple((i % 3, (i + 1) % 3) for i in range(len(shape))), + # ((before, after),) + ((1, 2),), ((2, 0),), + # (before, after) (not in the docstring but works in numpy) + (2, 0), (0, 0), + # (pad,) + (1,), (2,), + # pad + 0, 1, + ] + if (pad_width != () and constant_values != () and + ((mode == 'constant' and constant_values is not None) or + (mode != 'constant' and constant_values is None)))], + ) + def testPad(self, shape, dtype, mode, pad_width, constant_values): + if np.issubdtype(dtype, np.unsignedinteger): + constant_values = jax.tree.map(abs, constant_values) + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(shape, dtype)] + if constant_values is None: + np_fun = partial(np.pad, pad_width=pad_width, mode=mode) + jnp_fun = partial(jnp.pad, pad_width=pad_width, mode=mode) + else: + np_fun = partial(np.pad, pad_width=pad_width, mode=mode, + constant_values=constant_values) + jnp_fun = partial(jnp.pad, pad_width=pad_width, mode=mode, + constant_values=constant_values) + + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, + check_dtypes=shape is not jtu.PYTHON_SCALAR_SHAPE) + self._CompileAndCheck(jnp_fun, args_maker) + + @jtu.sample_product( + [dict(mode=mode, shape=shape, dtype=dtype, + pad_width=pad_width, stat_length=stat_length) + for mode in ['maximum', 'minimum', 'mean', 'median'] + for shape, dtype in _shape_and_dtypes(nonempty_shapes, all_dtypes) + for pad_width in [ + # ((before_1, after_1), ..., (before_N, after_N)) + tuple((i % 3, (i + 1) % 3) for i in range(len(shape))), + # ((before, after),) + ((1, 2),), ((2, 0),), + # (before, after) (not in the docstring but works in numpy) + (2, 0), (0, 0), + # (pad,) + (1,), (2,), + # pad + 0, 1, + ] + for stat_length in [ + None, + # ((before_1, after_1), ..., (before_N, after_N)) + tuple(((i % 3 + 1), ((i + 1) % 3) + 1) for i in range(len(shape))), + # ((before, after),) + ((1, 2),), ((2, 2),), + # (before, after) (not in the docstring but works in numpy) + (1, 1), (3, 4), + # (pad,) + (1,), (2,), + # pad + 1, 2 + ] + if (pad_width != () and stat_length != () and + not (dtype in bool_dtypes and mode == 'mean'))], + ) + def testPadStatValues(self, shape, dtype, mode, pad_width, stat_length): + if mode == 'median' and np.issubdtype(dtype, np.complexfloating): + self.skipTest("median statistic is not supported for dtype=complex.") + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(shape, dtype)] + + np_fun = partial(np.pad, pad_width=pad_width, mode=mode, stat_length=stat_length) + jnp_fun = partial(jnp.pad, pad_width=pad_width, mode=mode, stat_length=stat_length) + + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, + check_dtypes=shape is not jtu.PYTHON_SCALAR_SHAPE) + self._CompileAndCheck(jnp_fun, args_maker) + + @jtu.sample_product( + [dict(shape=shape, dtype=dtype, + pad_width=pad_width, reflect_type=reflect_type) + for shape, dtype in _shape_and_dtypes(nonempty_shapes, all_dtypes) + for pad_width in [ + # ((before_1, after_1), ..., (before_N, after_N)) + tuple((i % 3, (i + 1) % 3) for i in range(len(shape))), + # ((before, after),) + ((1, 2),), ((2, 3),), + # (before, after) (not in the docstring but works in numpy) + (2, 1), (1, 2), + # (pad,) + (1,), (2,), (3,), + # pad + 0, 5, 7, 10 + ] + for reflect_type in ['even', 'odd'] + if (pad_width != () and + # following types lack precision when calculating odd values + (reflect_type != 'odd' or dtype not in [np.bool_, np.float16, jnp.bfloat16]))], + mode=['symmetric', 'reflect'] + ) + def testPadSymmetricAndReflect(self, shape, dtype, mode, pad_width, reflect_type): + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(shape, dtype)] + + np_fun = partial(np.pad, pad_width=pad_width, mode=mode, reflect_type=reflect_type) + jnp_fun = partial(jnp.pad, pad_width=pad_width, mode=mode, reflect_type=reflect_type) + + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, + check_dtypes=shape is not jtu.PYTHON_SCALAR_SHAPE, + tol={np.float32: 1e-3, np.complex64: 1e-3}) + self._CompileAndCheck(jnp_fun, args_maker) + + @jtu.sample_product( + [dict(shape=shape, dtype=dtype, pad_width=pad_width, end_values=end_values) + for shape, dtype in _shape_and_dtypes(nonempty_shapes, default_dtypes + complex_dtypes) + for pad_width in [ + # ((before_1, after_1), ..., (before_N, after_N)) + tuple((i % 3, (i + 1) % 3) for i in range(len(shape))), + # ((before, after),) + ((1, 2),), ((2, 0),), + # (before, after) (not in the docstring but works in numpy) + (2, 0), (0, 0), + # (pad,) + (1,), (2,), + # pad + 0, 1, + ] + for end_values in [ + # ((before_1, after_1), ..., (before_N, after_N)) + tuple((i % 3, (i + 1) % 3) for i in range(len(shape))), + # ((before, after),) + ((1, 2),), ((2.0, 3.14),), + # (before, after) (not in the docstring but works in numpy) + (0, 0), (-8.0, 2.0), + # (end_values,) + (1,), (2,), + # end_values + 0, 1, 100, 10.0, 3.5, 4.2, -5, -3 + ] + if (pad_width != () and end_values != () and + # following types lack precision + dtype not in [np.int8, np.int16, np.float16, jnp.bfloat16])], + ) + def testPadLinearRamp(self, shape, dtype, pad_width, end_values): + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(shape, dtype)] + + np_fun = partial(np.pad, pad_width=pad_width, mode="linear_ramp", + end_values=end_values) + jnp_fun = partial(jnp.pad, pad_width=pad_width, mode="linear_ramp", + end_values=end_values) + + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, + check_dtypes=shape is not jtu.PYTHON_SCALAR_SHAPE) + self._CompileAndCheck(jnp_fun, args_maker) + + def testPadEmpty(self): + arr = np.arange(6).reshape(2, 3) + + pad_width = ((2, 3), (3, 1)) + np_res = np.pad(arr, pad_width=pad_width, mode="empty") + jnp_res = jnp.pad(arr, pad_width=pad_width, mode="empty") + + np.testing.assert_equal(np_res.shape, jnp_res.shape) + np.testing.assert_equal(arr, np_res[2:-3, 3:-1]) + np.testing.assert_equal(arr, jnp_res[2:-3, 3:-1]) + np.testing.assert_equal(np_res[2:-3, 3:-1], jnp_res[2:-3, 3:-1]) + + def testPadKwargs(self): + modes = { + 'constant': {'constant_values': 0}, + 'edge': {}, + 'linear_ramp': {'end_values': 0}, + 'maximum': {'stat_length': None}, + 'mean': {'stat_length': None}, + 'median': {'stat_length': None}, + 'minimum': {'stat_length': None}, + 'reflect': {'reflect_type': 'even'}, + 'symmetric': {'reflect_type': 'even'}, + 'wrap': {}, + 'empty': {} + } + arr = jnp.array([1, 2, 3]) + pad_width = 1 + + for mode in modes.keys(): + allowed = modes[mode] + not_allowed = {} + for kwargs in modes.values(): + if kwargs != allowed: + not_allowed.update(kwargs) + + # Test if allowed keyword arguments pass + jnp.pad(arr, pad_width, mode, **allowed) + # Test if prohibited keyword arguments of other modes raise an error + match = f"unsupported keyword arguments for mode '{mode}'" + for key, value in not_allowed.items(): + with self.assertRaisesRegex(ValueError, match): + jnp.pad(arr, pad_width, mode, **{key: value}) + + # Test if unsupported mode raise error. + unsupported_modes = [1, None, "foo"] + for mode in unsupported_modes: + match = f"Unimplemented padding mode '{mode}' for np.pad." + with self.assertRaisesRegex(NotImplementedError, match): + jnp.pad(arr, pad_width, mode) + + def testPadFunction(self): + def np_pad_with(vector, pad_width, iaxis, kwargs): + pad_value = kwargs.get('padder', 10) + vector[:pad_width[0]] = pad_value + vector[-pad_width[1]:] = pad_value + + def jnp_pad_with(vector, pad_width, iaxis, kwargs): + pad_value = kwargs.get('padder', 10) + vector = vector.at[:pad_width[0]].set(pad_value) + vector = vector.at[-pad_width[1]:].set(pad_value) + return vector + + arr = np.arange(6).reshape(2, 3) + np_res = np.pad(arr, 2, np_pad_with) + jnp_res = jnp.pad(arr, 2, jnp_pad_with) + np.testing.assert_equal(np_res, jnp_res) + + arr = np.arange(24).reshape(2, 3, 4) + np_res = np.pad(arr, 1, np_pad_with, padder=100) + jnp_res = jnp.pad(arr, 1, jnp_pad_with, padder=100) + np.testing.assert_equal(np_res, jnp_res) + + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(arr.shape, arr.dtype)] + jnp_fun = partial(jnp.pad, pad_width=1, mode=jnp_pad_with) + self._CompileAndCheck(jnp_fun, args_maker) + + def testPadWithNumpyPadWidth(self): + a = jnp.array([1, 2, 3, 4, 5]) + f = jax.jit( + partial( + jnp.pad, + pad_width=np.asarray((2, 3)), + mode="constant", + constant_values=(4, 6))) + + np.testing.assert_array_equal( + f(a), + np.pad( + a, + pad_width=np.asarray((2, 3)), + mode="constant", + constant_values=(4, 6))) + + def testPadWeakType(self): + x = jnp.array(1.0)[None] + for mode in ['constant', 'edge', 'linear_ramp', 'maximum', 'mean', 'median', + 'minimum', 'reflect', 'symmetric', 'wrap', 'empty']: + y = jnp.pad(x, 0, mode=mode) + self.assertTrue(dtypes.is_weakly_typed(y)) + + @jtu.sample_product( + [dict(shape=shape, dtype=dtype) + for shape, dtype in _shape_and_dtypes(all_shapes, default_dtypes)], + reps=[(), (2,), (3, 4), (2, 3, 4), (1, 0, 2)], + ) + def testTile(self, shape, dtype, reps): + rng = jtu.rand_default(self.rng()) + np_fun = lambda arg: np.tile(arg, reps) + jnp_fun = lambda arg: jnp.tile(arg, reps) + + args_maker = lambda: [rng(shape, dtype)] + + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, + check_dtypes=shape is not jtu.PYTHON_SCALAR_SHAPE) + self._CompileAndCheck(jnp_fun, args_maker) + + @jtu.sample_product(shape=all_shapes, dtype=all_dtypes) + def testExtract(self, shape, dtype): + rng = jtu.rand_some_zero(self.rng()) + args_maker = lambda: [rng(shape, jnp.float32), rng(shape, dtype)] + self._CheckAgainstNumpy(np.extract, jnp.extract, args_maker) + + @jtu.sample_product( + [dict(ncond=ncond, nfunc=nfunc) + for ncond in [1, 2, 3] + for nfunc in [ncond, ncond + 1] + ], + shape=all_shapes, + dtype=all_dtypes) + def testPiecewise(self, shape, dtype, ncond, nfunc): + rng = jtu.rand_default(self.rng()) + rng_bool = jtu.rand_int(self.rng(), 0, 2) + funclist = [lambda x: x - 1, 1, lambda x: x, 0][:nfunc] + args_maker = lambda: (rng(shape, dtype), [rng_bool(shape, bool) for i in range(ncond)]) + np_fun = partial(np.piecewise, funclist=funclist) + jnp_fun = partial(jnp.piecewise, funclist=funclist) + + if dtype == np.bool_: + # The `x - 1` above uses type promotion. + jnp_fun = jax.numpy_dtype_promotion('standard')(jnp_fun) + + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True) + # This is a higher-order function, so the cache miss check will fail. + self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True, check_cache_misses=False) + + def testPiecewiseRecompile(self): + def g(x): + g.num_traces += 1 + return x + g.num_traces = 0 + x = jnp.arange(10.0) + for i in range(5): + jnp.piecewise(x, [x < 0], [g, 0.]) + self.assertEqual(g.num_traces, 1) + + @jtu.sample_product( + [dict(shape=shape, perm=perm) + for shape in array_shapes + for perm in [ + None, + tuple(np.random.RandomState(0).permutation(np.zeros(shape).ndim)), + tuple(np.random.RandomState(0).permutation( + np.zeros(shape).ndim) - np.zeros(shape).ndim) + ] + ], + dtype=default_dtypes, + arg_type=["splat", "value"], + ) + def testTransposeTuple(self, shape, dtype, perm, arg_type): + rng = jtu.rand_some_zero(self.rng()) + args_maker = lambda: [rng(shape, dtype)] + if arg_type == "value": + np_fun = lambda x: x.transpose(perm) + jnp_fun = lambda x: jnp.array(x).transpose(perm) + else: + np_fun = lambda x: x.transpose(*(perm or ())) + jnp_fun = lambda x: jnp.array(x).transpose(*(perm or ())) + + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True) + self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True) + + @jtu.sample_product( + shape=array_shapes, + dtype=default_dtypes, + ) + def testPermuteDims(self, shape, dtype): + rng = jtu.rand_some_zero(self.rng()) + args_maker = lambda: [rng(shape, dtype)] + axes = self.rng().permutation(len(shape)) + np_fun = partial(getattr(np, "permute_dims", np.transpose), axes=axes) + jnp_fun = partial(jnp.permute_dims, axes=axes) + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True) + self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True) + + @jtu.sample_product( + shape=[s for s in array_shapes if len(s) >= 2], + dtype=default_dtypes, + use_property=[True, False] + ) + def testMatrixTranspose(self, shape, dtype, use_property): + if use_property: + jnp_fun = lambda x: jnp.asarray(x).mT + else: + jnp_fun = jnp.matrix_transpose + if hasattr(np, 'matrix_transpose'): + np_fun = np.matrix_transpose + else: + np_fun = lambda x: np.swapaxes(x, -1, -2) + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(shape, dtype)] + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + + @jtu.sample_product( + dtype=default_dtypes, + a_shape=one_dim_array_shapes, + trim=["f", "b", "fb"], + ) + def testTrimZeros(self, a_shape, dtype, trim): + rng = jtu.rand_some_zero(self.rng()) + args_maker = lambda: [rng(a_shape, dtype)] + np_fun = lambda arg1: np.trim_zeros(arg1, trim) + jnp_fun = lambda arg1: jnp.trim_zeros(arg1, trim) + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True) + + @unittest.skip("Jax-metal don't support map op.") + @jtu.sample_product( + rank=(1, 2), + dtype=default_dtypes, + a_shape=one_dim_array_shapes, + ) + @jax.default_matmul_precision("float32") + def testPoly(self, a_shape, dtype, rank): + if dtype in (np.float16, jnp.bfloat16, np.int16): + self.skipTest(f"{dtype} gets promoted to {np.float16}, which is not supported.") + elif rank == 2 and not jtu.test_device_matches(["cpu"]): + self.skipTest("Nonsymmetric eigendecomposition is only implemented on the CPU backend.") + rng = jtu.rand_default(self.rng()) + tol = { np.int8: 2e-3, np.int32: 1e-3, np.float32: 1e-3, np.float64: 1e-6 } + if jtu.test_device_matches(["tpu"]): + tol[np.int32] = tol[np.float32] = 1e-1 + tol = jtu.tolerance(dtype, tol) + args_maker = lambda: [rng(a_shape * rank, dtype)] + self._CheckAgainstNumpy(np.poly, jnp.poly, args_maker, check_dtypes=False, tol=tol) + self._CompileAndCheck(jnp.poly, args_maker, check_dtypes=True, rtol=tol, atol=tol) + + @unittest.skip("Jax-metal don't support map op.") + @jtu.sample_product( + dtype=default_dtypes, + a_shape=one_dim_array_shapes, + b_shape=one_dim_array_shapes, + ) + def testPolyAdd(self, a_shape, b_shape, dtype): + rng = jtu.rand_default(self.rng()) + np_fun = lambda arg1, arg2: np.polyadd(arg1, arg2) + jnp_fun = lambda arg1, arg2: jnp.polyadd(arg1, arg2) + args_maker = lambda: [rng(a_shape, dtype), rng(b_shape, dtype)] + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True) + self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True) + + @unittest.skip("Jax-metal don't support map op.") + @jtu.sample_product( + dtype=default_dtypes, + a_shape=one_dim_array_shapes, + b_shape=one_dim_array_shapes, + ) + def testPolySub(self, a_shape, b_shape, dtype): + rng = jtu.rand_default(self.rng()) + np_fun = lambda arg1, arg2: np.polysub(arg1, arg2) + jnp_fun = lambda arg1, arg2: jnp.polysub(arg1, arg2) + args_maker = lambda: [rng(a_shape, dtype), rng(b_shape, dtype)] + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True) + self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True) + + @unittest.skip("Jax-metal don't support map op.") + @jtu.sample_product( + [dict(order=order, k=k, dtype=dtype) + for dtype in default_dtypes + for order in range(5) + for k in [np.arange(order, dtype=dtype), np.ones(1, dtype), None]], + a_shape=one_dim_array_shapes, + ) + def testPolyInt(self, a_shape, order, k, dtype): + rng = jtu.rand_default(self.rng()) + np_fun = lambda arg1: np.polyint(arg1, m=order, k=k) + jnp_fun = lambda arg1: jnp.polyint(arg1, m=order, k=k) + args_maker = lambda: [rng(a_shape, dtype)] + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False) + self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True) + + @unittest.skip("Jax-metal don't support map op.") + @jtu.sample_product( + dtype=default_dtypes, + a_shape=one_dim_array_shapes, + order=list(range(5)), + ) + def testPolyDer(self, a_shape, order, dtype): + rng = jtu.rand_default(self.rng()) + np_fun = lambda arg1: np.polyder(arg1, m=order) + jnp_fun = lambda arg1: jnp.polyder(arg1, m=order) + args_maker = lambda: [rng(a_shape, dtype)] + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False) + self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True) + + @parameterized.parameters(['int', 'np.int', 'jnp.int']) + def testIntegerPower(self, ptype): + p = {'int': 2, 'np.int': np.int32(2), 'jnp.int': jnp.int32(2)}[ptype] + jaxpr = jax.make_jaxpr(lambda x1: jnp.power(x1, p))(1) + eqns = jaxpr.jaxpr.eqns + self.assertLen(eqns, 1) + self.assertEqual(eqns[0].primitive, lax.integer_pow_p) + + @jtu.sample_product( + x=[-1, 0, 1], + y=[0, 32, 64, 128], + ) + def testIntegerPowerOverflow(self, x, y): + # Regression test for https://github.com/google/jax/issues/5987 + args_maker = lambda: [x, y] + self._CheckAgainstNumpy(np.power, jnp.power, args_maker) + self._CompileAndCheck(jnp.power, args_maker) + + @jtu.sample_product( + [dict(shape=shape, axis=axis) + for shape in all_shapes + for axis in [None] + list(range(len(shape))) + ], + dtype=all_dtypes, + ) + def testCompress(self, shape, dtype, axis): + rng = jtu.rand_some_zero(self.rng()) + if shape in scalar_shapes or len(shape) == 0: + cond_shape = (0,) + elif axis is None: + cond_shape = (math.prod(shape),) + else: + cond_shape = (shape[axis],) + + args_maker = lambda: [rng(cond_shape, jnp.float32), rng(shape, dtype)] + + np_fun = partial(np.compress, axis=axis) + jnp_fun = partial(jnp.compress, axis=axis) + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + + @jtu.sample_product( + shape=[(2, 3)], + dtype=int_dtypes, + # condition entries beyond axis size must be zero. + condition=[[1], [1, 0, 0, 0, 0, 0, 0]], + axis=[None, 0, 1], + ) + def testCompressMismatchedShapes(self, shape, dtype, condition, axis): + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [np.array(condition), rng(shape, dtype)] + np_fun = partial(np.compress, axis=axis) + jnp_fun = partial(jnp.compress, axis=axis) + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + + @jtu.sample_product( + [dict(shape=shape, axis=axis) + for shape in array_shapes + for axis in [None] + list(range(len(shape))) + ], + dtype=all_dtypes, + ) + def testCompressMethod(self, shape, dtype, axis): + rng = jtu.rand_some_zero(self.rng()) + if shape in scalar_shapes or len(shape) == 0: + cond_shape = (0,) + elif axis is None: + cond_shape = (math.prod(shape),) + else: + cond_shape = (shape[axis],) + + args_maker = lambda: [rng(cond_shape, jnp.float32), rng(shape, dtype)] + + np_fun = lambda condition, x: np.compress(condition, x, axis=axis) + jnp_fun = lambda condition, x: x.compress(condition, axis=axis) + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + + @jtu.sample_product( + [dict(base_shape=base_shape, axis=axis) + for base_shape in [(4,), (3, 4), (2, 3, 4)] + for axis in (None, *range(-len(base_shape)+1, len(base_shape))) + ], + arg_dtypes=[ + arg_dtypes + for num_arrs in [3] + for arg_dtypes in itertools.combinations_with_replacement(default_dtypes, num_arrs) + ], + dtype=[None] + default_dtypes, + ) + def testConcatenate(self, axis, dtype, base_shape, arg_dtypes): + rng = jtu.rand_default(self.rng()) + wrapped_axis = 0 if axis is None else axis % len(base_shape) + shapes = [base_shape[:wrapped_axis] + (size,) + base_shape[wrapped_axis+1:] + for size, _ in zip(itertools.cycle([3, 1, 4]), arg_dtypes)] + @jtu.promote_like_jnp + def np_fun(*args, dtype=dtype): + dtype = dtype or args[0].dtype + args = [x if x.dtype != jnp.bfloat16 else x.astype(np.float32) + for x in args] + return np.concatenate(args, axis=axis, dtype=dtype, casting='unsafe') + jnp_fun = lambda *args: jnp.concatenate(args, axis=axis, dtype=dtype) + + def args_maker(): + return [rng(shape, dtype) for shape, dtype in zip(shapes, arg_dtypes)] + + with jtu.strict_promotion_if_dtypes_match(arg_dtypes): + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + + @jtu.sample_product( + [dict(shape=shape, axis=axis) + for shape in [(4, 1), (4, 3), (4, 5, 6)] + for axis in [None] + list(range(1 - len(shape), len(shape) - 1)) + ], + dtype=all_dtypes, + ) + def testConcatenateArray(self, shape, dtype, axis): + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(shape, dtype)] + np_fun = lambda x: np.concatenate(x, axis=axis) + jnp_fun = lambda x: jnp.concatenate(x, axis=axis) + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + + def testConcatenateAxisNone(self): + # https://github.com/google/jax/issues/3419 + a = jnp.array([[1, 2], [3, 4]]) + b = jnp.array([[5]]) + jnp.concatenate((a, b), axis=None) + + def testConcatenateScalarAxisNone(self): + arrays = [np.int32(0), np.int32(1)] + self.assertArraysEqual(jnp.concatenate(arrays, axis=None), + np.concatenate(arrays, axis=None)) + + @jtu.sample_product( + [dict(base_shape=base_shape, axis=axis) + for base_shape in [(), (4,), (3, 4), (2, 3, 4)] + for axis in (None, *range(-len(base_shape)+1, len(base_shape))) + ], + dtype=default_dtypes, + ) + def testConcat(self, axis, base_shape, dtype): + rng = jtu.rand_default(self.rng()) + wrapped_axis = 0 if axis is None else axis % len(base_shape) + shapes = [base_shape[:wrapped_axis] + (size,) + base_shape[wrapped_axis+1:] + for size in [3, 1, 4]] + @jtu.promote_like_jnp + def np_fun(*args): + if jtu.numpy_version() >= (2, 0, 0): + return np.concat(args, axis=axis) + else: + return np.concatenate(args, axis=axis) + jnp_fun = lambda *args: jnp.concat(args, axis=axis) + args_maker = lambda: [rng(shape, dtype) for shape in shapes] + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + + @jtu.sample_product( + [dict(base_shape=base_shape, axis=axis) + for base_shape in [(4,), (3, 4), (2, 3, 4)] + for axis in range(-len(base_shape)+1, len(base_shape))], + arg_dtypes=itertools.combinations_with_replacement(default_dtypes, 2) + ) + def testAppend(self, axis, base_shape, arg_dtypes): + rng = jtu.rand_default(self.rng()) + wrapped_axis = axis % len(base_shape) + shapes = [base_shape[:wrapped_axis] + (size,) + base_shape[wrapped_axis+1:] + for size, _ in zip(itertools.cycle([3, 1, 4]), arg_dtypes)] + def np_fun(arr, values): + arr = arr.astype(np.float32) if arr.dtype == jnp.bfloat16 else arr + values = (values.astype(np.float32) if values.dtype == jnp.bfloat16 + else values) + out = np.append(arr, values, axis=axis) + return out.astype(jnp.promote_types(*arg_dtypes)) + jnp_fun = lambda arr, values: jnp.append(arr, values, axis=axis) + + def args_maker(): + return [rng(shape, dtype) for shape, dtype in zip(shapes, arg_dtypes)] + + with jtu.strict_promotion_if_dtypes_match(arg_dtypes): + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + + @jtu.sample_product( + [dict(shape=shape, axis=axis, idx=idx) + for shape in nonempty_nonscalar_array_shapes + for axis in [None] + list(range(-len(shape), len(shape))) + for idx in (range(-math.prod(shape), math.prod(shape)) + if axis is None else + range(-shape[axis], shape[axis]))], + dtype=all_dtypes, + ) + def testDeleteInteger(self, shape, dtype, idx, axis): + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(shape, dtype)] + np_fun = lambda arg: np.delete(arg, idx, axis=axis) + jnp_fun = lambda arg: jnp.delete(arg, idx, axis=axis) + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + + @jtu.sample_product( + [dict(shape=shape, axis=axis) + for shape in nonempty_nonscalar_array_shapes + for axis in [None] + list(range(-len(shape), len(shape))) + ], + dtype=all_dtypes, + slc=[slice(None), slice(1, 3), slice(1, 5, 2)], + ) + def testDeleteSlice(self, shape, dtype, axis, slc): + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(shape, dtype)] + np_fun = lambda arg: np.delete(arg, slc, axis=axis) + jnp_fun = lambda arg: jnp.delete(arg, slc, axis=axis) + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + + @jtu.sample_product( + [dict(shape=shape, axis=axis) + for shape in nonempty_nonscalar_array_shapes + for axis in [None] + list(range(-len(shape), len(shape))) + ], + dtype=all_dtypes, + idx_shape=all_shapes, + ) + def testDeleteIndexArray(self, shape, dtype, axis, idx_shape): + rng = jtu.rand_default(self.rng()) + max_idx = np.zeros(shape).size if axis is None else np.zeros(shape).shape[axis] + idx = jtu.rand_int(self.rng(), low=-max_idx, high=max_idx)(idx_shape, int) + args_maker = lambda: [rng(shape, dtype)] + np_fun = lambda arg: np.delete(arg, idx, axis=axis) + jnp_fun = lambda arg: jnp.delete(arg, idx, axis=axis) + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + + @jtu.sample_product( + [dict(shape=shape, axis=axis) + for shape in nonempty_nonscalar_array_shapes + for axis in [None] + list(range(-len(shape), len(shape))) + ], + dtype=all_dtypes, + idx_shape=all_shapes, + ) + def testDeleteUniqueIndices(self, shape, dtype, axis, idx_shape): + rng = jtu.rand_default(self.rng()) + max_idx = np.zeros(shape).size if axis is None else np.zeros(shape).shape[axis] + idx_size = np.zeros(idx_shape).size + if idx_size > max_idx: + self.skipTest("Too many indices to be unique") + def args_maker(): + x = rng(shape, dtype) + idx = self.rng().choice(max_idx, idx_shape, replace=False) + return x, idx + np_fun = partial(np.delete, axis=axis) + jnp_fun = partial(jnp.delete, axis=axis, assume_unique_indices=True) + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + + @jtu.sample_product( + [dict(shape=shape, axis=axis) + for shape in nonempty_nonscalar_array_shapes + for axis in [None] + list(range(-len(shape), len(shape))) + ], + dtype=all_dtypes, + ) + def testDeleteMaskArray(self, shape, dtype, axis): + rng = jtu.rand_default(self.rng()) + mask_size = np.zeros(shape).size if axis is None else np.zeros(shape).shape[axis] + mask = jtu.rand_int(self.rng(), low=0, high=2)(mask_size, bool) + args_maker = lambda: [rng(shape, dtype)] + np_fun = lambda arg: np.delete(arg, mask, axis=axis) + jnp_fun = lambda arg: jnp.delete(arg, mask, axis=axis) + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + + @unittest.skip("JAX-metal fail.") + @jtu.sample_product( + [dict(shape=shape, axis=axis) + for shape in nonempty_nonscalar_array_shapes + for axis in [None] + list(range(-len(shape), len(shape))) + ], + dtype=all_dtypes, + ) + def testInsertInteger(self, shape, dtype, axis): + x = jnp.empty(shape) + max_ind = x.size if axis is None else x.shape[axis] + rng = jtu.rand_default(self.rng()) + i_rng = jtu.rand_int(self.rng(), -max_ind, max_ind) + args_maker = lambda: [rng(shape, dtype), i_rng((), np.int32), rng((), dtype)] + np_fun = lambda *args: np.insert(*args, axis=axis) + jnp_fun = lambda *args: jnp.insert(*args, axis=axis) + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + + @unittest.skip("Jax-metal fail.") + @jtu.sample_product( + [dict(shape=shape, axis=axis) + for shape in nonempty_nonscalar_array_shapes + for axis in [None] + list(range(-len(shape), len(shape))) + ], + dtype=all_dtypes, + ) + def testInsertSlice(self, shape, dtype, axis): + x = jnp.empty(shape) + max_ind = x.size if axis is None else x.shape[axis] + rng = jtu.rand_default(self.rng()) + i_rng = jtu.rand_int(self.rng(), -max_ind, max_ind) + slc = slice(i_rng((), jnp.int32).item(), i_rng((), jnp.int32).item()) + args_maker = lambda: [rng(shape, dtype), rng((), dtype)] + np_fun = lambda x, val: np.insert(x, slc, val, axis=axis) + jnp_fun = lambda x, val: jnp.insert(x, slc, val, axis=axis) + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + + @parameterized.parameters([ + [[[1, 1], [2, 2], [3, 3]], 1, 5, None], + [[[1, 1], [2, 2], [3, 3]], 1, 5, 1], + [[[1, 1], [2, 2], [3, 3]], 1, [1, 2, 3], 1], + [[[1, 1], [2, 2], [3, 3]], [1], [[1],[2],[3]], 1], + [[1, 1, 2, 2, 3, 3], [2, 2], [5, 6], None], + [[1, 1, 2, 2, 3, 3], slice(2, 4), [5, 6], None], + [[1, 1, 2, 2, 3, 3], [2, 2], [7.13, False], None], + [[[0, 1, 2, 3], [4, 5, 6, 7]], (1, 3), 999, 1] + ]) + def testInsertExamples(self, arr, index, values, axis): + # Test examples from the np.insert docstring + args_maker = lambda: ( + np.asarray(arr), index if isinstance(index, slice) else np.array(index), + np.asarray(values), axis) + self._CheckAgainstNumpy(np.insert, jnp.insert, args_maker) + + @jtu.sample_product( + [dict(shape=shape, axis=axis) + for shape in nonempty_array_shapes + for axis in range(-len(shape), len(shape)) + ], + dtype=default_dtypes, + out_dims=[0, 1, 2], + ) + def testApplyAlongAxis(self, shape, dtype, axis, out_dims): + def func(x, out_dims): + if out_dims == 0: + return x.sum(dtype=x.dtype) + elif out_dims == 1: + return x * x[0] + elif out_dims == 2: + return x[:, None] + x[None, :] + else: + raise NotImplementedError(f"{out_dims=}") + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(shape, dtype)] + np_fun = lambda arr: np.apply_along_axis(func, axis, arr, out_dims=out_dims) + jnp_fun = lambda arr: jnp.apply_along_axis(func, axis, arr, out_dims=out_dims) + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, + atol={dtypes.bfloat16: 2e-2}) + self._CompileAndCheck(jnp_fun, args_maker) + + @jtu.sample_product( + [dict(shape=shape, axes=axes) + for shape in nonempty_shapes + for axes in itertools.combinations(range(len(shape)), 2) + ], + func=["sum"], + keepdims=[True, False], + # Avoid low-precision types in sum() + dtype=[dtype for dtype in default_dtypes + if dtype not in [np.float16, jnp.bfloat16]], + ) + def testApplyOverAxes(self, shape, dtype, func, keepdims, axes): + f = lambda x, axis: getattr(x, func)(axis=axis, keepdims=keepdims, dtype=dtype) + rng = jtu.rand_default(self.rng()) + args_maker = lambda: (rng(shape, dtype),) + np_fun = lambda a: np.apply_over_axes(f, a, axes) + jnp_fun = lambda a: jnp.apply_over_axes(f, a, axes) + self._CompileAndCheck(jnp_fun, args_maker) + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + + @jtu.sample_product( + [dict(shape=shape, dtype=dtype, axis=axis) + for shape, dtype in _shape_and_dtypes(all_shapes, default_dtypes) + for axis in [None] + list(range(-len(shape), max(1, len(shape)))) + ], + repeats=[0, 1, 2], + fixed_size=[False, True], + ) + def testRepeat(self, axis, shape, dtype, repeats, fixed_size): + rng = jtu.rand_default(self.rng()) + np_fun = lambda arg: np.repeat(arg, repeats=repeats, axis=axis) + np_fun = jtu.promote_like_jnp(np_fun) + if fixed_size: + total_repeat_length = np.repeat(np.zeros(shape), repeats, axis).shape[axis or 0] + jnp_fun = lambda arg, rep: jnp.repeat(arg, repeats=rep, axis=axis, + total_repeat_length=total_repeat_length) + jnp_args_maker = lambda: [rng(shape, dtype), repeats] + clo_fun = lambda arg: jnp.repeat(arg, repeats=repeats, axis=axis, + total_repeat_length=total_repeat_length) + clo_fun_args_maker = lambda: [rng(shape, dtype)] + self._CompileAndCheck(jnp_fun, jnp_args_maker) + self._CheckAgainstNumpy(np_fun, clo_fun, clo_fun_args_maker) + else: + # Now repeats is in a closure, so a constant. + jnp_fun = lambda arg: jnp.repeat(arg, repeats=repeats, axis=axis) + args_maker = lambda: [rng(shape, dtype)] + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + + def testRepeatScalarFastPath(self): + a = jnp.array([1,2,3,4]) + f = lambda a: jnp.repeat(a, repeats=2) + jaxpr = jax.make_jaxpr(f)(a) + self.assertLessEqual(len(jaxpr.jaxpr.eqns), 6) + + @unittest.skip("jax-metal fail to convert sort op.") + @jtu.sample_product( + [dict(shape=shape, axis=axis) + for shape in all_shapes + for axis in [None] + list(range(len(shape)))], + dtype=number_dtypes, + return_index=[False, True], + return_inverse=[False, True], + return_counts=[False, True], + ) + def testUnique(self, shape, dtype, axis, return_index, return_inverse, return_counts): + rng = jtu.rand_some_equal(self.rng()) + args_maker = lambda: [rng(shape, dtype)] + extra_args = (return_index, return_inverse, return_counts) + use_defaults = (False, *(True for arg in extra_args if arg)) if any(extra_args) else False + np_fun = jtu.with_jax_dtype_defaults(lambda x: np_unique_backport(x, *extra_args, axis=axis), use_defaults) + jnp_fun = lambda x: jnp.unique(x, *extra_args, axis=axis) + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + + @jtu.sample_product(shape=all_shapes, dtype=number_dtypes) + def testUniqueAll(self, shape, dtype): + rng = jtu.rand_some_equal(self.rng()) + args_maker = lambda: [rng(shape, dtype)] + if jtu.numpy_version() < (2, 0, 0): + np_fun = partial(np_unique_backport, return_index=True, return_inverse=True, return_counts=True) + else: + np_fun = np.unique_all + self._CheckAgainstNumpy(jnp.unique_all, np_fun, args_maker) + + @jtu.sample_product(shape=all_shapes, dtype=number_dtypes) + def testUniqueCounts(self, shape, dtype): + rng = jtu.rand_some_equal(self.rng()) + args_maker = lambda: [rng(shape, dtype)] + if jtu.numpy_version() < (2, 0, 0): + np_fun = lambda x: np.unique(x, return_counts=True) + else: + np_fun = np.unique_counts + self._CheckAgainstNumpy(jnp.unique_counts, np_fun, args_maker) + + @jtu.sample_product(shape=all_shapes, dtype=number_dtypes) + def testUniqueInverse(self, shape, dtype): + rng = jtu.rand_some_equal(self.rng()) + args_maker = lambda: [rng(shape, dtype)] + if jtu.numpy_version() < (2, 0, 0): + np_fun = partial(np_unique_backport, return_inverse=True) + else: + np_fun = np.unique_inverse + self._CheckAgainstNumpy(jnp.unique_inverse, np_fun, args_maker) + + @jtu.sample_product(shape=all_shapes, dtype=number_dtypes) + def testUniqueValues(self, shape, dtype): + rng = jtu.rand_some_equal(self.rng()) + args_maker = lambda: [rng(shape, dtype)] + if jtu.numpy_version() < (2, 0, 0): + np_fun = np.unique + else: + np_fun = np.unique_values + self._CheckAgainstNumpy(jnp.unique_values, np_fun, args_maker) + + @unittest.skip("jax-metal fail.") + @jtu.sample_product( + [dict(shape=shape, axis=axis) + for shape in nonempty_array_shapes + for axis in [None] + list(range(len(shape)))], + dtype=number_dtypes, + size=[1, 5, 10], + fill_value=[None, 0, "slice"], + ) + def testUniqueSize(self, shape, dtype, axis, size, fill_value): + rng = jtu.rand_some_equal(self.rng()) + args_maker = lambda: [rng(shape, dtype)] + kwds = dict(axis=axis, return_index=True, return_inverse=True, return_counts=True) + + if fill_value == "slice": + if axis is None: + fill_value = rng((), dtype) + else: + fill_value = rng(shape[:axis] + shape[axis + 1:], dtype) + elif fill_value is not None: + fill_value = np.array(fill_value).astype(dtype) + + @partial(jtu.with_jax_dtype_defaults, use_defaults=(False, True, True, True)) + def np_fun(x, fill_value=fill_value): + u, ind, inv, counts = np_unique_backport(x, **kwds) + axis = kwds['axis'] + if axis is None: + x = x.ravel() + axis = 0 + + n_unique = u.shape[axis] + if size <= u.shape[axis]: + slc = (slice(None),) * axis + (slice(size),) + u, ind, counts = u[slc], ind[:size], counts[:size] + else: + extra = (0, size - n_unique) + pads = [(0, 0)] * u.ndim + pads[axis] = extra + u = np.pad(u, pads, constant_values=0) + slices = [slice(None)] * u.ndim + slices[axis] = slice(1) + if fill_value is None: + fill_value = u[tuple(slices)] + elif np.ndim(fill_value): + fill_value = lax.expand_dims(fill_value, (axis,)) + slices[axis] = slice(n_unique, None) + u[tuple(slices)] = fill_value + ind = np.pad(ind, extra, constant_values=ind[0]) + counts = np.pad(counts, extra, constant_values=0) + return u, ind, inv, counts + + jnp_fun = lambda x: jnp.unique(x, size=size, fill_value=fill_value, **kwds) + + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + + @unittest.skip("jax-metal fail.") + @jtu.sample_product(dtype=inexact_dtypes) + def testUniqueNans(self, dtype): + def args_maker(): + x = [-0.0, 0.0, 1.0, 1.0, np.nan, -np.nan] + if np.issubdtype(dtype, np.complexfloating): + x = [complex(i, j) for i, j in itertools.product(x, repeat=2)] + return [np.array(x, dtype=dtype)] + + kwds = dict(return_index=True, return_inverse=True, return_counts=True) + jnp_fun = partial(jnp.unique, **kwds) + def np_fun(x): + dtype = x.dtype + # numpy unique fails for bfloat16 NaNs, so we cast to float64 + if x.dtype == jnp.bfloat16: + x = x.astype('float64') + u, *rest = np.unique(x, **kwds) + return (u.astype(dtype), *rest) + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + + @unittest.skip("jax-metal fail.") + @jtu.sample_product(dtype=inexact_dtypes, equal_nan=[True, False]) + def testUniqueEqualNan(self, dtype, equal_nan): + shape = (20,) + rng = jtu.rand_some_nan(self.rng()) + args_maker = lambda: [rng(shape, dtype)] + def np_fun(x): + dtype = x.dtype + # numpy unique fails for bfloat16 NaNs, so we cast to float64 + if x.dtype == jnp.bfloat16: + x = x.astype('float64') + return np.unique(x, equal_nan=equal_nan).astype(dtype) + jnp_fun = partial(jnp.unique, equal_nan=equal_nan) + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + + @jtu.sample_product(fixed_size=[False, True]) + def testNonScalarRepeats(self, fixed_size): + ''' + Following numpy test suite from `test_repeat` at + https://github.com/numpy/numpy/blob/main/numpy/core/tests/test_multiarray.py + ''' + tol = 1e-5 + + def test_single(m, args_maker, repeats, axis): + lax_ans = jnp.repeat(m, repeats, axis) + numpy_ans = np.repeat(m, repeats, axis) + + self.assertAllClose(lax_ans, numpy_ans, rtol=tol, atol=tol) + if fixed_size: + + # Calculate expected size of the repeated axis. + rep_length = np.repeat(np.zeros_like(m), repeats, axis).shape[axis or 0] + jnp_fun = lambda arg, rep: jnp.repeat( + arg, repeats=rep, axis=axis, total_repeat_length=rep_length) + else: + jnp_fun = lambda arg: jnp.repeat(arg, repeats = repeats, axis=axis) + self._CompileAndCheck(jnp_fun, args_maker) + + m = jnp.array([1,2,3,4,5,6]) + if fixed_size: + args_maker = lambda: [m, repeats] + else: + args_maker = lambda: [m] + + for repeats in [2, jnp.array([1,3,0,1,1,2]), jnp.array([1,3,2,1,1,2]), jnp.array([2])]: + test_single(m, args_maker, repeats, axis=None) + test_single(m, args_maker, repeats, axis=0) + + m_rect = m.reshape((2,3)) + if fixed_size: + args_maker = lambda: [m_rect, repeats] + else: + args_maker = lambda: [m_rect] + + for repeats in [2, jnp.array([2,1]), jnp.array([2])]: + test_single(m_rect, args_maker, repeats, axis=0) + + for repeats in [2, jnp.array([1,3,2]), jnp.array([2])]: + test_single(m_rect, args_maker, repeats, axis=1) + + def testIssue2330(self): + ''' + Make sure return value of jnp.concatenate is a jax.ndarray and is side-effect save + ''' + def attempt_sideeffect(x): + x = [x] + x = jnp.concatenate(x) + x -= 1. + return x + + np_input = np.ones(1) + jnp_input = jnp.ones(1) + expected_np_input_after_call = np.ones(1) + expected_jnp_input_after_call = jnp.ones(1) + + out = jnp.concatenate([np_input]) + self.assertIs(type(out), array.ArrayImpl) + + attempt_sideeffect(np_input) + attempt_sideeffect(jnp_input) + + self.assertAllClose(np_input, expected_np_input_after_call) + self.assertAllClose(jnp_input, expected_jnp_input_after_call) + + @jtu.sample_product( + mode=['full', 'same', 'valid'], + op=['convolve', 'correlate'], + dtype= float_dtypes, #number_dtypes, + xshape=one_dim_array_shapes, + yshape=one_dim_array_shapes, + ) + def testConvolutions(self, xshape, yshape, dtype, mode, op): + jnp_op = getattr(jnp, op) + np_op = getattr(np, op) + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(xshape, dtype), rng(yshape, dtype)] + precision = lax.Precision.HIGHEST if jtu.test_device_matches(["tpu"]) else None + jnp_fun = partial(jnp_op, mode=mode, precision=precision) + def np_fun(x, y): + return np_op(x, y, mode=mode).astype(dtypes.to_inexact_dtype(dtype)) + tol = {np.float16: 2e-1, np.float32: 1e-2, np.float64: 1e-14, + np.complex128: 1e-14} + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True, tol=tol) + self._CompileAndCheck(jnp_fun, args_maker) + + @jtu.sample_product( + mode=['full', 'same', 'valid'], + op=['convolve', 'correlate'], + dtype=float_dtypes, #number_dtypes, + xshape=one_dim_array_shapes, + yshape=one_dim_array_shapes, + ) + @jtu.skip_on_devices("cuda", "rocm") # backends don't support all dtypes. + def testConvolutionsPreferredElementType(self, xshape, yshape, dtype, mode, op): + jnp_op = getattr(jnp, op) + np_op = getattr(np, op) + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(xshape, dtype), rng(yshape, dtype)] + precision = lax.Precision.HIGHEST if jtu.test_device_matches(["tpu"]) else None + jnp_fun = partial(jnp_op, mode=mode, precision=precision, + preferred_element_type=dtype) + def np_fun(x, y): + return np_op(x, y, mode=mode).astype(dtype) + tol = {np.float16: 2e-1, np.float32: 1e-2, np.float64: 1e-14, + np.complex128: 1e-14} + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True, tol=tol) + self._CompileAndCheck(jnp_fun, args_maker) + + @jtu.sample_product( + [dict(shape=shape, axis=axis) + for shape in all_shapes + for axis in [None] + list(range(-len(shape), len(shape)))], + op=["cumsum", "cumprod"], + dtype=all_dtypes, + out_dtype=[dtype for dtype in default_dtypes if dtype != np.float16], + ) + def testCumSumProd(self, axis, shape, dtype, out_dtype, op): + jnp_op = getattr(jnp, op) + np_op = getattr(np, op) + rng = jtu.rand_default(self.rng()) + np_fun = lambda arg: np_op(arg, axis=axis, dtype=out_dtype) + np_fun = jtu.ignore_warning(category=NumpyComplexWarning)(np_fun) + np_fun = jtu.ignore_warning(category=RuntimeWarning, + message="overflow encountered.*")(np_fun) + jnp_fun = lambda arg: jnp_op(arg, axis=axis, dtype=out_dtype) + jnp_fun = jtu.ignore_warning(category=jnp.ComplexWarning)(jnp_fun) + + args_maker = lambda: [rng(shape, dtype)] + + tol_thresholds = {dtypes.bfloat16: 4e-2} + tol = max(jtu.tolerance(dtype, tol_thresholds), + jtu.tolerance(out_dtype, tol_thresholds)) + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, + tol=tol) + self._CompileAndCheck(jnp_fun, args_maker) + + @jtu.sample_product( + [dict(shape=shape, axis=axis) + for shape in all_shapes + for axis in [None] + list(range(-len(shape), len(shape)))], + op=["nancumsum", "nancumprod"], + dtype=all_dtypes, + out_dtype=default_dtypes, + ) + def testNanCumSumProd(self, axis, shape, dtype, out_dtype, op): + jnp_op = getattr(jnp, op) + np_op = getattr(np, op) + rng = jtu.rand_some_nan(self.rng()) + np_fun = partial(np_op, axis=axis, dtype=out_dtype) + np_fun = jtu.ignore_warning(category=NumpyComplexWarning)(np_fun) + np_fun = jtu.ignore_warning(category=RuntimeWarning, + message="overflow encountered.*")(np_fun) + jnp_fun = partial(jnp_op, axis=axis, dtype=out_dtype) + jnp_fun = jtu.ignore_warning(category=jnp.ComplexWarning)(jnp_fun) + + args_maker = lambda: [rng(shape, dtype)] + + tol_thresholds = {dtypes.bfloat16: 4e-2, np.float16: 3e-3} + tol = max(jtu.tolerance(dtype, tol_thresholds), + jtu.tolerance(out_dtype, tol_thresholds)) + if dtype != jnp.bfloat16: + # numpy functions do not properly handle bfloat16 + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True, + tol=tol) + self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True) + + @unittest.skip("Jax-metal fail on testEye2") + @jtu.sample_product( + dtype=default_dtypes, + n=[0, 4], + m=[None, 0, 1, 3, 4], + k=[*range(-4, 4), -2**100, 2**100], + ) + def testEye(self, n, m, k, dtype): + np_fun = lambda: np.eye(n, M=m, k=k, dtype=dtype) + jnp_fun = lambda: jnp.eye(n, M=m, k=k, dtype=dtype) + args_maker = lambda: [] + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + + @jtu.sample_product( + dtype=default_dtypes, + n=[0, 4], + m=[None, 0, 1, 3, 4], + k=range(-4, 4), + ) + def testTri(self, m, n, k, dtype): + np_fun = lambda: np.tri(n, M=m, k=k, dtype=dtype) + jnp_fun = lambda: jnp.tri(n, M=m, k=k, dtype=dtype) + args_maker = lambda: [] + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + + @jtu.sample_product( + dtype=default_dtypes, + shape=[shape for shape in all_shapes if len(shape) >= 2], + op=["tril", "triu"], + k=list(range(-3, 3)), + ) + def testTriLU(self, dtype, shape, op, k): + rng = jtu.rand_default(self.rng()) + np_fun = lambda arg: getattr(np, op)(arg, k=k) + jnp_fun = lambda arg: getattr(jnp, op)(arg, k=k) + args_maker = lambda: [rng(shape, dtype)] + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + + @jtu.sample_product( + n=range(5), + k=range(-3, 3), + m=[None, *range(5)], + ) + def testTrilIndices(self, n, k, m): + np_fun = lambda n, k, m: np.tril_indices(n, k=k, m=m) + jnp_fun = lambda n, k, m: jnp.tril_indices(n, k=k, m=m) + args_maker = lambda: [n, k, m] + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + + @jtu.sample_product( + n=range(5), + k=range(-3, 3), + m=[None, *range(5)], + ) + def testTriuIndices(self, n, k, m): + np_fun = lambda n, k, m: np.triu_indices(n, k=k, m=m) + jnp_fun = lambda n, k, m: jnp.triu_indices(n, k=k, m=m) + args_maker = lambda: [n, k, m] + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + + @jtu.sample_product( + dtype=default_dtypes, + shape=[(1,1), (1,2), (2,2), (2,3), (3,2), (3,3), (4,4)], + k=[-1, 0, 1], + ) + def testTriuIndicesFrom(self, shape, dtype, k): + rng = jtu.rand_default(self.rng()) + np_fun = lambda arr, k: np.triu_indices_from(arr, k=k) + jnp_fun = lambda arr, k: jnp.triu_indices_from(arr, k=k) + args_maker = lambda: [rng(shape, dtype), k] + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + + @jtu.sample_product( + dtype=default_dtypes, + shape=[(1,1), (1,2), (2,2), (2,3), (3,2), (3,3), (4,4)], + k=[-1, 0, 1], + ) + def testTrilIndicesFrom(self, shape, dtype, k): + rng = jtu.rand_default(self.rng()) + np_fun = lambda arr, k: np.tril_indices_from(arr, k=k) + jnp_fun = lambda arr, k: jnp.tril_indices_from(arr, k=k) + args_maker = lambda: [rng(shape, dtype), k] + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + + @jtu.sample_product( + dtype=default_dtypes, + a_shape=[(0, 0), (0, 1), (1, 0), (1, 1), (2, 0), (2, 1), (2, 2), (1, 2), (0, 2), (2, 3), (2, 2, 2), (2, 2, 2, 2)], + val_shape=[(), (1,), (2,), (1, 2), (3, 2)], + ) + def testFillDiagonal(self, dtype, a_shape, val_shape): + rng = jtu.rand_default(self.rng()) + + def np_fun(a, val): + a_copy = a.copy() + np.fill_diagonal(a_copy, val) + return a_copy + + jnp_fun = partial(jnp.fill_diagonal, inplace=False) + args_maker = lambda : [rng(a_shape, dtype), rng(val_shape, dtype)] + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + + @jtu.sample_product( + ndim=[0, 1, 4], + n=[0, 1, 7], + ) + def testDiagIndices(self, ndim, n): + np.testing.assert_equal(jtu.with_jax_dtype_defaults(np.diag_indices)(n, ndim), + jnp.diag_indices(n, ndim)) + + @jtu.sample_product( + dtype=default_dtypes, + shape=[(1,1), (2,2), (3,3), (4,4), (5,5)], + ) + def testDiagIndicesFrom(self, dtype, shape): + rng = jtu.rand_default(self.rng()) + np_fun = jtu.with_jax_dtype_defaults(np.diag_indices_from) + jnp_fun = jnp.diag_indices_from + args_maker = lambda : [rng(shape, dtype)] + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + + @jtu.sample_product( + dtype=default_dtypes, + shape=[shape for shape in all_shapes if len(shape) in (1, 2)], + k=list(range(-4, 4)), + ) + def testDiag(self, shape, dtype, k): + rng = jtu.rand_default(self.rng()) + np_fun = lambda arg: np.diag(arg, k) + jnp_fun = lambda arg: jnp.diag(arg, k) + args_maker = lambda: [rng(shape, dtype)] + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + + @jtu.sample_product( + dtype=default_dtypes, + shape=all_shapes, + k=list(range(-4, 4)), + ) + def testDiagFlat(self, shape, dtype, k): + rng = jtu.rand_default(self.rng()) + # numpy has inconsistencies for scalar values + # https://github.com/numpy/numpy/issues/16477 + # jax differs in that it treats scalars values as length-1 arrays + np_fun = lambda arg: np.diagflat(np.atleast_1d(arg), k) + jnp_fun = lambda arg: jnp.diagflat(arg, k) + args_maker = lambda: [rng(shape, dtype)] + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True) + self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True) + + @unittest.skip("jax-metal fail.") + @jtu.sample_product( + dtype=default_dtypes, + a1_shape=one_dim_array_shapes, + a2_shape=one_dim_array_shapes, + ) + def testPolyMul(self, a1_shape, a2_shape, dtype): + rng = jtu.rand_default(self.rng()) + np_fun = lambda arg1, arg2: np.polymul(arg1, arg2) + jnp_fun_np = lambda arg1, arg2: jnp.polymul(arg1, arg2, trim_leading_zeros=True) + jnp_fun_co = lambda arg1, arg2: jnp.polymul(arg1, arg2) + args_maker = lambda: [rng(a1_shape, dtype), rng(a2_shape, dtype)] + tol = {np.float16: 2e-1, np.float32: 5e-2, np.float64: 1e-13} + self._CheckAgainstNumpy(np_fun, jnp_fun_np, args_maker, check_dtypes=False, tol=tol) + self._CompileAndCheck(jnp_fun_co, args_maker, check_dtypes=False) + + @unittest.skip("jax-metal fail.") + @jtu.sample_product( + dtype=[dtype for dtype in default_dtypes + if dtype not in (np.float16, jnp.bfloat16)], + a_shape=one_dim_array_shapes, + b_shape=one_dim_array_shapes, + ) + def testPolyDiv(self, a_shape, b_shape, dtype): + rng = jtu.rand_default(self.rng()) + + @jtu.ignore_warning(category=RuntimeWarning, message="divide by zero.*") + @jtu.ignore_warning(category=RuntimeWarning, message="invalid value.*") + @jtu.ignore_warning(category=RuntimeWarning, message="overflow encountered.*") + def np_fun(arg1, arg2): + q, r = np.polydiv(arg1, arg2) + while r.size < max(arg1.size, arg2.size): # Pad residual to same size + r = np.pad(r, (1, 0), 'constant') + return q, r + + def jnp_fun(arg1, arg2): + q, r = jnp.polydiv(arg1, arg2, trim_leading_zeros=True) + while r.size < max(arg1.size, arg2.size): # Pad residual to same size + r = jnp.pad(r, (1, 0), 'constant') + return q, r + + args_maker = lambda: [rng(a_shape, dtype), rng(b_shape, dtype)] + tol = { + dtypes.bfloat16: 2e-1, + np.float16: 2e-1, + np.float32: 5e-2, + np.float64: 5e-7 + } + + jnp_compile = jnp.polydiv # Without trim_leading_zeros (trim_zeros make it unable to be compiled by XLA) + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False, tol=tol) + self._CompileAndCheck(jnp_compile, args_maker, check_dtypes=True, atol=tol, rtol=tol) + + @jtu.sample_product( + [dict(shape=shape, axis1=axis1, axis2=axis2) + for shape in [shape for shape in all_shapes if len(shape) >= 2] + for axis1 in range(-len(shape), len(shape)) + for axis2 in [a for a in range(-len(shape), len(shape)) + if a % len(shape) != axis1 % len(shape)] + ], + dtype=default_dtypes, + offset=list(range(-4, 4)), + ) + def testDiagonal(self, shape, dtype, offset, axis1, axis2): + rng = jtu.rand_default(self.rng()) + np_fun = lambda arg: np.diagonal(arg, offset, axis1, axis2) + jnp_fun = lambda arg: jnp.diagonal(arg, offset, axis1, axis2) + args_maker = lambda: [rng(shape, dtype)] + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + + @jtu.sample_product( + dtype=default_dtypes, + n=list(range(4)), + ) + def testIdentity(self, n, dtype): + np_fun = lambda: np.identity(n, dtype) + jnp_fun = lambda: jnp.identity(n, dtype) + args_maker = lambda: [] + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + + @unittest.skip("jax-metal crash.") + @jtu.sample_product( + shape=nonempty_shapes, + period=[None, 0.59], + left=[None, 0], + right=[None, 1], + # Note: skip 8-bit and 16-bit types due to insufficient precision. + dtype=jtu.dtypes.integer + jtu.dtypes.floating, + target_dtype=jtu.dtypes.inexact, + ) + def testInterp(self, shape, dtype, period, left, right, target_dtype): + rng = jtu.rand_default(self.rng(), scale=10) + kwds = dict(period=period, left=left, right=right) + np_fun = partial(np.interp, **kwds) + jnp_fun = partial(jnp.interp, **kwds) + + args_maker = lambda: [rng(shape, dtype), np.unique(rng((100,), dtype))[:20], + rng((20,), target_dtype)] + + with jtu.strict_promotion_if_dtypes_match([dtype, target_dtype]): + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False, + rtol=3e-3, atol=1e-3) + self._CompileAndCheck(jnp_fun, args_maker) + + @unittest.skip("jax-metal crash.") + @jtu.sample_product([ + dict(x=0.5, left='extrapolate', expected=5), + dict(x=1.5, left='extrapolate', expected=15), + dict(x=3.5, left='extrapolate', expected=30), + dict(x=3.9, right='extrapolate', expected=39), + ]) + def testInterpExtrapoate(self, x, expected, **kwargs): + xp = jnp.array([1.0, 2.0, 3.0]) + fp = jnp.array([10.0, 20.0, 30.0]) + actual = jnp.interp(x, xp, fp, **kwargs) + self.assertAlmostEqual(actual, expected) + + def testInterpErrors(self): + with self.assertRaisesWithLiteralMatch( + ValueError, + 'xp and fp must be one-dimensional arrays of equal size' + ): + jnp.interp(0.0, jnp.arange(2.0), jnp.arange(3.0)) + with self.assertRaisesWithLiteralMatch( + ValueError, + "the only valid string value of `left` is 'extrapolate', but got: 'interpolate'" + ): + jnp.interp(0.0, jnp.arange(3.0), jnp.arange(3.0), left='interpolate') + with self.assertRaisesWithLiteralMatch( + ValueError, + "the only valid string value of `right` is 'extrapolate', but got: 'interpolate'" + ): + jnp.interp(0.0, jnp.arange(3.0), jnp.arange(3.0), right='interpolate') + with self.assertRaisesWithLiteralMatch( + ValueError, + "jnp.interp: complex x values not supported." + ): + jnp.interp(1j, 1j * np.arange(3.0), np.arange(3.0)) + with self.assertRaisesRegex( + ValueError, + "period must be a scalar; got" + ): + jnp.interp(0.0, jnp.arange(3.0), jnp.arange(3.0), period=np.array([1.0])) + + @jtu.sample_product( + period=[None, 0.59], + left=[None, 0], + right=[None, 1], + dtype=jtu.dtypes.floating, + ) + def testInterpGradNan(self, dtype, period, left, right): + kwds = dict(period=period, left=left, right=right) + jnp_fun = partial(jnp.interp, **kwds) + # Probe values of x and xp that are close to zero and close together. + x = dtype(np.exp(np.linspace(-90, -20, 1000))) + g = jax.grad(lambda z: jnp.sum(jnp_fun(z, z, jnp.ones_like(z))))(x) + np.testing.assert_equal(np.all(np.isfinite(g)), True) + + @jtu.sample_product( + [dict(x1_shape=x1_shape, x2_shape=x2_shape) + for x1_shape, x2_shape in filter(_shapes_are_broadcast_compatible, + itertools.combinations_with_replacement(array_shapes, 2)) + ], + x1_rng_factory=[jtu.rand_some_inf_and_nan, jtu.rand_some_zero], + x2_rng_factory=[partial(jtu.rand_int, low=-1075, high=1024)], + x1_dtype=default_dtypes, + ) + @jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion. + def testLdexp(self, x1_shape, x1_dtype, x2_shape, x1_rng_factory, x2_rng_factory): + x1_rng = x1_rng_factory(self.rng()) + x2_rng = x2_rng_factory(self.rng()) + + @jtu.ignore_warning(category=RuntimeWarning, message="overflow.*") + def np_fun(x1, x2): + out_dtype = dtypes.to_inexact_dtype(x1.dtype) + return np.ldexp(x1.astype(out_dtype), x2) + + jnp_fun = jnp.ldexp + args_maker = lambda: [x1_rng(x1_shape, x1_dtype), + x2_rng(x2_shape, np.int32)] + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + + @jtu.sample_product( + rng_factory=[ + jtu.rand_some_inf_and_nan, + jtu.rand_some_zero, + partial(jtu.rand_not_small, offset=1e8), + ], + shape=all_shapes, + dtype=default_dtypes, + ) + def testFrexp(self, shape, dtype, rng_factory): + # integer types are converted to float64 in numpy's implementation + if (dtype not in [jnp.bfloat16, np.float16, np.float32] + and not config.enable_x64.value): + self.skipTest("Only run float64 testcase when float64 is enabled.") + rng = rng_factory(self.rng()) + args_maker = lambda: [rng(shape, dtype)] + def np_frexp(x): + mantissa, exponent = np.frexp(x) + # NumPy is inconsistent between Windows and Linux/Mac on what the + # value of exponent is if the input is infinite. Normalize to the Linux + # behavior. + exponent = np.where(np.isinf(mantissa), np.zeros_like(exponent), exponent) + return mantissa, exponent + self._CheckAgainstNumpy(np_frexp, jnp.frexp, args_maker, + check_dtypes=np.issubdtype(dtype, np.inexact)) + self._CompileAndCheck(jnp.frexp, args_maker) + + @jtu.sample_product( + [dict(shape=shape, axis1=axis1, axis2=axis2) + for shape in [shape for shape in all_shapes if len(shape) >= 2] + for axis1 in range(-len(shape), len(shape)) + for axis2 in range(-len(shape), len(shape)) + if (axis1 % len(shape)) != (axis2 % len(shape)) + ], + dtype=default_dtypes, + out_dtype=[None] + number_dtypes, + offset=list(range(-4, 4)), + ) + def testTrace(self, shape, dtype, out_dtype, offset, axis1, axis2): + rng = jtu.rand_default(self.rng()) + def np_fun(arg): + if out_dtype == jnp.bfloat16: + return np.trace(arg, offset, axis1, axis2, np.float32).astype(jnp.bfloat16) + else: + return np.trace(arg, offset, axis1, axis2, out_dtype) + jnp_fun = lambda arg: jnp.trace(arg, offset, axis1, axis2, out_dtype) + args_maker = lambda: [rng(shape, dtype)] + # TODO: Fails with uint8/uint16 output dtypes (integer overflow?) + if out_dtype not in (np.uint8, np.uint16, np.uint32): + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + + #unittest.skip("jax-metal fail with empty vshape.") + @jtu.sample_product( + ashape=[(15,), (16,), (17,)], + vshape= [(5,), (5, 5)],#[(), (5,), (5, 5)], + side=['left', 'right'], + dtype= number_dtypes, + method=['sort', 'scan', 'scan_unrolled', 'compare_all'], + ) + def testSearchsorted(self, ashape, vshape, side, dtype, method): + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [np.sort(rng(ashape, dtype)), rng(vshape, dtype)] + def np_fun(a, v): + return np.searchsorted(a, v, side=side).astype('int32') + jnp_fun = lambda a, v: jnp.searchsorted(a, v, side=side, method=method) + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + + @unittest.skipIf( + platform.system() == "Windows", + "Under Windows, NumPy throws if 2**32 is converted to an int32" + ) + def testSearchsortedDtype(self): + # Test that for large arrays, int64 indices are used. We test this + # via abstract evaluation to avoid allocating a large array in tests. + a_int32 = core.ShapedArray((np.iinfo(np.int32).max,), np.float32) + a_int64 = core.ShapedArray((np.iinfo(np.int32).max + 1,), np.float32) + v = core.ShapedArray((), np.float32) + + out_int32 = jax.eval_shape(jnp.searchsorted, a_int32, v) + self.assertEqual(out_int32.dtype, np.int32) + + if config.enable_x64.value: + out_int64 = jax.eval_shape(jnp.searchsorted, a_int64, v) + self.assertEqual(out_int64.dtype, np.int64) + elif jtu.numpy_version() < (2, 0, 0): + with self.assertWarnsRegex(UserWarning, "Explicitly requested dtype int64"): + with jtu.ignore_warning(category=DeprecationWarning, + message="NumPy will stop allowing conversion.*"): + out_int64 = jax.eval_shape(jnp.searchsorted, a_int64, v) + else: + with self.assertWarnsRegex(UserWarning, "Explicitly requested dtype int64"): + with self.assertRaisesRegex(OverflowError, "Python integer 2147483648 out of bounds.*"): + out_int64 = jax.eval_shape(jnp.searchsorted, a_int64, v) + + @unittest.skip("Jax-metal fail.") + @jtu.sample_product( + dtype=inexact_dtypes, + side=['left', 'right'], + method=['sort', 'scan', 'compare_all'], + ) + def testSearchsortedNans(self, dtype, side, method): + if np.issubdtype(dtype, np.complexfloating): + raise SkipTest("Known failure for complex inputs; see #9107") + x = np.array([-np.inf, -1.0, 0.0, -0.0, 1.0, np.inf, np.nan, -np.nan], dtype=dtype) + # The sign bit should not matter for 0.0 or NaN, so argsorting the above should be + # equivalent to argsorting the following: + x_equiv = np.array([0, 1, 2, 2, 3, 4, 5, 5]) + + if jnp.issubdtype(dtype, jnp.complexfloating): + x = np.array([complex(r, c) for r, c in itertools.product(x, repeat=2)]) + x_equiv = np.array([complex(r, c) for r, c in itertools.product(x_equiv, repeat=2)]) + + fun = partial(jnp.searchsorted, side=side, method=method) + self.assertArraysEqual(fun(x, x), fun(x_equiv, x_equiv)) + self.assertArraysEqual(jax.jit(fun)(x, x), fun(x_equiv, x_equiv)) + + @jtu.sample_product( + xshape=[(20,), (5, 4)], + binshape=[(1,), (5,)], + right=[True, False], + reverse=[True, False], + dtype=default_dtypes, + ) + def testDigitize(self, xshape, binshape, right, reverse, dtype): + order = jnp.index_exp[::-1] if reverse else jnp.index_exp[:] + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(xshape, dtype), jnp.sort(rng(binshape, dtype))[order]] + np_fun = lambda x, bins: np.digitize(x, bins, right=right).astype('int32') + jnp_fun = lambda x, bins: jnp.digitize(x, bins, right=right) + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + + @jtu.sample_product( + dtypes=[ + [np.float32], + [np.float32, np.float32], + [np.float32, np.int32, np.float32], + [np.float32, np.int64, np.float32], + [np.float32, np.int32, np.float64], + ], + shape=[(), (2,), (3, 4), (1, 5)], + array_input=[True, False], + ) + def testColumnStack(self, shape, dtypes, array_input): + rng = jtu.rand_default(self.rng()) + if array_input: + args_maker = lambda: [np.array([rng(shape, dtype) for dtype in dtypes])] + else: + args_maker = lambda: [[rng(shape, dtype) for dtype in dtypes]] + np_fun = jtu.promote_like_jnp(np.column_stack) + jnp_fun = jnp.column_stack + with jtu.strict_promotion_if_dtypes_match(dtypes): + self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + + @jtu.sample_product( + [dict(shape=shape, axis=axis) + for shape in [(), (2,), (3, 4), (1, 100)] + for axis in range(-len(shape), len(shape) + 1) + ], + dtypes=[ + [np.float32], + [np.float32, np.float32], + [np.float32, np.int32, np.float32], + [np.float32, np.int64, np.float32], + [np.float32, np.int32, np.float64], + ], + array_input=[True, False], + out_dtype=[np.float32, np.int32], + ) + def testStack(self, shape, axis, dtypes, array_input, out_dtype): + rng = jtu.rand_default(self.rng()) + if array_input: + args_maker = lambda: [np.array([rng(shape, dtype) for dtype in dtypes])] + else: + args_maker = lambda: [[rng(shape, dtype) for dtype in dtypes]] + + np_fun = jtu.promote_like_jnp(partial(np.stack, axis=axis, dtype=out_dtype, casting='unsafe')) + + jnp_fun = partial(jnp.stack, axis=axis, dtype=out_dtype) + with jtu.strict_promotion_if_dtypes_match(dtypes): + self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + + @jtu.sample_product( + op=["hstack", "vstack", "dstack"], + dtypes=[ + [np.float32], + [np.float32, np.float32], + [np.float32, np.int32, np.float32], + [np.float32, np.int64, np.float32], + [np.float32, np.int32, np.float64], + ], + shape=[(), (2,), (3, 4), (1, 100), (2, 3, 4)], + array_input=[True, False], + out_dtype=[np.float32, np.int32], + ) + def testHVDStack(self, shape, op, dtypes, array_input, out_dtype): + rng = jtu.rand_default(self.rng()) + if array_input: + args_maker = lambda: [np.array([rng(shape, dtype) for dtype in dtypes])] + else: + args_maker = lambda: [[rng(shape, dtype) for dtype in dtypes]] + + if op == "dstack": + np_fun = jtu.promote_like_jnp(lambda *args: getattr(np, op)(*args).astype(out_dtype)) + else: + np_fun = partial(jtu.promote_like_jnp(getattr(np, op)), dtype=out_dtype, + casting='unsafe') + + jnp_fun = partial(getattr(jnp, op), dtype=out_dtype) + with jtu.strict_promotion_if_dtypes_match(dtypes): + self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + + @jtu.sample_product( + [dict(name=name, **kwds) + for name in ['blackman', 'bartlett', 'hamming', 'hanning', 'kaiser'] + for kwds in ([dict(beta=1), dict(beta=0.5)] if name == 'kaiser' else [{}]) + ], + size = [0, 1, 5, 10], + ) + def testWindowFunction(self, name, size, **kwds): + jnp_fun = partial(getattr(jnp, name), size, **kwds) + np_fun = jtu.with_jax_dtype_defaults(partial(getattr(np, name), size, **kwds)) + args_maker = lambda: [] + tol = ( + 5e-6 if jtu.test_device_matches(['tpu']) and name == 'kaiser' else None + ) + self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker, atol=tol, rtol=tol) + self._CompileAndCheck(jnp_fun, args_maker) + + @jtu.sample_product( + [dict(shape=shape, fill_value_shape=fill_value_shape) + for shape in array_shapes + [3, np.array(7, dtype=np.int32)] + for fill_value_shape in _compatible_shapes(shape)], + fill_value_dtype=default_dtypes, + out_dtype=[None] + default_dtypes, + ) + def testFull(self, shape, fill_value_dtype, fill_value_shape, out_dtype): + rng = jtu.rand_default(self.rng()) + np_fun = lambda fill_value: np.full(shape, fill_value, dtype=out_dtype) + jnp_fun = lambda fill_value: jnp.full(shape, fill_value, dtype=out_dtype) + args_maker = lambda: [rng(fill_value_shape, fill_value_dtype)] + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + + @jtu.sample_product( + [dict(shape=shape, dtype=dtype, axis=axis) + for shape, dtype in _shape_and_dtypes(nonempty_nonscalar_array_shapes, default_dtypes) + for axis in list(range(-len(shape), max(1, len(shape)))) + ], + prepend=[None, 1, 0], + append=[None, 1, 0], + n=[0, 1, 2], + ) + def testDiff(self, shape, dtype, n, axis, prepend, append): + prepend = np.zeros(shape, dtype=dtype) if prepend == 0 else prepend + append = np.zeros(shape, dtype=dtype) if append == 0 else append + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(shape, dtype)] + + def np_fun(x, n=n, axis=axis, prepend=prepend, append=append): + if prepend is None: + prepend = np._NoValue + elif not np.isscalar(prepend) and prepend.dtype == jnp.bfloat16: + prepend = prepend.astype(np.float32) + + if append is None: + append = np._NoValue + elif not np.isscalar(append) and append.dtype == jnp.bfloat16: + append = append.astype(np.float32) + + if x.dtype == jnp.bfloat16: + return np.diff(x.astype(np.float32), n=n, axis=axis, prepend=prepend, append=append).astype(jnp.bfloat16) + else: + return np.diff(x, n=n, axis=axis, prepend=prepend, append=append) + + jnp_fun = lambda x: jnp.diff(x, n=n, axis=axis, prepend=prepend, append=append) + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False) + self._CompileAndCheck(jnp_fun, args_maker) + + def testDiffPrepoendScalar(self): + # Regression test for https://github.com/google/jax/issues/19362 + x = jnp.arange(10) + result_jax = jnp.diff(x, prepend=x[0], append=x[-1]) + + x = np.array(x) + result_numpy = np.diff(x, prepend=x[0], append=x[-1]) + + self.assertArraysEqual(result_jax, result_numpy) + + @jtu.sample_product( + op=["zeros", "ones"], + shape=[2, (), (2,), (3, 0), np.array((4, 5, 6), dtype=np.int32), + np.array(4, dtype=np.int32)], + dtype=all_dtypes, + ) + def testZerosOnes(self, op, shape, dtype): + np_op = getattr(np, op) + jnp_op = getattr(jnp, op) + args_maker = lambda: [] + np_op = partial(np_op, shape, dtype) + jnp_op = partial(jnp_op, shape, dtype) + self._CheckAgainstNumpy(np_op, jnp_op, args_maker) + self._CompileAndCheck(jnp_op, args_maker) + + def testOnesWithInvalidShape(self): + with self.assertRaises(TypeError): + jnp.ones((-1, 1)) + + def test_full_like_commited(self): + x = jnp.array((1, 2, 3), dtype=np.int32) + self.assertFalse(x._committed) + self.assertFalse(lax.full_like(x, 1.1)._committed) + x = jax.device_put(x, jax.devices()[-1]) + self.assertTrue(x._committed) + y = lax.full_like(x, 1.1) + self.assertTrue(y._committed) + self.assertEqual(x.sharding, y.sharding) + + def test_zeros_like_with_explicit_device_and_jitted(self): + x = jnp.array((1, 2, 3), dtype=np.int32) + x = jax.device_put(x, jax.devices()[0]) + zeros_like_with_device = partial(jnp.zeros_like, device=jax.devices()[0]) + y = jax.jit(zeros_like_with_device)(x) + self.assertEqual(x.shape, y.shape) + self.assertEqual(y.sharding, SingleDeviceSharding(jax.devices()[0])) + + @jtu.sample_product( + [dict(shape=shape, out_shape=out_shape, fill_value_shape=fill_value_shape) + for shape in array_shapes + for out_shape in [None] + array_shapes + for fill_value_shape in _compatible_shapes(shape if out_shape is None else out_shape) + ], + in_dtype=default_dtypes, + fill_value_dtype=default_dtypes, + out_dtype=default_dtypes, + ) + def testFullLike(self, shape, in_dtype, fill_value_dtype, fill_value_shape, out_dtype, out_shape): + rng = jtu.rand_default(self.rng()) + np_fun = lambda x, fill_value: np.full_like( + x, fill_value, dtype=out_dtype, shape=out_shape) + jnp_fun = lambda x, fill_value: jnp.full_like( + x, fill_value, dtype=out_dtype, shape=out_shape) + args_maker = lambda: [rng(shape, in_dtype), rng(fill_value_shape, fill_value_dtype)] + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + + @jtu.sample_product( + shape=array_shapes, + out_shape=[None] + array_shapes, + in_dtype=default_dtypes, + func=["ones_like", "zeros_like"], + out_dtype=default_dtypes, + ) + def testZerosOnesLike(self, func, shape, in_dtype, out_shape, out_dtype): + rng = jtu.rand_default(self.rng()) + np_fun = lambda x: getattr(np, func)(x, dtype=out_dtype, shape=out_shape) + jnp_fun = lambda x: getattr(jnp, func)(x, dtype=out_dtype, shape=out_shape) + args_maker = lambda: [rng(shape, in_dtype)] + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + + @jtu.sample_product( + func=[jnp.empty, jnp.zeros, jnp.ones, jnp.full], + shape=array_shapes, + dtype=default_dtypes, + ) + def testArrayCreationWithDevice(self, func, shape, dtype): + device = jax.devices()[-1] + kwds = {'fill_value': 1} if func is jnp.full else {} + out = func(**kwds, shape=shape, dtype=dtype, device=device) + self.assertEqual(out.devices(), {device}) + + @jtu.sample_product( + func=[jnp.empty, jnp.zeros, jnp.ones, jnp.full], + shape=array_shapes, + dtype=default_dtypes, + ) + def testArrayCreationWithSharding(self, func, shape, dtype): + sharding = SingleDeviceSharding(jax.devices()[-1]) + kwds = {'fill_value': 1} if func is jnp.full else {} + out = func(**kwds, shape=shape, dtype=dtype, device=sharding) + self.assertEqual(out.sharding, sharding) + + @jtu.sample_product( + func=[jnp.empty_like, jnp.zeros_like, jnp.ones_like, jnp.full_like], + shape=array_shapes, + dtype=default_dtypes, + ) + def testFullLikeWithDevice(self, func, shape, dtype): + device = jax.devices()[-1] + rng = jtu.rand_default(self.rng()) + x = rng(shape, dtype) + kwds = {'fill_value': 1} if func is jnp.full_like else {} + + with self.subTest('device from keyword'): + out = func(x, **kwds, device=device) + self.assertEqual(out.devices(), {device}) + + with self.subTest('device from input array'): + out2 = func(out, **kwds) + self.assertEqual(out2.devices(), out.devices()) + + @jtu.sample_product( + func=[jnp.empty_like, jnp.zeros_like, jnp.ones_like, jnp.full_like], + shape=array_shapes, + dtype=default_dtypes, + ) + def testFullLikeWithSharding(self, func, shape, dtype): + sharding = SingleDeviceSharding(jax.devices()[-1]) + rng = jtu.rand_default(self.rng()) + x = rng(shape, dtype) + kwds = {'fill_value': 1} if func is jnp.full_like else {} + + with self.subTest('device from keyword'): + out = func(x, **kwds, device=sharding) + self.assertEqual(out.sharding, sharding) + + with self.subTest('device from input array'): + out2 = func(out, **kwds) + self.assertEqual(out2.devices(), out.devices()) + + def testDuckTypedLike(self): + x = jax.ShapeDtypeStruct((1, 2, 3), np.dtype("int32")) + self.assertArraysEqual(jnp.zeros_like(x), jnp.zeros(x.shape, x.dtype)) + self.assertArraysEqual(jnp.ones_like(x), jnp.ones(x.shape, x.dtype)) + self.assertArraysEqual(jnp.empty_like(x), jnp.empty(x.shape, x.dtype)) + self.assertArraysEqual(jnp.full_like(x, 2), jnp.full(x.shape, 2, x.dtype)) + + @jtu.sample_product( + [dict(func=func, args=args) + for func, args in [("full_like", (-100,)), ("ones_like", ()), ("zeros_like", ())] + ], + shape=array_shapes, + #in_dtype=[np.int32, np.float32, np.complex64], + in_dtype=[np.int32, np.float32], + weak_type=[True, False], + out_shape=[None, (), (10,)], + out_dtype=[None, float], + ) + def testZerosOnesFullLikeWeakType(self, func, args, shape, in_dtype, weak_type, out_shape, out_dtype): + rng = jtu.rand_default(self.rng()) + x = lax_internal._convert_element_type(rng(shape, in_dtype), + weak_type=weak_type) + fun = lambda x: getattr(jnp, func)(x, *args, dtype=out_dtype, shape=out_shape) + expected_weak_type = weak_type and (out_dtype is None) + self.assertEqual(dtypes.is_weakly_typed(fun(x)), expected_weak_type) + self.assertEqual(dtypes.is_weakly_typed(jax.jit(fun)(x)), expected_weak_type) + + @jtu.sample_product( + funcname=["array", "asarray"], + dtype=[int, float, None], + val=[0, 1], + input_type=[int, float, np.int32, np.float32], + ) + def testArrayWeakType(self, funcname, input_type, val, dtype): + func = lambda x: getattr(jnp, funcname)(x, dtype=dtype) + fjit = jax.jit(func) + val = input_type(val) + expected_weak_type = dtype is None and input_type in set(dtypes._weak_types) + self.assertEqual(dtypes.is_weakly_typed(func(val)), expected_weak_type) + self.assertEqual(dtypes.is_weakly_typed(fjit(val)), expected_weak_type) + + @jtu.sample_product( + shape=nonempty_nonscalar_array_shapes, + #dtype=[int, float, complex], + dtype=[int, float], + weak_type=[True, False], + slc=[slice(None), slice(0), slice(3), 0, ...], + ) + def testSliceWeakTypes(self, shape, dtype, weak_type, slc): + rng = jtu.rand_default(self.rng()) + x = lax_internal._convert_element_type(rng(shape, dtype), + weak_type=weak_type) + op = lambda x: x[slc] + self.assertEqual(op(x).aval.weak_type, weak_type) + self.assertEqual(jax.jit(op)(x).aval.weak_type, weak_type) + + @jtu.sample_product( + [dict(shape=shape, axis=axis, num_sections=num_sections) + for shape, axis, num_sections in [ + ((3,), 0, 3), ((12,), 0, 3), ((12, 4), 0, 4), ((12, 4), 1, 2), + ((2, 3, 4), -1, 2), ((2, 3, 4), -2, 3)] + ], + dtype=default_dtypes, + ) + def testSplitStaticInt(self, shape, num_sections, axis, dtype): + rng = jtu.rand_default(self.rng()) + np_fun = lambda x: np.split(x, num_sections, axis=axis) + jnp_fun = lambda x: jnp.split(x, num_sections, axis=axis) + args_maker = lambda: [rng(shape, dtype)] + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + + @jtu.sample_product( + [dict(shape=shape, axis=axis, num_sections=num_sections) + # All testcases split the specified axis unequally + for shape, axis, num_sections in [ + ((3,), 0, 2), ((12,), 0, 5), ((12, 4), 0, 7), ((12, 4), 1, 3), + ((2, 3, 5), -1, 2), ((2, 4, 4), -2, 3), ((7, 2, 2), 0, 3)] + ], + dtype=default_dtypes, + ) + def testArraySplitStaticInt(self, shape, num_sections, axis, dtype): + rng = jtu.rand_default(self.rng()) + np_fun = lambda x: np.array_split(x, num_sections, axis=axis) + jnp_fun = lambda x: jnp.array_split(x, num_sections, axis=axis) + args_maker = lambda: [rng(shape, dtype)] + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + + def testSplitTypeError(self): + # If we pass an ndarray for indices_or_sections -> no error + self.assertEqual(3, len(jnp.split(jnp.zeros(3), jnp.array([1, 2])))) + + CONCRETIZATION_MSG = "Abstract tracer value encountered where concrete value is expected." + with self.assertRaisesRegex(TypeError, CONCRETIZATION_MSG): + # An abstract tracer for idx + jax.jit(lambda idx: jnp.split(jnp.zeros((12, 2)), idx))(2.) + with self.assertRaisesRegex(TypeError, CONCRETIZATION_MSG): + # A list including an abstract tracer + jax.jit(lambda idx: jnp.split(jnp.zeros((12, 2)), [2, idx]))(2.) + + # A concrete tracer -> no error + jax.jvp(lambda idx: jnp.split(jnp.zeros((12, 2)), idx), + (2.,), (1.,)) + # A tuple including a concrete tracer -> no error + jax.jvp(lambda idx: jnp.split(jnp.zeros((12, 2)), (1, idx.astype(np.int32))), + (2.,), (1.,)) + + @jtu.sample_product( + shape=[(5,), (5, 5)], + dtype=number_dtypes, + bins=[10, np.arange(-5, 6), np.array([-5, 0, 3])], + range=[None, (0, 0), (0, 10)], + weights=[True, False], + ) + def testHistogramBinEdges(self, shape, dtype, bins, range, weights): + rng = jtu.rand_default(self.rng()) + _weights = lambda w: abs(w) if weights else None + np_fun = lambda a, w, r: np.histogram_bin_edges(a, bins=bins, range=r, + weights=_weights(w)) + jnp_fun = lambda a, w, r: jnp.histogram_bin_edges(a, bins=bins, range=r, + weights=_weights(w)) + args_maker = lambda: [rng(shape, dtype), rng(shape, dtype), range] + tol = {jnp.bfloat16: 2E-2, np.float16: 1E-2} + # linspace() compares poorly to numpy when using bfloat16 + if dtype != jnp.bfloat16: + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False, tol=tol) + self._CompileAndCheck(jnp_fun, args_maker, + atol=tol, rtol=tol) + + @jtu.sample_product( + shape=[(5,), (4, 5)], + dtype=default_dtypes, + # We only test explicit integer-valued bin edges because in other cases + # rounding errors lead to flaky tests. + bins=[np.arange(-5, 6), np.array([-5, 0, 3])], + density=[True, False], + weights=[True, False], + ) + def testHistogram(self, shape, dtype, bins, density, weights): + rng = jtu.rand_default(self.rng()) + _weights = lambda w: abs(w) if weights else None + def np_fun(a, w): + # Numpy can't handle bfloat16 + a = a.astype('float32') if a.dtype == jnp.bfloat16 else a + w = w.astype('float32') if w.dtype == jnp.bfloat16 else w + return np.histogram(a, bins=bins, density=density, weights=_weights(w)) + jnp_fun = lambda a, w: jnp.histogram(a, bins=bins, density=density, + weights=_weights(w)) + args_maker = lambda: [rng(shape, dtype), rng(shape, dtype)] + tol = {jnp.bfloat16: 2E-2, np.float16: 1E-1} + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False, + tol=tol) + self._CompileAndCheck(jnp_fun, args_maker) + + @jtu.sample_product( + shape=[(5,), (12,)], + dtype=int_dtypes, + bins=[2, [2, 2], [np.array([0, 1, 3, 5]), np.array([0, 2, 3, 4, 6])]], + weights=[False, True], + density=[False, True], + range=[None, [(-1, 1), None], [(-1, 1), (-2, 2)]], + ) + def testHistogram2d(self, shape, dtype, bins, weights, density, range): + rng = jtu.rand_default(self.rng()) + _weights = lambda w: abs(w) if weights else None + np_fun = jtu.ignore_warning(category=RuntimeWarning, message="invalid value.*")( + lambda a, b, w: np.histogram2d(a, b, bins=bins, weights=_weights(w), density=density, range=range)) + jnp_fun = lambda a, b, w: jnp.histogram2d(a, b, bins=bins, weights=_weights(w), density=density, range=range) + args_maker = lambda: [rng(shape, dtype), rng(shape, dtype), rng(shape, dtype)] + tol = {jnp.bfloat16: 2E-2, np.float16: 1E-1} + # np.searchsorted errors on bfloat16 with + # "TypeError: invalid type promotion with custom data type" + with np.errstate(divide='ignore', invalid='ignore'): + if dtype != jnp.bfloat16: + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False, + tol=tol) + self._CompileAndCheck(jnp_fun, args_maker) + + @jtu.sample_product( + shape=[(5, 3), (10, 3)], + dtype=int_dtypes, + bins=[(2, 2, 2), [np.array([-5, 0, 4]), np.array([-4, -1, 2]), np.array([-6, -1, 4])]], + weights=[False, True], + density=[False, True], + range=[None, [(-1, 1), None, None], [(-1, 1), (-2, 2), (-3, 3)]], + ) + def testHistogramdd(self, shape, dtype, bins, weights, density, range): + rng = jtu.rand_default(self.rng()) + _weights = lambda w: abs(w) if weights else None + np_fun = jtu.ignore_warning(category=RuntimeWarning, message="invalid value.*")( + lambda a, w: np.histogramdd(a, bins=bins, weights=_weights(w), density=density, range=range)) + jnp_fun = lambda a, w: jnp.histogramdd(a, bins=bins, weights=_weights(w), density=density, range=range) + args_maker = lambda: [rng(shape, dtype), rng((shape[0],), dtype)] + tol = {jnp.bfloat16: 2E-2, np.float16: 1E-1} + # np.searchsorted errors on bfloat16 with + # "TypeError: invalid type promotion with custom data type" + if dtype != jnp.bfloat16: + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False, + tol=tol) + self._CompileAndCheck(jnp_fun, args_maker) + + @jtu.sample_product( + [dict(shape=shape, axis=axis, num_sections=num_sections) + for shape, axis, num_sections in [ + ((12, 4), 0, 4), ((12,), 1, 2), + ((2, 3, 4), 2, 2), ((4, 3, 4), 0, 2)]], + dtype=default_dtypes, + ) + def testHVDSplit(self, shape, num_sections, axis, dtype): + rng = jtu.rand_default(self.rng()) + def fn(module, axis): + if axis == 0: + return module.vsplit + elif axis == 1: + return module.hsplit + else: + assert axis == 2 + return module.dsplit + + np_fun = lambda x: fn(np, axis)(x, num_sections) + jnp_fun = lambda x: fn(jnp, axis)(x, num_sections) + args_maker = lambda: [rng(shape, dtype)] + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + + @jtu.sample_product( + [dict(arg_shape=arg_shape, out_shape=out_shape) + for arg_shape, out_shape in [ + (jtu.NUMPY_SCALAR_SHAPE, (1, 1, 1)), + ((), (1, 1, 1)), + ((7, 0), (0, 42, 101)), + ((3, 4), 12), + ((3, 4), (12,)), + ((3, 4), -1), + ((2, 1, 4), (-1,)), + ((2, 2, 4), (2, 8)) + ] + ], + dtype=default_dtypes, + order=["C", "F"], + ) + def testReshape(self, arg_shape, out_shape, dtype, order): + rng = jtu.rand_default(self.rng()) + np_fun = lambda x: np.reshape(x, out_shape, order=order) + jnp_fun = lambda x: jnp.reshape(x, out_shape, order=order) + args_maker = lambda: [rng(arg_shape, dtype)] + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + + @jtu.sample_product( + [dict(arg_shape=arg_shape, out_shape=out_shape) + for arg_shape, out_shape in [ + ((7, 0), (0, 42, 101)), + ((2, 1, 4), (-1,)), + ((2, 2, 4), (2, 8)) + ] + ], + dtype=default_dtypes, + ) + def testReshapeMethod(self, arg_shape, out_shape, dtype): + rng = jtu.rand_default(self.rng()) + np_fun = lambda x: np.reshape(x, out_shape) + jnp_fun = lambda x: x.reshape(*out_shape) + args_maker = lambda: [rng(arg_shape, dtype)] + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + + @jtu.sample_product( + [dict(arg_shape=arg_shape, out_shape=out_shape) + for arg_shape, out_shape in itertools.product(all_shapes, array_shapes)], + dtype=default_dtypes, + ) + def testResize(self, arg_shape, out_shape, dtype): + rng = jtu.rand_default(self.rng()) + np_fun = lambda x: np.resize(x, out_shape) + jnp_fun = lambda x: jnp.resize(x, out_shape) + args_maker = lambda: [rng(arg_shape, dtype)] + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + + @jtu.sample_product( + [dict(arg_shape=arg_shape, dim=dim) + for arg_shape in [(), (3,), (3, 4)] + for dim in (list(range(-len(arg_shape)+1, len(arg_shape))) + + [np.array(0), np.array(-1), (0,), [np.array(0)], + (len(arg_shape), len(arg_shape) + 1)]) + ], + dtype=default_dtypes, + ) + def testExpandDimsStaticDim(self, arg_shape, dtype, dim): + rng = jtu.rand_default(self.rng()) + np_fun = lambda x: np.expand_dims(x, dim) + jnp_fun = lambda x: jnp.expand_dims(x, dim) + args_maker = lambda: [rng(arg_shape, dtype)] + self._CompileAndCheck(jnp_fun, args_maker) + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + + def testExpandDimsRepeatedAxisError(self): + x = jnp.ones((2, 3)) + self.assertRaisesRegex( + ValueError, 'repeated axis.*', + lambda: jnp.expand_dims(x, [1, 1])) + self.assertRaisesRegex( + ValueError, 'repeated axis.*', + lambda: jnp.expand_dims(x, [3, -1])) + + # ensure this is numpy's behavior too, so that we remain consistent + x = np.ones((2, 3)) + self.assertRaisesRegex( + ValueError, 'repeated axis.*', + lambda: np.expand_dims(x, [1, 1])) + self.assertRaisesRegex( + ValueError, 'repeated axis.*', + lambda: np.expand_dims(x, [3, -1])) + + @jtu.sample_product( + [dict(arg_shape=arg_shape, ax1=ax1, ax2=ax2) + for arg_shape, ax1, ax2 in [ + ((3, 4), 0, 1), ((3, 4), 1, 0), ((3, 4, 5), 1, 2), + ((3, 4, 5), -1, -2), ((3, 4, 5), 0, 1)] + ], + dtype=default_dtypes, + ) + def testSwapAxesStaticAxes(self, arg_shape, dtype, ax1, ax2): + rng = jtu.rand_default(self.rng()) + np_fun = lambda x: np.swapaxes(x, ax1, ax2) + jnp_fun = lambda x: jnp.swapaxes(x, ax1, ax2) + args_maker = lambda: [rng(arg_shape, dtype)] + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + + @jtu.sample_product( + [dict(arg_shape=arg_shape, ax=ax) + for arg_shape, ax in [ + ((3, 1), None), + ((3, 1), 1), + ((3, 1), -1), + ((3, 1), np.array(1)), + ((1, 3, 1), (0, 2)), + ((1, 3, 1), (0,)), + ((1, 4, 1), (np.array(0),))] + ], + dtype=default_dtypes, + ) + def testSqueeze(self, arg_shape, dtype, ax): + rng = jtu.rand_default(self.rng()) + np_fun = lambda x: np.squeeze(x, ax) + jnp_fun = lambda x: jnp.squeeze(x, ax) + args_maker = lambda: [rng(arg_shape, dtype)] + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + + def testArrayFromMasked(self): + args_maker = lambda: [np.ma.array([1, 2], mask=[True, False])] + # Like np.array, jnp.array strips the mask from masked array inputs. + self._CheckAgainstNumpy(np.array, jnp.array, args_maker) + # Under JIT, masked arrays are flagged as invalid. + with self.assertRaisesRegex(ValueError, "numpy masked arrays are not supported"): + jax.jit(jnp.asarray)(*args_maker()) + + @jtu.sample_product( + [dict(arg=arg, dtype=dtype, ndmin=ndmin) + for arg, dtypes in [ + ([True, False, True], all_dtypes), + (3., all_dtypes), + ([1, 2, 3], all_dtypes), + (np.array([1, 2, 3], dtype=np.int64), all_dtypes), + ([1., 2., 3.], all_dtypes), + ([[1, 2], [3, 4], [5, 6]], all_dtypes), + ([[1, 2.], [3, 4], [5, 6]], all_dtypes), + ([[1., 2j], [3., 4.], [5., 6.]], complex_dtypes), + ([[3, np.array(2, dtype=jnp.float_), 1], + np.arange(3., dtype=jnp.float_)], all_dtypes), + ] + for dtype in [None] + dtypes + for ndmin in [None, np.ndim(arg), np.ndim(arg) + 1, np.ndim(arg) + 2] + ], + ) + def testArray(self, arg, ndmin, dtype): + args_maker = lambda: [arg] + canonical_dtype = dtypes.canonicalize_dtype(dtype or np.array(arg).dtype) + if ndmin is not None: + np_fun = partial(np.array, ndmin=ndmin, dtype=canonical_dtype) + jnp_fun = partial(jnp.array, ndmin=ndmin, dtype=dtype) + else: + np_fun = partial(np.array, dtype=canonical_dtype) + jnp_fun = partial(jnp.array, dtype=dtype) + + # We are testing correct canonicalization behavior here, so we turn off the + # permissive canonicalization logic in the test harness. + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, + canonicalize_dtypes=False) + self._CompileAndCheck(jnp_fun, args_maker) + + @jtu.sample_product(copy=[None, True, False]) + def testAsarrayCopy(self, copy): + x_jax = jnp.arange(4) + x_np = np.arange(4) + x_list = [0, 1, 2, 3] + x_buf = make_python_array('l', x_list) + + func = partial(jnp.asarray, copy=copy) + self.assertArraysEqual(x_jax, func(x_jax)) + self.assertArraysEqual(x_jax, func(x_list), check_dtypes=False) + + if copy is False and jax.default_backend() != 'cpu': + # copy=False is strict: it must raise if the input supports the buffer protocol + # but a copy is still required. + self.assertRaises(ValueError, func, x_np) + self.assertRaises(ValueError, func, x_buf) + else: + self.assertArraysEqual(x_jax, func(x_np), check_dtypes=False) + self.assertArraysEqual(x_jax, func(x_buf), check_dtypes=False) + + @unittest.skip("Jax-metal don't support all dtypes.") + @jtu.ignore_warning(category=UserWarning, message="Explicitly requested dtype.*") + def testArrayDtypeInference(self): + def _check(obj, out_dtype, weak_type): + dtype_reference = np.array(obj, dtype=out_dtype) + + out = jnp.array(obj) + self.assertDtypesMatch(out, dtype_reference) + self.assertEqual(dtypes.is_weakly_typed(out), weak_type) + + out_jit = jax.jit(jnp.array)(obj) + self.assertDtypesMatch(out_jit, dtype_reference) + self.assertEqual(dtypes.is_weakly_typed(out_jit), weak_type) + + # Python scalars become 64-bit weak types. + _check(1, np.int64, True) + _check(1.0, np.float64, True) + _check(1.0j, np.complex128, True) + + # Lists become strongly-typed defaults. + _check([1], jnp.int64, False) + _check([1.0], jnp.float64, False) + _check([1.0j], jnp.complex128, False) + + # Lists of weakly-typed objects become strongly-typed defaults. + _check([jnp.array(1)], jnp.int64, False) + _check([jnp.array(1.0)], jnp.float64, False) + _check([jnp.array(1.0j)], jnp.complex128, False) + + # Lists of strongly-typed objects maintain their strong type. + _check([jnp.int64(1)], np.int64, False) + _check([jnp.float64(1)], np.float64, False) + _check([jnp.complex128(1)], np.complex128, False) + + # Mixed inputs use JAX-style promotion. + # (regression test for https://github.com/google/jax/issues/8945) + _check([0, np.int16(1)], np.int16, False) + _check([0.0, np.float16(1)], np.float16, False) + + @jtu.sample_product( + dtype=all_dtypes, + func=["array", "copy", "copy.copy", "copy.deepcopy"], + ) + def testArrayCopy(self, dtype, func): + x = jnp.ones(10, dtype=dtype) + if func == "copy.deepcopy": + copy_func = copy.deepcopy + elif func == "copy.copy": + copy_func = copy.copy + else: + copy_func = getattr(jnp, func) + + x_view = jnp.asarray(x) + x_view_jit = jax.jit(jnp.asarray)(x) + x_copy = copy_func(x) + x_copy_jit = jax.jit(copy_func)(x) + + _ptr = lambda x: x.unsafe_buffer_pointer() + + self.assertEqual(_ptr(x), _ptr(x_view)) + self.assertNotEqual(_ptr(x), _ptr(x_view_jit)) + self.assertNotEqual(_ptr(x), _ptr(x_copy)) + self.assertNotEqual(_ptr(x), _ptr(x_copy_jit)) + + x.delete() + + self.assertTrue(x_view.is_deleted()) + self.assertFalse(x_view_jit.is_deleted()) + + self.assertFalse(x_copy.is_deleted()) + self.assertFalse(x_copy_jit.is_deleted()) + + def testArrayCopyAutodiff(self): + f = lambda x: jnp.array(x, copy=True) + + x = jnp.ones(10) + xdot = jnp.ones(10) + y, ydot = jax.jvp(f, (x,), (xdot,)) + self.assertIsNot(x, y) + self.assertIsNot(xdot, ydot) + + ybar = jnp.ones(10) + y, f_vjp = jax.vjp(f, x) + xbar, = f_vjp(ybar) + self.assertIsNot(x, y) + self.assertIsNot(xbar, ybar) + + def testArrayCopyVmap(self): + f = lambda x: jnp.array(x, copy=True) + x = jnp.ones(10) + y = jax.vmap(f)(x) + self.assertIsNot(x, y) + + def testArrayUnsupportedDtypeError(self): + with self.assertRaisesRegex(TypeError, + "JAX only supports number and bool dtypes.*"): + jnp.array(3, [('a',' 0.: + return x * 2 + else: + return x + 2 + + self.assertRaises(jax.errors.ConcretizationTypeError, lambda: g(3.)) + + @jtu.sample_product( + [dict(shape=shape, axis=axis) + for shape in [(3,), (2, 3)] + for axis in list(range(-len(shape), len(shape))) + [None] + [tuple(range(len(shape)))] # Test negative axes and tuples + ], + dtype=default_dtypes, + ) + def testFlip(self, shape, dtype, axis): + rng = jtu.rand_default(self.rng()) + args_maker = self._GetArgsMaker(rng, [shape], [dtype]) + jnp_op = lambda x: jnp.flip(x, axis) + np_op = lambda x: np.flip(x, axis) + self._CheckAgainstNumpy(np_op, jnp_op, args_maker) + self._CompileAndCheck(jnp_op, args_maker) + + @jtu.sample_product( + shape=[(3,), (2, 3), (3, 2, 4)], + dtype=default_dtypes, + ) + def testFlipud(self, shape, dtype): + rng = jtu.rand_default(self.rng()) + args_maker = self._GetArgsMaker(rng, [shape], [dtype]) + jnp_op = lambda x: jnp.flipud(x) + np_op = lambda x: np.flipud(x) + self._CheckAgainstNumpy(np_op, jnp_op, args_maker) + self._CompileAndCheck(jnp_op, args_maker) + + @jtu.sample_product( + shape=[(3, 2), (2, 3), (3, 2, 4)], + dtype=default_dtypes, + ) + def testFliplr(self, shape, dtype): + rng = jtu.rand_default(self.rng()) + args_maker = self._GetArgsMaker(rng, [shape], [dtype]) + jnp_op = lambda x: jnp.fliplr(x) + np_op = lambda x: np.fliplr(x) + self._CheckAgainstNumpy(np_op, jnp_op, args_maker) + self._CompileAndCheck(jnp_op, args_maker) + + @jtu.sample_product( + [dict(shape=shape, axes=axes) + for shape, axes in [ + [(2, 3), (0, 1)], + [(2, 3), (1, 0)], + [(4, 3, 2), (0, 2)], + [(4, 3, 2), (2, 1)], + ] + ], + k=range(-3, 4), + dtype=default_dtypes, + ) + def testRot90(self, shape, dtype, k, axes): + rng = jtu.rand_default(self.rng()) + args_maker = self._GetArgsMaker(rng, [shape], [dtype]) + jnp_op = lambda x: jnp.rot90(x, k, axes) + np_op = lambda x: np.rot90(x, k, axes) + self._CheckAgainstNumpy(np_op, jnp_op, args_maker) + self._CompileAndCheck(jnp_op, args_maker) + + # TODO(mattjj): test infix operator overrides + + def testRavel(self): + rng = self.rng() + args_maker = lambda: [rng.randn(3, 4).astype("float32")] + self._CompileAndCheck(lambda x: x.ravel(), args_maker) + + @jtu.sample_product( + shape=nonempty_nonscalar_array_shapes, + order=['C', 'F'], + mode=['wrap', 'clip', 'raise'], + ) + @jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion. + def testRavelMultiIndex(self, shape, order, mode): + # generate indices in each dimension with a few out of bounds. + rngs = [jtu.rand_int(self.rng(), low=-1, high=dim + 1) + for dim in shape] + # generate multi_indices of different dimensions that broadcast. + args_maker = lambda: [tuple(rng(ndim * (3,), jnp.int_) + for ndim, rng in enumerate(rngs))] + def np_fun(x): + try: + return np.ravel_multi_index(x, shape, order=order, mode=mode) + except ValueError as err: + if str(err).startswith('invalid entry'): + # sentinel indicating expected error. + return -999 + else: + raise + def jnp_fun(x): + try: + return jnp.ravel_multi_index(x, shape, order=order, mode=mode) + except ValueError as err: + if str(err).startswith('invalid entry'): + # sentinel indicating expected error. + return -999 + else: + raise + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False) + if mode == 'raise': + msg = ("The error occurred because ravel_multi_index was jit-compiled " + "with mode='raise'. Use mode='wrap' or mode='clip' instead.") + with self.assertRaisesRegex(core.ConcretizationTypeError, msg): + jax.jit(jnp_fun)(*args_maker()) + else: + self._CompileAndCheck(jnp_fun, args_maker) + + @jtu.sample_product( + ashape=((), (4,), (3, 4)), + cshapes=[ + [(), (4,)], + [(3, 4), (4,), (3, 1)] + ], + adtype=int_dtypes, + cdtype=default_dtypes, + mode=['wrap', 'clip', 'raise'], + ) + def testChoose(self, ashape, adtype, cshapes, cdtype, mode): + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(ashape, adtype), [rng(s, cdtype) for s in cshapes]] + def np_fun(a, c): + try: + return np.choose(a, c, mode=mode) + except ValueError as err: + if mode == 'raise' and str(err).startswith('invalid entry'): + return -999 # sentinel indicating expected error. + else: + raise + def jnp_fun(a, c): + try: + return jnp.choose(a, c, mode=mode) + except ValueError as err: + if mode == 'raise' and str(err).startswith('invalid entry'): + return -999 # sentinel indicating expected error. + else: + raise + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False) + if mode == 'raise': + msg = ("The error occurred because jnp.choose was jit-compiled" + " with mode='raise'. Use mode='wrap' or mode='clip' instead.") + with self.assertRaisesRegex(core.ConcretizationTypeError, msg): + jax.jit(jnp_fun)(*args_maker()) + else: + self._CompileAndCheck(jnp_fun, args_maker) + + @jtu.sample_product( + shape=nonempty_nonscalar_array_shapes, + dtype=int_dtypes, + idx_shape=all_shapes, + ) + def testUnravelIndex(self, shape, idx_shape, dtype): + size = math.prod(shape) + rng = jtu.rand_int(self.rng(), low=-((2 * size) // 3), high=(2 * size) // 3) + + def np_fun(index, shape): + # JAX's version outputs the same dtype as the input in the typical case + # where shape is weakly-typed. + out_dtype = index.dtype + # Adjust out-of-bounds behavior to match jax's documented behavior. + index = np.clip(index, -size, size - 1) + index = np.where(index < 0, index + size, index) + return [i.astype(out_dtype) for i in np.unravel_index(index, shape)] + + jnp_fun = jnp.unravel_index + args_maker = lambda: [rng(idx_shape, dtype), shape] + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + + @jtu.sample_product( + from_dtype=['int32', 'float32'], + to_dtype=['int32', 'float32', None], + use_method=[True, False], + ) + def testAstype(self, from_dtype, to_dtype, use_method): + rng = self.rng() + args_maker = lambda: [rng.randn(3, 4).astype(from_dtype)] + if (not use_method) and hasattr(np, "astype"): # Added in numpy 2.0 + np_op = lambda x: np.astype(x, to_dtype) + else: + np_op = lambda x: np.asarray(x).astype(to_dtype) + if use_method: + jnp_op = lambda x: jnp.asarray(x).astype(to_dtype) + else: + jnp_op = lambda x: jnp.astype(x, to_dtype) + self._CheckAgainstNumpy(np_op, jnp_op, args_maker) + self._CompileAndCheck(jnp_op, args_maker) + + @unittest.skip("Jax-metal don't support all dtypes") + def testAstypeInt4(self): + # Test converting from int4 to int8 + x = np.array([1, -2, -3, 4, -8, 7], dtype=jnp.int4) + args_maker = lambda: [x] + np_op = lambda x: np.asarray(x).astype(jnp.int8) + jnp_op = lambda x: jnp.asarray(x).astype(jnp.int8) + self._CheckAgainstNumpy(np_op, jnp_op, args_maker) + self._CompileAndCheck(jnp_op, args_maker) + + # Test converting from int8 to int4 + x = np.array([1, -2, -3, 4, -8, 7], dtype=jnp.int8) + args_maker = lambda: [x] + np_op = lambda x: np.asarray(x).astype(jnp.int4) + jnp_op = lambda x: jnp.asarray(x).astype(jnp.int4) + self._CheckAgainstNumpy(np_op, jnp_op, args_maker) + self._CompileAndCheck(jnp_op, args_maker) + + @jtu.sample_product( + shape=array_shapes, + dtype=all_dtypes, + ) + def testNbytes(self, shape, dtype): + rng = jtu.rand_default(self.rng()) + np_op = lambda x: np.asarray(x).nbytes + jnp_op = lambda x: jnp.asarray(x).nbytes + args_maker = lambda: [rng(shape, dtype)] + self._CheckAgainstNumpy(np_op, jnp_op, args_maker) + self._CompileAndCheck(jnp_op, args_maker) + + @jtu.sample_product( + shape=array_shapes, + dtype=all_dtypes, + ) + def testItemsize(self, shape, dtype): + rng = jtu.rand_default(self.rng()) + np_op = lambda x: np.asarray(x).itemsize + jnp_op = lambda x: jnp.asarray(x).itemsize + args_maker = lambda: [rng(shape, dtype)] + self._CheckAgainstNumpy(np_op, jnp_op, args_maker) + self._CompileAndCheck(jnp_op, args_maker) + + @jtu.sample_product( + shape=nonempty_array_shapes, + dtype=all_dtypes, + num_args=[0, 1, "all"], + use_tuple=[True, False] + ) + def testItem(self, shape, dtype, num_args, use_tuple): + rng = jtu.rand_default(self.rng()) + size = math.prod(shape) + + if num_args == 0: + args = () + elif num_args == 1: + args = (self.rng().randint(0, size),) + else: + args = tuple(self.rng().randint(0, s) for s in shape) + args = (args,) if use_tuple else args + + np_op = lambda x: np.asarray(x).item(*args) + jnp_op = lambda x: jnp.asarray(x).item(*args) + args_maker = lambda: [rng(shape, dtype)] + + if size != 1 and num_args == 0: + with self.assertRaises(ValueError): + jnp_op(*args_maker()) + else: + self._CheckAgainstNumpy(np_op, jnp_op, args_maker) + + @jtu.sample_product( + # Final dimension must be a multiple of 16 to ensure compatibilty of all dtype pairs. + shape=[(0,), (32,), (2, 16)], + a_dtype=all_dtypes, + dtype=(*all_dtypes, None) if config.enable_x64.value else all_dtypes, + ) + def testView(self, shape, a_dtype, dtype): + if jtu.test_device_matches(["tpu"]): + if jnp.dtype(a_dtype).itemsize in [1, 2] or jnp.dtype(dtype).itemsize in [1, 2]: + self.skipTest("arr.view() not supported on TPU for 8- or 16-bit types.") + # It is possible to fill bool arrays with arbitrary bits (not just 0/1 + # bytes), but the behavior is implementation-defined. We therefore only test + # the well-defined case. + rng = (jtu.rand_bool if a_dtype == np.bool_ else jtu.rand_fullrange)( + self.rng() + ) + args_maker = lambda: [rng(shape, a_dtype)] + np_op = lambda x: np.asarray(x).view(dtype) + jnp_op = lambda x: jnp.asarray(x).view(dtype) + # Above may produce signaling nans; ignore warnings from invalid values. + with np.errstate(invalid='ignore'): + self._CheckAgainstNumpy(np_op, jnp_op, args_maker) + self._CompileAndCheck(jnp_op, args_maker) + + @jtu.sample_product([ + {'a_dtype': a_dtype, 'dtype': dtype} + for a_dtype in all_dtypes + for dtype in all_dtypes + if np.dtype(a_dtype).itemsize == np.dtype(dtype).itemsize + ]) + def testViewScalar(self, a_dtype, dtype): + if jtu.test_device_matches(["tpu"]): + if jnp.dtype(a_dtype).itemsize in [1, 2] or jnp.dtype(dtype).itemsize in [1, 2]: + self.skipTest("arr.view() not supported on TPU for 8- or 16-bit types.") + rng = jtu.rand_fullrange(self.rng()) + args_maker = lambda: [jnp.array(rng((), a_dtype))] + np_op = lambda x: np.asarray(x).view(dtype) + jnp_op = lambda x: jnp.asarray(x).view(dtype) + # Above may produce signaling nans; ignore warnings from invalid values. + with np.errstate(invalid='ignore'): + self._CheckAgainstNumpy(np_op, jnp_op, args_maker) + self._CompileAndCheck(jnp_op, args_maker) + + def testPathologicalFloats(self): + args_maker = lambda: [np.array([ + 0b_0111_1111_1000_0000_0000_0000_0000_0000, # inf + 0b_1111_1111_1000_0000_0000_0000_0000_0000, # -inf + 0b_0111_1111_1100_0000_0000_0000_0000_0000, # qnan + 0b_1111_1111_1100_0000_0000_0000_0000_0000, # -qnan + 0b_0111_1111_1000_0000_0000_0000_0000_0001, # snan + 0b_1111_1111_1000_0000_0000_0000_0000_0001, # -snan + 0b_0111_1111_1000_0000_0000_1100_0000_0000, # nonstandard nan + 0b_1111_1111_1000_0000_0000_1100_0000_0000, # -nonstandard nan + 0b_0000_0000_0000_0000_0000_0000_0000_0000, # zero + 0b_1000_0000_0000_0000_0000_0000_0000_0000, # -zero + ], dtype='uint32')] + + np_op = lambda x: np.asarray(x).view('float32').view('uint32') + jnp_op = lambda x: jnp.asarray(x).view('float32').view('uint32') + + self._CheckAgainstNumpy(np_op, jnp_op, args_maker) + self._CompileAndCheck(jnp_op, args_maker) + + # TODO(mattjj): test other ndarray-like method overrides + + def testNpMean(self): + # from https://github.com/google/jax/issues/125 + x = jnp.eye(3, dtype=float) + 0. + ans = np.mean(x) + self.assertAllClose(ans, np.array(1./3), check_dtypes=False) + + def testArangeOnFloats(self): + np_arange = jtu.with_jax_dtype_defaults(np.arange) + # from https://github.com/google/jax/issues/145 + self.assertAllClose(np_arange(0.0, 1.0, 0.1), + jnp.arange(0.0, 1.0, 0.1)) + # from https://github.com/google/jax/issues/3450 + self.assertAllClose(np_arange(2.5), + jnp.arange(2.5)) + self.assertAllClose(np_arange(0., 2.5), + jnp.arange(0., 2.5)) + + def testArangeTypes(self): + # Test that arange() output type is equal to the default types. + int_ = dtypes.canonicalize_dtype(jnp.int_) + float_ = dtypes.canonicalize_dtype(jnp.float_) + + self.assertEqual(jnp.arange(10).dtype, int_) + self.assertEqual(jnp.arange(10.).dtype, float_) + self.assertEqual(jnp.arange(10, dtype='uint16').dtype, np.uint16) + #self.assertEqual(jnp.arange(10, dtype='bfloat16').dtype, jnp.bfloat16) + + self.assertEqual(jnp.arange(0, 10, 1).dtype, int_) + with jax.numpy_dtype_promotion('standard'): + self.assertEqual(jnp.arange(0, 10, 1.).dtype, float_) + self.assertEqual(jnp.arange(0., 10, 1).dtype, float_) + + @jtu.sample_product( + [dict(shape=shape, axis=axis) + for shape in nonzerodim_shapes + for axis in (NO_VALUE, None, *range(-len(shape), len(shape))) + ], + stable=[True, False], + dtype=all_dtypes, + ) + def testSort(self, dtype, shape, axis, stable): + rng = jtu.rand_some_equal(self.rng()) if stable else jtu.rand_some_inf_and_nan(self.rng()) + args_maker = lambda: [rng(shape, dtype)] + kwds = {} if axis is NO_VALUE else {'axis': axis} + + def np_fun(arr): + # Note: numpy sort fails on NaN and Inf values with bfloat16 + dtype = arr.dtype + if arr.dtype == jnp.bfloat16: + arr = arr.astype('float32') + # TODO(jakevdp): switch to stable=stable when supported by numpy. + result = np.sort(arr, kind='stable' if stable else None, **kwds) + with jtu.ignore_warning(category=RuntimeWarning, message='invalid value'): + return result.astype(dtype) + jnp_fun = partial(jnp.sort, stable=stable, **kwds) + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + + def testSortStableDescending(self): + # TODO(jakevdp): test directly against np.sort when descending is supported. + x = jnp.array([0, 1, jnp.nan, 0, 2, jnp.nan, -jnp.inf, jnp.inf]) + x_sorted = jnp.array([-jnp.inf, 0, 0, 1, 2, jnp.inf, jnp.nan, jnp.nan]) + argsorted_stable = jnp.array([6, 0, 3, 1, 4, 7, 2, 5]) + argsorted_rev_stable = jnp.array([2, 5, 7, 4, 1, 0, 3, 6]) + + self.assertArraysEqual(jnp.sort(x), x_sorted) + self.assertArraysEqual(jnp.sort(x, descending=True), lax.rev(x_sorted, [0])) + self.assertArraysEqual(jnp.argsort(x), argsorted_stable) + self.assertArraysEqual(jnp.argsort(x, descending=True), argsorted_rev_stable) + + @unittest.skip("Jax-metal don't support complex.") + @jtu.sample_product( + [dict(shape=shape, axis=axis) + for shape in one_dim_array_shapes + for axis in [None] + ], + dtype=all_dtypes, + ) + def testSortComplex(self, dtype, shape, axis): + rng = jtu.rand_some_equal(self.rng()) + args_maker = lambda: [rng(shape, dtype)] + self._CheckAgainstNumpy(np.sort_complex, jnp.sort_complex, args_maker, + check_dtypes=False) + self._CompileAndCheck(jnp.sort_complex, args_maker) + + @unittest.skip("Jax-metal fail to convert sort op.") + @jtu.sample_product( + [dict(shape=shape, axis=axis) + for shape in nonempty_nonscalar_array_shapes + for axis in (-1, *range(len(shape) - 1)) + ], + dtype=all_dtypes, + input_type=[np.array, tuple], + ) + def testLexsort(self, dtype, shape, input_type, axis): + rng = jtu.rand_some_equal(self.rng()) + args_maker = lambda: [input_type(rng(shape, dtype))] + jnp_op = lambda x: jnp.lexsort(x, axis=axis) + np_op = jtu.with_jax_dtype_defaults(lambda x: np.lexsort(x, axis=axis)) + self._CheckAgainstNumpy(np_op, jnp_op, args_maker) + self._CompileAndCheck(jnp_op, args_maker) + + @unittest.skip("JAX-metal crash.") + @jtu.sample_product( + [dict(shape=shape, axis=axis) + for shape in nonzerodim_shapes + for axis in (NO_VALUE, None, *range(-len(shape), len(shape))) + ], + dtype=all_dtypes, + ) + def testArgsort(self, dtype, shape, axis): + rng = jtu.rand_some_equal(self.rng()) + args_maker = lambda: [rng(shape, dtype)] + kwds = {} if axis is NO_VALUE else {'axis': axis} + + @jtu.with_jax_dtype_defaults + def np_fun(arr): + # Note: numpy sort fails on NaN and Inf values with bfloat16 + if arr.dtype == jnp.bfloat16: + arr = arr.astype('float32') + # TODO(jakevdp): switch to stable=True when supported by numpy. + return np.argsort(arr, kind='stable', **kwds) + jnp_fun = partial(jnp.argsort, stable=True, **kwds) + + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + + @unittest.skip("JAX-metal crash.") + @jtu.sample_product( + [dict(shape=shape, axis=axis) + for shape in nonempty_nonscalar_array_shapes + for axis in (NO_VALUE, None, *range(-len(shape), len(shape))) + ], + descending=[True, False], + dtype=all_dtypes, + ) + def testArgsortUnstable(self, dtype, shape, axis, descending): + # We cannot directly compare unstable argsorts, so instead check that indexed values match. + rng = jtu.rand_some_equal(self.rng()) + x = rng(shape, dtype) + kwds = {} if axis is NO_VALUE else {'axis': axis} + expected = jnp.sort(x, descending=descending, stable=False, **kwds) + indices = jnp.argsort(x, descending=descending, stable=False, **kwds) + if axis is None: + actual = jnp.ravel(x)[indices] + else: + actual = jnp.take_along_axis(x, indices, axis=-1 if axis is NO_VALUE else axis) + self.assertArraysEqual(actual, expected) + + @jtu.sample_product( + [{'shape': shape, 'axis': axis, 'kth': kth} + for shape in nonzerodim_shapes + for axis in range(-len(shape), len(shape)) + for kth in range(-shape[axis], shape[axis])], + dtype=default_dtypes, + ) + def testPartition(self, shape, dtype, axis, kth): + rng = jtu.rand_default(self.rng()) + arg = rng(shape, dtype) + jnp_output = jnp.partition(arg, axis=axis, kth=kth) + np_output = np.partition(arg, axis=axis, kth=kth) + + # Assert that pivot point is equal: + self.assertArraysEqual( + lax.index_in_dim(jnp_output, axis=axis, index=kth), + lax.index_in_dim(np_output, axis=axis, index=kth)) + + # Assert remaining values are correctly partitioned: + self.assertArraysEqual( + lax.sort(lax.slice_in_dim(jnp_output, start_index=0, limit_index=kth, axis=axis), dimension=axis), + lax.sort(lax.slice_in_dim(np_output, start_index=0, limit_index=kth, axis=axis), dimension=axis)) + self.assertArraysEqual( + lax.sort(lax.slice_in_dim(jnp_output, start_index=kth + 1, limit_index=shape[axis], axis=axis), dimension=axis), + lax.sort(lax.slice_in_dim(np_output, start_index=kth + 1, limit_index=shape[axis], axis=axis), dimension=axis)) + + #@unittest.skipIf(jtu.device_under_test=="METAL", "Jax-metal fail on empty dim shape.") + @jtu.sample_product( + [{'shape': shape, 'axis': axis, 'kth': kth} + for shape in nonempty_shapes# nonzerodim_shapes + for axis in range(-len(shape), len(shape)) + for kth in range(-shape[axis], shape[axis])], + dtype=default_dtypes, + ) + def testArgpartition(self, shape, dtype, axis, kth): + rng = jtu.rand_default(self.rng()) + arg = rng(shape, dtype) + + jnp_output = jnp.argpartition(arg, axis=axis, kth=kth) + np_output = np.argpartition(arg, axis=axis, kth=kth) + + # Assert that all indices are present + self.assertArraysEqual(jnp.sort(jnp_output, axis), np.sort(np_output, axis), check_dtypes=False) + + # Because JAX & numpy may treat duplicates differently, we must compare values + # rather than indices. + getvals = lambda x, ind: x[ind] + for ax in range(arg.ndim): + if ax != range(arg.ndim)[axis]: + getvals = jax.vmap(getvals, in_axes=ax, out_axes=ax) + jnp_values = getvals(arg, jnp_output) + np_values = getvals(arg, np_output) + + # Assert that pivot point is equal: + self.assertArraysEqual( + lax.index_in_dim(jnp_values, axis=axis, index=kth), + lax.index_in_dim(np_values, axis=axis, index=kth)) + + # Assert remaining values are correctly partitioned: + self.assertArraysEqual( + lax.sort(lax.slice_in_dim(jnp_values, start_index=0, limit_index=kth, axis=axis), dimension=axis), + lax.sort(lax.slice_in_dim(np_values, start_index=0, limit_index=kth, axis=axis), dimension=axis)) + self.assertArraysEqual( + lax.sort(lax.slice_in_dim(jnp_values, start_index=kth + 1, limit_index=shape[axis], axis=axis), dimension=axis), + lax.sort(lax.slice_in_dim(np_values, start_index=kth + 1, limit_index=shape[axis], axis=axis), dimension=axis)) + + @jtu.sample_product( + [dict(shifts=shifts, axis=axis) + for shifts, axis in [ + (3, None), + (1, 1), + ((3,), (0,)), + ((-2,), (-2,)), + ((1, 2), (0, -1)), + ((4, 2, 5, 5, 2, 4), None), + (100, None), + ] + ], + dtype=all_dtypes, + shape=[(3, 4), (3, 4, 5), (7, 4, 0)], + ) + def testRoll(self, shape, dtype, shifts, axis): + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(shape, dtype), np.array(shifts)] + jnp_op = partial(jnp.roll, axis=axis) + np_op = partial(np.roll, axis=axis) + self._CheckAgainstNumpy(np_op, jnp_op, args_maker) + self._CompileAndCheck(jnp_op, args_maker) + + @jtu.sample_product( + dtype=all_dtypes, + shape=[(1, 2, 3, 4)], + axis=[-3, 0, 2, 3], + start=[-4, -1, 2, 4], + ) + def testRollaxis(self, shape, dtype, start, axis): + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(shape, dtype)] + jnp_op = partial(jnp.rollaxis, axis=axis, start=start) + np_op = partial(np.rollaxis, axis=axis, start=start) + self._CheckAgainstNumpy(np_op, jnp_op, args_maker) + self._CompileAndCheck(jnp_op, args_maker) + + @unittest.skip("jax-metal generates a different result from cpu.") + @jtu.sample_product( + dtype=[np.uint8, np.bool_], + bitorder=['big', 'little'], + shape=[(1, 2, 3, 4)], + axis=[None, 0, 1, -2, -1], + ) + def testPackbits(self, shape, dtype, axis, bitorder): + rng = jtu.rand_some_zero(self.rng()) + args_maker = lambda: [rng(shape, dtype)] + jnp_op = partial(jnp.packbits, axis=axis, bitorder=bitorder) + np_op = partial(np.packbits, axis=axis, bitorder=bitorder) + self._CheckAgainstNumpy(np_op, jnp_op, args_maker) + self._CompileAndCheck(jnp_op, args_maker) + + @jtu.sample_product( + dtype=[np.uint8], + bitorder=['big', 'little'], + shape=[(1, 2, 3, 4)], + axis=[None, 0, 1, -2, -1], + count=[None, 20], + ) + def testUnpackbits(self, shape, dtype, axis, bitorder, count): + rng = jtu.rand_int(self.rng(), 0, 256) + args_maker = lambda: [rng(shape, dtype)] + jnp_op = partial(jnp.unpackbits, axis=axis, bitorder=bitorder, count=count) + np_op = partial(np.unpackbits, axis=axis, bitorder=bitorder, count=count) + self._CheckAgainstNumpy(np_op, jnp_op, args_maker) + self._CompileAndCheck(jnp_op, args_maker) + + #@unittest.skip("jax-metal generates a different result from cpu.") + @jtu.sample_product( + [dict(shape=shape, axis=axis) + for shape in [(3,), (3, 4), (3, 4, 5)] + for axis in itertools.chain(range(-len(shape), len(shape)), + [cast(Union[int, None], None)]) + ], + index_shape=scalar_shapes + [(3,), (2, 1, 3)], + dtype=all_dtypes, + index_dtype=int_dtypes, + #mode=[None, 'wrap', 'clip'], + mode=[None, 'wrap'], + ) + def testTake(self, shape, dtype, index_shape, index_dtype, axis, mode): + def args_maker(): + x = rng(shape, dtype) + i = rng_indices(index_shape, index_dtype) + return x, i + + rng = jtu.rand_default(self.rng()) + if mode is None: + rng_indices = jtu.rand_int(self.rng(), -shape[axis or 0], shape[axis or 0]) + else: + rng_indices = jtu.rand_int(self.rng(), -5, 5) + jnp_op = lambda x, i: jnp.take(x, i, axis=axis, mode=mode) + np_op = lambda x, i: np.take(x, i, axis=axis, mode=mode) + self._CheckAgainstNumpy(np_op, jnp_op, args_maker) + self._CompileAndCheck(jnp_op, args_maker) + + def testTakeEmpty(self): + np.testing.assert_array_equal( + jnp.array([], dtype=jnp.float32), + jnp.take(jnp.array([], jnp.float32), jnp.array([], jnp.int32))) + + np.testing.assert_array_equal( + jnp.ones((2, 0, 4), dtype=jnp.float32), + jnp.take(jnp.ones((2, 0, 4), dtype=jnp.float32), jnp.array([], jnp.int32), + axis=1)) + + with self.assertRaisesRegex(IndexError, "non-empty jnp.take"): + jnp.take(jnp.ones((2, 0, 4), dtype=jnp.float32), + jnp.array([0], jnp.int32), axis=1) + + def testTakeOptionalArgs(self): + x = jnp.arange(5.0) + ind = jnp.array([0, 2, 4, 6]) + expected = jnp.array([0.0, 2.0, 4.0, 10.0], dtype=x.dtype) + actual = jnp.take(x, ind, unique_indices=True, + indices_are_sorted=True, fill_value=10.0) + self.assertArraysEqual(expected, actual) + + @jtu.sample_product( + [dict(x_shape=x_shape, i_shape=i_shape, axis=axis) + for x_shape, i_shape in filter( + _shapes_are_equal_length, + filter(_shapes_are_broadcast_compatible, + itertools.combinations_with_replacement(nonempty_nonscalar_array_shapes, 2))) + for axis in itertools.chain(range(len(x_shape)), [-1], + [cast(Union[int, None], None)]) + ], + dtype=default_dtypes, + index_dtype=int_dtypes, + ) + def testTakeAlongAxis(self, x_shape, i_shape, dtype, index_dtype, axis): + rng = jtu.rand_default(self.rng()) + + i_shape = list(i_shape) + if axis is None: + i_shape = [math.prod(i_shape)] + else: + # Test the case where the size of the axis doesn't necessarily broadcast. + i_shape[axis] *= 3 + def args_maker(): + x = rng(x_shape, dtype) + n = math.prod(x_shape) if axis is None else x_shape[axis] + if np.issubdtype(index_dtype, np.unsignedinteger): + index_rng = jtu.rand_int(self.rng(), 0, n) + else: + index_rng = jtu.rand_int(self.rng(), -n, n) + i = index_rng(i_shape, index_dtype) + return x, i + + jnp_op = lambda x, i: jnp.take_along_axis(x, i, axis=axis) + + if hasattr(np, "take_along_axis"): + np_op = lambda x, i: np.take_along_axis(x, i, axis=axis) + self._CheckAgainstNumpy(np_op, jnp_op, args_maker) + self._CompileAndCheck(jnp_op, args_maker) + + def testTakeAlongAxisWithUint8IndicesDoesNotOverflow(self): + # https://github.com/google/jax/issues/5088 + h = jtu.rand_default(self.rng())((256, 256, 100), np.float32) + g = jtu.rand_int(self.rng(), 0, 100)((256, 256, 1), np.uint8) + q0 = jnp.take_along_axis(h, g, axis=-1) + q1 = np.take_along_axis( h, g, axis=-1) + np.testing.assert_equal(q0, q1) + + @unittest.skip("Jax-metal fail.") + def testTakeAlongAxisOutOfBounds(self): + x = jnp.arange(10, dtype=jnp.float32) + idx = jnp.array([-11, -10, -9, -5, -1, 0, 1, 5, 9, 10, 11]) + out = jnp.take_along_axis(x, idx, axis=0) + expected_fill = np.array([jnp.nan, 0, 1, 5, 9, 0, 1, 5, 9, jnp.nan, + jnp.nan], np.float32) + np.testing.assert_array_equal(expected_fill, out) + out = jnp.take_along_axis(x, idx, axis=0, mode="fill") + np.testing.assert_array_equal(expected_fill, out) + + expected_clip = np.array([0, 0, 1, 5, 9, 0, 1, 5, 9, 9, 9], np.float32) + out = jnp.take_along_axis(x, idx, axis=0, mode="clip") + np.testing.assert_array_equal(expected_clip, out) + + def testTakeAlongAxisRequiresIntIndices(self): + x = jnp.arange(5) + idx = jnp.array([3.], jnp.float32) + with self.assertRaisesRegex( + TypeError, + "take_along_axis indices must be of integer type, got float32"): + jnp.take_along_axis(x, idx, axis=0) + + def testTakeAlongAxisWithEmptyArgs(self): + # take_along_axis should allow us to gather an empty list of indices from + # an empty input axis without raising a shape error. + x = jnp.ones((4, 0, 3), dtype=jnp.int32) + np.testing.assert_array_equal(x, jnp.take_along_axis(x, x, axis=1)) + + @jtu.sample_product( + dtype=inexact_dtypes, + shape=[0, 5], + n=[2, 4], + increasing=[False, True], + ) + def testVander(self, shape, dtype, n, increasing): + rng = jtu.rand_default(self.rng()) + def np_fun(arg): + arg = arg.astype(np.float32) if dtype == jnp.bfloat16 else arg + return np.vander(arg, N=n, increasing=increasing) + jnp_fun = lambda arg: jnp.vander(arg, N=n, increasing=increasing) + args_maker = lambda: [rng([shape], dtype)] + # np.vander seems to return float64 for all floating types. We could obey + # those semantics, but they seem like a bug. + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False, + tol={np.float32: 1e-3, np.complex64: 1e-3}) + self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=False) + + @jtu.sample_product( + shape=array_shapes, + dtype=all_dtypes, + ) + def testNanToNum(self, shape, dtype): + rng = jtu.rand_some_inf_and_nan(self.rng()) + dtype = np.dtype(dtypes.canonicalize_dtype(dtype)).type + def np_fun(x): + if dtype == jnp.bfloat16: + x = np.where(np.isnan(x), dtype(0), x) + x = np.where(np.isposinf(x), jnp.finfo(dtype).max, x) + x = np.where(np.isneginf(x), jnp.finfo(dtype).min, x) + return x + else: + return np.nan_to_num(x).astype(dtype) + + args_maker = lambda: [rng(shape, dtype)] + check_dtypes = shape is not jtu.PYTHON_SCALAR_SHAPE + self._CheckAgainstNumpy(np_fun, jnp.nan_to_num, args_maker, + check_dtypes=check_dtypes) + self._CompileAndCheck(jnp.nan_to_num, args_maker, + check_dtypes=check_dtypes) + + @jtu.sample_product( + [dict(shapes=shapes, dtypes=dtypes) + for shapes, dtypes in ( + ((), ()), + (((7,),), (np.int32,)), + (((3,), (4,)), (np.int32, np.int32)), + (((3,), (1,), (4,)), (np.int32, np.int32, np.int32)), + ) + ], + ) + def testIx_(self, shapes, dtypes): + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(shape, dtype) + for shape, dtype in zip(shapes, dtypes)] + self._CheckAgainstNumpy(np.ix_, jnp.ix_, args_maker) + self._CompileAndCheck(jnp.ix_, args_maker) + + @jtu.sample_product( + dimensions=[(), (2,), (3, 0), (4, 5, 6)], + dtype=number_dtypes, + sparse=[True, False], + ) + def testIndices(self, dimensions, dtype, sparse): + def args_maker(): return [] + np_fun = partial(np.indices, dimensions=dimensions, + dtype=dtype, sparse=sparse) + jnp_fun = partial(jnp.indices, dimensions=dimensions, + dtype=dtype, sparse=sparse) + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + + @jtu.sample_product( + shape=nonzerodim_shapes, dtype=all_dtypes, + ) + def testWhereOneArgument(self, shape, dtype): + rng = jtu.rand_some_zero(self.rng()) + args_maker = lambda: [rng(shape, dtype)] + self._CheckAgainstNumpy(np.where, jnp.where, args_maker, check_dtypes=False) + + # JIT compilation requires specifying a size statically. Full test of + # this behavior is in testNonzeroSize(). + jnp_fun = lambda x: jnp.where(x, size=np.size(x) // 2) + + with jtu.ignore_warning(category=DeprecationWarning, + message="Calling nonzero on 0d arrays.*"): + self._CompileAndCheck(jnp_fun, args_maker) + + @jtu.sample_product( + shapes=filter(_shapes_are_broadcast_compatible, + itertools.combinations_with_replacement(all_shapes, 3)), + dtypes=itertools.combinations_with_replacement(all_dtypes, 3), + ) + def testWhereThreeArgument(self, shapes, dtypes): + rng = jtu.rand_default(self.rng()) + args_maker = self._GetArgsMaker(rng, shapes, dtypes) + def np_fun(cond, x, y): + return jtu.promote_like_jnp(partial(np.where, cond))(x, y) + with jtu.strict_promotion_if_dtypes_match(dtypes): + self._CheckAgainstNumpy(np_fun, jnp.where, args_maker) + self._CompileAndCheck(jnp.where, args_maker) + + def testWhereExtraCode(self): + def f(x): + return jnp.where(x > 0, x, -x) + + # Test no comparison literal True/False in jaxpr, and hence no comparison to + # literals + jaxpr = jax.make_jaxpr(jax.grad(f))(3.) + self.assertNotIn('False', str(jaxpr)) + self.assertNotIn('True', str(jaxpr)) + + def testWhereScalarPromotion(self): + x = jnp.where(jnp.array([True, False]), 3, + jnp.ones((2,), dtype=jnp.float32)) + self.assertEqual(x.dtype, np.dtype(np.float32)) + + @jtu.sample_product( + [dict(n=n, shapes=shapes) + for n in range(1, 3) + for shapes in filter( + _shapes_are_broadcast_compatible, + itertools.combinations_with_replacement(all_shapes, 2 * n + 1)) + ], + # To avoid forming the full product of shapes and dtypes we always sample + # maximal set of dtypes. + dtypes=itertools.combinations_with_replacement(all_dtypes, 3), + ) + def testSelect(self, n, shapes, dtypes): + dtypes = dtypes[:n+1] + rng = jtu.rand_default(self.rng()) + n = len(dtypes) - 1 + def args_maker(): + condlist = [rng(shape, np.bool_) for shape in shapes[:n]] + choicelist = [rng(shape, dtype) + for shape, dtype in zip(shapes[n:-1], dtypes[:n])] + default = rng(shapes[-1], dtypes[-1]) + return condlist, choicelist, default + # TODO(phawkins): float32/float64 type mismatches + @jax.numpy_dtype_promotion('standard') + def np_fun(condlist, choicelist, default): + choicelist = [x if jnp.result_type(x) != jnp.bfloat16 + else x.astype(np.float32) for x in choicelist] + dtype = jnp.result_type(default, *choicelist) + return np.select(condlist, + [np.asarray(x, dtype=dtype) for x in choicelist], + np.asarray(default, dtype=dtype)) + with jtu.strict_promotion_if_dtypes_match(dtypes): + self._CheckAgainstNumpy(np_fun, jnp.select, args_maker, + check_dtypes=False) + self._CompileAndCheck(jnp.select, args_maker, + rtol={np.float64: 1e-7, np.complex128: 1e-7}) + + def testIssue330(self): + x = jnp.full((1, 1), jnp.array([1])[0]) # doesn't crash + self.assertEqual(x[0, 0], 1) + + def testScalarDtypePromotion(self): + orig_numpy_result = (1 + np.eye(1, dtype=np.float32)).dtype + jax_numpy_result = (1 + jnp.eye(1, dtype=jnp.float32)).dtype + self.assertEqual(orig_numpy_result, jax_numpy_result) + + def testSymmetrizeDtypePromotion(self): + x = np.eye(3, dtype=np.float32) + orig_numpy_result = ((x + x.T) / 2).dtype + + x = jnp.eye(3, dtype=jnp.float32) + jax_numpy_result = ((x + x.T) / 2).dtype + self.assertEqual(orig_numpy_result, jax_numpy_result) + + # NOTE(mattjj): I disabled this test when removing lax._safe_mul because + # introducing the convention 0 * inf = 0 leads to silently wrong results in + # some cases. See this comment for details: + # https://github.com/google/jax/issues/1052#issuecomment-514083352 + # def testIssue347(self): + # # https://github.com/google/jax/issues/347 + # def test_fail(x): + # x = jnp.sqrt(jnp.sum(x ** 2, axis=1)) + # ones = jnp.ones_like(x) + # x = jnp.where(x > 0.5, x, ones) + # return jnp.sum(x) + # x = jnp.array([[1, 2], [3, 4], [0, 0]], dtype=jnp.float64) + # result = jax.grad(test_fail)(x) + # assert not np.any(np.isnan(result)) + + def testIssue453(self): + # https://github.com/google/jax/issues/453 + a = np.arange(6) + 1 + ans = jnp.reshape(a, (3, 2), order='F') + expected = np.reshape(a, (3, 2), order='F') + self.assertAllClose(ans, expected) + + @jtu.sample_product( + #dtype=[int, float, bool, complex], + dtype=[int, float, bool], + op=["atleast_1d", "atleast_2d", "atleast_3d"], + ) + def testAtLeastNdLiterals(self, dtype, op): + # Fixes: https://github.com/google/jax/issues/634 + np_fun = lambda arg: getattr(np, op)(arg).astype(dtypes.python_scalar_dtypes[dtype]) + jnp_fun = lambda arg: getattr(jnp, op)(arg) + args_maker = lambda: [dtype(2)] + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + + @jtu.sample_product( + shape=[(0,), (5,), (10,)], + dtype=int_dtypes, + weights=[True, False], + minlength=[0, 20], + length=[None, 8], + ) + def testBincount(self, shape, dtype, weights, minlength, length): + rng = jtu.rand_default(self.rng()) + args_maker = lambda: (rng(shape, dtype), (rng(shape, 'float32') if weights else None)) + + def np_fun(x, *args): + x = np.clip(x, 0, None) # jnp.bincount clips negative values to zero. + out = np.bincount(x, *args, minlength=minlength) + if length and length > out.size: + return np.pad(out, (0, length - out.size)) + return out[:length] + jnp_fun = partial(jnp.bincount, minlength=minlength, length=length) + + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False) + if length is not None: + self._CompileAndCheck(jnp_fun, args_maker) + + def testBincountNegative(self): + # Test that jnp.bincount ignores negative values. + x_rng = jtu.rand_int(self.rng(), -100, 100) + w_rng = jtu.rand_uniform(self.rng()) + shape = (1000,) + x = x_rng(shape, 'int32') + w = w_rng(shape, 'float32') + + xn = np.array(x) + xn[xn < 0] = 0 + wn = np.array(w) + np_result = np.bincount(xn[xn >= 0], wn[xn >= 0]) + jnp_result = jnp.bincount(x, w) + self.assertAllClose(np_result, jnp_result, check_dtypes=False) + + @jtu.sample_product( + input=[ + 3, + [3], + [np.array(3)], + [np.array([3])], + [[np.array(3)]], + [[np.array([3])]], + [3, 4, 5], + [ + [np.eye(2, dtype=np.int32) * 2, np.zeros((2, 3), dtype=np.int32)], + [np.ones((3, 2), dtype=np.int32), np.eye(3, dtype=np.int32) * 3], + ], + [np.array([1, 2, 3]), np.array([2, 3, 4]), 10], + [np.ones((2, 2), dtype=np.int32), np.zeros((2, 2), dtype=np.int32)], + [[np.array([1, 2, 3])], [np.array([2, 3, 4])]], + ], + ) + def testBlock(self, input): + args_maker = lambda: [input] + self._CheckAgainstNumpy(np.block, jnp.block, args_maker) + self._CompileAndCheck(jnp.block, args_maker) + + def testLongLong(self): + self.assertAllClose(np.int64(7), jax.jit(lambda x: x)(np.longlong(7))) + + @jtu.ignore_warning(category=UserWarning, + message="Explicitly requested dtype.*") + @jax.numpy_dtype_promotion('standard') # This test explicitly exercises mixed type promotion + def testArange(self): + # test cases inspired by dask tests at + # https://github.com/dask/dask/blob/main/dask/array/tests/test_creation.py#L92 + np_arange = jtu.with_jax_dtype_defaults(np.arange) + self.assertAllClose(jnp.arange(77), + np_arange(77)) + self.assertAllClose(jnp.arange(2, 13), + np_arange(2, 13)) + self.assertAllClose(jnp.arange(4, 21, 9), + np_arange(4, 21, 9)) + self.assertAllClose(jnp.arange(53, 5, -3), + np_arange(53, 5, -3)) + self.assertAllClose(jnp.arange(77, dtype=float), + np_arange(77, dtype=float)) + self.assertAllClose(jnp.arange(2, 13, dtype=int), + np_arange(2, 13, dtype=int)) + self.assertAllClose(jnp.arange(0, 1, -0.5), + np_arange(0, 1, -0.5)) + + self.assertRaises(TypeError, lambda: jnp.arange()) + + # test that jnp.arange(N) doesn't instantiate an ndarray + self.assertNotEqual(type(jnp.arange(77)), type(np.arange(77))) + self.assertEqual(type(jnp.arange(77)), type(lax.iota(np.int32, 77))) + + # test that jnp.arange(N, dtype=int32) doesn't instantiate an ndarray + self.assertNotEqual(type(jnp.arange(77, dtype=jnp.int32)), + type(np.arange(77, dtype=np.int32))) + self.assertEqual(type(jnp.arange(77, dtype=jnp.int32)), + type(lax.iota(np.int32, 77))) + + def testArangeJit(self): + ans = jax.jit(lambda: jnp.arange(5))() + expected = jtu.with_jax_dtype_defaults(np.arange)(5) + self.assertAllClose(ans, expected) + + @jtu.sample_product(args=[(5,), (0, 5)]) + def testArangeJaxpr(self, args): + jaxpr = jax.make_jaxpr(lambda: jnp.arange(*args))() + self.assertEqual(len(jaxpr.jaxpr.eqns), 1) + self.assertEqual(jaxpr.jaxpr.eqns[0].primitive, lax.iota_p) + + @unittest.skip("Jax-metal don't support complex.") + def testIssue830(self): + a = jnp.arange(4, dtype=jnp.complex64) + self.assertEqual(a.dtype, jnp.complex64) + + def testIssue728(self): + np_eye = jtu.with_jax_dtype_defaults(np.eye) + self.assertAllClose(jnp.eye(5000), np_eye(5000)) + self.assertEqual(0, np.sum(jnp.eye(1050) - np_eye(1050))) + + def testIssue746(self): + jnp.arange(12).reshape(3, 4) # doesn't crash + + def testIssue764(self): + x = jnp.linspace(190, 200, 4) + f = jax.grad(lambda x: jnp.sum(jnp.tanh(x))) + # Expected values computed with autograd in float64 precision. + expected = np.array([3.71669453e-165, 4.72999108e-168, 6.01954653e-171, + 7.66067839e-174], np.float64) + self.assertAllClose(f(x), expected, check_dtypes=False) + + # Test removed because tie_in is deprecated. + # def testIssue776(self): + # """Tests that the scatter-add transpose rule instantiates symbolic zeros.""" + # def f(u): + # y = jnp.ones_like(u, shape=10).at[np.array([2, 4, 5])].add(u) + # # The transpose rule for lax.tie_in returns a symbolic zero for its first + # # argument. + # return lax.tie_in(y, 7.) + + # self.assertAllClose(np.zeros(3,), jax.grad(f)(np.ones(3,))) + + # NOTE(mattjj): I disabled this test when removing lax._safe_mul because this + # is a numerical stability issue that should be solved with a custom jvp rule + # of the sigmoid function being differentiated here, not by safe_mul. + # def testIssue777(self): + # x = jnp.linspace(-200, 0, 4, dtype=np.float32) + # f = jax.grad(lambda x: jnp.sum(1 / (1 + jnp.exp(-x)))) + # self.assertAllClose(f(x), np.array([0., 0., 0., 0.25], dtype=np.float32)) + + #unittest.skip("Jax-metal fail on tanh with np.nan") + @jtu.sample_product( + dtype=float_dtypes, + op=("sqrt", "arccos", "arcsin", "arctan", "sin", "cos", "tan", + "sinh", "cosh", "tanh", "arccosh", "arcsinh", "arctanh", "exp", + "log", "expm1", "log1p"), + ) + def testMathSpecialFloatValues(self, op, dtype): + np_op = getattr(np, op) + np_op = jtu.ignore_warning(category=RuntimeWarning, + message="invalid value.*")(np_op) + np_op = jtu.ignore_warning(category=RuntimeWarning, + message="divide by zero.*")(np_op) + np_op = jtu.ignore_warning(category=RuntimeWarning, + message="overflow.*")(np_op) + + jnp_op = getattr(jnp, op) + dtype = np.dtype(dtypes.canonicalize_dtype(dtype)).type + for x in (-np.inf, -100., -2., -1., 0., 1., 2., 100., np.inf, + jnp.finfo(dtype).max, np.sqrt(jnp.finfo(dtype).max), + np.sqrt(jnp.finfo(dtype).max) * 2.): #np.nan + x = dtype(x) + expected = np_op(x) + actual = jnp_op(x) + tol = jtu.tolerance(dtype, {np.float32: 1e-3, np.float64: 1e-7}) + self.assertAllClose(expected, actual, atol=tol, + rtol=tol) + + def testIssue956(self): + self.assertRaises(TypeError, lambda: jnp.ndarray((1, 1))) + + def testIssue967(self): + self.assertRaises(TypeError, lambda: jnp.zeros(1.5)) + + @jtu.sample_product( + shape=[(5,), (10, 5), (4, 10)], + dtype=number_dtypes, + rowvar=[True, False], + ) + @jax.default_matmul_precision("float32") + def testCorrCoef(self, shape, dtype, rowvar): + rng = jtu.rand_default(self.rng()) + def args_maker(): + ok = False + while not ok: + x = rng(shape, dtype) + ok = not np.any(np.isclose(np.std(x), 0.0)) + return (x,) + np_fun = partial(np.corrcoef, rowvar=rowvar) + np_fun = jtu.ignore_warning( + category=RuntimeWarning, message="invalid value encountered.*")(np_fun) + jnp_fun = partial(jnp.corrcoef, rowvar=rowvar) + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False) + self._CompileAndCheck(jnp_fun, args_maker) + + @jtu.sample_product( + [dict(dtype=dtype, end_dtype=end_dtype, begin_dtype=begin_dtype, + shape=shape, begin_shape=begin_shape, end_shape=end_shape) + for dtype in number_dtypes + for end_dtype in [None] + [dtype] + for begin_dtype in [None] + [dtype] + for shape in [s for s in all_shapes if s != jtu.PYTHON_SCALAR_SHAPE] + for begin_shape in ( + [None] if begin_dtype is None + else [s for s in all_shapes if s != jtu.PYTHON_SCALAR_SHAPE]) + for end_shape in ( + [None] if end_dtype is None + else [s for s in all_shapes if s != jtu.PYTHON_SCALAR_SHAPE]) + ], + ) + def testEDiff1d(self, shape, dtype, end_shape, end_dtype, begin_shape, + begin_dtype): + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(shape, dtype), + (None if end_dtype is None else rng(end_shape, end_dtype)), + (None if begin_dtype is None else rng(begin_shape, begin_dtype))] + np_fun = lambda x, to_end, to_begin: np.ediff1d(x, to_end, to_begin) + jnp_fun = lambda x, to_end, to_begin: jnp.ediff1d(x, to_end, to_begin) + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + + def testEDiff1dWithDtypeCast(self): + rng = jtu.rand_default(self.rng()) + shape = jtu.NUMPY_SCALAR_SHAPE + dtype = jnp.float32 + end_dtype = jnp.int32 + args_maker = lambda: [rng(shape, dtype), rng(shape, end_dtype), rng(shape, dtype)] + np_fun = lambda x, to_end, to_begin: np.ediff1d(x, to_end, to_begin) + jnp_fun = lambda x, to_end, to_begin: jnp.ediff1d(x, to_end, to_begin) + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + + @jtu.sample_product( + shapes=[(), (5,), (5, 3)], + dtype=number_dtypes, + indexing=['xy', 'ij'], + sparse=[True, False], + ) + def testMeshGrid(self, shapes, dtype, indexing, sparse): + rng = jtu.rand_default(self.rng()) + args_maker = self._GetArgsMaker(rng, [(x,) for x in shapes], + [dtype] * len(shapes)) + np_fun = partial(np.meshgrid, indexing=indexing, sparse=sparse) + jnp_fun = partial(jnp.meshgrid, indexing=indexing, sparse=sparse) + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + + def testMgrid(self): + # wrap indexer for appropriate dtype defaults. + np_mgrid = _indexer_with_default_outputs(np.mgrid) + assertAllEqual = partial(self.assertAllClose, atol=0, rtol=0) + assertAllEqual(np_mgrid[()], jnp.mgrid[()]) + assertAllEqual(np_mgrid[:4], jnp.mgrid[:4]) + assertAllEqual(np_mgrid[:4,], jnp.mgrid[:4,]) + assertAllEqual(np_mgrid[:4], jax.jit(lambda: jnp.mgrid[:4])()) + assertAllEqual(np_mgrid[:5, :5], jnp.mgrid[:5, :5]) + assertAllEqual(np_mgrid[:3, :2], jnp.mgrid[:3, :2]) + assertAllEqual(np_mgrid[1:4:2], jnp.mgrid[1:4:2]) + assertAllEqual(np_mgrid[1:5:3, :5], jnp.mgrid[1:5:3, :5]) + assertAllEqual(np_mgrid[:3, :2, :5], jnp.mgrid[:3, :2, :5]) + assertAllEqual(np_mgrid[:3:2, :2, :5], jnp.mgrid[:3:2, :2, :5]) + # Corner cases + assertAllEqual(np_mgrid[:], jnp.mgrid[:]) + # When the step length is a complex number, because of float calculation, + # the values between jnp and np might slightly different. + atol = 1e-6 + rtol = 1e-6 + self.assertAllClose(np_mgrid[-1:1:5j], + jnp.mgrid[-1:1:5j], + atol=atol, + rtol=rtol) + self.assertAllClose(np_mgrid[3:4:7j], + jnp.mgrid[3:4:7j], + atol=atol, + rtol=rtol) + self.assertAllClose(np_mgrid[1:6:8j, 2:4], + jnp.mgrid[1:6:8j, 2:4], + atol=atol, + rtol=rtol) + # Non-integer steps + self.assertAllClose(np_mgrid[0:3.5:0.5], + jnp.mgrid[0:3.5:0.5], + atol=atol, + rtol=rtol) + self.assertAllClose(np_mgrid[1.3:4.2:0.3], + jnp.mgrid[1.3:4.2:0.3], + atol=atol, + rtol=rtol) + # abstract tracer value for jnp.mgrid slice + with self.assertRaisesRegex(core.ConcretizationTypeError, + "slice start of jnp.mgrid"): + jax.jit(lambda a, b: jnp.mgrid[a:b])(0, 2) + + def testOgrid(self): + # wrap indexer for appropriate dtype defaults. + np_ogrid = _indexer_with_default_outputs(np.ogrid) + def assertSequenceOfArraysEqual(xs, ys): + self.assertIsInstance(xs, (list, tuple)) + self.assertIsInstance(ys, (list, tuple)) + self.assertEqual(len(xs), len(ys)) + for x, y in zip(xs, ys): + self.assertArraysEqual(x, y) + + self.assertArraysEqual(np_ogrid[:5], jnp.ogrid[:5]) + self.assertArraysEqual(np_ogrid[:5], jax.jit(lambda: jnp.ogrid[:5])()) + self.assertArraysEqual(np_ogrid[1:7:2], jnp.ogrid[1:7:2]) + # List of arrays + assertSequenceOfArraysEqual(np_ogrid[:5,], jnp.ogrid[:5,]) + assertSequenceOfArraysEqual(np_ogrid[0:5, 1:3], jnp.ogrid[0:5, 1:3]) + assertSequenceOfArraysEqual(np_ogrid[1:3:2, 2:9:3], jnp.ogrid[1:3:2, 2:9:3]) + assertSequenceOfArraysEqual(np_ogrid[:5, :9, :11], jnp.ogrid[:5, :9, :11]) + # Corner cases + self.assertArraysEqual(np_ogrid[:], jnp.ogrid[:]) + # Complex number steps + atol = 1e-6 + rtol = 1e-6 + self.assertAllClose(np_ogrid[-1:1:5j], + jnp.ogrid[-1:1:5j], + atol=atol, + rtol=rtol) + # Non-integer steps + self.assertAllClose(np_ogrid[0:3.5:0.3], + jnp.ogrid[0:3.5:0.3], + atol=atol, + rtol=rtol) + self.assertAllClose(np_ogrid[1.2:4.8:0.24], + jnp.ogrid[1.2:4.8:0.24], + atol=atol, + rtol=rtol) + # abstract tracer value for ogrid slice + with self.assertRaisesRegex(core.ConcretizationTypeError, + "slice start of jnp.ogrid"): + jax.jit(lambda a, b: jnp.ogrid[a:b])(0, 2) + + def testR_(self): + a = np.arange(6).reshape((2,3)) + self.assertArraysEqual(np.r_[np.array([1,2,3]), 0, 0, np.array([4,5,6])], + jnp.r_[np.array([1,2,3]), 0, 0, np.array([4,5,6])]) + self.assertArraysEqual(np.r_['-1', a, a], jnp.r_['-1', a, a]) + + self.assertArraysEqual(np.r_['0,2', [1,2,3], [4,5,6]], jnp.r_['0,2', [1,2,3], [4,5,6]]) + self.assertArraysEqual(np.r_['0,2,0', [1,2,3], [4,5,6]], jnp.r_['0,2,0', [1,2,3], [4,5,6]]) + self.assertArraysEqual(np.r_['1,2,0', [1,2,3], [4,5,6]], jnp.r_['1,2,0', [1,2,3], [4,5,6]]) + # negative 1d axis start + self.assertArraysEqual(np.r_['0,4,-1', [1,2,3], [4,5,6]], jnp.r_['0,4,-1', [1,2,3], [4,5,6]]) + self.assertArraysEqual(np.r_['0,4,-2', [1,2,3], [4,5,6]], jnp.r_['0,4,-2', [1,2,3], [4,5,6]]) + + # matrix directives + with jtu.ignore_warning(category=PendingDeprecationWarning): + self.assertArraysEqual(np.r_['r',[1,2,3], [4,5,6]], jnp.r_['r',[1,2,3], [4,5,6]]) + self.assertArraysEqual(np.r_['c', [1, 2, 3], [4, 5, 6]], jnp.r_['c', [1, 2, 3], [4, 5, 6]]) + + # bad directive + with self.assertRaisesRegex(ValueError, "could not understand directive.*"): + jnp.r_["asdfgh",[1,2,3]] + # abstract tracer value for r_ slice + with self.assertRaisesRegex(core.ConcretizationTypeError, + "slice start of jnp.r_"): + jax.jit(lambda a, b: jnp.r_[a:b])(0, 2) + + # wrap indexer for appropriate dtype defaults. + np_r_ = _indexer_with_default_outputs(np.r_) + + # Complex number steps + atol = 1e-6 + rtol = 1e-6 + self.assertAllClose(np_r_[-1:1:6j], + jnp.r_[-1:1:6j], + atol=atol, + rtol=rtol) + with jax.numpy_dtype_promotion('standard'): # Requires dtype promotion. + self.assertAllClose(np_r_[-1:1:6j, [0]*3, 5, 6], + jnp.r_[-1:1:6j, [0]*3, 5, 6], + atol=atol, + rtol=rtol) + # Non-integer steps + self.assertAllClose(np_r_[1.2:4.8:0.24], + jnp.r_[1.2:4.8:0.24], + atol=atol, + rtol=rtol) + + def testC_(self): + a = np.arange(6).reshape((2, 3)) + self.assertArraysEqual(np.c_[np.array([1,2,3]), np.array([4,5,6])], + jnp.c_[np.array([1,2,3]), np.array([4,5,6])]) + self.assertArraysEqual(np.c_[np.array([[1,2,3]]), 0, 0, np.array([[4,5,6]])], + jnp.c_[np.array([[1,2,3]]), 0, 0, np.array([[4,5,6]])]) + self.assertArraysEqual(np.c_['-1', a, a], jnp.c_['-1', a, a]) + + self.assertArraysEqual(np.c_['0,2', [1,2,3], [4,5,6]], jnp.c_['0,2', [1,2,3], [4,5,6]]) + self.assertArraysEqual(np.c_['0,2,0', [1,2,3], [4,5,6]], jnp.c_['0,2,0', [1,2,3], [4,5,6]]) + self.assertArraysEqual(np.c_['1,2,0', [1,2,3], [4,5,6]], jnp.c_['1,2,0', [1,2,3], [4,5,6]]) + # negative 1d axis start + self.assertArraysEqual(np.c_['0,4,-1', [1,2,3], [4,5,6]], jnp.c_['0,4,-1', [1,2,3], [4,5,6]]) + self.assertArraysEqual(np.c_['0,4,-2', [1,2,3], [4,5,6]], jnp.c_['0,4,-2', [1,2,3], [4,5,6]]) + # matrix directives, avoid numpy deprecation warning + with jtu.ignore_warning(category=PendingDeprecationWarning): + self.assertArraysEqual(np.c_['r',[1,2,3], [4,5,6]], jnp.c_['r',[1,2,3], [4,5,6]]) + self.assertArraysEqual(np.c_['c', [1, 2, 3], [4, 5, 6]], jnp.c_['c', [1, 2, 3], [4, 5, 6]]) + + # bad directive + with self.assertRaisesRegex(ValueError, "could not understand directive.*"): + jnp.c_["asdfgh",[1,2,3]] + # abstract tracer value for c_ slice + with self.assertRaisesRegex(core.ConcretizationTypeError, + "slice start of jnp.c_"): + jax.jit(lambda a, b: jnp.c_[a:b])(0, 2) + + # wrap indexer for appropriate dtype defaults. + np_c_ = _indexer_with_default_outputs(np.c_) + + # Complex number steps + atol = 1e-6 + rtol = 1e-6 + self.assertAllClose(np_c_[-1:1:6j], + jnp.c_[-1:1:6j], + atol=atol, + rtol=rtol) + + # Non-integer steps + self.assertAllClose(np_c_[1.2:4.8:0.24], + jnp.c_[1.2:4.8:0.24], + atol=atol, + rtol=rtol) + + def testS_(self): + self.assertEqual(np.s_[1:2:20],jnp.s_[1:2:20]) + + def testIndex_exp(self): + self.assertEqual(np.index_exp[5:3:2j],jnp.index_exp[5:3:2j]) + + @jtu.sample_product( + start_shape=[(), (2,), (2, 2)], + stop_shape=[(), (2,), (2, 2)], + num=[0, 1, 2, 5, 20], + endpoint=[True, False], + retstep=[True, False], + # floating-point compute between jitted platforms and non-jit + rounding + # cause unavoidable variation in integer truncation for some inputs, so + # we currently only test inexact 'dtype' arguments. + dtype=inexact_dtypes + [None,], + ) + @jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion. + def testLinspace(self, start_shape, stop_shape, num, endpoint, retstep, dtype): + rng = jtu.rand_default(self.rng()) + # relax default tolerances slightly + tol = jtu.tolerance(dtype if dtype else np.float32) * 10 + args_maker = self._GetArgsMaker(rng, + [start_shape, stop_shape], + [dtype, dtype]) + start, stop = args_maker() + ndim = len(np.shape(start + stop)) + for axis in range(-ndim, ndim): + jnp_op = lambda start, stop: jnp.linspace( + start, stop, num, + endpoint=endpoint, retstep=retstep, dtype=dtype, axis=axis) + np_op = lambda start, stop: np.linspace( + start, stop, num, + endpoint=endpoint, retstep=retstep, dtype=dtype, axis=axis) + + self._CheckAgainstNumpy(np_op, jnp_op, args_maker, + check_dtypes=False, tol=tol) + self._CompileAndCheck(jnp_op, args_maker, + check_dtypes=False, atol=tol, rtol=tol) + + @jtu.sample_product(dtype=number_dtypes) + def testLinspaceEndpoints(self, dtype): + """Regression test for Issue #3014.""" + rng = jtu.rand_default(self.rng()) + endpoints = rng((2,), dtype) + out = jnp.linspace(*endpoints, 10, dtype=dtype) + self.assertAllClose(out[np.array([0, -1])], endpoints, rtol=0, atol=0) + + @jtu.sample_product( + start_shape=[(), (2,), (2, 2)], + stop_shape=[(), (2,), (2, 2)], + num=[0, 1, 2, 5, 20], + endpoint=[True, False], + base=[10.0, 2, np.e], + # skip 16-bit floats due to insufficient precision for the test. + dtype=jtu.dtypes.inexact + [None,], + ) + @jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion. + def testLogspace(self, start_shape, stop_shape, num, + endpoint, base, dtype): + if (dtype in int_dtypes and + jtu.test_device_matches(["gpu", "tpu"]) and + not config.enable_x64.value): + raise unittest.SkipTest("GPUx32 truncated exponentiation" + " doesn't exactly match other platforms.") + rng = jtu.rand_default(self.rng()) + # relax default tolerances slightly + tol = {np.float32: 1e-2, np.float64: 1e-6, np.complex64: 1e-3, np.complex128: 1e-6} + args_maker = self._GetArgsMaker(rng, + [start_shape, stop_shape], + [dtype, dtype]) + start, stop = args_maker() + ndim = len(np.shape(start + stop)) + for axis in range(-ndim, ndim): + jnp_op = lambda start, stop: jnp.logspace( + start, stop, num, endpoint=endpoint, base=base, dtype=dtype, axis=axis) + @jtu.ignore_warning(category=RuntimeWarning, + message="overflow encountered in power") + def np_op(start, stop): + return np.logspace(start, stop, num, endpoint=endpoint, + base=base, dtype=dtype, axis=axis) + self._CheckAgainstNumpy(np_op, jnp_op, args_maker, + check_dtypes=False, tol=tol) + if dtype in (inexact_dtypes + [None,]): + # Why do compiled and op-by-op float16 np.power numbers differ + # slightly more than expected? + atol = {np.float16: 1e-2} + self._CompileAndCheck(jnp_op, args_maker, + check_dtypes=False, atol=atol, rtol=tol) + + @jtu.sample_product( + [dict(start_shape=start_shape, stop_shape=stop_shape, axis=axis) + for start_shape in [(), (2,), (2, 2)] + for stop_shape in [(), (2,), (2, 2)] + for axis in range(-max(len(start_shape), len(stop_shape)), + max(len(start_shape), len(stop_shape))) + ], + num=[0, 1, 2, 5, 20], + endpoint=[True, False], + # NB: numpy's geomspace gives nonsense results on integer types + dtype=inexact_dtypes + [None,], + ) + @jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion. + def testGeomspace(self, start_shape, stop_shape, num, + endpoint, dtype, axis): + rng = jtu.rand_default(self.rng()) + # relax default tolerances slightly + tol = {dtypes.bfloat16: 2e-2, np.float16: 4e-3, np.float32: 2e-3, + np.float64: 1e-14, np.complex64: 2e-3, np.complex128: 1e-14} + def args_maker(): + """Test the set of inputs np.geomspace is well-defined on.""" + start, stop = self._GetArgsMaker(rng, + [start_shape, stop_shape], + [dtype, dtype])() + # np.geomspace can't handle differently ranked tensors + # w. negative numbers! + start, stop = jnp.broadcast_arrays(start, stop) + if dtype in complex_dtypes: + return start, stop + # to avoid NaNs, non-complex start and stop cannot + # differ in sign, elementwise + start = start * jnp.sign(start) * jnp.sign(stop) + return start, stop + start, stop = args_maker() + def jnp_op(start, stop): + return jnp.geomspace(start, stop, num, endpoint=endpoint, dtype=dtype, + axis=axis) + def np_op(start, stop): + start = start.astype(np.float32) if dtype == jnp.bfloat16 else start + stop = stop.astype(np.float32) if dtype == jnp.bfloat16 else stop + return np.geomspace( + start, stop, num, endpoint=endpoint, + dtype=dtype if dtype != jnp.bfloat16 else np.float32, + axis=axis).astype(dtype) + self._CheckAgainstNumpy(np_op, jnp_op, args_maker, + check_dtypes=False, tol=tol) + if dtype in (inexact_dtypes + [None,]): + self._CompileAndCheck(jnp_op, args_maker, + check_dtypes=False, atol=tol, rtol=tol) + + def testDisableNumpyRankPromotionBroadcasting(self): + with jax.numpy_rank_promotion('allow'): + jnp.ones(2) + jnp.ones((1, 2)) # works just fine + + with jax.numpy_rank_promotion('raise'): + self.assertRaises(ValueError, lambda: jnp.ones(2) + jnp.ones((1, 2))) + jnp.ones(2) + 3 # don't want to raise for scalars + + with jax.numpy_rank_promotion('warn'): + self.assertWarnsRegex(UserWarning, "Following NumPy automatic rank promotion for add on " + r"shapes \(2,\) \(1, 2\).*", lambda: jnp.ones(2) + jnp.ones((1, 2))) + jnp.ones(2) + 3 # don't want to warn for scalars + + @unittest.skip("Test fails on CI, perhaps due to JIT caching") + def testDisableNumpyRankPromotionBroadcastingDecorator(self): + with jax.numpy_rank_promotion("allow"): + jnp.ones(2) + jnp.ones((1, 2)) # works just fine + + with jax.numpy_rank_promotion("raise"): + self.assertRaises(ValueError, lambda: jnp.ones(2) + jnp.ones((1, 2))) + jnp.ones(2) + 3 # don't want to raise for scalars + + with jax.numpy_rank_promotion("warn"): + self.assertWarnsRegex(UserWarning, "Following NumPy automatic rank promotion for add on " + r"shapes \(2,\) \(1, 2\).*", lambda: jnp.ones(2) + jnp.ones((1, 2))) + jnp.ones(2) + 3 # don't want to warn for scalars + + def testStackArrayArgument(self): + # tests https://github.com/google/jax/issues/1271 + @jax.jit + def foo(x): + return jnp.stack(x) + foo(np.zeros(2)) # doesn't crash + + @jax.jit + def foo(x): + return jnp.concatenate(x) + foo(np.zeros((2, 2))) # doesn't crash + + def testReluGradientConstants(self): + # This is a regression test that verifies that constants associated with the + # gradient of np.maximum (from lax._balanced_eq) aren't hoisted into the + # outermost jaxpr. This was producing some large materialized constants for + # every relu activation in a model. + def body(i, xy): + x, y = xy + y = y + jax.grad(lambda z: jnp.sum(jnp.maximum(z, 0.)))(x) + return x, y + + f = lambda y: lax.fori_loop(0, 5, body, (y, y)) + jaxpr = jax.make_jaxpr(f)(np.zeros((3, 4), np.float32)) + self.assertFalse( + any(np.array_equal(x, np.full((3, 4), 2., dtype=np.float32)) + for x in jaxpr.consts)) + + @jtu.sample_product( + [dict(from_shape=from_shape, to_shape=to_shape) + for from_shape, to_shape in [ + [(1, 3), (4, 3)], + [(3,), (2, 1, 3)], + [(3,), (3, 3)], + [(1,), (3,)], + [(1,), 3], + ] + ], + ) + def testBroadcastTo(self, from_shape, to_shape): + rng = jtu.rand_default(self.rng()) + args_maker = self._GetArgsMaker(rng, [from_shape], [np.float32]) + np_op = lambda x: np.broadcast_to(x, to_shape) + jnp_op = lambda x: jnp.broadcast_to(x, to_shape) + self._CheckAgainstNumpy(np_op, jnp_op, args_maker) + self._CompileAndCheck(jnp_op, args_maker) + + @jtu.sample_product( + [dict(shapes=shapes, broadcasted_shape=broadcasted_shape) + for shapes, broadcasted_shape in [ + [[], ()], + [[()], ()], + [[(1, 3), (4, 3)], (4, 3)], + [[(3,), (2, 1, 3)], (2, 1, 3)], + [[(3,), (3, 3)], (3, 3)], + [[(1,), (3,)], (3,)], + [[(1,), 3], (3,)], + [[(6, 7), (5, 6, 1), (7,), (5, 1, 7)], (5, 6, 7)], + [[[1], [0, 1]], (0, 1)], + [[(1,), np.array([0, 1])], (0, 1)], + ] + ], + ) + def testBroadcastShapes(self, shapes, broadcasted_shape): + # Test against np.broadcast_shapes once numpy 1.20 is minimum required version + np.testing.assert_equal(jnp.broadcast_shapes(*shapes), broadcasted_shape) + + def testBroadcastToIssue1522(self): + self.assertRaisesRegex( + ValueError, "Incompatible shapes for broadcasting: .*", + lambda: jnp.broadcast_to(np.ones((2, 3)), (1, 3))) + + def testBroadcastToIntIssue1548(self): + self.assertAllClose(jnp.broadcast_to(1, (3, 2)), np.ones((3, 2)), + check_dtypes=False) + + def testBroadcastToOnScalar(self): + self.assertIsInstance(jnp.broadcast_to(10.0, ()), jax.Array) + self.assertIsInstance(np.broadcast_to(10.0, ()), np.ndarray) + + def testPrecision(self): + + ones_1d = np.ones((2,)) + ones_2d = np.ones((2, 2)) + ones_3d = np.ones((2, 2, 2)) + HIGHEST = lax.Precision.HIGHEST + + jtu.assert_dot_precision(None, jnp.dot, ones_1d, ones_1d) + jtu.assert_dot_precision( + HIGHEST, + partial(jnp.dot, precision=HIGHEST), + ones_1d, ones_1d) + jtu.assert_dot_precision( + HIGHEST, + partial(jnp.dot, precision=HIGHEST), + ones_3d, ones_3d) + jtu.assert_dot_precision( + HIGHEST, + partial(jnp.matmul, precision=HIGHEST), + ones_2d, ones_2d) + jtu.assert_dot_precision( + HIGHEST, + partial(jnp.vdot, precision=HIGHEST), + ones_1d, ones_1d) + jtu.assert_dot_precision( + HIGHEST, + partial(jnp.vecdot, precision=HIGHEST), + ones_1d, ones_1d) + jtu.assert_dot_precision( + HIGHEST, + partial(jnp.tensordot, axes=2, precision=HIGHEST), + ones_2d, ones_2d) + jtu.assert_dot_precision( + HIGHEST, + partial(jnp.tensordot, axes=(0, 0), precision=HIGHEST), + ones_1d, ones_1d) + jtu.assert_dot_precision( + HIGHEST, + partial(jnp.tensordot, axes=((0,), (0,)), precision=HIGHEST), + ones_1d, ones_1d) + jtu.assert_dot_precision( + HIGHEST, + partial(jnp.einsum, 'i,i', precision=HIGHEST), + ones_1d, ones_1d) + jtu.assert_dot_precision( + HIGHEST, + partial(jnp.einsum, 'ij,ij', precision=HIGHEST), + ones_2d, ones_2d) + jtu.assert_dot_precision( + HIGHEST, + partial(jnp.inner, precision=HIGHEST), + ones_1d, ones_1d) + + @jtu.sample_product( + funcname=['inner', 'matmul', 'dot', 'vdot', 'tensordot', 'vecdot'] + ) + def testPreferredElementType(self, funcname): + func = getattr(jnp, funcname) + kwargs = dict(axes=0) if funcname == 'tensordot' else {} + + ones_i32 = np.ones(2, dtype='int32') + ones_f32 = np.ones(2, dtype='float32') + + with jax.numpy_dtype_promotion('strict'): + jtu.assert_dot_preferred_element_type('int32', func, ones_i32, ones_i32, **kwargs) + jtu.assert_dot_preferred_element_type('float32', func, ones_f32, ones_f32, **kwargs) + jtu.assert_dot_preferred_element_type('bfloat16', func, ones_f32, ones_f32, **kwargs, + preferred_element_type='bfloat16') + with jax.numpy_dtype_promotion('standard'): + jtu.assert_dot_preferred_element_type('float32', func, ones_i32, ones_f32, **kwargs) + + @jtu.sample_product( + [dict(shape=shape, varargs=varargs, axis=axis) + for shape in [(10,), (10, 15), (10, 15, 20)] + for _num_axes in range(len(shape)) + for varargs in itertools.combinations(range(1, len(shape) + 1), _num_axes) + for axis in itertools.combinations(range(len(shape)), _num_axes) + ], + dtype=inexact_dtypes, + ) + def testGradient(self, shape, varargs, axis, dtype): + rng = jtu.rand_default(self.rng()) + args_maker = self._GetArgsMaker(rng, [shape], [dtype]) + jnp_fun = lambda y: jnp.gradient(y, *varargs, axis=axis) + np_fun = lambda y: np.gradient(y, *varargs, axis=axis) + self._CheckAgainstNumpy( + np_fun, jnp_fun, args_maker, check_dtypes=False) + self._CompileAndCheck(jnp_fun, args_maker) + + def testZerosShapeErrors(self): + # see https://github.com/google/jax/issues/1822 + self.assertRaisesRegex( + TypeError, + "Shapes must be 1D sequences of concrete values of integer type.*", + lambda: jnp.zeros(1.)) + self.assertRaisesRegex( + TypeError, + r"Shapes must be 1D sequences of concrete values of integer type.*\n" + "If using `jit`, try using `static_argnums` or applying `jit` to " + "smaller subfunctions.", + lambda: jax.jit(jnp.zeros)(2)) + + def testTraceMethod(self): + x = self.rng().randn(3, 4).astype(jnp.float_) + self.assertAllClose(x.trace(), jnp.array(x).trace()) + self.assertAllClose(x.trace(), jax.jit(lambda y: y.trace())(x)) + + def testIntegerPowersArePrecise(self): + # See https://github.com/google/jax/pull/3036 + # Checks if the squares of float32 integers have no numerical errors. + # It should be satisfied with all integers less than sqrt(2**24). + x = jnp.arange(-2**12, 2**12, dtype=jnp.int32) + np.testing.assert_array_equal(jnp.square(x.astype(jnp.float32)), x * x) + np.testing.assert_array_equal(x.astype(jnp.float32) ** 2, x * x) + + # Similarly for cubes. + x = jnp.arange(-2**8, 2**8, dtype=jnp.int32) + np.testing.assert_array_equal(x.astype(jnp.float32) ** 3, x * x * x) + + x = np.arange(10, dtype=np.float32) + for i in range(10): + self.assertAllClose(x.astype(jnp.float32) ** i, x ** i, + check_dtypes=False) + + def testToBytes(self): + v = np.arange(12, dtype=np.int32).reshape(3, 4) + for order in ['C', 'F']: + self.assertEqual(jnp.asarray(v).tobytes(order), v.tobytes(order)) + + def testToBytesJitError(self): + v = np.arange(12, dtype=np.int32).reshape(3, 4) + f = jax.jit(lambda x: x.tobytes()) + msg = r".*The tobytes\(\) method was called on traced array" + with self.assertRaisesRegex(core.ConcretizationTypeError, msg): + f(v) + + def testToList(self): + v = np.arange(12, dtype=np.int32).reshape(3, 4) + self.assertEqual(jnp.asarray(v).tolist(), v.tolist()) + + def testToListJitError(self): + v = np.arange(12, dtype=np.int32).reshape(3, 4) + f = jax.jit(lambda x: x.tolist()) + msg = r".*The tolist\(\) method was called on traced array" + with self.assertRaisesRegex(core.ConcretizationTypeError, msg): + f(v) + + def testArangeConcretizationError(self): + msg = r"It arose in the jnp.arange argument '{}'".format + with self.assertRaisesRegex(core.ConcretizationTypeError, msg('stop')): + jax.jit(jnp.arange)(3) + + with self.assertRaisesRegex(core.ConcretizationTypeError, msg('start')): + jax.jit(lambda start: jnp.arange(start, 3))(0) + + with self.assertRaisesRegex(core.ConcretizationTypeError, msg('stop')): + jax.jit(lambda stop: jnp.arange(0, stop))(3) + + @jtu.sample_product(dtype=[None] + float_dtypes) + def testArange64Bit(self, dtype): + # Test that jnp.arange uses 64-bit arithmetic to define its range, even if the + # output has another dtype. The issue here is that if python scalar inputs to + # jnp.arange are cast to float32 before the range is computed, it changes the + # number of elements output by the range. It's unclear whether this was deliberate + # behavior in the initial implementation, but it's behavior that downstream users + # have come to rely on. + args = (1.2, 4.8, 0.24) + + # Ensure that this test case leads to differing lengths if cast to float32. + self.assertLen(np.arange(*args), 15) + self.assertLen(np.arange(*map(np.float32, args)), 16) + + jnp_fun = lambda: jnp.arange(*args, dtype=dtype) + np_fun = jtu.with_jax_dtype_defaults(lambda: np.arange(*args, dtype=dtype), dtype is None) + args_maker = lambda: [] + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + + def testIssue2347(self): + # https://github.com/google/jax/issues/2347 + object_list = list[tuple[jnp.array, float, float, jnp.array, bool]] + self.assertRaises(TypeError, jnp.array, object_list) + + np_object_list = np.array(object_list) + self.assertRaises(TypeError, jnp.array, np_object_list) + + @unittest.skip("JAX-metal don't support complex type yet.") + @jtu.sample_product( + [dict(shapes=shapes, dtypes=dtypes) + for shapes in filter( + _shapes_are_broadcast_compatible, + itertools.combinations_with_replacement(all_shapes, 2)) + for dtypes in itertools.product( + *(_valid_dtypes_for_shape(s, complex_dtypes + [None]) for s in shapes)) + ], + ) + @jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion. + def testLogaddexpComplex(self, shapes, dtypes): + @jtu.ignore_warning(category=RuntimeWarning, message="invalid value.*") + def np_op(x1, x2): + return np.log(np.exp(x1) + np.exp(x2)) + + rng = jtu.rand_some_nan(self.rng()) + args_maker = lambda: tuple(rng(shape, dtype) for shape, dtype in zip(shapes, dtypes)) + if jtu.test_device_matches(["tpu"]): + tol = {np.complex64: 1e-3, np.complex128: 1e-10} + else: + tol = {np.complex64: 1e-5, np.complex128: 1e-14} + + with jtu.strict_promotion_if_dtypes_match(dtypes): + self._CheckAgainstNumpy(jtu.promote_like_jnp(np_op), jnp.logaddexp, args_maker, tol=tol) + self._CompileAndCheck(jnp.logaddexp, args_maker, rtol=tol, atol=tol) + + @unittest.skip("JAX-metal don't support complex type yet.") + @jtu.sample_product( + [dict(shapes=shapes, dtypes=dtypes) + for shapes in filter( + _shapes_are_broadcast_compatible, + itertools.combinations_with_replacement(all_shapes, 2)) + for dtypes in itertools.product( + *(_valid_dtypes_for_shape(s, complex_dtypes + [None]) for s in shapes)) + ], + ) + @jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion. + def testLogaddexp2Complex(self, shapes, dtypes): + @jtu.ignore_warning(category=RuntimeWarning, message="invalid value.*") + def np_op(x1, x2): + return np.log2(np.exp2(x1) + np.exp2(x2)) + + rng = jtu.rand_some_nan(self.rng()) + args_maker = lambda: tuple(rng(shape, dtype) for shape, dtype in zip(shapes, dtypes)) + if jtu.test_device_matches(["tpu"]): + tol = {np.complex64: 1e-3, np.complex128: 1e-10} + else: + tol = {np.complex64: 1e-5, np.complex128: 1e-14} + + with jtu.strict_promotion_if_dtypes_match(dtypes): + self._CheckAgainstNumpy(jtu.promote_like_jnp(np_op), jnp.logaddexp2, args_maker, tol=tol) + self._CompileAndCheck(jnp.logaddexp2, args_maker, rtol=tol, atol=tol) + + def testDefaultDtypes(self): + precision = config.default_dtype_bits.value + assert precision in ['32', '64'] + self.assertEqual(jnp.bool_, np.bool_) + self.assertEqual(jnp.int_, np.int32 if precision == '32' else np.int64) + self.assertEqual(jnp.uint, np.uint32 if precision == '32' else np.uint64) + self.assertEqual(jnp.float_, np.float32 if precision == '32' else np.float64) + self.assertEqual(jnp.complex_, np.complex64 if precision == '32' else np.complex128) + + def testFromBuffer(self): + buf = b'\x01\x02\x03' + expected = np.frombuffer(buf, dtype='uint8') + actual = jnp.frombuffer(buf, dtype='uint8') + self.assertArraysEqual(expected, actual) + + def testFromFunction(self): + def f(x, y, z): + return x + 2 * y + 3 * z + shape = (3, 4, 5) + expected = np.fromfunction(f, shape=shape) + actual = jnp.fromfunction(f, shape=shape) + self.assertArraysEqual(expected, actual, check_dtypes=False) + + def testFromString(self): + s = "1,2,3" + expected = np.fromstring(s, sep=',', dtype=int) + actual = jnp.fromstring(s, sep=',', dtype=int) + self.assertArraysEqual(expected, actual) + + @jtu.sample_product( + a_shape=nonempty_nonscalar_array_shapes, + v_shape=nonempty_shapes, + dtype=jtu.dtypes.all, + ) + def testPlace(self, a_shape, v_shape, dtype): + rng = jtu.rand_default(self.rng()) + mask_rng = jtu.rand_bool(self.rng()) + + def args_maker(): + a = rng(a_shape, dtype) + m = mask_rng(a_shape, bool) + v = rng(v_shape, dtype) + return a, m, v + + def np_fun(a, m, v): + a_copy = a.copy() + np.place(a_copy, m, v) + return a_copy + + jnp_fun = partial(jnp.place, inplace=False) + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + + @jtu.sample_product( + a_shape=nonempty_nonscalar_array_shapes, + i_shape=all_shapes, + v_shape=all_shapes, + dtype=jtu.dtypes.all, + mode=[None, 'wrap', 'clip'], + ) + def testPut(self, mode, a_shape, i_shape, v_shape, dtype): + size = math.prod(a_shape) + if math.prod(i_shape) > size: + self.skipTest("too many indices") + rng = jtu.rand_default(self.rng()) + # Must test unique integers, because overlapping updates in + # JAX have implementation-defined order + idx_rng = jtu.rand_unique_int(self.rng(), size) + + def args_maker(): + a = rng(a_shape, dtype) + i = idx_rng(i_shape, np.int32) + v = rng(v_shape, dtype) + # put some indices out of range without duplicating indices + if mode == "clip" and i.size: + np.put(i, np.argmax(i), size + 2) + np.put(i, np.argmin(i), -2) + if mode == "wrap" and i.size: + np.put(i, 0, np.take(i, 0) + size) + return a, i, v + + def np_fun(a, i, v): + a_copy = a.copy() + np.put(a_copy, i, v, mode=mode) + return a_copy + + jnp_fun = partial(jnp.put, mode=mode, inplace=False) + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + + def test_rot90_error(self): + with self.assertRaisesRegex( + ValueError, + "rot90 requires its first argument to have ndim at least two, " + "but got first argument of"): + jnp.rot90(jnp.ones(2)) + + @parameterized.named_parameters( + ('ones', jnp.ones), + ('zeros', jnp.zeros), + ('empty', jnp.empty)) + def test_error_hint(self, fn): + with self.assertRaisesRegex( + TypeError, + r"Did you accidentally write `jax\.numpy\..*?\(2, 3\)` " + r"when you meant `jax\.numpy\..*?\(\(2, 3\)\)`"): + fn(2, 3) + + @jtu.sample_product( + dtype=jtu.dtypes.all, + kind=['bool', 'signed integer', 'unsigned integer', 'integral', + 'real floating', 'complex floating', 'numeric'] + ) + def test_isdtype(self, dtype, kind): + # Full tests also in dtypes_test.py; here we just compare against numpy + jax_result = jnp.isdtype(dtype, kind) + if jtu.numpy_version() < (2, 0, 0) or dtype == dtypes.bfloat16: + # just a smoke test + self.assertIsInstance(jax_result, bool) + else: + numpy_result = np.isdtype(dtype, kind) + self.assertEqual(jax_result, numpy_result) + + +from jaxlib import xla_client +@unittest.skipIf(metal_plugin == None, "Tests require jax-metal plugin.") +class ReportedIssuesTests(jtu.JaxTestCase): + def dispatchOn(self, args, func, device=jax.devices('cpu')[0]): + deviceArgs = [] + for arg in args: + deviceArgs.append(jax.device_put(arg, device)) + return func(*deviceArgs) + + @staticmethod + def compile_and_exec(module, args, run_on_cpu=False): + backend = jax.lib.xla_bridge.get_backend('METAL') + if (run_on_cpu): + backend = jax.lib.xla_bridge.get_backend('cpu') + executables = backend.compile(module) + return xla_client.execute_with_python_values(executables, args, backend) + + @staticmethod + def jax_metal_supported(target_ver): + if metal_plugin is None or not hasattr(metal_plugin, 'version'): + return False + curr_ver = metal_plugin.version() + if hasattr(jtu, 'parse_version'): + return jtu.parse_version(curr_ver) >= jtu.parse_version(target_ver) + return False + + + #https://github.com/google/jax/issues/16420 + def test_broadcast_dim(self): + x = jnp.arange(2) + f = lambda x : jax.lax.broadcast_in_dim(x, (2, 2), (0,)) + res = f(x) + print(res) + res_cpu = self.dispatchOn([x],f) + jtu.check_eq(res, res_cpu) + f = lambda x : jax.lax.broadcast_in_dim(x, (2, 2), (1,)) + res = f(x) + print(res) + res_cpu = self.dispatchOn([x],f) + jtu.check_eq(res, res_cpu) + + def test_identity(self): + x = jnp.identity(4) + jtu.check_eq(x, np.identity(4)) + + def test_triu(self): + x = np.ones((4,4)) + res = jnp.triu(x) + jtu.check_eq(res, np.triu(x)) + + #https://github.com/google/jax/issues/16471 + def test_matmul_1d(self): + x = np.array(np.random.rand(3, 3)) + y = np.array(np.random.rand(3)) + z = np.array(np.random.rand(3)) + res = jnp.dot(y, z) + self.assertArraysAllClose(res, np.dot(y,z)) + res = jnp.dot(x, y) + self.assertArraysAllClose(res, np.dot(x,y)) + + #https://github.com/google/jax/issues/17175 + def test_indexing(self): + x = jnp.array([[1,2,3],[4,5,6],[7,8,9],[10,11,12]], dtype=jnp.float32) + @jax.vmap + def f(i): + return x[i] + f = jax.jit(f) + idx = jnp.array([1,1,2,2,0]) + res = f(idx) + jtu.check_eq(res, np.array([[4., 5., 6.], [4., 5., 6.], [7., 8., 9.], [7., 8., 9.], [1., 2., 3.]])) + + #https://github.com/google/jax/issues/17344 + def test_take_along_axis(self): + @jax.jit + def f(): + idx = jnp.array([[0],[0],[0]]) + x = jnp.array([[0.3756883, 0.05820537, 0.7399422, 0.45242703], + [0.5848844, 0.18772626, 0.47942543, 0.20703673], + [0.1071583, 0.26139486, 0.25664794, 0.8109596]]) + return jnp.take_along_axis(x, idx, axis=1) + jtu.check_eq(f(), self.dispatchOn([], f)) + + #https://github.com/google/jax/issues/17590 + def test_in1d(self): + a = np.array([123,2,4]) + b = np.array([123,1]) + res = jnp.isin(a,b) + jtu.check_eq(res, np.isin(a, b)) + + def test_indexing_update(self): + x = jnp.array([[1,2,3],[4,5,6],[7,8,9],[10,11,12]], dtype=jnp.float32) + @jax.vmap + def f(x): + return x.at[0].set(1.0) + f = jax.jit(f) + res = f(x) + jtu.check_eq(res, np.array([[1., 2., 3.], [1., 5., 6.,], [1., 8., 9.], [1., 11., 12.]])) + + #https://github.com/google/jax/issues/16326 + def test_indexing_update2(self): + @jax.jit + def f(x, r): + x = x.at[:, 0].set(x[:, 0] / r) + return x + x = jnp.array([[1.0, 2.0], [3.0, 4.0]]) + fx = f(x, jnp.array([10.0])) + jtu.check_eq(fx, np.array([[0.1, 2.0], [0.3, 4.]])) + + def test_gather_ir(self): + ir = ''' +#loc = loc(unknown) +module @jit_gather attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main(%arg0: tensor<3x2x3xf32> {mhlo.sharding = "{replicated}"} loc(unknown), %arg1: tensor<3x2xi32> {mhlo.sharding = "{replicated}"} loc(unknown)) -> tensor<3x2xf32> { + %0 = "stablehlo.gather"(%arg0, %arg1) {dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array} : (tensor<3x2x3xf32>, tensor<3x2xi32>) -> tensor<3x2xf32> loc(#loc2) + return %0 : tensor<3x2xf32> loc(#loc) + } loc(#loc) +} loc(#loc) +#loc1 = loc("/Users/shuhan/Code/jax-metal/tests/lax_numpy_indexing_test.py":1156:0) +#loc2 = loc("jit(gather)/jit(main)/gather[dimension_numbers=GatherDimensionNumbers(offset_dims=(1,), collapsed_slice_dims=(0, 2), start_index_map=(0, 2)) slice_sizes=(1, 2, 1) unique_indices=False indices_are_sorted=False mode=GatherScatterMode.CLIP fill_value=None]"(#loc1)) + ''' + data = np.array([[[0.6369617, 0.26978672, 0.04097353], + [0.01652764, 0.8132702, 0.91275555]], + [[0.60663575, 0.72949654, 0.543625 ], + [0.9350724, 0.81585354, 0.0027385 ]], + [[0.8574043, 0.03358557, 0.72965544], + [0.17565562, 0.8631789, 0.5414612 ]]], dtype=np.float32) + index = np.array([[1, 0],[2, 1],[0, 2]], dtype=np.int32) + res = ReportedIssuesTests.compile_and_exec(ir, [data, index]) + res_ref = ReportedIssuesTests.compile_and_exec(ir, [data, index], run_on_cpu = True) + print(res) + jtu.check_eq(res, res_ref) + + #https://github.com/google/jax/issues/16366 + def test_pad_interior_1(self): + if not ReportedIssuesTests.jax_metal_supported('0.0.6'): + raise unittest.SkipTest("jax-metal version doesn't support it.") + ir = ''' + module @jit_gather attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main(%arg0: tensor<128x7x7x64xf32> {mhlo.sharding = "{replicated}"} loc(unknown), %arg1: tensor {mhlo.sharding = "{replicated}"} loc(unknown)) -> tensor<128x15x15x64xf32> { + %206 = "mhlo.pad"(%arg0, %arg1) {edge_padding_high = dense<[0, 1, 1, 0]> : tensor<4xi64>, edge_padding_low = dense<[0, 1, 1, 0]> : tensor<4xi64>, interior_padding = dense<[0, 1, 1, 0]> : tensor<4xi64>} : (tensor<128x7x7x64xf32>, tensor) -> tensor<128x15x15x64xf32> + return %206 : tensor<128x15x15x64xf32> + } + } + ''' + data = np.random.rand(128,7,7,64).astype(np.float32) + padding = np.array(0.5, dtype=np.float32) + res = ReportedIssuesTests.compile_and_exec(ir, [data, padding]) + res_ref = ReportedIssuesTests.compile_and_exec(ir, [data, padding], run_on_cpu = True) + jtu.check_eq(res, res_ref) + + def test_pad_interior_2(self): + if not ReportedIssuesTests.jax_metal_supported('0.0.6'): + raise unittest.SkipTest("jax-metal version doesn't support it.") + batch = 2 + seq_len = 8 + num_decode = 32 + + seq = np.random.randint(size=(batch, seq_len, num_decode), low=0, high=256, dtype=np.uint8) + res = jnp.cumsum(seq, axis=-1) + res_ref = np.cumsum(seq, axis=-1, dtype=np.uint8) + jtu.check_eq(res, res_ref) + + @unittest.expectedFailure + def test_issue_pad(self): + ir = ''' + module @jit_dummy attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main(%arg0: tensor<2x2xf32> {mhlo.sharding = "{replicated}"} loc(unknown), %arg1: tensor<3x4xf32> {mhlo.sharding = "{replicated}"} loc(unknown)) -> tensor<4x4xf32> { + %12 = stablehlo.slice %arg0 [0:1, 1:2] : (tensor<2x2xf32>) -> tensor<1x1xf32> + %13 = stablehlo.reshape %12 : (tensor<1x1xf32>) -> tensor + %14 = stablehlo.pad %arg1, %13, low = [0, 0], high = [1, 0], interior = [0, 0] : (tensor<3x4xf32>, tensor) -> tensor<4x4xf32> + return %14 : tensor<4x4xf32> + } + } + ''' + data = np.array([[1, 3], [1, 3]], dtype=np.float32) + input = np.random.rand(3,4).astype(np.float32) + res = ReportedIssuesTests.compile_and_exec(ir, [data, input]) + res_ref = ReportedIssuesTests.compile_and_exec(ir, [data, input], run_on_cpu = True) + jtu.check_eq(res, res_ref) + +if __name__ == "__main__": + absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/lax_numpy_einsum_test.py b/tests/lax_numpy_einsum_test.py index 92259c8f4342..7397cf3e4ee8 100644 --- a/tests/lax_numpy_einsum_test.py +++ b/tests/lax_numpy_einsum_test.py @@ -27,8 +27,7 @@ import jax.numpy as jnp import jax._src.test_util as jtu -from jax import config -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() class EinsumTest(jtu.JaxTestCase): @@ -393,6 +392,46 @@ def test_einsum_mixed_precision(self, lhs_dtype, rhs_dtype): f_np = jtu.promote_like_jnp(partial(np.einsum, 'a,a->a')) self._CheckAgainstNumpy(f_np, f_jax, args_maker, check_dtypes=True) + @jtu.sample_product( + [ + {'signature': 'i->', 'shapes': [(3,)]}, + {'signature': 'ii->i', 'shapes': [(4, 4)]}, + {'signature': 'ij,jk', 'shapes': [(3, 4), (4, 3)]}, + {'signature': 'ij,jkl,klm', 'shapes': [(2, 2), (2, 3, 4), (3, 4, 2)]}, + ], + optimize=[True, False, 'optimal', 'auto', 'greedy', 'eager'], + dtype=[np.dtype('float32')], + ) + @jtu.skip_on_devices('tpu') + def test_einsum_optimization_modes(self, signature, shapes, optimize, dtype): + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(shape, dtype) for shape in shapes] + jnp_fun = partial(jnp.einsum, signature, optimize=optimize) + np_fun = partial(np.einsum, signature) + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, rtol=1E-4) + self._CompileAndCheck(jnp_fun, args_maker, rtol=1E-4) + + @jtu.sample_product( + [ + {'signature': 'i->', 'shapes': [(3,)]}, + {'signature': 'ii->i', 'shapes': [(4, 4)]}, + {'signature': 'ij,jk', 'shapes': [(3, 4), (4, 3)]}, + {'signature': 'ij,jkl,klm', 'shapes': [(2, 2), (2, 3, 4), (3, 4, 2)]}, + ], + optimize=[True, False, 'optimal', 'auto', 'greedy', 'eager'], + dtype=[np.dtype('float32')], + ) + @jtu.skip_on_devices('tpu') + def test_einsum_path_optimization_modes(self, signature, shapes, optimize, dtype): + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(shape, dtype) for shape in shapes] + def jnp_fun(*args, signature=signature, optimize=optimize): + path, _ = jnp.einsum_path(signature, *args, optimize=optimize) + return jnp.einsum(signature, *args, optimize=path) + np_fun = partial(np.einsum, signature) + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, rtol=1E-4) + self._CompileAndCheck(jnp_fun, args_maker, rtol=1E-4) + if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/lax_numpy_indexing_test.py b/tests/lax_numpy_indexing_test.py index f5029226fbe7..d18244062da6 100644 --- a/tests/lax_numpy_indexing_test.py +++ b/tests/lax_numpy_indexing_test.py @@ -918,6 +918,12 @@ def testSimpleIndexingUsesSlice(self): self.assertEqual(jaxpr.jaxpr.eqns[-2].primitive, lax.slice_p) self.assertEqual(jaxpr.jaxpr.eqns[-1].primitive, lax.squeeze_p) + # Indexing with `Ellipsis` is not lowered to `gather`. + jaxpr = jax.make_jaxpr(lambda x: x[..., 0])(jnp.ones((3, 4, 5))) + self.assertLen((jaxpr.jaxpr.eqns), 2) + self.assertEqual(jaxpr.jaxpr.eqns[-2].primitive, lax.slice_p) + self.assertEqual(jaxpr.jaxpr.eqns[-1].primitive, lax.squeeze_p) + # Simple reverses lower to lax.rev_p jaxpr = jax.make_jaxpr(lambda x: x[:, ::-1])(jnp.ones((3, 4))) self.assertEqual(len(jaxpr.jaxpr.eqns), 1) @@ -1047,6 +1053,34 @@ def testScalarBooleanIndexing(self, shape, idx): jnp_fun = lambda x: jnp.asarray(x)[idx] self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + @jtu.sample_product( + shape=[(2, 3, 4, 5)], + update_ndim=[0, 1, 2], + idx=[ + np.index_exp[True], + np.index_exp[False], + np.index_exp[..., True], + np.index_exp[..., False], + np.index_exp[0, :2, True], + np.index_exp[0, :2, False], + np.index_exp[:2, 0, True], + np.index_exp[:2, 0, False], + np.index_exp[:2, np.array([0, 2]), True], + np.index_exp[np.array([1, 0]), :, True], + np.index_exp[True, :, True, :, np.array(True)], + ] + ) + def testScalarBoolUpdate(self, shape, idx, update_ndim): + update_shape = np.zeros(shape)[idx].shape[-update_ndim:] + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(shape, np.int32), rng(update_shape, np.int32)] + def np_fun(x, update): + x = np.array(x, copy=True) + x[idx] = update + return x + jnp_fun = lambda x, update: jnp.asarray(x).at[idx].set(update) + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + def testFloatIndexingError(self): BAD_INDEX_TYPE_ERROR = "Indexer must have integer or boolean type, got indexer with type" with self.assertRaisesRegex(TypeError, BAD_INDEX_TYPE_ERROR): diff --git a/tests/lax_numpy_operators_test.py b/tests/lax_numpy_operators_test.py index c1c04935f5fe..4c31684e145f 100644 --- a/tests/lax_numpy_operators_test.py +++ b/tests/lax_numpy_operators_test.py @@ -221,7 +221,7 @@ def op_record(name, nargs, dtypes, shapes, rng_factory, diff_modes, op_record("fmod", 2, default_dtypes, all_shapes, jtu.rand_some_nan, []), op_record("heaviside", 2, default_dtypes, all_shapes, jtu.rand_default, [], inexact=True), - op_record("hypot", 2, default_dtypes, all_shapes, jtu.rand_default, [], + op_record("hypot", 2, real_dtypes, all_shapes, jtu.rand_default, [], inexact=True), op_record("kron", 2, number_dtypes, nonempty_shapes, jtu.rand_default, []), op_record("outer", 2, number_dtypes, all_shapes, jtu.rand_default, []), @@ -423,6 +423,17 @@ def _shapes_are_equal_length(shapes): return all(len(shape) == len(shapes[0]) for shape in shapes[1:]) +def _get_testcase_name(index, params): + dtypes = "_".join(str(dt.__name__) for dt in params['dtypes']) + name = params['op_name'] if "op_name" in params else params["name"] + return f"{index}_{name}_{dtypes}" + + +def _create_named_parameters(iter_params): + for i, params in enumerate(iter_params): + yield dict(params, **{'testcase_name': _get_testcase_name(i, params)}) + + class JaxNumpyOperatorTests(jtu.JaxTestCase): """Tests for LAX-backed Numpy operators.""" @@ -436,7 +447,7 @@ def f(): for a in out] return f - @parameterized.parameters(itertools.chain.from_iterable( + @parameterized.named_parameters(_create_named_parameters(itertools.chain.from_iterable( jtu.sample_product_testcases( [dict(op_name=rec.name, rng_factory=rec.rng_factory, check_dtypes=rec.check_dtypes, tolerance=rec.tolerance, @@ -449,7 +460,7 @@ def f(): *(_valid_dtypes_for_shape(s, rec.dtypes) for s in shapes))], ) for rec in itertools.chain(JAX_ONE_TO_ONE_OP_RECORDS, - JAX_COMPOUND_OP_RECORDS))) + JAX_COMPOUND_OP_RECORDS)))) @jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion. def testOp(self, op_name, rng_factory, shapes, dtypes, check_dtypes, tolerance, inexact, kwargs, alias): @@ -477,7 +488,7 @@ def testOp(self, op_name, rng_factory, shapes, dtypes, check_dtypes, self._CompileAndCheck(jnp_op, args_maker, check_dtypes=check_dtypes, atol=tol, rtol=tol) - @parameterized.parameters(itertools.chain.from_iterable( + @parameterized.named_parameters(_create_named_parameters(itertools.chain.from_iterable( jtu.sample_product_testcases( [dict(name=rec.name, rng_factory=rec.rng_factory, tol=rec.tolerance)], [dict(shapes=shapes, dtypes=dtypes) @@ -487,7 +498,7 @@ def testOp(self, op_name, rng_factory, shapes, dtypes, check_dtypes, for dtypes in itertools.product( *(_valid_dtypes_for_shape(s, rec.dtypes) for s in shapes))], ) - for rec in JAX_OPERATOR_OVERLOADS)) + for rec in JAX_OPERATOR_OVERLOADS))) @jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion. def testOperatorOverload(self, name, rng_factory, shapes, dtypes, tol): rng = rng_factory(self.rng()) @@ -498,7 +509,7 @@ def testOperatorOverload(self, name, rng_factory, shapes, dtypes, tol): with jtu.strict_promotion_if_dtypes_match(dtypes): self._CompileAndCheck(fun, args_maker, atol=tol, rtol=tol) - @parameterized.parameters(itertools.chain.from_iterable( + @parameterized.named_parameters(_create_named_parameters(itertools.chain.from_iterable( jtu.sample_product_testcases( [dict(name=rec.name, rng_factory=rec.rng_factory, op_tolerance=rec.tolerance)], @@ -509,7 +520,7 @@ def testOperatorOverload(self, name, rng_factory, shapes, dtypes, tol): for dtypes in itertools.product( *(_valid_dtypes_for_shape(s, rec.dtypes) for s in shapes))], ) - for rec in JAX_RIGHT_OPERATOR_OVERLOADS)) + for rec in JAX_RIGHT_OPERATOR_OVERLOADS))) @jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion. def testRightOperatorOverload(self, name, rng_factory, shapes, dtypes, op_tolerance): @@ -579,7 +590,7 @@ def testBinaryOperatorDefers(self, op_name, rng_factory, dtype): with self.assertRaises(TypeError): op(arg, other) - @parameterized.parameters(itertools.chain.from_iterable( + @parameterized.named_parameters(_create_named_parameters(itertools.chain.from_iterable( jtu.sample_product_testcases( [dict(name=rec.name, rng_factory=rec.rng_factory, alias=rec.alias)], shapes=filter( @@ -589,7 +600,7 @@ def testBinaryOperatorDefers(self, op_name, rng_factory, dtype): _dtypes_are_compatible_for_bitwise_ops, itertools.combinations_with_replacement(rec.dtypes, rec.nargs)), ) - for rec in JAX_BITWISE_OP_RECORDS)) + for rec in JAX_BITWISE_OP_RECORDS))) @jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion. def testBitwiseOp(self, name, rng_factory, shapes, dtypes, alias): np_op = getattr(np, name) if hasattr(np, name) else getattr(np, alias) diff --git a/tests/lax_numpy_reducers_test.py b/tests/lax_numpy_reducers_test.py index 0b8bbd79ef11..588368cd8553 100644 --- a/tests/lax_numpy_reducers_test.py +++ b/tests/lax_numpy_reducers_test.py @@ -507,14 +507,12 @@ def testReductionWithRepeatedAxisError(self): for weights_shape in ([None, shape] if axis is None or len(shape) == 1 or isinstance(axis, tuple) else [None, (shape[axis],), shape]) ], - keepdims=([False, True] if numpy_version >= (1, 23) else [None]), + keepdims=[False, True], returned=[False, True], ) def testAverage(self, shape, dtype, axis, weights_shape, returned, keepdims): rng = jtu.rand_default(self.rng()) - kwds = dict(returned=returned) - if keepdims is not None: - kwds['keepdims'] = keepdims + kwds = dict(returned=returned, keepdims=keepdims) if weights_shape is None: np_fun = lambda x: np.average(x, axis, **kwds) jnp_fun = lambda x: jnp.average(x, axis, **kwds) @@ -527,50 +525,68 @@ def testAverage(self, shape, dtype, axis, weights_shape, returned, keepdims): tol = {dtypes.bfloat16: 2e-1, np.float16: 1e-2, np.float32: 1e-5, np.float64: 1e-12, np.complex64: 1e-5} check_dtypes = shape is not jtu.PYTHON_SCALAR_SHAPE - if numpy_version == (1, 23, 0) and keepdims and weights_shape is not None and axis is not None: - # Known failure: https://github.com/numpy/numpy/issues/21850 - pass - else: - try: - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, - check_dtypes=check_dtypes, tol=tol) - except ZeroDivisionError: - self.skipTest("don't support checking for ZeroDivisionError") + try: + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, + check_dtypes=check_dtypes, tol=tol) + except ZeroDivisionError: + self.skipTest("don't support checking for ZeroDivisionError") self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=check_dtypes, rtol=tol, atol=tol) @jtu.sample_product( + test_fns=[(np.var, jnp.var), (np.std, jnp.std)], shape=[(5,), (10, 5)], dtype=all_dtypes, out_dtype=inexact_dtypes, axis=[None, 0, -1], - ddof=[0, 1, 2], + ddof_correction=[(0, None), (1, None), (1, 0), (0, 0), (0, 1), (0, 2)], keepdims=[False, True], ) - def testVar(self, shape, dtype, out_dtype, axis, ddof, keepdims): + def testStdOrVar(self, test_fns, shape, dtype, out_dtype, axis, ddof_correction, keepdims): + np_fn, jnp_fn = test_fns + ddof, correction = ddof_correction rng = jtu.rand_default(self.rng()) args_maker = self._GetArgsMaker(rng, [shape], [dtype]) @jtu.ignore_warning(category=RuntimeWarning, message="Degrees of freedom <= 0 for slice.") @jtu.ignore_warning(category=NumpyComplexWarning) def np_fun(x): + # setup ddof and correction kwargs excluding case when correction is not specified + ddof_correction_kwargs = {"ddof": ddof} + if correction is not None: + key = "correction" if numpy_version >= (2, 0) else "ddof" + ddof_correction_kwargs[key] = correction # Numpy fails with bfloat16 inputs - out = np.var(x.astype(np.float32 if dtype == dtypes.bfloat16 else dtype), + out = np_fn(x.astype(np.float32 if dtype == dtypes.bfloat16 else dtype), dtype=np.float32 if out_dtype == dtypes.bfloat16 else out_dtype, - axis=axis, ddof=ddof, keepdims=keepdims) + axis=axis, keepdims=keepdims, **ddof_correction_kwargs) return out.astype(out_dtype) - jnp_fun = partial(jnp.var, dtype=out_dtype, axis=axis, ddof=ddof, keepdims=keepdims) + jnp_fun = partial(jnp_fn, dtype=out_dtype, axis=axis, ddof=ddof, correction=correction, + keepdims=keepdims) tol = jtu.tolerance(out_dtype, {np.float16: 1e-1, np.float32: 1e-3, np.float64: 1e-3, np.complex128: 1e-6}) if (jnp.issubdtype(dtype, jnp.complexfloating) and not jnp.issubdtype(out_dtype, jnp.complexfloating)): - self.assertRaises(ValueError, lambda: jnp_fun(*args_maker())) + self.assertRaises(ValueError, jnp_fun, *args_maker()) + elif (correction is not None and ddof != 0): + self.assertRaises(ValueError, jnp_fun, *args_maker()) else: self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, tol=tol) self._CompileAndCheck(jnp_fun, args_maker, rtol=tol, atol=tol) + @jtu.sample_product( + jnp_fn=[jnp.var, jnp.std], + size=[0, 1, 2] + ) + def testStdOrVarLargeDdofReturnsNan(self, jnp_fn, size): + # test for https://github.com/google/jax/issues/21330 + x = jnp.arange(size) + self.assertTrue(np.isnan(jnp_fn(x, ddof=size))) + self.assertTrue(np.isnan(jnp_fn(x, ddof=size + 1))) + self.assertTrue(np.isnan(jnp_fn(x, ddof=size + 2))) + @jtu.sample_product( shape=[(5,), (10, 5)], dtype=all_dtypes, @@ -662,6 +678,7 @@ def testCov(self, shape, dtype, y_shape, y_dtype, rowvar, ddof, bias, fweights, [dict(a_shape=a_shape, axis=axis) for a_shape, axis in ( ((7,), None), + ((6, 7,), None), ((47, 7), 0), ((47, 7), ()), ((4, 101), 1), @@ -716,6 +733,7 @@ def testPercentilePrecision(self): [dict(a_shape=a_shape, axis=axis) for a_shape, axis in ( ((7,), None), + ((6, 7,), None), ((47, 7), 0), ((4, 101), 1), ) @@ -768,5 +786,64 @@ def test_f16_mean(self, dtype): self.assertAllClose(expected, actual, atol=0) + @jtu.sample_product( + [dict(shape=shape, axis=axis) + for shape in all_shapes + for axis in list( + range(-len(shape), len(shape)) + ) + ([None] if len(shape) == 1 else [])], + [dict(dtype=dtype, out_dtype=out_dtype) + for dtype in (all_dtypes+[None]) + for out_dtype in ( + complex_dtypes if np.issubdtype(dtype, np.complexfloating) + else all_dtypes + ) + ], + include_initial=[False, True], + ) + @jtu.ignore_warning(category=NumpyComplexWarning) + @jax.numpy_dtype_promotion('standard') # This test explicitly exercises mixed type promotion + def testCumulativeSum(self, shape, axis, dtype, out_dtype, include_initial): + rng = jtu.rand_some_zero(self.rng()) + + def np_mock_op(x, axis=None, dtype=None, include_initial=False): + axis = axis or 0 + out = np.cumsum(x, axis=axis, dtype=dtype or x.dtype) + if include_initial: + zeros_shape = list(x.shape) + zeros_shape[axis] = 1 + out = jnp.concat([jnp.zeros(zeros_shape, dtype=out.dtype), out], axis=axis) + return out + + + # We currently "cheat" to ensure we have JAX arrays, not NumPy arrays as + # input because we rely on JAX-specific casting behavior + args_maker = lambda: [jnp.array(rng(shape, dtype))] + np_op = getattr(np, "cumulative_sum", np_mock_op) + kwargs = dict(axis=axis, dtype=out_dtype, include_initial=include_initial) + + np_fun = lambda x: np_op(x, **kwargs) + jnp_fun = lambda x: jnp.cumulative_sum(x, **kwargs) + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + + + @jtu.sample_product( + shape=filter(lambda x: len(x) != 1, all_shapes), dtype=all_dtypes, + include_initial=[False, True]) + def testCumulativeSumErrors(self, shape, dtype, include_initial): + rng = jtu.rand_some_zero(self.rng()) + x = rng(shape, dtype) + rank = jnp.asarray(x).ndim + if rank == 0: + msg = r"The input must be non-scalar to take" + with self.assertRaisesRegex(ValueError, msg): + jnp.cumulative_sum(x, include_initial=include_initial) + elif rank > 1: + msg = r"The input array has rank \d*, however" + with self.assertRaisesRegex(ValueError, msg): + jnp.cumulative_sum(x, include_initial=include_initial) + + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index e05a538e3d34..45cc177fbfd1 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -173,12 +173,35 @@ def f(): for a in out] return f + + @jtu.sample_product( + [dict(shape=shape, axis=axis) + for shape in all_shapes + for axis in list(range(-len(shape), len(shape)))], + dtype=all_dtypes, + ) + def testUnstack(self, shape, axis, dtype): + rng = jtu.rand_default(self.rng()) + x = rng(shape, dtype) + if jnp.asarray(x).ndim == 0: + with self.assertRaisesRegex(ValueError, "Unstack requires arrays with"): + jnp.unstack(x, axis=axis) + return + y = jnp.unstack(x, axis=axis) + if shape[axis] == 0: + self.assertEqual(y, ()) + else: + self.assertArraysEqual(jnp.moveaxis(jnp.array(y), 0, axis), x) + + @parameterized.parameters( - [dtype for dtype in [jnp.bool, jnp.uint8, jnp.uint16, jnp.uint32, - jnp.uint64, jnp.int8, jnp.int16, jnp.int32, jnp.int64, - jnp.bfloat16, jnp.float16, jnp.float32, jnp.float64, - jnp.complex64, jnp.complex128] - if dtype == dtypes.canonicalize_dtype(dtype)]) + [dtype for dtype in [ + jnp.bool, + jnp.uint4, jnp.uint8, jnp.uint16, jnp.uint32, jnp.uint64, + jnp.int4, jnp.int8, jnp.int16, jnp.int32, jnp.int64, + jnp.bfloat16, jnp.float16, jnp.float32, jnp.float64, + jnp.complex64, jnp.complex128] + if dtype == dtypes.canonicalize_dtype(dtype)]) def testDtypeWrappers(self, dtype): arr = dtype(0) self.assertIsInstance(arr, jax.Array) @@ -300,17 +323,15 @@ def testCountNonzero(self, shape, dtype, axis): self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False) self._CompileAndCheck(jnp_fun, args_maker) - @jtu.sample_product(shape=all_shapes, dtype=all_dtypes) + @jtu.sample_product(shape=nonzerodim_shapes, dtype=all_dtypes) def testNonzero(self, shape, dtype): rng = jtu.rand_some_zero(self.rng()) args_maker = lambda: [rng(shape, dtype)] - with jtu.ignore_warning(category=DeprecationWarning, - message="Calling nonzero on 0d arrays.*"): - self._CheckAgainstNumpy(np.nonzero, jnp.nonzero, args_maker, check_dtypes=False) + self._CheckAgainstNumpy(np.nonzero, jnp.nonzero, args_maker, check_dtypes=False) @jtu.sample_product( [dict(shape=shape, fill_value=fill_value) - for shape in nonempty_array_shapes + for shape in nonempty_nonscalar_array_shapes for fill_value in [None, -1, shape or (1,)] ], dtype=all_dtypes, @@ -328,17 +349,13 @@ def np_fun(x): return tuple(np.concatenate([arg, np.full(size - len(arg), fval, arg.dtype)]) for fval, arg in safe_zip(fillvals, result)) jnp_fun = lambda x: jnp.nonzero(x, size=size, fill_value=fill_value) - with jtu.ignore_warning(category=DeprecationWarning, - message="Calling nonzero on 0d arrays.*"): - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False) - self._CompileAndCheck(jnp_fun, args_maker) + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False) + self._CompileAndCheck(jnp_fun, args_maker) - @jtu.sample_product(shape=all_shapes, dtype=all_dtypes) + @jtu.sample_product(shape=nonzerodim_shapes, dtype=all_dtypes) def testFlatNonzero(self, shape, dtype): rng = jtu.rand_some_zero(self.rng()) - np_fun = jtu.ignore_warning( - category=DeprecationWarning, - message="Calling nonzero on 0d arrays.*")(np.flatnonzero) + np_fun = np.flatnonzero jnp_fun = jnp.flatnonzero args_maker = lambda: [rng(shape, dtype)] self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False) @@ -348,7 +365,7 @@ def testFlatNonzero(self, shape, dtype): self._CompileAndCheck(jnp_fun, args_maker) @jtu.sample_product( - shape=nonempty_array_shapes, + shape=nonempty_nonscalar_array_shapes, dtype=all_dtypes, fill_value=[None, -1, 10, (-1,), (10,)], size=[1, 5, 10], @@ -356,7 +373,6 @@ def testFlatNonzero(self, shape, dtype): def testFlatNonzeroSize(self, shape, dtype, size, fill_value): rng = jtu.rand_some_zero(self.rng()) args_maker = lambda: [rng(shape, dtype)] - @jtu.ignore_warning(category=DeprecationWarning, message="Calling nonzero on 0d arrays.*") def np_fun(x): result = np.flatnonzero(x) if size <= len(result): @@ -368,24 +384,20 @@ def np_fun(x): self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False) self._CompileAndCheck(jnp_fun, args_maker) - @jtu.sample_product(shape=all_shapes, dtype=all_dtypes) + @jtu.sample_product(shape=nonzerodim_shapes, dtype=all_dtypes) def testArgWhere(self, shape, dtype): rng = jtu.rand_some_zero(self.rng()) args_maker = lambda: [rng(shape, dtype)] - with jtu.ignore_warning(category=DeprecationWarning, - message="Calling nonzero on 0d arrays.*"): - self._CheckAgainstNumpy(np.argwhere, jnp.argwhere, args_maker, check_dtypes=False) + self._CheckAgainstNumpy(np.argwhere, jnp.argwhere, args_maker, check_dtypes=False) # JIT compilation requires specifying a size statically. Full test of this # behavior is in testNonzeroSize(). jnp_fun = lambda x: jnp.argwhere(x, size=np.size(x) // 2) - with jtu.ignore_warning(category=DeprecationWarning, - message="Calling nonzero on 0d arrays.*"): - self._CompileAndCheck(jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) @jtu.sample_product( [dict(shape=shape, fill_value=fill_value) - for shape in nonempty_array_shapes + for shape in nonempty_nonscalar_array_shapes for fill_value in [None, -1, shape or (1,)] ], dtype=all_dtypes, @@ -404,10 +416,8 @@ def np_fun(x): for fval, arg in safe_zip(fillvals, result.T)]).T jnp_fun = lambda x: jnp.argwhere(x, size=size, fill_value=fill_value) - with jtu.ignore_warning(category=DeprecationWarning, - message="Calling nonzero on 0d arrays.*"): - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False) - self._CompileAndCheck(jnp_fun, args_maker) + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False) + self._CompileAndCheck(jnp_fun, args_maker) @jtu.sample_product( [dict(np_op=getattr(np, rec.name), jnp_op=getattr(jnp, rec.name), @@ -568,7 +578,7 @@ def np_fun(x, y): return np.matmul(x, y).astype(dtype) args_maker = lambda: [rng(lhs_shape, lhs_dtype), rng(rhs_shape, rhs_dtype)] tol = {np.float16: 1e-2, np.float32: 2e-2, np.float64: 1e-12, - np.complex128: 1e-12} + np.complex128: 1e-12, jnp.bfloat16: 1e-1} with jtu.strict_promotion_if_dtypes_match([lhs_dtype, rhs_dtype]): self._CheckAgainstNumpy(np_fun, jnp.matmul, args_maker, tol=tol) @@ -603,7 +613,7 @@ def np_fn(x, y, axis=axis): return f(x, y, axis=axis).astype(x.dtype) jnp_fn = partial(jnp.vecdot, axis=axis) tol = {np.float16: 1e-2, np.float32: 1E-3, np.float64: 1e-12, - np.complex64: 1E-3, np.complex128: 1e-12} + np.complex64: 1E-3, np.complex128: 1e-12, jnp.bfloat16: 1e-1} self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, tol=tol) self._CompileAndCheck(jnp_fn, args_maker, tol=tol) @@ -862,7 +872,7 @@ def testPolyfit(self, shape, dtype, deg, rcond, full, w, cov): (np.full(1, -0.9), np.ones(1))] ], shape=all_shapes, - dtype=number_dtypes, + dtype=float_dtypes + int_dtypes + unsigned_dtypes, ) @jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion. @jax.numpy_dtype_promotion('standard') # This test explicitly exercises mixed type promotion @@ -872,14 +882,67 @@ def testClipStaticBounds(self, shape, dtype, a_min, a_max): a_max = None if a_max is None else abs(a_max) rng = jtu.rand_default(self.rng()) np_fun = lambda x: np.clip(x, a_min=a_min, a_max=a_max) - jnp_fun = lambda x: jnp.clip(x, a_min=a_min, a_max=a_max) + jnp_fun = lambda x: jnp.clip(x, min=a_min, max=a_max) args_maker = lambda: [rng(shape, dtype)] self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False) self._CompileAndCheck(jnp_fun, args_maker) - def testClipError(self): - with self.assertRaisesRegex(ValueError, "At most one of a_min and a_max.*"): - jnp.clip(jnp.zeros((3,))) + + @jtu.sample_product( + shape=all_shapes, + dtype=default_dtypes + unsigned_dtypes, + ) + def testClipNone(self, shape, dtype): + rng = jtu.rand_default(self.rng()) + x = rng(shape, dtype) + self.assertArraysEqual(jnp.clip(x), x) + + + # TODO(micky774): Check for ValueError instead of DeprecationWarning when + # jnp.clip deprecation is completed (began 2024-4-2) and default behavior is + # Array API 2023 compliant + @jtu.sample_product(shape=all_shapes) + @jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion. + @jax.numpy_dtype_promotion('standard') # This test explicitly exercises mixed type promotion + def testClipComplexInputDeprecation(self, shape): + rng = jtu.rand_default(self.rng()) + x = rng(shape, dtype=jnp.complex64) + msg = "Complex values have no ordering and cannot be clipped" + # jit is disabled so we don't miss warnings due to caching. + with jax.disable_jit(): + with self.assertWarns(DeprecationWarning, msg=msg): + jnp.clip(x) + + with self.assertWarns(DeprecationWarning, msg=msg): + jnp.clip(x, max=x) + + x = rng(shape, dtype=jnp.int32) + with self.assertWarns(DeprecationWarning, msg=msg): + jnp.clip(x, min=-1+5j) + + with self.assertWarns(DeprecationWarning, msg=msg): + jnp.clip(x, max=jnp.array([-1+5j])) + + + # TODO(micky774): Check for ValueError instead of DeprecationWarning when + # jnp.hypot deprecation is completed (began 2024-4-2) and default behavior is + # Array API 2023 compliant + @jtu.sample_product(shape=all_shapes) + @jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion. + @jax.numpy_dtype_promotion('standard') # This test explicitly exercises mixed type promotion + def testHypotComplexInputDeprecation(self, shape): + rng = jtu.rand_default(self.rng()) + x = rng(shape, dtype=jnp.complex64) + msg = "Passing complex-valued inputs to hypot" + # jit is disabled so we don't miss warnings due to caching. + with jax.disable_jit(): + with self.assertWarns(DeprecationWarning, msg=msg): + jnp.hypot(x, x) + + with self.assertWarns(DeprecationWarning, msg=msg): + y = jnp.ones_like(x) + jnp.hypot(x, y) + @jtu.sample_product( [dict(shape=shape, dtype=dtype) @@ -1234,6 +1297,22 @@ def testExtract(self, shape, dtype): args_maker = lambda: [rng(shape, jnp.float32), rng(shape, dtype)] self._CheckAgainstNumpy(np.extract, jnp.extract, args_maker) + @jtu.sample_product(shape=nonempty_array_shapes, dtype=all_dtypes) + def testExtractSize(self, shape, dtype): + rng = jtu.rand_some_zero(self.rng()) + args_maker = lambda: [rng(shape, jnp.float32), rng(shape, dtype)] + def jnp_fun(condition, arr): + return jnp.extract(condition, arr, size=jnp.size(arr) - 1) + def np_fun(condition, arr): + size = jnp.size(arr) - 1 + out = np.extract(condition, arr) + result = np.zeros(np.size(arr) - 1, dtype=dtype) + size = min(len(out), size) + result[:size] = out[:size] + return result + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + @jtu.sample_product( [dict(ncond=ncond, nfunc=nfunc) for ncond in [1, 2, 3] @@ -1451,6 +1530,37 @@ def testCompress(self, shape, dtype, axis): jnp_fun = partial(jnp.compress, axis=axis) self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + @jtu.sample_product( + [dict(shape=shape, axis=axis) + for shape in all_shapes + for axis in list(range(len(shape))) + ], + dtype=all_dtypes, + ) + def testCompressSize(self, shape, dtype, axis): + rng = jtu.rand_default(self.rng()) + if shape in scalar_shapes or len(shape) == 0: + cond_shape = (0,) + elif axis is None: + cond_shape = (math.prod(shape),) + else: + cond_shape = (shape[axis],) + args_maker = lambda: [rng(cond_shape, bool), rng(shape, dtype)] + + def np_fun(condition, a, axis=axis, fill_value=1): + # assuming size = a.shape[axis] + out = np.compress(condition, a, axis=axis) + result = np.full_like(a, fill_value) + result[tuple(slice(s) for s in out.shape)] = out + return result + + def jnp_fun(condition, a, axis=axis, fill_value=1): + return jnp.compress(condition, a, axis=axis, + size=a.shape[axis], fill_value=fill_value) + + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + @jtu.sample_product( shape=[(2, 3)], dtype=int_dtypes, @@ -1483,8 +1593,8 @@ def testCompressMethod(self, shape, dtype, axis): args_maker = lambda: [rng(cond_shape, jnp.float32), rng(shape, dtype)] - np_fun = lambda condition, x: np.compress(condition, x, axis=axis) - jnp_fun = lambda condition, x: x.compress(condition, axis=axis) + np_fun = lambda condition, x: np.asarray(x).compress(condition, axis=axis) + jnp_fun = lambda condition, x: jnp.asarray(x).compress(condition, axis=axis) self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) @jtu.sample_product( @@ -1938,9 +2048,6 @@ def np_fun(x, fill_value=fill_value): @jtu.sample_product(dtype=inexact_dtypes) def testUniqueNans(self, dtype): - if numpy_version == (1, 23, 0) and dtype == np.float16: - # https://github.com/numpy/numpy/issues/21838 - self.skipTest("Known failure on numpy 1.23.0") def args_maker(): x = [-0.0, 0.0, 1.0, 1.0, np.nan, -np.nan] if np.issubdtype(dtype, np.complexfloating): @@ -1960,8 +2067,6 @@ def np_fun(x): @jtu.sample_product(dtype=inexact_dtypes, equal_nan=[True, False]) def testUniqueEqualNan(self, dtype, equal_nan): - if numpy_version < (1, 24, 0): - self.skipTest("np.unique equal_nan requires NumPy 1.24 or newer.") shape = (20,) rng = jtu.rand_some_nan(self.rng()) args_maker = lambda: [rng(shape, dtype)] @@ -2158,6 +2263,19 @@ def testEye(self, n, m, k, dtype): self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) self._CompileAndCheck(jnp_fun, args_maker) + @jtu.sample_product( + dtype=default_dtypes, + n=[0, 4], + m=[None, 0, 1, 3, 4], + k=range(-4, 4), + ) + def testEyeDynamicK(self, n, m, k, dtype): + np_fun = lambda k: np.eye(n, M=m, k=k, dtype=dtype) + jnp_fun = lambda k: jnp.eye(n, M=m, k=k, dtype=dtype) + args_maker = lambda: [k] + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + @jtu.sample_product( dtype=default_dtypes, n=[0, 4], @@ -2537,13 +2655,18 @@ def np_fun(arg): side=['left', 'right'], dtype=number_dtypes, method=['sort', 'scan', 'scan_unrolled', 'compare_all'], + use_sorter=[True, False], ) - def testSearchsorted(self, ashape, vshape, side, dtype, method): + def testSearchsorted(self, ashape, vshape, side, dtype, method, use_sorter): rng = jtu.rand_default(self.rng()) - args_maker = lambda: [np.sort(rng(ashape, dtype)), rng(vshape, dtype)] - def np_fun(a, v): - return np.searchsorted(a, v, side=side).astype('int32') - jnp_fun = lambda a, v: jnp.searchsorted(a, v, side=side, method=method) + def args_maker(): + a = rng(ashape, dtype) + v = rng(vshape, dtype) + return (a, v, np.argsort(a)) if use_sorter else (np.sort(a), v) + def np_fun(a, v, sorter=None): + return np.searchsorted(a, v, side=side, sorter=sorter).astype('int32') + def jnp_fun(a, v, sorter=None): + return jnp.searchsorted(a, v, side=side, method=method, sorter=sorter) self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) self._CompileAndCheck(jnp_fun, args_maker) @@ -2597,7 +2720,7 @@ def testSearchsortedNans(self, dtype, side, method): @jtu.sample_product( xshape=[(20,), (5, 4)], - binshape=[(1,), (5,)], + binshape=[(0,), (1,), (5,)], right=[True, False], reverse=[True, False], dtype=default_dtypes, @@ -2656,10 +2779,7 @@ def testStack(self, shape, axis, dtypes, array_input, out_dtype): else: args_maker = lambda: [[rng(shape, dtype) for dtype in dtypes]] - if numpy_version < (1, 24): - np_fun = jtu.promote_like_jnp(lambda *args: np.stack(*args, axis=axis).astype(out_dtype)) - else: - np_fun = jtu.promote_like_jnp(partial(np.stack, axis=axis, dtype=out_dtype, casting='unsafe')) + np_fun = jtu.promote_like_jnp(partial(np.stack, axis=axis, dtype=out_dtype, casting='unsafe')) jnp_fun = partial(jnp.stack, axis=axis, dtype=out_dtype) with jtu.strict_promotion_if_dtypes_match(dtypes): @@ -2686,7 +2806,7 @@ def testHVDStack(self, shape, op, dtypes, array_input, out_dtype): else: args_maker = lambda: [[rng(shape, dtype) for dtype in dtypes]] - if numpy_version < (1, 24) or op == "dstack": + if op == "dstack": np_fun = jtu.promote_like_jnp(lambda *args: getattr(np, op)(*args).astype(out_dtype)) else: np_fun = partial(jtu.promote_like_jnp(getattr(np, op)), dtype=out_dtype, @@ -2868,6 +2988,30 @@ def testArrayCreationWithSharding(self, func, shape, dtype): out = func(**kwds, shape=shape, dtype=dtype, device=sharding) self.assertEqual(out.sharding, sharding) + @jtu.sample_product( + func=[ + lambda dtype, device: jnp.arange(5, dtype=dtype, device=device), + lambda dtype, device: jnp.eye(5, 6, dtype=dtype, device=device), + ], + dtype=default_dtypes, + ) + def testArangeEyeWithDevice(self, func, dtype): + device = jax.devices()[-1] + out = func(dtype=dtype, device=device) + self.assertEqual(out.devices(), {device}) + + @jtu.sample_product( + func=[ + lambda dtype, device: jnp.arange(5, dtype=dtype, device=device), + lambda dtype, device: jnp.eye(5, 6, dtype=dtype, device=device), + ], + dtype=default_dtypes, + ) + def testArangeEyeWithSharding(self, func, dtype): + sharding = SingleDeviceSharding(jax.devices()[-1]) + out = func(dtype=dtype, device=sharding) + self.assertEqual(out.sharding, sharding) + @jtu.sample_product( func=[jnp.empty_like, jnp.zeros_like, jnp.ones_like, jnp.full_like], shape=array_shapes, @@ -3789,6 +3933,44 @@ def testAstype(self, from_dtype, to_dtype, use_method): self._CheckAgainstNumpy(np_op, jnp_op, args_maker) self._CompileAndCheck(jnp_op, args_maker) + @jtu.sample_product( + from_dtype=['int32', 'float32', 'complex64'], + use_method=[True, False], + ) + def testAstypeBool(self, from_dtype, use_method, to_dtype='bool'): + rng = jtu.rand_some_zero(self.rng()) + args_maker = lambda: [rng((3, 4), from_dtype)] + if (not use_method) and hasattr(np, "astype"): # Added in numpy 2.0 + np_op = lambda x: np.astype(x, to_dtype) + else: + np_op = lambda x: np.asarray(x).astype(to_dtype) + if use_method: + jnp_op = lambda x: jnp.asarray(x).astype(to_dtype) + else: + jnp_op = lambda x: jnp.astype(x, to_dtype) + self._CheckAgainstNumpy(np_op, jnp_op, args_maker) + self._CompileAndCheck(jnp_op, args_maker) + + @jtu.sample_product( + change_dtype=[True, False], + copy=[True, False], + ) + def testAstypeCopy(self, change_dtype, copy): + dtype = 'float32' if change_dtype else 'int32' + expect_copy = change_dtype or copy + x = jnp.arange(5, dtype='int32') + y = x.astype(dtype, copy=copy) + + self.assertEqual(y.dtype, dtype) + y.delete() + self.assertNotEqual(x.is_deleted(), expect_copy) + + def testAstypeComplexDowncast(self): + x = jnp.array(2.0+1.5j, dtype='complex64') + msg = "Casting from complex to non-complex dtypes will soon raise " + with self.assertWarns(DeprecationWarning, msg=msg): + x.astype('float32') + def testAstypeInt4(self): # Test converting from int4 to int8 x = np.array([1, -2, -3, 4, -8, 7], dtype=jnp.int4) @@ -4322,6 +4504,13 @@ def testTakeAlongAxisWithEmptyArgs(self): x = jnp.ones((4, 0, 3), dtype=jnp.int32) np.testing.assert_array_equal(x, jnp.take_along_axis(x, x, axis=1)) + def testTakeAlongAxisOptionalArgs(self): + x = jnp.arange(5.0) + ind = jnp.array([0, 2, 4, 6]) + expected = jnp.array([0.0, 2.0, 4.0, 10.0], dtype=x.dtype) + actual = jnp.take_along_axis(x, ind, axis=None, mode='fill', fill_value=10.0) + self.assertArraysEqual(expected, actual) + @jtu.sample_product( dtype=inexact_dtypes, shape=[0, 5], @@ -4396,23 +4585,20 @@ def args_maker(): return [] self._CompileAndCheck(jnp_fun, args_maker) @jtu.sample_product( - shape=all_shapes, dtype=all_dtypes, + shape=nonzerodim_shapes, + dtype=all_dtypes, ) def testWhereOneArgument(self, shape, dtype): rng = jtu.rand_some_zero(self.rng()) args_maker = lambda: [rng(shape, dtype)] - with jtu.ignore_warning(category=DeprecationWarning, - message="Calling nonzero on 0d arrays.*"): - self._CheckAgainstNumpy(np.where, jnp.where, args_maker, check_dtypes=False) + self._CheckAgainstNumpy(np.where, jnp.where, args_maker, check_dtypes=False) # JIT compilation requires specifying a size statically. Full test of # this behavior is in testNonzeroSize(). jnp_fun = lambda x: jnp.where(x, size=np.size(x) // 2) - with jtu.ignore_warning(category=DeprecationWarning, - message="Calling nonzero on 0d arrays.*"): - self._CompileAndCheck(jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) @jtu.sample_product( shapes=filter(_shapes_are_broadcast_compatible, @@ -4432,18 +4618,12 @@ def testWhereExtraCode(self): def f(x): return jnp.where(x > 0, x, -x) + jaxpr = jax.make_jaxpr(jax.grad(f))(3.) # Test no comparison literal True/False in jaxpr, and hence no comparison to # literals - jaxpr = jax.make_jaxpr(jax.grad(f))(3.) self.assertNotIn('False', str(jaxpr)) self.assertNotIn('True', str(jaxpr)) - # But if we set the option off, we get the old behavior. - with config.new_select_transpose(False): - jaxpr = jax.make_jaxpr(jax.grad(f))(3.) - self.assertIn('False', str(jaxpr)) - self.assertIn('True', str(jaxpr)) - def testWhereScalarPromotion(self): x = jnp.where(jnp.array([True, False]), 3, jnp.ones((2,), dtype=jnp.float32)) @@ -4460,6 +4640,7 @@ def testWhereScalarPromotion(self): # maximal set of dtypes. dtypes=itertools.combinations_with_replacement(all_dtypes, 3), ) + @jax.numpy_rank_promotion('allow') def testSelect(self, n, shapes, dtypes): dtypes = dtypes[:n+1] rng = jtu.rand_default(self.rng()) @@ -4639,10 +4820,19 @@ def testArangeJit(self): expected = jtu.with_jax_dtype_defaults(np.arange)(5) self.assertAllClose(ans, expected) - @jtu.sample_product(args=[(5,), (0, 5)]) - def testArangeJaxpr(self, args): - jaxpr = jax.make_jaxpr(lambda: jnp.arange(*args))() - self.assertEqual(len(jaxpr.jaxpr.eqns), 1) + @jtu.sample_product( + args=[(5,), (0, 5)], + specify_device=[True, False], + ) + def testArangeJaxpr(self, args, specify_device): + device = jax.devices()[-1] if specify_device else None + kwargs = {"device": device} + jaxpr = jax.make_jaxpr(lambda: jnp.arange(*args, **kwargs))() + # We have 2 statements in jaxpr: + # [a:i32[5] = iota[dimension=0 dtype=int32 shape=(5,)], + # a:i32[5] = device_put[devices=[None] srcs=[None]] b] + num_eqs = 2 if device is not None else 1 + self.assertEqual(len(jaxpr.jaxpr.eqns), num_eqs) self.assertEqual(jaxpr.jaxpr.eqns[0].primitive, lax.iota_p) def testIssue830(self): @@ -5102,8 +5292,11 @@ def np_op(start, stop): start, stop, num, endpoint=endpoint, dtype=dtype if dtype != jnp.bfloat16 else np.float32, axis=axis).astype(dtype) - self._CheckAgainstNumpy(np_op, jnp_op, args_maker, - check_dtypes=False, tol=tol) + + # JAX follows NumPy 2.0 semantics for complex geomspace. + if not (jtu.numpy_version() < (2, 0, 0) and dtypes.issubdtype(dtype, jnp.complexfloating)): + self._CheckAgainstNumpy(np_op, jnp_op, args_maker, + check_dtypes=False, tol=tol) if dtype in (inexact_dtypes + [None,]): self._CompileAndCheck(jnp_op, args_maker, check_dtypes=False, atol=tol, rtol=tol) @@ -5182,6 +5375,13 @@ def testBroadcastTo(self, from_shape, to_shape): self._CheckAgainstNumpy(np_op, jnp_op, args_maker) self._CompileAndCheck(jnp_op, args_maker) + def testBroadcastToInvalidShape(self): + # Regression test for https://github.com/google/jax/issues/20533 + x = jnp.zeros((3, 4, 5)) + with self.assertRaisesRegex( + ValueError, "Cannot broadcast to shape with fewer dimensions"): + jnp.broadcast_to(x, (4, 5)) + @jtu.sample_product( [dict(shapes=shapes, broadcasted_shape=broadcasted_shape) for shapes, broadcasted_shape in [ @@ -5576,6 +5776,37 @@ def test_isdtype(self, dtype, kind): numpy_result = np.isdtype(dtype, kind) self.assertEqual(jax_result, numpy_result) + @jtu.sample_product( + [dict(yshape=yshape, xshape=xshape, dx=dx, axis=axis) + for yshape, xshape, dx, axis in [ + ((10,), None, 1.0, -1), + ((3, 10), None, 2.0, -1), + ((3, 10), None, 3.0, -0), + ((10, 3), (10,), 1.0, -2), + ((3, 10), (10,), 1.0, -1), + ((3, 10), (3, 10), 1.0, -1), + ((2, 3, 10), (3, 10), 1.0, -2), + ] + ], + dtype=float_dtypes + int_dtypes, + ) + @jtu.skip_on_devices("tpu") # TODO(jakevdp): fix and reenable this test. + @jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion. + def test_trapezoid(self, yshape, xshape, dtype, dx, axis): + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(yshape, dtype), rng(xshape, dtype) if xshape is not None else None] + if jtu.numpy_version() >= (2, 0, 0): + np_fun = partial(np.trapezoid, dx=dx, axis=axis) + else: + np_fun = partial(np.trapz, dx=dx, axis=axis) + jnp_fun = partial(jnp.trapezoid, dx=dx, axis=axis) + tol = jtu.tolerance(dtype, {np.float16: 2e-3, np.float64: 1e-12, + jax.dtypes.bfloat16: 4e-2}) + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, tol=tol, + check_dtypes=False) + self._CompileAndCheck(jnp_fun, args_maker, atol=tol, rtol=tol, + check_dtypes=False) + # Most grad tests are at the lax level (see lax_test.py), but we add some here # as needed for e.g. particular compound ops of interest. @@ -5736,14 +5967,14 @@ def testWrappedSignaturesMatch(self): 'argpartition': ['kind', 'order'], 'asarray': ['like'], 'broadcast_to': ['subok'], - 'clip': ['kwargs'], + 'clip': ['kwargs', 'out'], 'copy': ['subok'], 'corrcoef': ['ddof', 'bias', 'dtype'], 'cov': ['dtype'], 'empty_like': ['subok', 'order'], 'einsum': ['kwargs'], 'einsum_path': ['einsum_call'], - 'eye': ['device', 'order', 'like'], + 'eye': ['order', 'like'], 'hstack': ['casting'], 'identity': ['like'], 'isin': ['kind'], @@ -5763,26 +5994,30 @@ def testWrappedSignaturesMatch(self): 'partition': ['kind', 'order'], 'percentile': ['weights'], 'quantile': ['weights'], + 'reshape': ['shape', 'copy'], 'row_stack': ['casting'], 'stack': ['casting'], - 'std': ['correction', 'mean'], + 'std': ['mean'], 'tri': ['like'], - 'var': ['correction', 'mean'], + 'var': ['mean'], 'vstack': ['casting'], 'zeros_like': ['subok', 'order'] } extra_params = { + # TODO(micky774): Remove when np.clip has adopted the Array API 2023 + # standard + 'clip': ['x', 'max', 'min'], 'einsum': ['subscripts', 'precision'], 'einsum_path': ['subscripts'], - 'take_along_axis': ['mode'], + 'take_along_axis': ['mode', 'fill_value'], 'fill_diagonal': ['inplace'], } mismatches = {} for name, (jnp_fun, np_fun) in func_pairs.items(): - if numpy_version >= (1, 24) and name in ['histogram', 'histogram2d', 'histogramdd']: + if name in ['histogram', 'histogram2d', 'histogramdd']: # numpy 1.24 re-orders the density and weights arguments. # TODO(jakevdp): migrate histogram APIs to match newer numpy versions. continue @@ -5879,32 +6114,39 @@ class NumpyDocTests(jtu.JaxTestCase): def test_lax_numpy_docstrings(self): # Test that docstring wrapping & transformation didn't fail. - # Functions that have their own docstrings & don't wrap numpy. - known_exceptions = {'fromfile', 'fromiter', 'frompyfunc', 'vectorize'} + unimplemented = ['fromfile', 'fromiter'] + aliases = ['abs', 'acos', 'acosh', 'asin', 'asinh', 'atan', 'atanh', 'atan2', + 'amax', 'amin'] for name in dir(jnp): - if name in known_exceptions or name.startswith('_'): + if name.startswith('_') or name in unimplemented: continue - # We only check signatures of functions. obj = getattr(jnp, name) - if isinstance(obj, type) or not callable(obj): - continue - - # Some jnp functions are imported from numpy or jax.dtypes directly. - if any(obj is getattr(mod, obj.__name__, None) for mod in [np, dtypes]): - continue - wrapped_fun = obj.__np_wrapped__ - if wrapped_fun is None: - continue - - # If the wrapped function has a docstring, obj should too - if wrapped_fun.__doc__ and not obj.__doc__: - raise Exception(f"jnp.{name} does not contain wrapped docstring.") - - if obj.__doc__ and "*Original docstring below.*" not in obj.__doc__: - raise Exception(f"jnp.{name} does not have a wrapped docstring.") + if isinstance(obj, type) or not callable(obj): + # Skip docstring checks for non-functions + pass + elif hasattr(np, name) and obj is getattr(np, name): + # Some APIs are imported directly from NumPy; we don't check these. + pass + elif hasattr(obj, '__np_wrapped__'): + # Functions decorated with @implements(...) should have __np_wrapped__ + wrapped_fun = obj.__np_wrapped__ + if wrapped_fun is not None: + # If the wrapped function has a docstring, obj should too + if wrapped_fun.__doc__ and not obj.__doc__: + raise Exception(f"jnp.{name} does not contain wrapped docstring.") + if obj.__doc__ and "*Original docstring below.*" not in obj.__doc__: + raise Exception(f"jnp.{name} does not have a wrapped docstring.") + elif name in aliases: + assert "Alias of" in obj.__doc__ + else: + # Other functions should have nontrivial docs including "Args" and "Returns". + doc = obj.__doc__ + self.assertNotEmpty(doc) + self.assertIn("Args:", doc, msg=f"'Args:' not found in docstring of jnp.{name}") + self.assertIn("Returns:", doc, msg=f"'Returns:' not found in docstring of jnp.{name}") @parameterized.named_parameters( {"testcase_name": "_jit" if jit else "", "jit": jit} for jit in [True, False]) diff --git a/tests/lax_numpy_ufuncs_test.py b/tests/lax_numpy_ufuncs_test.py index 0f40e9d4d97e..40c9eb3bc4f4 100644 --- a/tests/lax_numpy_ufuncs_test.py +++ b/tests/lax_numpy_ufuncs_test.py @@ -24,8 +24,7 @@ from jax._src import test_util as jtu from jax._src.numpy.ufunc_api import get_if_single_primitive -from jax import config -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() def scalar_add(x, y): diff --git a/tests/lax_numpy_vectorize_test.py b/tests/lax_numpy_vectorize_test.py index cb0d9a0dcf64..edc344467d7c 100644 --- a/tests/lax_numpy_vectorize_test.py +++ b/tests/lax_numpy_vectorize_test.py @@ -21,8 +21,7 @@ from jax import numpy as jnp from jax._src import test_util as jtu -from jax import config -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() class VectorizeTest(jtu.JaxTestCase): diff --git a/tests/lax_scipy_special_functions_test.py b/tests/lax_scipy_special_functions_test.py index 027cce44797f..38607cae883b 100644 --- a/tests/lax_scipy_special_functions_test.py +++ b/tests/lax_scipy_special_functions_test.py @@ -23,11 +23,11 @@ import scipy.special as osp_special import jax +from jax._src import deprecations from jax._src import test_util as jtu from jax.scipy import special as lsp_special -from jax import config -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() all_shapes = [(), (4,), (3, 4), (3, 1), (1, 4), (2, 1, 4)] @@ -53,10 +53,10 @@ def op_record(name, nargs, dtypes, rng_factory, test_grad, nondiff_argnums=(), t JAX_SPECIAL_FUNCTION_RECORDS = [ op_record( - "beta", 2, float_dtypes, jtu.rand_positive, False + "beta", 2, float_dtypes, jtu.rand_default, False ), op_record( - "betaln", 2, float_dtypes, jtu.rand_positive, False + "betaln", 2, float_dtypes, jtu.rand_default, False ), op_record( "betainc", 3, float_dtypes, jtu.rand_positive, False @@ -73,6 +73,9 @@ def op_record(name, nargs, dtypes, rng_factory, test_grad, nondiff_argnums=(), t op_record( "gammaincc", 2, float_dtypes, jtu.rand_positive, True ), + op_record( + "gammasgn", 1, float_dtypes, jtu.rand_default, True + ), op_record( "erf", 1, float_dtypes, jtu.rand_small_positive, True ), @@ -114,7 +117,7 @@ def op_record(name, nargs, dtypes, rng_factory, test_grad, nondiff_argnums=(), t ), op_record( "ndtri", 1, float_dtypes, - functools.partial(jtu.rand_uniform, low=0.05, high=0.95), True, + functools.partial(jtu.rand_uniform, low=0.0, high=1.0), True, ), op_record( "ndtr", 1, float_dtypes, jtu.rand_default, True @@ -145,7 +148,12 @@ def op_record(name, nargs, dtypes, rng_factory, test_grad, nondiff_argnums=(), t "rel_entr", 2, float_dtypes, jtu.rand_positive, True, ), op_record("poch", 2, float_dtypes, jtu.rand_positive, True), - op_record("hyp1f1", 3, float_dtypes, functools.partial(jtu.rand_uniform, low=0.5, high=30), True) + op_record( + "hyp1f1", 3, float_dtypes, + functools.partial(jtu.rand_uniform, low=0.5, high=30), True + ), + op_record("log_softmax", 1, float_dtypes, jtu.rand_default, True), + op_record("softmax", 1, float_dtypes, jtu.rand_default, True), ] @@ -218,6 +226,44 @@ def testGammaSign(self): self._CheckAgainstNumpy(osp_special.gamma, lsp_special.gamma, args_maker, rtol=rtol) self._CompileAndCheck(lsp_special.gamma, args_maker, rtol=rtol) + def testNdtriExtremeValues(self): + # Testing at the extreme values (bounds (0. and 1.) and outside the bounds). + dtype = jax.numpy.zeros(0).dtype # default float dtype. + args_maker = lambda: [np.arange(-10, 10).astype(dtype)] + rtol = 1E-3 if jtu.test_device_matches(["tpu"]) else 1e-5 + self._CheckAgainstNumpy(osp_special.ndtri, lsp_special.ndtri, args_maker, rtol=rtol) + self._CompileAndCheck(lsp_special.ndtri, args_maker, rtol=rtol) + + def testRelEntrExtremeValues(self): + # Testing at the extreme values (bounds (0. and 1.) and outside the bounds). + dtype = jax.numpy.zeros(0).dtype # default float dtype. + args_maker = lambda: [np.array([-2, -2, -2, -1, -1, -1, 0, 0, 0]).astype(dtype), + np.array([-1, 0, 1, -1, 0, 1, -1, 0, 1]).astype(dtype)] + rtol = 1E-3 if jtu.test_device_matches(["tpu"]) else 1e-5 + self._CheckAgainstNumpy(osp_special.rel_entr, lsp_special.rel_entr, args_maker, rtol=rtol) + self._CompileAndCheck(lsp_special.rel_entr, args_maker, rtol=rtol) + + def testBetaParameterDeprecation(self): + with self.assertNoWarnings(): + lsp_special.beta(1, 1) + lsp_special.beta(1, b=1) + lsp_special.beta(a=1, b=1) + if deprecations.is_accelerated('jax-scipy-beta-args'): + with self.assertRaises(ValueError): + lsp_special.beta(x=1, y=1) + else: + with self.assertWarns(DeprecationWarning): + lsp_special.beta(1, y=1) + with self.assertWarns(DeprecationWarning): + lsp_special.beta(a=1, y=1) + with self.assertWarns(DeprecationWarning): + lsp_special.beta(x=1, b=1) + with self.assertWarns(DeprecationWarning): + lsp_special.beta(x=1, y=1) + with self.assertRaises(TypeError), self.assertWarns(DeprecationWarning): + lsp_special.beta(1, x=1) + with self.assertRaises(TypeError), self.assertWarns(DeprecationWarning): + lsp_special.beta(b=1, y=1) if __name__ == "__main__": diff --git a/tests/lax_scipy_spectral_dac_test.py b/tests/lax_scipy_spectral_dac_test.py index 2d353d5909e2..a09dcac5371c 100644 --- a/tests/lax_scipy_spectral_dac_test.py +++ b/tests/lax_scipy_spectral_dac_test.py @@ -14,6 +14,7 @@ import unittest +import jax from jax import lax from jax import numpy as jnp from jax._src import test_util as jtu @@ -21,8 +22,7 @@ from absl.testing import absltest -from jax import config -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() linear_sizes = [16, 97, 128] diff --git a/tests/lax_scipy_test.py b/tests/lax_scipy_test.py index 78a24d1f349b..ab8c03c18a25 100644 --- a/tests/lax_scipy_test.py +++ b/tests/lax_scipy_test.py @@ -34,8 +34,7 @@ from jax.scipy import special as lsp_special from jax.scipy import cluster as lsp_cluster -from jax import config -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() scipy_version = jtu.parse_version(scipy.version.version) @@ -190,6 +189,19 @@ def testLogSumExpNans(self): result = lsp_special.logsumexp(1.0, b=1.0) self.assertEqual(result, 1.0) + @jtu.sample_product( + shape=[(0,), (1,), (2,), (3,), (4,), (5,)], + dtype=float_dtypes, + ) + def testLogSumExpWhere(self, shape, dtype): + rng = jtu.rand_default(self.rng()) + x = rng(shape, dtype) + rng = jtu.rand_bool(self.rng()) + mask = rng(shape, bool) + y_expected = osp_special.logsumexp(x[mask]) if mask.any() else -jnp.inf + y_actual = lsp_special.logsumexp(x, where=mask) + self.assertAllClose(y_expected, y_actual, check_dtypes=False) + @jtu.sample_product( shape=all_shapes, dtype=float_dtypes, @@ -327,8 +339,8 @@ def testLpmn(self, l_max, shape, dtype): def scipy_fun(z, m=l_max, n=l_max): # scipy only supports scalar inputs for z, so we must loop here. - vals, derivs = zip(*(osp_special.lpmn(m, n, zi) for zi in z)) - return np.dstack(vals), np.dstack(derivs) + vals, derivs = zip(*(osp_special.lpmn(m, n, zi) for zi in z.astype('float64'))) + return np.dstack(vals).astype(z.dtype), np.dstack(derivs).astype(z.dtype) self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, rtol=1e-5, atol=3e-3, check_dtypes=False) @@ -348,7 +360,7 @@ def testNormalizedLpmnValues(self, l_max, shape, dtype): def scipy_fun(z, m=l_max, n=l_max): # scipy only supports scalar inputs for z, so we must loop here. - vals, _ = zip(*(osp_special.lpmn(m, n, zi) for zi in z)) + vals, _ = zip(*(osp_special.lpmn(m, n, zi) for zi in z.astype('float64'))) a = np.dstack(vals) # apply the normalization @@ -360,7 +372,7 @@ def scipy_fun(z, m=l_max, n=l_max): c1 = (4.0 * np.pi) * osp_special.factorial(l + m) c2 = np.sqrt(c0 / c1) a_normalized[m, l] = c2 * a[m, l] - return a_normalized + return a_normalized.astype(z.dtype) self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, rtol=1e-5, atol=1e-5, check_dtypes=False) diff --git a/tests/lax_test.py b/tests/lax_test.py index aadac1d64566..ce1a2d4ff897 100644 --- a/tests/lax_test.py +++ b/tests/lax_test.py @@ -45,9 +45,8 @@ from jax._src.interpreters import pxla from jax._src.internal_test_util import lax_test_util from jax._src.lax import lax as lax_internal -from jax._src.lib import xla_client as xc -from jax._src.lib import xla_extension_version -from jax._src.util import NumpyComplexWarning +from jax._src.util import NumpyComplexWarning, safe_zip +from jax._src.tree_util import tree_map config.parse_flags_with_absl() @@ -70,6 +69,22 @@ (np.int32, np.float32), (np.int32, np.float64), (np.int64, np.float64)] +def _reduce_custom_add(x, y): + return x + y + +def _reduce_custom_mul(x, y): + return x * y + +def _reduce_custom_sub(x, y): + return x - y + +def _reduce_custom_min(x, y): + return jnp.minimum(x, y) + +def _reduce_custom_max(x, y): + return jnp.maximum(x, y) + + class LaxTest(jtu.JaxTestCase): """Numerical tests for LAX operations.""" @@ -1059,6 +1074,40 @@ def testDotPreferredElement(self, lhs_shape, rhs_shape, dtype, y.astype(preferred_element_type)) self.assertArraysAllClose(result_with_preferred_type, result_with_upcast_inputs) + @jtu.sample_product( + [dict(lhs_shape=lhs_shape, rhs_shape=rhs_shape) + for lhs_shape in [(3,), (4, 3)] for rhs_shape in [(3,), (3, 6)]], + [dict(dtype_lhs=dtype_lhs, dtype_rhs=dtype_rhs) + for dtype_lhs, dtype_rhs in [(dtypes.float8_e4m3fn, dtypes.float8_e5m2), + (dtypes.float8_e5m2, dtypes.float8_e4m3fn), + (dtypes.float8_e4m3fnuz, dtypes.float8_e5m2fnuz), + (dtypes.float8_e5m2fnuz, dtypes.float8_e4m3fnuz)]], + ) + def test_mixed_fp8_dot_general(self, lhs_shape, rhs_shape, dtype_lhs, dtype_rhs): + if jtu.test_device_matches(["tpu"]): + raise SkipTest("Mixed fp8 precision matmul is not yet supported on TPU") + if not jtu.is_device_rocm() and ( + dtype_lhs in [dtypes.float8_e4m3fnuz, dtypes.float8_e5m2fnuz] or + dtype_rhs in [dtypes.float8_e4m3fnuz, dtypes.float8_e5m2fnuz] + ): + raise SkipTest("float8_e4m3fnuz and float8_e5m2fnuz types are only supported on ROCm") + rng = jtu.rand_default(self.rng()) + lhs = rng(lhs_shape, dtype=dtype_lhs) + rhs = rng(rhs_shape, dtype=dtype_rhs) + dot_general_result = lax.dot( + lhs, rhs, + preferred_element_type=jnp.float32 + ) + + lhs_upcasted = lhs.astype(jnp.float32) + rhs_upcasted = rhs.astype(jnp.float32) + dot_general_result_upcasted = lax.dot( + lhs_upcasted, rhs_upcasted, + preferred_element_type=jnp.float32 + ) + self.assertArraysAllClose( + dot_general_result, dot_general_result_upcasted, rtol=1e-3, atol=1e-3) + @jtu.sample_product( [ dict(lhs_shape=lhs_shape, rhs_shape=rhs_shape) @@ -1074,7 +1123,8 @@ def testDotAgainstNumpy(self, lhs_shape, rhs_shape, dtype): np.float16: 1e-2, np.float64: max(jtu.default_tolerance()[np.dtype(np.float64)], 1e-14), np.complex128: max(jtu.default_tolerance()[np.dtype(np.complex128)], - 1e-14) + 1e-14), + jnp.bfloat16: 1e-1 } lax_op = partial(lax.dot, precision=lax.Precision.HIGHEST) self._CheckAgainstNumpy(lax_reference.dot, lax_op, args_maker, tol=tol) @@ -1161,6 +1211,33 @@ def testDotGeneralAgainstNumpy(self, lhs_shape, rhs_shape, dtype, numpy_op = lambda x, y: lax_reference.dot_general(x, y, dimension_numbers) self._CheckAgainstNumpy(numpy_op, op, args_maker) + @jtu.sample_product( + [ + {'m': 5, 'k': 4, 'n': 3, 'num_groups': 1}, + {'m': 10, 'k': 9, 'n': 8, 'num_groups': 2}, + ], + dtype=jtu.dtypes.numeric, + ) + def testRaggedDot(self, m, k, n, num_groups, dtype): + """Tests ragged_dot. + + The ragged_dot is tested against numpy reference implementation, and by running JAX compilation. + + Raises: + SkipTest: in the case dtype is not supported. + """ + lhs_shape = (m, k) + rhs_shape = (num_groups, k, n) + def group_sizes(m, num_groups): + ends_no_final = jnp.sort(self.rng().choice(m, size=num_groups - 1)) + ends = jnp.concatenate([ends_no_final, jnp.array([m], dtype=ends_no_final.dtype)]) + starts = jnp.concatenate([jnp.zeros(1, dtype=ends_no_final.dtype), ends_no_final]) + return ends - starts + rng = jtu.rand_small(self.rng()) + args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype), group_sizes(m, num_groups)] + self._CompileAndCheck(lax.ragged_dot, args_maker) + self._CheckAgainstNumpy(lax_reference.ragged_dot, lax.ragged_dot, args_maker) + @jtu.sample_product( shape=[(), (2, 3)], dtype=lax_test_util.default_dtypes, @@ -1286,7 +1363,7 @@ def testSqueeze(self, arg_shape, dimensions): numpy_op = lambda x: lax_reference.squeeze(x, dimensions) self._CompileAndCheck(op, args_maker) self._CheckAgainstNumpy(numpy_op, op, args_maker) - check_grads(op, args_maker(), 2, ["fwd", "rev"], eps=1.) + check_grads(op, args_maker(), 3, ["fwd", "rev"], eps=1.) @jtu.sample_product( input_type=["np.array", "jnp.array", "float", "np.float32"], @@ -1800,35 +1877,287 @@ def reference_fun(operand, init_val): # we separately test the version that uses a concrete init_val because it # can hit different code paths def fun(operand): - return lax.reduce_window(operand, init_val, op, dims, strides, padding, - base_dilation, window_dilation) + return lax.reduce_window( + operand, + init_val, + op, + dims, + strides, + padding, + base_dilation, + window_dilation, + ) args_maker = lambda: [rng(shape, dtype)] self._CompileAndCheck(fun, args_maker) + # TODO(voz): I broke these out to their own test for 2 reasons: + # 1. I wanted to show that general ops work, there's a small subset of + # ops, specifically, the ones used in the test above, lax.add, lax.max, and + # lax.min that actually route to a monoid operator that *doesn't* pass JVP + # tests. + # 2. Slightly different parameterization. @jtu.sample_product( - [dict(shape=shape, dims=dims, strides=strides, padding=padding, - base_dilation=base_dilation, window_dilation=window_dilation) - for shape, dims, strides, padding, base_dilation, window_dilation in ( - itertools.chain( - itertools.product( - [(4, 6)], - [(2, 1), (1, 2)], - [(1, 1), (2, 1), (1, 2)], - ["VALID", "SAME", [(0, 3), (1, 2)]], - [(1, 1), (2, 3)], - [(1, 1), (1, 2)]), - itertools.product( - [(3, 2, 4, 6)], [(1, 1, 2, 1), (2, 1, 2, 1)], - [(1, 2, 2, 1), (1, 1, 1, 1)], - ["VALID", "SAME", [(0, 1), (1, 0), (2, 3), (0, 2)]], - [(1, 1, 1, 1), (2, 1, 3, 2)], - [(1, 1, 1, 1), (1, 2, 2, 1)]))) - ], - dtype=[np.float32], + [ + dict(init_val=init_val, op=op, dtype=dtype) + for init_val, op, dtypes in [ + (1, _reduce_custom_add, [np.float32]), + (0, _reduce_custom_mul, [np.float32]), + (0, _reduce_custom_sub, [np.float32]), + ] + for dtype in dtypes + ], + [ + dict( + shape=shape, + dims=dims, + strides=strides, + padding=padding, + base_dilation=base_dilation, + window_dilation=window_dilation, + ) + for shape, dims, strides, padding, base_dilation, window_dilation in ( + itertools.chain( + itertools.product( + [(4, 6)], + [(2, 1), (1, 2)], + [(1, 1), (2, 1), (1, 2)], + ['VALID', 'SAME', [(0, 3), (1, 2)]], + [(1, 1), (2, 3)], + [(1, 1), (1, 2)], + ), + itertools.product( + [(3, 2, 4, 6)], + [(1, 1, 2, 1), (2, 1, 2, 1)], + [(1, 2, 2, 1), (1, 1, 1, 1)], + ['VALID', 'SAME', [(0, 1), (1, 0), (2, 3), (0, 2)]], + [(1, 1, 1, 1), (2, 1, 3, 2)], + [(1, 1, 1, 1), (1, 2, 2, 1)], + ), + ) + ) + ], + ) + @jtu.skip_on_devices('gpu') # jax.lax.mul has an XLA bug on GPU b/339071103 + @jtu.skip_on_devices('tpu') # b/39342488 + def testReduceWindowGeneralJVP( + self, + op, + init_val, + dtype, + shape, + dims, + strides, + padding, + base_dilation, + window_dilation, + ): + rng = jtu.rand_small(self.rng()) + init_val = np.asarray(init_val, dtype=dtype) + + def fun(operand, init_val): + return lax.reduce_window( + operand, + init_val, + op, + dims, + strides, + padding, + base_dilation, + window_dilation, + ) + + args_maker = lambda: [rng(shape, dtype), init_val] + self._CompileAndCheck(fun, args_maker) + args = args_maker() + init_val = args[1] + + # we separately test the version that uses a concrete init_val because it + # can hit different code paths + def fun2(operand): + return lax.reduce_window( + operand, + init_val, + op, + dims, + strides, + padding, + base_dilation, + window_dilation, + ) + + args_maker = lambda: [rng(shape, dtype)] + self._CompileAndCheck(fun2, args_maker) + + operand = args_maker()[0] + jtu.check_jvp(fun2, partial(jax.jvp, fun2), (operand,)) + check_grads(fun2, (operand,), 3, ["fwd"], eps=1.) + + @jtu.sample_product( + [ + dict(init_val=init_val, op=op, dtype=dtype) + for init_val, op, dtypes in [ + (-np.inf, lax.max, [np.float32]), + (np.inf, lax.min, [np.float32]), + (0, lax.add, [np.float32]), + ] + for dtype in dtypes + ], + [ + dict( + shape=shape, + dims=dims, + strides=strides, + padding=padding, + base_dilation=base_dilation, + window_dilation=window_dilation, + ) + for shape, dims, strides, padding, base_dilation, window_dilation in ( + itertools.chain( + itertools.product( + [(4, 6)], + [(2, 1), (1, 2)], + [(1, 1), (2, 1), (1, 2)], + ['VALID', 'SAME', [(0, 3), (1, 2)]], + [(1, 1), (2, 3)], + [(1, 1), (1, 2)], + ), + itertools.product( + [(3, 2, 4, 6)], + [(1, 1, 2, 1), (2, 1, 2, 1)], + [(1, 2, 2, 1), (1, 1, 1, 1)], + ['VALID', 'SAME', [(0, 1), (1, 0), (2, 3), (0, 2)]], + [(1, 1, 1, 1), (2, 1, 3, 2)], + [(1, 1, 1, 1), (1, 2, 2, 1)], + ), + ) + ) + ], ) + @jtu.skip_on_devices('gpu') # jax.lax.mul has an XLA bug on GPU b/339071103 + @jtu.skip_on_devices('tpu') # b/39342488 + def testReduceWindowCustomSameAsMonoid( + self, + op, + init_val, + dtype, + shape, + dims, + strides, + padding, + base_dilation, + window_dilation, + ): + rng = jtu.rand_small(self.rng()) + init_val = np.asarray(init_val, dtype=dtype) + + def fun(op_, operand_): + return lax.reduce_window( + operand_, + init_val, + op_, + dims, + strides, + padding, + base_dilation, + window_dilation, + ) + + args_maker = lambda: [rng(shape, dtype)] + args = args_maker() + operand = args[0] + rng = np.random.RandomState(0) + tangent = tree_map(partial(jtu.rand_like, rng), operand) + + # There are "special" paths for "monoid" ops that have + # their jvp defined separately, either for legacy reasons + # or for optimization - compare across both and prove + # that their jvp is the same. + # TODO(voz): Look into the "monoid" paths and collapse them as necessary. + # Especially when we go to add support for (1) recursive is_jvp (hessians), + # and (2) transpose? + custom_equiv = { + lax.max: _reduce_custom_max, + lax.min: _reduce_custom_min, + lax.add: _reduce_custom_add, + } + custom_op = custom_equiv[op] + custom_primals, custom_tangents = jax.jvp( + partial(fun, custom_op), + primals=(operand,), + tangents=(tangent,), + ) + lax_primals, lax_tangents = jax.jvp( + partial(fun, op), + primals=(operand,), + tangents=(tangent,), + ) + # tol = 1e-4 + # None is sane defaults, but useful to have here for debugging. + tol = None + jtu.check_close( + lax_primals, + custom_primals, + atol=tol, + rtol=tol, + err_msg='Mismatched primal', + ) + jtu.check_close( + lax_tangents, + custom_tangents, + atol=tol, + rtol=tol, + err_msg='Mismatched tangents', + ) + # Numerical jvp comparison for min and max values + # does not work - the underlying implementation of the test util + # nans on infs. + if init_val.item() in (np.inf, -np.inf): + return + op_bound_fn = partial(fun, op) + jtu.check_jvp( + op_bound_fn, + partial(jax.jvp, op_bound_fn), + (operand,), + ) + check_grads(partial(fun, op), [operand], 3, ["fwd"], eps=1.) + check_grads(partial(fun, custom_op), [operand], 3, ["fwd"], eps=1.) + # TODO(b/183233858): variadic reduce-window is not implemented on XLA:GPU - @jtu.skip_on_devices("gpu") + @jtu.sample_product( + [ + dict( + shape=shape, + dims=dims, + strides=strides, + padding=padding, + base_dilation=base_dilation, + window_dilation=window_dilation, + ) + for shape, dims, strides, padding, base_dilation, window_dilation in ( + itertools.chain( + itertools.product( + [(4, 6)], + [(2, 1), (1, 2)], + [(1, 1), (2, 1), (1, 2)], + ['VALID', 'SAME', [(0, 3), (1, 2)]], + [(1, 1), (2, 3)], + [(1, 1), (1, 2)], + ), + itertools.product( + [(3, 2, 4, 6)], + [(1, 1, 2, 1), (2, 1, 2, 1)], + [(1, 2, 2, 1), (1, 1, 1, 1)], + ['VALID', 'SAME', [(0, 1), (1, 0), (2, 3), (0, 2)]], + [(1, 1, 1, 1), (2, 1, 3, 2)], + [(1, 1, 1, 1), (1, 2, 2, 1)], + ), + ) + ) + ], + dtype=[np.float32], + ) + @jtu.skip_on_devices('gpu') def testReduceWindowVariadic(self, dtype, shape, dims, strides, padding, base_dilation, window_dilation): if (jtu.test_device_matches(["tpu"]) and @@ -2652,6 +2981,24 @@ def testRngBitGeneratorReturnedKey(self): new_key, _ = lax.rng_bit_generator(key, (0,)) self.assertAllClose(key, new_key) + def test_rng_bit_generator_vmap(self): + def f(key): + return lax.rng_bit_generator(key, shape=(5, 7)) + + keys = np.arange(3 * 4).reshape((3, 4)).astype(np.uint32) + out_keys, bits = jax.vmap(f)(keys) + self.assertEqual(out_keys.shape, (3, 4)) + self.assertEqual(bits.shape, (3, 5, 7)) + + def test_rng_bit_generator_vmap_vmap(self): + def f(key): + return lax.rng_bit_generator(key, shape=(5, 7)) + + keys = np.arange(2 * 3 * 4).reshape((2, 3, 4)).astype(np.uint32) + out_keys, bits = jax.vmap(jax.vmap(f))(keys) + self.assertEqual(out_keys.shape, (2, 3, 4)) + self.assertEqual(bits.shape, (2, 3, 5, 7)) + @jtu.sample_product( dtype=lax_test_util.all_dtypes + lax_test_util.python_scalar_types, weak_type=[True, False], @@ -2709,21 +3056,20 @@ def f(x): def test_constant_folding_complex_to_real_scan_regression(self): # regression test for github.com/google/jax/issues/19059 def g(hiddens): - hiddens_aug = jnp.vstack((hiddens[0], hiddens)) - new_hiddens = hiddens_aug.copy() - diff = new_hiddens[:-1] - hiddens - diff = new_hiddens[:-1] - hiddens - out = jnp.trace(jnp.conj(diff).T @ diff).real - return jnp.array(out, dtype=jnp.complex64) - + hiddens_aug = jnp.vstack((hiddens[0], hiddens)) + new_hiddens = hiddens_aug.copy() + diff = new_hiddens[:-1] - hiddens + diff = new_hiddens[:-1] - hiddens + out = jnp.trace(jnp.conj(diff).T @ diff).real + return jnp.array(out, dtype=jnp.complex64) def _step(carry, arg): - primals, f_vjp = jax.vjp( - g, - jax.random.normal(jax.random.key(0), (9, 8), dtype=jnp.complex64), - ) - out = f_vjp(np.array(1.0 + 0j, 'complex64'))[0] - return carry, carry + primals, f_vjp = jax.vjp( + g, + jax.random.normal(jax.random.key(0), (9, 8), dtype=jnp.complex64), + ) + out = f_vjp(np.array(1.0 + 0j, 'complex64'))[0] + return carry, carry a, b = jax.lax.scan(_step, 0, jnp.arange(4, dtype=jnp.complex64)) @@ -2749,7 +3095,6 @@ def testAsarray(self, typ): jax.jit(asarray_closure)() - class LazyConstantTest(jtu.JaxTestCase): def _Check(self, make_const, expected): # check casting to ndarray works @@ -2970,22 +3315,6 @@ class FooTyRules: def physical_element_aval(dtype) -> core.ShapedArray: return core.ShapedArray((2,), jnp.dtype('uint32')) - @staticmethod - def physical_hlo_sharding(aval, hlo_sharding: xc.HloSharding): - op_sharding_proto = hlo_sharding.to_proto() - new_op_sharding = op_sharding_proto.clone() - tad = list(new_op_sharding.tile_assignment_dimensions) - new_op_sharding.tile_assignment_dimensions = [*tad, 1] - return xc.HloSharding.from_proto(new_op_sharding) - - @staticmethod - def logical_sharding(aval, phys_sharding): - return phys_sharding - - @staticmethod - def physical_sharding(aval, sharding): - return sharding - @staticmethod def result_handler(sticky_device, aval): def handler(_, buf): @@ -3004,14 +3333,6 @@ def handler(arr): return FooArray(aval.shape, buf) return handler - @staticmethod - def replicate_trailing_dims(ctx, val, aval): - return val - - @staticmethod - def check_replicated_trailing_dims(sharding: jax.sharding.GSPMDSharding, aval): - pass - class FooTy(dtypes.ExtendedDType): type = dtypes.extended @@ -3073,11 +3394,14 @@ def __repr__(self) -> str: size = property(lambda self: self.data.size // 2) ndim = property(lambda self: self.data.ndim - 1) -def shard_foo_array_handler(x, sharding): - device, = sharding._addressable_device_assignment - aval = core.raise_to_shaped(core.get_aval(x.data)) - return pxla.batched_device_put( - aval, jax.sharding.SingleDeviceSharding(device), [x.data], [device]) +def shard_foo_array_handler(xs, shardings): + results = [] + for x, sharding in safe_zip(xs, shardings): + device, = sharding._addressable_device_assignment + aval = core.raise_to_shaped(core.get_aval(x.data)) + results.append(pxla.batched_device_put( + aval, jax.sharding.SingleDeviceSharding(device), [x.data], [device])) + return results def foo_array_constant_handler(x): return array._array_mlir_constant_handler(x.data) @@ -3237,6 +3561,54 @@ def test_scan_jaxpr(self): b, = e.outvars self.assertEqual(b.aval, core.ShapedArray((3, 4), FooTy())) + def test_scan_jaxpr_split_transpose(self): + def stage(x, w): + x = x @ w + x = jnp.tanh(x) + return (x, ()) + def loss(ws, x, split_transpose=False): + return jnp.sum(jax.lax.scan(stage, x, ws, + _split_transpose=split_transpose)[0]) + + def fn(*args, split_transpose=False): + v, fn_transpose = jax.vjp( + partial(loss, split_transpose=split_transpose), *args) + grads = fn_transpose(1.0) + return *grads, v + + # x : [batch, d_model] + x = jax.random.uniform(jax.random.key(0), [256, 100]) + # wss : [layers, d_model, d_model] + wss = jax.random.uniform(jax.random.key(1), [7, 100, 100]) + + jaxpr = jax.make_jaxpr(partial(fn))(wss, x) + jaxpr_split_transpose = jax.make_jaxpr(partial(fn, split_transpose=True))( + wss, x + ) + + # Check that the shapes were preserved. + self.assertEqual(jaxpr.in_avals, jaxpr_split_transpose.in_avals) + self.assertEqual(jaxpr.out_avals, jaxpr_split_transpose.out_avals) + + # The first two outvars (corresponding to gradients of params and inputs) + # must come from two different loops. + ct_ws = jaxpr_split_transpose.jaxpr.outvars[0] + ct_x = jaxpr_split_transpose.jaxpr.outvars[1] + + # The last two equations are the two loops we care about + backprop_scan = jaxpr_split_transpose.jaxpr.eqns[-2] + self.assertEqual(backprop_scan.primitive, jax.lax.scan_p) + + param_gradient_map = jaxpr_split_transpose.jaxpr.eqns[-1] + self.assertEqual(param_gradient_map.primitive, jax.lax.scan_p) + self.assertEqual(param_gradient_map.params['num_consts'], 0) + self.assertEqual(param_gradient_map.params['num_carry'], 0) + + # Assert that parameter gradients come from the map. + self.assertEqual(ct_ws, param_gradient_map.outvars[0]) + # And that activation gradients come from the scan. + self.assertEqual(ct_x, backprop_scan.outvars[0]) + def test_scan_lowering(self): ks = jax.jit(lambda: make((3, 4)))() f = lambda ks: jax.lax.scan(lambda _, k: (None, bake(k)), None, ks) @@ -3355,85 +3727,133 @@ def f(x): class FunctionAccuracyTest(jtu.JaxTestCase): @parameterized.named_parameters( - dict(testcase_name=f"_{name}_{dtype.__name__}_{kind}", name=name, dtype=dtype, kind=kind) - for name, dtype, kind in itertools.product( - [ 'arccos', 'arccosh', 'arcsin', 'arcsinh', - 'arctan', 'arctanh', 'conjugate', 'cos', - 'cosh', 'exp', 'exp2', 'expm1', 'log', - 'log10', 'log1p', 'sin', 'sinh', 'sqrt', - 'square', 'tan', 'tanh', 'sinc', 'positive', - 'negative', 'absolute', 'sign'], + dict(testcase_name=f"_{dtype.__name__}", dtype=dtype) + for dtype in jtu.dtypes.supported([np.float32, np.float64, np.complex64, np.complex128])) + def testMPMathUtils(self, dtype): + try: + import mpmath + except ImportError as msg: + self.skipTest(f'could not import mpmath: {msg}') + + prec = {np.float32: 24, np.float64: 53, np.complex64: 24, np.complex128: 53}[dtype] + is_complex = dtype().dtype.kind == 'c' + + def func(x): + assert isinstance(x, mpmath.ctx_mp.mpnumeric) + assert x.context.prec == prec + assert isinstance(x, x.context.mpc if is_complex else x.context.mpf) + return x + + ufunc = jtu.vectorize_with_mpmath(func, mpmath=mpmath) + + with jtu.ignore_warning(category=RuntimeWarning, message="(overflow|invalid value|divide by zero) encountered in.*"): + if is_complex: + arr = jtu.complex_plane_sample(dtype=dtype, size_re=11) + else: + cdtype = getattr(np, ufunc.map_float_to_complex[dtype.__name__]) + arr = jtu.complex_plane_sample(dtype=cdtype, size_re=11, size_im=0)[1:2].real + + arr2 = ufunc.mptonp(ufunc.nptomp(arr)) + with jtu.ignore_warning(category=RuntimeWarning, message="(overflow|invalid value|divide by zero) encountered in.*"): + self.assertAllClose(arr, arr2, atol=0, rtol=0) + + arr3 = ufunc(arr) + with jtu.ignore_warning(category=RuntimeWarning, message="(overflow|invalid value|divide by zero) encountered in.*"): + self.assertAllClose(arr, arr3, atol=0, rtol=0) + + if is_complex: + # tests scale in normalize + v = dtype(1.1071487177940644+1.1102230246251565e-16j) + r = dtype(1.1071487177940644+0j) + mnp = jtu.numpy_with_mpmath(mpmath, extra_prec=1) + nr, nv = mnp.normalize(r, r, v) + self.assertAllClose(nr, nv) + + _functions_on_complex_plane = [ + 'arccos', 'arccosh', 'arcsin', 'arcsinh', + 'arctan', 'arctanh', 'conjugate', 'cos', + 'cosh', 'exp', 'exp2', 'expm1', 'log', + 'log10', 'log1p', 'sin', 'sinh', 'sqrt', + 'square', 'tan', 'tanh', 'sinc', 'positive', + 'negative', 'absolute', 'sign' + ] + + @parameterized.named_parameters( + dict(testcase_name=f"_{name}_{dtype.__name__}", name=name, dtype=dtype) + for name, dtype in itertools.product( + _functions_on_complex_plane, jtu.dtypes.supported([np.complex64, np.complex128]), - ['success', 'failure'], )) @jtu.skip_on_devices("tpu") - def testOnComplexPlane(self, name, dtype, kind): - all_regions = ['q1', 'q2', 'q3', 'q4', 'pos', 'neg', 'posj', 'negj', 'ninf', 'pinf', 'ninfj', 'pinfj', 'zero'] + def testSuccessOnComplexPlane(self, name, dtype): + self._testOnComplexPlaneWorker(name, dtype, 'success') + + @parameterized.named_parameters( + dict(testcase_name=f"_{name}_{dtype.__name__}", name=name, dtype=dtype) + for name, dtype in itertools.product( + _functions_on_complex_plane, + jtu.dtypes.supported([np.complex64, np.complex128]), + )) + @jtu.skip_on_devices("tpu") + def testFailureOnComplexPlane(self, name, dtype): + self._testOnComplexPlaneWorker(name, dtype, 'failure') + + def _testOnComplexPlaneWorker(self, name, dtype, kind): + try: + import mpmath + except ImportError as msg: + self.skipTest(f'could not import mpmath: {msg}') + is_cpu = jtu.test_device_matches(["cpu"]) machine = platform.machine() + # TODO: remove is_arm_cpu as previously arm cpu related failures + # were due to numpy issues. Confirm? is_arm_cpu = machine.startswith('aarch') or machine.startswith('arm') is_cuda = jtu.test_device_matches(["cuda"]) - # TODO(pearu): eliminate all items in the following lists: - # TODO(pearu): when all items are eliminated, eliminate the kind == 'failure' tests - regions_with_inaccuracies = dict( - absolute = ['q1', 'q2', 'q3', 'q4'] if dtype == np.complex128 and is_cuda else [], - exp = (['pos', 'pinfj', 'pinf', 'ninfj', 'ninf'] - + (['q1', 'q4'] if is_arm_cpu and dtype != np.complex128 else [])), - exp2 = ['pos', 'pinfj', 'pinf', 'ninfj', 'ninf', *(['q1', 'q4'] if is_cpu else [])], - log = ['q1', 'q2', 'q3', 'q4'], - log1p = ['q1', 'q2', 'q3', 'q4', 'pos', 'neg', 'posj', 'negj', 'ninf', 'ninfj', 'pinfj'], - log10 = ['q1', 'q2', 'q3', 'q4', 'zero', 'ninf', 'ninfj', 'pinf', 'pinfj'], - sinh = (['pos', 'neg', 'ninf', 'pinf'] - + (['q1', 'q2', 'q3', 'q4'] if is_arm_cpu and dtype != np.complex128 else [])), - cosh = (['pos', 'neg', 'ninf', 'pinf'] - + (['q1', 'q2', 'q3', 'q4'] if is_arm_cpu and dtype != np.complex128 else [])), - tan = ['q1', 'q2', 'q3', 'q4', 'negj', 'posj', 'ninf', 'ninfj', 'pinf', 'pinfj'], - square = (['pinf'] - + (['ninfj', 'pinfj'] if is_arm_cpu else []) - + (['ninf'] if not is_arm_cpu else []) - + (['q1', 'q2', 'q3', 'q4', 'ninfj', 'pinfj'] if is_cuda else []) - + (['q1', 'q2', 'q3', 'q4'] if is_cpu and dtype == np.complex128 else [])), - sinc = ['q1', 'q2', 'q3', 'q4'], - sign = ['q1', 'q2', 'q3', 'q4', 'negj', 'posj', 'ninf', 'ninfj', 'pinf', 'pinfj'], - arcsin = ['q1', 'q2', 'q3', 'q4', 'pos', 'neg', 'posj', 'negj', 'ninf', 'pinf', 'ninfj', 'pinfj'], - arccos = ['q1', 'q2', 'q3', 'q4', 'pos', 'neg', 'posj', 'negj', 'ninf', 'pinf', 'ninfj', 'pinfj'], - arctan = ['q1', 'q2', 'q3', 'q4', 'pos', 'neg', 'posj', 'negj', 'ninf', 'pinf', 'ninfj', 'pinfj'], - arcsinh = ['q1', 'q2', 'q3', 'q4', 'pos', 'neg', 'posj', 'negj', 'ninf', 'pinf', 'ninfj', 'pinfj'], - arccosh = ['q1', 'q2', 'q3', 'q4', 'pos', 'neg', 'posj', 'negj', 'ninf', 'pinf', 'ninfj', 'pinfj'], - arctanh = ['q1', 'q2', 'q3', 'q4', 'pos', 'neg', 'posj', 'negj', 'ninf', 'pinf', 'ninfj', 'pinfj'], - sin = ['q1', 'q2', 'q3', 'q4', 'ninfj', 'pinfj'] if is_arm_cpu and dtype != np.complex128 else [], - cos = ['q1', 'q2', 'q3', 'q4', 'ninfj', 'pinfj'] if is_arm_cpu and dtype != np.complex128 else [], - expm1 = ['q1', 'q4', 'pinf'] if is_arm_cpu and dtype != np.complex128 else [], - ) + size_re = 11 + size_im = 11 + atol = None - jnp_op = getattr(jnp, name) + mnp = jtu.numpy_with_mpmath(mpmath, extra_prec=1) + mnp2 = jtu.numpy_with_mpmath(mpmath, extra_prec_multiplier=1) - if name == 'square': - # numpy square is incorrect on inputs with large absolute value - tiny = np.finfo(dtype).tiny - - def square(x): - re = (x.real - x.imag) * (x.real + x.imag) - im = x.real * x.imag * 2 - if is_cuda: - # apply FTZ - if np.isfinite(re) and abs(re) < tiny: - re *= 0 - if np.isfinite(im) and abs(im) < tiny: - im *= 0 - return np.array(complex(re, im), dtype=dtype) - - np_op = np.vectorize(square) - else: - np_op = getattr(np, name) + ref_op = getattr(mnp, name) + ref2_op = getattr(mnp2, name) + jnp_op = getattr(jnp, name) with jtu.ignore_warning(category=RuntimeWarning, message="(overflow|invalid value|divide by zero) encountered in.*"): - args = (jtu.complex_plane_sample(dtype=dtype, size_re=11),) + args = (jtu.complex_plane_sample(dtype=dtype, size_re=size_re, size_im=size_im),) result = np.asarray(jnp_op(*args)) - expected = np_op(*args) - - s0, s1 = (result.shape[0] - 3) // 2, (result.shape[1] - 3) // 2 + expected = ref_op(*args) + expected2 = ref2_op(*args) + + normalized_expected, normalized_result = mnp2.normalize(expected2, expected, result) + + # When comparing the results with expected, we'll divide the + # complex plane grid into smaller regions and perform the + # closeness tests on each region separately. The reason for this + # is that the inaccuracy or incorrectness issues with a particular + # function exists typically in specific regions while in other + # regions the function is accurate. So, such a division of the + # complex plane helps to identify the problematic regions as well + # as to fix the inaccuracy or incorrectness issues. + # + # Regions in complex plane: + # + # ( pinfj ) + # ( q2 ) (posj) ( q1 ) + # (ninf) ( neg ) (zero) ( pos ) (pinf) + # ( q3 ) (negj) ( q4 ) + # ( ninfj ) + # + # In addition, the 1/3 middle parts of regions q1, q2, q3, q4, + # neg, pos are tested separately as these don't contain extremely + # small or extremelly large values and functions on these regions + # ought not to possess any incorrectness issues. + + s0, s1 = size_re, size_im + s03, s13 = s0 // 3, s1 // 3 s_dict = dict( q1=(slice(s0 + 2, -1), slice(s1 + 2, -1)), q2=(slice(s0 + 2, -1), slice(1, s1 + 1)), @@ -3450,28 +3870,233 @@ def square(x): zero=(slice(s0 + 1, s0 + 2), slice(s1 + 1, s1 + 2)), ) - for region in all_regions: - s = s_dict[region] - inds = np.where(result[s] != expected[s]) - if inds[0].size > 0: - mismatches = [] - for ind in zip(*inds): - x, r, e = args[0][s][ind], str(result[s][ind]), str(expected[s][ind]) - if r == e: - # skip equal nan-s - continue - mismatches.append(f'jax.numpy.{name}{x} -> {r}, expected {e}') - mismatches = "\n".join(mismatches) + if s03 and s13: + s_dict.update( + mq1 = (slice(s0 + 3 + s03, s0 + 3 + 2 * s03), slice(s1 + 3 + s13, s1 + 3 + 2 * s13)), + mq2 = (slice(s0 + 3 + s03, s0 + 3 + 2 * s03), slice(2 + s13, 2 + 2 * s13)), + mq3 = (slice(2 + s03, 2 + 2 * s03), slice(2 + s13, 2 + 2 * s13)), + mq4 = (slice(2 + s03, 2 + 2 * s03), slice(s1 + 3 + s13, s1 + 3 + 2 * s13)), + mneg=(s0 + 1, slice(2 + s13, 2 + 2 * s13)), + mpos=(s0 + 1, slice(s1 + 3 + s13, s1 + 3 + 2 * s13)), + mnegj=(slice(2 + s03, 2 + 2 * s03), s1 + 1), + mposj=(slice(s0 + 3 + s03, s0 + 3 + 2 * s03), s1 + 1), + ) + + # The regions are split to real and imaginary parts (of function + # return values) to (i) workaround numpy 1.x assert_allclose bug + # in comparing complex infinities, and (ii) expose more details + # about failing cases: + s_dict_parts = dict() + for k, v in s_dict.items(): + s_dict_parts[k + '.real'] = v + s_dict_parts[k + '.imag'] = v + + # Start with an assumption that all regions are problematic for a + # particular function: + regions_with_inaccuracies = list(s_dict_parts) + + # Next, we'll remove non-problematic regions from the + # regions_with_inaccuracies list by explicitly keeping problematic + # regions: + def regions_with_inaccuracies_keep(*to_keep): + to_keep_parts = [] + for r in to_keep: + if r.endswith('.real') or r.endswith('.imag'): + to_keep_parts.append(r) + else: + to_keep_parts.append(r + '.real') + to_keep_parts.append(r + '.imag') + for item in regions_with_inaccuracies[:]: + if item not in to_keep_parts: + regions_with_inaccuracies.remove(item) + + if name == 'absolute': + if is_cuda and dtype == np.complex128: + regions_with_inaccuracies_keep('q1.real', 'q2.real', 'q3.real', 'q4.real') else: - mismatches = '' - if kind == 'success' and region not in regions_with_inaccuracies.get(name, []): - with jtu.ignore_warning(category=RuntimeWarning, message="overflow encountered in.*"): - self.assertAllClose(result[s], expected[s], err_msg=f"{name} in {region}, {is_cpu=} {is_cuda=}, {xla_extension_version=}\n{mismatches}") - if kind == 'failure' and region in regions_with_inaccuracies.get(name, []): - with self.assertRaises(AssertionError, msg=f"{name} in {region}, {is_cpu=} {is_cuda=}, {xla_extension_version=}"): - with jtu.ignore_warning(category=RuntimeWarning, message="overflow encountered in.*"): - self.assertAllClose(result[s], expected[s]) # on success, update regions_with_inaccuracies + regions_with_inaccuracies.clear() + + elif name == 'sign': + regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4') + + elif name == 'square': + if is_cuda: + regions_with_inaccuracies_keep('q1.real', 'q2.real', 'q3.real', 'q4.real', 'ninf.real', 'pinf.real', 'ninfj.real', 'pinfj.real') + if is_cpu: + regions_with_inaccuracies_keep('ninf.real', 'pinf.real', 'q1.real', 'q2.real', 'q3.real', 'q4.real') + + elif name == 'log': + regions_with_inaccuracies_keep('q1.real', 'q2.real', 'q3.real', 'q4.real', 'ninf.imag', 'pinf.imag', 'ninfj.imag', 'pinfj.imag') + + elif name == 'log10': + regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'ninf.imag', 'pinf.imag', 'ninfj.imag', 'pinfj.imag', 'zero.imag') + + elif name == 'log1p': + regions_with_inaccuracies_keep('q1.real', 'q2.real', 'q3.real', 'q4.real', 'neg.real', 'pos.real', + 'negj.real', 'posj.real', 'ninf.real', 'ninfj.real', 'pinfj.real') + + elif name == 'exp': + regions_with_inaccuracies_keep('pos.imag', 'pinf.imag', 'mpos.imag') + + elif name == 'exp2': + if dtype == np.complex64: + regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'pos.imag', 'negj', 'posj', 'ninf', 'pinf', 'mq1', 'mq2', 'mq3', 'mq4', 'mpos.imag', 'mnegj', 'mposj') + if dtype == np.complex128: + regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'pos.imag', 'negj', 'posj', 'ninf', 'pinf', 'mpos.imag') + + elif name == 'expm1': + regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'pinf', 'mq1', 'mq2', 'mq3', 'mq4', 'mneg', 'mpos') + + elif name == 'sinc': + regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'negj', 'posj', 'mq1', 'mq2', 'mq3', 'mq4', + 'mneg.real', 'mpos.real', 'mnegj', 'mposj', + 'ninf.imag', 'pinf.imag', 'ninfj.real', 'pinfj.real') + + elif name == 'tan': + # TODO(pearu): eliminate this if-block when openxla/xla#10525 lands + regions_with_inaccuracies_keep('q1.imag', 'q2.imag', 'q3.imag', 'q4.imag', 'negj.imag', 'posj.imag', + 'ninfj.imag', 'pinfj.imag', 'mq1.imag', 'mq2.imag', 'mq3.imag', 'mq4.imag', 'mnegj.imag', 'mposj.imag', + 'ninf.imag', 'pinf.imag', 'ninf.real', 'pinf.real', 'ninfj.real', 'pinfj.real') + + elif name == 'sinh': + if is_cuda: + regions_with_inaccuracies_keep('q1.real', 'q2.real', 'q3.real', 'q4.real', 'neg', 'pos', + 'ninf.imag', 'pinf.imag', 'mq1.real', 'mq2.real', 'mq3.real', 'mq4.real', 'mneg', 'mpos', + 'ninfj.real', 'pinfj.real') + if is_cpu: + regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'negj.imag', 'posj.imag', 'ninf.imag', 'pinf.imag', + 'mq1.real', 'mq2.real', 'mq3.real', 'mq4.real', 'mneg', 'mpos', + 'ninfj.real', 'pinfj.real') + elif name == 'cosh': + regions_with_inaccuracies_keep('neg.imag', 'pos.imag', 'ninf.imag', 'pinf.imag', 'mneg.imag', 'mpos.imag', + 'ninfj.imag', 'pinfj.imag') + + elif name == 'tanh': + regions_with_inaccuracies_keep('ninf', 'pinf', 'ninfj', 'pinfj') + + elif name == 'arccos': + if dtype == np.complex64: + regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'negj', 'posj', 'ninf', 'pinf', 'ninfj', 'pinfj', 'mq2', 'mq3', 'mneg', + 'mpos.imag', 'mnegj') + if dtype == np.complex128: + regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'negj', 'posj', 'ninf', 'pinf', 'ninfj', 'pinfj', 'mq2', 'mq3', 'mq4', 'mneg', 'mpos.imag', 'mnegj') + + elif name == 'arccosh': + if dtype == np.complex64: + regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'negj', 'posj.real', 'ninf', 'pinf', 'ninfj', 'pinfj', 'mq2', 'mq3', 'mneg', 'mpos.imag', 'mnegj') + if dtype == np.complex128: + regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos.real', 'negj', 'posj.real', 'ninf', 'pinf', 'ninfj', 'pinfj', 'mq2', 'mq3', 'mq4', 'mneg', 'mnegj') + + elif name == 'arcsin': + if dtype == np.complex64: + regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'negj', 'posj', 'ninf', 'pinf', 'ninfj', 'pinfj', 'mq1', 'mq2', 'mq3', 'mq4', 'mneg', 'mpos', 'mnegj', 'mposj') + if dtype == np.complex128: + regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'negj', 'posj', 'ninf', 'pinf', 'ninfj', 'pinfj', + 'mq1', 'mq2', 'mq3', 'mq4', 'mneg.imag', 'mpos.imag', 'mnegj', 'mposj') + + elif name == 'arcsinh': + if dtype == np.complex64: + regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg.real', 'pos.real', 'negj', 'posj.real', 'ninf', 'pinf', 'ninfj', 'pinfj', + 'mq1.real', 'mq2', 'mq3', 'mq4.real', 'mneg.real', 'mpos.real', 'mnegj') + if dtype == np.complex128: + regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg.real', 'pos.real', 'negj', 'posj.real', 'ninf', 'pinf', 'ninfj', 'pinfj', 'mq2', 'mq3', 'mneg.real', 'mnegj') + + elif name == 'arctan': + if dtype == np.complex64: + regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'negj', 'posj', 'ninf', 'pinf', 'ninfj', 'pinfj', + 'mq1.imag', 'mq2.imag', 'mq3.imag', 'mq4.imag', 'mnegj.imag', 'mposj.imag') + if dtype == np.complex128: + regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'negj', 'posj', 'ninf', 'pinf', 'ninfj', 'pinfj') + + elif name == 'arctanh': + regions_with_inaccuracies_keep('pos.imag', 'ninf', 'pinf', 'ninfj', 'pinfj', 'mpos.imag') + + elif name in {'cos', 'sin'}: + regions_with_inaccuracies_keep('ninf.imag', 'pinf.imag') + + elif name in {'positive', 'negative', 'conjugate', 'sin', 'cos', 'sqrt', 'expm1', 'log1p'}: + regions_with_inaccuracies.clear() + else: + assert 0 # unreachable + + # Finally, perform the closeness tests per region: + unexpected_success_regions = [] + for region_name, region_slice in s_dict_parts.items(): + region = args[0][region_slice] + if region_name.endswith('.real'): + result_slice, expected_slice = result[region_slice].real, expected[region_slice].real + normalized_result_slice, normalized_expected_slice = normalized_result[region_slice].real, normalized_expected[region_slice].real + elif region_name.endswith('.imag'): + result_slice, expected_slice = result[region_slice].imag, expected[region_slice].imag + normalized_result_slice, normalized_expected_slice = normalized_result[region_slice].imag, normalized_expected[region_slice].imag + else: + result_slice, expected_slice = result[region_slice], expected[region_slice] + normalized_result_slice, normalized_expected_slice = normalized_result[region_slice], normalized_expected[region_slice] + + inexact_indices = np.where(normalized_result_slice != normalized_expected_slice) + if inexact_indices[0].size == 0: + inexact_samples = '' + else: + inexact_samples = [] + for ind in zip(*inexact_indices): + x = region[ind] + y1, y2 = result[region_slice][ind], expected[region_slice][ind] + ny1, ny2 = normalized_result[region_slice][ind], normalized_expected[region_slice][ind] + if str(y1) == str(y2): # skip equal nan-s + continue + max_abs_diff = abs(ny1 - ny2).max() if np.isfinite(y1) and np.isfinite(y1) else np.inf + inexact_samples.append((max_abs_diff, f'jax.numpy.{name}({x}) -> {y1} [{ny1}], expected {y2} [{ny2}]')) + inexact_samples = "\n".join([msg for _, msg in sorted(inexact_samples)]) + + if kind == 'success' and region_name not in regions_with_inaccuracies: + with jtu.ignore_warning(category=RuntimeWarning, message="overflow encountered in.*"): + self.assertAllClose( + normalized_result_slice, normalized_expected_slice, atol=atol, + err_msg=f"{name} in {region_name}, {is_cpu=} {is_cuda=},\n{inexact_samples}") + + if kind == 'failure' and region_name in regions_with_inaccuracies: + try: + with self.assertRaises(AssertionError, msg=f"{name} in {region_name}, {is_cpu=} {is_cuda=}"): + with jtu.ignore_warning(category=RuntimeWarning, message="overflow encountered in.*"): + self.assertAllClose(normalized_result_slice, normalized_expected_slice) + except AssertionError as msg: + if str(msg).startswith('AssertionError not raised'): + unexpected_success_regions.append(region_name) + else: + raise # something else is wrong.. + + def eliminate_parts(seq): + # replace n.real and n.imag items in seq with n. + result = [] + for part_name in seq: + name = part_name.split('.')[0] + if name in result: + continue + if name + '.real' in seq and name + '.imag' in seq: + result.append(name) + else: + result.append(part_name) + return result + + regions_with_inaccuracies = eliminate_parts(regions_with_inaccuracies) + unexpected_success_regions = eliminate_parts(unexpected_success_regions) + + if kind == 'success' and regions_with_inaccuracies: + reason = "xfail: problematic regions: " + ", ".join(regions_with_inaccuracies) + raise unittest.SkipTest(reason) + + if kind == 'failure': + if not regions_with_inaccuracies: + raise unittest.SkipTest("no problematic regions") + elif unexpected_success_regions: + # This skip ought to be effective only when fixing functions + # on problematic regions in XLA that should follow up a JAX PR + # that enables testing the functions on these regions for + # success. + raise unittest.SkipTest( + f"detected success in regions {', '.join(unexpected_success_regions)}, please update regions_with_inaccuracies!" + ) if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/lax_vmap_op_test.py b/tests/lax_vmap_op_test.py index 5d30281327d6..c7059a29343c 100644 --- a/tests/lax_vmap_op_test.py +++ b/tests/lax_vmap_op_test.py @@ -26,8 +26,7 @@ from jax._src.internal_test_util import lax_test_util from jax._src import util -from jax import config -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() map, unsafe_map = util.safe_map, map zip, unsafe_zip = util.safe_zip, zip diff --git a/tests/lax_vmap_test.py b/tests/lax_vmap_test.py index 0d22d801d085..37d51c04f8de 100644 --- a/tests/lax_vmap_test.py +++ b/tests/lax_vmap_test.py @@ -35,8 +35,7 @@ from jax._src.lib import xla_client from jax._src.util import safe_map, safe_zip -from jax import config -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() map, unsafe_map = safe_map, map zip, unsafe_zip = safe_zip, zip diff --git a/tests/layout_test.py b/tests/layout_test.py index 7a7168ee7554..d0d0a27b8951 100644 --- a/tests/layout_test.py +++ b/tests/layout_test.py @@ -12,42 +12,29 @@ # See the License for the specific language governing permissions and # limitations under the License. +import contextlib import math -import os from absl.testing import absltest import numpy as np +from functools import partial import jax -from jax.sharding import NamedSharding, PartitionSpec as P +import jax.numpy as jnp +from jax.sharding import NamedSharding, PartitionSpec as P, SingleDeviceSharding from jax._src import config -from jax._src import layout +from jax._src.layout import Layout, DeviceLocalLayout as DLL from jax._src import test_util as jtu from jax._src.util import safe_zip -from jax._src import xla_bridge -from jax._src.lib import xla_extension_version config.parse_flags_with_absl() -prev_xla_flags = None +_exit_stack = contextlib.ExitStack() def setUpModule(): - global prev_xla_flags - prev_xla_flags = os.getenv("XLA_FLAGS") - flags_str = prev_xla_flags or "" - # Don't override user-specified device count, or other XLA flags. - if "xla_force_host_platform_device_count" not in flags_str: - os.environ["XLA_FLAGS"] = (flags_str + - " --xla_force_host_platform_device_count=8") - # Clear any cached backends so new CPU backend will pick up the env var. - xla_bridge.get_backend.cache_clear() - + _exit_stack.enter_context(jtu.set_host_platform_device_count(8)) def tearDownModule(): - if prev_xla_flags is None: - del os.environ["XLA_FLAGS"] - else: - os.environ["XLA_FLAGS"] = prev_xla_flags - xla_bridge.get_backend.cache_clear() + _exit_stack.close() class LayoutTest(jtu.JaxTestCase): @@ -55,14 +42,14 @@ class LayoutTest(jtu.JaxTestCase): def setUp(self): if not jtu.test_device_matches(['tpu']): self.skipTest("Layouts do not work on CPU and GPU backends yet.") - if xla_extension_version < 215: - self.skipTest('All tests require xla_extension_version >= 215') super().setUp() def test_auto_layout(self): - mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) shape1 = (128, 128) shape2 = (128, 128) + s1 = NamedSharding(mesh, P('x', 'y')) + s2 = NamedSharding(mesh, P('x')) def apply(x, y): return x.T, y.T @@ -71,70 +58,98 @@ def init(x, y): return x * 2, y * 2 np_inp1 = np.arange(math.prod(shape1)).reshape(shape1) - arr1 = jax.device_put(np_inp1, NamedSharding(mesh, P('x', 'y'))) np_inp2 = np.arange(math.prod(shape2)).reshape(shape2) - arr2 = jax.device_put(np_inp2, NamedSharding(mesh, P('x'))) + sds1 = jax.ShapeDtypeStruct(np_inp1.shape, np_inp1.dtype, sharding=s1) + sds2 = jax.ShapeDtypeStruct(np_inp2.shape, np_inp2.dtype, sharding=s2) - lowered_apply = jax.jit(apply).lower(arr1, arr2, _in_layouts=layout.AUTO, - _out_layouts=layout.AUTO) + lowered_apply = jax.jit(apply, in_shardings=Layout(DLL.AUTO), + out_shardings=Layout(DLL.AUTO)).lower(sds1, sds2) compiled_apply = lowered_apply.compile() - arg_layouts, kw_layouts = compiled_apply._input_layouts() + arg_layouts, kw_layouts = compiled_apply.input_layouts() self.assertEmpty(kw_layouts) - for i, o in zip(arg_layouts, compiled_apply._output_layouts()): - self.assertEqual(i._minor_to_major, o._minor_to_major[::-1]) - init_compiled = jax.jit(init).lower( - arr1, arr2, _out_layouts=arg_layouts).compile() + for i, o in zip(arg_layouts, compiled_apply.output_layouts()): + self.assertEqual(i.device_local_layout.major_to_minor, + o.device_local_layout.major_to_minor[::-1]) + + init_compiled = jax.jit( + init, out_shardings=arg_layouts).lower(sds1, sds2).compile() + + for i, o in zip(init_compiled.input_layouts()[0], + init_compiled.output_layouts()): + self.assertEqual(i, o) - for i, o in zip(init_compiled._input_layouts()[0], - init_compiled._output_layouts()): - self.assertEqual(i._minor_to_major, o._minor_to_major) + arr1 = jax.device_put(np_inp1, s1) + arr2 = jax.device_put(np_inp2, s2) with jtu.count_aot_jit_cpp_cache_miss() as init_count: init_out = init_compiled(arr1, arr2) init_compiled(arr1, arr2) self.assertEqual(init_count[0], 1) + self.assertEqual(init_out[0].layout, init_compiled.output_layouts()[0]) + self.assertEqual(init_out[1].layout, init_compiled.output_layouts()[1]) + with jtu.count_aot_jit_cpp_cache_miss() as apply_count: apply_out = compiled_apply(*init_out) compiled_apply(*init_out) self.assertEqual(apply_count[0], 1) + self.assertEqual(apply_out[0].layout, compiled_apply.output_layouts()[0]) + self.assertEqual(apply_out[1].layout, compiled_apply.output_layouts()[1]) + + self.assertTupleEqual(apply_out[0].layout.device_local_layout.major_to_minor, + init_out[0].layout.device_local_layout.major_to_minor[::-1]) + self.assertTupleEqual(apply_out[1].layout.device_local_layout.major_to_minor, + init_out[1].layout.device_local_layout.major_to_minor[::-1]) + self.assertArraysEqual(init_out[0], np_inp1 * 2) self.assertArraysEqual(init_out[1], np_inp2 * 2) self.assertArraysEqual(apply_out[0], (np_inp1 * 2).T) self.assertArraysEqual(apply_out[1], (np_inp2 * 2).T) def test_default_layout(self): - mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) - shape = (8, 4, 2) + mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) + shape = (4, 4, 2) np_inp = np.arange(math.prod(shape)).reshape(shape) s = NamedSharding(mesh, P('x', 'y')) + sds = jax.ShapeDtypeStruct(np_inp.shape, np_inp.dtype, sharding=s) arr = jax.device_put(np_inp, s) def f(x): return x.T - lowered = jax.jit(f).lower(arr, _in_layouts=None, _out_layouts=None) + lowered = jax.jit(f, in_shardings=None, out_shardings=None).lower(sds) self.assertIn("default", lowered.as_text()) compiled = lowered.compile() out = compiled(arr) - self.assertTupleEqual(compiled._input_layouts()[0][0]._minor_to_major, (2, 1, 0)) - self.assertTupleEqual(compiled._output_layouts()._minor_to_major, (2, 1, 0)) + self.assertTupleEqual( + compiled.input_layouts()[0][0].device_local_layout.major_to_minor[::-1], + (2, 1, 0)) + self.assertTupleEqual( + compiled.output_layouts().device_local_layout.major_to_minor[::-1], + (2, 1, 0)) self.assertArraysEqual(out, np_inp.T) self.assertEqual(out.sharding, NamedSharding(mesh, P(None, 'y', 'x'))) - compiled_auto = jax.jit(f).lower(arr, _in_layouts=layout.AUTO, - _out_layouts=layout.AUTO).compile() - self.assertTupleEqual(compiled_auto._input_layouts()[0][0]._minor_to_major, - (2, 1, 0)) - self.assertTupleEqual(compiled_auto._output_layouts()._minor_to_major, - (0, 1, 2)) + compiled_auto = jax.jit(f, in_shardings=Layout(DLL.AUTO), + out_shardings=Layout(DLL.AUTO)).lower(sds).compile() + self.assertTupleEqual( + compiled_auto.input_layouts()[0][0].device_local_layout.major_to_minor[::-1], + (2, 1, 0)) + self.assertTupleEqual( + compiled_auto.output_layouts().device_local_layout.major_to_minor[::-1], + (0, 1, 2)) + + with self.assertRaisesRegex( + ValueError, "jax.jit` does not accept device-local layouts directly"): + jax.jit(f, in_shardings=DLL.AUTO, + out_shardings=DLL.AUTO).lower(sds).compile() def test_in_layouts_out_layouts(self): - mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) shape = (8, 8) np_inp = np.arange(math.prod(shape)).reshape(shape) s = NamedSharding(mesh, P('x', 'y')) @@ -142,26 +157,36 @@ def test_in_layouts_out_layouts(self): def f(x): return x.T - compiled = jax.jit(f).lower( - arr, _in_layouts=None, _out_layouts=layout.AUTO).compile() - self.assertTupleEqual(compiled._input_layouts()[0][0]._minor_to_major, (1, 0)) - self.assertTupleEqual(compiled._output_layouts()._minor_to_major, (0, 1)) + + compiled = jax.jit(f, in_shardings=Layout(), + out_shardings=Layout(DLL.AUTO)).lower(arr).compile() + self.assertTupleEqual( + compiled.input_layouts()[0][0].device_local_layout.major_to_minor[::-1], + (1, 0)) + self.assertTupleEqual( + compiled.output_layouts().device_local_layout.major_to_minor[::-1], + (0, 1)) out = compiled(arr) self.assertArraysEqual(out, np_inp.T) + self.assertEqual(out.layout, compiled.output_layouts()) self.assertEqual(out.sharding, NamedSharding(mesh, P('y', 'x'))) def test_sharding_and_layouts(self): - mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) shape = (4, 8) np_inp = np.arange(math.prod(shape)).reshape(shape) s = NamedSharding(mesh, P('x', 'y')) - compiled = jax.jit(lambda x: x.T, in_shardings=s, out_shardings=s).lower( - np_inp, _in_layouts=layout.AUTO, _out_layouts=layout.AUTO).compile() + compiled = jax.jit(lambda x: x.T, in_shardings=Layout(DLL.AUTO, s), + out_shardings=Layout(DLL.AUTO, s)).lower(np_inp).compile() out = compiled(np_inp) - self.assertTupleEqual(compiled._input_layouts()[0][0]._minor_to_major, (1, 0)) - self.assertTupleEqual(compiled._output_layouts()._minor_to_major, (0, 1)) + self.assertTupleEqual( + compiled.input_layouts()[0][0].device_local_layout.major_to_minor[::-1], + (1, 0)) + self.assertTupleEqual( + compiled.output_layouts().device_local_layout.major_to_minor[::-1], + (0, 1)) self.assertArraysEqual(out, np_inp.T) self.assertEqual(out.sharding, s) @@ -171,22 +196,224 @@ def f(x, y, z, a, b, c): shape = (8, 2) inps = [np.arange(math.prod(shape)).reshape(shape)] * 6 - compiled = jax.jit(f).lower(*inps, _in_layouts=layout.AUTO, - _out_layouts=layout.AUTO).compile() - arg_layouts, _ = compiled._input_layouts() + compiled = jax.jit(f, in_shardings=Layout(DLL.AUTO), + out_shardings=Layout(DLL.AUTO)).lower(*inps).compile() + arg_layouts, _ = compiled.input_layouts() out1, out2 = compiled(*inps) - compiled2 = jax.jit(f).lower(*inps, _in_layouts=arg_layouts).compile() + compiled2 = jax.jit(f, in_shardings=arg_layouts).lower(*inps).compile() out3, out4 = compiled2(*inps) - for l1, l2 in safe_zip(arg_layouts, compiled2._input_layouts()[0]): + for l1, l2 in safe_zip(arg_layouts, compiled2.input_layouts()[0]): self.assertEqual(l1, l2) self.assertArraysEqual(out1, out3) self.assertArraysEqual(out2, out4) - # TODO(yashkatariya, frostig): Also use the arg_layouts to create an Array - # and then pass that back into compiled. + arrs = [jax.device_put(i, l) for i, l in zip(inps, arg_layouts)] + out5, out6 = jax.jit(f)(*arrs) + self.assertArraysEqual(out1, out5) + self.assertArraysEqual(out2, out6) + + def test_no_error_dced_args(self): + mesh = jtu.create_global_mesh((2, 1), ('x', 'y')) + shape = (8, 2) + s = NamedSharding(mesh, P('x', 'y')) + np_inp = np.arange(math.prod(shape)).reshape(shape) + arr1 = jax.device_put(np_inp, s) + arr2 = jax.device_put(np_inp, s) + arrs = [arr1, arr2] + + def f(x, y): + return x * 2 + + jf = jax.jit(f, in_shardings=Layout(DLL.AUTO, s), + out_shardings=Layout(DLL.AUTO, s)) + compiled = jf.lower(np_inp, np_inp).compile() + arg_layouts, _ = compiled.input_layouts() + arrs = [jax.device_put(i, l) for i, l in zip(arrs, arg_layouts)] + compiled(*arrs) + + def test_aot_layout_mismatch(self): + mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) + shape = (256, 4, 2) + np_inp = np.arange(math.prod(shape)).reshape(shape) + s = NamedSharding(mesh, P('x')) + + sds = jax.ShapeDtypeStruct(np_inp.shape, np_inp.dtype, sharding=s) + arr = jax.device_put(np_inp, s) + + def f(x): + return (x * 2).T + + with self.assertRaisesRegex( + ValueError, + 'Layout passed to jit does not match the layout on the respective arg'): + jax.jit(f, in_shardings=Layout(DLL.AUTO)).lower(arr) + + compiled = jax.jit(f, in_shardings=Layout(DLL.AUTO), + out_shardings=Layout(DLL.AUTO)).lower(sds).compile() + + with self.assertRaisesRegex( + ValueError, + r'Compiled object called with input layout\(s\) does' + r' not match the layout\(s\) the computation was' + ' compiled with'): + compiled(arr) + + @jtu.ignore_warning(category=DeprecationWarning, + message="backend and device argument") + def test_cpu_default_backend_layout(self): + inp = jax.device_put(np.ones((8, 8)), device=jax.devices('cpu')[0]) + out_cpu = jax.jit(jnp.dot)(inp, inp) + + jax.jit(jnp.dot, backend=jax.default_backend()).lower( + out_cpu, out_cpu).compile() # doesn't crash + + def test_device_put_concrete_layout(self): + mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) + shape = (8, 128) + np_inp = np.arange(math.prod(shape)).reshape(shape) + s = NamedSharding(mesh, P('x', 'y')) + arr = jax.device_put(np_inp, s) + + compiled = jax.jit( + lambda x: x * 2, out_shardings=Layout(DLL.AUTO)).lower(arr).compile() + col = compiled.output_layouts() + + out = jax.device_put(np_inp, col) + self.assertEqual(out.layout, col) + self.assertArraysEqual(out, np_inp) + for s in out.addressable_shards: + self.assertEqual(out.layout.device_local_layout, + s.data.layout.device_local_layout) + + def test_device_put_non_concrete_layout_error(self): + np_inp = np.arange(16).reshape(8, 2) + + l1 = Layout(DLL.AUTO, SingleDeviceSharding(jax.devices()[0])) + with self.assertRaisesRegex( + ValueError, 'sharding and device_local_layout.*should be concrete'): + jax.device_put(np_inp, l1) + + l2 = Layout(DLL.AUTO) + with self.assertRaisesRegex( + ValueError, 'sharding and device_local_layout.*should be concrete'): + jax.device_put(np_inp, l2) + + l3 = Layout(None, SingleDeviceSharding(jax.devices()[0])) + out = jax.device_put(np_inp, l3) + self.assertArraysEqual(out, np_inp) + self.assertTrue(out._committed) + + def invalid_layout_spec(self): + x = np.arange(8) + compiled = jax.jit(lambda x: x).lower(x).compile() + with self.assertRaisesRegex( + ValueError, 'Sharding has to be concrete when layout.*'): + Layout(compiled.output_layouts()[0], None) + + def test_layout_on_sds(self): + mesh = jtu.create_global_mesh((2, 1), ('x', 'y')) + s = NamedSharding(mesh, P('x', 'y')) + np_inp = np.arange(16).reshape(8, 2) + arr = jax.device_put(np_inp, s) + + out_layout = jax.jit(jnp.sin, out_shardings=Layout(DLL.AUTO)).lower( + arr).compile().output_layouts() + + sds = jax.ShapeDtypeStruct(arr.shape, arr.dtype, sharding=out_layout) + arg_layout, _ = jax.jit(lambda x: x * 2).lower(sds).compile().input_layouts() + self.assertEqual(arg_layout[0], out_layout) + + with self.assertRaisesRegex( + TypeError, + 'DeviceLocalLayout.AUTO` cannot be used in place of a device-local' + ' layout in a `ShapeDtypeStruct`'): + jax.ShapeDtypeStruct(arr.shape, arr.dtype, sharding=Layout(DLL.AUTO)) + + def test_make_array_from_callback(self): + mesh = jtu.create_global_mesh((2, 1), ('x', 'y')) + s = NamedSharding(mesh, P('x', 'y')) + np_inp = np.arange(16).reshape(8, 2) + sds = jax.ShapeDtypeStruct(np_inp.shape, np_inp.dtype, sharding=s) + + layout = jax.jit(lambda x: x * 2).lower(sds).compile().output_layouts() + + out = jax.make_array_from_callback(np_inp.shape, layout, + lambda idx: np_inp[idx]) + self.assertArraysEqual(out, np_inp) + self.assertEqual(out.layout, layout) + + with self.assertRaisesRegex( + TypeError, + '`DeviceLocalLayout.AUTO` cannot be used in place of a device-local' + ' layout'): + jax.make_array_from_callback(np_inp.shape, Layout(DLL.AUTO, s), + lambda idx: np_inp[idx]) + + with self.assertRaisesRegex( + TypeError, 'sharding should be an instance of `jax.sharding`'): + jax.make_array_from_callback( + np_inp.shape, Layout(None, None), lambda idx: np_inp[idx]) + + def test_wsc_concrete_layout(self): + mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) + shape = (128, 128) + s = NamedSharding(mesh, P('x')) + np_inp = np.arange(math.prod(shape)).reshape(shape) + arr = jax.device_put(np_inp, s) + + # Create a custom layout instead of using `arr.layout` to test the API. + custom_dll = DLL(major_to_minor=(0, 1), tiling=((8, 128),)) + + # We need AUTO so that XLA can override the entry computation layout set. + # TODO(yashkatariya): Expose a config that sets out_shardings to AUTO by + # default instead of `None` i.e. default layout and let the compiler choose + # the layout or try setting it to AUTO by default and see if there is chaos. + @partial(jax.jit, out_shardings=Layout(DLL.AUTO)) + def f(x): + y = x.T + # Constrain `y` to the original layout of `arr` because without it, + # the layout of `y` would be the transpose of `arr`. + return jax.lax.with_sharding_constraint(y, Layout(custom_dll, s)) + + out = f(arr) + self.assertEqual(out.layout, Layout(custom_dll, s)) + self.assertEqual(out.layout, arr.layout) + self.assertArraysEqual(out, np_inp.T) + + def test_wsc_concrete_layout_bfloat16(self): + mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) + shape = (128, 128) + s = NamedSharding(mesh, P('x')) + inp = jnp.arange(math.prod(shape), dtype=jnp.bfloat16).reshape(shape) + arr = jax.device_put(inp, s) + + # Create a custom layout instead of using `arr.layout` to test the API. + custom_dll = DLL(major_to_minor=(0, 1), tiling=((8, 128), (2, 1))) + + @partial(jax.jit, out_shardings=Layout(DLL.AUTO)) + def f(x): + y = x.T + # Constrain `y` to the original layout of `arr` because without it, + # the layout of `y` would be the transpose of `arr`. + return jax.lax.with_sharding_constraint(y, Layout(custom_dll, s)) + + out = f(arr) + self.assertEqual(out.layout, Layout(custom_dll, s)) + self.assertEqual(out.layout, arr.layout) + self.assertArraysEqual(out, inp.T) + + def test_device_put_user_concrete_layout(self): + shape = (8, 128) + np_inp = np.arange(math.prod(shape)).reshape(shape) + dll = DLL(major_to_minor=(1, 0), tiling=((8, 128),)) + s = SingleDeviceSharding(jax.devices()[0]) + + out = jax.device_put(np_inp, Layout(dll, s)) + self.assertEqual(out.layout, Layout(dll, s)) + self.assertArraysEqual(out, np_inp) if __name__ == '__main__': diff --git a/tests/linalg_test.py b/tests/linalg_test.py index 78971889a6d2..fe7ddc83e8b6 100644 --- a/tests/linalg_test.py +++ b/tests/linalg_test.py @@ -16,7 +16,6 @@ from functools import partial import itertools -import unittest import numpy as np import scipy @@ -30,7 +29,6 @@ from jax import lax from jax import numpy as jnp from jax import scipy as jsp -from jax._src.lib import version as jaxlib_version from jax._src import config from jax._src.lax import linalg as lax_linalg from jax._src import test_util as jtu @@ -182,6 +180,19 @@ def testTensorsolve(self, m, nq, dtype): args_maker, rtol={np.float64: 1e-13}) + def testTensorsolveAxes(self): + a_shape = (2, 1, 3, 6) + b_shape = (1, 6) + axes = (0, 2) + dtype = "float32" + + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(a_shape, dtype), rng(b_shape, dtype)] + np_fun = partial(np.linalg.tensorsolve, axes=axes) + jnp_fun = partial(jnp.linalg.tensorsolve, axes=axes) + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + @jtu.sample_product( [dict(dtype=dtype, method=method) for dtype in float_types + complex_types @@ -269,7 +280,6 @@ def check_left_eigenvectors(a, w, vl): # TODO(phawkins): enable when there is an eigendecomposition implementation # for GPU/TPU. @jtu.run_on_devices("cpu") - @unittest.skipIf(jaxlib_version < (0, 4, 21), "Test requires jaxlib 0.4.21") def testEigHandlesNanInputs(self, shape, dtype, compute_left_eigenvectors, compute_right_eigenvectors): """Verifies that `eig` fails gracefully if given non-finite inputs.""" @@ -280,7 +290,6 @@ def testEigHandlesNanInputs(self, shape, dtype, compute_left_eigenvectors, for result in results: self.assertTrue(np.all(np.isnan(result))) - @jtu.sample_product( shape=[(4, 4), (5, 5), (8, 8), (7, 6, 6)], dtype=float_types + complex_types, @@ -335,13 +344,13 @@ def testEigBatching(self, shape, dtype): np.matmul(args, vs) - ws[..., None, :] * vs) < 1e-3)) @jtu.sample_product( - n=[0, 4, 5, 50, 512], - dtype=float_types + complex_types, - lower=[True, False], + n=[0, 4, 5, 50, 512], + dtype=float_types + complex_types, + lower=[True, False], ) def testEigh(self, n, dtype, lower): rng = jtu.rand_default(self.rng()) - tol = 0.5 * np.maximum(n, 80) * np.finfo(dtype).eps + eps = np.finfo(dtype).eps args_maker = lambda: [rng((n, n), dtype)] uplo = "L" if lower else "U" @@ -351,15 +360,36 @@ def testEigh(self, n, dtype, lower): w, v = jnp.linalg.eigh(np.tril(a) if lower else np.triu(a), UPLO=uplo, symmetrize_input=False) w = w.astype(v.dtype) - self.assertLessEqual( - np.linalg.norm(np.eye(n) - np.matmul(np.conj(T(v)), v)), 4 * tol + tol = 2 * n * eps + self.assertAllClose( + np.eye(n, dtype=v.dtype), + np.matmul(np.conj(T(v)), v), + atol=tol, + rtol=tol, ) + with jax.numpy_rank_promotion('allow'): - self.assertLessEqual(np.linalg.norm(np.matmul(a, v) - w * v), - tol * np.linalg.norm(a)) + tol = 100 * eps + self.assertLessEqual( + np.linalg.norm(np.matmul(a, v) - w * v), tol * np.linalg.norm(a) + ) self._CompileAndCheck( - partial(jnp.linalg.eigh, UPLO=uplo), args_maker, rtol=tol + partial(jnp.linalg.eigh, UPLO=uplo), args_maker, rtol=eps + ) + + # Compare eigenvalues against Numpy using double precision. We do not compare + # eigenvectors because they are not uniquely defined, but the two checks above + # guarantee that that they satisfy the conditions for being eigenvectors. + double_type = dtype + if dtype == np.float32: + double_type = np.float64 + if dtype == np.complex64: + double_type = np.complex128 + w_np = np.linalg.eigvalsh(a.astype(double_type)) + tol = 8 * eps + self.assertAllClose( + w_np.astype(w.dtype), w, atol=tol * np.linalg.norm(a), rtol=tol ) @jtu.sample_product( @@ -373,7 +403,7 @@ def testEighSubsetByIndex(self, start, end): dtype = np.float32 n = 256 rng = jtu.rand_default(self.rng()) - tol = np.maximum(n, 80) * np.finfo(dtype).eps + eps = np.finfo(dtype).eps args_maker = lambda: [rng((n, n), dtype)] subset_by_index = (start, end) k = end - start @@ -387,21 +417,36 @@ def testEighSubsetByIndex(self, start, end): self.assertEqual(v.shape, (n, k)) self.assertEqual(w.shape, (k,)) - self.assertLessEqual( - np.linalg.norm(np.eye(k) - np.matmul(np.conj(T(v)), v)), 3 * tol - ) with jax.numpy_rank_promotion("allow"): + tol = 200 * eps self.assertLessEqual( np.linalg.norm(np.matmul(a, v) - w * v), tol * np.linalg.norm(a) ) + tol = 3 * n * eps + self.assertAllClose( + np.eye(k, dtype=v.dtype), + np.matmul(np.conj(T(v)), v), + atol=tol, + rtol=tol, + ) - self._CompileAndCheck(partial(jnp.linalg.eigh), args_maker, rtol=tol) + self._CompileAndCheck(partial(jnp.linalg.eigh), args_maker, rtol=eps) # Compare eigenvalues against Numpy. We do not compare eigenvectors because # they are not uniquely defined, but the two checks above guarantee that # that they satisfy the conditions for being eigenvectors. - w_np = np.linalg.eigvalsh(a)[subset_by_index[0] : subset_by_index[1]] - self.assertAllClose(w_np, w, atol=tol, rtol=tol) + double_type = dtype + if dtype == np.float32: + double_type = np.float64 + if dtype == np.complex64: + double_type = np.complex128 + w_np = np.linalg.eigvalsh(a.astype(double_type))[ + subset_by_index[0] : subset_by_index[1] + ] + tol = 20 * eps + self.assertAllClose( + w_np.astype(w.dtype), w, atol=tol * np.linalg.norm(a), rtol=tol + ) def testEighZeroDiagonal(self): a = np.array([[0., -1., -1., 1.], @@ -425,7 +470,7 @@ def testEighTinyNorm(self): w = w.astype(v.dtype) with jax.numpy_rank_promotion("allow"): self.assertLessEqual( - np.linalg.norm(np.matmul(a, v) - w * v), 20 * eps * np.linalg.norm(a) + np.linalg.norm(np.matmul(a, v) - w * v), 80 * eps * np.linalg.norm(a) ) @jtu.sample_product( @@ -558,7 +603,7 @@ def testEighBatching(self, shape, dtype): ws, vs = vmap(jsp.linalg.eigh)(args) ws = ws.astype(vs.dtype) norm = np.max(np.linalg.norm(np.matmul(args, vs) - ws[..., None, :] * vs)) - self.assertLess(norm, 1e-2) + self.assertLess(norm, 1.4e-2) @jtu.sample_product( shape=[(1,), (4,), (5,)], @@ -687,6 +732,12 @@ def testVecdot(self, lhs_shape, rhs_shape, axis, dtype): self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, tol=tol) self._CompileAndCheck(jnp_fn, args_maker, tol=tol) + # smoke-test for optional kwargs. + jnp_fn = partial(jnp.linalg.vecdot, axis=axis, + precision=lax.Precision.HIGHEST, + preferred_element_type=dtype) + self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, tol=tol) + # jnp.linalg.matmul is an alias of jnp.matmul; do a minimal test here. @jtu.sample_product( [ @@ -709,6 +760,12 @@ def testMatmul(self, lhs_shape, rhs_shape, dtype): self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, tol=tol) self._CompileAndCheck(jnp_fn, args_maker, tol=tol) + # smoke-test for optional kwargs. + jnp_fn = partial(jnp.linalg.matmul, + precision=lax.Precision.HIGHEST, + preferred_element_type=dtype) + self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, tol=tol) + # jnp.linalg.tensordot is an alias of jnp.tensordot; do a minimal test here. @jtu.sample_product( [ @@ -732,6 +789,12 @@ def testTensordot(self, lhs_shape, rhs_shape, axes, dtype): self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, tol=tol) self._CompileAndCheck(jnp_fn, args_maker, tol=tol) + # smoke-test for optional kwargs. + jnp_fn = partial(jnp.linalg.tensordot, axes=axes, + precision=lax.Precision.HIGHEST, + preferred_element_type=dtype) + self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, tol=tol) + @jtu.sample_product( [ dict(m=m, n=n, full_matrices=full_matrices, hermitian=hermitian) @@ -767,7 +830,7 @@ def compute_max_backward_error(operand, reconstructed_operand): max_backward_error = np.amax(backward_error) return max_backward_error - tol = 80 * jnp.finfo(dtype).eps + tol = 100 * jnp.finfo(dtype).eps reconstruction_tol = 2 * tol unitariness_tol = 3 * tol @@ -977,7 +1040,7 @@ def _args_gen(): args_maker = args_gen(pnorm) if pnorm not in [2, -2] and len(set(shape[-2:])) != 1: - with self.assertRaises(np.linalg.LinAlgError): + with self.assertRaises(ValueError): jnp.linalg.cond(*args_maker()) else: self._CheckAgainstNumpy(np.linalg.cond, jnp.linalg.cond, args_maker, @@ -987,28 +1050,16 @@ def _args_gen(): check_dtypes=False, rtol=1e-03, atol=1e-03) @jtu.sample_product( - shape=[(1, 1), (4, 4), (200, 200), (7, 7, 7, 7)], - dtype=float_types, + shape=[(1, 1), (4, 4), (6, 2, 3), (3, 4, 2, 6)], + dtype=float_types + complex_types, ) def testTensorinv(self, shape, dtype): rng = jtu.rand_default(self.rng()) - - def tensor_maker(): - invertible = False - while not invertible: - a = rng(shape, dtype) - try: - np.linalg.inv(a) - invertible = True - except np.linalg.LinAlgError: - pass - return a - - args_maker = lambda: [tensor_maker(), int(np.floor(len(shape) / 2))] - self._CheckAgainstNumpy(np.linalg.tensorinv, jnp.linalg.tensorinv, args_maker, - check_dtypes=False, tol=1e-3) - partial_inv = partial(jnp.linalg.tensorinv, ind=int(np.floor(len(shape) / 2))) - self._CompileAndCheck(partial_inv, lambda: [tensor_maker()], check_dtypes=False, rtol=1e-03, atol=1e-03) + args_maker = lambda: [rng(shape, dtype)] + np_fun = partial(np.linalg.tensorinv, ind=len(shape) // 2) + jnp_fun = partial(jnp.linalg.tensorinv, ind=len(shape) // 2) + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, tol=1E-4) + self._CompileAndCheck(jnp_fun, args_maker) @jtu.sample_product( [dict(lhs_shape=lhs_shape, rhs_shape=rhs_shape) @@ -1016,7 +1067,6 @@ def tensor_maker(): ((1, 1), (1, 1)), ((4, 4), (4,)), ((8, 8), (8, 4)), - ((1, 2, 2), (3, 2)), ((2, 2), (3, 2, 2)), ((2, 1, 3, 3), (1, 4, 3, 4)), ((1, 0, 0), (1, 0, 2)), @@ -1043,9 +1093,15 @@ def testSolveBroadcasting(self, lhs_shape, rhs_shape): # that we match NumPy's convention in all cases. rng = jtu.rand_default(self.rng()) args_maker = lambda: [rng(lhs_shape, 'float32'), rng(rhs_shape, 'float32')] - # As of numpy 1.26.3, np.linalg.solve fails when this condition is not met. - if len(lhs_shape) == 2 or len(rhs_shape) > 1: - self._CheckAgainstNumpy(np.linalg.solve, jnp.linalg.solve, args_maker, tol=1E-3) + + if jtu.numpy_version() >= (2, 0, 0): + # TODO(jakevdp) remove this condition after solve broadcast deprecation is finalized. + if len(rhs_shape) == 1 or (len(lhs_shape) != len(rhs_shape) + 1): + self._CheckAgainstNumpy(np.linalg.solve, jnp.linalg.solve, args_maker, tol=1E-3) + else: # numpy 1.X + # As of numpy 1.26.3, np.linalg.solve fails when this condition is not met. + if len(lhs_shape) == 2 or len(rhs_shape) > 1: + self._CheckAgainstNumpy(np.linalg.solve, jnp.linalg.solve, args_maker, tol=1E-3) self._CompileAndCheck(jnp.linalg.solve, args_maker) @jtu.sample_product( @@ -1284,6 +1340,18 @@ def testDiagonal(self, shape, dtype, offset): self._CheckAgainstNumpy(np_fun, lax_fun, args_maker) self._CompileAndCheck(lax_fun, args_maker) + def testTrace(self): + shape, dtype, offset, out_dtype = (3, 4), "float32", 0, None + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(shape, dtype)] + lax_fun = partial(jnp.linalg.trace, offset=offset, dtype=out_dtype) + if jtu.numpy_version() >= (2, 0, 0): + np_fun = partial(np.linalg.trace, offset=offset) + else: + np_fun = partial(np.trace, offset=offset, axis1=-2, axis2=-1, dtype=out_dtype) + self._CheckAgainstNumpy(np_fun, lax_fun, args_maker) + self._CompileAndCheck(lax_fun, args_maker) + class ScipyLinalgTest(jtu.JaxTestCase): @@ -1359,6 +1427,8 @@ def testLuBatching(self, shape, dtype): self.assertAllClose(us, actual_us) @jtu.skip_on_devices("cpu", "tpu") + @jtu.ignore_warning(category=DeprecationWarning, + message="backend and device argument") def testLuCPUBackendOnGPU(self): # tests running `lu` on cpu when a gpu is present. jit(jsp.linalg.lu, backend="cpu")(np.ones((2, 2))) # does not crash @@ -2079,6 +2149,16 @@ def testMatrixTranspose(self, shape, dtype): self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) self._CompileAndCheck(jnp_fun, args_maker) + @jtu.sample_product( + n=[0, 1, 5, 10, 20], + ) + def testHilbert(self, n): + args_maker = lambda: [] + osp_fun = partial(osp.linalg.hilbert, n=n) + jsp_fun = partial(jsp.linalg.hilbert, n=n) + self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker) + self._CompileAndCheck(jsp_fun, args_maker) + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/lobpcg_test.py b/tests/lobpcg_test.py index 167e3744b7d3..02d340abc1b7 100644 --- a/tests/lobpcg_test.py +++ b/tests/lobpcg_test.py @@ -30,7 +30,6 @@ import scipy.sparse as sps import jax -from jax import config from jax._src import test_util as jtu from jax.experimental.sparse import linalg, bcoo import jax.numpy as jnp @@ -288,7 +287,10 @@ def _debug_plots(self, X, eigs, info, matrix_name, lobpcg_debug_plot_dir): # We import matplotlib lazily because (a) it's faster this way, and # (b) concurrent imports of matplotlib appear to trigger some sort of # collision on the matplotlib cache lock on Windows. - from matplotlib import pyplot as plt + try: + from matplotlib import pyplot as plt + except (ModuleNotFoundError, ImportError): + return # If matplotlib isn't available, don't emit plots. os.makedirs(lobpcg_debug_plot_dir, exist_ok=True) clean_matrix_name = _clean_matrix_name(matrix_name) @@ -414,21 +416,21 @@ def setUp(self): super().setUp() @parameterized.named_parameters(_make_concrete_cases(f64=True)) - @jtu.skip_on_devices("tpu", "iree", "gpu") + @jtu.skip_on_devices("tpu", "gpu") def testLobpcgConsistencyF64(self, matrix_name, n, k, m, tol): self.checkLobpcgConsistency(matrix_name, n, k, m, tol, jnp.float64) @parameterized.named_parameters(_make_concrete_cases(f64=True)) - @jtu.skip_on_devices("tpu", "iree", "gpu") + @jtu.skip_on_devices("tpu", "gpu") def testLobpcgMonotonicityF64(self, matrix_name, n, k, m, tol): self.checkLobpcgMonotonicity(matrix_name, n, k, m, tol, jnp.float64) @parameterized.named_parameters(_make_callable_cases(f64=True)) - @jtu.skip_on_devices("tpu", "iree", "gpu") + @jtu.skip_on_devices("tpu", "gpu") def testCallableMatricesF64(self, matrix_name): self.checkApproxEigs(matrix_name, jnp.float64) if __name__ == '__main__': - config.parse_flags_with_absl() + jax.config.parse_flags_with_absl() absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/logging_test.py b/tests/logging_test.py index 05bb31015c1a..5a495d47d31b 100644 --- a/tests/logging_test.py +++ b/tests/logging_test.py @@ -15,6 +15,7 @@ import contextlib import io import logging +import os import platform import subprocess import sys @@ -22,7 +23,6 @@ import unittest import jax -from jax import config import jax._src.test_util as jtu from jax._src import xla_bridge @@ -33,7 +33,20 @@ # parsing to work correctly with bazel (otherwise we could avoid importing # absltest/absl logging altogether). from absl.testing import absltest -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() + + +@contextlib.contextmanager +def jax_debug_log_modules(value): + # jax_debug_log_modules doesn't have a context manager, because it's + # not thread-safe. But since tests are always single-threaded, we + # can define one here. + original_value = jax.config.jax_debug_log_modules + jax.config.update("jax_debug_log_modules", value) + try: + yield + finally: + jax.config.update("jax_debug_log_modules", original_value) @contextlib.contextmanager @@ -71,9 +84,17 @@ def test_no_log_spam(self): """) python = sys.executable assert "python" in python + env_variables = {"TF_CPP_MIN_LOG_LEVEL": "1"} + if os.getenv("PYTHONPATH"): + env_variables["PYTHONPATH"] = os.getenv("PYTHONPATH") + if os.getenv("LD_LIBRARY_PATH"): + env_variables["LD_LIBRARY_PATH"] = os.getenv("LD_LIBRARY_PATH") # Make sure C++ logging is at default level for the test process. - proc = subprocess.run([python, "-c", program], capture_output=True, - env={"TF_CPP_MIN_LOG_LEVEL": "1"}) + proc = subprocess.run( + [python, "-c", program], + capture_output=True, + env=env_variables, + ) lines = proc.stdout.split(b"\n") lines.extend(proc.stderr.split(b"\n")) @@ -96,30 +117,30 @@ def test_debug_logging(self): self.assertEmpty(log_output.getvalue()) # Turn on all debug logging. - config.update("jax_debug_log_modules", "jax") - with capture_jax_logs() as log_output: - jax.jit(lambda x: x + 1)(1) - self.assertIn("Finished tracing + transforming", log_output.getvalue()) - self.assertIn("Compiling ", log_output.getvalue()) + with jax_debug_log_modules("jax"): + with capture_jax_logs() as log_output: + jax.jit(lambda x: x + 1)(1) + self.assertIn("Finished tracing + transforming", log_output.getvalue()) + self.assertIn("Compiling ", log_output.getvalue()) # Turn off all debug logging. - config.update("jax_debug_log_modules", None) - with capture_jax_logs() as log_output: - jax.jit(lambda x: x + 1)(1) - self.assertEmpty(log_output.getvalue()) + with jax_debug_log_modules(""): + with capture_jax_logs() as log_output: + jax.jit(lambda x: x + 1)(1) + self.assertEmpty(log_output.getvalue()) # Turn on one module. - config.update("jax_debug_log_modules", "jax._src.dispatch") - with capture_jax_logs() as log_output: - jax.jit(lambda x: x + 1)(1) - self.assertIn("Finished tracing + transforming", log_output.getvalue()) - self.assertNotIn("Compiling ", log_output.getvalue()) + with jax_debug_log_modules("jax._src.dispatch"): + with capture_jax_logs() as log_output: + jax.jit(lambda x: x + 1)(1) + self.assertIn("Finished tracing + transforming", log_output.getvalue()) + self.assertNotIn("Compiling ", log_output.getvalue()) # Turn everything off again. - config.update("jax_debug_log_modules", None) - with capture_jax_logs() as log_output: - jax.jit(lambda x: x + 1)(1) - self.assertEmpty(log_output.getvalue()) + with jax_debug_log_modules(""): + with capture_jax_logs() as log_output: + jax.jit(lambda x: x + 1)(1) + self.assertEmpty(log_output.getvalue()) if __name__ == "__main__": diff --git a/tests/lru_cache_test.py b/tests/lru_cache_test.py new file mode 100644 index 000000000000..fb999cbef0cf --- /dev/null +++ b/tests/lru_cache_test.py @@ -0,0 +1,155 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import importlib.util +import tempfile +import time + +from absl.testing import absltest + +from jax._src import path as pathlib +from jax._src.lru_cache import LRUCache +import jax._src.test_util as jtu + + +class LRUCacheTestCase(jtu.JaxTestCase): + name: str | None + path: pathlib.Path | None + + def setUp(self): + if importlib.util.find_spec("filelock") is None: + self.skipTest("filelock is not installed") + + super().setUp() + tmpdir = tempfile.TemporaryDirectory() + self.enter_context(tmpdir) + self.name = tmpdir.name + self.path = pathlib.Path(self.name) + + def tearDown(self): + self.path = None + self.name = None + super().tearDown() + + +class LRUCacheTest(LRUCacheTestCase): + + def test_get_nonexistent_key(self): + cache = LRUCache(self.name, max_size=-1) + self.assertIsNone(cache.get("cache-a")) + + def test_put_and_get_key(self): + cache = LRUCache(self.name, max_size=-1) + + cache.put("cache-a", b"a") + self.assertEqual(cache.get("cache-a"), b"a") + self.assertEqual(set(self.path.glob("cache-*")), {self.path / "cache-a"}) + + cache.put("cache-b", b"b") + self.assertEqual(cache.get("cache-a"), b"a") + self.assertEqual(cache.get("cache-b"), b"b") + self.assertEqual(set(self.path.glob("cache-*")), {self.path / "cache-a", self.path / "cache-b"}) + + def test_put_empty_value(self): + cache = LRUCache(self.name, max_size=-1) + + cache.put("cache-a", b"") + self.assertEqual(cache.get("cache-a"), b"") + + def test_put_empty_key(self): + cache = LRUCache(self.name, max_size=-1) + + with self.assertRaisesRegex(ValueError, r"key cannot be empty"): + cache.put("", b"a") + + def test_eviction(self): + cache = LRUCache(self.name, max_size=2) + + cache.put("cache-a", b"a") + cache.put("cache-b", b"b") + + # `sleep()` is necessary to guarantee that `cache-b`"s timestamp is strictly greater than `cache-a`"s + time.sleep(1) + cache.get("cache-b") + + # write `cache-c`, evict `cache-a` + cache.put("cache-c", b"c") + self.assertEqual(set(self.path.glob("cache-*")), {self.path / "cache-b", self.path / "cache-c"}) + + # calling `get()` on `cache-b` makes `cache-c` least recently used + time.sleep(1) + cache.get("cache-b") + + # write `cache-d`, evict `cache-c` + cache.put("cache-d", b"d") + self.assertEqual(set(self.path.glob("cache-*")), {self.path / "cache-b", self.path / "cache-d"}) + + def test_eviction_with_empty_value(self): + cache = LRUCache(self.name, max_size=1) + + cache.put("cache-a", b"a") + + # write `cache-b` with length 0 + # eviction should not happen even though the cache is full + cache.put("cache-b", b"") + self.assertEqual(set(self.path.glob("cache-*")), {self.path / "cache-a", self.path / "cache-b"}) + + # calling `get()` on `cache-a` makes `cache-b` least recently used + time.sleep(1) + cache.get("cache-a") + + # writing `cache-c` should result in evicting the + # least recent used file (`cache-b`) first, + # but this is not sufficient to make room for `cache-c`, + # so `cache-a` should be evicted as well + cache.put("cache-c", b"c") + self.assertEqual(set(self.path.glob("cache-*")), {self.path / "cache-c"}) + + def test_existing_cache_dir(self): + cache = LRUCache(self.name, max_size=2) + + cache.put("cache-a", b"a") + + # simulates reinitializing the cache in another process + del cache + cache = LRUCache(self.name, max_size=2) + + self.assertEqual(cache.get("cache-a"), b"a") + + # ensure that the LRU policy survives cache reinitialization + cache.put("cache-b", b"b") + + # calling `get()` on `cache-a` makes `cache-b` least recently used + time.sleep(1) + cache.get("cache-a") + + # write `cache-c`, evict `cache-b` + cache.put("cache-c", b"c") + self.assertEqual(set(self.path.glob("cache-*")), {self.path / "cache-a", self.path / "cache-c"}) + + def test_max_size(self): + cache = LRUCache(self.name, max_size=1) + + msg = (r"Cache value for key .+? of size \d+ bytes exceeds the maximum " + r"cache size of \d+ bytes") + with self.assertWarnsRegex(UserWarning, msg): + cache.put("cache-a", b"aaaa") + self.assertIsNone(cache.get("cache-a")) + self.assertEqual(set(self.path.glob("cache-*")), set()) + + +if __name__ == "__main__": + absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/memories_test.py b/tests/memories_test.py index 6903cdbef789..afe265425a08 100644 --- a/tests/memories_test.py +++ b/tests/memories_test.py @@ -12,8 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +import copy import functools import math +import re from absl.testing import absltest from absl.testing import parameterized from absl import flags @@ -23,16 +25,16 @@ from jax import lax from jax._src import test_util as jtu from jax._src import xla_bridge as xb -from jax._src.lib import xla_extension_version from jax._src import config from jax.ad_checkpoint import checkpoint_name, checkpoint as new_checkpoint import jax.numpy as jnp -from jax.sharding import PartitionSpec as P -from jax.ad_checkpoint import Offloadable, remat +from jax.ad_checkpoint import Offloadable, remat, Recompute +from jax._src.sharding import common_devices_indices_map from jax._src.sharding_impls import (NamedSharding, PositionalSharding, SingleDeviceSharding, GSPMDSharding, - TransferToMemoryKind, - common_devices_indices_map) + TransferToMemoryKind, PartitionSpec as P) +from jax.experimental.compute_on import compute_on +from jax.experimental.shard_map import shard_map import numpy as np config.parse_flags_with_absl() @@ -45,40 +47,22 @@ def get_memory_kinds_from_executable(f, args): def _create_inputs(shape, pspec, mem_kind=None): - mesh = jtu.create_global_mesh((4, 2), ("x", "y")) + mesh = jtu.create_global_mesh((2, 2), ("x", "y")) np_inp = np.arange(math.prod(shape)).reshape(shape) s = NamedSharding(mesh, pspec, memory_kind=mem_kind) inp = jax.device_put(np_inp, s) return mesh, s, np_inp, inp -# Tests TODO -# * wsc with memory_kinds -# * shard_map -# * AOT -# * autodiff tests (jtu.check_grads) -# * scan tests -# * jaxpr checks for primitive running on different mem kinds -# * nested jit - - +@jtu.with_config(jax_enable_memories=True) class ShardingMemoriesTest(jtu.JaxTestCase): def setUp(self): - if not jtu.test_device_matches(["tpu"]): - self.skipTest("Memories do not work on CPU and GPU backends yet.") - # TODO(b/311021572) - if jtu.is_cloud_tpu(): - self.skipTest("Experimental feature not yet implemented on Cloud TPU") super().setUp() - self.orig_memories_flag = config.enable_memories.value - jax.config.update('jax_enable_memories', True) - FLAGS.xla_tpu_enable_host_aware_passes = True - - def tearDown(self): - jax.config.update('jax_enable_memories', self.orig_memories_flag) - FLAGS.xla_tpu_enable_host_aware_passes = False - super().tearDown() + if jtu.test_device_matches(["cpu"]): + self._default_memory_kind = "unpinned_host" + else: + self._default_memory_kind = "device" @parameterized.named_parameters( ("named_sharding", "named_sharding"), @@ -90,17 +74,17 @@ def test_canonicalize_memory_kind(self, name): if name == "named_sharding": mesh = jtu.create_global_mesh((1,), "x") ns = NamedSharding(mesh, P("x")) - self.assertEqual(ns.memory_kind, "device") + self.assertEqual(ns.memory_kind, self._default_memory_kind) elif name == "positional_sharding": ps = PositionalSharding(jax.devices()) - self.assertEqual(ps.memory_kind, "device") + self.assertEqual(ps.memory_kind, self._default_memory_kind) elif name == "single_device_sharding": ss = SingleDeviceSharding(jax.devices()[0]) - self.assertEqual(ss.memory_kind, "device") + self.assertEqual(ss.memory_kind, self._default_memory_kind) else: assert name == "gspmd_sharding" gs = GSPMDSharding.get_replicated(jax.devices()) - self.assertEqual(gs.memory_kind, "device") + self.assertEqual(gs.memory_kind, self._default_memory_kind) @parameterized.named_parameters( ("named_sharding", "named_sharding"), @@ -111,27 +95,26 @@ def test_canonicalize_memory_kind(self, name): def test_wrong_memory_kind(self, name): if name == "named_sharding": with self.assertRaisesRegex( - ValueError, "Could not find memory addressable by device TPU.*" + ValueError, "Could not find memory addressable by device.*" ): - mesh = jtu.create_global_mesh((8,), ("x",)) + mesh = jtu.create_global_mesh((1,), ("x",)) NamedSharding(mesh, P("x"), memory_kind="hbm") elif name == "positional_sharding": with self.assertRaisesRegex( - ValueError, "Could not find memory addressable by device TPU.*" + ValueError, "Could not find memory addressable by device.*" ): PositionalSharding(jax.devices(), memory_kind="gpu_hbm") elif name == "single_device_sharding": with self.assertRaisesRegex( ValueError, - "Could not find memory addressable by device TPU.*Device TPU.*" - " can address the following memory kinds: " - "(device, unpinned_host|unpinned_host, device).*", + "Could not find memory addressable by device.*Device.*" + " can address the following memory kinds.*", ): SingleDeviceSharding(jax.devices()[0], memory_kind="host") else: assert name == "gspmd_sharding" with self.assertRaisesRegex( - ValueError, "Could not find memory addressable by device TPU.*" + ValueError, "Could not find memory addressable by device.*" ): GSPMDSharding.get_replicated(jax.devices(), memory_kind="my_host") @@ -142,11 +125,14 @@ def test_wrong_memory_kind(self, name): ("gspmd_sharding", "gspmd_sharding"), ) def test_correct_tpu_memory_kind(self, name): + if not jtu.test_device_matches(["tpu"]): + self.skipTest("TPU memory kind test.") + if name == "named_sharding": - mesh = jtu.create_global_mesh((8,), ("x",)) - NamedSharding(mesh, P("x"), memory_kind="device") + mesh = jtu.create_global_mesh((1,), ("x",)) + NamedSharding(mesh, P("x"), memory_kind=self._default_memory_kind) elif name == "positional_sharding": - PositionalSharding(jax.devices(), memory_kind="device") + PositionalSharding(jax.devices(), memory_kind=self._default_memory_kind) elif name == "single_device_sharding": SingleDeviceSharding(jax.devices()[0], memory_kind="unpinned_host") else: @@ -161,35 +147,35 @@ def test_correct_tpu_memory_kind(self, name): ) def test_sharding_eq(self, name): if name == "named_sharding": - mesh = jtu.create_global_mesh((8,), ("x",)) + mesh = jtu.create_global_mesh((1,), ("x",)) s1 = NamedSharding(mesh, P("x")) - s2 = NamedSharding(mesh, P("x"), memory_kind="device") + s2 = NamedSharding(mesh, P("x"), memory_kind=self._default_memory_kind) self.assertEqual(s1, s2) elif name == "positional_sharding": s1 = PositionalSharding(jax.devices()) - s2 = PositionalSharding(jax.devices(), memory_kind="device") + s2 = PositionalSharding(jax.devices(), memory_kind=self._default_memory_kind) self.assertEqual(s1, s2) elif name == "single_device_sharding": s1 = SingleDeviceSharding(jax.devices()[0]) - s2 = SingleDeviceSharding(jax.devices()[0], memory_kind="device") + s2 = SingleDeviceSharding(jax.devices()[0], memory_kind=self._default_memory_kind) self.assertEqual(s1, s2) elif name == "gspmd_sharding": s1 = GSPMDSharding.get_replicated(jax.devices()) - s2 = GSPMDSharding.get_replicated(jax.devices(), memory_kind="device") + s2 = GSPMDSharding.get_replicated(jax.devices(), memory_kind=self._default_memory_kind) self.assertEqual(s1, s2) def test_sharding_equivalent(self): - mesh = jtu.create_global_mesh((8,), ("x",)) + mesh = jtu.create_global_mesh((1,), ("x",)) ndim = 2 ns1 = NamedSharding(mesh, P("x")) gs1 = GSPMDSharding( tuple(mesh.devices.flat), ns1._to_xla_hlo_sharding(ndim), - memory_kind="device", + memory_kind=self._default_memory_kind, ) self.assertTrue(ns1.is_equivalent_to(gs1, ndim)) - ns2 = NamedSharding(mesh, P("x"), memory_kind="device") + ns2 = NamedSharding(mesh, P("x"), memory_kind=self._default_memory_kind) gs2 = GSPMDSharding( tuple(mesh.devices.flat), ns2._to_xla_hlo_sharding(ndim) ) @@ -197,21 +183,17 @@ def test_sharding_equivalent(self): def test_default_memory_kind(self): dev = jax.devices()[0] - self.assertEqual(dev.default_memory().kind, "device") + self.assertEqual(dev.default_memory().kind, self._default_memory_kind) -class MemoriesComputationTest(jtu.BufferDonationTestCase): +@jtu.with_config(jax_enable_memories=True) +class DevicePutTest(jtu.JaxTestCase): def setUp(self): - self.skipTest("Compute via memories does not work yet.") + if not jtu.test_device_matches(["tpu"]): + self.skipTest("Memories do not work on CPU and GPU backends yet.") super().setUp() - def _check_mem_kind(self, executable_kind, out_sharding, expected_kind): - out_kind = out_sharding.memory_kind - self.assertEqual(executable_kind, out_kind) - self.assertEqual(out_kind, expected_kind) - self.assertEqual(executable_kind, expected_kind) - def _check_device_put_addressable_shards( self, out, inp, expected_sharding, expected_mem_kind, index=True): self.assertArraysEqual(out, inp) @@ -224,885 +206,1044 @@ def _check_device_put_addressable_shards( self.assertArraysEqual(s.data, inp) self.assertEqual(s.data.sharding.memory_kind, expected_mem_kind) - def test_jit_memory_transfer_to_host_middle(self): - _, s, np_inp, inp = _create_inputs((8, 2), P("x", "y"), mem_kind="device") + def test_error_transfer_to_memory_kind_outside_jit(self): + with self.assertRaisesRegex( + ValueError, + "TransferToMemoryKind argument to jax.device_put can only be used" + " inside jax.jit"): + jax.device_put(np.arange(16), TransferToMemoryKind("device")) - @jax.jit - def f(x): - x = x * 2 - y = jax.device_put(x, s.with_memory_kind("unpinned_host")) - z = y * 3 - a = jax.device_put(z, s.with_memory_kind("device")) - return a * 4, a - - out1, out2 = f(inp) - executable_mk = get_memory_kinds_from_executable(f, [inp]) - - self.assertArraysEqual(out1, np_inp * 24) - self.assertArraysEqual(out2, np_inp * 6) - self.assertEqual(out1.sharding, s) - self.assertEqual(out2.sharding, s) - self._check_mem_kind(executable_mk[0], out1.sharding, "device") - self._check_mem_kind(executable_mk[1], out2.sharding, "device") - - def test_addressable_shards_mem_kind(self): - _, s, np_inp, inp = _create_inputs((8, 2), P("x", "y")) + @parameterized.parameters("unpinned_host", "pinned_host") + def test_device_put_host_to_hbm(self, host_memory_kind: str): + mesh = jtu.create_global_mesh((2, 2), ("x", "y")) + s_host = NamedSharding(mesh, P("y"), memory_kind=host_memory_kind) + np_inp = np.arange(16).reshape(8, 2) - @jax.jit - def f(x): - x = jax.device_put(x, s.with_memory_kind("unpinned_host")) - return x * 2 + out_on_host = jax.device_put(np_inp, s_host) + self.assertEqual(out_on_host.sharding, s_host) - out = f(inp) - executable_mk = get_memory_kinds_from_executable(f, [inp]) + s_hbm = s_host.with_memory_kind("device") + out_on_hbm = jax.device_put(out_on_host, s_hbm) + self._check_device_put_addressable_shards( + out_on_hbm, np_inp, s_hbm, "device") - expected_out = np_inp * 2 - self.assertArraysEqual(out, expected_out) - self.assertEqual(out.sharding, s.with_memory_kind("unpinned_host")) - self._check_mem_kind(executable_mk[0], out.sharding, "unpinned_host") - for s in out.addressable_shards: - self.assertArraysEqual(s.data, expected_out[s.index]) - self._check_mem_kind(executable_mk[0], s.data.sharding, "unpinned_host") + @parameterized.parameters("unpinned_host", "pinned_host") + def test_device_put_hbm_to_host(self, host_memory_kind: str): + mesh = jtu.create_global_mesh((2, 2), ("x", "y")) + s_host = NamedSharding(mesh, P("y"), memory_kind=host_memory_kind) + inp = jnp.arange(16).reshape(8, 2) - def test_jit_host_multi_outputs(self): - _, s, np_inp, inp = _create_inputs((8, 2), P("x")) + out_on_host = jax.device_put(inp, s_host) + self._check_device_put_addressable_shards( + out_on_host, inp, s_host, host_memory_kind) - @jax.jit - def f(x, y): - x, y = jnp.sin(x), jnp.cos(y) - x = jax.device_put(x, s.with_memory_kind("unpinned_host")) - y = jax.device_put(y, s.with_memory_kind("device")) - return x, y + sharded_inp = jax.device_put(inp, s_host.with_memory_kind("device")) + sharded_out_on_host = jax.device_put(sharded_inp, s_host) + self._check_device_put_addressable_shards( + sharded_out_on_host, sharded_inp, s_host, host_memory_kind) - out1, out2 = f(inp, inp) + @parameterized.parameters("unpinned_host", "pinned_host") + def test_device_put_different_device_and_memory_host_to_hbm( + self, host_memory_kind: str + ): + if jax.device_count() < 3: + raise unittest.SkipTest("Test requires >=3 devices") - self.assertArraysAllClose(out1, np.sin(np_inp)) - self.assertArraysAllClose(out2, np.cos(np_inp)) - self.assertEqual(out1.sharding, s.with_memory_kind("unpinned_host")) - self.assertEqual(out2.sharding, s.with_memory_kind("device")) + out_host0 = jax.device_put( + jnp.arange(8), + SingleDeviceSharding(jax.devices()[0], memory_kind=host_memory_kind)) - def test_jit_explicit_device(self): - _, s, np_inp, inp = _create_inputs((8, 2), P("x"), mem_kind="device") + dev2 = jax.devices()[2] + out_hbm1 = jax.device_put( + out_host0, SingleDeviceSharding(dev2, memory_kind="device")) + self.assertEqual(out_hbm1.sharding.memory_kind, "device") + self.assertEqual(out_hbm1.sharding._device, dev2) + self.assertEqual(out_hbm1.addressable_shards[0].data.sharding._device, dev2) + self.assertEqual( + out_hbm1.addressable_shards[0].data.sharding.memory_kind, "device") - @jax.jit - def f(x): - return x * 2 + @parameterized.parameters("unpinned_host", "pinned_host") + def test_device_put_different_device_and_memory_hbm_to_host( + self, host_memory_kind: str + ): + if jax.device_count() < 3: + raise unittest.SkipTest("Test requires >=3 devices") - out = f(inp) - executable_mk = get_memory_kinds_from_executable(f, [inp]) - self.assertEqual(out.sharding, s) - self.assertArraysEqual(out, np_inp * 2) - self._check_mem_kind(executable_mk[0], out.sharding, "device") + out_hbm0 = jnp.arange(8) - def test_same_constant_value_on_different_memories(self): - _, s, np_inp, inp = _create_inputs((8, 2), P("x", "y"), mem_kind="device") + dev2 = jax.devices()[2] + out_host1 = jax.device_put( + out_hbm0, SingleDeviceSharding(dev2, memory_kind=host_memory_kind)) + self.assertEqual(out_host1.sharding.memory_kind, host_memory_kind) + self.assertEqual(out_host1.sharding._device, dev2) + self.assertEqual(out_host1.addressable_shards[0].data.sharding._device, + dev2) + self.assertEqual( + out_host1.addressable_shards[0].data.sharding.memory_kind, + host_memory_kind) - @jax.jit - def f(x): - x = x * 2 - y = jax.device_put(x, s.with_memory_kind("unpinned_host")) - z = y * 2 - a = jax.device_put(z, s.with_memory_kind("device")) - return a * 2, z + @parameterized.parameters("unpinned_host", "pinned_host") + def test_device_put_on_different_device_with_the_same_memory_kind( + self, host_memory_kind: str): + if len(jax.devices()) < 2: + raise unittest.SkipTest("Test requires >=2 devices.") - out1, out2 = f(inp) - executable_mk = get_memory_kinds_from_executable(f, [inp]) + np_inp = np.arange(16).reshape(8, 2) - self.assertArraysEqual(out1, np_inp * 8) - self.assertArraysEqual(out2, np_inp * 4) - self._check_mem_kind(executable_mk[0], out1.sharding, "device") - self._check_mem_kind(executable_mk[1], out2.sharding, "unpinned_host") + s_hbm_dev_0 = SingleDeviceSharding(jax.devices()[0], memory_kind="device") + s_hbm_dev_1 = SingleDeviceSharding(jax.devices()[1], memory_kind="device") + inp_hbm_dev0 = jax.device_put(np_inp, s_hbm_dev_0) + out_hbm_dev_1 = jax.device_put(inp_hbm_dev0, s_hbm_dev_1) + self._check_device_put_addressable_shards( + out_hbm_dev_1, np_inp, s_hbm_dev_1, "device") - def test_jit_out_shardings(self): - _, s, _, inp = _create_inputs((8, 2), P("x", "y")) + inp_host_dev0 = jax.device_put( + np_inp, s_hbm_dev_0.with_memory_kind(host_memory_kind)) + s_host_dev_1 = s_hbm_dev_1.with_memory_kind(host_memory_kind) + out_host_dev_1 = jax.device_put(inp_host_dev0, s_host_dev_1) + self._check_device_put_addressable_shards( + out_host_dev_1, np_inp, s_host_dev_1, host_memory_kind) - def _check(fun): - executable_mk = get_memory_kinds_from_executable(fun, [inp]) - outs = fun(inp) - for o, m in zip(outs, executable_mk): - self._check_mem_kind(m, o.sharding, "unpinned_host") - self.assertEqual(o.sharding, s.with_memory_kind("unpinned_host")) + # TODO(yashkatariya): Enable this once we can compute on host. + # def test_device_put_resharding(self): + # mesh = jtu.create_global_mesh((2, 2), ("x", "y")) + # s_host = NamedSharding(mesh, P("x", "y"), memory_kind="unpinned_host") + # s_hbm = s_host.with_memory_kind("device") + # np_inp = np.arange(16).reshape(8, 2) - @functools.partial( - jax.jit, out_shardings=s.with_memory_kind("unpinned_host") - ) - def f(x): - return x * 2, x * 2 + # # Reshard single device array on HBM to multi device on host + # sds_inp_hbm = jax.device_put( + # jnp.arange(16).reshape(8, 2), + # SingleDeviceSharding(jax.devices()[0], memory_kind="device")) + # # device_put on host + # out_sharded_host = jax.device_put(sds_inp_hbm, s_host) + # self._check_device_put_addressable_shards( + # out_sharded_host, np_inp, s_host, "unpinned_host") - _check(f) + # # Reshard single device array on host to multi device on hbm + # sds_inp_host = jax.device_put( + # jnp.arange(16).reshape(8, 2), + # SingleDeviceSharding(jax.devices()[0], memory_kind="unpinned_host")) + # # device_put on hbm + # out_sharded_hbm = jax.device_put(sds_inp_host, s_hbm) + # self._check_device_put_addressable_shards( + # out_sharded_hbm, np_inp, s_hbm, "device") - @functools.partial( - jax.jit, out_shardings=s.with_memory_kind("unpinned_host") - ) - def h(x): - return x, x * 3 + @parameterized.parameters("unpinned_host", "pinned_host") + def test_device_put_numpy_array(self, host_memory_kind: str): + mesh = jtu.create_global_mesh((2, 2), ("x", "y")) + np_inp = np.arange(16).reshape(8, 2) + s_hbm = NamedSharding(mesh, P(("x", "y")), memory_kind="device") + s_host = s_hbm.with_memory_kind(host_memory_kind) - _check(h) + out_hbm = jax.device_put(np_inp, s_hbm) + self._check_device_put_addressable_shards(out_hbm, np_inp, s_hbm, "device") - @functools.partial( - jax.jit, out_shardings=s.with_memory_kind("unpinned_host") - ) - def i(x): - return x, x + out_host = jax.device_put(np_inp, s_host) + self._check_device_put_addressable_shards( + out_host, np_inp, s_host, host_memory_kind) - _check(i) + @parameterized.parameters("unpinned_host", "pinned_host") + def test_device_put_numpy_scalar(self, host_memory_kind: str): + np_inp = np.float32(8) + s_hbm = SingleDeviceSharding(jax.devices()[0], memory_kind="device") + s_host = s_hbm.with_memory_kind(host_memory_kind) - def test_jit_out_shardings_single_output(self): - mesh, _, _, inp = _create_inputs((8, 2), P("x", "y")) - out_s = NamedSharding(mesh, P(), memory_kind="unpinned_host") + out_hbm = jax.device_put(np_inp, s_hbm) + self._check_device_put_addressable_shards(out_hbm, np_inp, s_hbm, "device") - @functools.partial(jax.jit, out_shardings=out_s) - def g(x): - return jnp.sum(x * 2) + out_host = jax.device_put(np_inp, s_host) + self._check_device_put_addressable_shards( + out_host, np_inp, s_host, host_memory_kind) - out = g(inp) - self.assertEqual(out.sharding, out_s) - executable_mk = get_memory_kinds_from_executable(g, [inp]) - self._check_mem_kind(executable_mk[0], out.sharding, "unpinned_host") + @parameterized.parameters("unpinned_host", "pinned_host") + def test_device_put_python_scalar(self, host_memory_kind: str): + py_scalar = float(8) + s_hbm = SingleDeviceSharding(jax.devices()[0], memory_kind="device") + s_host = s_hbm.with_memory_kind(host_memory_kind) - @jax.jit - def h(x): - x = jnp.sum(x * 2) - out = jax.device_put(x, out_s) - return out + out_hbm = jax.device_put(py_scalar, s_hbm) + self._check_device_put_addressable_shards( + out_hbm, py_scalar, s_hbm, "device", index=False) - out = h(inp) - self.assertEqual(out.sharding, out_s) - executable_mk = get_memory_kinds_from_executable(h, [inp]) - self._check_mem_kind(executable_mk[0], out.sharding, "unpinned_host") + out_host = jax.device_put(py_scalar, s_host) + self._check_device_put_addressable_shards( + out_host, py_scalar, s_host, host_memory_kind, index=False) - def test_jit_device_put_host_output(self): - _, s, _, inp = _create_inputs((8, 2), P("x", "y")) + @parameterized.parameters("unpinned_host", "pinned_host") + def test_device_put_python_int(self, host_memory_kind: str): + py_inp = 8 + s_hbm = SingleDeviceSharding(jax.devices()[0], memory_kind="device") + s_host = s_hbm.with_memory_kind(host_memory_kind) - def _check(fun): - executable_mk = get_memory_kinds_from_executable(fun, [inp]) - outs = fun(inp) - for o, m in zip(outs, executable_mk): - self._check_mem_kind(m, o.sharding, "unpinned_host") - self.assertEqual(o.sharding, s.with_memory_kind("unpinned_host")) + out_hbm = jax.device_put(py_inp, s_hbm) + self._check_device_put_addressable_shards( + out_hbm, py_inp, s_hbm, "device", index=False) - @jax.jit - def f(x): - x = x * 2 - out = jax.device_put(x, s.with_memory_kind("unpinned_host")) - return out, out + out_host = jax.device_put(py_inp, s_host) + self._check_device_put_addressable_shards( + out_host, py_inp, s_host, host_memory_kind, index=False) - _check(f) + def test_device_put_inside_jit(self): + _, s_host, np_inp, inp_host = _create_inputs( + (8, 2), P("x", "y"), mem_kind="pinned_host") + s_dev = s_host.with_memory_kind("device") @jax.jit - def h(x): - x = x * 2 - out = jax.device_put(x, s.with_memory_kind("unpinned_host")) - return out, out * 3 + def f(a, b): + x, y = jax.device_put((a, b), s_dev) + return x * y - _check(h) + out = f(inp_host, inp_host) + self._check_device_put_addressable_shards( + out, np_inp * np_inp, s_dev, "device") + + def test_parameter_streaming(self): + _, s_host, np_inp, inp_host = _create_inputs( + (8, 2), P("x", "y"), mem_kind="pinned_host") + s_dev = s_host.with_memory_kind('device') + inp_dev = jax.device_put(np_inp, s_dev) + + @functools.partial(jax.jit, out_shardings=s_host) + def f(a, b): + x = b * 2 + y = jax.device_put(a, s_dev) + z = x * y + return z * 4, z + + compiled = f.lower(inp_host, inp_dev).compile() # doesn't crash + compiled_text = compiled.as_text() + self.assertRegex(compiled_text, r"entry_computation_layout=.*S\(5\)}") + + out1, out2 = f(inp_host, inp_dev) + self._check_device_put_addressable_shards( + out1, np_inp * np_inp * 8, s_host, 'pinned_host') + self._check_device_put_addressable_shards( + out2, np_inp * np_inp * 2, s_host, 'pinned_host') - @jax.jit - def i(x): - x = x * 2 - out = jax.device_put(x, s.with_memory_kind("unpinned_host")) - return out * 2, out * 2 + def test_parameter_streaming_with_scalar_and_constant(self): + mesh = jtu.create_global_mesh((2, 2), ("x", "y")) + scalar_inp = 1 + s_host = NamedSharding(mesh, P(), memory_kind="pinned_host") - _check(i) + @functools.partial(jax.jit, out_shardings=s_host) + def f(scalar_input): + y = jax.device_put(scalar_input, s_host) + z = 2 + w = jax.device_put(z, s_host) + return y, w - def test_jit_in_shardings(self): - _, s, np_inp, inp = _create_inputs((8, 2), P("x", "y")) + compiled = f.lower(scalar_inp).compile() # doesn't crash + compiled_text = compiled.as_text() + self.assertRegex(compiled_text, r"entry_computation_layout=.*S\(5\)}") - @functools.partial( - jax.jit, in_shardings=s.with_memory_kind("unpinned_host") + out1, out2 = f(scalar_inp) + self._check_device_put_addressable_shards( + out1, scalar_inp, s_host, "pinned_host", index=False + ) + self._check_device_put_addressable_shards( + out2, 2, s_host, "pinned_host", index=False ) - def f(x): - return x * 2 - with self.assertRaisesRegex( - ValueError, - "Memory kinds passed to jax.jit does not match memory kind on the" - " respective arg. Got pjit memory kind: unpinned_host, arg memory kind:" - " device for arg shape.*", - ): - f(jnp.arange(16).reshape(8, 2)) # uncommitted inp also raises error + def test_parameter_and_output_streaming_with_array(self): + if xb.backend_xla_version() is not None and xb.backend_xla_version() < 2: + self.skipTest("This test requires an xla_version >= 2.") + mesh = jtu.create_global_mesh((2, 2), ("x", "y")) + np_inp = np.arange(16).reshape(8, 2) + s_host = NamedSharding(mesh, P("x", "y"), memory_kind="pinned_host") + inp_host = jax.device_put(np_inp, s_host) - with self.assertRaisesRegex( - ValueError, - "Memory kinds passed to jax.jit does not match memory kind on the" - " respective arg. Got pjit memory kind: unpinned_host, arg memory kind:" - " device for arg shape.*", - ): - f(inp) # committed inp raises error. + @functools.partial(jax.jit, out_shardings=(s_host, s_host)) + def f(x): + return (x, x) - @functools.partial(jax.jit, in_shardings=s.with_memory_kind("device")) - def g(x): - return x * 2 + compiled = f.lower(inp_host).compile() # doesn't crash + compiled_text = compiled.as_text() + if compiled_text is not None: + self.assertRegex(compiled_text, r"entry_computation_layout=.*S\(5\)}") - out = g(inp) - executable_kind = get_memory_kinds_from_executable(g, [inp]) - self.assertArraysEqual(out, np_inp * 2) - self._check_mem_kind(executable_kind[0], out.sharding, "device") + out1, out2 = f(inp_host) + self._check_device_put_addressable_shards( + out1, np_inp, s_host, "pinned_host" + ) + self._check_device_put_addressable_shards( + out2, np_inp, s_host, "pinned_host" + ) - def test_jit_in_out_shardings(self): - mesh, s, np_inp, inp = _create_inputs( - (8, 2), P("x", "y"), mem_kind="device" + def test_parameter_and_output_streaming_with_scalar(self): + if xb.backend_xla_version() is not None and xb.backend_xla_version() < 2: + self.skipTest("This test requires an xla_version >= 2.") + + mesh = jax.sharding.Mesh(jax.devices(), "axis") + s_host = jax.sharding.NamedSharding( + mesh, jax.sharding.PartitionSpec(), memory_kind="pinned_host" ) - out_s = NamedSharding(mesh, P(), memory_kind="device") + scalar_inp = 1 - @functools.partial(jax.jit, in_shardings=s, out_shardings=out_s) + @functools.partial(jax.jit, out_shardings=(s_host, s_host)) def f(x): - return jnp.sum(x) + return (x, x) - out = f(inp) - executable_kind = get_memory_kinds_from_executable(f, [inp]) - self.assertArraysEqual(out, np.sum(np_inp)) - self._check_mem_kind(executable_kind[0], out.sharding, "device") + compiled = f.lower(scalar_inp).compile() # doesn't crash + compiled_text = compiled.as_text() + if compiled_text is not None: + self.assertRegex(compiled_text, r"entry_computation_layout=.*S\(5\)}") - @functools.partial( - jax.jit, - in_shardings=s, - out_shardings=out_s.with_memory_kind("unpinned_host"), + out1, out2 = f(scalar_inp) + self._check_device_put_addressable_shards( + out1, scalar_inp, s_host, "pinned_host", index=False + ) + self._check_device_put_addressable_shards( + out2, scalar_inp, s_host, "pinned_host", index=False ) - def g(x): - return jnp.sum(x) - out = g(inp) - executable_kind = get_memory_kinds_from_executable(g, [inp]) - self.assertArraysEqual(out, np.sum(np_inp)) - self._check_mem_kind(executable_kind[0], out.sharding, "unpinned_host") + def test_identity_jit_host_to_device_and_vice_versa(self): + mesh = jtu.create_global_mesh((2, 2), ("x", "y")) + np_inp = np.arange(16).reshape(8, 2) + s_host = NamedSharding(mesh, P('x', 'y'), memory_kind='pinned_host') + s_dev = s_host.with_memory_kind('device') + arr_host = jax.device_put(np_inp, s_host) + arr_dev = jax.device_put(np_inp, s_dev) + + # pinned_host -> device + f = jax.jit(lambda x: x, out_shardings=s_dev) + out_dev = f(arr_host) + self.assertArraysEqual(out_dev, np_inp) + self.assertEqual(out_dev.sharding, s_dev) + + # device -> pinned_host + g = jax.jit(lambda x: x, out_shardings=s_host) + out_host = g(arr_dev) + self.assertArraysEqual(out_host, np_inp) + self.assertEqual(out_host.sharding, s_host) + + def test_parameter_streaming_inside_scan(self): + mesh = jtu.create_global_mesh((1, 1, 2), ("x", "y", "z")) + np_inp = np.arange(4096.0).reshape(16, 16, 16) + s_host = NamedSharding(mesh, P("x", "y", "z"), memory_kind="pinned_host") + arr_host = jax.device_put(np_inp, s_host) - def test_device_put_different_devices(self): - _, _, _, inp = _create_inputs((8, 2), P("x", "y")) + @jax.jit + def f(xs): + def body(carry, x): + x_tpu = jax.device_put(x, TransferToMemoryKind("device")) + return carry, x_tpu + carry + + return jax.lax.scan(body, 1.0, xs) + + _, out_hbm = f(arr_host) + self.assertArraysEqual(out_hbm, np_inp + 1.0) + # Only expect the last dimension to have a named sharding. + out_s = NamedSharding(mesh, P(None, None, "z"), memory_kind="device") + self.assertEqual(out_hbm.sharding, out_s) + + def test_output_streaming(self): + mesh = jtu.create_global_mesh((1, 1), ("x", "y")) + np_inp = np.arange(16.0).reshape(8, 2) + s_hbm = NamedSharding(mesh, P("x", "y"), memory_kind="device") + s_host = NamedSharding(mesh, P("x", "y"), memory_kind="pinned_host") + arr_hbm = jax.device_put(np_inp, s_hbm) + + @functools.partial(jax.jit, out_shardings=s_host) + def f(xs): + out_tpu = xs + 1.0 + return out_tpu + + out_host = f(arr_hbm) + self.assertArraysEqual(out_host, np_inp + 1.0) + self.assertEqual(out_host.sharding, s_host) + + def test_weight_offload_with_dp_on_output(self): + _, s_dev, np_inp, inp_dev = _create_inputs( + (8, 2), P("x", "y"), mem_kind="device") + s_host = s_dev.with_memory_kind('pinned_host') @jax.jit def f(x): - return jax.device_put( - x, SingleDeviceSharding(jax.devices()[0], memory_kind="unpinned_host") - ) + x = x * 2 + y = jax.device_put(x, s_host) + return y - with self.assertRaisesRegex( - ValueError, "Received incompatible devices for jitted computation" - ): - f(inp) + out_host = f(inp_dev) + self._check_device_put_addressable_shards( + out_host, np_inp * 2, s_host, 'pinned_host') - def test_jit_multiple_transfers(self): - mesh, _, np_inp, inp = _create_inputs((8, 2), P(None, "y")) - s2 = NamedSharding(mesh, P("x")) - inp2 = jax.device_put(np_inp, s2) + def test_output_streaming_inside_scan(self): + if xb.backend_xla_version() is not None and xb.backend_xla_version() < 2: + self.skipTest("This test requires an xla_version >= 2.") + mesh = jtu.create_global_mesh((1, 1, 2), ("x", "y", "z")) + np_inp = np.arange(4096).reshape(16, 16, 16) + s_hbm = NamedSharding(mesh, P(None, "y", "z"), memory_kind="device") + arr_hbm = jax.device_put(np_inp, s_hbm) @jax.jit - def f(x, y): - a = x + y - b, c = jax.device_put((a, x), s2.with_memory_kind("unpinned_host")) - return b * c, y * 2 - - out1, out2 = f(inp, inp2) - executable_mem = get_memory_kinds_from_executable(f, [inp, inp2]) - self.assertArraysEqual(out1, (np_inp + np_inp) * np_inp) - self.assertArraysEqual(out2, np_inp * 2) - self._check_mem_kind(executable_mem[0], out1.sharding, "unpinned_host") - self._check_mem_kind(executable_mem[1], out2.sharding, "device") - - def test_jit_single_device_multi_output_host_mem(self): - if xb.using_pjrt_c_api(): - raise unittest.SkipTest("GetOutputShardings not supported in PJRT C API") - inp = jnp.arange(16).reshape(8, 2) + def f(xs): + def body(carry, x): + out_tpu = x + carry + return carry, jax.device_put( + out_tpu, NamedSharding(mesh, P("y", "z"), memory_kind="pinned_host")) + _, res = jax.lax.scan(body, 1, xs) + return res - @jax.jit - def f(x): - x = jax.device_put( - x, SingleDeviceSharding(jax.devices()[0], memory_kind="unpinned_host") - ) - return x * 2, x * 3 - - out1, out2 = f(inp) - executable_mem = get_memory_kinds_from_executable(f, [inp]) - self.assertArraysEqual(out1, inp * 2) - self.assertArraysEqual(out2, inp * 3) - self._check_mem_kind(executable_mem[0], out1.sharding, "unpinned_host") - self._check_mem_kind(executable_mem[1], out2.sharding, "unpinned_host") - - def test_jit_reshard(self): - mesh, _, np_inp, inp = _create_inputs((8, 2), P(None, "y")) - out_s = NamedSharding(mesh, P(("x", "y")), memory_kind="unpinned_host") - - def _check(fun, inp): - out = fun(inp) - self.assertArraysEqual(out, np_inp * 2) - self.assertEqual(out.sharding, out_s) - executable_kind = get_memory_kinds_from_executable(fun, [inp]) - self._check_mem_kind(executable_kind[0], out.sharding, "unpinned_host") + out = f(arr_hbm) + self.assertArraysEqual(out, np_inp + 1) + self.assertEqual(out.sharding.memory_kind, 'pinned_host') - @functools.partial(jax.jit, out_shardings=out_s) - def f(x): - return x * 2 + def test_deepcopy(self): + if xb.backend_xla_version() is not None and xb.backend_xla_version() < 2: + self.skipTest("This test requires an xla_version >= 2.") + mesh = jax.sharding.Mesh(jax.devices(), "x") + s_host = NamedSharding(mesh, P(), memory_kind="pinned_host") - _check(f, inp) + t = jax.device_put(jnp.zeros((8, 2)), s_host) + t_copy = copy.deepcopy(t) + self.assertArraysEqual(t, t_copy) + self.assertEqual(t.shape, t_copy.shape) - @jax.jit - def g(x): - y = jax.device_put(x, out_s) - return y * 2 - _check(g, inp) +@jtu.with_config(jax_enable_memories=True) +class ComputeOffload(jtu.BufferDonationTestCase): - def test_jit_cpp_cache_hit(self): - mesh, _, np_inp, inp = _create_inputs((8, 2), P("x", "y")) - inp2 = jax.device_put( - np_inp, NamedSharding(mesh, P("x", "y"), memory_kind="device") - ) + def setUp(self): + if not jtu.test_device_matches(["tpu"]): + self.skipTest("Memories do not work on CPU and GPU backends yet.") + super().setUp() - f = jax.jit(lambda x: x @ x.T) + def _check_mem_kind(self, executable_kind, out_sharding, expected_kind): + out_kind = out_sharding.memory_kind + self.assertEqual(executable_kind, out_kind) + self.assertEqual(out_kind, expected_kind) + self.assertEqual(executable_kind, expected_kind) - with jtu.count_pjit_cpp_cache_miss() as count: - out = f(inp) - out2 = f(inp2) - self.assertEqual(count[0], 1) + def test_compute_no_inputs(self): + mesh = jtu.create_global_mesh((4,), ('data')) - self.assertArraysEqual(out, np_inp @ np_inp.T) - self.assertArraysEqual(out2, np_inp @ np_inp.T) + tpu_sharding = NamedSharding(mesh, P('data')) + cpu_sharding = NamedSharding(mesh, P('data'), memory_kind='pinned_host') - def test_jit_compilation_cache_hit(self): - mesh, s, np_inp, inp = _create_inputs((8, 2), P("x", "y")) - inp2 = jax.device_put( - np_inp, - GSPMDSharding( - tuple(mesh.devices.flat), - s._to_xla_hlo_sharding(inp.ndim), - memory_kind="device", - ), - ) + @functools.partial(jax.jit, out_shardings=(tpu_sharding, cpu_sharding)) + def init(): + tpu_array = jax.random.normal(jax.random.key(42), (16,16)) + cpu_array = jax.random.normal(jax.random.key(42), (16,16)) + return tpu_array, cpu_array - f = jax.jit(lambda x: x @ x.T) + tpu_array, cpu_array = init() + self.assertEqual(tpu_array.sharding, tpu_sharding) + self.assertEqual(cpu_array.sharding, cpu_sharding) - with ( - jtu.count_pjit_cpp_cache_miss() as cpp_count, - jtu.count_jit_and_pmap_compiles() as compile_count, - ): - f(inp) - f(inp2) - self.assertEqual(cpp_count[0], 2) - self.assertEqual(compile_count[0], 1) + def test_compute_no_inputs_host_replicated(self): + if xb.backend_xla_version() is not None and xb.backend_xla_version() < 3: + self.skipTest("This test requires an xla_version >= 3.") + mesh = jtu.create_global_mesh((4,), ('data')) - def test_jit_cpp_cache_output_hit(self): - _, _, _, inp = _create_inputs((8, 2), P("x"), mem_kind="device") + tpu_sharding = NamedSharding(mesh, P('data')) + cpu_sharding = NamedSharding(mesh, P(), memory_kind='pinned_host') + + @functools.partial(jax.jit, out_shardings=(tpu_sharding, cpu_sharding)) + def init(): + tpu_array = jax.random.normal(jax.random.key(42), (16,16)) + cpu_array = jax.random.normal(jax.random.key(42), (16,16)) + return tpu_array, cpu_array + + tpu_array, cpu_array = init() + self.assertEqual(tpu_array.sharding, tpu_sharding) + self.assertEqual(cpu_array.sharding, cpu_sharding) + def test_compute_on_basic(self): + out_s = SingleDeviceSharding(jax.devices()[0], memory_kind='pinned_host') + + @compute_on('device_host') @jax.jit - def mul_two(x): + def g(x): return x * 2 - with jtu.count_pjit_cpp_cache_miss() as count: - out = mul_two(inp) - mul_two(out) - self.assertEqual(count[0], 1) + @jax.jit + def f(x): + y = g(x) + return y * 3 - def test_jit_cache_miss(self): - mesh, _, np_inp, inp = _create_inputs( - (8, 2), P("x", "y"), mem_kind="device" - ) - out_s_host = NamedSharding(mesh, P("x", "y"), memory_kind="unpinned_host") + inp = jnp.arange(8) + out = f(inp) + self.assertArraysEqual(out, inp * 6) - @functools.partial(jax.jit, out_shardings=out_s_host) - def mul_three(x): - return x * 3 + lowered_text = f.lower(jnp.arange(8)).as_text() + self.assertIn('_xla_compute_type', lowered_text) - with ( - jtu.count_pjit_cpp_cache_miss() as cpp_count, - jtu.count_jit_and_pmap_compiles() as compile_count, - ): - out = mul_three(inp) - out2 = mul_three(out) + @functools.partial(jax.jit, out_shardings=out_s) + def h(x): + y = g(x) + return y * 3 - self.assertEqual(cpp_count[0], 2) - self.assertEqual(compile_count[0], 2) - self.assertEqual(out.sharding, out_s_host) - self.assertEqual(out2.sharding, out_s_host) - self.assertArraysEqual(out, np_inp * 3) - self.assertArraysEqual(out2, np_inp * 9) - executable_mk = get_memory_kinds_from_executable(mul_three, [inp]) - self._check_mem_kind(executable_mk[0], out.sharding, "unpinned_host") - executable_mk2 = get_memory_kinds_from_executable(mul_three, [out]) - self._check_mem_kind(executable_mk2[0], out2.sharding, "unpinned_host") - - def test_jit_host_input_from_another_jit_output(self): - mesh, _, np_inp, inp = _create_inputs((8, 2), P("x", "y")) - out_host_s = jax.sharding.NamedSharding( - mesh, P("x", "y"), memory_kind="unpinned_host" - ) + out2 = h(inp) + self.assertArraysEqual(out2, inp * 6) + self.assertEqual(out2.sharding.memory_kind, 'pinned_host') + + def test_compute_on_reduction(self): + out_s = SingleDeviceSharding(jax.devices()[0], memory_kind='pinned_host') + + @compute_on('device_host') + @jax.jit + def g(x): + # Reduction generates multiple host computations (inside a single host + # computation module): the main one and a reduction body. + return jnp.sum(x) - @functools.partial(jax.jit, out_shardings=out_host_s) + @jax.jit def f(x): - return x * 2 + y = g(x) + z = jnp.sum(x) + return y * z + inp = jnp.arange(8) out = f(inp) - self.assertEqual(out.sharding, out_host_s) - executable_kind = get_memory_kinds_from_executable(f, [inp]) - self._check_mem_kind(executable_kind[0], out.sharding, "unpinned_host") - self.assertArraysEqual(out, np_inp * 2) + self.assertArraysEqual(out, np.sum(inp) * np.sum(inp)) - # Input to `f` is on host memory. - out2 = f(out) - self.assertEqual(out2.sharding, out_host_s) - executable_kind = get_memory_kinds_from_executable(f, [out]) - self._check_mem_kind(executable_kind[0], out2.sharding, "unpinned_host") - self.assertArraysEqual(out2, np_inp * 4) + lowered_text = f.lower(jnp.arange(8)).as_text() + self.assertIn('_xla_compute_type', lowered_text) - lowered_hlo = f.lower(out).as_text(dialect="hlo") - self.assertIn('_xla_buffer_placement="arg"', lowered_hlo) + @functools.partial(jax.jit, out_shardings=out_s) + def h(x): + y = g(x) + z = jnp.sum(x) + return y * z - def test_jit_cache_hit_with_default_and_specified_mem_kind(self): - _, s, np_inp, _ = _create_inputs((8, 2), P("x", "y")) - _, s2, np_inp2, _ = _create_inputs((8, 2), P("x", "y"), mem_kind="device") + out2 = h(inp) + self.assertArraysEqual(out2, np.sum(inp) * np.sum(inp)) + self.assertEqual(out2.sharding.memory_kind, 'pinned_host') - def mul(x): - return x @ x.T + def test_nested_compute_error(self): + @compute_on('device') + @jax.jit + def f0(x): + return x * 2 - f = jax.jit(mul, in_shardings=s) - g = jax.jit(mul, in_shardings=s2) + @compute_on('device_host') + @jax.jit + def f1(x): + return f0(x) - with jtu.count_jit_and_pmap_compiles() as count: - out = f(np_inp) - out2 = g(np_inp2) - self.assertEqual(count[0], 1) + @jax.jit + def f2(x): + return f1(x) - self.assertArraysEqual(out, np_inp @ np_inp.T) - self.assertArraysEqual(out2, np_inp2 @ np_inp2.T) + with self.assertRaisesRegex( + NotImplementedError, + "Nesting `compute_on` with different compute types is not supported" + " yet."): + f2(jnp.arange(8)) - def test_sharding_devices_indices_map_cache_hit(self): - mesh = jtu.create_global_mesh((4, 2), ("x", "y")) - shape = (8, 2) - s1 = NamedSharding(mesh, P("x", "y")) - s2 = NamedSharding(mesh, P("x", "y"), memory_kind="device") + def test_compute_on_grad(self): + @compute_on('device_host') + @jax.jit + def g(x): + return jnp.sin(x) - s1.devices_indices_map(shape) - cache_info1 = common_devices_indices_map.cache_info() - s2.devices_indices_map(shape) - cache_info2 = common_devices_indices_map.cache_info() - self.assertEqual(cache_info2.hits, cache_info1.hits + 1) - self.assertEqual(cache_info2.misses, cache_info1.misses) + def f(x): + y = g(x) + return jnp.sum(y) - def test_jit_host_inputs_via_device_put_outside(self): - mesh = jtu.create_global_mesh((4, 2), ("x", "y")) - s_host = NamedSharding(mesh, P("x", "y"), memory_kind="unpinned_host") - s_hbm = s_host.with_memory_kind("device") - inp = jnp.arange(16).reshape(8, 2) - np_inp = np.arange(16).reshape(8, 2) + inp = jnp.arange(8.) + jf = jax.jit(jax.grad(f)) - inp_host = jax.device_put(inp, s_host) - inp_hbm = jax.device_put(inp, s_hbm) + jtu.check_grads(jf, (inp,), order=2) - @jax.jit - def f(x, y): - return x * 2, y * 2 + lowered_text = jf.lower(inp).as_text('hlo') + out = re.findall(r"call.*to_apply.*_xla_compute_type", lowered_text) + self.assertLen(out, 2) - out_host, out_hbm = f(inp_host, inp_hbm) + def test_compute_on_remat(self): + inp = jnp.arange(16.) - self._check_device_put_addressable_shards( - out_host, np_inp * 2, s_host, "unpinned_host") - self._check_device_put_addressable_shards( - out_hbm, np_inp * 2, s_hbm, "device") + def policy(prim, *avals, **params): + return Recompute - def test_trivial_computation(self): - if xb.using_pjrt_c_api(): - raise unittest.SkipTest("GetOutputShardings not supported in PJRT C API") - mesh = jtu.create_global_mesh((2, 1), ("x", "y")) - np_inp = np.arange(16).reshape(8, 2) + @compute_on('device_host') + @jax.jit + def g(x): + x = jnp.sin(x) + x = jnp.sin(x) + x = jnp.sin(x) + return x - s_hbm = NamedSharding(mesh, P("x")) - inp = jax.device_put(np_inp, s_hbm) - f = jax.jit(lambda x: x) - out = f(inp) - self.assertArraysEqual(out, np_inp) - self.assertEqual(out.sharding, s_hbm) + @functools.partial(remat, policy=policy) + def f(x): + x = g(x) + return jnp.sum(x) - s_host = NamedSharding(mesh, P(None, "x"), memory_kind="unpinned_host") - inp = jax.device_put(np_inp, s_host) - f = jax.jit(lambda x: x) - out = f(inp) - self.assertArraysEqual(out, np_inp) - self.assertEqual(out.sharding, s_host) + # Execution test. + jf = jax.jit(jax.grad(f)) + jf(inp) # doesn't crash - def test_no_donation_across_memory_kinds(self): - if xb.using_pjrt_c_api(): - raise unittest.SkipTest("GetOutputShardings not supported in PJRT C API") - mesh = jtu.create_global_mesh((2, 1), ("x", "y")) + lowered_text = jf.lower(inp).as_text('hlo') + out = re.findall(r"call.*to_apply.*_xla_compute_type", lowered_text) + self.assertLen(out, 2) + + def test_nested_no_op_compute(self): + mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) + s = NamedSharding(mesh, P('x', 'y')) np_inp = np.arange(16).reshape(8, 2) - s_hbm = NamedSharding(mesh, P("x")) - s_host = s_hbm.with_memory_kind("unpinned_host") - inp = jax.device_put(np_inp, s_hbm) + arr = jax.device_put(np_inp, s) - @functools.partial(jax.jit, out_shardings=s_host, donate_argnums=0) - def f(x): + @compute_on('device_host') + @jax.jit + def f0(x): return x * 2 - with self.assertWarnsRegex(UserWarning, "Some donated buffers were not usable"): - f(inp) + @compute_on('device_host') + @jax.jit + def f1(x): + x = x * 3 + return f0(x) - lowered_text = f.lower(inp).as_text("hlo") - self.assertNotIn("input_output_alias", lowered_text) - self.assertNotDeleted(inp) + @jax.jit + def f2(x): + return f1(x) - @parameterized.named_parameters( - ("hbm_to_host", "device", "unpinned_host"), - ("host_to_hbm", "unpinned_host", "device") - ) - def test_device_put_memory_kind_no_sharding(self, inp_mem_kind, out_mem_kind): - mesh = jtu.create_global_mesh((4, 2), ("x", "y")) + out = f2(arr) + self.assertArraysEqual(out, arr * 6) + self.assertEqual(out.sharding, s) + + def test_sharded_compute_on_host(self): + mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) + s = NamedSharding(mesh, P('x', 'y')) np_inp = np.arange(16).reshape(8, 2) - s = NamedSharding(mesh, P("x", "y"), memory_kind=inp_mem_kind) - inp = jax.device_put(np_inp, s) + arr = jax.device_put(np_inp, s) + @compute_on('device_host') @jax.jit - def f(x): - y = x @ x.T - z = jax.device_put(y, TransferToMemoryKind(out_mem_kind)) - return z * 2 + def g(x, y): + return x * y - out = f(inp) + @jax.jit + def f(x): + x = x * 3 + return g(x, x) - self._check_device_put_addressable_shards( - out, (np_inp @ np_inp.T) * 2, - NamedSharding(mesh, P("x"), memory_kind=out_mem_kind), - out_mem_kind) - executable_kind = get_memory_kinds_from_executable(f, [inp]) - self._check_mem_kind(executable_kind[0], out.sharding, out_mem_kind) + out = f(arr) + expected_out = (np_inp * 3) * (np_inp * 3) + self.assertEqual(out.sharding, s) + self.assertArraysEqual(out, expected_out) - @parameterized.named_parameters( - ("hbm_to_host", "device", "unpinned_host"), - ("host_to_hbm", "unpinned_host", "device") - ) - def test_device_put_memory_kind_no_sharding_output( - self, inp_mem_kind, out_mem_kind): - mesh = jtu.create_global_mesh((4, 2), ("x", "y")) - np_inp = np.arange(16).reshape(8, 2) - s = NamedSharding(mesh, P("x", "y"), memory_kind=inp_mem_kind) - inp = jax.device_put(np_inp, s) + def test_host_offload_in_custom_vjp(self): + if xb.backend_xla_version() is not None and xb.backend_xla_version() < 2: + self.skipTest("This test requires an xla_version >= 2.") + @jax.custom_vjp + def f(x): + return jnp.sin(x) + @compute_on('device_host') @jax.jit + def eq(x, y): + return (x == y).astype(jnp.float32) + + def f_fwd(x): + y = x * 2 + z = jax.device_put(y, TransferToMemoryKind('pinned_host')) + return y, (x, z) + + def f_bwd(res, tx): + x, z = res + y = x * 2 + z2 = jax.device_put(y, TransferToMemoryKind('pinned_host')) + return (eq(z, z2),) + + f.defvjp(f_fwd, f_bwd) + g = jax.jit(jax.grad(lambda x: f(x).sum())) + + x = jnp.ones(3) * 4 + all_true = jnp.ones(3) + self.assertArraysEqual(g(x), all_true) + + def test_host_offload_in_custom_vjp_sharded(self): + if xb.backend_xla_version() is not None and xb.backend_xla_version() < 2: + self.skipTest("This test requires an xla_version >= 2.") + mesh = jtu.create_global_mesh((2, 2), ("x", "y")) + s = NamedSharding(mesh, P('x')) + + @jax.custom_vjp def f(x): - y = x @ x.T - return jax.device_put(y, TransferToMemoryKind(out_mem_kind)) + return jnp.sin(x) - out = f(inp) + @compute_on('device_host') + @jax.jit + def eq(x, y): + return (x == y).astype(jnp.float32) - self._check_device_put_addressable_shards( - out, np_inp @ np_inp.T, - NamedSharding(mesh, P("x"), memory_kind=out_mem_kind), - out_mem_kind) - executable_kind = get_memory_kinds_from_executable(f, [inp]) - self._check_mem_kind(executable_kind[0], out.sharding, out_mem_kind) + def f_fwd(x): + y = x * 2 + z = jax.device_put(y, s.with_memory_kind('pinned_host')) + return y, (x, z) - @parameterized.named_parameters( - ("hbm_to_host", "device", "unpinned_host"), - ("host_to_hbm", "unpinned_host", "device") - ) - def test_device_put_memory_kind_no_sharding_input( - self, inp_mem_kind, out_mem_kind): - mesh = jtu.create_global_mesh((4, 2), ("x", "y")) + def f_bwd(res, tx): + x, z = res + y = x * 2 + z2 = jax.device_put(y, s.with_memory_kind('pinned_host')) + return (eq(z, z2),) + + f.defvjp(f_fwd, f_bwd) + g = jax.jit(jax.grad(lambda x: f(x).sum())) + + arr = jax.device_put(jnp.ones(4) * 4, s) + all_true = jnp.ones(4) + self.assertArraysEqual(g(arr), all_true) + + def test_scan_offload(self): + np_inp = jnp.arange(4096).reshape(16, 16, 16) + + @jax.jit + def f(xs): + def body(carry, x): + with compute_on('device_host'): + out_tpu = x + carry + return carry, out_tpu + _, res = jax.lax.scan(body, 1, xs) + return res + + out = f(np_inp) + self.assertArraysEqual(out, np_inp + 1) + + @compute_on('device_host') + @jax.jit + def body2(carry, x): + out_tpu = x + carry + return carry, out_tpu + + @jax.jit + def f2(xs): + _, res = jax.lax.scan(body2, 1, xs) + return res + + out2 = f2(np_inp) + self.assertArraysEqual(out2, np_inp + 1) + + def test_pure_host_data_and_compute(self): + if xb.backend_xla_version() is not None and xb.backend_xla_version() < 2: + self.skipTest("This test requires an xla_version >= 2.") + mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) + s = NamedSharding(mesh, P('x', 'y'), memory_kind='pinned_host') np_inp = np.arange(16).reshape(8, 2) - s = NamedSharding(mesh, P("x", "y"), memory_kind=inp_mem_kind) - inp = jax.device_put(np_inp, s) + arr_host = jax.device_put(np_inp, s) + @compute_on('device_host') @jax.jit + def g(x): + return x * x + + @functools.partial(jax.jit, out_shardings=s) def f(x): - y = jax.device_put(x, TransferToMemoryKind(out_mem_kind)) - return y + return g(x) - # committed sharded input. - out = f(inp) - self.assertTrue(out._committed) - self._check_device_put_addressable_shards( - out, np_inp, s.with_memory_kind(out_mem_kind), out_mem_kind) + out = f(arr_host) + self.assertEqual(out.sharding, s) + self.assertEqual(out.sharding.memory_kind, 'pinned_host') + self.assertArraysEqual(out, np_inp * np_inp) - s1 = SingleDeviceSharding(jax.devices()[1], memory_kind=inp_mem_kind) - committed_single_device_inp = jax.device_put(np_inp, s1) - out2 = f(committed_single_device_inp) - self.assertTrue(out2._committed) - self._check_device_put_addressable_shards( - out2, np_inp, s1.with_memory_kind(out_mem_kind), out_mem_kind) + def test_eager_compute(self): + inp = jnp.arange(8.) + with compute_on('device_host'): + out = inp * 2 + out = jnp.sin(out) + self.assertArraysAllClose(out, jnp.sin(inp * 2)) + + def test_compute_per_annotation(self): + mesh = jtu.create_global_mesh((2, 2), ("x", "y")) + s = NamedSharding(mesh, P("x", "y")) + np_inp = np.arange(16.).reshape(8, 2) + arr = jax.device_put(np_inp, s) @jax.jit - def g(x): - y = jax.device_put(x, TransferToMemoryKind(out_mem_kind)) - return y + @compute_on('device_host') + def f(x): + return jnp.sin(x * 2) - # Uncommitted input but output will be committed because of device_put. - out3 = g(np_inp) - self.assertTrue(out3._committed) - self._check_device_put_addressable_shards( - out3, np_inp, - SingleDeviceSharding(jax.devices()[0], memory_kind=out_mem_kind), - out_mem_kind) + # # sharded input + out = f(arr) + self.assertArraysAllClose(out, np.sin(np_inp * 2)) - @functools.partial(jax.jit, in_shardings=s) - def h(x): - y = jax.device_put(x, TransferToMemoryKind(out_mem_kind)) - return y + out2 = f(np_inp) + self.assertArraysAllClose(out2, np.sin(np_inp * 2)) - out4 = h(np_inp) - self.assertTrue(out4._committed) - self._check_device_put_addressable_shards( - out4, np_inp, s.with_memory_kind(out_mem_kind), out_mem_kind) + def test_jit_host_multi_outputs(self): + if xb.backend_xla_version() is not None and xb.backend_xla_version() < 2: + self.skipTest("This test requires an xla_version >= 2.") + _, s, np_inp, inp = _create_inputs((8, 2), P("x")) - def test_error_transfer_to_memory_kind_outside_jit(self): - with self.assertRaisesRegex( - ValueError, - "TransferToMemoryKind argument to jax.device_put can only be used" - " inside jax.jit"): - jax.device_put(np.arange(16), TransferToMemoryKind("device")) + @jax.jit + def f(x, y): + x, y = jnp.sin(x), jnp.cos(y) + x = jax.device_put(x, s.with_memory_kind("pinned_host")) + y = jax.device_put(y, s.with_memory_kind("device")) + return x, y - def test_single_mem_kind_donation_default_mem_kind(self): - mesh = jtu.create_global_mesh((2,), "x") - s = NamedSharding(mesh, P()) + out1, out2 = f(inp, inp) - @functools.partial(jax.jit, out_shardings=s, donate_argnums=0) - def f(inp1): - return inp1 * 2 + self.assertArraysAllClose(out1, np.sin(np_inp)) + self.assertArraysAllClose(out2, np.cos(np_inp)) + self.assertEqual(out1.sharding, s.with_memory_kind("pinned_host")) + self.assertEqual(out2.sharding, s.with_memory_kind("device")) - x = jax.device_put(np.arange(16).reshape(8, 2), s) + def test_jit_out_shardings_single_output(self): + mesh, _, _, inp = _create_inputs((8, 2), P("x", "y")) + out_s = NamedSharding(mesh, P(), memory_kind="pinned_host") - f(x) + @functools.partial(jax.jit, out_shardings=out_s) + def g(x): + return jnp.sum(x * 2) - lowered_text = f.lower(x).as_text("hlo") - self.assertIn("input_output_alias", lowered_text) - self.assertDeleted(x) + out = g(inp) + self.assertEqual(out.sharding, out_s) + executable_mk = get_memory_kinds_from_executable(g, [inp]) + self._check_mem_kind(executable_mk[0], out.sharding, "pinned_host") - def test_host_offload_in_custom_vjp(self): - if xb.using_pjrt_c_api(): - raise unittest.SkipTest("GetOutputShardings not supported in PJRT C API") - @jax.custom_vjp + @jax.jit + def h(x): + x = jnp.sum(x * 2) + out = jax.device_put(x, out_s) + return out + + out = h(inp) + self.assertEqual(out.sharding, out_s) + executable_mk = get_memory_kinds_from_executable(h, [inp]) + self._check_mem_kind(executable_mk[0], out.sharding, "pinned_host") + + def test_jit_in_shardings(self): + _, s, np_inp, inp = _create_inputs((8, 2), P("x", "y")) + + @functools.partial(jax.jit, in_shardings=s.with_memory_kind("pinned_host")) def f(x): - return jnp.sin(x) + return x * 2 - def f_fwd(x): - y = x * 2 - z = jax.device_put(y, TransferToMemoryKind('unpinned_host')) - return y, (x, z) + with self.assertRaisesRegex( + ValueError, + "Memory kinds passed to jax.jit does not match memory kind on the" + " respective arg. Got pjit memory kind: pinned_host, arg memory kind:" + " device for arg shape.*"): + f(jnp.arange(16).reshape(8, 2)) # uncommitted inp also raises error - def f_bwd(res, tx): - x, z = res - y = x * 2 - z2 = jax.device_put(y, TransferToMemoryKind('unpinned_host')) - return ((z == z2).astype(jnp.float32),) + with self.assertRaisesRegex( + ValueError, + "Memory kinds passed to jax.jit does not match memory kind on the" + " respective arg. Got pjit memory kind: pinned_host, arg memory kind:" + " device for arg shape.*"): + f(inp) # committed inp raises error. - f.defvjp(f_fwd, f_bwd) - g = jax.jit(jax.grad(lambda x: f(x).sum())) + @functools.partial(jax.jit, in_shardings=s.with_memory_kind("device")) + def g(x): + return x * 2 - x = jnp.ones(3) * 4 - all_true = jnp.ones(3) - self.assertArraysEqual(g(x), all_true) + out = g(inp) + executable_kind = get_memory_kinds_from_executable(g, [inp]) + self.assertArraysEqual(out, np_inp * 2) + self._check_mem_kind(executable_kind[0], out.sharding, "device") - def test_host_offload_in_custom_vjp_sharded(self): - mesh = jtu.create_global_mesh((2, 2), ("x", "y")) - s = NamedSharding(mesh, P('x')) + def test_jit_in_out_shardings(self): + mesh, s, np_inp, inp = _create_inputs((8, 2), P("x", "y"), mem_kind="device") + out_s = NamedSharding(mesh, P(), memory_kind="device") - @jax.custom_vjp + @functools.partial(jax.jit, in_shardings=s, out_shardings=out_s) def f(x): - return jnp.sin(x) + return jnp.sum(x) - def f_fwd(x): - y = x * 2 - z = jax.device_put(y, s.with_memory_kind('unpinned_host')) - return y, (x, z) + out = f(inp) + executable_kind = get_memory_kinds_from_executable(f, [inp]) + self.assertArraysEqual(out, np.sum(np_inp)) + self._check_mem_kind(executable_kind[0], out.sharding, "device") - def f_bwd(res, tx): - x, z = res - y = x * 2 - z2 = jax.device_put(y, s.with_memory_kind('unpinned_host')) - return ((z == z2).astype(jnp.float32),) + @functools.partial( + jax.jit, + in_shardings=s, + out_shardings=out_s.with_memory_kind("pinned_host"), + ) + def g(x): + return jnp.sum(x) - f.defvjp(f_fwd, f_bwd) - g = jax.jit(jax.grad(lambda x: f(x).sum())) + out = g(inp) + executable_kind = get_memory_kinds_from_executable(g, [inp]) + self.assertArraysEqual(out, np.sum(np_inp)) + self._check_mem_kind(executable_kind[0], out.sharding, "pinned_host") - x = jax.device_put(jnp.ones(4) * 4, s) - all_true = jnp.ones(4) - self.assertArraysEqual(g(x), all_true) + def test_device_put_different_devices(self): + _, _, _, inp = _create_inputs((8, 2), P("x", "y")) + @jax.jit + def f(x): + return jax.device_put( + x, SingleDeviceSharding(jax.devices()[0], memory_kind="pinned_host")) -class DevicePutTest(jtu.JaxTestCase): + with self.assertRaisesRegex( + ValueError, "Received incompatible devices for jitted computation"): + f(inp) - def setUp(self): - if not jtu.test_device_matches(["tpu"]): - self.skipTest("Memories do not work on CPU and GPU backends yet.") - super().setUp() - self.orig_memories_flag = config.enable_memories.value - jax.config.update('jax_enable_memories', True) + def test_jit_cpp_cache_hit(self): + mesh, _, np_inp, inp = _create_inputs((8, 2), P("x", "y")) + inp2 = jax.device_put( + np_inp, NamedSharding(mesh, P("x", "y"), memory_kind="device")) - def tearDown(self): - jax.config.update('jax_enable_memories', self.orig_memories_flag) - super().tearDown() + f = jax.jit(lambda x: x @ x.T) - def _check_device_put_addressable_shards( - self, out, inp, expected_sharding, expected_mem_kind, index=True): - self.assertArraysEqual(out, inp) - self.assertEqual(out.sharding, expected_sharding) - self.assertEqual(out.sharding.memory_kind, expected_mem_kind) - for s in out.addressable_shards: - if index: - self.assertArraysEqual(s.data, inp[s.index]) - else: - self.assertArraysEqual(s.data, inp) - self.assertEqual(s.data.sharding.memory_kind, expected_mem_kind) + with jtu.count_pjit_cpp_cache_miss() as count: + out = f(inp) + out2 = f(inp2) + self.assertEqual(count[0], 1) - def test_device_put_host_to_hbm(self): - mesh = jtu.create_global_mesh((4, 2), ("x", "y")) - s_host = NamedSharding(mesh, P("y"), memory_kind="unpinned_host") - np_inp = np.arange(16).reshape(8, 2) + self.assertArraysEqual(out, np_inp @ np_inp.T) + self.assertArraysEqual(out2, np_inp @ np_inp.T) - out_on_host = jax.device_put(np_inp, s_host) - self.assertEqual(out_on_host.sharding, s_host) + def test_jit_compilation_cache_hit(self): + mesh, s, np_inp, inp = _create_inputs((8, 2), P("x", "y")) + inp2 = jax.device_put( + np_inp, GSPMDSharding(tuple(mesh.devices.flat), + s._to_xla_hlo_sharding(inp.ndim), + memory_kind="device") + ) - s_hbm = s_host.with_memory_kind("device") - out_on_hbm = jax.device_put(out_on_host, s_hbm) - self._check_device_put_addressable_shards( - out_on_hbm, np_inp, s_hbm, "device") + f = jax.jit(lambda x: x @ x.T) - def test_device_put_hbm_to_host(self): - mesh = jtu.create_global_mesh((4, 2), ("x", "y")) - s_host = NamedSharding(mesh, P("y"), memory_kind="unpinned_host") - inp = jnp.arange(16).reshape(8, 2) + with (jtu.count_pjit_cpp_cache_miss() as cpp_count, + jtu.count_jit_and_pmap_compiles() as compile_count): + f(inp) + f(inp2) + self.assertEqual(cpp_count[0], 2) + self.assertEqual(compile_count[0], 1) - out_on_host = jax.device_put(inp, s_host) - self._check_device_put_addressable_shards( - out_on_host, inp, s_host, "unpinned_host") + def test_jit_cpp_cache_output_hit(self): + _, _, _, inp = _create_inputs((8, 2), P("x"), mem_kind="device") - sharded_inp = jax.device_put(inp, s_host.with_memory_kind("device")) - sharded_out_on_host = jax.device_put(sharded_inp, s_host) - self._check_device_put_addressable_shards( - sharded_out_on_host, sharded_inp, s_host, "unpinned_host") + @jax.jit + def mul_two(x): + return x * 2 - def test_device_put_different_device_and_memory_host_to_hbm(self): - if jax.device_count() < 3: - raise unittest.SkipTest("Test requires >=3 devices") + with jtu.count_pjit_cpp_cache_miss() as count: + out = mul_two(inp) + mul_two(out) + self.assertEqual(count[0], 1) - out_host0 = jax.device_put( - jnp.arange(8), - SingleDeviceSharding(jax.devices()[0], memory_kind="unpinned_host")) + def test_jit_cache_hit_with_default_and_specified_mem_kind(self): + _, s, np_inp, _ = _create_inputs((8, 2), P("x", "y")) + _, s2, np_inp2, _ = _create_inputs((8, 2), P("x", "y"), mem_kind="device") - dev2 = jax.devices()[2] - out_hbm1 = jax.device_put( - out_host0, SingleDeviceSharding(dev2, memory_kind="device")) - self.assertEqual(out_hbm1.sharding.memory_kind, "device") - self.assertEqual(out_hbm1.sharding._device, dev2) - self.assertEqual(out_hbm1.addressable_shards[0].data.sharding._device, dev2) - self.assertEqual( - out_hbm1.addressable_shards[0].data.sharding.memory_kind, "device") + def mul(x): + return x @ x.T - def test_device_put_different_device_and_memory_hbm_to_host(self): - if jax.device_count() < 3: - raise unittest.SkipTest("Test requires >=3 devices") + f = jax.jit(mul, in_shardings=s) + g = jax.jit(mul, in_shardings=s2) - out_hbm0 = jnp.arange(8) + with jtu.count_jit_and_pmap_compiles() as count: + out = f(np_inp) + out2 = g(np_inp2) + self.assertEqual(count[0], 1) - dev2 = jax.devices()[2] - out_host1 = jax.device_put( - out_hbm0, SingleDeviceSharding(dev2, memory_kind="unpinned_host")) - self.assertEqual(out_host1.sharding.memory_kind, "unpinned_host") - self.assertEqual(out_host1.sharding._device, dev2) - self.assertEqual(out_host1.addressable_shards[0].data.sharding._device, - dev2) - self.assertEqual( - out_host1.addressable_shards[0].data.sharding.memory_kind, - "unpinned_host") + self.assertArraysEqual(out, np_inp @ np_inp.T) + self.assertArraysEqual(out2, np_inp2 @ np_inp2.T) - def test_device_put_on_different_device_with_the_same_memory_kind(self): - if len(jax.devices()) < 2: - raise unittest.SkipTest("Test requires >=2 devices.") + def test_sharding_devices_indices_map_cache_hit(self): + mesh = jtu.create_global_mesh((2, 2), ("x", "y")) + shape = (8, 2) + s1 = NamedSharding(mesh, P("x", "y")) + s2 = NamedSharding(mesh, P("x", "y"), memory_kind="device") + + s1.devices_indices_map(shape) + cache_info1 = common_devices_indices_map.cache_info() + s2.devices_indices_map(shape) + cache_info2 = common_devices_indices_map.cache_info() + self.assertEqual(cache_info2.hits, cache_info1.hits + 1) + self.assertEqual(cache_info2.misses, cache_info1.misses) + def test_no_donation_across_memory_kinds(self): + if xb.using_pjrt_c_api(): + raise unittest.SkipTest("GetOutputShardings not supported in PJRT C API") + mesh = jtu.create_global_mesh((2, 1), ("x", "y")) np_inp = np.arange(16).reshape(8, 2) + s_hbm = NamedSharding(mesh, P("x")) + s_host = s_hbm.with_memory_kind("pinned_host") + inp = jax.device_put(np_inp, s_hbm) - s_hbm_dev_0 = SingleDeviceSharding(jax.devices()[0], memory_kind="device") - s_hbm_dev_1 = SingleDeviceSharding(jax.devices()[1], memory_kind="device") - inp_hbm_dev0 = jax.device_put(np_inp, s_hbm_dev_0) - out_hbm_dev_1 = jax.device_put(inp_hbm_dev0, s_hbm_dev_1) - self._check_device_put_addressable_shards( - out_hbm_dev_1, np_inp, s_hbm_dev_1, "device") + @functools.partial(jax.jit, out_shardings=s_host, donate_argnums=0) + def f(x): + return x * 2 - inp_host_dev0 = jax.device_put( - np_inp, s_hbm_dev_0.with_memory_kind("unpinned_host")) - s_host_dev_1 = s_hbm_dev_1.with_memory_kind("unpinned_host") - out_host_dev_1 = jax.device_put(inp_host_dev0, s_host_dev_1) - self._check_device_put_addressable_shards( - out_host_dev_1, np_inp, s_host_dev_1, "unpinned_host") + with self.assertWarnsRegex( + UserWarning, "Some donated buffers were not usable"): + f(inp) - # TODO(yashkatariya): Enable this once we can compute on host. - # def test_device_put_resharding(self): - # mesh = jtu.create_global_mesh((2, 2), ("x", "y")) - # s_host = NamedSharding(mesh, P("x", "y"), memory_kind="unpinned_host") - # s_hbm = s_host.with_memory_kind("device") - # np_inp = np.arange(16).reshape(8, 2) + lowered_text = f.lower(inp).as_text("hlo") + self.assertNotIn("input_output_alias", lowered_text) + self.assertNotDeleted(inp) - # # Reshard single device array on HBM to multi device on host - # sds_inp_hbm = jax.device_put( - # jnp.arange(16).reshape(8, 2), - # SingleDeviceSharding(jax.devices()[0], memory_kind="device")) - # # device_put on host - # out_sharded_host = jax.device_put(sds_inp_hbm, s_host) - # self._check_device_put_addressable_shards( - # out_sharded_host, np_inp, s_host, "unpinned_host") + def test_single_mem_kind_donation_default_mem_kind(self): + mesh = jtu.create_global_mesh((2,), "x") + s = NamedSharding(mesh, P()) - # # Reshard single device array on host to multi device on hbm - # sds_inp_host = jax.device_put( - # jnp.arange(16).reshape(8, 2), - # SingleDeviceSharding(jax.devices()[0], memory_kind="unpinned_host")) - # # device_put on hbm - # out_sharded_hbm = jax.device_put(sds_inp_host, s_hbm) - # self._check_device_put_addressable_shards( - # out_sharded_hbm, np_inp, s_hbm, "device") + @functools.partial(jax.jit, out_shardings=s, donate_argnums=0) + def f(inp1): + return inp1 * 2 - def test_device_put_numpy_array(self): - mesh = jtu.create_global_mesh((4, 2), ("x", "y")) - np_inp = np.arange(16).reshape(8, 2) - s_hbm = NamedSharding(mesh, P(("x", "y")), memory_kind="device") - s_host = s_hbm.with_memory_kind("unpinned_host") + x = jax.device_put(np.arange(16).reshape(8, 2), s) - out_hbm = jax.device_put(np_inp, s_hbm) - self._check_device_put_addressable_shards(out_hbm, np_inp, s_hbm, "device") + f(x) - out_host = jax.device_put(np_inp, s_host) - self._check_device_put_addressable_shards( - out_host, np_inp, s_host, "unpinned_host") + lowered_text = f.lower(x).as_text("hlo") + self.assertIn("input_output_alias", lowered_text) + self.assertDeleted(x) - def test_device_put_numpy_scalar(self): - np_inp = np.float32(8) - s_hbm = SingleDeviceSharding(jax.devices()[0], memory_kind="device") - s_host = s_hbm.with_memory_kind("unpinned_host") + def test_compute_offload_inside_shmap(self): + mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) + s = NamedSharding(mesh, P('x', 'y')) + np_inp = np.arange(16).reshape(8, 2) + arr = jax.device_put(np_inp, s) - out_hbm = jax.device_put(np_inp, s_hbm) - self._check_device_put_addressable_shards(out_hbm, np_inp, s_hbm, "device") + @compute_on('device_host') + @jax.jit + def g(x): + return x * 2 - out_host = jax.device_put(np_inp, s_host) - self._check_device_put_addressable_shards( - out_host, np_inp, s_host, "unpinned_host") + def f(x): + x = x * 3 + y = g(x) + return y * 4 - def test_device_put_python_scalar(self): - py_scalar = float(8) - s_hbm = SingleDeviceSharding(jax.devices()[0], memory_kind="device") - s_host = s_hbm.with_memory_kind("unpinned_host") + out = jax.jit(shard_map(f, mesh=mesh, in_specs=P('x', 'y'), + out_specs=P('x', 'y')))(arr) + self.assertArraysEqual(out, np_inp * 24) - out_hbm = jax.device_put(py_scalar, s_hbm) - self._check_device_put_addressable_shards( - out_hbm, py_scalar, s_hbm, "device", index=False) + def test_qr_decomposition_offload(self): + shape = (3, 3) + dtype = np.float32 + operand = jnp.reshape(jnp.arange(math.prod(shape), dtype=dtype), shape) - out_host = jax.device_put(py_scalar, s_host) - self._check_device_put_addressable_shards( - out_host, py_scalar, s_host, "unpinned_host", index=False) + @compute_on("device_host") + @jax.jit + def g(x): + return lax.linalg.qr(x, full_matrices=True) - def test_device_put_python_int(self): - py_inp = 8 - s_hbm = SingleDeviceSharding(jax.devices()[0], memory_kind="device") - s_host = s_hbm.with_memory_kind("unpinned_host") + @jax.jit + def f(x): + x, _ = lax.linalg.qr(x, full_matrices=True) + x, _ = g(x) + return x - out_hbm = jax.device_put(py_inp, s_hbm) - self._check_device_put_addressable_shards( - out_hbm, py_inp, s_hbm, "device", index=False) + out = f(operand) # doesn't crash + lowered_text = f.lower(operand).as_text() + self.assertIn('@lapack_sgeqrf', lowered_text) + self.assertIn('@Qr', lowered_text) - out_host = jax.device_put(py_inp, s_host) - self._check_device_put_addressable_shards( - out_host, py_inp, s_host, "unpinned_host", index=False) + @jax.jit + def h(x): + x, _ = lax.linalg.qr(x, full_matrices=True) + x, _ = lax.linalg.qr(x, full_matrices=True) + return x + + expected_out = h(operand) + self.assertArraysAllClose(out, expected_out, rtol=1e-3) + +@jtu.with_config(jax_enable_memories=True) class ActivationOffloadingTest(jtu.JaxTestCase): def setUp(self): - if not jtu.test_device_matches(["tpu"]): - self.skipTest("Memories do not work on CPU and GPU backends yet.") + if not jtu.test_device_matches(["tpu", "gpu"]): + self.skipTest("Memories do not work on CPU backend.") super().setUp() def test_remat_jaxpr_offloadable(self): @@ -1145,7 +1286,7 @@ def f(x): compiled_stats = compiled_f.memory_analysis() if compiled_stats is not None: - if xla_extension_version >= 240 and jtu.pjrt_c_api_version_at_least(0, 43): + if jtu.pjrt_c_api_version_at_least(0, 43): self.assertGreater(compiled_stats.host_temp_size_in_bytes, 0) def test_remat_scan_jaxpr_offloadable(self): @@ -1202,10 +1343,12 @@ def g(ys, _): compiled_stats = compiled_f.memory_analysis() if compiled_stats is not None: - if xla_extension_version >= 240 and jtu.pjrt_c_api_version_at_least(0, 43): + if jtu.pjrt_c_api_version_at_least(0, 43): self.assertGreater(compiled_stats.host_temp_size_in_bytes, 0) def test_remat_scan_layout_change_offloadable(self): + if not jtu.test_device_matches(["tpu"]): + self.skipTest("Remat scan does not work on GPU backend.") mesh = jtu.create_global_mesh((2,), ("x",)) shape = (256, 128) np_inp = np.arange(math.prod(shape), dtype=np.float32).reshape(shape) @@ -1242,7 +1385,7 @@ def g(ys, _): compiled_stats = compiled_f.memory_analysis() if compiled_stats is not None: - if xla_extension_version >= 240 and jtu.pjrt_c_api_version_at_least(0, 43): + if jtu.pjrt_c_api_version_at_least(0, 43): self.assertGreater(compiled_stats.host_temp_size_in_bytes, 0) def test_remat_checkpoint_dots_with_no_batch_dims(self): @@ -1274,7 +1417,7 @@ def f(x): compiled_stats = compiled_f.memory_analysis() if compiled_stats is not None: - if xla_extension_version >= 240 and jtu.pjrt_c_api_version_at_least(0, 43): + if jtu.pjrt_c_api_version_at_least(0, 43): self.assertGreater(compiled_stats.host_temp_size_in_bytes, 0) if __name__ == "__main__": diff --git a/tests/mesh_utils_test.py b/tests/mesh_utils_test.py index 0ed5ce4c521c..1c8893f36070 100644 --- a/tests/mesh_utils_test.py +++ b/tests/mesh_utils_test.py @@ -18,14 +18,24 @@ from collections.abc import Sequence import dataclasses -import numpy as np - from absl import logging from absl.testing import absltest from absl.testing import parameterized -from jax.experimental import mesh_utils -from jax.sharding import Mesh +from jax._src import mesh as mesh_lib from jax._src import test_util +from jax._src.sharding_impls import NamedSharding, PartitionSpec, local_to_global_shape +from jax.experimental import mesh_utils +from jax.sharding import Mesh # pylint: disable=g-importing-member +import numpy as np + +# pyformat: disable + + +@dataclasses.dataclass(frozen=True) +class MockClient: + """Mock client for testing, everything is done as process index 0.""" + def process_index(self) -> int: + return 0 @dataclasses.dataclass(frozen=True) @@ -38,6 +48,7 @@ class MockTpuDevice: coords: Sequence[int] core_on_chip: int slice_index: int = 0 + client: MockClient = dataclasses.field(default_factory=MockClient) def mock_tpu_devices(x, y, z, dev_kind, one_device_per_chip, num_slices=1, @@ -183,15 +194,173 @@ def test_get_physical_tpu_mesh(self, xyz, reorder): ('4x8x8', mock_4x8x8_devices, [1, 32, 8], [(), (0, 2), (1,)]), ('8x8x8', mock_8x8x8_devices, [1, 64, 8], [(), (1, 2), (0,)]), ('8x8x16', mock_8x8x16_devices, [1, 64, 16], [(), (0, 1), (2,)]), - ('8x8', mock_8x8_devices, [8, 8], [(1,), (0, 2)]) + ('8x8', mock_8x8_devices, [8, 8], [(1,), (0, 2)]), ) - def test_create_device_mesh_for_nd_torus(self, devices, mesh_shape, - expected_assignment): + def test_create_device_mesh_for_nd_torus( + self, devices, mesh_shape, expected_assignment + ): jax_devices = devices(True) physical_mesh = mesh_utils._get_physical_tpu_mesh(jax_devices) _, assignment = mesh_utils._create_device_mesh_for_nd_torus( - physical_mesh, mesh_shape) - self.assertEqual(assignment, expected_assignment) + physical_mesh, mesh_shape + ) + + # The expected assignment is specified as a list, where each element is a + # sequence of physical axis assigned. We convert this into assignment + # matrix. + expected_assignment_matrix = np.ones( + [physical_mesh.ndim, len(mesh_shape)], dtype=np.int64 + ) + for logical_axis, axis_assignment in enumerate(expected_assignment): + for physical_axis in axis_assignment: + expected_assignment_matrix[physical_axis, logical_axis] = ( + physical_mesh.shape[physical_axis] + ) + self.assertArraysEqual(assignment, expected_assignment_matrix) + + @parameterized.named_parameters( + ('2x2x1', mock_2x2x1_devices,), + ('2x2x4', mock_2x2x4_devices, ), + ('4x4x4', mock_4x4x4_devices,), + ('4x4x8', mock_4x4x8_devices,), + ('4x8x8', mock_4x8x8_devices, ), + ('8x8', mock_8x8_devices), + ) + def test_create_device_mesh_has_computable_global_shape(self, devices): + def factorize(n, max_factors=3): + if max_factors == 1 or n == 1: + yield (n, ) * max_factors + return + for i in range(2, n+1): + if n % i == 0: + for remaining in factorize(n // i, max_factors=max_factors - 1): + yield (i, *remaining) + jax_devices = devices(True) + for mesh_shape in factorize(len(jax_devices), max_factors=3): + mesh = mesh_utils.create_device_mesh(mesh_shape, devices=jax_devices, + allow_split_physical_axes=True) + mesh = mesh_lib.Mesh(mesh, ('a', 'b', 'c')) + sharding = NamedSharding(mesh, PartitionSpec('a', 'b', 'c')) + computed_global_shape = local_to_global_shape(sharding, (1, 1, 1)) + self.assertFalse( + np.any([x is None for x in computed_global_shape]), + f'{mesh_shape=}, {computed_global_shape=} is not uniform') + + sharding = NamedSharding(mesh, PartitionSpec(('a', 'c',), 'b')) + computed_global_shape = local_to_global_shape(sharding, (1, 1, 1)) + self.assertFalse( + np.any([x is None for x in computed_global_shape]), + f'{mesh_shape=}, {computed_global_shape=} is not uniform') + + + @parameterized.named_parameters( + ('2x2x1', mock_2x2x1_devices, [1, 1, 4], [(), (), (0, 1, 2)]), + ('2x2x4', mock_2x2x4_devices, [1, 4, 4], [(), (2,), (0, 1)]), + ('4x4x4', mock_4x4x4_devices, [1, 16, 4], [(), (1, 2), (0,)]), + ('4x4x8a', mock_4x4x8_devices, [1, 16, 8], [(), (0, 1), (2,)]), + ('4x4x8b', mock_4x4x8_devices, [1, 8, 16], [(), (2,), (0, 1)]), + ('4x4x8c', mock_4x4x8_devices, [16, 8, 1], [(0, 1), (2,), ()]), + ('4x8x8', mock_4x8x8_devices, [1, 32, 8], [(), (0, 2), (1,)]), + ('8x8x8', mock_8x8x8_devices, [1, 64, 8], [(), (1, 2), (0,)]), + ('8x8x16', mock_8x8x16_devices, [1, 64, 16], [(), (0, 1), (2,)]), + ('8x8', mock_8x8_devices, [8, 8], [(1,), (0, 2)]), + ) + def test_create_device_mesh_for_nd_torus_split_axes_backward_compatible( + self, devices, mesh_shape, expected_assignment + ): + jax_devices = devices(True) + physical_mesh = mesh_utils._get_physical_tpu_mesh(jax_devices) + _, assignment = mesh_utils._create_device_mesh_for_nd_torus_splitting_axes( + physical_mesh, mesh_shape + ) + + # The expected assignment is specified as a list, where each element is a + # sequence of physical axis assigned. We convert this into assignment + # matrix. + expected_assignment_matrix = np.ones( + [physical_mesh.ndim, len(mesh_shape)], dtype=np.int64 + ) + for logical_axis, axis_assignment in enumerate(expected_assignment): + for physical_axis in axis_assignment: + expected_assignment_matrix[physical_axis, logical_axis] = ( + physical_mesh.shape[physical_axis] + ) + self.assertArraysEqual(assignment, expected_assignment_matrix) + + @parameterized.named_parameters( + ( + '4x4x4a', + mock_4x4x4_devices, + [2, 1, 32], + [ + [1, 1, 4], + [1, 1, 4], + [2, 1, 2], + ], + ), + ( + '4x4x4b', + mock_4x4x4_devices, + [8, 8, 1], + [ + [1, 4, 1], + [2, 2, 1], + [4, 1, 1], + ], + ), + ( + '4x4x8a', + mock_4x4x8_devices, + [2, 2, 8, 4], + [ + [1, 1, 1, 4], + [2, 2, 1, 1], + [1, 1, 8, 1], + ], + ), + ( + '4x4x8b', + mock_4x4x8_devices, + [2, 4, 1, 16], + [ + [1, 1, 1, 4], + [1, 1, 1, 4], + [2, 4, 1, 1], + ], + ), + ( + '4x8x8', + mock_4x8x8_devices, + [1, 128, 2], + [ + [1, 2, 2], + [1, 8, 1], + [1, 8, 1], + ], + ), + ( + '8x8', + mock_8x8_devices, + [2, 1, 32, 1], + [ + [1, 1, 8, 1], + [2, 1, 4, 1], + [1, 1, 1, 1], + ], + ), + ) + def test_create_device_mesh_for_nd_torus_split_axes_can_handle_axes_split( + self, devices, mesh_shape, assignment_matrix + ): + jax_devices = devices(True) + physical_mesh = mesh_utils._get_physical_tpu_mesh(jax_devices) + logical_mesh, assignment = mesh_utils._create_device_mesh_for_nd_torus( + physical_mesh, mesh_shape, allow_split_physical_axes=True + ) + self.assertEqual(logical_mesh.shape, tuple(mesh_shape)) + self.assertArraysEqual( + assignment, np.array(assignment_matrix, dtype=np.int64) + ) @parameterized.named_parameters( ('2X4x4x4a', (1, 16, 4), (2, 1, 1)), @@ -324,7 +493,160 @@ def test_create_contiguous_submeshes_errors(self): "(1, 128, 2) and physical mesh topology (4, 8, 8). " 'Available mesh_shapes: [(64, 4), (4, 64)]'): mesh_utils.create_device_mesh( - mesh_shape, devices=devices, contiguous_submeshes=True) + mesh_shape, devices=devices, contiguous_submeshes=True + ) + + +def int64_array(x) -> np.ndarray: + return np.array(x, dtype=np.int64) + + +def get_int_mesh(shape: Sequence[int]) -> np.ndarray: + return np.arange(np.prod(shape), dtype=np.int64).reshape(shape) + + +class SplitAxesDeviceMeshCreationTest(test_util.JaxTestCase): + + def test_get_prime_factors(self): + self.assertEqual(mesh_utils._get_prime_factors(1), []) # 1 has no factor. + self.assertEqual(mesh_utils._get_prime_factors(2), [2]) + self.assertEqual(mesh_utils._get_prime_factors(4), [2, 2]) + self.assertEqual(mesh_utils._get_prime_factors(8), [2, 2, 2]) + self.assertEqual(mesh_utils._get_prime_factors(6), [2, 3]) + self.assertEqual(mesh_utils._get_prime_factors(16), [2, 2, 2, 2]) + self.assertEqual(mesh_utils._get_prime_factors(12), [2, 2, 3]) + self.assertEqual(mesh_utils._get_prime_factors(121), [11, 11]) # square + self.assertEqual(mesh_utils._get_prime_factors(43), [43]) # prime + + @parameterized.named_parameters( + ( + '2x2x1', + [2, 2, 1], + [1, 2, 1], + 4, + [], # infeasible + ), + ( + '12x4x4', + [12, 4, 4], + [2, 2, 1], + 6, + [[6, 1, 1], [3, 2, 1], [3, 1, 2]], + ), + ( + '4x4x8', + [4, 4, 8], + [2, 2, 2], + 4, + [[2, 2, 1], [2, 1, 2], [1, 2, 2], [1, 1, 4]], + ), + ) + def test_enumerate_feasible_axis_assignments( + self, + physical_mesh_shape, + assigned_physical_mesh_shape, + logical_axis_size, + expected_assignments, + ): + assignment = int64_array([list(assigned_physical_mesh_shape)]).T + self.assertArraysEqual( + list( + mesh_utils._enumerate_feasible_logical_axis_assignments( + physical_mesh_shape, + assignment, + logical_axis_size=logical_axis_size, + ) + ), + [int64_array(a) for a in expected_assignments], + ) + + @parameterized.named_parameters( + ( + '2x2x1', + [2, 2, 1], + [1, 2, 2, 1], + [ + [1, 2, 1, 1], + [1, 1, 2, 1], + [1, 1, 1, 1], + ], + ), + ( + '4x4x4', + [4, 4, 4], + [2, 1, 32], + [ + [1, 1, 4], + [2, 1, 2], + [1, 1, 4], + ], + ), + ( + '12x4x8', + [12, 4, 8], + [2, 8, 24], + [ + [2, 2, 3], + [1, 2, 4], + [1, 2, 2], + ], + ), + ) + def test_generate_logical_mesh( + self, + physical_mesh_shape, + logical_mesh_shape, + assignment, + ): + assignment = np.array(assignment, dtype=np.int64) + physical_mesh = get_int_mesh(physical_mesh_shape) + logical_mesh = mesh_utils._generate_logical_mesh( + physical_mesh, logical_mesh_shape, assignment + ) + self.assertEqual(logical_mesh.shape, tuple(logical_mesh_shape)) + # We check that the logical mesh is assigned correctly using the following + # consistency check, which transforms the logical mesh back to physical + # mesh. + transpose = ( + np.arange(assignment.size).reshape(assignment.shape).T.reshape([-1]) + ) + self.assertArraysEqual( + physical_mesh.reshape([-1]), + logical_mesh.reshape(np.reshape(assignment.T, [-1])) + .transpose(transpose) + .reshape([-1]), + ) + + def test_prefer_assignment_whole_axis_size(self): + self.assertTrue( + mesh_utils._prefer_first_logical_axis_assignment( + int64_array([1, 2, 1]), + int64_array([1, 1, 2]), + physical_mesh_shape=[2, 2, 4], + assignment=int64_array([[1, 1, 1]]).T, + ) + ) + + def test_prefer_assignment_more_whole_axes(self): + # This entails the original implementation already. + self.assertTrue( + mesh_utils._prefer_first_logical_axis_assignment( + int64_array([4, 4, 1]), + int64_array([1, 1, 16]), + physical_mesh_shape=[4, 4, 16], + assignment=int64_array([[1, 1, 1]]).T, + ) + ) + + def test_prefer_assignment_avoid_already_assigned(self): + self.assertTrue( + mesh_utils._prefer_first_logical_axis_assignment( + int64_array([2, 1]), + int64_array([1, 2]), + physical_mesh_shape=[2, 4], + assignment=int64_array([[1, 2]]).T, + ) + ) if __name__ == '__main__': diff --git a/tests/metadata_test.py b/tests/metadata_test.py index 3511595d6c5a..c0e406bf6696 100644 --- a/tests/metadata_test.py +++ b/tests/metadata_test.py @@ -23,8 +23,7 @@ from jax._src.lib.mlir import ir from jax import numpy as jnp -from jax import config -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() def module_to_string(module: ir.Module) -> str: @@ -82,7 +81,6 @@ def false_fun(x): def f(which, x): return jax.lax.cond(which, x, true_fun, x, false_fun) hlo = module_to_string(jax.jit(f).lower(True, 1.).compiler_ir()) - self.assertRegex(hlo, r'loc\(".*cond\[linear=\(False, False\)\]"') self.assertRegex(hlo, r'loc\(".*cond/branch_0_fun/cos"') self.assertRegex(hlo, r'loc\(".*cond/branch_1_fun/sin"') diff --git a/tests/mock_gpu_test.py b/tests/mock_gpu_test.py index b955f0398e0c..fd1fe560ff4b 100644 --- a/tests/mock_gpu_test.py +++ b/tests/mock_gpu_test.py @@ -17,31 +17,21 @@ from absl.testing import absltest import jax -from jax import config from jax._src import test_util as jtu import jax.numpy as jnp from jax.sharding import NamedSharding from jax.sharding import PartitionSpec as P import numpy as np -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() +NUM_SHARDS = 16 +@jtu.with_config(use_mock_gpu_client=True, mock_num_gpus=NUM_SHARDS) class MockGPUTest(jtu.JaxTestCase): - def setUp(self): - super().setUp() - jax.config.update('use_mock_gpu_client', True) - - def tearDown(self): - jax.config.update('use_mock_gpu_client', False) - jax.config.update('mock_num_gpus', 1) - super().tearDown() - def testMockWithSharding(self): - num_shards = 16 - jax.config.update('mock_num_gpus', num_shards) - mesh = jtu.create_global_mesh((num_shards,), ('x',)) + mesh = jtu.create_global_mesh((NUM_SHARDS,), ('x',)) @partial( jax.jit, in_shardings=NamedSharding(mesh, P('x',)), diff --git a/tests/mosaic/BUILD b/tests/mosaic/BUILD new file mode 100644 index 000000000000..d182c99be7b1 --- /dev/null +++ b/tests/mosaic/BUILD @@ -0,0 +1,89 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load( + "//jaxlib:jax.bzl", + "jax_generate_backend_suites", + "jax_test", + "py_deps", +) + +licenses(["notice"]) + +package( + default_applicable_licenses = [], + default_visibility = ["//visibility:private"], +) + +jax_generate_backend_suites() + +DISABLED_BACKENDS = [ + "cpu", + "tpu", +] + +DISABLED_CONFIGS = [ + "gpu", + "gpu_a100", + "gpu_p100", + "gpu_p100_x32", + "gpu_x32", + "gpu_pjrt_c_api", +] + +jax_test( + name = "gpu_test", + srcs = ["gpu_test.py"], + disable_backends = DISABLED_BACKENDS, + disable_configs = DISABLED_CONFIGS, + shard_count = 4, + deps = [ + "//jax:mosaic_gpu", + ] + py_deps("absl/testing") + py_deps("numpy"), +) + +jax_test( + name = "matmul_test", + srcs = ["matmul_test.py"], + disable_backends = DISABLED_BACKENDS, + disable_configs = DISABLED_CONFIGS, + shard_count = 16, + deps = [ + "//jax:mosaic_gpu", + "//jax/experimental/mosaic/gpu/examples:matmul", + ] + py_deps("absl/testing") + py_deps("numpy"), +) + +jax_test( + name = "flash_attention", + srcs = ["//third_party/py/jax/experimental/mosaic/gpu/examples:flash_attention.py"], + disable_backends = DISABLED_BACKENDS, + disable_configs = DISABLED_CONFIGS, + main = "//third_party/py/jax/experimental/mosaic/gpu/examples:flash_attention.py", + tags = ["notap"], + deps = [ + "//jax:mosaic_gpu", + ] + py_deps("numpy"), +) + +jax_test( + name = "flash_attention_test", + srcs = ["flash_attention_test.py"], + disable_backends = DISABLED_BACKENDS, + disable_configs = DISABLED_CONFIGS, + deps = [ + "//jax:mosaic_gpu", + "//jax/experimental/mosaic/gpu/examples:flash_attention", + ] + py_deps("absl/testing"), +) diff --git a/tests/mosaic/flash_attention_test.py b/tests/mosaic/flash_attention_test.py new file mode 100644 index 000000000000..1d15159ca44e --- /dev/null +++ b/tests/mosaic/flash_attention_test.py @@ -0,0 +1,78 @@ +# Copyright 2024 The JAX Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Test different parameterizations of FlashAttention.""" + +import os + +from absl.testing import absltest, parameterized +from jax._src import config +from jax._src import test_util as jtu + +# pylint: disable=g-import-not-at-top +try: + # We only import this to see if Mosaic is available. + import jax.experimental.mosaic.gpu # noqa: F401 +except ImportError: + flash_attention = None +else: + from jax.experimental.mosaic.gpu.examples import flash_attention + + +config.parse_flags_with_absl() +os.environ["XLA_FLAGS"] = ( + os.environ.get("XLA_FLAGS", "") + " --xla_gpu_autotune_level=0") + + +@jtu.with_config(jax_traceback_filtering="off") +class FlashAttentionTestCase(jtu.JaxTestCase): + + def setUp(self): + super().setUp() + if flash_attention is None: + self.skipTest("Mosaic GPU not available.") + if (not jtu.test_device_matches(["cuda"]) or + not jtu.is_cuda_compute_capability_at_least("9.0")): + self.skipTest("Only works on GPU with capability >= sm90") + + @parameterized.product( + batch_size=(1,), + q_seq_len=(4096,), + kv_seq_len=(4096,), + num_q_and_kv_heads=((4, 1), # MQA + (6, 3), # GQA + (4, 4),), # MHA + head_dim=(64, 128, 256), + # Provide a default value for exp_impl if 'flash_attention' is not + # available. Bypasses test failures when Mosaic is not available. + exp_impl=[*(flash_attention.ExpImplementation + if flash_attention is not None else (NotImplementedError,))], + ) + def test_flash_attention(self, batch_size, q_seq_len, kv_seq_len, + num_q_and_kv_heads, head_dim, exp_impl): + num_q_heads, num_kv_heads = num_q_and_kv_heads + flash_attention.benchmark_and_verify( + batch_size=batch_size, + q_seq_len=q_seq_len, + kv_seq_len=kv_seq_len, + num_q_heads=num_q_heads, + num_kv_heads=num_kv_heads, + head_dim=head_dim, + exp_impl=exp_impl, + blocks=flash_attention.BlockSizes(stages=2, q=64, kv=64) + ) + + +if __name__ == "__main__": + absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py new file mode 100644 index 000000000000..48a9d0b47d2b --- /dev/null +++ b/tests/mosaic/gpu_test.py @@ -0,0 +1,976 @@ +# Copyright 2024 The JAX Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for Mosaic GPU DSL functions and utilities.""" + +from functools import partial +import operator + +from absl.testing import absltest, parameterized +import jax +from jax._src import config +from jax._src import test_util as jtu +from jax._src.interpreters import mlir +from jax._src.lib.mlir import ir +from jax._src.lib.mlir.dialects import arith +from jax._src.lib.mlir.dialects import scf +from jax._src.lib.mlir.dialects import vector +import jax.numpy as jnp +import numpy as np +try: + import jax._src.lib.mosaic_gpu # noqa: F401 + HAS_MOSAIC_GPU = True +except ImportError: + HAS_MOSAIC_GPU = False +else: + from jax.experimental.mosaic import gpu as mosaic_gpu + from jax.experimental.mosaic.gpu import dsl as mgpu + from jax.experimental.mosaic.gpu import profiler + from jax.experimental.mosaic.gpu.utils import * # noqa: F403 + from jax._src.lib.mlir.dialects import gpu + from jax._src.lib.mlir.dialects import llvm + + +# ruff: noqa: F405 +config.parse_flags_with_absl() + +def nd_loop(bounds, body, *, _idxs = ()): + if not bounds: + body(*_idxs) + return + bound, *other_bounds = bounds + @fori(bound, ()) + def _loop_body(i, _): + nd_loop(other_bounds, body, _idxs=(*_idxs, i)) + return () + + +def mlir_sum(elems): + assert elems + total = elems[0] + for elem in elems[1:]: + total = arith.addi(total, elem) + return total + + +def copy(src: ir.Value, dst: ir.Value, swizzle: int | None = None): + index = ir.IndexType.get() + thread_id = gpu.thread_id(gpu.Dimension.x) + stride = gpu.block_dim(gpu.Dimension.x) + for dim in (gpu.Dimension.y, gpu.Dimension.z): + thread_id = arith.addi(thread_id, arith.muli(gpu.thread_id(dim), stride)) + stride = arith.muli(stride, gpu.block_dim(dim)) + is_first_thread = arith.cmpi(arith.CmpIPredicate.eq, thread_id, c(0, index)) + src_ty = ir.MemRefType(src.type) + dst_ty = ir.MemRefType(dst.type) + if src_ty.shape != dst_ty.shape: + raise ValueError( + f"src and dst shapes don't match: {src_ty.shape} != {dst_ty.shape}" + ) + shape = src_ty.shape + dyn_strides = [c(s, index) for s in get_contiguous_strides(shape)] + with ir.InsertionPoint(scf.IfOp(is_first_thread).then_block): + def body(*idx): + dst_idx = idx + if swizzle is not None: + if swizzle != 128: + raise NotImplementedError("Only swizzle 128B implemented") + # TODO(apaszke): This can probably be cleaned up. + # But it works and it's test-only, so it doesn't matter much. + # After all, swizzle should just be an xor of row and linear idx, + # adjusted for the bytewidth. + bytes_per_element = bytewidth(src_ty.element_type) + elems_per_tile = 1024 // bytes_per_element + elems_per_row = elems_per_tile // 8 + elems_per_group = 16 // bytes_per_element + linear_idx = c(0, index) + for stride, i in zip(dyn_strides, idx): + linear_idx = arith.addi(linear_idx, arith.muli(i, stride)) + tile_offset = arith.remui(linear_idx, c(elems_per_tile, index)) + linear_tile_start = arith.subi(linear_idx, tile_offset) + row = arith.divui(tile_offset, c(elems_per_row, index)) + row_offset = arith.remui(tile_offset, c(elems_per_row, index)) + src_group = arith.divui(row_offset, c(elems_per_group, index)) + group_offset = arith.remui(row_offset, c(elems_per_group, index)) + dst_group = arith.xori(src_group, row) + dst_linear_idx = mlir_sum([ + linear_tile_start, + arith.muli(row, c(elems_per_row, index)), + arith.muli(dst_group, c(elems_per_group, index)), + group_offset, + ]) + dst_idx = [ + arith.remui(arith.divui(dst_linear_idx, stride), c(bound, index)) + for stride, bound in zip(dyn_strides, shape) + ] + memref.store(memref.load(src, idx), dst, dst_idx) + nd_loop([c(d, index) for d in shape], body) + scf.yield_([]) + gpu.barrier() + nvvm.fence_proxy(nvvm.ProxyKind.async_) + + +def iota_tensor(m, n, mlir_dtype): + assert m % 64 == 0 + assert n % 8 == 0 + def c(i): + return arith.constant(index, ir.IntegerAttr.get(index, i)) + index = ir.IndexType.get() + i32 = ir.IntegerType.get_signless(32) + warp_id = arith.divui(gpu.thread_id(gpu.Dimension.x), c(32)) + within_warp_id = arith.remui(gpu.thread_id(gpu.Dimension.x), c(32)) + warp_row_start = arith.muli(warp_id, c(16)) + within_warp_row = arith.divui(within_warp_id, c(4)) + start_row = arith.addi(warp_row_start, within_warp_row) + start_col = arith.muli(arith.remui(within_warp_id, c(4)), c(2)) + registers = np.empty((m // 64, n // 8, 2, 1), dtype=object) + for row_tile, col_tile, row_subtile, _ in np.ndindex(registers.shape): + row = arith.addi(start_row, c(row_tile * 64 + row_subtile * 8)) + col = arith.addi(start_col, c(col_tile * 8)) + row_value_base = arith.muli(row, c(n)) + vec = llvm.mlir_undef(ir.VectorType.get((2,), i32)) + for col_offset in range(2): + value = arith.addi(row_value_base, arith.addi(c(col_offset), col)) + value = arith.index_cast(i32, value) + vec = vector.insertelement(value, vec, position=c(col_offset)) + registers[row_tile, col_tile, row_subtile, 0] = vec + t = mgpu.FragmentedArray(_registers=registers, _layout=mgpu.WGMMA_LAYOUT) + return t.astype(mlir_dtype) + + +class TestCase(parameterized.TestCase): + + def setUp(self): + if not HAS_MOSAIC_GPU: + self.skipTest("jaxlib built without Mosaic GPU") + if (not jtu.test_device_matches(["cuda"]) or + not jtu.is_cuda_compute_capability_at_least("9.0")): + self.skipTest("Only works on GPU with capability >= sm90") + super().setUp() + self.prng = np.random.default_rng(1234) + self.enter_context(jtu.global_config_context(jax_traceback_filtering="off")) + self.enter_context(mlir.make_ir_context()) + self.enter_context(ir.Location.unknown()) + + +class TestUtilTest(TestCase): + + def test_copy_basic(self): + def kernel(ctx, src, dst, _): + copy(src, dst) + x = jnp.arange(2 * 3 * 5).reshape(2, 5, 3) + y = mosaic_gpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), x, x, ())(x) + np.testing.assert_array_equal(y, x) + + def test_copy_swizzle(self): + def kernel(ctx, src, dst, _): + copy(src, dst, swizzle=128) + x = jnp.arange(8 * 32, dtype=jnp.float32).reshape(8, 32) + y = mosaic_gpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), x, x, ())(x) + expected = np.zeros_like(y) + for i in range(8): + for j in range(8): + js = j ^ i + expected[i, (j * 4):(j * 4) + 4] = x[i, (js * 4):(js * 4) + 4] + np.testing.assert_array_equal(y, expected) + + def test_copy_swizzle_noop(self): + # Two swizzles cancel out + def kernel(ctx, src, dst, smem): + copy(src, smem, swizzle=128) + copy(smem, dst, swizzle=128) + x = jnp.arange(8 * 32, dtype=jnp.float32).reshape(8, 32) + y = mosaic_gpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), x, x, x)(x) + np.testing.assert_array_equal(y, x) + + def test_iota_tensor(self): + m = n = 64 + def kernel(ctx, dst, _): + f32 = ir.F32Type.get() + index = ir.IndexType.get() + registers = iota_tensor(m, n, f32).registers + assert registers.size == 16, registers.size + for i, vec_reg in enumerate(registers.flat): + for j in range(2): + reg = vector.extractelement(vec_reg, position=c(j, index)) + memref.store( + reg, dst, [gpu.thread_id(gpu.Dimension.x), c(2 * i + j, index)] + ) + out_shape = jax.ShapeDtypeStruct((128, 32), jnp.float32) + regs = mosaic_gpu.as_gpu_kernel( + kernel, (1, 1, 1), (128, 1, 1), (), out_shape, () + )() + thread_ids = np.arange(128) + warp_ids = thread_ids // 32 + lane_ids = thread_ids % 32 + thread_rows = warp_ids * 16 + lane_ids // 4 + thread_start_cols = (lane_ids % 4) * 2 + thread_cols = thread_start_cols[:, None] + (np.arange(n // 8)[None] * 8) + regs = regs.reshape(128, 8, 2, 2) + for row_half in range(2): + for col_half in range(2): + np.testing.assert_array_equal( + regs[..., row_half, col_half], + (thread_rows[:, None] + row_half * 8) * n + thread_cols + col_half + ) + + +class MemRefTest(TestCase): + @parameterized.product( + dim=tuple(range(3)), + strided=(False, True) + ) + def test_unsqueeze(self, dim, strided): + def kernel(ctx, inp, out, _): + if strided: + for i in range(8): + s = ds(i, 1) + out_slice = s if dim != 0 else (slice(None), s) + copy( + memref_unsqueeze(memref_slice(inp, s), dim), + memref_slice(out, out_slice), + ) + else: + copy(memref_unsqueeze(inp, dim), out) + x = np.arange(8 * 16, dtype=jnp.float32).reshape(8, 16) + out_shape = list(x.shape) + out_shape.insert(dim, 1) + out_ty = jax.ShapeDtypeStruct(out_shape, jnp.float32) + y = mosaic_gpu.as_gpu_kernel( + kernel, (1, 1, 1), (128, 1, 1), x, out_ty, () + )(x) + np.testing.assert_array_equal(y, x.reshape(out_shape)) + + @parameterized.product( + dim=tuple(range(2)), + strided=(False, True) + ) + def test_unfold(self, dim, strided): + in_shape = (8, 16) + def kernel(ctx, inp, out, _): + if strided: + # We slice the dim we don't unfold + for i in range(in_shape[1 - dim] // 4): + s = ds(i * 4, 4) + in_slice = s if dim == 1 else (slice(None), s) + out_slice = s if dim == 1 else (slice(None),) * 3 + (s,) + copy( + memref_unfold(memref_slice(inp, in_slice), dim, (2, 2, None)), + memref_slice(out, out_slice), + ) + else: + copy(memref_unfold(inp, dim, (2, 2, None)), out) + x = np.arange(np.prod(in_shape), dtype=jnp.float32).reshape(in_shape) + out_shape = list(in_shape) + out_shape[dim:dim + 1] = [2, 2, out_shape[dim] // 4] + out_ty = jax.ShapeDtypeStruct(out_shape, jnp.float32) + y = mosaic_gpu.as_gpu_kernel( + kernel, (1, 1, 1), (128, 1, 1), x, out_ty, () + )(x) + np.testing.assert_array_equal(y, x.reshape(out_ty.shape)) + + @parameterized.product( + dim=tuple(range(2)), + ) + def test_fold_not_strided(self, dim): + def kernel(ctx, inp, out, _): + copy(memref_fold(inp, dim, 2), out) + + x = np.arange(8 * 2 * 8, dtype=jnp.float32).reshape(8, 2, 8) + out_ty = jax.ShapeDtypeStruct((16, 8) if dim == 0 else (8, 16), jnp.float32) + y = mosaic_gpu.as_gpu_kernel( + kernel, (1, 1, 1), (128, 1, 1), x, out_ty, () + )(x) + np.testing.assert_array_equal(y, x.reshape(out_ty.shape)) + + @parameterized.named_parameters([ + ("packed", (4, 4, 4), (16, 4, 1), 1, 2, False), + ("strided_end", (4, 4, 4, 4), (256, 64, 16, 4), 1, 2, False), + ("strided_bot", (4, 4, 4, 4), (256, 16, 4, 1), 1, 2, False), + ("strided_top", (4, 4, 4, 4), (256, 64, 4, 1), 1, 2, True), + ("strided_mid", (4, 4, 4, 4), (265, 64, 16, 1), 1, 3, True), + ("overap", (2, 4, 4), (16, 1, 1), 0, 3, True), + ]) + def test_fold_strided( + self, shape, strides, dim, fold_rank, throws_not_impl + ): + expanded_shape = get_packed_shape(strides, shape) + total_size = np.prod(expanded_shape) + np_inp = np.arange(total_size, dtype=jnp.float32).reshape(expanded_shape) + index = tuple([slice(0, s) for s in shape]) + + # Reference implementation + def np_fold(inp, dim, fold_rank): + out_shape = list(inp.shape) + out_shape[dim : dim + fold_rank] = [ + int(np.prod(inp.shape[dim : dim + fold_rank])) + ] + if throws_not_impl: + return jax.ShapeDtypeStruct(shape=out_shape, dtype=inp.dtype) + else: + return inp.reshape(*out_shape) + + total_size = np.prod(shape) * np.prod(strides) + + def do_test(): + def kernel(ctx, inp, out, _): + copy(memref_fold(memref_slice(inp, index), dim, fold_rank), out) + + out = np_fold(np_inp[index], dim, fold_rank) + y = mosaic_gpu.as_gpu_kernel( + kernel, (1, 1, 1), (128, 1, 1), np_inp, out, () + )(np_inp) + assert ( + not throws_not_impl + ), "If it should have thrown it would during the call." + np.testing.assert_array_equal(y, out) + + if throws_not_impl: + with self.assertRaises(NotImplementedError): + do_test() + else: + do_test() + + +def get_packed_shape(strides, shape): + perm = sorted(range(len(strides)), key=lambda i: strides[i], reverse=True) + ordered_strides = [strides[i] for i in perm] + ordered_shape = [shape[i] for i in perm] + packed_shape = [ordered_shape[-1]] + packed_shape += [ + stride0 // stride + for stride0, stride in zip(ordered_strides, ordered_strides[1:]) + ] + # Invert permutation + inv_perm = [None] * len(perm) + for i, p in enumerate(perm): + inv_perm[p] = i + return [packed_shape[i] for i in inv_perm] + + +class WGMMATest(TestCase): + + @parameterized.named_parameters( + ("f32", ir.F32Type, jnp.float32), ("f16", ir.F16Type, jnp.float16) + ) + def test_store_untiled(self, mlir_dtype_cls, jax_dtype): + mlir_dtype = mlir_dtype_cls.get() + def kernel(ctx, out, _): + del ctx + iota_tensor(64, 64, mlir_dtype).store_untiled(out) + expected = np.arange(64 * 64, dtype=jax_dtype).reshape(64, 64) + iota = mosaic_gpu.as_gpu_kernel( + kernel, (1, 1, 1), (128, 1, 1), (), expected, () + )() + np.testing.assert_array_equal(iota, expected) + + @parameterized.named_parameters( + ("f32", ir.F32Type.get, jnp.float32), + ("f16", ir.F16Type.get, jnp.float16), + ("i8", partial(ir.IntegerType.get_signless, 8), jnp.int8), + ) + def test_store_tiled(self, mlir_dtype_cls, jax_dtype): + mlir_dtype = mlir_dtype_cls() + m = 128 + n = 256 + tiling = (64, 128 // bytewidth(mlir_dtype)) + def kernel(ctx, out, smem): + del ctx + iota_tensor(m, n, mlir_dtype).store_tiled(smem, swizzle=128) + copy(smem, out, swizzle=128) + expected = ( + np.arange(m * n, dtype=jax_dtype) + .reshape(m // tiling[0], tiling[0], n // tiling[1], tiling[1]) + .transpose(0, 2, 1, 3) + ) + iota = mosaic_gpu.as_gpu_kernel( + kernel, (1, 1, 1), (128, 1, 1), (), expected, expected + )() + np.testing.assert_array_equal(iota, expected) + + @parameterized.named_parameters( + ("bf16_i8", + ir.BF16Type.get, jnp.bfloat16, + lambda: ir.IntegerType.get_signless(8), jnp.int8), + ("i8_bf16", + lambda: ir.IntegerType.get_signless(8), jnp.int8, + ir.BF16Type.get, jnp.bfloat16), + ("i8_i8", + lambda: ir.IntegerType.get_signless(8), jnp.int8, + lambda: ir.IntegerType.get_signless(8), jnp.int8), + ) + def test_convert_tiled(self, + mlir_dtype_cls_from, jax_dtype_from, + mlir_dtype_cls_to, jax_dtype_to): + mlir_dtype_from = mlir_dtype_cls_from() + mlir_dtype_to = mlir_dtype_cls_to() + m = 128 + n = 256 // bytewidth(mlir_dtype_from) + def kernel(ctx, inp, out, smem): + del ctx + smem_from, smem_to = smem + copy(inp, smem_from, swizzle=128) + t = mgpu.FragmentedArray.load_tiled(smem_from, swizzle=128) + t = t.astype(mlir_dtype_to) + t.store_tiled(smem_to, swizzle=128) + copy(smem_to, out, swizzle=128) + + from_tiling = (64, 128 // bytewidth(mlir_dtype_from)) + to_tiling = (64, 128 // bytewidth(mlir_dtype_to)) + expected_raw = self.prng.integers( + low=-127, high=127, size=(m, n), dtype=np.int8 + ) + expected = lambda jax_dtype, tiling: expected_raw.reshape( + m // tiling[0], tiling[0], n // tiling[1], tiling[1] + ).transpose(0, 2, 1, 3).astype(jax_dtype) + + expected_from = expected(jax_dtype_from, from_tiling) + expected_to = expected(jax_dtype_to, to_tiling) + res = mosaic_gpu.as_gpu_kernel( + kernel, + (1, 1, 1), + (128, 1, 1), + expected_from, + expected_to, + (expected_from, expected_to), + )(expected_from) + np.testing.assert_array_equal(res, expected_to) + + @parameterized.named_parameters( + ("f32", ir.F32Type.get, jnp.float32), + ("f16", ir.F16Type.get, jnp.float16), + ("i8", partial(ir.IntegerType.get_signless, 8), jnp.int8), + ) + def test_load_tiled(self, mlir_dtype_cls, jax_dtype): + mlir_dtype = mlir_dtype_cls() + m = 128 + n = 256 // bytewidth(mlir_dtype) + tiling = (64, 128 // bytewidth(mlir_dtype)) + def kernel(ctx, in_, out, smem): + del ctx + smem1, smem2 = smem + copy(in_, smem1, swizzle=128) + t = mgpu.FragmentedArray.load_tiled(smem1, swizzle=128) + t.store_tiled(smem2, swizzle=128) + copy(smem2, out, swizzle=128) + expected = ( + np.arange(m * n, dtype=jax_dtype) + .reshape(m // tiling[0], tiling[0], n // tiling[1], tiling[1]) + .transpose(0, 2, 1, 3) + ) + iota = mosaic_gpu.as_gpu_kernel( + kernel, (1, 1, 1), (128, 1, 1), expected, expected, (expected,) * 2 + )(expected) + np.testing.assert_array_equal(iota, expected) + + @parameterized.product( + lhs_transpose=(False, True), + rhs_transpose=(False, True), + in_mlir_dtype_cls=(ir.F16Type, ir.BF16Type, ir.F32Type), + m=(64, 128, 192), + n=(64, 128, 192), + k_steps=(1, 2), + tma_inputs=(False, True), + swizzle=(32, 64, 128), + jax_out_dtype=(jnp.float16, jnp.float32), + ) + def test_wgmma_basic( + self, + m, + n, + k_steps, + in_mlir_dtype_cls, + lhs_transpose, + rhs_transpose, + tma_inputs, + swizzle, + jax_out_dtype, + ): + if jax_out_dtype == jnp.float16 and in_mlir_dtype_cls is not ir.F16Type: + raise self.skipTest("Only f16 input is supported for f16 output.") + if swizzle != 128 and lhs_transpose: + raise self.skipTest("Transpose only supported in 128B swizzled WGMMA") + if swizzle != 128 and not tma_inputs: + raise self.skipTest("Copy with non-128B swizzles not implemented") + + in_mlir_dtype = in_mlir_dtype_cls.get() + out_mlir_dtype = mlir.dtype_to_ir_type(jnp.dtype(jax_out_dtype)) + if ir.F32Type.isinstance(in_mlir_dtype): # We actually use tf32 instead + in_jax_dtype = jnp.float32 + if lhs_transpose or not rhs_transpose: + self.skipTest("Transpose only supported in 16-bit WGMMA") + exponent_bits, mantissa_bits = 8, 10 # Use tf32 + elif bytewidth(in_mlir_dtype) == 2: + if n % 64 != 0: + self.skipTest("16-bit WGMMA only supports n % 64 == 0") + if ir.F16Type.isinstance(in_mlir_dtype): + in_jax_dtype = jnp.float16 + exponent_bits, mantissa_bits = 5, 10 + elif ir.BF16Type.isinstance(in_mlir_dtype): + in_jax_dtype = jnp.bfloat16 + exponent_bits, mantissa_bits = 8, 7 + else: + raise NotImplementedError(in_mlir_dtype) + else: + raise NotImplementedError(in_mlir_dtype) + nk_tile = swizzle // bytewidth(in_mlir_dtype) + k = nk_tile * k_steps + assert m % 64 == 0 and n % nk_tile == 0 + index = ir.IndexType.get() + + row_major = mgpu.WGMMALayout.ROW_MAJOR + col_major = mgpu.WGMMALayout.COL_MAJOR + lhs_order = col_major if lhs_transpose else row_major + rhs_order = col_major if rhs_transpose else row_major + + def kernel(ctx, lhs, rhs, out, scratch): + lhs_smem, rhs_smem = scratch + if tma_inputs: + lhs_transform = (mosaic_gpu.TileTransform((64, nk_tile)),) + if lhs_transpose: + assert nk_tile == 64 # Make sure we didn't have to transpose tiling. + lhs_transform += (mosaic_gpu.TransposeTransform((1, 0, 2, 3)),) + rhs_transform = (mosaic_gpu.TileTransform((nk_tile, nk_tile)),) + if rhs_transpose: + rhs_transform += (mosaic_gpu.TransposeTransform((1, 0, 2, 3)),) + barriers = BarrierArray(2) + ctx.async_copy( + src_ref=lhs, + dst_ref=lhs_smem, + swizzle=swizzle, + gmem_transform=lhs_transform, + barrier=barriers[0], + ) + ctx.async_copy( + src_ref=rhs, + dst_ref=rhs_smem, + swizzle=swizzle, + gmem_transform=rhs_transform, + barrier=barriers[1], + ) + for i in range(2): + barriers[i].wait() + else: + for mi in range(m // 64): + for ki in range(k // nk_tile): + lhs_slice = ( + ds(c(mi * 64, index), 64), + ds(c(ki * nk_tile, index), nk_tile), + ) + if lhs_transpose: + lhs_slice = lhs_slice[::-1] + copy( + src=memref_slice(lhs, lhs_slice), + dst=memref_slice(lhs_smem, (mi, ki)), + swizzle=swizzle, + ) + for ki in range(k // nk_tile): + k_slice = ds(c(ki * nk_tile, index), nk_tile) + for ni in range(n // nk_tile): + rhs_slice = (k_slice, ds(c(ni * nk_tile, index), nk_tile)) + if rhs_transpose: + rhs_slice = rhs_slice[::-1] + copy( + src=memref_slice(rhs, rhs_slice), + dst=memref_slice(rhs_smem, (ki, ni)), + swizzle=swizzle, + ) + init_acc = mgpu.WGMMAAccumulator.zero(m=m, n=n, dtype=out_mlir_dtype) + acc = mgpu.wgmma( + init_acc, lhs_smem, rhs_smem, + a_order=lhs_order, b_order=rhs_order, swizzle=swizzle, + ) + nvvm.wgmma_commit_group_sync_aligned() + nvvm.wgmma_wait_group_sync_aligned(0) + acc.value.store_untiled(out) + + def quantize(x): + # Quantize the input to avoid rounding when feeding the WGMMA + return jax.lax.reduce_precision(x, exponent_bits, mantissa_bits) + + x_shape = (k, m) if lhs_transpose else (m, k) + x = quantize(self.prng.uniform(-1, 1, x_shape)).astype(in_jax_dtype) + y_shape = (n, k) if rhs_transpose else (k, n) + y = quantize(self.prng.uniform(-1, 1, y_shape)).astype(in_jax_dtype) + out_shape = jax.ShapeDtypeStruct((m, n), jax_out_dtype) + scratch_shape = [ + jax.ShapeDtypeStruct((m // 64, k // nk_tile, 64, nk_tile), in_jax_dtype), + jax.ShapeDtypeStruct( + (k // nk_tile, n // nk_tile, nk_tile, nk_tile), in_jax_dtype + ), + ] + z = mosaic_gpu.as_gpu_kernel( + kernel, (1, 1, 1), (128, 1, 1), (x, y), out_shape, scratch_shape + )(x, y) + x32, y32 = x.astype(np.float32), y.astype(np.float32) + ref = (x32.T if lhs_transpose else x32) @ (y32.T if rhs_transpose else y32) + atol = 2e-2 if jax_out_dtype == jnp.float16 else 5e-6 + np.testing.assert_allclose(z, ref, atol=atol) + + # TODO(apaszke): Add support for f32 + @parameterized.product( + m=(64, 128, 192), + n=(64, 128, 192), + k_steps=(1, 2), + rhs_transpose=(False, True), + mlir_dtype_cls=(ir.F16Type, ir.BF16Type), + ) + def test_wgmma_reg_lhs(self, m, n, k_steps, rhs_transpose, mlir_dtype_cls): + k = 64 * k_steps + index = ir.IndexType.get() + + row_major = mgpu.WGMMALayout.ROW_MAJOR + col_major = mgpu.WGMMALayout.COL_MAJOR + rhs_order = col_major if rhs_transpose else row_major + + def kernel(ctx, rhs, out, rhs_smem): + del ctx + for ki in range(k_steps): + for ni in range(n // 64): + rhs_slice = (ds(c(ki * 64, index), 64), ds(c(ni * 64, index), 64)) + if rhs_transpose: + rhs_slice = rhs_slice[::-1] + copy( + src=memref_slice(rhs, rhs_slice), + dst=memref_slice(rhs_smem, (ki, ni)), + swizzle=128, + ) + init_acc = mgpu.WGMMAAccumulator.zero(m=m, n=n) + lhs_regs = iota_tensor(m, k, mlir_dtype_cls.get()) + acc = mgpu.wgmma(init_acc, lhs_regs, rhs_smem, b_order=rhs_order) + nvvm.wgmma_commit_group_sync_aligned() + nvvm.wgmma_wait_group_sync_aligned(0) + acc.value.store_untiled(out) + + jax_dtype = jnp.float16 if mlir_dtype_cls == ir.F16Type else jnp.bfloat16 + y_shape = (n, k) if rhs_transpose else (k, n) + y = self.prng.uniform(-1, 1, y_shape).astype(jax_dtype) + out_shape = jax.ShapeDtypeStruct((m, n), jnp.float32) + scratch_shape = jax.ShapeDtypeStruct( + (k_steps, n // 64, 64, 64), jax_dtype + ) + z = mosaic_gpu.as_gpu_kernel( + kernel, (1, 1, 1), (128, 1, 1), y, out_shape, scratch_shape + )(y) + x = np.arange(m * k, dtype=jax_dtype).reshape(m, k) + ref = jax.lax.dot( + x, (y.T if rhs_transpose else y), preferred_element_type=jnp.float32 + ) + rtol = 0 if k_steps == 1 else 2.2e-4 + np.testing.assert_allclose(z, ref, rtol=rtol, atol=0) + + +class BarrierTest(TestCase): + + def test_wg_communication(self): + i32 = ir.IntegerType.get_signless(32) + def kernel(ctx, dst, tmp): + del ctx # Unused. + barriers = BarrierArray(3, arrival_count=128) + gpu.barrier() # Make sure the barriers are initialized. + wg_idx = arith.divui(mgpu.warp_idx(), c(4, i32)) + is_first_wg = arith.cmpi(arith.CmpIPredicate.eq, wg_idx, c(0, i32)) + is_second_wg = arith.cmpi(arith.CmpIPredicate.eq, wg_idx, c(1, i32)) + arr = mgpu.FragmentedArray.splat( + arith.addi(wg_idx, c(1, i32)), + (128,), + mgpu.WGStridedFragLayout((128,), 1), + ) + with ir.InsertionPoint(scf.IfOp(is_first_wg).then_block): + arr.store_untiled(tmp) + barriers[0].arrive() # Signal that tmp is ready. + barriers[1].wait() # Wait for the other warp to produce tmp. + final_arr = arr + mgpu.FragmentedArray.load_strided(tmp) + final_arr.store_untiled(memref_slice(dst, 0)) + scf.yield_([]) + with ir.InsertionPoint(scf.IfOp(is_second_wg).then_block): + barriers[0].wait() + final_arr = arr + mgpu.FragmentedArray.load_strided(tmp) + barriers[2].arrive() + barriers[2].wait() # Synchronize this warpgroup before we overwrite tmp. + arr.store_untiled(tmp) + barriers[1].arrive() # Signal that tmp is ready. + final_arr.store_untiled(memref_slice(dst, 1)) + scf.yield_([]) + out_shape = jax.ShapeDtypeStruct((2, 128), jnp.int32) + y = mosaic_gpu.as_gpu_kernel( + kernel, + (1, 1, 1), + (2 * 128, 1, 1), + (), + out_shape, + jax.ShapeDtypeStruct((128,), jnp.int32), + )() + np.testing.assert_array_equal(y, np.full_like(y, 3, dtype=np.int32)) + +class TMATest(TestCase): + + @parameterized.product( + swizzle=(None, 128), + shape=((64, 64), (5, 64), (2, 3, 5, 64)), + dtype=(jnp.float16, jnp.float32), + ) + def test_tma_load_basic(self, swizzle, shape, dtype): + if dtype == jnp.float32: + shape = (*shape[:-1], shape[-1] // 2) + i1 = ir.IntegerType.get_signless(1) + def kernel(ctx, src, dst, tmp): + barrier = BarrierArray(1)[0] + ctx.async_copy(src_ref=src, dst_ref=tmp, swizzle=swizzle, barrier=barrier) + barrier.wait_parity(c(0, i1)) + copy(tmp, dst, swizzle=swizzle) + x = np.arange(np.prod(shape), dtype=dtype).reshape(shape) + y = mosaic_gpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), x, x, x)(x) + np.testing.assert_array_equal(y, x) + + @parameterized.product( + swizzle=(None, 128), + shape=((128, 128), (5, 32, 128)), + dtype=(jnp.float16, jnp.float32), + ) + def test_tma_load_tiled(self, swizzle, shape, dtype): + i1 = ir.IntegerType.get_signless(1) + index = ir.IndexType.get() + tiling = (32, 128 // jnp.dtype(dtype).itemsize) + tiled_shape = tile_shape(shape, tiling)[:len(shape)] + def kernel(ctx, src, dst, tmp): + barrier = BarrierArray(1)[0] + ctx.async_copy( + src_ref=src, + dst_ref=tmp, + swizzle=swizzle, + barrier=barrier, + gmem_transform=mosaic_gpu.TileTransform(tiling), + ) + barrier.wait_parity(c(0, i1)) + for idxs in np.ndindex(tiled_shape): + untiled_idxs, tiled_idxs = idxs[:-len(tiling)], idxs[-len(tiling):] + s = ( + *untiled_idxs, + *(ds(c(ix * t, index), t) for ix, t in zip(tiled_idxs, tiling)), + ) + copy(memref_slice(tmp, idxs), memref_slice(dst, s), swizzle=swizzle) + x = np.arange(np.prod(shape), dtype=dtype).reshape(shape) + smem = jax.ShapeDtypeStruct(tile_shape(shape, tiling), dtype) + f = mosaic_gpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), x, x, smem) + y = f(x) + np.testing.assert_array_equal(y, x) + + @parameterized.product( + swizzle=(None, 128), + dtype=(jnp.float16, jnp.float32), + ) + def test_tma_squeeze_indexing(self, swizzle, dtype): + shape = (4, 5, 64) + if dtype == jnp.float32: + shape = (*shape[:-1], shape[-1] // 2) + def kernel(ctx, src, dst, tmp): + barrier = BarrierArray(1)[0] + for i in range(4): + ctx.async_copy( + src_ref=src, + dst_ref=memref_slice(tmp, i), + gmem_slice=i, + swizzle=swizzle, + barrier=barrier, + ) + barrier.wait() + copy(tmp, dst, swizzle=swizzle) + x = np.arange(np.prod(shape), dtype=dtype).reshape(shape) + y = mosaic_gpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), x, x, x)(x) + np.testing.assert_array_equal(y, x) + + def test_parity_tracking(self): + shape = (16, 64) + index = ir.IndexType.get() + def kernel(ctx, src, dst, tmp): + barrier = BarrierArray(1)[0] + for i in range(shape[0]): + s = ds(c(i, index), 1) + ctx.async_copy( + src_ref=src, dst_ref=tmp, gmem_slice=s, barrier=barrier, + ) + barrier.wait() + copy(tmp, memref_slice(dst, s)) + x = np.arange(np.prod(shape), dtype=jnp.float16).reshape(shape) + y = mosaic_gpu.as_gpu_kernel( + kernel, (1, 1, 1), (128, 1, 1), x, x, x[0:1] + )(x) + np.testing.assert_array_equal(y, x) + + @parameterized.product( + swizzle=(None, 128), + shape=((64, 64), (5, 64), (2, 3, 5, 64)), + dtype=(jnp.float16, jnp.float32), + ) + def test_tma_store(self, swizzle, shape, dtype): + if dtype == jnp.float32: + shape = (*shape[:-1], shape[-1] // 2) + def kernel(ctx, src, dst, tmp): + copy(src, tmp, swizzle=swizzle) + ctx.async_copy(src_ref=tmp, dst_ref=dst, swizzle=swizzle) + ctx.await_async_copy(0) + x = np.arange(np.prod(shape), dtype=dtype).reshape(shape) + y = mosaic_gpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), x, x, x)(x) + np.testing.assert_array_equal(y, x) + + +class FragmentedArrayTest(TestCase): + + @parameterized.product( + op=( + operator.add, + operator.mul, + operator.sub, + operator.truediv, + (lambda x, y: mgpu.FragmentedArray.max(x, y), np.maximum), + ), + m=(64, 128), + n=(8, 16, 32, 64, 80, 128, 256), + ) + def test_binary(self, op, m=64, n=32): + if isinstance(op, tuple): + op, np_op = op + else: + np_op = op + + for scalar_rhs in [None, 2]: + def kernel(ctx, dst, _): + f32 = ir.F32Type.get() + iota = iota_tensor(m=m, n=n, mlir_dtype=f32) + rhs = iota if scalar_rhs is None else c(scalar_rhs, iota.mlir_dtype) + op(iota, rhs).store_untiled(dst) + out_shape = jax.ShapeDtypeStruct((m, n), jnp.float32) + result = mosaic_gpu.as_gpu_kernel( + kernel, (1, 1, 1), (128, 1, 1), (), out_shape, () + )() + ref_x = np.arange(m * n, dtype=jnp.float32).reshape(m, n) + ref_rhs = scalar_rhs or ref_x + if op == operator.truediv: + np.testing.assert_allclose(result, np_op(ref_x, ref_rhs), atol=2e-7) + else: + np.testing.assert_array_equal(result, np_op(ref_x, ref_rhs)) + + @parameterized.product( + ops=((lambda x: mgpu.FragmentedArray.exp(x), np.exp),), + m=(64, 128), + n=(8, 16, 32, 64, 80, 128, 256), + ) + def test_unary(self, ops, m=64, n=32): + op, np_op = ops + def kernel(ctx, dst, _): + f32 = ir.F32Type.get() + iota = iota_tensor(m=m, n=n, mlir_dtype=f32) + op(iota).store_untiled(dst) + out_shape = jax.ShapeDtypeStruct((m, n), jnp.float32) + result = mosaic_gpu.as_gpu_kernel( + kernel, (1, 1, 1), (128, 1, 1), (), out_shape, () + )() + x = np.arange(m * n, dtype=jnp.float32).reshape(m, n) + np.testing.assert_allclose(result, np_op(x), atol=2e-7, rtol=2e-7) + + @parameterized.product( + op=(arith.addf, arith.maximumf), + m=(64, 128), + n=(8, 16, 32, 64, 80, 128, 256), + ) + def test_reduce(self, op, m=64, n=32): + def kernel(ctx, dst, _): + f32 = ir.F32Type.get() + iota = iota_tensor(m=m, n=n, mlir_dtype=f32) + iota.reduce(op, axis=1).broadcast_minor(n).store_untiled(dst) + out_shape = jax.ShapeDtypeStruct((m, n), jnp.float32) + result = mosaic_gpu.as_gpu_kernel( + kernel, (1, 1, 1), (128, 1, 1), (), out_shape, () + )() + x = np.arange(m * n, dtype=jnp.float32).reshape(m, n) + if op == arith.addf: + expected = np.broadcast_to(x.sum(axis=1, keepdims=True), x.shape) + elif op == arith.maximumf: + expected = np.broadcast_to(x.max(axis=1, keepdims=True), x.shape) + else: + raise NotImplementedError(f"Unsupported op: {op}") + np.testing.assert_array_equal(result, expected) + + def test_splat_layout(self): + m, n = 64, 8 + def kernel(ctx, dst, _): + f32 = ir.F32Type.get() + iota = iota_tensor(m=m, n=n, mlir_dtype=f32) + cte = c(1, iota.mlir_dtype) + cte_arr = mgpu.FragmentedArray.splat(cte, ()) + cte_arr = cte_arr.reshape((1, 1)).broadcast((m, n)) + (iota + cte_arr).store_untiled(dst) + out_shape = jax.ShapeDtypeStruct((m, n), jnp.float32) + result = mosaic_gpu.as_gpu_kernel( + kernel, (1, 1, 1), (128, 1, 1), (), out_shape, () + )() + expected = np.arange(m * n, dtype=jnp.float32).reshape(m, n) + 1 + np.testing.assert_array_equal(result, expected) + + def test_splat(self): + def kernel(ctx, dst, _): + f32 = ir.F32Type.get() + v = arith.constant(f32, ir.FloatAttr.get(f32, 3.14)) + t = mgpu.FragmentedArray.splat(v, (128,), mgpu.WGMMA_ROW_LAYOUT) + t.broadcast_minor(32).store_untiled(dst) + out_shape = jax.ShapeDtypeStruct((128, 32), jnp.float32) + result = mosaic_gpu.as_gpu_kernel( + kernel, (1, 1, 1), (128, 1, 1), (), out_shape, () + )() + np.testing.assert_array_equal(result, np.full((128, 32), 3.14, np.float32)) + + @parameterized.product(in_shape=((128, 128), (128, 64), (64, 128))) + def test_strided_load_store(self, in_shape): + def kernel(ctx, *args): + gmem_input, gmem_output, (smem_input, smem_output) = args + copy(gmem_input, smem_input) + t = mgpu.FragmentedArray.load_strided(smem_input) + t.store_untiled(smem_output) + copy(smem_output, gmem_output) + + inp = out = self.prng.uniform(-1, 1, in_shape).astype(jnp.float32) + result = mosaic_gpu.as_gpu_kernel( + kernel, (1, 1, 1), (128, 1, 1), (inp,), out, [inp, out], + )(inp) + np.testing.assert_array_equal(inp, result) + + def test_warp_tree_reduce(self): + def kernel(ctx, out, *_): + del ctx + i32 = ir.IntegerType.get_signless(32) + tid = gpu.thread_id(gpu.Dimension.x) + value = arith.index_cast(i32, tid) + grp = warp_tree_reduce(value, arith.addi, 4) + memref.store(grp, out, [tid]) + + x = np.arange(128, dtype=jnp.int32) + result = mosaic_gpu.as_gpu_kernel( + kernel, (1, 1, 1), (128, 1, 1), (), x, [], + )() + for i in range(0, 128, 4): + x[i:i + 4] = jnp.sum(x[i:i + 4]) + + np.testing.assert_array_equal(result, x) + + +class ProfilerTest(TestCase): + + def test_measure(self): + x = jnp.arange(1024 * 1024) + profiler.measure(lambda x, y: x + y, x, x) # This is just a smoke test + + +if __name__ == "__main__": + absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/mosaic/matmul_test.py b/tests/mosaic/matmul_test.py new file mode 100644 index 000000000000..9e6f66b3a72d --- /dev/null +++ b/tests/mosaic/matmul_test.py @@ -0,0 +1,149 @@ +# Copyright 2024 The JAX Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Test different parameterizations of a matmul.""" + +import os + +from absl.testing import absltest, parameterized +from jax._src import config +from jax._src import test_util as jtu +import jax.numpy as jnp +try: + # We only import this to see if Mosaic is available. + import jax.experimental.mosaic.gpu # noqa: F401 +except ImportError: + matmul = None +else: + from jax.experimental.mosaic.gpu.examples import matmul + + +config.parse_flags_with_absl() +os.environ["XLA_FLAGS"] = ( + os.environ.get("XLA_FLAGS", "") + " --xla_gpu_autotune_level=0") + + +@jtu.with_config(jax_traceback_filtering="off") +class MatmulTestCase(jtu.JaxTestCase): + + def setUp(self): + super().setUp() + if matmul is None: + self.skipTest("Mosaic GPU not available.") + if (not jtu.test_device_matches(["cuda"]) or + not jtu.is_cuda_compute_capability_at_least("9.0")): + self.skipTest("Only works on GPU with capability >= sm90") + + @parameterized.product( + m=(128, 256, 512, 2048), + n=(128, 256, 512, 2048), + k=(128, 256, 512, 2048), + stages=(2, 4), + tile_m=(64, 128, 256), + tile_n=(64, 128, 256), + in_dtype=(jnp.float16, jnp.bfloat16), # f32 tested separately + ) + def test_matmul(self, m, k, n, stages, tile_m, tile_n, in_dtype): + if stages * (128 // jnp.dtype(in_dtype).itemsize) > k: + self.skipTest("Too many stages.") + + if m < tile_m: + self.skipTest(f"No use in running a test with {m=} < {tile_m=}.") + + if n < tile_n: + self.skipTest(f"No use in running a test with {n=} < {tile_n=}.") + + try: + matmul.verify( + m, + k, + n, + stages, + tile_m=tile_m, + tile_n=tile_n, + lhs_dtype=in_dtype, + rhs_dtype=in_dtype, + rhs_transpose=True, + ) + except ValueError as e: + if "Mosaic GPU kernel exceeds available shared memory" in str(e): + self.skipTest("Not enough shared memory for test, skipping.") + raise e + + @parameterized.product( + m=(128, 256, 512, 2048), + n=(128, 256, 512, 2048), + k=(128, 256, 512, 2048), + stages=(2, 4), + tile_m=(64, 128, 256), + tile_n=(64, 128, 256), + high_precision=(False, True), + ) + def test_matmul_f32(self, m, k, n, stages, tile_m, tile_n, high_precision): + if stages * (128 // jnp.dtype(jnp.float32).itemsize) > k: + self.skipTest("Too many stages.") + + if m < tile_m: + self.skipTest(f"No use in running a test with {m=} < {tile_m=}.") + + if n < tile_n: + self.skipTest(f"No use in running a test with {n=} < {tile_n=}.") + + try: + matmul.verify( + m, + k, + n, + stages, + tile_m=tile_m, + tile_n=tile_n, + lhs_dtype=jnp.float32, + rhs_dtype=jnp.float32, + rhs_transpose=True, + precision=( + matmul.F32Precision.TF32_X3 + if high_precision + else matmul.F32Precision.DEFAULT + ), + ) + except ValueError as e: + if "Mosaic GPU kernel exceeds available shared memory" in str(e): + self.skipTest("Not enough shared memory for test, skipping.") + raise e + + @parameterized.parameters( + dict(m=55 * 128, n=95 * 128, k=48 * 128, stages=4, tile_m=128), + dict(m=55 * 128, n=45 * 128, k=48 * 128, stages=4, tile_m=128), + dict(m=64, n=95 * 128, k=48 * 128, stages=4, tile_m=64), + dict(m=64, n=45 * 128, k=48 * 128, stages=4, tile_m=64), + ) + def test_mixed_matmul(self, m, k, n, stages, tile_m): + # RHS.element_size==1b so k_tile=128 + if stages * 128 > k: + self.skipTest("Too many stages.") + + matmul.verify( + m, + k, + n, + stages, + tile_m=tile_m, + rhs_transpose=False, + lhs_dtype=jnp.bfloat16, + rhs_dtype=jnp.int8, + ) + + +if __name__ == "__main__": + absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/mosaic_test.py b/tests/mosaic_test.py index 518766c1e7d1..03c8f1ce36f0 100644 --- a/tests/mosaic_test.py +++ b/tests/mosaic_test.py @@ -14,9 +14,9 @@ from absl.testing import absltest from jax._src import test_util as jtu -from jax import config +import jax -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() class ImportTest(jtu.JaxTestCase): diff --git a/tests/multi_device_test.py b/tests/multi_device_test.py index 0060df9deada..b2731f256566 100644 --- a/tests/multi_device_test.py +++ b/tests/multi_device_test.py @@ -12,8 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. - -import os +import contextlib from unittest import SkipTest import tracemalloc as tm @@ -24,33 +23,17 @@ from jax import lax from jax.sharding import Mesh, NamedSharding, PartitionSpec as P from jax._src import test_util as jtu -from jax._src import xla_bridge - -from jax import config -config.parse_flags_with_absl() - -prev_xla_flags = None +jax.config.parse_flags_with_absl() # Run all tests with 8 CPU devices. +_exit_stack = contextlib.ExitStack() + def setUpModule(): - global prev_xla_flags - prev_xla_flags = os.getenv("XLA_FLAGS") - flags_str = prev_xla_flags or "" - # Don't override user-specified device count, or other XLA flags. - if "xla_force_host_platform_device_count" not in flags_str: - os.environ["XLA_FLAGS"] = (flags_str + - " --xla_force_host_platform_device_count=8") - # Clear any cached backends so new CPU backend will pick up the env var. - xla_bridge.get_backend.cache_clear() - -# Reset to previous configuration in case other test modules will be run. + _exit_stack.enter_context(jtu.set_host_platform_device_count(8)) + def tearDownModule(): - if prev_xla_flags is None: - del os.environ["XLA_FLAGS"] - else: - os.environ["XLA_FLAGS"] = prev_xla_flags - xla_bridge.get_backend.cache_clear() + _exit_stack.close() class MultiDeviceTest(jtu.JaxTestCase): @@ -196,8 +179,10 @@ def f(): return lax.add(3., 4.) self.assertIsInstance(f(), jax.Array) self.assert_uncommitted_to_device(f(), devices[0]) self.assert_uncommitted_to_device(jax.jit(f)(), devices[0]) - self.assert_committed_to_device(jax.jit(f, device=devices[1])(), - devices[1]) + with jtu.ignore_warning(category=DeprecationWarning, + message="backend and device argument"): + self.assert_committed_to_device(jax.jit(f, device=devices[1])(), + devices[1]) def test_reshape(self): devices = self.get_devices() diff --git a/tests/multibackend_test.py b/tests/multibackend_test.py index 40cbb6630411..4f2e36c64f4b 100644 --- a/tests/multibackend_test.py +++ b/tests/multibackend_test.py @@ -25,8 +25,7 @@ from jax._src import test_util as jtu from jax import numpy as jnp -from jax import config -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() npr.seed(0) @@ -35,6 +34,8 @@ class MultiBackendTest(jtu.JaxTestCase): """Tests jit targeting to different backends.""" @jtu.sample_product(backend=['cpu', 'gpu', 'tpu', None]) + @jtu.ignore_warning(category=DeprecationWarning, + message="backend and device argument") def testMultiBackend(self, backend): if backend not in ('cpu', jtu.device_under_test(), None): raise SkipTest("Backend is not CPU or the device under test") @@ -53,6 +54,8 @@ def fun(x, y): @jtu.sample_product( ordering=[('cpu', None), ('gpu', None), ('tpu', None), (None, None)] ) + @jtu.ignore_warning(category=DeprecationWarning, + message="backend and device argument") def testMultiBackendNestedJit(self, ordering): outer, inner = ordering if outer not in ('cpu', jtu.device_under_test(), None): @@ -79,6 +82,8 @@ def infun(x, y): (None, 'cpu'), (None, 'gpu'), (None, 'tpu'), ], ) + @jtu.ignore_warning(category=DeprecationWarning, + message="backend and device argument") def testMultiBackendNestedJitConflict(self, ordering): outer, inner = ordering if outer not in ('cpu', jtu.device_under_test(), None): @@ -106,6 +111,8 @@ def infun(x, y): self.assertRaises(ValueError, lambda: fun(x, y)) @jtu.sample_product(backend=['cpu', 'gpu', 'tpu']) + @jtu.ignore_warning(category=DeprecationWarning, + message="backend and device argument") def testGpuMultiBackendOpByOpReturn(self, backend): if backend not in ('cpu', jtu.device_under_test()): raise SkipTest("Backend is not CPU or the device under test") @@ -120,6 +127,8 @@ def fun(x, y): self.assertEqual(list(w.devices())[0].platform, backend) @jtu.skip_on_devices("cpu") # test can only fail with non-cpu backends + @jtu.ignore_warning(category=DeprecationWarning, + message="backend and device argument") def testJitCpu(self): @partial(jax.jit, backend='cpu') def get_arr(scale): @@ -136,6 +145,8 @@ def get_arr(scale): self.assertEqual(c.devices(), {jax.devices('cpu')[0]}) @jtu.skip_on_devices("cpu") # test can only fail with non-cpu backends + @jtu.ignore_warning(category=DeprecationWarning, + message="backend and device argument") def test_closed_over_values_device_placement(self): # see https://github.com/google/jax/issues/1431 def f(): return jnp.add(3., 4.) @@ -145,6 +156,8 @@ def f(): return jnp.add(3., 4.) {jax.devices('cpu')[0]}) @jtu.skip_on_devices("cpu") # test only makes sense on non-cpu backends + @jtu.ignore_warning(category=DeprecationWarning, + message="backend and device argument") def test_jit_on_nondefault_backend(self): cpus = jax.devices("cpu") self.assertNotEmpty(cpus) diff --git a/tests/multiprocess_gpu_test.py b/tests/multiprocess_gpu_test.py index bbe79ecff969..760e340815af 100644 --- a/tests/multiprocess_gpu_test.py +++ b/tests/multiprocess_gpu_test.py @@ -26,21 +26,22 @@ import numpy as np import jax -from jax import config from jax._src import core from jax._src import distributed -from jax._src import maps from jax._src import test_util as jtu from jax._src import util from jax.experimental import pjit import jax.numpy as jnp +# Used to test for mpi4py installation and skip tests if not installed +import importlib.util + try: import portpicker except ImportError: portpicker = None -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() @unittest.skipIf(not portpicker, "Test requires portpicker") class DistributedTest(jtu.JaxTestCase): @@ -220,11 +221,52 @@ def test_gpu_ompi_distributed_initialize(self): finally: proc.kill() + def test_gpu_mpi4py_distributed_initialize(self): + if not jtu.test_device_matches(['gpu']): + raise unittest.SkipTest('Tests only for GPU.') + if shutil.which('mpirun') is None: + raise unittest.SkipTest('Tests only for MPI (mpirun not found).') + if importlib.util.find_spec("mpi4py") is None: + raise unittest.SkipTest('Test of mpi4py initialize only possible with mpi4py installed.') + + num_gpus = 4 + num_gpus_per_task = 1 + + with contextlib.ExitStack() as exit_stack: + args = [ + 'mpirun', + '--oversubscribe', + '--allow-run-as-root', + '-n', + str(num_gpus), + sys.executable, + '-c', + ('import jax, os; ' + 'jax.distributed.initialize(spec_detection_method="mpi4py"); ' + 'print(f\'{jax.local_device_count()},{jax.device_count()}\' if jax.process_index() == 0 else \'\', end="")' + ) + ] + env = os.environ.copy() + # In case the job was launched via Slurm, + # prevent OpenMPI from detecting Slurm environment + env.pop('SLURM_JOBID', None) + proc = subprocess.Popen(args, env=env, stdout=subprocess.PIPE, + stderr=subprocess.PIPE, universal_newlines=True) + proc = exit_stack.enter_context(proc) + + try: + out, _ = proc.communicate() + self.assertEqual(proc.returncode, 0) + self.assertEqual(out, f'{num_gpus_per_task},{num_gpus}') + finally: + proc.kill() + @unittest.skipIf( os.environ.get("SLURM_JOB_NUM_NODES", None) != "2", "Slurm environment with at least two nodes needed!") @jtu.pytest_mark_if_available('SlurmMultiNodeGpuTest') +@jtu.with_config(experimental_xmap_spmd_lowering=True) class SlurmMultiNodeGpuTest(jtu.JaxTestCase): def sorted_devices(self): @@ -257,16 +299,6 @@ def create_2d_non_contiguous_mesh(self): ] == [0, 2, 4, 6, 1, 3, 5, 7, 8, 10, 12, 14, 9, 11, 13, 15] return jax.sharding.Mesh(device_mesh, ("x", "y")) - def setUp(self): - super().setUp() - self.xmap_spmd_lowering_enabled = maps.SPMD_LOWERING.value - jax.config.update("experimental_xmap_spmd_lowering", True) - - def tearDown(self): - jax.config.update("experimental_xmap_spmd_lowering", - self.xmap_spmd_lowering_enabled) - super().tearDown() - def test_gpu_multi_node_initialize_and_psum(self): # Hookup the ENV vars expected to be set already in the SLURM environment diff --git a/tests/mutable_array_test.py b/tests/mutable_array_test.py new file mode 100644 index 000000000000..fe7ddd618ee1 --- /dev/null +++ b/tests/mutable_array_test.py @@ -0,0 +1,227 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from absl.testing import absltest +from absl.testing import parameterized +import numpy as np +import jax +from jax._src import core +from jax._src import config +from jax._src import test_util as jtu +import jax.numpy as jnp + +from jax._src.state.types import (RefEffect) + +config.parse_flags_with_absl() + +class MutableArrayTest(jtu.JaxTestCase): + + @parameterized.parameters([True, False]) + def test_basic(self, jit): + def f(x_mut): + x_mut[...] += 1. + x_mut[0] += 1 + x_mut[1] += 5 + + if jit: + f = jax.jit(f) + + x_mut = core.mutable_array(jnp.zeros(3)) + f(x_mut) + + self.assertAllClose(x_mut[...], jnp.array([2., 6., 1.]), + check_dtypes=False) + + jaxpr = jax.make_jaxpr(f)(x_mut) + self.assertTrue(any(isinstance(e, RefEffect) for e in jaxpr.effects)) + + # disabling this test for now. TODO(dougalm): re-enable once we add checks to + # ensure mutable arrays aren't returned or duplicated etc. + # def test_staging_error(self): + # x = jnp.zeros(3) + # with self.assertRaises(Exception): + # jax.jit(core.mutable_array)(x) + + @parameterized.parameters([True, False]) + def test_multiple_inputs_and_outputs(self, jit): + def f(x_mut, y, z_mut, w): + x_mut[...] += 1 + z_mut[...] += 1 + return x_mut[...] + y + z_mut[...] + w, y + w + + if jit: + f = jax.jit(f) + + x_mut = core.mutable_array(jnp.zeros((1, 3))) + y = jnp.ones((2, 3)) + z_mut = core.mutable_array(jnp.zeros((2, 3))) + w = jnp.ones((2, 1)) + + out1, out2 = f(x_mut, y, z_mut, w) + + self.assertAllClose(x_mut[...], jnp.ones((1, 3)), check_dtypes=False) + self.assertAllClose(z_mut[...], jnp.ones((2, 3)), check_dtypes=False) + self.assertAllClose(out1, 4 * jnp.ones((2, 3)), check_dtypes=False) + self.assertAllClose(out2, y + w, check_dtypes=False) + + @parameterized.parameters([True, False]) + def test_closed_over_basic(self, jit): + x_mut = core.mutable_array(jnp.zeros(3)) + def f(): + x_mut[...] += 1. + x_mut[0] += 1 + x_mut[1] += 5 + + if jit: + f = jax.jit(f) + + f() + + self.assertAllClose(x_mut[...], jnp.array([2., 6., 1.]), + check_dtypes=False) + + jaxpr = jax.make_jaxpr(f)() + self.assertTrue(any(isinstance(e, RefEffect) for e in jaxpr.effects)) + + @parameterized.parameters([True, False]) + def test_closed_over_nested(self, jit): + x_mut = core.mutable_array(jnp.zeros(3)) + + @jax.jit + def f(y_mut, z): + x_mut[...] += 1. + x_mut[0] += 1 + x_mut[1] += 5 + + y_mut[2] += 7 + return z + 9 + + if jit: + f = jax.jit(f) + + y_mut = core.mutable_array(np.zeros(3)) + + w = f(y_mut, 1) + + self.assertAllClose(x_mut[...], jnp.array([2., 6., 1.]), + check_dtypes=False) + self.assertAllClose(y_mut[...], jnp.array([0., 0., 7.]), + check_dtypes=False) + self.assertAllClose(w, 10, check_dtypes=False) + + @parameterized.parameters([True, False]) + def test_internal_mutarray_basic(self, jit): + def f(): + x_mut = core.mutable_array(jnp.zeros(3)) + x_mut[0] += 1 + x_mut[0] += 1 + x_mut[2] += 1 + return x_mut[...] + + if jit: + f = jax.jit(f) + + out = f() + self.assertAllClose(out, jnp.array([2., 0., 1.]), check_dtypes=False) + + @parameterized.parameters([True, False]) + def test_refs_in_vjps(self, jit): + def gradient_history_calculator_fwd(x, ref): + return x, ref + + def gradient_history_calculator_bwd(amax_history, grad_output): + amax_update = jnp.max(jnp.abs(grad_output)) + shifted = jnp.roll(amax_history[:], 1) + shifted = shifted.at[0].set(amax_update) + amax_history[:] = shifted + amax_from_history = jnp.max(amax_history[:]) + grad_output = grad_output / amax_from_history + return grad_output, None + + @jax.custom_vjp + def gradient_history_calculator(x, ref): + return x + + gradient_history_calculator.defvjp( + gradient_history_calculator_fwd, + gradient_history_calculator_bwd) + + class DotOp: + def __init__(self): + self.amax_history = core.mutable_array(jnp.zeros(5,)) + + def forward(self, x, y): + out = jnp.dot(x, y) + out = gradient_history_calculator(out, self.amax_history) + return out + + dot_op = DotOp() + x_top = jnp.ones((5,)) + y_top = jnp.ones((5,)) + + def loss(x, y): + return dot_op.forward(x, y).sum() + + if jit: + loss = jax.jit(loss) + + for i in range(3): + jax.grad(loss, (0,1))(x_top, y_top) + self.assertAllClose(dot_op.amax_history[:], jnp.zeros((5,)).at[:i+1].set(1.0), check_dtypes=False) + + @parameterized.parameters([True, False]) + def test_scan_internal_mut_array(self, jit): + def body_fun(_, x): + x_mut = core.mutable_array(x) + x_mut[...] += 2 + return ((), x_mut[...]) + doit = lambda: jax.lax.scan(body_fun, (), np.arange(5)) + if jit: + doit = jax.jit(doit) + _, xs = doit() + self.assertAllClose(xs, (np.arange(5) + 2), check_dtypes=False) + + @parameterized.parameters([True, False]) + def test_scan_closed_over_mut_array(self, jit): + x_mut = core.mutable_array(0) + def body_fun(_, x): + x_mut[...] += 2 + return ((), x_mut[...]) + + doit = lambda: jax.lax.scan(body_fun, (), np.arange(5)) + if jit: + doit = jax.jit(doit) + _, xs = doit() + self.assertAllClose(x_mut[...], 10) + self.assertAllClose(xs, np.arange(5) * 2 + 2, check_dtypes=False) + + @parameterized.parameters([True, False]) + def test_scan_scanned_mut_array(self, jit): + def body_fun(_, index_x): + (index, x) = index_x + x[...] += index + # breakpoint() + return ((), x[...]) + + x_mut = core.mutable_array(np.arange(5)) + doit = lambda: jax.lax.scan(body_fun, (), (np.arange(5), x_mut)) + if jit: + doit = jax.jit(doit) + _, xs = doit() + self.assertAllClose(xs, (np.arange(5) * 2), check_dtypes=False) + +if __name__ == '__main__': + absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/name_stack_test.py b/tests/name_stack_test.py index e6ac29e7088b..993c729f01d8 100644 --- a/tests/name_stack_test.py +++ b/tests/name_stack_test.py @@ -20,16 +20,15 @@ from jax import lax from jax._src.pjit import pjit from jax._src import linear_util as lu -from jax import config from jax._src import test_util as jtu from jax._src.lib import xla_client from jax._src import ad_checkpoint -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() def _get_hlo(f): def wrapped(*args, **kwargs): - c = jax.xla_computation(f)(*args, **kwargs) + c = jax.jit(f).lower(*args, **kwargs).compiler_ir('hlo') print_opts = xla_client._xla.HloPrintOptions.short_parsable() print_opts.print_metadata = True return c.as_hlo_module().to_string(print_opts) @@ -216,7 +215,7 @@ def f(x): self.assertIn('transpose(jvp(foo))/mul', hlo_text) def test_grad_should_add_jvp_and_transpose_to_call_jaxpr(self): - @jax.grad + @jax.value_and_grad @jax.named_scope('foo') @jax.jit def f(x): @@ -241,7 +240,7 @@ def f(x): def test_nested_jit_stack(self): - @jax.grad + @jax.value_and_grad @jax.jit def f(x): @jax.jit @@ -255,7 +254,7 @@ def g(y): self.assertIn('transpose(jvp(jit(f)))/jit(g)/mul', hlo_text) def test_nested_pjit_stack(self): - @jax.grad + @jax.value_and_grad @pjit def f(x): @pjit @@ -498,16 +497,19 @@ def false_fn(x): def test_grad_of_cond_transforms_name_stack(self): - @jax.grad + @jax.value_and_grad @jax.named_scope('foo') def f(x, y): @jax.named_scope('true') def true_fn(x): return x * x * 2. + @jax.named_scope('false') def false_fn(x): return x / jnp.square(x) + return lax.cond(y, true_fn, false_fn, x) + jaxpr = jax.make_jaxpr(f)(1., True) self.assertEqual(str(jaxpr.eqns[1].source_info.name_stack), 'jvp(foo)') self.assertEqual(str(jaxpr.eqns[2].source_info.name_stack), @@ -530,7 +532,7 @@ def false_fn(x): def test_vmap_of_grad_of_cond_transforms_name_stack(self): @functools.partial(jax.vmap, in_axes=(0, None)) - @jax.grad + @jax.value_and_grad @jax.named_scope('foo') def f(x, y): @jax.named_scope('true') diff --git a/tests/nn_test.py b/tests/nn_test.py index 1a89389a9f58..5f8f499b7471 100644 --- a/tests/nn_test.py +++ b/tests/nn_test.py @@ -63,6 +63,25 @@ def testSoftplusGradNan(self): def testSoftplusZero(self, dtype): self.assertEqual(jnp.log(dtype(2)), nn.softplus(dtype(0))) + def testSparseplusGradZero(self): + check_grads(nn.sparse_plus, (-2.,), order=1, + rtol=1e-2 if jtu.test_device_matches(["tpu"]) else None) + + def testSparseplusGrad(self): + check_grads(nn.sparse_plus, (0.,), order=1, + rtol=1e-2 if jtu.test_device_matches(["tpu"]) else None) + + def testSparseplusAndSparseSigmoid(self): + self.assertAllClose( + jax.grad(nn.sparse_plus)(0.), nn.sparse_sigmoid(0.), + check_dtypes=False) + self.assertAllClose( + jax.grad(nn.sparse_plus)(2.), nn.sparse_sigmoid(2.), + check_dtypes=False) + self.assertAllClose( + jax.grad(nn.sparse_plus)(-2.), nn.sparse_sigmoid(-2.), + check_dtypes=False) + def testSquareplusGrad(self): check_grads(nn.squareplus, (1e-8,), order=4, rtol=1e-2 if jtu.test_device_matches(["tpu"]) else None) @@ -83,6 +102,26 @@ def testSquareplusGradNan(self): def testSquareplusZero(self, dtype): self.assertEqual(dtype(1), nn.squareplus(dtype(0), dtype(4))) + def testMishGrad(self): + check_grads(nn.mish, (1e-8,), order=4, + rtol=1e-2 if jtu.test_device_matches(["tpu"]) else None) + + def testMishGradZero(self): + check_grads(nn.mish, (0.,), order=1, + rtol=1e-2 if jtu.test_device_matches(["tpu"]) else None) + + def testMishGradNegInf(self): + check_grads(nn.mish, (-float('inf'),), order=1, + rtol=1e-2 if jtu.test_device_matches(["tpu"]) else None) + + def testMishGradNan(self): + check_grads(nn.mish, (float('nan'),), order=1, + rtol=1e-2 if jtu.test_device_matches(["tpu"]) else None) + + @parameterized.parameters([float] + jtu.dtypes.floating) + def testMishZero(self, dtype): + self.assertEqual(dtype(0), nn.mish(dtype(0))) + def testReluGrad(self): rtol = 1e-2 if jtu.test_device_matches(["tpu"]) else None check_grads(nn.relu, (1.,), order=3, rtol=rtol) @@ -101,10 +140,23 @@ def testSoftplusValue(self): val = nn.softplus(89.) self.assertAllClose(val, 89., check_dtypes=False) + def testSparseplusValue(self): + val = nn.sparse_plus(89.) + self.assertAllClose(val, 89., check_dtypes=False) + + def testSparsesigmoidValue(self): + self.assertAllClose(nn.sparse_sigmoid(-2.), 0., check_dtypes=False) + self.assertAllClose(nn.sparse_sigmoid(2.), 1., check_dtypes=False) + self.assertAllClose(nn.sparse_sigmoid(0.), .5, check_dtypes=False) + def testSquareplusValue(self): val = nn.squareplus(1e3) self.assertAllClose(val, 1e3, check_dtypes=False, atol=1e-3) + def testMishValue(self): + val = nn.mish(1e3) + self.assertAllClose(val, 1e3, check_dtypes=False, atol=1e-3) + @jtu.skip_on_flag("jax_skip_slow_tests", True) def testEluGrad(self): check_grads(nn.elu, (1e4,), order=4, eps=1.) @@ -137,7 +189,7 @@ def gelu_reference(x): (jnp.float32, jnp.bfloat16, jnp.float16), (partial(nn.gelu, approximate=False), partial(nn.gelu, approximate=True), - nn.relu, nn.softplus, nn.sigmoid, nn.squareplus))) + nn.relu, nn.softplus, nn.sparse_plus, nn.sigmoid, nn.squareplus, nn.mish))) def testDtypeMatchesInput(self, dtype, fn): x = jnp.zeros((), dtype=dtype) out = fn(x) @@ -153,12 +205,24 @@ def testHardTanhMemory(self): with jax.enable_checks(False): # With checks we materialize the array jax.make_jaxpr(lambda: nn.hard_tanh(jnp.ones((10 ** 12,)))) # don't oom + @parameterized.parameters([nn.softmax, nn.log_softmax]) + def testSoftmaxEmptyArray(self, fn): + x = jnp.array([], dtype=float) + self.assertArraysEqual(fn(x), x) + + @parameterized.parameters([nn.softmax, nn.log_softmax]) + def testSoftmaxEmptyMask(self, fn): + x = jnp.array([5.5, 1.3, -4.2, 0.9]) + m = jnp.zeros_like(x, dtype=bool) + expected = jnp.full_like(x, 0.0 if fn is nn.softmax else -jnp.inf) + self.assertArraysEqual(fn(x, where=m), expected) + @parameterized.parameters([nn.softmax, nn.log_softmax]) def testSoftmaxWhereMask(self, fn): x = jnp.array([5.5, 1.3, -4.2, 0.9]) m = jnp.array([True, False, True, True]) - out = fn(x, where=m, initial=-jnp.inf) + out = fn(x, where=m) self.assertAllClose(out[m], fn(x[m])) probs = out if fn is nn.softmax else jnp.exp(out) @@ -178,7 +242,7 @@ def testSoftmaxWhereGrad(self, fn): x = jnp.array([36., 10000.]) mask = x < 1000 - f = lambda x, mask: fn(x, where=mask, initial=x.min())[0] + f = lambda x, mask: fn(x, where=mask)[0] self.assertAllClose(jax.grad(f)(x, mask), jnp.zeros_like(x)) diff --git a/tests/ode_test.py b/tests/ode_test.py index 2d2bcc971434..834745e1cf1c 100644 --- a/tests/ode_test.py +++ b/tests/ode_test.py @@ -24,8 +24,7 @@ import scipy.integrate as osp_integrate -from jax import config -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() class ODETest(jtu.JaxTestCase): diff --git a/tests/optimizers_test.py b/tests/optimizers_test.py index 3fb3101c4142..b7710d9b94c2 100644 --- a/tests/optimizers_test.py +++ b/tests/optimizers_test.py @@ -26,8 +26,7 @@ from jax import lax from jax.example_libraries import optimizers -from jax import config -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() class OptimizerTests(jtu.JaxTestCase): diff --git a/tests/package_structure_test.py b/tests/package_structure_test.py index 12424b475bac..448f8f1e7f93 100644 --- a/tests/package_structure_test.py +++ b/tests/package_structure_test.py @@ -32,6 +32,7 @@ class PackageStructureTest(jtu.JaxTestCase): # TODO(jakevdp): expand test to other public modules. _mod("jax.errors"), _mod("jax.nn.initializers"), + _mod("jax.tree_util", exclude=['PyTreeDef', 'default_registry']), ]) def test_exported_names_match_module(self, module_name, include, exclude): """Test that all public exports have __module__ set correctly.""" @@ -43,7 +44,8 @@ def test_exported_names_match_module(self, module_name, include, exclude): obj = getattr(module, name) if isinstance(obj, types.ModuleType): continue - self.assertEqual(obj.__module__, module_name) + self.assertEqual(obj.__module__, module_name, + f"{obj} has {obj.__module__=}, expected {module_name}") if __name__ == '__main__': diff --git a/tests/pallas/BUILD b/tests/pallas/BUILD index 5dd5547c191f..30f897d0fe6e 100644 --- a/tests/pallas/BUILD +++ b/tests/pallas/BUILD @@ -33,32 +33,35 @@ jax_test( srcs = [ "pallas_test.py", ], - backend_tags = { - "gpu": ["noasan"], # https://github.com/openai/triton/issues/2918 - }, config_tags_overrides = { - "gpu_x32": { + "gpu_a100_x32": { "ondemand": False, # Include in presubmit. }, }, - disable_backends = [ - "cpu", - "tpu", - ], disable_configs = [ + "cpu", # The 64-bit variant "gpu", + "gpu_x32", "gpu_a100", "gpu_h100", + "gpu_p100", + "gpu_p100_x32", ], enable_configs = [ - "gpu_x32", "gpu_a100_x32", - "gpu_p100_x32", "gpu_h100_x32", ], - shard_count = 4, + shard_count = { + "cpu": 8, + "gpu": 4, + "tpu": 4, + }, deps = [ + "//jax:pallas", "//jax:pallas_gpu", + "//jax:pallas_gpu_ops", + "//jax:pallas_tpu", + "//jax:pallas_tpu_ops", ] + py_deps("absl/testing") + py_deps("numpy"), ) @@ -67,11 +70,8 @@ jax_test( srcs = [ "gpu_attention_test.py", ], - backend_tags = { - "gpu": ["noasan"], # https://github.com/openai/triton/issues/2918 - }, config_tags_overrides = { - "gpu_x32": { + "gpu_a100_x32": { "ondemand": False, # Include in presubmit. }, }, @@ -81,55 +81,52 @@ jax_test( ], disable_configs = [ "gpu", + "gpu_x32", "gpu_p100", + "gpu_p100_x32", "gpu_a100", "gpu_h100", ], enable_configs = [ - "gpu_x32", "gpu_a100_x32", - "gpu_p100_x32", "gpu_h100_x32", ], shard_count = 1, deps = [ - "//jax:pallas_gpu", + "//jax:pallas", + "//jax:pallas_gpu", # build_cleaner: keep "//jax:pallas_gpu_ops", ] + py_deps("absl/testing") + py_deps("numpy"), ) jax_test( - name = "pallas_via_xla_test", + name = "ops_test", srcs = [ - "pallas_test.py", + "ops_test.py", ], - backend_tags = { - "gpu": ["noasan"], # https://github.com/openai/triton/issues/2918 - }, config_tags_overrides = { - "gpu_x32": { + "gpu_a100_x32": { "ondemand": False, # Include in presubmit. }, }, disable_backends = [ "cpu", - "tpu", ], disable_configs = [ "gpu", - "gpu_p100", + "gpu_x32", "gpu_a100", + "gpu_p100", + "gpu_p100_x32", "gpu_h100", ], enable_configs = [ - "gpu_x32", "gpu_a100_x32", - "gpu_p100_x32", "gpu_h100_x32", ], - shard_count = 4, deps = [ - "//jax:pallas_gpu", + "//jax:pallas", + "//jax:pallas_gpu", # build_cleaner: keep ] + py_deps("absl/testing") + py_deps("numpy"), ) @@ -208,3 +205,192 @@ jax_test( "//jax:pallas_tpu_ops", ], ) + +jax_test( + name = "pallas_pipeline_tpu_test", + srcs = ["pallas_pipeline_tpu_test.py"], + disable_backends = [ + "gpu", + ], + main = "pallas_pipeline_tpu_test.py", + shard_count = 2, + tags = [ + "noasan", # Times out. + "nomsan", # Times out. + "notsan", # Times out. + ], + deps = [ + "//jax:extend", + "//jax:pallas_tpu", + "//jax:pallas_tpu_ops", + ] + py_deps("hypothesis"), +) + +jax_test( + name = "paged_attention_kernel_test", + srcs = ["paged_attention_kernel_test.py"], + disable_backends = [ + "cpu", + "gpu", + ], + shard_count = 5, + tags = [ + "noasan", # Times out. + "nomsan", # Times out. + "notsan", # Times out. + ], + deps = [ + "//jax:pallas_tpu_ops", + ] + py_deps("absl/testing") + py_deps("numpy"), +) + +jax_test( + name = "gmm_test", + srcs = [ + "gmm_test.py", + ], + disable_backends = [ + "cpu", + "gpu", + ], + shard_count = 50, + tags = [ + "noasan", # Times out. + "nomsan", # Times out. + "notsan", # Times out. + ], + deps = [ + "//jax:pallas_tpu_ops", + ] + py_deps([ + "absl/testing", + "absl/flags", + "numpy", + "hypothesis", + ]), +) + +jax_test( + name = "mosaic_gpu_test", + srcs = [ + "mosaic_gpu_test.py", + ], + config_tags_overrides = { + # TODO(slebedev): Switch to False once Mosaic GPU is unconditionally enabled. + "gpu_h100_x32": { + "ondemand": True, # Include in presubmit. + }, + }, + disable_backends = [ + "cpu", + "tpu", + ], + disable_configs = [ + "gpu", + "gpu_x32", + "gpu_a100", + "gpu_a100_x32", + "gpu_p100", + "gpu_p100_x32", + "gpu_h100", + ], + enable_configs = [ + "gpu_h100_x32", + ], + env = { + "JAX_PALLAS_USE_MOSAIC_GPU": "1", + }, + tags = ["notap"], + deps = [ + "//jax:pallas", + "//jax:pallas_gpu", # build_cleaner: keep + ] + py_deps("absl/testing") + py_deps("numpy"), +) + +jax_test( + name = "export_back_compat_pallas_test", + srcs = ["export_back_compat_pallas_test.py"], + config_tags_overrides = { + "gpu_a100_x32": { + "ondemand": False, # Include in presubmit. + }, + }, + disable_backends = [ + "cpu", + "tpu", + ], + disable_configs = [ + "gpu", + "gpu_x32", + "gpu_a100", + "gpu_h100", + "gpu_p100", + "gpu_p100_x32", + "gpu_pjrt_c_api", + ], + enable_configs = [ + "gpu_a100_x32", + "gpu_h100_x32", + ], + tags = [], + deps = [ + "//jax:internal_export_back_compat_test_data", + "//jax:internal_export_back_compat_test_util", + "//jax:pallas", + "//jax:pallas_gpu", # build_cleaner: keep + ], +) + +jax_test( + name = "export_pallas_test", + srcs = ["export_pallas_test.py"], + config_tags_overrides = { + "gpu_a100_x32": { + "ondemand": False, # Include in presubmit. + }, + }, + disable_configs = [ + "gpu", + "gpu_x32", + "gpu_a100", + "gpu_h100", + "gpu_p100", + "gpu_p100_x32", + "gpu_pjrt_c_api", + ], + enable_configs = [ + "gpu_a100_x32", + ], + tags = [], + deps = [ + "//jax:pallas", + "//jax:pallas_gpu", # build_cleaner: keep + "//jax:pallas_tpu", # build_cleaner: keep + ], +) + +jax_test( + name = "pallas_shape_poly_test", + srcs = ["pallas_shape_poly_test.py"], + config_tags_overrides = { + "gpu_a100_x32": { + "ondemand": False, # Include in presubmit. + }, + }, + disable_configs = [ + "gpu_x32", + "gpu_h100", + "gpu_p100", + "gpu_p100_x32", + "gpu_pjrt_c_api", + ], + enable_configs = [ + "gpu_a100_x32", + ], + tags = [], + deps = [ + "//jax:pallas", + "//jax:pallas_gpu", # build_cleaner: keep + "//jax:pallas_tpu", # build_cleaner: keep + "//jax/experimental/export", + ], +) diff --git a/tests/pallas/all_gather_test.py b/tests/pallas/all_gather_test.py index b151594d3e64..98b3e5b40135 100644 --- a/tests/pallas/all_gather_test.py +++ b/tests/pallas/all_gather_test.py @@ -84,13 +84,14 @@ def _array_dtypes(draw): class AllGatherTest(jtu.JaxTestCase): def setUp(self): - super().setUp() if not jtu.test_device_matches(["tpu"]): self.skipTest("Need TPU devices") if not jtu.is_device_tpu(version=5, variant="e"): # TODO(sharadmv,apaszke): expand support to more versions self.skipTest("Currently only supported on TPU v5e") + super().setUp() + @hp.given(hps.booleans(), _array_shapes(), _array_dtypes()) def test_all_gather_1d_mesh(self, is_vmem, shape, dtype): if jax.device_count() < 2: diff --git a/tests/pallas/export_back_compat_pallas_test.py b/tests/pallas/export_back_compat_pallas_test.py new file mode 100644 index 000000000000..8cf3f9708e38 --- /dev/null +++ b/tests/pallas/export_back_compat_pallas_test.py @@ -0,0 +1,61 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for backwards compatibility of exporting code with Pallas custom calls. + +See the export_back_compat_test_util module docstring for how to setup and +update these tests. +""" + +from absl.testing import absltest + +import jax +import jax.numpy as jnp +from jax._src import config +from jax._src import test_util as jtu +from jax._src.internal_test_util import export_back_compat_test_util as bctu +from jax._src.internal_test_util.export_back_compat_test_data.pallas import cuda_add_one +from jax.experimental import pallas as pl + +config.parse_flags_with_absl() + + +@jtu.with_config(jax_include_full_tracebacks_in_locations=False) +class CompatTest(bctu.CompatTestBase): + + def setUp(self): + if jax.config.x64_enabled: + self.skipTest("Only works in 32-bit") + if not jtu.test_device_matches(["gpu"]): + self.skipTest("Only works on GPU") + if (jtu.test_device_matches(["cuda"]) and + not jtu.is_cuda_compute_capability_at_least("8.0")): + self.skipTest("Only works on GPUs with capability >= sm80") + super().setUp() + + def test_cuda_add_one(self): + def func(x): + def add_one(x_ref, o_ref): + o_ref[0] = x_ref[0] + 1 + return pl.pallas_call(add_one, + out_shape=jax.ShapeDtypeStruct((8,), jnp.float32), + in_specs=[pl.BlockSpec((1,), lambda i: i)], + out_specs=pl.BlockSpec((1,), lambda i: i), + grid=8)(x) + data = self.load_testdata(cuda_add_one.data_2024_05_02) + + self.run_one_test(func, data) + + +if __name__ == "__main__": + absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/pallas/export_pallas_test.py b/tests/pallas/export_pallas_test.py new file mode 100644 index 000000000000..3ea293f55c3d --- /dev/null +++ b/tests/pallas/export_pallas_test.py @@ -0,0 +1,64 @@ +# Copyright 2023 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test exporting Pallas kernels.""" +import sys + +from absl.testing import absltest +import jax +from jax._src import test_util as jtu +from jax import export +# Import mosaic for flag definitions +from jax.experimental import mosaic as _ # noqa: F401 +from jax.experimental import pallas as pl +import numpy as np + + +jax.config.parse_flags_with_absl() + + +class ExportTest(jtu.JaxTestCase): + + def setUp(self): + if sys.platform == "win32": + self.skipTest("Only works on non-Windows platforms") + + super().setUp() + + def test_cross_platform(self): + def add_vectors_kernel(x_ref, y_ref, o_ref): + x, y = x_ref[...], y_ref[...] + o_ref[...] = x + y + + @jax.jit + def add_vectors(x: jax.Array, y: jax.Array) -> jax.Array: + return pl.pallas_call(add_vectors_kernel, + out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype) + )(x, y) + + a = np.arange(8) + exp = export.export( + add_vectors, + lowering_platforms=["tpu", "cuda"], + )(a, a) + + if (jtu.device_under_test() == "tpu" or + (jtu.device_under_test() == "gpu" and + jtu.is_cuda_compute_capability_at_least("8.0"))): + res = exp.call(a, a) + self.assertAllClose(res, a + a) + + +if __name__ == '__main__': + absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/pallas/gmm_test.py b/tests/pallas/gmm_test.py new file mode 100644 index 000000000000..be830a6a4473 --- /dev/null +++ b/tests/pallas/gmm_test.py @@ -0,0 +1,390 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import functools +import itertools +from typing import Any + +from absl.testing import absltest +from absl.testing import parameterized +import jax +from jax._src import test_util as jtu +import jax.experimental.pallas.ops.tpu.megablox as mblx +import jax.numpy as jnp +import numpy as np + +try: + import hypothesis as hp + import hypothesis.strategies as hps + CAN_USE_HYPOTHESIS = True +except (ModuleNotFoundError, ImportError): + CAN_USE_HYPOTHESIS = False + +jax.config.parse_flags_with_absl() + +P = jax.sharding.PartitionSpec + +partial = functools.partial + +if CAN_USE_HYPOTHESIS: + hp.settings.register_profile( + "deterministic", + database=None, + derandomize=True, + deadline=None, + max_examples=10, + print_blob=True, + ) + hp.settings.load_profile("deterministic") + + + def seed_strategy() -> hps.SearchStrategy[int]: + return hps.integers(min_value=0, max_value=4) + + + @hps.composite + def group_strategy( + draw: hps.DrawFn, + max_groups: int = 32, + max_stride: int = 32, + min_groups: int = 1, + ) -> tuple[int, int]: + assert max_stride <= max_groups + + # Sample the number of groups owned by each shard. + group_stride = draw(hps.integers(min_value=1, max_value=max_stride)) + + # Sample the number of groups as a multiple of the stride to ensure that we + # have an equal number of groups per shard. Round down s.t. num_groups <= + # max_groups. + num_groups = group_stride * draw( + hps.integers(min_value=min_groups, max_value=max_groups // group_stride) + ) + return num_groups, group_stride + + + @hps.composite + def group_sizes_strategy( + draw: hps.DrawFn, m: int, num_groups: int + ) -> jnp.ndarray: + # Randomly sample the ends of the groups in the m-dimension. Let the fuzzer + # sample with replacement so that it's possible to get zero-sized groups. Get + # 'num_groups - 1' run ends. The final group will end at 'm'. + ends_no_final = np.sort( + np.array( + [ + draw(hps.integers(min_value=0, max_value=m)) + for _ in range(num_groups - 1) + ], + dtype=np.int32, + ), + ) + ends = np.concatenate([ends_no_final, np.array([m], dtype=np.int32)]) + + # Calculate the run starts by shifting ends 1 to the right. The first run + # starts at zero. + starts = np.concatenate([np.zeros(1, dtype=np.int32), ends_no_final]) + return jnp.array(ends - starts, dtype=jnp.int32) + + + GROUPED_MATMUL_TESTS = ( + (128, 128, 128), + (256, 128, 128), + (128, 256, 128), + (128, 128, 256), + (256, 128, 512), + (512, 128, 128), + (512, 2048, 128), + (128, 8, 16), # Test partial tiles. + ) + + + def random_dense( + shape: tuple[int, ...], + key: jax.Array, + dtype: jnp.dtype, + limit: int | None = None, + ) -> jnp.ndarray: + if limit is None: + limit = 1 / np.prod(shape) + x = jax.random.uniform(key, shape, dtype, minval=-limit, maxval=limit) # pylint: disable=invalid-unary-operand-type + return x.astype(jnp.bfloat16).astype(dtype) + + + def dot( + lhs: jnp.ndarray, + rhs: jnp.ndarray, + transpose_lhs: bool = False, + transpose_rhs: bool = False, + preferred_element_type: jnp.dtype = jnp.float32, + ) -> jnp.ndarray: + lhs = jnp.transpose(lhs) if transpose_lhs else lhs + rhs = jnp.transpose(rhs) if transpose_rhs else rhs + return jax.lax.dot(lhs, rhs, preferred_element_type=preferred_element_type) + + + def reference_gmm( + lhs: jnp.ndarray, + rhs: jnp.ndarray, + group_sizes: jnp.ndarray, + preferred_element_type: jnp.dtype = jnp.float32, + ) -> jnp.ndarray: + + start = 0 + out = [] + for i, size in enumerate(group_sizes): + result = dot( + lhs[start : start + size, :], + rhs[i, :, :], + preferred_element_type=preferred_element_type, + ) + + out.append(result) + start += group_sizes[i] + return jnp.concatenate(out, axis=0) + + + def with_dtype_arguments(xs: tuple[Any, ...]) -> tuple[Any, ...]: + dtypes = [jnp.float32, jnp.bfloat16] + + result = [] + for x in xs: + for dtypes_tuple in itertools.product(dtypes, dtypes, dtypes): + result.append(x + dtypes_tuple) + return tuple(result) + + + def with_transpose_argument(xs: tuple[Any, ...]) -> tuple[Any, ...]: + flags = [False, True] + result = [] + for x in xs: + for flag in flags: + result.append(x + (flag,)) + return tuple(result) + + + def tolerances( + lhs_dtype: jnp.dtype, rhs_dtype: jnp.dtype, out_dtype: jnp.dtype + ) -> tuple[float, float]: + if ( + lhs_dtype == jnp.bfloat16 + or rhs_dtype == jnp.bfloat16 + or out_dtype == jnp.bfloat16 + ): + return 1e-3, 1e-2 # atol, rtol + return 1e-3, 1e-5 # atol, rtol + + + # TODO(tgale): Fix errors with strict dtype promotion. + @jtu.with_config(jax_numpy_dtype_promotion="standard") + class GroupedMatmulTest(jtu.JaxTestCase): + + def setUp(self): + if not jtu.test_device_matches(["tpu"]): + self.skipTest("Test requires TPU device.") + + super().setUp() + self.key = jax.random.PRNGKey(1234) + + def assert_allclose( + self, + out: jnp.ndarray, + expected_out: jnp.ndarray, + *, + atol: float = 1e-5, + rtol: float = 1e-5, + ): + self.assertEqual(out.dtype, expected_out.dtype) + np.testing.assert_allclose( + out.astype(jnp.float32), + expected_out.astype(jnp.float32), + atol=atol, + rtol=rtol, + ) + + def gmm_test( + self, + m: int, + k: int, + n: int, + lhs_dtype: jnp.dtype, + rhs_dtype: jnp.dtype, + out_dtype: jnp.dtype, + transpose_rhs: bool, + data: hps.SearchStrategy[hps.DataObject], + interpret: bool = False, + ): + seed = data.draw(seed_strategy()) + num_groups, _ = data.draw(group_strategy(max_stride=1)) + + key = jax.random.key(seed) + k1, k2 = jax.random.split(key, 2) + lhs = random_dense((m, k), k1, lhs_dtype, limit=1) + rhs = random_dense((num_groups, k, n), k2, rhs_dtype, limit=1) + group_sizes = data.draw(group_sizes_strategy(m=m, num_groups=num_groups)) + + out, vjpfun = jax.vjp( + partial( + mblx.gmm, + preferred_element_type=out_dtype, + transpose_rhs=transpose_rhs, + interpret=interpret, + ), + lhs, + rhs.swapaxes(1, 2) if transpose_rhs else rhs, + group_sizes, + ) + + def reference_fn(lhs, rhs, group_sizes, preferred_element_type): + rhs = rhs.swapaxes(1, 2) if transpose_rhs else rhs + return reference_gmm( + lhs, rhs, group_sizes, preferred_element_type=preferred_element_type + ) + + expected_out, reference_vjpfun = jax.vjp( + partial(reference_fn, preferred_element_type=out_dtype), + lhs, + rhs.swapaxes(1, 2) if transpose_rhs else rhs, + group_sizes, + ) + self.assertEqual(out.dtype, out_dtype) + self.assertEqual(expected_out.dtype, out_dtype) + + atol, rtol = tolerances(lhs_dtype, rhs_dtype, out_dtype) + self.assert_allclose(out, expected_out, atol=atol, rtol=rtol) + + cotangent = random_dense((m, n), k1, out_dtype, limit=1) + grad_lhs, grad_rhs, *_ = vjpfun(cotangent) + expected_grad_lhs, expected_grad_rhs, *_ = reference_vjpfun(cotangent) + self.assert_allclose(grad_lhs, expected_grad_lhs, atol=atol, rtol=rtol) + self.assert_allclose(grad_rhs, expected_grad_rhs, atol=atol, rtol=rtol) + + @parameterized.parameters( + *with_transpose_argument(with_dtype_arguments(GROUPED_MATMUL_TESTS)) + ) + @hp.given(hps.data()) + def test_gmm( + self, + m: int, + k: int, + n: int, + lhs_dtype: jnp.dtype, + rhs_dtype: jnp.dtype, + out_dtype: jnp.dtype, + transpose_rhs: bool, + data: hps.SearchStrategy[hps.DataObject], + ): + self.gmm_test(m, k, n, lhs_dtype, rhs_dtype, out_dtype, transpose_rhs, data) + + # NOTE: Run fewer tests with interpret mode. We just want to sanity check that + # changes do not break running these kernels with interpret=True. + @parameterized.parameters(*with_dtype_arguments(GROUPED_MATMUL_TESTS[0:1])) + @hp.given(hps.data()) + def test_gmm_interpret( + self, + m: int, + k: int, + n: int, + lhs_dtype: jnp.dtype, + rhs_dtype: jnp.dtype, + out_dtype: jnp.dtype, + data: hps.SearchStrategy[hps.DataObject], + ): + self.skipTest("interpret mode with dynamic grids is unsupported") + self.gmm_test( + m, + k, + n, + lhs_dtype, + rhs_dtype, + out_dtype, + transpose_rhs=False, + data=data, + interpret=True, + ) + + @parameterized.parameters(*with_dtype_arguments(GROUPED_MATMUL_TESTS)) + @hp.given(hps.data()) + def test_gmm_sharded_groups( + self, + m: int, + k: int, + n: int, + lhs_dtype: jnp.dtype, + rhs_dtype: jnp.dtype, + out_dtype: jnp.dtype, + data: hps.SearchStrategy[hps.DataObject], + ): + seed = data.draw(seed_strategy()) + num_groups, group_stride = data.draw(group_strategy()) + + key = jax.random.key(seed) + k1, k2 = jax.random.split(key, 2) + lhs = random_dense((m, k), k1, lhs_dtype, limit=1) + rhs = random_dense((num_groups, k, n), k2, rhs_dtype, limit=1) + group_sizes = data.draw(group_sizes_strategy(m=m, num_groups=num_groups)) + + out, shard_vjpfun = jax.vjp( + partial(mblx.gmm, preferred_element_type=out_dtype), + lhs, + rhs[0:group_stride], + group_sizes, + ) + vjpfuns = [shard_vjpfun] + for group_offset in range(group_stride, num_groups, group_stride): + out, shard_vjpfun = jax.vjp( + lambda lhs, rhs, group_sizes, out: mblx.gmm( + lhs, + rhs, + group_sizes, + out_dtype, + group_offset=jnp.array(group_offset, dtype=jnp.int32), # pylint: disable=cell-var-from-loop + existing_out=out, + ), + lhs, + rhs[group_offset : group_offset + group_stride], + group_sizes, + out, + ) + vjpfuns.append(shard_vjpfun) + + expected_out, reference_vjpfun = jax.vjp( + partial(reference_gmm, preferred_element_type=out_dtype), + lhs, + rhs, + group_sizes, + ) + self.assertEqual(out.dtype, out_dtype) + self.assertEqual(expected_out.dtype, out_dtype) + atol, rtol = tolerances(lhs_dtype, rhs_dtype, out_dtype) + self.assert_allclose(out, expected_out, atol=atol, rtol=rtol) + + cotangent = random_dense((m, n), k1, out_dtype, limit=1) + shard_grad_lhs, shard_grad_rhs, *_ = vjpfuns[0](cotangent) + grad_lhs = shard_grad_lhs + grad_rhs = [shard_grad_rhs] + for i, group_offset in enumerate( + range(group_stride, num_groups, group_stride) + ): + shard_grad_lhs, shard_grad_rhs, *_ = vjpfuns[i + 1](cotangent) + grad_lhs += shard_grad_lhs + grad_rhs.append(shard_grad_rhs) + grad_rhs = jnp.concatenate(grad_rhs, axis=0) + expected_grad_lhs, expected_grad_rhs, *_ = reference_vjpfun(cotangent) + self.assert_allclose(grad_lhs, expected_grad_lhs, atol=atol, rtol=rtol) + self.assert_allclose(grad_rhs, expected_grad_rhs, atol=atol, rtol=rtol) + + +if __name__ == "__main__": + absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/pallas/gpu_attention_test.py b/tests/pallas/gpu_attention_test.py index 3ff740227a24..95688474a099 100644 --- a/tests/pallas/gpu_attention_test.py +++ b/tests/pallas/gpu_attention_test.py @@ -13,7 +13,6 @@ # limitations under the License. import os -import unittest from absl.testing import absltest from absl.testing import parameterized @@ -28,32 +27,20 @@ os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.5" -try: - from jax.experimental.pallas import gpu as plgpu -except ImportError: - pass # pylint: disable=no-value-for-parameter -config.update("jax_traceback_filtering", "off") config.parse_flags_with_absl() -class PallasTest(jtu.JaxTestCase): - - def check_gpu_capability_at_least(self, capability, device: int = 0): - return plgpu.get_compute_capability(device) >= capability +@jtu.with_config(jax_traceback_filtering="off") +class DecodeAttentionTest(jtu.JaxTestCase): def setUp(self): - if not jtu.test_device_matches(["gpu"]): - self.skipTest("Only works on GPU") - try: - import triton # noqa: F401 - except ImportError: - self.skipTest("Triton is not installed. Skipping PallasTest.") - super().setUp() + if not jtu.is_cuda_compute_capability_at_least("8.0"): + self.skipTest("Fused attention only works on GPUs with capability >= sm80") -class DecodeAttentionTest(PallasTest): + super().setUp() @parameterized.named_parameters(*[ ( @@ -86,10 +73,6 @@ def test_mqa( kwargs, ): del kwargs - if not self.check_gpu_capability_at_least(80): - raise unittest.SkipTest( - "Fused attention only works on GPUs with capability >= sm80" - ) k1, k2, k3 = random.split(random.key(0), 3) q = random.normal(k1, (batch_size, num_heads, head_dim), dtype=jnp.float16) @@ -134,10 +117,6 @@ def test_gqa( kwargs, ): del kwargs - if not self.check_gpu_capability_at_least(80): - raise unittest.SkipTest( - "Fused attention only works on GPUs with capability >= sm80" - ) k1, k2, k3 = random.split(random.key(0), 3) q = random.normal( diff --git a/tests/pallas/indexing_test.py b/tests/pallas/indexing_test.py index d66991bd986a..da4da6c7254f 100644 --- a/tests/pallas/indexing_test.py +++ b/tests/pallas/indexing_test.py @@ -19,11 +19,14 @@ import unittest from absl.testing import absltest +from absl.testing import parameterized import jax from jax._src import test_util as jtu from jax._src import util from jax._src.state import indexing import numpy as np +import jax.numpy as jnp +from jax.experimental import pallas as pl try: import hypothesis as hp @@ -97,25 +100,18 @@ def test_simple_ndindexer(self): def test_invalid_ndindexer(self): indices = (0, 0, 0) shape = (5, 5) - with self.assertRaises(ValueError): + with self.assertRaisesRegex( + ValueError, "`indices` must not be longer than `shape`" + ): _ = NDIndexer.from_indices_shape(indices, shape) - def test_invalid_ndindexer_oob_int(self): - indices = (4, 0) - shape = (3, 5) - with self.assertRaises(ValueError): - _ = NDIndexer.from_indices_shape(indices, shape) - - def test_invalid_ndindexer_oob_slice_start(self): - indices = (slice(3, 2), 0) - shape = (3, 5) - with self.assertRaises(ValueError): - _ = NDIndexer.from_indices_shape(indices, shape) - - def test_invalid_ndindexer_oob_slice_end(self): - indices = (Slice(2, 2), 0) - shape = (3, 5) - with self.assertRaises(ValueError): + @parameterized.parameters( + ((4, 0), (3, 5)), + ((slice(3, 2), 0), (3, 5)), + ((Slice(2, 2), 0), (3, 5)), + ) + def test_invalid_ndindexer_oob(self, indices, shape): + with self.assertRaisesRegex(ValueError, "Out of bound"): _ = NDIndexer.from_indices_shape(indices, shape) def test_ndindexer_with_padding(self): @@ -124,6 +120,12 @@ def test_ndindexer_with_padding(self): indexer = NDIndexer.from_indices_shape(indices, shape) self.assertTupleEqual(indexer.get_indexer_shape(), shape) + def test_ndindexer_with_ellipsis(self): + indices = (..., 4) + shape = (5, 5) + indexer = NDIndexer.from_indices_shape(indices, shape) + self.assertTupleEqual(indexer.get_indexer_shape(), (5,)) + def test_ndindexer_with_slices(self): indices = (slice(2, 3), slice(4, 7)) shape = (5, 6) @@ -152,6 +154,14 @@ def test_ndindexer_with_arrays_and_broadcasting(self): indexer = NDIndexer.from_indices_shape(indices, shape) self.assertTupleEqual(indexer.get_indexer_shape(), (10, 20)) + def test_ndindexer_with_arrays_and_invalid_broadcasting(self): + indices = (np.arange(10)[None], np.arange(20)[None, :]) + shape = (5, 5) + with self.assertRaisesRegex( + ValueError, "Cannot broadcast shapes for indexing" + ): + indexer = NDIndexer.from_indices_shape(indices, shape) + def test_indexer_with_all_types(self): indices = (0, slice(10), np.arange(5)) shape = (2, 3, 4) @@ -197,5 +207,108 @@ def test_ndindexer(self, data): indexer.get_indexer_shape()) + def test_multi_indexing_interpreter_only(self): + # Interpreter only test! YMMV actually compiling this. + def permute(left, right, left_out_ref, right_out_ref): + left_out = jnp.zeros_like(left) + left_out = left_out.at[:, 0].set(left[:, 0]) + left_out = left_out.at[:, 1].set(right[:, 0]) + left_out = left_out.at[:, 2:].set(left[:, 1:-1]) + + right_out = jnp.zeros_like(right) + right_out = right_out.at[:, :-1].set(right[:, 1:]) + right_out = right_out.at[:, -1].set(left[:, -1]) + + left_out_ref[...] = left_out + right_out_ref[...] = right_out + + def invoke_permutes(x_ref, y_ref, x_out_ref, y_out_ref): + shape = x_ref.shape + _, n = shape[-2], shape[-1] + x_ref = x_ref.at[: n // 2, : n // 2] + y_ref = y_ref.at[: n // 2, : n // 2] + x_out_ref = x_out_ref.at[: n // 2, : n // 2] + y_out_ref = y_out_ref.at[: n // 2, : n // 2] + permute(x_ref, y_ref, x_out_ref, y_out_ref) + + n = 8 + x = jnp.ones([n, n]) + y = jnp.ones([n, n]) + jitted_permute = jax.jit(invoke_permutes) + grid = (1,) + pl.pallas_call( + jitted_permute, + grid=grid, + out_shape=[ + jax.ShapeDtypeStruct(x.shape, x.dtype), + jax.ShapeDtypeStruct(x.shape, y.dtype), + ], + in_specs=[ + pl.BlockSpec(x.shape, lambda i: (0, 0)), + pl.BlockSpec(y.shape, lambda i: (0, 0)), + ], + out_specs=[ + pl.BlockSpec(x.shape, lambda i: (0, 0)), + pl.BlockSpec(y.shape, lambda i: (0, 0)), + ], + interpret=True, + )(x, y) + + def test_ellipsis_indexing_iterpret_only(self): + # Interpreter only test! YMMV actually compiling this. + def permute_columns_in_row_kernel(left, right, new_left, new_right): + shape = left.shape + k = shape[-1] + ndim = len(shape) + left_slices = [ + left[..., :1], + right[..., :1], + left[..., 1:k-1] + ] + right_slices = [ + right[..., 1:k], + left[..., k-1:k] + ] + new_left[...] = np.concatenate(left_slices, axis=ndim - 1) + new_right[...] = np.concatenate(right_slices, axis=ndim - 1) + + left = jnp.array([[1, 2, 3], [4, 5, 6]], dtype=jnp.float32) + right = jnp.array([[7, 8, 9], [10, 11, 12]], dtype=jnp.float32) + + output_shape = left.shape + + # hack to reuse the same fn for np cat + import jax.numpy as np # noqa: F811 + left_out, right_out = pl.pallas_call( + permute_columns_in_row_kernel, + grid=(1,), + out_shape=[ + jax.ShapeDtypeStruct(output_shape, jnp.float32), + jax.ShapeDtypeStruct(output_shape, jnp.float32) + ], + in_specs=[ + pl.BlockSpec(left.shape, lambda i: (0, 0)), + pl.BlockSpec(right.shape, lambda i: (0, 0)) + ], + out_specs=[ + pl.BlockSpec(output_shape, lambda i: (0, 0)), + pl.BlockSpec(output_shape, lambda i: (0, 0)) + ], + interpret=True, + )(left, right) + + + import numpy as np # noqa: F811 + left_np = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float32) + right_np = np.array([[7, 8, 9], [10, 11, 12]], dtype=np.float32) + left_out_np = left_np.copy() + right_out_np = right_np.copy() + + + permute_columns_in_row_kernel(left_np, right_np, left_out_np, right_out_np) + np.testing.assert_array_equal(left_out_np, left_out) + np.testing.assert_array_equal(right_out_np, right_out) + + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py new file mode 100644 index 000000000000..57b38e8dc305 --- /dev/null +++ b/tests/pallas/mosaic_gpu_test.py @@ -0,0 +1,118 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import functools +from absl.testing import absltest +from absl.testing import parameterized +import jax +from jax._src import test_util as jtu +from jax.experimental import pallas as pl +import jax.numpy as jnp +import numpy as np + + +jax.config.parse_flags_with_absl() + + +class PallasTest(jtu.JaxTestCase): + + def setUp(self): + if not jtu.is_cuda_compute_capability_at_least("9.0"): + self.skipTest("Only works on a GPU with capability >= sm90") + + super().setUp() + + +class PallasCallTest(PallasTest): + + def test_add_one(self): + @functools.partial( + pl.pallas_call, + out_shape=jax.ShapeDtypeStruct([256], jnp.float32), + ) + def add_one(x_ref, o_ref): + o_ref[...] = x_ref[...] + 1.0 + + x = jnp.arange(256).astype(jnp.float32) + np.testing.assert_array_equal(add_one(x), x + 1.0) + + @parameterized.product(input_factor=[0.001, 1, 10, 100, 100]) + def test_layer_norm(self, input_factor): + eps = 1e-5 + gamma = 1.0 + beta = 1.0 + + @functools.partial( + pl.pallas_call, + out_shape=jax.ShapeDtypeStruct([256], jnp.float32), + compiler_params={"smem_scratch_bytes": 4 * 4}, + ) + def layer_norm(x_ref, o_ref): + x_mean = jnp.mean(x_ref[...]) + x_centered = x_ref[...] - x_mean + o_ref[...] = ( + x_centered * jax.lax.rsqrt(jnp.mean(x_centered**2) + eps) * gamma + + beta + ) + + def layer_norm_np(x): + x_mean = np.mean(x) + x_centered = x - x_mean + return (x_centered / np.sqrt(np.mean(x_centered**2) + eps) * gamma) + beta + + # Ones are always fully precise + x = jnp.ones((256,)).astype(jnp.float32) * input_factor + np.testing.assert_allclose(layer_norm(x), layer_norm_np(x)) + + # random (and anything else is not) + x = ( + jax.random.uniform(jax.random.key(42), shape=(256,), dtype=jnp.float32) + * input_factor + ) + # TODO(cperivol): find out why in this particular case we have a small-ish error. + rtol = 1e-07 if input_factor > 10 else 5e-5 + np.testing.assert_allclose(layer_norm(x), layer_norm_np(x), rtol=rtol) + + def test_print(self): + @functools.partial( + pl.pallas_call, + out_shape=jax.ShapeDtypeStruct([256], jnp.float32), + ) + def kernel(x_ref, o_ref): + del x_ref, o_ref + pl.debug_print("It works!") + + x = jnp.arange(256).astype(jnp.float32) + with jtu.capture_stdout() as output: + jax.block_until_ready(kernel(x)) + + self.assertEqual(output(), "It works!\n") + + def test_print_with_values(self): + @functools.partial( + pl.pallas_call, + out_shape=jax.ShapeDtypeStruct([256], jnp.float32), + ) + def kernel(x_ref, o_ref): + del o_ref + pl.debug_print("x[0] = {}", x_ref[0]) + + x = jnp.arange(256).astype(jnp.float32) + with self.assertRaises(Exception): + # TODO(slebedev): Remove assertRaises() once we support indexing. + kernel(x) + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/pallas/ops_test.py b/tests/pallas/ops_test.py new file mode 100644 index 000000000000..78679a92b0fe --- /dev/null +++ b/tests/pallas/ops_test.py @@ -0,0 +1,82 @@ +# Copyright 2023 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for common JAX operations within pallas_call.""" + +import functools + +import numpy as np +from absl.testing import absltest +from absl.testing import parameterized + +import jax +import jax.numpy as jnp +from jax import lax +from jax._src import test_util as jtu +from jax.experimental import pallas as pl + +# Import mosaic for flag definitions +from jax.experimental import mosaic as _ # noqa: F401 + + +jax.config.parse_flags_with_absl() + + +class OpsTest(jtu.JaxTestCase): + INTERPRET = False + + def setUp(self): + if jax.config.x64_enabled: + self.skipTest("Only works in 32-bit") + if not self.INTERPRET: + if jtu.device_under_test() == "cpu": + self.skipTest("Only interpreter mode supported on CPU") + if (jtu.test_device_matches(["cuda"]) and + not jtu.is_cuda_compute_capability_at_least("8.0")): + self.skipTest("Only works on GPUs with capability >= sm80") + + super().setUp() + + @classmethod + def pallas_call(cls, *args, **kwargs): + return pl.pallas_call(*args, interpret=cls.INTERPRET, **kwargs) + + @parameterized.named_parameters( + (fn.__name__, fn, dtype) for fn, dtype in [ + (lax.pow, jnp.float32), + (lax.bitwise_and, jnp.int32), + (lax.bitwise_or, jnp.int32), + (lax.bitwise_xor, jnp.int32), + (lax.shift_left, jnp.int32), + (lax.shift_right_logical, jnp.int32), + ] + ) + def test_weak_dtype(self, fn, dtype): + @functools.partial( + self.pallas_call, out_shape=jax.ShapeDtypeStruct([1], dtype), + ) + def kernel(x_ref, o_ref): + o_ref[:] = fn(x_ref[:], y) + + x = jnp.array([4], dtype=dtype) + y = 2 if jnp.issubdtype(dtype, jnp.integer) else 2.0 + np.testing.assert_allclose(kernel(x), fn(x, y)) + + +class OpsInterpreterTest(OpsTest): + INTERPRET = True + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/pallas/paged_attention_kernel_test.py b/tests/pallas/paged_attention_kernel_test.py new file mode 100644 index 000000000000..b1dcb2fab5f8 --- /dev/null +++ b/tests/pallas/paged_attention_kernel_test.py @@ -0,0 +1,182 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from absl.testing import absltest +from absl.testing import parameterized +import jax +from jax._src import test_util as jtu +from jax.experimental.pallas.ops.tpu import paged_attention +from jax.experimental.pallas.ops.tpu.paged_attention import quantization_utils +import jax.numpy as jnp +import numpy as np + + +jax.config.parse_flags_with_absl() + + +def _generate_qkv( + seq_lens, + page_size, + max_seq_len, + num_kv_heads, + num_heads, + head_dim, + prng_key, + dtype=jnp.float32, + are_kv_quantized=False, +): + assert max_seq_len % page_size == 0 + pages_per_sequence = max_seq_len // page_size + batch_size = len(seq_lens) + total_pages = batch_size * pages_per_sequence + k1, k2, k3, k4 = jax.random.split(prng_key, 4) + k_pages = jax.random.normal( + k1, (num_kv_heads, total_pages, page_size, head_dim), dtype=dtype + ) + v_pages = jax.random.normal( + k2, (num_kv_heads, total_pages, page_size, head_dim), dtype=dtype + ) + + if are_kv_quantized: + k_pages = quantization_utils.quantize_to_int8(k_pages) + v_pages = quantization_utils.quantize_to_int8(v_pages) + + page_indices = jnp.arange(batch_size * pages_per_sequence, dtype=jnp.int32) + page_indices = jax.random.permutation(k3, page_indices, independent=True) + page_indices = page_indices.reshape(batch_size, pages_per_sequence) + q = jax.random.normal(k4, (batch_size, num_heads, head_dim), dtype=dtype) + return q, k_pages, v_pages, page_indices + + +def _reconstruct_kv(page_indices, pages): + if isinstance(pages, quantization_utils.QuantizedTensor): + pages = quantization_utils.unquantize_from_int8(pages, dtype=jnp.float32) + + batch_size = page_indices.shape[0] + num_heads, _, _, head_dim = pages.shape + + def per_sequence_page_gather(pages, page_indices): + return jnp.take(pages, page_indices, 1) + + gathered = jax.vmap(per_sequence_page_gather, in_axes=(None, 0))( + pages, page_indices + ) + return gathered.reshape(batch_size, num_heads, -1, head_dim) + + +def _grouped_query_attention_reference(q, k, v, lengths): + batch_size, num_heads, head_dim = q.shape + _, num_kv_heads, max_seq_len, _ = k.shape + assert k.shape == v.shape + assert num_heads % num_kv_heads == 0 + q = q.reshape(batch_size, num_kv_heads, num_heads // num_kv_heads, head_dim) + + if isinstance(k, quantization_utils.QuantizedTensor): + k = quantization_utils.unquantize_from_int8(k, dtype=jnp.float32) + if isinstance(v, quantization_utils.QuantizedTensor): + v = quantization_utils.unquantize_from_int8(v, dtype=jnp.float32) + + logits = jnp.einsum( + "bhgd,bhtd->bhgt", q.astype(jnp.float32), k.astype(jnp.float32) + ) + mask = jnp.arange(max_seq_len)[None] < lengths[:, None] + mask_value = -0.7 * float(np.finfo(np.dtype("float32")).max) + logits = logits + jnp.where(mask, 0.0, mask_value)[:, None, None, :] + weights = jax.nn.softmax(logits, axis=-1) + o = jnp.einsum("bhgt,bhtd->bhgd", weights.astype(v.dtype), v) + return o.reshape(batch_size, num_heads, head_dim) + + +def _megacore_enabled(): + return jax.devices()[0].device_kind == "TPU v4" or jtu.is_device_tpu( + version=5, variant="p" + ) + + +@jtu.with_config(jax_numpy_dtype_promotion="standard") +class PagedAttentionKernelTest(jtu.JaxTestCase): + @parameterized.product( + dtype=(jnp.float32, jnp.bfloat16), + page_size=(16, 32, 64), + num_kv_heads=(1, 8), + q_kv_head_ratio=(1, 4, 8), + head_dim=(128, 256), + megacore_mode=("batch", "kv_head", None), + are_kv_quantized=( + False, + True, + ), + ) + def test_paged_attention( + self, + dtype, + page_size, + num_kv_heads, + q_kv_head_ratio, + head_dim, + megacore_mode, + are_kv_quantized, + ): + if not jtu.is_device_tpu_at_least(4): + self.skipTest("Only supports TPU generation 4 or above") + if jtu.is_device_tpu(version=4) and are_kv_quantized: + # TPU v4 has only 16MiB of VMEM which is not sufficient to store both the + # weight and scale tensors for quantized tensors. When enabled on TPUv4, + # the tests sometimes failed with resource exhausted error. + self.skipTest("Quantization is not supported on TPU v4") + if megacore_mode and not _megacore_enabled(): + self.skipTest("Megacore is only available on TPU v4 or TPU v5p") + if num_kv_heads % 2 != 0 and megacore_mode == "kv_head": + self.skipTest("Skip kv_head megacore mode when num_kv_heads is odd") + max_kv_len = 2048 + block_size = 512 + seq_lens = np.asarray([0, 3, 256, 513, 1023, 2048]) + q, k_pages, v_pages, page_indices = _generate_qkv( + seq_lens, + page_size, + max_kv_len, + num_kv_heads, + num_kv_heads * q_kv_head_ratio, + head_dim, + jax.random.key(0), + dtype, + are_kv_quantized=are_kv_quantized, + ) + o = paged_attention.paged_attention( + q, + k_pages, + v_pages, + seq_lens, + page_indices, + pages_per_compute_block=block_size // page_size, + megacore_mode=megacore_mode, + ) + k = _reconstruct_kv(page_indices, k_pages) + v = _reconstruct_kv(page_indices, v_pages) + o_ref = _grouped_query_attention_reference(q, k, v, seq_lens) + + if q_kv_head_ratio > 1: + atol, rtol = 1e-2, 2e-2 + else: + atol, rtol = 2e-1, 1e-1 + np.testing.assert_allclose( + o[np.where(seq_lens > 0)].astype(jnp.float32), + o_ref[np.where(seq_lens > 0)].astype(jnp.float32), + atol=atol, + rtol=rtol, + ) + + +if __name__ == "__main__": + absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/pallas/pallas_call_tpu_test.py b/tests/pallas/pallas_call_tpu_test.py index c27e6468bb4b..92e5182061b2 100644 --- a/tests/pallas/pallas_call_tpu_test.py +++ b/tests/pallas/pallas_call_tpu_test.py @@ -14,11 +14,16 @@ """Test TPU-specific extensions to pallas_call.""" +import contextlib import functools +import io +import re +import sys from absl.testing import absltest from absl.testing import parameterized import jax from jax import lax +from jax._src import checkify from jax._src import state from jax._src import test_util as jtu from jax._src.interpreters import partial_eval as pe @@ -40,15 +45,25 @@ partial = functools.partial +@contextlib.contextmanager +def string_stdout(): + """Redirects stdout to a string.""" + initial_stdout = sys.stdout + stringio = io.StringIO() + sys.stdout = stringio + yield stringio + sys.stdout = initial_stdout + class PallasTPUTest(jtu.JaxTestCase): interpret: bool = False def setUp(self): - super().setUp() if not self.interpret and jtu.device_under_test() != 'tpu': self.skipTest('Only interpret mode supported on non-TPU') + super().setUp() + def pallas_call(self, *args, **kwargs): return pl.pallas_call(*args, **kwargs, interpret=self.interpret) @@ -72,10 +87,11 @@ def _x_transform(i, s_ref): grid_spec=pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=1, in_specs=[ - pl.BlockSpec(_x_transform, (x.shape[0] // 8, x.shape[1])), + pl.BlockSpec((x.shape[0] // 8, x.shape[1]), _x_transform), ], - out_specs=pl.BlockSpec(lambda i, _: (i, 0), - (x.shape[0] // 8, x.shape[1])), + out_specs=pl.BlockSpec( + (x.shape[0] // 8, x.shape[1]), lambda i, _: (i, 0) + ), grid=8, ), interpret=self.interpret, @@ -117,10 +133,11 @@ def f(x): grid_spec=pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=1, in_specs=[ - pl.BlockSpec(_x_transform, (x.shape[0] // 8, x.shape[1])), + pl.BlockSpec((x.shape[0] // 8, x.shape[1]), _x_transform), ], - out_specs=pl.BlockSpec(lambda i, _: (i, 0), - (x.shape[0] // 8, x.shape[1])), + out_specs=pl.BlockSpec( + (x.shape[0] // 8, x.shape[1]), lambda i, _: (i, 0) + ), grid=8, ), interpret=self.interpret, @@ -150,9 +167,9 @@ def _o_transform(i, _, s2_ref): grid_spec=pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=2, in_specs=[ - pl.BlockSpec(_x_transform, (8, 128)), + pl.BlockSpec((8, 128), _x_transform), ], - out_specs=pl.BlockSpec(_o_transform, (8, 128)), + out_specs=pl.BlockSpec((8, 128), _o_transform), grid=8, ), interpret=self.interpret, @@ -200,10 +217,10 @@ def single_inst(i, _): ], grid_spec=pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=1, - in_specs=[pl.BlockSpec(lambda i, *_: (i, 0), (8, 128))], + in_specs=[pl.BlockSpec((8, 128), lambda i, *_: (i, 0))], out_specs=[ - pl.BlockSpec(lambda i, *_: (i, 0), (8, 128)), - pl.BlockSpec(lambda *_: (0, 0), (8, 128)), + pl.BlockSpec((8, 128), lambda i, *_: (i, 0)), + pl.BlockSpec((8, 128), lambda *_: (0, 0)), ], grid=8, ), @@ -237,7 +254,7 @@ def loop_body(i, carry): out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32), grid_spec=pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=1, - out_specs=pl.BlockSpec(lambda *_: (0, 0), (8, 128)), + out_specs=pl.BlockSpec((8, 128), lambda *_: (0, 0)), grid=1, ), interpret=self.interpret, @@ -261,20 +278,23 @@ def _x_transform(i, s_ref): s = s[None] x = x[None] - out = jax.vmap(pl.pallas_call( - body, - out_shape=jax.ShapeDtypeStruct(x.shape[1:], x.dtype), - grid_spec=pltpu.PrefetchScalarGridSpec( - num_scalar_prefetch=1, - in_specs=[ - pl.BlockSpec(_x_transform, (x.shape[1] // 8, x.shape[2])), - ], - out_specs=pl.BlockSpec(lambda i, _: (i, 0), - (x.shape[1] // 8, x.shape[2])), - grid=8, - ), - interpret=self.interpret, - ))(s, x) + out = jax.vmap( + pl.pallas_call( + body, + out_shape=jax.ShapeDtypeStruct(x.shape[1:], x.dtype), + grid_spec=pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=1, + in_specs=[ + pl.BlockSpec((x.shape[1] // 8, x.shape[2]), _x_transform), + ], + out_specs=pl.BlockSpec( + (x.shape[1] // 8, x.shape[2]), lambda i, _: (i, 0) + ), + grid=8, + ), + interpret=self.interpret, + ) + )(s, x) np.testing.assert_allclose( out, x.reshape((1, 8, 8, -1))[:, s].reshape(x.shape) ) @@ -284,34 +304,40 @@ def body(_, x_ref, o_ref): o_ref[...] = x_ref[...] s = jnp.array([4, 3, 2, 5, 3, 5, 2, 7], jnp.int32) - x = jnp.arange(8 * 8 * 128, dtype=jnp.int32).reshape((8 * 8, 128)) + x = jnp.arange(2 * 8 * 8 * 128, dtype=jnp.int32).reshape((2, 8 * 8, 128)) def _x_transform(i, s_ref): s = pl.load(s_ref, (i,)) return (s, 0) s = jnp.tile(s[None], [2, 1]) - x = jnp.tile(x[None], [2, 1, 1]) - - with self.assertRaises(NotImplementedError): - jax.vmap( - pl.pallas_call( - body, - out_shape=jax.ShapeDtypeStruct(x.shape[1:], x.dtype), - grid_spec=pltpu.PrefetchScalarGridSpec( - num_scalar_prefetch=1, - in_specs=[ - pl.BlockSpec(_x_transform, (x.shape[1] // 8, x.shape[2])), - ], - out_specs=pl.BlockSpec( - lambda i, _: (i, 0), (x.shape[1] // 8, x.shape[2]) - ), - grid=8, + + @jax.jit + @jax.vmap + def kernel(s, x): + return pl.pallas_call( + body, + out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype), + grid_spec=pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=1, + in_specs=[ + pl.BlockSpec((x.shape[0] // 8, x.shape[1]), _x_transform), + ], + out_specs=pl.BlockSpec( + (x.shape[0] // 8, x.shape[1]), lambda i, _: (i, 0) ), - interpret=self.interpret, - ) + grid=8, + ), + interpret=self.interpret, + compiler_params=dict(mosaic=dict(allow_input_fusion=[False, True])), )(s, x) + first = x[0, ...].reshape((1, 8, 8, -1))[:, s[0, ...]].reshape(x.shape[1:]) + second = x[1, ...].reshape((1, 8, 8, -1))[:, s[1, ...]].reshape(x.shape[1:]) + + expected = jnp.stack([first, second]) + np.testing.assert_allclose(kernel(s, x), expected) + class PallasCallScalarPrefetchInterpretTest(PallasCallScalarPrefetchTest): interpret: bool = True @@ -319,6 +345,46 @@ class PallasCallScalarPrefetchInterpretTest(PallasCallScalarPrefetchTest): class PallasCallDynamicGridTest(PallasTPUTest): + def test_can_query_grid_statically_via_num_programs(self): + + def kernel(_): + num_programs = pl.num_programs(0) + self.assertIsInstance(num_programs, int) + self.assertEqual(num_programs, 2) + + pl.pallas_call(kernel, out_shape=None, grid=(2,))() + + def test_can_query_grid_statically_via_num_programs_in_block_spec(self): + + def kernel(*_): + pass + + def x_index_map(_): + num_programs = pl.num_programs(0) + self.assertIsInstance(num_programs, int) + self.assertEqual(num_programs, 2) + return 0 + pl.pallas_call( + kernel, + in_specs=[pl.BlockSpec((8, 128), x_index_map)], + out_shape=None, + grid=(2,), + )(jnp.ones((8, 128))) + + def test_dynamic_grid_has_dynamic_size(self): + + def kernel(_): + num_programs = pl.num_programs(0) + self.assertIsInstance(num_programs, int, msg=type(num_programs)) + self.assertEqual(num_programs, 2) + num_programs = pl.num_programs(1) + self.assertIsInstance(num_programs, jax.Array) + + @jax.jit + def outer(x): + pl.pallas_call(kernel, out_shape=None, grid=(2, x))() + outer(2) + def test_dynamic_grid(self): shape = (8, 128) result_ty = jax.ShapeDtypeStruct(shape, jnp.float32) @@ -334,13 +400,37 @@ def dynamic_kernel(steps): return self.pallas_call( kernel, grid=(steps * 2,), - out_specs=pl.BlockSpec(lambda i: (0, 0), shape), + out_specs=pl.BlockSpec(shape, lambda i: (0, 0)), out_shape=result_ty, )() np.testing.assert_array_equal( dynamic_kernel(jnp.int32(4)), np.full(shape, 8.0, np.float32) ) + def test_dynamic_grid_overflow(self): + # If we pad statically the dynamic grid dims to max int32, then the product + # of this grid size will overflow int64 and can cause failing checks in XLA. + shape = (8, 128) + result_ty = jax.ShapeDtypeStruct(shape, jnp.float32) + + def kernel(y_ref): + @pl.when(sum(pl.program_id(i) for i in range(3)) == 0) + def _init(): + y_ref[...] = jnp.zeros_like(y_ref) + y_ref[...] += 1 + + @jax.jit + def dynamic_kernel(steps): + return self.pallas_call( + kernel, + grid=(steps * 2, steps + 1, 3), + out_specs=pl.BlockSpec(shape, lambda *_: (0, 0)), + out_shape=result_ty, + )() + np.testing.assert_array_equal( + dynamic_kernel(jnp.int32(4)), np.full(shape, 120.0, np.float32) + ) + # TODO(apaszke): Add tests for scalar_prefetch too def test_dynamic_grid_scalar_input(self): shape = (8, 128) @@ -355,7 +445,7 @@ def dynamic_kernel(steps): kernel, out_shape=result_ty, in_specs=[pl.BlockSpec(memory_space=pltpu.SMEM)], - out_specs=pl.BlockSpec(lambda i: (0, 0), shape), + out_specs=pl.BlockSpec(shape, lambda i: (0, 0)), grid=(steps * 2,), )(jnp.array([[42]], dtype=jnp.int32)) @@ -379,7 +469,8 @@ def dynamic_kernel(steps, x): return self.pallas_call( kernel, grid=(steps * 2,), - out_specs=pl.BlockSpec(lambda i: (0, 0), shape), + in_specs=[pl.BlockSpec(shape, lambda i: (0, 0))], + out_specs=pl.BlockSpec(shape, lambda i: (0, 0)), out_shape=result_ty, )(x) x = jnp.arange(8 * 128., dtype=jnp.float32).reshape((1, *shape)) @@ -405,11 +496,14 @@ def dynamic_kernel(steps): return self.pallas_call( kernel, grid=(steps * 2,), - out_specs=pl.BlockSpec(lambda i: (0, 0), shape), + out_specs=pl.BlockSpec(shape, lambda i: (0, 0)), out_shape=result_ty, )() - with self.assertRaises(NotImplementedError): - dynamic_kernel(jnp.array([4, 8], jnp.int32)) + out = dynamic_kernel(jnp.array([4, 8], jnp.int32)) + first = jnp.full(shape, fill_value=8.0, dtype=jnp.float32) + second = jnp.full(shape, fill_value=16.0, dtype=jnp.float32) + expected_out = jnp.stack([first, second], axis=0) + np.testing.assert_array_equal(out, expected_out) def test_vmap_dynamic_grid(self): shape = (8, 128) @@ -426,7 +520,7 @@ def dynamic_kernel(x, steps): return self.pallas_call( kernel, grid=(steps * 2,), - out_specs=pl.BlockSpec(lambda i: (0, 0), shape), + out_specs=pl.BlockSpec(shape, lambda i: (0, 0)), out_shape=result_ty, )(x) x = jnp.arange(4 * 8 * 128., dtype=jnp.float32).reshape((4, *shape)) @@ -448,7 +542,28 @@ def dynamic_kernel(steps): out_shape=jax.ShapeDtypeStruct((1, 1), jnp.int32), )() - self.assertEqual(dynamic_kernel(4), 8) + self.assertEqual(dynamic_kernel(np.int32(4)), 8) + + @parameterized.parameters(range(1, 4)) + def test_vmap_num_programs(self, num_vmaps): + result_ty = jax.ShapeDtypeStruct((8, 128), jnp.int32) + + def kernel(y_ref): + y_ref[...] = jnp.full_like(y_ref, pl.num_programs(0)) + + kernel_call = self.pallas_call( + kernel, + grid=(8,), + out_specs=pl.BlockSpec(result_ty.shape, lambda i: (0, 0)), + out_shape=result_ty, + ) + + out_shape = (*(2 for _ in range(num_vmaps)), *result_ty.shape) + f = kernel_call + for _ in range(num_vmaps): + f = lambda impl=f: jax.vmap(impl, axis_size=2)() + out = jax.jit(f)() + np.testing.assert_array_equal(out, np.full(out_shape, 8.0)) def test_num_programs_block_spec(self): def kernel(x_ref, y_ref): @@ -461,40 +576,41 @@ def dynamic_kernel(steps, x): grid=(steps * 2,), in_specs=[ pl.BlockSpec( + (8, 128), # Should always evaluate to (1, 0) lambda i: (1 + 8 - pl.num_programs(0), 0), - (8, 128), ) ], - out_specs=pl.BlockSpec(lambda i: (0, 0), (8, 128)), + out_specs=pl.BlockSpec((8, 128), lambda i: (0, 0)), out_shape=jax.ShapeDtypeStruct((8, 128), jnp.int32), )(x) x = np.arange(4 * 8 * 128., dtype=np.int32).reshape((4 * 8, 128)) - np.testing.assert_array_equal(dynamic_kernel(4, x), x[8:16]) + np.testing.assert_array_equal(dynamic_kernel(np.int32(4), x), x[8:16]) class PallasCallInterpretDynamicGridTest(PallasCallDynamicGridTest): interpret: bool = True -class PallasCallDMATest(parameterized.TestCase): +class PallasCallDMATest(PallasTPUTest): def setUp(self): - super().setUp() if not jtu.is_device_tpu_at_least(4): self.skipTest('DMAs not supported on TPU generations <= 3') + super().setUp() + def test_can_have_unspecified_memory_spaces(self): def kernel(x_ref, y_ref): # Just test whether things compile del x_ref, y_ref x = jnp.ones((8, 128), dtype=jnp.float32) - y = pl.pallas_call( + y = self.pallas_call( kernel, - in_specs=[pl.BlockSpec(None, None, pltpu.TPUMemorySpace.ANY)], - out_specs=pl.BlockSpec(None, None, pltpu.TPUMemorySpace.ANY), + in_specs=[pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY)], + out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32), )(x) jax.block_until_ready(y) @@ -526,7 +642,7 @@ def body(x_ref): pltpu.run_scoped(body, pltpu.VMEM((8, 128), jnp.float32)) - o = pl.pallas_call( + o = self.pallas_call( kernel, out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32), )() @@ -543,7 +659,7 @@ def inner_body(z_ref): y_ref[...] = 4 * x_ref[...] pltpu.run_scoped(body, pltpu.VMEM((8, 128), jnp.float32)) - o = pl.pallas_call( + o = self.pallas_call( kernel, out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32), )() @@ -555,7 +671,7 @@ def body(sem1): pass pltpu.run_scoped(body, pltpu.SemaphoreType.DMA) - jax.block_until_ready(pl.pallas_call( + jax.block_until_ready(self.pallas_call( kernel, out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32), )()) @@ -567,7 +683,7 @@ def body(sem1, sem2): pltpu.run_scoped(body, pltpu.SemaphoreType.DMA, pltpu.SemaphoreType.REGULAR) - jax.block_until_ready(pl.pallas_call( + jax.block_until_ready(self.pallas_call( kernel, out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32), )()) @@ -582,7 +698,7 @@ def body(dma_sems, sems): pltpu.run_scoped(body, pltpu.SemaphoreType.DMA((4,)), pltpu.SemaphoreType.REGULAR((3,))) - jax.block_until_ready(pl.pallas_call( + jax.block_until_ready(self.pallas_call( kernel, out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32), )()) @@ -594,6 +710,7 @@ def kernel(y_ref, dma_sems, sems): self.assertTrue(jnp.issubdtype(dma_sems.dtype, pltpu.dma_semaphore)) self.assertTrue(jnp.issubdtype(sems.dtype, pltpu.semaphore)) + # TODO(b/345534352): Add interpret support for REGULAR semaphore. jax.block_until_ready( pl.pallas_call( kernel, @@ -628,6 +745,7 @@ def body3(sem): pltpu.semaphore_wait(sem) pltpu.run_scoped(body3, pltpu.SemaphoreType.REGULAR) + # TODO(b/345534352): Add interpret support for semaphore signal/wait. jax.block_until_ready(pl.pallas_call( kernel, out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32), @@ -651,6 +769,7 @@ def body(sems): pltpu.semaphore_wait(sems.at[2]) pltpu.run_scoped(body, pltpu.SemaphoreType.REGULAR((3,))) + # TODO(b/345534352): Add interpret support for semaphore signal/wait. jax.block_until_ready(pl.pallas_call( kernel, out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32), @@ -675,14 +794,43 @@ def body(sems): pltpu.semaphore_wait(sems.at[i, 2]) pltpu.run_scoped(body, pltpu.SemaphoreType.REGULAR((4, 3))) - jax.block_until_ready(pl.pallas_call( - kernel, - in_specs=[], - out_specs=pl.BlockSpec(lambda i: (0, 0), (8, 128)), - out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32), - grid=4, - debug=True, - )()) + # TODO(b/345534352): Add interpret support for semaphore signal/wait. + jax.block_until_ready( + pl.pallas_call( + kernel, + in_specs=[], + out_specs=pl.BlockSpec((8, 128), lambda i: (0, 0)), + out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32), + grid=4, + debug=True, + )() + ) + + def test_can_read_semaphore(self): + m, n = 2, 3 + + def kernel(y_ref): + def body(sems): + for r in range(m): + for c in range(n): + v = r * n + c + pltpu.semaphore_signal(sems.at[r, c],v) + y_ref[r, c] = pltpu.semaphore_read(sems.at[r, c]) + pltpu.semaphore_wait(sems.at[r, c], v) + + pltpu.run_scoped(body, pltpu.SemaphoreType.REGULAR((m, n))) + + # TODO(b/345534352): Add interpret support for semaphore signal/wait. + y = jax.block_until_ready( + pl.pallas_call( + kernel, + out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM), + out_shape=jax.ShapeDtypeStruct((m, n), jnp.int32), + )() + ) + np.testing.assert_array_equal( + y, jnp.arange(m * n).astype(jnp.int32).reshape((m, n)) + ) def test_hbm_hbm_dma(self): def kernel(x_hbm_ref, y_hbm_ref): @@ -691,7 +839,7 @@ def body(sem): sem).wait() pltpu.run_scoped(body, pltpu.SemaphoreType.DMA) x = jnp.arange(8 * 128.).reshape((8, 128)) - y = pl.pallas_call( + y = self.pallas_call( kernel, in_specs=[ pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), @@ -707,6 +855,8 @@ def body(sem): pltpu.async_copy(x_hbm_ref.at[pl.ds(8), :], y_hbm_ref.at[:, pl.ds(128)], sem).wait() pltpu.run_scoped(body, pltpu.SemaphoreType.DMA((1,))) + + # TODO(b/345534352): Add interpret support for nonscalar semaphores. with self.assertRaisesRegex(ValueError, 'Cannot signal'): x = jnp.arange(8 * 128.).reshape((8, 128)) pl.pallas_call( @@ -725,6 +875,8 @@ def body(sem): sem.at[0]).wait() pltpu.run_scoped(body, pltpu.SemaphoreType.DMA((1,))) x = jnp.arange(8 * 128.).reshape((8, 128)) + + # TODO(b/345534352): Add interpret support for nonscalar semaphores. y = pl.pallas_call( kernel, in_specs=[ @@ -746,7 +898,7 @@ def body(sem): ).wait() pltpu.run_scoped(body, pltpu.SemaphoreType.DMA) x = jnp.arange(2 * 8 * 128.).reshape((2, 8, 128)) - y = pl.pallas_call( + y = self.pallas_call( kernel, in_specs=[ pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), @@ -766,7 +918,7 @@ def body(x_ref, sem): pltpu.run_scoped(body, pltpu.VMEM((8, 128), jnp.float32), pltpu.SemaphoreType.DMA) x = jnp.arange(8 * 128.).reshape((8, 128)) - y = pl.pallas_call( + y = self.pallas_call( kernel, in_specs=[ pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), @@ -783,7 +935,7 @@ def body(y_ref, sem): pltpu.run_scoped(body, pltpu.VMEM((8, 128), jnp.float32), pltpu.SemaphoreType.DMA) x = jnp.arange(8 * 128.).reshape((8, 128)) - y = pl.pallas_call( + y = self.pallas_call( kernel, out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32), @@ -801,7 +953,7 @@ def body(x_ref, y_ref, sem): pltpu.VMEM((8, 128), jnp.float32), pltpu.SemaphoreType.DMA) x = jnp.arange(8 * 128.).reshape((8, 128)) - y = pl.pallas_call( + y = self.pallas_call( kernel, in_specs=[pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY)], out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), @@ -817,7 +969,7 @@ def body(x_ref, sem): pltpu.run_scoped(body, pltpu.SMEM((8, 128), jnp.float32), pltpu.SemaphoreType.DMA) x = 4 * jnp.ones((8, 128), jnp.float32) - y = pl.pallas_call( + y = self.pallas_call( kernel, in_specs=[ pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), @@ -829,20 +981,21 @@ def body(x_ref, sem): def test_smem_hbm_dma(self): def kernel(x_ref, y_hbm_ref): def body(y_ref, sem): - y_ref[0, 0] = x_ref[4, 4] + y_ref[0, 0] = 0.0 + y_ref[0, 1] = x_ref[4, 4] pltpu.async_copy(y_ref, y_hbm_ref, sem).wait() - pltpu.run_scoped(body, pltpu.SMEM((8, 128), jnp.float32), + pltpu.run_scoped(body, pltpu.SMEM((1, 2), jnp.float32), pltpu.SemaphoreType.DMA) x = jnp.arange(8 * 128.).reshape((8, 128)) - y = pl.pallas_call( + y = self.pallas_call( kernel, in_specs=[ pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM), ], out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), - out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32), + out_shape=jax.ShapeDtypeStruct((1, 2), jnp.float32), )(x) - expected = jnp.zeros_like(x).at[0, 0].set(x[4, 4]) + expected = jnp.zeros_like(x[0:1, 0:2]).at[0, 1].set(x[4, 4]) np.testing.assert_allclose(y, expected) def test_vmem_vmem_dma(self): @@ -851,7 +1004,7 @@ def body(sem): pltpu.async_copy(x_ref, y_ref, sem).wait() pltpu.run_scoped(body, pltpu.SemaphoreType.DMA) x = jnp.arange(8 * 128.).reshape((8, 128)) - y = pl.pallas_call( + y = self.pallas_call( kernel, in_specs=[ pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM), @@ -874,7 +1027,7 @@ def body(sem): dma2.wait() pltpu.run_scoped(body, pltpu.SemaphoreType.DMA) x = jnp.arange(2 * 8 * 128.).reshape((16, 128)) - y = pl.pallas_call( + y = self.pallas_call( kernel, in_specs=[ pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), @@ -897,7 +1050,7 @@ def body(sem): dma2.wait() pltpu.run_scoped(body, pltpu.SemaphoreType.DMA) x = jnp.arange(2 * 8 * 128.).reshape((2, 8, 128)) - y = pl.pallas_call( + y = self.pallas_call( kernel, in_specs=[ pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), @@ -923,7 +1076,7 @@ def body(sem): dma2.wait() pltpu.run_scoped(body, pltpu.SemaphoreType.DMA) x = jnp.arange(3 * 2 * 8 * 128.).reshape((3, 2, 8, 128)) - y = pl.pallas_call( + y = self.pallas_call( kernel, in_specs=[ pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), @@ -947,7 +1100,7 @@ def body(sem): pltpu.run_scoped(body, pltpu.SemaphoreType.DMA) x = jnp.arange(2 * 8 * 128.).reshape((2, 8, 128)) with self.assertRaises(Exception): - _ = pl.pallas_call( + _ = self.pallas_call( kernel, in_specs=[ pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), @@ -978,10 +1131,10 @@ def _(): grid_spec=pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=0, in_specs=[ - pl.BlockSpec(lambda i: (0, 0), (8, 128)), + pl.BlockSpec((8, 128), lambda i: (0, 0)), ], scratch_shapes=[pltpu.VMEM((8, 128), jnp.float32)], - out_specs=pl.BlockSpec(lambda i: (0, 0), (8, 128)), + out_specs=pl.BlockSpec((8, 128), lambda i: (0, 0)), grid=(3,), ), interpret=interpret, @@ -1003,7 +1156,7 @@ def kernel(y_ref, scratch_ref): num_scalar_prefetch=0, in_specs=[], scratch_shapes=[pltpu.SMEM((1, 1), jnp.int32)], - out_specs=pl.BlockSpec(lambda i: (i, 0, 0), (None, 8, 128)), + out_specs=pl.BlockSpec((None, 8, 128), lambda i: (i, 0, 0)), grid=(2,), ), debug=True, @@ -1019,6 +1172,7 @@ def kernel(x_bbm_ref, y_ref, sem, dma_sem): pltpu.semaphore_wait(sem) pltpu.async_copy(x_bbm_ref, y_ref, dma_sem).wait() + # TODO(b/345534352): Add interpret support for semaphore signal/wait. x = jnp.arange(8 * 128.).reshape((8, 128)) y = pl.pallas_call( kernel, @@ -1043,7 +1197,7 @@ def test_large_array_indexing(self): def kernel(index, x, y, sem): pltpu.async_copy(x.at[index[0]], y.at[:], sem).wait() - run = pl.pallas_call(kernel, + run = self.pallas_call(kernel, grid_spec=pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=1, in_specs=[ @@ -1061,16 +1215,16 @@ def kernel(index, x, y, sem): np.testing.assert_array_equal(y, i) del y - class PallasCallRemoteDMATest(parameterized.TestCase): def setUp(self): - super().setUp() if jax.device_count() < 2: self.skipTest('Only >=2 devices are supported.') if not jtu.is_device_tpu_at_least(5): self.skipTest('Only works with TPU v5') + super().setUp() + @parameterized.named_parameters( ('vmem', pltpu.TPUMemorySpace.VMEM), ('hbm', pltpu.TPUMemorySpace.ANY), @@ -1270,10 +1424,11 @@ def body(x): class PallasCallTest(PallasTPUTest): def setUp(self): - super().setUp() if jtu.device_under_test() != 'tpu': self.skipTest('Test only works on TPU') + super().setUp() + def test_cost_analysis(self): def kernel(x, y): y[:] = x[:] @@ -1313,14 +1468,62 @@ def kernel(x_ref, y_ref): compiler_params=dict(mosaic=dict(vmem_limit_bytes=int(2**18))), )(x) + def test_allow_input_fusion(self): + shape = (3, 128, 128) + + def kernel(x_ref, y_ref): + y_ref[...] = x_ref[...] + + def f(x, y): + z = jax.numpy.add(x, y) + return pl.pallas_call( + kernel, + grid=(3,), + in_specs=[pl.BlockSpec((1, 128, 128), lambda i: (i, 0, 0))], + out_specs=pl.BlockSpec((1, 128, 128), lambda i: (i, 0, 0)), + out_shape=x, + compiler_params=dict(mosaic=dict(allow_input_fusion=[True])), + )(z) + + x = jnp.arange(np.prod(shape), dtype=np.float32).reshape(shape) + y = jnp.arange(np.prod(shape), dtype=np.float32).reshape(shape) + + out = f(x, y) + expected = x + y + np.testing.assert_array_equal(out, expected) + compiled = jax.jit(f).lower(x, y).compile().as_text() + assert re.search(r'fusion.*kind=kCustom.*fused_computation', compiled) + + def test_set_internal_scratch_size(self): + shape = (128, 128) + + def kernel(x_ref, y_ref): + y_ref[...] = x_ref[...] + + x = np.arange(np.prod(shape), dtype=np.float32).reshape(shape) + requested_bytes = 128 * 4 + with self.assertRaisesRegex( + Exception, + f'Requested internal scratch size {requested_bytes} needs to be at' + ' least', + ): + pl.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct(shape, jnp.float32), + compiler_params=dict( + mosaic=dict(internal_scratch_in_bytes=requested_bytes) + ), + )(x) + class PallasCallUnblockedIndexingTest(PallasTPUTest): def setUp(self): - super().setUp() if not self.interpret and jtu.device_under_test() != 'tpu': self.skipTest('Only interpret mode supported on non-TPU') + super().setUp() + def test_unblocked_indexing(self): shape = (16 * 8, 128) result_ty = jax.ShapeDtypeStruct((15 * 8, 128), jnp.float32) @@ -1334,10 +1537,10 @@ def kernel(x_ref, y_ref): grid=(15,), in_specs=( pl.BlockSpec( - lambda i: (i * 8, 0), (2 * 8, 128), indexing_mode=pl.unblocked + (2 * 8, 128), lambda i: (i * 8, 0), indexing_mode=pl.unblocked ), ), - out_specs=pl.BlockSpec(lambda i: (i, 0), (8, 128)), + out_specs=pl.BlockSpec((8, 128), lambda i: (i, 0)), out_shape=result_ty, interpret=self.interpret, )(x) @@ -1360,12 +1563,12 @@ def kernel(x_ref, y_ref): grid=(1,), in_specs=( pl.BlockSpec( - lambda i: (0, 0), (2 * 8, 128), + lambda i: (0, 0), indexing_mode=pl.Unblocked(((0, 8), (0, 0))), ), ), - out_specs=pl.BlockSpec(lambda i: (0, 0), (8, 128)), + out_specs=pl.BlockSpec((8, 128), lambda i: (0, 0)), out_shape=result_ty, interpret=self.interpret, )(x) @@ -1381,10 +1584,11 @@ class PallasCallInterpreterUnblockedIndexingTest( class PallasUXTest(PallasTPUTest): def setUp(self): - super().setUp() if jtu.device_under_test() != 'tpu': self.skipTest('Test only works on TPU') + super().setUp() + def test_mlir_location(self): # Make sure that MLIR locations are correctly propagated to primitives. args = (jax.ShapeDtypeStruct((8, 128), jnp.float32),) @@ -1404,10 +1608,11 @@ def capture_as_tpu_kernel(module, *args, **kwargs): class PallasCallInputOutputAliasingTest(PallasTPUTest): def setUp(self): - super().setUp() if not self.interpret and jtu.device_under_test() != 'tpu': self.skipTest('Only interpret mode supported on non-TPU') + super().setUp() + def test_basic_input_output_aliasing(self): # Input needs to be big so it doesn't fit in VMEM x = jnp.ones((32, 1024, 1024)) @@ -1420,15 +1625,15 @@ def f(x): return pl.pallas_call( kernel, out_shape=x, - in_specs=[pl.BlockSpec(lambda i: (i, 0, 0), (None, 1024, 1024))], - out_specs=pl.BlockSpec(lambda i: (i, 0, 0), (None, 1024, 1024)), + in_specs=[pl.BlockSpec((None, 1024, 1024), lambda i: (i, 0, 0))], + out_specs=pl.BlockSpec((None, 1024, 1024), lambda i: (i, 0, 0)), grid=(x.shape[0],), input_output_aliases={0: 0}, interpret=self.interpret, )(x) o = f(x) np.testing.assert_array_equal(o, expected) - compiled = f.lower(x).compile() + compiled = f.lower(jax.ShapeDtypeStruct(x.shape, x.dtype)).compile() mem_analysis = compiled.memory_analysis() expected_num_bytes = np.prod(x.shape) * x.dtype.itemsize self.assertEqual(mem_analysis.alias_size_in_bytes, expected_num_bytes) @@ -1447,16 +1652,20 @@ def f(x): out_shape=x, grid_spec=pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=1, - in_specs=[pl.BlockSpec(lambda i, _: (i, 0, 0), (None, 1024, 1024))], - out_specs=pl.BlockSpec(lambda i, _: (i, 0, 0), (None, 1024, 1024)), + in_specs=[ + pl.BlockSpec((None, 1024, 1024), lambda i, _: (i, 0, 0)) + ], + out_specs=pl.BlockSpec( + (None, 1024, 1024), lambda i, _: (i, 0, 0) + ), grid=(x.shape[0],), ), input_output_aliases={1: 0}, interpret=self.interpret, - )(jnp.array([1,2,3]), x) + )(jnp.array([1, 2, 3]), x) o = f(x) np.testing.assert_array_equal(o, expected) - compiled = f.lower(x).compile() + compiled = f.lower(jax.ShapeDtypeStruct(x.shape, x.dtype)).compile() mem_analysis = compiled.memory_analysis() expected_num_bytes = np.prod(x.shape) * x.dtype.itemsize self.assertEqual(mem_analysis.alias_size_in_bytes, expected_num_bytes) @@ -1470,10 +1679,11 @@ class PallasCallInterpreterInputOutputAliasingTest(PallasTPUTest): class PallasMegacoreTest(PallasTPUTest): def setUp(self): - super().setUp() if jtu.device_under_test() != 'tpu': self.skipTest('Test only works on TPU') + super().setUp() + def test_megacore_splitting(self): # We want to make sure a 3-sized dimension is split across megacore # correctly, and if we combine the (3, 3) dimensions together it is still @@ -1489,29 +1699,32 @@ def _(): x = jax.random.uniform(k1, (3, 3, 512, 512)) y = jax.random.uniform(k2, (3, 3, 512, 512)) - z = jax.vmap(jax.vmap( - pl.pallas_call( - matmul_kernel, - out_shape=jax.ShapeDtypeStruct((512, 512), jnp.float32), - grid=(4, 4, 4), - in_specs=[ - pl.BlockSpec(lambda i, j, k: (i, k), (128, 128)), - pl.BlockSpec(lambda i, j, k: (k, j), (128, 128)), - ], - out_specs=pl.BlockSpec(lambda i, j, k: (i, j), (128, 128)), - debug=True, + z = jax.vmap( + jax.vmap( + pl.pallas_call( + matmul_kernel, + out_shape=jax.ShapeDtypeStruct((512, 512), jnp.float32), + grid=(4, 4, 4), + in_specs=[ + pl.BlockSpec((128, 128), lambda i, j, k: (i, k)), + pl.BlockSpec((128, 128), lambda i, j, k: (k, j)), + ], + out_specs=pl.BlockSpec((128, 128), lambda i, j, k: (i, j)), + debug=True, + ) ) - ))(x, y) + )(x, y) np.testing.assert_allclose(z, jax.vmap(jax.vmap(jnp.dot))(x, y)) class PallasCallVmapTest(PallasTPUTest): def setUp(self): - super().setUp() if jtu.device_under_test() != 'tpu': self.skipTest('Test only works on TPU') + super().setUp() + def test_scratch_input_vmap(self): """Test that vmapp-ing a kernel with scratch inputs works correctly.""" @@ -1531,8 +1744,8 @@ def add_one_with_scratch(x_ref, o_ref, scratch_ref): out_shape=jax.ShapeDtypeStruct(array_shape, jnp.int32), grid_spec=pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=0, - in_specs=[pl.BlockSpec(lambda i, j: (i, j), tile_shape)], - out_specs=pl.BlockSpec(lambda i, j: (i, j), tile_shape), + in_specs=[pl.BlockSpec(tile_shape, lambda i, j: (i, j))], + out_specs=pl.BlockSpec(tile_shape, lambda i, j: (i, j)), scratch_shapes=[pltpu.VMEM(tile_shape, dtype=jnp.int32)], grid=(2, 2), ), @@ -1550,10 +1763,11 @@ def add_one_with_scratch(x_ref, o_ref, scratch_ref): class PallasCallControlFlowTest(PallasTPUTest): def setUp(self): - super().setUp() if jtu.device_under_test() != 'tpu': self.skipTest('Test only works on TPU') + super().setUp() + def test_nested_conds(self): def kernel(y_ref): def select(pred, x, y, nesting=0): @@ -1576,418 +1790,924 @@ def _false(): pl.pallas_call( kernel, grid=(1,), - out_specs=pl.BlockSpec(lambda i: (0, 0), (8, 128)), + out_specs=pl.BlockSpec((8, 128), lambda i: (0, 0)), out_shape=jax.ShapeDtypeStruct((8, 128), jnp.int32), )() return -class PallasCallPipelineTest(parameterized.TestCase): +class PallasCallWhileLoopTest(PallasTPUTest): def setUp(self): + if jtu.device_under_test() != 'tpu': + self.skipTest('Test only works on TPU') + super().setUp() - if jax.device_count() < 2: - self.skipTest('Only >=2 devices are supported.') - if not jtu.is_device_tpu_at_least(5): - self.skipTest('Only works with TPU v5') - @parameterized.named_parameters( - ('vmem', pltpu.TPUMemorySpace.VMEM), - ('hbm', pltpu.TPUMemorySpace.ANY), - ) - def test_pipeline_matmul(self, memory_space): - k1, k2 = jax.random.split(jax.random.key(0)) - x = jax.random.uniform(k1, (512, 512)) - y = jax.random.uniform(k2, (512, 512)) + def test_range_while_loop(self): + """Tests lowering of a while_loop which can reduce to a fori_loop.""" - def matmul_pipeline(x_ref, y_ref, z_ref): - @pl.when(pl.program_id(2) == 0) + def kernel(x_ref, r_ref): + @pl.when(pl.program_id(0) == 0) def _(): - z_ref[...] = jnp.zeros(z_ref.shape, jnp.float32) + pl.store(r_ref, (0, 0), 0) + + def cond(carry): + i, j = carry + return i < j + + def body(carry): + io, j = carry + i = io - 128 + sl = jax.lax.div(i, 128) + l = jax.lax.rem(i, 128) + v = x_ref[0, sl, l] + s = pl.load(r_ref, (0, 0)) + pl.store(r_ref, (0, 0), s + v) + return io + 1, j + + i = 128 + j = 128 + 1024 + i, j = jax.lax.while_loop(cond, body, (i, j)) + + x = jnp.arange(4096) + x = jnp.reshape(x, [4, 8, 128]) + + r = pl.pallas_call( + kernel, + grid=(1,), + out_specs=pl.BlockSpec((1, 1), memory_space=pltpu.SMEM), + out_shape=jax.ShapeDtypeStruct([1, 1], jnp.int32), + in_specs=[ + pl.BlockSpec( + (1, 8, 128), + lambda i: (i, 0, 0), + memory_space=pltpu.SMEM, + ) + ], + )(x) + expected = jnp.sum(jnp.arange(1024)) + np.testing.assert_array_equal(r, expected) - z_ref[...] += x_ref[...] @ y_ref[...] + def test_fori(self): + """Tests lowering of a while_loop which can reduce to a fori_loop.""" - def matmul_kernel(x_ref, y_ref, z_ref): - pltpu.emit_pipeline( - matmul_pipeline, - grid=(4, 4, 4), - in_specs=[ - pl.BlockSpec(lambda i, j, k: (i, k), (128, 128)), - pl.BlockSpec(lambda i, j, k: (k, j), (128, 128)), - ], - out_specs=pl.BlockSpec(lambda i, j, k: (i, j), (128, 128)), - )(x_ref, y_ref, z_ref) + def kernel(lb_ref, ub_ref, o_ref): + o_ref[0, 0] = 0 + + def body(i, _): + o_ref[0, 0] += 1 + + jax.lax.fori_loop(lb_ref[0, 0], ub_ref[0, 0], body, None) + + smem = pl.BlockSpec(memory_space=pltpu.SMEM) + r = pl.pallas_call( + kernel, + in_specs=(smem, smem), + out_specs=smem, + out_shape=jax.ShapeDtypeStruct([1, 1], jnp.int32), + )(*(jnp.array([[x]]) for x in (2, 6))) + np.testing.assert_array_equal(r, 4) - z = pl.pallas_call( - matmul_kernel, - out_shape=jax.ShapeDtypeStruct((512, 512), jnp.float32), + def test_non_range_while_loop(self): + """Tests lowering of a while_loop which cannot reduce to a fori_loop.""" + + def kernel(x_ref, r_ref): + @pl.when(pl.program_id(0) == 0) + def _(): + pl.store(r_ref, (0, 0), 0) + + def cond(state): + i, s = state + return jnp.logical_and(i < 1024, s < 1024) + + def body(state): + i, s = state + sl = jax.lax.div(i, 128) + l = jax.lax.rem(i, 128) + v = pl.load(x_ref, (0, sl, l)) + return i + 1, s + v + + i = jnp.int32(0) + s = pl.load(r_ref, (0, 0)) + + i, s = jax.lax.while_loop(cond, body, (i, s)) + pl.store(r_ref, (0, 0), s) + + x = jnp.arange(4096) + x = jnp.reshape(x, [4, 8, 128]) + + r = pl.pallas_call( + kernel, + grid=(4,), + out_specs=pl.BlockSpec((1, 1), memory_space=pltpu.SMEM), + out_shape=jax.ShapeDtypeStruct([1, 1], jnp.int32), in_specs=[ - pl.BlockSpec(memory_space=memory_space), - pl.BlockSpec(memory_space=memory_space), + pl.BlockSpec( + (1, 8, 128), + lambda i: (i, 0, 0), + memory_space=pltpu.SMEM, + ) ], - out_specs=pl.BlockSpec(memory_space=memory_space), - ) + )(x) + np.testing.assert_array_equal(r, [[1035]]) - jax.block_until_ready(z(x, y)) - jax.block_until_ready(jnp.dot(x, y)) + def test_vector_carry_while_loop(self): + """Tests lowering of a while_loop which carries a vector quantity.""" - out = jax.block_until_ready(z(x, y)) - expected_out = jax.block_until_ready(jnp.dot(x, y)) + def kernel(x_ref, r_ref): - np.testing.assert_allclose(out, expected_out) + def cond(v): + return v[0, 0] < 16 + + def body(v): + return v * 2 + + r_ref[:] = jax.lax.while_loop(cond, body, x_ref[:]) + + x = jnp.full((8, 128), 3, dtype=jnp.int32) + fn = pl.pallas_call( + kernel, + grid=(1,), + in_specs=[pl.BlockSpec((8, 128), lambda i: (0, 0))], + out_specs=pl.BlockSpec((8, 128), lambda i: (0, 0)), + out_shape=jax.ShapeDtypeStruct((8, 128), jnp.int32), + ) + r = fn(x) + reduced = jnp.sum(r) + # 3 -> 6 -> 12 -> 24 + np.testing.assert_array_equal(reduced, 1024 * 24) @parameterized.named_parameters( - ('vmem', pltpu.TPUMemorySpace.VMEM), - ('hbm', pltpu.TPUMemorySpace.ANY), + ('1x128', (1, 128)), + ('2x128', (2, 128)), + ('4x128', (4, 128)), + ('8x128', (8, 128)), + ('8x256', (8, 256)), ) - def test_double_pipeline_matmul(self, memory_space): - k1, k2 = jax.random.split(jax.random.key(0)) - x = jax.random.uniform(k1, (512, 512)) - y = jax.random.uniform(k2, (512, 512)) + def test_while_loop_carry_memref(self, shape): + """Tests a while loop carrying a memref.""" - def matmul_pipeline(x_ref, y_ref, z_ref): - @pl.when(pl.program_id(2) == 0) - def _(): - z_ref[...] = jnp.zeros(z_ref.shape, jnp.float32) + # TODO(hmckenzie): Investigate further why this occurs. + if shape == (1, 128): + self.skipTest('memref<1x128> inexplicably doubles to 2x128.') - z_ref[...] += x_ref[...] @ y_ref[...] + def kernel(out_ref, bound): + def cond(i): + return i < bound - def matmul_kernel(x_ref, y_ref, z_ref): + def body(i): + out_ref[0, i] = 2 + return i + 1 - def emit_pipeline(should_accumulate_out): - pltpu.emit_pipeline( - matmul_pipeline, - grid=(4, 4, 4), - in_specs=[ - pl.BlockSpec(lambda i, j, k: (i, k), (128, 128)), - pl.BlockSpec(lambda i, j, k: (k, j), (128, 128)), - ], - out_specs=pl.BlockSpec(lambda i, j, k: (i, j), (128, 128)), - should_accumulate_out=should_accumulate_out, - )(x_ref, y_ref, z_ref) + jax.lax.while_loop(cond, body, 0) - emit_pipeline(False) - emit_pipeline(True) + x = jnp.asarray([1, 1, 1, 1]) + x = jnp.asarray(x) + x = jnp.pad(x, (0, np.prod(shape) - 4), constant_values=0) + x = jnp.reshape(x, shape) + kernel = partial(kernel, bound=x.shape[1]) - z = pl.pallas_call( - matmul_kernel, - out_shape=jax.ShapeDtypeStruct((512, 512), jnp.float32), + fn = pl.pallas_call( + kernel, + grid=(1,), + out_specs=[ + pl.BlockSpec(shape, lambda i: (0, 0), memory_space=pltpu.SMEM), + ], + out_shape=[ + jax.ShapeDtypeStruct(shape, jnp.int32), + ], + ) + y = fn()[0] + np.testing.assert_array_equal(y[0, 0], 2) + np.testing.assert_array_equal(y[0, 1], 2) + np.testing.assert_array_equal(y[0, 2], 2) + np.testing.assert_array_equal(y[0, 3], 2) + + def test_nested_while_loop(self): + """Tests lowering a nested while_loop.""" + + def kernel(in_key_ref, out_segment_count, out_size_ref, key_count): + # Compute the length of contiguous segments of keys. + + def inner_cond(carry): + i, prev_key = carry + sl = jax.lax.div(i, 128) + l = jax.lax.rem(i, 128) + key = jax.lax.cond( + i < key_count, lambda i: in_key_ref[sl, l], lambda i: -1, i + ) + return jnp.logical_and(i < key_count, key == prev_key) + + def inner_body(carry): + i, key = carry + return i + 1, key + + def outer_cond(carry): + i, _ = carry + return i < key_count + + def outer_body(carry): + i, next_out_idx = carry + sl = jax.lax.div(i, 128) + l = jax.lax.rem(i, 128) + key = in_key_ref[sl, l] + end, _ = jax.lax.while_loop(inner_cond, inner_body, (i + 1, key)) + + sl = jax.lax.div(next_out_idx, 128) + l = jax.lax.rem(next_out_idx, 128) + out_size_ref[sl, l] = end - i + return end, next_out_idx + 1 + + _, count = jax.lax.while_loop(outer_cond, outer_body, (0, 0)) + out_segment_count[0, 0] = count + + keys = [4, 4, 4, 3, 2, 2, 7, 7, 7, 7] + keys = jnp.asarray(keys) + real_keys = keys.shape[0] + key_count = 1024 + keys = jnp.pad(keys, (0, key_count - real_keys), constant_values=32768) + keys = jnp.reshape(keys, (8, 128)) + kernel_fn = partial(kernel, key_count=key_count) + + fn = pl.pallas_call( + kernel_fn, + grid=(1,), in_specs=[ - pl.BlockSpec(memory_space=memory_space), - pl.BlockSpec(memory_space=memory_space), + # keys. + pl.BlockSpec((8, 128), lambda i: (0, 0), memory_space=pltpu.SMEM), ], - out_specs=pl.BlockSpec(memory_space=memory_space), - )(x, y) + out_specs=[ + # Segments found. + pl.BlockSpec((1, 1), memory_space=pltpu.SMEM), + # Segment sizes. + pl.BlockSpec((8, 128), memory_space=pltpu.SMEM), + ], + out_shape=[ + jax.ShapeDtypeStruct((1, 1), jnp.int32), + jax.ShapeDtypeStruct((8, 128), jnp.int32), + ], + ) + count, sizes = fn(keys) + np.testing.assert_equal(count[0, 0], jnp.asarray(5)) + np.testing.assert_equal(sizes[0, 0], jnp.asarray(3)) + np.testing.assert_equal(sizes[0, 1], jnp.asarray(1)) + np.testing.assert_equal(sizes[0, 2], jnp.asarray(2)) + np.testing.assert_equal(sizes[0, 3], jnp.asarray(4)) + np.testing.assert_equal(sizes[0, 4], jnp.asarray(key_count - real_keys)) - np.testing.assert_allclose(z, jnp.dot(x, y) + jnp.dot(x, y)) - @parameterized.named_parameters( - ('vmem', pltpu.TPUMemorySpace.VMEM, jnp.bfloat16), - ('hbm', pltpu.TPUMemorySpace.ANY, jnp.bfloat16), - ('hbm_float32', pltpu.TPUMemorySpace.ANY, jnp.float32), - ) - def test_pipeline_all_gather_matmul(self, memory_space, out_dtype): - num_devices = jax.device_count() - if num_devices < 2: - self.skipTest('Only >=2 devices are supported.') - steps = num_devices // 2 +class PallasCallReductionTest(PallasTPUTest): - tm = 1024 - tk = 768 - tn = 2048 + def setUp(self): + if jtu.device_under_test() != 'tpu': + self.skipTest('Test only works on TPU') - m = 1024 - k = 6144 - n = 6144 * 8 + super().setUp() - sharded_k = k // num_devices - sharded_n = n // num_devices + def test_integer_sum(self): + def kernel(x_ref, o_ref): + x = x_ref[:] + # We'd prefer to say: + # o_ref[0, 0] = jnp.sum(x) + # But this currently hits issues in both Pallas and Mosaic lowering. + r = jnp.sum(x, keepdims=True, axis=1) + r = jnp.sum(r, keepdims=True, axis=0) + o_ref[0, 0] = r[0, 0] + + x = jnp.full([8, 128], 2.0) + result = pl.pallas_call( + kernel, + in_specs=[ + pl.BlockSpec((8, 128), lambda *_: (0, 0)), + ], + out_specs=pl.BlockSpec((1, 1), memory_space=pltpu.SMEM), + out_shape=jax.ShapeDtypeStruct([1, 1], jnp.float32), + grid=(1,), + )(x) - k1, k2 = jax.random.split(jax.random.key(0)) - x = jax.random.uniform(k1, (m, k), dtype=jnp.bfloat16, minval=-1, maxval=1) - y = jax.random.uniform( - k2, (k, sharded_n), dtype=jnp.bfloat16, minval=-1, maxval=1 - ) + np.testing.assert_array_equal(result[0, 0], 2048.0) + + def test_integer_max(self): + def kernel(x_ref, o_ref): + x = x_ref[:] + # We'd prefer to say: + # o_ref[0, 0] = jnp.max(x) + # But this currently hits issues in both Pallas and Mosaic lowering. + x = jnp.max(x, keepdims=True, axis=1) + x = jnp.max(x, keepdims=True, axis=0) + o_ref[0, 0] = x[0, 0] + + x = jnp.arange(1024.0) + x = jnp.reshape(x, [8, 128]) + result = pl.pallas_call( + kernel, + in_specs=[ + pl.BlockSpec((8, 128), lambda *_: (0, 0)), + ], + out_specs=pl.BlockSpec((1, 1), memory_space=pltpu.SMEM), + out_shape=jax.ShapeDtypeStruct([1, 1], jnp.float32), + grid=(1,), + )(x) - def existing_matmul_kernel( - lhs_ref, rhs_ref, out_ref, acc_scratch_ref, *, acc_steps - ): - @pl.when(pl.program_id(2) == 0) - def _zero_acc(): - acc_scratch_ref[...] = jnp.zeros( - acc_scratch_ref.shape, acc_scratch_ref.dtype - ) + np.testing.assert_array_equal(result[0, 0], 1023.0) - acc_scratch_ref[...] += jnp.dot( - lhs_ref[...], - rhs_ref[...], - preferred_element_type=acc_scratch_ref.dtype, - ) - @pl.when(pl.program_id(2) == acc_steps - 1) - def _store_acc(): - out_ref[...] = acc_scratch_ref[...].astype(out_ref.dtype) +class PallasCallDynamicDMATest(PallasTPUTest): + + def setUp(self): + if not jtu.is_device_tpu_at_least(4): + self.skipTest('DMAs not supported on TPU generations <= 3') + + super().setUp() + + def test_simple_tile_aligned_dynamic_size_dma(self): + + def kernel(size_smem_ref, x_hbm_ref, _, o_hbm_ref, sem): + size = size_smem_ref[0] + pltpu.async_copy( + x_hbm_ref.at[pl.ds(0, size)], + o_hbm_ref.at[pl.ds(0, size)], sem).wait() + + x = jnp.tile(jnp.arange(8, dtype=jnp.int32)[:, None, None], [1, 8, 128]) + o = jnp.zeros((8, 8, 128), dtype=jnp.int32) + size = jnp.array([4], dtype=jnp.int32) + + out = pl.pallas_call( + kernel, + grid_spec=pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=0, + in_specs=[pl.BlockSpec(memory_space=pltpu.SMEM), + pl.BlockSpec(memory_space=pltpu.ANY), + pl.BlockSpec(memory_space=pltpu.ANY)], + out_specs=pl.BlockSpec(memory_space=pltpu.ANY), + scratch_shapes=[pltpu.SemaphoreType.DMA] + ), + out_shape=o, + input_output_aliases={2: 0}, + )(size, x, o) + expected = o.at[:4].set(x.at[:4].get()) + np.testing.assert_array_equal(out, expected) + + def test_simple_dynamic_size_dma(self): + self.skipTest("doesn't work yet.") + def kernel(size_smem_ref, x_hbm_ref, _, o_hbm_ref, sem): + size = size_smem_ref[0] + pltpu.async_copy( + x_hbm_ref.at[pl.ds(0, size)], + o_hbm_ref.at[pl.ds(0, size)], sem).wait() + + x = jnp.arange(8, dtype=jnp.int32) + o = jnp.zeros(8, dtype=jnp.int32) + size = jnp.array([4], dtype=jnp.int32) + + out = pl.pallas_call( + kernel, + grid_spec=pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=0, + in_specs=[pl.BlockSpec(memory_space=pltpu.SMEM), + pl.BlockSpec(memory_space=pltpu.ANY), + pl.BlockSpec(memory_space=pltpu.ANY)], + out_specs=pl.BlockSpec(memory_space=pltpu.ANY), + scratch_shapes=[pltpu.SemaphoreType.DMA] + ), + out_shape=o, + input_output_aliases={2: 0}, + )(size, x, o) + expected = o.at[:4].set(x.at[:4].get()) + np.testing.assert_array_equal(out, expected) + + +class PallasCallComparisonTest(PallasTPUTest): + + def setUp(self): + if jtu.device_under_test() != 'tpu': + self.skipTest('Test only works on TPU') + + super().setUp() + + @parameterized.named_parameters( + ('integer_1_1', (1, 1)), + ('integer_1_16', (1, 16)), + ('integer_16_1', (16, 1)), + ('integer_-1_1', (-1, 1)), + ('integer_1_-1', (1, -1)), + ('float_1_1', (1.0, 1.0)), + ('float_1_16', (1.0, 16.0)), + ('float_16_1', (16.0, 1.0)), + ('float_-1_1', (-1.0, 1.0)), + ('float_1_-1', (1.0, -1.0)), + ('float_1_inf', (1.0, float('inf'))), + ('float_inf_1', (float('inf'), 1.0)), + ('float_inf_inf', (float('inf'), float('inf'))), + ('float_1_nan', (1.0, float('nan'))), + ('float_nan_1', (float('nan'), 1.0)), + ('float_nan_nan', (float('nan'), float('nan'))), + ('float_inf_nan', (float('inf'), float('nan'))), + ('float_nan_inf', (float('inf'), float('inf'))), + ) + def test_scalar_compare(self, params): + """Test some scalar compares. + + We don't really expect that the results would be wrong, but rather we want + to exercise the lowering rules. + """ + + def kernel(x_ref, y_ref, o_ref): + x = x_ref[0, 0] + y = y_ref[0, 0] + o_ref[0, 0] = jax.lax.select(x == y, 1, 0) + o_ref[0, 1] = jax.lax.select(x != y, 1, 0) + o_ref[0, 2] = jax.lax.select(x < y, 1, 0) + o_ref[0, 3] = jax.lax.select(x <= y, 1, 0) + o_ref[0, 4] = jax.lax.select(x > y, 1, 0) + o_ref[0, 5] = jax.lax.select(x >= y, 1, 0) + + x, y = params + r = jnp.array( + [ + [x == y, x != y, x < y, x <= y, x > y, x >= y], + ], + jnp.int32, + ) + x = jnp.array([[x]]) + y = jnp.array([[y]]) + + result = pl.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct([1, 128], jnp.int32), + in_specs=[ + pl.BlockSpec(memory_space=pltpu.SMEM), + pl.BlockSpec(memory_space=pltpu.SMEM), + ], + out_specs=pl.BlockSpec( + (1, 128), lambda i: (0, 0), memory_space=pltpu.SMEM + ), + grid=(1,), + )(x, y) + np.testing.assert_array_equal(r, result[..., 0:6]) - grid_k = sharded_k // tk - pipeline, make_pipeline_allocations = pltpu.emit_pipeline_with_allocations( - partial(existing_matmul_kernel, acc_steps=grid_k), - grid=(sharded_n // tn, m // tm, grid_k), + @parameterized.named_parameters( + ('integer_1_1', (1, 1)), + ('integer_1_16', (1, 16)), + ('integer_16_1', (16, 1)), + ('integer_-1_1', (-1, 1)), + ('integer_1_-1', (1, -1)), + ('float_1_1', (1.0, 1.0)), + ('float_1_16', (1.0, 16.0)), + ('float_16_1', (16.0, 1.0)), + ('float_-1_1', (-1.0, 1.0)), + ('float_1_-1', (1.0, -1.0)), + ('float_1_inf', (1.0, float('inf'))), + ('float_inf_1', (float('inf'), 1.0)), + ('float_inf_inf', (float('inf'), float('inf'))), + ('float_1_nan', (1.0, float('nan'))), + ('float_nan_1', (float('nan'), 1.0)), + ('float_nan_nan', (float('nan'), float('nan'))), + ('float_inf_nan', (float('inf'), float('nan'))), + ('float_nan_inf', (float('inf'), float('inf'))), + ) + def test_vector_compare(self, params): + """Test some vector compares. + + We don't really expect that the results would be wrong, but rather we want + to exercise the lowering rules. + """ + + def kernel(x_ref, y_ref, o_ref): + x = x_ref[:] + y = y_ref[:] + one = jnp.ones([8, 128], dtype=jnp.int32) + zero = jnp.zeros([8, 128], dtype=jnp.int32) + o_ref[0] = jax.lax.select(x == y, one, zero) + o_ref[1] = jax.lax.select(x != y, one, zero) + o_ref[2] = jax.lax.select(x < y, one, zero) + o_ref[3] = jax.lax.select(x <= y, one, zero) + o_ref[4] = jax.lax.select(x > y, one, zero) + o_ref[5] = jax.lax.select(x >= y, one, zero) + + # Widen out our params to (8, 128) vectors. + x, y = params + x = jnp.full([8, 128], x) + y = jnp.full([8, 128], y) + + r = [x == y, x != y, x < y, x <= y, x > y, x >= y] + + result = pl.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct([6, 8, 128], jnp.int32), in_specs=[ - pl.BlockSpec(lambda n, m, k: (m, k), (tm, tk)), - pl.BlockSpec(lambda n, m, k: (k, n), (tk, tn)), + pl.BlockSpec((8, 128), lambda *_: (0, 0)), + pl.BlockSpec((8, 128), lambda *_: (0, 0)), ], - out_specs=pl.BlockSpec(lambda n, m, k: (m, n), (tm, tn)), - should_accumulate_out=True, + out_specs=pl.BlockSpec((6, 8, 128), lambda *_: (0, 0, 0)), + grid=(1,), + )(x, y) + np.testing.assert_array_equal(r[0], result[0]) + np.testing.assert_array_equal(r[1], result[1]) + np.testing.assert_array_equal(r[2], result[2]) + np.testing.assert_array_equal(r[3], result[3]) + np.testing.assert_array_equal(r[4], result[4]) + np.testing.assert_array_equal(r[5], result[5]) + + +class PallasCallPrintTest(PallasTPUTest): + + def test_debug_print(self): + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct((2,), jnp.float32), ) + def kernel(x_ref, o_ref): + pl.debug_print('It works!') + + x = jnp.array([4.2, 2.4]).astype(jnp.float32) + compiled_kernel = ( + jax.jit(kernel) + .lower(x) + .compile({'xla_tpu_enable_log_recorder': 'true'}) + ) + compiled_kernel(x) + + def test_debug_print_with_values(self): + @functools.partial( + self.pallas_call, + in_specs=(pl.BlockSpec(memory_space=pltpu.SMEM),), + out_shape=jax.ShapeDtypeStruct((2,), jnp.float32), + ) + def kernel(x_ref, o_ref): + pl.debug_print('x[0] == {}', x_ref[0]) + + x = jnp.array([42, 24]).astype(jnp.int32) + compiled_kernel = ( + jax.jit(kernel) + .lower(x) + .compile({'xla_tpu_enable_log_recorder': 'true'}) + ) + compiled_kernel(x) - # Given shapes: - # lhs: A 2d, jnp.ndarray with shape [m, k // lax.psum(1, - # collective_axes.axes)]. - # rhs: A wd, jnp.ndarray with shape [k, n]. - - # We start with a prologue that gets us the lhs chunk that our left neighbor - # will send backward for us to send forward. After that at every step we do - # compute on our local chunks while overlapping the backward and forward - # collective permutes of lhs. We add to the same accumulator at every step. - # Effectively, this permute + compute pattern achieves an all-gather of lhs - # that is overlapped with the matmul. - - # We wait for the permutes in the pipeline epilogues so we can fuse the - # inner compute pipeline across matmul steps and avoid bubbles. - def all_gather_lhs_matmul_kernel( - lhs_ref, # [m, sharded_k] - rhs_ref, # [k, n] - out_ref, # [m, n] - # Fwd/bwd, and double buffered. - lhs_scratch_ref, # [2, 2, m, sharded_k] - acc_scratch_ref, # [tm, tn] - bwd_recv_sem, - bwd_send_sem, - fwd_recv_sem, - fwd_send_sem, - pipeline_allocations, - ): - step = pl.program_id(0) - fwd_bwd = pl.program_id(1) - is_first_step = step == 0 - is_not_last_step = step != steps - 1 - is_start_of_step = fwd_bwd == 0 - is_end_of_step = jnp.logical_not(is_start_of_step) - is_start = jnp.logical_and(is_first_step, is_start_of_step) - is_end = jnp.logical_and(step == steps - 1, is_end_of_step) - compute_buffer = lax.rem(step, 2) - send_buffer = 1 - compute_buffer - my_id = lax.axis_index('x') - right_neighbor = lax.rem(my_id + 1, num_devices) - left_neighbor = lax.rem(my_id - 1, num_devices) - left_neighbor = jnp.where( - left_neighbor < 0, left_neighbor + num_devices, left_neighbor - ) - prologue_fwd_copy = pltpu.make_async_remote_copy( - lhs_ref, - lhs_scratch_ref.at[1, compute_buffer], - fwd_send_sem, - fwd_recv_sem, - device_id=right_neighbor, +class PallasCallTPUInterpretTest(PallasTPUTest): + + def test_local_dma(self): + def test_kernel(x_ref, + o_ref, + copy_sem, + ): + o_ref[...] = jnp.zeros_like(o_ref[...]) + input_to_output_copy = pltpu.make_async_copy( + src_ref=x_ref.at[0:8], + dst_ref=o_ref.at[0:8], + sem=copy_sem, ) + input_to_output_copy.start() + input_to_output_copy.wait() - @pl.when(is_start) - @pltpu.trace('sync_and_bwd_prologue') - def _sync_and_bwd_prologue(): - barrier_sem = pltpu.get_barrier_semaphore() - pltpu.semaphore_signal(barrier_sem, device_id=left_neighbor) - pltpu.semaphore_signal(barrier_sem, device_id=right_neighbor) - pltpu.semaphore_wait(barrier_sem, 2) - prologue_bwd_copy = pltpu.make_async_copy( - lhs_ref, - lhs_scratch_ref.at[0, compute_buffer], - bwd_send_sem, - ) - prologue_bwd_copy.start() - prologue_fwd_copy.start() - prologue_bwd_copy.wait() - - bwd_kwargs, fwd_kwargs = [ - { - 'src_ref': scratch_ref.at[compute_buffer], - 'dst_ref': scratch_ref.at[send_buffer], - 'send_sem': send_sem, - 'recv_sem': recv_sem, - 'device_id': device_id, - } - for scratch_ref, send_sem, recv_sem, device_id in [ - ( - lhs_scratch_ref.at[0], - bwd_send_sem, - bwd_recv_sem, - left_neighbor, - ), - ( - lhs_scratch_ref.at[1], - fwd_send_sem, - fwd_recv_sem, - right_neighbor, - ), - ] - ] - - @pl.when(jnp.logical_and(is_not_last_step, is_start_of_step)) - @pltpu.trace('send_next_dma') - def _send_next_dma(): - pltpu.make_async_remote_copy(**bwd_kwargs).start() - - @pl.when(jnp.logical_not(is_start)) - def _send_next_fwd_dma(): - pltpu.make_async_remote_copy(**fwd_kwargs).start() - - def get_rhs_slice(step, is_start_of_step=is_start_of_step): - bwd_rhs_offset = lax.rem(my_id + step, num_devices) - fwd_rhs_offset = lax.rem(my_id - step - 1, num_devices) - fwd_rhs_offset = jnp.where( - fwd_rhs_offset < 0, fwd_rhs_offset + num_devices, fwd_rhs_offset + out_shape = (jax.ShapeDtypeStruct((9, 128), jnp.float32)) + grid_spec = pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=0, + in_specs=[ + pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + ], + scratch_shapes=( + [pltpu.SemaphoreType.DMA] + ) ) - offset = jnp.where(is_start_of_step, bwd_rhs_offset, fwd_rhs_offset) - return pl.ds( - pl.multiple_of(offset * sharded_k, sharded_k), - sharded_k, + + kernel = pl.pallas_call( + test_kernel, + out_shape=out_shape, + grid_spec=grid_spec, + interpret=True + ) + x = jax.random.normal(jax.random.key(0), shape=(16, 128)) + result = kernel(x) + np.testing.assert_array_equal(result[0:8], x[0:8]) + np.testing.assert_array_equal(result[8:], jnp.zeros_like(result[8:])) + + @parameterized.parameters(('left',), ('right',)) + def test_remote_dma_ppermute(self, permutation): + if jax.device_count() <= 1: + self.skipTest('Test requires multiple devices.') + num_devices = jax.device_count() + if permutation == 'left': + permute_fn = lambda x: lax.rem(x + num_devices - 1, num_devices) + else: + permute_fn = lambda x: lax.rem(x + num_devices + 1, num_devices) + + # Construct a kernel which performs a ppermute based on permute_fn. + def test_kernel(x_ref, + o_ref, + copy_send_sem, + copy_recv_sem, + ): + o_ref[...] = jnp.zeros_like(o_ref[...]) + my_id = lax.axis_index('x') + dst_device = permute_fn(my_id) + input_to_output_copy = pltpu.make_async_remote_copy( + src_ref=x_ref, + dst_ref=o_ref, + send_sem=copy_send_sem, + recv_sem=copy_recv_sem, + device_id=dst_device, + device_id_type=pltpu.DeviceIdType.LOGICAL, + ) + input_to_output_copy.start() + input_to_output_copy.wait() + + out_shape = (jax.ShapeDtypeStruct((8, 128), jnp.float32)) + grid_spec = pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=0, + in_specs=[ + pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + ], + scratch_shapes=( + [pltpu.SemaphoreType.DMA] * 2 + ) ) - with pltpu.trace('dots'): + devices = mesh_utils.create_device_mesh((1, num_devices)) + mesh = jax.sharding.Mesh(devices, P(None, 'x')) + sharding = jax.sharding.NamedSharding(mesh, P(None, 'x')) + unsharded_arr = jax.random.normal( + jax.random.key(0), shape=(8, 128 * num_devices)) + sharded_arr = jax.device_put(unsharded_arr, sharding) - def epilogue(epilogue_args: pltpu.PipelineCallbackArgs): + kernel = pl.pallas_call( + test_kernel, + out_shape=out_shape, + grid_spec=grid_spec, + interpret=True + ) + compiled_func = jax.jit(shard_map.shard_map( + kernel, + mesh=mesh, + in_specs=P(None, 'x'), + out_specs=P(None, 'x'), + check_rep=False)) + result = compiled_func(sharded_arr) + + perm = tuple((src, permute_fn(src)) for src in range(num_devices)) + perm = jax.tree_util.tree_map(int, perm) + def lax_permute(x): + return lax.ppermute(x, 'x', perm) + expected = jax.jit(shard_map.shard_map(lax_permute, + mesh=mesh, + in_specs=P(None, 'x'), + out_specs=P(None, 'x')))(sharded_arr) + np.testing.assert_array_equal(result, expected) + + +class PallasCallTraceTest(PallasTPUTest): + interpret: bool = False - @pl.when(is_start) - @pltpu.trace('fwd_prologue') - def _fwd_prologue(): - prologue_fwd_copy.wait() - pltpu.make_async_remote_copy(**fwd_kwargs).start() + def parse_debug_string(self, debug_string): + jaxpr, mlir = debug_string.split('module') + return {'jaxpr': jaxpr, 'mlir': mlir} - @pl.when(jnp.logical_and(is_not_last_step, is_end_of_step)) - @pltpu.trace('wait_on_prev_dma') - def _wait_on_prev_dma(): - pltpu.make_async_remote_copy(**bwd_kwargs).wait() - pltpu.make_async_remote_copy(**fwd_kwargs).wait() + def test_trace_start_stop_match(self): + def kernel(o_ref): + with jax.named_scope('scope1'): + o_ref[...] = jnp.zeros_like(o_ref[...]) - def prefetch_pipeline_inputs(): - prefetch_compute_buffer = jnp.where( - is_start_of_step, compute_buffer, send_buffer - ) - prefetch_fwd_bwd = lax.rem(fwd_bwd + 1, 2) - prefetch_pipeline_refs = epilogue_args.make_pipeline_refs( - lhs_scratch_ref.at[prefetch_fwd_bwd, prefetch_compute_buffer], - rhs_ref.at[ - get_rhs_slice( - jnp.where(is_start_of_step, step, step + 1), - jnp.logical_not(is_start_of_step), - ) - ], - out_ref, - ) - return epilogue_args.start_pipeline_prefetch( - pltpu.PipelinePrefetchArgs( - prefetch_pipeline_refs, - epilogue_args.pipeline_allocations, - epilogue_args.pipeline_buffers, - ), - # Force copy lhs because we just permuted it. - # Force copy rhs because we need a different slice. - force_copy=([True, True], False), - ) + with string_stdout() as msg: + _ = self.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32), + debug=True, + )() + # TODO(justinfu): Add an official lowering API to get the MLIR. + mlir = self.parse_debug_string(msg.getvalue())['mlir'] + + num_start = mlir.count('tpu.trace_start') + num_stop = mlir.count('tpu.trace_stop') + self.assertEqual(num_start, 1) + self.assertEqual(num_stop, 1) + + def test_run_scoped(self): + def kernel(o_ref): + def scope1(): + with jax.named_scope('scope1'): + o_ref[...] = jnp.zeros_like(o_ref[...]) + pltpu.run_scoped(scope1) + + def scope2(): + with jax.named_scope('scope2'): + o_ref[...] = o_ref[...] + 1 + pltpu.run_scoped(scope2) + + with string_stdout() as msg: + _ = self.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32), + debug=True, + )() + # TODO(justinfu): Add an official lowering API to get the MLIR. + mlir = self.parse_debug_string(msg.getvalue())['mlir'] - return lax.cond( - jnp.logical_not(is_end), - prefetch_pipeline_inputs, - lambda: ( - epilogue_args.pipeline_buffers.input, - epilogue_args.pipeline_buffers.in_out, - ), - ) + num_start = mlir.count('tpu.trace_start') + num_stop = mlir.count('tpu.trace_stop') + self.assertEqual(num_start, 2) + self.assertEqual(num_stop, 2) - pipeline( - lhs_scratch_ref.at[fwd_bwd, compute_buffer], - rhs_ref.at[get_rhs_slice(step)], - out_ref, - scratchs=[acc_scratch_ref], - allocations=pipeline_allocations, - init_allocations=is_start, - prologue=lambda _: ( - # Input and accum prologue input copy start skip conditions. - ( - jnp.logical_not(is_start), - jnp.logical_not(is_start), - ), - # Force input and accum input copy wait. - ([True, True], False), - ), - epilogue=epilogue, - # Only skip prologue output copy wait if starting and there is no - # previous output. - out_prologue=lambda _: is_start, - # Skip epilogue output copy wait unless it's the end. - out_epilogue=lambda _: jnp.logical_not(is_end), - ) - kernel = pl.pallas_call( - all_gather_lhs_matmul_kernel, - out_shape=[ - jax.ShapeDtypeStruct((m, sharded_n), out_dtype), - jax.ShapeDtypeStruct((2, 2, m, sharded_k), x.dtype), - ], +class PallasCallTPUCheckifyTest(PallasTPUTest): + interpret: bool = True + + @parameterized.parameters((2,), (5,), (6,), (7,)) + def test_checkify_with_scalar_prefetch(self, threshold): + def body(scalar_ref, x_ref, o_ref): + scalar = scalar_ref[pl.program_id(0)] + o_ref[...] = x_ref[...] + checkify.check(scalar < threshold, 'failed on value {x}', x=scalar) + + s = jnp.array([4, 3, 2, 6, 3, 5, 2, 7], jnp.int32) + x = jnp.arange(8 * 8 * 128, dtype=jnp.int32).reshape((8 * 8, 128)) + + def _x_transform(i, s_ref): + s = pl.load(s_ref, (i,)) + return (s, 0) + + pallas_call = self.pallas_call( + body, + out_shape=jax.ShapeDtypeStruct(x.shape, jnp.int32), grid_spec=pltpu.PrefetchScalarGridSpec( - num_scalar_prefetch=0, + num_scalar_prefetch=1, in_specs=[ - pl.BlockSpec(memory_space=memory_space), - pl.BlockSpec(memory_space=memory_space), + pl.BlockSpec((x.shape[0] // 8, x.shape[1]), _x_transform), ], - out_specs=[pl.BlockSpec(memory_space=memory_space)] * 2, - grid=(steps, 2), - scratch_shapes=[pltpu.VMEM((tm, tn), jnp.float32)] - + [pltpu.SemaphoreType.DMA] * 4 - + [ - make_pipeline_allocations( - memory_space((), x.dtype), - memory_space((), y.dtype), - memory_space((), out_dtype), - ) - ], - ), - compiler_params=dict( - mosaic=dict(collective_id=0, vmem_limit_bytes=int(134217728 * 0.9)) + out_specs=pl.BlockSpec( + (x.shape[0] // 8, x.shape[1]), lambda i, _: (i, 0) + ), + grid=8, ), ) + checked_call = checkify.checkify(pallas_call) + err, out = checked_call(s, x) + expected_error_value = s[jnp.argmax(s >= threshold)] + with self.assertRaisesRegex( + checkify.JaxRuntimeError, f'failed on value {expected_error_value}'): + err.throw() + np.testing.assert_allclose(out, x.reshape((8, 8, -1))[s].reshape(x.shape)) + + def test_checkify_with_scratch(self): + def body(x_ref, o_ref, scratch_ref): + scratch_ref[...] = x_ref[...] + o_ref[...] = scratch_ref[...] + all_nequal = ~jnp.all(o_ref[...] == x_ref[...]) + checkify.check(all_nequal, 'x_ref equals o_ref id=({x}, {y})', + x=pl.program_id(0), y=pl.program_id(1)) - shard = partial( - shard_map.shard_map, - mesh=jax.sharding.Mesh( - mesh_utils.create_device_mesh((num_devices,), jax.devices()), - ['x'], + x = jax.random.uniform(jax.random.key(0), (128, 128), dtype=jnp.float32) + pallas_call = self.pallas_call( + body, + out_shape=jax.ShapeDtypeStruct(x.shape, jnp.float32), + grid_spec=pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=0, + in_specs=[ + pl.BlockSpec((32, 32), lambda i, j: (i, j)), + ], + out_specs=pl.BlockSpec((32, 32), lambda i, j: (i, j)), + scratch_shapes=[pltpu.VMEM((32, 32), dtype=jnp.float32)], + grid=(4, 4), ), - in_specs=(P(None, 'x'), P(None, None)), - out_specs=P(None, None), - check_rep=False, ) + checked_call = checkify.checkify(pallas_call) + err, out = checked_call(x) + with self.assertRaisesRegex( + checkify.JaxRuntimeError, r'x_ref equals o_ref id=\(0, 0\)'): + err.throw() + np.testing.assert_allclose(out, x) + + @parameterized.parameters((4,), (9,)) + def test_checkify_with_dynamic_grid(self, iteration): + grid_size = 4 + shape = (8, 128) + result_ty = jax.ShapeDtypeStruct(shape, jnp.float32) - test = jax.jit(shard(kernel)) + def kernel(y_ref): + @pl.when(pl.program_id(0) == 0) + def _init(): + y_ref[...] = jnp.zeros_like(y_ref) + y_ref[...] += 1 + @pl.when(pl.program_id(0) == iteration) + def _(): + checkify.check(False, f"error on iteration {iteration}") @jax.jit - @shard - def reference(x, y): - x = jax.lax.all_gather(x, 'x', axis=1, tiled=True) - return jnp.dot(x, y, preferred_element_type=out_dtype) + def dynamic_kernel(steps): + pallas_call = self.pallas_call( + kernel, + grid=(steps * 2,), + out_specs=pl.BlockSpec(shape, lambda i: (0, 0)), + out_shape=result_ty, + ) + return checkify.checkify(pallas_call)() - jax.block_until_ready(test(x, y)) - jax.block_until_ready(reference(x, y)) + err, result = dynamic_kernel(jnp.int32(grid_size)) + if iteration < grid_size * 2: + with self.assertRaisesRegex( + checkify.JaxRuntimeError, f"error on iteration {iteration}"): + err.throw() + np.testing.assert_array_equal( + result, np.full(shape, grid_size * 2.0, np.float32) + ) - out = jax.block_until_ready(test(x, y)[0]) - expected_out = jax.block_until_ready(reference(x, y)) - np.testing.assert_allclose( - out.astype(jnp.float32), - expected_out.astype(jnp.float32), - atol=1 if out_dtype == jnp.float32 else 5, +class MiscellaneousTest(PallasTPUTest): + """Tests for recently reported bugs; only pass in interpret mode.""" + + interpret: bool = True + + def test_float32_stack(self): + """b/347761105""" + x = np.arange(128, dtype=jnp.float32).reshape(1, 128) + y = x + 128 + + def kernel(x_ref, y_ref, out_ref): + out_ref[...] = jnp.stack([x_ref[...], y_ref[...]], axis=1) + + out = self.pallas_call( + kernel, out_shape=jax.ShapeDtypeStruct((1, 2, 128), jnp.float32) + )(x, y) + np.testing.assert_array_equal(out, np.stack([x, y], axis=1)) + + def test_lane_to_chunk_reshape_bf16(self): + """b/348038320""" + x = np.arange(256 * 1024, dtype=jnp.bfloat16).reshape(1, 256, 1024) + + def kernel(x_ref, out_ref): + out_ref[...] = jnp.reshape(x_ref[...], (1, 256, 8, 128)) + + out = self.pallas_call( + kernel, out_shape=jax.ShapeDtypeStruct((1, 256, 8, 128), jnp.bfloat16) + )(x) + np.testing.assert_array_equal(out, np.reshape(x, (1, 256, 8, 128))) + + def test_lane_to_chunk_broadcast_fp32(self): + """b/348033362""" + x = np.arange(256 * 128, dtype=jnp.float32).reshape(1, 256, 128) + + def kernel(x_ref, out_ref): + out_ref[...] = jnp.broadcast_to( + jnp.expand_dims(x_ref[...], 2), (1, 256, 8, 128) + ) + + out = self.pallas_call( + kernel, out_shape=jax.ShapeDtypeStruct((1, 256, 8, 128), jnp.float32) + )(x) + np.testing.assert_array_equal( + out, np.broadcast_to(np.expand_dims(x, 2), (1, 256, 8, 128)) ) + def test_lane_dynamic_slice(self): + """b/346849973""" + x = np.arange(128, dtype=jnp.float32) + + def kernel(x_ref, out_ref): + out_ref[...] = lax.dynamic_slice_in_dim(x_ref[...], 64, 1, 0) + + out = self.pallas_call( + kernel, out_shape=jax.ShapeDtypeStruct((1,), jnp.float32) + )(x) + np.testing.assert_array_equal(out, x[64:65]) + + def test_lane_broadcast_bf16(self): + """b/346654106""" + x = np.arange(256, dtype=jnp.bfloat16).reshape(256, 1) + + def kernel(x_ref, out_ref): + out_ref[...] = jnp.broadcast_to(x_ref[...], (256, 512)) + + out = self.pallas_call( + kernel, out_shape=jax.ShapeDtypeStruct((256, 512), jnp.bfloat16) + )(x) + np.testing.assert_array_equal(out, np.broadcast_to(x, (256, 512))) + + def test_bfloat16_to_uint32_bitcast(self): + """b/347771903""" + x = np.arange(16 * 2 * 256, dtype=jnp.bfloat16).reshape(16, 2, 256) + + def kernel(x_ref, out_ref): + out_ref[...] = pltpu.bitcast(x_ref[...], jnp.uint32) + + out = self.pallas_call( + kernel, out_shape=jax.ShapeDtypeStruct((16, 1, 256), jnp.uint32) + )(x) + # FIXME: Add correctness test for result. + + def test_roll_partial(self): + """b/337384645""" + x = np.arange(8192, dtype=jnp.float32).reshape(128, 64) + + def kernel(x_ref, out_ref): + out_ref[...] = pltpu.roll(x_ref[...], 3, 1) + + out = self.pallas_call( + kernel, out_shape=jax.ShapeDtypeStruct((128, 64), jnp.float32) + )(x) + np.testing.assert_array_equal(out, np.roll(x, 3, 1)) + if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/pallas/pallas_pipeline_tpu_test.py b/tests/pallas/pallas_pipeline_tpu_test.py new file mode 100644 index 000000000000..cc5ef51e9c22 --- /dev/null +++ b/tests/pallas/pallas_pipeline_tpu_test.py @@ -0,0 +1,1519 @@ +# Copyright 2023 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test TPU-specific extensions to pallas_call.""" + +import functools +from absl.testing import absltest +from absl.testing import parameterized +import jax +from jax import lax +from jax._src import test_util as jtu +from jax.experimental import mesh_utils +from jax.experimental import pallas as pl +from jax.experimental import shard_map +from jax.experimental.pallas import tpu as pltpu +import jax.numpy as jnp +import numpy as np +try: + import hypothesis as hp + import hypothesis.strategies as hps + CAN_USE_HYPOTHESIS = True +except (ModuleNotFoundError, ImportError): + CAN_USE_HYPOTHESIS = False + + +if CAN_USE_HYPOTHESIS: + hp.settings.register_profile( + 'deterministic', + database=None, + derandomize=True, + deadline=None, + max_examples=200, + print_blob=True, + ) + hp.settings.load_profile('deterministic') + + +jax.config.parse_flags_with_absl() +P = jax.sharding.PartitionSpec +partial = functools.partial + + +def mod(a, n): + return lax.rem(a + n, n) + + +def make_ds(idx, stride): + return pl.ds(pl.multiple_of(idx * stride, stride), stride) + + +def _grid_size(grid): + size = jnp.array(1, jnp.int32) + for dim in grid: + size *= dim + return size + + +@jax.named_scope('compute') +def basic_matmul_kernel( + lhs_ref, + rhs_ref, + out_ref, + acc_scratch_ref, + *, + k: int, +): + k_index = pl.program_id(2) + num_k = pl.num_programs(2) + bk = lhs_ref.shape[1] + @pl.when(k_index == 0) + def _zero_acc(): + acc_scratch_ref[...] = jnp.zeros( + acc_scratch_ref.shape, acc_scratch_ref.dtype) + + divisible_k = k % bk == 0 + if divisible_k: + acc_scratch_ref[...] += jnp.dot( + lhs_ref[...], + rhs_ref[...], + preferred_element_type=acc_scratch_ref.dtype, + ) + else: + def _last_block(): + accum_dtype = acc_scratch_ref.dtype + lhs_mask = ( + k_index * bk + jax.lax.broadcasted_iota(jnp.int32, lhs_ref.shape, 1) + < k + ) + rhs_mask = ( + k_index * bk + jax.lax.broadcasted_iota(jnp.int32, rhs_ref.shape, 0) + < k + ) + dtype = lhs_ref.dtype + lhs = lhs_ref[...].astype(accum_dtype) + lhs = jnp.where(lhs_mask, lhs, 0).astype(dtype) + rhs = rhs_ref[...].astype(accum_dtype) + rhs = jnp.where(rhs_mask, rhs, 0).astype(dtype) + acc_scratch_ref[...] += jnp.dot( + lhs, rhs, preferred_element_type=acc_scratch_ref.dtype) + def _not_last_block(): + acc_scratch_ref[...] += jnp.dot( + lhs_ref[...], + rhs_ref[...], + preferred_element_type=acc_scratch_ref.dtype, + ) + jax.lax.cond( + k_index == num_k - 1, _last_block, _not_last_block + ) + + @pl.when(k_index == num_k - 1) + def _reduce_out(): + out_ref[...] = acc_scratch_ref[...].astype(out_ref.dtype) + + +class PallasCallPipelineTest(parameterized.TestCase): + + def setUp(self): + if jax.device_count() < 2: + self.skipTest('Only >=2 devices are supported.') + if not jtu.is_device_tpu_at_least(5): + self.skipTest('Only works with TPU v5') + + super().setUp() + + @parameterized.named_parameters( + ('vmem', pltpu.TPUMemorySpace.VMEM), + ('hbm', pltpu.TPUMemorySpace.ANY), + ) + def test_pipeline_matmul(self, memory_space): + k1, k2 = jax.random.split(jax.random.key(0)) + x = jax.random.uniform(k1, (512, 512)) + y = jax.random.uniform(k2, (512, 512)) + + def matmul_pipeline(x_ref, y_ref, z_ref): + @pl.when(pl.program_id(2) == 0) + def _(): + z_ref[...] = jnp.zeros(z_ref.shape, jnp.float32) + + z_ref[...] += x_ref[...] @ y_ref[...] + + def matmul_kernel(x_ref, y_ref, z_ref): + pltpu.emit_pipeline( + matmul_pipeline, + grid=(4, 4, 4), + in_specs=[ + pl.BlockSpec((128, 128), lambda i, j, k: (i, k)), + pl.BlockSpec((128, 128), lambda i, j, k: (k, j)), + ], + out_specs=pl.BlockSpec((128, 128), lambda i, j, k: (i, j)), + )(x_ref, y_ref, z_ref) + + z = pl.pallas_call( + matmul_kernel, + out_shape=jax.ShapeDtypeStruct((512, 512), jnp.float32), + in_specs=[ + pl.BlockSpec(memory_space=memory_space), + pl.BlockSpec(memory_space=memory_space), + ], + out_specs=pl.BlockSpec(memory_space=memory_space), + ) + + jax.block_until_ready(z(x, y)) + jax.block_until_ready(jnp.dot(x, y)) + + out = jax.block_until_ready(z(x, y)) + expected_out = jax.block_until_ready(jnp.dot(x, y)) + + np.testing.assert_allclose(out, expected_out) + + @parameterized.named_parameters( + ('vmem', pltpu.TPUMemorySpace.VMEM), + ('hbm', pltpu.TPUMemorySpace.ANY), + ) + def test_double_pipeline_matmul(self, memory_space): + k1, k2 = jax.random.split(jax.random.key(0)) + x = jax.random.uniform(k1, (512, 512)) + y = jax.random.uniform(k2, (512, 512)) + + def matmul_pipeline(x_ref, y_ref, z_ref): + @pl.when(pl.program_id(2) == 0) + def _(): + z_ref[...] = jnp.zeros(z_ref.shape, jnp.float32) + + z_ref[...] += x_ref[...] @ y_ref[...] + + def matmul_kernel(x_ref, y_ref, z_ref): + + def emit_pipeline(should_accumulate_out): + pltpu.emit_pipeline( + matmul_pipeline, + grid=(4, 4, 4), + in_specs=[ + pl.BlockSpec((128, 128), lambda i, j, k: (i, k)), + pl.BlockSpec((128, 128), lambda i, j, k: (k, j)), + ], + out_specs=pl.BlockSpec((128, 128), lambda i, j, k: (i, j)), + should_accumulate_out=should_accumulate_out, + )(x_ref, y_ref, z_ref) + + emit_pipeline(False) + emit_pipeline(True) + + z = pl.pallas_call( + matmul_kernel, + out_shape=jax.ShapeDtypeStruct((512, 512), jnp.float32), + in_specs=[ + pl.BlockSpec(memory_space=memory_space), + pl.BlockSpec(memory_space=memory_space), + ], + out_specs=pl.BlockSpec(memory_space=memory_space), + )(x, y) + + np.testing.assert_allclose(z, jnp.dot(x, y) + jnp.dot(x, y)) + + +class PallasCallCollectivePipelineTest(parameterized.TestCase): + + def setUp(self): + if jax.device_count() < 2: + self.skipTest('Only >=2 devices are supported.') + if not jtu.is_device_tpu_at_least(5): + self.skipTest('Only works with TPU v5') + + super().setUp() + + @parameterized.named_parameters( + ('vmem', pltpu.TPUMemorySpace.VMEM, jnp.bfloat16, 2, 2, 2), + ('hbm', pltpu.TPUMemorySpace.ANY, jnp.bfloat16, 2, 2, 2), + ('hbm_float32', pltpu.TPUMemorySpace.ANY, jnp.float32, 2, 2, 2), + ('hbm_float32_112', pltpu.TPUMemorySpace.ANY, jnp.float32, 1, 1, 2), + ('hbm_float32_111', pltpu.TPUMemorySpace.ANY, jnp.float32, 1, 1, 1), + ) + def test_pipeline_latency_optimized_allgather_matmul( + self, memory_space, out_dtype, n_tiles, m_tiles, k_tiles): + input_dtype = jnp.float32 + + num_devices = jax.local_device_count() + + tn = 128 * 1 + tm = 128 * 1 + tk = 128 * 1 + n = tn * n_tiles + m = tm * m_tiles + k = tk * k_tiles * num_devices + + outer_steps = num_devices // 2 + + sharded_k = k // num_devices + inner_grid = (n // tn, m // tm, sharded_k // tk) + + inner_kernel = partial(basic_matmul_kernel, k=sharded_k) + + inner_allocs = [ + pltpu.BufferedRef.input( + pl.BlockSpec((tm, tk), lambda n, m, k: (m, k)), input_dtype), + pltpu.BufferedRef.input( + pl.BlockSpec((tk, tn), lambda n, m, k: (k, n)), input_dtype), + pltpu.BufferedRef.accumulator( + pl.BlockSpec((tm, tn), lambda n, m, k: (m, n)), out_dtype), + ] + + def all_gather_lhs_matmul_kernel( + # in refs: + lhs_ref, # [m, sharded_k] + rhs_ref, # [k, n] + # out refs: + out_ref, # [m, n] + # used as scratch but defined as an output so that it's resident in HBM: + lhs_scratch_ref, # [2 (bwd/fwd), 2 (work/buffer), m, sharded_k] + # scratch refs: + acc_scratch_ref, # [tm, tn] + bwd_recv_sem, + bwd_send_sem, + fwd_recv_sem, + fwd_send_sem, + lhs_bref, + rhs_bref, + out_bref + ): + # Outer Pipeline + # Coordinates collective rDMAs and prefetching around the inner compute + # pipeline. + + # Given shapes: + # lhs: A sharded 2d, jnp.ndarray with shape [m, k // axis_size]. + # rhs: A 2d, jnp.ndarray with shape [k, n]. + # Results: + # out: jnp.ndarray with shape [m, n]. + + # A bidirectional collective allgather-matmul sends two "streams" of LHS + # chunks in the forward and backward directions. These are matmul'd with + # their corresponding chunks in the RHS contracting dimension and added to + # a running accumulator. + + # We run the computation in N / 2 steps with 2 "phases" per step: + # - phase 0: in which we compute the backwards stream matmul. + # - phase 1: in which we compute the forward stream matmul. + + # In the prologue we initialize the backwards stream by using the local + # LHS shard, and the forwards stream by sending the local LHS shard + # "right" along the contractive sharding axis. + + # At each step afterwards we roll the fwd copies right, and the bwd copies + # left via rDMAs that are overlapped with matmuls. + + # At step n, phase p, we compute the following: + # let: + # idx = (axis_index + step) if p == 0 else (axis_index - step - 1) + # contraction_slice = idx * (k // axis_size) : (idx+1) * (k // axis_size) + # out[m, n] += lhs[m, contraction_slice] @ rhs[contraction_slice, n] + + # Where LHS slices are the corresponding shards passed along the "fwd" + # and "bwd" streams, and RHS is sliced directly from an unsharded array. + + outer_step = pl.program_id(0) # range [0, steps-1] + phase = pl.program_id(1) # range [0, 1] 0 == BWD Matmul, 1 == FWD Matmul + + # kernel start / end booleans + is_start = jnp.logical_and(outer_step == 0, phase == 0) + is_end = jnp.logical_and(outer_step == outer_steps - 1, phase == 1) + + # slots for double-buffered LHS scratch accumulator + # at each sub-step, working slot --> buffering slot + working_slot = lax.rem(outer_step, 2) + buffering_slot = 1 - working_slot + + # IDs of self and neighbors + my_id = lax.axis_index('x') + right_neighbor = mod(my_id + 1, num_devices) + left_neighbor = mod(my_id - 1, num_devices) + + # Async copy definitions. + + # NB: The send semaphore is what the sender uses to wait until the + # destination buffer is free. The recv semaphore is only readable by the + # destination core to wait until all data has arrived. (The completion of + # these sync flags can be multiple microseconds apart.) async wait() + # calls will only unblock after both for remote copies. + + # Initialize backwards stream by transfer of LHS chunks into local + # working copies. + initial_bwd_copy = pltpu.make_async_copy( + lhs_ref, + lhs_scratch_ref.at[0, working_slot], + bwd_send_sem, + ) + + # Initialize forwards stream by transfer of initial LHS chunks to right + # neighbors' working copies. + initial_fwd_copy = pltpu.make_async_remote_copy( + src_ref=lhs_ref, + dst_ref=lhs_scratch_ref.at[1, working_slot], + send_sem=fwd_send_sem, + recv_sem=fwd_recv_sem, + device_id=right_neighbor, + ) + + # Transfer working copies of LHS chunks backwards to left neighbors' + # buffering copies. + bwd_copy = pltpu.make_async_remote_copy( + src_ref=lhs_scratch_ref.at[0, working_slot], + dst_ref=lhs_scratch_ref.at[0, buffering_slot], + send_sem=bwd_send_sem, + recv_sem=bwd_recv_sem, + device_id=left_neighbor, + ) + + # Transfer working copies of LHS chunks forwards to right neighbors' + # buffering copies. + fwd_copy = pltpu.make_async_remote_copy( + src_ref=lhs_scratch_ref.at[1, working_slot], + dst_ref=lhs_scratch_ref.at[1, buffering_slot], + send_sem=fwd_send_sem, + recv_sem=fwd_recv_sem, + device_id=right_neighbor, + ) + + # Slice RHS to match LHS slices in bwd/fwd phases for contractions. + def get_rhs_slice(outer_step, phase): + bwd_rhs_offset = mod(my_id + outer_step, num_devices) + fwd_rhs_offset = mod(my_id - outer_step - 1, num_devices) + offset = jnp.where(phase, fwd_rhs_offset, bwd_rhs_offset) + return pl.ds( + pl.multiple_of(offset * sharded_k, sharded_k), + sharded_k, + ) + + # Fixed Ref schedule, only really needed to prevent HBM data race in the + # degenerate case of a trivial (single-step) inner loop. + accum_schedule = pltpu.get_pipeline_schedule('fixed') + # Tweak schedule to skip copying in initial accumulator data as we zero it + # out anyway. + for k in ['prologue_copy_in', 'wait_in', 'copy_in']: + accum_schedule[k] = functools.partial( # avoid cell-var-from-loop + lambda original_pred_fn, *a: original_pred_fn(*a) & ~is_start, + accum_schedule[k]) + + # Outer loop prologue + @pl.when(is_start) + @jax.named_scope('sync_and_bwd_init') + def _sync_and_bwd_init(): + # barrier at start + barrier_sem = pltpu.get_barrier_semaphore() + pltpu.semaphore_signal(barrier_sem, device_id=left_neighbor) + pltpu.semaphore_signal(barrier_sem, device_id=right_neighbor) + pltpu.semaphore_wait(barrier_sem, 2) + + # initializing copies + initial_bwd_copy.start() + initial_fwd_copy.start() + initial_bwd_copy.wait() + + @pl.when(jnp.logical_and(outer_step != outer_steps - 1, phase == 0)) + @jax.named_scope('send_next_dma') + def _send_next_dma(): + bwd_copy.start() + @pl.when(jnp.logical_not(is_start)) + def _send_next_fwd_dma(): + fwd_copy.start() + + # Cross-loop prefetch + def prefetch(lhs_bref, rhs_bref, out_bref, scheduler): + @pl.when(is_start) + @jax.named_scope('fwd_init') + def _fwd_init(): + initial_fwd_copy.wait() + fwd_copy.start() + + @pl.when(jnp.logical_and(outer_step != outer_steps - 1, phase == 1)) + @jax.named_scope('wait_on_prev_dma') + def _wait_on_prev_dma(): + bwd_copy.wait() + fwd_copy.wait() + + # prefetch next loop's inputs + prefetch_working_slot = jnp.where( + phase == 0, working_slot, buffering_slot) + prefetch_step = jnp.where(phase == 0, outer_step, outer_step + 1) + prefetch_phase = lax.rem(phase + 1, 2) + scheduler.prefetch( + lhs_bref, lhs_scratch_ref.at[prefetch_phase, prefetch_working_slot]) + scheduler.prefetch( + rhs_bref, rhs_ref.at[get_rhs_slice(prefetch_step, prefetch_phase)]) + scheduler.prefetch(out_bref, out_ref, accum_schedule) + + pltpu.emit_pipeline(inner_kernel, grid=inner_grid)( + lhs_scratch_ref.at[phase, working_slot], + rhs_ref.at[get_rhs_slice(outer_step, phase)], + out_ref, + allocations=[lhs_bref, rhs_bref, out_bref], + scratches=[acc_scratch_ref], + first_cycle=is_start, + last_cycle=is_end, + init_accumulators=is_start, + prefetch=prefetch, + schedule=[None, None, accum_schedule] + ) + + kernel = pl.pallas_call( + all_gather_lhs_matmul_kernel, + out_shape=[ + jax.ShapeDtypeStruct((m, n), out_dtype), + jax.ShapeDtypeStruct((2, 2, m, sharded_k), input_dtype), + ], + grid_spec=pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=0, + in_specs=[ + pl.BlockSpec(memory_space=memory_space), + pl.BlockSpec(memory_space=memory_space), + ], + out_specs=[pl.BlockSpec(memory_space=memory_space), + pl.BlockSpec(memory_space=memory_space)], + grid=(outer_steps, 2), + scratch_shapes=[ + pltpu.VMEM((tm, tn), jnp.float32)] + + [pltpu.SemaphoreType.DMA] * 4 + + inner_allocs + ), + compiler_params=dict( + mosaic=dict(collective_id=0, + # must set scoped vmem flag *larger* than below! e.g.: + # flags.FLAGS.xla_tpu_scoped_vmem_limit_kib = 131072 + vmem_limit_bytes=int(134217728 * 0.9) # 0.9 * 128MiB + ) + ), + ) + + shard = partial( + shard_map.shard_map, + mesh=jax.sharding.Mesh( + mesh_utils.create_device_mesh((num_devices,), jax.devices()), + ['x'], + ), + in_specs=(P(None, 'x'), P(None, None)), + out_specs=P(None, None), + check_rep=False, + ) + + test = jax.jit(shard(kernel)) + + @jax.jit + @shard + def reference(x, y): + x = jax.lax.all_gather(x, 'x', axis=1, tiled=True) + return jnp.dot(x, y, preferred_element_type=out_dtype) + + k1, k2 = jax.random.split(jax.random.key(42)) + x = jax.random.uniform( + k1, (m, k), dtype=input_dtype, minval=-1, maxval=1) + y = jax.random.uniform( + k2, (k, n), dtype=input_dtype, minval=-1, maxval=1 + ) + + out = jax.block_until_ready(test(x, y)[0]) + expected_out = jax.block_until_ready(reference(x, y)) + + np.testing.assert_allclose( + out.astype(jnp.float32), + expected_out.astype(jnp.float32), + atol=1 if out_dtype == jnp.float32 else 5, + ) + + @parameterized.named_parameters( + ('vmem', pltpu.TPUMemorySpace.VMEM, jnp.bfloat16, 2, 2, 2), + ('hbm', pltpu.TPUMemorySpace.ANY, jnp.bfloat16, 2, 2, 2), + ('hbm_float32', pltpu.TPUMemorySpace.ANY, jnp.float32, 2, 2, 2), + ('hbm_float32_122', pltpu.TPUMemorySpace.ANY, jnp.float32, 1, 2, 2), + ('hbm_float32_121', pltpu.TPUMemorySpace.ANY, jnp.float32, 1, 2, 1), + ) + def test_pipeline_throughput_optimized_allgather_matmul( + self, memory_space, out_dtype, n_tiles, m_tiles, k_tiles): + input_dtype = out_dtype + num_devices = jax.local_device_count() + + tn = 128 + tm = 128 + tk = 128 + n = tn * n_tiles + m = tm * m_tiles # subsplit on this dim! + k = tk * k_tiles * num_devices + + outer_steps = num_devices + + sharded_k = k // num_devices + half_m = m // 2 + inner_grid = (n // tn, half_m // tm, sharded_k // tk) + + inner_kernel = partial(basic_matmul_kernel, k=sharded_k) + + inner_allocs = [ + pltpu.BufferedRef.input( + pl.BlockSpec((tm, tk), lambda n, m, k: (m, k)), input_dtype), + pltpu.BufferedRef.input( + pl.BlockSpec((tk, tn), lambda n, m, k: (k, n)), input_dtype), + pltpu.BufferedRef.accumulator( + pl.BlockSpec((tm, tn), lambda n, m, k: (m, n)), out_dtype), + ] + + def all_gather_lhs_matmul_kernel( + # in refs: + lhs_ref, # [m, sharded_k] + rhs_ref, # [k, n] + # out refs: + out_ref, # [m, n] + # used as scratch but defined as an output so that it's resident in HBM: + lhs_scratch_ref, # [2 (bwd/fwd), 2 (work/buffer), m//2, sharded_k] + # scratch refs: + acc_scratch_ref, # [tm, tn] + bwd_recv_sem, + bwd_send_sem, + fwd_recv_sem, + fwd_send_sem, + lhs_bref, + rhs_bref, + out_bref + ): + outer_step = pl.program_id(0) # range [0, steps-1] + phase = pl.program_id(1) # range [0, 1] 0 == BWD Matmul, 1 == FWD Matmul + + # kernel start / end booleans + is_start = jnp.logical_and(outer_step == 0, phase == 0) + is_end = jnp.logical_and(outer_step == outer_steps - 1, phase == 1) + + # slots for double-buffered LHS scratch accumulator + # at each sub-step, working slot --> buffering slot + working_slot = lax.rem(outer_step, 2) + buffering_slot = 1 - working_slot + + # IDs of self and neighbors + my_id = lax.axis_index('x') + right_neighbor = mod(my_id + 1, num_devices) + left_neighbor = mod(my_id - 1, num_devices) + + # Initialize backwards stream by transfer of LHS chunks into local + # working copies. + initial_bwd_copy = pltpu.make_async_copy( + lhs_ref.at[make_ds(1, m//2)], + lhs_scratch_ref.at[0, working_slot], + bwd_send_sem, + ) + + # Initialize forwards stream by transfer of LHS chunks into local + # working copies. + initial_fwd_copy = pltpu.make_async_copy( + lhs_ref.at[make_ds(0, m//2)], + lhs_scratch_ref.at[1, working_slot], + bwd_send_sem, + ) + + # Transfer working copies of LHS chunks backwards to left neighbors' + # buffering copies. + bwd_copy = pltpu.make_async_remote_copy( + src_ref=lhs_scratch_ref.at[0, working_slot], + dst_ref=lhs_scratch_ref.at[0, buffering_slot], + send_sem=bwd_send_sem, + recv_sem=bwd_recv_sem, + device_id=left_neighbor, + ) + + # Transfer working copies of LHS chunks forwards to right neighbors' + # buffering copies. + fwd_copy = pltpu.make_async_remote_copy( + src_ref=lhs_scratch_ref.at[1, working_slot], + dst_ref=lhs_scratch_ref.at[1, buffering_slot], + send_sem=fwd_send_sem, + recv_sem=fwd_recv_sem, + device_id=right_neighbor, + ) + + # Slice RHS to match LHS slices in bwd/fwd phases for contractions. + def get_rhs_slice(outer_step, phase): + bwd_rhs_offset = mod(my_id + outer_step, num_devices) + fwd_rhs_offset = mod(my_id - outer_step, num_devices) + offset = jnp.where(phase, fwd_rhs_offset, bwd_rhs_offset) + return make_ds(offset, sharded_k) + + def get_half(phase): + return make_ds(jnp.where(phase, 0, 1), m//2) + + # Loop Prologue + @pl.when(is_start) + @jax.named_scope('sync_and_bwd_init') + def _sync_and_bwd_init(): + # barrier at start + barrier_sem = pltpu.get_barrier_semaphore() + pltpu.semaphore_signal(barrier_sem, device_id=left_neighbor) + pltpu.semaphore_signal(barrier_sem, device_id=right_neighbor) + pltpu.semaphore_wait(barrier_sem, 2) + # initializing copies + initial_bwd_copy.start() + initial_fwd_copy.start() + initial_bwd_copy.wait() + + @pl.when(jnp.logical_and(outer_step != outer_steps - 1, phase == 0)) + @jax.named_scope('send_next_dma') + def _send_next_dma(): + bwd_copy.start() + @pl.when(jnp.logical_not(is_start)) + def _send_next_fwd_dma(): + fwd_copy.start() + + # Loop Prefetch + def prefetch(lhs_bref, rhs_bref, out_bref, scheduler): + @pl.when(is_start) + @jax.named_scope('fwd_init') + def _fwd_init(): + initial_fwd_copy.wait() + fwd_copy.start() + + @pl.when(jnp.logical_and(outer_step != outer_steps - 1, phase == 1)) + @jax.named_scope('wait_on_prev_dma') + def _wait_on_prev_dma(): + bwd_copy.wait() + fwd_copy.wait() + + # prefetch next inputs + next_working_slot = jnp.where( + phase == 0, working_slot, buffering_slot) + next_step = jnp.where(phase == 0, outer_step, outer_step + 1) + next_phase = lax.rem(phase + 1, 2) + scheduler.prefetch( + lhs_bref, lhs_scratch_ref.at[next_phase, next_working_slot]) + scheduler.prefetch( + rhs_bref, rhs_ref.at[get_rhs_slice(next_step, next_phase)]) + scheduler.prefetch( + out_bref, out_ref.at[get_half(next_phase)]) + + pltpu.emit_pipeline(inner_kernel, grid=inner_grid)( + lhs_scratch_ref.at[phase, working_slot], + rhs_ref.at[get_rhs_slice(outer_step, phase)], + out_ref.at[get_half(phase)], + allocations=[lhs_bref, rhs_bref, out_bref], + scratches=[acc_scratch_ref], + first_cycle=is_start, + last_cycle=is_end, + init_accumulators=outer_step == 0, + prefetch=prefetch, + ) + + kernel = pl.pallas_call( + all_gather_lhs_matmul_kernel, + out_shape=[ + jax.ShapeDtypeStruct((m, n), out_dtype), + jax.ShapeDtypeStruct((2, 2, half_m, sharded_k), input_dtype), + ], + grid_spec=pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=0, + in_specs=[ + pl.BlockSpec(memory_space=memory_space), + pl.BlockSpec(memory_space=memory_space), + ], + out_specs=[pl.BlockSpec(memory_space=memory_space), + pl.BlockSpec(memory_space=memory_space)], + grid=(outer_steps, 2), + scratch_shapes=[ + pltpu.VMEM((tm, tn), jnp.float32)] + + [pltpu.SemaphoreType.DMA] * 4 + + inner_allocs + ), + compiler_params=dict( + mosaic=dict(collective_id=0, + # must set scoped vmem flag *larger* than below! e.g.: + # flags.FLAGS.xla_tpu_scoped_vmem_limit_kib = 131072 + vmem_limit_bytes=int(134217728 * 0.9) # 0.9 * 128MiB + ) + ), + ) + + shard = partial( + shard_map.shard_map, + mesh=jax.sharding.Mesh( + mesh_utils.create_device_mesh((num_devices,), jax.devices()), + ['x'], + ), + in_specs=(P(None, 'x'), P(None, None)), + out_specs=P(None, None), + check_rep=False, + ) + + test = jax.jit(shard(kernel)) + + @jax.jit + @shard + def reference(x, y): + x = jax.lax.all_gather(x, 'x', axis=1, tiled=True) + return jnp.dot(x, y, preferred_element_type=out_dtype) + + k1, k2 = jax.random.split(jax.random.key(42)) + x = jax.random.uniform( + k1, (m, k), dtype=input_dtype, minval=-1, maxval=1) + y = jax.random.uniform( + k2, (k, n), dtype=input_dtype, minval=-1, maxval=1 + ) + + out = jax.block_until_ready(test(x, y)[0]) + expected_out = jax.block_until_ready(reference(x, y)) + + np.testing.assert_allclose( + out.astype(jnp.float32), + expected_out.astype(jnp.float32), + atol=1 if out_dtype == jnp.float32 else 5, + ) + + @parameterized.named_parameters( + ('vmem', pltpu.TPUMemorySpace.VMEM, jnp.bfloat16, 2, 2, 2), + ('hbm', pltpu.TPUMemorySpace.ANY, jnp.bfloat16, 2, 2, 2), + ('hbm_float32', pltpu.TPUMemorySpace.ANY, jnp.float32, 2, 4, 2), + ('hbm_float32_112', pltpu.TPUMemorySpace.ANY, jnp.float32, 1, 1, 2), + ('hbm_float32_111', pltpu.TPUMemorySpace.ANY, jnp.float32, 1, 1, 1), + ) + def test_pipeline_latency_optimized_matmul_reducescatter( + self, memory_space, out_dtype, n_tiles, m_tiles, k_tiles): + input_dtype = jnp.float32 + num_devices = jax.device_count() + + tn = 128 * 1 + tm = 128 * 1 + tk = 128 * 1 + n = tn * n_tiles + m = tm * m_tiles * num_devices + k = tk * k_tiles * num_devices + + sharded_m = m // num_devices + sharded_k = k // num_devices + inner_grid = (n // tn, sharded_m // tm, sharded_k // tk) + outer_steps = num_devices // 2 + reduce_grid = (sharded_m // tm,) + + inner_kernel = partial(basic_matmul_kernel, k=sharded_k) + + def reduce_kernel( + out_ref, # [tm, tn] + rs_accum_scratch_ref, # [tm, tn] + ): + rs_accum_scratch_ref[...] = out_ref[...] + + inner_allocs = [ + pltpu.BufferedRef.input( + pl.BlockSpec((tm, tk), lambda n, m, k: (m, k)), input_dtype), + pltpu.BufferedRef.input( + pl.BlockSpec((tk, tn), lambda n, m, k: (k, n)), input_dtype), + pltpu.BufferedRef.accumulator( + pl.BlockSpec((tm, tn), lambda n, m, k: (m, n)), out_dtype), + # only used for final addition of fwd + bwd streams. + pltpu.BufferedRef.input( + pl.BlockSpec((tm, n), lambda m: (m, 0)), out_dtype), + pltpu.BufferedRef.accumulator( + pl.BlockSpec((tm, n), lambda m: (m, 0)), out_dtype), + ] + + def reduce_scatter_lhs_matmul_kernel( + # in refs: + lhs_ref, # [sharded_m, sharded_k] + rhs_ref, # [sharded_k, n] + # out refs: + accumulator_ref, # [2 (bwd/fwd), 2 (work/buffer), sharded_m, n] + # scratch refs: + acc_scratch_ref, # [tm, tn] + bwd_recv_sem, + bwd_send_sem, + fwd_recv_sem, + fwd_send_sem, + lhs_bref, + rhs_bref, + out_bref, + reduce_in_bref, + reduce_out_bref, + ): + outer_step = pl.program_id(0) # range [0, outer_steps-1] + phase = pl.program_id(1) # range [0, 1] 0 == BWD Matmul, 1 == FWD Matmul + + num_inner_steps = _grid_size(inner_grid) + trivial_loop = num_inner_steps == 1 + + # kernel start / end booleans + is_start = jnp.logical_and(outer_step == 0, phase == 0) + is_end = jnp.logical_and(outer_step == outer_steps - 1, phase == 1) + + # slots for double-buffered accumulator + # at each sub-step, working slot --> buffering slot + working_slot = lax.rem(outer_step, 2) + buffering_slot = 1 - working_slot + + # IDs of self and neighbors + my_id = lax.axis_index('x') + right_neighbor = mod(my_id + 1, num_devices) + left_neighbor = mod(my_id - 1, num_devices) + + # Async copy definitions: + + # Transfer accumulator chunks backwards to left neighbors + bwd_copy = pltpu.make_async_remote_copy( + # buffering <--> working swapped as this is run in a subsequent step. + src_ref=accumulator_ref.at[1, buffering_slot], + dst_ref=accumulator_ref.at[1, working_slot], + send_sem=bwd_send_sem, + recv_sem=bwd_recv_sem, + device_id=left_neighbor, + ) + + # Transfer accumulator chunks forwards to right neighbors + fwd_copy = pltpu.make_async_remote_copy( + src_ref=accumulator_ref.at[0, working_slot], + dst_ref=accumulator_ref.at[0, buffering_slot], + send_sem=fwd_send_sem, + recv_sem=fwd_recv_sem, + device_id=right_neighbor, + ) + + # Slice RHS slices in bwd/fwd phases for contractions. + def get_lhs_slice(step, phase): + bwd_lhs_offset = mod(my_id + step + num_devices//2 + 1, num_devices) + fwd_lhs_offset = mod(my_id - step - num_devices//2, num_devices) + offset = jnp.where(phase, bwd_lhs_offset, fwd_lhs_offset) + return ( + pl.ds(pl.multiple_of(offset * sharded_m, sharded_m), sharded_m), + pl.ds(pl.multiple_of(0, sharded_k), sharded_k), + ) + + # Outer Loop Prologue + @pl.when(is_start) + @jax.named_scope('sync') + def _sync_barrier(): + # barrier at start + barrier_sem = pltpu.get_barrier_semaphore() + pltpu.semaphore_signal(barrier_sem, device_id=left_neighbor) + pltpu.semaphore_signal(barrier_sem, device_id=right_neighbor) + pltpu.semaphore_wait(barrier_sem, 2) + + # Writeback previous outputs on first step of our present cycle + def postyeet(lhs_bref, rhs_bref, out_bref, scheduler): + del lhs_bref, rhs_bref + + @pl.when(~is_start) + def _rdmas(): + @pl.when(phase == 1) + @jax.named_scope('send_prev_fwd_dma') + def _send_prev_fwd_dma(): + fwd_copy.start() + @pl.when(phase == 0) + @jax.named_scope('send_prev_bwd_dma') + def _send_prev_bwd_dma(): + bwd_copy.start() + + # When the inner matmul loop consists of a single iteration, we have + # no opportunity to overlap in this loop and must block immediately. + @pl.when(trivial_loop) + def _prefetch_accumulator_late(): + @pl.when(~is_start) + def _rdmas(): + @pl.when(phase == 1) + @jax.named_scope('send_prev_fwd_dma') + def _send_prev_fwd_dma(): + fwd_copy.wait() + @pl.when(phase == 0) + @jax.named_scope('send_prev_bwd_dma') + def _send_prev_bwd_dma(): + bwd_copy.wait() + + # deferred "prefetch" + next_slot = jnp.where(phase == 0, working_slot, buffering_slot) + next_phase = 1 - phase + scheduler.prefetch(out_bref, + accumulator_ref.at[next_phase, next_slot]) + + # Prefetch next inputs on last step of our present cycle + def prefetch(lhs_bref, rhs_bref, out_bref, scheduler): + @pl.when(~is_start & ~trivial_loop) + def _wait_dmas(): + @pl.when(phase == 1) + @jax.named_scope('wait_prev_fwd_dma') + def _wait_prev_fwd_dma(): + fwd_copy.wait() + @pl.when(phase == 0) + @jax.named_scope('wait_prev_bwd_dma') + def _wait_prev_bwd_dma(): + bwd_copy.wait() + + # prefetch next inputs + next_working_slot = jnp.where(phase == 0, working_slot, buffering_slot) + next_step = jnp.where(phase == 0, outer_step, outer_step + 1) + next_phase = lax.rem(phase + 1, 2) + scheduler.prefetch( + lhs_bref, lhs_ref.at[get_lhs_slice(next_step, next_phase)]) + scheduler.prefetch( + rhs_bref, rhs_ref) + # When the inner matmul loop consists of a single iteration, we need + # to avoid optimistic prefetch to avoid a data race. + @pl.when(~trivial_loop) + def _prefetch_accum(): + scheduler.prefetch( + out_bref, accumulator_ref.at[next_phase, next_working_slot]) + + # Run matmul pipeline + pltpu.emit_pipeline(inner_kernel, grid=inner_grid)( + lhs_ref.at[get_lhs_slice(outer_step, phase)], + rhs_ref, + accumulator_ref.at[phase, working_slot], + allocations=[lhs_bref, rhs_bref, out_bref], + scratches=[acc_scratch_ref], + first_cycle=is_start, + last_cycle=is_end, + init_accumulators=outer_step == 0, + prefetch=prefetch, + postyeet=postyeet, + ) + + # Add forwards and backwards stream results together + # Is it really advantageous to do this here rather than doing a simple + # addition outside? + @pl.when(is_end) + def _loop_epilogue(): + pltpu.emit_pipeline(reduce_kernel, grid=reduce_grid)( + accumulator_ref.at[1, 1], # <-- 1,1/0,0 always correct? + accumulator_ref.at[0, 0], + allocations=[reduce_in_bref, reduce_out_bref], + scratches=[], + first_cycle=True, + last_cycle=True, + init_accumulators=False, + ) + + kernel = pl.pallas_call( + reduce_scatter_lhs_matmul_kernel, + out_shape=jax.ShapeDtypeStruct((2, 2, sharded_m, n), out_dtype), + grid_spec=pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=0, + in_specs=[ + pl.BlockSpec(memory_space=memory_space), + pl.BlockSpec(memory_space=memory_space), + ], + out_specs=pl.BlockSpec(memory_space=memory_space), + grid=(outer_steps, 2), + scratch_shapes=[pltpu.VMEM((tm, tn), jnp.float32)] + + [pltpu.SemaphoreType.DMA] * 4 + + inner_allocs + ), + compiler_params=dict( + mosaic=dict( + collective_id=0, + # must set scoped vmem flag *larger* than below! + # e.g. flags.FLAGS.xla_tpu_scoped_vmem_limit_kib = 131072 + vmem_limit_bytes=int(134217728 * 0.9) # 0.9 * 128MiB + ) + ), + ) + + shard = partial( + shard_map.shard_map, + mesh=jax.sharding.Mesh( + mesh_utils.create_device_mesh( + (num_devices,), jax.devices()[:num_devices]), + ['x'], + ), + in_specs=(P(None, 'x'), P('x', None)), + out_specs=P('x', None), + check_rep=False, + ) + + test = jax.jit(shard(lambda x, y: kernel(x, y)[0, 0])) + + @jax.jit + @shard + def reference(x, y): + unreduced = jnp.dot(x, y, preferred_element_type=out_dtype) + return lax.psum_scatter( + unreduced, 'x', scatter_dimension=0, tiled=True) + + k1, k2 = jax.random.split(jax.random.key(42)) + x = jax.random.uniform( + k1, (m, k), dtype=input_dtype, minval=-1, maxval=1) + y = jax.random.uniform( + k2, (k, n), dtype=input_dtype, minval=-1, maxval=1 + ) + + out = jax.block_until_ready(test(x, y)) + expected_out = jax.block_until_ready(reference(x, y)) + + np.testing.assert_allclose( + out.astype(jnp.float32), + expected_out.astype(jnp.float32), + atol=1 if out_dtype == jnp.float32 else 5, + ) + + np.mean(np.abs(out - expected_out)) + + @parameterized.named_parameters( + ('vmem', pltpu.TPUMemorySpace.VMEM, jnp.bfloat16, 2, 2, 2), + ('hbm', pltpu.TPUMemorySpace.ANY, jnp.bfloat16, 2, 2, 2), + ('hbm_float32', pltpu.TPUMemorySpace.ANY, jnp.float32, 2, 4, 2), + ('hbm_float32_112', pltpu.TPUMemorySpace.ANY, jnp.float32, 1, 2, 2), + ('hbm_float32_111', pltpu.TPUMemorySpace.ANY, jnp.float32, 1, 2, 1), + ) + def test_pipeline_throughput_optimized_matmul_reducescatter( + self, memory_space, out_dtype, n_tiles, m_tiles, k_tiles): + input_dtype = jnp.float32 + num_devices = jax.device_count() + + tn = 128 * 1 + tm = 128 * 1 + tk = 128 * 1 + n = tn * n_tiles + m = tm * m_tiles * num_devices # subsplit dim + k = tk * k_tiles * num_devices + + sharded_m = m // num_devices + half_m = sharded_m // 2 + sharded_k = k // num_devices + inner_grid = (n // tn, half_m // tm, sharded_k // tk) + outer_steps = num_devices + inner_kernel = partial(basic_matmul_kernel, k=sharded_k) + + inner_allocs = [ + pltpu.BufferedRef.input( + pl.BlockSpec((tm, tk), lambda n, m, k: (m, k)), input_dtype), + pltpu.BufferedRef.input( + pl.BlockSpec((tk, tn), lambda n, m, k: (k, n)), input_dtype), + pltpu.BufferedRef.accumulator( + pl.BlockSpec((tm, tn), lambda n, m, k: (m, n)), out_dtype), + ] + + def reduce_scatter_lhs_matmul_kernel( + # in refs: + lhs_ref, # [sharded_m, sharded_k] + rhs_ref, # [sharded_k, n] + # out refs: + rs_accum_scratch_ref, # [2 (work/buffer), 2*sharded_m//2 (fwd/bwd), n] + # scratch refs: + acc_scratch_ref, # [tm, tn] + bwd_recv_sem, + bwd_send_sem, + fwd_recv_sem, + fwd_send_sem, + lhs_bref, + rhs_bref, + out_bref, + ): + outer_step = pl.program_id(0) # range [0, outer_steps-1] + phase = pl.program_id(1) # range [0, 1] 0 == BWD Matmul, 1 == FWD Matmul + + num_inner_steps = _grid_size(inner_grid) + trivial_loop = num_inner_steps == 1 + + # kernel start / end booleans + is_start = jnp.logical_and(outer_step == 0, phase == 0) + is_end = jnp.logical_and(outer_step == outer_steps - 1, phase == 1) + + # slots for double-buffered accumulator + # at each sub-step, working slot --> buffering slot + working_slot = lax.rem(outer_step, 2) + buffering_slot = 1 - working_slot + + # IDs of self and neighbors + my_id = lax.axis_index('x') + right_neighbor = mod(my_id + 1, num_devices) + left_neighbor = mod(my_id - 1, num_devices) + + # Async copy definitions: + + # Transfer accumulator chunks backwards to left neighbors + bwd_copy = pltpu.make_async_remote_copy( + # buffering <--> working swapped as this is run in a subsequent step. + src_ref=rs_accum_scratch_ref.at[buffering_slot, make_ds(1, half_m)], + dst_ref=rs_accum_scratch_ref.at[working_slot, make_ds(1, half_m)], + send_sem=bwd_send_sem, + recv_sem=bwd_recv_sem, + device_id=left_neighbor, + ) + + # Transfer accumulator chunks forwards to right neighbors + fwd_copy = pltpu.make_async_remote_copy( + src_ref=rs_accum_scratch_ref.at[working_slot, make_ds(0, half_m)], + dst_ref=rs_accum_scratch_ref.at[buffering_slot, make_ds(0, half_m)], + send_sem=fwd_send_sem, + recv_sem=fwd_recv_sem, + device_id=right_neighbor, + ) + + # Slice RHS slices in bwd/fwd phases for contractions. + def get_lhs_slice(step, phase): + bwd_lhs_offset = 2 * mod(my_id + step + 1, num_devices) + 1 + fwd_lhs_offset = 2 * mod(my_id - step - 1, num_devices) + offset = jnp.where(phase, bwd_lhs_offset, fwd_lhs_offset) + return ( + pl.ds(pl.multiple_of(offset * half_m, half_m), half_m), + pl.ds(pl.multiple_of(0, sharded_k), sharded_k), + ) + def get_accum_slice(phase, slot): + return (slot, make_ds(phase, half_m)) + + # Outer Loop Prologue + @pl.when(is_start) + @jax.named_scope('sync') + def _sync_and_bwd_init(): + # barrier at start + barrier_sem = pltpu.get_barrier_semaphore() + pltpu.semaphore_signal(barrier_sem, device_id=left_neighbor) + pltpu.semaphore_signal(barrier_sem, device_id=right_neighbor) + pltpu.semaphore_wait(barrier_sem, 2) + + # Writeback previous outputs on first step of our present cycle + def postyeet(lhs_bref, rhs_bref, out_bref, scheduler): + del lhs_bref, rhs_bref + @pl.when(~is_start & ~is_end) + def _rdmas(): + @pl.when(phase == 1) + @jax.named_scope('send_prev_fwd_dma') + def _send_prev_fwd_dma(): + fwd_copy.start() + @pl.when(phase == 0) + @jax.named_scope('send_prev_bwd_dma') + def _send_prev_bwd_dma(): + bwd_copy.start() + + # When the inner matmul loop consists of a single iteration, we have + # no opportunity to overlap in this loop and must block immediately. + @pl.when(trivial_loop) + def _prefetch_accumulator_late(): + @pl.when(~is_start & ~is_end) + def _wait_dmas(): + @pl.when(phase == 1) + @jax.named_scope('wait_prev_fwd_dma') + def _wait_prev_fwd_dma(): + fwd_copy.wait() + @pl.when(phase == 0) + @jax.named_scope('wait_prev_bwd_dma') + def _wait_prev_bwd_dma(): + bwd_copy.wait() + # deferred "prefetch" + next_working_slot = jnp.where( + phase == 0, working_slot, buffering_slot) + next_phase = 1 - phase + scheduler.prefetch( + out_bref, rs_accum_scratch_ref.at[ + get_accum_slice(next_phase, next_working_slot)]) + + # Prefetch next inputs on last step of our present cycle + def prefetch(lhs_bref, rhs_bref, out_bref, scheduler): + @pl.when(~is_start & ~is_end & ~trivial_loop) + def _wait_dmas(): + @pl.when(phase == 1) + @jax.named_scope('wait_prev_fwd_dma') + def _wait_prev_fwd_dma(): + fwd_copy.wait() + @pl.when(phase == 0) + @jax.named_scope('wait_prev_bwd_dma') + def _wait_prev_bwd_dma(): + bwd_copy.wait() + + # prefetch next inputs + next_working_slot = jnp.where(phase == 0, working_slot, buffering_slot) + next_step = jnp.where(phase == 0, outer_step, outer_step + 1) + next_phase = lax.rem(phase + 1, 2) + scheduler.prefetch(lhs_bref, + lhs_ref.at[get_lhs_slice(next_step, next_phase)]) + scheduler.prefetch(rhs_bref, rhs_ref) + @pl.when(~trivial_loop) + def _prefetch_accumulator(): + scheduler.prefetch( + out_bref, rs_accum_scratch_ref.at[ + get_accum_slice(next_phase, next_working_slot)]) + + # Run matmul pipeline + pltpu.emit_pipeline(inner_kernel, grid=inner_grid)( + lhs_ref.at[get_lhs_slice(outer_step, phase)], + rhs_ref, + rs_accum_scratch_ref.at[get_accum_slice(phase, working_slot)], + allocations=[lhs_bref, rhs_bref, out_bref], + scratches=[acc_scratch_ref], + first_cycle=is_start, + last_cycle=is_end, + init_accumulators=outer_step == 0, + prefetch=prefetch, + postyeet=postyeet, + ) + + kernel = pl.pallas_call( + reduce_scatter_lhs_matmul_kernel, + out_shape=jax.ShapeDtypeStruct((2, sharded_m, n), out_dtype), + grid_spec=pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=0, + in_specs=[ + pl.BlockSpec(memory_space=memory_space), + pl.BlockSpec(memory_space=memory_space), + ], + out_specs=pl.BlockSpec(memory_space=memory_space), + grid=(outer_steps, 2), + scratch_shapes=[pltpu.VMEM((tm, tn), jnp.float32)] + + [pltpu.SemaphoreType.DMA] * 4 + + inner_allocs + ), + compiler_params=dict( + mosaic=dict( + collective_id=0, + # must set scoped vmem flag *larger* than below! + # e.g. flags.FLAGS.xla_tpu_scoped_vmem_limit_kib = 131072 + vmem_limit_bytes=int(134217728 * 0.9) # 0.9 * 128MiB + ) + ), + ) + + shard = partial( + shard_map.shard_map, + mesh=jax.sharding.Mesh( + mesh_utils.create_device_mesh( + (num_devices,), jax.devices()[:num_devices]), + ['x'], + ), + in_specs=(P(None, 'x'), P('x', None)), + out_specs=P('x', None), + check_rep=False, + ) + + test = jax.jit(shard(lambda x, y: kernel(x, y)[1])) + + @jax.jit + @shard + def reference(x, y): + unreduced = jnp.dot(x, y, preferred_element_type=out_dtype) + return lax.psum_scatter( + unreduced, 'x', scatter_dimension=0, tiled=True) + + k1, k2 = jax.random.split(jax.random.key(42)) + x = jax.random.uniform( + k1, (m, k), dtype=input_dtype, minval=-1, maxval=1) + y = jax.random.uniform( + k2, (k, n), dtype=input_dtype, minval=-1, maxval=1 + ) + + out = jax.block_until_ready(test(x, y)) + expected_out = jax.block_until_ready(reference(x, y)) + + np.testing.assert_allclose( + out.astype(jnp.float32), + expected_out.astype(jnp.float32), + atol=1 if out_dtype == jnp.float32 else 5, + ) + + +class PallasCallMegacoreTest(parameterized.TestCase): + + def setUp(self): + if not jtu.is_device_tpu_at_least(4): + self.skipTest('Only works with TPU v4') + + super().setUp() + + def test_can_partition_nondivisible_grid_with_dynamic_dimensions(self): + + def mul_pipeline(x_ref, y_ref): + y_ref[...] = x_ref[...] * 2 + + def mul_kernel(iters_ref, x_ref, y_ref): + pltpu.emit_pipeline( + mul_pipeline, + grid=(iters_ref[0], 5), + in_specs=[ + pl.BlockSpec((128, 128), lambda i, j: (i, j)), + ], + out_specs=pl.BlockSpec((128, 128), lambda i, j: (i, j)), + core_axis=0, + dimension_semantics=(pltpu.PARALLEL, pltpu.PARALLEL), + )(x_ref, y_ref) + + num_cores = jax.devices()[0].num_cores + func = pl.pallas_call( + mul_kernel, + out_shape=jax.ShapeDtypeStruct((640, 640), jnp.float32), + grid_spec=pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=1, + in_specs=[ + pl.BlockSpec(memory_space=pltpu.ANY), + ], + out_specs=pl.BlockSpec(memory_space=pltpu.ANY), + grid=(num_cores,), + ), + compiler_params=dict(mosaic=dict(dimension_semantics=('parallel',))), + ) + x = jax.random.uniform(jax.random.key(0), (640, 640)) + np.testing.assert_allclose(func(jnp.array([5]), x), x * 2) + + def test_megacore_mul(self): + x = jax.random.uniform(jax.random.key(0), (512, 512)) + + def matmul_pipeline(x_ref, y_ref): + y_ref[...] = x_ref[...] * 2 + + def matmul_kernel(x_ref, y_ref): + pltpu.emit_pipeline( + matmul_pipeline, + grid=(4, 4), + in_specs=[ + pl.BlockSpec((128, 128), lambda i, j: (i, j)), + ], + out_specs=pl.BlockSpec((128, 128), lambda i, j: (i, j)), + core_axis=0, + dimension_semantics=(pltpu.ARBITRARY, pltpu.PARALLEL) + )(x_ref, y_ref) + + num_cores = jax.devices()[0].num_cores + func = pl.pallas_call( + matmul_kernel, + out_shape=jax.ShapeDtypeStruct((512, 512), jnp.float32), + in_specs=[ + pl.BlockSpec(memory_space=pltpu.ANY), + ], + out_specs=pl.BlockSpec(memory_space=pltpu.ANY), + grid=(num_cores,), + compiler_params=dict(mosaic=dict(dimension_semantics=('parallel',))), + ) + np.testing.assert_allclose(func(x), x * 2) + + @parameterized.parameters( + (1024, 1024, 1024, 256, 512, 256), + (768, 1024, 1024, 256, 512, 256), + (1024, 1024, 768, 256, 512, 256), + (768, 1024, 768, 256, 512, 256), + ) + def test_megacore_matmul(self, m, k, n, bm, bk, bn): + k1, k2 = jax.random.split(jax.random.key(42)) + x = jax.random.uniform(k1, (m, k)) + y = jax.random.uniform(k2, (k, n)) + + def matmul_pipeline(x_ref, y_ref, z_ref): + @pl.when(pl.program_id(2) == 0) + def _(): + z_ref[...] = jnp.zeros_like(z_ref) + z_ref[...] += x_ref[...] @ y_ref[...] + + def matmul_kernel(x_ref, y_ref, z_ref, *, bm, bk, bn): + m, k = x_ref.shape + _, n = y_ref.shape + assert k % bk == 0 + pltpu.emit_pipeline( + matmul_pipeline, + grid=(pl.cdiv(m, bm), pl.cdiv(n, bn), pl.cdiv(k, bk)), + in_specs=[ + pl.BlockSpec((bm, bk), lambda i, j, k: (i, k)), + pl.BlockSpec((bk, bn), lambda i, j, k: (k, j)), + ], + out_specs=pl.BlockSpec((bm, bn), lambda i, j, k: (i, j)), + core_axis=0, + dimension_semantics=(pltpu.PARALLEL, pltpu.PARALLEL, pltpu.ARBITRARY) + )(x_ref, y_ref, z_ref) + + num_cores = jax.devices()[0].num_cores + func = pl.pallas_call( + functools.partial(matmul_kernel, bm=bm, bk=bk, bn=bn), + out_shape=jax.ShapeDtypeStruct((m, n), jnp.float32), + in_specs=[ + pl.BlockSpec(memory_space=pltpu.ANY), + pl.BlockSpec(memory_space=pltpu.ANY), + ], + out_specs=pl.BlockSpec(memory_space=pltpu.ANY), + grid=(num_cores,), + compiler_params=dict(mosaic=dict(dimension_semantics=('parallel',))), + ) + np.testing.assert_allclose(func(x, y), x @ y, atol=7e-5) + + +if CAN_USE_HYPOTHESIS: + + @partial(jax.jit, static_argnames=['bm', 'bk', 'bn']) + def matmul(x: jax.Array, y: jax.Array, *, bm: int, bk: int, bn: int): + + m, k = x.shape + _, n = y.shape + + def kernel(x_hbm_ref, y_hbm_ref, o_hbm_ref): + + grid = (pl.cdiv(m, bm), pl.cdiv(n, bn), pl.cdiv(k, bk)) + + def run(acc_scratch_ref): + pltpu.emit_pipeline( + partial(basic_matmul_kernel, acc_scratch_ref=acc_scratch_ref, k=k), + in_specs=[ + pl.BlockSpec((bm, bk), lambda i, j, k: (i, k)), + pl.BlockSpec((bk, bn), lambda i, j, k: (k, j)), + ], + out_specs=pl.BlockSpec((bm, bn), lambda i, j, k: (i, j)), + grid=grid, + core_axis=0, + dimension_semantics=( + pltpu.PARALLEL, + pltpu.PARALLEL, + pltpu.ARBITRARY, + ), + )(x_hbm_ref, y_hbm_ref, o_hbm_ref) + + accum_dtype = ( + jnp.float32 if jnp.issubdtype(x.dtype, jnp.floating) else jnp.int32 + ) + pltpu.run_scoped(run, pltpu.VMEM((bm, bn), accum_dtype)) + + num_cores = jax.devices()[0].num_cores + return pl.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct((m, n), x.dtype), + in_specs=[ + pl.BlockSpec(memory_space=pltpu.ANY), + pl.BlockSpec(memory_space=pltpu.ANY), + ], + out_specs=pl.BlockSpec(memory_space=pltpu.ANY), + grid=(num_cores,), + )(x, y) + + class PaddedPipelineEmitterTest(parameterized.TestCase): + + def setUp(self): + super().setUp() + if not jtu.is_device_tpu_at_least(4): + self.skipTest('Only TPU v4+ allowed.') + + @hp.given( + hps.sampled_from(['float32', 'bfloat16', 'int8']), + hps.integers(1, 1024), + hps.integers(1, 1024), + hps.integers(1, 1024), + hps.sampled_from([8, 16, 32, 128, 256, 512]), + hps.sampled_from([128, 256, 512]), + hps.sampled_from([128, 256, 512]), + hps.integers(0, 4), + ) + def test_padded_matmul(self, dtype, m, k, n, bm, bk, bn, seed): + hp.assume(bm <= m) + hp.assume(bn <= n) + hp.assume(bk <= k) + if dtype == 'bfloat16': + hp.assume(bm >= 16) + if dtype == 'int8': + hp.assume(bm >= 32) + hp.assume(jtu.is_device_tpu_at_least(5)) + k1, k2 = jax.random.split(jax.random.key(seed)) + x = jax.random.normal(k1, (m, k), jnp.float32).astype(dtype) + y = jax.random.normal(k2, (k, n), jnp.float32).astype(dtype) + + out = matmul(x, y, bm=bm, bk=bk, bn=bn) + expected = x @ y + atol = rtol = 1e-5 + if dtype == 'bfloat16': + out = out.astype('float32') + expected = expected.astype('float32') + atol = rtol = 1e-2 + np.testing.assert_allclose(out, expected, atol=atol, rtol=rtol) + + +if __name__ == '__main__': + absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/pallas/pallas_shape_poly_test.py b/tests/pallas/pallas_shape_poly_test.py new file mode 100644 index 000000000000..fb5065ebf52f --- /dev/null +++ b/tests/pallas/pallas_shape_poly_test.py @@ -0,0 +1,220 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ruff: noqa: F401 + +import functools +import math +import os +import sys + +os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.5" + +from absl.testing import absltest +from absl.testing import parameterized + +import jax +from jax._src import config +from jax._src import test_util as jtu +from jax._src.pallas.pallas_call import _trace_to_jaxpr +from jax._src import tpu_custom_call # For configuration values +import jax.numpy as jnp +from jax.experimental import pallas as pl +from jax import export +import numpy as np + + +config.update("jax_traceback_filtering", "off") +config.parse_flags_with_absl() + + +# TODO(necula): support an activation +def matmul_kernel(x_ref, y_ref, o_ref): + # x shape: (m, l), y shape (l, n), o shape: (m ,n) + block_m, block_l = x_ref.shape + block_l2, block_n = y_ref.shape + assert block_l2 == block_l + assert o_ref.shape == (block_m, block_n) + @pl.when(pl.program_id(axis=2) == 0) + def _(): + o_ref[...] = jnp.zeros_like(o_ref) + + o_ref[...] += x_ref[...] @ y_ref[...] + + +@functools.partial(jax.jit, static_argnames=['block_shape']) +def matmul( + x: jax.Array, + y: jax.Array, + *, + block_shape=(128, 128, 128) +): + m, l = x.shape + l2, n = y.shape + assert l2 == l + block_m, block_n, block_l = block_shape + assert l % block_l == 0, f"{l=}, {block_l=}" + assert m % block_m == 0, f"{m=}, {block_m=}" + assert n % block_n == 0, f"{n=}, {block_n=}" + grid = (m // block_m, n // block_n, l // block_l) + fused_matmul = pl.pallas_call( + functools.partial(matmul_kernel), + out_shape=jax.ShapeDtypeStruct((m, n), jnp.float32), + in_specs=[ + pl.BlockSpec((block_m, block_l), lambda i, j, k: (i, k)), + pl.BlockSpec((block_l, block_n), lambda i, j, k: (k, j)), + ], + out_specs=pl.BlockSpec((block_m, block_n), lambda i, j, k: (i, j)), + grid=grid, + interpret=jtu.test_device_matches(["cpu"]), + ) + return fused_matmul(x, y) + + +class ShapePolyTest(jtu.JaxTestCase, + parameterized.TestCase): + + def setUp(self): + if jax.config.x64_enabled: + self.skipTest("Only works in 32-bit") + if (jtu.test_device_matches(["cuda"]) and + not jtu.is_cuda_compute_capability_at_least("8.0")): + self.skipTest("Only works on GPU with capability >= sm80") + if sys.platform == "win32": + self.skipTest("Only works on non-Windows platforms") + super().setUp() + _trace_to_jaxpr.cache_clear() + + def test_copy(self): + # The blocks are static, but the input and the grid are of polymorphic + # dimensions. + block_shape = (8, 128) + def f(x, *, eager=False): # x: i32[w, h] + def copy_kernel(x_ref, o_ref): + o_ref[...] = x_ref[...] + # Use both pl.cdiv and // for specifying the grid + grid = (pl.cdiv(x.shape[0], block_shape[0]), + (x.shape[1] + 1) // block_shape[1]) + return pl.pallas_call( + copy_kernel, + out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype), + in_specs=[pl.BlockSpec(block_shape, lambda i, j: (i, j))], + out_specs=pl.BlockSpec(block_shape, lambda i, j: (i, j)), + grid=grid, + interpret=eager and jtu.test_device_matches(["cpu"]))(x) + + shape1 = (128, 256) + x1 = jnp.arange(math.prod(shape1), dtype=np.int32).reshape(shape1) + res = f(x1, eager=True) + self.assertAllClose(res, x1) + + w, h = export.symbolic_shape("w, h") + exp = export.export( + jax.jit(f), + platforms=["tpu"])(jax.ShapeDtypeStruct((w, h), jnp.int32)) + + if jtu.test_device_matches(["tpu"]): + res_exp_1 = exp.call(x1) + self.assertAllClose(res_exp_1, x1) + + shape2 = block_shape + x2 = jnp.arange(math.prod(shape2), dtype=np.int32).reshape(shape2) + res_exp_2 = exp.call(x2) + self.assertAllClose(res_exp_2, x2) + + # TODO(necula): support shape polymorphism for GPU + with self.assertRaisesRegex( + NotImplementedError, + "dynamic grid bounds not supported in the Triton backend"): + export.export( + jax.jit(f), + platforms=["cuda"])(jax.ShapeDtypeStruct((w, h), jnp.int32)) + + def test_block_sizes_must_be_static_no_grid(self): + def f(x, *, eager=False): # x: f32[w, h] + def copy_one(x_ref, o_ref): + o_ref[...] = x_ref[...] + return pl.pallas_call( + copy_one, + out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype), + interpret=eager and jtu.test_device_matches(["cpu"]))(x) + shape1 = (128, 256) + x1 = jnp.arange(math.prod(shape1), dtype=np.int32).reshape(shape1) + res = f(x1, eager=True) + self.assertAllClose(res, x1) + + w, h = export.symbolic_shape("w, h") + with self.assertRaisesRegex( + ValueError, + "shape polymorphism for Pallas does not support dynamically-shaped blocks"): + export.export( + jax.jit(f), + platforms=["tpu"])(jax.ShapeDtypeStruct((w, h), jnp.int32)) + + def test_block_sizes_must_be_static(self): + def f(x, *, eager=False): # x: f32[w, h] + def copy_one(x_ref, o_ref): + o_ref[...] = x_ref[...] + grid = (2, 2) + block_shape = (x.shape[0] // grid[0], x.shape[1] // grid[1]) + return pl.pallas_call( + copy_one, + out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype), + in_specs=[pl.BlockSpec(block_shape, lambda i, j: (i, j))], + out_specs=pl.BlockSpec(block_shape, lambda i, j: (i, j)), + grid=grid, + interpret=eager and jtu.test_device_matches(["cpu"]))(x) + shape1 = (128, 256) + x1 = jnp.arange(math.prod(shape1), dtype=np.int32).reshape(shape1) + res = f(x1, eager=True) + self.assertAllClose(res, x1) + + w, h = export.symbolic_shape("w, h") + with self.assertRaisesRegex( + ValueError, + "shape polymorphism for Pallas does not support dynamically-shaped blocks"): + + export.export( + jax.jit(f), + platforms=["tpu"])(jax.ShapeDtypeStruct((w, h), jnp.int32)) + + @jtu.run_on_devices("tpu") + def test_matmul(self): + x_shape = (1024, 256) + y_shape = (256, 2048) + + key = jax.random.key(42) + key1, key2 = jax.random.split(key, 2) + x = jax.random.normal(key1, x_shape, dtype=np.float32) + y = jax.random.normal(key2, y_shape, dtype=np.float32) + + res = matmul(x, y) + self.assertAllClose(res, x @ y, atol=1e-4) + + m, n, l = export.symbolic_shape("m, n, l", + constraints=["mod(m, 128) == 0", + "mod(n, 128) == 0", + "mod(l, 128) == 0"]) + exp = export.export( + matmul, + platforms=["tpu"])( + jax.ShapeDtypeStruct((m, l), jnp.float32), + jax.ShapeDtypeStruct((l, n), jnp.float32)) + if jtu.test_device_matches(["tpu"]): + res_exp = exp.call(x, y) + self.assertAllClose(res_exp, x @ y, atol=1e-4) + + +if __name__ == "__main__": + absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/pallas/pallas_test.py b/tests/pallas/pallas_test.py index bde8e56bb7c6..1cd206f4788f 100644 --- a/tests/pallas/pallas_test.py +++ b/tests/pallas/pallas_test.py @@ -12,51 +12,47 @@ # See the License for the specific language governing permissions and # limitations under the License. +import contextlib import functools import itertools import os -import unittest +import sys os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.5" from absl.testing import absltest from absl.testing import parameterized - import jax from jax import lax from jax import random +from jax._src import checkify from jax._src import config from jax._src import linear_util as lu -from jax._src import test_util as jtu from jax._src import state +from jax._src import test_util as jtu from jax._src.lax.control_flow.for_loop import for_loop from jax._src.pallas.pallas_call import _trace_to_jaxpr +from jax.experimental import pallas as pl +from jax.experimental.pallas.ops.gpu import attention +from jax.experimental.pallas.ops.gpu import layer_norm +from jax.experimental.pallas.ops.gpu import rms_norm +from jax.experimental.pallas.ops.gpu import softmax from jax.interpreters import partial_eval as pe import jax.numpy as jnp -from jax.experimental import pallas as pl -from jax.experimental.pallas.ops import attention -from jax.experimental.pallas.ops import layer_norm -from jax.experimental.pallas.ops import rms_norm -from jax.experimental.pallas.ops import softmax -try: - from jax._src.pallas.triton.pallas_call_registration import ( - compile_jaxpr, - _TRITON_COMPILE_VIA_XLA, - ) - from jax.experimental.pallas import gpu as plgpu -except ModuleNotFoundError: - compile_jaxpr = None - _TRITON_COMPILE_VIA_XLA = None import numpy as np +if sys.platform != "win32": + from jax.experimental.pallas import gpu as plgpu +else: + plgpu = None + # TODO(sharadmv): Update signatures of pallas_call to correct inputs/outputs. # pylint: disable=no-value-for-parameter - -config.update("jax_traceback_filtering", "off") config.parse_flags_with_absl() + @functools.partial(jax.jit, static_argnames=["bm", "bn", "gm", "bk", "interpret", "debug"]) def matmul(x, y, *, bm, bn, gm, bk, interpret, debug=False): @@ -100,20 +96,23 @@ def body(i, acc_ref): o_ref[o_idx] = acc return matmul_kernel(x, y) + @functools.partial(jax.jit, static_argnames=["bm", "bn", "bk", "interpret", "debug"]) def matmul_block_spec(x, y, *, bm, bn, bk, interpret, debug=False): m, n, k = x.shape[0], y.shape[1], x.shape[1] @functools.partial( - pl.pallas_call, out_shape=jax.ShapeDtypeStruct((m, n), jnp.float32), + pl.pallas_call, + out_shape=jax.ShapeDtypeStruct((m, n), jnp.float32), interpret=interpret, debug=debug, in_specs=[ - pl.BlockSpec(lambda i, _: (i, 0), (bm, x.shape[1])), - pl.BlockSpec(lambda _, j: (0, j), (y.shape[0], bn)) + pl.BlockSpec((bm, x.shape[1]), lambda i, _: (i, 0)), + pl.BlockSpec((y.shape[0], bn), lambda _, j: (0, j)), ], - out_specs=pl.BlockSpec(lambda i, j: (i, j), (bm, bn)), - grid=(pl.cdiv(m, bm), pl.cdiv(n, bn))) + out_specs=pl.BlockSpec((bm, bn), lambda i, j: (i, j)), + grid=(pl.cdiv(m, bm), pl.cdiv(n, bn)), + ) def matmul_kernel(x_ref, y_ref, o_ref): acc = jnp.zeros(o_ref.shape, dtype=jnp.float32) def body(i, acc_ref): @@ -125,41 +124,40 @@ def body(i, acc_ref): return matmul_kernel(x, y) -class PallasTest(parameterized.TestCase): +@jtu.with_config(jax_traceback_filtering="off") +class PallasTest(jtu.JaxTestCase): INTERPRET = False def setUp(self): - if jax.config.x64_enabled: - self.skipTest("Only works in 32-bit") - if not self.INTERPRET: - if not jtu.test_device_matches(["gpu"]): - self.skipTest("Only works on GPU") - try: - import triton # noqa: F401 - except ImportError: - if ( - _TRITON_COMPILE_VIA_XLA is not None - and not _TRITON_COMPILE_VIA_XLA.value - ): - self.skipTest("Triton is not installed.") + if jtu.test_device_matches(["cpu"]) and not self.INTERPRET: + self.skipTest("On CPU the test works only in interpret mode") + if jtu.test_device_matches(["gpu"]) and jax.config.x64_enabled: + self.skipTest("On GPU the test works only in 32-bit") + if (jtu.test_device_matches(["cuda"]) and + not jtu.is_cuda_compute_capability_at_least("8.0")): + self.skipTest("Only works on GPU with capability >= sm80") + if sys.platform == "win32" and not self.INTERPRET: + self.skipTest("Only works on non-Windows platforms") + super().setUp() - if compile_jaxpr: - compile_jaxpr.cache_clear() _trace_to_jaxpr.cache_clear() def pallas_call(self, *args, **kwargs): return pl.pallas_call(*args, **kwargs, interpret=self.INTERPRET) - def check_gpu_capability_at_least(self, capability, - device: int = 0): - if self.INTERPRET: - return True - return plgpu.get_compute_capability(device) >= capability - class PallasCallTest(PallasTest): + def setUp(self): + super().setUp() + if jtu.test_device_matches(["tpu"]): + # TODO: most tests fail on TPU in non-interpreter mode + self.skipTest("On TPU the test works only in interpret mode") + def test_add_one(self): + if jtu.test_device_matches(["cpu"]) and jax.config.x64_enabled: + # TODO: assertion failures on CPU in 64-bit mode + self.skipTest("On CPU the test works only in 32-bit mode") @functools.partial( self.pallas_call, out_shape=jax.ShapeDtypeStruct((), jnp.float32)) def add_one(x_ref, o_ref): @@ -179,22 +177,32 @@ def add_one(x_ref, o_ref): np.testing.assert_allclose(add_one(x), jnp.array([1.], jnp.float32)) def test_add_vector_block_spec(self): + if jtu.test_device_matches(["cpu"]) and jax.config.x64_enabled: + # TODO: assertion failures on CPU in 64-bit mode + self.skipTest("On CPU the test works only in 32-bit mode") @functools.partial( - self.pallas_call, out_shape=jax.ShapeDtypeStruct((8,), jnp.int32), - in_specs=[pl.BlockSpec(lambda i: i, (1,))], - out_specs=pl.BlockSpec(lambda i: i, (1,)), - grid=8, debug=False) + self.pallas_call, + out_shape=jax.ShapeDtypeStruct((8,), jnp.int32), + in_specs=[pl.BlockSpec((1,), lambda i: i)], + out_specs=pl.BlockSpec((1,), lambda i: i), + grid=8, + ) def add_one(x_ref, o_ref): o_ref[0] = x_ref[0] + 1 np.testing.assert_allclose(add_one(jnp.arange(8)), jnp.arange(8) + 1) def test_add_matrix_block_spec(self): + if jtu.test_device_matches(["cpu"]) and jax.config.x64_enabled: + # TODO: assertion failures on CPU in 64-bit mode + self.skipTest("On CPU the test works only in 32-bit mode") @functools.partial( - self.pallas_call, out_shape=jax.ShapeDtypeStruct((8, 8), jnp.int32), - in_specs=[pl.BlockSpec(lambda i, j: (i, j), (2, 2))], - out_specs=pl.BlockSpec(lambda i, j: (i, j), (2, 2)), - grid=(4, 4)) + self.pallas_call, + out_shape=jax.ShapeDtypeStruct((8, 8), jnp.int32), + in_specs=[pl.BlockSpec((2, 2), lambda i, j: (i, j))], + out_specs=pl.BlockSpec((2, 2), lambda i, j: (i, j)), + grid=(4, 4), + ) def add_one(x_ref, o_ref): o_ref[:, :] = x_ref[:, :] + 1 @@ -211,6 +219,9 @@ def logical_and(x_ref, o_ref): self.assertTrue(jnp.all(logical_and(x))) def test_vector_indexing(self): + if jtu.test_device_matches(["cpu"]) and jax.config.x64_enabled: + # TODO: assertion failures on CPU in 64-bit mode + self.skipTest("On CPU the test works only in 32-bit mode") @functools.partial( self.pallas_call, out_shape=jax.ShapeDtypeStruct((), jnp.float32), grid=1) @@ -221,7 +232,24 @@ def index(x_ref, i_ref, o_ref): for i in range(5): np.testing.assert_allclose(index(x, i), x[i]) + def test_hoisted_consts(self): + # See https://github.com/google/jax/issues/21557. + x = jnp.zeros(32) + indices = jnp.arange(4).reshape((2, 2)) + + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype), + ) + def kernel(src, dst): + dst[indices] = src[indices] + + jax.block_until_ready(kernel(x)) + def test_vector_slicing(self): + if jtu.test_device_matches(["cpu"]) and jax.config.x64_enabled: + # TODO: assertion failures on CPU in 64-bit mode + self.skipTest("On CPU the test works only in 32-bit mode") @functools.partial( self.pallas_call, out_shape=jax.ShapeDtypeStruct((2,), jnp.float32), grid=1) @@ -234,100 +262,6 @@ def index(x_ref, idx_ref, o_ref): idx = jnp.arange(i, i + 2) np.testing.assert_allclose(index(x, idx), x[idx]) - def test_num_programs(self): - @functools.partial( - self.pallas_call, - out_shape=jax.ShapeDtypeStruct((4,), jnp.int32), - grid=4, - ) - def kernel(o_ref): - o_ref[pl.program_id(0)] = pl.num_programs(0) - - np.testing.assert_array_equal( - kernel(), np.asarray([4, 4, 4, 4], dtype=np.int32) - ) - - def test_where_broadcasting(self): - @functools.partial( - self.pallas_call, - out_shape=jax.ShapeDtypeStruct((4, 2, 2), jnp.float32), - grid=1) - def copyitem(x_ref, in_idx_ref, out_idx_ref, o_ref): - mask = (jnp.arange(o_ref.shape[0]) == out_idx_ref[()])[:, None, None] - o_ref[...] = jnp.where(mask, x_ref[in_idx_ref[()]], 0) - - x = jnp.arange(7 * 2 * 2.).reshape(7, 2, 2) - for ii in range(7): - for oi in range(4): - out = copyitem(x, ii, oi) - self.assertEqual((4, 2, 2), out.shape) - np.testing.assert_allclose(out[:oi], jnp.zeros_like(out[:oi])) - np.testing.assert_allclose(out[oi], x[ii]) - np.testing.assert_allclose(out[oi + 1:], jnp.zeros_like(out[oi + 1:])) - - @parameterized.parameters(*[ - ((), (2,), ()), - ((1,), (2,), (0,)), - ((1, 1), (2, 2), (0, 1)), - ((), (2, 2), ()), - ]) - def test_broadcast_in_dim(self, in_shape, out_shape, dims): - @functools.partial( - self.pallas_call, out_shape=jax.ShapeDtypeStruct(out_shape, jnp.float32), - grid=1) - def f(x_ref, o_ref): - x = x_ref[...] - o_ref[...] = jax.lax.broadcast_in_dim(x, out_shape, dims) - - x = jnp.arange(int(np.prod(in_shape)), dtype=jnp.float32).reshape(in_shape) - expected = jax.lax.broadcast_in_dim(x, out_shape, dims) - np.testing.assert_allclose(f(x), expected) - - @parameterized.parameters(*[ - ((2, 4), (8,)), - ((2, 4), (8, 1)), - ((2, 4), (1, 8)), - ((64,), (32, 2)), - ]) - def test_reshape(self, in_shape, out_shape): - # TODO(sharadmv): re-enable when `reshape` works again - if not self.INTERPRET: - self.skipTest("Reshape not yet supported in Triton-MLIR") - @functools.partial( - self.pallas_call, out_shape=jax.ShapeDtypeStruct(out_shape, jnp.float32), - grid=1) - def f(x_ref, o_ref): - o_ref[...] = x_ref[...].reshape(out_shape) - - x = jnp.arange(int(np.prod(in_shape)), dtype=jnp.float32).reshape(in_shape) - expected = x.reshape(out_shape) - np.testing.assert_allclose(f(x), expected) - - @parameterized.parameters(*[ - ((), (1,)), - ((), (1, 1)), - ((2, 4), (2, 4)), - ((2, 4), (2, 4, 1)), - ((2, 4, 1), (2, 4)), - ((2, 4), (1, 2, 4)), - ((1, 2, 4), (2, 4)), - ((2, 4), (2, 1, 4)), - ((1, 2, 1, 4, 1), (2, 4)), - ((2, 4,), (1, 2, 1, 4)), - ((2, 4,), (1, 2, 4, 1)), - ((1, 2, 4, 1), (1, 2, 1, 4, 1)), - ]) - def test_reshape_noop_or_singleton_dims(self, in_shape, out_shape): - @functools.partial( - self.pallas_call, out_shape=jax.ShapeDtypeStruct(out_shape, jnp.float32), - grid=1) - def f(x_ref, o_ref): - o_ref[...] = x_ref[...].reshape(out_shape) - - x = jnp.arange(int(np.prod(in_shape)), dtype=jnp.float32).reshape(in_shape) - expected = x.reshape(out_shape) - np.testing.assert_allclose(f(x), expected) - @parameterized.named_parameters(*[ (f"m_{m}_n_{n}_k_{k}_dtype_{dtype}_bm_{block_size_m}_" f"bn_{block_size_n}_bk_{block_size_k}_gm_{group_size_m}", m, n, k, dtype, @@ -337,20 +271,15 @@ def f(x_ref, o_ref): for n in [512, 1024] for dtype in ["float32", "float16"] for block_size_m in [64, 128] - for block_size_n in [128, 256] + for block_size_n in [64, 128] for block_size_k in [32] for group_size_m in [8] if block_size_m <= m and block_size_n <= n and block_size_k <= k ]) def test_matmul(self, m, n, k, dtype, bm, bn, bk, gm): - if not self.check_gpu_capability_at_least(70): - raise unittest.SkipTest( - "Matmul only works on GPUs with capability >= sm70") - if not self.INTERPRET and ( - plgpu.get_compute_capability(0) <= 75 - and (bm >= 128 or bn > 128 or bk > 32) - ): - raise unittest.SkipTest("Block sizes too big for sm70.") + if jtu.test_device_matches(["cpu"]) and jax.config.x64_enabled: + # TODO: all sort of assertion failures on CPU in 64-bit mode + self.skipTest("On CPU the test works only in 32-bit mode") k1, k2 = random.split(random.key(0)) x = random.normal(k1, (m, k), dtype=dtype) y = random.normal(k2, (k, n), dtype=dtype) @@ -367,20 +296,14 @@ def test_matmul(self, m, n, k, dtype, bm, bn, bk, gm): for n in [512, 1024] for dtype in ["float32", "float16"] for block_size_m in [64, 128] - for block_size_n in [128, 256] + for block_size_n in [64, 128] for block_size_k in [32] if block_size_m <= m and block_size_n <= n and block_size_k <= k ]) def test_matmul_block_spec(self, m, n, k, dtype, bm, bn, bk): - if not self.check_gpu_capability_at_least(70): - raise unittest.SkipTest( - "Matmul only works on GPUs with capability >= sm70") - if not self.INTERPRET and ( - plgpu.get_compute_capability(0) <= 75 - and (bm >= 128 or bn > 128 or bk > 32) - ): - raise unittest.SkipTest("Block sizes too big for sm70.") - + if jtu.test_device_matches(["cpu"]) and jax.config.x64_enabled: + # TODO: all sort of assertion failures on CPU in 64-bit mode + self.skipTest("On CPU the test works only in 32-bit mode") k1, k2 = random.split(random.key(0)) x = random.normal(k1, (m, k), dtype=dtype) y = random.normal(k2, (k, n), dtype=dtype) @@ -388,36 +311,6 @@ def test_matmul_block_spec(self, m, n, k, dtype, bm, bn, bk): interpret=self.INTERPRET), jnp.matmul(x, y) np.testing.assert_allclose(out, expected, atol=0.05, rtol=0.05) - @parameterized.product( - size=[16, 32, 64], - dtype=["float32", "float16"], - trans_a=[False, True], - trans_b=[False, True], - ) - def test_dot(self, size, dtype, trans_a, trans_b): - if not self.check_gpu_capability_at_least(70): - raise unittest.SkipTest( - "Matmul only works on GPUs with capability >= sm70") - if trans_a or trans_b: - # TODO(slebedev): Remove this once the problematic Triton pass is fixed. - raise unittest.SkipTest( - "Triton crashes if any of the operands are transposed") - - @functools.partial( - self.pallas_call, - out_shape=jax.ShapeDtypeStruct((size, size), dtype), - grid=1) - def dot(x_ref, y_ref, o_ref): - x = x_ref[:, :] - y = y_ref[:, :] - o_ref[:, :] = pl.dot(x, y, trans_a, trans_b).astype(o_ref.dtype) - - k1, k2 = random.split(random.key(0)) - x = random.normal(k1, (size, size), dtype=dtype) - y = random.normal(k2, (size, size), dtype=dtype) - out, expected = dot(x, y), jnp.dot(x, y) - np.testing.assert_allclose(out, expected, atol=0.05, rtol=0.05) - @parameterized.named_parameters(*( dict(testcase_name=f"{batch_size}_{size}_{block_size}_{dtype}", batch_size=batch_size, size=size, block_size=block_size, dtype=dtype) @@ -448,83 +341,6 @@ def softmax(x_ref, o_ref): np.testing.assert_allclose(softmax(x), jax.nn.softmax(x, axis=-1), atol=1e-5, rtol=1e-5) - @parameterized.parameters(*( - (size, block_size) - for size in [1, 2, 64, 129, 1021] - for block_size in [1, 2, 32, 64, 128] - )) - def test_masked_load_store(self, size, block_size): - @functools.partial(self.pallas_call, - out_shape=( - jax.ShapeDtypeStruct((size,), jnp.float32) - ), - grid=pl.cdiv(size, block_size)) - def add_one(x_ref, o_ref): - idx = pl.program_id(0) * block_size + jnp.arange(block_size) - mask = idx < x_ref.shape[0] - x = pl.load(x_ref, (idx,), mask=mask) - pl.store(o_ref, (idx,), x + 1., mask=mask) - - key = random.key(0) - x = random.normal(key, (size,)) - np.testing.assert_allclose(add_one(x), x + 1., atol=1e-5, rtol=1e-5) - - def test_broadcasted_load_store(self): - m, n = 16, 32 - @functools.partial( - self.pallas_call, - out_shape=( - jax.ShapeDtypeStruct((m, n), jnp.float32) - ), grid=1) - def load(x_ref, o_ref): - x = pl.load(x_ref, (jnp.arange(m)[:, None], jnp.arange(n)[None, :])) - pl.store(o_ref, (jnp.arange(m)[:, None], jnp.arange(n)[None, :]), x + 1.) - - key = random.key(0) - x = random.normal(key, (m, n)) - np.testing.assert_allclose(load(x), x + 1., atol=1e-5, rtol=1e-5) - - def test_swap(self): - m, n = 16, 32 - - @functools.partial( - self.pallas_call, - out_shape=(jax.ShapeDtypeStruct((m, n), jnp.float32),) * 2, - grid=1, - input_output_aliases={0: 0, 1: 1}, - ) - def swap(_, _2, x_ref, y_ref): - x = x_ref[:] - y = pl.swap(y_ref, (slice(None),), x) - x_ref[:] = y - - x = random.normal(random.key(0), (m, n)) - y = random.normal(random.key(1), (m, n)) - out = swap(x, y) - np.testing.assert_array_equal(out[0], y) - np.testing.assert_array_equal(out[1], x) - - def test_masked_swap(self): - m, n = 16, 32 - - @functools.partial( - self.pallas_call, - out_shape=(jax.ShapeDtypeStruct((m, n), jnp.float32),) * 2, - grid=1, - input_output_aliases={0: 0, 1: 1}, - ) - def masked_swap(_, _2, mask_ref, x_ref, y_ref): - x = x_ref[:] - y = pl.swap(y_ref, (slice(None),), x, mask=mask_ref[:]) - x_ref[:] = y - - x = random.normal(random.key(0), (m, n)) - y = random.normal(random.key(1), (m, n)) - mask = random.bernoulli(random.key(2), shape=(m, n)) - out = masked_swap(x, y, mask) - np.testing.assert_array_equal(out[0], jnp.where(mask, y, x)) - np.testing.assert_array_equal(out[1], jnp.where(mask, x, y)) - def test_unused_ref(self): m, n = 16, 32 @functools.partial( @@ -541,7 +357,9 @@ def dummy(_, o_ref): np.testing.assert_allclose(dummy(x), jnp.ones_like(x), atol=1e-5, rtol=1e-5) def test_pallas_call_with_input_output_aliasing(self): - + if jtu.test_device_matches(["cpu"]) and jax.config.x64_enabled: + # TODO: assertion failures on CPU in 64-bit mode + self.skipTest("On CPU the test works only in 32-bit mode") def add_inplace_kernel(_, o_ref, *, block_size): pid = pl.program_id(axis=0) # we use a 1d launch grid so axis is 0 block_start = pid * block_size @@ -565,191 +383,26 @@ def add_inplace_kernel(_, o_ref, *, block_size): expected = x + 1 np.testing.assert_allclose(out, expected) - @parameterized.named_parameters(*[ - ("add_i32", pl.atomic_add, np.array([1, 2, 3, 4], np.int32), np.sum), - ("max_i", pl.atomic_max, np.array([1, 2, 3, 4], np.int32), np.max), - ("min_i32", pl.atomic_min, np.array([1, 2, 3, 4], np.int32), np.min), - ("add_f16", pl.atomic_add, np.array([1, 2, 3, 4], np.float16), np.sum), - ("add_f32", pl.atomic_add, np.array([1, 2, 3, 4], np.float32), np.sum), - ("max_f32", pl.atomic_max, np.array([1, 2, 3, 4], np.float32), np.max), - ("min_f32", pl.atomic_min, np.array([1, 2, 3, 4], np.float32), np.min), - ]) - def test_scalar_atomic(self, op, value, numpy_op): - if not self.check_gpu_capability_at_least(70): - raise unittest.SkipTest( - "Atomic ops onl works on GPUs with capability >= sm70") - - @functools.partial( - self.pallas_call, - out_shape=jax.ShapeDtypeStruct((), value.dtype), - grid=value.shape[0], - input_output_aliases={1: 0}) - def atomic_kernel(x_ref, _, o_ref): - pid = pl.program_id(axis=0) - op(o_ref, (), x_ref[pid]) - if op == pl.atomic_add: - neutral = np.array(0, dtype=value.dtype) - elif op == pl.atomic_max: - if np.issubdtype(value.dtype, np.integer): - neutral = np.array(np.iinfo(value.dtype).min, value.dtype) - else: - neutral = np.array(-float('inf'), value.dtype) - elif op == pl.atomic_min: - if np.issubdtype(value.dtype, np.integer): - neutral = np.array(np.iinfo(value.dtype).max, value.dtype) - else: - neutral = np.array(float('inf'), value.dtype) - elif op == pl.atomic_or: - neutral = np.array(False, value.dtype) - else: - raise NotImplementedError() - out = atomic_kernel(value, neutral) - np.testing.assert_allclose(out, numpy_op(value)) - - @parameterized.parameters(*[(0,), (1,)]) - def test_array_atomic_add(self, axis): - if not self.check_gpu_capability_at_least(70): - raise unittest.SkipTest( - "Atomic ops onl works on GPUs with capability >= sm70") - - m, n = 32, 8 - if axis == 0: - grid = m - else: - grid = n - out_shape = jax.ShapeDtypeStruct((n if axis == 0 else m,), jnp.float32) + def test_using_pallas_slice(self): + if jtu.test_device_matches(["cpu"]) and jax.config.x64_enabled: + # TODO: assertion failures on CPU in 64-bit mode + self.skipTest("On CPU the test works only in 32-bit mode") + m, n = 32, 4 + out_shape = jax.ShapeDtypeStruct((4, n), jnp.float32) @functools.partial( self.pallas_call, out_shape=out_shape, - grid=grid, - input_output_aliases={1: 0}) - def reduce(x_ref, _, y_ref): - i = pl.program_id(axis=0) - if axis == 0: - idx = (i, jnp.arange(n)) - else: - idx = (jnp.arange(m), i) - x = pl.load(x_ref, idx) - pl.atomic_add(y_ref, (jnp.arange(y.shape[0]),), x) + grid=1) + def slice_kernel(x_ref, y_ref): + x = pl.load(x_ref, (pl.dslice(0, 4), pl.dslice(0, 4))) + pl.store(y_ref, (pl.dslice(4), pl.dslice(4)), x) x = random.normal(random.key(0), (m, n)) - y = jnp.zeros(out_shape.shape, out_shape.dtype) - y = reduce(x, y) - y_ref = np.sum(x, axis=axis) + y = slice_kernel(x) + y_ref = x[:4] np.testing.assert_allclose(y, y_ref, atol=1e-2, rtol=1e-2) - @parameterized.parameters(False, True) - def test_reduce_only_dim(self, use_store): - m = 32 - x = random.normal(random.key(0), (m,), dtype=jnp.float32) - out_shape = jax.ShapeDtypeStruct((), x.dtype) - @functools.partial( - self.pallas_call, - out_shape=out_shape, - grid=1, debug=False) - def reduce(x_ref, y_ref): - x = pl.load(x_ref, (jnp.arange(m),)) - y = jnp.sum(x, axis=-1) - if use_store: - pl.store(y_ref, (), y) - else: - y_ref[...] = y - y = reduce(x) - y_ref = jnp.sum(x, axis=-1) - np.testing.assert_allclose(y, y_ref, atol=1e-2, rtol=1e-2) - - @parameterized.named_parameters(*[ - (f"{op_name}_{dtype}_{axis}", op, dtype, axis) - for op_name, op in [ - ("add", jnp.sum), - ("max", jnp.max), - ("min", jnp.min), - ("argmax", jnp.argmax), - ("argmin", jnp.argmin), - ] - for axis in [0, 1, (1,), (0, 1)] - for dtype in ["float16", "float32", "int32", "uint32"] - if isinstance(axis, int) or "arg" not in op_name - ]) - def test_array_reduce(self, op, dtype, axis): - m, n = 32, 8 - out_dtype = dtype - if op in {jnp.argmin, jnp.argmax}: - out_dtype = jnp.int32 - def make_x(key): - if jnp.issubdtype(dtype, jnp.integer): - return random.permutation( - key, jnp.arange(m * n, dtype=dtype), independent=True - ).reshape(m, n) - else: - return random.normal(key, (m, n), dtype=dtype) - out_shape = jax.ShapeDtypeStruct( - op(make_x(random.key(0)), axis=axis).shape, out_dtype) - if isinstance(axis, int): - grid = tuple(a for i, a in enumerate((m, n)) if i != axis) - else: - grid = tuple(a for i, a in enumerate((m, n)) if i not in axis) - @functools.partial( - self.pallas_call, - out_shape=out_shape, - grid=grid) - def reduce(x_ref, y_ref): - x = pl.load(x_ref, (jnp.arange(m)[:, None], jnp.arange(n)[None])) - y = op(x, axis=axis) - pl.store(y_ref, tuple(jnp.arange(d) for d in y.shape), y) - for i, key in enumerate(random.split(random.key(0), 20)): - x = make_x(key) - y = reduce(x) - y_ref = op(x, axis=axis) - np.testing.assert_allclose(y, y_ref, atol=1e-2, rtol=1e-2, err_msg=i) - - @parameterized.named_parameters(*[ - (f"{dtype}_{axis}", dtype, axis) - for axis in [0, 1] - for dtype in ["float16", "float32", "int32", "uint32"] - if isinstance(axis, int) - ]) - def test_cumsum(self, dtype, axis): - m, n = 32, 8 - out_dtype = dtype - def make_x(key): - if jnp.issubdtype(dtype, jnp.integer): - return random.permutation( - key, jnp.arange(m * n, dtype=dtype), independent=True - ).reshape(m, n) - else: - return random.normal(key, (m, n), dtype=dtype) - out_shape = jax.ShapeDtypeStruct((m, n), out_dtype) - grid = () - @functools.partial( - self.pallas_call, - out_shape=out_shape, - grid=grid) - def reduce(x_ref, y_ref): - x = x_ref[...] - y_ref[...] = jnp.cumsum(x, axis=axis) - for i, key in enumerate(random.split(random.key(0), 20)): - x = make_x(key) - y = reduce(x) - y_ref = jnp.cumsum(x, axis=axis) - np.testing.assert_allclose(y, y_ref, atol=1e-2, rtol=1e-2, err_msg=i) - - def test_using_pallas_slice(self): - m, n = 32, 4 - out_shape = jax.ShapeDtypeStruct((4, n), jnp.float32) - @functools.partial( - self.pallas_call, - out_shape=out_shape, - grid=1) - def slice_kernel(x_ref, y_ref): - x = pl.load(x_ref, (pl.dslice(0, 4), pl.dslice(0, 4))) - pl.store(y_ref, (pl.dslice(4), pl.dslice(4)), x) - x = random.normal(random.key(0), (m, n)) - y = slice_kernel(x) - y_ref = x[:4] - np.testing.assert_allclose(y, y_ref, atol=1e-2, rtol=1e-2) - - def test_pallas_trace_cache(self): - trace_count = 0 + def test_pallas_trace_cache(self): + trace_count = 0 @functools.partial( self.pallas_call, out_shape=jax.ShapeDtypeStruct((), jnp.float32), grid=1) @@ -766,81 +419,6 @@ def f(x): self.assertEqual(f(x), 2.) self.assertEqual(trace_count, 1) - def test_pallas_compilation_cache(self): - if not compile_jaxpr: - self.skipTest("No Triton GPU.") - if self.INTERPRET: - raise unittest.SkipTest("No Triton compilation in interpreter mode.") - if _TRITON_COMPILE_VIA_XLA.value: - raise unittest.SkipTest("Triton is compiled via XLA.") - - @functools.partial( - self.pallas_call, out_shape=jax.ShapeDtypeStruct((), jnp.float32), - grid=1) - def add_one(x_ref, o_ref): - o_ref[()] = x_ref[()] + 1. - - @jax.jit - def f(x): - return add_one(add_one(x)) - - x = jnp.array(0., dtype=jnp.float32) - self.assertEqual(f(x), 2.) - num_misses = compile_jaxpr.cache_info().misses - self.assertEqual(num_misses, 1) - - @parameterized.parameters(*[ - (0, 0, 1), - (0, 1, 1), - (1, 0, 1), - (1, 1, 1), - (2, 1, 1), - (2, 1, 1), - ]) - def test_atomic_cas(self, init_value, cmp, new_value): - if not self.check_gpu_capability_at_least(70): - raise unittest.SkipTest("requires a GPU with compute capability >= sm70") - - @functools.partial( - self.pallas_call, out_shape=( - jax.ShapeDtypeStruct((), jnp.int32), - jax.ShapeDtypeStruct((), jnp.int32)), - input_output_aliases={0: 0}) - def swap(_, lock_ref, out_ref): - out_ref[()] = pl.atomic_cas(lock_ref, cmp, new_value) - - lock, out = swap(init_value) - np.testing.assert_allclose(lock, new_value if cmp == init_value else - init_value) - np.testing.assert_allclose(out, init_value) - - @parameterized.parameters(*[ - 1, 2, 3, 4, 8 - ]) - def test_atomic_counter(self, num_threads): - if self.INTERPRET: - self.skipTest("While loop not supported in interpreter mode.") - - if not self.check_gpu_capability_at_least(70): - raise unittest.SkipTest("requires a GPU compute capability >= sm70") - - @functools.partial( - self.pallas_call, out_shape=( - jax.ShapeDtypeStruct((), jnp.int32), - jax.ShapeDtypeStruct((), jnp.int32)), - input_output_aliases={0: 0, 1: 1}, - grid=(num_threads,)) - def increment(_, __, lock_ref, counter_ref): - def _cond(_): - return pl.atomic_cas(lock_ref, 0, 1) == 1 - lax.while_loop(_cond, lambda a: a, 0) - counter_ref[...] += 1 - pl.atomic_xchg(lock_ref, (), 0) - - lock, count = increment(0, 0) - np.testing.assert_allclose(lock, 0) - np.testing.assert_allclose(count, num_threads) - def test_custom_jvp_call(self): @functools.partial(jax.custom_jvp, nondiff_argnums=(1,)) def softmax(x, axis=-1): @@ -873,6 +451,9 @@ def setUp(self): super().setUp() if self.INTERPRET: self.skipTest("Control flow not supported in interpreter mode yet.") + if jtu.test_device_matches(["tpu"]): + # TODO: most tests fail on TPU in non-interpreter mode + self.skipTest("On TPU the test works only in interpret mode") def test_loop_with_float64_carry(self): # Test that the jnp.zeros(f64) loop init_val is actually f64, and that @@ -883,7 +464,7 @@ def test_loop_with_float64_carry(self): @functools.partial(self.pallas_call, out_shape=jax.ShapeDtypeStruct((4,), jnp.float64), grid=1, - debug=False) + ) def f(x_ref, y_ref): def body(i, acc): # TODO(sharadmv): DCE loop index but retain carry breaks scan pattern. @@ -899,7 +480,7 @@ def test_cond_simple(self): arg = jnp.float32(0.) @functools.partial(self.pallas_call, out_shape=jax.ShapeDtypeStruct(arg.shape, jnp.float32), - debug=False) + ) def f(branch_ref, x_ref, y_ref): y_ref[...] = lax.switch( branch_ref[...], @@ -915,7 +496,7 @@ def test_cond_threebranch(self): @functools.partial(self.pallas_call, out_shape=jax.ShapeDtypeStruct(arg.shape, jnp.float32), grid=1, - debug=False) + ) def f(branch_ref, x_ref, y_ref): y_ref[...] = lax.switch( branch_ref[...], @@ -931,13 +512,16 @@ def f(branch_ref, x_ref, y_ref): @parameterized.parameters(1, 2, 4, 8) def test_cond_vectors(self, block_size): arg = jnp.float32([0.] * 8) - @functools.partial(self.pallas_call, - out_shape=jax.ShapeDtypeStruct(arg.shape, jnp.float32), - in_specs=[pl.BlockSpec(lambda _: (), ()), - pl.BlockSpec(lambda i: i, (block_size,))], - out_specs=pl.BlockSpec(lambda i: i, (block_size,)), - grid=pl.cdiv(arg.shape[0], block_size), - debug=False) + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct(arg.shape, jnp.float32), + in_specs=[ + pl.BlockSpec((), lambda _: ()), + pl.BlockSpec((block_size,), lambda i: i), + ], + out_specs=pl.BlockSpec((block_size,), lambda i: i), + grid=pl.cdiv(arg.shape[0], block_size), + ) def f(branch_ref, x_ref, y_ref): y_ref[...] = lax.switch( branch_ref[...], @@ -951,13 +535,16 @@ def f(branch_ref, x_ref, y_ref): @parameterized.parameters(1, 2, 4, 8) def test_cond_threebranch_vectors(self, block_size): arg = jnp.float32([0.] * 8) - @functools.partial(self.pallas_call, - out_shape=jax.ShapeDtypeStruct(arg.shape, jnp.float32), - in_specs=[pl.BlockSpec(lambda _: (), ()), - pl.BlockSpec(lambda i: i, (block_size,))], - out_specs=pl.BlockSpec(lambda i: i, (block_size,)), - grid=pl.cdiv(arg.shape[0], block_size), - debug=False) + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct(arg.shape, jnp.float32), + in_specs=[ + pl.BlockSpec((), lambda _: ()), + pl.BlockSpec((block_size,), lambda i: i), + ], + out_specs=pl.BlockSpec((block_size,), lambda i: i), + grid=pl.cdiv(arg.shape[0], block_size), + ) def f(branch_ref, x_ref, y_ref): y_ref[...] = lax.switch( branch_ref[...], @@ -973,18 +560,19 @@ def f(branch_ref, x_ref, y_ref): @parameterized.parameters(*itertools.product([1, 8], [1, 2, 4])) def test_cond_threebranch_matrix_out(self, bx, by): x = jnp.arange(64.)[:, None] - y = jnp.arange(128.)[None, :] - # TODO(sharadmv): Renaming in_specs->in_spec silently breaks. + y = jnp.arange(128.0)[None, :] + @functools.partial( self.pallas_call, out_shape=jax.ShapeDtypeStruct((x.shape[0], y.shape[1]), jnp.float32), in_specs=[ - pl.BlockSpec(lambda _, __: (), ()), - pl.BlockSpec(lambda i, _: (i, 0), (bx, 1)), - pl.BlockSpec(lambda _, j: (0, j), (1, by))], - out_specs=pl.BlockSpec(lambda i, j: (i, j), (bx, by)), + pl.BlockSpec((), lambda _, __: ()), + pl.BlockSpec((bx, 1), lambda i, _: (i, 0)), + pl.BlockSpec((1, by), lambda _, j: (0, j)), + ], + out_specs=pl.BlockSpec((bx, by), lambda i, j: (i, j)), grid=(pl.cdiv(x.shape[0], bx), pl.cdiv(y.shape[1], by)), - debug=False) + ) def f(branch_ref, x_ref, y_ref, o_ref): o_ref[...] = lax.switch( branch_ref[...], @@ -1001,7 +589,7 @@ def test_conditional_write(self): arg = jnp.arange(8, dtype=jnp.float32) @functools.partial(self.pallas_call, out_shape=jax.ShapeDtypeStruct(arg.shape, jnp.float32), - debug=False) + ) def f(branch_ref, x_ref, out_ref): out_ref[...] = -x_ref[...] def if_true(z): @@ -1023,6 +611,10 @@ def if_true(z): # dx, jnp.float32([0., 2, 4, 6, 0, 10, 12 + 12, 14])) def test_scan_cond_vm_explicit_ref_arg(self): + if jtu.test_device_matches(["cpu"]): + # TODO: fix this + self.skipTest("Fails on CPU: assertion error") + program = jnp.int32([0, 1, 2, 3, 2]) params = jnp.arange(len(program) * 3.).reshape(len(program), 3) x = jnp.arange(7.) @@ -1033,12 +625,13 @@ def test_scan_cond_vm_explicit_ref_arg(self): self.pallas_call, out_shape=jax.ShapeDtypeStruct((x.shape[0],), jnp.float32), in_specs=[ - pl.BlockSpec(lambda _: (0,), program.shape), # program - pl.BlockSpec(lambda _: (0, 0), params.shape), # params - pl.BlockSpec(lambda i: (i,), (bx,))], # x - out_specs=pl.BlockSpec(lambda i: (i,), (bx,)), + pl.BlockSpec(program.shape, lambda _: (0,)), # program + pl.BlockSpec(params.shape, lambda _: (0, 0)), # params + pl.BlockSpec((bx,), lambda i: (i,)), + ], # x + out_specs=pl.BlockSpec((bx,), lambda i: (i,)), grid=pl.cdiv(x.shape[0], bx), - debug=False) + ) def f(program_ref, params_ref, x_ref, out_ref): x = x_ref[...] @@ -1070,6 +663,10 @@ def body_fn(i, args): params, x) def test_scan_cond_vm_closing_over_ref(self): + if jtu.test_device_matches(["cpu"]): + # TODO: fix this + self.skipTest("Fails on CPU: assertion error") + # ** Difference is the closure over params_ref in the switch branches. ** program = jnp.int32([0, 1, 2, 3, 2, -1]) params = jnp.arange(len(program) * 3.).reshape(len(program), 3) @@ -1081,12 +678,13 @@ def test_scan_cond_vm_closing_over_ref(self): self.pallas_call, out_shape=jax.ShapeDtypeStruct((x.shape[0],), jnp.float32), in_specs=[ - pl.BlockSpec(lambda _: (0,), program.shape), # program - pl.BlockSpec(lambda _: (0, 0), params.shape), # params - pl.BlockSpec(lambda i: (i,), (bx,))], # x - out_specs=pl.BlockSpec(lambda i: (i,), (bx,)), + pl.BlockSpec(program.shape, lambda _: (0,)), # program + pl.BlockSpec(params.shape, lambda _: (0, 0)), # params + pl.BlockSpec((bx,), lambda i: (i,)), + ], # x + out_specs=pl.BlockSpec((bx,), lambda i: (i,)), grid=pl.cdiv(x.shape[0], bx), - debug=False) + ) def f(program_ref, params_ref, x_ref, out_ref): x = x_ref[...] @@ -1258,6 +856,7 @@ def body(i): x = jnp.array([1, 4, 100]) np.testing.assert_array_equal(jax.vmap(f)(x), x) + class PallasControlFlowInterpreterTest(PallasControlFlowTest): INTERPRET = True @@ -1275,13 +874,26 @@ class PallasControlFlowInterpreterTest(PallasControlFlowTest): ("tanh", jnp.tanh), ] + class PallasCallAutodifferentiationTest(PallasTest): + def setUp(self): + super().setUp() + if jtu.test_device_matches(["tpu"]): + # TODO: most tests fail on TPU in non-interpreter mode + self.skipTest("On TPU the test works only in interpret mode") + # TODO: improve tolerance setting + self.tol = 1e-5 + self.grad_tol = jtu.default_gradient_tolerance[np.dtype(jnp.float32)] + @parameterized.named_parameters(*AD_TEST_CASES) def test_jvp(self, impl): + grad_tol = self.grad_tol + if jtu.test_device_matches(["tpu"]) and "recip_exp_sq" in self._testMethodName: + grad_tol = 1e-1 + @functools.partial( self.pallas_call, out_shape=jax.ShapeDtypeStruct((), jnp.float32), - debug=False, grid=1) def pallas_impl(x_ref, o_ref): x = x_ref[()] @@ -1292,10 +904,12 @@ def pallas_impl(x_ref, o_ref): t = random.normal(k2) out_primal, out_tangent = jax.jvp(pallas_impl, (x,), (t,)) out_primal_ref, out_tangent_ref = jax.jvp(impl, (x,), (t,)) - np.testing.assert_allclose(out_primal, out_primal_ref, atol=1e-5, rtol=1e-5) - np.testing.assert_allclose(out_tangent, out_tangent_ref, atol=1e-5, - rtol=1e-5) - jtu.check_grads(pallas_impl, (x,), modes=["fwd"], order=2) + np.testing.assert_allclose(out_primal, out_primal_ref, atol=self.tol, + rtol=self.tol) + np.testing.assert_allclose(out_tangent, out_tangent_ref, atol=self.tol, + rtol=self.tol) + jtu.check_grads(pallas_impl, (x,), modes=["fwd"], order=2, + atol=grad_tol, rtol=grad_tol) @parameterized.named_parameters(*AD_TEST_CASES) def test_pallas_around_grad(self, impl): @@ -1303,7 +917,6 @@ def test_pallas_around_grad(self, impl): self.pallas_call, out_shape=jax.ShapeDtypeStruct((), jnp.float32), name=self.id().split(".")[-1], - debug=True, grid=1) def pallas_impl(x_ref, o_ref): x = x_ref[()] @@ -1316,9 +929,12 @@ def pallas_impl(x_ref, o_ref): @parameterized.named_parameters(*AD_TEST_CASES) def test_jvp_slice(self, impl): + grad_tol = self.grad_tol + if jtu.test_device_matches(["tpu"]) and "tanh" in self._testMethodName: + grad_tol = 1e-1 + @functools.partial( self.pallas_call, out_shape=jax.ShapeDtypeStruct((4,), jnp.float32), - debug=False, grid=1) def pallas_impl(x_ref, o_ref): x = x_ref[jnp.arange(2)] @@ -1331,10 +947,12 @@ def pallas_impl(x_ref, o_ref): out_primal, out_tangent = jax.jvp(pallas_impl, (x,), (t,)) out_primal_ref, out_tangent_ref = jax.jvp( lambda x: jnp.concatenate([jnp.zeros(2), impl(x[:2])]), (x,), (t,)) - np.testing.assert_allclose(out_primal, out_primal_ref, atol=1e-5, rtol=1e-5) - np.testing.assert_allclose(out_tangent, out_tangent_ref, atol=1e-5, - rtol=1e-5) - jtu.check_grads(pallas_impl, (x,), modes=["fwd"], order=2) + np.testing.assert_allclose(out_primal, out_primal_ref, atol=self.tol, + rtol=self.tol) + np.testing.assert_allclose(out_tangent, out_tangent_ref, atol=self.tol, + rtol=self.tol) + jtu.check_grads(pallas_impl, (x,), modes=["fwd"], order=2, + atol=grad_tol, rtol=grad_tol) # TODO(sharadmv): enable this when we update Triton # def test_jvp_matmul(self): @@ -1348,12 +966,14 @@ def pallas_impl(x_ref, o_ref): def test_slicing_block_spec(self): @functools.partial( - self.pallas_call, out_shape=jax.ShapeDtypeStruct((4,), jnp.float32), + self.pallas_call, + out_shape=jax.ShapeDtypeStruct((4,), jnp.float32), in_specs=[ - pl.BlockSpec(lambda _: (0, 0), (None, 4)), - pl.BlockSpec(lambda _: (1, 0), (None, 4)), + pl.BlockSpec((None, 4), lambda _: (0, 0)), + pl.BlockSpec((None, 4), lambda _: (1, 0)), ], - debug=False, grid=1) + grid=1, + ) def add_vectors(x_ref, y_ref, o_ref): o_ref[:] = x_ref[:] + y_ref[:] xy = jnp.arange(8.).reshape((2, 4)) @@ -1362,12 +982,28 @@ def add_vectors(x_ref, y_ref, o_ref): np.testing.assert_allclose(out, out_ref) +class PallasCallAutodifferentiationInterpreterTest(PallasCallAutodifferentiationTest): + INTERPRET = True + + def setUp(self): + super().setUp() + if jtu.test_device_matches(["cpu"]) and jax.config.x64_enabled: + # TODO: assertion failures on CPU in 64-bit mode + self.skipTest("On CPU the test works only in 32-bit mode") + + class PallasCallVmapTest(PallasTest): + def setUp(self): + super().setUp() + if jtu.test_device_matches(["tpu"]): + # TODO: most tests fail on TPU in non-interpreter mode + self.skipTest("On TPU the test works only in interpret mode") + def test_vmap_of_simple_kernel(self): @functools.partial( self.pallas_call, out_shape=jax.ShapeDtypeStruct((), jnp.int32), - debug=False) + ) def add_one(x_ref, o_ref): o_ref[()] = x_ref[()] + 1 out = jax.vmap(add_one)(jnp.arange(8)) @@ -1377,7 +1013,7 @@ def add_one(x_ref, o_ref): def test_vmap_of_simple_kernel_with_in_axes_None(self): @functools.partial( self.pallas_call, out_shape=jax.ShapeDtypeStruct((), jnp.int32), - debug=False) + ) def add(x_ref, y_ref, o_ref): o_ref[()] = x_ref[()] + y_ref[()] out = jax.vmap(add, in_axes=(0, None))(jnp.arange(8), 1) @@ -1387,7 +1023,7 @@ def add(x_ref, y_ref, o_ref): def test_double_vmap_of_simple_kernel(self): @functools.partial( self.pallas_call, out_shape=jax.ShapeDtypeStruct((), jnp.int32), - debug=False) + ) def add_one(x_ref, o_ref): o_ref[()] = x_ref[()] + 1 out = jax.vmap(jax.vmap(add_one))(jnp.arange(8).reshape((4, 2))) @@ -1397,7 +1033,7 @@ def add_one(x_ref, o_ref): def test_quadruple_vmap_of_simple_kernel(self): @functools.partial( self.pallas_call, out_shape=jax.ShapeDtypeStruct((), jnp.int32), - debug=False) + ) def add_one(x_ref, o_ref): o_ref[()] = x_ref[()] + 1 out = jax.vmap(jax.vmap(jax.vmap(jax.vmap(add_one))))( @@ -1408,7 +1044,6 @@ def add_one(x_ref, o_ref): def test_quadruple_vmap_of_batched_kernel(self): @functools.partial( self.pallas_call, out_shape=jax.ShapeDtypeStruct((7,), jnp.int32), - debug=False, grid=(7,)) def add_one(x_ref, o_ref): i = pl.program_id(0) @@ -1421,7 +1056,6 @@ def add_one(x_ref, o_ref): def test_vmap_of_slicing_kernel(self): @functools.partial( self.pallas_call, out_shape=jax.ShapeDtypeStruct((2,), jnp.int32), - debug=False, grid=(2,)) def add_one(x_ref, o_ref): i = pl.program_id(0) @@ -1433,7 +1067,6 @@ def add_one(x_ref, o_ref): def test_vmap_of_kernel_with_input_output_aliases(self): @functools.partial( self.pallas_call, out_shape=jax.ShapeDtypeStruct((), jnp.int32), - debug=False, input_output_aliases={1:0}, grid=()) def add(x_ref, _, o_ref): @@ -1442,118 +1075,921 @@ def add(x_ref, _, o_ref): out_ref = jnp.arange(2, 10) np.testing.assert_allclose(out, out_ref) + def test_vmap_of_kernel_with_input_output_aliases_different_axes(self): + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct((4,), jnp.int32), + input_output_aliases={0: 0}, + grid=(), + ) + def add(x_ref, o_ref): + o_ref[()] = x_ref[()] + 1 + + out = jax.vmap(add, in_axes=1)(jnp.arange(8).reshape((4, 2))) + out_ref = jnp.arange(1, 9).reshape((4, 2)).swapaxes(0, 1) + np.testing.assert_allclose(out, out_ref) + def test_vmap_of_slicing_kernel_different_axes(self): @functools.partial( - self.pallas_call, out_shape=jax.ShapeDtypeStruct((2,), jnp.int32), - debug=False, - grid=(2,)) - def add_one(x_ref, o_ref): - i = pl.program_id(0) - o_ref[i] = x_ref[i] + 1 - add_one_ref = lambda x: x + 1 - x = jnp.arange(8).reshape((2, 4)) + self.pallas_call, out_shape=jax.ShapeDtypeStruct((2,), jnp.int32), + grid=(2,)) + def add_one(x_ref, o_ref): + i = pl.program_id(0) + o_ref[i] = x_ref[i] + 1 + add_one_ref = lambda x: x + 1 + x = jnp.arange(8).reshape((2, 4)) + + out = jax.vmap(add_one, in_axes=1, out_axes=1)(x) + out_ref = jax.vmap(add_one_ref, in_axes=1, out_axes=1)(x) + np.testing.assert_allclose(out, out_ref) + + out = jax.vmap(add_one, in_axes=1, out_axes=0)(x) + out_ref = jax.vmap(add_one_ref, in_axes=1, out_axes=0)(x) + np.testing.assert_allclose(out, out_ref) + + def test_double_vmap_of_slicing_kernel_different_axes(self): + @functools.partial( + self.pallas_call, out_shape=jax.ShapeDtypeStruct((4,), jnp.float32), + grid=(4,)) + def sin(x_ref, o_ref): + i = pl.program_id(0) + o_ref[i] = jnp.sin(x_ref[i]) + sin_ref = jnp.sin + x = jnp.arange(64.).reshape((8, 4, 2)) + + out = jax.vmap(jax.vmap(sin, in_axes=1), in_axes=0)(x) + out_ref = jax.vmap(jax.vmap(sin_ref, in_axes=1), in_axes=0)(x) + np.testing.assert_allclose(out, out_ref, atol=1e-3, rtol=1e-3) + + @jtu.skip_on_flag("jax_skip_slow_tests", True) + def test_small_large_vmap(self): + # Catches https://github.com/google/jax/issues/18361 + @functools.partial( + self.pallas_call, out_shape=jax.ShapeDtypeStruct((2,), jnp.int32), + grid=(2,)) + def add_one(x_ref, o_ref): + o_ref[()] = x_ref[()] + 1 + + add_one = jax.vmap(jax.vmap(add_one)) + add_one_ref = lambda x: x + 1 + + x = random.randint(random.key(0), (4, 65536, 2), 0, 10000) + + out = add_one(x) + out_ref = add_one_ref(x) + + np.testing.assert_allclose(out, out_ref) + + def test_small_small_large_vmap(self): + @functools.partial( + self.pallas_call, out_shape=jax.ShapeDtypeStruct((2,), jnp.int32), + grid=(2,)) + def add_one(x_ref, o_ref): + o_ref[()] = x_ref[()] + 1 + + add_one = jax.vmap(jax.vmap(jax.vmap(add_one))) + add_one_ref = lambda x: x + 1 + + x = random.randint(random.key(0), (2, 2, 65536, 2), 0, 10000) + + out = add_one(x) + out_ref = add_one_ref(x) + + np.testing.assert_allclose(out, out_ref) + + +class PallasCallVmapInterpreterTest(PallasCallVmapTest): + INTERPRET = True + + def setUp(self): + super().setUp() + if jtu.test_device_matches(["cpu"]) and jax.config.x64_enabled: + # TODO: assertion failures on CPU in 64-bit mode + self.skipTest("On CPU the test works only in 32-bit mode") + + +class PallasOpsTest(PallasTest): + + def setUp(self): + super().setUp() + if jtu.test_device_matches(["tpu"]): + # TODO: most tests fail on TPU in non-interpreter mode + self.skipTest("On TPU the test works only in interpret mode") + + ELEMENTWISE_OPS = [ + ( + [jnp.abs, jnp.negative], + ["int16", "int32", "int64", "float16", "float32", "float64"], + ), + ([jnp.ceil, jnp.floor], ["float32", "float64", "int32"]), + ( + [jnp.exp, jnp.exp2, jnp.sin, jnp.cos, jnp.log, jnp.sqrt], + ["float16", "float32", "float64"], + ), + ( + # fmt: off + [jnp.expm1, jnp.log1p, jnp.cbrt, lax.rsqrt, jnp.tan, jnp.asin, + jnp.acos, jnp.atan, jnp.sinh, jnp.cosh, jnp.asinh, jnp.acosh, + jnp.atanh], + # fmt: on + ["float32", "float64"], + ), + ([lax.population_count, lax.clz, jnp.invert], ["int32", "int64"]), + ] + + @parameterized.named_parameters( + (f"{fn.__name__}_{dtype}", fn, dtype) + for args in ELEMENTWISE_OPS + for fn, dtype in itertools.product(*args) + ) + def test_elementwise(self, fn, dtype): + @functools.partial( + self.pallas_call, out_shape=jax.ShapeDtypeStruct((2,), dtype), grid=1 + ) + def kernel(x_ref, o_ref): + o_ref[:] = fn(x_ref[...]) + + with contextlib.ExitStack() as stack: + if jnp.dtype(dtype).itemsize == 8: + stack.enter_context(config.enable_x64(True)) + x = jnp.array([0.42, 2.4]).astype(dtype) + np.testing.assert_allclose(kernel(x), fn(x), rtol=1e-6) + + @parameterized.parameters( + ("float32", "int32"), + ("float64", "int32"), + ("float32", "float32"), + ("float64", "float64"), + ) + def test_pow(self, x_dtype, y_dtype): + @functools.partial( + self.pallas_call, out_shape=jax.ShapeDtypeStruct((4,), x_dtype), grid=1 + ) + def kernel(x_ref, y_ref, o_ref): + o_ref[:] = lax.pow(x_ref[...], y_ref[...]) + + with contextlib.ExitStack() as stack: + if jnp.dtype(x_dtype).itemsize == 8: + stack.enter_context(config.enable_x64(True)) + x = jnp.array([1, 2, 3, 4]).astype(x_dtype) + y = jnp.array([1, 2, 3, 4]).astype(y_dtype) + np.testing.assert_allclose(kernel(x, y), lax.pow(x, y)) + + @parameterized.parameters(0, 1, 2, 3, 4, 5, -1, -2, -3) + def test_integer_pow(self, y): + @functools.partial( + self.pallas_call, out_shape=jax.ShapeDtypeStruct((4,), jnp.float32), + ) + def kernel(x_ref, o_ref): + o_ref[:] = lax.integer_pow(x_ref[...], y) + + x = jnp.array([1, 2, 3, 4]).astype(jnp.float32) / 10 + np.testing.assert_allclose(kernel(x), lax.integer_pow(x, y)) + + @parameterized.parameters("float32", "float64") + def test_nextafter(self, dtype): + if jtu.test_device_matches(["tpu"]) and dtype == "float64": + self.skipTest("float64 disabled on TPU.") + @functools.partial( + self.pallas_call, out_shape=jax.ShapeDtypeStruct((4,), dtype), grid=1 + ) + def kernel(x_ref, y_ref, o_ref): + o_ref[:] = jnp.nextafter(x_ref[...], y_ref[...]) + + with contextlib.ExitStack() as stack: + if jnp.dtype(dtype).itemsize == 8: + stack.enter_context(config.enable_x64(True)) + x = jnp.array([1, 2, 3, 4]).astype(dtype) + y = jnp.array([1, 2, 3, 4]).astype(dtype) + np.testing.assert_allclose(kernel(x, y), jnp.nextafter(x, y)) + + COMPARISON_OPS = [ + jnp.equal, + jnp.not_equal, + jnp.less, + jnp.less_equal, + jnp.greater, + jnp.greater_equal, + ] + + @parameterized.named_parameters( + (f"{fn.__name__}_{dtype}", fn, dtype) + for fn, dtype in itertools.product( + COMPARISON_OPS, ["int32", "uint32", "float16", "float32"] + ) + ) + def test_comparison(self, fn, dtype): + @functools.partial( + self.pallas_call, out_shape=jax.ShapeDtypeStruct((8,), jnp.bool_), + grid=1) + def kernel(x_ref, y_ref, o_ref): + o_ref[:] = fn(x_ref[...], y_ref[...]) + + x = jnp.array([1, 3, -4, -6, 2, 5, 4, -7]).astype(dtype) + y = jnp.array([3, 1, -4, -5, 2, -2, 2, 4]).astype(dtype) + np.testing.assert_allclose(kernel(x, y), fn(x, y)) + + def test_isnan(self): + @functools.partial( + self.pallas_call, out_shape=jax.ShapeDtypeStruct((8,), jnp.bool_), + grid=1) + def isnan(x_ref, o_ref): + o_ref[:] = jnp.isnan(x_ref[...]) + + x = jnp.arange(8.) + x = x.at[3].set(jnp.nan) + np.testing.assert_allclose(isnan(x), jnp.isnan(x)) + + @parameterized.parameters( + ("int32", "float32"), + ("float32", "float32"), + ) + def test_true_divide(self, dtype, out_dtype): + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct((8,), out_dtype), + grid=1, + ) + def kernel(x_ref, y_ref, o_ref): + o_ref[...] = jnp.true_divide(x_ref[...], y_ref[...]) + + x = jnp.array([1, 3, -4, -6, 2, 5, 4, -7]).astype(dtype) + y = jnp.array([3, 1, -4, -5, 2, -2, 2, 4]).astype(dtype) + np.testing.assert_allclose(jnp.true_divide(x, y), kernel(x, y)) + + @parameterized.parameters("float16", "bfloat16") + def test_true_divide_unsupported(self, dtype): + if self.INTERPRET: + self.skipTest("No lowering in interpreter mode") + + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct((2,), dtype), + grid=1, + ) + def kernel(x_ref, y_ref, o_ref): + o_ref[...] = jnp.true_divide(x_ref[...], y_ref[...]) + + x = jnp.array([2.4, 4.2]).astype(dtype) + y = jnp.array([4.2, 2.4]).astype(dtype) + with self.assertRaises(Exception): + kernel(x, y) + + BINARY_OPS = [ + ([jnp.floor_divide], ["int32", "uint32"]), + ( + [jnp.add, jnp.subtract, jnp.multiply], + ["int16", "int32", "uint32", "float16", "float32"], + ), + ([jnp.remainder], ["int32", "uint32", "float32"]), + ( + # fmt: off + [jnp.bitwise_and, jnp.bitwise_or, jnp.bitwise_xor, + jnp.bitwise_left_shift, jnp.bitwise_right_shift], + # fmt: on + ["int32", "uint32"], + ), + ] + + @parameterized.named_parameters( + (f"{fn.__name__}_{dtype}", fn, dtype) + for args in BINARY_OPS + for fn, dtype in itertools.product(*args) + ) + def test_binary(self, f, dtype): + @functools.partial( + self.pallas_call, out_shape=jax.ShapeDtypeStruct((8,), dtype), grid=1 + ) + def kernel(x_ref, y_ref, o_ref): + o_ref[...] = f(x_ref[...], y_ref[...]) + + x = jnp.array([1, 3, -4, -6, 2, 5, 4, -7]).astype(dtype) + if (f == jnp.bitwise_left_shift): + y = jnp.array([3, 1, 4, 5, 2, 2, 2, 4]).astype(dtype) + else: + y = jnp.array([3, 1, -4, -5, 2, -2, 2, 4]).astype(dtype) + + np.testing.assert_allclose(f(x, y), kernel(x, y)) + + @parameterized.parameters( + ((8, 4), jnp.int32, 0), + ((8, 16), jnp.float32, 1), + ((8, 16, 2), jnp.int8, 1), + ) + def test_broadcasted_iota(self, shape, dtype, dimension): + f = lambda: jax.lax.broadcasted_iota(dtype, shape, dimension) + + @functools.partial( + self.pallas_call, out_shape=jax.ShapeDtypeStruct(shape, dtype), grid=1 + ) + def kernel(o_ref): + o_ref[...] = f() + + np.testing.assert_allclose(f(), kernel()) + + @parameterized.parameters("float16", "bfloat16", "float32") + def test_approx_tanh(self, dtype): + if self.INTERPRET: + self.skipTest("approx_tanh is not supported in interpreter mode") + if (dtype == "bfloat16" and + not jtu.is_cuda_compute_capability_at_least("9.0")): + self.skipTest("tanh.approx.bf16 requires a GPU with capability >= sm90") + + @functools.partial( + self.pallas_call, out_shape=jax.ShapeDtypeStruct((4,), dtype), grid=1 + ) + def kernel(x_ref, o_ref): + o_ref[...] = plgpu.approx_tanh(x_ref[...]) + + x = jnp.asarray([-1, 0.42, 0.24, 1]).astype(dtype) + # We upcast to float32 because NumPy <2.0 does not handle custom dtypes + # properly. See https://github.com/google/jax/issues/11014. + np.testing.assert_allclose( + kernel(x).astype(jnp.float32), + jnp.tanh(x).astype(jnp.float32), + atol=5e-3, + rtol=5e-3, + ) + + def test_elementwise_inline_asm(self): + if self.INTERPRET: + self.skipTest( + "elementwise_inline_asm is not supported in interpreter mode" + ) + + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct((256,), jnp.float16), + grid=1, + ) + def kernel(x_ref, o_ref): + [o_ref[...]] = plgpu.elementwise_inline_asm( + "tanh.approx.f16x2 $0, $1;", + args=[x_ref[...]], + constraints="=r,r", + pack=2, + result_shape_dtypes=[jax.ShapeDtypeStruct(x_ref.shape, x_ref.dtype)], + ) + + x = jnp.arange(256).astype(jnp.float16) + np.testing.assert_allclose(kernel(x), jnp.tanh(x), atol=5e-3, rtol=5e-3) + + def test_debug_print(self): + # TODO: this test flakes on gpu + if jtu.test_device_matches(["gpu"]): + self.skipTest("This test flakes on gpu") + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct((2,), jnp.float32), + grid=1, + compiler_params=dict(triton=dict(num_warps=1, num_stages=1)) + ) + def kernel(x_ref, o_ref): + pl.debug_print("It works!") + + x = jnp.array([4.2, 2.4]).astype(jnp.float32) + with jtu.capture_stdout() as output: + jax.block_until_ready(kernel(x)) + jax.effects_barrier() + + self.assertIn("It works!", output()) + + def test_debug_print_with_values(self): + # TODO: this test flakes on gpu + if jtu.test_device_matches(["gpu"]): + self.skipTest("This test flakes on gpu") + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct((2,), jnp.float32), + grid=1, + compiler_params=dict(triton=dict(num_warps=1, num_stages=1)) + ) + def kernel(x_ref, o_ref): + pl.debug_print("x[0] =", x_ref[0]) + + x = jnp.array([4.2, 2.4]).astype(jnp.float32) + with jtu.capture_stdout() as output: + jax.block_until_ready(kernel(x)) + jax.effects_barrier() + + self.assertIn("x[0] = 4.2", output()) + + @parameterized.parameters( + ((2, 4), (8,)), + ((2, 4), (8, 1)), + ((2, 4), (1, 8)), + ((64,), (32, 2)), + ) + def test_reshape(self, in_shape, out_shape): + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct(out_shape, jnp.float32), + grid=1, + ) + def f(x_ref, o_ref): + o_ref[...] = x_ref[...].reshape(out_shape) + + x = jnp.arange(int(np.prod(in_shape)), dtype=jnp.float32).reshape(in_shape) + expected = x.reshape(out_shape) + np.testing.assert_allclose(f(x), expected) + + @parameterized.parameters( + # fmt: off + ((), (1,)), + ((), (1, 1)), + ((2, 4), (2, 4)), + ((2, 4), (2, 4, 1)), + ((2, 4, 1), (2, 4)), + ((2, 4), (1, 2, 4)), + ((1, 2, 4), (2, 4)), + ((2, 4), (2, 1, 4)), + ((1, 2, 1, 4, 1), (2, 4)), + ((2, 4,), (1, 2, 1, 4)), + ((2, 4,), (1, 2, 4, 1)), + ((1, 2, 4, 1), (1, 2, 1, 4, 1)), + # fmt: on + ) + def test_reshape_noop_or_singleton_dims(self, in_shape, out_shape): + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct(out_shape, jnp.float32), + grid=1, + ) + def f(x_ref, o_ref): + o_ref[...] = x_ref[...].reshape(out_shape) + + x = jnp.arange(int(np.prod(in_shape)), dtype=jnp.float32).reshape(in_shape) + expected = x.reshape(out_shape) + np.testing.assert_allclose(f(x), expected) + + def test_num_programs(self): + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct((4,), jnp.int32), + grid=4, + ) + def kernel(o_ref): + o_ref[pl.program_id(0)] = pl.num_programs(0) + + np.testing.assert_array_equal( + kernel(), np.asarray([4, 4, 4, 4], dtype=np.int32) + ) + + def test_where_broadcasting(self): + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct((4, 2, 2), jnp.float32), + grid=1, + ) + def copyitem(x_ref, in_idx_ref, out_idx_ref, o_ref): + mask = (jnp.arange(o_ref.shape[0]) == out_idx_ref[()])[:, None, None] + o_ref[...] = jnp.where(mask, x_ref[in_idx_ref[()]], 0) + + x = jnp.arange(7 * 2 * 2.0).reshape(7, 2, 2) + for ii in range(7): + for oi in range(4): + out = copyitem(x, ii, oi) + self.assertEqual((4, 2, 2), out.shape) + np.testing.assert_allclose(out[:oi], jnp.zeros_like(out[:oi])) + np.testing.assert_allclose(out[oi], x[ii]) + np.testing.assert_allclose(out[oi + 1 :], jnp.zeros_like(out[oi + 1 :])) + + @parameterized.parameters( + ((), (2,), ()), + ((1,), (2,), (0,)), + ((1, 1), (2, 2), (0, 1)), + ((), (2, 2), ()), + ) + def test_broadcast_in_dim(self, in_shape, out_shape, dims): + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct(out_shape, jnp.float32), + grid=1, + ) + def f(x_ref, o_ref): + x = x_ref[...] + o_ref[...] = jax.lax.broadcast_in_dim(x, out_shape, dims) + + x = jnp.arange(int(np.prod(in_shape)), dtype=jnp.float32).reshape(in_shape) + expected = jax.lax.broadcast_in_dim(x, out_shape, dims) + np.testing.assert_allclose(f(x), expected) + + @parameterized.product( + size=[16, 32, 64], + dtype=["float32", "float16"], + trans_x=[False, True], + trans_y=[False, True], + ) + def test_dot(self, size, dtype, trans_x, trans_y): + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct((size, size), dtype), + grid=1, + ) + def dot(x_ref, y_ref, o_ref): + x = x_ref[:, :] + y = y_ref[:, :] + o_ref[:, :] = pl.dot(x, y, trans_x, trans_y).astype(o_ref.dtype) + + k1, k2 = random.split(random.key(0)) + x = random.normal(k1, (size, size), dtype=dtype) + y = random.normal(k2, (size, size), dtype=dtype) + out = dot(x, y) + expected = jnp.dot(x.T if trans_x else x, y.T if trans_y else y) + np.testing.assert_allclose(out, expected, atol=0.05, rtol=0.05) + + @parameterized.product( + size=[1, 2, 64, 129, 1021], + block_size=[1, 2, 32, 64, 128], + ) + def test_masked_load_store(self, size, block_size): + @functools.partial( + self.pallas_call, + out_shape=(jax.ShapeDtypeStruct((size,), jnp.float32)), + grid=pl.cdiv(size, block_size), + ) + def kernel(x_ref, o_ref): + idx = pl.program_id(0) * block_size + jnp.arange(block_size) + mask = idx < x_ref.shape[0] + x = pl.load(x_ref, (idx,), mask=mask) + pl.store(o_ref, (idx,), x + 1.0, mask=mask) + + key = random.key(0) + x = random.normal(key, (size,)) + np.testing.assert_allclose(kernel(x), x + 1.0, atol=1e-5, rtol=1e-5) + + def test_masked_oob_load_store_slice(self): + n = 16 + + @functools.partial( + self.pallas_call, + out_shape=(jax.ShapeDtypeStruct((n,), jnp.float32)), + grid=1, + ) + def masked_oob_load_store_slice(x_ref, mask_ref, start_idx_ref, o_ref): + x = pl.load(x_ref, (pl.dslice(start_idx_ref[()], n)), + mask=mask_ref[:], other=-1.) + pl.store(o_ref, (pl.dslice(None),), x) + + x = random.normal(random.key(0), (n,)) + slice_start = random.randint(random.key(2), (), 1, n) + indices = jnp.arange(n) + slice_start + mask = indices < n + out = masked_oob_load_store_slice(x, mask, slice_start) + o_new = jnp.where(mask, x[indices], jnp.full_like(x, -1.)) + np.testing.assert_array_equal(out, o_new) + + def test_strided_load(self): + if self.INTERPRET: + # TODO(b/329733289): Remove this once the bug is fixed. + self.skipTest("Strided load not yet supported in interpreter mode") + + # Reproducer from https://github.com/google/jax/issues/20895. + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct((4,), jnp.float32), + ) + def kernel(x_ref, o_ref): + o_ref[...] = x_ref[::4] + + x = jnp.arange(16, dtype=jnp.float32) + np.testing.assert_array_equal(kernel(x), x[::4]) + + def test_broadcasted_load_store(self): + m, n = 16, 32 + + @functools.partial( + self.pallas_call, + out_shape=(jax.ShapeDtypeStruct((m, n), jnp.float32)), + grid=1, + ) + def load(x_ref, o_ref): + x = pl.load(x_ref, (jnp.arange(m)[:, None], jnp.arange(n)[None, :])) + pl.store(o_ref, (jnp.arange(m)[:, None], jnp.arange(n)[None, :]), x + 1.0) + + key = random.key(0) + x = random.normal(key, (m, n)) + np.testing.assert_allclose(load(x), x + 1.0, atol=1e-5, rtol=1e-5) + + @parameterized.parameters( + ((16, 32), (16,)), + ((16, 32), (32,)), + ((16, 32), (16, 31)), + ) + def test_invalid_broadcasted_load(self, x_shape, mask_shape): + if self.INTERPRET: + self.skipTest("No broadcasting checks in pl.load in interpreter mode") + + @functools.partial( + self.pallas_call, out_shape=jax.ShapeDtypeStruct((), jnp.float32) + ) + def kernel(x_ref, mask_ref, o_ref): + del o_ref # Unused. + pl.load(x_ref, slice(None), mask=mask_ref[:]) + + x = jnp.ones(x_shape, dtype=jnp.float32) + mask = jnp.ones(mask_shape, dtype=jnp.bool_) + # assertRaises* methods do not support inspecting the __cause__, so + # we have to check it manually. + try: + kernel(x, mask) + except Exception as e: + self.assertIn("Cannot broadcast", str(e.__cause__)) + else: + self.fail("Expected exception due to invalid broadcasting") + + def test_swap(self): + m, n = 16, 32 + + @functools.partial( + self.pallas_call, + out_shape=(jax.ShapeDtypeStruct((m, n), jnp.float32),) * 2, + grid=1, + input_output_aliases={0: 0, 1: 1}, + ) + def swap(_, _2, x_ref, y_ref): + x = x_ref[:] + y = pl.swap(y_ref, (slice(None),), x) + x_ref[:] = y + + x = random.normal(random.key(0), (m, n)) + y = random.normal(random.key(1), (m, n)) + out = swap(x, y) + np.testing.assert_array_equal(out[0], y) + np.testing.assert_array_equal(out[1], x) + + def test_masked_swap(self): + m, n = 16, 32 + + @functools.partial( + self.pallas_call, + out_shape=(jax.ShapeDtypeStruct((m, n), jnp.float32),) * 2, + grid=1, + input_output_aliases={0: 0, 1: 1}, + ) + def masked_swap(_, _2, mask_ref, x_ref, y_ref): + x = x_ref[:] + y = pl.swap(y_ref, (slice(None),), x, mask=mask_ref[:]) + x_ref[:] = y + + x = random.normal(random.key(0), (m, n)) + y = random.normal(random.key(1), (m, n)) + mask = random.bernoulli(random.key(2), shape=(m, n)) + out = masked_swap(x, y, mask) + np.testing.assert_array_equal(out[0], jnp.where(mask, y, x)) + np.testing.assert_array_equal(out[1], jnp.where(mask, x, y)) + + def test_masked_oob_swap_slice(self): + m, n = 32, 16 + + @functools.partial( + self.pallas_call, + out_shape=(jax.ShapeDtypeStruct((n,), jnp.float32), + jax.ShapeDtypeStruct((m,), jnp.float32)), + grid=1, + input_output_aliases={0: 0, 1: 1}, + ) + def masked_oob_swap_slice(_, _2, mask_ref, start_idx_ref, x_ref, y_ref): + x, mask = x_ref[:], mask_ref[:] + y = pl.swap(y_ref, (pl.dslice(start_idx_ref[()], n)), x, mask=mask) + x_ref[:] = y + + x = random.normal(random.key(0), (n,)) + y = random.normal(random.key(1), (m,)) + slice_start = random.randint(random.key(2), (), m-n+1, m) + indices = jnp.arange(n) + slice_start + mask = indices < m + out = masked_oob_swap_slice(x, y, mask, slice_start) + + # the unjittable masked indexing equivalent + unmasked_idx = indices[mask] + x_new = x.at[mask].set(y[unmasked_idx]) + y_new = y.at[unmasked_idx].set(x[mask]) + np.testing.assert_array_equal(out[0], x_new) + np.testing.assert_array_equal(out[1], y_new) + + @parameterized.named_parameters( + ("add_i32", pl.atomic_add, np.array([1, 2, 3, 4], np.int32), np.sum), + ("max_i", pl.atomic_max, np.array([1, 2, 3, 4], np.int32), np.max), + ("min_i32", pl.atomic_min, np.array([1, 2, 3, 4], np.int32), np.min), + ("add_f16", pl.atomic_add, np.array([1, 2, 3, 4], np.float16), np.sum), + ("add_f32", pl.atomic_add, np.array([1, 2, 3, 4], np.float32), np.sum), + ("max_f32", pl.atomic_max, np.array([1, 2, 3, 4], np.float32), np.max), + ("min_f32", pl.atomic_min, np.array([1, 2, 3, 4], np.float32), np.min), + ) + def test_scalar_atomic(self, op, value, numpy_op): + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct((), value.dtype), + grid=value.shape[0], + input_output_aliases={1: 0}, + ) + def atomic_kernel(x_ref, _, o_ref): + pid = pl.program_id(axis=0) + op(o_ref, (), x_ref[pid]) - out = jax.vmap(add_one, in_axes=1, out_axes=1)(x) - out_ref = jax.vmap(add_one_ref, in_axes=1, out_axes=1)(x) - np.testing.assert_allclose(out, out_ref) + if op == pl.atomic_add: + neutral = np.array(0, dtype=value.dtype) + elif op == pl.atomic_max: + if np.issubdtype(value.dtype, np.integer): + neutral = np.array(np.iinfo(value.dtype).min, value.dtype) + else: + neutral = np.array(-float("inf"), value.dtype) + elif op == pl.atomic_min: + if np.issubdtype(value.dtype, np.integer): + neutral = np.array(np.iinfo(value.dtype).max, value.dtype) + else: + neutral = np.array(float("inf"), value.dtype) + elif op == pl.atomic_or: + neutral = np.array(False, value.dtype) + else: + raise NotImplementedError() + out = atomic_kernel(value, neutral) + np.testing.assert_allclose(out, numpy_op(value)) - out = jax.vmap(add_one, in_axes=1, out_axes=0)(x) - out_ref = jax.vmap(add_one_ref, in_axes=1, out_axes=0)(x) - np.testing.assert_allclose(out, out_ref) + @parameterized.parameters((0,), (1,)) + def test_array_atomic_add(self, axis): + m, n = 32, 8 + if axis == 0: + grid = m + else: + grid = n + out_shape = jax.ShapeDtypeStruct((n if axis == 0 else m,), jnp.float32) - def test_double_vmap_of_slicing_kernel_different_axes(self): @functools.partial( - self.pallas_call, out_shape=jax.ShapeDtypeStruct((4,), jnp.float32), - debug=False, - grid=(4,)) - def sin(x_ref, o_ref): - i = pl.program_id(0) - o_ref[i] = jnp.sin(x_ref[i]) - sin_ref = jnp.sin - x = jnp.arange(64.).reshape((8, 4, 2)) + self.pallas_call, + out_shape=out_shape, + grid=grid, + input_output_aliases={1: 0}, + ) + def reduce(x_ref, _, y_ref): + i = pl.program_id(axis=0) + if axis == 0: + idx = (i, jnp.arange(n)) + else: + idx = (jnp.arange(m), i) + x = pl.load(x_ref, idx) + pl.atomic_add(y_ref, (jnp.arange(y.shape[0]),), x) - out = jax.vmap(jax.vmap(sin, in_axes=1), in_axes=0)(x) - out_ref = jax.vmap(jax.vmap(sin_ref, in_axes=1), in_axes=0)(x) - np.testing.assert_allclose(out, out_ref, atol=1e-3, rtol=1e-3) + x = random.normal(random.key(0), (m, n)) + y = jnp.zeros(out_shape.shape, out_shape.dtype) + y = reduce(x, y) + y_ref = np.sum(x, axis=axis) + np.testing.assert_allclose(y, y_ref, atol=1e-2, rtol=1e-2) - def test_small_large_vmap(self): - # Catches https://github.com/google/jax/issues/18361 + @parameterized.parameters( + (0, 0, 1), + (0, 1, 1), + (1, 0, 1), + (1, 1, 1), + (2, 1, 1), + (2, 1, 1), + ) + def test_atomic_cas(self, init_value, cmp, new_value): @functools.partial( - self.pallas_call, out_shape=jax.ShapeDtypeStruct((2,), jnp.int32), - debug=False, - grid=(2,)) - def add_one(x_ref, o_ref): - o_ref[()] = x_ref[()] + 1 + self.pallas_call, out_shape=( + jax.ShapeDtypeStruct((), jnp.int32), + jax.ShapeDtypeStruct((), jnp.int32)), + input_output_aliases={0: 0}) + def swap(_, lock_ref, out_ref): + out_ref[()] = pl.atomic_cas(lock_ref, cmp, new_value) - add_one = jax.vmap(jax.vmap(add_one)) - add_one_ref = lambda x: x + 1 + lock, out = swap(init_value) + np.testing.assert_allclose(lock, new_value if cmp == init_value else + init_value) + np.testing.assert_allclose(out, init_value) - x = random.randint(random.key(0), (4, 65536, 2), 0, 10000) + @parameterized.parameters(1, 2, 3, 4, 8) + def test_atomic_counter(self, num_threads): + if self.INTERPRET: + self.skipTest("While loop not supported in interpreter mode.") - out = add_one(x) - out_ref = add_one_ref(x) + @functools.partial( + self.pallas_call, out_shape=( + jax.ShapeDtypeStruct((), jnp.int32), + jax.ShapeDtypeStruct((), jnp.int32)), + input_output_aliases={0: 0, 1: 1}, + grid=(num_threads,)) + def increment(_, __, lock_ref, counter_ref): + def _cond(_): + return pl.atomic_cas(lock_ref, 0, 1) == 1 + lax.while_loop(_cond, lambda a: a, 0) + counter_ref[...] += 1 + pl.atomic_xchg(lock_ref, (), 0) - np.testing.assert_allclose(out, out_ref) + lock, count = increment(0, 0) + np.testing.assert_allclose(lock, 0) + np.testing.assert_allclose(count, num_threads) - def test_small_small_large_vmap(self): - @functools.partial( - self.pallas_call, out_shape=jax.ShapeDtypeStruct((2,), jnp.int32), - debug=False, - grid=(2,)) - def add_one(x_ref, o_ref): - o_ref[()] = x_ref[()] + 1 + @parameterized.parameters(False, True) + def test_reduce_only_dim(self, use_store): + m = 32 + x = random.normal(random.key(0), (m,), dtype=jnp.float32) + out_shape = jax.ShapeDtypeStruct((), x.dtype) - add_one = jax.vmap(jax.vmap(jax.vmap(add_one))) - add_one_ref = lambda x: x + 1 + @functools.partial( + self.pallas_call, out_shape=out_shape, grid=1 + ) + def reduce(x_ref, y_ref): + x = pl.load(x_ref, (jnp.arange(m),)) + y = jnp.sum(x, axis=-1) + if use_store: + pl.store(y_ref, (), y) + else: + y_ref[...] = y - x = random.randint(random.key(0), (2, 2, 65536, 2), 0, 10000) + y = reduce(x) + y_ref = jnp.sum(x, axis=-1) + np.testing.assert_allclose(y, y_ref, atol=1e-2, rtol=1e-2) - out = add_one(x) - out_ref = add_one_ref(x) + @parameterized.named_parameters(*[ + (f"{op_name}_{dtype}_{axis}", op, dtype, axis) + for op_name, op in [ + ("add", jnp.sum), + ("max", jnp.max), + ("min", jnp.min), + ("argmax", jnp.argmax), + ("argmin", jnp.argmin), + ] + for axis in [0, 1, (1,), (0, 1)] + for dtype in ["float16", "float32", "int32", "uint32"] + if isinstance(axis, int) or "arg" not in op_name + ]) + def test_array_reduce(self, op, dtype, axis): + m, n = 32, 8 + out_dtype = dtype + if op in {jnp.argmin, jnp.argmax}: + out_dtype = jnp.int32 - np.testing.assert_allclose(out, out_ref) + def make_x(key): + if jnp.issubdtype(dtype, jnp.integer): + return random.permutation( + key, jnp.arange(m * n, dtype=dtype), independent=True + ).reshape(m, n) + else: + return random.normal(key, (m, n), dtype=dtype) + out_shape = jax.ShapeDtypeStruct( + op(make_x(random.key(0)), axis=axis).shape, out_dtype + ) + if isinstance(axis, int): + grid = tuple(a for i, a in enumerate((m, n)) if i != axis) + else: + grid = tuple(a for i, a in enumerate((m, n)) if i not in axis) -class PallasCallInterpreterVmapTest(PallasCallVmapTest): - INTERPRET = True + @functools.partial(self.pallas_call, out_shape=out_shape, grid=grid) + def reduce(x_ref, y_ref): + x = pl.load(x_ref, (jnp.arange(m)[:, None], jnp.arange(n)[None])) + y = op(x, axis=axis) + pl.store(y_ref, tuple(jnp.arange(d) for d in y.shape), y) -class PallasOpsTest(PallasTest): + for i, key in enumerate(random.split(random.key(0), 20)): + x = make_x(key) + y = reduce(x) + y_ref = op(x, axis=axis) + np.testing.assert_allclose(y, y_ref, atol=1e-2, rtol=1e-2, err_msg=i) - def test_pow_weak_dtype(self): - @functools.partial( - self.pallas_call, out_shape=jax.ShapeDtypeStruct((), jnp.float32)) - def square(x_ref, o_ref): - o_ref[()] = x_ref[()]**2.0 + @parameterized.product( + axis=[0, 1], + dtype=["float16", "float32", "int32", "uint32"], + ) + def test_cumsum(self, dtype, axis): + m, n = 32, 8 + out_dtype = dtype - x = jnp.array(42.0) - np.testing.assert_allclose(square(x), x*x) + def make_x(key): + if jnp.issubdtype(dtype, jnp.integer): + return random.permutation( + key, jnp.arange(m * n, dtype=dtype), independent=True + ).reshape(m, n) + else: + return random.normal(key, (m, n), dtype=dtype) - def test_ne(self): - @functools.partial( - self.pallas_call, out_shape=jax.ShapeDtypeStruct((8,), jnp.bool_), - grid=1) - def ne(x_ref, y_ref, o_ref): - o_ref[:] = x_ref[...] != y_ref[...] + out_shape = jax.ShapeDtypeStruct((m, n), out_dtype) + grid = () - x = jnp.ones(8, dtype=jnp.int32) - y = jnp.arange(8, dtype=jnp.int32) - not_equal = ne(x, y) - np.testing.assert_allclose(not_equal, x != y) + @functools.partial(self.pallas_call, out_shape=out_shape, grid=grid) + def reduce(x_ref, y_ref): + x = x_ref[...] + y_ref[...] = jnp.cumsum(x, axis=axis) - def test_isnan(self): - @functools.partial( - self.pallas_call, out_shape=jax.ShapeDtypeStruct((8,), jnp.bool_), - grid=1) - def isnan(x_ref, o_ref): - o_ref[:] = jnp.isnan(x_ref[...]) + for i, key in enumerate(random.split(random.key(0), 20)): + x = make_x(key) + y = reduce(x) + y_ref = jnp.cumsum(x, axis=axis) + np.testing.assert_allclose(y, y_ref, atol=1e-2, rtol=1e-2, err_msg=i) - x = jnp.arange(8.) - x = x.at[3].set(jnp.nan) - np.testing.assert_allclose(isnan(x), jnp.isnan(x)) -class PallasOpsInterpretTest(PallasOpsTest): +class PallasOpsInterpreterTest(PallasOpsTest): INTERPRET = True + def setUp(self): + super().setUp() + if jtu.test_device_matches(["cpu"]) and jax.config.x64_enabled: + # TODO: assertion failures on CPU in 64-bit mode + self.skipTest("On CPU the test works only in 32-bit mode") + + class PallasPrimitivesTest(PallasTest): @parameterized.parameters(*[ @@ -1606,8 +2042,19 @@ def body(x_ref): lu.wrap_init(body), [state.shaped_array_ref((4, 3, 2), jnp.int32)]) self.assertIn(expected, jaxpr.pretty_print(use_color=False)) + +class PallasPrimitivesInterpreterTest(PallasPrimitivesTest): + INTERPRET = True + + class FusedAttentionTest(PallasTest): + def setUp(self): + super().setUp() + # TODO: fix for other platforms. On TPU if fails even in interpret mode. + if jtu.test_device_matches(["cpu", "tpu"]): + self.skipTest("Works only on GPU") + @parameterized.named_parameters( *[ ( @@ -1658,10 +2105,6 @@ def test_fused_attention_fwd( use_segment_ids, kwargs, ): - if not self.check_gpu_capability_at_least(80): - raise unittest.SkipTest( - "Fused attention only works on GPUs with capability >= sm80") - k1, k2, k3 = random.split(random.key(0), 3) q = random.normal( k1, (batch_size, seq_len, num_heads, head_dim), dtype=jnp.float16 @@ -1735,9 +2178,6 @@ def impl(q, k, v): def test_fused_attention_bwd( self, batch_size, seq_len, num_heads, head_dim, causal, use_segment_ids ): - if not self.check_gpu_capability_at_least(80): - raise unittest.SkipTest( - "Fused attention only works on GPUs with capability >= sm80") k1, k2, k3 = random.split(random.key(0), 3) q = random.normal( k1, (batch_size, seq_len, num_heads, head_dim), dtype=jnp.float16 @@ -1769,19 +2209,23 @@ def f_ref(q, k, v): np.testing.assert_allclose(dv, dv_ref, atol=0.05) -class FusedAttentionInterpreterTest(PallasTest): +class FusedAttentionInterpreterTest(FusedAttentionTest): INTERPRET = True + class FusedLayerNormTest(PallasTest): + def setUp(self): + super().setUp() + # TODO: fix for other platforms; on TPU fails even in interpret mode + if jtu.test_device_matches(["cpu", "tpu"]): + self.skipTest("Works only on GPU") + @parameterized.parameters(*[ (1, 384, 192), (2, 384, 192), ]) def test_fused_layernorm_fwd(self, batch_size, seq_len, embed_dim): - if not self.check_gpu_capability_at_least(70): - raise unittest.SkipTest( - "Fused layernorm only works on GPUs with capability >= sm70") k1, k2, k3 = random.split(random.key(0), 3) x = random.normal(k1, (batch_size, seq_len, embed_dim), dtype=jnp.float32) w = jax.random.normal(k2, (embed_dim,), dtype=jnp.float32) @@ -1796,9 +2240,6 @@ def test_fused_layernorm_fwd(self, batch_size, seq_len, embed_dim): (2, 384, 192), ]) def test_fused_layernorm_bwd(self, batch_size, seq_len, embed_dim): - if not self.check_gpu_capability_at_least(70): - raise unittest.SkipTest( - "Fused layernorm only works on GPUs with capability >= sm70") k1, k2, k3 = random.split(random.key(0), 3) x = random.normal(k1, (batch_size, seq_len, embed_dim), dtype=jnp.float32) w = jax.random.normal(k2, (embed_dim,), dtype=jnp.float32) @@ -1817,20 +2258,23 @@ def f_ref(x, w, b): np.testing.assert_allclose(db, db_ref, rtol=1e-2, atol=1e-2) -class FusedLayerNormInterpreterTest(PallasTest): +class FusedLayerNormInterpreterTest(FusedLayerNormTest): INTERPRET = True class RmsNormTest(PallasTest): + def setUp(self): + super().setUp() + # TODO: fix for other platforms + if jtu.test_device_matches(["cpu", "tpu"]): + self.skipTest("Works only on GPU") + @parameterized.parameters(*[ (1, 384, 192), (2, 384, 192), ]) def test_rms_fwd(self, batch_size, seq_len, embed_dim): - if not self.check_gpu_capability_at_least(70): - raise unittest.SkipTest( - "Rms norm only works on GPUs with capability >= sm70") k1, k2, k3 = random.split(random.key(0), 3) x = random.normal(k1, (batch_size, seq_len, embed_dim), dtype=jnp.float32) w = jax.random.normal(k2, (embed_dim,), dtype=jnp.float32) @@ -1845,9 +2289,6 @@ def test_rms_fwd(self, batch_size, seq_len, embed_dim): (2, 384, 192), ]) def test_rms_norm_bwd(self, batch_size, seq_len, embed_dim): - if not self.check_gpu_capability_at_least(70): - raise unittest.SkipTest( - "Rms norm only works on GPUs with capability >= sm70") k1, k2, k3 = random.split(random.key(0), 3) x = random.normal(k1, (batch_size, seq_len, embed_dim), dtype=jnp.float32) w = jax.random.normal(k2, (embed_dim,), dtype=jnp.float32) @@ -1865,21 +2306,24 @@ def f_ref(x, w, b): np.testing.assert_allclose(dw, dw_ref, rtol=1e-2, atol=1e-2) np.testing.assert_allclose(db, db_ref, rtol=1e-2, atol=1e-2) -class RmsNormInterpreterTest(PallasTest): + +class RmsNormInterpreterTest(RmsNormTest): INTERPRET = True + class SoftmaxTest(PallasTest): - @parameterized.parameters( - (shape, dtype) - for shape in [(1024, 125), (4, 1024, 125)] - for dtype in (jnp.bfloat16, jnp.float16, jnp.float32) + def setUp(self): + super().setUp() + # TODO: fix for other platforms + if jtu.test_device_matches(["cpu", "tpu"]): + self.skipTest("Works only on GPU") + + @parameterized.product( + shape=[(1024, 125), (4, 1024, 125)], + dtype=[jnp.bfloat16, jnp.float16, jnp.float32] ) def test_softmax(self, shape, dtype): - # TODO(bchetioui): add Triton bug reference when filed - if dtype == jnp.bfloat16: - raise absltest.SkipTest("Disabled due to Triton lowering bug") - x = jax.random.normal(random.key(0), shape, dtype=dtype) atol, rtol = { @@ -1888,16 +2332,207 @@ def test_softmax(self, shape, dtype): jnp.float32: (1e-7, 1e-6), }[dtype] + # We upcast to float32 because NumPy <2.0 does not handle custom dtypes + # properly. See https://github.com/google/jax/issues/11014. np.testing.assert_allclose( - softmax.softmax(x, axis=-1), - jax.nn.softmax(x, axis=-1), + softmax.softmax(x, axis=-1).astype(jnp.float32), + jax.nn.softmax(x, axis=-1).astype(jnp.float32), atol=atol, rtol=rtol, ) -class SoftmaxInterpreterTest(PallasTest): +class SoftmaxInterpreterTest(SoftmaxTest): INTERPRET = True + +class PallasOutOfBoundsInterpreterTest(PallasTest): + + INTERPRET: bool = True + + def test_interpret_mode_out_of_bounds_access(self): + block_size = 32 + dtype = jnp.float32 + # Create input tensors which require a reduction along an axis + # not divisible by block_size. + x = jax.random.normal(jax.random.key(0), + (block_size, block_size + 1), + dtype=dtype) + y = jax.random.normal(jax.random.key(1), + (block_size + 1, block_size), + dtype=dtype) + expected = x @ y + + in_specs = [ + pl.BlockSpec((block_size, block_size), lambda i, j, k: (i, k)), + pl.BlockSpec((block_size, block_size), lambda i, j, k: (k, j)), + ] + out_spec = pl.BlockSpec((block_size, block_size), lambda i, j, k: (i, j)) + + def _unmasked_matmul_kernel(x_ref, y_ref, o_ref): + @pl.when(pl.program_id(2) == 0) + def _(): + o_ref[...] = jnp.zeros_like(o_ref) + + o_ref[...] += x_ref[...] @ y_ref[...] + + out = self.pallas_call( + _unmasked_matmul_kernel, + out_shape=expected, + grid=(1, 1, 2), + in_specs=in_specs, + out_specs=out_spec)(x, y) + + # With a naive matmul implementation, using uninitialized values (NaN) will + # cause the overall output to be NaN. + with self.subTest('UnmaskedIsNaN'): + np.testing.assert_allclose( + np.isnan(out), jnp.ones_like(out, dtype=jnp.bool_) + ) + + def _masked_matmul_kernel(x_ref, y_ref, o_ref): + @pl.when(pl.program_id(2) == 0) + def _(): + o_ref[:, :] = jnp.zeros_like(o_ref) + + # Create a validity mask for OOB values. + num_valid = x.shape[1] - pl.program_id(2) * block_size + num_valid = jnp.minimum(num_valid, block_size) + mask = jnp.tril(jnp.ones_like(x_ref[:, :]))[num_valid - 1][jnp.newaxis, :] + mask = jnp.repeat(mask, block_size, axis=0) + + # Mask and multiply. + masked_x = jnp.where(mask, x_ref[:, :], 0.0) + masked_y = jnp.where(mask.T, y_ref[:, :], 0.0) + o_ref[:, :] += masked_x @ masked_y + + out = self.pallas_call( + _masked_matmul_kernel, + out_shape=expected, + grid=(1, 1, 2), + in_specs=in_specs, + out_specs=out_spec)(x, y) + + # TODO(justinfu): This test has low precision on GPU. Improve precision. + if jtu.test_device_matches(["gpu"]): + atol = 1e-2 + else: + atol = 1e-5 + + # With a masked matmul implementation, uninitialized values will be + # masked before computation. This should return the correct result. + with self.subTest('MaskedOutputIsCorrect'): + np.testing.assert_allclose(out, expected, atol=atol) + + +class PallasCheckifyInterpreterTest(PallasTest): + # TODO(b/346651778): Support non-interpret mode checkify. + INTERPRET: bool = True + + def test_no_checkify(self,): + def kernel(y_ref): + y_ref[...] = jnp.zeros_like(y_ref[...]) + out_shape = jax.ShapeDtypeStruct((2, 2), jnp.float32) + pallas_call = self.pallas_call(kernel, + out_shape=out_shape) + checked_call = checkify.checkify(pallas_call) + err, result = checked_call() + err.throw() # Should not raise. + np.testing.assert_allclose(result, jnp.zeros_like(result)) + + def test_does_not_clobber_previous_error(self,): + def kernel(y_ref): + y_ref[...] = jnp.zeros_like(y_ref[...]) + checkify.check(False, "error in kernel") + out_shape = jax.ShapeDtypeStruct((2, 2), jnp.float32) + pallas_call = self.pallas_call(kernel, + out_shape=out_shape) + def error_before_call(): + checkify.check(False, "error before call") + return pallas_call() + checked_call = checkify.checkify(error_before_call) + err, result = checked_call() + with self.assertRaisesRegex( + checkify.JaxRuntimeError, "error before call"): + err.throw() + np.testing.assert_allclose(result, jnp.zeros_like(result)) + + @parameterized.parameters((False,), (True,)) + def test_trivial_check(self, assert_cond): + def kernel(x_ref, y_ref): + y_ref[...] = x_ref[...] + checkify.check(assert_cond, "pallas check failed") + input = jnp.arange(4, dtype=jnp.int32) + out_shape = jax.ShapeDtypeStruct(input.shape, input.dtype) + pallas_call = self.pallas_call(kernel, + out_shape=out_shape) + checked_call = checkify.checkify(pallas_call) + err, result = checked_call(input) + if not assert_cond: + with self.assertRaisesRegex( + checkify.JaxRuntimeError, "pallas check failed"): + err.throw() + np.testing.assert_allclose(result, input) + + def test_nan_error(self): + def kernel(x_ref, y_ref): + y_ref[...] = jnp.log(x_ref[...]) + input = jnp.arange(4, dtype=jnp.float32) - 2 + out_shape = jax.ShapeDtypeStruct(input.shape, input.dtype) + pallas_call = self.pallas_call(kernel, + out_shape=out_shape) + checked_call = checkify.checkify(pallas_call, + errors=checkify.all_checks) + err, result = checked_call(input) + with self.assertRaisesRegex( + checkify.JaxRuntimeError, "nan generated by primitive: log"): + err.throw() + is_nan = jnp.isnan(result) + np.testing.assert_allclose(is_nan, input < 0) + + def test_nan_error_with_assertion(self): + # TODO(b/346842088): Fix check asserts clobbering other errors. + self.skipTest('Known failure.') + # Test NaN error is not clobbered by an assertion failure + def kernel(x_ref, y_ref): + y_ref[...] = jnp.log(x_ref[...]) + checkify.check(False, "do not raise") + input = jnp.arange(4, dtype=jnp.float32) - 10 + out_shape = jax.ShapeDtypeStruct(input.shape, input.dtype) + pallas_call = self.pallas_call(kernel, + out_shape=out_shape) + checked_call = checkify.checkify(pallas_call, + errors=checkify.all_checks) + err, _ = checked_call(input) + with self.assertRaisesRegex( + checkify.JaxRuntimeError, "nan generated by primitive: log"): + err.throw() + + @parameterized.parameters((5, 0), (8, 3), (4, 3)) + def test_checkify_returns_first_error_in_grid( + self, num_loops, fail_iteration): + # Check that checkify returns the first error that occurs + # TODO(justinfu): This test doesn't make sense on GPU, where threads run + # in parallel. Update checkify to return a grid of errors. + def kernel(x_ref, _): + value = jnp.squeeze(x_ref[...]) + checkify.check( + value < fail_iteration, "failed on loop {itr}", itr=value) + input_arr = jnp.arange(num_loops, dtype=jnp.float32) + in_specs = [pl.BlockSpec((1,), lambda x: (x,))] + out_shape = jax.ShapeDtypeStruct((1,), dtype=jnp.float32) + pallas_call = self.pallas_call(kernel, + grid=(num_loops,), + in_specs=in_specs, + out_shape=out_shape) + + checked_call = checkify.checkify(pallas_call, + errors=checkify.all_checks) + err, _ = checked_call(input_arr) + with self.assertRaisesRegex( + checkify.JaxRuntimeError, f"failed on loop {fail_iteration}"): + err.throw() + + if __name__ == "__main__": absltest.main() diff --git a/tests/pallas/splash_attention_kernel_test.py b/tests/pallas/splash_attention_kernel_test.py index c874ce9925a0..e6132a1966a3 100644 --- a/tests/pallas/splash_attention_kernel_test.py +++ b/tests/pallas/splash_attention_kernel_test.py @@ -15,9 +15,10 @@ """Tests for splash_attention.""" from __future__ import annotations +from collections.abc import Callable import dataclasses import functools -from typing import Any, Callable, TypeVar +from typing import Any, TypeVar import unittest from absl.testing import absltest @@ -224,7 +225,13 @@ def sequence_length_strategy(draw: Draw) -> tuple[int, int]: def attention_strategy(draw: Draw) -> tuple[int, int, int, np.dtype]: q_seq_len, kv_seq_len = draw(sequence_length_strategy()) head_dim = draw(hps.sampled_from([128, 256])) - dtype = draw(hps.sampled_from([np.dtype("float32"), np.dtype(jnp.bfloat16)])) + if q_seq_len >= 4096 and kv_seq_len >= 4096: + # Do not draw bfloat16 on longer sequence lengths, as this increases + # the risk of numerical precision errors causing false positives in + # tests. + dtype = np.dtype("float32") + else: + dtype = draw(hps.sampled_from([np.dtype("float32"), np.dtype(jnp.bfloat16)])) return q_seq_len, kv_seq_len, head_dim, dtype @@ -301,13 +308,14 @@ def attn_logits_soft_cap_strategy() -> hps.SearchStrategy[float | None]: class AttentionTest(jtu.JaxTestCase): def setUp(self): - super().setUp() if not jtu.test_device_matches(["tpu"]): self.skipTest("Need TPU devices") # TODO(b/327487669): selectively re-enable tests that works on TPU v3. if not jtu.is_device_tpu_at_least(4): self.skipTest("Not supported on TPU generations <= 3") + super().setUp() + def _assert_allclose(self, x, y, **kwargs): if x.dtype == np.dtype(jnp.bfloat16): x = x.astype(np.float32) @@ -353,7 +361,7 @@ def test_splash_attention(self, is_mqa, is_segmented, data): attn_logits_soft_cap = data.draw(attn_logits_soft_cap_strategy()) masks = data.draw(mha_mask_strategy(q_seq_len, kv_seq_len, num_q_heads)) - mask = mask_lib.MultiHeadMask(tuple((m.get_mask() for m in masks))) + mask = mask_lib.MultiHeadMask(tuple(m.get_mask() for m in masks)) block_sizes = data.draw(block_sizes_strategy(q_seq_len, kv_seq_len)) if is_mqa: @@ -414,7 +422,7 @@ def test_splash_attention_fwd( segment_ids = data.draw(segment_ids_strategy(q_seq_len)) attn_logits_soft_cap = data.draw(attn_logits_soft_cap_strategy()) masks = data.draw(mha_mask_strategy(q_seq_len, kv_seq_len, num_q_heads)) - mask = mask_lib.MultiHeadMask(tuple((m.get_mask() for m in masks))) + mask = mask_lib.MultiHeadMask(tuple(m.get_mask() for m in masks)) block_sizes = data.draw(block_sizes_strategy(q_seq_len, kv_seq_len)) if is_mqa: attn_ref = splash.make_masked_mqa_reference(mask) @@ -510,21 +518,18 @@ def test_splash_attention_custom_bwd(self, is_segmented, data): atols["dk"] = 0.09 else: raise NotImplementedError - with self.subTest("dv"): - self._assert_allclose( - dv_vanilla, dv_ref, atol=atols_v["dv"], rtol=rtols_v["dv"] - ) - self._assert_allclose(dv, dv_ref, atol=atols["dv"], rtol=rtols["dv"]) - with self.subTest("dq"): - self._assert_allclose( - dq_vanilla, dq_ref, atol=atols_v["dq"], rtol=rtols_v["dq"] - ) - self._assert_allclose(dq, dq_ref, atol=atols["dq"], rtol=rtols["dq"]) - with self.subTest("dk"): - self._assert_allclose( - dk_vanilla, dk_ref, atol=atols_v["dk"], rtol=rtols_v["dk"] - ) - self._assert_allclose(dk, dk_ref, atol=atols["dk"], rtol=rtols["dk"]) + self._assert_allclose( + dv_vanilla, dv_ref, atol=atols_v["dv"], rtol=rtols_v["dv"] + ) + self._assert_allclose(dv, dv_ref, atol=atols["dv"], rtol=rtols["dv"]) + self._assert_allclose( + dq_vanilla, dq_ref, atol=atols_v["dq"], rtol=rtols_v["dq"] + ) + self._assert_allclose(dq, dq_ref, atol=atols["dq"], rtol=rtols["dq"]) + self._assert_allclose( + dk_vanilla, dk_ref, atol=atols_v["dk"], rtol=rtols_v["dk"] + ) + self._assert_allclose(dk, dk_ref, atol=atols["dk"], rtol=rtols["dk"]) @parameterized.product( is_mqa=(False, True), diff --git a/tests/pallas/splash_attention_mask_test.py b/tests/pallas/splash_attention_mask_test.py index a408872100c4..ce7d8fd09182 100644 --- a/tests/pallas/splash_attention_mask_test.py +++ b/tests/pallas/splash_attention_mask_test.py @@ -15,7 +15,6 @@ """Tests for splash_attention_masks.""" from __future__ import annotations -from typing import List from absl.testing import absltest from absl.testing import parameterized import jax @@ -733,7 +732,7 @@ def _expected_local_mask_next(self, mask_base_index: int): _expected_local_mask_next_dkv = _expected_local_mask_next - def _stack(self, arrays: List[np.ndarray]) -> np.ndarray: + def _stack(self, arrays: list[np.ndarray]) -> np.ndarray: return np.stack(arrays, axis=0) # For each test, check both the lazy and the dense versions of the mask. diff --git a/tests/pallas/tpu/BUILD b/tests/pallas/tpu/BUILD new file mode 100644 index 000000000000..4b0ffa941510 --- /dev/null +++ b/tests/pallas/tpu/BUILD @@ -0,0 +1,47 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load( + "//jaxlib:jax.bzl", + "jax_generate_backend_suites", + "jax_test", + "py_deps", +) + +licenses(["notice"]) + +package( + default_applicable_licenses = [], + default_visibility = ["//visibility:private"], +) + +jax_generate_backend_suites() + +jax_test( + name = "pallas_random_test", + srcs = [ + "pallas_random_test.py", + ], + disable_backends = [ + "cpu", + "gpu", + ], + deps = [ + "//jax:pallas", + "//jax:pallas_tpu", + "//jax/_src/pallas/mosaic:random", + "//third_party/py/absl/testing:absltest", + "//third_party/py/absl/testing:parameterized", + ] + py_deps("numpy"), +) diff --git a/tests/pallas/tpu/pallas_random_test.py b/tests/pallas/tpu/pallas_random_test.py new file mode 100644 index 000000000000..64e2184e925d --- /dev/null +++ b/tests/pallas/tpu/pallas_random_test.py @@ -0,0 +1,211 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for random ops in Pallas + Mosaic.""" + +from absl.testing import absltest +from absl.testing import parameterized +import jax +from jax import random as jax_random +from jax._src import test_util as jtu +from jax._src.pallas.mosaic import random as plrandom +from jax.experimental import pallas as pl +from jax.experimental.pallas import tpu as pltpu +import jax.numpy as jnp +import numpy as np + + +jax.config.parse_flags_with_absl() + + +class PRNGTest(jtu.JaxTestCase): + + def setUp(self): + if not jtu.test_device_matches(["tpu"]): + self.skipTest("Need TPU devices") + super().setUp() + + def test_to_pallas_key_under_vmap(self): + key = jax.random.key(42, impl="rbg") + key = jax.random.split(key, 10) + batched_key = plrandom.to_pallas_key(key) + batched_key_data = jax.random.key_data(batched_key) + vmapped_key = jax.vmap(plrandom.to_pallas_key)(key) + vmapped_key_data = jax.random.key_data(vmapped_key) + np.testing.assert_array_equal(batched_key_data, vmapped_key_data) + + def test_pallas_key_raise_not_implemented_outside_of_kernel(self): + key = jax_random.key(0, impl="rbg") + pallas_key = plrandom.to_pallas_key(key) + # Using a pallas key outside of a kernel should raise an error when + # trying to lower TPU-specific ops to XLA. + # TODO(justinfu): Make this error more specific to pallas PRNG usage. + with self.assertRaisesRegex(NotImplementedError, + "MLIR translation rule .* not found"): + jax.random.uniform( + pallas_key, shape=(1,), minval=0.0, maxval=1.0) + + def test_seeded_reproducibility(self): + # Test whether generating random bits with the same seed + # produces the same result (and different seeds produce + # different results). + def seeded_body(seed: int): + def body(o_ref): + pltpu.prng_seed(seed) + o_ref[...] = pltpu.prng_random_bits(o_ref[...].shape) + return body + + out = jax.ShapeDtypeStruct((8, 128), jnp.int32) + result_1a = pl.pallas_call(seeded_body(0), out_shape=out)() + result_1b = pl.pallas_call(seeded_body(0), out_shape=out)() + result_2 = pl.pallas_call(seeded_body(1), out_shape=out)() + with self.subTest("same_seed_same_result"): + np.testing.assert_array_equal(result_1a, result_1b) + with self.subTest("diff_seed_diff_result"): + np.testing.assert_array_compare(np.not_equal, result_1a, result_2) + + @parameterized.parameters( + ((32, 256),), + ((8, 16),), + ) + def test_prng_non_vreg_shape_output(self, shape): + # Tests that RNG generation works with output shapes + # not equal to a native-sized VREG. + # This test makes sure that vector layout tiling + # is implemented correctly. + def body(o_ref): + pltpu.prng_seed(0) + samples = pltpu.prng_random_bits(o_ref[...].shape) + o_ref[...] = samples + + o_shape = jax.ShapeDtypeStruct(shape, jnp.int32) + result = pl.pallas_call(body, out_shape=o_shape)() + # Check that random_bits generates (mostly) unique values. + unique_frac = float(len(jnp.unique(result))) / np.prod(shape) + self.assertGreater(unique_frac, 0.99) + self.assertLessEqual(jnp.max(result), np.iinfo(jnp.int32).max) + self.assertGreaterEqual(jnp.min(result), np.iinfo(jnp.int32).min) + + def test_stateful_uniform_sample(self): + # Test stateful RNG using the jax.random API wrappers. + def body(key_ref, o_ref): + plrandom.set_seed(key_ref[...]) + o_ref[...] = plrandom.uniform( + shape=o_ref[...].shape, minval=0.0, maxval=1.0) + + rbg_key = jax_random.key(0, impl="rbg") + key = plrandom.to_pallas_key(rbg_key) + o_shape = jax.ShapeDtypeStruct((8, 128), jnp.float32) + result = pl.pallas_call( + body, + in_specs=[pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM)], + out_shape=o_shape, + )(key) + self.assertGreaterEqual(jnp.min(result), 0) + self.assertLessEqual(jnp.max(result), 1.0) + + def test_stateless_uniform_sample(self): + # Test keyed RNG using the jax.random API. + def body(key_ref, o_ref): + o_ref[...] = jax_random.uniform( + key_ref[...], shape=o_ref[...].shape, minval=0.0, maxval=1.0 + ) + + rbg_key = jax_random.key(0, impl="rbg") + key = plrandom.to_pallas_key(rbg_key) + o_shape = jax.ShapeDtypeStruct((8, 128), jnp.float32) + result = pl.pallas_call( + body, + in_specs=[pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM)], + out_shape=o_shape, + )(key) + self.assertGreaterEqual(jnp.min(result), 0) + self.assertLessEqual(jnp.max(result), 1.0) + + def test_fold_in(self): + # Test that folding in a value results in different random numbers. + def body(key_ref, o_ref): + key = key_ref[...] + o_ref[0, ...] = jax_random.uniform( + key, shape=o_ref[0, ...].shape, minval=0.0, maxval=1.0 + ) + + key = jax_random.fold_in(key, 2) + o_ref[1, ...] = jax_random.uniform( + key, shape=o_ref[1, ...].shape, minval=0.0, maxval=1.0 + ) + + rbg_key = jax_random.key(0, impl="rbg") + key = plrandom.to_pallas_key(rbg_key) + o_shape = jax.ShapeDtypeStruct((2, 8, 128), jnp.float32) + result = pl.pallas_call( + body, + in_specs=[pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM)], + out_shape=o_shape, + )(key) + result_a = result[0] + result_b = result[1] + np.testing.assert_array_compare(np.not_equal, result_a, result_b) + + +class BlockInvarianceTest(parameterized.TestCase): + + def setUp(self): + if not jtu.test_device_matches(["tpu"]): + self.skipTest("Need TPU devices") + super().setUp() + + def test_block_invariance(self): + + def make_kernel_body(index_map): + def body(key_ref, o_ref): + key = key_ref[0, 0] + samples = plrandom.sample_block( + jax.random.uniform, + key, + block_size=o_ref[...].shape, + tile_size=(16, 128), + total_size=(64, 512), + block_index=index_map(pl.program_id(0), pl.program_id(1)), + minval=0.0, + maxval=1.0) + o_ref[...] = samples + return body + + global_key = jax_random.key(0, impl="pallas_tpu") + o_shape = jnp.ones((64, 512), dtype=jnp.float32) + key_spec = pl.BlockSpec( + (1, 1), lambda i, j: (0, 0), memory_space=pltpu.TPUMemorySpace.SMEM + ) + out_spec = pl.BlockSpec((16, 128), lambda i, j: (i, j)) + result_16x128 = pl.pallas_call( + make_kernel_body(index_map=lambda i, j: (i, j)), + out_shape=o_shape, + in_specs=[key_spec], + out_specs=out_spec, + grid=(4, 4), + )(global_key) + + out_spec = pl.BlockSpec((32, 256), lambda i, j: (j, i)) + result_32x256 = pl.pallas_call( + make_kernel_body(index_map=lambda i, j: (j, i)), + in_specs=[key_spec], + out_shape=o_shape, + out_specs=out_spec, + grid=(2, 2), + )(global_key) + np.testing.assert_array_equal(result_16x128, result_32x256) + + +if __name__ == "__main__": + absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/pgle_test.py b/tests/pgle_test.py index 466da7f27067..7dc015c90bca 100644 --- a/tests/pgle_test.py +++ b/tests/pgle_test.py @@ -18,39 +18,250 @@ import math import os import tempfile +import unittest from absl.testing import absltest import jax -from jax import config +from jax._src import config +from jax._src import profiler +from jax._src import pjit +from jax._src import monitoring from jax._src import test_util as jtu -from jax.sharding import NamedSharding +from jax._src import api from jax.experimental import profiler as exp_profiler import jax.numpy as jnp -from jax.sharding import PartitionSpec as P +from jax.sharding import NamedSharding, PartitionSpec +from jax._src import compilation_cache as cc import numpy as np -config.parse_flags_with_absl() +from jax.experimental.serialize_executable import ( + deserialize_and_load, + serialize, +) + +jax.config.parse_flags_with_absl() @jtu.pytest_mark_if_available('multiaccelerator') class PgleTest(jtu.JaxTestCase): + def setUp(self): + super().setUp() + cc.reset_cache() + + def tearDown(self): + cc.reset_cache() + super().tearDown() + + @unittest.skip("Test failing in CI") + def testPGLEProfilerGetFDOProfile(self): + mesh = jtu.create_global_mesh((2,), ('x',)) + + @partial( + jax.jit, + in_shardings=NamedSharding(mesh, PartitionSpec('x')), + out_shardings=NamedSharding(mesh, PartitionSpec('x')), + ) + def f(x, y): + return x @ y + + shape = (16, 16) + x = jnp.arange(math.prod(shape)).reshape(shape).astype(np.float32) + y = x + 1 + + with config.pgle_profiling_runs(0): + f_lowered = f.lower(x, y) + compiled = f_lowered.compile() + + pgle_profiler = profiler.PGLEProfiler(1, 90) + with config.enable_pgle(False): + with profiler.PGLEProfiler.trace(pgle_profiler): + compiled(x, y) + + fdo_profile = pgle_profiler.consume_fdo_profile() + self.assertIsNotNone(fdo_profile) + self.assertIn(b'custom', fdo_profile) + + @unittest.skip("Test failing in CI") + def testPGLEProfilerGetFDOProfileLarge(self): + mesh = jtu.create_global_mesh((2,), ('x',)) + its = 500 + + @partial( + jax.jit, + in_shardings=NamedSharding(mesh, PartitionSpec('x')), + out_shardings=NamedSharding(mesh, PartitionSpec('x')), + ) + def f(x): + agg = x + for _ in range(its): + agg = agg @ x + return agg + + shape = (16, 16) + x = jnp.arange(math.prod(shape)).reshape(shape).astype(np.float32) + + with config.pgle_profiling_runs(0): + f_lowered = f.lower(x) + f_compiled = f_lowered.compile() + + pgle_profiler = profiler.PGLEProfiler(1, 90) + with config.enable_pgle(False): + with profiler.PGLEProfiler.trace(pgle_profiler): + f_compiled(x) + fdo_profile = pgle_profiler.consume_fdo_profile() + self.assertEqual(fdo_profile.count(b'custom'), its) + + def testAutoPgle(self): + mesh = jtu.create_global_mesh((2,), ('x',)) + + @partial( + jax.jit, + in_shardings=NamedSharding(mesh, PartitionSpec('x')), + out_shardings=NamedSharding(mesh, PartitionSpec('x')), + ) + def f(x): + return x * 2 + + shape = (16, 16) + x = jnp.arange(math.prod(shape)).reshape(shape).astype(np.float32) + expected = x * 2 + + with config.pgle_profiling_runs(2), config.enable_pgle(True): + # Run 1: Module should be compiled without FDO. Two modules are expected + # One is the funtion f, the other one is multi slice module + with jtu.count_cached_compilation_cache_miss() as cache_miss_count: + self.assertArraysEqual(f(x), expected) + self.assertEqual(cache_miss_count[0], 2) + + # Run 2: Second PGLE run should not recompile the module + with jtu.count_cached_compilation_cache_miss() as cache_miss_count: + self.assertArraysEqual(f(x), expected) + self.assertEqual(cache_miss_count[0], 0) + + # Run 3: The module should be recompiled with FDO profiles + with jtu.count_cached_compilation_cache_miss() as cache_miss_count: + self.assertArraysEqual(f(x), expected) + self.assertEqual(cache_miss_count[0], 2) + + # Run 4: Fast-path should be used after PGLE is done + with jtu.count_cached_compilation_cache_miss() as cache_miss_count: + self.assertArraysEqual(f(x), expected) + self.assertEqual(cache_miss_count[0], 0) + + def testAutoPgleWithAot(self): + @jax.jit + def f(x): + return x * 2 + + x = jnp.arange(1) + expected = x * 2 + + f_lowered = f.lower(x) + serialized, in_tree, out_tree = serialize(f_lowered.compile()) + compiled = deserialize_and_load(serialized, in_tree, out_tree) + + with config.pgle_profiling_runs(1), config.enable_pgle(True): + # Run 1 + with jtu.count_cached_compilation_cache_miss() as cache_miss_count: + self.assertArraysEqual(compiled(x), expected) + self.assertEqual(cache_miss_count[0], 0) + + # Run 2 + with jtu.count_cached_compilation_cache_miss() as cache_miss_count: + self.assertArraysEqual(compiled(x), expected) + self.assertEqual(cache_miss_count[0], 0) + + @unittest.skip("Test failing in CI") + def testAutoPgleWithPersistentCache(self): + + @jax.jit + def f(x): + return x * 2 + + x = jnp.arange(1) + expected = x * 2 + + profilers_dict = ( + pjit._most_recent_pjit_call_executable.weak_pgle_profiler_dict) + with (config.enable_compilation_cache(True), + config.enable_pgle(True), + config.raise_persistent_cache_errors(True), + config.raise_persistent_cache_errors(True), + config.persistent_cache_min_entry_size_bytes(0), + config.persistent_cache_min_compile_time_secs(0), + config.pgle_profiling_runs(2), + tempfile.TemporaryDirectory() as tmpdir): + cc.set_cache_dir(tmpdir) + # Run 1: Module should be compiled without FDO + with jtu.count_cached_compilation_cache_miss() as cache_miss_count: + self.assertArraysEqual(f(x), expected) + self.assertEqual(cache_miss_count[0], 1) + + # Non-pgle profiled version of module should be saved + non_pgle_profiled_files = os.listdir(tmpdir) + self.assertLen(non_pgle_profiled_files, 1) + + # Run 2: Compilation should not be called + with jtu.count_cached_compilation_cache_miss() as cache_miss_count: + self.assertArraysEqual(f(x), expected) + self.assertEqual(cache_miss_count[0], 0) + + # Run 3: Module should be compiled with FDO and stored to persistent cache + with jtu.count_cached_compilation_cache_miss() as cache_miss_count: + self.assertArraysEqual(f(x), expected) + self.assertEqual(cache_miss_count[0], 1) + + for pgle_profiler in profilers_dict.values(): + self.assertTrue(pgle_profiler.is_enabled()) + self.assertTrue(pgle_profiler.is_fdo_consumed()) + # One module is PGLEd version another one is not PGLEd + self.assertLen(os.listdir(tmpdir), 2) + + # Removing non-pgle profiled module from cache to check that later pgle + # profiled version will be used. + os.remove(os.path.join(tmpdir, non_pgle_profiled_files[0])) + + api.clear_caches() + profilers_dict.clear() + + # Run 4: Persistent compilation cache should be hit PGLE profiler should + # be disabled + cache_hit = 0 + def check_if_cache_hit(event): + nonlocal cache_hit + if event == '/jax/compilation_cache/cache_hits': + cache_hit += 1 + + monitoring.register_event_listener(check_if_cache_hit) + with jtu.count_cached_compilation_cache_miss() as cache_miss_count: + self.assertArraysEqual(f(x), expected) + monitoring._unregister_event_listener_by_callback(check_if_cache_hit) + + self.assertEqual(cache_miss_count[0], 1) + self.assertEqual(cache_hit, 1) + self.assertLen(profilers_dict, 1) + for pgle_profiler in profilers_dict.values(): + self.assertFalse(pgle_profiler.is_enabled()) + self.assertFalse(pgle_profiler.is_fdo_consumed()) def testPassingFDOProfile(self): mesh = jtu.create_global_mesh((2,), ('x',)) + @partial( jax.jit, - in_shardings=NamedSharding(mesh, P('x',)), - out_shardings=NamedSharding(mesh, P('x',)), + in_shardings=NamedSharding(mesh, PartitionSpec('x')), + out_shardings=NamedSharding(mesh, PartitionSpec('x')), ) def f(x, y): - z = x @ y - return z @ y + return x @ y - shape = (8, 8) + shape = (16, 16) x = jnp.arange(math.prod(shape)).reshape(shape).astype(np.float32) y = x + 1 - f_lowered = f.lower(x, y) - compiled = f_lowered.compile() + + with config.pgle_profiling_runs(0): + f_lowered = f.lower(x, y) + compiled = f_lowered.compile() with tempfile.TemporaryDirectory() as tmpdir: jax.profiler.start_trace(tmpdir) diff --git a/tests/pickle_test.py b/tests/pickle_test.py index 8fa6613cf895..1dede34d2bf1 100644 --- a/tests/pickle_test.py +++ b/tests/pickle_test.py @@ -26,14 +26,13 @@ import jax from jax import numpy as jnp -from jax import config from jax.interpreters import pxla from jax._src import test_util as jtu from jax._src.lib import xla_client as xc import numpy as np -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() def _get_device_by_id(device_id: int) -> xc.Device: diff --git a/tests/pjit_test.py b/tests/pjit_test.py index a3f17e7b6637..9e1ca442a8f7 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -13,9 +13,9 @@ # limitations under the License. from collections import OrderedDict, namedtuple -import os +import contextlib import re -from functools import partial, lru_cache +from functools import partial import logging import math import textwrap @@ -32,7 +32,6 @@ import jax.numpy as jnp from jax._src import core from jax._src import config -from jax._src import maps from jax._src import test_util as jtu from jax import dtypes from jax import stages @@ -40,18 +39,18 @@ from jax import lax from jax.lax import with_sharding_constraint from jax._src import prng -from jax.sharding import PartitionSpec as P -from jax.experimental.maps import xmap +from jax.sharding import PartitionSpec as P, Mesh from jax.experimental import multihost_utils from jax.experimental.custom_partitioning import custom_partitioning from jax._src import array -from jax._src.sharding import Sharding +from jax._src.sharding import Sharding, common_devices_indices_map from jax._src import op_shardings from jax._src import sharding_impls from jax._src.sharding_impls import ( AUTO, UNSPECIFIED, NamedSharding, GSPMDSharding, PositionalSharding, SingleDeviceSharding, parse_flatten_op_sharding) import jax._src.pjit as pjit_lib +from jax._src.maps import xmap from jax._src.pjit import pjit, pjit_p from jax._src import mesh as mesh_lib from jax._src.interpreters import pxla @@ -59,37 +58,19 @@ from jax._src import xla_bridge from jax._src.lib import xla_client as xc from jax._src.lib import xla_extension -from jax._src.lib import xla_extension_version from jax._src.util import curry, unzip2 config.parse_flags_with_absl() -prev_xla_flags = None -prev_spmd_lowering_flag = None - +# Run all tests with 8 CPU devices. +_exit_stack = contextlib.ExitStack() def setUpModule(): - global prev_xla_flags - prev_xla_flags = os.getenv("XLA_FLAGS") - flags_str = prev_xla_flags or "" - # Don't override user-specified device count, or other XLA flags. - if "xla_force_host_platform_device_count" not in flags_str: - os.environ["XLA_FLAGS"] = (flags_str + - " --xla_force_host_platform_device_count=8") - # Clear any cached backends so new CPU backend will pick up the env var. - xla_bridge.get_backend.cache_clear() - global prev_spmd_lowering_flag - prev_spmd_lowering_flag = maps.SPMD_LOWERING.value - config.update('experimental_xmap_spmd_lowering', True) + _exit_stack.enter_context(jtu.set_host_platform_device_count(8)) + _exit_stack.enter_context(jtu.global_config_context(experimental_xmap_spmd_lowering=True)) def tearDownModule(): - if prev_xla_flags is None: - del os.environ["XLA_FLAGS"] - else: - os.environ["XLA_FLAGS"] = prev_xla_flags - xla_bridge.get_backend.cache_clear() - config.update('experimental_xmap_spmd_lowering', prev_spmd_lowering_flag) - + _exit_stack.close() def create_array(global_shape, global_mesh, mesh_axes, global_data=None, dtype=np.float32): @@ -106,11 +87,6 @@ def create_array(global_shape, global_mesh, mesh_axes, global_data=None, global_shape, sharding, lambda idx: global_data[idx]), global_data -@lru_cache -def simulated_cached_fun(s): - return s - - def _check_instance(self, x): self.assertIsInstance(x, array.ArrayImpl) @@ -394,7 +370,6 @@ def f(inp1, inp2, inp3): jax.tree.map(self.assertDeleted, y_tree) jax.tree.map(self.assertNotDeleted, z_tree) - @unittest.skipIf(xla_extension_version < 220, 'jaxlib version too old') @jtu.run_on_devices('tpu', 'cpu', 'gpu') def testBufferDonationWithOutputShardingInference(self): mesh = jtu.create_global_mesh((2,), 'x') @@ -426,14 +401,13 @@ def f(inp1, inp2, inp3): jax.tree.map(self.assertDeleted, y_tree) jax.tree.map(self.assertDeleted, z_tree) - @unittest.skipIf(xla_extension_version < 220, 'jaxlib version too old') @jtu.run_on_devices('tpu') def testBufferDonationWithOutputShardingInferenceAndTokens(self): mesh = jtu.create_global_mesh((2,), 'x') s = NamedSharding(mesh, P('x')) def _callback(x): - self.assertIs(type(x), np.ndarray) + self.assertIsInstance(x, jax.Array) @partial(pjit, donate_argnames=('x')) def f(x): @@ -448,7 +422,6 @@ def f(x): jax.effects_barrier() self.assertDeleted(x) - @unittest.skipIf(xla_extension_version < 220, 'jaxlib version too old') @jtu.run_on_devices('tpu', 'cpu', 'gpu') def testBufferDonationNotDonated(self): mesh = jtu.create_global_mesh((2,), 'x') @@ -515,7 +488,7 @@ def testShardingConstraintWithArrayOpSharding(self): shape = (8, 8) mesh = jtu.create_global_mesh((2, 1), ('x', 'y')) s = NamedSharding(mesh, P(None)) - ops = pjit_lib.to_gspmd_sharding( + ops = pxla.to_gspmd_sharding( NamedSharding(mesh, P('x', 'y')), len(shape)) @partial(pjit, in_shardings=s, out_shardings=s) @@ -748,7 +721,7 @@ def testVMapShardingConstraint(self): jaxpr = jax.make_jaxpr(jax.vmap(f))(x) pjit_eqn, = jaxpr.eqns constraint_eqn, = pjit_eqn.params['jaxpr'].eqns - op = constraint_eqn.params['sharding']._hlo_sharding + op = constraint_eqn.params['sharding']._to_xla_hlo_sharding(x.ndim) self.assertTrue(op.is_tiled()) self.assertListEqual(op.tile_assignment_dimensions(), [1, 2]) self.assertListEqual(op.tile_assignment_devices(), [0, 1]) @@ -768,7 +741,7 @@ def testVMapShardingConstraintWithSpmdAxis(self): jaxpr = jax.make_jaxpr(f)(x) pjit_eqn, = jaxpr.eqns constraint_eqn, = pjit_eqn.params['jaxpr'].eqns - op = constraint_eqn.params['sharding']._hlo_sharding + op = constraint_eqn.params['sharding']._to_xla_hlo_sharding(x.ndim) self.assertTrue(op.is_tiled()) self.assertListEqual(op.tile_assignment_dimensions(), [2, 1]) self.assertListEqual(op.tile_assignment_devices(), [0, 1]) @@ -785,13 +758,6 @@ def testShardingInXMap(self): def _test_rule(*args, **kwargs): nonlocal test_rule_called test_rule_called = True - in_shardings = kwargs['in_shardings'] - self.assertLen(in_shardings, 1) - self.assertListEqual(in_shardings[0]._hlo_sharding.tile_assignment_dimensions(), - [1, 1, 2]) - self.assertFalse(op_shardings.is_op_sharding_replicated( - in_shardings[0]._hlo_sharding)) - return rule(*args, **kwargs) try: mlir._lowerings[pjit_p] = _test_rule @@ -1298,27 +1264,37 @@ def f(x): y = jnp.array([4.2, 2.4], dtype=jnp.float32) jaxpr = jax.make_jaxpr(g)(x, y) self.assertEqual( - jaxpr.pretty_print(), + jaxpr.pretty_print(use_color=False), textwrap.dedent(""" - let f = { lambda ; a:f32[1]. let in (a,) } in - let f1 = { lambda ; b:f32[2]. let in (b,) } in + let f = { lambda ; a:f32[1]. let in () } in + let f1 = { lambda ; b:f32[2]. let in () } in { lambda ; c:f32[1] d:f32[2]. let e:f32[2] = pjit[ name=g jaxpr={ lambda ; g:f32[1] h:f32[2]. let - i:f32[1] = pjit[name=f jaxpr=f] g - j:f32[1] = pjit[name=f jaxpr=f] g - k:f32[1] = mul i j - l:f32[2] = pjit[name=f jaxpr=f1] h - m:f32[2] = pjit[name=f jaxpr=f1] h - n:f32[2] = mul l m - o:f32[2] = add k n - in (o,) } + pjit[name=f jaxpr=f] g + pjit[name=f jaxpr=f] g + i:f32[1] = mul g g + pjit[name=f jaxpr=f1] h + pjit[name=f jaxpr=f1] h + j:f32[2] = mul h h + k:f32[2] = add i j + in (k,) } ] c d in (e,) } """).strip(), ) + def test_with_sharding_constraint_vmap_spmd_axis_name_error(self): + mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) + + def f(x): + return jax.lax.with_sharding_constraint(x, NamedSharding(mesh, P('x'))) + + xs = jnp.arange(4 * 16.).reshape(4, 16) + with self.assertRaisesRegex(ValueError, "spmd_axis_name"): + jax.vmap(f, spmd_axis_name='x')(xs) + @jtu.pytest_mark_if_available('multiaccelerator') class CustomPartitionerTest(jtu.JaxTestCase): @@ -1326,8 +1302,6 @@ class CustomPartitionerTest(jtu.JaxTestCase): def skip_if_custom_partitioning_not_supported(self): if jtu.is_cloud_tpu(): raise unittest.SkipTest("Custom partitioning is not supported on libtpu.") - if xla_bridge.using_pjrt_c_api(): - raise unittest.SkipTest('custom partitioning not implemented in PJRT C API') @jtu.skip_on_devices('cpu') # Collectives don't seem to work on CPU. @jtu.with_mesh([('x', 4), ('y', 2)]) @@ -1530,6 +1504,37 @@ def infer_sharding_from_operands(mesh, arg_shapes, result_shape): pjit_f = pjit(jit_f, in_shardings=(P('x')), out_shardings=P('x')) self.assertArraysEqual(x, pjit_f(x)) + @jtu.with_mesh([('x', 4)]) + def test_custom_partitioner_with_scan(self): + self.skip_if_custom_partitioning_not_supported() + + # This is a reproducer from https://github.com/google/jax/issues/20864. + + @custom_partitioning + def f(x): + return jnp.sum(x) + + def partition(mesh, arg_shapes, result_shape): + def lower_fn(xs): + def f(carry, x): + return carry + jax.lax.psum(jnp.sum(x), axis_name='x'), None + + carry, _ = jax.lax.scan(f, 0, xs) + return carry + + result_shardings = jax.tree.map(lambda x: x.sharding, result_shape) + arg_shardings = jax.tree.map(lambda x: x.sharding, arg_shapes) + return mesh, lower_fn, result_shardings, arg_shardings + + f.def_partition( + partition, + infer_sharding_from_operands=lambda mesh, *_: NamedSharding(mesh, P()), + propagate_user_sharding=lambda _, user_shape: user_shape.sharding) + + pjit_f = pjit(f, in_shardings=P(None, 'x')) + xs = jnp.ones([32, 16]) + self.assertEqual(pjit_f(xs), xs.sum()) + @jtu.pytest_mark_if_available('multiaccelerator') class AutoShardingPjitTest(jtu.JaxTestCase): @@ -1559,7 +1564,7 @@ def test_pjit_arr_auto_sharding_array(self, mesh_shape, global_input_shape, def test_xla_arr_sharding_mismatch(self): global_mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) - global_input_shape = (4, 2) + global_input_shape = (6, 2) input_data = np.arange( math.prod(global_input_shape), dtype=np.float32).reshape(global_input_shape) @@ -1637,6 +1642,41 @@ def test_jit_different_mesh_in_auto(self): "Received incompatible devices for jitted computation"): f.lower(inp, inp).compile() + @parameterized.named_parameters( + ('2d_array', (4, 2), ('x', 'y')), + ('1d_array', (8,), ('x')), + ) + def test_jit_auto_sharding_partial_tuple_input_shardings( + self, mesh_shape, mesh_axis_names): + if not jtu.test_device_matches(["tpu"]): + self.skipTest('Parameters are tupled only on TPU if >2000 parameters') + + mesh = jtu.create_global_mesh(mesh_shape, mesh_axis_names) + global_input_shape = (8, 4) + input_data = np.arange( + math.prod(global_input_shape), dtype=np.float32).reshape(global_input_shape) + input_sharding = NamedSharding(mesh, P(mesh_axis_names)) # sharded + input_sharding_annotations = [AUTO(mesh)] * 2001 + output_sharding = NamedSharding(mesh, P()) # replicated + output_sharding_annotations = [AUTO(mesh)] * 2001 + for i in range(1000): + input_sharding_annotations[2*i] = input_sharding + output_sharding_annotations[2*i] = output_sharding + + jit_tuple_identity_fn = jax.jit( + lambda *x: x, + in_shardings=input_sharding_annotations, + out_shardings=tuple(output_sharding_annotations)) + + inp = core.ShapedArray(input_data.shape, input_data.dtype) + compiled = jit_tuple_identity_fn.lower(*([inp] * 2001)).compile() + + + # Check sharding preservation for even numbered inputs. + for i in range(1000): + self.assertEqual(compiled.input_shardings[0][2*i], input_sharding) + self.assertEqual(compiled.output_shardings[2*i], output_sharding) + @unittest.skip('The error is not raised yet. Enable this back once we raise ' 'the error in pjit again.') def test_pjit_array_error(self): @@ -1824,6 +1864,18 @@ def f(tree): for s in out4.addressable_shards: self.assertArraysEqual(s.data, input_data) + def test_sds_full_like(self): + # https://github.com/google/jax/issues/20390 + mesh = jtu.create_global_mesh((2, 1), ('x', 'y')) + s = NamedSharding(mesh, P('x', 'y')) + x = jax.ShapeDtypeStruct((4, 4), jnp.float32, sharding=s) + y = jnp.zeros_like(x) + z = jnp.zeros_like(x, device=y.sharding) + + self.assertEqual(x.sharding, s) + self.assertEqual(y.sharding, s) + self.assertEqual(z.sharding, s) + def test_in_axis_resources_mismatch_error(self): global_input_shape = (8, 2) global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) @@ -1892,9 +1944,9 @@ def test_array_lower_compile(self): with self.assertRaisesRegex( ValueError, - r"Compiled object called with input sharding\(s\) does not match the " - r"sharding\(s\) the computation was compiled with. " - "Here are 5 mismatches out of 6"): + r"Compiled object called with input sharding.*does not match the " + r"sharding.*the computation was compiled with. " + "Here are.*mismatches.*"): compiled(a2, a2, a2, a2, a2, a2) with global_mesh: @@ -1906,9 +1958,9 @@ def test_array_lower_compile(self): inp2 = {'x': a2, 'y': {'y1': a2}} with self.assertRaisesRegex( ValueError, - r"Compiled object called with input sharding\(s\) does not match the " - r"sharding\(s\) the computation was compiled with. " - "Here are the 2 mismatches"): + r"Compiled object called with input sharding.*does not match the " + r"sharding.*the computation was compiled with. " + "Here are the.*mismatches"): compiled(inp2) def test_globally_sharded_key_array_result_8x4_single_device(self): @@ -2139,30 +2191,6 @@ def test_fast_path_array(self): self.assertTrue(out2.sharding.is_equivalent_to(out.sharding, out.ndim)) self.assertArraysEqual(out2, inp_data) - def test_not_xlacompatible_sharding_error(self): - shape = (8, 2) - inp_data = np.arange(math.prod(shape)).reshape(shape) - ts = TempSharding(jax.devices()) - arr = array.make_array_from_callback( - shape, ts, lambda idx: inp_data[idx]) - with self.assertRaisesRegex( - ValueError, - 'One of the argument to pjit got sharding.*which is not a subclass of ' - 'XLACompatibleSharding.'): - pjit(lambda x: x)(arr) - - with self.assertRaisesRegex( - ValueError, - 'One of in_shardings leaf specifications got sharding.*which is ' - 'not a subclass of XLACompatibleSharding.'): - pjit(lambda x: x, in_shardings=ts)(arr) - - with self.assertRaisesRegex( - ValueError, - 'One of out_shardings leaf specifications got sharding.*which is ' - 'not a subclass of XLACompatibleSharding.'): - pjit(lambda x: x, out_shardings=ts)(arr) - def test_array_enabled_non_empty_mesh_with_pspec(self): arr = jnp.array([1, 2, 3]) with self.assertRaisesRegex( @@ -2308,18 +2336,18 @@ def test_out_sharding_indices_id_cache_hit(self): out1 = f(arr) self.assertIsInstance(out1.sharding, NamedSharding) out1.sharding.devices_indices_map(shape) - cache_info1 = sharding_impls.common_devices_indices_map.cache_info() + cache_info1 = common_devices_indices_map.cache_info() out2 = f(out1) self.assertIsInstance(out2.sharding, NamedSharding) out2.sharding.devices_indices_map(shape) - cache_info2 = sharding_impls.common_devices_indices_map.cache_info() + cache_info2 = common_devices_indices_map.cache_info() self.assertEqual(cache_info2.hits, cache_info1.hits + 1) out3 = f(out2) self.assertIsInstance(out3.sharding, NamedSharding) out3.sharding.devices_indices_map(shape) - cache_info3 = sharding_impls.common_devices_indices_map.cache_info() + cache_info3 = common_devices_indices_map.cache_info() self.assertEqual(cache_info3.hits, cache_info2.hits + 1) def test_aot_compile_in_tree_mismatch(self): @@ -2436,6 +2464,8 @@ def f(x, y, z): r"pjit inside pjit with device ids \[0, 1, 2, 3\].*"): my_nested_pjit(committed_inp, committed_inp, committed_inp) + @jtu.ignore_warning(category=DeprecationWarning, + message="backend and device argument") def test_jit_device_with_sharding_constraint_error(self): mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) @@ -2836,6 +2866,20 @@ def f(x, y, z, a, b, c): # pylint: disable=unused-argument self.assertEqual(compiled._executable._kept_var_idx, {5}) self.assertLen(compiled._executable.in_avals, 1) + def test_pjit_relayout_multi_slice(self): + mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) + + @jax.jit + def mul(x): + return x @ x.T + + x = jnp.arange(8).reshape(4, 2) + y = jax.device_put(x, jax.sharding.NamedSharding(mesh, P('x', 'y'))) + compiled = mul.lower(jax.ShapeDtypeStruct( + y.shape, y.dtype, sharding=y.sharding)).compile() + out = compiled(y) + self.assertArraysEqual(out, x @ x.T) + def test_pjit_with_device_arg(self): def mul(x): return x @ x.T @@ -2847,7 +2891,9 @@ def _check(out, expected_device, expected_out): mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) - f = pjit(mul, device=jax.devices()[1]) + with jtu.ignore_warning(category=DeprecationWarning, + message="backend and device argument"): + f = pjit(mul, device=jax.devices()[1]) x = jnp.arange(8).reshape(4, 2) f_out = f(x) f_out2 = f(f_out) @@ -2862,7 +2908,9 @@ def _check(out, expected_device, expected_out): self.assertEqual(cache_info2.hits, cache_info1.hits + 1) - h = pjit(mul, device=jax.devices()[-1]) + with jtu.ignore_warning(category=DeprecationWarning, + message="backend and device argument"): + h = pjit(mul, device=jax.devices()[-1]) h_out = h(y) cache_info3 = pjit_lib._pjit_lower_cached.cache_info() _check(h_out, jax.devices()[-1], y) @@ -2882,7 +2930,9 @@ def test_pjit_with_device_arg_input_from_another_pjit(self): out = pjit(lambda x: x * 2)(y) expected_device = jax.devices()[2] - final_out = pjit(lambda x: x * 3, device=expected_device)(out) + with jtu.ignore_warning(category=DeprecationWarning, + message="backend and device argument"): + final_out = pjit(lambda x: x * 3, device=expected_device)(out) self.assertEqual(final_out.devices(), {expected_device}) self.assertLen(final_out.sharding.device_set, 1) @@ -2896,7 +2946,9 @@ def _check(out, expected_device, expected_out): self.assertArraysEqual(out, expected_out) x = jnp.arange(8) - g = pjit(lambda x: x, backend='tpu') + with jtu.ignore_warning(category=DeprecationWarning, + message="backend and device argument"): + g = pjit(lambda x: x, backend='tpu') g_out = g(x) _check(g_out, jax.devices()[0], x) @@ -2909,8 +2961,10 @@ def test_autodiff_with_device_arg(self): self.skipTest('Test requires more >1 device.') # Add a constant captured by the nested pjit to make things more complicated h = jnp.arange(4.) - f = pjit(lambda x: x.sum(1) * h.sum(), device=jax.devices()[1]) - g = pjit(lambda x: f(jnp.sin(x * 4 + 2)), device=jax.devices()[1]) + with jtu.ignore_warning(category=DeprecationWarning, + message="backend and device argument"): + f = pjit(lambda x: x.sum(1) * h.sum(), device=jax.devices()[1]) + g = pjit(lambda x: f(jnp.sin(x * 4 + 2)), device=jax.devices()[1]) jtu.check_grads(g, (jnp.arange(16.).reshape((4, 4)) / 100,), order=2) def test_pjit_device_backend_axis_resources_error(self): @@ -2919,18 +2973,34 @@ def test_pjit_device_backend_axis_resources_error(self): ValueError, 'If backend or device is specified on jit, then ' 'in_shardings should not be specified.'): - pjit(lambda x: x, in_shardings=s, backend='cpu') + with jtu.ignore_warning(category=DeprecationWarning, + message="backend and device argument"): + pjit(lambda x: x, in_shardings=s, backend='cpu') with self.assertRaisesRegex( ValueError, 'If backend or device is specified on jit, then ' 'out_shardings should not be specified.'): - pjit(lambda x: x, out_shardings=s, device=jax.devices()[0]) + with jtu.ignore_warning(category=DeprecationWarning, + message="backend and device argument"): + pjit(lambda x: x, out_shardings=s, device=jax.devices()[0]) + + def test_check_arg_error(self): + sds = jax.ShapeDtypeStruct((4, 2), np.int32) + inp = np.arange(8).reshape(4, 2) + + with self.assertRaisesRegex( + TypeError, + r"Argument 'x\['b'\]\['c'\]' of shape int32\[4,2\] of " + "type.*ShapeDtypeStruct.*is not a valid JAX type."): + jax.jit(lambda x: x)({'a': inp, 'b': {'c': sds}}) def test_pjit_device_backend_both_error(self): with self.assertRaisesRegex( ValueError, "can't specify both a device and a backend for jit"): - pjit(lambda x: x, device=jax.devices()[0], backend='cpu') + with jtu.ignore_warning(category=DeprecationWarning, + message="backend and device argument"): + pjit(lambda x: x, device=jax.devices()[0], backend='cpu') def test_pjit_mesh_with_device_or_backend_error(self): mesh = jtu.create_global_mesh((1,), ('x',)) @@ -2939,7 +3009,9 @@ def test_pjit_mesh_with_device_or_backend_error(self): ValueError, "Mesh context manager should not be used with jit when backend or " "device is also specified as an argument to jit."): - pjit(lambda x: x, device=jax.devices()[0])(jnp.arange(8)) + with jtu.ignore_warning(category=DeprecationWarning, + message="backend and device argument"): + pjit(lambda x: x, device=jax.devices()[0])(jnp.arange(8)) def test_pjit_inline(self): @partial(pjit, inline=False) @@ -3087,7 +3159,7 @@ def test_jit_with_mesh_context_manager(self): mesh = jtu.create_global_mesh((1,), ('x',)) with self.assertRaisesRegex( RuntimeError, - "jax.jit only supports `XLACompatibleSharding`s being passed to " + "jax.jit only supports `Sharding`s being passed to " "in_shardings"): with mesh: jax.jit(lambda x: x, in_shardings=P('x'), @@ -3156,7 +3228,9 @@ def test_pjit_no_global_cache_hit_axis_resources(self): with jtu.count_pjit_cpp_cache_miss() as count: for _ in range(10): - pjit(lambda x: x * 2, device=jax.devices()[0])(inp) + with jtu.ignore_warning(category=DeprecationWarning, + message="backend and device argument"): + pjit(lambda x: x * 2, device=jax.devices()[0])(inp) self.assertEqual(count[0], 10) pf = pjit(lambda x: x * 2, in_shardings=s, out_shardings=s) @@ -3165,7 +3239,9 @@ def test_pjit_no_global_cache_hit_axis_resources(self): pf(inp) self.assertEqual(count[0], 1) - pf1 = pjit(lambda x: x * 2, device=jax.devices()[0]) + with jtu.ignore_warning(category=DeprecationWarning, + message="backend and device argument"): + pf1 = pjit(lambda x: x * 2, device=jax.devices()[0]) with jtu.count_pjit_cpp_cache_miss() as count: for _ in range(10): pf1(inp) @@ -3411,6 +3487,12 @@ def mul(x): self.assertEqual(cache_info2.hits, cache_info1.hits + 1) self.assertEqual(cache_info2.misses, cache_info1.misses) + def test_list_in_pspec(self): + mesh = jtu.create_global_mesh((2,), ('x',)) + with mesh: + out = with_sharding_constraint(jnp.arange(8), P(['x'])) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x'))) + def test_sharding_preserved_trivial(self): mesh = jtu.create_global_mesh((2, 1), ('x', 'y')) ns = NamedSharding(mesh, P('x')) @@ -3574,7 +3656,9 @@ def test_different_named_sharding_object_replicated(self): self.assertNotEqual(x.sharding, y.sharding) def test_vmap_pjit_single_device(self): - jf = pjit(lambda x: x, device=jax.devices()[0]) + with jtu.ignore_warning(category=DeprecationWarning, + message="backend and device argument"): + jf = pjit(lambda x: x, device=jax.devices()[0]) out = jax.vmap(jf)(jnp.ones((3,))) # doesn't crash self.assertIsInstance(out.sharding, SingleDeviceSharding) @@ -3591,8 +3675,10 @@ def identity(x): self.assertEqual(out.devices(), {jax.devices()[0]}) self.assertArraysEqual(out, np_inp) - out2 = jax.jit(identity, device=jax.devices()[0])( - jax.device_put(np_inp, NamedSharding(mesh, P('x')))) + with jtu.ignore_warning(category=DeprecationWarning, + message="backend and device argument"): + out2 = jax.jit(identity, device=jax.devices()[0])( + jax.device_put(np_inp, NamedSharding(mesh, P('x')))) self.assertEqual(out2.devices(), {jax.devices()[0]}) self.assertArraysEqual(out2, np_inp) @@ -3748,8 +3834,8 @@ def f(inp): ' manager.*SingleDeviceSharding'): jax.jit(jax.vmap(f, spmd_axis_name='x'))(arr) - @jtu.skip_on_devices("tpu") - def test_device_put_memory_kind_not_tpu(self): + @jtu.skip_on_devices("tpu", "gpu") + def test_device_put_memory_kind_not_tpu_gpu(self): @jax.jit def f(x): y = x * 2 @@ -3942,27 +4028,217 @@ def f(x, y, z, a, b): self.assertArraysEqual(out4, np_inp * 3) self.assertArraysEqual(out5, np_inp.T) + def test_input_shardings_aot(self): + mesh = jtu.create_global_mesh((2, 1), ('x', 'y')) + np_inp = np.arange(16).reshape(8, 2) + arr = jax.device_put(np_inp, NamedSharding(mesh, P('x'))) + + @jax.jit + def f(x, y): + return x * 2, y.T + + arg_shardings, _ = f.lower(arr, np_inp).compile().input_shardings + for s in arg_shardings: + self.assertIsInstance(s, NamedSharding) + + def test_parameter_tupled_jit(self): + if not jtu.test_device_matches(["tpu"]): + self.skipTest('Parameters are tupled only on TPU if >2000 parameters') + + mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) + s = NamedSharding(mesh, P('x')) + + @jax.jit + def f(*args): + return args * 2 + + inp = np.arange(8) + arr = jax.device_put(inp, s) + inps = [arr, *[inp] * 2001] + f(inps) # doesn't crash + + def test_spmd_preserves_input_sharding_vmap_grad(self): + # https://github.com/google/jax/issues/20710 + n_devices = jax.device_count() + sharding = PositionalSharding(jax.devices()) + + def model(params, x): + return x @ params + + feature_dim = 3 + batch_size_total = 8 + + # Get example data + x = jnp.ones((batch_size_total, feature_dim)) + params = jnp.ones(feature_dim) + + # Shard data, replicate params + x = jax.device_put(x, sharding.reshape(n_devices, 1)) + params = jax.device_put(params, sharding.replicate(axis=0)) + + model(params, x) # doesn't crash + + jax.vmap(model, in_axes=(None, 0))(params, x) # doesn't crash + + jax.grad(lambda p: model(p, x).sum())(params) # doesn't crash + + jax.vmap(jax.grad(model), in_axes=(None, 0))(params, x) # doesn't crash + + def test_jit_token_input(self): + x = jnp.arange(8) + token = jax.lax.create_token(None) + device = jax.devices()[0] + x = jax.device_put(x, device=device) + out1, out2 = jax.jit(lambda x, t: (x, t))(x, token) + self.assertArraysEqual(out1, x) + self.assertIsInstance(out2, core.Token) + + def test_uneven_sharding_wsc(self): + mesh = jtu.create_global_mesh( + (2, 1, 1, 1, 1), ('data', 'expert', 'fsdp', 'seq', 'model') + ) + + @jax.jit + def fn(key): + x = jnp.arange(113003) + x = with_sharding_constraint(x, P('data')) + y = jnp.arange(65536) + y = with_sharding_constraint(y.reshape(-1), P('data')) + z = jnp.concatenate([x, y], axis=0) + z = with_sharding_constraint(z, P('data')) + return x, y, z + + with mesh: + x, y, z = fn(jax.random.key(42)) + + expected_x = np.arange(113003) + expected_y = np.arange(65536) + expected_z = np.concatenate([x, y], axis=0) + + self.assertArraysEqual(expected_x.max(), x.max()) + self.assertArraysEqual(expected_y.max(), y.max()) + self.assertArraysEqual(expected_z.max(), z.max()) + + def test_threefry_partitionable_context_within_jit(self): + with jax.threefry_partitionable(False): + def f(x): + return x + jax.random.randint(jax.random.key(72), (), 0, 10) + + def g(x): + with jax.threefry_partitionable(True): # False by default + return x + jax.random.randint(jax.random.key(72), (), 0, 10) + + h = jax.jit(g) + + self.assertNotEqual(f(1), g(1)) + self.assertEqual(g(1), h(1)) + + def test_wsc_vmap_unconstrained_spmd_axis_name(self): + def get_wsc_eqn_sharding(jaxpr): + for eqn in jaxpr.eqns: + if str(eqn.primitive) == 'sharding_constraint': + return eqn.params['sharding'], eqn.params['unconstrained_dims'] + for s in core.subjaxprs(jaxpr): + return get_wsc_eqn_sharding(s) + + mesh = jtu.create_global_mesh((2, 1), ('x', 'y')) + inp = jnp.ones((10, 10)) + + def a_function(x): + return with_sharding_constraint(x, NamedSharding(mesh, P(P.UNCONSTRAINED))) + + def vmap_the_function_spmd(y): + return jax.vmap(a_function, spmd_axis_name='x')(y) + + f1 = jax.jit(vmap_the_function_spmd) + f1(inp) # doesn't crash + jaxpr1 = jax.make_jaxpr(f1)(inp) + s1, u1 = get_wsc_eqn_sharding(jaxpr1) + self.assertEqual(s1.spec, P('x', P.UNCONSTRAINED)) + self.assertEqual(u1, {1}) + + def vmap_the_function_no_spmd(y): + return jax.vmap(a_function)(y) + + f2 = jax.jit(vmap_the_function_no_spmd) + f2(inp) # doesn't crash + jaxpr2 = jax.make_jaxpr(f2)(inp) + s2, u2 = get_wsc_eqn_sharding(jaxpr2) + self.assertEqual(s2.spec, P(P.UNCONSTRAINED, P.UNCONSTRAINED)) + self.assertEqual(u2, {0, 1}) + + def test_aot_sharding_dce(self): + inp = np.arange(8) + + @jax.jit + def f(x, y): + return x + + input_shardings, _ = f.lower(inp, inp).compile().input_shardings + self.assertLen(input_shardings, 2) + + def test_aot_out_info(self): + inp = np.arange(8, dtype=np.int32) + out_info = jax.jit(lambda x: x).lower((inp, inp)).out_info + self.assertEqual(out_info[0].shape, (8,)) + self.assertEqual(out_info[1].shape, (8,)) + self.assertEqual(out_info[0].dtype, np.int32) + self.assertEqual(out_info[1].dtype, np.int32) + self.assertEqual(out_info[0].sharding, None) + self.assertEqual(out_info[1].sharding, None) -class TempSharding(Sharding): + def test_jit_trace(self): + def f(x): + return x * 2 + + traced = jax.jit(f).trace(jnp.arange(8, dtype=jnp.int32)) + self.assertLen(traced.jaxpr.eqns, 1) + self.assertEqual(jax.tree.structure(traced.out_info).num_leaves, 1) + self.assertEqual(traced.out_info.shape, (8,)) + self.assertEqual(traced.out_info.dtype, jnp.int32) + # one for args, one for kwargs (though kwargs is empty) + self.assertLen(traced.in_avals, 2) + self.assertLen(traced.in_avals[0], 1) + self.assertLen(traced.in_avals[1], 0) # empty kwarg + + def test_jit_trace_lower_and_compile(self): + def f(x): + return x * 2 - def __init__(self, devices): - if xla_extension_version >= 235: - super().__init__() - self._devices = devices + lowered = jax.jit(f).trace(jnp.arange(8)).lower() + self.assertEqual(lowered.args_info[0][0].shape, (8,)) - @property - def device_set(self): - return set(self._devices) + compiled = lowered.compile() + out = compiled(jnp.arange(8)) + self.assertArraysEqual(out, np.arange(8) * 2) + + # fast-forward + lowered2 = jax.jit(f).lower(jnp.arange(8)) + self.assertEqual(lowered2.args_info[0][0].shape, (8,)) + + compiled2 = lowered2.compile() + out2 = compiled2(jnp.arange(8)) + self.assertArraysEqual(out2, np.arange(8) * 2) + + def test_device_put_efficient_reshard_single_host(self): + if jax.device_count() < 4: + self.skipTest('Requires >= 4 devices') - def devices_indices_map(self, global_shape): - return {d: (slice(None),) * len(global_shape) for d in self.device_set} + dev = jax.devices() + mesh1 = Mesh(np.array([dev[0], dev[1], dev[2], dev[3]]).reshape(2, 2), + ('x', 'y')) + mesh2 = Mesh(np.array([dev[3], dev[2], dev[1], dev[0]]).reshape(2, 2), + ('x', 'y')) + np_inp = np.arange(16).reshape(8, 2) + s1 = NamedSharding(mesh1, P('x', 'y')) + s2 = NamedSharding(mesh2, P('x')) - def shard_shape(self, global_shape): - return global_shape + x_s1 = jax.device_put(np_inp, s1) - @property - def is_fully_replicated(self): - return True + with jax.transfer_guard('disallow_explicit'): + out = jax.device_put(x_s1, s2) + self.assertArraysEqual(out, np_inp) + self.assertEqual(out.sharding, s2) def spec_regex(s): @@ -4071,8 +4347,7 @@ def testRankTooLowConstraint(self): r".*rank at least 2, but was applied to a value of rank 1", re.M | re.S) with self.assertRaisesRegex(ValueError, error): pjit( - lambda x: with_sharding_constraint(x, spec), - in_shardings=None, + lambda x: with_sharding_constraint(x, spec), in_shardings=None, out_shardings=None, )(x) @@ -4295,32 +4570,6 @@ def f(x, y): ' compiled'): g(x, y2) - def test_aot_error_on_dced_shardings_mismatch(self): - mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) - shape = (8, 2) - np_inp = np.arange(math.prod(shape)).reshape(shape) - - x = jax.device_put(np_inp, NamedSharding(mesh, P('x', 'y'))) - y1 = jax.device_put(np_inp, NamedSharding(mesh, P('x'))) - y2 = jax.device_put(np_inp, NamedSharding(mesh, P('y'))) - - @jax.jit - def f(x, y): - return x + 1 - - f_out1 = f(x, y1) - f(x, y2) - - g = f.lower(x, y1).compile() - g_out1 = g(x, y1) - self.assertArraysEqual(f_out1, g_out1) - - with self.assertRaisesRegex( - ValueError, - r"Compiled object called with input sharding.*does not match the " - r"sharding.*the computation was compiled with"): - g(x, y2) - def test_dce_no_array(self): mesh = jtu.create_global_mesh((2,), ('x',)) arr = jax.device_put(np.arange(8.), NamedSharding(mesh, P('x'))) @@ -4533,19 +4782,19 @@ def test_device_indices_cache(self): ops = GSPMDSharding(devices, op1) ops.devices_indices_map(shape) - cache_info1 = sharding_impls.gspmd_sharding_devices_indices_map.cache_info() + cache_info1 = common_devices_indices_map.cache_info() ops.devices_indices_map(shape) - cache_info2 = sharding_impls.gspmd_sharding_devices_indices_map.cache_info() + cache_info2 = common_devices_indices_map.cache_info() self.assertEqual(cache_info2.hits, cache_info1.hits + 1) ops = GSPMDSharding(devices, op2) ops.devices_indices_map(shape) - cache_info3 = sharding_impls.gspmd_sharding_devices_indices_map.cache_info() + cache_info3 = common_devices_indices_map.cache_info() self.assertEqual(cache_info3.hits, cache_info2.hits + 1) ops.devices_indices_map(shape) - cache_info4 = sharding_impls.gspmd_sharding_devices_indices_map.cache_info() + cache_info4 = common_devices_indices_map.cache_info() self.assertEqual(cache_info4.hits, cache_info3.hits + 1) def test_op_sharding_semantically_replicated(self): @@ -4681,24 +4930,6 @@ def test_op_sharding_cache_on_mesh_pspec_sharding(self): self.assertEqual(cache_info2.misses, cache_info1.misses) self.assertEqual(cache_info2.currsize, cache_info1.currsize) - def test_simulated_training_cache_in_pjit(self): - ndim = 2 - mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) - - mps1 = NamedSharding(mesh, P('x', 'y')) - gspmd_sharding = pjit_lib.to_gspmd_sharding(mps1, ndim) - next_loop_sharding = simulated_cached_fun(gspmd_sharding) - cache_info1 = simulated_cached_fun.cache_info() - - next_gspmd_sharding = pjit_lib.to_gspmd_sharding( - next_loop_sharding, ndim) - simulated_cached_fun(next_gspmd_sharding) - cache_info2 = simulated_cached_fun.cache_info() - - self.assertEqual(cache_info2.hits, cache_info1.hits + 1) - self.assertEqual(cache_info2.misses, cache_info1.misses) - self.assertEqual(id(next_gspmd_sharding), id(gspmd_sharding)) - def test_get_partition_spec(self): mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) s = NamedSharding(mesh, P('x', 'y', None)) diff --git a/tests/pmap_test.py b/tests/pmap_test.py index 4cf7cc06502e..5e279a5e6daa 100644 --- a/tests/pmap_test.py +++ b/tests/pmap_test.py @@ -15,11 +15,11 @@ from __future__ import annotations from concurrent.futures import ThreadPoolExecutor +import contextlib from functools import partial import itertools as it import gc import math -import os from random import shuffle import re from typing import Union, cast @@ -35,6 +35,7 @@ from jax import (pmap, jit, vmap, jvp, grad, make_jaxpr, linearize, device_put) from jax import lax +import jax.scipy.linalg from jax import random from jax.ad_checkpoint import checkpoint as new_checkpoint import jax.numpy as jnp @@ -45,7 +46,6 @@ from jax._src import sharding_impls from jax._src import sharding_specs from jax._src import test_util as jtu -from jax._src import xla_bridge from jax._src.internal_test_util import lax_test_util from jax._src.interpreters import mlir from jax._src.interpreters import pxla @@ -55,7 +55,14 @@ config.parse_flags_with_absl() -prev_xla_flags = None +# Run all tests with 8 CPU devices. +_exit_stack = contextlib.ExitStack() + +def setUpModule(): + _exit_stack.enter_context(jtu.set_host_platform_device_count(8)) + +def tearDownModule(): + _exit_stack.close() compatible_shapes = [[(3,)], [(3, 4), (3, 1), (1, 4)], [(2, 3, 4), (2, 1, 4)]] @@ -84,26 +91,6 @@ def args_slicer(args, bdims): slicers = safe_map(slicer, args, bdims) return lambda i: [sl(i) for sl in slicers] -# Run all tests with 8 CPU devices. -def setUpModule(): - global prev_xla_flags - prev_xla_flags = os.getenv("XLA_FLAGS") - flags_str = prev_xla_flags or "" - # Don't override user-specified device count, or other XLA flags. - if "xla_force_host_platform_device_count" not in flags_str: - os.environ["XLA_FLAGS"] = (flags_str + - " --xla_force_host_platform_device_count=8") - # Clear any cached backends so new CPU backend will pick up the env var. - xla_bridge.get_backend.cache_clear() - -# Reset to previous configuration in case other test modules will be run. -def tearDownModule(): - if prev_xla_flags is None: - del os.environ["XLA_FLAGS"] - else: - os.environ["XLA_FLAGS"] = prev_xla_flags - xla_bridge.get_backend.cache_clear() - ignore_jit_of_pmap_warning = partial( jtu.ignore_warning, message=".*jit-of-pmap.*") @@ -1109,12 +1096,41 @@ def testAxisGroups(self): self.assertEqual((tuple(sorted(groups[0])),), ((0, 1, 2, 3, 4, 5, 6, 7,),)) # order doesn't matter + @jtu.run_on_devices("gpu") + def testCollectiveBroadcast(self): + device_count = jax.device_count() + f = lambda x: lax.pbroadcast(x, source=0, axis_name='i') + f = self.pmap(f, 'i') + x = jnp.arange(4 * device_count).reshape((device_count, 4)) + ans = f(x) + expected = np.take(x, [0] * device_count, axis=0) + self.assertAllClose(ans, expected, check_dtypes=False) + + @jtu.run_on_devices("gpu") + def testCollectiveBroadcastVmap(self): + device_count = jax.device_count() + f = lambda x: lax.pbroadcast(x, source=0, axis_name='i') + x = np.arange(device_count * 16, dtype=np.float32) + x = x.reshape((device_count, 4, 4)) + ans = self.pmap(vmap(f), 'i')(x) + expected = jnp.broadcast_to(x[0:1], x.shape) + self.assertAllClose(ans, expected, check_dtypes=False) + + @jtu.run_on_devices("gpu") + def testCollectiveBroadcastGrad(self): + device_count = jax.device_count() + f = lambda x: lax.pbroadcast(x, source=0, axis_name='i') + x = np.arange(device_count, dtype=np.float32) + ans = self.pmap(grad(f), 'i')(x) + expected = np.zeros_like(x) + expected[0] = device_count + self.assertAllClose(ans, expected, check_dtypes=False) + def testCollectivePermute(self): device_count = jax.device_count() rotation = [(i, (i + 1) % device_count) for i in range(device_count)] f = lambda x: lax.ppermute(x, perm=rotation, axis_name='i') f = self.pmap(f, 'i') - x = jnp.arange(4 * device_count).reshape((device_count, 4)) ans = f(x) expected = np.roll(x, shift=1, axis=0) @@ -1213,7 +1229,7 @@ def print_board(board): boards.append(''.join('*' if x else ' ' for x in board.ravel())) print_board(reshaped_board) - for _ in range(20): + for _ in range(9): reshaped_board = step(reshaped_board) print_board(reshaped_board) @@ -1229,17 +1245,6 @@ def print_board(board): ' ** **** ****** ', ' ** * *** * ', ' ** **** ** * *** ', - ' ** * * **** ** * ', - ' ** **** ** * * **** ', - ' ** * *** ** ** * * ', - ' ** **** ** *** *** ** *** ', - ' ** * * *** * *** * * ', - ' ** **** ** * * ***** ******* ', - ' ** * *** **** * *** * ', - ' ** **** ** *** ** ** * *** ', - ' ** * * *** * ** *** **** ** * ', - ' ** **** ** * ****** * * *** ****', - ' * * *** **** **** *** ** * ', )) print(ans) @@ -1787,8 +1792,8 @@ def matrix_vector(x, y, parallel=True): res = fv(x) return res - x = random.normal(random.PRNGKey(1), (80, 5)) - y = random.normal(random.PRNGKey(1), (10, 5)) + x = random.normal(random.PRNGKey(1), (40, 5)) + y = random.normal(random.PRNGKey(1), (5, 5)) result1 = vmap(lambda b: matrix_vector(x, b, True))(y) # vmap + pmap result2 = lax.map(lambda b: matrix_vector(x, b, False), y) # map + map @@ -2045,13 +2050,10 @@ def test_grad_of_pmap_compilation_caching(self, axis_size): def f(x): return jnp.sin(x) + # warm-up the cache x = jnp.ones(axis_size) - f(x) # warm-up any dispatching compilations - - with jtu.count_jit_and_pmap_compiles() as count: # noqa: F841 - _, f_bwd = jax.vjp(f, x) - _ = f_bwd(x) - self.assertEqual(count[0], 2) # one for fwd, one for bwd + _, f_bwd = jax.vjp(f, x) + _ = f_bwd(x) with jtu.count_jit_and_pmap_compiles() as count: # noqa: F841 _, f_bwd2 = jax.vjp(f, x) @@ -2166,6 +2168,26 @@ def test_axis_name_shadowing_with_vmap(self): jax.vmap(jax.vmap(lambda x: 2 * x, axis_name='i'), axis_name='i')(jax.numpy.ones((1, 1))) # don't crash + @jtu.run_on_devices("cpu") + def test_pmap_stack_size(self): + # Regression test for https://github.com/google/jax/issues/20428 + # pmap isn't particularly important here, but it guarantees that the CPU + # client runs the computation on a threadpool rather than inline. + if jax.device_count() < 2: + raise SkipTest("test requires at least two devices") + x = jnp.eye(200) + y = jax.pmap(jax.scipy.linalg.expm)(jnp.array([x, x])) + y.block_until_ready() # doesn't crash + + def test_pmap_of_prng_key(self): + # Regression test for https://github.com/google/jax/issues/20392 + keys = jax.random.split(jax.random.key(0), jax.device_count()) + result1 = jax.pmap(jax.random.bits)(keys) + with jtu.ignore_warning( + category=UserWarning, message="The jitted function bits includes a pmap"): + result2 = jax.jit(jax.pmap(jax.random.bits))(keys) + self.assertArraysEqual(result1, result2) + @jtu.pytest_mark_if_available('multiaccelerator') class CppPmapTest(PythonPmapTest): @@ -2480,7 +2502,7 @@ def testOneDevice(self): f = lambda x: jnp.dot(x, x.T) f0 = pmap(f, devices=[d0]) f1 = pmap(f, devices=[d1]) - x = self.rng().rand(1, 1000, 1000) + x = self.rng().rand(1, 500, 500) r0 = f0(x) r1 = f1(x) expected = np.expand_dims(np.dot(x.squeeze(), x.squeeze().T), 0) diff --git a/tests/polynomial_test.py b/tests/polynomial_test.py index ccba4c2ef11f..3eeaec482719 100644 --- a/tests/polynomial_test.py +++ b/tests/polynomial_test.py @@ -19,12 +19,12 @@ from absl.testing import absltest +import jax from jax._src import dtypes from jax import numpy as jnp from jax._src import test_util as jtu -from jax import config -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() all_dtypes = jtu.dtypes.floating + jtu.dtypes.integer + jtu.dtypes.complex diff --git a/tests/pretty_printer_test.py b/tests/pretty_printer_test.py new file mode 100644 index 000000000000..d87708c9d91c --- /dev/null +++ b/tests/pretty_printer_test.py @@ -0,0 +1,36 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from absl.testing import absltest + +from jax._src import test_util as jtu +from jax._src import pretty_printer as pp + + +class PrettyPrinterTest(jtu.JaxTestCase): + + def testSourceMap(self): + doc = pp.concat([ + pp.text("abc"), pp.source_map(pp.text("def"), 101), + pp.source_map(pp.concat([pp.text("gh"), pp.brk(""), pp.text("ijkl")]), 77), + pp.text("mn"), + ]) + source_map = [] + out = doc.format(width=8, source_map=source_map) + self.assertEqual(out, "abcdefgh\nijklmn") + self.assertEqual(source_map, [[(3, 6, 101), (6, 8, 77)], [(0, 4, 77)]]) + + +if __name__ == "__main__": + absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/profiler_test.py b/tests/profiler_test.py index c232c3afd699..b67b078aec02 100644 --- a/tests/profiler_test.py +++ b/tests/profiler_test.py @@ -26,7 +26,6 @@ import jax import jax.numpy as jnp import jax.profiler -from jax import config import jax._src.test_util as jtu from jax._src import profiler @@ -50,7 +49,7 @@ except ImportError: pass -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() class ProfilerTest(unittest.TestCase): diff --git a/tests/python_callback_test.py b/tests/python_callback_test.py index b519e8bb1ce1..8eccffb9b773 100644 --- a/tests/python_callback_test.py +++ b/tests/python_callback_test.py @@ -13,9 +13,9 @@ # limitations under the License. import collections +import contextlib import functools import logging -import textwrap import time import unittest @@ -30,9 +30,9 @@ from jax._src import test_util as jtu from jax._src import util from jax._src.lib import xla_client +from jax._src.maps import xmap from jax.experimental import io_callback from jax.experimental import pjit -from jax.experimental.maps import xmap from jax.experimental.shard_map import shard_map import jax.numpy as jnp from jax.sharding import Mesh @@ -40,22 +40,13 @@ config.parse_flags_with_absl() - -def _format_multiline(text): - return textwrap.dedent(text).lstrip() - -prev_xla_flags = None - +_exit_stack = contextlib.ExitStack() def setUpModule(): - global prev_xla_flags - # This will control the CPU devices. On TPU we always have 2 devices - prev_xla_flags = jtu.set_host_platform_device_count(2) + _exit_stack.enter_context(jtu.set_host_platform_device_count(2)) - -# Reset to previous configuration in case other test modules will be run. def tearDownModule(): - prev_xla_flags() + _exit_stack.close() map, unsafe_map = util.safe_map, map @@ -71,6 +62,7 @@ def tearDownModule(): for flavor in ("io_unordered", "io_ordered", "pure") ) + class PythonCallbackTest(jtu.JaxTestCase): def setUp(self): @@ -86,12 +78,78 @@ def tearDown(self): def test_callback_with_scalar_values(self, *, callback): @jax.jit def f(x): - return callback(lambda x: x + np.float32(1.), - core.ShapedArray(x.shape, x.dtype), x) + return callback(lambda x: x + 1.0, core.ShapedArray(x.shape, x.dtype), x) out = f(0.) self.assertEqual(out, 1.) + @parameterized.named_parameters( + dict( + testcase_name=f"{flavor}_expect_dtype_{expect_dtype}", + callback=dict( + io_unordered=io_calback_unordered, + io_ordered=io_callback_ordered, + pure=jax.pure_callback, + )[flavor], + expect_dtype=expect_dtype, + ) + for flavor in ("io_unordered", "io_ordered", "pure") + for expect_dtype in (np.int32, np.int64, np.float32, np.float64) + ) + def test_callback_returning_python_literal(self, *, callback, expect_dtype): + returned_literal = 42 if expect_dtype in (np.int32, np.int64) else 42.0 + + @jax.jit + def f(x): + return callback( + lambda x: returned_literal, core.ShapedArray((), expect_dtype), x + ) + + if not config.enable_x64.value and expect_dtype in (np.int64, np.float64): + ctx = self.assertRaisesRegex(Exception, "result_shape_dtypes cannot specify 64-bit types") + elif config.enable_x64.value and expect_dtype in (np.int32, np.float32): + ctx = self.assertRaisesRegex(Exception, "Incorrect output dtype for return value") + else: + ctx = contextlib.nullcontext() + + with ctx: + out = f(0.0) + jax.effects_barrier() + self.assertEqual(out, returned_literal) + + @with_pure_and_io_callbacks + def test_callback_returning_custom_array(self, *, callback): + # Some users write the callback in TF, returning a tf.Tensor. We don't + # want to add TF as a dependency, but simulate that use case with a + # custom array class. + class CustomArray: + + def __init__(self, a: np.ndarray): + self.a = a + + @property + def shape(self): + return self.a.shape + + @property + def dtype(self): + return self.a.dtype + + def __array__(self): + return self.a + + @jax.jit + def f(x): + return callback( + lambda x: CustomArray(np.array(42.0, dtype=np.float32)), + core.ShapedArray((), np.float32), + x, + ) + + out = f(0.0) + jax.effects_barrier() + self.assertEqual(out, 42.0) + @parameterized.named_parameters( dict(testcase_name=f"{flavor}_{dtype}", dtype=dtype, @@ -187,9 +245,14 @@ def f(): # Calling a function expected a f32 return value but getting f64 return callback(_cb, core.ShapedArray((1,), np.float32)) - with self.assertRaises(RuntimeError): - f() + if config.enable_x64.value: + ctx = self.assertRaisesRegex(Exception, "Incorrect output dtype for return value") + else: + ctx = contextlib.nullcontext() + with ctx: + res = f() jax.effects_barrier() + self.assertAllClose(res, np.array([1], np.float32)) @with_pure_and_io_callbacks def test_callback_with_wrongly_specified_64_bit_dtype(self, *, callback): @@ -494,6 +557,43 @@ def f(x): out, np.arange(2 * jax.local_device_count()).reshape([-1, 2]) + 1.) + @with_pure_and_io_callbacks + def test_exception_in_callback(self, *, callback): + def fail(x): + raise RuntimeError("Ooops") + + @jax.jit + def f(x): + return callback(fail, core.ShapedArray(x.shape, x.dtype), x) + + with self.assertLogs(level="ERROR") as l: + try: + f(0.0).block_until_ready() + except RuntimeError: + pass + + api_name = ( + "pure_callback" if callback is jax.pure_callback else "io_callback" + ) + output = "\n".join(l.output) + self.assertIn(f"jax.{api_name} failed", output) + self.assertIn("Traceback (most recent call last)", output) + + @with_pure_and_io_callbacks + def test_compilation_caching(self, *, callback): + def f_outside(x): + return 2 * x + + def fun(x): + return callback(f_outside, x, x) + + x = np.arange(6, dtype=np.int32).reshape((2, 3)) + with jtu.count_primitive_compiles() as count: + for _ in range(3): + self.assertAllClose(2 * x, fun(x)) + self.assertEqual(count[0], 1) + + class PureCallbackTest(jtu.JaxTestCase): def setUp(self): @@ -505,10 +605,10 @@ def tearDown(self): super().tearDown() dispatch.runtime_tokens.clear() - def test_pure_callback_passes_ndarrays_without_jit(self): + def test_pure_callback_passes_jax_arrays_without_jit(self): def cb(x): - self.assertIs(type(x), np.ndarray) + self.assertIsInstance(x, jax.Array) return x def f(x): @@ -616,7 +716,6 @@ def h(x, y): out = h(jnp.arange(4.)[None], 4.) np.testing.assert_allclose(out, np.sin(np.arange(4.)[None]) + 4.) - def test_vmap_vectorized_callback_errors_if_returns_wrong_shape(self): def cb(x): @@ -693,6 +792,30 @@ def f(x): ValueError, "Pure callbacks do not support JVP."): f(2.) + def test_error_propagation(self): + def throws_error_fn(x): + raise RuntimeError("Errors should propagate.") + + @jax.jit + def f(x): + return jax.pure_callback(throws_error_fn, x, x) + + with self.assertRaisesRegex(Exception, "Errors should propagate."): + print(np.array(f(2.0)), flush=True) + + def test_reentrant_error_propagation(self): + reentrant_fn = jax.jit(jnp.sin).lower(2.0).compile() + + @jax.jit + def f(x): + return jax.pure_callback(reentrant_fn, x, x) + + try: + np.array(f(2.0)) + except: + # Only should not deadlock. + pass + def test_can_take_grad_of_pure_callback_with_custom_jvp(self): @jax.custom_jvp @@ -810,7 +933,6 @@ def f(self, ys): # callback alive. np.testing.assert_allclose(out, np.full((num_devices, 4), 11, np.float32)) - def test_callback_inside_xmap(self): def _callback(x): @@ -918,6 +1040,18 @@ def f(x): out, inp + y ) + def test_does_not_deadlock(self): + if jtu.device_under_test() == "tpu": + self.skipTest("The test raises an exception on TPU") + + def f(x): + y = jnp.asarray(x) + 1 + return np.asarray(2 * jnp.log(y)) + + x = jnp.array([1.0, 2.0, 3.0, 4.0]) + out = jax.pure_callback(f, jax.ShapeDtypeStruct(x.shape, x.dtype), x) + np.testing.assert_allclose(out, 2 * jnp.log(x + 1)) + class IOCallbackTest(jtu.JaxTestCase): @@ -1032,8 +1166,8 @@ def body(i): io_callback(check, None, i) return i + 1 return lax.while_loop(cond, body, i) - with self.assertRaisesRegex(NotImplementedError, - "IO effect not supported in vmap-of-while."): + with self.assertRaisesRegex( + Exception, "not supported in while_loop with batched predicate"): jax.vmap(f)(jnp.array([0, 4])) def test_cannot_use_io_callback_in_checkpoint(self): @@ -1189,6 +1323,40 @@ def f(shard_ids, x): self.assertLen(shard, 2) np.testing.assert_array_equal(shard[0] + 1, shard[1]) + def test_batching_with_side_effects(self): + # https://github.com/google/jax/issues/20628#issuecomment-2050800195 + x_lst = [] + def append_x(x): + nonlocal x_lst + x_lst.append(x) + + @jax.jit + def f(x): + io_callback(append_x, None, x, ordered=False) + io_callback(append_x, None, 2 * x, ordered=False) + + jax.vmap(f)(jnp.arange(3.)) + jax.effects_barrier() + self.assertAllClose(x_lst, [0., 1., 2., 0., 2., 4.], check_dtypes=False) + + def test_batching_with_side_effects_while_loop(self): + # https://github.com/google/jax/issues/20628#issuecomment-2050921219 + x_lst = [] + def append_x(x): + nonlocal x_lst + x_lst.append(x) + + @jax.jit + def f(x): + def body(i): + io_callback(append_x, None, x, ordered=False) + io_callback(append_x, None, 2 * x, ordered=False) + return i + 1 + jax.lax.while_loop(lambda i: i < 2, body, 0) + + jax.vmap(f)(jnp.arange(3.)) # don't crash + jax.effects_barrier() + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/pytorch_interoperability_test.py b/tests/pytorch_interoperability_test.py index 48cbc0b60b84..8d00b5eedaf4 100644 --- a/tests/pytorch_interoperability_test.py +++ b/tests/pytorch_interoperability_test.py @@ -88,8 +88,6 @@ def testJaxToTorch(self, shape, dtype): @jtu.sample_product(shape=all_shapes, dtype=torch_dtypes) def testJaxArrayToTorch(self, shape, dtype): - if xla_bridge.using_pjrt_c_api(): - self.skipTest("DLPack support is incomplete in the PJRT C API") if not config.enable_x64.value and dtype in [ jnp.int64, jnp.float64, @@ -111,8 +109,6 @@ def testJaxArrayToTorch(self, shape, dtype): self.assertAllClose(np, y.cpu().numpy()) def testTorchToJaxInt64(self): - if xla_bridge.using_pjrt_c_api(): - self.skipTest("DLPack support is incomplete in the PJRT C API") # See https://github.com/google/jax/issues/11895 x = jax.dlpack.from_dlpack( torch.utils.dlpack.to_dlpack(torch.ones((2, 3), dtype=torch.int64))) @@ -121,8 +117,6 @@ def testTorchToJaxInt64(self): @jtu.sample_product(shape=all_shapes, dtype=torch_dtypes) def testTorchToJax(self, shape, dtype): - if xla_bridge.using_pjrt_c_api(): - self.skipTest("DLPack support is incomplete in the PJRT C API") if not config.enable_x64.value and dtype in [ jnp.int64, jnp.float64, diff --git a/tests/qdwh_test.py b/tests/qdwh_test.py index 51179fee9b79..705c96f00f05 100644 --- a/tests/qdwh_test.py +++ b/tests/qdwh_test.py @@ -15,217 +15,192 @@ """Tests for the library of QDWH-based polar decomposition.""" import functools +from absl.testing import absltest import jax -import jax.numpy as jnp -import numpy as np -import scipy.linalg as osp_linalg from jax._src import config from jax._src import test_util as jtu from jax._src.lax import qdwh - -from absl.testing import absltest - +import jax.numpy as jnp +import numpy as np config.parse_flags_with_absl() -_JAX_ENABLE_X64_QDWH = config.enable_x64.value - -# Input matrix data type for QdwhTest. -_QDWH_TEST_DTYPE = np.float64 if _JAX_ENABLE_X64_QDWH else np.float32 - -# Machine epsilon used by QdwhTest. -_QDWH_TEST_EPS = jnp.finfo(_QDWH_TEST_DTYPE).eps - -# Largest log10 value of condition numbers used by QdwhTest. -_MAX_LOG_CONDITION_NUM = np.log10(int(1 / _QDWH_TEST_EPS)) - -def _check_symmetry(x: jax.Array) -> bool: - """Check if the array is symmetric.""" - m, n = x.shape - eps = jnp.finfo(x.dtype).eps - tol = 50.0 * eps - is_hermitian = False - if m == n: - if np.linalg.norm(x - x.T.conj()) / np.linalg.norm(x) < tol: - is_hermitian = True +float_types = jtu.dtypes.floating +complex_types = jtu.dtypes.complex - return is_hermitian -def _compute_relative_diff(actual, expected): +def _compute_relative_normwise_diff(actual, expected): """Computes relative difference between two matrices.""" return np.linalg.norm(actual - expected) / np.linalg.norm(expected) -_dot = functools.partial(jnp.dot, precision="highest") +_dot = functools.partial(jnp.dot, precision='highest') -class QdwhTest(jtu.JaxTestCase): - @jtu.sample_product( - [dict(m=m, n=n) for m, n in [(8, 6), (10, 10), (20, 18)]], - log_cond=np.linspace(1, _MAX_LOG_CONDITION_NUM, 4), - ) - def testQdwhUnconvergedAfterMaxNumberIterations( - self, m, n, log_cond): - """Tests unconvergence after maximum number of iterations.""" - a = jnp.triu(jnp.ones((m, n))) - u, s, v = jnp.linalg.svd(a, full_matrices=False) - cond = 10**log_cond - s = jnp.expand_dims(jnp.linspace(cond, 1, min(m, n)), range(u.ndim - 1)) - with jax.numpy_dtype_promotion('standard'): - a = (u * s) @ v - is_hermitian = _check_symmetry(a) - max_iterations = 2 +class QdwhTest(jtu.JaxTestCase): - _, _, actual_num_iterations, is_converged = qdwh.qdwh( - a, is_hermitian=is_hermitian, max_iterations=max_iterations) + def _testReconstruction(self, a, u, h, tol): + """Tests that a = u*p.""" + with self.subTest('Test reconstruction'): + diff = _compute_relative_normwise_diff(_dot(u, h), a) + self.assertLessEqual(diff, tol) - with self.subTest('Number of iterations.'): - self.assertEqual(max_iterations, actual_num_iterations) + def _testUnitary(self, u, tol): + """Tests that u is unitary.""" + with self.subTest('Test unitary'): + m, n = u.shape + self.assertAllClose( + _dot(u.conj().T, u), np.eye(n, dtype=u.dtype), atol=tol, rtol=tol + ) - with self.subTest('Converged.'): - self.assertFalse(is_converged) + def _testHermitian(self, h, tol): + """Tests that h is Hermitian.""" + with self.subTest('Test hermitian'): + self.assertAllClose(h, h.conj().T, atol=tol, rtol=tol) + + def _testPolarDecomposition(self, a, u, h, tol): + """Tests that u*h is the polar decomposition of a""" + self._testReconstruction(a, u, h, tol) + self._testUnitary(u, tol) + self._testHermitian(h, tol) + + def _testQdwh(self, a, dynamic_shape=None): + """Computes the polar decomposition and tests its basic properties.""" + eps = jnp.finfo(a.dtype).eps + u, h, iters, conv = qdwh.qdwh(a, dynamic_shape=dynamic_shape) + tol = 13 * eps + if dynamic_shape is not None: + m, n = dynamic_shape + a = a[:m, :n] + u = u[:m, :n] + h = h[:n, :n] + self._testPolarDecomposition(a, u, h, tol=tol) @jtu.sample_product( - [dict(m=m, n=n) for m, n in [(8, 6), (10, 10), (20, 18)]], - log_cond=np.linspace(1, _MAX_LOG_CONDITION_NUM, 4), + shape=[(8, 6), (10, 10), (20, 18)], + dtype=float_types + complex_types, ) - def testQdwhWithUpperTriangularInputAllOnes(self, m, n, log_cond): + def testQdwhWithUpperTriangularInputAllOnes(self, shape, dtype): """Tests qdwh with upper triangular input of all ones.""" - a = jnp.triu(jnp.ones((m, n))).astype(_QDWH_TEST_DTYPE) - u, s, v = jnp.linalg.svd(a, full_matrices=False) - cond = 10**log_cond - s = jnp.expand_dims(jnp.linspace(cond, 1, min(m, n)), range(u.ndim - 1)) - a = (u * s) @ v - is_hermitian = _check_symmetry(a) - max_iterations = 10 - - actual_u, actual_h, _, _ = qdwh.qdwh(a, is_hermitian=is_hermitian, - max_iterations=max_iterations) - expected_u, expected_h = osp_linalg.polar(a) - - # Sets the test tolerance. - rtol = 1E6 * _QDWH_TEST_EPS + eps = jnp.finfo(dtype).eps + m, n = shape + a = jnp.triu(jnp.ones((m, n))).astype(dtype) + self._testQdwh(a) - with self.subTest('Test u.'): - relative_diff_u = _compute_relative_diff(actual_u, expected_u) - np.testing.assert_almost_equal(relative_diff_u, 1E-6, decimal=5) - - with self.subTest('Test h.'): - relative_diff_h = _compute_relative_diff(actual_h, expected_h) - np.testing.assert_almost_equal(relative_diff_h, 1E-6, decimal=5) - - with self.subTest('Test u.dot(h).'): - a_round_trip = _dot(actual_u, actual_h) - relative_diff_a = _compute_relative_diff(a_round_trip, a) - np.testing.assert_almost_equal(relative_diff_a, 1E-6, decimal=5) - - with self.subTest('Test orthogonality.'): - actual_results = _dot(actual_u.T, actual_u) - expected_results = np.eye(n) - self.assertAllClose( - actual_results, expected_results, rtol=rtol, atol=1E-5) + @jtu.sample_product( + shape=[(2, 2), (5, 5), (8, 5), (10, 10)], + dtype=float_types + complex_types, + ) + def testQdwhWithDynamicShape(self, shape, dtype): + """Tests qdwh with dynamic shapes.""" + rng = jtu.rand_uniform(self.rng()) + a = rng((10, 10), dtype) + self._testQdwh(a, dynamic_shape=shape) @jtu.sample_product( - [dict(m=m, n=n) for m, n in [(6, 6), (8, 4)]], - padding=(None, (3, 2)), - log_cond=np.linspace(1, 4, 4), + shape=[(8, 6), (10, 10), (20, 18), (300, 300)], + log_cond=np.linspace(0, 1, 4), + dtype=float_types + complex_types, ) - def testQdwhWithRandomMatrix(self, m, n, log_cond, padding): - """Tests qdwh with random input.""" - rng = jtu.rand_uniform(self.rng(), low=0.3, high=0.9) - a = rng((m, n), _QDWH_TEST_DTYPE) - u, s, v = jnp.linalg.svd(a, full_matrices=False) + def testQdwhWithRandomMatrix(self, shape, log_cond, dtype): + """Tests qdwh with upper triangular input of all ones.""" + eps = jnp.finfo(dtype).eps + m, n = shape + max_cond = np.log10(1.0 / eps) + log_cond = log_cond * max_cond cond = 10**log_cond + + # Generates input matrix with prescribed condition number. + rng = jtu.rand_uniform(self.rng()) + a = rng((m, n), dtype) + u, _, v = jnp.linalg.svd(a, full_matrices=False) s = jnp.expand_dims(jnp.linspace(cond, 1, min(m, n)), range(u.ndim - 1)) - a = (u * s) @ v - is_hermitian = _check_symmetry(a) - max_iterations = 10 + a = (u * s.astype(u.dtype)) @ v + self._testQdwh(a) + @jtu.sample_product( + [dict(m=m, n=n) for m, n in [(6, 6), (8, 4)]], + padding=(None, (3, 2)), + dtype=float_types + complex_types, + ) + def testQdwhJitCompatibility(self, m, n, padding, dtype): + """Tests JIT compilation of QDWH with and without dynamic shape.""" + rng = jtu.rand_uniform(self.rng()) + a = rng((m, n), dtype) def lsp_linalg_fn(a): if padding is not None: pm, pn = padding a = jnp.pad(a, [(0, pm), (0, pn)], constant_values=jnp.nan) - u, h, _, _ = qdwh.qdwh( - a, is_hermitian=is_hermitian, max_iterations=max_iterations, - dynamic_shape=(m, n) if padding else None) + u, h, _, _ = qdwh.qdwh(a, dynamic_shape=(m, n) if padding else None) if padding is not None: u = u[:m, :n] h = h[:n, :n] return u, h args_maker = lambda: [a] - - # Sets the test tolerance. - rtol = 1E6 * _QDWH_TEST_EPS - with self.subTest('Test JIT compatibility'): self._CompileAndCheck(lsp_linalg_fn, args_maker) - with self.subTest('Test against numpy.'): - self._CheckAgainstNumpy(osp_linalg.polar, lsp_linalg_fn, args_maker, - rtol=rtol, atol=1E-3) - @jtu.sample_product( - [dict(m=m, n=n) for m, n in [(10, 10), (8, 8)]], - log_cond=np.linspace(1, 4, 4), + [dict(m=m, n=n, r=r) for m, n, r in [(10, 10, 8), (8, 8, 7), (12, 8, 5)]], + log_cond=np.linspace(0, 1, 4), + dtype=float_types + complex_types, ) - def testQdwhWithOnRankDeficientInput(self, m, n, log_cond): - """Tests qdwh with rank-deficient input.""" - a = np.triu(np.ones((m, n))).astype(_QDWH_TEST_DTYPE) - - # Generates a rank-deficient input. - u, s, v = np.linalg.svd(a, full_matrices=False) - cond = 10**log_cond - s = jnp.linspace(cond, 1, min(m, n)) - s = jnp.expand_dims(s.at[-1].set(0), range(u.ndim - 1)) - a = (u * s) @ v - - is_hermitian = _check_symmetry(a) - max_iterations = 15 - actual_u, actual_h, _, _ = qdwh.qdwh(a, is_hermitian=is_hermitian, - max_iterations=max_iterations) - _, expected_h = osp_linalg.polar(a) - - # For rank-deficient matrix, `u` is not unique. - with self.subTest('Test h.'): - relative_diff_h = _compute_relative_diff(actual_h, expected_h) - np.testing.assert_almost_equal(relative_diff_h, 1E-6, decimal=5) - - with self.subTest('Test u.dot(h).'): - a_round_trip = _dot(actual_u, actual_h) - relative_diff_a = _compute_relative_diff(a_round_trip, a) - np.testing.assert_almost_equal(relative_diff_a, 1E-6, decimal=5) - + def testQdwhOnRankDeficientInput(self, m, n, r, log_cond, dtype): + """Tests qdwh on rank-deficient input.""" + eps = jnp.finfo(dtype).eps + a = np.triu(np.ones((m, n))).astype(dtype) + + # Generates a rank-deficient input with prescribed condition number. + max_cond = np.log10(1.0 / eps) + log_cond = log_cond * max_cond + u, _, vh = np.linalg.svd(a, full_matrices=False) + s = 10**jnp.linspace(log_cond, 0, min(m, n)) + print(s) + s = jnp.expand_dims(s.at[r:].set(0), range(u.ndim - 1)) + a = (u * s.astype(u.dtype)) @ vh + + actual_u, actual_h, _, _ = qdwh.qdwh(a) + + self._testHermitian(actual_h, 10 * eps) + self._testReconstruction(a, actual_u, actual_h, 60 * eps) + + # QDWH gives U_p = U Σₖ V* for input A with SVD A = U Σ V*. For full rank + # input, we expect convergence Σₖ → I, giving the correct polar factor + # U_p = U V*. Zero singular values stay at 0 in exact arithmetic, but can + # end up anywhere in [0, 1] as a result of rounding errors---in particular, + # we do not generally expect convergence to 1. As a result, we can only + # expect (U_p V_r) to be orthogonal, where V_r are the columns of V + # corresponding to nonzero singular values. with self.subTest('Test orthogonality.'): - actual_results = _dot(actual_u.T.conj(), actual_u) - expected_results = np.eye(n) + vr = vh.conj().T[:, :r] + uvr = _dot(actual_u, vr) + actual_results = _dot(uvr.T.conj(), uvr) + expected_results = np.eye(r, dtype=actual_u.dtype) self.assertAllClose( - actual_results, expected_results, rtol=_QDWH_TEST_EPS, atol=1e-6 + actual_results, expected_results, atol=25 * eps, rtol=25 * eps ) @jtu.sample_product( - [dict(m=m, n=n, r=r, c=c) for m, n, r, c in [(4, 3, 1, 1), (5, 2, 0, 0)]], - dtype=jtu.dtypes.floating, + [dict(m=m, n=n, r=r, c=c) for m, n, r, c in [(4, 3, 1, 1), (5, 2, 0, 0)]], + dtype=float_types + complex_types, ) def testQdwhWithTinyElement(self, m, n, r, c, dtype): """Tests qdwh on matrix with zeros and close-to-zero entries.""" a = jnp.zeros((m, n), dtype=dtype) - tiny_elem = jnp.finfo(a.dtype).tiny + one = dtype(1.0) + tiny_elem = dtype(jnp.finfo(a.dtype).tiny) a = a.at[r, c].set(tiny_elem) - is_hermitian = _check_symmetry(a) - max_iterations = 10 - @jax.jit def lsp_linalg_fn(a): - u, h, _, _ = qdwh.qdwh( - a, is_hermitian=is_hermitian, max_iterations=max_iterations) + u, h, _, _ = qdwh.qdwh(a) return u, h actual_u, actual_h = lsp_linalg_fn(a) expected_u = jnp.zeros((m, n), dtype=dtype) - expected_u = expected_u.at[r, c].set(1.0) + expected_u = expected_u.at[r, c].set(one) with self.subTest('Test u.'): np.testing.assert_array_equal(expected_u, actual_u) diff --git a/tests/random_lax_test.py b/tests/random_lax_test.py index 52e8cbc7b262..f69687ddc6cd 100644 --- a/tests/random_lax_test.py +++ b/tests/random_lax_test.py @@ -65,6 +65,17 @@ def _CheckKolmogorovSmirnovCDF(self, samples, cdf, pval=None): # whether RBG keys may be involved, but that's no longer exact. if config.enable_custom_prng.value and samples.dtype == jnp.bfloat16: return + # kstest does not understand bfloat16 input, so cast to float32. + if samples.dtype == jnp.bfloat16: + samples = samples.astype('float32') + # kstest fails for infinities starting in scipy 1.12 + # (https://github.com/scipy/scipy/issues/20386) + # TODO(jakevdp): remove this logic if/when fixed upstream. + scipy_version = jtu.parse_version(scipy.__version__) + if scipy_version >= (1, 12) and np.issubdtype(samples.dtype, np.floating): + samples = np.array(samples, copy=True) + samples[np.isposinf(samples)] = 0.01 * np.finfo(samples.dtype).max + samples[np.isneginf(samples)] = 0.01 * np.finfo(samples.dtype).min self.assertGreater(scipy.stats.kstest(samples, cdf).pvalue, fail_prob) def _CheckChiSquared(self, samples, pmf, *, pval=None): @@ -202,22 +213,6 @@ def testTruncatedNormal(self, dtype): for samples in [uncompiled_samples, compiled_samples]: self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.truncnorm(-0.3, 0.3).cdf) - @jtu.sample_product(dtype=jtu.dtypes.floating + jtu.dtypes.integer) - def testShuffle(self, dtype): - key = lambda: self.make_key(0) - x = np.arange(100).astype(dtype) - rand = lambda key: random.shuffle(key, x) - crand = jax.jit(rand) - - with self.assertWarns((DeprecationWarning, FutureWarning)): - perm1 = rand(key()) - with self.assertWarns((DeprecationWarning, FutureWarning)): - perm2 = crand(key()) - - self.assertAllClose(perm1, perm2) - self.assertFalse(np.all(perm1 == x)) # seems unlikely! - self.assertAllClose(np.sort(perm1), x, check_dtypes=False) - @jtu.sample_product( [dict(shape=shape, replace=replace, axis=axis, input_range_or_shape=input_range_or_shape) @@ -1083,7 +1078,7 @@ def f(): # TODO(jakevdp): key reuse checks for this OOM because of slice masking. # Can we fix this? - with jax.enable_key_reuse_checks(False): + with jax.debug_key_reuse(False): # just lower, don't run, takes too long jax.jit(f).lower() @@ -1248,29 +1243,27 @@ def testBinomialCornerCases(self): self.assertArraysAllClose(samples2, jnp.array([jnp.nan, 0., jnp.nan, jnp.nan]), check_dtypes=False) self.assertArraysAllClose(samples3, jnp.array([jnp.nan, jnp.nan, jnp.nan]), check_dtypes=False) - def test_batched_key_warnings(self): + def test_batched_key_errors(self): keys = lambda: jax.random.split(self.make_key(0)) msg = "{} accepts a single key, but was given a key array of shape.*" - # Check a handful of functions that are expected to warn. - with self.assertWarnsRegex(FutureWarning, msg.format('bits')): + # Check a handful of functions that are expected to error. + with self.assertRaisesRegex(ValueError, msg.format('bits')): jax.random.bits(keys(), shape=(2,)) - with self.assertWarnsRegex(FutureWarning, msg.format('chisquare')): + with self.assertRaisesRegex(ValueError, msg.format('chisquare')): jax.random.chisquare(keys(), 1.0, shape=(2,)) - with self.assertWarnsRegex(FutureWarning, msg.format('dirichlet')): + with self.assertRaisesRegex(ValueError, msg.format('dirichlet')): jax.random.dirichlet(keys(), jnp.arange(2.0), shape=(2,)) - with self.assertWarnsRegex(FutureWarning, msg.format('gamma')): + with self.assertRaisesRegex(ValueError, msg.format('gamma')): jax.random.gamma(keys(), 1.0, shape=(2,)) - with self.assertWarnsRegex(FutureWarning, msg.format('loggamma')): + with self.assertRaisesRegex(ValueError, msg.format('loggamma')): jax.random.loggamma(keys(), 1.0, shape=(2,)) - - # Other functions should error; test a few cases. with self.assertRaisesRegex(ValueError, msg.format('fold_in')): jax.random.fold_in(keys(), 0) with self.assertRaisesRegex(ValueError, msg.format('split')): jax.random.split(keys()) - # Some shouldn't error or warn + # Shouldn't error or warn: with self.assertNoWarnings(): jax.random.key_data(keys()) jax.random.key_impl(keys()) @@ -1348,6 +1341,7 @@ def test_vmap_fold_in_shape(self): out = vmap(vmap(random.fold_in), in_axes=(1, 0))(keys(), msgs.T) self.assertEqual(out.shape, (3, 2)) + @jax.debug_key_reuse(False) def test_vmap_split_mapped_key(self): key = self.make_key(73) mapped_keys = random.split(key, num=3) @@ -1382,7 +1376,7 @@ def test_split_shape(self): keys = random.split(key, 10) self.assertEqual(keys.shape, (10, *key.shape)) - @jax.enable_key_reuse_checks(False) + @jax.debug_key_reuse(False) def test_vmap_fold_in_shape(self): # broadcast with scalar keys = random.split(self.make_key(73), 2) @@ -1398,7 +1392,7 @@ def test_vmap_fold_in_shape(self): out = vmap(random.fold_in, in_axes=(0, None))(keys, msgs[0]) self.assertEqual(out.shape, keys.shape) - @jax.enable_key_reuse_checks(False) + @jax.debug_key_reuse(False) def test_vmap_split_not_mapped_key(self): key = self.make_key(73) single_split_key = random.split(key) @@ -1408,24 +1402,58 @@ def test_vmap_split_not_mapped_key(self): self.assertArraysEqual(random.key_data(vk), random.key_data(single_split_key)) - def test_vmap_split_mapped_key(self): + @jax.debug_key_reuse(False) + def test_vmap_split_mapped_key_shape(self): key = self.make_key(73) mapped_keys = random.split(key, num=3) - forloop_keys = [random.split(k) for k in mapped_keys] vmapped_keys = vmap(random.split)(mapped_keys) self.assertEqual(vmapped_keys.shape, (3, 2, *key.shape)) - for fk, vk in zip(forloop_keys, vmapped_keys): - self.assertArraysEqual(random.key_data(fk), + + @jax.debug_key_reuse(False) + def test_vmap_split_mapped_key_values(self): + key = self.make_key(73) + mapped_keys = random.split(key, num=3) + vmapped_keys = vmap(random.split)(mapped_keys) + ref_keys = [random.split(k) for k in mapped_keys] + for rk, vk in zip(ref_keys, vmapped_keys): + self.assertArraysEqual(random.key_data(rk), random.key_data(vk)) - def test_vmap_random_bits(self): - rand_fun = lambda key: random.randint(key, (), 0, 100) + @jax.debug_key_reuse(False) + def test_vmap_random_bits_shape(self): + rand_fun = lambda key, shape=(): random.randint(key, shape, 0, 100) key = self.make_key(73) mapped_keys = random.split(key, num=3) - forloop_rand_nums = [rand_fun(k) for k in mapped_keys] rand_nums = vmap(rand_fun)(mapped_keys) self.assertEqual(rand_nums.shape, (3,)) - self.assertArraysEqual(rand_nums, jnp.array(forloop_rand_nums)) + + @jtu.skip_on_devices("tpu") + @jax.debug_key_reuse(False) + def test_vmap_random_bits_value(self): + rand_fun = lambda key, shape=(): random.randint(key, shape, 0, 100) + key = self.make_key(73) + mapped_keys = random.split(key, num=3) + rand_nums = vmap(rand_fun)(mapped_keys) + ref_nums = rand_fun(mapped_keys[0], shape=(3,)) + self.assertArraysEqual(rand_nums, ref_nums) + + def test_vmap_random_bits_distribution(self): + dtype = jnp.float32 + keys = lambda: jax.random.split(self.make_key(0), 10) + + def rand(key): + nums = jax.vmap(lambda key: random.uniform(key, (1000,), dtype))(key) + return nums.flatten() + + crand = jax.jit(rand) + + uncompiled_samples = rand(keys()) + compiled_samples = crand(keys()) + + for samples in [uncompiled_samples, compiled_samples]: + self._CheckCollisions(samples, jnp.finfo(dtype).nmant) + self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.uniform().cdf, + pval=0.005) def test_cannot_add(self): key = self.make_key(73) @@ -1455,6 +1483,15 @@ class LaxRandomWithUnsafeRBGPRNGTest(LaxRandomWithRBGPRNGTest): def make_key(self, seed): return random.PRNGKey(seed, impl="unsafe_rbg") + @jtu.skip_on_devices("tpu") + @jax.debug_key_reuse(False) + def test_vmap_split_mapped_key_values(self): + key = self.make_key(73) + mapped_keys = random.split(key, num=3) + vmapped_keys = vmap(random.split)(mapped_keys) + ref_keys = random.split(mapped_keys[0], (3, 2)) + self.assertArraysEqual(random.key_data(vmapped_keys), + random.key_data(ref_keys)) def _sampler_unimplemented_with_custom_prng(*args, **kwargs): raise SkipTest('sampler only implemented for default RNG') diff --git a/tests/random_test.py b/tests/random_test.py index d8212106e211..2c45d60cc64d 100644 --- a/tests/random_test.py +++ b/tests/random_test.py @@ -33,7 +33,6 @@ from jax import random from jax._src import config from jax._src import core -from jax._src import deprecations from jax._src import dtypes from jax._src import test_util as jtu from jax import vmap @@ -46,6 +45,9 @@ PRNG_IMPLS = list(prng_internal.prngs.items()) +# Remove Pallas keys from this test, which do not run in XLA. +PRNG_IMPLS = [ + (name, impl) for (name, impl) in PRNG_IMPLS if "pallas" not in name] class OnX64(enum.Enum): @@ -386,6 +388,22 @@ def testPRNGValues(self, make_key): random.key_data(random.fold_in(make_key(seed), 4)), np.array([2285895361, 433833334], dtype='uint32')) + @jtu.run_on_devices("gpu") + def test_threefry_gpu_kernel_lowering(self): + f = lambda key: jax.random.uniform(key, (1,)) + with jax._src.config.threefry_gpu_kernel_lowering(False): + hlo_text = jax.jit(f).lower(jax.random.key(17)).as_text() + if jtu.is_device_rocm(): + self.assertNotIn("hip_threefry2x32", hlo_text) + else: + self.assertNotIn("cu_threefry2x32", hlo_text) + with jax._src.config.threefry_gpu_kernel_lowering(True): + hlo_text = jax.jit(f).lower(jax.random.key(17)).as_text() + if jtu.is_device_rocm(): + self.assertIn("hip_threefry2x32", hlo_text) + else: + self.assertIn("cu_threefry2x32", hlo_text) + @parameterized.parameters([{'make_key': ctor} for ctor in KEY_CTORS]) def test_random_seed_offset(self, make_key): k1 = make_key(17) @@ -588,6 +606,14 @@ def test_construction(self): key = random.key(42) self.assertIsInstance(key, prng_internal.PRNGKeyArray) + def test_random_clone(self): + # Here we test value semantics and compatibility with jit/vmap + # key reuse semantics are tested in key_reuse_test.py + keys = jax.random.split(jax.random.key(0), 5) + self.assertKeysEqual(keys, jax.random.clone(keys)) + self.assertKeysEqual(keys, jax.jit(jax.random.clone)(keys)) + self.assertKeysEqual(keys, jax.vmap(jax.random.clone)(keys)) + def test_issubdtype(self): key = random.key(42) @@ -669,7 +695,7 @@ def test_key_copy(self): self.assertKeysEqual(key, jax.jit(lambda k: k.copy())(key)) # TODO(jakevdp) remove this decorator when reuse checks move to C++ - @jax.enable_key_reuse_checks(False) + @jax.debug_key_reuse(False) def test_cpp_dispatch_normal(self): # Ensure we stay on the C++ dispatch path when calling a jitted # function with a key array as an argument. @@ -686,7 +712,7 @@ def f(key): self.assertEqual(count[0], 1) # TODO(jakevdp) remove this decorator when reuse checks move to C++ - @jax.enable_key_reuse_checks(False) + @jax.debug_key_reuse(False) def test_cpp_dispatch_split(self): # Ensure we stay on the C++ dispatch path when calling a jitted # function with a key arrays as inputs and as outputs. @@ -1011,12 +1037,11 @@ def test_array_impl_attributes(self): self.assertEqual(key.is_fully_addressable, key._base_array.is_fully_addressable) self.assertEqual(key.is_fully_replicated, key._base_array.is_fully_replicated) - if not deprecations.is_accelerated('jax._src.array', 'device-method'): - with jtu.ignore_warning(category=DeprecationWarning, message="arr.device"): - self.assertEqual(key.device(), key._base_array.device()) self.assertEqual(key.devices(), key._base_array.devices()) - self.assertEqual(key.on_device_size_in_bytes, key._base_array.on_device_size_in_bytes) - self.assertEqual(key.unsafe_buffer_pointer, key._base_array.unsafe_buffer_pointer) + self.assertEqual(key.on_device_size_in_bytes(), + key._base_array.on_device_size_in_bytes()) + self.assertEqual(key.unsafe_buffer_pointer(), + key._base_array.unsafe_buffer_pointer()) self.assertArraysEqual(key.addressable_data(0)._base_array, key._base_array.addressable_data(0)) self.assertLen(key.addressable_shards, len(key._base_array.addressable_shards)) @@ -1205,9 +1230,9 @@ def test_reshape(self): key = random.key(123) keys = random.split(key, 4) - newshape = (2, 2) - key_func = partial(jnp.reshape, newshape=newshape) - arr_func = partial(jnp.reshape, newshape=(*newshape, *key._impl.key_shape)) + shape = (2, 2) + key_func = partial(jnp.reshape, shape=shape) + arr_func = partial(jnp.reshape, shape=(*shape, *key._impl.key_shape)) self.check_shape(key_func, keys) self.check_against_reference(key_func, arr_func, keys) @@ -1267,7 +1292,7 @@ def test_append(self): self.check_shape(key_func, keys(), key()) self.check_shape(arr_func, keys(), key()) - with jax.enable_key_reuse_checks(False): + with jax.debug_key_reuse(False): self.check_against_reference(key_func, arr_func, keys(), key()) def test_ravel(self): @@ -1275,7 +1300,7 @@ def test_ravel(self): keys = random.split(key, 4).reshape(2, 2) key_func = jnp.ravel - arr_func = partial(jnp.reshape, newshape=(4, *key._impl.key_shape)) + arr_func = partial(jnp.reshape, shape=(4, *key._impl.key_shape)) self.check_shape(key_func, keys) self.check_against_reference(key_func, arr_func, keys) @@ -1312,7 +1337,7 @@ def test_getitem(self, idx): key_func = arr_func = lambda x: x[idx] self.check_shape(key_func, keys()) - with jax.enable_key_reuse_checks(False): + with jax.debug_key_reuse(False): self.check_against_reference(key_func, arr_func, keys()) @parameterized.parameters([ @@ -1326,10 +1351,10 @@ def test_gather(self, idx): key_func = arr_func = lambda key: key.at[idx].get() self.check_shape(key_func, keys()) - with jax.enable_key_reuse_checks(False): + with jax.debug_key_reuse(False): self.check_against_reference(key_func, arr_func, keys()) - @jax.enable_key_reuse_checks(False) + @jax.debug_key_reuse(False) def test_equality(self): key = random.key(123) key2 = random.key(456) diff --git a/tests/scipy_fft_test.py b/tests/scipy_fft_test.py index 77ee057d29b4..17c1e9c2d1d0 100644 --- a/tests/scipy_fft_test.py +++ b/tests/scipy_fft_test.py @@ -15,13 +15,12 @@ from absl.testing import absltest +import jax from jax._src import test_util as jtu import jax.scipy.fft as jsp_fft import scipy.fft as osp_fft -from jax import config - -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() float_dtypes = jtu.dtypes.floating real_dtypes = float_dtypes + jtu.dtypes.integer + jtu.dtypes.boolean diff --git a/tests/scipy_interpolate_test.py b/tests/scipy_interpolate_test.py index ee905b7f0112..1fead634ab7b 100644 --- a/tests/scipy_interpolate_test.py +++ b/tests/scipy_interpolate_test.py @@ -18,13 +18,13 @@ from functools import reduce import numpy as np +import jax from jax._src import test_util as jtu import scipy.interpolate as sp_interp import jax.scipy.interpolate as jsp_interp -from jax import config -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() class LaxBackedScipyInterpolateTests(jtu.JaxTestCase): diff --git a/tests/scipy_ndimage_test.py b/tests/scipy_ndimage_test.py index 7ce0df8736cd..701b7c570937 100644 --- a/tests/scipy_ndimage_test.py +++ b/tests/scipy_ndimage_test.py @@ -21,13 +21,13 @@ from absl.testing import absltest import scipy.ndimage as osp_ndimage +import jax from jax import grad from jax._src import test_util as jtu from jax import dtypes from jax.scipy import ndimage as lsp_ndimage -from jax import config -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() float_dtypes = jtu.dtypes.floating @@ -112,10 +112,6 @@ def testMapCoordinatesErrors(self): with self.assertRaisesRegex(ValueError, 'sequence of length'): lsp_ndimage.map_coordinates(x, [c, c], order=1) - def testMapCoordinateDocstring(self): - self.assertIn("Only nearest neighbor", - lsp_ndimage.map_coordinates.__doc__) - @jtu.sample_product( dtype=float_dtypes + int_dtypes, order=[0, 1], diff --git a/tests/scipy_optimize_test.py b/tests/scipy_optimize_test.py index e07455e06f81..70a00e14c468 100644 --- a/tests/scipy_optimize_test.py +++ b/tests/scipy_optimize_test.py @@ -17,13 +17,13 @@ import scipy import scipy.optimize +import jax from jax import numpy as jnp from jax._src import test_util as jtu from jax import jit -from jax import config import jax.scipy.optimize -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() def rosenbrock(np): diff --git a/tests/scipy_signal_test.py b/tests/scipy_signal_test.py index 70a367a04e74..11923257a9dd 100644 --- a/tests/scipy_signal_test.py +++ b/tests/scipy_signal_test.py @@ -21,14 +21,14 @@ import numpy as np import scipy.signal as osp_signal +import jax from jax import lax import jax.numpy as jnp from jax._src import dtypes from jax._src import test_util as jtu import jax.scipy.signal as jsp_signal -from jax import config -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() onedim_shapes = [(1,), (2,), (5,), (10,)] twodim_shapes = [(1, 1), (2, 2), (2, 3), (3, 4), (4, 4)] diff --git a/tests/scipy_spatial_test.py b/tests/scipy_spatial_test.py index f51ad49adc22..5acbdc0ddb6b 100644 --- a/tests/scipy_spatial_test.py +++ b/tests/scipy_spatial_test.py @@ -25,9 +25,8 @@ import jax.numpy as jnp import numpy as onp -from jax import config -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() scipy_version = jtu.parse_version(scipy.version.version) diff --git a/tests/scipy_stats_test.py b/tests/scipy_stats_test.py index 7679b1a96810..501f4cbe5e5f 100644 --- a/tests/scipy_stats_test.py +++ b/tests/scipy_stats_test.py @@ -23,12 +23,12 @@ import scipy.version import jax +import jax.numpy as jnp from jax._src import dtypes, test_util as jtu from jax.scipy import stats as lsp_stats from jax.scipy.special import expit -from jax import config -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() scipy_version = jtu.parse_version(scipy.version.version) @@ -130,7 +130,6 @@ def testPoissonLogPmf(self, shapes, dtypes): def args_maker(): k, mu, loc = map(rng, shapes, dtypes) - k = np.floor(k) # clipping to ensure that rate parameter is strictly positive mu = np.clip(np.abs(mu), a_min=0.1, a_max=None).astype(mu.dtype) loc = np.floor(loc) @@ -149,7 +148,6 @@ def testPoissonPmf(self, shapes, dtypes): def args_maker(): k, mu, loc = map(rng, shapes, dtypes) - k = np.floor(k) # clipping to ensure that rate parameter is strictly positive mu = np.clip(np.abs(mu), a_min=0.1, a_max=None).astype(mu.dtype) loc = np.floor(loc) @@ -253,7 +251,7 @@ def args_maker(): @genNamedParametersNArgs(5) def testBetaLogPdf(self, shapes, dtypes): - rng = jtu.rand_positive(self.rng()) + rng = jtu.rand_default(self.rng()) scipy_fun = osp_stats.beta.logpdf lax_fun = lsp_stats.beta.logpdf @@ -323,6 +321,23 @@ def testBetaLogPdfZero(self): osp_stats.beta.pdf(x, a, b), lsp_stats.beta.pdf(x, a, b), atol=1e-5, rtol=2e-5) + def testBetaLogPdfNegativeConstants(self): + a = b = -1.1 + x = jnp.array([0., 0.5, 1.]) + self.assertAllClose( + osp_stats.beta.pdf(x, a, b), lsp_stats.beta.pdf(x, a, b), atol=1e-5, + rtol=2e-5) + + def testBetaLogPdfNegativeScale(self): + a = b = 1. + x = jnp.array([0., 0.5, 1.]) + loc = 0 + scale = -1 + self.assertAllClose( + osp_stats.beta.pdf(x, a, b, loc, scale), + lsp_stats.beta.pdf(x, a, b, loc, scale), atol=1e-5, + rtol=2e-5) + @genNamedParametersNArgs(3) def testCauchyLogPdf(self, shapes, dtypes): rng = jtu.rand_default(self.rng()) @@ -1476,14 +1491,14 @@ def resample(key, dataset, weights, *, shape): ndim = shape[0] if len(shape) > 1 else 1 func = partial(resample, shape=()) - with jax.enable_key_reuse_checks(False): + with jax.debug_key_reuse(False): self._CompileAndCheck( func, args_maker, rtol={np.float32: 3e-07, np.float64: 4e-15}) result = func(*args_maker()) assert result.shape == (ndim,) func = partial(resample, shape=(4,)) - with jax.enable_key_reuse_checks(False): + with jax.debug_key_reuse(False): self._CompileAndCheck( func, args_maker, rtol={np.float32: 3e-07, np.float64: 4e-15}) result = func(*args_maker()) @@ -1618,21 +1633,25 @@ def testRankData(self, shape, dtype, axis, method): self._CompileAndCheck(lax_fun, args_maker, rtol=tol) @jtu.sample_product( - [dict(shape=shape, axis=axis, ddof=ddof, nan_policy=nan_policy) + [dict(shape=shape, axis=axis, ddof=ddof, nan_policy=nan_policy, keepdims=keepdims) for shape in [(5,), (5, 6), (5, 6, 7)] for axis in [None, *range(len(shape))] for ddof in [0, 1, 2, 3] for nan_policy in ["propagate", "omit"] + for keepdims in [True, False] ], dtype=jtu.dtypes.integer + jtu.dtypes.floating, ) - def testSEM(self, shape, dtype, axis, ddof, nan_policy): + def testSEM(self, shape, dtype, axis, ddof, nan_policy, keepdims): rng = jtu.rand_default(self.rng()) args_maker = lambda: [rng(shape, dtype)] - scipy_fun = partial(osp_stats.sem, axis=axis, ddof=ddof, nan_policy=nan_policy) - lax_fun = partial(lsp_stats.sem, axis=axis, ddof=ddof, nan_policy=nan_policy) + kwds = {} if scipy_version < (1, 11) else {'keepdims': keepdims} + scipy_fun = partial(osp_stats.sem, axis=axis, ddof=ddof, nan_policy=nan_policy, + **kwds) + lax_fun = partial(lsp_stats.sem, axis=axis, ddof=ddof, nan_policy=nan_policy, + **kwds) tol_spec = {np.float32: 2e-4, np.float64: 5e-6} tol = jtu.tolerance(dtype, tol_spec) self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, diff --git a/tests/shape_poly_test.py b/tests/shape_poly_test.py index 930263b891fc..f079d6753edd 100644 --- a/tests/shape_poly_test.py +++ b/tests/shape_poly_test.py @@ -16,13 +16,13 @@ from __future__ import annotations import enum -from collections.abc import Sequence +from collections.abc import Callable, Sequence import cProfile import itertools import math import os from pstats import Stats -from typing import Any, Callable +from typing import Any import unittest from absl import logging @@ -35,9 +35,7 @@ import re import jax -from jax.experimental import export -from jax.experimental.export import _shape_poly as shape_poly -from jax.experimental.export import _shape_poly_decision as shape_poly_decision +from jax import export from jax.experimental import pjit from jax import lax import jax.numpy as jnp @@ -46,6 +44,8 @@ from jax._src import config from jax._src import core from jax._src import test_util as jtu +from jax._src.export import shape_poly +from jax._src.export import shape_poly_decision from jax._src.lax import lax as lax_internal from jax._src.lax import control_flow as lax_control_flow from jax._src.lib import xla_client @@ -1267,11 +1267,11 @@ def log_message(extra: str): tst.assertEqual(getattr(jax.config, fname), fvalue, ( f"Flag {fname} current value {getattr(jax.config, fname)} != {fvalue}")) - f_jax = self.dyn_fun + f_jax = jax.jit(self.dyn_fun) args = self.dyn_args_maker(tst.rng()) args = jax.tree.map(jnp.array, args) args_specs = export.symbolic_args_specs(args, self.polymorphic_shapes, - symbolic_constraints=self.symbolic_constraints) + constraints=self.symbolic_constraints) if self.expect_error is not None: with tst.assertRaisesRegex(self.expect_error[0], self.expect_error[1]): @@ -1283,7 +1283,7 @@ def log_message(extra: str): return None # Run the JAX natively and then the exported function and compare res_jax_native = f_jax(*args) - res_jax_exported = export.call_exported(exp)(*args) + res_jax_exported = exp.call(*args) custom_assert_lims = [ l for l in self.limitations if l.custom_assert is not None] assert len(custom_assert_lims) <= 1, custom_assert_lims @@ -1315,7 +1315,7 @@ def check_shape_poly(tst, f_jax: Callable, *, symbolic_constraints: Sequence[str] = (), expect_error=None) -> jax.Array | None: # Builds a PolyHarness and runs the test. See PolyHarness documentation. - h = PolyHarness("", "", f_jax, + h = PolyHarness("", "", jax.jit(f_jax), arg_descriptors=arg_descriptors, polymorphic_shapes=polymorphic_shapes, symbolic_constraints=symbolic_constraints, @@ -1408,11 +1408,10 @@ def test_kwargs(self): def f_jax(x, *, y): return x + jnp.sin(y) - f_exported = export.call_exported( - export.export(f_jax)(jax.ShapeDtypeStruct(export.symbolic_shape("b"), - x.dtype), - y=jax.ShapeDtypeStruct(y.shape, y.dtype))) - self.assertAllClose(f_jax(x, y=y), f_exported(x, y=y)) + exp = export.export(jax.jit(f_jax))( + jax.ShapeDtypeStruct(export.symbolic_shape("b"), x.dtype), + y=jax.ShapeDtypeStruct(y.shape, y.dtype)) + self.assertAllClose(f_jax(x, y=y), exp.call(x, y=y)) def test_arg_avals_errors(self): """Test error reporting for shape polymorphism.""" @@ -1617,8 +1616,8 @@ def f(x): # x: i32[a, b] acc += jnp.sum(slice, axis=0) return acc - _ = export.export(f)(jax.ShapeDtypeStruct(export.symbolic_shape("a, b"), - np.int32)) + _ = export.export(jax.jit(f))( + jax.ShapeDtypeStruct(export.symbolic_shape("a, b"), np.int32)) def test_constraints_compile_time_check(self): @@ -1630,29 +1629,30 @@ def f(x): # x: i32[a] x_spec = jax.ShapeDtypeStruct( export.symbolic_shape("a", constraints=["a >= 2", "a <= 4"]), np.int32) - exp = export.export(f)(x_spec) + exp = export.export(jax.jit(f))(x_spec) x_2 = np.arange(2, dtype=np.int32) - res_2 = export.call_exported(exp)(x_2) + res_2 = exp.call(x_2) self.assertAllClose(x_2[0:2], res_2) x_4 = np.arange(4, dtype=np.int32) - res_4 = export.call_exported(exp)(x_4) + res_4 = exp.call(x_4) self.assertAllClose(x_4[1:3], res_4) with self.assertRaisesRegex( ValueError, re.escape("Expected 'a - 2' to be greater or equal to 0, but found -1")): - export.call_exported(exp)(np.arange(1, dtype=np.int32)) + exp.call(np.arange(1, dtype=np.int32)) with self.assertRaisesRegex( ValueError, re.escape("Expected '- a + 4' to be greater or equal to 0, but found -1")): - export.call_exported(exp)(np.arange(5, dtype=np.int32)) + exp.call(np.arange(5, dtype=np.int32)) def test_caching_with_scopes(self): f_tracing_count = 0 expected_a_bounds = (1, np.inf) + @jax.jit def f(x): # x: i32[a] nonlocal f_tracing_count f_tracing_count += 1 @@ -1757,7 +1757,7 @@ def f(x): polymorphic_shapes=["(b,)"]) self.assertAllClose(f(x), res_tf) - @jax.enable_key_reuse_checks(False) + @jax.debug_key_reuse(False) def test_prng(self): # The PRNG implementation uses opaque types, test shape polymorphism with config.enable_custom_prng(True): @@ -1997,6 +1997,30 @@ def test_vmap_error(self): + jnp.sin(x))), arg_descriptors=[RandArg((3, 4), _f32)], polymorphic_shapes=["b, ..."]), + [ # approx_max_k + # x: f32[b, {n}, 32] with n being either 8 or the symbol "n" + # we reduce on dim=1, with size n + # k is either the constant 4 or the symbol "k" + PolyHarness("approx_max_k", f"n_{n}_k_{k}_agg={agg}", + lambda x, x_k, agg: lax.approx_max_k( + x, k=x_k.shape[0], reduction_dimension=1, + aggregate_to_topk=agg), + arg_descriptors=[RandArg((3, 8, 32), _f32), + RandArg((4,), _f32), + StaticArg(agg)], + polymorphic_shapes=[f"b, {n}, 32", f"{k},"], + # k must be at most the reduction dimension size + symbolic_constraints=[f"{k} <= {n}"], + expect_error=( + (NotImplementedError, "aggregate_to_topk=False") if ( + not agg and (isinstance(k, str) or + isinstance(n, str))) else + None + )) + for n in [8, "n"] + for k in [4, "k"] + for agg in [True, False] + ], [ # arange PolyHarness("arange", name, f_jax, @@ -3071,6 +3095,12 @@ def test_vmap_error(self): lambda x: jnp.tri(x.shape[0], M=x.shape[0] + 2) + x, arg_descriptors=[RandArg((3, 1), _f32)], polymorphic_shapes=["b, ..."]), + PolyHarness("tril", "", + lambda x: jnp.tril(jnp.ones((x.shape[0], x.shape[0] + x.shape[1]), + dtype=_f32), + k=x.shape[1]), + arg_descriptors=[RandArg((3, 4), _f32)], + polymorphic_shapes=["m, n"]), [ PolyHarness("triangular_solve", f"shape={jtu.format_shape_dtype_string(a_shape, dtype)}_{left_side=}_{a_poly=}_{b_poly=}", @@ -3259,7 +3289,7 @@ def test_harness(self, harness: PolyHarness): if harness.group_name == "vmap_eigh" and jtu.test_device_matches(["gpu"]): # For eigh on GPU with shape polymorphism under native serialization, - # we use a different lowering for small matrices. See README.md. + # we use a different lowering for small matrices. shape = harness.original_harness.params["shape"] if 0 < shape[-1] <= 32: harness.check_result = False @@ -3294,16 +3324,15 @@ def test_harness(self, harness: PolyHarness): # Update this here rather than in harness object because vmap_random_gamma is derived # from test_harnesses.all_harnesses, which strips override_jax_config_flags. if "random_gamma" in harness.group_name: - config_flags = {**config_flags, "jax_enable_key_reuse_checks": False} + config_flags = {**config_flags, "jax_debug_key_reuse": False} - prev_jax_config_flags = {fname: getattr(jax.config, fname) for fname in config_flags} - try: - for fname, fvalue in config_flags.items(): - jax.config.update(fname, fvalue) + # TPU precision is a little lower since we swap the order of matmul operands. + if "cholesky" in harness.group_name and jtu.test_device_matches(["tpu"]): + harness.tol = 5e-5 + + with jtu.global_config_context(**config_flags): harness.run_test(self) - finally: - for fname, _ in config_flags.items(): - jax.config.update(fname, prev_jax_config_flags[fname]) + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/shard_alike_test.py b/tests/shard_alike_test.py index 291ee5360864..11305e937a08 100644 --- a/tests/shard_alike_test.py +++ b/tests/shard_alike_test.py @@ -12,42 +12,27 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os +import contextlib import jax import jax.numpy as jnp import numpy as np from absl.testing import absltest -from jax._src import xla_bridge from jax._src import test_util as jtu from jax.sharding import NamedSharding, PartitionSpec as P from jax.experimental.shard_alike import shard_alike from jax.experimental.shard_map import shard_map -from jax._src.lib import xla_extension_version -from jax import config -config.parse_flags_with_absl() - -prev_xla_flags = None +jax.config.parse_flags_with_absl() +# Run all tests with 8 CPU devices. +_exit_stack = contextlib.ExitStack() def setUpModule(): - global prev_xla_flags - prev_xla_flags = os.getenv("XLA_FLAGS") - flags_str = prev_xla_flags or "" - # Don't override user-specified device count, or other XLA flags. - if "xla_force_host_platform_device_count" not in flags_str: - os.environ["XLA_FLAGS"] = (flags_str + - " --xla_force_host_platform_device_count=8") - # Clear any cached backends so new CPU backend will pick up the env var. - xla_bridge.get_backend.cache_clear() + _exit_stack.enter_context(jtu.set_host_platform_device_count(8)) def tearDownModule(): - if prev_xla_flags is None: - del os.environ["XLA_FLAGS"] - else: - os.environ["XLA_FLAGS"] = prev_xla_flags - xla_bridge.get_backend.cache_clear() + _exit_stack.close() class ShardAlikeDownstreamTest(jtu.JaxTestCase): @@ -64,8 +49,6 @@ class ShardAlikeTest(jtu.JaxTestCase): def setUp(self): super().setUp() - if xla_extension_version < 227: - self.skipTest('Requires xla_extension_version >= 227') def test_basic(self): mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) diff --git a/tests/shard_map_test.py b/tests/shard_map_test.py index 8e47e71838d9..ca9d813e2571 100644 --- a/tests/shard_map_test.py +++ b/tests/shard_map_test.py @@ -14,14 +14,14 @@ from __future__ import annotations -from collections.abc import Sequence, Iterable, Iterator, Generator +from collections.abc import Callable, Generator, Iterable, Iterator, Sequence +import contextlib from functools import partial import itertools as it import math import operator as op -import os from types import SimpleNamespace -from typing import Any, NamedTuple, Callable, TypeVar +from typing import Any, NamedTuple, TypeVar import unittest from absl.testing import absltest @@ -36,7 +36,6 @@ from jax._src import config from jax._src import core from jax._src import test_util as jtu -from jax._src import xla_bridge from jax._src.util import safe_zip, safe_map, partition_list, merge_lists from jax._src.ad_checkpoint import saved_residuals from jax._src.interpreters import mlir @@ -67,31 +66,17 @@ def create_inputs(a_sharding, b_sharding): jax.sharding.NamedSharding(mesh, b_sharding)) return mesh, m1, m2 -# Run all tests with 8 CPU devices. -prev_xla_flags = None # Run all tests with 8 CPU devices. -def setUpModule(): - global prev_xla_flags - prev_xla_flags = os.getenv("XLA_FLAGS") - flags_str = prev_xla_flags or "" - # Don't override user-specified device count, or other XLA flags. - if "xla_force_host_platform_device_count" not in flags_str: - os.environ["XLA_FLAGS"] = (flags_str + - " --xla_force_host_platform_device_count=8") - # Clear any cached backends so new CPU backend will pick up the env var. - xla_bridge.get_backend.cache_clear() +_exit_stack = contextlib.ExitStack() +def setUpModule(): + _exit_stack.enter_context(jtu.set_host_platform_device_count(8)) if len(jax.devices()) < 8: raise unittest.SkipTest("tests require 8 devices") -# Reset to previous configuration in case other test modules will be run. def tearDownModule(): - if prev_xla_flags is None: - del os.environ["XLA_FLAGS"] - else: - os.environ["XLA_FLAGS"] = prev_xla_flags - xla_bridge.get_backend.cache_clear() + _exit_stack.close() class ShardMapTest(jtu.JaxTestCase): @@ -1352,13 +1337,13 @@ def f(x): @partial(shard_map, mesh=mesh, in_specs=P(), out_specs=P()) def g(x): - return jax.jit(lambda x: x)(x) + return jax.jit(lambda x: 1. * x)(x) jaxpr = jax.make_jaxpr(jax.vjp(g, 1.)[1])(1.) e, = jaxpr.jaxpr.eqns e1, e2 = e.params['jaxpr'].eqns self.assertEmpty(e1.outvars) - self.assertEmpty(e2.params['jaxpr'].eqns) + self.assertLen(e2.params['jaxpr'].eqns, 1) def test_fanout_specs_transpose_to_psum(self): mesh = jtu.create_global_mesh((4,), ('x',)) @@ -1482,7 +1467,7 @@ def f(x): e1, _, e2 = jaxpr.eqns self.assertLen(e1.outvars, 1) # only primal output self.assertLen(e2.invars, 2) # res and cotangent inputs - self.assertEqual(sum([e1.outvars[0] is v for v in e2.invars]), 1) + self.assertEqual(sum(e1.outvars[0] is v for v in e2.invars), 1) @parameterized.parameters(it.product([True, False], repeat=2)) def test_res_forwarding_optimization_complex(self, jit, remat): @@ -1505,7 +1490,7 @@ def f(x): e1, _, e2 = jaxpr.eqns self.assertLen(e1.outvars, 2) # one primal and one res output self.assertLen(e2.invars, 4) # two res and two cotangent inputs - self.assertEqual(sum([e1.outvars[-1] is v for v in e2.invars]), 1) + self.assertEqual(sum(e1.outvars[-1] is v for v in e2.invars), 1) @parameterized.parameters([True, False]) def test_check_rep_failure_inside_rule(self, jit): @@ -1604,6 +1589,327 @@ def f(x): with jax.disable_jit(): f(x) # don't crash + @parameterized.parameters(it.product(range(4), repeat=3)) + @jtu.run_on_devices("cpu") + def test_forwarding_correctness(self, seed, num_input_fwd, num_output_fwd): + num_args = 3 + rng = np.random.RandomState(seed) + mesh = Mesh(np.array(jax.devices()[:1]), ('i',)) + + in_perm = rng.permutation(num_args) + out_perm = rng.permutation(num_args) + + @partial(shard_map, mesh=mesh, in_specs=P('i'), out_specs=P('i')) + def f(inputs): + inputs = [inputs[i] for i in in_perm] + outputs = inputs[:num_input_fwd] + [ + jnp.exp(inputs[i]) if i < num_output_fwd else jnp.sin(inputs[i]) + for i in range(num_args - num_input_fwd)] + return [outputs[i] for i in out_perm] + + jtu.check_grads(f, (list(jnp.arange(float(num_args))[:,None]),), order=1, + modes=['rev'], atol=1e-3, rtol=1e-3) + + def test_partial_auto(self): + mesh = jtu.create_global_mesh((2, 2), ('i', 'j')) + + def g(x): + x = jax.lax.with_sharding_constraint( + x, jax.sharding.NamedSharding(mesh, P(None, 'j'))) + return x * x + + @jax.jit + def f(x): + x = shard_map(g, mesh, + in_specs=P('i', None), + out_specs=P('i', None), + check_rep=False, + auto=frozenset({'j'}))(x) + return jax.lax.with_sharding_constraint( + x, jax.sharding.NamedSharding(mesh, P('i', 'j'))) + + v = jnp.arange(32.).reshape(4, 8) + v = jax.device_put(v, jax.sharding.NamedSharding(mesh, P('i', 'j'))) + self.assertIn( + 'sharding={devices=[1,1,2,2]<=[4] last_tile_dims={manual, replicated}}', + f.lower(v).as_text('hlo'), + ) + self.assertAllClose(v*v, f(v), check_dtypes=False) + + def test_partial_auto_error_wsc_manual(self): + mesh = jtu.create_global_mesh((2, 2), ('i', 'j')) + + def g(x): + x = jax.lax.with_sharding_constraint( + x, jax.sharding.NamedSharding(mesh, P('i', 'j'))) + return x * x + + @jax.jit + def f(x): + x = shard_map(g, mesh, + in_specs=P('i', None), + out_specs=P('i', None), + check_rep=False, + auto=frozenset({'j'}))(x) + return jax.lax.with_sharding_constraint( + x, jax.sharding.NamedSharding(mesh, P('i', 'j'))) + + v = jnp.arange(32.).reshape(4, 8) + v = jax.device_put(v, jax.sharding.NamedSharding(mesh, P('i', 'j'))) + with self.assertRaisesRegex(ValueError, "manual"): + f(v) + + def test_partial_auto_error_invalid_auto(self): + mesh = jtu.create_global_mesh((2, 2), ('i', 'j')) + + def g(x): + x = jax.lax.with_sharding_constraint( + x, jax.sharding.NamedSharding(mesh, P('i', 'j'))) + return x * x + + @jax.jit + def f(x): + x = shard_map(g, mesh, + in_specs=P('i', None), + out_specs=P('i', None), + check_rep=False, + auto=frozenset({'k'}))(x) + return jax.lax.with_sharding_constraint( + x, jax.sharding.NamedSharding(mesh, P('i', 'j'))) + + v = jnp.arange(32.).reshape(4, 8) + v = jax.device_put(v, jax.sharding.NamedSharding(mesh, P('i', 'j'))) + with self.assertRaisesRegex(ValueError, "to be a subset of mesh.axis_names"): + f(v) + + def test_partial_auto_error_wrong_in_specs(self): + mesh = jtu.create_global_mesh((2, 2), ('i', 'j')) + + def g(x): + x = jax.lax.with_sharding_constraint( + x, jax.sharding.NamedSharding(mesh, P('i', 'j'))) + return x * x + + @jax.jit + def f(x): + x = shard_map(g, mesh, + in_specs=P('i', 'j'), + out_specs=P('i', None), + check_rep=False, + auto=frozenset({'j'}))(x) + return jax.lax.with_sharding_constraint( + x, jax.sharding.NamedSharding(mesh, P('i', 'j'))) + + v = jnp.arange(32.).reshape(4, 8) + v = jax.device_put(v, jax.sharding.NamedSharding(mesh, P('i', 'j'))) + with self.assertRaisesRegex(ValueError, "in_specs refers to 'j'"): + f(v) + + def test_nested_partial_auto(self): + mesh = jtu.create_global_mesh((2, 2), ('i', 'j')) + + def g(x): + return x * x + + def h(x): + return shard_map(g, mesh, + in_specs=P(None, 'j'), + out_specs=P(None, 'j'))(x) + + @jax.jit + def f(x): + return shard_map(h, mesh, + in_specs=P('i', None), + out_specs=P('i', None), + check_rep=False, + auto=frozenset({'j'}))(x) + + v = jnp.arange(32.).reshape(4, 8) + v = jax.device_put(v, jax.sharding.NamedSharding(mesh, P('i', 'j'))) + self.assertAllClose(v*v, f(v), check_dtypes=False) + + + def test_vmap_grad_shmap_spmd_axis_name_residuals(self): + # https://github.com/google/jax/pull/21032 + mesh = jtu.create_global_mesh((4, 2), ('i', 'j')) + + @partial( + shard_map, + mesh=mesh, + in_specs=P('j'), + out_specs=P('j'), + ) + def f(x): + return jnp.sin(x) + + xs = jnp.arange(4 * 16.).reshape(4, 16) + + jax.vmap(jax.grad(lambda x: f(x).sum()), spmd_axis_name='i')(xs) # don't crash + + def test_vmap_grad_remat_shmap_spmd_axis_name_residuals(self): + # https://github.com/google/jax/pull/21056 + mesh = jtu.create_global_mesh((4, 2), ('i', 'j')) + + @partial(jax.remat, policy=lambda *_, **__: True) + @partial( + shard_map, + mesh=mesh, + in_specs=P('j'), + out_specs=P('j'), + ) + def f(x): + return jnp.sin(x) + + xs = jnp.arange(4 * 16.).reshape(4, 16) + + jax.vmap(jax.grad(lambda x: f(x).sum()), spmd_axis_name='i')(xs) # don't crash + + def test_grad_shmap_residuals_axis_names_in_mesh_order(self): + # https://github.com/google/jax/issues/21236 + mesh = jtu.create_global_mesh((4, 2, 1, 1), ('i', 'j', 'k', 'a')) + + @partial( + shard_map, + mesh=mesh, + in_specs=P('j'), + out_specs=P('j'), + ) + def f(x): + return jnp.sin(x) + + xs = jnp.arange(16.) + + ir = jax.jit(jax.grad(lambda x: f(x).sum())).lower(xs) + self.assertIn( + '{jax.result_info = "[(\'i\', \'j\', \'k\', \'a\')]"}', + ir.as_text() + ) + + def test_vmap_spmd_axis_name_error(self): + mesh = jtu.create_global_mesh((4, 2), ('i', 'j')) + + @partial( + shard_map, + mesh=mesh, + in_specs=P('i'), + out_specs=P('i'), + ) + def f(x): + return jnp.sin(x) + + xs = jnp.arange(4 * 16.).reshape(4, 16) + with self.assertRaisesRegex(ValueError, "spmd_axis_name cannot appear"): + jax.vmap(f, spmd_axis_name='i')(xs) + + @partial( + shard_map, + mesh=mesh, + in_specs=P('j'), + out_specs=P(('i', 'j')), + check_rep=False, + ) + def g(x): + return jnp.sin(x) + + xs = jnp.arange(4 * 16.).reshape(4, 16) + with self.assertRaisesRegex(ValueError, "spmd_axis_name cannot appear"): + jax.vmap(g, spmd_axis_name='i')(xs) + + def test_in_spec_none(self): + mesh = jtu.create_global_mesh((4, 2), ('i', 'j')) + + x = jnp.arange(8).reshape(4, 2) + + def f(o, x): + self.assertIs(o, obj) + return jnp.sin(x) + + obj = object() + y = shard_map(f, mesh, (None, P('i')), P('i'))(obj, x) + self.assertAllClose(y, jnp.sin(x), check_dtypes=False) + + obj = None + y = shard_map(f, mesh, (None, P('i')), P('i'))(None, x) + self.assertAllClose(y, jnp.sin(x), check_dtypes=False) + + def f2(o, x): + self.assertIsInstance(o, dict) + self.assertIs(o['a'], obj['a']) + return jnp.sin(x) + + obj = {'a': object()} + y = shard_map(f2, mesh, ({'a': None}, P('i')), P('i'))(obj, x) + self.assertAllClose(y, jnp.sin(x), check_dtypes=False) + + def f3(x, o): + self.assertIs(o, obj) + return jnp.sin(x) + + obj = object() + y = shard_map(f3, mesh, (P('i'), None), P('i'))(x, obj) + self.assertAllClose(y, jnp.sin(x), check_dtypes=False) + + obj = None + y = shard_map(f3, mesh, (P('i'), None), P('i'))(x, obj) + self.assertAllClose(y, jnp.sin(x), check_dtypes=False) + + def f4(o1, o2, x, o3): + self.assertIs(o1, obj1) + self.assertIs(o2[0], obj2[0]) + self.assertIs(o2[1], obj2[1]) + self.assertIs(o3, obj3) + return jnp.sin(x) + + obj1 = object() + obj2 = (object(), object()) + obj3 = object() + y = shard_map(f4, mesh, (None, None, P('i'), None), P('i'))(obj1, obj2, x, obj3) + self.assertAllClose(y, jnp.sin(x), check_dtypes=False) + + def test_in_spec_none_divisibility_errors(self): + mesh = jtu.create_global_mesh((4, 2), ('i', 'j')) + x = jnp.arange(4).reshape(2, 2) + + with self.assertRaisesRegex(ValueError, 'divisible'): + shard_map(lambda *_: None, mesh, (None, P('i')), None)(object(), x) + + with self.assertRaisesRegex(ValueError, 'divisible'): + shard_map(lambda *_: None, mesh, (P('i'), None), None)(x, object()) + + with self.assertRaisesRegex(ValueError, 'divisible'): + shard_map(lambda *_: None, mesh, (P('i'), None), None + )(x, (object(), object())) + + with self.assertRaisesRegex(ValueError, 'divisible'): + shard_map(lambda *_: None, mesh, (P('i'), (None, None)), None, + )(x, (object(), object())) + + with self.assertRaisesRegex(ValueError, 'divisible'): + shard_map(lambda *_: None, mesh, ((None, None), P('i')), None, + )((object(), object()), x) + + def test_in_spec_none_rank_errors(self): + mesh = jtu.create_global_mesh((4, 2), ('i', 'j')) + x = jnp.arange(4) + + with self.assertRaisesRegex(ValueError, 'rank'): + shard_map(lambda *_: None, mesh, (None, P('i', 'j')), None)(object(), x) + + with self.assertRaisesRegex(ValueError, 'rank'): + shard_map(lambda *_: None, mesh, (P('i', 'j'), None), None)(x, object()) + + with self.assertRaisesRegex(ValueError, 'rank'): + shard_map(lambda *_: None, mesh, (P('i', 'j'), None), None + )(x, (object(), object())) + + with self.assertRaisesRegex(ValueError, 'rank'): + shard_map(lambda *_: None, mesh, (P('i', 'j'), (None, None)), None, + )(x, (object(), object())) + + with self.assertRaisesRegex(ValueError, 'rank'): + shard_map(lambda *_: None, mesh, ((None, None), P('i', 'j')), None, + )((object(), object()), x) + class FunSpec(NamedTuple): name: str @@ -1970,8 +2276,6 @@ class CustomPartitionerTest(jtu.JaxTestCase): def skip_if_custom_partitioning_not_supported(self): if jtu.is_cloud_tpu(): raise unittest.SkipTest("Custom partitioning is not supported on libtpu.") - if xla_bridge.using_pjrt_c_api(): - raise unittest.SkipTest('custom partitioning not implemented in PJRT C API') def test_custom_partitioning(self): self.skip_if_custom_partitioning_not_supported() diff --git a/tests/source_info_test.py b/tests/source_info_test.py index aaa3abf552d8..0f876de1c20f 100644 --- a/tests/source_info_test.py +++ b/tests/source_info_test.py @@ -19,11 +19,10 @@ import jax from jax import lax -from jax import config from jax._src import source_info_util from jax._src import test_util as jtu -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() class SourceInfoTest(jtu.JaxTestCase): diff --git a/tests/sourcemap_test.py b/tests/sourcemap_test.py new file mode 100644 index 000000000000..bb741324bdd9 --- /dev/null +++ b/tests/sourcemap_test.py @@ -0,0 +1,89 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json + +from absl.testing import absltest +from absl.testing import parameterized +from jax._src import sourcemap +from jax._src import test_util as jtu + + +class SourceMapTest(jtu.JaxTestCase): + + @parameterized.parameters( + (0,), + (1,), + (2,), + (3,), + (4,), + (5,), + (-1,), + (-2,), + (-3,), + (-4,), + (123,), + (456,), + (1024,), + (1025,), + (2**16,), + (2**31 - 1,), + ) + def test_roundtrip_vlq(self, value): + actual = sourcemap.decode_vlq(sourcemap.encode_vlq(value)) + self.assertEqual(actual, value) + + @parameterized.parameters( + (b"A",), + (b"C",), + (b"AAAA",), + (b"ACDE",), + (b"AACAA",), + ) + def test_roundtrip_segment(self, enc): + actual = sourcemap.encode_segment(sourcemap.decode_segment(enc)) + self.assertEqual(actual, enc) + + def test_roundtrip_sourcemap_json(self): + data = { + "version": 3, + # "file": "out.js", + # "sourceRoot": "", + "sources": ["foo.js", "bar.js"], + "sourcesContent": [None, None], + "names": ["src", "maps", "are", "fun"], + "mappings": "A,AAAC;;AACDE", + } + json_data = json.dumps(data) + json_data_roundtripped = sourcemap.SourceMap.from_json(json_data).to_json() + self.assertEqual(json.loads(json_data_roundtripped), data) + + def test_generate_mappings(self): + expected = "A,AAAC;;AACDE" + gen = sourcemap.MappingsGenerator() + # A + gen.new_group() + gen.new_segment(0) + # ,AAAC + gen.new_segment(0, 0, 0, 1) + # ; + gen.new_group() + # ;AACDE + gen.new_group() + gen.new_segment(0, 0, 1, 0, 2) + self.assertEqual(sourcemap.serialize_mappings(gen.mappings()), expected) + + +if __name__ == "__main__": + absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/sparse_bcoo_bcsr_test.py b/tests/sparse_bcoo_bcsr_test.py index e9091a9b9c02..680bdda5675a 100644 --- a/tests/sparse_bcoo_bcsr_test.py +++ b/tests/sparse_bcoo_bcsr_test.py @@ -22,7 +22,6 @@ from absl.testing import absltest import jax -from jax import config from jax import jit from jax import lax from jax import vmap @@ -40,7 +39,7 @@ from jax.util import split_list import numpy as np -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() COMPATIBLE_SHAPE_PAIRS = [ [(), ()], @@ -151,7 +150,7 @@ def _is_required_cuda_version_satisfied(cuda_version): class BCOOTest(sptu.SparseTestCase): def gpu_matmul_warning_context(self, msg): - if sptu.GPU_LOWERING_ENABLED and config.jax_bcoo_cusparse_lowering: + if jax.config.jax_bcoo_cusparse_lowering: return self.assertWarnsRegex(sparse.CuSparseEfficiencyWarning, msg) return contextlib.nullcontext() @@ -479,9 +478,6 @@ def test_bcoo_dot_general( dtype=jtu.dtypes.floating + jtu.dtypes.complex, ) @jax.default_matmul_precision("float32") - @unittest.skipIf( - not sptu.GPU_LOWERING_ENABLED, "test requires cusparse/hipsparse" - ) @jtu.run_on_devices("gpu") def test_bcoo_dot_general_cusparse( self, lhs_shape, rhs_shape, dtype, lhs_contracting, rhs_contracting @@ -528,9 +524,6 @@ def f_sparse(lhs_bcoo, lhs, rhs): dtype=jtu.dtypes.floating + jtu.dtypes.complex, ) @jax.default_matmul_precision("float32") - @unittest.skipIf( - not sptu.GPU_LOWERING_ENABLED, "test requires cusparse/hipsparse" - ) @jtu.run_on_devices("gpu") def test_bcoo_batched_matmat_cusparse( self, @@ -581,9 +574,6 @@ def f_sparse(lhs_bcoo, lhs, rhs): ], dtype=jtu.dtypes.floating + jtu.dtypes.complex, ) - @unittest.skipIf( - not sptu.GPU_LOWERING_ENABLED, "test requires cusparse/hipsparse" - ) @jtu.run_on_devices("gpu") def test_bcoo_batched_matmat_default_lowering( self, @@ -615,9 +605,6 @@ def test_bcoo_batched_matmat_default_lowering( matmat_default_lowering_fallback = sp_matmat(lhs_bcoo, rhs) self.assertArraysEqual(matmat_expected, matmat_default_lowering_fallback) - @unittest.skipIf( - not sptu.GPU_LOWERING_ENABLED, "test requires cusparse/hipsparse" - ) @jtu.run_on_devices("gpu") def test_bcoo_dot_general_oob_and_unsorted_indices_cusparse(self): """Tests bcoo dot general with out-of-bound and unsorted indices.""" @@ -1952,6 +1939,18 @@ def test_bcsr_concatenate(self, shape, dtype, n_batch, n_dense, dimension): if jnp.issubdtype(dtype, jnp.floating): self._CheckGradsSparse(dense_func, sparse_func, args_maker) + def test_bcoo_spdot_abstract_eval_bug(self): + # Regression test for https://github.com/google/jax/issues/21921 + lhs = sparse.BCOO( + (jnp.float32([[1]]), lax.broadcasted_iota(jnp.int32, (10, 1, 1), 0)), + shape=(10, 10)) + rhs = sparse.BCOO( + (jnp.float32([1]), jnp.int32([[3]])), + shape=(10,)) + args_maker = lambda: [lhs, rhs] + def func(lhs, rhs): + return (lhs @ rhs).todense() + self._CompileAndCheck(func, args_maker) if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/sparse_nm_test.py b/tests/sparse_nm_test.py new file mode 100644 index 000000000000..9ecf30eb6229 --- /dev/null +++ b/tests/sparse_nm_test.py @@ -0,0 +1,209 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math + +import numpy as np +from absl.testing import absltest +from absl.testing import parameterized + +import jax +import jax.numpy as jnp +from jax import dtypes +from jax._src import config +from jax._src import test_util as jtu +from jax.experimental.sparse import nm + +jax.config.parse_flags_with_absl() + + +class SpmmTest(jtu.JaxTestCase): + def setUp(self): + if not jtu.test_device_matches(["gpu"]): + self.skipTest("Only works on GPU") + if (jtu.test_device_matches(["cuda"]) and + not jtu.is_cuda_compute_capability_at_least("8.0")): + self.skipTest("Only works on GPUs with capability >= sm80") + super().setUp() + + # ----- Test different input shapes + @parameterized.product( + tile_m=(32, 128), + tile_n=(32, 128), + tile_k=(32, 128), + batch=(None, 5), + sparse_idx=(0, 1), + ) + @jtu.run_on_devices("gpu") + def test_shapes(self, tile_m, tile_n, tile_k, batch, sparse_idx): + # Build keyword arguments + kwargs = { + "dimension_numbers": (((1,), (1,)), (tuple(), tuple())), + "sparse_operand_idx": sparse_idx, + } + if batch: + kwargs["dimension_numbers"] = (((2,), (2,)), ((0,), (0,))) + + # Build input data + batch_dims = (batch,) if batch else tuple() + lhs = ( + (np.arange((batch or 1) * tile_m * tile_k) % 11) + .astype(dtypes.bfloat16) + .reshape(batch_dims + (tile_m, tile_k)) + ) + rhs = ( + (np.arange((batch or 1) * tile_n * tile_k) % 13) + .astype(dtypes.bfloat16) + .reshape(batch_dims + (tile_n, tile_k)) + ) + + # Build sparsity mask and metadata + sp = [lhs, rhs][sparse_idx] + mask = np.tile([True, False], math.prod(sp.shape) // 2).reshape(sp.shape) + sparse = sp[mask].reshape(sp.shape[:-1] + (sp.shape[-1] // 2,)) + meta = nm.nm_pack(mask) + + # Calculate sparse and dense dots + if sparse_idx == 0: + dot_sparse = nm.nm_spmm(sparse, rhs, meta, **kwargs) + dot_dense = jnp.einsum("...mk,...nk->...mn", (lhs * mask), rhs) + else: + dot_sparse = nm.nm_spmm(lhs, sparse, meta, **kwargs) + dot_dense = jnp.einsum("...mk,...nk->...mn", lhs, (rhs * mask)) + + # Verify the result + jtu.check_eq(dot_sparse, dot_dense.astype(dtypes.bfloat16)) + + # ----- Test different input types + @parameterized.product( + lhs_type=[jnp.int8, jnp.int16, jnp.float16, jnp.bfloat16], + rhs_type=[jnp.bfloat16], + output_type=[jnp.bfloat16, jnp.float32], + ) + @jtu.run_on_devices("gpu") + def test_types(self, lhs_type, rhs_type, output_type): + tile_m, tile_n, tile_k = 64, 32, 128 + + # Build input data + lhs = ( + (np.arange(tile_m * tile_k) % 17) + .astype(lhs_type) + .reshape((tile_m, tile_k)) + ) + rhs = ( + (np.arange(tile_k * tile_n) % 19) + .astype(rhs_type) + .reshape((tile_k, tile_n)) + ) + + # Build sparsity mask and metadata + mask = np.tile([True, False], tile_m * tile_k // 2).reshape(lhs.shape) + sparse = lhs[mask].reshape(tile_m, tile_k // 2) + meta = nm.nm_pack(mask) + + # Calculate sparse and dense dots + dot_sparse = nm.nm_spmm(sparse, rhs, meta, output_dtype=output_type) + dot_dense = (lhs * mask) @ rhs + + # Verify the result + jtu.check_close(dot_sparse, dot_dense.astype(output_type), rtol=0.01) + + # ----- Test validation + @jtu.run_on_devices("gpu") + def test_validate_nm_pack(self): + with self.assertRaisesRegex(TypeError, "Mask should be bool"): + nm.nm_pack(jnp.zeros(16, jnp.int8)) + with self.assertRaisesRegex( + TypeError, "Inner dimension size should be divisible by 16" + ): + nm.nm_pack(jnp.array([False] * 8)) + + @jtu.run_on_devices("gpu") + def test_validate_nm_spmm(self): + batch, tile_m, tile_n, tile_k = 2, 64, 32, 128 + lhs = jnp.zeros((batch, tile_m, tile_k // 2), dtype=jnp.bfloat16) + rhs = jnp.zeros((batch, tile_k, tile_n), dtype=jnp.bfloat16) + meta = jnp.zeros((batch, tile_m, tile_k // 16), dtype=jnp.uint16) + + if config.enable_x64.value: + with self.assertRaisesRegex(TypeError, "Unsupported lhs input type"): + nm.nm_spmm(jnp.zeros(lhs.shape, dtype=jnp.int64), rhs, meta) + with self.assertRaisesRegex(TypeError, "Unsupported rhs input type"): + nm.nm_spmm(lhs, jnp.zeros(rhs.shape, dtype=jnp.int64), meta) + with self.assertRaisesRegex(TypeError, "Unsupported output type"): + nm.nm_spmm(lhs, rhs, meta, output_dtype=jnp.int64) + + # Check dimension numbers + nm_spmm_with_dnums = lambda c, b: nm.nm_spmm( + lhs, rhs, meta, dimension_numbers=(c, b) + ) + with self.assertRaisesRegex( + TypeError, "Only single contracting dimension is supported" + ): + nm_spmm_with_dnums(((0, 2), (0, 1)), (tuple(), tuple())) + with self.assertRaisesRegex( + TypeError, "Incorrect dimension numbers for lhs" + ): + nm_spmm_with_dnums(((2,), (1,)), ((2,), (0,))) + with self.assertRaisesRegex( + TypeError, "Incorrect dimension numbers for rhs" + ): + nm_spmm_with_dnums(((2,), (1,)), ((0,), (1,))) + with self.assertRaisesRegex( + TypeError, "Only single non-contracting dimension is supported" + ): + nm_spmm_with_dnums(((2,), (1,)), (tuple(), tuple())) + with self.assertRaisesRegex( + TypeError, "Batch dimension sizes do not match" + ): + nm.nm_spmm( + lhs, + rhs.reshape(1, tile_k, tile_n * batch), + meta, + dimension_numbers=(((2,), (1,)), ((0,), (0,))), + ) + + # Check metadata + nm_spmm_with_meta = lambda m: nm.nm_spmm( + lhs, rhs, m, dimension_numbers=(((2,), (1,)), ((0,), (0,))) + ) + with self.assertRaisesRegex(TypeError, "Metadata must be uint16"): + nm_spmm_with_meta(jnp.zeros(meta.shape, dtype=jnp.uint8)) + with self.assertRaisesRegex( + TypeError, "Metadata shape must match the operand shape" + ): + nm_spmm_with_meta(meta.reshape(1, batch * tile_m, tile_k // 16)) + with self.assertRaisesRegex( + TypeError, + "Metadata must be exactly 8 times less than the contracting dimension" + " for 2:4 structured sparsity", + ): + nm_spmm_with_meta(jnp.repeat(meta, 2, axis=-1)) + with self.assertRaisesRegex( + TypeError, "Contracting dimension must be the minor one" + ): + nm.nm_spmm(lhs, rhs, meta, dimension_numbers=(((1,), (1,)), ((0,), (0,)))) + with self.assertRaisesRegex( + TypeError, "Contracting dimension sizes should have 2:4 ratio" + ): + nm.nm_spmm( + lhs, + jnp.repeat(rhs, 2, axis=1), + meta, + dimension_numbers=(((2,), (1,)), ((0,), (0,))), + ) + + +if __name__ == "__main__": + absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/sparse_test.py b/tests/sparse_test.py index 153c0d5d7a16..49438f411ff5 100644 --- a/tests/sparse_test.py +++ b/tests/sparse_test.py @@ -16,14 +16,12 @@ from functools import partial import itertools import math -import unittest from absl.testing import absltest from absl.testing import parameterized import jax import jax.random -from jax import config from jax import dtypes from jax.experimental import sparse from jax.experimental.sparse import coo as sparse_coo @@ -44,7 +42,7 @@ import numpy as np import scipy.sparse -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() all_dtypes = jtu.dtypes.integer + jtu.dtypes.floating + jtu.dtypes.complex @@ -351,9 +349,6 @@ def test_coo_sorted_indices(self): mat_resorted = mat_unsorted._sort_indices() self.assertArraysEqual(mat.todense(), mat_resorted.todense()) - @unittest.skipIf( - not sptu.GPU_LOWERING_ENABLED, "test requires cusparse/hipsparse" - ) @jtu.run_on_devices("gpu") def test_coo_sorted_indices_gpu_lowerings(self): dtype = jnp.float32 @@ -544,9 +539,7 @@ def test_coo_matmul_ad(self, shape, dtype, bshape): dtype=_lowerings.SUPPORTED_DATA_DTYPES, transpose=[True, False], ) - @unittest.skipIf( - not sptu.GPU_LOWERING_ENABLED, "test requires cusparse/hipsparse" - ) + @jtu.run_on_devices("gpu") def test_coo_spmv(self, shape, dtype, transpose): rng_sparse = sptu.rand_sparse(self.rng()) rng_dense = jtu.rand_default(self.rng()) @@ -569,9 +562,7 @@ def test_coo_spmv(self, shape, dtype, transpose): dtype=_lowerings.SUPPORTED_DATA_DTYPES, transpose=[True, False], ) - @unittest.skipIf( - not sptu.GPU_LOWERING_ENABLED, "test requires cusparse/hipsparse" - ) + @jtu.run_on_devices("gpu") def test_coo_spmm(self, shape, dtype, transpose): rng_sparse = sptu.rand_sparse(self.rng()) rng_dense = jtu.rand_default(self.rng()) @@ -594,9 +585,7 @@ def test_coo_spmm(self, shape, dtype, transpose): dtype=_lowerings.SUPPORTED_DATA_DTYPES, transpose=[True, False], ) - @unittest.skipIf( - not sptu.GPU_LOWERING_ENABLED, "test requires cusparse/hipsparse" - ) + @jtu.run_on_devices("gpu") def test_csr_spmv(self, shape, dtype, transpose): rng_sparse = sptu.rand_sparse(self.rng()) rng_dense = jtu.rand_default(self.rng()) @@ -617,9 +606,7 @@ def test_csr_spmv(self, shape, dtype, transpose): dtype=_lowerings.SUPPORTED_DATA_DTYPES, transpose=[True, False], ) - @unittest.skipIf( - not sptu.GPU_LOWERING_ENABLED, "test requires cusparse/hipsparse" - ) + @jtu.run_on_devices("gpu") def test_csr_spmm(self, shape, dtype, transpose): rng_sparse = sptu.rand_sparse(self.rng()) rng_dense = jtu.rand_default(self.rng()) @@ -1083,8 +1070,6 @@ class SparseSolverTest(sptu.SparseTestCase): ) @jtu.run_on_devices("cpu", "cuda") def test_sparse_qr_linear_solver(self, size, reorder, dtype): - if jtu.test_device_matches(["cuda"]) and not sptu.GPU_LOWERING_ENABLED: - raise unittest.SkipTest('test requires cusparse/cusolver') rng = sptu.rand_sparse(self.rng()) a = rng((size, size), dtype) nse = (a != 0).sum() @@ -1110,8 +1095,6 @@ def sparse_solve(data, indices, indptr, b): ) @jtu.run_on_devices("cpu", "cuda") def test_sparse_qr_linear_solver_grads(self, size, dtype): - if jtu.test_device_matches(["cuda"]) and not sptu.GPU_LOWERING_ENABLED: - raise unittest.SkipTest('test requires cusparse/cusolver') rng = sptu.rand_sparse(self.rng()) a = rng((size, size), dtype) nse = (a != 0).sum() diff --git a/tests/sparsify_test.py b/tests/sparsify_test.py index 998ce1c4067d..46086511d8b5 100644 --- a/tests/sparsify_test.py +++ b/tests/sparsify_test.py @@ -22,7 +22,7 @@ import numpy as np import jax -from jax import config, jit, lax +from jax import jit, lax import jax.numpy as jnp import jax._src.test_util as jtu from jax.experimental.sparse import BCOO, BCSR, sparsify, todense, SparseTracer @@ -31,7 +31,7 @@ from jax.experimental.sparse.util import CuSparseEfficiencyWarning from jax.experimental.sparse import test_util as sptu -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() def rand_sparse(rng, nse=0.5, post=lambda x: x, rand_method=jtu.rand_default): def _rand_sparse(shape, dtype, nse=nse): diff --git a/tests/stack_test.py b/tests/stack_test.py index acefc0630018..655a42571b01 100644 --- a/tests/stack_test.py +++ b/tests/stack_test.py @@ -17,13 +17,13 @@ from absl.testing import absltest +import jax import jax.numpy as jnp from jax._src.lax.stack import Stack from jax._src import test_util as jtu -from jax import config -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() class StackTest(jtu.JaxTestCase): diff --git a/tests/state_test.py b/tests/state_test.py index 1f109536fc16..b6dbb490b794 100644 --- a/tests/state_test.py +++ b/tests/state_test.py @@ -14,10 +14,10 @@ from __future__ import annotations -from collections.abc import Sequence +from collections.abc import Callable, Sequence from functools import partial import itertools as it -from typing import Any, Callable, NamedTuple, Union +from typing import Any, NamedTuple, Union from absl.testing import absltest from absl.testing import parameterized @@ -29,7 +29,6 @@ from jax._src import config from jax._src import linear_util as lu from jax._src.interpreters import partial_eval as pe -from jax._src import prng from jax._src import test_util as jtu from jax._src.util import tuple_insert import jax.numpy as jnp @@ -49,7 +48,7 @@ ref_addupdate, ref_get, ref_set, ref_swap) from jax._src.state.types import (shaped_array_ref, ReadEffect, WriteEffect, - AccumEffect, RefEffect, AbstractRef) + AccumEffect, AbstractRef) config.parse_flags_with_absl() @@ -1496,55 +1495,6 @@ def _body(ref): jtu.check_grads(f, (0.5,), order=3) -class MutableArrayTest(jtu.JaxTestCase): - - @parameterized.parameters([True, False]) - def test_basic(self, jit): - def f(x_mut): - x_mut[...] += 1. - x_mut[0] += 1 - x_mut[1] += 5 - - if jit: - f = jax.jit(f) - - x_mut = core.mutable_array(jnp.zeros(3)) - f(x_mut) - - self.assertAllClose(x_mut[...], jnp.array([2., 6., 1.]), - check_dtypes=False) - - jaxpr = jax.make_jaxpr(f)(x_mut) - self.assertTrue(any(isinstance(e, RefEffect) for e in jaxpr.effects)) - - def test_staging_error(self): - x = jnp.zeros(3) - with self.assertRaises(Exception): - jax.jit(core.mutable_array)(x) - - @parameterized.parameters([True, False]) - def test_multiple_inputs_and_outputs(self, jit): - def f(x_mut, y, z_mut, w): - x_mut[...] += 1 - z_mut[...] += 1 - return x_mut[...] + y + z_mut[...] + w, y + w - - if jit: - f = jax.jit(f) - - x_mut = core.mutable_array(jnp.zeros((1, 3))) - y = jnp.ones((2, 3)) - z_mut = core.mutable_array(jnp.zeros((2, 3))) - w = jnp.ones((2, 1)) - - out1, out2 = f(x_mut, y, z_mut, w) - - self.assertAllClose(x_mut[...], jnp.ones((1, 3)), check_dtypes=False) - self.assertAllClose(z_mut[...], jnp.ones((2, 3)), check_dtypes=False) - self.assertAllClose(out1, 4 * jnp.ones((2, 3)), check_dtypes=False) - self.assertAllClose(out2, y + w, check_dtypes=False) - - if CAN_USE_HYPOTHESIS: class FuncSpec(NamedTuple): @@ -1735,8 +1685,8 @@ def ref(x): y, impl_vjp = jax.vjp(impl, x) y_ref, ref_vjp = jax.vjp(ref, x) self.assertAllClose(y, y_ref) - t = random.normal(prng.reuse_key(k2), x.shape) - y2 = random.normal(prng.reuse_key(k1), y.shape) + t = random.normal(jax.random.clone(k2), x.shape) + y2 = random.normal(jax.random.clone(k1), y.shape) self.assertAllClose(impl_vjp(t), ref_vjp(t)) # Second order @@ -1752,7 +1702,7 @@ def ref(x): (x,), impl_vjp2 = jax.vjp(impl_vjp, t2) (x_ref,), ref_vjp2 = jax.vjp(ref_vjp, t2) self.assertAllClose(x, x_ref) - y2 = random.normal(prng.reuse_key(k1), y.shape) + y2 = random.normal(jax.random.clone(k1), y.shape) self.assertAllClose(impl_vjp2((y2,)), ref_vjp2((y2,))) if __name__ == '__main__': diff --git a/tests/stax_test.py b/tests/stax_test.py index 351a0fdb3d71..6850f36a02ea 100644 --- a/tests/stax_test.py +++ b/tests/stax_test.py @@ -18,13 +18,13 @@ import numpy as np +import jax from jax._src import test_util as jtu from jax import random from jax.example_libraries import stax from jax import dtypes -from jax import config -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() def random_inputs(rng, input_shape): diff --git a/tests/svd_test.py b/tests/svd_test.py index 7284ff60d3a8..52833b4349ab 100644 --- a/tests/svd_test.py +++ b/tests/svd_test.py @@ -149,6 +149,26 @@ def testSvdWithOnRankDeficientInput(self, m, r, log_cond): np.testing.assert_almost_equal(diff, 1E-4, decimal=2) + @jtu.sample_product( + [dict(m=m, r=r) for m, r in zip([8, 8, 8, 10], [3, 5, 7, 9])], + ) + def testSvdWithOnRankDeficientInputZeroColumns(self, m, r): + """Tests SVD with rank-deficient input.""" + with jax.default_matmul_precision('float32'): + np.random.seed(1235) + a = np.random.randn(m, m).astype(_SVD_TEST_DTYPE) + d = np.ones(m).astype(_SVD_TEST_DTYPE) + d[r:m] = 0 + a = a @ np.diag(d) + + with jax.default_matmul_precision('float32'): + u, s, v = svd.svd(a, full_matrices=True, hermitian=False) + diff = np.linalg.norm(a - (u * s) @ v) + np.testing.assert_almost_equal(diff, 1e-4, decimal=2) + # Check that u and v are orthogonal. + self.assertAllClose(u.T.conj() @ u, np.eye(m), atol=10 * _SVD_TEST_EPS) + self.assertAllClose(v.T.conj() @ v, np.eye(m), atol=10 * _SVD_TEST_EPS) + @jtu.sample_product( [dict(m=m, n=n) for m, n in zip([2, 8, 10, 20], [4, 6, 10, 18])], log_cond=np.linspace(1, _MAX_LOG_CONDITION_NUM, 4), diff --git a/tests/third_party/scipy/line_search_test.py b/tests/third_party/scipy/line_search_test.py index 5e7d9a943352..5a22372b732a 100644 --- a/tests/third_party/scipy/line_search_test.py +++ b/tests/third_party/scipy/line_search_test.py @@ -3,13 +3,12 @@ import jax from jax import grad -from jax import config import jax.numpy as jnp import jax._src.test_util as jtu from jax._src.scipy.optimize.line_search import line_search -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() class TestLineSearch(jtu.JaxTestCase): @@ -87,6 +86,7 @@ def bind_index(func, idx): @jtu.sample_product( name=['_line_func_1', '_line_func_2'], ) + @jax.default_matmul_precision("float32") def test_line_search_wolfe2(self, name): def bind_index(func, idx): # Remember Python's closure semantics! diff --git a/tests/transfer_guard_test.py b/tests/transfer_guard_test.py index fa08c52b6aff..6a255b0a1b09 100644 --- a/tests/transfer_guard_test.py +++ b/tests/transfer_guard_test.py @@ -25,9 +25,7 @@ import jax._src.test_util as jtu import jax.numpy as jnp -from jax import config - -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() def _host_to_device_funcs(): @@ -101,12 +99,16 @@ def _all_funcs(): ] +# TransferGuardTest disables `--jax_enable_checks` because it +# can prematurely fetch the value of device arrays and make +# device-to-host tests to incur no transfers unexpectedly. +@jtu.with_config(jax_enable_checks=False) class TransferGuardTest(jtu.JaxTestCase): - # `_default_config` is used by `jtu.JaxTestCase` to update the JAX config for - # every test case. TransferGuardTest disables `--jax_enable_checks` because it - # can prematurely fetch the value of device arrays and make device-to-host - # tests to incur no transfers unexpectedly. - _default_config = {"jax_enable_checks": False} + def setUp(self): + super().setUp() + # Nearly all test methods use the deprecated device argument to JIT. + self.enter_context(jtu.ignore_warning(category=DeprecationWarning, + message="backend and device argument")) @contextlib.contextmanager def assertAllows(self, func_name): diff --git a/tests/tree_util_test.py b/tests/tree_util_test.py index f6204ff3ba03..23ddf73904b5 100644 --- a/tests/tree_util_test.py +++ b/tests/tree_util_test.py @@ -17,15 +17,15 @@ import functools import pickle import re +from typing import TypeVar from absl.testing import absltest from absl.testing import parameterized - import jax -from jax import tree_util from jax import flatten_util +from jax import tree_util from jax._src import test_util as jtu -from jax._src.tree_util import prefix_errors, flatten_one_level +from jax._src.tree_util import flatten_one_level, prefix_errors import jax.numpy as jnp @@ -42,6 +42,19 @@ class ANamedTupleSubclass(ATuple): tree_util.register_pytree_node(ATuple2, lambda o: ((o.foo,), o.bar), lambda bar, foo: ATuple2(foo[0], bar)) +BadFlattenNonTuple = collections.namedtuple("ATuple2", ("foo", "bar")) +tree_util.register_pytree_node(BadFlattenNonTuple, lambda o: "hello", + lambda bar, foo: ATuple2(foo[0], bar)) + +BadFlattenBadArityTuple = collections.namedtuple("ATuple2", ("foo", "bar")) +tree_util.register_pytree_node(BadFlattenBadArityTuple, lambda o: (2, 3, 4), + lambda bar, foo: ATuple2(foo[0], bar)) + +BadFlattenNonIterableLeaves = collections.namedtuple("ATuple2", ("foo", "bar")) +tree_util.register_pytree_node(BadFlattenNonIterableLeaves, lambda o: (7, 7), + lambda bar, foo: ATuple2(foo[0], bar)) + + class AnObject: def __init__(self, x, y, z): @@ -129,6 +142,27 @@ def tree_unflatten(cls, meta, data): data, meta = tree_util.tree_flatten(tree_util.tree_unflatten(meta, data)) return FlatCache(None, leaves=data, treedef=meta) +_T = TypeVar("_T") + + +# Inspired by Flax. +def pytree_node_dataclass(clz: _T, **kwargs) -> _T: + data_clz = dataclasses.dataclass(**kwargs)(clz) # type: ignore + meta_fields = [] + data_fields = [] + for field_info in dataclasses.fields(data_clz): + is_pytree_node = field_info.metadata.get("pytree_node", True) + if is_pytree_node: + data_fields.append(field_info.name) + else: + meta_fields.append(field_info.name) + + jax.tree_util.register_dataclass( + data_clz, data_fields, meta_fields + ) + + return data_clz + @tree_util.register_static class StaticInt(int): @@ -170,19 +204,11 @@ def __eq__(self, other): ([3, ATuple(foo=(3, ATuple(foo=3, bar=None)), bar={"baz": 34})],), ([AnObject(3, None, [4, "foo"])],), ([AnObject2(3, None, [4, "foo"])],), - (Special(2, 3.),), + (Special(2, 3.0),), ({"a": 1, "b": 2},), (StaticInt(1),), (StaticTuple((2, 3)),), (StaticDict(foo=4, bar=5),), - (collections.OrderedDict([("foo", 34), ("baz", 101), ("something", -42)]),), - (collections.defaultdict(dict, - [("foo", 34), ("baz", 101), ("something", -42)]),), - (ANamedTupleSubclass(foo="hello", bar=3.5),), - (FlatCache(None),), - (FlatCache(1),), - (FlatCache({"a": [1, 2]}),), - (BlackBox(value=2),), ) @@ -205,6 +231,42 @@ def __eq__(self, other): "PyTreeDef(CustomNode(StaticDict[{'foo': 4, 'bar': 5}], []))", ) +@pytree_node_dataclass +class ADataclass: + x: tuple[int, int] + y: int + +@pytree_node_dataclass +class ADataclassWithMeta: + x: tuple[int, int] + y: int + z: int = dataclasses.field(metadata={"pytree_node": False}) + +TREES += ( + (ADataclass(x=(1, 2), y=3),), + (ADataclassWithMeta(x=(1, 2), y=3, z=4),), +) +TREE_STRINGS += ( + "PyTreeDef(CustomNode(ADataclass[()], [(*, *), *]))", + "PyTreeDef(CustomNode(ADataclassWithMeta[(4,)], [(*, *), *]))", +) + + +TREES += ( + (collections.OrderedDict([("foo", 34), ("baz", 101), ("something", -42)]),), + ( + collections.defaultdict( + dict, [("foo", 34), ("baz", 101), ("something", -42)] + ), + ), + (ANamedTupleSubclass(foo="hello", bar=3.5),), + (FlatCache(None),), + (FlatCache(1),), + (FlatCache({"a": [1, 2]}),), + (BlackBox(value=2),), +) + + # pytest expects "tree_util_test.ATuple" STRS = [] for tree_str in TREE_STRINGS: @@ -567,6 +629,7 @@ def testTransposeWithCustomObject(self): def testStringRepresentation(self, tree, correct_string): """Checks that the string representation of a tree works.""" treedef = tree_util.tree_structure(tree) + print(TREES) self.assertRegex(str(treedef), correct_string) def testTreeDefWithEmptyDictStringRepresentation(self): @@ -762,6 +825,34 @@ def testNamedTupleRegisteredWithoutKeysIsntTreatedAsLeaf(self): leaves, _ = tree_util.tree_flatten_with_path(ATuple2(1, 'hi')) self.assertLen(leaves, 1) + def testBadFlattenNonTuple(self): + t = BadFlattenNonTuple(3, 4) + with self.assertRaisesRegex( + ValueError, + "The to_iterable function for a custom PyTree node should return a" + r" \(children, aux_data\) tuple, got 'hello'", + ): + tree_util.tree_flatten(t) + + def testBadFlattenBadArityTuple(self): + t = BadFlattenBadArityTuple(3, 4) + with self.assertRaisesRegex( + ValueError, + "The to_iterable function for a custom PyTree node should return a" + r" \(children, aux_data\) tuple, got \(2, 3, 4\)", + ): + tree_util.tree_flatten(t) + + def testBadFlattenNonIterableLeaves(self): + t = BadFlattenNonIterableLeaves(3, 4) + with self.assertRaisesRegex( + ValueError, + "The to_iterable function for a custom PyTree node should return a" + r" \(children, aux_data\) tuple where 'children' is iterable, got " + r"\(7, 7\)", + ): + tree_util.tree_flatten(t) + class StaticTest(parameterized.TestCase): @@ -822,6 +913,15 @@ def fn(x: int, static_y: BlackBox): self.assertEqual(fn(3, BlackBox(1)), 5) self.assertEqual(num_called, 1) + def test_serialize_treedef(self): + tree_structure = jax.tree_util.tree_structure([1, [2], (3,), {'a': 4, 'b': 5}]) + serialized = tree_structure.serialize_using_proto() + new_structure = jax.tree_util.PyTreeDef.deserialize_using_proto( + jax.tree_util.default_registry, + serialized + ) + self.assertEqual(tree_structure, new_structure) + class RavelUtilTest(jtu.JaxTestCase): @@ -1044,20 +1144,117 @@ def test_different_structure_no_children(self): class TreeAliasTest(jtu.JaxTestCase): - @parameterized.parameters( - ('all', 'tree_all'), - ('flatten', 'tree_flatten'), - ('leaves', 'tree_leaves'), - ('map', 'tree_map'), - ('reduce', 'tree_reduce'), - ('structure', 'tree_structure'), - ('transpose', 'tree_transpose'), - ('unflatten', 'tree_unflatten'), - ) - def test_tree_aliases(self, tree_name, tree_util_name): - wrapper = getattr(jax.tree, tree_name) - original = getattr(jax.tree_util, tree_util_name) - self.assertIs(wrapper.__wrapped__, original) + """Simple smoke-tests for tree_util aliases under jax.tree""" + + def test_tree_all(self): + obj = [True, True, (True, False)] + self.assertEqual( + jax.tree.all(obj), + tree_util.tree_all(obj), + ) + + def test_tree_all_is_leaf(self): + obj = [True, True, (True, False)] + is_leaf = lambda x: isinstance(x, tuple) + self.assertEqual( + jax.tree.all(obj, is_leaf=is_leaf), + tree_util.tree_all(obj, is_leaf=is_leaf), + ) + + def test_tree_flatten(self): + obj = [1, 2, (3, 4)] + self.assertEqual( + jax.tree.flatten(obj), + tree_util.tree_flatten(obj), + ) + + def test_tree_flatten_is_leaf(self): + obj = [1, 2, (3, 4)] + is_leaf = lambda x: isinstance(x, tuple) + self.assertEqual( + jax.tree.flatten(obj, is_leaf=is_leaf), + tree_util.tree_flatten(obj, is_leaf=is_leaf), + ) + + def test_tree_leaves(self): + obj = [1, 2, (3, 4)] + self.assertEqual( + jax.tree.leaves(obj), + tree_util.tree_leaves(obj), + ) + + def test_tree_leaves_is_leaf(self): + obj = [1, 2, (3, 4)] + is_leaf = lambda x: isinstance(x, tuple) + self.assertEqual( + jax.tree.leaves(obj, is_leaf=is_leaf), + tree_util.tree_leaves(obj, is_leaf=is_leaf), + ) + + def test_tree_map(self): + func = lambda x: x * 2 + obj = [1, 2, (3, 4)] + self.assertEqual( + jax.tree.map(func, obj), + tree_util.tree_map(func, obj), + ) + + def test_tree_map_is_leaf(self): + func = lambda x: x * 2 + obj = [1, 2, (3, 4)] + is_leaf = lambda x: isinstance(x, tuple) + self.assertEqual( + jax.tree.map(func, obj, is_leaf=is_leaf), + tree_util.tree_map(func, obj, is_leaf=is_leaf), + ) + + def test_tree_reduce(self): + func = lambda a, b: a + b + obj = [1, 2, (3, 4)] + self.assertEqual( + jax.tree.reduce(func, obj), + tree_util.tree_reduce(func, obj), + ) + + def test_tree_reduce_is_leaf(self): + func = lambda a, b: a + b + obj = [(1, 2), (3, 4)] + is_leaf = lambda x: isinstance(x, tuple) + self.assertEqual( + jax.tree.reduce(func, obj, is_leaf=is_leaf), + tree_util.tree_reduce(func, obj, is_leaf=is_leaf), + ) + + def test_tree_structure(self): + obj = [1, 2, (3, 4)] + self.assertEqual( + jax.tree.structure(obj), + tree_util.tree_structure(obj), + ) + + def test_tree_structure_is_leaf(self): + obj = [1, 2, (3, 4)] + is_leaf = lambda x: isinstance(x, tuple) + self.assertEqual( + jax.tree.structure(obj, is_leaf=is_leaf), + tree_util.tree_structure(obj, is_leaf=is_leaf), + ) + + def test_tree_transpose(self): + obj = [(1, 2), (3, 4), (5, 6)] + outer_treedef = tree_util.tree_structure(['*', '*', '*']) + inner_treedef = tree_util.tree_structure(('*', '*')) + self.assertEqual( + jax.tree.transpose(outer_treedef, inner_treedef, obj), + tree_util.tree_transpose(outer_treedef, inner_treedef, obj) + ) + + def test_tree_unflatten(self): + leaves, treedef = jax.tree.flatten([1, 2, (3, 4)]) + self.assertEqual( + jax.tree.unflatten(treedef, leaves), + tree_util.tree_unflatten(treedef, leaves) + ) if __name__ == "__main__": diff --git a/tests/util_test.py b/tests/util_test.py index e06df8b3fa70..5f07d2f50880 100644 --- a/tests/util_test.py +++ b/tests/util_test.py @@ -16,13 +16,13 @@ from absl.testing import absltest +import jax from jax._src import linear_util as lu from jax._src import test_util as jtu from jax._src import util -from jax import config from jax._src.util import weakref_lru_cache -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() try: from jax._src.lib import utils as jaxlib_utils diff --git a/tests/x64_context_test.py b/tests/x64_context_test.py index 13274a132fa6..58cf4a2baae3 100644 --- a/tests/x64_context_test.py +++ b/tests/x64_context_test.py @@ -24,12 +24,11 @@ import jax from jax import lax from jax import random -from jax import config from jax.experimental import enable_x64, disable_x64 import jax.numpy as jnp import jax._src.test_util as jtu -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() class X64ContextTests(jtu.JaxTestCase): @@ -49,12 +48,12 @@ def test_make_array(self, jit): ) def test_correctly_capture_default(self, jit, enable_or_disable): # The fact we defined a jitted function with a block with a different value - # of `config.enable_x64` has no impact on the output. + # of `jax.config.enable_x64` has no impact on the output. with enable_or_disable(): func = jit(lambda: jnp.array(np.float64(0))) func() - expected_dtype = "float64" if config._read("jax_enable_x64") else "float32" + expected_dtype = "float64" if jax.config._read("jax_enable_x64") else "float32" self.assertEqual(func().dtype, expected_dtype) with enable_x64(): @@ -112,7 +111,7 @@ def func_x64(): self.assertEqual(x32.result(), jnp.int32) @jax.legacy_prng_key('allow') - @jax.enable_key_reuse_checks(False) + @jax.debug_key_reuse(False) @jtu.ignore_warning(category=UserWarning, message="Explicitly requested dtype float64 is not available") def test_jit_cache(self): diff --git a/tests/xla_bridge_test.py b/tests/xla_bridge_test.py index 9ea345e26d30..7e778cc99d2c 100644 --- a/tests/xla_bridge_test.py +++ b/tests/xla_bridge_test.py @@ -20,13 +20,13 @@ from absl import logging from absl.testing import absltest +from jax import version from jax._src import compiler from jax._src import config from jax._src import test_util as jtu from jax._src import xla_bridge as xb from jax._src.interpreters import xla from jax._src.lib import xla_client as xc -from jax._src.lib import xla_extension_version config.parse_flags_with_absl() @@ -56,9 +56,7 @@ def test_set_fdo_profile(self): num_replicas=1, num_partitions=1, fdo_profile=b"test_profile" ) self.assertEqual( - compile_options.executable_build_options.fdo_profile, - b"test_profile" if xla_extension_version >= 242 else "test_profile" - ) + compile_options.executable_build_options.fdo_profile, b"test_profile") def test_autofdo_profile(self): @@ -146,7 +144,7 @@ def test_timer_tpu_warning(self): with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") - def _mock_tpu_client(library_path=None): + def _mock_tpu_client_with_options(library_path=None, options=None): time_to_wait = 5 start = time.time() while not w: @@ -160,8 +158,11 @@ def _mock_tpu_client(library_path=None): msg = str(w[-1].message) self.assertIn("Did you run your code on all TPU hosts?", msg) + def _mock_tpu_client(library_path=None): + _mock_tpu_client_with_options(library_path=library_path, options=None) + with mock.patch.object(xc, "make_tpu_client", - side_effect=_mock_tpu_client): + side_effect=_mock_tpu_client_with_options): xb.tpu_client_timer_callback(0.01) def test_register_plugin(self): @@ -196,7 +197,12 @@ def test_register_plugin(self): self.assertIn("name2", xb._backend_factories) self.assertEqual(registration.priority, 400) self.assertTrue(registration.experimental) - mock_make.assert_called_once_with("name1", {}, None) + + options = {} + if xb.get_backend().platform == 'tpu': + options["ml_framework_name"] = "JAX" + options["ml_framework_version"] = version.__version__ + mock_make.assert_called_once_with("name1", options, None) def test_register_plugin_with_config(self): test_json_file_path = os.path.join( @@ -223,16 +229,19 @@ def test_register_plugin_with_config(self): self.assertIn("name1", xb._backend_factories) self.assertEqual(registration.priority, 400) self.assertTrue(registration.experimental) - mock_make.assert_called_once_with( - "name1", - { - "int_option": 64, - "int_list_option": [32, 64], - "string_option": "string", - "float_option": 1.0, - }, - None, - ) + + # The expectation is specified in example_pjrt_plugin_config.json. + options = { + "int_option": 64, + "int_list_option": [32, 64], + "string_option": "string", + "float_option": 1.0, + } + if xb.get_backend().platform == 'tpu': + options["ml_framework_name"] = "JAX" + options["ml_framework_version"] = version.__version__ + + mock_make.assert_called_once_with("name1", options, None) class GetBackendTest(jtu.JaxTestCase): diff --git a/tests/xmap_test.py b/tests/xmap_test.py index 2f4f91b6174e..428c7fc66801 100644 --- a/tests/xmap_test.py +++ b/tests/xmap_test.py @@ -15,11 +15,11 @@ from __future__ import annotations from collections.abc import Generator, Iterator +import contextlib import functools import itertools as it from itertools import product, permutations import math -import os import re from unittest import SkipTest from typing import Union @@ -37,13 +37,13 @@ from jax import lax from jax.ad_checkpoint import checkpoint from jax.errors import JAXTypeError -from jax.experimental.maps import xmap, serial_loop, SerialLoop from jax.experimental.pjit import pjit from jax.interpreters import batching from jax.sharding import PartitionSpec as P from jax._src import array from jax._src import core from jax._src import maps +from jax._src.maps import xmap, serial_loop, SerialLoop from jax._src import xla_bridge from jax._src.core import NamedShape from jax._src.lax import parallel as lax_parallel @@ -53,30 +53,16 @@ from jax._src.sharding_impls import NamedSharding from jax._src.util import unzip2 -from jax import config -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() - -# TODO(mattjj): de-duplicate setUpModule and tearDownModule with pmap_test.py # Run all tests with 8 CPU devices. +_exit_stack = contextlib.ExitStack() + def setUpModule(): - global prev_xla_flags - prev_xla_flags = os.getenv("XLA_FLAGS") - flags_str = prev_xla_flags or "" - # Don't override user-specified device count, or other XLA flags. - if "xla_force_host_platform_device_count" not in flags_str: - os.environ["XLA_FLAGS"] = (flags_str + - " --xla_force_host_platform_device_count=8") - # Clear any cached backends so new CPU backend will pick up the env var. - xla_bridge.get_backend.cache_clear() - -# Reset to previous configuration in case other test modules will be run. + _exit_stack.enter_context(jtu.set_host_platform_device_count(8)) + def tearDownModule(): - if prev_xla_flags is None: - del os.environ["XLA_FLAGS"] - else: - os.environ["XLA_FLAGS"] = prev_xla_flags - xla_bridge.get_backend.cache_clear() + _exit_stack.close() def create_array(global_shape, global_mesh, mesh_axes, global_data=None): @@ -248,10 +234,10 @@ class SPMDTestMixin: def setUp(self): super().setUp() self.spmd_lowering = maps.SPMD_LOWERING.value - config.update('experimental_xmap_spmd_lowering', True) + jax.config.update('experimental_xmap_spmd_lowering', True) def tearDown(self): - config.update('experimental_xmap_spmd_lowering', self.spmd_lowering) + jax.config.update('experimental_xmap_spmd_lowering', self.spmd_lowering) class ManualSPMDTestMixin: @@ -261,12 +247,12 @@ def setUp(self): super().setUp() self.spmd_lowering = maps.SPMD_LOWERING.value self.spmd_manual_lowering = maps.SPMD_LOWERING_MANUAL.value - config.update('experimental_xmap_spmd_lowering', True) - config.update('experimental_xmap_spmd_lowering_manual', True) + jax.config.update('experimental_xmap_spmd_lowering', True) + jax.config.update('experimental_xmap_spmd_lowering_manual', True) def tearDown(self): - config.update('experimental_xmap_spmd_lowering', self.spmd_lowering) - config.update('experimental_xmap_spmd_lowering_manual', self.spmd_manual_lowering) + jax.config.update('experimental_xmap_spmd_lowering', self.spmd_lowering) + jax.config.update('experimental_xmap_spmd_lowering_manual', self.spmd_manual_lowering) @jtu.pytest_mark_if_available('multiaccelerator') @@ -845,13 +831,13 @@ def testFixedSharding(self): # TODO(apaszke): Add support for extracting XLA computations generated by # xmap and make this less of a smoke test. try: - config.update("experimental_xmap_ensure_fixed_sharding", True) + jax.config.update("experimental_xmap_ensure_fixed_sharding", True) f = xmap(lambda x: jnp.sin(2 * jnp.sum(jnp.cos(x) + 4, 'i')), in_axes=['i'], out_axes={}, axis_resources={'i': 'x'}) x = jnp.arange(20, dtype=jnp.float32) f(x) finally: - config.update("experimental_xmap_ensure_fixed_sharding", False) + jax.config.update("experimental_xmap_ensure_fixed_sharding", False) @jtu.with_mesh([('x', 2)]) def testConstantsInLowering(self): diff --git a/third_party/nanobind/BUILD.bazel b/third_party/nanobind/BUILD.bazel deleted file mode 100644 index fa975bc2d002..000000000000 --- a/third_party/nanobind/BUILD.bazel +++ /dev/null @@ -1,22 +0,0 @@ -licenses(["notice"]) - -package(default_visibility = ["//visibility:public"]) - -cc_library( - name = "nanobind", - srcs = glob([ - "src/*.cpp", - ]), - copts = ["-fexceptions"], - includes = ["include"], - textual_hdrs = glob( - [ - "include/**/*.h", - "src/*.h", - ], - ), - deps = [ - "@robin_map", - "@xla//third_party/python_runtime:headers", - ], -) diff --git a/third_party/nanobind/workspace.bzl b/third_party/nanobind/workspace.bzl deleted file mode 100644 index bb6e298157d5..000000000000 --- a/third_party/nanobind/workspace.bzl +++ /dev/null @@ -1,26 +0,0 @@ -# Copyright 2023 The JAX Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Loads the nanobind library.""" - -load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") - -def repo(): - tf_http_archive( - name = "nanobind", - strip_prefix = "nanobind-1.9.2", - sha256 = "149a3da40b0a988513d8cf5e71db3037373823505a3c92f87b988c92d7e0ab34", - urls = tf_mirror_urls("https://github.com/wjakob/nanobind/archive/refs/tags/v1.9.2.tar.gz"), - build_file = "//third_party/nanobind:BUILD.bazel", - ) diff --git a/third_party/robin_map/BUILD.bazel b/third_party/robin_map/BUILD.bazel deleted file mode 100644 index b649dda31766..000000000000 --- a/third_party/robin_map/BUILD.bazel +++ /dev/null @@ -1,17 +0,0 @@ -licenses(["notice"]) - -package(default_visibility = ["//visibility:public"]) - -cc_library( - name = "robin_map", - hdrs = [ - "include/tsl/robin_growth_policy.h", - "include/tsl/robin_hash.h", - "include/tsl/robin_map.h", - "include/tsl/robin_set.h", - ], - copts = ["-fexceptions"], - features = ["-use_header_modules"], # Incompatible with -fexceptions. - includes = ["."], - strip_include_prefix = "include", -) diff --git a/third_party/robin_map/workspace.bzl b/third_party/robin_map/workspace.bzl deleted file mode 100644 index 3b16856b0014..000000000000 --- a/third_party/robin_map/workspace.bzl +++ /dev/null @@ -1,26 +0,0 @@ -# Copyright 2023 The JAX Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Loads the robin_map library.""" - -load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") - -def repo(): - tf_http_archive( - name = "robin_map", - strip_prefix = "robin-map-1.2.1", - sha256 = "2b54d2c1de2f73bea5c51d5dcbd64813a08caf1bfddcfdeee40ab74e9599e8e3", - urls = tf_mirror_urls("https://github.com/Tessil/robin-map/archive/refs/tags/v1.2.1.tar.gz"), - build_file = "//third_party/robin_map:BUILD.bazel", - ) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index a6ac82fe100f..ae4f70cd6666 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +# buildifier: disable=module-docstring load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # To update XLA to a new revision, @@ -20,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "421b738e400a15c053b02924712a0e915b73cf7b" -XLA_SHA256 = "86beb00e75e235a3c7c481840304a54bb8fac233b5e9f8cdcd2947a2b924cdc0" +XLA_COMMIT = "6fd83ac4474a58b3afd5ec832612e59a441153f4" +XLA_SHA256 = "af39b1786ce6e5944f123e30d3ef0932423ab421a5c04ac58407ab11a2002823" def repo(): tf_http_archive(