@@ -14,9 +14,9 @@ class MXNetModel(DifferentiableModel):
14
14
The input to the model.
15
15
logits : `mxnet.symbol.Symbol`
16
16
The predictions of the model, before the softmax.
17
- weights : `dictionary mapping str to mxnet.nd.array`
18
- The weights of the model.
19
- device : `mxnet.context.Context`
17
+ args : `dictionary mapping str to mxnet.nd.array`
18
+ The parameters of the model.
19
+ ctx : `mxnet.context.Context`
20
20
The device, e.g. mxnet.cpu() or mxnet.gpu().
21
21
num_classes : int
22
22
The number of classes.
@@ -25,6 +25,8 @@ class MXNetModel(DifferentiableModel):
25
25
(0, 1) or (0, 255).
26
26
channel_axis : int
27
27
The index of the axis that represents color channels.
28
+ aux_states : `dictionary mapping str to mxnet.nd.array`
29
+ The states of auxiliary parameters of the model.
28
30
preprocessing: 2-element tuple with floats or numpy arrays
29
31
Elementwises preprocessing of input; we first subtract the first
30
32
element of preprocessing from the input and then divide the input by
@@ -36,11 +38,12 @@ def __init__(
36
38
self ,
37
39
data ,
38
40
logits ,
39
- weights ,
40
- device ,
41
+ args ,
42
+ ctx ,
41
43
num_classes ,
42
44
bounds ,
43
45
channel_axis = 1 ,
46
+ aux_states = None ,
44
47
preprocessing = (0 , 1 )):
45
48
46
49
super (MXNetModel , self ).__init__ (
@@ -52,7 +55,7 @@ def __init__(
52
55
53
56
self ._num_classes = num_classes
54
57
55
- self ._device = device
58
+ self ._device = ctx
56
59
57
60
self ._data_sym = data
58
61
self ._batch_logits_sym = logits
@@ -63,9 +66,18 @@ def __init__(
63
66
loss = mx .symbol .softmax_cross_entropy (logits , label )
64
67
self ._loss_sym = loss
65
68
66
- weight_names = list (weights .keys ())
67
- weight_arrays = [weights [name ] for name in weight_names ]
68
- self ._args_map = dict (zip (weight_names , weight_arrays ))
69
+ self ._args_map = args .copy ()
70
+ self ._aux_map = aux_states .copy () if aux_states is not None else None
71
+
72
+ # move all parameters to correct device
73
+ for k in self ._args_map .keys ():
74
+ self ._args_map [k ] = \
75
+ self ._args_map [k ].as_in_context (ctx ) # pragma: no cover
76
+
77
+ if aux_states is not None :
78
+ for k in self ._aux_map .keys (): # pragma: no cover
79
+ self ._aux_map [k ] = \
80
+ self ._aux_map [k ].as_in_context (ctx ) # pragma: no cover
69
81
70
82
def num_classes (self ):
71
83
return self ._num_classes
@@ -76,7 +88,8 @@ def batch_predictions(self, images):
76
88
data_array = mx .nd .array (images , ctx = self ._device )
77
89
self ._args_map [self ._data_sym .name ] = data_array
78
90
model = self ._batch_logits_sym .bind (
79
- ctx = self ._device , args = self ._args_map , grad_req = 'null' )
91
+ ctx = self ._device , args = self ._args_map , grad_req = 'null' ,
92
+ aux_states = self ._aux_map )
80
93
model .forward (is_train = False )
81
94
logits_array = model .outputs [0 ]
82
95
logits = logits_array .asnumpy ()
@@ -99,7 +112,8 @@ def predictions_and_gradient(self, image, label):
99
112
ctx = self ._device ,
100
113
args = self ._args_map ,
101
114
args_grad = grad_map ,
102
- grad_req = 'write' )
115
+ grad_req = 'write' ,
116
+ aux_states = self ._aux_map )
103
117
model .forward (is_train = True )
104
118
logits_array = model .outputs [0 ]
105
119
model .backward ([
@@ -119,7 +133,8 @@ def _loss_fn(self, image, label):
119
133
self ._args_map [self ._data_sym .name ] = data_array
120
134
self ._args_map [self ._label_sym .name ] = label_array
121
135
model = self ._loss_sym .bind (
122
- ctx = self ._device , args = self ._args_map , grad_req = 'null' )
136
+ ctx = self ._device , args = self ._args_map , grad_req = 'null' ,
137
+ aux_states = self ._aux_map )
123
138
model .forward (is_train = False )
124
139
loss_array = model .outputs [0 ]
125
140
loss = loss_array .asnumpy ()[0 ]
0 commit comments