Skip to content

Commit beabaaf

Browse files
authored
Merge pull request #1301 from thomcom/fea-ext-multiindex
Add MultiIndex support for Dataframes and Series
2 parents c78df5a + be1db05 commit beabaaf

File tree

18 files changed

+1033
-153
lines changed

18 files changed

+1033
-153
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
- PR #1466 Add GPU-accelerated ORC Reader
2626
- PR #1565 Add build script for nightly doc builds
2727
- PR #1508 Add Series isna, isnull, and notna
28+
- PR #1301 MultiIndex support
2829

2930
## Improvements
3031

ci/local/build.sh

100644100755
File mode changed.

cpp/src/binary/binary_ops.cu

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,7 @@ gdf_error gdf_div_f64(gdf_column *lhs, gdf_column *rhs, gdf_column *output) {
256256
gdf_error F##_generic(gdf_column *lhs, gdf_column *rhs, gdf_column *output) { \
257257
switch ( lhs->dtype ) { \
258258
case GDF_INT8: return F##_i8(lhs, rhs, output); \
259+
case GDF_STRING_CATEGORY: \
259260
case GDF_INT32: return F##_i32(lhs, rhs, output); \
260261
case GDF_INT64: return F##_i64(lhs, rhs, output); \
261262
case GDF_FLOAT32: return F##_f32(lhs, rhs, output); \

python/cudf/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1-
# Copyright (c) 2018, NVIDIA CORPORATION.
1+
# Copyright (c) 2018-2019, NVIDIA CORPORATION.
22

33
from cudf import dataframe
44
from cudf import datasets
55
from cudf.dataframe import DataFrame, from_pandas, merge
6-
from cudf.dataframe import Index
6+
from cudf.dataframe import Index, MultiIndex
77
from cudf.dataframe import Series
88
from cudf.multi import concat
99
from cudf.io import (read_csv, read_parquet, read_feather, read_json,

python/cudf/bindings/groupby.pyx

Lines changed: 14 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -334,12 +334,10 @@ def agg(groupby_class, args):
334334
sort_results=sort_results
335335
)
336336
add_col_values = False # we only want to add them once
337-
# TODO: Do multindex here
338-
if(groupby_class._as_index) and 1 == len(groupby_class._by):
339-
idx = index.as_index(result[groupby_class._by[0]])
340-
idx.name = groupby_class._by[0]
341-
result = result.set_index(idx)
342-
result.drop_column(idx.name)
337+
if(groupby_class._as_index):
338+
result = groupby_class.apply_multiindex_or_single_index(result)
339+
if use_prefix:
340+
result = groupby_class.apply_multicolumn(result, args)
343341
elif isinstance(args, collections.abc.Mapping):
344342
if (len(args.keys()) == 1):
345343
if(len(list(args.values())[0]) == 1):
@@ -377,15 +375,13 @@ def agg(groupby_class, args):
377375
sort_results=sort_results
378376
)
379377
add_col_values = False # we only want to add them once
380-
# TODO: Do multindex here
381-
if(groupby_class._as_index) and 1 == len(groupby_class._by):
382-
idx = index.as_index(result[groupby_class._by[0]])
383-
idx.name = groupby_class._by[0]
384-
result = result.set_index(idx)
385-
result.drop_column(idx.name)
378+
if groupby_class._as_index:
379+
result = groupby_class.apply_multiindex_or_single_index(result)
380+
if use_prefix:
381+
result = groupby_class.apply_multicolumn_mapped(result, args)
386382
else:
387383
result = groupby_class.agg([args])
388-
384+
389385
free(ctx)
390386

391387
nvtx_range_pop()
@@ -431,18 +427,13 @@ def _apply_basic_agg(groupby_class, agg_type, sort_results=False):
431427
else:
432428
idx.name = groupby_class._by[0]
433429
result_series = result_series.set_index(idx)
430+
if groupby_class._as_index:
431+
result = groupby_class.apply_multiindex_or_single_index(result)
432+
result_series.index = result.index
434433
return result_series
435434

436-
# TODO: Do MultiIndex here
437-
if(groupby_class._as_index):
438-
idx = index.as_index(result[groupby_class._by[0]])
439-
idx.name = groupby_class._by[0]
440-
result.drop_column(idx.name)
441-
if groupby_class.level == 0:
442-
idx.name = groupby_class._original_index_name
443-
else:
444-
idx.name = groupby_class._by[0]
445-
result = result.set_index(idx)
435+
if groupby_class._as_index:
436+
result = groupby_class.apply_multiindex_or_single_index(result)
446437

447438
nvtx_range_pop()
448439

python/cudf/dataframe/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
1+
# Copyright (c) 2018-2019, NVIDIA CORPORATION.
2+
13
from cudf.dataframe import (buffer, dataframe, series,
24
index, numerical, datetime, categorical, string)
35

46
from cudf.dataframe.dataframe import DataFrame, from_pandas, merge
57
from cudf.dataframe.index import (Index, GenericIndex,
68
RangeIndex, DatetimeIndex, CategoricalIndex)
9+
from cudf.dataframe.multiindex import MultiIndex
710
from cudf.dataframe.series import Series
811
from cudf.dataframe.buffer import Buffer
912
from cudf.dataframe.numerical import NumericalColumn

python/cudf/dataframe/dataframe.py

Lines changed: 88 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) 2018, NVIDIA CORPORATION.
1+
# Copyright (c) 2018-2019, NVIDIA CORPORATION.
22

