Skip to content

Commit 1b17082

Browse files
authored
Merge pull request #785 from JuliaAI/stack_cache_and_acceleration_rebased
Stack cache and acceleration (rebased)
2 parents adb341f + e02f6af commit 1b17082

File tree

8 files changed

+165
-51
lines changed

8 files changed

+165
-51
lines changed

Project.toml

+1-15
Original file line numberDiff line numberDiff line change
@@ -46,18 +46,4 @@ ScientificTypes = "3"
4646
StatisticalTraits = "3"
4747
StatsBase = "0.32, 0.33"
4848
Tables = "0.2, 1.0"
49-
julia = "1.6"
50-
51-
[extras]
52-
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
53-
DecisionTree = "7806a523-6efd-50cb-b5f6-3fa6f1930dbb"
54-
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
55-
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
56-
MultivariateStats = "6f286f6a-111f-5878-ab1e-185364afe411"
57-
NearestNeighbors = "b8a86587-4115-5ab1-83bc-aa920d37bbce"
58-
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
59-
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
60-
TypedTables = "9d95f2ec-7b3d-5a63-8d20-e2491e220bb9"
61-
62-
[targets]
63-
test = ["DataFrames", "DecisionTree", "Distances", "Logging", "MultivariateStats", "NearestNeighbors", "StableRNGs", "Test", "TypedTables"]
49+
julia = "1.6"

src/composition/learning_networks/machines.jl

+5-4
Original file line numberDiff line numberDiff line change
@@ -328,7 +328,7 @@ end
328328

329329
"""
330330
331-
return!(mach::Machine{<:Surrogate}, model, verbosity)
331+
return!(mach::Machine{<:Surrogate}, model, verbosity; acceleration=CPU1())
332332
333333
The last call in custom code defining the `MLJBase.fit` method for a
334334
new composite model type. Here `model` is the instance of the new type
@@ -345,7 +345,7 @@ the following:
345345
handles smart updating (namely, an `MLJBase.update` fallback for
346346
composite models).
347347
348-
- Calls `fit!(mach, verbosity=verbosity)`.
348+
- Calls `fit!(mach, verbosity=verbosity, acceleration=acceleration)`.
349349
350350
- Moves any data in source nodes of the learning network into `cache`
351351
(for data-anonymization purposes).
@@ -388,11 +388,12 @@ end
388388
"""
389389
function return!(mach::Machine{<:Surrogate},
390390
model::Union{Model,Nothing},
391-
verbosity)
391+
verbosity;
392+
acceleration=CPU1())
392393

393394
network_model_names_ = network_model_names(model, mach)
394395

395-
verbosity isa Nothing || fit!(mach, verbosity=verbosity)
396+
verbosity isa Nothing || fit!(mach, verbosity=verbosity, acceleration=acceleration)
396397
setfield!(mach.fitresult, :network_model_names, network_model_names_)
397398

398399
# anonymize the data

src/composition/learning_networks/nodes.jl

+23-2
Original file line numberDiff line numberDiff line change
@@ -201,14 +201,17 @@ end
201201
acceleration=CPU1())
202202
203203
Train all machines required to call the node `N`, in an appropriate
204-
order. These machines are those returned by `machines(N)`.
204+
order, but parallelizing where possible using specified `acceleration`
205+
mode. These machines are those returned by `machines(N)`.
206+
207+
Supported modes of `acceleration`: `CPU1()`, `CPUThreads()`.
205208
206209
"""
207210
fit!(y::Node; acceleration=CPU1(), kwargs...) =
208211
fit!(y::Node, acceleration; kwargs...)
209212

210213
fit!(y::Node, ::AbstractResource; kwargs...) =
211-
error("Only `acceleration=CPU1()` currently supported")
214+
error("Only `acceleration=CPU1()` and `acceleration=CPUThreads()` currently supported")
212215

213216
function fit!(y::Node, ::CPU1; kwargs...)
214217

@@ -230,6 +233,24 @@ function fit!(y::Node, ::CPU1; kwargs...)
230233

231234
return y
232235
end
236+
237+
function fit!(y::Node, ::CPUThreads; kwargs...)
238+
_machines = machines(y)
239+
240+
# flush the fit_okay channels:
241+
for mach in _machines
242+
flush!(mach.fit_okay)
243+
end
244+
245+
# fit the machines in Multithreading mode
246+
@sync for mach in _machines
247+
Threads.@spawn fit_only!(mach, true; kwargs...)
248+
end
249+
250+
return y
251+
252+
end
253+
233254
fit!(S::Source; args...) = S
234255

235256
# allow arguments of `Nodes` and `Machine`s to appear

src/composition/models/stacking.jl

+37-23
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,11 @@ mutable struct DeterministicStack{modelnames, inp_scitype, tg_scitype} <: Determ
3131
metalearner::Deterministic
3232
resampling
3333
measures::Union{Nothing,AbstractVector}
34-
function DeterministicStack(modelnames, models, metalearner, resampling, measures)
34+
cache::Bool
35+
acceleration::AbstractResource
36+
function DeterministicStack(modelnames, models, metalearner, resampling, measures, cache, acceleration)
3537
inp_scitype, tg_scitype = input_target_scitypes(models, metalearner)
36-
return new{modelnames, inp_scitype, tg_scitype}(models, metalearner, resampling, measures)
38+
return new{modelnames, inp_scitype, tg_scitype}(models, metalearner, resampling, measures, cache, acceleration)
3739
end
3840
end
3941

@@ -42,9 +44,11 @@ mutable struct ProbabilisticStack{modelnames, inp_scitype, tg_scitype} <: Probab
4244
metalearner::Probabilistic
4345
resampling
4446
measures::Union{Nothing,AbstractVector}
45-
function ProbabilisticStack(modelnames, models, metalearner, resampling, measures)
47+
cache::Bool
48+
acceleration::AbstractResource
49+
function ProbabilisticStack(modelnames, models, metalearner, resampling, measures, cache, acceleration)
4650
inp_scitype, tg_scitype = input_target_scitypes(models, metalearner)
47-
return new{modelnames, inp_scitype, tg_scitype}(models, metalearner, resampling, measures)
51+
return new{modelnames, inp_scitype, tg_scitype}(models, metalearner, resampling, measures, cache, acceleration)
4852
end
4953
end
5054

@@ -54,7 +58,7 @@ const Stack{modelnames, inp_scitype, tg_scitype} =
5458
ProbabilisticStack{modelnames, inp_scitype, tg_scitype}}
5559

5660
"""
57-
Stack(;metalearner=nothing, resampling=CV(), name1=model1, name2=model2, ...)
61+
Stack(; metalearner=nothing, name1=model1, name2=model2, ..., keyword_options...)
5862
5963
Implements the two-layer generalized stack algorithm introduced by
6064
[Wolpert
@@ -89,12 +93,17 @@ When training a machine bound to such an instance:
8993
model will optimize the squared error.
9094
9195
- `resampling`: The resampling strategy used
92-
to prepare out-of-sample predictions of the base learners.
96+
to prepare out-of-sample predictions of the base learners.
9397
94-
- `measures`: A measure or iterable over measures, to perform an internal
98+
- `measures`: A measure or iterable over measures, to perform an internal
9599
evaluation of the learners in the Stack while training. This is not for the
96100
evaluation of the Stack itself.
97101
102+
- `cache`: Whether machines created in the learning network will cache data or not.
103+
104+
- `acceleration`: A supported `AbstractResource` to define the training parallelization
105+
mode of the stack.
106+
98107
- `name1=model1, name2=model2, ...`: the `Supervised` model instances
99108
to be used as base learners. The provided names become properties
100109
of the instance created to allow hyper-parameter access
@@ -139,15 +148,15 @@ evaluate!(mach; resampling=Holdout(), measure=rmse)
139148
140149
```
141150
142-
The internal evaluation report can be accessed like this
151+
The internal evaluation report can be accessed like this
143152
and provides a PerformanceEvaluation object for each model:
144153
145154
```julia
146155
report(mach).cv_report
147156
```
148157
149158
"""
150-
function Stack(;metalearner=nothing, resampling=CV(), measure=nothing, measures=measure, named_models...)
159+
function Stack(;metalearner=nothing, resampling=CV(), measure=nothing, measures=measure, cache=true, acceleration=CPU1(), named_models...)
151160
metalearner === nothing &&
152161
throw(ArgumentError("No metalearner specified. Use Stack(metalearner=...)"))
153162

@@ -159,9 +168,9 @@ function Stack(;metalearner=nothing, resampling=CV(), measure=nothing, measures=
159168
end
160169

161170
if metalearner isa Deterministic
162-
stack = DeterministicStack(modelnames, models, metalearner, resampling, measures)
171+
stack = DeterministicStack(modelnames, models, metalearner, resampling, measures, cache, acceleration)
163172
elseif metalearner isa Probabilistic
164-
stack = ProbabilisticStack(modelnames, models, metalearner, resampling, measures)
173+
stack = ProbabilisticStack(modelnames, models, metalearner, resampling, measures, cache, acceleration)
165174
else
166175
throw(ArgumentError("The metalearner should be a subtype
167176
of $(Union{Deterministic, Probabilistic})"))
@@ -202,13 +211,16 @@ function MMI.clean!(stack::Stack{modelnames, inp_scitype, tg_scitype}) where {mo
202211
end
203212

204213

205-
Base.propertynames(::Stack{modelnames}) where modelnames = tuple(:resampling, :metalearner, modelnames...)
214+
Base.propertynames(::Stack{modelnames}) where modelnames =
215+
tuple(:metalearner, :resampling, :measures, :cache, :acceleration, modelnames...)
206216

207217

208218
function Base.getproperty(stack::Stack{modelnames}, name::Symbol) where modelnames
209219
name === :metalearner && return getfield(stack, :metalearner)
210220
name === :resampling && return getfield(stack, :resampling)
211221
name == :measures && return getfield(stack, :measures)
222+
name === :cache && return getfield(stack, :cache)
223+
name == :acceleration && return getfield(stack, :acceleration)
212224
models = getfield(stack, :models)
213225
for j in eachindex(modelnames)
214226
name === modelnames[j] && return models[j]
@@ -221,6 +233,8 @@ function Base.setproperty!(stack::Stack{modelnames}, _name::Symbol, val) where m
221233
_name === :metalearner && return setfield!(stack, :metalearner, val)
222234
_name === :resampling && return setfield!(stack, :resampling, val)
223235
_name === :measures && return setfield!(stack, :measures, val)
236+
_name === :cache && return setfield!(stack, :cache, val)
237+
_name === :acceleration && return setfield!(stack, :acceleration, val)
224238
idx = findfirst(==(_name), modelnames)
225239
idx isa Nothing || return getfield(stack, :models)[idx] = val
226240
error("type Stack has no property $name")
@@ -272,7 +286,7 @@ internal_stack_report(m::Stack, verbosity::Int, tt_pairs, folds_evaluations::Var
272286
"""
273287
internal_stack_report(m::Stack, verbosity::Int, y::AbstractNode, folds_evaluations::Vararg{AbstractNode})
274288
275-
When measure/measures is provided, the folds_evaluation will have been filled by `store_for_evaluation`. This function is
289+
When measure/measures is provided, the folds_evaluation will have been filled by `store_for_evaluation`. This function is
276290
not doing any heavy work (not constructing nodes corresponding to measures) but just unpacking all the folds_evaluations in a single node that
277291
can be evaluated later.
278292
"""
@@ -304,10 +318,10 @@ function internal_stack_report(stack::Stack{modelnames,}, verbosity::Int, tt_pai
304318
fitted_params_per_fold=[],
305319
report_per_fold=[],
306320
train_test_pairs=tt_pairs
307-
)
321+
)
308322
for model in getfield(stack, :models)]
309323
)
310-
324+
311325
# Update the results
312326
index = 1
313327
for foldid in 1:nfolds
@@ -330,7 +344,7 @@ function internal_stack_report(stack::Stack{modelnames,}, verbosity::Int, tt_pai
330344
end
331345

332346
# Update per_fold
333-
model_results.per_fold[i][foldid] =
347+
model_results.per_fold[i][foldid] =
334348
reports_each_observation(measure) ? MLJBase.aggregate(loss, measure) : loss
335349
end
336350
index += 1
@@ -366,7 +380,7 @@ end
366380
oos_set(m::Stack, folds::AbstractNode, Xs::Source, ys::Source)
367381
368382
This function is building the out-of-sample dataset that is later used by the `judge`
369-
for its own training. It also returns the folds_evaluations object if internal
383+
for its own training. It also returns the folds_evaluations object if internal
370384
cross-validation results are requested.
371385
"""
372386
function oos_set(m::Stack, Xs::Source, ys::Source, tt_pairs)
@@ -384,7 +398,7 @@ function oos_set(m::Stack, Xs::Source, ys::Source, tt_pairs)
384398
# predictions are subsequently used as an input to the metalearner
385399
Zfold = []
386400
for model in getfield(m, :models)
387-
mach = machine(model, Xtrain, ytrain)
401+
mach = machine(model, Xtrain, ytrain, cache=m.cache)
388402
ypred = predict(mach, Xtest)
389403
# Internal evaluation on the fold if required
390404
push!(folds_evaluations, store_for_evaluation(mach, Xtest, ytest, m.measures))
@@ -417,15 +431,15 @@ function fit(m::Stack, verbosity::Int, X, y)
417431

418432
Xs = source(X)
419433
ys = source(y)
420-
434+
421435
Zval, yval, folds_evaluations = oos_set(m, Xs, ys, tt_pairs)
422436

423-
metamach = machine(m.metalearner, Zval, yval)
437+
metamach = machine(m.metalearner, Zval, yval, cache=m.cache)
424438

425439
# Each model is retrained on the original full training set
426440
Zpred = []
427441
for model in getfield(m, :models)
428-
mach = machine(model, Xs, ys)
442+
mach = machine(model, Xs, ys, cache=m.cache)
429443
ypred = predict(mach, Xs)
430444
ypred = pre_judge_transform(ypred, typeof(model), target_scitype(model))
431445
push!(Zpred, ypred)
@@ -438,6 +452,6 @@ function fit(m::Stack, verbosity::Int, X, y)
438452

439453
# We can infer the Surrogate by two calls to supertype
440454
mach = machine(supertype(supertype(typeof(m)))(), Xs, ys; predict=ŷ, internal_report...)
441-
442-
return!(mach, m, verbosity)
455+
456+
return!(mach, m, verbosity, acceleration=m.acceleration)
443457
end

test/Project.toml

+26
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
[deps]
2+
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"
3+
ComputationalResources = "ed09eef8-17a6-5b46-8889-db040fac31e3"
4+
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
5+
DecisionTree = "7806a523-6efd-50cb-b5f6-3fa6f1930dbb"
6+
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
7+
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
8+
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
9+
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
10+
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
11+
LossFunctions = "30fc2ffe-d236-52d8-8643-a9d8f7c094a7"
12+
MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea"
13+
MultivariateStats = "6f286f6a-111f-5878-ab1e-185364afe411"
14+
NearestNeighbors = "b8a86587-4115-5ab1-83bc-aa920d37bbce"
15+
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
16+
Parameters = "d96e819e-fc66-5662-9728-84c9c7592b0a"
17+
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
18+
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
19+
ScientificTypes = "321657f4-b219-11e9-178b-2701a2544e81"
20+
Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
21+
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
22+
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
23+
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
24+
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
25+
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
26+
TypedTables = "9d95f2ec-7b3d-5a63-8d20-e2491e220bb9"

test/composition/learning_networks/machines.jl

+4-4
Original file line numberDiff line numberDiff line change
@@ -155,9 +155,9 @@ zhat = inverse_transform(standM, uhat)
155155
yhat = exp(zhat)
156156
enode = @node mae(ys, yhat)
157157

158-
@testset "replace method for learning network machines" begin
158+
@testset "replace method for learning network machines, acceleration: $(typeof(accel))" for accel in (CPU1(), CPUThreads())
159159

160-
fit!(yhat, verbosity=0)
160+
fit!(yhat, verbosity=0, acceleration=accel)
161161

162162
# test nested reporting:
163163
r = MLJBase.report(yhat)
@@ -199,7 +199,7 @@ enode = @node mae(ys, yhat)
199199
knnM2 = machines(yhat2, knn) |> first
200200
hotM2 = machines(yhat2, hot) |> first
201201

202-
@test_mach_sequence(fit!(yhat2, force=true),
202+
@test_mach_sequence(fit!(yhat2, force=true, acceleration=accel),
203203
[(:train, standM2), (:train, hotM2),
204204
(:train, knnM2), (:train, oakM2)],
205205
[(:train, hotM2), (:train, standM2),
@@ -218,7 +218,7 @@ enode = @node mae(ys, yhat)
218218
# this change should trigger retraining of all machines except the
219219
# univariate standardizer:
220220
hot2.drop_last = true
221-
@test_mach_sequence(fit!(yhat2),
221+
@test_mach_sequence(fit!(yhat2, acceleration=accel),
222222
[(:skip, standM2), (:update, hotM2),
223223
(:train, knnM2), (:train, oakM2)],
224224
[(:update, hotM2), (:skip, standM2),

0 commit comments

Comments
 (0)