diff --git a/core/models/entity/_base.py b/core/models/entity/_base.py index fa7cb6b8..b33deb9f 100644 --- a/core/models/entity/_base.py +++ b/core/models/entity/_base.py @@ -1,13 +1,14 @@ import json -from dataclasses import fields +from dataclasses import dataclass, fields from enum import Enum from typing import Any, Dict, List -class Entity: +def Entity(cls): + cls = dataclass(frozen=True)(cls) + def to_dict(self) -> Dict[str, Any]: field_dict = {f.name: getattr(self, f.name) for f in fields(self)} - property_dict = { k: getattr(self, k) for k in dir(self) @@ -26,7 +27,6 @@ def to_dict(self) -> Dict[str, Any]: v.to_dict() if hasattr(v, "to_dict") and callable(v.to_dict) else v for v in value ] - return result def to_json(self) -> str: @@ -69,6 +69,8 @@ def format_value(value): return f"{value:.8f}" return str(value) + print(field_dict) + return ", ".join( f"{key}={format_value(value)}" if value is not None else f"{key}=NA" for key, value in field_dict.items() @@ -79,3 +81,21 @@ def __repr__(self) -> str: def __format__(self, format_spec: str) -> str: return self.to_json() if format_spec == "json" else self.__str__() + + cls_methods = { + "to_dict": to_dict, + "to_json": to_json, + "from_dict": from_dict, + "from_json": from_json, + "from_list": from_list, + } + + for method_name, method in cls_methods.items(): + if not hasattr(cls, method_name): + setattr(cls, method_name, method) + + cls.__str__ = __str__ + cls.__repr__ = __repr__ + cls.__format__ = __format__ + + return cls diff --git a/core/models/entity/bar.py b/core/models/entity/bar.py index 065440e8..c95ad01d 100644 --- a/core/models/entity/bar.py +++ b/core/models/entity/bar.py @@ -1,10 +1,8 @@ -from dataclasses import dataclass - from ._base import Entity from .ohlcv import OHLCV -@dataclass(frozen=True) -class Bar(Entity): +@Entity +class Bar: ohlcv: OHLCV closed: bool diff --git a/core/models/entity/ohlcv.py b/core/models/entity/ohlcv.py index fcb74696..7c958bbf 100644 --- a/core/models/entity/ohlcv.py +++ b/core/models/entity/ohlcv.py @@ -1,4 +1,3 @@ -from dataclasses import dataclass from typing import Any, Dict, List from core.models.candle_type import CandleType @@ -6,8 +5,8 @@ from ._base import Entity -@dataclass(frozen=True) -class OHLCV(Entity): +@Entity +class OHLCV: timestamp: int open: float high: float diff --git a/core/models/entity/order.py b/core/models/entity/order.py index 9e3a988a..43280304 100644 --- a/core/models/entity/order.py +++ b/core/models/entity/order.py @@ -1,5 +1,5 @@ import uuid -from dataclasses import dataclass, field +from dataclasses import field from datetime import datetime from core.models.order_type import OrderStatus, OrderType @@ -7,8 +7,8 @@ from ._base import Entity -@dataclass(frozen=True) -class Order(Entity): +@Entity +class Order: status: OrderStatus price: float size: float