Skip to content

Commit

Permalink
mnist fix
Browse files Browse the repository at this point in the history
  • Loading branch information
ivan-vasilev committed Apr 25, 2014
1 parent c54e424 commit 8f20c3c
Showing 1 changed file with 3 additions and 42 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,15 @@

import com.amd.aparapi.Kernel.EXECUTION_MODE;
import com.github.neuralnetworks.architecture.NeuralNetworkImpl;
import com.github.neuralnetworks.architecture.types.Autoencoder;
import com.github.neuralnetworks.architecture.types.NNFactory;
import com.github.neuralnetworks.architecture.types.RBM;
import com.github.neuralnetworks.input.MultipleNeuronsOutputError;
import com.github.neuralnetworks.input.ScalingInputFunction;
import com.github.neuralnetworks.samples.mnist.MnistInputProvider;
import com.github.neuralnetworks.training.Trainer;
import com.github.neuralnetworks.training.TrainerFactory;
import com.github.neuralnetworks.training.backpropagation.BackPropagationTrainer;
import com.github.neuralnetworks.training.events.LogTrainingListener;
import com.github.neuralnetworks.training.random.MersenneTwisterRandomInitializer;
import com.github.neuralnetworks.training.random.NNRandomInitializer;
import com.github.neuralnetworks.training.rbm.AparapiCDTrainer;
import com.github.neuralnetworks.util.Environment;

/**
Expand Down Expand Up @@ -72,44 +68,6 @@ public void testSigmoidHiddenBP() {
assertEquals(0, bpt.getOutputError().getTotalNetworkError(), 0.1);
}

@Test
public void testRBM() {
RBM rbm = NNFactory.rbm(784, 10, false);
MnistInputProvider trainInputProvider = new MnistInputProvider("train-images.idx3-ubyte", "train-labels.idx1-ubyte");
trainInputProvider.addInputModifier(new ScalingInputFunction(255));
MnistInputProvider testInputProvider = new MnistInputProvider("t10k-images.idx3-ubyte", "t10k-labels.idx1-ubyte");
testInputProvider.addInputModifier(new ScalingInputFunction(255));

AparapiCDTrainer t = TrainerFactory.cdSigmoidBinaryTrainer(rbm, trainInputProvider, testInputProvider, new MultipleNeuronsOutputError(), new NNRandomInitializer(new MersenneTwisterRandomInitializer(-0.01f, 0.01f)), 0.01f, 0.5f, 0f, 0f, 1, 1, 1, false);

t.addEventListener(new LogTrainingListener(Thread.currentThread().getStackTrace()[1].getMethodName(), false, true));
Environment.getInstance().setExecutionMode(EXECUTION_MODE.CPU);
t.train();
t.test();

assertEquals(0, t.getOutputError().getTotalNetworkError(), 0.8);
}

@Test
public void testAE() {
Autoencoder nn = NNFactory.autoencoderSigmoid(784, 10, true);

MnistInputProvider trainInputProvider = new MnistInputProvider("train-images.idx3-ubyte", "train-labels.idx1-ubyte");
trainInputProvider.addInputModifier(new ScalingInputFunction(255));
MnistInputProvider testInputProvider = new MnistInputProvider("t10k-images.idx3-ubyte", "t10k-labels.idx1-ubyte");
testInputProvider.addInputModifier(new ScalingInputFunction(255));

Trainer<?> t = TrainerFactory.backPropagationAutoencoder(nn, trainInputProvider, testInputProvider, new MultipleNeuronsOutputError(), new NNRandomInitializer(new MersenneTwisterRandomInitializer(-0.01f, 0.01f)), 0.01f, 0.5f, 0f, 0f, 0f, 1, 1000, 1);

t.addEventListener(new LogTrainingListener(Thread.currentThread().getStackTrace()[1].getMethodName(), false, true));
Environment.getInstance().setExecutionMode(EXECUTION_MODE.CPU);
t.train();
nn.removeLayer(nn.getOutputLayer());
t.test();

assertEquals(0, t.getOutputError().getTotalNetworkError(), 0.1);
}

@Test
public void testLeNetSmall() {
// cpu execution mode
Expand Down Expand Up @@ -180,6 +138,9 @@ public void testLeNetTiny() {
*/
@Test
public void testLeNetTiny2() {
Environment.getInstance().setUseDataSharedMemory(false);
Environment.getInstance().setUseWeightsSharedMemory(false);

// very simple convolutional network with a single convolutional layer with 6 5x5 filters and a single 2x2 max pooling layer
NeuralNetworkImpl nn = NNFactory.convNN(new int[][] { { 28, 28, 1 }, { 5, 5, 6, 1 }, {2, 2}, {10} }, true);
nn.setLayerCalculator(NNFactory.lcSigmoid(nn, null));
Expand Down

0 comments on commit 8f20c3c

Please sign in to comment.