Skip to content

Commit e46cd3d

Browse files
authored
Merge pull request #108 from kookmin-sw/mhsong-dev
Fix several errors
2 parents b92114d + 0f564b1 commit e46cd3d

File tree

2 files changed

+24
-5
lines changed

2 files changed

+24
-5
lines changed

automation/deploy_train/create_rayjob.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,7 @@ def getTrainInfo(ModelClass, optimstr, lossstr, lr, model_dir):
256256
import torch.nn as nn
257257
model = ModelClass()
258258
if os.path.exists(f'{{model_dir}}/torch.pt'):
259-
model.load_state_dict(torch.load("{{model_dir}}/torch.pt"))
259+
model.load_state_dict(torch.load(f"{{model_dir}}/torch.pt"))
260260
261261
optimizer = eval("optim."+optimstr+"(model.parameters(), lr=lr)")
262262
criterion = eval("nn."+lossstr+"()")
@@ -300,7 +300,7 @@ def train_func(config):
300300
test_loader = prepare_data_loader(test_loader)
301301
302302
# 모델, 손실 함수 및 옵티마이저 설정
303-
model, criterion, optimizer = getTrainInfo(ModelClass, optimstr=OPTIMIZER_STR, lossstr=LOSS_STR, lr=config["lr"], model_dir)
303+
model, criterion, optimizer = getTrainInfo(ModelClass, optimstr=OPTIMIZER_STR, lossstr=LOSS_STR, lr=config["lr"], model_dir=model_dir)
304304
model = ray.train.torch.prepare_model(model)
305305
if torch.cuda.is_available():
306306
model = model.cuda()

inference/template_code/lambda_app.py

+22-3
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,29 @@
3030

3131
try:
3232
model = ModelClass()
33-
if os.path.exists('/tmp/model/torch.pt'):
34-
model.load_state_dict(torch.load("/tmp/model/torch.pt", map_location=torch.device('cpu')))
33+
model_path = '/tmp/model/torch.pt'
34+
35+
if os.path.exists(model_path):
36+
try:
37+
state_dict = torch.load(model_path, map_location=torch.device('cpu'))
38+
39+
# DataParallel로 래핑된 모델인지 확인
40+
if 'module.' in list(state_dict.keys())[0]:
41+
# 키에서 'module.'을 제거
42+
from collections import OrderedDict
43+
new_state_dict = OrderedDict()
44+
for k, v in state_dict.items():
45+
new_state_dict[k.replace('module.', '')] = v
46+
model.load_state_dict(new_state_dict)
47+
else:
48+
model.load_state_dict(state_dict)
49+
except Exception as e:
50+
print(f"모델 상태 딕셔너리를 로드하는 중 오류 발생: {e}")
51+
os._exit(0)
52+
else:
53+
raise FileNotFoundError(f"모델 파일을 찾을 수 없습니다: {model_path}")
3554
except Exception as e:
36-
print(f"Model load failed: {e}")
55+
print(f"모델 로드 실패: {e}")
3756
os._exit(0)
3857

3958
def handler(event, context):

0 commit comments

Comments
 (0)