diff --git a/recommenders/models/deeprec/DataModel/ImplicitCF.py b/recommenders/models/deeprec/DataModel/ImplicitCF.py index f490c48f3..3cfbb2821 100644 --- a/recommenders/models/deeprec/DataModel/ImplicitCF.py +++ b/recommenders/models/deeprec/DataModel/ImplicitCF.py @@ -80,6 +80,7 @@ def _data_processing(self, train, test): user_idx = df[[self.col_user]].drop_duplicates().reindex() user_idx[self.col_user + "_idx"] = np.arange(len(user_idx)) self.n_users = len(user_idx) + self.n_users_in_train = train[self.col_user].nunique() self.user_idx = user_idx self.user2id = dict( @@ -210,7 +211,7 @@ def sample_neg(x): if neg_id not in x: return neg_id - indices = range(self.n_users) + indices = range(self.n_users_in_train) if self.n_users < batch_size: users = [random.choice(indices) for _ in range(batch_size)] else: