diff --git a/empyrical/stats.py b/empyrical/stats.py index ef7a095..9583d57 100644 --- a/empyrical/stats.py +++ b/empyrical/stats.py @@ -237,7 +237,9 @@ def cum_returns(returns, starting_value=0, out=None): if returns.ndim == 1 and isinstance(returns, pd.Series): out = pd.Series(out, index=returns.index) elif isinstance(returns, pd.DataFrame): - out = pd.DataFrame(out, index=returns.index) + out = pd.DataFrame( + out, index=returns.index, columns=returns.columns, + ) return out diff --git a/empyrical/tests/test_stats.py b/empyrical/tests/test_stats.py index 86e5a5c..f11eb76 100644 --- a/empyrical/tests/test_stats.py +++ b/empyrical/tests/test_stats.py @@ -36,6 +36,10 @@ def assert_indexes_match(self, result, expected): """ assert_index_equal(result.index, expected.index) + if isinstance(result, pd.DataFrame) and \ + isinstance(expected, pd.DataFrame): + assert_index_equal(result.columns, expected.columns) + class TestStats(BaseTestCase):