Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

initial rust mpi support #2025

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions enzyme/Enzyme/ActivityAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,11 @@ static const StringSet<> InactiveGlobals = {
"jl_small_typeof",
"ompi_request_null",
"ompi_mpi_double",
"RSMPI_DOUBLE",
"RSMPI_FLOAT",
"RSMPI_SUM",
"RSMPI_COMM_WORLD",
"RSMPI_COMM_SELF",
"ompi_mpi_comm_world",
"__cxa_thread_atexit_impl",
"stderr",
Expand Down
5 changes: 3 additions & 2 deletions enzyme/Enzyme/AdjointGenerator.h
Original file line number Diff line number Diff line change
Expand Up @@ -145,9 +145,10 @@ class AdjointGenerator : public llvm::InstVisitor<AdjointGenerator> {
C = CE->getOperand(0);
}
if (auto GV = dyn_cast<GlobalVariable>(C)) {
if (GV->getName() == "ompi_mpi_double") {
auto name = GV->getName();
ZuseZ4 marked this conversation as resolved.
Show resolved Hide resolved
if (name == "ompi_mpi_double" || name == "RSMPI_DOUBLE") {
return ConstantInt::get(intType, 8, false);
} else if (GV->getName() == "ompi_mpi_float") {
} else if (name == "ompi_mpi_float" || name == "RSMPI_FLOAT") {
return ConstantInt::get(intType, 4, false);
}
}
Expand Down
11 changes: 9 additions & 2 deletions enzyme/Enzyme/CallDerivatives.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1150,7 +1150,8 @@ void AdjointGenerator::handleMPI(llvm::CallInst &call, llvm::Function *called,
C = CE->getOperand(0);
}
if (auto GV = dyn_cast<GlobalVariable>(C)) {
if (GV->getName() == "ompi_mpi_op_sum") {
if (GV->getName() == "ompi_mpi_op_sum" ||
GV->getName() == "RSMPI_SUM") {
Comment on lines +1153 to +1154
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know how to resolve this.

source_id: DefId(0:12 ~ dot_enzyme[f53a]::dot_parallel)
num_fnc_args: 5
input_activity.len(): 5
error: <unknown>:0:0: in function preprocess__ZN10dot_enzyme12dot_parallel17hd37f1f8a2c8de07dE double (ptr, ptr, i64, ptr, i64): Enzyme: cannot compute with global variable that doesn't have marked shadow global
@RSMPI_SUM = external local_unnamed_addr global ptr

This is the relevant code: rsmpi/rsmpi@840e01c#diff-9a676b0d0c142cd1e89e8174ddb007db982d8602bd374a04e40e9f6a421acaebR216-R228

Run with

$ RUSTFLAGS='-Z unstable-options' cargo +enzyme r --example=dot_enzyme --release

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Me neither, but can you share the module you got from ENZYME_OPT=1? Then I can experiment around to see if I find the right changes.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jedbrown you need to add RSMPI_SUM to the ActivityAnalysis.cpp code [the message (poorly) warns that the global variable is differentiable, but Enzyme is unable to determine a differentiable version of the global. Of course it makes no sense to differentiate wrt MPI_SUM so we can mark that in ActivityAnalysis.cpp]

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. Why didn't ompi_mpi_op_sum need to be there?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think one has an extra pointer indirection causing a load which needs to be analyzed

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah in this case, now the issue is that the op of MPI_allreduce cannot be detected (the earlier check if the argument was a literal global no longer applies, since this is a load of RSMPI_SUM). Changing the MPI_Allreduce check to consider something along the lines of:

if (LI = dyn_cast<LoadInst>(...)))
  if (auto GV = dyn_cast<GlobalVariable>(LI->getPointerOperand()))
    if (GV->getName() == "RSMPI_SUM")
      legal = true;

Copy link
Collaborator

@jedbrown jedbrown Aug 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I pushed that (cb2d739; assuming it was the correct place), but now I have

source_id: DefId(0:12 ~ dot_enzyme[42dc]::dot_parallel)
num_fnc_args: 5
input_activity.len(): 5
error: <unknown>:0:0: in function preprocess__ZN10dot_enzyme12dot_parallel17h2f3ed146b457ca09E double (ptr, ptr, i64, ptr, i64): Enzyme: cannot compute with global variable that doesn't have marked shadow global
@RSMPI_COMM_WORLD = external local_unnamed_addr global ptr

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

okay that's the same as the first issue [a global which cannot be proven non-differentiable]. ActivityAnalysis.cpp is again the right place to add that.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the hand-holding. I think this output is correct now.

$ RUSTFLAGS='-Z unstable-options' cargo +enzyme mpirun -n 2 --example=dot_enzyme --release | sort
[0] bx: [0.0, 2.0, 4.0, 6.0, 8.0], by: [0.0, 2.0, 4.0, 6.0, 8.0]
[0] local: 30
[1] bx: [200.0, 202.0, 204.0, 206.0, 208.0], by: [20.0, 22.0, 24.0, 26.0, 28.0]
[1] local: 6130
global: 6160
global: 6160

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that --release is required, otherwise I see

invertedPointers:
   invertedPointers[ptr %0] =   <badref> = load ptr, ptr %4, align 8
end invertedPointers
  <badref> = load ptr, ptr %4, align 8
rustc: /home/jed/src/rust-enzyme/src/tools/enzyme/enzyme/Enzyme/GradientUtils.cpp:8489: virtual void InvertedPointerVH::deleted(): Assertion `0 && "erasing something in invertedPointers map"' failed.
error: could not compile `mpi` (example "dot_enzyme")

isSum = true;
}
}
Expand Down Expand Up @@ -1391,7 +1392,8 @@ void AdjointGenerator::handleMPI(llvm::CallInst &call, llvm::Function *called,
C = CE->getOperand(0);
}
if (auto GV = dyn_cast<GlobalVariable>(C)) {
if (GV->getName() == "ompi_mpi_op_sum") {
if (GV->getName() == "ompi_mpi_op_sum" ||
GV->getName() == "RSMPI_SUM") {
isSum = true;
}
}
Expand All @@ -1402,6 +1404,11 @@ void AdjointGenerator::handleMPI(llvm::CallInst &call, llvm::Function *called,
}
}
}
if (auto LI = dyn_cast<LoadInst>(orig_op)) {
if (auto GV = dyn_cast<GlobalVariable>(LI->getPointerOperand()))
if (GV->getName() == "RSMPI_SUM")
isSum = true;
}
if (!isSum) {
std::string s;
llvm::raw_string_ostream ss(s);
Expand Down
1 change: 1 addition & 0 deletions enzyme/Enzyme/MLIR/Analysis/ActivityAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ static const char *KnownInactiveFunctionsContains[] = {
static const std::set<std::string> InactiveGlobals = {
"ompi_request_null", "ompi_mpi_double", "ompi_mpi_comm_world", "stderr",
"stdout", "stdin", "_ZSt3cin", "_ZSt4cout", "_ZSt5wcout", "_ZSt4cerr",
"RSMPI_DOUBLE", "RSMPI_FLOAT",
"_ZTVNSt7__cxx1115basic_stringbufIcSt11char_traitsIcESaIcEEE",
"_ZTVSt15basic_streambufIcSt11char_traitsIcEE",
"_ZTVSt9basic_iosIcSt11char_traitsIcEE",
Expand Down
7 changes: 4 additions & 3 deletions enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4785,11 +4785,12 @@ void TypeAnalyzer::visitCallBase(CallBase &call) {
C = CE->getOperand(0);
}
if (auto GV = dyn_cast<GlobalVariable>(C)) {
if (GV->getName() == "ompi_mpi_double") {
auto name = GV->getName();
if (name == "ompi_mpi_double" || name == "RSMPI_DOUBLE") {
buf.insert({0}, Type::getDoubleTy(C->getContext()));
} else if (GV->getName() == "ompi_mpi_float") {
} else if (name == "ompi_mpi_float" || name == "RSMPI_FLOAT") {
buf.insert({0}, Type::getFloatTy(C->getContext()));
} else if (GV->getName() == "ompi_mpi_cxx_bool") {
} else if (name == "ompi_mpi_cxx_bool") {
buf.insert({0}, BaseType::Integer);
}
} else if (auto CI = dyn_cast<ConstantInt>(C)) {
Expand Down
5 changes: 3 additions & 2 deletions enzyme/Enzyme/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2344,9 +2344,10 @@ bool writesToMemoryReadBy(llvm::AAResults &AA, llvm::TargetLibraryInfo &TLI,
C = CE->getOperand(0);
}
if (auto GV = dyn_cast<GlobalVariable>(C)) {
if (GV->getName() == "ompi_mpi_double") {
auto name = GV->getName();
if (name == "ompi_mpi_double" || name == "RSMPI_DOUBLE") {
type = ConcreteType(Type::getDoubleTy(C->getContext()));
} else if (GV->getName() == "ompi_mpi_float") {
} else if (name == "ompi_mpi_float" || name == "RSMPI_FLOAT") {
type = ConcreteType(Type::getFloatTy(C->getContext()));
}
}
Expand Down
115 changes: 115 additions & 0 deletions enzyme/test/Enzyme/ReverseMode/mpi_rust.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
; RUN: if [ %llvmver -eq 15 ]; then %opt < %s %loadEnzyme -enzyme -opaque-pointers=1 -S | FileCheck %s; fi
; RUN: if [ %llvmver -ge 15 ]; then %opt < %s %newLoadEnzyme -passes="enzyme" -opaque-pointers=1 -S | FileCheck %s; fi

; ModuleID = 'enzyme-repro.ll'
source_filename = "dot_enzyme.3df87ea89a38df43-cgu.0"
target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128"
target triple = "x86_64-unknown-linux-gnu"

@RSMPI_DOUBLE = external local_unnamed_addr global ptr
@RSMPI_COMM_WORLD = external local_unnamed_addr global ptr
@RSMPI_COMM_SELF = external local_unnamed_addr global ptr

; Function Attrs: noinline nonlazybind sanitize_hwaddress uwtable
define hidden noundef "enzyme_type"="{[-1]:Float@double}" double @_ZN10dot_enzyme12dot_parallel17h7dfcd86d9e8c176bE(ptr noalias nocapture noundef readonly align 8 dereferenceable(16) "enzyme_type"="{[-1]:Pointer}" %0, ptr noalias nocapture noundef nonnull readonly align 8 "enzyme_type"="{[-1]:Pointer, [-1,-1]:Float@double}" %1, i64 noundef "enzyme_type"="{[-1]:Integer}" %2, ptr noalias nocapture noundef nonnull readonly align 8 "enzyme_type"="{[-1]:Pointer, [-1,-1]:Float@double}" %3, i64 noundef "enzyme_type"="{[-1]:Integer}" %4, ptr noundef "enzyme_type"="{[0]:Pointer}" %5) unnamed_addr #1 personality ptr @rust_eh_personality {
%7 = alloca double, align 8
%8 = alloca double, align 8
call void @llvm.lifetime.start.p0(i64 8, ptr nonnull %8)
%9 = alloca double, align 8
store double 1.000, ptr %8, align 8
call void @llvm.lifetime.start.p0(i64 8, ptr nonnull %7)
store double 0.000000e+00, ptr %7, align 8
tail call void @llvm.experimental.noalias.scope.decl(metadata !7)
%10 = load ptr, ptr @RSMPI_DOUBLE, align 8, !noalias !10, !noundef !13
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is unused atm so this won't test any thing. Can you make this a minimal runnable case. Maybe a different mpi fn?

Also get rid of the other stuff like enzyme_type etc

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jedbrown I can't run things, can you test if some other function works on this branch?

%11 = load i64, ptr %0, align 8, !range !14, !alias.scope !15, !noalias !18, !noundef !13
switch i64 %11, label %12 [
i64 0, label %20
i64 1, label %13
i64 2, label %14
i64 3, label %16
i64 4, label %18
]

12: ; preds = %6
unreachable

13: ; preds = %6
br label %20

14: ; preds = %6
%15 = getelementptr inbounds { i64, ptr }, ptr %0, i64 0, i32 1
br label %20

16: ; preds = %6
%17 = getelementptr inbounds { i64, ptr }, ptr %0, i64 0, i32 1
br label %20

18: ; preds = %6
%19 = getelementptr inbounds { i64, ptr }, ptr %0, i64 0, i32 1
br label %20

20: ; preds = %18, %16, %14, %13, %6
%21 = phi ptr [ %19, %18 ], [ %17, %16 ], [ %15, %14 ], [ @RSMPI_COMM_WORLD, %13 ], [ @RSMPI_COMM_SELF, %6 ]
%22 = load ptr, ptr %21, align 8, !noalias !18, !noundef !13
%23 = alloca i32, align 4
;%23 = call noundef i32 @MPI_Allreduce(ptr noundef nonnull %8, ptr noundef nonnull %7, i32 noundef 1, ptr noundef %10, ptr noundef %5, ptr noundef %22), !noalias !7
%24 = load double, ptr %7, align 8, !noundef !13
call void @llvm.lifetime.end.p0(i64 8, ptr nonnull %7)
call void @llvm.lifetime.end.p0(i64 8, ptr nonnull %8)
ret double %24
}

; Function Attrs: nonlazybind sanitize_hwaddress uwtable
declare noundef i32 @MPI_Allreduce(ptr noundef, ptr noundef, i32 noundef, ptr noundef, ptr noundef, ptr noundef) unnamed_addr #2

; Function Attrs: nocallback nofree nosync nounwind willreturn memory(argmem: readwrite)
declare void @llvm.lifetime.start.p0(i64 immarg, ptr nocapture) #3

; Function Attrs: nocallback nofree nosync nounwind willreturn memory(argmem: readwrite)
declare void @llvm.lifetime.end.p0(i64 immarg, ptr nocapture) #3

; Function Attrs: nocallback nofree nosync nounwind willreturn memory(inaccessiblemem: readwrite)
declare void @llvm.experimental.noalias.scope.decl(metadata) #4

; Function Attrs: nounwind nonlazybind sanitize_hwaddress uwtable
declare noundef i32 @rust_eh_personality(i32 noundef, i32 noundef, i64, ptr noundef, ptr noundef) unnamed_addr #5

declare double @__enzyme_autodiff(...)

define double @enzyme_opt_helper_0(ptr %0, ptr %1, i64 %2, ptr %3, i64 %4, ptr %5) {
%7 = call double (...) @__enzyme_autodiff(ptr @_ZN10dot_enzyme12dot_parallel17h7dfcd86d9e8c176bE, metadata !"enzyme_const", ptr %0, metadata !"enzyme_dup", ptr %1, ptr %1, metadata !"enzyme_const", i64 %2, metadata !"enzyme_dup", ptr %3, ptr %3, metadata !"enzyme_const", i64 %4, metadata !"enzyme_const", ptr %5)
ret double %7
}

attributes #0 = { noinline nounwind nonlazybind sanitize_hwaddress uwtable "probe-stack"="inline-asm" "target-cpu"="x86-64" }
attributes #1 = { noinline nonlazybind sanitize_hwaddress uwtable "probe-stack"="inline-asm" "target-cpu"="x86-64" }
attributes #2 = { nonlazybind sanitize_hwaddress uwtable "probe-stack"="inline-asm" "target-cpu"="x86-64" }
attributes #3 = { nocallback nofree nosync nounwind willreturn memory(argmem: readwrite) }
attributes #4 = { nocallback nofree nosync nounwind willreturn memory(inaccessiblemem: readwrite) }
attributes #5 = { nounwind nonlazybind sanitize_hwaddress uwtable "probe-stack"="inline-asm" "target-cpu"="x86-64" }

!llvm.module.flags = !{!0, !1, !2, !3, !4, !5}
!llvm.ident = !{!6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6}
!llvm.dbg.cu = !{}

!0 = !{i32 8, !"PIC Level", i32 2}
!1 = !{i32 7, !"PIE Level", i32 2}
!2 = !{i32 2, !"RtLibUseGOT", i32 1}
!3 = !{i32 1, !"LTOPostLink", i32 1}
!4 = !{i32 2, !"Dwarf Version", i32 4}
!5 = !{i32 2, !"Debug Info Version", i32 3}
!6 = !{!"rustc version 1.77.0-nightly (ecb2f9cdf 2024-07-30)"}
!7 = !{!8}
!8 = distinct !{!8, !9, !"_ZN3mpi10collective23CommunicatorCollectives15all_reduce_into17h5bd43ff3d0a82648E: argument 0"}
!9 = distinct !{!9, !"_ZN3mpi10collective23CommunicatorCollectives15all_reduce_into17h5bd43ff3d0a82648E"}
!10 = !{!8, !11, !12}
!11 = distinct !{!11, !9, !"_ZN3mpi10collective23CommunicatorCollectives15all_reduce_into17h5bd43ff3d0a82648E: argument 1"}
!12 = distinct !{!12, !9, !"_ZN3mpi10collective23CommunicatorCollectives15all_reduce_into17h5bd43ff3d0a82648E: argument 2"}
!13 = !{}
!14 = !{i64 0, i64 5}
!15 = !{!16, !8}
!16 = distinct !{!16, !17, !"_ZN69_$LT$mpi..topology..SimpleCommunicator$u20$as$u20$mpi..raw..AsRaw$GT$6as_raw17h5ddd9d255d268465E: argument 0"}
!17 = distinct !{!17, !"_ZN69_$LT$mpi..topology..SimpleCommunicator$u20$as$u20$mpi..raw..AsRaw$GT$6as_raw17h5ddd9d255d268465E"}
!18 = !{!11, !12}

; CHECK: define internal void @diffe_ZN10dot_enzyme12dot_parallel17h7dfcd86d9e8c176bE(ptr noalias nocapture noundef readonly align 8 dereferenceable(16) "enzyme_type"="{[-1]:Pointer}" %0, ptr noalias nocapture noundef nonnull readonly align 8 "enzyme_type"="{[-1]:Pointer, [-1,-1]:Float@double}" %1, ptr nocapture align 8 "enzyme_type"="{[-1]:Pointer, [-1,-1]:Float@double}" %"'", i64 noundef "enzyme_type"="{[-1]:Integer}" %2, ptr noalias nocapture noundef nonnull readonly align 8 "enzyme_type"="{[-1]:Pointer, [-1,-1]:Float@double}" %3, ptr nocapture align 8 "enzyme_type"="{[-1]:Pointer, [-1,-1]:Float@double}" %"'1", i64 noundef "enzyme_type"="{[-1]:Integer}" %4, ptr noundef "enzyme_type"="{[0]:Pointer}" %5, double %differeturn)
Loading