Skip to content

Commit

Permalink
test(sampler): fit to add --ratios
Browse files Browse the repository at this point in the history
  • Loading branch information
breakthewall committed Jul 26, 2024
1 parent c518895 commit ad69e2b
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions tests/test_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def setUp(self):
def test_generate_lhs_samples_normal(self, mock_read_csv):
mock_read_csv.return_value = self.components_df

result = generate_lhs_samples("fake_path.csv", self.num_samples, self.step, None, self.seed)
result = generate_lhs_samples("fake_path.csv", self.num_samples, self.step, None, None, self.seed)

self.assertEqual(result.shape, (self.num_samples, 3))
self.assertListEqual(list(result.columns), ['A', 'B', 'C'])
Expand All @@ -28,7 +28,7 @@ def test_generate_lhs_samples_normal(self, mock_read_csv):
def test_generate_lhs_samples_no_seed(self, mock_read_csv):
mock_read_csv.return_value = self.components_df

result = generate_lhs_samples("fake_path.csv", self.num_samples, self.step, None, None)
result = generate_lhs_samples("fake_path.csv", self.num_samples, self.step, None, None, None)

self.assertEqual(result.shape, (self.num_samples, 3))
self.assertListEqual(list(result.columns), ['A', 'B', 'C'])
Expand All @@ -39,7 +39,7 @@ def test_generate_lhs_samples_edge_case_zero_maxValue(self, mock_read_csv):
edge_case_df.loc[0, 'maxValue'] = 0 # Set maxValue of component 'A' to 0
mock_read_csv.return_value = edge_case_df

result = generate_lhs_samples("fake_path.csv", self.num_samples, self.step, None, self.seed)
result = generate_lhs_samples("fake_path.csv", self.num_samples, self.step, None, None, self.seed)

self.assertEqual(result.shape, (self.num_samples, 3))
self.assertTrue((result['A'] == 0).all()) # All values in column 'A' should be zero
Expand All @@ -49,20 +49,20 @@ def test_generate_lhs_samples_invalid_step(self, mock_read_csv):
mock_read_csv.return_value = self.components_df

with self.assertRaises(IndexError):
generate_lhs_samples("fake_path.csv", self.num_samples, -2.5, None, self.seed) # Negative step size should raise an error
generate_lhs_samples("fake_path.csv", self.num_samples, -2.5, None, None, self.seed) # Negative step size should raise an error

@patch("icfree.sampler.pd.read_csv")
def test_generate_lhs_samples_invalid_input_file(self, mock_read_csv):
mock_read_csv.side_effect = FileNotFoundError

with self.assertRaises(FileNotFoundError):
generate_lhs_samples("invalid_path.csv", self.num_samples, self.step, None, self.seed)
generate_lhs_samples("invalid_path.csv", self.num_samples, self.step, None, None, self.seed)

@patch("icfree.sampler.pd.read_csv")
def test_generate_lhs_samples_fix_component_value(self, mock_read_csv):
mock_read_csv.return_value = self.components_df

result = generate_lhs_samples("fake_path.csv", self.num_samples, self.step, {'A': 5}, self.seed)
result = generate_lhs_samples("fake_path.csv", self.num_samples, self.step, None, {'A': 5}, self.seed)

self.assertEqual(result.shape, (self.num_samples, 3))
self.assertTrue((result['A'] == 5).all())
Expand Down

0 comments on commit ad69e2b

Please sign in to comment.