@@ -15,16 +15,16 @@ class DeepFMModdelTrain:
1515 def __init__ (self , data_path ):
1616 self .device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
1717 self .data = pd .read_csv (data_path )
18- self .sparse_features = ["user_id " , "user_name " , "age" , "gender" , "place_id" , "place_name" ,"category" , "sub_category " ]
18+ self .sparse_features = ["userid " , "name " , "age" , "gender" , "place_id" , "place_name" ,"category" , "subcategory " ]
1919 self .sequence_feature = "like_list"
2020 self .linear_feature_columns = None
2121 self .dnn_feature_columns = None
2222 self .feature_names = None
2323 self .model_input = None
2424 self .target = "yn"
25- self .model_path = "/home/ubuntu/working/ MLOps/MLOps/app/model/deepfm_model.pt"
26- self .encoders_path = "/home/ubuntu/working/ MLOps/MLOps/app/model/label_encoders.pkl"
27- self .key2index_path = "/home/ubuntu/working/ MLOps/MLOps/app/model/key2index.pkl"
25+ self .model_path = "/home/ubuntu/MLOps/MLOps/app/model/deepfm_model.pt"
26+ self .encoders_path = "/home/ubuntu/MLOps/MLOps/app/model/label_encoders.pkl"
27+ self .key2index_path = "/home/ubuntu/MLOps/MLOps/app/model/key2index.pkl"
2828 self .model = None
2929 self .max_len = None
3030 self .label_encoders = {}
@@ -107,6 +107,24 @@ def predict(self, input_data):
107107 self .label_encoders = pickle .load (f )
108108 with open (self .key2index_path , 'rb' ) as f :
109109 self .key2index = pickle .load (f )
110+
111+ # 예측에 필요한 메타데이터 재구성
112+ temp_like_list = self .data [self .sequence_feature ].apply (ast .literal_eval )
113+ self .max_len = max (len (x ) for x in temp_like_list )
114+
115+ sparse_feature_names = ["userid" , "name" , "age" , "gender" , "place_id" , "place_name" , "category" , "subcategory" ]
116+
117+ reconstructed_sparse_features = [SparseFeat (feat , vocabulary_size = len (self .label_encoders [feat ].classes_ ), embedding_dim = 4 )
118+ for feat in sparse_feature_names ]
119+
120+ reconstructed_sequence_feature = [VarLenSparseFeat (SparseFeat (self .sequence_feature ,
121+ vocabulary_size = len (self .key2index ) + 1 ,
122+ embedding_dim = 4 ),
123+ maxlen = self .max_len , combiner = 'mean' )]
124+
125+ self .linear_feature_columns = reconstructed_sparse_features + reconstructed_sequence_feature
126+ self .dnn_feature_columns = reconstructed_sparse_features + reconstructed_sequence_feature
127+ self .feature_names = get_feature_names (self .linear_feature_columns + self .dnn_feature_columns )
110128
111129 # 입력 데이터를 DataFrame으로 변환
112130 if isinstance (input_data , dict ):
@@ -115,13 +133,19 @@ def predict(self, input_data):
115133 input_df = input_data .copy ()
116134
117135 # sparse feature 전처리
118- sparse_feature_names = ["user_id" , "user_name" , "age" , "gender" , "place_id" , "place_name" ,"category" , "sub_category" ]
119136 for feature in sparse_feature_names :
120- input_df [feature ] = input_df [feature ].fillna ("unknown" )
121- # 학습 시 보지 못한 값은 'unknown'으로 처리
122- input_df [feature ] = input_df [feature ].apply (
123- lambda x : x if x in self .label_encoders [feature ].classes_ else "unknown"
124- )
137+ encoder = self .label_encoders [feature ]
138+ known_classes = set (encoder .classes_ )
139+
140+ # 'unknown'이 학습되었는지 확인
141+ unknown_in_classes = 'unknown' in known_classes
142+
143+ def transform_element (x ):
144+ if pd .isna (x ) or x not in known_classes :
145+ return 'unknown' if unknown_in_classes else encoder .classes_ [0 ]
146+ return x
147+
148+ input_df [feature ] = input_df [feature ].apply (transform_element )
125149 input_df [feature ] = self .label_encoders [feature ].transform (input_df [feature ])
126150
127151 # sequence feature 전처리
@@ -153,19 +177,19 @@ def encode_sequence(x):
153177 return model .predict (model_input )
154178
155179if __name__ == "__main__" :
156- deepfm_train = DeepFMModdelTrain ("/home/ubuntu/working/MLOps /data/final_click_log.csv" )
180+ deepfm_train = DeepFMModdelTrain (".. /data/final_click_log.csv" )
157181 deepfm_train .preprocess ()
158182 model = deepfm_train .train ()
159183 # 예시 데이터
160184 input_data = {
161- "user_id " : ["0x06fa1ba7a7e44621a2338e6093e53341" , "0x6d132cda535848e295b8e489486ea841" , "0x0fa0a9c4a283451181b77d91e3229c91" ],
162- "user_name " : ["딩딩이" , "댕댕이 언니" , "에구궁" ],
185+ "userid " : ["0x06fa1ba7a7e44621a2338e6093e53341" , "0x6d132cda535848e295b8e489486ea841" , "0x0fa0a9c4a283451181b77d91e3229c91" ],
186+ "name " : ["딩딩이" , "댕댕이 언니" , "에구궁" ],
163187 "age" : [30 , 60 , 50 ],
164188 "gender" : [1 , 1 , 0 ],
165189 "place_id" : ["0xeb37b72b1fa54dc6a3867517ac2df6ef" , "0x0528fbb073104d51974112a71d72b4e4" , "0x1226fc5501194d2eba00383748045c20" ],
166190 "place_name" : ["롯데월드 쇼핑몰" , "청아라 생선구이" , "시골보쌈" ],
167191 "category" : ["쇼핑" , "음식점&카페" , "음식점&카페" ],
168- "sub_category " : ["전문매장/상가" , "한식" , "한식" ],
192+ "subcategory " : ["전문매장/상가" , "한식" , "한식" ],
169193 "like_list" : ["[11, 12, 13, 14, 15, 16, 17, 18, 19, 20]" , "[26, 22, 29, 44]" , "[11, 28, 14, 29, 10, 22, 8, 25, 30]" ]
170194 }
171195 prediction = deepfm_train .predict (input_data )
0 commit comments