-
Notifications
You must be signed in to change notification settings - Fork 1
/
LSTMWithJava.java
243 lines (204 loc) · 10.9 KB
/
LSTMWithJava.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
import org.apache.commons.io.FileUtils;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.BackpropType;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.Updater;
import org.deeplearning4j.nn.conf.layers.GravesLSTM;
import org.deeplearning4j.nn.conf.layers.RnnOutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.api.IterationListener;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.deeplearning4j.parallelism.ParallelWrapper;
import org.nd4j.jita.conf.CudaEnvironment;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.buffer.util.DataTypeUtil;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction;
import java.io.File;
import java.io.IOException;
import java.net.URL;
import java.nio.charset.Charset;
import java.util.Random;
public class LSTMWithJava {
public static void main( String[] args ) throws Exception {
int lstmLayerSize = 200; //Number of units
int miniBatchSize = 32; //Size of mini batch
int exampleLength = 1000; //Length of each training example sequence to use.
int tbpttLength = 50; //Length for truncated backpropagation through time.
int numEpochs = 1; //Total number of training epochs
int generateSamplesEveryNMinibatches = 10; //How frequently to generate samples from the network?
int nSamplesToGenerate = 4; //Number of samples to generate after each training epoch
int nCharactersToSample = 300; //Length of each sample to generate
String generationInitialization = null; //Optional character initialization; a random character is used if null
// Enable GPU Usage
CudaEnvironment.getInstance().getConfiguration()
// key option enabled
.allowMultiGPU(true)
// we're allowing larger memory caches
.setMaximumDeviceCache(2L * 1024L * 1024L * 1024L)
// cross-device access is used for faster model averaging over pcie
.allowCrossDeviceAccess(true);
Random rng = new Random(12345);
//Get a DataSetIterator that handles vectorization of text into something we can use to train
// our GravesLSTM network.
CharacterIterator iter = getShakespeareIterator(miniBatchSize,exampleLength);
int nOut = iter.totalOutcomes();
//Set up network configuration:
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).iterations(1)
.learningRate(0.1)
.rmsDecay(0.95)
.seed(12345)
.regularization(true)
.l2(0.001)
.weightInit(WeightInit.XAVIER)
.updater(Updater.RMSPROP)
.list()
.layer(0, new GravesLSTM.Builder().nIn(iter.inputColumns()).nOut(lstmLayerSize)
.activation(Activation.TANH).build())
.layer(1, new GravesLSTM.Builder().nIn(lstmLayerSize).nOut(lstmLayerSize)
.activation(Activation.TANH).build())
.layer(2, new RnnOutputLayer.Builder(LossFunction.MCXENT).activation(Activation.SOFTMAX) //MCXENT + softmax for classification
.nIn(lstmLayerSize).nOut(nOut).build())
.backpropType(BackpropType.TruncatedBPTT).tBPTTForwardLength(tbpttLength).tBPTTBackwardLength(tbpttLength)
.pretrain(false).backprop(true)
.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
net.setListeners(new ScoreIterationListener(1), new IterationListener() {
@Override
public boolean invoked() {
return true;
}
@Override
public void invoke() {
}
@Override
public void iterationDone(Model model, int iteration) {
System.out.println("--------------------");
System.out.println("Sampling characters from network given initialization \"" + (generationInitialization == null ? "" : generationInitialization) + "\"");
String[] samples = sampleCharactersFromNetwork(generationInitialization, (MultiLayerNetwork) model,iter,rng,nCharactersToSample,nSamplesToGenerate);
for( int j = 0; j<samples.length; j++) {
System.out.println("----- Sample " + j + " -----");
System.out.println(samples[j]);
System.out.println();
}
}
});
//Print the number of parameters in the network (and for each layer)
Layer[] layers = net.getLayers();
int totalNumParams = 0;
for( int i= 0; i < layers.length; i++) {
int nParams = layers[i].numParams();
System.out.println("Number of parameters in layer " + i + ": " + nParams);
totalNumParams += nParams;
}
System.out.println("Total number of network parameters: " + totalNumParams);
// ParallelWrapper will take care of load balancing between GPUs.
ParallelWrapper wrapper = new ParallelWrapper.Builder(net)
// DataSets prefetching options. Set this value with respect to number of actual devices
.prefetchBuffer(24)
// set number of workers equal or higher then number of available devices. x1-x2 are good values to start with
.workers(4)
// rare averaging improves performance, but might reduce model accuracy
.averagingFrequency(3)
// if set to TRUE, on every averaging model score will be reported
.reportScoreAfterAveraging(true)
// optinal parameter, set to false ONLY if your system has support P2P memory access across PCIe (hint: AWS do not support P2P)
.useLegacyAveraging(true)
.build();
wrapper.fit(iter);
System.out.println("\n\nExample complete");
}
/** Downloads Shakespeare training data and stores it locally (temp directory). Then set up and return a simple
* DataSetIterator that does vectorization based on the text.
* @param miniBatchSize Number of text segments in each training mini-batch
* @param sequenceLength Number of characters in each text segment.
*/
public static CharacterIterator getShakespeareIterator(int miniBatchSize, int sequenceLength) throws Exception {
//The Complete Works of William Shakespeare
//5.3MB file in UTF-8 Encoding, ~5.4 million characters
String url = "https://s3.amazonaws.com/dl4j-distribution/pg100.txt";
String tempDir = System.getProperty("java.io.tmpdir");
String fileLocation = tempDir + "/Data/Shakespeare.txt"; //Storage location from downloaded file
File f = new File(fileLocation);
if( !f.exists() ){
FileUtils.copyURLToFile(new URL(url), f);
System.out.println("File downloaded to " + f.getAbsolutePath());
} else {
System.out.println("Using existing text file at " + f.getAbsolutePath());
}
if(!f.exists()) throw new IOException("File does not exist: " + fileLocation); //Download problem?
char[] validCharacters = CharacterIterator.getMinimalCharacterSet(); //Which characters are allowed? Others will be removed
return new CharacterIterator(fileLocation, Charset.forName("UTF-8"),
miniBatchSize, sequenceLength, validCharacters, new Random(12345));
}
/** Generate a sample from the network, given an (optional, possibly null) initialization. Initialization
* can be used to 'prime' the RNN with a sequence you want to extend/continue.
* Note that the initalization is used for all samples
* @param initialization String, may be null. If null, select a random character as initialization for all samples
* @param charactersToSample Number of characters to sample from network (excluding initialization)
* @param net MultiLayerNetwork with one or more GravesLSTM/RNN layers and a softmax output layer
* @param iter CharacterIterator. Used for going from indexes back to characters
*/
private static String[] sampleCharactersFromNetwork(String initialization, MultiLayerNetwork net,
CharacterIterator iter, Random rng, int charactersToSample, int numSamples ){
//Set up initialization. If no initialization: use a random character
if( initialization == null ){
initialization = String.valueOf(iter.getRandomCharacter());
}
//Create input for initialization
INDArray initializationInput = Nd4j.zeros(numSamples, iter.inputColumns(), initialization.length());
char[] init = initialization.toCharArray();
for( int i=0; i<init.length; i++ ){
int idx = iter.convertCharacterToIndex(init[i]);
for( int j=0; j<numSamples; j++ ){
initializationInput.putScalar(new int[]{j,idx,i}, 1.0f);
}
}
StringBuilder[] sb = new StringBuilder[numSamples];
for( int i=0; i<numSamples; i++ ) sb[i] = new StringBuilder(initialization);
//Sample from network (and feed samples back into input) one character at a time (for all samples)
//Sampling is done in parallel here
net.rnnClearPreviousState();
INDArray output = net.rnnTimeStep(initializationInput);
output = output.tensorAlongDimension(output.size(2)-1,1,0); //Gets the last time step output
for( int i=0; i<charactersToSample; i++ ){
//Set up next input (single time step) by sampling from previous output
INDArray nextInput = Nd4j.zeros(numSamples,iter.inputColumns());
//Output is a probability distribution. Sample from this for each example we want to generate, and add it to the new input
for( int s=0; s<numSamples; s++ ){
double[] outputProbDistribution = new double[iter.totalOutcomes()];
for( int j=0; j<outputProbDistribution.length; j++ ) outputProbDistribution[j] = output.getDouble(s,j);
int sampledCharacterIdx = sampleFromDistribution(outputProbDistribution,rng);
nextInput.putScalar(new int[]{s,sampledCharacterIdx}, 1.0f); //Prepare next time step input
sb[s].append(iter.convertIndexToCharacter(sampledCharacterIdx)); //Add sampled character to StringBuilder (human readable output)
}
output = net.rnnTimeStep(nextInput); //Do one time step of forward pass
}
String[] out = new String[numSamples];
for( int i=0; i<numSamples; i++ ) out[i] = sb[i].toString();
return out;
}
/** Given a probability distribution over discrete classes, sample from the distribution
* and return the generated class index.
* @param distribution Probability distribution over classes. Must sum to 1.0
*/
public static int sampleFromDistribution( double[] distribution, Random rng ){
double d = rng.nextDouble();
double sum = 0.0;
for( int i=0; i<distribution.length; i++ ){
sum += distribution[i];
if( d <= sum ) return i;
}
//Should never happen if distribution is a valid probability distribution
throw new IllegalArgumentException("Distribution is invalid? d="+d+", sum="+sum);
}
}