Skip to content

Commit

Permalink
fixed utils.py
Browse files Browse the repository at this point in the history
  • Loading branch information
thomaswmorris committed Jul 19, 2023
1 parent 1567665 commit 18e3b3f
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 21 deletions.
24 changes: 6 additions & 18 deletions bloptools/bayesian/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,9 +123,7 @@ def __init__(

# make some test points for sampling

self.normalized_test_active_inputs = utils.normalized_sobol_sampler(
n=MAX_TEST_INPUTS, d=self.n_active_dofs
)
self.normalized_test_active_inputs = utils.normalized_sobol_sampler(n=MAX_TEST_INPUTS, d=self.n_active_dofs)

n_per_active_dim = int(np.power(MAX_TEST_INPUTS, 1 / self.n_active_dofs))

Expand Down Expand Up @@ -190,9 +188,7 @@ def test_active_inputs_grid(self):
def input_transform(self):
coefficient = torch.tensor(self.dof_bounds.ptp(axis=1)).unsqueeze(0)
offset = torch.tensor(self.dof_bounds.min(axis=1)).unsqueeze(0)
return botorch.models.transforms.input.AffineInputTransform(
d=self.n_dofs, coefficient=coefficient, offset=offset
)
return botorch.models.transforms.input.AffineInputTransform(d=self.n_dofs, coefficient=coefficient, offset=offset)

def save_data(self, filepath="./self_data.h5"):
"""
Expand Down Expand Up @@ -312,9 +308,7 @@ def tell(self, new_table=None, append=True, train=True, **kwargs):
raise ValueError("There must be at least two feasible data points per task!")

train_inputs = torch.tensor(self.inputs.loc[task.feasibility].values).double().unsqueeze(0)
train_targets = (
torch.tensor(task.targets.loc[task.feasibility].values).double().unsqueeze(0).unsqueeze(-1)
)
train_targets = torch.tensor(task.targets.loc[task.feasibility].values).double().unsqueeze(0).unsqueeze(-1)

if train_inputs.ndim == 1:
train_inputs = train_inputs.unsqueeze(-1)
Expand Down Expand Up @@ -383,9 +377,7 @@ def train_models(self, **kwargs):

def get_acquisition_function(self, acqf_identifier="ei", return_metadata=False, acqf_args={}, **kwargs):
if not self._initialized:
raise RuntimeError(
f'Can\'t construct acquisition function "{acqf_identifier}" (the self is not initialized!)'
)
raise RuntimeError(f'Can\'t construct acquisition function "{acqf_identifier}" (the self is not initialized!)')

if acqf_identifier.lower() in AVAILABLE_ACQFS["expected_improvement"]["identifiers"]:
acqf = botorch.acquisition.analytic.LogExpectedImprovement(
Expand Down Expand Up @@ -443,9 +435,7 @@ def ask(self, acqf_identifier="ei", n=1, route=True, return_metadata=False):
x, acqf_meta = self.ask_single(acqf_identifier, return_metadata=True)

if i < (n - 1):
task_samples = [
task.regressor.posterior(torch.tensor(x)).sample().item() for task in self.tasks
]
task_samples = [task.regressor.posterior(torch.tensor(x)).sample().item() for task in self.tasks]
fantasy_table = pd.DataFrame(
np.append(x, task_samples)[None], columns=[*self.dof_names, *self.task_names]
)
Expand Down Expand Up @@ -996,9 +986,7 @@ def plot_history(self, x_key="index", show_all_tasks=False):
)
hist_axes = np.atleast_1d(hist_axes)

unique_strategies, acqf_index, acqf_inverse = np.unique(
self.table.acqf, return_index=True, return_inverse=True
)
unique_strategies, acqf_index, acqf_inverse = np.unique(self.table.acqf, return_index=True, return_inverse=True)

sample_colors = np.array(DEFAULT_COLOR_LIST)[acqf_inverse]

Expand Down
4 changes: 1 addition & 3 deletions bloptools/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,7 @@ def route(start_point, points):
delay_matrix = np.sqrt(np.square(normalized_points[:, None, :] - normalized_points[None, :, :]).sum(axis=-1))
delay_matrix = (1e4 * delay_matrix).astype(int) # it likes integers idk

manager = pywrapcp.RoutingIndexManager(
len(total_points), 1, 0
) # number of depots, number of salesmen, starting index
manager = pywrapcp.RoutingIndexManager(len(total_points), 1, 0) # number of depots, number of salesmen, starting index
routing = pywrapcp.RoutingModel(manager)

def delay_callback(from_index, to_index):
Expand Down

0 comments on commit 18e3b3f

Please sign in to comment.