33
from __future__ import print_function, division
44

@@ -29,6 +29,7 @@
2929

3030
from librmm_cffi import librmm as rmm
3131

32+
import cudf
3233
from cudf import formatting
3334
from cudf.utils import cudautils, queryutils, applyutils, utils, ioutils
3435
from cudf.dataframe.index import as_index, Index, RangeIndex
@@ -224,10 +225,14 @@ def __getitem__(self, arg):
224225
>>> print(df[[True, False, True, False]]) # mask the entire dataframe,
225226
# returning the rows specified in the boolean mask
226227
"""
228+
if isinstance(self.columns, cudf.dataframe.multiindex.MultiIndex) and\
229+
isinstance(arg, tuple):
230+
return self.columns._get_column_major(self, arg)
227231
if isinstance(arg, str) or isinstance(arg, numbers.Integral) or \
228232
isinstance(arg, tuple):
229233
s = self._cols[arg]
230234
s.name = arg
235+
s.index = self.index
231236
return s
232237
elif isinstance(arg, slice):
233238
df = DataFrame()
@@ -247,7 +252,7 @@ def __getitem__(self, arg):
247252
index = self.index.take(selinds.to_gpu_array())
248253
for col in self._cols:
249254
df[col] = Series(self._cols[col][arg], index=index)
250-
df.set_index(index)
255+
df = df.set_index(index)
251256
else:
252257
for col in arg:
253258
df[col] = self[col]
@@ -272,7 +277,6 @@ def mask(self, other):
272277
def __setitem__(self, name, col):
273278
"""Add/set column by *name or DataFrame*
274279
"""
275-
# div[div < 0] = 0
276280
if isinstance(name, DataFrame):
277281
for col_name in self._cols:
278282
mask = name[col_name]
@@ -399,6 +403,11 @@ def to_string(self, nrows=NOTSET, ncols=NOTSET):
399403
>>> df.to_string()
400404
' key val\\n0 0 10.0\\n1 1 11.0\\n2 2 12.0'
401405
"""
406+
if isinstance(self.index, cudf.dataframe.multiindex.MultiIndex) or\
407+
isinstance(self.columns, cudf.dataframe.multiindex.MultiIndex):
408+
raise TypeError("You're trying to print a DataFrame that contains "
409+
"a MultiIndex. Print this dataframe with "
410+
".to_pandas()")
402411
if nrows is NOTSET:
403412
nrows = settings.formatting.get('nrows')
404413
if ncols is NOTSET:
@@ -420,9 +429,12 @@ def to_string(self, nrows=NOTSET, ncols=NOTSET):
420429
# Prepare cells
421430
cols = OrderedDict()
422431
dtypes = OrderedDict()
423-
use_cols = list(self.columns[:ncols - 1])
424-
if ncols > 0:
425-
use_cols.append(self.columns[-1])
432+
if hasattr(self, 'multi_cols'):
433+
use_cols = list(range(len(self.columns)))
434+
else:
435+
use_cols = list(self.columns[:ncols - 1])
436+
if ncols > 0:
437+
use_cols.append(self.columns[-1])
426438

427439
for h in use_cols:
428440
cols[h] = self[h].values_to_string(nrows=nrows)
@@ -664,19 +676,41 @@ def iloc(self):
664676
def columns(self):
665677
"""Returns a tuple of columns
666678
"""
667-
return pd.Index(self._cols)
679+
if hasattr(self, 'multi_cols'):
680+
return self.multi_cols
681+
else:
682+
return pd.Index(self._cols)
668683

669684
@columns.setter
670685
def columns(self, columns):
686+
if isinstance(columns, Index):
687+
if len(columns) != len(self.columns):
688+
msg = f"Length mismatch: Expected axis has %d elements, "\
689+
"new values have %d elements"\
690+
% (len(self.columns), len(columns))
691+
raise ValueError(msg)
692+
"""
693+
new_names = []
694+
for idx, name in enumerate(columns):
695+
new_names.append(name)
696+
self._rename_columns(new_names)
697+
"""
698+
self.multi_cols = columns
699+
else:
700+
if hasattr(self, 'multi_cols'):
701+
delattr(self, 'multi_cols')
702+
self._rename_columns(columns)
703+
704+
def _rename_columns(self, new_names):
671705
old_cols = list(self._cols.keys())
672706
l_old_cols = len(old_cols)
673-
l_new_cols = len(columns)
707+
l_new_cols = len(new_names)
674708
if l_new_cols != l_old_cols:
675709
msg = f'Length of new column names: {l_new_cols} does not ' \
676710
'match length of previous column names: {l_old_cols}'
677711
raise ValueError(msg)
678712

