Skip to content

Commit

Permalink
Monkey patch functions in APEX.
Browse files Browse the repository at this point in the history
Signed-off-by: Aditya Agrawal <[email protected]>
  • Loading branch information
adityaiitb authored and debermudez committed Jun 16, 2020
1 parent 6279485 commit 6a8f571
Showing 1 changed file with 16 additions and 0 deletions.
16 changes: 16 additions & 0 deletions pyprof/nvtx/nvmarker.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,12 +334,28 @@ def new_iter(self, *args, **kwargs):

mod.DataLoader.__iter__ = new_iter

# Monkey-patch functions in APEX
#
def patch_apex():
import importlib
if importlib.util.find_spec("amp_C") is not None:
import amp_C
patchClass(amp_C)

if importlib.util.find_spec("fused_adam_cuda") is not None:
import fused_adam_cuda
patchClass(fused_adam_cuda)

if importlib.util.find_spec("fused_layer_norm_cuda") is not None:
import fused_layer_norm_cuda
patchClass(fused_layer_norm_cuda)

def init():
print("Initializing NVTX monkey patches")

patch_dataloader()
patch_torch_classes()
patch_torch_nn_forward_functions()
patch_apex()

print("Done with NVTX monkey patching")

0 comments on commit 6a8f571

Please sign in to comment.