-
Notifications
You must be signed in to change notification settings - Fork 0
/
export_model.py
140 lines (113 loc) · 5.79 KB
/
export_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import torch
import torchvision
from Chars import *
from OCRModels import *
from TorchScriptModels import *
device = torch.device('cuda')
def export_rcnn_model():
rcnn_model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=False, pretrained_backbone=True,
num_classes=2, min_size=400, max_size=600)
rcnn_checkpoint = torch.load('models/FasterRCNN_last_checkpoint.pt')
rcnn_model.load_state_dict(rcnn_checkpoint['model'])
rcnn_model = RCNNTorchScript(rcnn_model)
rcnn_model.eval()
rcnn_model.to(device)
rcnn_model_script = torch.jit.script(rcnn_model)
torch.jit.save(rcnn_model_script, 'models/object_detection.torchscript')
def export_ocr_efficientnet_model():
chars = SC5000Chars()
chars.export('models/ocr_SC5000Chars.txt')
ocr_model = CRNNEfficientNetB3(len(chars.chars), rnn_hidden=768, bidirectional=True)
ocr_checkpoint = torch.load('models/ocr_v3_amp_SC5000Chars_yuan/CRNNEfficientNetB3_768_bi/CRNNEfficientNetB3_768_bi_checkpoint.pt')
ocr_model.load_state_dict(ocr_checkpoint['model'])
ocr_model.backbone.set_swish(memory_efficient=False)
ocr_model.to(device)
ocr_model.eval()
ocr_model.backbone = torch.jit.trace(ocr_model.backbone, torch.rand(1, 3, 40, 900).to(device))
ocr_model = OCRTorchScriptV3(ocr_model)
ocr_model.eval()
ocr_model.to(device)
ocr_model_script = torch.jit.script(ocr_model)
torch.jit.save(ocr_model_script, 'models/ocrV3_SC5000Chars_yuan.torchscript')
ocr_model = CRNNEfficientNetB3(len(chars.chars), rnn_hidden=768, bidirectional=True)
ocr_checkpoint = torch.load('models/ocr_v3_amp_SC5000Chars_hei/CRNNEfficientNetB3_768_bi/CRNNEfficientNetB3_768_bi_checkpoint.pt')
ocr_model.load_state_dict(ocr_checkpoint['model'])
ocr_model.backbone.set_swish(memory_efficient=False)
ocr_model.to(device)
ocr_model.eval()
ocr_model.backbone = torch.jit.trace(ocr_model.backbone, torch.rand(1, 3, 40, 900).to(device))
ocr_model = OCRTorchScriptV3(ocr_model)
ocr_model.eval()
ocr_model.to(device)
ocr_model_script = torch.jit.script(ocr_model)
torch.jit.save(ocr_model_script, 'models/ocrV3_SC5000Chars_hei.torchscript')
chars = TC5000Chars()
chars.export('models/ocr_TC5000Chars.txt')
ocr_model = CRNNEfficientNetB3(len(chars.chars), rnn_hidden=768, bidirectional=True)
ocr_checkpoint = torch.load('models/ocr_v3_amp_TC5000Chars_yuan/CRNNEfficientNetB3_768_bi/CRNNEfficientNetB3_768_bi_checkpoint.pt')
ocr_model.load_state_dict(ocr_checkpoint['model'])
ocr_model.backbone.set_swish(memory_efficient=False)
ocr_model.to(device)
ocr_model.eval()
ocr_model.backbone = torch.jit.trace(ocr_model.backbone, torch.rand(1, 3, 40, 900).to(device))
ocr_model = OCRTorchScriptV3(ocr_model)
ocr_model.eval()
ocr_model.to(device)
ocr_model_script = torch.jit.script(ocr_model)
torch.jit.save(ocr_model_script, 'models/ocrV3_TC5000Chars_yuan.torchscript')
ocr_model = CRNNEfficientNetB3(len(chars.chars), rnn_hidden=768, bidirectional=True)
ocr_checkpoint = torch.load('models/ocr_v3_amp_TC5000Chars_hei/CRNNEfficientNetB3_768_bi/CRNNEfficientNetB3_768_bi_checkpoint.pt')
ocr_model.load_state_dict(ocr_checkpoint['model'])
ocr_model.backbone.set_swish(memory_efficient=False)
ocr_model.to(device)
ocr_model.eval()
ocr_model.backbone = torch.jit.trace(ocr_model.backbone, torch.rand(1, 3, 40, 900).to(device))
ocr_model = OCRTorchScriptV3(ocr_model)
ocr_model.eval()
ocr_model.to(device)
ocr_model_script = torch.jit.script(ocr_model)
torch.jit.save(ocr_model_script, 'models/ocrV3_TC5000Chars_hei.torchscript')
def export_ocr_resnet_model():
chars = SC5000Chars()
ocr_model = CRNNResnext101(len(chars.chars), rnn_hidden=1280)
ocr_checkpoint = torch.load('models/SC3500Chars_yuan_CRNNResnext101_1280_checkpoint.pt')
ocr_model.load_state_dict(ocr_checkpoint['model'])
ocr_model = OCRTorchScript(ocr_model)
ocr_model.eval()
ocr_model.to(device)
ocr_model_script = torch.jit.script(ocr_model)
torch.jit.save(ocr_model_script, 'models/ocr_SC3500Chars_yuan.torchscript')
ocr_model = CRNNResnext101(len(chars.chars), rnn_hidden=1280)
ocr_checkpoint = torch.load('models/SC3500Chars_hei_CRNNResnext101_1280_checkpoint.pt')
ocr_model.load_state_dict(ocr_checkpoint['model'])
ocr_model = OCRTorchScript(ocr_model)
ocr_model.eval()
ocr_model.to(device)
ocr_model_script = torch.jit.script(ocr_model)
torch.jit.save(ocr_model_script, 'models/ocr_SC3500Chars_hei.torchscript')
chars.export('models/ocr_SC3500Chars.txt')
chars = TC3600Chars()
ocr_model = CRNNResnext101(len(chars.chars), rnn_hidden=1280)
ocr_checkpoint = torch.load('models/TC3600Chars_yuan_CRNNResnext101_1280_checkpoint.pt')
ocr_model.load_state_dict(ocr_checkpoint['model'])
ocr_model = OCRTorchScript(ocr_model)
ocr_model.eval()
ocr_model.to(device)
ocr_model_script = torch.jit.script(ocr_model)
torch.jit.save(ocr_model_script, 'models/ocr_TC3600Chars_yuan.torchscript')
ocr_model = CRNNResnext101(len(chars.chars), rnn_hidden=1280)
ocr_checkpoint = torch.load('models/TC3600Chars_hei_CRNNResnext101_1280_checkpoint.pt')
ocr_model.load_state_dict(ocr_checkpoint['model'])
ocr_model = OCRTorchScript(ocr_model)
ocr_model.eval()
ocr_model.to(device)
ocr_model_script = torch.jit.script(ocr_model)
torch.jit.save(ocr_model_script, 'models/ocr_TC3600Chars_hei.torchscript')
chars.export('models/ocr_TC3600Chars.txt')
def export_mse():
mse_model = MSETorchScript()
mse_model.to(device)
mse_model_script = torch.jit.script(mse_model)
torch.jit.save(mse_model_script, 'models/mse.torchscript')
if __name__ == "__main__":
export_ocr_efficientnet_model()