Skip to content

Commit 2b07002

Browse files
Bycobmergify[bot]
authored andcommitted
fix(torch): black&white image now working with crnn & dataaug
1 parent c675876 commit 2b07002

File tree

5 files changed

+99
-5
lines changed

5 files changed

+99
-5
lines changed

src/backends/torch/native/templates/crnn.cc

+2-1
Original file line numberDiff line numberDiff line change
@@ -306,7 +306,8 @@ namespace dd
306306

307307
if (_timesteps > 0)
308308
{
309-
at::Tensor dummy_img = torch::zeros({ 1, 3, _img_height, _img_width });
309+
at::Tensor dummy_img
310+
= torch::zeros({ 1, _input_channels, _img_height, _img_width });
310311
at::Tensor dummy = _backbone(dummy_img).reshape({ 1, -1, _timesteps });
311312
output_channel = dummy.size(1);
312313
// XXX should use logger

src/backends/torch/torchdataaug.cc

+16
Original file line numberDiff line numberDiff line change
@@ -804,6 +804,15 @@ namespace dd
804804
if (_noise_params._prob == 0.0)
805805
return;
806806

807+
// sanity check
808+
bool img_is_bw = src.channels() == 1;
809+
if (img_is_bw
810+
&& (_noise_params._hist_eq || _noise_params._decolorize
811+
|| _noise_params._jpg || _noise_params._convert_to_hsv
812+
|| _noise_params._convert_to_lab))
813+
throw std::runtime_error(
814+
"Image has one channel when 3 channel dataaug is enabled");
815+
807816
if (_noise_params._rgb)
808817
{
809818
cv::Mat bgr;
@@ -847,6 +856,13 @@ namespace dd
847856
if (_distort_params._prob == 0.0)
848857
return;
849858

859+
bool img_is_bw = src.channels() == 1;
860+
if (img_is_bw
861+
&& (_distort_params._saturation || _distort_params._hue
862+
|| _distort_params._channel_order))
863+
throw std::runtime_error(
864+
"Image has one channel when 3 channel dataaug is enabled");
865+
850866
if (_distort_params._rgb)
851867
{
852868
cv::Mat bgr;

src/backends/torch/torchdataaug.h

+5-2
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,9 @@ namespace dd
156156
class NoiseParams
157157
{
158158
public:
159-
NoiseParams()
159+
NoiseParams(bool bw = false)
160+
: _hist_eq(!bw), _decolorize(!bw), _jpg(!bw), _convert_to_hsv(!bw),
161+
_convert_to_lab(!bw)
160162
{
161163
}
162164

@@ -192,7 +194,8 @@ namespace dd
192194
class DistortParams
193195
{
194196
public:
195-
DistortParams()
197+
DistortParams(bool bw = false)
198+
: _saturation(!bw), _hue(!bw), _channel_order(!bw)
196199
{
197200
}
198201

src/backends/torch/torchlib.cc

+2-2
Original file line numberDiff line numberDiff line change
@@ -756,15 +756,15 @@ namespace dd
756756
ad_geometry.get("pad_mode").get<std::string>());
757757
}
758758
auto *img_ic = reinterpret_cast<ImgTorchInputFileConn *>(&inputc);
759-
NoiseParams noise_params;
759+
NoiseParams noise_params(img_ic->_bw);
760760
noise_params._rgb = img_ic->_rgb;
761761
APIData ad_noise = ad_mllib.getobj("noise");
762762
if (!ad_noise.empty())
763763
{
764764
noise_params._prob = ad_noise.get("prob").get<double>();
765765
this->_logger->info("noise: {}", noise_params._prob);
766766
}
767-
DistortParams distort_params;
767+
DistortParams distort_params(img_ic->_bw);
768768
distort_params._rgb = img_ic->_rgb;
769769
APIData ad_distort = ad_mllib.getobj("distort");
770770
if (!ad_distort.empty())

tests/ut-torchapi.cc

+74
Original file line numberDiff line numberDiff line change
@@ -1610,6 +1610,80 @@ TEST(torchapi, service_train_images_ctc_native)
16101610
fileops::remove_dir(resnet18_ocr_train_repo + "test_0.lmdb");
16111611
}
16121612

1613+
TEST(torchapi, service_train_ctc_native_bw)
1614+
{
1615+
// Just check that there are no errors when training in black&white
1616+
setenv("CUBLAS_WORKSPACE_CONFIG", ":4096:8", true);
1617+
torch::manual_seed(torch_seed);
1618+
at::globalContext().setDeterministicCuDNN(true);
1619+
1620+
// Create service
1621+
JsonAPI japi;
1622+
std::string sname = "imgserv";
1623+
std::string jstr
1624+
= "{\"mllib\":\"torch\",\"description\":\"image\",\"type\":"
1625+
"\"supervised\",\"model\":{\"repository\":\""
1626+
+ resnet18_ocr_train_repo
1627+
+ "\"},\"parameters\":{\"input\":{\"connector\":\"image\","
1628+
"\"width\":112,\"height\":32,\"bw\":true,\"db\":true,\"ctc\":true},"
1629+
"\"mllib\":{\"template\":\"crnn\",\"gpu\":true,\"timesteps\":128}}}";
1630+
std::string joutstr = japi.jrender(japi.service_create(sname, jstr));
1631+
ASSERT_EQ(created_str, joutstr);
1632+
1633+
// Train (few iterations)
1634+
std::string jtrainstr
1635+
= "{\"service\":\"imgserv\",\"async\":false,\"parameters\":{"
1636+
"\"mllib\":{\"solver\":{\"iterations\":3,\"base_lr\":1e-4"
1637+
",\"iter_size\":4,\"solver_type\":\"ADAM\",\"test_"
1638+
"interval\":200},\"net\":{\"batch_size\":32},"
1639+
"\"resume\":false,\"mirror\":false,\"rotate\":false,"
1640+
"\"geometry\":{\"prob\":0.1,\"persp_horizontal\":"
1641+
"false,\"persp_vertical\":false,\"zoom_in\":true,\"zoom_out\":true,"
1642+
"\"pad_mode\":\"constant\"},\"noise\":{\"prob\":0.01},\"distort\":{"
1643+
"\"prob\":0.01},\"dataloader_threads\":4},"
1644+
"\"input\":{\"seed\":12345,\"db\":true,\"shuffle\":true},"
1645+
"\"output\":{\"measure\":[\"acc\"]}},\"data\":[\""
1646+
+ ocr_train_data + "\",\"" + ocr_test_data + "\"]}";
1647+
joutstr = japi.jrender(japi.service_train(jtrainstr));
1648+
JDoc jd;
1649+
std::cout << "joutstr=" << joutstr << std::endl;
1650+
jd.Parse<rapidjson::kParseNanAndInfFlag>(joutstr.c_str());
1651+
ASSERT_TRUE(!jd.HasParseError());
1652+
ASSERT_EQ(201, jd["status"]["code"]);
1653+
1654+
// predict
1655+
std::string jpredictstr = "{\"service\":\"imgserv\",\"parameters\":{"
1656+
"\"output\":{\"best\":1,\"ctc\":true}},"
1657+
"\"data\":[\""
1658+
+ ocr_test_image + "\"]}";
1659+
1660+
joutstr = japi.jrender(japi.service_predict(jpredictstr));
1661+
jd = JDoc();
1662+
std::cout << "joutstr=" << joutstr << std::endl;
1663+
jd.Parse<rapidjson::kParseNanAndInfFlag>(joutstr.c_str());
1664+
ASSERT_TRUE(!jd.HasParseError());
1665+
ASSERT_EQ(200, jd["status"]["code"]);
1666+
1667+
// remove files
1668+
std::unordered_set<std::string> lfiles;
1669+
fileops::list_directory(resnet18_ocr_train_repo, true, false, false, lfiles);
1670+
ASSERT_TRUE(
1671+
fileops::file_exists(resnet18_ocr_train_repo + "checkpoint-3.npt"));
1672+
for (std::string ff : lfiles)
1673+
{
1674+
if (ff.find("checkpoint") != std::string::npos
1675+
|| ff.find("solver") != std::string::npos)
1676+
remove(ff.c_str());
1677+
}
1678+
ASSERT_TRUE(
1679+
!fileops::file_exists(resnet18_ocr_train_repo + "checkpoint-3.npt"));
1680+
1681+
fileops::clear_directory(resnet18_ocr_train_repo + "train.lmdb");
1682+
fileops::clear_directory(resnet18_ocr_train_repo + "test_0.lmdb");
1683+
fileops::remove_dir(resnet18_ocr_train_repo + "train.lmdb");
1684+
fileops::remove_dir(resnet18_ocr_train_repo + "test_0.lmdb");
1685+
}
1686+
16131687
TEST(torchapi, service_publish_trained_model)
16141688
{
16151689
setenv("CUBLAS_WORKSPACE_CONFIG", ":4096:8", true);

0 commit comments

Comments
 (0)