Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions exetera/core/abstract_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,10 @@ def isin(self, test_elements:Union[list, set, np.ndarray]):
def unique(self, return_index=False, return_inverse=False, return_counts=False):
raise NotImplementedError()

@staticmethod
def where(cond, a, b):
raise NotImplementedError()


class Dataset(ABC):
"""
Expand Down
53 changes: 53 additions & 0 deletions exetera/core/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from exetera.core.data_writer import DataWriter
from exetera.core import operations as ops
from exetera.core import validation as val
from exetera.core import utils


def isin(field:Field, test_elements:Union[list, set, np.ndarray]):
Expand All @@ -39,6 +40,23 @@ def isin(field:Field, test_elements:Union[list, set, np.ndarray]):
return ret


def where(cond: Union[list, tuple, np.ndarray, Field], a, b):
if isinstance(cond, (list, tuple, np.ndarray)):
cond = cond
elif isinstance(cond, Field):
if cond.indexed:
raise NotImplementedError("Where does not support condition on indexed string fields at present")
cond = cond.data[:]
elif callable(cond):
raise NotImplementedError("module method `fields.where` doesn't support callable cond, please use instance mehthod `where` for callable cond.")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo: mehthod -> method

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typo, please replace with:

"module method fields.where doesn't support callable cond parameter, please use the instance method where if you need to use a callable cond parameter"


if isinstance(a, Field):
a = a.data[:]
if isinstance(b, Field):
b = b.data[:]
return np.where(cond, a, b)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is still returning a numpy array rather than a field

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is still returning a numpy array rather than a field

The logic of module-level where API will be almost same as instance-level where API. Think we can focus on one first, e.g. instance-level where API.



class HDF5Field(Field):
def __init__(self, session, group, dataframe, write_enabled=False):
super().__init__()
Expand Down Expand Up @@ -143,6 +161,41 @@ def _ensure_valid(self):
if not self._valid_reference:
raise ValueError("This field no longer refers to a valid underlying field object")

def where(self, cond:Union[list, tuple, np.ndarray, Field], b, inplace=False):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add the callable signature to cond's type information

if isinstance(cond, (list, tuple, np.ndarray)):
cond = cond
elif isinstance(cond, Field):
if cond.indexed:
raise NotImplementedError("Where does not support indexed string fields at present")
cond = cond.data[:]
elif callable(cond):
cond = cond(self.data[:])
else:
raise TypeError("'cond' parameter needs to be either callable lambda function, or array like, or NumericMemField")

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here we could just do return where(cond, self, b) and then the rest of the body of this method can be put into the global where function.

# if isinstance(b, str):
# b = b.encode()
if isinstance(b, Field):
b = b.data[:]

result_ndarray = np.where(cond, self.data[:], b)
result_mem_field = None
if str(result_ndarray.dtype) in utils.PERMITTED_NUMERIC_TYPES:
result_mem_field = NumericMemField(self._session, str(result_ndarray.dtype))
result_mem_field.data.write(result_ndarray)

elif isinstance(self, (IndexedStringField, FixedStringField)) or isinstance(b, (IndexedStringField, FixedStringField)):
result_mem_field = IndexedStringMemField(self._session)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't seem right. Why are we causing an operation with fixed string field to output an indexed string field?
It doesn't make the logic much more complicated. Also, I would make that a separate method probably, because I can imagine us needing it elsewhere in the future.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For FixedStringField, you can refer to the matrix I listed above. Only when two FixedStringField will generate FixedStringField, otherwise it will be IndexedStringField.

result_mem_field.data.write(result_ndarray)
else:
raise NotImplementedError(f"instance method where doesn't support the current input type")

# if inplace:
# self.data.clear()
# self.data.write(result)

return result_mem_field


class MemoryField(Field):

Expand Down
137 changes: 137 additions & 0 deletions tests/test_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -2169,6 +2169,143 @@ def test_indexed_string_isin(self, data, isin_data, expected):
np.testing.assert_array_equal(expected, result)


WHERE_NUMERIC_TESTS = [
(lambda f: f > 5, "create_numeric", {"nformat": "int8"}, shuffle_randstate(list(range(-10,10))), None, None, 0, 'int8'),
(lambda f: f > 5, "create_categorical", {"nformat": "int32", "key": {"a": 1, "b": 2, "c": 3}}, RAND_STATE.randint(1, 4, 20).tolist(), None, None, -1.0, 'float64'),
(lambda f: f > 5, "create_numeric", {"nformat": "float32"}, shuffle_randstate(list(range(-10,10))), None, None, -1.0, 'float32'),
(lambda f: f > 5, "create_numeric", {"nformat": "int32"}, shuffle_randstate(list(range(-10,10))), None, None, shuffle_randstate(list(range(0,20))), 'int64'),
(lambda f: f > 5, "create_numeric", {"nformat": "int32"}, shuffle_randstate(list(range(-10,10))), None, None, np.array(shuffle_randstate(list(range(0,20))), dtype='int32'), 'int32'),
(lambda f: f > 5, "create_categorical", {"nformat": "int16", "key": {"a": 1, "b": 2, "c": 3}}, RAND_STATE.randint(1, 4, 20).tolist(), None, None, np.array(shuffle_randstate(list(range(-20,0))), dtype='float32'), 'float32'),
(lambda f: f > 5, "create_numeric", {"nformat": "float32"}, shuffle_randstate(list(range(-10,10))), "create_categorical", {"nformat": "int8", "key": {"a": 1, "b": 2, "c": 3}}, RAND_STATE.randint(1, 4, 20).tolist(), 'float32'),
(lambda f: f > 5, "create_numeric", {"nformat": "float32"}, shuffle_randstate(list(range(-10,10))), "create_categorical", {"nformat": "int32", "key": {"a": 1, "b": 2, "c": 3}}, RAND_STATE.randint(1, 4, 20).tolist(), 'float64'),
(lambda f: f > 5, "create_categorical", {"nformat": "int16", "key": {"a": 1, "b": 2, "c": 3}}, RAND_STATE.randint(1, 4, 20).tolist(), "create_numeric", {"nformat": "float32"}, shuffle_randstate(list(range(-10,10))), 'float32'),
(lambda f: f > 5, "create_categorical", {"nformat": "int16", "key": {"a": 1, "b": 2, "c": 3}}, RAND_STATE.randint(1, 4, 20).tolist(), "create_numeric", {"nformat": "float32"}, shuffle_randstate(list(range(-10,10))), 'float32'),
(lambda f: f > 5, "create_numeric", {"nformat": "float32"}, shuffle_randstate(list(range(-10,10))),"create_numeric", {"nformat": "float64"}, shuffle_randstate(list(range(-10,10))), 'float64'),
(RAND_STATE.randint(0, 2, 20).tolist(), "create_categorical", {"nformat": "int16", "key": {"a": 1, "b": 2, "c": 3}}, RAND_STATE.randint(1, 4, 20).tolist(),"create_categorical", {"nformat": "int32", "key": {"a": 1, "b": 2, "c": 3}}, RAND_STATE.randint(1, 4, 20).tolist(), 'int32'),

]


WHERE_INDEXED_STRING_TESTS = [
(lambda f: f > 5, ['a', 'b', 'c'], [1,2,3]),
]

def where_oracle(cond, a, b):
if callable(cond):
if isinstance(a, fields.Field):
cond = cond(a.data[:])
elif isinstance(a, list):
cond = cond(np.array(a))
elif isinstance(a, np.ndarray):
cond = cond(a)
return np.where(cond, a, b)


class TestFieldWhereFunctions(SessionTestCase):

@parameterized.expand(WHERE_NUMERIC_TESTS)
def test_module_fields_where(self, cond, a_creator, a_kwarg, a_field_data, b_creator, b_kwarg, b_data, expected_dtype):
"""
Test `where` for the numeric fields using `fields.where` function and the object's method.
"""
a_field = self.setup_field(self.df, a_creator, "af", (), a_kwarg, a_field_data)
expected_result = where_oracle(cond, a_field_data, b_data)

if b_kwarg is None:
with self.subTest(f"Test instance where method: a is {type(a_field)}, b is {type(b_data)}"):
if callable(cond):
with self.assertRaises(NotImplementedError) as context:
result = fields.where(cond, a_field, b_data)
self.assertEqual(str(context.exception), "module method `fields.where` doesn't support callable cond, please use instance mehthod `where` for callable cond.")
else:
result = fields.where(cond, a_field, b_data)
np.testing.assert_array_equal(expected_result, result)

else:
b_field = self.setup_field(self.df, b_creator, "bf", (), b_kwarg, b_data)
with self.subTest(f"Test instance where method: a is {type(a_field)}, b is {type(b_field)}"):
if callable(cond):
with self.assertRaises(NotImplementedError) as context:
result = fields.where(cond, a_field, b_field)
self.assertEqual(str(context.exception), "module method `fields.where` doesn't support callable cond, please use instance mehthod `where` for callable cond.")
else:
result = fields.where(cond, a_field, b_field)
np.testing.assert_array_equal(expected_result, result)


@parameterized.expand(WHERE_NUMERIC_TESTS)
def test_instance_field_where_return_numericmemfield(self, cond, a_creator, a_kwarg, a_field_data, b_creator, b_kwarg, b_data, expected_dtype):
a_field = self.setup_field(self.df, a_creator, "af", (), a_kwarg, a_field_data)

expected_result = where_oracle(cond, a_field_data, b_data)

if b_kwarg is None:
with self.subTest(f"Test instance where method: a is {type(a_field)}, b is {type(b_data)}"):
result = a_field.where(cond, b_data)
self.assertEqual(result._nformat, expected_dtype)
np.testing.assert_array_equal(result, expected_result)

else:
b_field = self.setup_field(self.df, b_creator, "bf", (), b_kwarg, b_data)

with self.subTest(f"Test instance where method: a is {type(a_field)}, b is {type(b_field)}"):
result = a_field.where(cond, b_field)
self.assertIsInstance(result, fields.NumericMemField)
self.assertEqual(result._nformat, expected_dtype)
np.testing.assert_array_equal(result, expected_result)


@parameterized.expand(WHERE_INDEXED_STRING_TESTS)
def test_instance_field_where_return_numericmemfield(self, cond, a, b):
pass


# def test_field_where_fixed_string(self):
# def create_fixed_string(df, name):
# f = df.create_fixed_string(name, 6)
# f.data.write(np.asarray(['foo', '"foo"', '', 'bar', 'barn', 'bat'], dtype='S6'))
# return f

# self._test_module_where(create_fixed_string, lambda f: np.char.str_len(f.data[:]) > 3,
# 'boo', '_far',
# ['_far', 'boo', '_far', '_far', 'boo', '_far'])
# # [b'_far', b'boo', b'_far', b'_far', b'boo', b'_far'])

# self._test_instance_where(create_fixed_string, lambda f: np.char.str_len(f.data[:]) > 3,
# 'foobar',
# [b'foobar', b'"foo"', b'foobar', b'foobar', b'barn', b'foobar'])


# def test_field_where_indexed_string(self):
# def create_indexed_string(df, name):
# f = df.create_indexed_string(name)
# f.data.write(['foo', '"foo"', '', 'bar', 'barn', 'bat'])
# return f

# self._test_module_where(create_indexed_string, lambda f: np.char.str_len(f.data[:]) > 3,
# 'boo', '_far', ['_far', 'boo', '_far', '_far', 'boo', '_far'])

# self._test_module_where(create_indexed_string, lambda f: (f.indices[1:] - f.indices[:-1]) > 3,
# 'boo', '_far', ['_far', 'boo', '_far', '_far', 'boo', '_far'])

#
# def test_instance_where_numeric_inplace(self):
# input_data = [1,2,3,5,9,8,6,4,7,0]
# data = np.asarray(input_data, dtype=np.int32)
# bio = BytesIO()
# with session.Session() as s:
# src = s.open_dataset(bio, 'w', 'src')
# df = src.create_dataframe('df')
# f = df.create_numeric('foo', 'int32')
# f.data.write(data)
#
# r = f.where(f > 5, 0)
# self.assertEqual(list(f.data[:]), [1,2,3,5,9,8,6,4,7,0])
# r = f.where(f > 5, 0, inplace=True)
# self.assertEqual(list(f.data[:]), [0,0,0,0,9,8,6,0,7,0])
#


class TestFieldModuleFunctions(SessionTestCase):

@parameterized.expand(DEFAULT_FIELD_DATA)
Expand Down