diff --git a/examples/caffe/mnist/lenet_train_test.prototxt b/examples/caffe/mnist/lenet_train_test.prototxt index 4a76cbd95..d4063dbbf 100644 --- a/examples/caffe/mnist/lenet_train_test.prototxt +++ b/examples/caffe/mnist/lenet_train_test.prototxt @@ -11,26 +11,24 @@ layer { scale: 0.00390625 } data_param { - source: "mnist_train_lmdb" + source: "train.lmdb" batch_size: 64 backend: LMDB } } layer { name: "mnist" - type: "Data" + type: "MemoryData" top: "data" top: "label" include { phase: TEST } - transform_param { - scale: 0.00390625 - } - data_param { - source: "mnist_test_lmdb" - batch_size: 100 - backend: LMDB + memory_data_param { + batch_size: 64 + channels: 1 + height: 28 + width: 28 } } layer { diff --git a/src/caffelib.cc b/src/caffelib.cc index 0b7c32319..aa1757945 100644 --- a/src/caffelib.cc +++ b/src/caffelib.cc @@ -437,7 +437,10 @@ namespace dd else lparam = net_param.add_layer(); // training loss lparam->set_name("loss"); if (regression) - lparam->set_type("EuclideanLoss"); + { + lparam->set_type("EuclideanLoss"); + lparam->add_include()->set_phase(caffe::TRAIN); + } else lparam->set_type("SoftmaxWithLoss"); lparam->add_bottom(last_ip); lparam->add_bottom("label"); @@ -1222,9 +1225,7 @@ namespace dd && (solver->iter_ > 0 || solver->param_.test_initialization())) { if (!_net) - { - _net = new Net(this->_mlmodel._def,caffe::TEST); //TODO: this is loading deploy file, we could use the test net when it exists and if its source is memory data - } + _net = new Net(this->_mlmodel._trainf,caffe::TEST); //XXX: needs to be memory data input layer _net->ShareTrainedLayersWith(solver->net().get()); APIData meas_out; test(_net,ad,inputc,test_batch_size,has_mean_file,meas_out); diff --git a/src/caffemodel.cc b/src/caffemodel.cc index a6c0e7590..2e7e16bce 100644 --- a/src/caffemodel.cc +++ b/src/caffemodel.cc @@ -40,6 +40,7 @@ namespace dd else { _def = ad.get("def").get(); + _trainf = ad.get("trainf").get(); _weights = ad.get("weights").get(); _corresp = ad.get("corresp").get(); _solver = ad.get("solver").get(); @@ -51,6 +52,7 @@ namespace dd int CaffeModel::read_from_repository(const std::string &repo) { static std::string deploy = "deploy.prototxt"; + static std::string train = ".prototxt"; static std::string weights = ".caffemodel"; static std::string sstate = ".solverstate"; static std::string corresp = "corresp"; @@ -63,7 +65,7 @@ namespace dd LOG(ERROR) << "error reading or listing caffe models in repository " << repo << std::endl; return 1; } - std::string deployf,weightsf,correspf,solverf,sstatef; + std::string deployf,trainf,weightsf,correspf,solverf,sstatef; long int state_t=-1, weight_t=-1; auto hit = lfiles.begin(); while(hit!=lfiles.end()) @@ -100,9 +102,12 @@ namespace dd deployf = (*hit); else if ((*hit).find(solver)!=std::string::npos) solverf = (*hit); + else if ((*hit).find(train)!=std::string::npos) + trainf = (*hit); ++hit; } _def = deployf; + _trainf = trainf; _weights = weightsf; _corresp = correspf; _solver = solverf; diff --git a/src/caffemodel.h b/src/caffemodel.h index 6c9aced5e..e10ae11dc 100644 --- a/src/caffemodel.h +++ b/src/caffemodel.h @@ -50,6 +50,7 @@ namespace dd } std::string _def; /**< file name of the model definition in the form of a protocol buffer message description. */ + std::string _trainf; /**< file name of the training model definition. */ std::string _weights; /**< file name of the network's weights. */ std::string _corresp; /**< file name of the class correspondences (e.g. house / 23) */ std::unordered_map _hcorresp; /**< table of class correspondences. */ diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 187829c28..6d09f69df 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -46,13 +46,15 @@ if (GTEST_FOUND) set(MNIST_EXAMPLE_TEST_ARCHIVE "mnist_test_lmdb.tar.bz2") set(MNIST_EXAMPLE_TRAIN_OUT "mnist_train_lmdb") set(MNIST_EXAMPLE_TEST_OUT "mnist_test_lmdb") + set(MNIST_EXAMPLE_TRAIN_OUT_T "train.lmdb") + set(MNIST_EXAMPLE_TEST_OUT_T "test.lmdb") if (NOT EXISTS "${MNIST_EXAMPLE_PATH}${MNIST_EXAMPLE_TRAIN_ARCHIVE}") file(DOWNLOAD "http://www.deepdetect.com/dd/examples/caffe/mnist/mnist_train_lmdb.tar.bz2" "${MNIST_EXAMPLE_PATH}/${MNIST_EXAMPLE_TRAIN_ARCHIVE}") file(DOWNLOAD "http://www.deepdetect.com/dd/examples/caffe/mnist/mnist_test_lmdb.tar.bz2" "${MNIST_EXAMPLE_PATH}/${MNIST_EXAMPLE_TEST_ARCHIVE}") execute_process(COMMAND ${CMAKE_COMMAND} -E tar xvjf ${MNIST_EXAMPLE_PATH}${MNIST_EXAMPLE_TRAIN_ARCHIVE} -C ${MNIST_EXAMPLE_PATH}) # XXX: output directory doesn't work here, maybe because of special cmake handling execute_process(COMMAND ${CMAKE_COMMAND} -E tar xvjf ${MNIST_EXAMPLE_PATH}${MNIST_EXAMPLE_TEST_ARCHIVE} -C ${MNIST_EXAMPLE_PATH}) - execute_process(COMMAND ${CMAKE_COMMAND} -E rename ${MNIST_EXAMPLE_TRAIN_OUT} ${MNIST_EXAMPLE_PATH}${MNIST_EXAMPLE_TRAIN_OUT}) - execute_process(COMMAND ${CMAKE_COMMAND} -E rename ${MNIST_EXAMPLE_TEST_OUT} ${MNIST_EXAMPLE_PATH}${MNIST_EXAMPLE_TEST_OUT}) + execute_process(COMMAND ${CMAKE_COMMAND} -E rename ${MNIST_EXAMPLE_TRAIN_OUT} ${MNIST_EXAMPLE_PATH}${MNIST_EXAMPLE_TRAIN_OUT_T}) + execute_process(COMMAND ${CMAKE_COMMAND} -E rename ${MNIST_EXAMPLE_TEST_OUT} ${MNIST_EXAMPLE_PATH}${MNIST_EXAMPLE_TEST_OUT_T}) endif() set(FOREST_EXAMPLE_PATH "examples/all/forest_type/")