From d9bd3587c52f0744057a42e5c724339320a82ed9 Mon Sep 17 00:00:00 2001 From: jax authors Date: Tue, 2 Jul 2024 12:48:00 -0700 Subject: [PATCH] Enable runtime uptime telemetry for JAX on Cloud TPU, only if the flag is not set by the user explicitly. Otherwise, prefer the user preference. PiperOrigin-RevId: 648812558 --- jax/_src/cloud_tpu_init.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax/_src/cloud_tpu_init.py b/jax/_src/cloud_tpu_init.py index 73e61de68008..a61f44847523 100644 --- a/jax/_src/cloud_tpu_init.py +++ b/jax/_src/cloud_tpu_init.py @@ -80,7 +80,7 @@ def cloud_tpu_init() -> None: os.environ.setdefault('JAX_PLATFORMS', 'tpu,cpu') os.environ['TPU_ML_PLATFORM'] = 'JAX' os.environ['TPU_ML_PLATFORM_VERSION'] = version.__version__ - os.environ['ENABLE_RUNTIME_UPTIME_TELEMETRY'] = '1' + os.environ.setdefault('ENABLE_RUNTIME_UPTIME_TELEMETRY', '1') if hardware_utils.tpu_enhanced_barrier_supported(): os.environ["LIBTPU_INIT_ARGS"] = os.environ.get("LIBTPU_INIT_ARGS","") + " --xla_tpu_use_enhanced_launch_barrier=true"