@@ -352,6 +352,10 @@ def __init__(self, data, attributes=None):
352352 self .vmax = None
353353 self ._number_format = "%s"
354354 self .sort_key = None # (kind='axis'|'column'|'row', idx_of_kind, direction (1, -1))
355+ # caching support
356+ self ._cached_fragment = None
357+ self ._cached_fragment_v_start = None
358+ self ._cached_fragment_h_start = None
355359
356360 # ================================ #
357361 # methods which MUST be overridden #
@@ -402,6 +406,39 @@ def close(self):
402406 """Close the ressources used by the adapter"""
403407 pass
404408
409+ def _is_chunk_cached (self , h_start , v_start , h_stop , v_stop ):
410+ cached_fragment = self ._cached_fragment
411+ if cached_fragment is None :
412+ return False
413+ cached_h_start = self ._cached_fragment_h_start
414+ cached_v_start = self ._cached_fragment_v_start
415+ cached_width = cached_fragment .shape [1 ]
416+ cached_height = cached_fragment .shape [0 ]
417+ return (h_start >= cached_h_start and
418+ h_stop <= cached_h_start + cached_width and
419+ v_start >= cached_v_start and
420+ v_stop <= cached_v_start + cached_height )
421+
422+ def _get_fragment_via_cache (self , h_start , v_start , h_stop , v_stop ):
423+ clsname = self .__class__ .__name__
424+ logger .debug (f"{ clsname } ._get_fragment_via_cache({ h_start , v_start , h_stop , v_stop } )" )
425+ if self ._is_chunk_cached (h_start , v_start , h_stop , v_stop ):
426+ fragment = self ._cached_fragment
427+ fragment_h_start = self ._cached_fragment_h_start
428+ fragment_v_start = self ._cached_fragment_v_start
429+ logger .debug (" -> cache hit ! "
430+ f"({ fragment_h_start = } { fragment_v_start = } )" )
431+ else :
432+ fragment , fragment_h_start , fragment_v_start = (
433+ self ._get_fragment_from_source (h_start , v_start ,
434+ h_stop , v_stop ))
435+ logger .debug (" -> cache miss ! "
436+ f"({ fragment_h_start = } { fragment_v_start = } )" )
437+ self ._cached_fragment = fragment
438+ self ._cached_fragment_h_start = fragment_h_start
439+ self ._cached_fragment_v_start = fragment_v_start
440+ return fragment , fragment_h_start , fragment_v_start
441+
405442 # TODO: factorize with LArrayArrayAdapter (so that we get the attributes
406443 # handling of LArrayArrayAdapter for all types and the larray adapter
407444 # can benefit from the generic code here
@@ -2105,70 +2142,105 @@ def get_values(self, h_start, v_start, h_stop, v_stop):
21052142@adapter_for ('pyarrow.parquet.ParquetFile' )
21062143class PyArrowParquetFileAdapter (AbstractColumnarAdapter ):
21072144 def __init__ (self , data , attributes ):
2145+ import json
21082146 super ().__init__ (data = data , attributes = attributes )
21092147 self ._schema = data .schema
2110- # TODO: take pandas metadata index columns into account:
2111- # - display those columns as labels
2112- # - remove those columns from shape
2113- # - do not read those columns in get_values
2114- # pandas_metadata = data.schema.to_arrow_schema().pandas_metadata
2115- # index_columns = pandas_metadata['index_columns']
2148+ meta = data .metadata
2149+ self ._num_cols = meta .num_columns
2150+ self ._num_rows = meta .num_rows
2151+ meta_meta = data .metadata .metadata
2152+ col_names = self ._schema .names
2153+ self ._col_names = col_names
2154+ self ._pandas_idx_cols = []
2155+ if b'pandas' in meta_meta :
2156+ pd_meta = json .loads (meta_meta [b'pandas' ])
2157+
2158+ idx_col_names = pd_meta ['index_columns' ]
2159+ if all (isinstance (col_name , str ) for col_name in idx_col_names ):
2160+ idx_col_indices = [col_names .index (col_name )
2161+ for col_name in idx_col_names ]
2162+ # We only support the case where index columns are at the end
2163+ # and are sorted. It is the case in all files I have seen so
2164+ # far but I don't know whether it is always the case
2165+ expected_first_idx_col = self ._num_cols - len (idx_col_indices )
2166+ idx_cols_at_the_end = all (idx >= expected_first_idx_col
2167+ for idx in idx_col_indices )
2168+ idx_cols_sorted = sorted (idx_col_indices ) == idx_col_indices
2169+ if idx_cols_at_the_end and idx_cols_sorted :
2170+ self ._pandas_idx_cols = idx_col_indices
2171+ self ._num_cols -= len (idx_col_indices )
2172+
21162173 meta = data .metadata
21172174 num_rows_per_group = np .array ([meta .row_group (i ).num_rows
21182175 for i in range (data .num_row_groups )])
21192176 self ._group_ends = num_rows_per_group .cumsum ()
21202177 assert self ._group_ends [- 1 ] == meta .num_rows
2121- self ._cached_table = None
2122- self ._cached_table_h_start = None
2123- self ._cached_table_v_start = None
21242178
21252179 def shape2d (self ):
2126- meta = self .data .metadata
2127- return meta .num_rows , meta .num_columns
2180+ return self ._num_rows , self ._num_cols
21282181
21292182 def get_hlabels_values (self , start , stop ):
2130- return [self ._schema . names [start :stop ]]
2183+ return [self ._col_names [start :stop ]]
21312184
2132- # TODO: provide caching in a base class
2133- def _is_chunk_cached (self , h_start , v_start , h_stop , v_stop ):
2134- cached_table = self ._cached_table
2135- if cached_table is None :
2136- return False
2137- cached_v_start = self ._cached_table_v_start
2138- cached_h_start = self ._cached_table_h_start
2139- return (h_start >= cached_h_start and
2140- h_stop <= cached_h_start + cached_table .shape [1 ] and
2141- v_start >= cached_v_start and
2142- v_stop <= cached_v_start + cached_table .shape [0 ])
2185+ def get_vnames (self ):
2186+ if self ._pandas_idx_cols :
2187+ return [self ._col_names [i ] for i in self ._pandas_idx_cols ]
2188+ else :
2189+ return ['' ]
21432190
2144- def get_values (self , h_start , v_start , h_stop , v_stop ):
2145- if self ._is_chunk_cached (h_start , v_start , h_stop , v_stop ):
2146- logger .debug ("cache hit !" )
2147- table = self ._cached_table
2148- table_h_start = self ._cached_table_h_start
2149- table_v_start = self ._cached_table_v_start
2191+ def get_vlabels_values (self , start , stop ):
2192+ if self ._pandas_idx_cols :
2193+ # This assumes that index columns are contiguous (which is
2194+ # implicitly tested in __init__ via the "all index at the end" test)
2195+ # and sorted (tested in __init__)
2196+ h_start = self ._pandas_idx_cols [0 ]
2197+ h_stop = self ._pandas_idx_cols [- 1 ] + 1
2198+ return self .get_values (h_start , start , h_stop , stop )
21502199 else :
2151- logger .debug ("cache miss !" )
2152- start_row_group , stop_row_group = (
2153- # - 1 because the last row is not included
2154- np .searchsorted (self ._group_ends , [v_start , v_stop - 1 ],
2155- side = 'right' ))
2156- # - 1 because _group_ends stores row group ends and we want the start
2157- table_h_start = h_start
2158- table_v_start = (
2159- self ._group_ends [start_row_group - 1 ] if start_row_group > 0 else 0 )
2160- row_groups = range (start_row_group , stop_row_group + 1 )
2161- column_names = self ._schema .names [h_start :h_stop ]
2162- f = self .data
2163- table = f .read_row_groups (row_groups , columns = column_names )
2164- self ._cached_table = table
2165- self ._cached_table_h_start = table_h_start
2166- self ._cached_table_v_start = table_v_start
2167-
2168- chunk = table [v_start - table_v_start :v_stop - table_v_start ]
2200+ return [[i ] for i in range (start , stop )]
2201+
2202+ def get_values (self , h_start , v_start , h_stop , v_stop ):
2203+ # fragment is a pyarrow.Table
2204+ fragment , fragment_h_start , fragment_v_start = (
2205+ self ._get_fragment_via_cache (h_start , v_start , h_stop , v_stop ))
2206+
2207+ # chunk is a list of pyarrow.ChunkedArray
2208+ chunk = self ._fragment_to_chunk (fragment ,
2209+ fragment_h_start , fragment_v_start ,
2210+ h_start , v_start , h_stop , v_stop )
2211+ return self ._chunk_to_numpy (chunk )
2212+
2213+ def _get_fragment_from_source (self , h_start , v_start , h_stop , v_stop ):
2214+ start_row_group , stop_row_group = (
2215+ # - 1 because the last row is not included
2216+ np .searchsorted (self ._group_ends , [v_start , v_stop - 1 ],
2217+ side = 'right' ))
2218+ # - 1 because _group_ends stores row group ends and we want the start
2219+ table_h_start = h_start
2220+ table_v_start = (
2221+ self ._group_ends [start_row_group - 1 ] if start_row_group > 0 else 0 )
2222+ row_groups = range (start_row_group , stop_row_group + 1 )
2223+ column_names = self ._schema .names [h_start :h_stop ]
2224+ f = self .data
2225+ table = f .read_row_groups (row_groups , columns = column_names )
2226+ return table , table_h_start , table_v_start
2227+
2228+ # fragment is a native object representing the smallest buffer which can
2229+ # hold the requested rows and columns (not sliced in memory)
2230+ # chunk is the actual requested slice from that fragment, still in
2231+ # whatever format is most convenient for the adapter
2232+ def _fragment_to_chunk (self , fragment , fragment_h_start , fragment_v_start ,
2233+ h_start , v_start , h_stop , v_stop ):
2234+
2235+ chunk = fragment [v_start - fragment_v_start :v_stop - fragment_v_start ]
2236+ h_start_in_chunk = h_start - fragment_h_start
2237+ h_stop_in_chunk = h_stop - fragment_h_start
21692238 # not going via to_pandas() because it "eats" index columns
2170- columns = chunk .columns [h_start - table_h_start :h_stop - table_h_start ]
2171- np_columns = [c .to_numpy () for c in columns ]
2239+ return chunk .columns [h_start_in_chunk :h_stop_in_chunk ]
2240+
2241+ def _chunk_to_numpy (self , chunk ):
2242+ # chunk is a list of pyarrow.ChunkedArray
2243+ np_columns = [c .to_numpy () for c in chunk ]
21722244 try :
21732245 return np .stack (np_columns , axis = 1 )
21742246 except np .exceptions .DTypePromotionError :
0 commit comments