-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathflask_api.py
358 lines (285 loc) · 12.9 KB
/
flask_api.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
from flask import Flask, request, jsonify
from prompt_template import create_prompt
from prompt_template import create_title_description_prompt
from chat_model import create_chain
from kiwi_module import KiwiProcessor, stopwords, apply_weights_to_similarity
import torch.nn as nn
import threading
import torch
import pandas as pd
import joblib
import numpy as np
from langchain_core.runnables import RunnableWithMessageHistory
from langchain_community.chat_message_histories import ChatMessageHistory
from langchain_core.messages import HumanMessage, AIMessage
from langchain.chat_models import ChatOpenAI
from langchain.schema import HumanMessage, SystemMessage
app = Flask(__name__)
# 락 생성
store_lock = threading.Lock()
# 세션 기록을 저장할 딕셔너리
store = {}
# 프롬프트 생성
prompt = create_prompt()
# `TransformerModel` 클래스 정의
class TransformerModel(nn.Module):
def __init__(self, input_size, output_size, num_heads, hidden_dim, num_layers, dropout=0.3):
super(TransformerModel, self).__init__()
self.input_layer = nn.Linear(input_size, hidden_dim)
self.pos_encoder = nn.Sequential(nn.Linear(hidden_dim, hidden_dim), nn.ReLU())
self.transformer = nn.Transformer(
d_model=hidden_dim,
nhead=num_heads,
num_encoder_layers=num_layers,
num_decoder_layers=num_layers,
dropout=dropout
)
self.output_layer = nn.Linear(hidden_dim, output_size)
nn.init.xavier_uniform_(self.output_layer.weight)
def forward(self, x):
x = self.input_layer(x)
x = self.pos_encoder(x)
x = self.transformer(x.unsqueeze(0), x.unsqueeze(0)).squeeze(0)
return torch.sigmoid(self.output_layer(x))
# `MLP` 클래스 정의
class MLP(nn.Module):
def __init__(self, input_size, hidden_sizes, output_size):
super(MLP, self).__init__()
self.hidden_layers = nn.ModuleList()
current_input_size = input_size
for hidden_size in hidden_sizes:
self.hidden_layers.append(nn.Linear(current_input_size, hidden_size))
current_input_size = hidden_size
self.output_layer = nn.Linear(current_input_size, output_size)
def forward(self, x):
for layer in self.hidden_layers:
x = torch.relu(layer(x))
x = torch.sigmoid(self.output_layer(x))
return x
# Load pre-trained model and scaler
scaler = None
model = None
accord_scaler = None
accord_model = None
# 하이퍼파라미터
input_size = 52 # 입력 피처 수
output_size = 26 # 예측할 노트 수
num_heads = 4
hidden_dim = 128
num_layers = 3
def load_models():
global model, scaler, accord_model, accord_scaler
# 스케일러 로드
scaler = joblib.load('scalers/new_scaler.pkl')
#scaler = StandardScaler()
accord_scaler = joblib.load('scalers/accord_scaler.pkl')
# 전체 모델을 로드
model = TransformerModel(input_size, output_size, num_heads, hidden_dim, num_layers)
model.load_state_dict(torch.load('models/new_trained_model.pth'))
accord_model = torch.load('models/accord_model.pth')
model.eval()
accord_model.eval()
load_models()
# 세션 ID를 기반으로 세션 기록을 가져오는 함수
def get_session_history(session_id):
with store_lock:
if session_id not in store:
store[session_id] = ChatMessageHistory() # ChatMessageHistory 객체로 변경
return store[session_id]
# 대화 체인 생성
def create_chain_with_history():
chain, memory = create_chain(prompt)
# RunnableWithMessageHistory 생성
rag_with_history = RunnableWithMessageHistory(
chain,
get_session_history, # 세션 기록을 가져오는 함수
input_messages_key="question", # 질문 키
history_messages_key="chat_history", # 대화 기록 키
)
return rag_with_history
# 대화 기록 저장 함수
def store_conversation(session_id, user_input, assistant_response):
if session_id not in store:
store[session_id] = ChatMessageHistory() # 새로운 세션 기록 생성
# 메시지 기록 (사용자의 메시지와 어시스턴트의 응답)
store[session_id].add_message(HumanMessage(content=user_input)) # 사용자 메시지 추가
store[session_id].add_message(AIMessage(content=assistant_response)) # 어시스턴트의 응답 추가
# 대화 처리 API
@app.route("/api/chats", methods=['POST'])
def chat():
data = request.json
question = data.get('input', '')
session_id = data.get('sessionId', 'default_session')
extra_info = data.get('extraInfo', {})
nickname = extra_info.get('nickname', 'Scentasy 사용자')
# 대화 체인 생성
rag_with_history = create_chain_with_history()
# 대화 처리
response = rag_with_history.invoke(
{
"question": question,
"nickname": nickname
},
{"configurable": {"session_id": session_id}}
)
# 대화 기록 저장
store_conversation(session_id, question, response)
print(f"Updated store for session ID: {session_id} - {store[session_id]}")
return jsonify({'response': response})
# 유사도 계산 및 예측 API
@app.route("/api/recipe", methods=['POST'])
def similarity_and_predict():
data = request.json
print(f"Received data: {data}") # 요청 데이터 로그
session_id = data.get('sessionId')
extra_info = data.get('extraInfo', {})
liked_scents = extra_info.get('likedScents', [])
disliked_scents = extra_info.get('dislikedScents', [])
# 해당 세션의 대화 기록을 가져옴
if session_id not in store:
return jsonify({'error': 'No conversation history for this session.'}), 404
conversation_history = store[session_id]
# 대화 기록에서 실제 메시지를 추출
try:
conversation_text = ""
for message in conversation_history.messages:
if isinstance(message, HumanMessage):
conversation_text += f"User: {message.content} "
elif isinstance(message, AIMessage):
conversation_text += f"Assistant: {message.content} "
# 유사도 계산
kiwi_processor = KiwiProcessor()
similarity_results = kiwi_processor.calculate_similarity_with_synonyms(conversation_text)
# 가중치 적용
weighted_results = apply_weights_to_similarity(similarity_results)
# 예측 호출
predicted_notes = predict_internal(weighted_results) # 이진값
predicted_notes_array = [int(x.strip()) for x in predicted_notes.split(',')]
predicted_accords = predict_accords(predicted_notes_array)
# 이름 매핑
predicted_note_names = map_notes_to_columns(predicted_notes_array)
predicted_accords_names = map_accords_to_columns(predicted_accords)
# 노트 싫어하는 향 제외
filtered_note_names = [
note for note in predicted_note_names if all(dislike.lower() not in note.lower() for dislike in disliked_scents)
]
# 제목과 설명 생성
title, description = generate_title_and_description(conversation_text, filtered_note_names, predicted_accords_names)
return jsonify({
'input_data': weighted_results,
'binary_note_recipe': predicted_notes,
'predicted_notes': filtered_note_names,
'predicted_accords': predicted_accords_names,
'title': title,
'description': description
})
except Exception as e:
return jsonify({'error': str(e)}), 500
# 내부 예측 함수
def predict_internal(weighted_results):
# 가중치를 DataFrame으로 변환
input_df = pd.DataFrame([weighted_results])
# 스케일러 적용
input_scaled = scaler.transform(input_df.values)
# 텐서로 전환
input_tensor = torch.tensor(input_scaled, dtype=torch.float32)
# 예측
with torch.no_grad():
predicted_notes = model(input_tensor)
# 예측 데이터 이진화
predicted_notes_binary = (predicted_notes > 0.4).int().numpy().tolist()
# 이진화 (임계값 0.5 적용)
#predicted_notes_binary = np.zeros_like(predicted_notes) # 모든 값을 0으로 초기화
#top_5_indices = np.argsort(predicted_notes[0])[::-1][:5] # 상위 5개 인덱스 찾기
#predicted_notes_binary[0, top_5_indices] = 1 # 상위 5개만 1로 설정
# # 후처리 과정: Accords_floral이 0인 경우, 플로럴 관련 노트를 0으로 설정
# if weighted_results.get('Accords_floral', 0) == 0:
# floral_note_indices = [8, 9, 12, 13] # Freesia, Rose, Muguet, Magnolia
# for idx in floral_note_indices:
# predicted_notes_binary[0][idx] = 0
#
# # 후처리 과정: Accords_fruity가 0인 경우, Fruity 관련 노트를 0으로 설정
# if weighted_results.get('Accords_fruity', 0) == 0:
# fruity_note_indices = [16, 4, 5, 6] # Black Currant, Peach, Fig, Black Cherry
# for idx in fruity_note_indices:
# predicted_notes_binary[0][idx] = 0
# 결과 배열을 문자열로 변환
predicted_notes_string = ', '.join(map(str, predicted_notes_binary[0]))
return predicted_notes_string
# 노트 매핑 함수
def map_notes_to_columns(predicted_notes):
note_columns = [
"TOP_BERGAMOT", "TOP_MINT", "TOP_LEMON", "TOP_AQUA",
"TOP_GRAPEFRUIT", "TOP_PEACH", "TOP_FIG",
"TOP_BLACKCHERRY", "TOP_GREEN", "MIDDLE_FREESIA",
"MIDDLE_ROSE", "MIDDLE_PEPPER", "MIDDLE_ROSEMARY",
"MIDDLE_MUGUET", "MIDDLE_MAGNOLIA", "MIDDLE_OCEAN",
"MIDDLE_BLACKCURRANT", "BASE_MUSK", "BASE_VANILLA",
"BASE_SANDALWOOD", "BASE_LEATHER", "BASE_PATCHOULI",
"BASE_CEDAR", "BASE_AMBER", "BASE_FRANKINCENSE",
"BASE_HINOKI"
]
active_notes = [note_columns[i] for i, value in enumerate(predicted_notes) if value == 1]
return active_notes
# 어코드 예측 함수
def predict_accords(predicted_notes):
# 입력 데이터 검증
if not predicted_notes:
return jsonify({'error': 'Predicted notes are required.'}), 400
# 예측된 노트를 데이터프레임으로 변환
input_df = pd.DataFrame([predicted_notes])
# 스케일러 적용
input_scaled = accord_scaler.transform(input_df)
# 텐서로 전환
input_tensor = torch.tensor(input_scaled, dtype=torch.float32)
# 어코드 예측
with torch.no_grad():
predicted_accords = accord_model(input_tensor)
# 어코드 비율을 리스트로 반환
predicted_accords_ratios = predicted_accords.numpy().tolist()[0]
return predicted_accords_ratios
# 어코드 매핑 함수
def map_accords_to_columns(predicted_accords):
accord_columns = [
"시트러스", "구르망", "오리엔탈",
"머스크", "프루티", "우디",
"스파이시", "플로럴", "아쿠아틱",
"레더", "아로마틱", "스모키"
]
# 어코드 이름과 값을 매핑 및 필터링
predicted_accords_names = [
{"accord": accord_columns[i], "value": round(predicted_accords[i], 3)}
for i in range(len(predicted_accords)) if predicted_accords[i] >= 0.5
]
return predicted_accords_names
# 제목과 설명 생성 함수
def generate_title_and_description(conversation_text, predicted_note_names, predicted_accords_with_columns):
# ChatOpenAI 인스턴스 생성
llm = ChatOpenAI(model_name="gpt-4", temperature=0.7)
prompt_template = create_title_description_prompt()
notes_str = ', '.join(predicted_note_names)
accords_str = ', '.join([f"{accord['accord']} ({accord['value']})" for accord in predicted_accords_with_columns])
# print(notes_str)
# print(accords_str)
prompt = prompt_template.format(
conversation_text=conversation_text,
notes_str=notes_str,
accords_str=accords_str
)
messages = [
SystemMessage(content="You are an assistant that generates creative perfume titles and descriptions."),
HumanMessage(content=prompt)
]
# 모델 호출
response = llm(messages)
# 응답에서 제목과 설명 추출
generated_text = response.content.strip().split("\n")
title = generated_text[0].replace("제목: ", "").strip() if len(generated_text) > 0 and generated_text[0].strip() else "우디한 성공의 순간"
description = generated_text[1].replace("설명: ", "").strip() if len(generated_text) > 1 and generated_text[1].strip() else (
"중요한 발표를 앞둔 당신을 위한 향입니다. 상쾌한 레몬의 탑 노트로 시작하여, 중간에는 프리지아와 페퍼의 조합이 "
"당신에게 활력을 불어넣어줍니다. 마지막으로, 샌달우드와 패츌리의 베이스 노트가 깔끔하고 우아한 향으로 마무리되며, "
"가을의 따뜻한 우디한 느낌을 전해줍니다. 이 향은 당신의 중요한 순간을 더욱 특별하게 만들어줄 것입니다."
)
return title, description
if __name__ == '__main__':
app.run(debug=True, host="0.0.0.0", port=5001)