Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

distributed training #12

Open
wants to merge 47 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
d609538
Remove Zygote.Params arg from loss_and_gradient
Jul 19, 2020
f92fdee
Distributed train!
Jul 20, 2020
88a14f9
Bugfixes
Jul 20, 2020
ded6689
Remove duplicate method defs.
Jul 20, 2020
b403f9c
Differentiate calls to `loss_and_gradient` better.
Jul 21, 2020
3839b7a
Bugfix attempt 2
Jul 21, 2020
d19379e
Bugfix attempt 3
Jul 21, 2020
c177e81
Bugfix 4
Jul 21, 2020
16877fc
Remove duplicate method defs.
Jul 22, 2020
28d02e0
Adding straggler method back.
Jul 22, 2020
954ace2
buffered_batch_loader
Jul 22, 2020
2ee1666
Align grads and params on their params orders.
Jul 23, 2020
4942e32
Bugfix
Jul 23, 2020
702852b
Bugfix
Jul 23, 2020
426a1c9
Bugfixes; occasional "CUDNNError: CUDNN_STATUS_NOT_INITIALIZED"
Jul 24, 2020
81a590b
Updated Flux/Zygote/CuArrays->CUDA
Jul 24, 2020
b78825b
export buffered_batch_loader
Jul 24, 2020
11d7114
add CUDA
Jul 24, 2020
85db31d
forgot to use CUDA
Jul 25, 2020
3cd605b
Add Renormalizer optimiser
Jul 25, 2020
b618117
Send less over wire to avoid trouble.
Jul 26, 2020
1ee66a2
Log __plot_data__
Jul 26, 2020
baed8e1
Bugfixes to robust controller-based training loop.
Jul 28, 2020
2b7bb5a
Don't crash if cannot kill stalled worker.
Jul 28, 2020
ae527d9
Bump up timeout for worker returning a gradient.
Jul 28, 2020
59a4121
Try not to get stuck in rmprocs()
Jul 28, 2020
3ef496c
Merge remote-tracking branch 'origin' into ks/big_train
Jul 28, 2020
c602ec1
Merge remote-tracking branch 'origin/ks/big_train' into ks/big_train
Jul 28, 2020
5e68d56
Handle case where no worker returns a gradient.
Jul 28, 2020
0c6c5d5
Typo bugfix
Jul 28, 2020
a3bf554
Whoups, nuked wrong one.
Jul 28, 2020
ee2fe7a
Remove probable blocker.
Jul 29, 2020
069daa4
Remove unresponsive workers
Jul 29, 2020
5fb0e53
Widen type of logger in `loss_and_gradient`
Jul 29, 2020
cf794a7
Merge remote-tracking branch 'origin/ks/big_train' into ks/big_train
Jul 29, 2020
5cf2d38
@show train_loss
Jul 29, 2020
797767b
Tighten up memory consumption on master.
Jul 31, 2020
ccfabd2
Merge branch 'ks/big_train' of github.com:beacon-biosignals/Lighthous…
Jul 31, 2020
3b10320
gpu_free_memory() util
Jul 31, 2020
810ea29
typo
Jul 31, 2020
44bdfd1
Preserve order / alignment.
Aug 4, 2020
82c2ee1
workers keep model between passes
Aug 8, 2020
7143b82
can't assign to _model?
Aug 9, 2020
cc482b9
Checkpoint before picking subset for PR.
Aug 18, 2020
e35b553
Bugfix: apply! needs to be in module Flux.Optimise
Aug 18, 2020
0ca38f6
Cleanup logger; it now lives in Lighthouse#general_logger
Sep 9, 2020
6346ca6
avoid memory leak by calling GC on remote workers
Oct 5, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,22 @@ authors = ["Beacon Biosignals, Inc."]
version = "0.2.4"

[deps]
BSON = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
FastS3 = "861778fc-d877-463f-8dc1-1c1e9d5ec150"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
Lighthouse = "ac2c24cd-07f0-4848-96b2-1b82c3ea0e59"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
Flux = "0.10.4, 0.11"
Flux = "0.11"
Lighthouse = "0.8, 0.9"
Zygote = "0.4.13, 0.5"
julia = "1.3"
Zygote = "0.5.3"
julia = "1.4"

