Skip to content

Commit de0686e

Browse files
committed
FEAT: added support for pandas index columns in pyarrow.Parquet and made cache system more generic
1 parent dc70860 commit de0686e

File tree

1 file changed

+121
-49
lines changed

1 file changed

+121
-49
lines changed

larray_editor/arrayadapter.py

Lines changed: 121 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -352,6 +352,10 @@ def __init__(self, data, attributes=None):
352352
self.vmax = None
353353
self._number_format = "%s"
354354
self.sort_key = None # (kind='axis'|'column'|'row', idx_of_kind, direction (1, -1))
355+
# caching support
356+
self._cached_fragment = None
357+
self._cached_fragment_v_start = None
358+
self._cached_fragment_h_start = None
355359

356360
# ================================ #
357361
# methods which MUST be overridden #
@@ -402,6 +406,39 @@ def close(self):
402406
"""Close the ressources used by the adapter"""
403407
pass
404408

409+
def _is_chunk_cached(self, h_start, v_start, h_stop, v_stop):
410+
cached_fragment = self._cached_fragment
411+
if cached_fragment is None:
412+
return False
413+
cached_h_start = self._cached_fragment_h_start
414+
cached_v_start = self._cached_fragment_v_start
415+
cached_width = cached_fragment.shape[1]
416+
cached_height = cached_fragment.shape[0]
417+
return (h_start >= cached_h_start and
418+
h_stop <= cached_h_start + cached_width and
419+
v_start >= cached_v_start and
420+
v_stop <= cached_v_start + cached_height)
421+
422+
def _get_fragment_via_cache(self, h_start, v_start, h_stop, v_stop):
423+
clsname = self.__class__.__name__
424+
logger.debug(f"{clsname}._get_fragment_via_cache({h_start, v_start, h_stop, v_stop})")
425+
if self._is_chunk_cached(h_start, v_start, h_stop, v_stop):
426+
fragment = self._cached_fragment
427+
fragment_h_start = self._cached_fragment_h_start
428+
fragment_v_start = self._cached_fragment_v_start
429+
logger.debug(" -> cache hit ! "
430+
f"({fragment_h_start=} {fragment_v_start=})")
431+
else:
432+
fragment, fragment_h_start, fragment_v_start = (
433+
self._get_fragment_from_source(h_start, v_start,
434+
h_stop, v_stop))
435+
logger.debug(" -> cache miss ! "
436+
f"({fragment_h_start=} {fragment_v_start=})")
437+
self._cached_fragment = fragment
438+
self._cached_fragment_h_start = fragment_h_start
439+
self._cached_fragment_v_start = fragment_v_start
440+
return fragment, fragment_h_start, fragment_v_start
441+
405442
# TODO: factorize with LArrayArrayAdapter (so that we get the attributes
406443
# handling of LArrayArrayAdapter for all types and the larray adapter
407444
# can benefit from the generic code here
@@ -2105,70 +2142,105 @@ def get_values(self, h_start, v_start, h_stop, v_stop):
21052142
@adapter_for('pyarrow.parquet.ParquetFile')
21062143
class PyArrowParquetFileAdapter(AbstractColumnarAdapter):
21072144
def __init__(self, data, attributes):
2145+
import json
21082146
super().__init__(data=data, attributes=attributes)
21092147
self._schema = data.schema
2110-
# TODO: take pandas metadata index columns into account:
2111-
# - display those columns as labels
2112-
# - remove those columns from shape
2113-
# - do not read those columns in get_values
2114-
# pandas_metadata = data.schema.to_arrow_schema().pandas_metadata
2115-
# index_columns = pandas_metadata['index_columns']
2148+
meta = data.metadata
2149+
self._num_cols = meta.num_columns
2150+
self._num_rows = meta.num_rows
2151+
meta_meta = data.metadata.metadata
2152+
col_names = self._schema.names
2153+
self._col_names = col_names
2154+
self._pandas_idx_cols = []
2155+
if b'pandas' in meta_meta:
2156+
pd_meta = json.loads(meta_meta[b'pandas'])
2157+
2158+
idx_col_names = pd_meta['index_columns']
2159+
if all(isinstance(col_name, str) for col_name in idx_col_names):
2160+
idx_col_indices = [col_names.index(col_name)
2161+
for col_name in idx_col_names]
2162+
# We only support the case where index columns are at the end
2163+
# and are sorted. It is the case in all files I have seen so
2164+
# far but I don't know whether it is always the case
2165+
expected_first_idx_col = self._num_cols - len(idx_col_indices)
2166+
idx_cols_at_the_end = all(idx >= expected_first_idx_col
2167+
for idx in idx_col_indices)
2168+
idx_cols_sorted = sorted(idx_col_indices) == idx_col_indices
2169+
if idx_cols_at_the_end and idx_cols_sorted:
2170+
self._pandas_idx_cols = idx_col_indices
2171+
self._num_cols -= len(idx_col_indices)
2172+
21162173
meta = data.metadata
21172174
num_rows_per_group = np.array([meta.row_group(i).num_rows
21182175
for i in range(data.num_row_groups)])
21192176
self._group_ends = num_rows_per_group.cumsum()
21202177
assert self._group_ends[-1] == meta.num_rows
2121-
self._cached_table = None
2122-
self._cached_table_h_start = None
2123-
self._cached_table_v_start = None
21242178

21252179
def shape2d(self):
2126-
meta = self.data.metadata
2127-
return meta.num_rows, meta.num_columns
2180+
return self._num_rows, self._num_cols
21282181

