diff --git a/jax/_src/cloud_tpu_init.py b/jax/_src/cloud_tpu_init.py index 73e61de68008..a61f44847523 100644 --- a/jax/_src/cloud_tpu_init.py +++ b/jax/_src/cloud_tpu_init.py @@ -80,7 +80,7 @@ def cloud_tpu_init() -> None: os.environ.setdefault('JAX_PLATFORMS', 'tpu,cpu') os.environ['TPU_ML_PLATFORM'] = 'JAX' os.environ['TPU_ML_PLATFORM_VERSION'] = version.__version__ - os.environ['ENABLE_RUNTIME_UPTIME_TELEMETRY'] = '1' + os.environ.setdefault('ENABLE_RUNTIME_UPTIME_TELEMETRY', '1') if hardware_utils.tpu_enhanced_barrier_supported(): os.environ["LIBTPU_INIT_ARGS"] = os.environ.get("LIBTPU_INIT_ARGS","") + " --xla_tpu_use_enhanced_launch_barrier=true"