[extras]
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Expand Down
147 changes: 14 additions & 133 deletions src/LighthouseFlux.jl
Original file line number Diff line number Diff line change
@@ -1,148 +1,29 @@
module LighthouseFlux

using Dates: now
using Zygote: Zygote
using Flux: Flux
using Lighthouse: Lighthouse, classes, log_resource_info!, log_value!

export FluxClassifier
using Distributed, BSON
using FastS3
using Serialization
using CUDA

#####
##### `FluxClassifier`
#####
abstract type AbstractFluxClassifier <: Lighthouse.AbstractClassifier end

struct FluxClassifier{M,O,C,P,OH,OC} <: Lighthouse.AbstractClassifier
model::M
optimiser::O
classes::C
params::P
onehot::OH
onecold::OC
end
include("local.jl")

"""
FluxClassifier(model, optimiser, classes; params=Flux.params(model),
onehot=(label -> Flux.onehot(label, 1:length(classes))),
onecold=(label -> Flux.onecold(label, 1:length(classes))))
# everything in `distributed/` should go somewhere else, it has nothing to do with Lighthouse of Flux
include("distributed/dataloader.jl")
include("distributed/sharding.jl")

Return a `FluxClassifier <: Lighthouse.AbstractClassifier` with the given arguments:
include("distributed.jl")

- `model`: a Flux model. The model must additionally support LighthouseFlux's [`loss`](@ref)
and [`loss_and_prediction`](@ref) functions.
include("optimiser.jl")

- `optimiser`: a [Flux optimiser](https://fluxml.ai/Flux.jl/stable/training/optimisers/)
export FluxClassifier, DistributedFluxClassifier

- `classes`: a `Vector` or `Tuple` of possible class values; this is the return
value of `Lighthouse.classes(::FluxClassifier)`.

- `params`: The parameters to optimise during training; generally, a `Zygote.Params`
value or a value that can be passed to `Zygote.Params`.

- `onehot`: the function used to convert hard labels to soft labels when
`Lighthouse.onehot` is called with this classifier.

- `onecold`: the function used to convert soft labels to hard labels when
`Lighthouse.onecold` is called with this classifier.
"""
function FluxClassifier(model, optimiser, classes; params=Flux.params(model),
onehot=(label -> Flux.onehot(label, 1:length(classes))),
onecold=(label -> Flux.onecold(label, 1:length(classes))))
return FluxClassifier(Flux.testmode!(model), optimiser, classes, params, onehot, onecold)
end

"""
loss(model, batch_arguments...)

Return the scalar loss of `model` given `batch_arguments`.

This method must be implemented for all `model`s passed to [`FluxClassifier`](@ref).
"""
function loss end

"""
loss_and_prediction(model, input_batch, other_batch_arguments...)

Return `(model_loss, model_prediction)` where:

- `model_loss` is equivalent to (and defaults to) `loss(model, input_batch, other_batch_arguments...)`.

- `model_prediction` is a matrix where the `i`th column is the soft label prediction for the `i`th
sample in `input_batch`. Thus, the numnber of columns should be `size(input_batch)[end]`, while the
number of rows is equal to the number of possible classes predicted by model. `model_prediction`
defaults to `model(input_batch)`.

This method must be implemented for all `model`s passed to [`FluxClassifier`](@ref), but
has the default return values described above, so it only needs to be overloaded if the
default definitions do not yield the expected values for a given `model` type. It
additionally may be overloaded to avoid redundant computation if `model`'s loss
function computes soft labels as an intermediate result.
"""
function loss_and_prediction(model, input_batch, other_batch_arguments...)
return (loss(model, input_batch, other_batch_arguments...), model(input_batch))
end

#####
##### Lighthouse `AbstractClassifier` Interface
#####

Lighthouse.classes(classifier::FluxClassifier) = classifier.classes

function Lighthouse.is_early_stopping_exception(::FluxClassifier, exception)
return exception isa Flux.Optimise.StopException
end

Lighthouse.onehot(classifier::FluxClassifier, label) = classifier.onehot(label)

Lighthouse.onecold(classifier::FluxClassifier, label) = classifier.onecold(label)

