Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Marking a CUDA custom call as command buffer-compatible has no effect #14889

Closed
andportnoy opened this issue Jul 13, 2024 · 2 comments
Closed
Assignees

Comments

@andportnoy
Copy link
Member

This seems to happen because this piece of logic only looks at registrations for generic platform gpu:

auto registration = ffi::FindHandler(hlo->custom_call_target(), "gpu");
return registration.ok()
? ffi::IsCommandBufferCompatible(registration->traits)
: false;
Hence custom calls registered for CUDA are not taken into account.

@ezhulenev has suggested offline that the fix might be to do platform name canonicalization more thoroughly.

A quick way to repro is to modify the JAX cuda_custom_call test as follows:

diff --git a/docs/cuda_custom_call/cuda_custom_call_test.py b/docs/cuda_custom_call/cuda_custom_call_test.py
index 563462feb..0e3a5453b 100644
--- a/docs/cuda_custom_call/cuda_custom_call_test.py
+++ b/docs/cuda_custom_call/cuda_custom_call_test.py
@@ -72,7 +72,8 @@ library = ctypes.cdll.LoadLibrary(SHARED_LIBRARY)
 xla_client.register_custom_call_target(name=XLA_CUSTOM_CALL_TARGET_FWD,
                                        fn=ffi.pycapsule(library.FooFwd),
                                        platform=XLA_PLATFORM,
-                                       api_version=XLA_CUSTOM_CALL_API_VERSION)
+                                       api_version=XLA_CUSTOM_CALL_API_VERSION,
+                                       traits=1)


 # our forward primitive will also return the intermediate output b+1
@@ -111,7 +112,8 @@ mlir.register_lowering(foo_fwd_p, _foo_fwd_lowering, platform=JAX_PLATFORM)
 xla_client.register_custom_call_target(name=XLA_CUSTOM_CALL_TARGET_BWD,
                                        fn=ffi.pycapsule(library.FooBwd),
                                        platform=XLA_PLATFORM,
-                                       api_version=XLA_CUSTOM_CALL_API_VERSION)
+                                       api_version=XLA_CUSTOM_CALL_API_VERSION,
+                                       traits=1)

then run the following script (you'll need Nsight Systems CLI installed), which will show whether each kernel was launched as part of a CUDA graph or not:

XLA_FLAGS=--xla_gpu_graph_min_graph_size=1 nsys profile --cuda-graph-trace=node -o custom-call-graph --force-overwrite=true python cuda_custom_call_test.py
nsys stats -r cuda_kern_exec_trace --force-export=true custom-call-graph.nsys-rep
@andportnoy andportnoy changed the title Marking a custom call as command buffer-compatible has no effect Marking a CUDA custom call as command buffer-compatible has no effect Jul 13, 2024
copybara-service bot pushed a commit that referenced this issue Jul 17, 2024
…tion and lookup

+ Use xla:util error constructors instead of absl::XyzError to automatically capture error stack trace

Fix for #14889

PiperOrigin-RevId: 653278962
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this issue Jul 17, 2024
…tion and lookup

+ Use xla:util error constructors instead of absl::XyzError to automatically capture error stack trace

Fix for openxla/xla#14889

PiperOrigin-RevId: 653278962
copybara-service bot pushed a commit that referenced this issue Jul 17, 2024
…tion and lookup

+ Use xla:util error constructors instead of absl::XyzError to automatically capture error stack trace

Fix for #14889

PiperOrigin-RevId: 653319684
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this issue Jul 17, 2024
…tion and lookup

+ Use xla:util error constructors instead of absl::XyzError to automatically capture error stack trace

Fix for openxla/xla#14889

PiperOrigin-RevId: 653319684
@phu0ngng
Copy link
Contributor

Hi @ezhulenev,
I confirm that cudaGraph showed up in the nsys reports with the fixes introduced in #14921 and #15021.
We can close this issue.
Many thanks.

@hawkinsp
Copy link
Member

Closing, per @phu0ngng 's report that this is fixed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants