Skip to content

Commit

Permalink
Avoid throwing exceptions in LAPACK CPU kernels.
Browse files Browse the repository at this point in the history
When an FFI kernel is executed, there isn't any global try/except block (I think!) so it's probably a good idea to avoid throwing.
Instead, it should be safer to handle mapping failures to ffi::Error manually.

PiperOrigin-RevId: 647348889
  • Loading branch information
dfm authored and jax authors committed Jun 27, 2024
1 parent 61185a2 commit 98b8754
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 8 deletions.
2 changes: 2 additions & 0 deletions jaxlib/cpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ cc_library(
"@xla//xla/service:custom_call_status",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/base:dynamic_annotations",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:str_format",
],
)
Expand Down
59 changes: 51 additions & 8 deletions jaxlib/cpu/lapack_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ limitations under the License.

#include "absl/algorithm/container.h"
#include "absl/base/dynamic_annotations.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_format.h"
#include "xla/ffi/api/c_api.h"
#include "xla/ffi/api/ffi.h"
Expand All @@ -39,28 +41,61 @@ static_assert(sizeof(jax::lapack_int) == sizeof(int32_t),

namespace ffi = xla::ffi;

// TODO(danfm): These macros and the casting functions should be moved to a
// separate header for use in other FFI kernels.
#define ASSIGN_OR_RETURN_FFI_ERROR(lhs, rhs) \
if (!rhs.ok()) { \
return ffi::Error(static_cast<XLA_FFI_Error_Code>(rhs.status().code()), \
std::string(rhs.status().message())); \
} \
lhs = rhs.value()

#define RETURN_IF_FFI_ERROR(...) \
do { \
ffi::Error err = (__VA_ARGS__); \
if (err.failure()) { \
return err; \
} \
} while (0)

namespace {

template <typename T>
inline T CastNoOverflow(int64_t value, const std::string& source = __FILE__) {
inline absl::StatusOr<T> MaybeCastNoOverflow(
int64_t value, const std::string& source = __FILE__) {
if constexpr (sizeof(T) == sizeof(int64_t)) {
return value;
} else {
if (value > std::numeric_limits<T>::max()) [[unlikely]] {
throw std::overflow_error{
return absl::InvalidArgumentError(
absl::StrFormat("%s: Value (=%d) exceeds the maximum representable "
"value of the desired type",
source, value)};
source, value));
}
return static_cast<T>(value);
}
}

template <typename T>
std::tuple<int64_t, int64_t, int64_t> SplitBatch2D(ffi::Span<T> dims) {
inline T CastNoOverflow(int64_t value, const std::string& source = __FILE__) {
auto result = MaybeCastNoOverflow<T>(value, source);
if (!result.ok()) {
throw std::overflow_error{std::string(result.status().message())};
}
return result.value();
}

template <typename T>
ffi::Error CheckMatrixDimensions(ffi::Span<T> dims) {
if (dims.size() < 2) {
throw std::invalid_argument("Matrix must have at least 2 dimensions");
return ffi::Error(ffi::ErrorCode::kInvalidArgument,
"Matrix must have at least 2 dimensions");
}
return ffi::Error::Success();
}

template <typename T>
std::tuple<int64_t, int64_t, int64_t> SplitBatch2D(ffi::Span<T> dims) {
auto matrix_dims = dims.last(2);
return std::make_tuple(absl::c_accumulate(dims.first(dims.size() - 2), 1,
std::multiplies<int64_t>()),
Expand Down Expand Up @@ -201,15 +236,18 @@ ffi::Error LuDecomposition<dtype>::Kernel(
ffi::Buffer<dtype> x, ffi::ResultBuffer<dtype> x_out,
ffi::ResultBuffer<LapackIntDtype> ipiv,
ffi::ResultBuffer<LapackIntDtype> info) {
RETURN_IF_FFI_ERROR(CheckMatrixDimensions(x.dimensions));
auto [batch_count, x_rows, x_cols] = SplitBatch2D(x.dimensions);
auto* x_out_data = x_out->data;
auto* ipiv_data = ipiv->data;
auto* info_data = info->data;

CopyIfDiffBuffer(x, x_out);

auto x_rows_v = CastNoOverflow<lapack_int>(x_rows);
auto x_cols_v = CastNoOverflow<lapack_int>(x_cols);
ASSIGN_OR_RETURN_FFI_ERROR(auto x_rows_v,
MaybeCastNoOverflow<lapack_int>(x_rows));
ASSIGN_OR_RETURN_FFI_ERROR(auto x_cols_v,
MaybeCastNoOverflow<lapack_int>(x_cols));
auto x_leading_dim_v = x_rows_v;

const int64_t x_out_step{x_rows * x_cols};
Expand Down Expand Up @@ -371,14 +409,16 @@ template <ffi::DataType dtype>
ffi::Error CholeskyFactorization<dtype>::Kernel(
ffi::Buffer<dtype> x, MatrixParams::UpLo uplo,
ffi::ResultBuffer<dtype> x_out, ffi::ResultBuffer<LapackIntDtype> info) {
RETURN_IF_FFI_ERROR(CheckMatrixDimensions(x.dimensions));
auto [batch_count, x_rows, x_cols] = SplitBatch2D(x.dimensions);
auto* x_out_data = x_out->data;
auto* info_data = info->data;

CopyIfDiffBuffer(x, x_out);

auto uplo_v = static_cast<char>(uplo);
auto x_order_v = CastNoOverflow<lapack_int>(x.dimensions.back());
ASSIGN_OR_RETURN_FFI_ERROR(
auto x_order_v, MaybeCastNoOverflow<lapack_int>(x.dimensions.back()));
auto x_leading_dim_v = x_order_v;

const int64_t x_out_step{x_rows * x_cols};
Expand Down Expand Up @@ -1077,3 +1117,6 @@ template struct Sytrd<std::complex<float>>;
template struct Sytrd<std::complex<double>>;

} // namespace jax

#undef ASSIGN_OR_RETURN_FFI_ERROR
#undef RETURN_IF_FFI_ERROR

0 comments on commit 98b8754

Please sign in to comment.