diff --git a/unfair/model/data.py b/unfair/model/data.py index f468178..11fbfc0 100644 --- a/unfair/model/data.py +++ b/unfair/model/data.py @@ -216,7 +216,9 @@ def extract_fets(dat, split_name, net): or np.isinf(dat_out[features.LABEL_FET]).any() ), f'Warning: NaNs or Infs in ground truth for split "{split_name}".' - replace_unknowns(dat_in, is_dt) + # HGBDT can handle unknowns, but other model types cannot. + if not isinstance(net, models.HistGbdtSklearnWrapper): + replace_unknowns(dat_in, is_dt) # Convert output features to class labels. dat_out = net.convert_to_class(dat_out)