diff --git a/doc/source/whatsnew/v3.0.0.rst b/doc/source/whatsnew/v3.0.0.rst index e6fafc8b1b14c..94f3a97b3c955 100644 --- a/doc/source/whatsnew/v3.0.0.rst +++ b/doc/source/whatsnew/v3.0.0.rst @@ -720,8 +720,10 @@ MultiIndex - :func:`MultiIndex.get_level_values` accessing a :class:`DatetimeIndex` does not carry the frequency attribute along (:issue:`58327`, :issue:`57949`) - Bug in :class:`DataFrame` arithmetic operations in case of unaligned MultiIndex columns (:issue:`60498`) - Bug in :class:`DataFrame` arithmetic operations with :class:`Series` in case of unaligned MultiIndex (:issue:`61009`) +- Bug in :class:`MultiIndex.concat` where extension dtypes such as ``timestamp[pyarrow]`` were silently coerced to ``object`` instead of preserving their original dtype (:issue:`58421`) - Bug in :meth:`MultiIndex.from_tuples` causing wrong output with input of type tuples having NaN values (:issue:`60695`, :issue:`60988`) + I/O ^^^ - Bug in :class:`DataFrame` and :class:`Series` ``repr`` of :py:class:`collections.abc.Mapping`` elements. (:issue:`57915`) diff --git a/pandas/core/reshape/concat.py b/pandas/core/reshape/concat.py index 5efaf0dc051bd..231e6fc9fb07e 100644 --- a/pandas/core/reshape/concat.py +++ b/pandas/core/reshape/concat.py @@ -22,6 +22,7 @@ from pandas.core.dtypes.common import ( is_bool, + is_extension_array_dtype, is_scalar, ) from pandas.core.dtypes.concat import concat_compat @@ -36,6 +37,7 @@ factorize_from_iterables, ) import pandas.core.common as com +from pandas.core.construction import array as pd_array from pandas.core.indexes.api import ( Index, MultiIndex, @@ -824,7 +826,20 @@ def _get_sample_object( def _concat_indexes(indexes) -> Index: - return indexes[0].append(indexes[1:]) + # try to preserve extension types such as timestamp[pyarrow] + values = [] + for idx in indexes: + values.extend(idx._values if hasattr(idx, "_values") else idx) + + # use the first index as a sample to infer the desired dtype + sample = indexes[0] + try: + # this helps preserve extension types like timestamp[pyarrow] + arr = pd_array(values, dtype=sample.dtype) + except Exception: + arr = pd_array(values) # fallback + + return Index(arr) def validate_unique_levels(levels: list[Index]) -> None: @@ -881,14 +896,32 @@ def _make_concat_multiindex(indexes, keys, levels=None, names=None) -> MultiInde concat_index = _concat_indexes(indexes) - # these go at the end if isinstance(concat_index, MultiIndex): levels.extend(concat_index.levels) codes_list.extend(concat_index.codes) else: - codes, categories = factorize_from_iterable(concat_index) - levels.append(categories) - codes_list.append(codes) + # handle the case where the resulting index is a flat Index + # but contains tuples (i.e., a collapsed MultiIndex) + if isinstance(concat_index[0], tuple): + # retrieve the original dtypes + original_dtypes = [lvl.dtype for lvl in indexes[0].levels] + + unzipped = list(zip(*concat_index)) + for i, level_values in enumerate(unzipped): + # reconstruct each level using original dtype + arr = pd_array(level_values, dtype=original_dtypes[i]) + level_codes, _ = factorize_from_iterable(arr) + levels.append(ensure_index(arr)) + codes_list.append(level_codes) + else: + # simple indexes factorize directly + codes, categories = factorize_from_iterable(concat_index) + values = getattr(concat_index, "_values", concat_index) + if is_extension_array_dtype(values): + levels.append(values) + else: + levels.append(categories) + codes_list.append(codes) if len(names) == len(levels): names = list(names) diff --git a/pandas/tests/frame/methods/test_concat_arrow_index.py b/pandas/tests/frame/methods/test_concat_arrow_index.py new file mode 100644 index 0000000000000..6fcc5ee5119a6 --- /dev/null +++ b/pandas/tests/frame/methods/test_concat_arrow_index.py @@ -0,0 +1,51 @@ +import pytest + +import pandas as pd + +schema = { + "id": "int64[pyarrow]", + "time": "timestamp[s][pyarrow]", + "value": "float[pyarrow]", +} + + +@pytest.mark.parametrize("dtype", ["timestamp[s][pyarrow]"]) +def test_concat_preserves_pyarrow_timestamp(dtype): + dfA = ( + pd.DataFrame( + [ + (0, "2021-01-01 00:00:00", 5.3), + (1, "2021-01-01 00:01:00", 5.4), + (2, "2021-01-01 00:01:00", 5.4), + (3, "2021-01-01 00:02:00", 5.5), + ], + columns=schema, + ) + .astype(schema) + .set_index(["id", "time"]) + ) + + dfB = ( + pd.DataFrame( + [ + (1, "2022-01-01 08:00:00", 6.3), + (2, "2022-01-01 08:01:00", 6.4), + (3, "2022-01-01 08:02:00", 6.5), + ], + columns=schema, + ) + .astype(schema) + .set_index(["id", "time"]) + ) + + df = pd.concat([dfA, dfB], keys=[0, 1], names=["run"]) + + # check whether df.index is multiIndex + assert isinstance(df.index, pd.MultiIndex), ( + f"Expected MultiIndex, but received {type(df.index)}" + ) + + # Verifying special dtype timestamp[s][pyarrow] stays intact after concat + assert df.index.levels[2].dtype == dtype, ( + f"Expected {dtype}, but received {df.index.levels[2].dtype}" + )