Skip to content

Commit

Permalink
[Mosaic GPU] Rearrange the pass pipeline (again)
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 642256145
  • Loading branch information
apaszke authored and jax authors committed Jun 11, 2024
1 parent 3345952 commit 1256ceb
Showing 1 changed file with 6 additions and 5 deletions.
11 changes: 6 additions & 5 deletions jaxlib/mosaic/gpu/custom_call.cc
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,9 @@ mlir::FailureOr<mlir::OpPassManager> GetPassPipeline(
mlir::registerConvertFuncToLLVMPass();
mlir::registerConvertAffineToStandard();
mlir::registerReconcileUnrealizedCasts();
mlir::registerGpuToLLVMConversionPass();
// TODO(apaszke): Only register the passes we actually use.
mlir::memref::registerMemRefPasses();
mlir::registerConvertToLLVMPass();
mlir::registerGPUPasses();
mosaic::gpu::registerGpuLaunchLoweringPass();
mosaic::gpu::registerConvertGpuToLLVMPass();
Expand Down Expand Up @@ -140,11 +140,12 @@ mlir::FailureOr<mlir::OpPassManager> GetPassPipeline(
convert-math-to-llvm{approximate-log1p=true},
canonicalize{max-iterations=10 max-num-rewrites=-1 region-simplify=true test-convergence=false top-down=true},
cse,
reconcile-unrealized-casts,)" +
)" +
(target != mlir::gpu::CompilationTarget::Assembly ? "gpu-launch-lowering,"
: "") +
R"(
convert-func-to-llvm{index-bitwidth=0 use-bare-ptr-memref-call-conv=false}
convert-to-llvm,
reconcile-unrealized-casts
)
)");
}
Expand All @@ -170,9 +171,9 @@ void InitContext(mlir::MLIRContext* context) {
mlir::registerConvertFuncToLLVMInterface(registry);
mlir::index::registerConvertIndexToLLVMInterface(registry);
mlir::cf::registerConvertControlFlowToLLVMInterface(registry);
mlir::ub::registerConvertUBToLLVMInterface(registry); // Arith needs this
mlir::ub::registerConvertUBToLLVMInterface(registry);
mlir::arith::registerConvertArithToLLVMInterface(registry);
mlir::registerFinalizeMemRefToLLVMConversionPass();
mlir::registerConvertMemRefToLLVMInterface(registry);
mlir::gpu::registerOffloadingLLVMTranslationInterfaceExternalModels(registry);
mlir::NVVM::registerNVVMTargetInterfaceExternalModels(registry);
mlir::registerBuiltinDialectTranslation(registry);
Expand Down

0 comments on commit 1256ceb

Please sign in to comment.