Skip to content

Commit 2e9e589

Browse files
wielandbrendeljonasrauber
authored andcommitted
Fixed MXNet interface (#42)
* added aux_states to mxnet * fix coverage * fixed overintendation
1 parent 60dd526 commit 2e9e589

File tree

2 files changed

+29
-14
lines changed

2 files changed

+29
-14
lines changed

foolbox/models/mxnet.py

Lines changed: 27 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,9 @@ class MXNetModel(DifferentiableModel):
1414
The input to the model.
1515
logits : `mxnet.symbol.Symbol`
1616
The predictions of the model, before the softmax.
17-
weights : `dictionary mapping str to mxnet.nd.array`
18-
The weights of the model.
19-
device : `mxnet.context.Context`
17+
args : `dictionary mapping str to mxnet.nd.array`
18+
The parameters of the model.
19+
ctx : `mxnet.context.Context`
2020
The device, e.g. mxnet.cpu() or mxnet.gpu().
2121
num_classes : int
2222
The number of classes.
@@ -25,6 +25,8 @@ class MXNetModel(DifferentiableModel):
2525
(0, 1) or (0, 255).
2626
channel_axis : int
2727
The index of the axis that represents color channels.
28+
aux_states : `dictionary mapping str to mxnet.nd.array`
29+
The states of auxiliary parameters of the model.
2830
preprocessing: 2-element tuple with floats or numpy arrays
2931
Elementwises preprocessing of input; we first subtract the first
3032
element of preprocessing from the input and then divide the input by
@@ -36,11 +38,12 @@ def __init__(
3638
self,
3739
data,
3840
logits,
39-
weights,
40-
device,
41+
args,
42+
ctx,
4143
num_classes,
4244
bounds,
4345
channel_axis=1,
46+
aux_states=None,
4447
preprocessing=(0, 1)):
4548

4649
super(MXNetModel, self).__init__(
@@ -52,7 +55,7 @@ def __init__(
5255

5356
self._num_classes = num_classes
5457

55-
self._device = device
58+
self._device = ctx
5659

5760
self._data_sym = data
5861
self._batch_logits_sym = logits
@@ -63,9 +66,18 @@ def __init__(
6366
loss = mx.symbol.softmax_cross_entropy(logits, label)
6467
self._loss_sym = loss
6568

66-
weight_names = list(weights.keys())
67-
weight_arrays = [weights[name] for name in weight_names]
68-
self._args_map = dict(zip(weight_names, weight_arrays))
69+
self._args_map = args.copy()
70+
self._aux_map = aux_states.copy() if aux_states is not None else None
71+
72+
# move all parameters to correct device
73+
for k in self._args_map.keys():
74+
self._args_map[k] = \
75+
self._args_map[k].as_in_context(ctx) # pragma: no cover
76+
77+
if aux_states is not None:
78+
for k in self._aux_map.keys(): # pragma: no cover
79+
self._aux_map[k] = \
80+
self._aux_map[k].as_in_context(ctx) # pragma: no cover
6981

7082
def num_classes(self):
7183
return self._num_classes
@@ -76,7 +88,8 @@ def batch_predictions(self, images):
7688
data_array = mx.nd.array(images, ctx=self._device)
7789
self._args_map[self._data_sym.name] = data_array
7890
model = self._batch_logits_sym.bind(
79-
ctx=self._device, args=self._args_map, grad_req='null')
91+
ctx=self._device, args=self._args_map, grad_req='null',
92+
aux_states=self._aux_map)
8093
model.forward(is_train=False)
8194
logits_array = model.outputs[0]
8295
logits = logits_array.asnumpy()
@@ -99,7 +112,8 @@ def predictions_and_gradient(self, image, label):
99112
ctx=self._device,
100113
args=self._args_map,
101114
args_grad=grad_map,
102-
grad_req='write')
115+
grad_req='write',
116+
aux_states=self._aux_map)
103117
model.forward(is_train=True)
104118
logits_array = model.outputs[0]
105119
model.backward([
@@ -119,7 +133,8 @@ def _loss_fn(self, image, label):
119133
self._args_map[self._data_sym.name] = data_array
120134
self._args_map[self._label_sym.name] = label_array
121135
model = self._loss_sym.bind(
122-
ctx=self._device, args=self._args_map, grad_req='null')
136+
ctx=self._device, args=self._args_map, grad_req='null',
137+
aux_states=self._aux_map)
123138
model.forward(is_train=False)
124139
loss_array = model.outputs[0]
125140
loss = loss_array.asnumpy()[0]

foolbox/tests/test_models_mxnet.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def mean_brightness_net(images):
2121
images,
2222
logits,
2323
{},
24-
device=mx.cpu(),
24+
ctx=mx.cpu(),
2525
num_classes=num_classes,
2626
bounds=bounds,
2727
channel_axis=1)
@@ -68,7 +68,7 @@ def mean_brightness_net(images):
6868
images,
6969
logits,
7070
{},
71-
device=mx.cpu(),
71+
ctx=mx.cpu(),
7272
num_classes=num_classes,
7373
bounds=bounds,
7474
preprocessing=preprocessing,

0 commit comments

Comments
 (0)