Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updated optimize to handle combination of multiple intervention templates #591

Merged
merged 21 commits into from
Jul 31, 2024

Conversation

anirban-chaudhuri
Copy link
Contributor

@anirban-chaudhuri anirban-chaudhuri commented Jul 16, 2024

  • created a lambda function to combine different intervention templates
  • updated type hinting for static_parameter_interventions input to optimize to accommodate lambda functions
  • downstream changes in ouu.py to handle lambda functions

- updated typing for interventions input to optimize
- created a lambda function to combine different intervention templates
@anirban-chaudhuri anirban-chaudhuri self-assigned this Jul 16, 2024
Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@anirban-chaudhuri anirban-chaudhuri linked an issue Jul 16, 2024 that may be closed by this pull request
@anirban-chaudhuri anirban-chaudhuri added WIP PR submitter still making changes, not ready for review integration Tasks for integration with TA4 labels Jul 16, 2024
@anirban-chaudhuri
Copy link
Contributor Author

anirban-chaudhuri commented Jul 16, 2024

@SamWitty Any ideas on how to make this cleaner and if there is a way to create a generic template for combining the different intervention builders into a lambda function?

  • Currently the type hint has become a bit complicated for interventions in optimize and we might want to think of a way to simplify it, if possible
  • The end of the optimize_interface has the lambda function implementation and the downstream changes required to accommodate that change in ouu.py

@anirban-chaudhuri anirban-chaudhuri added awaiting review PR submitter awaiting code review from reviewer WIP PR submitter still making changes, not ready for review and removed WIP PR submitter still making changes, not ready for review awaiting review PR submitter awaiting code review from reviewer labels Jul 16, 2024
@anirban-chaudhuri anirban-chaudhuri added awaiting review PR submitter awaiting code review from reviewer and removed WIP PR submitter still making changes, not ready for review labels Jul 16, 2024
Copy link
Contributor

@SamWitty SamWitty left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this PR addresses the issue, but it would be nice to have the explicit intervention template combinator provided (which I believe is now commented out) instead of combining them manually in a lambda function. This might be worth hopping on a chat and working through together.

@@ -771,8 +771,12 @@ def optimize(
logging_step_size: float,
qoi: Callable,
risk_bound: float,
static_parameter_interventions: Callable[
[torch.Tensor], Dict[float, Dict[str, Intervention]]
static_parameter_interventions: Union[
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is fine. If you want you can make this a bit more concise with a type alias.

E.g. https://github.com/BasisResearch/chirho/blob/master/chirho/dynamical/ops.py#L14

Here, you could add a line with the following:

InterventionFunc = Callable[[torch.Tensor], Dict[float, Dict[str, Intervention]]]

Then this type could be:

Union[InterventionFunc, Callable[[torch.Tensor], List[InterventionFunc]]

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking again, is this type signature correct? Should it be:

Callable[[torch.Tensor], Dict[float, Dict[str, Intervention]]],
        Callable[
            [torch.Tensor],
            List[Dict[float, Dict[str, Intervention]]],
        ]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure. It might be. I was thinking once I create the lambda function it becomes a Callable of Callable functions.. Let's discuss when we meet.

@@ -102,6 +114,35 @@ def intervention_generator(
return intervention_generator


# def combine_intervention_templates(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do think it would be good to provide this instead of the manual lambda function expansion you have in the tests.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I completely agree. I couldn't write a bug free one so I left it. We can hop on a call to make it happen.

@SamWitty SamWitty added awaiting response PR reviewer awaiting response from submitter and removed awaiting review PR submitter awaiting code review from reviewer labels Jul 24, 2024
@anirban-chaudhuri anirban-chaudhuri added WIP PR submitter still making changes, not ready for review and removed awaiting response PR reviewer awaiting response from submitter labels Jul 31, 2024
@anirban-chaudhuri anirban-chaudhuri added awaiting review PR submitter awaiting code review from reviewer and removed WIP PR submitter still making changes, not ready for review labels Jul 31, 2024
Copy link
Contributor

@SamWitty SamWitty left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good. One minor nitpick.

I think having the interventions be mapping from float to interventions is the right idea. We can make the interfaces.py consistent with this in a separate PR if they've drifted.

]
intervention_list.extend(
[self.interventions(torch.from_numpy(x))]
# if not isinstance(self.interventions(torch.from_numpy(x)), list)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you remove this commented out code?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@SamWitty SamWitty added awaiting response PR reviewer awaiting response from submitter and removed awaiting review PR submitter awaiting code review from reviewer labels Jul 31, 2024
@anirban-chaudhuri
Copy link
Contributor Author

Looks good. One minor nitpick.

I think having the interventions be mapping from float to interventions is the right idea. We can make the interfaces.py consistent with this in a separate PR if they've drifted.

Sounds good. I will leave my comment from slack here so that we have the specific details.

Type hinting needs to be fixed for interventions: Should it be Dict[float, Dict[str, Intervention]] or Dict[torch.Tensor, Dict[str, Intervention]] ? Both are used in the interfaces.py.

@anirban-chaudhuri anirban-chaudhuri added awaiting review PR submitter awaiting code review from reviewer and removed awaiting response PR reviewer awaiting response from submitter labels Jul 31, 2024
@SamWitty SamWitty merged commit adeb6b9 into main Jul 31, 2024
5 checks passed
@anirban-chaudhuri anirban-chaudhuri deleted the ac-multiInterventions branch September 4, 2024 19:38
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
awaiting review PR submitter awaiting code review from reviewer integration Tasks for integration with TA4
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Combine different intervention builder callables in optimize
2 participants