Skip to content

Commit 6e5ff8b

Browse files
committed
[Enhancement] squeeze first axis of row arrays/structs
1 parent b5498b9 commit 6e5ff8b

File tree

4 files changed

+80
-36
lines changed

4 files changed

+80
-36
lines changed

spm/__wrapper__.py

Lines changed: 54 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import warnings
22
import numpy as np
3+
from functools import partial
34
from collections.abc import (
45
MutableSequence, MutableMapping, KeysView, ValuesView, ItemsView
56
)
@@ -222,7 +223,7 @@ class MatlabType(object):
222223
"""Generic type for objects that have an exact matlab equivalent."""
223224

224225
@classmethod
225-
def from_any(cls, other):
226+
def from_any(cls, other, **kwargs):
226227
"""
227228
Convert python/matlab objects to `MatlabType` objects
228229
(`Cell`, `Struct`, `Array`, `MatlabClass`).
@@ -234,6 +235,9 @@ def from_any(cls, other):
234235
# - we do not convert to types that can be passed directly to
235236
# the matlab runtime;
236237
# - instead, we convert to python types that mimic matlab types.
238+
_from_any = partial(cls.from_any, **kwargs)
239+
_from_runtime = kwargs.pop("_from_runtime", False)
240+
237241
if isinstance(other, MatlabType):
238242
if isinstance(other, AnyDelayedArray):
239243
other._error_is_not_finalized()
@@ -278,15 +282,25 @@ def from_any(cls, other):
278282
raise ValueError("Don't know what to do with type", type__)
279283

280284
else:
285+
other = type(other)(
286+
zip(other.keys(),
287+
map(_from_any, other.values()))
288+
)
281289
return Struct.from_any(other)
282290

283291
if isinstance(other, (list, tuple, set)):
284292
# nested tuples are cells of cells, not cell arrays
285-
return Cell.from_any(other)
293+
if _from_runtime:
294+
return Cell._from_runtime(other)
295+
else:
296+
return Cell.from_any(other)
286297

287298
if isinstance(other, (np.ndarray, int, float, complex, bool)):
288299
# [array of] numbers -> Array
289-
return Array.from_any(other)
300+
if _from_runtime:
301+
return Array._from_runtime(other)
302+
else:
303+
return Array.from_any(other)
290304

291305
if isinstance(other, str):
292306
return other
@@ -305,17 +319,20 @@ def from_any(cls, other):
305319
return MatlabFunction.from_any(other)
306320

307321
if type(other) in _matlab_array_types():
308-
dtype = _matlab_array_types()[type(other)]
309-
return Array.from_any(other, dtype=dtype)
322+
return Array._from_runtime(other)
310323

311324
if hasattr(other, "__iter__"):
312325
# Iterable -> let's try to make it a cell
313-
return cls.from_any(list(other))
326+
return cls.from_any(list(other), _from_runtime=_from_runtime)
314327

315328
raise TypeError(
316329
f"Cannot convert {type(other)} into a matlab object."
317330
)
318331

332+
@classmethod
333+
def _from_runtime(cls, obj):
334+
return cls.from_any(obj, _from_runtime=True)
335+
319336
@classmethod
320337
def _to_runtime(cls, obj):
321338
"""
@@ -1157,7 +1174,10 @@ def _as_runtime(self) -> np.ndarray:
11571174

11581175
@classmethod
11591176
def _from_runtime(cls, other) -> "Array":
1160-
return cls.from_any(other)
1177+
other = np.asarray(other)
1178+
if len(other.shape) == 2 and other.shape[0] == 1:
1179+
other = other.squeeze(0)
1180+
return np.ndarray.view(other, cls)
11611181

11621182
@classmethod
11631183
def from_shape(cls, shape=tuple(), **kwargs) -> "Array":
@@ -1855,9 +1875,19 @@ def _as_runtime(self) -> dict:
18551875

