@@ -1610,6 +1610,80 @@ TEST(torchapi, service_train_images_ctc_native)
1610
1610
fileops::remove_dir (resnet18_ocr_train_repo + " test_0.lmdb" );
1611
1611
}
1612
1612
1613
+ TEST (torchapi, service_train_ctc_native_bw)
1614
+ {
1615
+ // Just check that there are no errors when training in black&white
1616
+ setenv (" CUBLAS_WORKSPACE_CONFIG" , " :4096:8" , true );
1617
+ torch::manual_seed (torch_seed);
1618
+ at::globalContext ().setDeterministicCuDNN (true );
1619
+
1620
+ // Create service
1621
+ JsonAPI japi;
1622
+ std::string sname = " imgserv" ;
1623
+ std::string jstr
1624
+ = " {\" mllib\" :\" torch\" ,\" description\" :\" image\" ,\" type\" :"
1625
+ " \" supervised\" ,\" model\" :{\" repository\" :\" "
1626
+ + resnet18_ocr_train_repo
1627
+ + " \" },\" parameters\" :{\" input\" :{\" connector\" :\" image\" ,"
1628
+ " \" width\" :112,\" height\" :32,\" bw\" :true,\" db\" :true,\" ctc\" :true},"
1629
+ " \" mllib\" :{\" template\" :\" crnn\" ,\" gpu\" :true,\" timesteps\" :128}}}" ;
1630
+ std::string joutstr = japi.jrender (japi.service_create (sname, jstr));
1631
+ ASSERT_EQ (created_str, joutstr);
1632
+
1633
+ // Train (few iterations)
1634
+ std::string jtrainstr
1635
+ = " {\" service\" :\" imgserv\" ,\" async\" :false,\" parameters\" :{"
1636
+ " \" mllib\" :{\" solver\" :{\" iterations\" :3,\" base_lr\" :1e-4"
1637
+ " ,\" iter_size\" :4,\" solver_type\" :\" ADAM\" ,\" test_"
1638
+ " interval\" :200},\" net\" :{\" batch_size\" :32},"
1639
+ " \" resume\" :false,\" mirror\" :false,\" rotate\" :false,"
1640
+ " \" geometry\" :{\" prob\" :0.1,\" persp_horizontal\" :"
1641
+ " false,\" persp_vertical\" :false,\" zoom_in\" :true,\" zoom_out\" :true,"
1642
+ " \" pad_mode\" :\" constant\" },\" noise\" :{\" prob\" :0.01},\" distort\" :{"
1643
+ " \" prob\" :0.01},\" dataloader_threads\" :4},"
1644
+ " \" input\" :{\" seed\" :12345,\" db\" :true,\" shuffle\" :true},"
1645
+ " \" output\" :{\" measure\" :[\" acc\" ]}},\" data\" :[\" "
1646
+ + ocr_train_data + " \" ,\" " + ocr_test_data + " \" ]}" ;
1647
+ joutstr = japi.jrender (japi.service_train (jtrainstr));
1648
+ JDoc jd;
1649
+ std::cout << " joutstr=" << joutstr << std::endl;
1650
+ jd.Parse <rapidjson::kParseNanAndInfFlag >(joutstr.c_str ());
1651
+ ASSERT_TRUE (!jd.HasParseError ());
1652
+ ASSERT_EQ (201 , jd[" status" ][" code" ]);
1653
+
1654
+ // predict
1655
+ std::string jpredictstr = " {\" service\" :\" imgserv\" ,\" parameters\" :{"
1656
+ " \" output\" :{\" best\" :1,\" ctc\" :true}},"
1657
+ " \" data\" :[\" "
1658
+ + ocr_test_image + " \" ]}" ;
1659
+
1660
+ joutstr = japi.jrender (japi.service_predict (jpredictstr));
1661
+ jd = JDoc ();
1662
+ std::cout << " joutstr=" << joutstr << std::endl;
1663
+ jd.Parse <rapidjson::kParseNanAndInfFlag >(joutstr.c_str ());
1664
+ ASSERT_TRUE (!jd.HasParseError ());
1665
+ ASSERT_EQ (200 , jd[" status" ][" code" ]);
1666
+
1667
+ // remove files
1668
+ std::unordered_set<std::string> lfiles;
1669
+ fileops::list_directory (resnet18_ocr_train_repo, true , false , false , lfiles);
1670
+ ASSERT_TRUE (
1671
+ fileops::file_exists (resnet18_ocr_train_repo + " checkpoint-3.npt" ));
1672
+ for (std::string ff : lfiles)
1673
+ {
1674
+ if (ff.find (" checkpoint" ) != std::string::npos
1675
+ || ff.find (" solver" ) != std::string::npos)
1676
+ remove (ff.c_str ());
1677
+ }
1678
+ ASSERT_TRUE (
1679
+ !fileops::file_exists (resnet18_ocr_train_repo + " checkpoint-3.npt" ));
1680
+
1681
+ fileops::clear_directory (resnet18_ocr_train_repo + " train.lmdb" );
1682
+ fileops::clear_directory (resnet18_ocr_train_repo + " test_0.lmdb" );
1683
+ fileops::remove_dir (resnet18_ocr_train_repo + " train.lmdb" );
1684
+ fileops::remove_dir (resnet18_ocr_train_repo + " test_0.lmdb" );
1685
+ }
1686
+
1613
1687
TEST (torchapi, service_publish_trained_model)
1614
1688
{
1615
1689
setenv (" CUBLAS_WORKSPACE_CONFIG" , " :4096:8" , true );
0 commit comments