-
Notifications
You must be signed in to change notification settings - Fork 414
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature Request] Enable Inter-Parameter Constriants in optimize_acqf_mixed
#2280
Comments
Thanks for raising this! Seems like an interesting advanced use case of BoTorch, would be great if we could properly support it.
Yes that would be the obvious / naive thing to do - I guess one could safeguard this by only doing this within reason (e.g. only if the total number of discrete combinations is less than some maximum threshold, and error out otherwise). That at least would make this work for some use cases (not sure how prevalent these are in practice).
I am not sure I fully understand this suggestion - is this just the standard |
To prevent a combinatorical explosion one could also think about performing a random optimization over the combinatorical search space, which would mean that one samples |
To implement this one would need to do the following things:
For me, this would be an interesting feature, that opens up new opportunities. @Balandat what do you think? |
Hi, thanks for your replies! def optimize_acqf_mixed( acq_function: AcquisitionFunction, bounds: Tensor, q: int, num_restarts: int, fixed_features_list: List[Dict[int, float]], raw_samples: Optional[int] = None, options: Optional[Dict[str, Union[bool, float, int, str]]] = None, inequality_constraints: Optional[List[Tuple[Tensor, Tensor, float]]] = None, equality_constraints: Optional[List[Tuple[Tensor, Tensor, float]]] = None, nonlinear_inequality_constraints: Optional[List[Tuple[Callable, bool]]] = None, post_processing_func: Optional[Callable[[Tensor], Tensor]] = None, batch_initial_conditions: Optional[Tensor] = None, ic_generator: Optional[TGenInitialConditions] = None, ic_gen_kwargs: Optional[Dict] = None, **kwargs: Any, ) -> Tuple[Tensor, Tensor]: if not fixed_features_list: raise ValueError("fixed_features_list must be non-empty.") if isinstance(acq_function, OneShotAcquisitionFunction): if not hasattr(acq_function, "evaluate") and q > 1: raise ValueError( "`OneShotAcquisitionFunction`s that do not implement `evaluate` " "are currently not supported when `q > 1`. This is needed to " "compute the joint acquisition value." ) _raise_deprecation_warning_if_kwargs("optimize_acqf_mixed", kwargs) ic_gen_kwargs = ic_gen_kwargs or {} if q == 1: ff_candidate_list, ff_acq_value_list = [], [] for fixed_features in fixed_features_list: candidate, acq_value = optimize_acqf( acq_function=acq_function, bounds=bounds, q=q, num_restarts=num_restarts, raw_samples=raw_samples, options=options or {}, inequality_constraints=inequality_constraints, equality_constraints=equality_constraints, nonlinear_inequality_constraints=nonlinear_inequality_constraints, fixed_features=fixed_features, post_processing_func=post_processing_func, batch_initial_conditions=batch_initial_conditions, ic_generator=ic_generator, return_best_only=True, **ic_gen_kwargs, ) ff_candidate_list.append(candidate) ff_acq_value_list.append(acq_value) ff_acq_values = torch.stack(ff_acq_value_list) best = torch.argmax(ff_acq_values) return ff_candidate_list[best], ff_acq_values[best] # For batch optimization with q > 1 we do not want to enumerate all n_combos^n # possible combinations of discrete choices. Instead, we use sequential greedy # optimization. base_X_pending = acq_function.X_pending candidates = torch.tensor([], device=bounds.device, dtype=bounds.dtype) # check if inter-point constraints are present inter_point = any( len(indices.shape) > 1 for constraints in (inequality_constraints or [], equality_constraints or []) for indices, _, _ in constraints ) equality_constraints_intra = [constr for constr in equality_constraints or [] if len(constr[0].shape) == 1] or None equality_constraints_inter = [constr for constr in equality_constraints or [] if len(constr[0].shape) > 1] or None inequality_constraints_intra = [constr for constr in inequality_constraints or [] if len(constr[0].shape) == 1] or None inequality_constraints_inter = [constr for constr in inequality_constraints or [] if len(constr[0].shape) > 1] or None base_fixed_features_list = [{par_idx : par_fixed_val for par_idx, par_fixed_val in ff_dict.items()} for ff_dict in fixed_features_list] for batch in range(q): candidate, acq_value = optimize_acqf_mixed( acq_function=acq_function, bounds=bounds, q=1, num_restarts=num_restarts, raw_samples=raw_samples, fixed_features_list=fixed_features_list, options=options or {}, inequality_constraints=inequality_constraints_intra, equality_constraints=equality_constraints_intra, nonlinear_inequality_constraints=nonlinear_inequality_constraints, post_processing_func=post_processing_func, batch_initial_conditions=batch_initial_conditions, ic_generator=ic_generator, ic_gen_kwargs=ic_gen_kwargs, ) candidates = torch.cat([candidates, candidate], dim=-2) acq_function.set_X_pending( torch.cat([base_X_pending, candidates], dim=-2) if base_X_pending is not None else candidates ) if inter_point: # adjust bounds and fixed_features_list according to the inter-parameter constraints. # In other words: find out what is the still feasible values of the parameters # given the parameter values from the already chosen batches and the # flexibility from the remaining batches? (possible approach: optimal control) # if the feasable values are a range: # adjust the boundes # elif the feasable values are a one specfic value: # adjust the fixed_features_list # drop duplicates in case the constrained parameter was already in in the fixed_features_list fixed_features_list = base_fixed_features_list.copy() for eq_constr_inter in equality_constraints_inter or []: # This is a test implementation for the case x_m(batch_i)*1 + x_n(batch_j)*(-1) == 0 if eq_constr_inter[0].size(dim=0) == 2: if eq_constr_inter[1][0] == -eq_constr_inter[1][1] and eq_constr_inter[2] == 0: if (eq_constr_inter[0][0,0] <= batch or eq_constr_inter[0][1,0] <= batch) and (eq_constr_inter[0][0,0] == 1+batch or eq_constr_inter[0][1,0] == 1+batch): past_batch = eq_constr_inter[0][0,:] if eq_constr_inter[0][0,0] <= batch else eq_constr_inter[0][1,:] next_batch = eq_constr_inter[0][0,:] if eq_constr_inter[0][0,0] == 1+batch else eq_constr_inter[0][1,:] for fixed_features in fixed_features_list: fixed_features.update({int(next_batch[1]): float(candidates[list(past_batch)]) }) fixed_features_list = [ff for n, ff in enumerate(fixed_features_list) if ff not in fixed_features_list[:n]] #fixed_features_list.drop_duplicates() else: raise NotImplementedError("Only equality inter-parameter constraints constraining two candidate parameters to be equal to each other (e.g. x1(batch_0)*1 + x1(batch_1)*(-1) == 0 ) are implemented so far in mixed acqistition function optimization.") else: raise NotImplementedError("Only equality inter-parameter constraints constraining two candidate parameters in mixed acqistition function optimization are implemented so far.") for ieq_constr_inter in inequality_constraints_inter or []: raise NotImplementedError("the consideration of inequality inter-parameter constraints in mixed acqistition function optimization is not yet implemented.") acq_function.set_X_pending(base_X_pending) # compute joint acquisition value if isinstance(acq_function, OneShotAcquisitionFunction): acq_value = acq_function.evaluate(X=candidates, bounds=bounds) else: acq_value = acq_function(candidates) return candidates, acq_value |
@jduerholt yes, I think this could be useful. I'm not sure if I would choose to modify the signature of |
@MoritzEckhoff So this is basically a greedy solution with respect to the constraints. How would we ensure that committing to some values in the first elements of the q-batch will render the subsequent optimizations subject to those fixed features feasible? It seems like this not straightforward to ensure unless we consider the fixed features jointly across the q-batch (as @jduerholt's solution above would). |
Yes, one would need to jointly recalculate a feasible set of |
@Balandat: I will put it to my list, but it can take some time until this will land ;) I see advantages not only for interpoint constraint but also for point specific fixed features in a q-batch in combination with joint optimization. |
🚀 Feature Request
The function
optimize_acqf_mixed
in botorch/optim/optimize.py currently runs a sequential greedy optimization for q>1. Because of this, the function cannot consider inter-parameter constraints and even encounters anIndexError
when trying to normalize the inter-parameter constraints.Motivation
Is your feature request related to a problem? Please describe.
The motivation may be two-fold:
optimize_acqf_mixed
tooptimize_acqf
. After implementing this request, both functions would accept the same types of constraints.optimize_acqf_mixed
is the optimization function to go with. Moreover, batched optimization withq>1
allows the execution of optimized experiments in parallel, saving lab resources. Then, however, inter-parameter constraints are needed when dealing with parameters such as the experiment's temperature (which will be applied to all samples in a batch). Inter-parameter constraints considered byoptimize_acqf_mixed
would enable the planning of batched chemistry experiments.Pitch
Describe the solution you'd like
One option would be to implement a case distinction:
In the joint optimization, one can either (a) enumerate all n_combos^n possible combinations, which will probably be expensive, or (b) directly use the provided
fixed_feature_list
without enumerating all combinations as proposed in the code snipped below. If we implement option (b), will there be any loss in optimality if only inter-parameter constraints are present? I am open for discussion.Are you willing to open a pull request? (See CONTRIBUTING)
It would be my first one, but if we find a good solution, I am happy to help implement it.
The text was updated successfully, but these errors were encountered: