-
Notifications
You must be signed in to change notification settings - Fork 46
Description

The roc_auc_score is used only in a binary classification context in this project.
The line if len(set(y)) <= 2: ensures that roc_auc_score is only called for binary classification.
In binary classification, multi_class is not required, hence avoiding the error.
However , I run into a problem that raised error like this:

After debug , I found that the shape of y_prob_all in classification task is (103,4) and the shape of y_true_all is (103 , ).But the original code use if len(set(y)) <= 2: as condition , which would still result in error. I would like to ask for ur help
Mandarin:
在处理多分类任务时会遇到 raise ValueError("multi_class must be in ('ovo', 'ovr')")的问题,因为y的形状被调整为了(num,),无论如何len(set(y))的长度都是小于2的,多分类情况下y_prob_all形状为(num,classes)这会导致判断失效一直进入该条件而导致报错,请问你们的解决方法是什么?