-
Notifications
You must be signed in to change notification settings - Fork 6
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Updated optimize to handle combination of multiple intervention templates #591
Conversation
- updated typing for interventions input to optimize - created a lambda function to combine different intervention templates
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
@SamWitty Any ideas on how to make this cleaner and if there is a way to create a generic template for combining the different intervention builders into a lambda function?
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this PR addresses the issue, but it would be nice to have the explicit intervention template combinator provided (which I believe is now commented out) instead of combining them manually in a lambda function. This might be worth hopping on a chat and working through together.
pyciemss/interfaces.py
Outdated
@@ -771,8 +771,12 @@ def optimize( | |||
logging_step_size: float, | |||
qoi: Callable, | |||
risk_bound: float, | |||
static_parameter_interventions: Callable[ | |||
[torch.Tensor], Dict[float, Dict[str, Intervention]] | |||
static_parameter_interventions: Union[ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is fine. If you want you can make this a bit more concise with a type alias.
E.g. https://github.com/BasisResearch/chirho/blob/master/chirho/dynamical/ops.py#L14
Here, you could add a line with the following:
InterventionFunc = Callable[[torch.Tensor], Dict[float, Dict[str, Intervention]]]
Then this type could be:
Union[InterventionFunc, Callable[[torch.Tensor], List[InterventionFunc]]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking again, is this type signature correct? Should it be:
Callable[[torch.Tensor], Dict[float, Dict[str, Intervention]]],
Callable[
[torch.Tensor],
List[Dict[float, Dict[str, Intervention]]],
]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure. It might be. I was thinking once I create the lambda function it becomes a Callable of Callable functions.. Let's discuss when we meet.
@@ -102,6 +114,35 @@ def intervention_generator( | |||
return intervention_generator | |||
|
|||
|
|||
# def combine_intervention_templates( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I do think it would be good to provide this instead of the manual lambda function expansion you have in the tests.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I completely agree. I couldn't write a bug free one so I left it. We can hop on a call to make it happen.
…ook (which should be deleted) (#595) Co-authored-by: Sam Witty <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good. One minor nitpick.
I think having the interventions be mapping from float to interventions is the right idea. We can make the interfaces.py consistent with this in a separate PR if they've drifted.
pyciemss/ouu/ouu.py
Outdated
] | ||
intervention_list.extend( | ||
[self.interventions(torch.from_numpy(x))] | ||
# if not isinstance(self.interventions(torch.from_numpy(x)), list) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you remove this commented out code?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
Sounds good. I will leave my comment from slack here so that we have the specific details. Type hinting needs to be fixed for interventions: Should it be |
static_parameter_interventions
input tooptimize
to accommodate lambda functionsouu.py
to handle lambda functions