Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions exetera/core/abstract_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,10 @@ def isin(self, test_elements:Union[list, set, np.ndarray]):
def unique(self, return_index=False, return_inverse=False, return_counts=False):
raise NotImplementedError()

@staticmethod
def where(cond, a, b):
raise NotImplementedError()


class Dataset(ABC):
"""
Expand Down
126 changes: 125 additions & 1 deletion exetera/core/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Callable, Optional, Union
from typing import Callable, ItemsView, Optional, Union
from datetime import datetime, timezone
import operator

import numpy as np
import h5py
import re

from exetera.core.abstract_types import Field
from exetera.core.data_writer import DataWriter
from exetera.core import operations as ops
from exetera.core import validation as val
from exetera.core import utils


def isin(field:Field, test_elements:Union[list, set, np.ndarray]):
Expand All @@ -39,6 +41,97 @@ def isin(field:Field, test_elements:Union[list, set, np.ndarray]):
return ret


def where(cond: Union[list, tuple, np.ndarray, Field], a, b):
if isinstance(cond, (list, tuple, np.ndarray)):
cond = cond
elif isinstance(cond, Field):
if isinstance(cond, (NumericField, NumericMemField, CategoricalField, CategoricalMemField)):
cond = cond.data[:]
else:
raise NotImplementedError("where only supports python sequences, numpy ndarrays, and numeric field and categorical field types for the cond parameter at present.")
elif callable(cond):
raise NotImplementedError("module method fields.where doesn't support callable cond parameter, please use the instance method where if you need to use a callable cond parameter.")

return where_helper(cond, a, b)


def where_helper(cond:Union[list, tuple, np.ndarray, Field], a, b) -> Field:

def get_indices_and_values_from_non_indexed_string_field(other_field):
# convert other field data to string array
other_field_row_count = len(other_field.data[:])
data_converted_to_str = np.where([True]*other_field_row_count, other_field.data[:], [""]*other_field_row_count)
maxLength = 0
re_match = re.findall(r"<U(\d+)|>U(\d+)|S(\d+)", str(data_converted_to_str.dtype))
if re_match:
for l in re_match[0]:
if l:
maxLength = int(l)
else:
raise ValueError("The return dtype of instance method where doesn't match '<U(\d+)' or 'S(\d+)' when one of the fields is a fixed string field")

# convert other field string array to indices and values
indices = np.zeros(other_field_row_count + 1, dtype=np.int64)
values = np.zeros(np.int64(other_field_row_count*maxLength), dtype=np.uint8)
for i, s in enumerate(data_converted_to_str):
encoded_s = np.array(list(s), dtype='S1').view(np.uint8)
indices[i + 1] = indices[i] + len(encoded_s)
values[indices[i]:indices[i + 1]] = encoded_s
return indices, values

def get_indices_and_values_from_all_field(f):
if isinstance(f, (IndexedStringField, IndexedStringMemField)):
indices, values = f.indices[:], f.values[:]
else:
indices, values = get_indices_and_values_from_non_indexed_string_field(f)
return indices, values

result_mem_field = None

if isinstance(a, (IndexedStringField, IndexedStringMemField)) or isinstance(b, (IndexedStringField, IndexedStringMemField)):
a_indices, a_values = get_indices_and_values_from_all_field(a)
b_indices, b_values = get_indices_and_values_from_all_field(b)

if len(cond) != len(a_indices) - 1 or len(cond) != len(b_indices) - 1:
raise ValueError(f"operands can't work with shapes ({len(cond)},) ({len(a_indices) - 1},) ({len(b_indices) - 1},)")

# get indices and values for result
r_indices = np.zeros(len(a_indices), dtype=np.int64)
r_values = np.zeros(max(len(a_values), len(b_values)), dtype=np.uint8)
ops.where_for_two_indexed_string_fields(np.array(cond), a_indices, a_values, b_indices, b_values, r_indices, r_values)
r_values = r_values[:r_indices[-1]]

# return IndexStringMemField
result_mem_field = IndexedStringMemField(a._session)
result_mem_field.indices.write(r_indices)
result_mem_field.values.write(r_values)

else:
b_data = b.data[:] if isinstance(b, Field) else b
r_ndarray = np.where(cond, a.data[:], b_data)

if isinstance(a, (FixedStringField, FixedStringMemField)) or isinstance(b, (FixedStringField, FixedStringMemField)):
maxLength = 0
re_match = re.findall(r"<U(\d+)|>U(\d+)|S(\d+)", str(r_ndarray.dtype))
if re_match:
for l in re_match[0]:
if l:
maxLength = int(l)
else:
raise ValueError("The return dtype of instance method where doesn't match '<U(\d+)' or 'S(\d+)' when one of the fields is a fixed string field")

result_mem_field = FixedStringMemField(a._session, maxLength)
result_mem_field.data.write(r_ndarray)

elif str(r_ndarray.dtype) in utils.PERMITTED_NUMERIC_TYPES:
result_mem_field = NumericMemField(a._session, str(r_ndarray.dtype))
result_mem_field.data.write(r_ndarray)
else:
raise NotImplementedError(f"instance method `where` doesn't support the current input type: {type(a)} and {type(b)}")

return result_mem_field


class HDF5Field(Field):
def __init__(self, session, group, dataframe, write_enabled=False):
super().__init__()
Expand Down Expand Up @@ -143,6 +236,21 @@ def _ensure_valid(self):
if not self._valid_reference:
raise ValueError("This field no longer refers to a valid underlying field object")

def where(self, cond:Union[list, tuple, np.ndarray, Field, Callable], b, inplace=False):
if isinstance(cond, (list, tuple, np.ndarray)):
cond = cond
elif isinstance(cond, Field):
if isinstance(cond, (NumericField, NumericMemField, CategoricalField, CategoricalMemField)):
cond = cond.data[:]
else:
raise NotImplementedError("where only supports python sequences, numpy ndarrays, and numeric field and categorical field types for the cond parameter at present.")
elif callable(cond):
cond = cond(self.data[:])
else:
raise TypeError("where only supports callables, python sequences, numpy ndarrays, and numeric field and categorical field types for the cond parameter at present.")

return where_helper(cond, self, b)


class MemoryField(Field):

Expand Down Expand Up @@ -223,6 +331,22 @@ def apply_index(self, index_to_apply, dstfld=None):
raise NotImplementedError("Please use apply_index() on specific fields, not the field base class.")


def where(self, cond:Union[list, tuple, np.ndarray, Field, Callable], b, inplace=False):
if isinstance(cond, (list, tuple, np.ndarray)):
cond = cond
elif isinstance(cond, Field):
if isinstance(cond, (NumericField, NumericMemField, CategoricalField, CategoricalMemField)):
cond = cond.data[:]
else:
raise NotImplementedError("where only supports python sequences, numpy ndarrays, and numeric field and categorical field types for the cond parameter at present.")
elif callable(cond):
cond = cond(self.data[:])
else:
raise TypeError("where only supports callables, python sequences, numpy ndarrays, and numeric field and categorical field types for the cond parameter at present.")

return where_helper(cond, self, b)


class ReadOnlyFieldArray:
def __init__(self, field, dataset_name):
self._field = field
Expand Down
10 changes: 10 additions & 0 deletions exetera/core/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -3111,3 +3111,13 @@ def compare_arrays(a, b):
return 1
return 0


@exetera_njit
def where_for_two_indexed_string_fields(cond, a_indices, a_values, b_indices, b_values, r_indices, r_values):
for i, c in enumerate(cond):
if c:
r_indices[i + 1] = r_indices[i] + a_indices[i + 1] - a_indices[i]
r_values[r_indices[i]:r_indices[i + 1]] = a_values[a_indices[i]:a_indices[i + 1]]
else:
r_indices[i + 1] = r_indices[i] + b_indices[i + 1] - b_indices[i]
r_values[r_indices[i]:r_indices[i + 1]] = b_values[b_indices[i]:b_indices[i + 1]]
Loading