Skip to content

Commit

Permalink
Added support for all number of classes in binomial soft labels function
Browse files Browse the repository at this point in the history
  • Loading branch information
victormvy committed Jan 27, 2025
1 parent dc40d19 commit 9147cd6
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 11 deletions.
13 changes: 2 additions & 11 deletions dlordinal/soft_labelling/binomial_distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,20 +46,11 @@ def get_binomial_soft_labels(J):
if J < 2 or not isinstance(J, int):
raise ValueError(f"{J=} must be a positive integer greater than 1")

params = {}

params["4"] = np.linspace(0.1, 0.9, 4)
params["5"] = np.linspace(0.1, 0.9, 5)
params["6"] = np.linspace(0.1, 0.9, 6)
params["7"] = np.linspace(0.1, 0.9, 7)
params["8"] = np.linspace(0.1, 0.9, 8)
params["10"] = np.linspace(0.1, 0.9, 10)
params["12"] = np.linspace(0.1, 0.9, 12)
params["14"] = np.linspace(0.1, 0.9, 14)
params = np.linspace(0.1, 0.9, J)

probs = []

for true_class in range(0, J):
probs.append(binom.pmf(np.arange(0, J), J - 1, params[str(J)][true_class]))
probs.append(binom.pmf(np.arange(0, J), J - 1, params[true_class]))

return np.array(probs)
30 changes: 30 additions & 0 deletions dlordinal/soft_labelling/tests/test_binomial_distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,33 @@ def test_get_binomial_soft_labels():

# Individual probabilities should be within [0, 1]
assert np.all(result >= 0) and np.all(result <= 1)


def test_get_binomial_soft_labels_invalid_input():
with pytest.raises(ValueError):
get_binomial_soft_labels(1)

with pytest.raises(ValueError):
get_binomial_soft_labels(1.0)

with pytest.raises(ValueError):
get_binomial_soft_labels(0)

with pytest.raises(ValueError):
get_binomial_soft_labels(-1)


def test_get_binomial_soft_labels_valid_input():
for i in range(2, 11):
soft_labels = get_binomial_soft_labels(i)

# Sum of probabilities in each row should be approximately 1
row_sums = np.sum(soft_labels, axis=1)
for row_sum in row_sums:
assert row_sum == pytest.approx(1.0, abs=1e-6)

# Check that all the elements in the matrix are less than or equal to the
# element in the diagonal
diagonal = np.diag(soft_labels)
diff = soft_labels - diagonal[:, np.newaxis]
assert np.all(diff <= 1e-9)

0 comments on commit 9147cd6

Please sign in to comment.