You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
We accidentally found that an object XlaDebugInfoManager which supposed to be a global singleton instance ends up with two copies in JAX code. The reason is that the singleton has been linked to both xla_extension.so and cuda_plugin.so so that different part of the python code would reference to different copy.
The direct consequence is that it leads to a few missing metadata in the profiler metadata and makes jax.profiler not function correctly.
This is a bug report but also a feature request because we want to make sure anything intended to be global should not leak from the control of the C API. (A future safety mechanism)
System info (python version, jaxlib version, accelerator, etc.)
This is a general issue with JAX plugins. I tested on JAX latest release and HEAD.
The text was updated successfully, but these errors were encountered:
@hawkinsp I chatted with Peter offline and I guess Peter has some ideas to improve the C API over this problem.
Could you please share some thoughts here? Thanks so much.
Description
Hi JAX team,
We identify a bug with the JAX cuda plugin. Here is the writeup for the bug,
https://docs.google.com/document/d/1ldlD8XQ6XYX4zcSRCUIVQyAUBJQZX6v9PdE2qX2_FGw/edit?usp=sharing
To summarize,
We accidentally found that an object
XlaDebugInfoManager
which supposed to be a global singleton instance ends up with two copies in JAX code. The reason is that the singleton has been linked to bothxla_extension.so
andcuda_plugin.so
so that different part of the python code would reference to different copy.The direct consequence is that it leads to a few missing metadata in the profiler metadata and makes
jax.profiler
not function correctly.This is a bug report but also a feature request because we want to make sure anything intended to be global should not leak from the control of the C API. (A future safety mechanism)
System info (python version, jaxlib version, accelerator, etc.)
This is a general issue with JAX plugins. I tested on JAX latest release and HEAD.
The text was updated successfully, but these errors were encountered: