Skip to content

Commit 4328570

Browse files
committed
fix tutorial
1 parent 71f0f0d commit 4328570

File tree

2 files changed

+156
-1
lines changed

2 files changed

+156
-1
lines changed

docs/src/tutorials/warcraft.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,4 +85,4 @@ final_gap = compute_gap(b, test_dataset, model, maximizer)
8585
#
8686
θ = model(x)
8787
y = maximizer(θ)
88-
plot_data(b, DataSample(; x, θ, y))
88+
plot_data(b, DataSample(; x, θ_true=θ, y_true=y))

docs/src/warcraft.md

Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
```@meta
2+
EditURL = "tutorials/warcraft.jl"
3+
```
4+
5+
# Path-finding on image maps
6+
7+
In this tutorial, we showcase DecisionFocusedLearningBenchmarks.jl capabilities on one of its main benchmarks: the Warcraft benchmark.
8+
This benchmark problem is a simple path-finding problem where the goal is to find the shortest path between the top left and bottom right corners of a given image map.
9+
The map is represented as a 2D image representing a 12x12 grid, each cell having an unknown travel cost depending on the terrain type.
10+
11+
First, let's load the package and create a benchmark object as follows:
12+
13+
````@example warcraft
14+
using DecisionFocusedLearningBenchmarks
15+
b = WarcraftBenchmark()
16+
````
17+
18+
## Dataset generation
19+
20+
These benchmark objects behave as generators that can generate various needed elements in order to build an algorithm to tackle the problem.
21+
First of all, all benchmarks are capable of generating datasets as needed, using the [`generate_dataset`](@ref) method.
22+
This method takes as input the benchmark object for which the dataset is to be generated, and a second argument specifying the number of samples to generate:
23+
24+
````@example warcraft
25+
dataset = generate_dataset(b, 50);
26+
nothing #hide
27+
````
28+
29+
We obtain a vector of [`DataSample`](@ref) objects, containing all needed data for the problem.
30+
Subdatasets can be created through regular slicing:
31+
32+
````@example warcraft
33+
train_dataset, test_dataset = dataset[1:45], dataset[46:50]
34+
````
35+
36+
And getting an individual sample will return a [`DataSample`](@ref) with four fields: `x`, `instance`, `θ`, and `y`:
37+
38+
````@example warcraft
39+
sample = test_dataset[1]
40+
````
41+
42+
`x` correspond to the input features, i.e. the input image (3D array) in the Warcraft benchmark case:
43+
44+
````@example warcraft
45+
x = sample.x
46+
````
47+
48+
`θ_true` correspond to the true unknown terrain weights. We use the opposite of the true weights in order to formulate the optimization problem as a maximization problem:
49+
50+
````@example warcraft
51+
θ_true = sample.θ_true
52+
````
53+
54+
`y_true` correspond to the optimal shortest path, encoded as a binary matrix:
55+
56+
````@example warcraft
57+
y_true = sample.y_true
58+
````
59+
60+
`instance` is not used in this benchmark, therefore set to nothing:
61+
62+
````@example warcraft
63+
isnothing(sample.instance)
64+
````
65+
66+
For some benchmarks, we provide the following plotting method [`plot_data`](@ref) to visualize the data:
67+
68+
````@example warcraft
69+
plot_data(b, sample)
70+
````
71+
72+
We can see here the terrain image, the true terrain weights, and the true shortest path avoiding the high cost cells.
73+
74+
## Building a pipeline
75+
76+
DecisionFocusedLearningBenchmarks also provides methods to build an hybrid machine learning and combinatorial optimization pipeline for the benchmark.
77+
First, the [`generate_statistical_model`](@ref) method generates a machine learning predictor to predict cell weights from the input image:
78+
79+
````@example warcraft
80+
model = generate_statistical_model(b)
81+
````
82+
83+
In the case of the Warcraft benchmark, the model is a convolutional neural network built using the Flux.jl package.
84+
85+
````@example warcraft
86+
θ = model(x)
87+
````
88+
89+
Note that the model is not trained yet, and its parameters are randomly initialized.
90+
91+
Finally, the [`generate_maximizer`](@ref) method can be used to generate a combinatorial optimization algorithm that takes the predicted cell weights as input and returns the corresponding shortest path:
92+
93+
````@example warcraft
94+
maximizer = generate_maximizer(b; dijkstra=true)
95+
````
96+
97+
In the case o fthe Warcraft benchmark, the method has an additional keyword argument to chose the algorithm to use: Dijkstra's algorithm or Bellman-Ford algorithm.
98+
99+
````@example warcraft
100+
y = maximizer(θ)
101+
````
102+
103+
As we can see, currently the pipeline predicts random noise as cell weights, and therefore the maximizer returns a straight line path.
104+
105+
````@example warcraft
106+
plot_data(b, DataSample(; x, θ, y))
107+
````
108+
109+
We can evaluate the current pipeline performance using the optimality gap metric:
110+
111+
````@example warcraft
112+
starting_gap = compute_gap(b, test_dataset, model, maximizer)
113+
````
114+
115+
## Using a learning algorithm
116+
117+
We can now train the model using the InferOpt.jl package:
118+
119+
````@example warcraft
120+
using InferOpt
121+
using Flux
122+
using Plots
123+
124+
perturbed_maximizer = PerturbedMultiplicative(maximizer; ε=0.2, nb_samples=100)
125+
loss = FenchelYoungLoss(perturbed_maximizer)
126+
127+
starting_gap = compute_gap(b, test_dataset, model, maximizer)
128+
129+
opt_state = Flux.setup(Adam(1e-3), model)
130+
loss_history = Float64[]
131+
for epoch in 1:50
132+
val, grads = Flux.withgradient(model) do m
133+
sum(loss(m(sample.x), sample.y) for sample in train_dataset) / length(train_dataset)
134+
end
135+
Flux.update!(opt_state, model, grads[1])
136+
push!(loss_history, val)
137+
end
138+
139+
plot(loss_history; xlabel="Epoch", ylabel="Loss", title="Training loss")
140+
````
141+
142+
````@example warcraft
143+
final_gap = compute_gap(b, test_dataset, model, maximizer)
144+
````
145+
146+
````@example warcraft
147+
θ = model(x)
148+
y = maximizer(θ)
149+
plot_data(b, DataSample(; x, θ, y))
150+
````
151+
152+
---
153+
154+
*This page was generated using [Literate.jl](https://github.com/fredrikekre/Literate.jl).*
155+

0 commit comments

Comments
 (0)