From 7322026005de0ad339eb00176f603f83ad843ec3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Mon, 25 Mar 2024 04:34:39 +0100 Subject: [PATCH] Update Thunder's GPT traces in the README (#1190) --- extensions/thunder/README.md | 108 +++-------------------------------- 1 file changed, 9 insertions(+), 99 deletions(-) diff --git a/extensions/thunder/README.md b/extensions/thunder/README.md index b84da74e5f..df2d0461a7 100644 --- a/extensions/thunder/README.md +++ b/extensions/thunder/README.md @@ -41,27 +41,8 @@ print(forward_trace) @no_autocast() def augmented_forward_fn(*args): # args: "Collection" - t0, \ - t1, \ - t2, \ - t3, \ - t4, \ - t5, \ - t6, \ - t7, \ - t8, \ - t9, \ - t10, \ - t11, \ - t12, \ - t13, \ - t14, \ - t15, \ - t16, \ - t17, \ - t18, \ - t19, \ - = args + t0, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10, t11, t12, t13, t14, t15, t16, t17, \ + t18, t19, = args del args t24 = torch.nn.functional.embedding(t0, t19, None, None, 2.0, False, False) # t24: "cuda:0 f32[2, 5, 4096]" t20 = torch_slice_prim_impl(t1, [0, 0], [5, 128], [1, 1]) # t20: "cuda:0 f32[5, 128]" @@ -247,90 +228,19 @@ print(backward_trace) def backward_fn(saved_for_backward, cotangents): # saved_for_backward: "Collection" # cotangents: "Collection" - C0, \ - C1, \ - = saved_for_backward + C0, C1, = saved_for_backward clear_collection(saved_for_backward) del saved_for_backward - t178, \ - = cotangents + t178, = cotangents clear_collection(cotangents) del cotangents - t0, \ - t101, \ - t104, \ - t105, \ - t114, \ - t136, \ - t138, \ - t139, \ - t140, \ - t141, \ - t142, \ - t144, \ - t146, \ - t15, \ - t152, \ - t155, \ - t156, \ - t157, \ - t158, \ - t16, \ - t164, \ - t166, \ - t17, \ - t172, \ - t175, \ - t176, \ - t18, \ - t24, \ - t3, \ - t30, \ - t33, \ - t34, \ - t4, \ - t43, \ - t49, \ - t5, \ - t51, \ - t6, \ - t65, \ - t67, \ - t68, \ - t69, \ - t7, \ - t70, \ - t71, \ - t73, \ - t75, \ - t8, \ - t81, \ - t84, \ - t85, \ - t86, \ - t87, \ - t9, \ - t93, \ - t95, \ - = C0 + t0, t101, t104, t105, t114, t136, t138, t139, t140, t141, t142, t144, t146, \ + t15, t152, t155, t156, t157, t158, t16, t164, t166, t17, t172, t175, t176, t18, \ + t24, t3, t30, t33, t34, t4, t43, t49, t5, t51, t6, t65, t67, t68, t69, t7, t70, \ + t71, t73, t75, t8, t81, t84, t85, t86, t87, t9, t93, t95, = C0 clear_collection(C0) del C0 - b1, \ - b2, \ - b41, \ - b91, \ - f101, \ - f106, \ - f40, \ - f42, \ - f51, \ - f56, \ - f6, \ - f90, \ - f92, \ - i0, \ - i23, \ - i73, \ + b1, b2, b41, b91, f101, f106, f40, f42, f51, f56, f6, f90, f92, i0, i23, i73, \ = C1 clear_collection(C1) del C1