18561876
@classmethod
18571877
def _from_runtime(cls, objdict: dict) -> "Cell":
1878+
if isinstance(objdict, (list, tuple, set)):
1879+
shape = [len(objdict)]
1880+
objdict = dict(type__='cell', size__=shape, data__=objdict)
1881+
18581882
if objdict['type__'] != 'cell':
18591883
raise TypeError('objdict is not a cell')
1884+
18601885
size = np.array(objdict['size__'], dtype=np.uint64).ravel()
1886+
if len(size) == 2 and size[0] == 1:
1887+
# NOTE: should not be needed for Cell, as this should
1888+
# have been taken care of by MPython, but I am keeping it
1889+
# here for symmetry with Array and Struct.
1890+
size = size[1:]
18611891
data = np.fromiter(objdict['data__'], dtype=object)
18621892
data = data.reshape(size[::-1]).transpose()
18631893
try:
@@ -1874,7 +1904,7 @@ def _from_runtime(cls, objdict: dict) -> "Cell":
18741904
op_flags=['readwrite', 'no_broadcast'])
18751905
with np.nditer(data, **opt) as iter:
18761906
for elem in iter:
1877-
elem[()] = MatlabType.from_any(elem.item())
1907+
elem[()] = MatlabType._from_runtime(elem.item())
18781908

18791909
return obj
18801910

@@ -2485,6 +2515,11 @@ def _from_runtime(cls, objdict: dict) -> "Struct":
24852515
if objdict['type__'] != 'structarray':
24862516
raise TypeError('objdict is not a structarray')
24872517
size = np.array(objdict['size__'], dtype=np.uint64).ravel()
2518+
if len(size) == 2 and size[0] == 1:
2519+
# NOTE: should not be needed for Cell, as this should
2520+
# have been taken care of by MPython, but I am keeping it
2521+
# here for symmetry with Array and Struct.
2522+
size = size[1:]
24882523
data = np.array(objdict['data__'], dtype=object)
24892524
data = data.reshape(size)
24902525
try:
@@ -2495,7 +2530,17 @@ def _from_runtime(cls, objdict: dict) -> "Struct":
24952530
f' data={data}\n'
24962531
f' objdict={objdict}'
24972532
)
2498-
return MatlabType.from_any(obj)
2533+
2534+
# recurse
2535+
opt = dict(flags=['refs_ok', 'zerosize_ok'],
2536+
op_flags=['readonly', 'no_broadcast'])
2537+
with np.nditer(data, **opt) as iter:
2538+
for elem in iter:
2539+
item = elem.item()
2540+
for key, val in item.items():
2541+
item[key] = MatlabType._from_runtime(val)
2542+
2543+
return obj
24992544

25002545
@classmethod
25012546
def from_shape(cls, shape=tuple(), **kwargs) -> "Struct":

tests/test_array.py

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def test_array_instantiate_shape_2d_col(self):
4444
self.assertEqual(a.shape, (3, 1))
4545
self.assertEqual(a.dtype, np.float64)
4646
self.assertTrue((a == 0).all())
47-
47+
4848
def test_array_instantiate_shape_2d(self):
4949
a = Array(2, 3, 4, 5)
5050

@@ -91,9 +91,9 @@ def test_array_from_matlab_2d_row(self):
9191
self.fail('2D row array to Matlab failed.')
9292

9393
self.assertIsInstance(a, Array)
94-
self.assertEqual(a.shape, (1, 3))
94+
self.assertEqual(a.shape, (3,))
9595
self.assertEqual(a.dtype, np.float64)
96-
self.assertEqual(a.tolist(), [[1, 2, 3]])
96+
self.assertEqual(a.tolist(), [1, 2, 3])
9797

9898
def test_array_from_matlab_2d_col(self):
9999
try:
@@ -116,20 +116,20 @@ def test_array_from_matlab_array2d(self):
116116
self.assertEqual(a.shape, (2, 3))
117117
self.assertEqual(a.dtype, np.float64)
118118
self.assertEqual(a.tolist(), [[1, 2, 3], [4, 5, 6]])
119-
119+
120120
def test_array_to_matlab_empty(self):
121121
# Construct an empty array
122122
a = Array()
123-
123+
124124
# Get properties in Matlab
125-
try:
125+
try:
126126
size = Runtime.call('size', a)
127127
type = Runtime.call('class', a)
128128
except Exception:
129129
self.fail('Empty array to Matlab failed.')
130130

131131
# Check properties in Matlab
132-
self.assertEqual(size.tolist(), [[1, 1]])
132+
self.assertEqual(size.tolist(), [1, 1])
133133
self.assertEqual(type, 'double')
134134

