Skip to content

Commit

Permalink
Add rounding service
Browse files Browse the repository at this point in the history
  • Loading branch information
MDUYN committed Apr 4, 2024
1 parent 6aa127d commit 95cc0bf
Show file tree
Hide file tree
Showing 6 changed files with 65 additions and 42 deletions.
35 changes: 11 additions & 24 deletions investing_algorithm_framework/app/algorithm.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import decimal
import inspect
import logging
from typing import List
import inspect

from investing_algorithm_framework.domain import OrderStatus, OrderFee, \
Position, Order, Portfolio, OrderType, OrderSide, \
BACKTESTING_FLAG, BACKTESTING_INDEX_DATETIME, MarketService, TimeUnit, \
OperationalException, random_string
OperationalException, random_string, RoundingService
from investing_algorithm_framework.services import MarketCredentialService, \
MarketDataSourceService, PortfolioService, PositionService, TradeService, \
OrderService, ConfigurationService, StrategyOrchestratorService, \
Expand Down Expand Up @@ -218,7 +217,7 @@ def create_limit_order(
amount = position.get_amount() * (percentage_of_position / 100)

if precision is not None:
amount = self.round_down(amount, precision)
amount = RoundingService.round_down(amount, precision)

order_data = {
"target_symbol": target_symbol,
Expand Down Expand Up @@ -594,7 +593,9 @@ def get_position_percentage_of_portfolio_by_net_size(
net_size = portfolio.get_net_size()
return (position.cost / net_size) * 100

def close_position(self, symbol, market=None, identifier=None):
def close_position(
self, symbol, market=None, identifier=None, precision=None
):
portfolio = self.portfolio_service.find(
{"market": market, "identifier": identifier}
)
Expand Down Expand Up @@ -623,6 +624,7 @@ def close_position(self, symbol, market=None, identifier=None):
amount=position.get_amount(),
order_side=OrderSide.SELL.value,
price=ticker["bid"],
precision=precision,
)

def add_strategies(self, strategies):
Expand Down Expand Up @@ -886,28 +888,13 @@ def get_trades(self, market=None):
def get_closed_trades(self):
return self.trade_service.get_closed_trades()

def round_down(self, value, amount_of_decimals):

if self.count_decimals(value) <= amount_of_decimals:
return value

with decimal.localcontext() as ctx:
d = decimal.Decimal(value)
ctx.rounding = decimal.ROUND_DOWN
return float(round(d, amount_of_decimals))

def count_decimals(self, number):
decimal_str = str(number)
if '.' in decimal_str:
return len(decimal_str.split('.')[1])
else:
return 0

def get_open_trades(self, target_symbol=None, market=None):
return self.trade_service.get_open_trades(target_symbol, market)

def close_trade(self, trade, market=None):
self.trade_service.close_trade(trade, market)
def close_trade(self, trade, market=None, precision=None) -> None:
self.trade_service.close_trade(
trade=trade, market=market, precision=precision
)

def get_number_of_positions(self):
"""
Expand Down
6 changes: 4 additions & 2 deletions investing_algorithm_framework/domain/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@
from .decimal_parsing import parse_decimal_to_string, parse_string_to_decimal
from .services import TickerMarketDataSource, OrderBookMarketDataSource, \
OHLCVMarketDataSource, BacktestMarketDataSource, MarketDataSource, \
MarketService, MarketCredentialService, AbstractPortfolioSyncService
MarketService, MarketCredentialService, AbstractPortfolioSyncService, \
RoundingService
from .data_structures import PeekableQueue

__all__ = [
Expand Down Expand Up @@ -109,5 +110,6 @@
"RESERVED_BALANCES",
"AbstractPortfolioSyncService",
"APP_MODE",
"AppMode"
"AppMode",
"RoundingService",
]
4 changes: 3 additions & 1 deletion investing_algorithm_framework/domain/services/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from .market_service import MarketService
from .market_credential_service import MarketCredentialService
from .portfolios import AbstractPortfolioSyncService
from .rounding_service import RoundingService

__all__ = [
"MarketDataSource",
Expand All @@ -12,5 +13,6 @@
"BacktestMarketDataSource",
"MarketService",
"MarketCredentialService",
"AbstractPortfolioSyncService"
"AbstractPortfolioSyncService",
"RoundingService",
]
27 changes: 27 additions & 0 deletions investing_algorithm_framework/domain/services/rounding_service.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import decimal


class RoundingService:
"""
Service to round numbers to a certain amount of decimals.
It will always round down.
"""

@staticmethod
def round_down(value, amount_of_decimals):

if RoundingService.count_decimals(value) <= amount_of_decimals:
return value

with decimal.localcontext() as ctx:
d = decimal.Decimal(value)
ctx.rounding = decimal.ROUND_DOWN
return float(round(d, amount_of_decimals))

@staticmethod
def count_decimals(number):
decimal_str = str(number)
if '.' in decimal_str:
return len(decimal_str.split('.')[1])
else:
return 0
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import logging
from typing import List
from queue import PriorityQueue
from typing import List

from investing_algorithm_framework.domain import OrderStatus, OrderSide, \
Trade, PeekableQueue, OrderType, TradeStatus, \
OperationalException, Order
from investing_algorithm_framework.services.position_service import \
PositionService
OperationalException, Order, RoundingService
from investing_algorithm_framework.services.market_data_source_service import \
MarketDataSourceService
from investing_algorithm_framework.services.position_service import \
PositionService

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -199,7 +199,7 @@ def get_closed_trades(self, portfolio_id=None) -> List[Trade]:
if order.get_trade_closed_at() is not None
]

def close_trade(self, trade, market=None) -> None:
def close_trade(self, trade, market=None, precision=None) -> None:
"""
Close trade method
Expand All @@ -210,6 +210,7 @@ def close_trade(self, trade, market=None) -> None:
return: None
"""

if trade.closed_at is not None:
raise OperationalException("Trade already closed.")

Expand All @@ -227,6 +228,9 @@ def close_trade(self, trade, market=None) -> None:
)
amount = order.get_amount()

if precision is not None:
amount = RoundingService.round_down(amount, precision)

if position.get_amount() < amount:
logger.warning(
f"Order amount {amount} is larger then amount "
Expand Down
21 changes: 11 additions & 10 deletions tests/app/algorithm/test_round_down.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from investing_algorithm_framework import create_app, RESOURCE_DIRECTORY, \
PortfolioConfiguration, Algorithm, MarketCredential
from investing_algorithm_framework.domain import RoundingService
from tests.resources import TestBase, MarketServiceStub


Expand Down Expand Up @@ -49,52 +50,52 @@ def setUp(self) -> None:
self.app.initialize()

def test_round_down(self):
new_value = self.app.algorithm.round_down(1, 3)
new_value = RoundingService.round_down(1, 3)
self.assertEqual(
0, self.count_decimals(new_value)
)
self.assertEqual(1, new_value)
new_value = self.app.algorithm.round_down(1.23456789, 2)
new_value = RoundingService.round_down(1.23456789, 2)
self.assertEqual(
2, self.count_decimals(new_value)
)
self.assertEqual(1.23, new_value)
new_value = self.app.algorithm.round_down(1.987654321, 3)
new_value = RoundingService.round_down(1.987654321, 3)
self.assertEqual(
3, self.count_decimals(new_value)
)
self.assertEqual(1.987, new_value)
new_value = self.app.algorithm.round_down(1.987654321, 4)
new_value = RoundingService.round_down(1.987654321, 4)
self.assertEqual(
4, self.count_decimals(new_value)
)
self.assertEqual(1.9876, new_value)
new_value = self.app.algorithm.round_down(1.987654321, 5)
new_value = RoundingService.round_down(1.987654321, 5)
self.assertEqual(
5, self.count_decimals(new_value)
)
self.assertEqual(1.98765, new_value)
new_value = self.app.algorithm.round_down(1.987654321, 6)
new_value = RoundingService.round_down(1.987654321, 6)
self.assertEqual(
6, self.count_decimals(new_value)
)
self.assertEqual(1.987654, new_value)
new_value = self.app.algorithm.round_down(1.987654321, 7)
new_value = RoundingService.round_down(1.987654321, 7)
self.assertEqual(
7, self.count_decimals(new_value)
)
self.assertEqual(1.9876543, new_value)
new_value = self.app.algorithm.round_down(1.987654321, 8)
new_value = RoundingService.round_down(1.987654321, 8)
self.assertEqual(
8, self.count_decimals(new_value)
)
self.assertEqual(1.98765432, new_value)
new_value = self.app.algorithm.round_down(1.987654321, 9)
new_value = RoundingService.round_down(1.987654321, 9)
self.assertEqual(
9, self.count_decimals(new_value)
)
self.assertEqual(1.987654321, new_value)
new_value = self.app.algorithm.round_down(1.987654321, 10)
new_value = RoundingService.round_down(1.987654321, 10)
self.assertEqual(
9, self.count_decimals(new_value)
)
Expand Down

0 comments on commit 95cc0bf

Please sign in to comment.