Skip to content

Commit

Permalink
don't count hours
Browse files Browse the repository at this point in the history
  • Loading branch information
matsen committed Jun 13, 2024
1 parent edcfbea commit 46f8c68
Showing 1 changed file with 8 additions and 10 deletions.
18 changes: 8 additions & 10 deletions netam/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,11 +431,11 @@ def reset_optimization(self, learning_rate=None):
self.optimizer, mode="min", factor=0.5, patience=10
)

def execution_hours(self):
def execution_time(self):
"""
Return time in hours (rounded to 3 decimal places) since the Burrito was created.
Return time since the Burrito was created.
"""
return round((time() - self.start_time) / 3600, 3)
return time() - self.start_time

def multi_train(self, epochs, max_tries=3):
"""
Expand All @@ -456,7 +456,7 @@ def multi_train(self, epochs, max_tries=3):
return train_history

def write_loss(self, loss_name, loss, step):
self.writer.add_scalar(loss_name, loss, step, walltime=self.execution_hours())
self.writer.add_scalar(loss_name, loss, step, walltime=self.execution_time())

def write_cuda_memory_info(self):
megabyte_scaling_factor = 1 / 1024**2
Expand Down Expand Up @@ -695,7 +695,7 @@ def mark_branch_lengths_optimized(self, cycle):
"branch length optimization",
cycle,
self.global_epoch,
walltime=self.execution_hours(),
walltime=self.execution_time(),
)

def joint_train(
Expand Down Expand Up @@ -725,9 +725,7 @@ def joint_train(
optimize_branch_lengths()
self.mark_branch_lengths_optimized(0)
for cycle in range(cycle_count):
print(
f"### Beginning cycle {cycle + 1}/{cycle_count} using optimizer {self.optimizer_name}"
)
print(f"### Beginning cycle {cycle + 1}/{cycle_count} using optimizer {self.optimizer_name}")
self.mark_branch_lengths_optimized(cycle + 1)
current_lr = self.optimizer.param_groups[0]["lr"]
# set new_lr to be the geometric mean of current_lr and the
Expand Down Expand Up @@ -967,10 +965,10 @@ def find_optimal_branch_lengths(self, dataset, **optimization_kwargs):
def write_loss(self, loss_name, loss, step):
rate_loss, csp_loss = loss.unbind()
self.writer.add_scalar(
"Rate " + loss_name, rate_loss.item(), step, walltime=self.execution_hours()
"Rate " + loss_name, rate_loss.item(), step, walltime=self.execution_time()
)
self.writer.add_scalar(
"CSP " + loss_name, csp_loss.item(), step, walltime=self.execution_hours()
"CSP " + loss_name, csp_loss.item(), step, walltime=self.execution_time()
)


Expand Down

0 comments on commit 46f8c68

Please sign in to comment.