135135
def test_array_to_matlab_empty_1d(self):
@@ -144,7 +144,7 @@ def test_array_to_matlab_empty_1d(self):
144144
self.fail('1D shape array to Matlab failed.')
145145

146146
# Check properties in Matlab
147-
self.assertEqual(size.tolist(), [[1, 3]])
147+
self.assertEqual(size.tolist(), [1, 3])
148148
self.assertEqual(type, 'double')
149149

150150
def test_array_to_matlab_empty_2d_row(self):
@@ -159,7 +159,7 @@ def test_array_to_matlab_empty_2d_row(self):
159159
self.fail('1D row array to Matlab failed.')
160160

161161
# Check properties in Matlab
162-
self.assertEqual(size.tolist(), [[1, 3]])
162+
self.assertEqual(size.tolist(), [1, 3])
163163
self.assertEqual(type, 'double')
164164

165165
def test_array_to_matlab_empty_2d_col(self):
@@ -174,7 +174,7 @@ def test_array_to_matlab_empty_2d_col(self):
174174
self.fail('1D col array to Matlab failed.')
175175

176176
# Check properties in Matlab
177-
self.assertEqual(size.tolist(), [[3, 1]])
177+
self.assertEqual(size.tolist(), [3, 1])
178178
self.assertEqual(type, 'double')
179179

180180
def test_array_to_matlab_empty_2d(self):
@@ -189,9 +189,9 @@ def test_array_to_matlab_empty_2d(self):
189189
self.fail('2D array to Matlab failed.')
190190

191191
# Check properties in Matlab
192-
self.assertEqual(size.tolist(), [[3, 2]])
192+
self.assertEqual(size.tolist(), [3, 2])
193193
self.assertEqual(type, 'double')
194-
194+
195195
def test_array_to_matlab_empty_nd(self):
196196
# Construct a 2x3x4x5 array
197197
a = Array(2, 3, 4, 5)
@@ -204,7 +204,7 @@ def test_array_to_matlab_empty_nd(self):
204204
self.fail('N-D array to Matlab failed.')
205205

206206
# Check properties in Matlab
207-
self.assertEqual(size.tolist(), [[2, 3, 4, 5]])
207+
self.assertEqual(size.tolist(), [2, 3, 4, 5])
208208
self.assertEqual(type, 'double')
209209

210210
def test_array_append_1d(self):
@@ -258,9 +258,8 @@ def test_array_indexing_1d(self):
258258
self.assertEqual(a.tolist(), [2])
259259

260260
def test_array_roundtrip_1d(self):
261-
a = Array.from_any([1, 2, 3])
262261
identity = Runtime.call('eval', '@(x) x')
263-
a = Array.from_any([[1, 2, 3]])
262+
a = Array.from_any([1, 2, 3])
264263
d = identity(a)
265264
self.assertListEqual(a.tolist(), d.tolist())
266265

tests/test_cell.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ def test_cell_to_matlab_empty(self):
108108
self.fail('Empty cell to Matlab failed.')
109109

110110
# Check properties in Matlab
111-
self.assertEqual(size.tolist(), [[1, 0]])
111+
self.assertEqual(size.tolist(), [1, 0])
112112
self.assertEqual(type, 'cell')
113113

114114
def test_cell_to_matlab_empty_1d(self):
@@ -123,7 +123,7 @@ def test_cell_to_matlab_empty_1d(self):
123123
self.fail('1D shape cell to Matlab failed.')
124124

125125
# Check properties in Matlab
126-
self.assertEqual(size.tolist(), [[1, 3]])
126+
self.assertEqual(size.tolist(), [1, 3])
127127
self.assertEqual(type, 'cell')
128128

129129
def test_cell_to_matlab_empty_2d_row(self):
@@ -138,7 +138,7 @@ def test_cell_to_matlab_empty_2d_row(self):
138138
self.fail('1D row cell to Matlab failed.')
139139

140140
# Check properties in Matlab
141-
self.assertEqual(size.tolist(), [[1, 3]])
141+
self.assertEqual(size.tolist(), [1, 3])
142142
self.assertEqual(type, 'cell')
143143

