diff --git a/dlordinal/soft_labelling/binomial_distribution.py b/dlordinal/soft_labelling/binomial_distribution.py index f3b6971..60bfedd 100644 --- a/dlordinal/soft_labelling/binomial_distribution.py +++ b/dlordinal/soft_labelling/binomial_distribution.py @@ -46,20 +46,11 @@ def get_binomial_soft_labels(J): if J < 2 or not isinstance(J, int): raise ValueError(f"{J=} must be a positive integer greater than 1") - params = {} - - params["4"] = np.linspace(0.1, 0.9, 4) - params["5"] = np.linspace(0.1, 0.9, 5) - params["6"] = np.linspace(0.1, 0.9, 6) - params["7"] = np.linspace(0.1, 0.9, 7) - params["8"] = np.linspace(0.1, 0.9, 8) - params["10"] = np.linspace(0.1, 0.9, 10) - params["12"] = np.linspace(0.1, 0.9, 12) - params["14"] = np.linspace(0.1, 0.9, 14) + params = np.linspace(0.1, 0.9, J) probs = [] for true_class in range(0, J): - probs.append(binom.pmf(np.arange(0, J), J - 1, params[str(J)][true_class])) + probs.append(binom.pmf(np.arange(0, J), J - 1, params[true_class])) return np.array(probs) diff --git a/dlordinal/soft_labelling/tests/test_binomial_distribution.py b/dlordinal/soft_labelling/tests/test_binomial_distribution.py index bdeb513..c64062d 100644 --- a/dlordinal/soft_labelling/tests/test_binomial_distribution.py +++ b/dlordinal/soft_labelling/tests/test_binomial_distribution.py @@ -28,3 +28,33 @@ def test_get_binomial_soft_labels(): # Individual probabilities should be within [0, 1] assert np.all(result >= 0) and np.all(result <= 1) + + +def test_get_binomial_soft_labels_invalid_input(): + with pytest.raises(ValueError): + get_binomial_soft_labels(1) + + with pytest.raises(ValueError): + get_binomial_soft_labels(1.0) + + with pytest.raises(ValueError): + get_binomial_soft_labels(0) + + with pytest.raises(ValueError): + get_binomial_soft_labels(-1) + + +def test_get_binomial_soft_labels_valid_input(): + for i in range(2, 11): + soft_labels = get_binomial_soft_labels(i) + + # Sum of probabilities in each row should be approximately 1 + row_sums = np.sum(soft_labels, axis=1) + for row_sum in row_sums: + assert row_sum == pytest.approx(1.0, abs=1e-6) + + # Check that all the elements in the matrix are less than or equal to the + # element in the diagonal + diagonal = np.diag(soft_labels) + diff = soft_labels - diagonal[:, np.newaxis] + assert np.all(diff <= 1e-9)