679-
mapper = dict(zip(old_cols, columns))
713+
mapper = dict(zip(old_cols, new_names))
680714
self.rename(mapper=mapper, inplace=True)
681715

682716
@property
@@ -687,12 +721,26 @@ def index(self):
687721

688722
@index.setter
689723
def index(self, _index):
724+
if isinstance(_index, cudf.dataframe.multiindex.MultiIndex):
725+
if len(_index) != len(self[self.columns[0]]):
726+
msg = f"Length mismatch: Expected axis has "\
727+
"%d elements, new values "\
728+
"have %d elements"\
729+
% (len(self[self.columns[0]]), len(_index))
730+
raise ValueError(msg)
731+
self._index = _index
732+
for k in self.columns:
733+
self[k].index = _index
734+
return
735+
690736
new_length = len(_index)
691737
old_length = len(self._index)
692738

693739
if new_length != old_length:
694-
msg = f'Length mismatch: Expected index has {old_length}' \
695-
' elements, new values have {new_length} elements'
740+
msg = f"Length mismatch: Expected axis has "\
741+
"%d elements, new values "\
742+
"have %d elements"\
743+
% (old_length, new_length)
696744
raise ValueError(msg)
697745

698746
# try to build an index from generic _index
@@ -906,8 +954,8 @@ def drop(self, labels, axis=None):
906954
if axis == 0:
907955
raise NotImplementedError("Can only drop columns, not rows")
908956

909-
columns = [labels] if isinstance(labels, str) else list(labels)
910-
957+
columns = [labels] if isinstance(
958+
labels, (str, numbers.Number)) else list(labels)
911959
outdf = self.copy()
912960
for c in columns:
913961
outdf._drop_column(c)
@@ -2240,6 +2288,13 @@ def to_pandas(self):
22402288
out = pd.DataFrame(index=index)
22412289
for c, x in self._cols.items():
22422290
out[c] = x.to_pandas(index=index)
2291+
if isinstance(self.columns, Index):
2292+
out.columns = self.columns
2293+
if isinstance(self.columns, cudf.dataframe.multiindex.MultiIndex):
2294+
if self.columns.names is not None:
2295+
out.columns.names = self.columns.names
2296+
else:
2297+
out.columns.name = self.columns.name
22432298
return out
22442299

22452300
@classmethod
@@ -2269,7 +2324,12 @@ def from_pandas(cls, dataframe, nan_as_null=True):
22692324
vals = dataframe[colk].values
22702325
df[colk] = Series(vals, nan_as_null=nan_as_null)
22712326
# Set index
2272-
return df.set_index(dataframe.index)
2327+
if isinstance(dataframe.index, pd.MultiIndex):
2328+
import cudf
2329+
index = cudf.from_pandas(dataframe.index)
2330+
else:
2331+
index = dataframe.index
2332+
return df.set_index(index)
22732333

22742334
def to_arrow(self, preserve_index=True):
22752335
"""
@@ -2696,6 +2756,13 @@ def __getitem__(self, arg):
26962756
row_slice = None
26972757
row_label = None
26982758

2759+
if isinstance(self._df.index, cudf.dataframe.multiindex.MultiIndex)\
2760+
and isinstance(arg, tuple): # noqa: E501
2761+
# Explicitly ONLY support tuple indexes into MultiIndex.
2762+
# Pandas allows non tuple indices and warns "results may be
2763+
# undefined."
2764+
return self._df._index._get_row_major(self._df, arg)
2765+
26992766
if isinstance(arg, int):
27002767
if arg < 0 or arg >= len(self._df):
27012768
raise IndexError("label scalar %s is out of bound" % arg)
@@ -2785,7 +2852,9 @@ def __setitem__(self, key, value):
27852852

27862853
def from_pandas(obj):
27872854
"""
2788-
Convert a Pandas DataFrame or Series object into the cudf equivalent
2855+
Convert certain Pandas objects into the cudf equivalent.
2856+
2857+
Supports DataFrame, Series, or MultiIndex.
27892858
27902859
Raises
27912860
------
@@ -2804,9 +2873,12 @@ def from_pandas(obj):
28042873
return DataFrame.from_pandas(obj)
28052874
elif isinstance(obj, pd.Series):
28062875
return Series.from_pandas(obj)
2876+
elif isinstance(obj, pd.MultiIndex):
2877+
return cudf.dataframe.multiindex.MultiIndex.from_pandas(obj)
28072878
else:
28082879
raise TypeError(
2809-
"from_pandas only accepts Pandas Dataframes and Series objects. "
2880+
"from_pandas only accepts Pandas Dataframes, Series, and "
2881+
"MultiIndex objects. "
28102882
"Got %s" % type(obj)
28112883
)
28122884

0 commit comments

Comments
 (0)