-
Notifications
You must be signed in to change notification settings - Fork 2
Open
Labels
Description
기존코드에서 속도차이가 나타나지 않아서 변경한 부분 Google slide에 댓글로 남겼는데
Issue로 작성 요청해주셔서 작성하였습니다.
기존코드
def Foo(x, y):
a = torch.sin(x)
b = torch.cos(y)
return a + b
foo = Foo
x =torch.rand(3, 4)
y =torch.rand(3, 4)
traced_foo =torch.jit.trace(foo, (x,y))
traced_foo(x, y)
optimized_foo = torch.compile(foo)
optimized_foo(x, y)
수정된 코드
import torch
import time
def Foo(x, y):
a = torch.matmul(x, y.T) # Add matmul
b = torch.sin(a)
c = torch.cos(a)
return b + c
foo = Foo
x = torch.rand(1000, 1000)
y = torch.rand(1000, 1000)
# torch.jit.trace part
traced_foo = torch.jit.trace(foo, (x, y))
# torch.compile part
optimized_foo = torch.compile(foo)
print("torch.compile warm up")
for i in range(5):
optimized_foo(x, y)
def measure_time(func, x, y, num_iters=10):
start = time.time()
for _ in range(num_iters):
func(x, y)
end = time.time()
return end - start
num_iters = 10
original_time = measure_time(foo, x, y, num_iters)
traced_time = measure_time(traced_foo, x, y, num_iters)
compiled_time = measure_time(optimized_foo, x, y, num_iters)
print(f"Original Time: {original_time:.6f} seconds")
print(f"Traced Time: {traced_time:.6f} seconds")
print(f"Compiled Time: {compiled_time:.6f} seconds")
input data의 크기가 커지면 성능 차이가 보이는 것과 torch.compile(foo)를 하기 전에 warm-up을 해야하는 부분을 추가 하였습니다.
감사합니다!