|
1 | | -import unittest |
2 | | - |
3 | 1 | import numpy as np |
4 | 2 | import numpy.testing as np_test |
| 3 | +import pytest |
5 | 4 | from scipy.special import softmax |
6 | 5 |
|
7 | 6 | from pgmpy.factors.discrete import TabularCPD |
|
13 | 12 | # Huange Wang, Xiaoguang Gao, Chris P. Thompson |
14 | 13 |
|
15 | 14 |
|
16 | | -class TestDBNInference(unittest.TestCase): |
17 | | - def setUp(self): |
18 | | - dbn_1 = DynamicBayesianNetwork() |
19 | | - dbn_1.add_edges_from([(("Z", 0), ("X", 0)), (("Z", 0), ("Y", 0)), (("Z", 0), ("Z", 1))]) |
20 | | - cpd_start_z_1 = TabularCPD(("Z", 0), 2, [[0.8], [0.2]]) |
21 | | - cpd_x_1 = TabularCPD(("X", 0), 2, [[0.9, 0.6], [0.1, 0.4]], [("Z", 0)], [2]) |
22 | | - cpd_y_1 = TabularCPD(("Y", 0), 2, [[0.7, 0.2], [0.3, 0.8]], [("Z", 0)], [2]) |
23 | | - cpd_trans_z_1 = TabularCPD(("Z", 1), 2, [[0.9, 0.1], [0.1, 0.9]], [("Z", 0)], [2]) |
24 | | - dbn_1.add_cpds(cpd_start_z_1, cpd_trans_z_1, cpd_x_1, cpd_y_1) |
25 | | - dbn_1.initialize_initial_state() |
26 | | - self.dbn_inference_1 = DBNInference(dbn_1) |
27 | | - dbn_2 = DynamicBayesianNetwork() |
28 | | - dbn_2.add_edges_from([(("Z", 0), ("X", 0)), (("X", 0), ("Y", 0)), (("Z", 0), ("Z", 1))]) |
29 | | - cpd_start_z_2 = TabularCPD(("Z", 0), 2, [[0.5], [0.5]]) |
30 | | - cpd_x_2 = TabularCPD(("X", 0), 2, [[0.6, 0.9], [0.4, 0.1]], [("Z", 0)], [2]) |
31 | | - cpd_y_2 = TabularCPD(("Y", 0), 2, [[0.2, 0.3], [0.8, 0.7]], [("X", 0)], [2]) |
32 | | - cpd_z_2 = TabularCPD(("Z", 1), 2, [[0.4, 0.7], [0.6, 0.3]], [("Z", 0)], [2]) |
33 | | - dbn_2.add_cpds(cpd_x_2, cpd_y_2, cpd_z_2, cpd_start_z_2) |
34 | | - dbn_2.initialize_initial_state() |
35 | | - self.dbn_inference_2 = DBNInference(dbn_2) |
36 | | - |
37 | | - def test_forward_inf_single_variable(self): |
38 | | - query_result = self.dbn_inference_1.forward_inference([("X", 0)]) |
| 15 | +@pytest.fixture |
| 16 | +def dbn_inference_instances(): |
| 17 | + dbn_1 = DynamicBayesianNetwork() |
| 18 | + dbn_1.add_edges_from([(("Z", 0), ("X", 0)), (("Z", 0), ("Y", 0)), (("Z", 0), ("Z", 1))]) |
| 19 | + cpd_start_z_1 = TabularCPD(("Z", 0), 2, [[0.8], [0.2]]) |
| 20 | + cpd_x_1 = TabularCPD(("X", 0), 2, [[0.9, 0.6], [0.1, 0.4]], [("Z", 0)], [2]) |
| 21 | + cpd_y_1 = TabularCPD(("Y", 0), 2, [[0.7, 0.2], [0.3, 0.8]], [("Z", 0)], [2]) |
| 22 | + cpd_trans_z_1 = TabularCPD(("Z", 1), 2, [[0.9, 0.1], [0.1, 0.9]], [("Z", 0)], [2]) |
| 23 | + dbn_1.add_cpds(cpd_start_z_1, cpd_trans_z_1, cpd_x_1, cpd_y_1) |
| 24 | + dbn_1.initialize_initial_state() |
| 25 | + dbn_inference_1 = DBNInference(dbn_1) |
| 26 | + dbn_2 = DynamicBayesianNetwork() |
| 27 | + dbn_2.add_edges_from([(("Z", 0), ("X", 0)), (("X", 0), ("Y", 0)), (("Z", 0), ("Z", 1))]) |
| 28 | + cpd_start_z_2 = TabularCPD(("Z", 0), 2, [[0.5], [0.5]]) |
| 29 | + cpd_x_2 = TabularCPD(("X", 0), 2, [[0.6, 0.9], [0.4, 0.1]], [("Z", 0)], [2]) |
| 30 | + cpd_y_2 = TabularCPD(("Y", 0), 2, [[0.2, 0.3], [0.8, 0.7]], [("X", 0)], [2]) |
| 31 | + cpd_z_2 = TabularCPD(("Z", 1), 2, [[0.4, 0.7], [0.6, 0.3]], [("Z", 0)], [2]) |
| 32 | + dbn_2.add_cpds(cpd_x_2, cpd_y_2, cpd_z_2, cpd_start_z_2) |
| 33 | + dbn_2.initialize_initial_state() |
| 34 | + dbn_inference_2 = DBNInference(dbn_2) |
| 35 | + return dbn_inference_1, dbn_inference_2 |
| 36 | + |
| 37 | + |
| 38 | +class TestDBNInference: |
| 39 | + def test_forward_inf_single_variable(self, dbn_inference_instances): |
| 40 | + dbn_inference_1, _ = dbn_inference_instances |
| 41 | + query_result = dbn_inference_1.forward_inference([("X", 0)]) |
39 | 42 | np_test.assert_array_almost_equal(query_result[("X", 0)].values, np.array([0.84, 0.16])) |
40 | 43 |
|
41 | | - def test_forward_inf_multiple_variable(self): |
42 | | - query_result = self.dbn_inference_1.forward_inference([("X", 0), ("Y", 0)]) |
| 44 | + def test_forward_inf_multiple_variable(self, dbn_inference_instances): |
| 45 | + dbn_inference_1, _ = dbn_inference_instances |
| 46 | + query_result = dbn_inference_1.forward_inference([("X", 0), ("Y", 0)]) |
43 | 47 | np_test.assert_array_almost_equal(query_result[("X", 0)].values, np.array([0.84, 0.16])) |
44 | 48 | np_test.assert_array_almost_equal(query_result[("Y", 0)].values, np.array([0.6, 0.4])) |
45 | 49 |
|
46 | | - def test_forward_inf_single_variable_with_evidence(self): |
47 | | - query_result = self.dbn_inference_1.forward_inference([("Z", 1)], {("Y", 0): 0, ("Y", 1): 0}) |
| 50 | + def test_forward_inf_single_variable_with_evidence(self, dbn_inference_instances): |
| 51 | + dbn_inference_1, dbn_inference_2 = dbn_inference_instances |
| 52 | + query_result = dbn_inference_1.forward_inference([("Z", 1)], {("Y", 0): 0, ("Y", 1): 0}) |
48 | 53 | np_test.assert_array_almost_equal(query_result[("Z", 1)].values, np.array([0.95080214, 0.04919786])) |
49 | | - query_result = self.dbn_inference_2.forward_inference([("X", 2)], {("Y", 0): 1, ("Y", 1): 0, ("Y", 2): 1}) |
| 54 | + query_result = dbn_inference_2.forward_inference([("X", 2)], {("Y", 0): 1, ("Y", 1): 0, ("Y", 2): 1}) |
50 | 55 | np_test.assert_array_almost_equal(query_result[("X", 2)].values, np.array([0.76738736, 0.23261264])) |
51 | 56 |
|
52 | | - def test_forward_inf_multiple_variable_with_evidence(self): |
53 | | - query_result = self.dbn_inference_1.forward_inference([("Z", 1), ("X", 1)], {("Y", 0): 0, ("Y", 1): 0}) |
| 57 | + def test_forward_inf_multiple_variable_with_evidence(self, dbn_inference_instances): |
| 58 | + dbn_inference_1, _ = dbn_inference_instances |
| 59 | + query_result = dbn_inference_1.forward_inference([("Z", 1), ("X", 1)], {("Y", 0): 0, ("Y", 1): 0}) |
54 | 60 | np_test.assert_array_almost_equal(query_result[("Z", 1)].values, np.array([0.95080214, 0.04919786])) |
55 | 61 |
|
56 | 62 | np_test.assert_array_almost_equal(query_result[("X", 1)].values, np.array([0.88524064, 0.11475936])) |
57 | 63 |
|
58 | | - def test_backward_inf_single_variable(self): |
59 | | - query_result = self.dbn_inference_2.backward_inference([("Y", 0)]) |
| 64 | + def test_backward_inf_single_variable(self, dbn_inference_instances): |
| 65 | + _, dbn_inference_2 = dbn_inference_instances |
| 66 | + query_result = dbn_inference_2.backward_inference([("Y", 0)]) |
60 | 67 | np_test.assert_array_almost_equal(query_result[("Y", 0)].values, np.array([0.225, 0.775])) |
61 | 68 |
|
62 | | - def test_backward_inf_multiple_variables(self): |
63 | | - query_result = self.dbn_inference_2.backward_inference([("X", 0), ("Y", 0)]) |
| 69 | + def test_backward_inf_multiple_variables(self, dbn_inference_instances): |
| 70 | + _, dbn_inference_2 = dbn_inference_instances |
| 71 | + query_result = dbn_inference_2.backward_inference([("X", 0), ("Y", 0)]) |
64 | 72 | np_test.assert_array_almost_equal(query_result[("X", 0)].values, np.array([0.75, 0.25])) |
65 | 73 | np_test.assert_array_almost_equal(query_result[("Y", 0)].values, np.array([0.225, 0.775])) |
66 | 74 |
|
67 | | - def test_backward_inf_single_variable_with_evidence(self): |
68 | | - query_result = self.dbn_inference_2.backward_inference([("X", 0)], {("Y", 0): 0, ("Y", 1): 1, ("Y", 2): 1}) |
| 75 | + def test_backward_inf_single_variable_with_evidence(self, dbn_inference_instances): |
| 76 | + dbn_inference_1, dbn_inference_2 = dbn_inference_instances |
| 77 | + query_result = dbn_inference_2.backward_inference([("X", 0)], {("Y", 0): 0, ("Y", 1): 1, ("Y", 2): 1}) |
69 | 78 | np_test.assert_array_almost_equal(query_result[("X", 0)].values, np.array([0.66594382, 0.33405618])) |
70 | 79 |
|
71 | | - query_result = self.dbn_inference_1.backward_inference([("Z", 1)], {("Y", 0): 0, ("Y", 1): 0, ("Y", 2): 0}) |
| 80 | + query_result = dbn_inference_1.backward_inference([("Z", 1)], {("Y", 0): 0, ("Y", 1): 0, ("Y", 2): 0}) |
72 | 81 | np_test.assert_array_almost_equal(query_result[("Z", 1)].values, np.array([0.98048698, 0.01951302])) |
73 | 82 |
|
74 | | - def test_backward_inf_multiple_variables_with_evidence(self): |
75 | | - query_result = self.dbn_inference_2.backward_inference( |
76 | | - [("X", 0), ("X", 1)], {("Y", 0): 0, ("Y", 1): 1, ("Y", 2): 1} |
77 | | - ) |
| 83 | + def test_backward_inf_multiple_variables_with_evidence(self, dbn_inference_instances): |
| 84 | + _, dbn_inference_2 = dbn_inference_instances |
| 85 | + query_result = dbn_inference_2.backward_inference([("X", 0), ("X", 1)], {("Y", 0): 0, ("Y", 1): 1, ("Y", 2): 1}) |
78 | 86 | np_test.assert_array_almost_equal(query_result[("X", 0)].values, np.array([0.677533, 0.322467])) |
79 | 87 | np_test.assert_array_almost_equal(query_result[("X", 1)].values, np.array([0.7621772, 0.2378228])) |
80 | 88 |
|
@@ -184,7 +192,7 @@ def test_github_issue_1794(self): |
184 | 192 | inference = DBNInference(dbn) |
185 | 193 |
|
186 | 194 | # Basic sanity checks |
187 | | - self.assertIsNotNone(inference.start_bayesian_model) |
188 | | - self.assertIsNotNone(inference.one_and_half_model) |
189 | | - self.assertIsNotNone(inference.start_junction_tree) |
190 | | - self.assertIsNotNone(inference.one_and_half_junction_tree) |
| 195 | + assert inference.start_bayesian_model is not None |
| 196 | + assert inference.one_and_half_model is not None |
| 197 | + assert inference.start_junction_tree is not None |
| 198 | + assert inference.one_and_half_junction_tree is not None |
0 commit comments