Skip to content

Commit 191e1b8

Browse files
committed
Test + fixes for cells (wip)
1 parent b3dcb0e commit 191e1b8

File tree

2 files changed

+112
-25
lines changed

2 files changed

+112
-25
lines changed

spm/__wrapper__.py

Lines changed: 31 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -199,15 +199,12 @@ def from_any(cls, other):
199199
else:
200200
return Struct.from_any(other)
201201

202-
if isinstance(other, (tuple, set)):
202+
if isinstance(other, (list, tuple, set)):
203203
# nested tuples are cells of cells, not cell arrays
204204
return Cell.from_any(other)
205205

206-
if isinstance(other, (list, np.ndarray, int, float, complex, bool)):
207-
try:
208-
return Array.from_any(other)
209-
except (ValueError, TypeError):
210-
return Cell.from_any(other)
206+
if isinstance(other, (np.ndarray, int, float, complex, bool)):
207+
return Array.from_any(other)
211208

212209
if isinstance(other, str):
213210
return other
@@ -602,7 +599,8 @@ def _resize_for_index(self, index):
602599

603600
def _finalize(self):
604601
"""Transform all DelayedArrays into concrete arrays."""
605-
opt = dict(flags=['refs_ok'], op_flags=['readwrite', 'no_broadcast'])
602+
opt = dict(flags=['refs_ok', 'zerosize_ok'],
603+
op_flags=['readwrite', 'no_broadcast'])
606604
with np.nditer(self, **opt) as iter:
607605
for elem in iter:
608606
item = elem.item()
@@ -1273,7 +1271,7 @@ def _from_runtime(cls, objdict: dict) -> "Cell":
12731271
raise TypeError('objdict is not a cell')
12741272
size = np.array(objdict['size__'], dtype=np.uint32).ravel()
12751273
data = np.array(objdict['data__'], dtype=object)
1276-
data = data.reshape(size)
1274+
data = data.reshape(size[::-1]).transpose()
12771275
try:
12781276
obj = data.view(cls)
12791277
except Exception:
@@ -1311,7 +1309,8 @@ def from_shape(cls, shape=tuple(), **kwargs) -> "Cell":
13111309
# Scalar cells are forbidden
13121310
shape = [0]
13131311
arr = np.ndarray(shape, **kwargs)
1314-
opt = dict(flags=['refs_ok'], op_flags=['write_only', 'no_broadcast'])
1312+
opt = dict(flags=['refs_ok', 'zerosize_ok'],
1313+
op_flags=['writeonly', 'no_broadcast'])
13151314
with np.nditer(arr, **opt) as iter:
13161315
for elem in iter:
13171316
elem[()] = cls._DEFAULT()
@@ -1367,13 +1366,16 @@ def from_any(cls, other, **kwargs) -> "Cell":
13671366
def asrecursive(other):
13681367
if isinstance(other, np.ndarray):
13691368
return other
1369+
elif isinstance(other, (str, bytes)):
1370+
return other
13701371
elif hasattr(other, "__iter__"):
13711372
other = list(map(asrecursive, other))
13721373
tmp = np.ndarray(len(other), dtype=object)
13731374
for i, x in enumerate(other):
13741375
tmp[i] = x
13751376
other = tmp
1376-
return np.ndarray.view(other, cls)
1377+
obj = np.ndarray.view(other, cls)
1378+
return obj
13771379
else:
13781380
return other
13791381

@@ -1394,7 +1396,8 @@ def asrecursive(other):
13941396
other = np.ndarray.view(other, cls)
13951397

