Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Train method fails in MultipleNeuronsOutputError.getTotalErrorSamples #40

Open
joelself opened this issue Sep 30, 2015 · 0 comments
Open

Comments

@joelself
Copy link

I have the following code (located at the end of the issue) to create, train and test a NN, but it fails in MultipleNeuronsOutputError here:

    for (OutputTargetTuple t : tuples) {
        if (!outputToTarget.get(t.outputPos).equals(t.targetPos)) {
        errorSamples++;
        }
    }

If I add outputToTarget.get(t.outputPos) != null && to the if statement it finishes successfully, but with zero samples and thus no error value.

I've checked to make sure the data is read in correctly and it seems to be fine. It trains just fine, the problem is that it fails on test.

Also switching to the GPU makes the training of a single epoch take forever. I've never actually seen it finish.

        Environment.getInstance().setExecutionMode(EXECUTION_MODE.SEQ);

        // create multi layer perceptron with one hidden layer and bias
        Environment.getInstance().setUseWeightsSharedMemory(false);
        Environment.getInstance().setUseDataSharedMemory(false);
        NeuralNetworkImpl mlp = NNFactory.mlpSigmoid(new int[]{40, 75, 75, 75, 10}, true);

        // create training and testing input providers
        FileReader reader;
        System.out.println("Try read data");
        List<float[][]> data = new ArrayList<float[][]>();
        try {
            reader = new FileReader("C:\\Users\\jself\\Data\\training_data.data2");
            data = GetDataFromFile(reader);
        } catch (FileNotFoundException e) {

        }
        System.out.println("Create input provider and trainer");
        SimpleInputProvider input = new SimpleInputProvider(data.get(0), data.get(1));
        // create backpropagation trainer for the network
        BackPropagationTrainer<?> bpt = TrainerFactory.backPropagation(mlp, input, input, new MultipleNeuronsOutputError(), new NNRandomInitializer(new MersenneTwisterRandomInitializer(-0.01f, 0.01f)), 0.1f, 0.7f, 0f, 0f, 0f, 1, 1, 1);

        // add logging
        bpt.addEventListener(new LogTrainingListener(Thread.currentThread().getStackTrace()[1].getMethodName()));

        // early stopping
        //bpt.addEventListener(new EarlyStoppingListener(testingInput, 10, 0.1f));

        System.out.println("Start training");
        // train
        bpt.train();

        System.out.println("Start testing");
        // test
        bpt.test();
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant