-
Notifications
You must be signed in to change notification settings - Fork 4
implement where api #298
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
implement where api #298
Changes from 12 commits
aa7301a
834b5a9
fd79955
a564305
fa62e3b
dea5b9c
bd51519
6da082d
0f83acd
213ddac
127511d
c0746fd
78c7b95
86fdeed
0d51c11
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -9,17 +9,19 @@ | |
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| from typing import Callable, Optional, Union | ||
| from typing import Callable, ItemsView, Optional, Union | ||
| from datetime import datetime, timezone | ||
| import operator | ||
|
|
||
| import numpy as np | ||
| import h5py | ||
| import re | ||
|
|
||
| from exetera.core.abstract_types import Field | ||
| from exetera.core.data_writer import DataWriter | ||
| from exetera.core import operations as ops | ||
| from exetera.core import validation as val | ||
| from exetera.core import utils | ||
|
|
||
|
|
||
| def isin(field:Field, test_elements:Union[list, set, np.ndarray]): | ||
|
|
@@ -39,6 +41,97 @@ def isin(field:Field, test_elements:Union[list, set, np.ndarray]): | |
| return ret | ||
|
|
||
|
|
||
| def where(cond: Union[list, tuple, np.ndarray, Field], a, b): | ||
| if isinstance(cond, (list, tuple, np.ndarray)): | ||
| cond = cond | ||
| elif isinstance(cond, Field): | ||
| if isinstance(cond, (NumericField, CategoricalField)): | ||
| cond = cond.data[:] | ||
| else: | ||
| raise NotImplementedError("Where only support condition on numeric field and categorical field at present.") | ||
|
||
| elif callable(cond): | ||
| raise NotImplementedError("module method `fields.where` doesn't support callable cond, please use instance mehthod `where` for callable cond.") | ||
|
||
|
|
||
| return where_helper(cond, a, b) | ||
|
|
||
|
|
||
| def where_helper(cond:Union[list, tuple, np.ndarray, Field], a, b) -> Field: | ||
|
|
||
| def get_indices_and_values_from_non_indexed_string_field(other_field): | ||
| # convert other field data to string array | ||
| other_field_row_count = len(other_field.data[:]) | ||
| data_converted_to_str = np.where([True]*other_field_row_count, other_field.data[:], [""]*other_field_row_count) | ||
| maxLength = 0 | ||
| re_match = re.findall(r"<U(\d+)|>U(\d+)|S(\d+)", str(data_converted_to_str.dtype)) | ||
| if re_match: | ||
| for l in re_match[0]: | ||
| if l: | ||
| maxLength = int(l) | ||
| else: | ||
| raise ValueError("The return dtype of instance method `where` doesn't match '<U(\d+)' or 'S(\d+)' when one of the field is FixedStringField") | ||
|
|
||
| # convert other field string array to indices and values | ||
| indices = np.zeros(other_field_row_count + 1, dtype=np.int64) | ||
| values = np.zeros(np.int64(other_field_row_count*maxLength), dtype=np.uint8) | ||
| for i, s in enumerate(data_converted_to_str): | ||
| encoded_s = np.array(list(s), dtype='S1').view(np.uint8) | ||
| indices[i + 1] = indices[i] + len(encoded_s) | ||
| values[indices[i]:indices[i + 1]] = encoded_s | ||
| return indices, values | ||
|
|
||
| def get_indices_and_values_from_all_field(f): | ||
| if isinstance(f, (IndexedStringField, IndexedStringMemField)): | ||
| indices, values = f.indices[:], f.values[:] | ||
| else: | ||
| indices, values = get_indices_and_values_from_non_indexed_string_field(f) | ||
| return indices, values | ||
|
|
||
| result_mem_field = None | ||
|
|
||
| if isinstance(a, (IndexedStringField, IndexedStringMemField)) or isinstance(b, (IndexedStringField, IndexedStringMemField)): | ||
| a_indices, a_values = get_indices_and_values_from_all_field(a) | ||
| b_indices, b_values = get_indices_and_values_from_all_field(b) | ||
|
|
||
| if len(cond) != len(a_indices) - 1 or len(cond) != len(b_indices) - 1: | ||
| raise ValueError(f"operands can't work with shapes ({len(cond)},) ({len(a_indices) - 1},) ({len(b_indices) - 1},)") | ||
|
|
||
| # get indices and values for result | ||
| r_indices = np.zeros(len(a_indices), dtype=np.int64) | ||
| r_values = np.zeros(max(len(a_values), len(b_values)), dtype=np.uint8) | ||
| ops.where_for_two_indexed_string_fields(np.array(cond), a_indices, a_values, b_indices, b_values, r_indices, r_values) | ||
| r_values = r_values[:r_indices[-1]] | ||
|
|
||
| # return IndexStringMemField | ||
| result_mem_field = IndexedStringMemField(a._session) | ||
| result_mem_field.indices.write(r_indices) | ||
| result_mem_field.values.write(r_values) | ||
|
|
||
| else: | ||
| b_data = b.data[:] if isinstance(b, Field) else b | ||
| r_ndarray = np.where(cond, a.data[:], b_data) | ||
|
|
||
| if isinstance(a, (FixedStringField, FixedStringMemField)) or isinstance(b, (FixedStringField, FixedStringMemField)): | ||
| maxLength = 0 | ||
| re_match = re.findall(r"<U(\d+)|>U(\d+)|S(\d+)", str(r_ndarray.dtype)) | ||
| if re_match: | ||
| for l in re_match[0]: | ||
| if l: | ||
| maxLength = int(l) | ||
| else: | ||
| raise ValueError("The return dtype of instance method `where` doesn't match '<U(\d+)' or 'S(\d+)' when one of the field is FixedStringField") | ||
|
||
|
|
||
| result_mem_field = FixedStringMemField(a._session, maxLength) | ||
| result_mem_field.data.write(r_ndarray) | ||
|
|
||
| elif str(r_ndarray.dtype) in utils.PERMITTED_NUMERIC_TYPES: | ||
| result_mem_field = NumericMemField(a._session, str(r_ndarray.dtype)) | ||
| result_mem_field.data.write(r_ndarray) | ||
| else: | ||
| raise NotImplementedError(f"instance method `where` doesn't support the current input type: {type(a)} and {type(b)}") | ||
|
|
||
| return result_mem_field | ||
|
|
||
|
|
||
| class HDF5Field(Field): | ||
| def __init__(self, session, group, dataframe, write_enabled=False): | ||
| super().__init__() | ||
|
|
@@ -143,6 +236,21 @@ def _ensure_valid(self): | |
| if not self._valid_reference: | ||
| raise ValueError("This field no longer refers to a valid underlying field object") | ||
|
|
||
| def where(self, cond:Union[list, tuple, np.ndarray, Field, Callable], b, inplace=False): | ||
| if isinstance(cond, (list, tuple, np.ndarray)): | ||
| cond = cond | ||
| elif isinstance(cond, Field): | ||
| if isinstance(cond, (NumericField, CategoricalField)): | ||
|
||
| cond = cond.data[:] | ||
| else: | ||
| raise NotImplementedError("Where only support condition on numeric field and categorical field at present.") | ||
|
||
| elif callable(cond): | ||
| cond = cond(self.data[:]) | ||
| else: | ||
| raise TypeError("'cond' parameter needs to be either callable lambda function, or array like, or NumericMemField.") | ||
|
||
|
|
||
| return where_helper(cond, self, b) | ||
|
|
||
|
|
||
| class MemoryField(Field): | ||
|
|
||
|
|
@@ -223,6 +331,22 @@ def apply_index(self, index_to_apply, dstfld=None): | |
| raise NotImplementedError("Please use apply_index() on specific fields, not the field base class.") | ||
|
|
||
|
|
||
| def where(self, cond:Union[list, tuple, np.ndarray, Field, Callable], b, inplace=False): | ||
| if isinstance(cond, (list, tuple, np.ndarray)): | ||
| cond = cond | ||
| elif isinstance(cond, Field): | ||
| if isinstance(cond, (NumericField, CategoricalField)): | ||
|
||
| cond = cond.data[:] | ||
| else: | ||
| raise NotImplementedError("Where only support condition on numeric field and categorical field at present.") | ||
|
||
| elif callable(cond): | ||
| cond = cond(self.data[:]) | ||
| else: | ||
| raise TypeError("'cond' parameter needs to be either callable lambda function, or array like, or NumericMemField.") | ||
|
||
|
|
||
| return where_helper(cond, self, b) | ||
|
|
||
|
|
||
| class ReadOnlyFieldArray: | ||
| def __init__(self, field, dataset_name): | ||
| self._field = field | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
still not checking for both hdf5 and mem field types