Skip to content

Commit

Permalink
Enable runtime uptime telemetry for JAX on Cloud TPU, only if the fla…
Browse files Browse the repository at this point in the history
…g is not set by the user explicitly. Otherwise, prefer the user preference.

PiperOrigin-RevId: 648812558
  • Loading branch information
Google-ML-Automation authored and jax authors committed Jul 2, 2024
1 parent 242c993 commit d9bd358
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion jax/_src/cloud_tpu_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def cloud_tpu_init() -> None:
os.environ.setdefault('JAX_PLATFORMS', 'tpu,cpu')
os.environ['TPU_ML_PLATFORM'] = 'JAX'
os.environ['TPU_ML_PLATFORM_VERSION'] = version.__version__
os.environ['ENABLE_RUNTIME_UPTIME_TELEMETRY'] = '1'
os.environ.setdefault('ENABLE_RUNTIME_UPTIME_TELEMETRY', '1')
if hardware_utils.tpu_enhanced_barrier_supported():
os.environ["LIBTPU_INIT_ARGS"] = os.environ.get("LIBTPU_INIT_ARGS","") + " --xla_tpu_use_enhanced_launch_barrier=true"

Expand Down

0 comments on commit d9bd358

Please sign in to comment.