13961398
# recurse
1397-
opt = dict(flags=['refs_ok'], op_flags=['readwrite', 'no_broadcast'])
1399+
opt = dict(flags=['refs_ok', 'zerosize_ok'],
1400+
op_flags=['readwrite', 'no_broadcast'])
13981401
with np.nditer(other, **opt) as iter:
13991402
for elem in iter:
14001403
elem[()] = MatlabType.from_any(elem.item())
@@ -1412,7 +1415,8 @@ def _unroll_build(cls, arr):
14121415
# them to lists, and recurse.
14131416
rebuild = False
14141417
arr = np.asarray(arr)
1415-
opt = dict(flags=['refs_ok'], op_flags=['readwrite', 'no_broadcast'])
1418+
opt = dict(flags=['refs_ok', 'zerosize_ok'],
1419+
op_flags=['readwrite', 'no_broadcast'])
14161420
with np.nditer(arr, **opt) as iter:
14171421
for elem in iter:
14181422
item = elem.item()
@@ -1593,7 +1597,7 @@ def __getitem__(self, key):
15931597
# Otherwise the default value is [], as per matlab.
15941598
isnewkey = key not in self.keys()
15951599
arr = np.ndarray.view(self, np.ndarray)
1596-
opt = dict(flags=['refs_ok'], op_flags=['readonly'])
1600+
opt = dict(flags=['refs_ok', 'zerosize_ok'], op_flags=['readonly'])
15971601
with np.nditer(arr, **opt) as iter:
15981602
for elem in iter:
15991603
elem.item().setdefault(
@@ -1617,7 +1621,8 @@ def __setitem__(self, key, value):
16171621
# Each element in the struct array is matched with an element
16181622
# in the "unpack" array.
16191623
value = value.broadcast_to_struct(self)
1620-
opt = dict(flags=['refs_ok', 'multi_index'], op_flags=['readonly'])
1624+
opt = dict(flags=['refs_ok', 'zerosize_ok', 'multi_index'],
1625+
op_flags=['readonly'])
16211626
with np.nditer(arr, **opt) as iter:
16221627
for elem in iter:
16231628
val = value[iter.multi_index]
@@ -1627,7 +1632,7 @@ def __setitem__(self, key, value):
16271632

16281633
else:
16291634
# Assign the same value to all elements in the struct array.
1630-
opt = dict(flags=['refs_ok'], op_flags=['readonly'])
1635+
opt = dict(flags=['refs_ok', 'zerosize_ok'], op_flags=['readonly'])
16311636
value = MatlabType.from_any(value)
16321637
with np.nditer(arr, **opt) as iter:
16331638
for elem in iter:
@@ -1637,7 +1642,7 @@ def __delitem__(self, key):
16371642
if key not in self._allkeys():
16381643
raise KeyError(key)
16391644
arr = np.ndarray.view(self, np.ndarray)
1640-
opt = dict(flags=['refs_ok'], op_flags=['readonly'])
1645+
opt = dict(flags=['refs_ok', 'zerosize_ok'], op_flags=['readonly'])
16411646
with np.nditer(arr, **opt) as iter:
16421647
for elem in iter:
16431648
del elem.item()[key]
@@ -1655,7 +1660,7 @@ def values(self):
16551660

16561661
def setdefault(self, key, value):
16571662
arr = np.ndarray.view(self, np.ndarray)
1658-
opt = dict(flags=['refs_ok'], op_flags=['readonly'])
1663+
opt = dict(flags=['refs_ok', 'zerosize_ok'], op_flags=['readonly'])
16591664
with np.nditer(arr, **opt) as iter:
16601665
for elem in iter:
16611666
item = elem.item()
@@ -1667,7 +1672,8 @@ def update(self, other):
16671672
other = np.broadcast_to(other, self.shape)
16681673

16691674
arr = np.ndarray.view(self, np.ndarray)
1670-
opt = dict(flags=['refs_ok', 'multi_index'], op_flags=['readonly'])
1675+
opt = dict(flags=['refs_ok', 'zerosize_ok', 'multi_index'],
1676+
op_flags=['readonly'])
16711677
with np.nditer(arr, **opt) as iter:
16721678
for elem in iter:
16731679
other_elem = other[iter.multi_index]
@@ -1844,7 +1850,7 @@ def from_shape(cls, shape=tuple(), **kwargs) -> "Struct":
18441850
"""
18451851
kwargs["dtype"] = dict
18461852
arr = np.ndarray(shape, **kwargs)
1847-
flags = dict(flags=['refs_ok'], op_flags=['readwrite'])
1853+
flags = dict(flags=['refs_ok', 'zerosize_ok'], op_flags=['readwrite'])
18481854
with np.nditer(arr, **flags) as iter:
18491855
for elem in iter:
18501856
elem[()] = dict()
@@ -1892,7 +1898,7 @@ def from_any(cls, other, **kwargs) -> "Struct":
18921898
# convert to array[dict]
18931899
other = np.asarray(other, **kwargs)
18941900
other = cls._unroll_build(other)
1895-
opt = dict(flags=['refs_ok'], op_flags=['readonly'])
1901+
opt = dict(flags=['refs_ok', 'zerosize_ok'], op_flags=['readonly'])
18961902
with np.nditer(other, **opt) as iter:
18971903
if not all(isinstance(elem.item(), dict) for elem in iter):
18981904
raise TypeError("Not an array of dictionaries")
@@ -1902,7 +1908,7 @@ def from_any(cls, other, **kwargs) -> "Struct":
19021908
other = _copy_if_needed(other, inp, copy)
19031909

19041910
# nested from_any
1905-
opt = dict(flags=['refs_ok'], op_flags=['readonly'])
1911+
opt = dict(flags=['refs_ok', 'zerosize_ok'], op_flags=['readonly'])
19061912
with np.nditer(other, **opt) as iter:
19071913
for elem in iter:
19081914
item: dict = elem.item()
@@ -1929,7 +1935,7 @@ def _unroll_build(cls, arr):
19291935
# them to lists, and recurse.
19301936
rebuild = False
19311937
arr = np.ndarray.view(arr, np.ndarray)
1932-
flags = dict(flags=['refs_ok'], op_flags=['readwrite'])
1938+
flags = dict(flags=['refs_ok', 'zerosize_ok'], op_flags=['readwrite'])
19331939
with np.nditer(arr, **flags) as iter:
19341940
for elem in iter:
19351941
item = elem.item()
@@ -1962,7 +1968,7 @@ def as_dict(self) -> dict:
19621968
if np.ndim(self) == 0:
19631969
return np.ndarray.item(self)
19641970
arr = np.ndarray.view(self, np.ndarray)
1965-
opt = dict(flags=['refs_ok'], op_flags=['readwrite'])
1971+
opt = dict(flags=['refs_ok', 'zerosize_ok'], op_flags=['readwrite'])
19661972
with np.nditer(arr, **opt) as iter:
19671973
return {
19681974
key: Cell.from_any([
@@ -1976,7 +1982,7 @@ def _allkeys(self):
19761982
# Keys are ordered by (1) element (2) within-element order
19771983
mock = {}
19781984
arr = np.ndarray.view(self, np.ndarray)
1979-
opt = dict(flags=['refs_ok'], op_flags=['readwrite'])
1985+
opt = dict(flags=['refs_ok', 'zerosize_ok'], op_flags=['readwrite'])
19801986
with np.nditer(arr, **opt) as iter:
19811987
for elem in iter:
19821988
mock.update({key: None for key in elem.item().keys()})
@@ -2065,7 +2071,7 @@ def __delattr__(self, key):
20652071

20662072
def _finalize(self):
20672073
arr = np.ndarray.view(self, np.ndarray)
2068-
opt = dict(flags=['refs_ok'], op_flags=['readwrite'])
2074+
opt = dict(flags=['refs_ok', 'zerosize_ok'], op_flags=['readwrite'])
20692075
with np.nditer(arr, **opt) as iter:
20702076
for elem in iter:
20712077
item = elem.item()

tests/test_cell.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
import unittest
2+
import numpy as np
3+
4+
from spm import Cell, Runtime
5+
6+
7+
class CellTestCase(unittest.TestCase):
8+
9+
def test_instantiate_empty(self):
10+
c = Cell()
11+
a = np.asarray([])
12+
self.assertIsInstance(c, Cell)
13+
self.assertListEqual(c.tolist(), a.tolist())
14+
15+
def test_instantiate_shape_1d(self):
16+
c = Cell([3])
17+
self.assertIsInstance(c, Cell)
18+
self.assertEqual(c.shape, (3,))
19+
self.assertTrue(all(x.shape == (0,) for x in c))
20+
21+
def test_instantiate_shape_2d(self):
22+
c = Cell([3, 2])
23+
self.assertIsInstance(c, Cell)
24+
self.assertEqual(c.shape, (3, 2))
25+
self.assertTrue(all(x.shape == (0,) for x in c.flat))
26+
27+
def test_instantiate_shape_order(self):
28+
c = Cell([3, 2], order="C")
29+
self.assertIsInstance(c, Cell)
30+
self.assertEqual(c.shape, (3, 2))
31+
self.assertEqual(c.strides, (2*8, 8))
32+
33+
c = Cell([3, 2], order="F")
34+
self.assertIsInstance(c, Cell)
35+
self.assertEqual(c.shape, (3, 2))
36+
self.assertEqual(c.strides, (8, 8*3))
37+
38+
def test_instantiate_list(self):
39+
c = Cell.from_any(['a', 'b', 'c'])
40+
self.assertEqual(c.shape, (3,))
41+
self.assertEqual(repr(c), "Cell(['a', 'b', 'c'])")
42+
self.assertListEqual(c.tolist(), ['a', 'b', 'c'])
43+
44+
def test_instantiate_nested_list(self):
45+
c = Cell.from_any([['a', 'b'], ['c', 'd']])
46+
self.assertEqual(c.shape, (2,))
47+
self.assertEqual(repr(c), "Cell([Cell(['a', 'b']), Cell(['c', 'd'])])")
48+
self.assertTrue(all(isinstance(x, Cell) for x in c.tolist()))
49+
50+
def test_instantiate_nested_list_deepcat(self):
51+
c = Cell.from_any([['a', 'b'], ['c', 'd']], deepcat=True)
52+
self.assertEqual(c.shape, (2, 2))
53+
self.assertEqual(repr(c), "Cell([['a', 'b'],\n ['c', 'd']])")
54+
self.assertTrue(all(isinstance(x, list) for x in c.tolist()))
55+
56+
def test_as_struct(self):
57+
self.assertRaises(TypeError, lambda: Cell().as_struct)
58+
59+
def test_as_num(self):
60+
self.assertRaises(TypeError, lambda: Cell().as_num)
61+
62+
def test_as_cell(self):
63+
c = Cell()
64+
self.assertTrue(c.as_cell is c)
65+
66+
def test_cell1d_from_matlab(self):
67+
c_matlab = Runtime.call("eval", "{1, 2, 3}")
68+
c_python = Cell.from_any([1, 2, 3])
69+
self.assertListEqual(c_matlab.tolist(), c_python.tolist())
70+
71+
def test_cell2d_from_matlab(self):
72+
c_matlab = Runtime.call("eval", "{1, 2, 3; 4, 5, 6}")
73+
c_python = Cell.from_any([[1, 2, 3], [4, 5, 6]], deepcat=True)
74+
self.assertTrue(c_matlab.tolist() == c_python.tolist())
75+
76+
def test_nested_cell_from_matlab(self):
77+
c_matlab = Runtime.call("eval", "{{1, 2, 3}, {4, 5, 6}}")
78+
c_python = Cell.from_any([[1, 2, 3], [4, 5, 6]])
79+
c_matlab = [x.tolist() for x in c_matlab.tolist()]
80+
c_python = [x.tolist() for x in c_python.tolist()]
81+
self.assertListEqual(c_matlab, c_python)

0 commit comments

Comments
 (0)