function Lighthouse.train!(classifier::FluxClassifier, batches, logger)
Flux.trainmode!(classifier.model)
weights = Zygote.Params(classifier.params)
for batch in batches
train_loss, back = log_resource_info!(logger, "train/forward_pass";
suffix="_per_batch") do
f = () -> loss(classifier.model, batch...)
return Zygote.pullback(f, weights)
end
log_value!(logger, "train/loss_per_batch", train_loss)
gradients = log_resource_info!(logger, "train/reverse_pass";
suffix="_per_batch") do
return back(Zygote.sensitivity(train_loss))
end
log_resource_info!(logger, "train/update"; suffix="_per_batch") do
Flux.Optimise.update!(classifier.optimiser, weights, gradients)
return nothing
end
end
Flux.testmode!(classifier.model)
return nothing
end

function Lighthouse.loss_and_prediction(classifier::FluxClassifier, batch...)
return Flux.cpu(loss_and_prediction(classifier.model, batch...))
end

#####
##### `CuIterator`
#####

Base.@deprecate_moved CuIterator CuArrays false

#####
##### miscellaneous utilities
#####

"""
evaluate_chain_in_debug_mode(chain::Flux.Chain, input)

Evaluate `chain(input)`, printing additional debug information at each layer.
"""
function evaluate_chain_in_debug_mode(chain::Flux.Chain, input)
for (i, layer) in enumerate(chain)
@info "Executing layer $i / $(length(chain))..." layer size(input)
input = layer(input)
@info output_size=size(input)
end
return input
end
export @defineat, remotecall_channel, sendto, buffered_batch_loader

end # module
198 changes: 198 additions & 0 deletions src/distributed.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
struct DistributedFluxClassifier <: AbstractFluxClassifier
workerpool::AbstractWorkerPool
model::FluxClassifier
end

Lighthouse.classes(classifier::DistributedFluxClassifier) = classifier.model.classes

function Lighthouse.is_early_stopping_exception(::DistributedFluxClassifier, exception)
return exception isa Flux.Optimise.StopException
end

Lighthouse.onehot(classifier::DistributedFluxClassifier, label) = classifier.model.onehot(label)

Lighthouse.onecold(classifier::DistributedFluxClassifier, label) = classifier.model.onecold(label)

model(classifier::DistributedFluxClassifier) = classifier.model.model
params(classifier::DistributedFluxClassifier) = classifier.model.params
optimizer(classifier::DistributedFluxClassifier) = classifier.model.optimiser

"""
loss_and_prediction(model, input_batch, other_batch_arguments...)

Return `(model_loss, model_prediction)` where:

- `model_loss` is equivalent to (and defaults to) `loss(model, input_batch, other_batch_arguments...)`.

- `model_prediction` is a matrix where the `i`th column is the soft label prediction for the `i`th
sample in `input_batch`. Thus, the numnber of columns should be `size(input_batch)[end]`, while the
number of rows is equal to the number of possible classes predicted by model. `model_prediction`
defaults to `model(input_batch)`.

This method must be implemented for all `model`s passed to [`FluxClassifier`](@ref), but
has the default return values described above, so it only needs to be overloaded if the
default definitions do not yield the expected values for a given `model` type. It
additionally may be overloaded to avoid redundant computation if `model`'s loss
function computes soft labels as an intermediate result.
"""
function loss_and_prediction_and_votes(model)
for (dst, src) in zip(_model_params, Zygote.Params(Flux.params(model)))
copyto!(dst, src)
end
batch, votes = try
first(_test_batches)
catch e
@error "error taking _test_batches" exception=(e, catch_backtrace())
return nothing
end
Flux.testmode!(_model)
l, preds = loss_and_prediction(_model, batch...)
return (l, preds, votes)
end

function loss_and_prediction_and_votes(classifier::DistributedFluxClassifier; timeout_secs=42.0)
model_push!(classifier, :test; timeout_secs=timeout_secs)
Channel(42) do channel
try
remote_channel = RemoteChannel(() -> channel)
asyncmap(collect(classifier.workerpool)) do p
Distributed.remotecall_eval(LighthouseFlux, p, quote
_predictions = $remote_channel
for (test_batch, votes_indices) in _test_batches
l, preds = loss_and_prediction(_model, test_batch...)
put!(_predictions, (l, preds, votes_indices))
end
end)
end
catch e
@error "error pushing predictions to _predictions" exception=(e, catch_backtrace())
end
end
end