21292182
def get_hlabels_values(self, start, stop):
2130-
return [self._schema.names[start:stop]]
2183+
return [self._col_names[start:stop]]
21312184

2132-
# TODO: provide caching in a base class
2133-
def _is_chunk_cached(self, h_start, v_start, h_stop, v_stop):
2134-
cached_table = self._cached_table
2135-
if cached_table is None:
2136-
return False
2137-
cached_v_start = self._cached_table_v_start
2138-
cached_h_start = self._cached_table_h_start
2139-
return (h_start >= cached_h_start and
2140-
h_stop <= cached_h_start + cached_table.shape[1] and
2141-
v_start >= cached_v_start and
2142-
v_stop <= cached_v_start + cached_table.shape[0])
2185+
def get_vnames(self):
2186+
if self._pandas_idx_cols:
2187+
return [self._col_names[i] for i in self._pandas_idx_cols]
2188+
else:
2189+
return ['']
21432190

2144-
def get_values(self, h_start, v_start, h_stop, v_stop):
2145-
if self._is_chunk_cached(h_start, v_start, h_stop, v_stop):
2146-
logger.debug("cache hit !")
2147-
table = self._cached_table
2148-
table_h_start = self._cached_table_h_start
2149-
table_v_start = self._cached_table_v_start
2191+
def get_vlabels_values(self, start, stop):
2192+
if self._pandas_idx_cols:
2193+
# This assumes that index columns are contiguous (which is
2194+
# implicitly tested in __init__ via the "all index at the end" test)
2195+
# and sorted (tested in __init__)
2196+
h_start = self._pandas_idx_cols[0]
2197+
h_stop = self._pandas_idx_cols[-1] + 1
2198+
return self.get_values(h_start, start, h_stop, stop)
21502199
else:
2151-
logger.debug("cache miss !")
2152-
start_row_group, stop_row_group = (
2153-
# - 1 because the last row is not included
2154-
np.searchsorted(self._group_ends, [v_start, v_stop - 1],
2155-
side='right'))
2156-
# - 1 because _group_ends stores row group ends and we want the start
2157-
table_h_start = h_start
2158-
table_v_start = (
2159-
self._group_ends[start_row_group - 1] if start_row_group > 0 else 0)
2160-
row_groups = range(start_row_group, stop_row_group + 1)
2161-
column_names = self._schema.names[h_start:h_stop]
2162-
f = self.data
2163-
table = f.read_row_groups(row_groups, columns=column_names)
2164-
self._cached_table = table
2165-
self._cached_table_h_start = table_h_start
2166-
self._cached_table_v_start = table_v_start
2167-
2168-
chunk = table[v_start - table_v_start:v_stop - table_v_start]
2200+
return [[i] for i in range(start, stop)]
2201+
2202+
def get_values(self, h_start, v_start, h_stop, v_stop):
2203+
# fragment is a pyarrow.Table
2204+
fragment, fragment_h_start, fragment_v_start = (
2205+
self._get_fragment_via_cache(h_start, v_start, h_stop, v_stop))
2206+
2207+
# chunk is a list of pyarrow.ChunkedArray
2208+
chunk = self._fragment_to_chunk(fragment,
2209+
fragment_h_start, fragment_v_start,
2210+
h_start, v_start, h_stop, v_stop)
2211+
return self._chunk_to_numpy(chunk)
2212+
2213+
def _get_fragment_from_source(self, h_start, v_start, h_stop, v_stop):
2214+
start_row_group, stop_row_group = (
2215+
# - 1 because the last row is not included
2216+
np.searchsorted(self._group_ends, [v_start, v_stop - 1],
2217+
side='right'))
2218+
# - 1 because _group_ends stores row group ends and we want the start
2219+
table_h_start = h_start
2220+
table_v_start = (
2221+
self._group_ends[start_row_group - 1] if start_row_group > 0 else 0)
2222+
row_groups = range(start_row_group, stop_row_group + 1)
2223+
column_names = self._schema.names[h_start:h_stop]
2224+
f = self.data
2225+
table = f.read_row_groups(row_groups, columns=column_names)
2226+
return table, table_h_start, table_v_start
2227+
2228+
# fragment is a native object representing the smallest buffer which can
2229+
# hold the requested rows and columns (not sliced in memory)
2230+
# chunk is the actual requested slice from that fragment, still in
2231+
# whatever format is most convenient for the adapter
2232+
def _fragment_to_chunk(self, fragment, fragment_h_start, fragment_v_start,
2233+
h_start, v_start, h_stop, v_stop):
2234+
2235+
chunk = fragment[v_start - fragment_v_start:v_stop - fragment_v_start]
2236+
h_start_in_chunk = h_start - fragment_h_start
2237+
h_stop_in_chunk = h_stop - fragment_h_start
21692238
# not going via to_pandas() because it "eats" index columns
2170-
columns = chunk.columns[h_start - table_h_start:h_stop - table_h_start]
2171-
np_columns = [c.to_numpy() for c in columns]
2239+
return chunk.columns[h_start_in_chunk:h_stop_in_chunk]
2240+
2241+
def _chunk_to_numpy(self, chunk):
2242+
# chunk is a list of pyarrow.ChunkedArray
2243+
np_columns = [c.to_numpy() for c in chunk]
21722244
try:
21732245
return np.stack(np_columns, axis=1)
21742246
except np.exceptions.DTypePromotionError:

0 commit comments

Comments
 (0)