Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update CUDA custom call example code to use ffi_call #22141

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from

Commits on Jun 27, 2024

  1. Add ffi_call function with a similar signature to pure_callback.

    This could be useful for supporting the most common use cases for FFI custom
    calls. It has several benefits over using the `Primitive` based approach, but
    the biggest one (in my opinion) is that it doesn't require interacting with
    `mlir` at all. It does have the limitation that transforms would need to be
    registered using interfaces like `custom_vjp`, but many users of custom calls
    already do that.
    
    ~~The easiest to-do item (I think) is to implement batching using a
    `vectorized` parameter like `pure_callback`, but we could also think about more
    sophisticated vmapping interfaces in the future.~~ Done.
    
    The more difficult to-do is to think about how to support sharding, and we
    might actually want to expose an interface similar to the one from
    `custom_partitioning`. I have less experience with this part so I'll have to
    think some more about it, and feedback would be appreciated!
    dfm committed Jun 27, 2024
    Configuration menu
    Copy the full SHA
    ed56df0 View commit details
    Browse the repository at this point in the history
  2. Update CUDA custom call example code to use ffi_call.

    Following up on google#21925, we can update the example code in
    `docs/cuda_custom_call` to use `ffi_call` instead of manually
    registering `core.Primitive`s. This removes quite a bit of boilerplate
    and doesn't require direct use of MLIR.
    dfm committed Jun 27, 2024
    Configuration menu
    Copy the full SHA
    2aea52b View commit details
    Browse the repository at this point in the history