diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 8e76212f7..821b28651 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -14,6 +14,7 @@ Changelog Added ^^^^^ - Implement __contains, __contained_by, __overlap and __len for ArrayField (#1877) +- Added default range validation to IntField, BigIntField, SmallIntField Fixed ^^^^^ diff --git a/tests/fields/test_int.py b/tests/fields/test_int.py index d5bd2b771..7421f3b31 100644 --- a/tests/fields/test_int.py +++ b/tests/fields/test_int.py @@ -1,6 +1,6 @@ from tests import testmodels from tortoise.contrib import test -from tortoise.exceptions import IntegrityError +from tortoise.exceptions import IntegrityError, ValidationError from tortoise.expressions import F @@ -38,6 +38,13 @@ async def test_min(self): obj2 = await testmodels.IntFields.get(id=obj.id) self.assertEqual(obj, obj2) + with self.assertRaises(ValidationError): + await testmodels.IntFields.create(intnum=-2147483649) + + async def test_max(self): + with self.assertRaises(ValidationError): + await testmodels.IntFields.create(intnum=2147483648) + async def test_cast(self): obj0 = await testmodels.IntFields.create(intnum="3") obj = await testmodels.IntFields.get(id=obj0.id) @@ -83,6 +90,13 @@ async def test_min(self): obj2 = await testmodels.SmallIntFields.get(id=obj.id) self.assertEqual(obj, obj2) + with self.assertRaises(ValidationError): + await testmodels.SmallIntFields.create(smallintnum=-32769) + + async def test_max(self): + with self.assertRaises(ValidationError): + await testmodels.SmallIntFields.create(smallintnum=32768) + async def test_values(self): obj0 = await testmodels.SmallIntFields.create(smallintnum=2) values = await testmodels.SmallIntFields.get(id=obj0.id).values("smallintnum") @@ -125,6 +139,13 @@ async def test_min(self): obj2 = await testmodels.BigIntFields.get(id=obj.id) self.assertEqual(obj, obj2) + with self.assertRaises(ValidationError): + await testmodels.BigIntFields.create(intnum=-9223372036854775809) + + async def test_max(self): + with self.assertRaises(ValidationError): + await testmodels.BigIntFields.create(intnum=9223372036854775808) + async def test_cast(self): obj0 = await testmodels.BigIntFields.create(intnum="3") obj = await testmodels.BigIntFields.get(id=obj0.id) diff --git a/tortoise/fields/data.py b/tortoise/fields/data.py index 307fb6df0..eed1d0687 100644 --- a/tortoise/fields/data.py +++ b/tortoise/fields/data.py @@ -13,7 +13,7 @@ from pypika_tortoise.terms import Term from tortoise import timezone -from tortoise.exceptions import ConfigurationError, FieldError +from tortoise.exceptions import ConfigurationError, FieldError, ValidationError from tortoise.fields.base import Field from tortoise.timezone import get_default_timezone, get_timezone, get_use_tz, localtime from tortoise.validators import MaxLengthValidator @@ -75,17 +75,32 @@ class IntField(Field[int], int): SQL_TYPE = "INT" allows_generated = True + GE = -2147483648 + LE = 2147483647 def __init__(self, primary_key: Optional[bool] = None, **kwargs: Any) -> None: if primary_key or kwargs.get("pk"): kwargs["generated"] = bool(kwargs.get("generated", True)) super().__init__(primary_key=primary_key, **kwargs) + def to_db_value(self, value: Any, instance: "Union[type[Model], Model]") -> Any: + if value is not None: + if not isinstance(value, int): + value = int(value) # pylint: disable=E1102 + if not self.GE <= value <= self.LE: + raise ValidationError( + f"{self.model_field_name}: " + f"Value should be less or equal to {self.LE} and greater or equal to {self.GE}" + ) + + self.validate(value) + return value + @property def constraints(self) -> dict: return { - "ge": -2147483648, - "le": 2147483647, + "ge": self.GE, + "le": self.LE, } class _db_postgres: @@ -113,13 +128,8 @@ class BigIntField(IntField): """ SQL_TYPE = "BIGINT" - - @property - def constraints(self) -> dict: - return { - "ge": -9223372036854775808, - "le": 9223372036854775807, - } + GE = -9223372036854775808 + LE = 9223372036854775807 class _db_postgres: GENERATED_SQL = "BIGSERIAL NOT NULL PRIMARY KEY" @@ -144,13 +154,8 @@ class SmallIntField(IntField): """ SQL_TYPE = "SMALLINT" - - @property - def constraints(self) -> dict: - return { - "ge": -32768, - "le": 32767, - } + GE = -32768 + LE = 32767 class _db_postgres: GENERATED_SQL = "SMALLSERIAL NOT NULL PRIMARY KEY"