Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions exetera/core/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -998,8 +998,8 @@ def _write_groupby_keys(self, ddf: DataFrame, write_keys=True):
Write groupby keys to ddf only if write_key = True
"""
if write_keys:
by_fields = np.asarray([self._columns[k] for k in self._by])
for field in by_fields:
for k in self._by:
field = self._columns[k]
newfld = field.create_like(ddf, field.name)

if self._sorted_index is not None:
Expand Down
31 changes: 29 additions & 2 deletions exetera/core/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def get_spans(self):
"""
raise NotImplementedError("Please use get_spans() on specific fields, not the field base class.")

def apply_filter(self, filter_to_apply, dstfld=None):
def apply_filter(self, filter_to_apply, target=None, in_place=False):
"""
Apply filter on the field.
"""
Expand All @@ -143,6 +143,33 @@ def _ensure_valid(self):
if not self._valid_reference:
raise ValueError("This field no longer refers to a valid underlying field object")

def __getitem__(self, item:Union[list, tuple, np.ndarray]):
if isinstance(item, slice):
data = self.data[item]
memfield = self.create_like()
memfield.data.write(data)
return memfield

elif isinstance(item, int):
data = self.data[item]
memfield = self.create_like()
memfield.data.write(np.array([data]))
return memfield

elif isinstance(item, (list, tuple, np.ndarray)):
allBooleanFlag = True
for x in item:
if not isinstance(x, bool):
allBooleanFlag = False
break

if allBooleanFlag:
filter_to_apply = np.array(item, dtype='bool') if not isinstance(item, np.ndarray) else item
return self.apply_filter(filter_to_apply, target=None, in_place=False)
else:
index_to_apply = np.array(item, dtype=np.int64) if not isinstance(item, np.ndarray) else item
return self.apply_index(index_to_apply, target=None, in_place=False)


class MemoryField(Field):

Expand Down Expand Up @@ -210,7 +237,7 @@ def __bool__(self):
# if f is not None:
return True

def apply_filter(self, filter_to_apply, dstfld=None):
def apply_filter(self, filter_to_apply, target=None, in_place=False):
"""
Apply filter on the field.
"""
Expand Down
68 changes: 68 additions & 0 deletions tests/test_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -2288,3 +2288,71 @@ def test_argsort(self, creator, name, kwargs, data):
else:
with self.assertRaises(ValueError):
fields.argsort(f)


ARRAY_DEREFERENCE_FILTER_TESTS = [
([True, False, True], "create_indexed_string", {}, ['a', 'bb', 'ccc']),
([True, False, True], "create_fixed_string", {"length": 3}, ['a', 'b', 'c']),
([True, False, True], "create_numeric", {"nformat": "int8"}, [20,30,40]),
([True, False, True], "create_categorical", {"nformat": "int32", "key": {"a": 1, "b": 2, "c": 3}}, [1,2,3])
]

ARRAY_DEREFERENCE_INDEX_TESTS = [
([0, 2], "create_indexed_string", {}, ['a', 'bb', 'ccc']),
([0, 2], "create_fixed_string", {"length": 3}, ['a', 'b', 'c']),
([0, 2], "create_numeric", {"nformat": "int8"}, [20,30,40]),
([0, 2], "create_categorical", {"nformat": "int32", "key": {"a": 1, "b": 2, "c": 3}}, [1,2,3])
]

ARRAY_SLICE_AND_INT_TESTS = [
(slice(0,2,1), "create_indexed_string", {}, ['a', 'bb', 'ccc']),
(slice(0,3,1), "create_fixed_string", {"length": 3}, [b'a', b'b', b'c']),
(slice(0,2,2), "create_numeric", {"nformat": "int8"}, [20,30,40]),
(slice(0,3,2), "create_categorical", {"nformat": "int32", "key": {"a": 1, "b": 2, "c": 3}}, [1,2,3]),
(0, "create_indexed_string", {}, ['a', 'bb', 'ccc']),
(1, "create_fixed_string", {"length": 3}, [b'a', b'b', b'c']),
(2, "create_numeric", {"nformat": "int8"}, [20,30,40]),
(0, "create_categorical", {"nformat": "int32", "key": {"a": 1, "b": 2, "c": 3}}, [1,2,3])
]
class TestArrayDereferenceFunctions(SessionTestCase):

def assertIfMemFieldAndIfSameTypeAsField(self, memfield, field):
self.assertIsInstance(memfield, fields.MemoryField)
if not (isinstance(field, fields.IndexedStringField) and isinstance(memfield, fields.IndexedStringMemField)) \
and not (isinstance(field, fields.FixedStringField) and isinstance(memfield, fields.FixedStringMemField)) \
and not (isinstance(field, fields.NumericField) and isinstance(memfield, fields.NumericMemField)) \
and not (isinstance(field, fields.CategoricalField) and isinstance(memfield, fields.CategoricalMemField)):
raise AssertionError(f"{type(memfield)} is not the MemField for {type(field)}")


@parameterized.expand(ARRAY_DEREFERENCE_FILTER_TESTS)
def test_field_filter_dereference(self, filter, creator, kwargs, data):
f = self.setup_field(self.df, creator, 'f', (), kwargs, data)
result = f[filter]

filter_to_apply = filter if isinstance(filter, np.ndarray) else np.array(filter, dtype=np.int8)
expected_result = f.apply_filter(filter_to_apply, target=None, in_place=False)

self.assertIfMemFieldAndIfSameTypeAsField(result, f)
np.testing.assert_array_equal(result.data[:], expected_result.data[:])

@parameterized.expand(ARRAY_DEREFERENCE_INDEX_TESTS)
def test_field_index_dereference(self, index, creator, kwargs, data):
f = self.setup_field(self.df, creator, 'f', (), kwargs, data)
result = f[index]

index_to_apply = index if isinstance(index, np.ndarray) else np.array(index, dtype=np.int8)
expected_result = f.apply_index(index_to_apply, target=None, in_place=False)

self.assertIfMemFieldAndIfSameTypeAsField(result, f)
np.testing.assert_array_equal(result.data[:], expected_result.data[:])


@parameterized.expand(ARRAY_SLICE_AND_INT_TESTS)
def test_field_slice(self, slice, creator, kwargs, data):
f = self.setup_field(self.df, creator, 'f', (), kwargs, data)
result = f[slice]
expected_result = data[slice]

self.assertIfMemFieldAndIfSameTypeAsField(result, f)
np.testing.assert_array_equal(result.data[:], expected_result)