Skip to content

Commit

Permalink
handle categorical multiclass by shuffling the category order
Browse files Browse the repository at this point in the history
  • Loading branch information
paulbkoch committed Dec 30, 2024
1 parent 14b79a9 commit 3e91ca7
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 5 deletions.
23 changes: 19 additions & 4 deletions shared/libebm/PartitionOneDimensionalBoosting.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,8 @@ static ErrorEbm Flatten(BoosterShell* const pBoosterShell,
EBM_ASSERT(nullptr != pBoosterShell);
EBM_ASSERT(iDimension <= k_cDimensionsMax);
EBM_ASSERT(nullptr != apBins);
EBM_ASSERT(nullptr != apBinsEnd);
EBM_ASSERT(apBins < apBinsEnd); // if zero bins then we should have handled elsewhere
EBM_ASSERT(1 <= cSlices);
EBM_ASSERT(2 <= cBins);
EBM_ASSERT(cSlices <= cBins);
Expand Down Expand Up @@ -972,7 +974,6 @@ template<bool bHessian, size_t cCompilerScores> class PartitionOneDimensionalBoo

// TODO: use all of these!
UNUSED(bUnseen);
UNUSED(cCategorySamplesMin);
UNUSED(categoryHessianPercentMin);
UNUSED(categoricalThresholdMax);
UNUSED(categoricalInclusionPercent);
Expand All @@ -987,7 +988,7 @@ template<bool bHessian, size_t cCompilerScores> class PartitionOneDimensionalBoo
pRootTreeNode->Init();

// we can only sort if there's a single sortable index, so 1 score value
bNominal = 1 == cCompilerScores && bNominal && (0 == (TermBoostFlags_DisableCategorical & flags));
bNominal = bNominal && (0 == (TermBoostFlags_DisableCategorical & flags));

size_t cNormalBins = cBins;
if(bMissing) {
Expand Down Expand Up @@ -1073,7 +1074,8 @@ template<bool bHessian, size_t cCompilerScores> class PartitionOneDimensionalBoo
} while(pBinsEnd != pBin);

if(bNominal) {
if(apBins == ppBin) {
size_t cRemaining = ppBin - apBins;
if(0 == cRemaining) {
// all categories are dregs, so pretend there's just one bin and everything is inside it

const bool bUpdateWithHessian = bHessian && !(TermBoostFlags_DisableNewtonUpdate & flags);
Expand Down Expand Up @@ -1128,7 +1130,20 @@ template<bool bHessian, size_t cCompilerScores> class PartitionOneDimensionalBoo

*pTotalGain = 0;
return error;
} else {
}

// shuffle
while(size_t{1} != cRemaining) {
const size_t iSwap = pRng->NextFast(cRemaining);
auto* const pTemp = apBins[iSwap];
--cRemaining;
apBins[iSwap] = apBins[cRemaining];
apBins[cRemaining] = pTemp;
}

static constexpr bool bSingleScore = 1 == cCompilerScores;
if(bSingleScore) {
// there isn't a single key to sort on with multiple grad/hess pairs, so use random ordering otherwise.
std::sort(apBins,
ppBin,
CompareBin<bHessian, cCompilerScores>(
Expand Down
2 changes: 1 addition & 1 deletion shared/libebm/tests/boosting_unusual_inputs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2268,7 +2268,7 @@ static double RandomizedTesting(const AccelerationFlags acceleration) {
}

TEST_CASE("stress test, boosting") {
const double expected = 17508883449920.195;
const double expected = 15044531333054.148;

double validationMetricExact = RandomizedTesting(AccelerationFlags_NONE);
CHECK(validationMetricExact == expected);
Expand Down

0 comments on commit 3e91ca7

Please sign in to comment.