Skip to content

Commit

Permalink
revise tests
Browse files Browse the repository at this point in the history
  • Loading branch information
SamWitty committed Oct 30, 2023
1 parent 025b2ae commit c47195b
Showing 1 changed file with 15 additions and 7 deletions.
22 changes: 15 additions & 7 deletions tests/test_interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@
def test_sample_no_interventions(
url, start_time, end_time, logging_step_size, num_samples
):
result = sample(url, start_time, end_time, logging_step_size, num_samples)
result = sample(
url, end_time, logging_step_size, num_samples, start_time=start_time
)
assert isinstance(result, dict)
check_result_sizes(result, start_time, end_time, logging_step_size, num_samples)

Expand Down Expand Up @@ -51,14 +53,16 @@ def test_sample_with_static_interventions(

intervened_result = sample(
url,
start_time,
end_time,
logging_step_size,
num_samples,
start_time=start_time,
static_interventions=static_interventions,
)

result = sample(url, start_time, end_time, logging_step_size, num_samples)
result = sample(
url, end_time, logging_step_size, num_samples, start_time=start_time
)

check_states_match_in_all_but_values(result, intervened_result)
check_result_sizes(result, start_time, end_time, logging_step_size, num_samples)
Expand Down Expand Up @@ -97,14 +101,16 @@ def intervention_event_fn_2(time: torch.Tensor, *args, **kwargs):

intervened_result = sample(
url,
start_time,
end_time,
logging_step_size,
num_samples,
start_time=start_time,
dynamic_interventions=dynamic_interventions,
)

result = sample(url, start_time, end_time, logging_step_size, num_samples)
result = sample(
url, end_time, logging_step_size, num_samples, start_time=start_time
)

check_states_match_in_all_but_values(result, intervened_result)
check_result_sizes(result, start_time, end_time, logging_step_size, num_samples)
Expand Down Expand Up @@ -139,15 +145,17 @@ def intervention_event_fn_1(time: torch.Tensor, *args, **kwargs):

intervened_result = sample(
url,
start_time,
end_time,
logging_step_size,
num_samples,
start_time=start_time,
static_interventions=static_interventions,
dynamic_interventions=dynamic_interventions,
)

result = sample(url, start_time, end_time, logging_step_size, num_samples)
result = sample(
url, end_time, logging_step_size, num_samples, start_time=start_time
)

check_states_match_in_all_but_values(result, intervened_result)
check_result_sizes(result, start_time, end_time, logging_step_size, num_samples)
Expand Down

0 comments on commit c47195b

Please sign in to comment.