Skip to content

Commit

Permalink
Update trainerfunc.jl to handle error non poorly
Browse files Browse the repository at this point in the history
  • Loading branch information
nityajoshi authored Nov 12, 2023
1 parent ce3266e commit f0cb1e8
Showing 1 changed file with 17 additions and 12 deletions.
29 changes: 17 additions & 12 deletions _tutorials/trainerfunc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,18 +26,23 @@ function trainer(m, func, ln::Real=5)
schedule = Interpolator(Step(initial_lr, 0.5, [25, 10]), es)
optim = func(initial_lr)
# callbacks
logger = TensorBoardBackend("tblogs")
# schcb = Scheduler(LearningRate => schedule)
logcb = (LogMetrics(logger),)# LogHyperParams(logger))
valcb = Metrics(Metric(accfn; phase = TrainingPhase, name = "train_acc"),
Metric(accfn; phase = ValidationPhase, name = "val_acc"))
learner = Learner(m, lossfn;
data = (trainloader, valloader),
optimizer = optim,
callbacks = [ToGPU(), logcb..., valcb])

FluxTraining.fit!(learner, ln)
close(logger.logger)
logger = TensorBoardBackend("tblog")
try

# schcb = Scheduler(LearningRate => schedule)
logcb = (LogMetrics(logger),)# LogHyperParams(logger))
valcb = Metrics(Metric(accfn; phase = TrainingPhase, name = "train_acc"),
Metric(accfn; phase = ValidationPhase, name = "val_acc"))
learner = Learner(m, lossfn;
data = (trainloader, valloader),
optimizer = optim,
callbacks = [ToGPU(), logcb..., valcb])

FluxTraining.fit!(learner, ln)
finally
close(logger.logger)
end

## save model
m = learner.model |> cpu
Expand All @@ -56,4 +61,4 @@ end
#m = resize_nobn(m)
#m = desaturate(m)
#m = rebn(m)
#m = trainer(m, 1)
#m = trainer(m, 1)

0 comments on commit f0cb1e8

Please sign in to comment.