144144
def test_cell_to_matlab_empty_2d_col(self):
@@ -153,7 +153,7 @@ def test_cell_to_matlab_empty_2d_col(self):
153153
self.fail('1D col cell to Matlab failed.')
154154

155155
# Check properties in Matlab
156-
self.assertEqual(size.tolist(), [[3, 1]])
156+
self.assertEqual(size.tolist(), [3, 1])
157157
self.assertEqual(type, 'cell')
158158

159159
def test_cell_to_matlab_empty_2d(self):
@@ -168,7 +168,7 @@ def test_cell_to_matlab_empty_2d(self):
168168
self.fail('2D cell to Matlab failed.')
169169

170170
# Check properties in Matlab
171-
self.assertEqual(size.tolist(), [[3, 2]])
171+
self.assertEqual(size.tolist(), [3, 2])
172172
self.assertEqual(type, 'cell')
173173

174174
def test_cell_to_matlab_empty_nd(self):
@@ -183,7 +183,7 @@ def test_cell_to_matlab_empty_nd(self):
183183
self.fail('N-D cell to Matlab failed.')
184184

185185
# Check properties in Matlab
186-
self.assertEqual(size.tolist(), [[2, 3, 4, 5]])
186+
self.assertEqual(size.tolist(), [2, 3, 4, 5])
187187
self.assertEqual(type, 'cell')
188188

189189
def test_cell_as_struct(self):

tests/test_struct.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ def test_struct_from_matlab_with_array(self):
172172
# Check Runtime conversion
173173
self.assertIsInstance(s_matlab, Struct)
174174
self.assertIsInstance(s_matlab.field1, Array)
175-
self.assertListEqual(s_matlab.field1.tolist(), [[1, 2, 3]])
175+
self.assertListEqual(s_matlab.field1.tolist(), [1, 2, 3])
176176
self.assertEqual(s_matlab.field2, 2)
177177
self.assertListEqual(list(s_matlab.keys()), ["field1", "field2"])
178178
self.assertListEqual(list(s_matlab.values()), [s_matlab.field1, 2])
@@ -215,7 +215,7 @@ def test_struct_to_matlab_empty(self):
215215
self.fail('Empty struct to Matlab failed.')
216216

217217
# Check properties in Matlab
218-
self.assertEqual(size.tolist(), [[1, 1]])
218+
self.assertEqual(size.tolist(), [1, 1])
219219
self.assertEqual(type, 'struct')
220220
self.assertListEqual(fieldnames.tolist(), [])
221221

@@ -232,7 +232,7 @@ def test_struct_to_matlab_empty_1d(self):
232232
self.fail('1D shape struct to Matlab failed.')
233233

234234
# Check properties in Matlab
235-
self.assertEqual(size.tolist(), [[1, 3]])
235+
self.assertEqual(size.tolist(), [1, 3])
236236
self.assertEqual(type, 'struct')
237237
self.assertListEqual(fieldnames.tolist(), [])
238238

@@ -249,7 +249,7 @@ def test_struct_to_matlab_empty_2d_row(self):
249249
self.fail('1D row struct to Matlab failed.')
250250

251251
# Check properties in Matlab
252-
self.assertEqual(size.tolist(), [[1, 3]])
252+
self.assertEqual(size.tolist(), [1, 3])
253253
self.assertEqual(type, 'struct')
254254
self.assertListEqual(fieldnames.tolist(), [])
255255

@@ -266,7 +266,7 @@ def test_struct_to_matlab_empty_2d_col(self):
266266
self.fail('1D col struct to Matlab failed.')
267267

268268
# Check properties in Matlab
269-
self.assertEqual(size.tolist(), [[3, 1]])
269+
self.assertEqual(size.tolist(), [3, 1])
270270
self.assertEqual(type, 'struct')
271271
self.assertListEqual(fieldnames.tolist(), [])
272272

@@ -283,7 +283,7 @@ def test_struct_to_matlab_empty_2d(self):
283283
self.fail('2D struct to Matlab failed.')
284284

285285
# Check properties in Matlab
286-
self.assertEqual(size.tolist(), [[3, 2]])
286+
self.assertEqual(size.tolist(), [3, 2])
287287
self.assertEqual(type, 'struct')
288288
self.assertListEqual(fieldnames.tolist(), [])
289289

0 commit comments

Comments
 (0)