Skip to content

Commit e6d0b76

Browse files
authored
[snake-dqn] Initial commit (tensorflow#265)
This PR checks in the following parts of the Snake game-based DQN example: - snake_game.js: Game logic (without any graphics) - dqn.js: DQN network definition, along with some utility functions - replay_memory.js: The replay buffer used for DQN training - agent.js: The agent based on the epsilon-greedy algorithm - train.js: Training logic All modules are accompanied by unit tests.
1 parent 59b606d commit e6d0b76

17 files changed

+8309
-0
lines changed

snake-dqn/.babelrc

+18
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
{
2+
"presets": [
3+
[
4+
"env",
5+
{
6+
"esmodules": false,
7+
"targets": {
8+
"browsers": [
9+
"> 3%"
10+
]
11+
}
12+
}
13+
]
14+
],
15+
"plugins": [
16+
"transform-runtime"
17+
]
18+
}

snake-dqn/README.md

+61
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
# Using Deep Q-Learning to Solve the Snake Game
2+
3+
Deep Q-Learning is a reinforcement-learning (RL) algorithm. It is used
4+
frequently to solve arcade-style games like the Snake game used in this
5+
example.
6+
7+
## The Snake game
8+
9+
The Snake game is a grid-world action game in which the player controls
10+
a virtual snake that keeps moving on the game board (9x9 by default).
11+
At each step, there are four possible actions: left, right, up, and down.
12+
To achieve higher scores (rewards), the player should guide the snake
13+
to the fruits on the screen and "eat" them, while avoiding
14+
- its head going off the board, and
15+
- its head bumping into its own body.
16+
17+
This example consists of two parts:
18+
1. Training the Deep Q-Network (DQN) in Node.js
19+
2. Live demo in the browser
20+
21+
## Training the Deep Q-Network in Node.js
22+
23+
To train the DQN, use command:
24+
25+
```sh
26+
yarn
27+
yarn train
28+
```
29+
30+
If you have a CUDA-enabled GPU installed on your system, along with all
31+
the required drivers and libraries, append the `--gpu` flag to the command
32+
above to let use the GPU for training, which will lead to a significant
33+
increase in the training speed:
34+
35+
```sh
36+
yarn train --gpu
37+
```
38+
39+
To monitor the training progress using TensorBoard, use the `--logDir` flag
40+
and point it to a log directory, e.g.,
41+
42+
```sh
43+
yarn train --logDir /tmp/snake_logs
44+
```
45+
46+
During the training, you can use TensorBoard to visualize the curves of
47+
- Cumulative reward values from the games
48+
- Training speed (game frames per second)
49+
- Value of the epsilon from the epsilon-greedy algorithm
50+
and so forth.
51+
52+
Specifically, open a separate terminal. In the terminal, install tensorboard and
53+
launch the backend server of tensorboard:
54+
55+
```sh
56+
pip install tensorboard
57+
tensorboard --logdir /tmp/snake_logs
58+
```
59+
60+
Once started, the tensorboard backend process will print an `http://` URL to the
61+
console. Open your browser and navigate to the URL to see the logged curves.

snake-dqn/agent.js

+156
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
/**
2+
* @license
3+
* Copyright 2019 Google LLC. All Rights Reserved.
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
* =============================================================================
16+
*/
17+
18+
import * as tf from '@tensorflow/tfjs';
19+
20+
import {createDeepQNetwork} from './dqn';
21+
import {getRandomAction, SnakeGame, NUM_ACTIONS, ALL_ACTIONS, getStateTensor} from './snake_game';
22+
import {ReplayMemory} from './replay_memory';
23+
import { assertPositiveInteger } from './utils';
24+
25+
export class SnakeGameAgent {
26+
/**
27+
* Constructor of SnakeGameAgent.
28+
*
29+
* @param {SnakeGame} game A game object.
30+
* @param {object} config The configuration object with the following keys:
31+
* - `replayBufferSize` {number} Size of the replay memory. Must be a
32+
* positive integer.
33+
* - `epsilonInit` {number} Initial value of epsilon (for the epsilon-
34+
* greedy algorithm). Must be >= 0 and <= 1.
35+
* - `epsilonFinal` {number} The final value of epsilon. Must be >= 0 and
36+
* <= 1.
37+
* - `epsilonDecayFrames` {number} The # of frames over which the value of
38+
* `epsilon` decreases from `episloInit` to `epsilonFinal`, via a linear
39+
* schedule.
40+
*/
41+
constructor(game, config) {
42+
assertPositiveInteger(config.epsilonDecayFrames);
43+
44+
this.game = game;
45+
46+
this.epsilonInit = config.epsilonInit;
47+
this.epsilonFinal = config.epsilonFinal;
48+
this.epsilonDecayFrames = config.epsilonDecayFrames;
49+
this.epsilonIncrement_ = (this.epsilonFinal - this.epsilonInit) /
50+
this.epsilonDecayFrames;
51+
52+
this.onlineNetwork =
53+
createDeepQNetwork(game.height, game.width, NUM_ACTIONS);
54+
this.targetNetwork =
55+
createDeepQNetwork(game.height, game.width, NUM_ACTIONS);
56+
// Freeze taget network: it's weights are updated only through copying from
57+
// the online network.
58+
this.targetNetwork.trainable = false;
59+
60+
this.optimizer = tf.train.adam(config.learningRate);
61+
62+
this.replayBufferSize = config.replayBufferSize;
63+
this.replayMemory = new ReplayMemory(config.replayBufferSize);
64+
this.frameCount = 0;
65+
this.reset();
66+
}
67+
68+
reset() {
69+
this.cumulativeReward_ = 0;
70+
this.game.reset();
71+
}
72+
73+
/**
74+
* Play one step of the game.
75+
*
76+
* @returns {number | null} If this step leads to the end of the game,
77+
* the total reward from the game as a plain number. Else, `null`.
78+
*/
79+
playStep() {
80+
this.epsilon = this.frameCount >= this.epsilonDecayFrames ?
81+
this.epsilonFinal :
82+
this.epsilonInit + this.epsilonIncrement_ * this.frameCount;
83+
this.frameCount++;
84+
85+
// The epsilon-greedy algorithm.
86+
let action;
87+
const state = this.game.getState();
88+
if (Math.random() < this.epsilon) {
89+
// Pick an action at random.
90+
action = getRandomAction();
91+
} else {
92+
// Greedily pick an action based on online DQN output.
93+
tf.tidy(() => {
94+
const stateTensor =
95+
getStateTensor(state, this.game.height, this.game.width)
96+
action = ALL_ACTIONS[
97+
this.onlineNetwork.predict(stateTensor).argMax(-1).dataSync()[0]];
98+
});
99+
}
100+
101+
const {state: nextState, reward, done} = this.game.step(action);
102+
103+
this.replayMemory.append([state, action, reward, done, nextState]);
104+
105+
this.cumulativeReward_ += reward;
106+
const output = {
107+
action,
108+
cumulativeReward: this.cumulativeReward_,
109+
done
110+
};
111+
if (done) {
112+
this.reset();
113+
}
114+
return output;
115+
}
116+
117+
/**
118+
* Perform training on a randomly sampled batch from the replay buffer.
119+
*
120+
* @param {number} batchSize Batch size.
121+
* @param {numebr} gamma Reward discount rate. Must be >= 0 and <= 1.
122+
* @param {tf.train.Optimizer} optimizer The optimizer object used to update
123+
* the weights of the online network.
124+
*/
125+
trainOnReplayBatch(batchSize, gamma, optimizer) {
126+
// Get a batch of examples from the replay buffer.
127+
const batch = this.replayMemory.sample(batchSize);
128+
const lossFunction = () => tf.tidy(() => {
129+
const stateTensor = getStateTensor(
130+
batch.map(example => example[0]), this.game.height, this.game.width);
131+
const actionTensor = tf.tensor1d(
132+
batch.map(example => example[1]), 'int32');
133+
const qs = this.onlineNetwork.predict(
134+
stateTensor).mul(tf.oneHot(actionTensor, NUM_ACTIONS)).sum(-1);
135+
136+
const rewardTensor = tf.tensor1d(batch.map(example => example[2]));
137+
const nextStateTensor = getStateTensor(
138+
batch.map(example => example[4]), this.game.height, this.game.width);
139+
const nextMaxQTensor =
140+
this.targetNetwork.predict(nextStateTensor).max(-1);
141+
const doneMask = tf.scalar(1).sub(
142+
tf.tensor1d(batch.map(example => example[3])).asType('float32'));
143+
const targetQs =
144+
rewardTensor.add(nextMaxQTensor.mul(doneMask).mul(gamma));
145+
return tf.losses.meanSquaredError(targetQs, qs);
146+
});
147+
148+
// TODO(cais): Remove the second argument when `variableGrads()` obeys the
149+
// trainable flag.
150+
const grads =
151+
tf.variableGrads(lossFunction, this.onlineNetwork.getWeights());
152+
optimizer.applyGradients(grads.grads);
153+
tf.dispose(grads);
154+
// TODO(cais): Return the loss value here?
155+
}
156+
}

snake-dqn/agent_test.js

+122
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
/**
2+
* @license
3+
* Copyright 2019 Google LLC. All Rights Reserved.
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
* =============================================================================
16+
*/
17+
18+
import * as tf from '@tensorflow/tfjs-node';
19+
20+
import {SnakeGameAgent} from "./agent";
21+
import {SnakeGame} from "./snake_game";
22+
23+
describe('SnakeGameAgent', () => {
24+
it('playStep', () => {
25+
const game = new SnakeGame({
26+
height: 9,
27+
width: 9,
28+
numFruits: 1,
29+
initLen: 2
30+
});
31+
const agent = new SnakeGameAgent(game, {
32+
replayBufferSize: 100,
33+
epsilonInit: 1,
34+
epsilonFinal: 0.1,
35+
epsilonDecayFrames: 10
36+
});
37+
38+
const numGames = 40;
39+
let bufferIndex = 0;
40+
for (let n = 0; n < numGames; ++n) {
41+
// At the beginnig of a game, the cumulative reward ought to be 0.
42+
expect(agent.cumulativeReward_).toEqual(0);
43+
let out = null;
44+
let outPrev = null;
45+
for (let m = 0; m < 10; ++m) {
46+
const currentState = agent.game.getState();
47+
out = agent.playStep();
48+
// Check the content of the replay buffer.
49+
expect(agent.replayMemory.buffer[bufferIndex % 100][0])
50+
.toEqual(currentState);
51+
expect(agent.replayMemory.buffer[bufferIndex % 100][1])
52+
.toEqual(out.action);
53+
54+
expect(agent.replayMemory.buffer[bufferIndex % 100][2]).toEqual(
55+
outPrev == null ? out.cumulativeReward :
56+
out.cumulativeReward - outPrev.cumulativeReward);
57+
expect(agent.replayMemory.buffer[bufferIndex % 100][3]).toEqual(out.done);
58+
expect(agent.replayMemory.buffer[bufferIndex % 100][4])
59+
.toEqual(out.done ? undefined : agent.game.getState());
60+
bufferIndex++;
61+
if (out.done) {
62+
break;
63+
}
64+
outPrev = out;
65+
}
66+
agent.reset();
67+
}
68+
});
69+
70+
it('trainOnReplayBatch', () => {
71+
const game = new SnakeGame({
72+
height: 9,
73+
width: 9,
74+
numFruits: 1,
75+
initLen: 2
76+
});
77+
const replayBufferSize = 1000;
78+
const agent = new SnakeGameAgent(game, {
79+
replayBufferSize,
80+
epsilonInit: 1,
81+
epsilonFinal: 0.1,
82+
epsilonDecayFrames: 1000,
83+
learningRate: 1e-2
84+
});
85+
86+
const oldOnlineWeights =
87+
agent.onlineNetwork.getWeights().map(x => x.dataSync());
88+
const oldTargetWeights =
89+
agent.targetNetwork.getWeights().map(x => x.dataSync());
90+
91+
for (let i = 0; i < replayBufferSize; ++i) {
92+
agent.playStep();
93+
}
94+
// Burn-in run for memory leak check below.
95+
const batchSize = 512;
96+
const gamma = 0.99;
97+
const optimizer = tf.train.adam();
98+
agent.trainOnReplayBatch(batchSize, gamma, optimizer);
99+
100+
const numTensors0 = tf.memory().numTensors;
101+
agent.trainOnReplayBatch(batchSize, gamma, optimizer);
102+
expect(tf.memory().numTensors).toEqual(numTensors0);
103+
104+
const newOnlineWeights =
105+
agent.onlineNetwork.getWeights().map(x => x.dataSync());
106+
const newTargetWeights =
107+
agent.targetNetwork.getWeights().map(x => x.dataSync());
108+
109+
// Verify that the online network's weights are updated.
110+
for (let i = 0; i < oldOnlineWeights.length; ++i) {
111+
expect(tf.tensor1d(newOnlineWeights[i])
112+
.sub(tf.tensor1d(oldOnlineWeights[i]))
113+
.abs().max().arraySync()).toBeGreaterThan(0);
114+
}
115+
// Verify that the target network's weights have not changed.
116+
for (let i = 0; i < oldOnlineWeights.length; ++i) {
117+
expect(tf.tensor1d(newTargetWeights[i])
118+
.sub(tf.tensor1d(oldTargetWeights[i]))
119+
.abs().max().arraySync()).toEqual(0);
120+
}
121+
});
122+
});

0 commit comments

Comments
 (0)