Skip to content

Commit

Permalink
finally using the Caffe inner test net while training, and the deploy…
Browse files Browse the repository at this point in the history
… net for predictions
  • Loading branch information
beniz committed Nov 27, 2015
1 parent a062871 commit 010b4fe
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 16 deletions.
16 changes: 7 additions & 9 deletions examples/caffe/mnist/lenet_train_test.prototxt
Original file line number Diff line number Diff line change
Expand Up @@ -11,26 +11,24 @@ layer {
scale: 0.00390625
}
data_param {
source: "mnist_train_lmdb"
source: "train.lmdb"
batch_size: 64
backend: LMDB
}
}
layer {
name: "mnist"
type: "Data"
type: "MemoryData"
top: "data"
top: "label"
include {
phase: TEST
}
transform_param {
scale: 0.00390625
}
data_param {
source: "mnist_test_lmdb"
batch_size: 100
backend: LMDB
memory_data_param {
batch_size: 64
channels: 1
height: 28
width: 28
}
}
layer {
Expand Down
9 changes: 5 additions & 4 deletions src/caffelib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -437,7 +437,10 @@ namespace dd
else lparam = net_param.add_layer(); // training loss
lparam->set_name("loss");
if (regression)
lparam->set_type("EuclideanLoss");
{
lparam->set_type("EuclideanLoss");
lparam->add_include()->set_phase(caffe::TRAIN);
}
else lparam->set_type("SoftmaxWithLoss");
lparam->add_bottom(last_ip);
lparam->add_bottom("label");
Expand Down Expand Up @@ -1222,9 +1225,7 @@ namespace dd
&& (solver->iter_ > 0 || solver->param_.test_initialization()))
{
if (!_net)
{
_net = new Net<float>(this->_mlmodel._def,caffe::TEST); //TODO: this is loading deploy file, we could use the test net when it exists and if its source is memory data
}
_net = new Net<float>(this->_mlmodel._trainf,caffe::TEST); //XXX: needs to be memory data input layer
_net->ShareTrainedLayersWith(solver->net().get());
APIData meas_out;
test(_net,ad,inputc,test_batch_size,has_mean_file,meas_out);
Expand Down
7 changes: 6 additions & 1 deletion src/caffemodel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ namespace dd
else
{
_def = ad.get("def").get<std::string>();
_trainf = ad.get("trainf").get<std::string>();
_weights = ad.get("weights").get<std::string>();
_corresp = ad.get("corresp").get<std::string>();
_solver = ad.get("solver").get<std::string>();
Expand All @@ -51,6 +52,7 @@ namespace dd
int CaffeModel::read_from_repository(const std::string &repo)
{
static std::string deploy = "deploy.prototxt";
static std::string train = ".prototxt";
static std::string weights = ".caffemodel";
static std::string sstate = ".solverstate";
static std::string corresp = "corresp";
Expand All @@ -63,7 +65,7 @@ namespace dd
LOG(ERROR) << "error reading or listing caffe models in repository " << repo << std::endl;
return 1;
}
std::string deployf,weightsf,correspf,solverf,sstatef;
std::string deployf,trainf,weightsf,correspf,solverf,sstatef;
long int state_t=-1, weight_t=-1;
auto hit = lfiles.begin();
while(hit!=lfiles.end())
Expand Down Expand Up @@ -100,9 +102,12 @@ namespace dd
deployf = (*hit);
else if ((*hit).find(solver)!=std::string::npos)
solverf = (*hit);
else if ((*hit).find(train)!=std::string::npos)
trainf = (*hit);
++hit;
}
_def = deployf;
_trainf = trainf;
_weights = weightsf;
_corresp = correspf;
_solver = solverf;
Expand Down
1 change: 1 addition & 0 deletions src/caffemodel.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ namespace dd
}

std::string _def; /**< file name of the model definition in the form of a protocol buffer message description. */
std::string _trainf; /**< file name of the training model definition. */
std::string _weights; /**< file name of the network's weights. */
std::string _corresp; /**< file name of the class correspondences (e.g. house / 23) */
std::unordered_map<int,std::string> _hcorresp; /**< table of class correspondences. */
Expand Down
6 changes: 4 additions & 2 deletions tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,15 @@ if (GTEST_FOUND)
set(MNIST_EXAMPLE_TEST_ARCHIVE "mnist_test_lmdb.tar.bz2")
set(MNIST_EXAMPLE_TRAIN_OUT "mnist_train_lmdb")
set(MNIST_EXAMPLE_TEST_OUT "mnist_test_lmdb")
set(MNIST_EXAMPLE_TRAIN_OUT_T "train.lmdb")
set(MNIST_EXAMPLE_TEST_OUT_T "test.lmdb")
if (NOT EXISTS "${MNIST_EXAMPLE_PATH}${MNIST_EXAMPLE_TRAIN_ARCHIVE}")
file(DOWNLOAD "http://www.deepdetect.com/dd/examples/caffe/mnist/mnist_train_lmdb.tar.bz2" "${MNIST_EXAMPLE_PATH}/${MNIST_EXAMPLE_TRAIN_ARCHIVE}")
file(DOWNLOAD "http://www.deepdetect.com/dd/examples/caffe/mnist/mnist_test_lmdb.tar.bz2" "${MNIST_EXAMPLE_PATH}/${MNIST_EXAMPLE_TEST_ARCHIVE}")
execute_process(COMMAND ${CMAKE_COMMAND} -E tar xvjf ${MNIST_EXAMPLE_PATH}${MNIST_EXAMPLE_TRAIN_ARCHIVE} -C ${MNIST_EXAMPLE_PATH}) # XXX: output directory doesn't work here, maybe because of special cmake handling
execute_process(COMMAND ${CMAKE_COMMAND} -E tar xvjf ${MNIST_EXAMPLE_PATH}${MNIST_EXAMPLE_TEST_ARCHIVE} -C ${MNIST_EXAMPLE_PATH})
execute_process(COMMAND ${CMAKE_COMMAND} -E rename ${MNIST_EXAMPLE_TRAIN_OUT} ${MNIST_EXAMPLE_PATH}${MNIST_EXAMPLE_TRAIN_OUT})
execute_process(COMMAND ${CMAKE_COMMAND} -E rename ${MNIST_EXAMPLE_TEST_OUT} ${MNIST_EXAMPLE_PATH}${MNIST_EXAMPLE_TEST_OUT})
execute_process(COMMAND ${CMAKE_COMMAND} -E rename ${MNIST_EXAMPLE_TRAIN_OUT} ${MNIST_EXAMPLE_PATH}${MNIST_EXAMPLE_TRAIN_OUT_T})
execute_process(COMMAND ${CMAKE_COMMAND} -E rename ${MNIST_EXAMPLE_TEST_OUT} ${MNIST_EXAMPLE_PATH}${MNIST_EXAMPLE_TEST_OUT_T})
endif()

set(FOREST_EXAMPLE_PATH "examples/all/forest_type/")
Expand Down

0 comments on commit 010b4fe

Please sign in to comment.