function model_update!(model, trainmode)
for (dst, src) in zip(_model_params, Zygote.Params(Flux.params(model)))
copyto!(dst, src)
end
if mode == :train
Flux.trainmode!(_model)
elseif mode == :test
Flux.testmode!(_model)
else
error("unknown model mode $mode not in [:train, :test]")
end
return nothing
end

function model_push!(model::DistributedFluxClassifier, mode; timeout_secs=42.0)
mode ∉ [:train, test] && error("unknown model mode $mode not in [:train, :test]")
shards = Dict{Int,Any}( p => (model_update!, Flux.cpu(classifier.model.model), mode) for p in classifier.workerpool.workers)
return_channel = remotecall_fetch_all(shards)
asyncmap(shards) do (pid, _)
status = timedwait(() -> isready(return_channel), timeout_secs)
if status != :ok
@warn "timeout pushing model to pid $pid !!!"
end
end
end

function loss_and_gradient(model, logger::RemoteChannel)
model_push!(model, :train)
batch = try
first(_training_batches)
catch e
@error "loss_and_gradient on worker" exception=(e, catch_backtrace())
return nothing
end
train_loss, back = log_resource_info!(logger, "train/forward_pass";
suffix="_per_batch") do
f = () -> loss(_model, batch...)
return Zygote.pullback(f, _model_params)
end
log_value!(logger, "train/loss_per_batch", train_loss)
gradients = log_resource_info!(logger, "train/reverse_pass";
suffix="_per_batch") do
return back(Zygote.sensitivity(train_loss))
end
return (train_loss, [Flux.cpu(gradients[p]) for p in gradients.params])
end

function loss_and_gradient(classifier::DistributedFluxClassifier, weights, b, logger::RemoteChannel; timeout_secs=42.0)
shards = Dict{Int,Any}( p => (loss_and_gradient, Flux.cpu(classifier.model.model), logger) for p in classifier.workerpool.workers)
return_channel = remotecall_fetch_all(shards)
train_loss, gradients, count = nothing, nothing, 0.0
pids = []
for (pid, _) in shards
status = timedwait(() -> isready(return_channel), timeout_secs)
if status == :ok
p, r = take!(return_channel)
# @show p, r
push!(pids, pid)
if r !== nothing && eltype(r[2]) != Nothing
loss, grad = r
if train_loss === nothing
train_loss, gradients = loss, grad
else
train_loss += loss
count += 1.0
map(+, gradients, grad)
end
end
else
pids = Set(pids)
unresponsive = setdiff(classifier.workerpool.workers, pids)
@warn "workers $unresponsive unresponsive, removing from worker pool, continuing tick without their batch data."
classifier.workerpool.workers = pids
for p in unresponsive
@async try
rmprocs(p; waitfor=1)
catch
rmprocs(p; waitfor=1)
nothing
# XXX call ClusterMangers kill
end
end
end
end
# train_loss /= count
# map(g -> g / count, gradients)
@show train_loss
return train_loss, reindex(gradients, weights)
end

function reindex(nope::Nothing, w)
grads = IdDict()
for wp in w
grads[wp] = zero(wp)
end
return Zygote.Grads(grads, w)
end

function reindex(g::Vector, w)
grads = IdDict()
for (gp, wp) in zip(g, w)
grads[wp] = gp
end
return Zygote.Grads(grads, w)
end

function Lighthouse.predict!(model::DistributedFluxClassifier,
predicted_soft_labels::AbstractMatrix,
batches::UnitRange{Int}, logger;
logger_prefix::AbstractString)
losses = []
for b in batches
for (batch_loss, soft_label_batch, votes) in loss_and_prediction_and_votes(model)
for (i, soft_label) in enumerate(eachcol(soft_label_batch))
predicted_soft_labels[votes[i], :] = soft_label
end
log_value!(logger, logger_prefix * "/loss_per_batch", batch_loss)
push!(losses, batch_loss)
end
end
# @info repr(losses)
mean_loss = sum(losses) ./ length(losses)
log_value!(logger, logger_prefix * "/mean_loss_per_epoch", mean_loss)
return mean_loss
end

Loading