Skip to content

Commit

Permalink
INTPYTHON-538 Add support for PyArrow Decimal128 type (#278)
Browse files Browse the repository at this point in the history
  • Loading branch information
blink1073 authored Feb 26, 2025
1 parent 9bbbed7 commit caf1cc5
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 2 deletions.
18 changes: 16 additions & 2 deletions bindings/python/pymongoarrow/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from decimal import Decimal

import numpy as np
import pandas as pd
Expand All @@ -25,6 +26,7 @@
import pymongo.errors
from bson import encode
from bson.codec_options import TypeEncoder, TypeRegistry
from bson.decimal128 import Decimal128
from bson.raw_bson import RawBSONDocument
from numpy import ndarray
from pyarrow import Schema as ArrowSchema
Expand Down Expand Up @@ -416,6 +418,18 @@ def transform_python(self, _):
return


class _DecimalCodec(TypeEncoder):
"""A custom type codec for Decimal objects."""

@property
def python_type(self):
return Decimal

def transform_python(self, value):
"""Transform an Decimal object into a BSON Decimal128 object"""
return Decimal128(value)


def write(collection, tabular, *, exclude_none: bool = False):
"""Write data from `tabular` into the given MongoDB `collection`.
Expand Down Expand Up @@ -469,9 +483,9 @@ def write(collection, tabular, *, exclude_none: bool = False):

tabular_gen = _tabular_generator(tabular, exclude_none=exclude_none)

# Handle Pandas NA objects.
# Add handling for special case types.
codec_options = collection.codec_options
type_registry = TypeRegistry([_PandasNACodec()])
type_registry = TypeRegistry([_PandasNACodec(), _DecimalCodec()])
codec_options = codec_options.with_options(type_registry=type_registry)

while cur_offset < tab_size:
Expand Down
1 change: 1 addition & 0 deletions bindings/python/pymongoarrow/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,7 @@ def get_numpy_type(type):
_atypes.is_date64: _BsonArrowTypes.date64.value,
_atypes.is_large_string: _BsonArrowTypes.string.value,
_atypes.is_large_list: _BsonArrowTypes.array.value,
_atypes.is_decimal128: _BsonArrowTypes.decimal128.value,
}


Expand Down
11 changes: 11 additions & 0 deletions bindings/python/test/test_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -966,6 +966,17 @@ def test_exclude_none(self):
col_data = list(self.coll.find({}))
assert "b" not in col_data[2]

def test_decimal128(self):
import decimal

a = decimal.Decimal("123.45")
arr = pa.array([a], pa.decimal128(5, 2))
data = Table.from_arrays([arr], names=["data"])
self.coll.drop()
write(self.coll, data)
coll_data = list(self.coll.find({}))
assert coll_data[0]["data"] == Decimal128(a)


class TestArrowExplicitApi(ArrowApiTestMixin, unittest.TestCase):
def run_find(self, *args, **kwargs):
Expand Down

0 comments on commit caf1cc5

Please sign in to comment.