-
Notifications
You must be signed in to change notification settings - Fork 53
Refactor Conditional GFlowNets #431
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
younik
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a few comments; good to go for me, but I would wait for @josephdviviano as he understands this code better
src/gfn/containers/trajectories.py
Outdated
| # Concatenate conditions of the trajectories. | ||
| if self.conditions is not None and other.conditions is not None: | ||
| self.conditions = torch.cat((self.conditions, other.conditions), dim=0) | ||
| else: | ||
| self.conditions = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we maybe add a test for extending with conditions, and then try common ops like get_item to check the output is as expected?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we maybe add a test for extending with conditions
I will add one.
and then try common ops like get_item to check the output is as expected?
I have no idea what this means. Could you elaborate more?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I mean in the test, after calling extend, check if the extend operation gave the expected result.
Like here:
torchgfn/testing/test_states.py
Lines 432 to 454 in c3f3096
| pre_extend_shape = state2.batch_shape | |
| state1.extend(state2) | |
| assert state2.batch_shape == pre_extend_shape | |
| # Check final shape should be (max_len=3, B=4) | |
| assert state1.batch_shape == (3, 4) | |
| # The actual count might be higher due to padding with sink states | |
| assert state1.tensor.x.size(0) == expected_nodes | |
| assert state1.tensor.num_edges == expected_edges | |
| # Check if states are extended as expected | |
| assert (state1[0, 0].tensor.x == datas[0].x).all() | |
| assert (state1[0, 1].tensor.x == datas[1].x).all() | |
| assert (state1[0, 2].tensor.x == datas[4].x).all() | |
| assert (state1[0, 3].tensor.x == datas[5].x).all() | |
| assert (state1[1, 0].tensor.x == datas[2].x).all() | |
| assert (state1[1, 1].tensor.x == datas[3].x).all() | |
| assert (state1[1, 2].tensor.x == datas[6].x).all() | |
| assert (state1[1, 3].tensor.x == datas[7].x).all() | |
| assert (state1[2, 0].tensor.x == MyGraphStates.sf.x).all() | |
| assert (state1[2, 1].tensor.x == MyGraphStates.sf.x).all() | |
| assert (state1[2, 2].tensor.x == datas[8].x).all() | |
| assert (state1[2, 3].tensor.x == datas[9].x).all() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see. I will add a test soon!
src/gfn/env.py
Outdated
| def reward(self, states: States, conditions: torch.Tensor) -> torch.Tensor: | ||
| """Compute rewards for the conditional environment. | ||
| Args: | ||
| states: The states to compute rewards for. | ||
| states.tensor.shape should be (batch_size, *state_shape) | ||
| conditions: The conditions to compute rewards for. | ||
| conditions.shape should be (batch_size, condition_vector_dim) | ||
| Returns: | ||
| A tensor of shape (batch_size,) containing the rewards. | ||
| """ | ||
| raise NotImplementedError |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
aha, this is not a real subclass of Env, as conditions are mandatory (i.e. if you can't call this function pretending it is an env obj while it is ConditionEnv).
Would it make sense to have a default condition?
If not, this shouldn't inehrit from Env probably.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would it make sense to have a default condition?
How could having a default condition solve the problem?
If not, this shouldn't inherit from Env probably.
Maybe, but still we need a parent class that defines the default methods for Envs, like reward, step, etc...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How could having a default condition solve the problem?
If we have a function like this:
def get_reward(env: Env, states: States) -> torch.Tensor:
return env.reward(states)This should work with any Env object, given the interface of Env.
However, currently, if I pass a ConditionEnv (which is an Env), this will fail as you need to specify the conditioning. If you have a default value for conditioning, now the get_reward function will work properly (indeed, with default, the reward function interface of ConditionEnv becomes a subtype of the one of Env)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
An alternative approach would be to have the conditions live inside the states themselves (states could have a conditioning field that is None unless conditioning is required, and then anything that accepts States follows a different path when conditioning is present).
The env itself would only be conditional or not depending on the logic the user defines in the reward and step functions. No actual ConditionalEnv class would be required.
The estimators would also optionally use the conditioning information, if it's present, just like how it's done currently.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now I'm seeing something I never noticed before - this class makes the hot path for calculating from conditions tensor-based, which may or may not be more torch.compile friendly than using conditions in the states class.
The use of a ConditionalEnv is growing on me. I don't mind the changing API, but I would prefer if this logic was somehow all in the Env directly somehow. I keep changing my mind on the best design. I suppose it depends on whether we think putting the conditions in States is ultimately a good design.
josephdviviano
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall a really nice PR, but I have a few questions about changes that seem unrelated to the goal (in particular I think we remove a few checks that might have side effects not captured in our test suites) and I wonder if it would be cleaner for the conditioning to live directly within the States class which would help avoid a lot of added complexity. We can discuss in the standup. Great work!
| self.conditions = conditions | ||
| assert self.conditions is None or ( | ||
| self.conditions.shape[: len(batch_shape)] == batch_shape | ||
| len(self.conditions.shape) == 2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
right, because we assume the conditioning would not change through the trajectory?
| self._log_rewards[self.is_terminating] = self.env.log_reward( | ||
| if isinstance(self.env, ConditionalEnv): | ||
| assert self.conditions is not None | ||
| log_reward_fn = partial( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice!
| # Assign rewards to valid terminating states. | ||
| terminating_mask = is_terminating & ( | ||
| valid_batch_indices == (self.terminating_idx[valid_traj_indices] - 1) | ||
| log_rewards[self.terminating_idx - 1, torch.arange(len(self))] = ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
really nice cleanup here!
src/gfn/env.py
Outdated
| def reward(self, states: States, conditions: torch.Tensor) -> torch.Tensor: | ||
| """Compute rewards for the conditional environment. | ||
| Args: | ||
| states: The states to compute rewards for. | ||
| states.tensor.shape should be (batch_size, *state_shape) | ||
| conditions: The conditions to compute rewards for. | ||
| conditions.shape should be (batch_size, condition_vector_dim) | ||
| Returns: | ||
| A tensor of shape (batch_size,) containing the rewards. | ||
| """ | ||
| raise NotImplementedError |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
An alternative approach would be to have the conditions live inside the states themselves (states could have a conditioning field that is None unless conditioning is required, and then anything that accepts States follows a different path when conditioning is present).
The env itself would only be conditional or not depending on the logic the user defines in the reward and step functions. No actual ConditionalEnv class would be required.
The estimators would also optionally use the conditioning information, if it's present, just like how it's done currently.
|
|
||
| from gfn.containers import StatesContainer, Trajectories | ||
| from gfn.env import DiscreteEnv | ||
| from gfn.env import Env |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is technically wrong because FlowMatching won't work for continuous environments.
| ) | ||
|
|
||
| self._all_states_tensor = all_states_tensor | ||
| if self.store_all_states: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice, thanks for this addition :)
| valid_states = trajectories.states[state_mask] | ||
| valid_actions = trajectories.actions[action_mask] | ||
|
|
||
| if valid_states.batch_shape != valid_actions.batch_shape: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why are you removing this stuff? I thought this was a useful check.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I disagree with removing this assert
| # Build distribution for active rows and compute step log-probs | ||
| # TODO: masking ctx with step_mask outside of compute_dist and log_probs, | ||
| # i.e., implement __getitem__ for ctx. (maybe we should contain only the | ||
| # tensors, and not additional metadata like the batch size, device, etc.) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Masking of ctx should already be handled. Or are you suggesting it should be handled in this logic here (i.e., generic)?
| valid_step_actions.tensor, dist, ctx, step_mask, vectorized=False | ||
| ) | ||
|
|
||
| # Pad back to full batch size. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why did you remove this? It's important.
josephdviviano
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For now, I'll leave comments - we can decide what to do with the other PR before deciding what to do for this one.
But I must say there's a lot of good work here. Thank you, I'm sure much of this will be a good improvement to the library!
src/gfn/containers/trajectories.py
Outdated
| new_max_length = terminating_idx.max().item() if len(terminating_idx) > 0 else 0 | ||
| states = self.states[:, index] | ||
| conditions = self.conditions[:, index] if self.conditions is not None else None | ||
| conditions = self.conditions[index] if self.conditions is not None else None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so this is indexing the batch dimension? since the condition is static for the whole trajectory?
src/gfn/containers/trajectories.py
Outdated
| # We need to index the conditions tensor to match the actions | ||
| # The actions exclude the last step, so we need to exclude the last step from conditions | ||
| conditions = self.conditions[:-1][~self.actions.is_dummy] | ||
| # The conditions tensor has shape (n_trajectories, condition_vector_dim) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is n_trajectories, batch_dim? That naming is a bit confusing because there's also trajectory_length.
src/gfn/containers/trajectories.py
Outdated
| # The conditions tensor has shape (n_trajectories, condition_vector_dim) | ||
| # The actions have batch shape (max_length, n_trajectories) | ||
| # We need to repeat the condition vector tensor to match the actions | ||
| conditions = self.conditions.repeat(self.actions.batch_shape[0], 1, 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you add inline batch dim notation here e.g., # (T, B, C) for trajectrory_length, batch_dim, conditioning_dim.
src/gfn/containers/trajectories.py
Outdated
| # The conditions tensor has shape (n_trajectories, condition_vector_dim) | ||
| # The states have batch shape (max_length, n_trajectories) | ||
| # We need to repeat the conditions to match the batch shape of the states. | ||
| conditions = self.conditions.repeat(self.states.batch_shape[0], 1, 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto as above
src/gfn/env.py
Outdated
| def reward(self, states: States, conditions: torch.Tensor) -> torch.Tensor: | ||
| """Compute rewards for the conditional environment. | ||
| Args: | ||
| states: The states to compute rewards for. | ||
| states.tensor.shape should be (batch_size, *state_shape) | ||
| conditions: The conditions to compute rewards for. | ||
| conditions.shape should be (batch_size, condition_vector_dim) | ||
| Returns: | ||
| A tensor of shape (batch_size,) containing the rewards. | ||
| """ | ||
| raise NotImplementedError |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now I'm seeing something I never noticed before - this class makes the hot path for calculating from conditions tensor-based, which may or may not be more torch.compile friendly than using conditions in the states class.
The use of a ConditionalEnv is growing on me. I don't mind the changing API, but I would prefer if this logic was somehow all in the Env directly somehow. I keep changing my mind on the best design. I suppose it depends on whether we think putting the conditions in States is ultimately a good design.
| if not env.is_discrete: | ||
| raise NotImplementedError( | ||
| "Flow Matching GFlowNet only supports discrete environments for now." | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see, it's handled here.
| valid_states = trajectories.states[state_mask] | ||
| valid_actions = trajectories.actions[action_mask] | ||
|
|
||
| if valid_states.batch_shape != valid_actions.batch_shape: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I disagree with removing this assert
Place conditions within States
Description
Major refactorings for conditional GFlowNets.
ConditionalEnvas a new abstract class for an environment with a conditional rewardTrajectories.conditionshave a shape of(n_trajectories, condition_vector_dim), simplifying many shape-related logics.train_conditional.pyexample (before,true_distfor the validation was wrong.)TODO (maybe in another PR?)
ConditionalEnvsupport conditional transitions