Skip to content

Commit b6d8076

Browse files
authored
Merge pull request #483 from pynapple-org/restrict_info
Add `restrict_info` method for metadata
2 parents 089d2c9 + af75d92 commit b6d8076

File tree

7 files changed

+261
-1
lines changed

7 files changed

+261
-1
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,3 +165,5 @@ your
165165
# Ignore npz files from testing:
166166
tests/*.npz
167167
.vscode/settings.json
168+
doc/user_guide/MyProject/sub-A2929/A2929-200711/stimulus-fish.json
169+
doc/user_guide/memmap.dat

doc/user_guide/03_metadata.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,18 @@ tsgroup.drop_info("coords")
214214
print(tsgroup)
215215
```
216216

217+
## Restricting metadata
218+
Instead of dropping multiple metadata fields, you may want to restrict to a set of specified fields, i.e. select which columns to keep. For this operation, use the [`restrict_info()`](pynapple.TsGroup.restrict_info) method. Multiple metadata columns can be kept by passing a list of metadata names.
219+
```{code-cell} ipython3
220+
import copy
221+
tsgroup2 = copy.deepcopy(tsgroup)
222+
tsgroup2.restrict_info("region")
223+
print(tsgroup2)
224+
```
225+
```{admonition} Note
226+
The `rate` column will always be kept for a `TsGroup`.
227+
```
228+
217229
## Using metadata to slice objects
218230
Metadata can be used to slice or filter objects based on metadata values.
219231

pynapple/core/interval_set.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1198,6 +1198,45 @@ def drop_info(self, key):
11981198
"""
11991199
return _MetadataMixin.drop_info(self, key)
12001200

1201+
@add_meta_docstring("restrict_info")
1202+
def restrict_info(self, key):
1203+
"""
1204+
Examples
1205+
--------
1206+
>>> import pynapple as nap
1207+
>>> import numpy as np
1208+
>>> times = np.array([[0, 5], [10, 12], [20, 33]])
1209+
>>> metadata = {"l1": [1, 2, 3], "l2": ["x", "x", "y"], "l3": [4, 5, 6]}
1210+
>>> ep = nap.IntervalSet(tmp,metadata=metadata)
1211+
>>> ep
1212+
index start end l1 l2 l3
1213+
0 0 5 1 x 4
1214+
1 10 12 2 x 5
1215+
2 20 33 3 y 6
1216+
shape: (3, 2), time unit: sec.
1217+
1218+
To restrict to multiple metadata columns:
1219+
1220+
>>> ep.restrict_info(["l2", "l3"])
1221+
>>> ep
1222+
index start end l2 l3
1223+
0 0 5 x 4
1224+
1 10 12 x 5
1225+
2 20 33 y 6
1226+
shape: (3, 2), time unit: sec.
1227+
1228+
To restrict to a single metadata column:
1229+
1230+
>>> ep.restrict_info("l2")
1231+
>>> ep
1232+
index start end l2
1233+
0 0 5 x
1234+
1 10 12 x
1235+
2 20 33 y
1236+
shape: (3, 2), time unit: sec.
1237+
"""
1238+
return _MetadataMixin.restrict_info(self, key)
1239+
12011240
@add_or_convert_metadata
12021241
@add_meta_docstring("groupby")
12031242
def groupby(self, by, get_group=None):

pynapple/core/metadata_class.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -407,6 +407,39 @@ def drop_info(self, key):
407407
f"Invalid metadata column {key}. Metadata columns are {self.metadata_columns}"
408408
)
409409

410+
def restrict_info(self, key):
411+
"""
412+
Restrict metadata columns to a key or list of keys.
413+
414+
Parameters
415+
----------
416+
key : str or list of str
417+
Metadata column name(s) to restrict to.
418+
419+
Returns
420+
-------
421+
None
422+
"""
423+
if isinstance(key, Number):
424+
raise TypeError(
425+
f"Invalid metadata column {key}. Metadata columns are {self.metadata_columns}"
426+
)
427+
if isinstance(key, str):
428+
key = [key]
429+
430+
no_keep = [k for k in key if k not in self.metadata_columns]
431+
if no_keep:
432+
raise KeyError(
433+
f"Metadata column(s) {no_keep} not found. Metadata columns are {self.metadata_columns}"
434+
)
435+
436+
drop_keys = set(self.metadata_columns) - set(key)
437+
for k in drop_keys:
438+
if (self.nap_class == "TsGroup") and (k == "rate"):
439+
continue # cannot drop TsGroup 'rate'
440+
else:
441+
del self._metadata[k]
442+
410443
def groupby(self, by, get_group=None):
411444
"""
412445
Group pynapple object by metadata name(s).

pynapple/core/time_series.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2371,6 +2371,65 @@ def drop_info(self, key):
23712371
"""
23722372
return _MetadataMixin.drop_info(self, key)
23732373

2374+
@add_meta_docstring("restrict_info")
2375+
def restrict_info(self, key):
2376+
"""
2377+
Examples
2378+
--------
2379+
>>> import pynapple as nap
2380+
>>> import numpy as np
2381+
>>> metadata = {"l1": [1, 2, 3], "l2": ["x", "x", "y"], "l3": [4, 5, 6]}
2382+
>>> tsdframe = nap.TsdFrame(t=np.arange(5), d=np.ones((5, 3)), metadata=metadata)
2383+
>>> print(tsdframe)
2384+
Time (s) 0 1 2
2385+
---------- --- --- ---
2386+
0.0 1.0 1.0 1.0
2387+
1.0 1.0 1.0 1.0
2388+
2.0 1.0 1.0 1.0
2389+
3.0 1.0 1.0 1.0
2390+
4.0 1.0 1.0 1.0
2391+
Metadata
2392+
---------- --- --- ---
2393+
l1 1 2 3
2394+
l2 x x y
2395+
l3 4 5 6
2396+
dtype: float64, shape: (5, 3)
2397+
2398+
To restrict to multiple metadata rows:
2399+
2400+
>>> tsdframe.restrict_info(["l2", "l3"])
2401+
>>> tsdframe
2402+
Time (s) 0 1 2
2403+
---------- --- --- ---
2404+
0.0 1.0 1.0 1.0
2405+
1.0 1.0 1.0 1.0
2406+
2.0 1.0 1.0 1.0
2407+
3.0 1.0 1.0 1.0
2408+
4.0 1.0 1.0 1.0
2409+
Metadata
2410+
---------- --- --- ---
2411+
l2 x x y
2412+
l3 4 5 6
2413+
dtype: float64, shape: (5, 3)
2414+
2415+
To restrict to a single metadata row:
2416+
2417+
>>> tsdframe.restrict_info("l2")
2418+
>>> tsdframe
2419+
Time (s) 0 1 2
2420+
---------- --- --- ---
2421+
0 1 1 1
2422+
1 1 1 1
2423+
2 1 1 1
2424+
3 1 1 1
2425+
4 1 1 1
2426+
Metadata
2427+
---------- --- --- ---
2428+
l2 x x y
2429+
dtype: float64, shape: (5, 3)
2430+
"""
2431+
return _MetadataMixin.restrict_info(self, key)
2432+
23742433
@add_or_convert_metadata
23752434
@add_meta_docstring("groupby")
23762435
def groupby(self, by, get_group=None):

pynapple/core/ts_group.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1843,6 +1843,52 @@ def drop_info(self, key):
18431843
"""
18441844
return _MetadataMixin.drop_info(self, key)
18451845

1846+
@add_meta_docstring("restrict_info")
1847+
def restrict_info(self, key):
1848+
"""
1849+
Note
1850+
----
1851+
The `rate` column is always kept in the metadata, even if it is not specified in `key`.
1852+
1853+
Examples
1854+
--------
1855+
>>> import pynapple as nap
1856+
>>> import numpy as np
1857+
>>> tmp = {0:nap.Ts(t=np.arange(0,200), time_units='s'),
1858+
... 1:nap.Ts(t=np.arange(0,200,0.5), time_units='s'),
1859+
... 2:nap.Ts(t=np.arange(0,300,0.25), time_units='s'),
1860+
... }
1861+
>>> metadata = {"l1": [1, 2, 3], "l2": ["x", "x", "y"], "l3": [4, 5, 6]}
1862+
>>> tsgroup = nap.TsGroup(tmp,metadata=metadata)
1863+
>>> print(tsgroup)
1864+
Index rate l1 l2 l3
1865+
------- ------- ---- ---- ----
1866+
0 0.66722 1 x 4
1867+
1 1.33445 2 x 5
1868+
2 4.00334 3 y 6
1869+
1870+
To restrict to multiple metadata columns:
1871+
1872+
>>> tsgroup.restrict_info(["l2", "l3"])
1873+
>>> tsgroup
1874+
Index rate l2 l3
1875+
------- ------- ---- ----
1876+
0 0.66722 x 4
1877+
1 1.33445 x 5
1878+
2 4.00334 y 6
1879+
1880+
To restrict to a single metadata column:
1881+
1882+
>>> tsgroup.drop_info("l2")
1883+
>>> tsgroup
1884+
Index rate l2
1885+
------- ------- ----
1886+
0 0.66722 x
1887+
1 1.33445 x
1888+
2 4.00334 y
1889+
"""
1890+
return _MetadataMixin.restrict_info(self, key)
1891+
18461892
@add_or_convert_metadata
18471893
@add_meta_docstring("groupby")
18481894
def groupby(self, by, get_group=None):

tests/test_metadata.py

Lines changed: 70 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1465,6 +1465,75 @@ def test_drop_metadata_error(self, obj, obj_len, drop, error):
14651465
if isinstance(drop, list) and ("label" in drop):
14661466
assert "label" in obj.metadata_columns
14671467

1468+
def test_restrict_metadata(self, obj, obj_len):
1469+
"""
1470+
Test for restricting metadata with restrict_info.
1471+
"""
1472+
info = np.ones(obj_len)
1473+
obj.set_info(l1=info, l2=info * 2, l3=info * 3)
1474+
for col in ["l1", "l2", "l3"]:
1475+
assert col in obj.metadata_columns
1476+
1477+
# restrict to 1 key
1478+
obj.restrict_info("l1")
1479+
assert "l1" in obj.metadata_columns
1480+
for col in ["l2", "l3"]:
1481+
assert col not in obj.metadata_columns
1482+
1483+
# rate should always be present in TsGroup
1484+
if isinstance(obj, nap.TsGroup):
1485+
assert "rate" in obj.metadata_columns
1486+
1487+
# restrict to multiple keys
1488+
obj.set_info(l2=info * 2, l3=info * 3, l4=info * 4)
1489+
obj.restrict_info(["l1", "l2"])
1490+
for col in ["l1", "l2"]:
1491+
assert col in obj.metadata_columns
1492+
for col in ["l3", "l4"]:
1493+
assert col not in obj.metadata_columns
1494+
1495+
# rate should always be present in TsGroup
1496+
if isinstance(obj, nap.TsGroup):
1497+
assert "rate" in obj.metadata_columns
1498+
1499+
@pytest.mark.parametrize(
1500+
"keep, error",
1501+
[
1502+
(
1503+
"not_info",
1504+
pytest.raises(
1505+
KeyError,
1506+
match=r"Metadata column\(s\) \['not_info'\] not found",
1507+
),
1508+
),
1509+
(
1510+
["not_info", "not_info2"],
1511+
pytest.raises(
1512+
KeyError,
1513+
match=r"Metadata column\(s\) \['not_info', 'not_info2'\] not found",
1514+
),
1515+
),
1516+
(
1517+
["label", 0],
1518+
pytest.raises(KeyError, match=r"Metadata column\(s\) \[0\] not found"),
1519+
),
1520+
(0, pytest.raises(TypeError, match="Invalid metadata column")),
1521+
],
1522+
)
1523+
def test_restrict_metadata_error(self, obj, obj_len, keep, error):
1524+
"""
1525+
Test for errors when dropping metadata.
1526+
"""
1527+
info = np.ones(obj_len)
1528+
obj.set_info(label=info, other=info * 2)
1529+
1530+
with error:
1531+
obj.restrict_info(keep)
1532+
1533+
# make sure nothing gets dropped
1534+
assert "label" in obj.metadata_columns
1535+
assert "other" in obj.metadata_columns
1536+
14681537
# test naming overlap of shared attributes
14691538
@pytest.mark.parametrize(
14701539
"name",
@@ -2527,7 +2596,7 @@ def test_no_conflict_between_class_and_metadatamixin(nap_class):
25272596
conflicting_members = iset_members.intersection(metadatamixin_members)
25282597

25292598
# set_info, get_info, drop_info, groupby, and groupby_apply are overwritten for class-specific examples in docstrings
2530-
assert len(conflicting_members) == 5, (
2599+
assert len(conflicting_members) == 6, (
25312600
f"Conflict detected! The following methods/attributes are "
25322601
f"overwritten in IntervalSet: {conflicting_members}"
25332602
)

0 commit comments

Comments
 (0)