Skip to content

Commit

Permalink
introduce more Enum types; move account_number from argument to field…
Browse files Browse the repository at this point in the history
… of TradingApi instance
  • Loading branch information
hzheng committed May 23, 2024
1 parent 5fedc1b commit f3a5a52
Show file tree
Hide file tree
Showing 5 changed files with 238 additions and 53 deletions.
61 changes: 34 additions & 27 deletions pyschwab/trading.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from .utils import format_params, request, time_to_str, to_json_str
from .trading_models import Order, SecuritiesAccount, TradingData, Transaction, UserPreference
from .types import OrderStatus, TransactionType
from .types import AssetType, MarketSession, OrderDuration, OrderInstruction, OrderStatus, OrderStrategyType, OrderType, TransactionType


"""
Expand All @@ -24,39 +24,49 @@ def __init__(self, access_token: str, trading_config: Dict[str, Any]):
response = request(f'{self.base_trader_url}/accounts/accountNumbers', headers=self.auth).json()
self.accounts_hash = {account.get('accountNumber'): account.get('hashValue') for account in response}
self.accounts: Dict[str, SecuritiesAccount] = {}
self.current_account_num = None

def get_accounts_hash(self) -> Dict[str, str]:
return self.accounts_hash

def get_account_hash(self, account_number: str) -> str:
return self.accounts_hash.get(account_number, None)
def set_current_account_number(self, account_number: str) -> None:
self.current_account_num = account_number

def get_current_account_number(self) -> str:
return self.current_account_num

def get_account_hash(self, account_number: str=None) -> str:
return self.accounts_hash.get(account_number or self.current_account_num, None)

def get_accounts(self) -> Dict[str, SecuritiesAccount]:
return self.accounts

def get_account(self, account_number: str) -> SecuritiesAccount:
return self.accounts.get(account_number, None)
def get_account(self, account_number: str=None) -> SecuritiesAccount:
return self.accounts.get(account_number or self.current_account_num, None)

def _get_account_hash(self, account_num: int | str=None) -> str:
account_num = account_num or self.current_account_num
if not account_num:
raise ValueError("Account number not set")

def _get_account_hash(self, account_num: int | str) -> str:
account_hash = self.accounts_hash.get(str(account_num), None)
if account_hash is None:
raise ValueError(f"Account number {account_num} not found.")
return account_hash

def get_user_preference(self):
response = request(f'{self.base_trader_url}/userPreference', headers=self.auth).json()
return UserPreference.from_dict(response)

def fetch_trading_data(self, account_num: str, include_pos: bool=True) -> TradingData:
account_hash = self._get_account_hash(account_num)
def fetch_trading_data(self, include_pos: bool=True) -> TradingData:
account_hash = self._get_account_hash()
params = {'fields': ['positions'] if include_pos else []}
resp = request(f'{self.base_trader_url}/accounts/{account_hash}', headers=self.auth, params=params).json()
trading_data = TradingData.from_dict(resp)
account = trading_data.account
self.accounts[account.account_number] = account
return trading_data


def fetch_all_trading_data(self, include_pos: bool=True) -> Dict[str, TradingData]:
params = {'fields': ['positions'] if include_pos else []}
trading_data_map = {}
Expand All @@ -76,45 +86,42 @@ def get_all_orders(self, start_time: datetime=None, end_time: datetime=None, sta
orders = request(f'{self.base_trader_url}/orders', headers=self.auth, params=format_params(params)).json()
return [Order.from_dict(order) for order in orders]

def get_orders(self, account_num: str, start_time: datetime=None, end_time: datetime=None, status: OrderStatus=None, max_results: int=100) -> List[Order]:
account_hash = self._get_account_hash(account_num)
def get_orders(self, start_time: datetime=None, end_time: datetime=None, status: OrderStatus=None, max_results: int=100) -> List[Order]:
account_hash = self._get_account_hash()
now = datetime.now()
start = time_to_str(start_time or now - timedelta(days=30))
end = time_to_str(end_time or now)
params = {'maxResults': max_results, 'fromEnteredTime': start, 'toEnteredTime': end, 'status': status}
orders = request(f'{self.base_trader_url}/accounts/{account_hash}/orders', headers=self.auth, params=format_params(params)).json()
return [Order.from_dict(order) for order in orders]

def get_order(self, account_num: int | str, order_id: str) -> Order:
def get_order(self, order_id: str, account_num: int | str=None) -> Order:
account_hash = self._get_account_hash(account_num)
order = request(f'{self.base_trader_url}/accounts/{account_hash}/orders/{order_id}', headers=self.auth).json()
return Order.from_dict(order)

def place_order(self, order: Dict[str, Any] | Order, account_num: int | str) -> None:
account_hash = self._get_account_hash(account_num)
def place_order(self, order: Dict[str, Any] | Order) -> None:
account_hash = self._get_account_hash()
order_dict = self._convert_order(order)
request(f'{self.base_trader_url}/accounts/{account_hash}/orders', method='POST', headers=self.auth, json=order_dict)

def cancel_order(self, order_id: int | str, account_num: int | str) -> None:
account_hash = self._get_account_hash(account_num)
def cancel_order(self, order_id: int | str) -> None:
account_hash = self._get_account_hash()
request(f'{self.base_trader_url}/accounts/{account_hash}/orders/{order_id}', method='DELETE', headers=self.auth)

def replace_order(self, order: Dict[str, Any] | Order) -> None:
order_dict = self._convert_order(order)
order_id = order_dict.get('orderId', None)
if not order_id:
raise ValueError("Order ID not found in order data.")
account_num = order_dict.get('accountNumber', None)
if not account_num:
raise ValueError("Account number not found in order data.")

account_hash = self._get_account_hash(account_num)
account_hash = self._get_account_hash()
order_json = to_json_str(order_dict)
request(f'{self.base_trader_url}/accounts/{account_hash}/orders/{order_id}', method='PUT', headers=self.auth2, data=order_json)

def preview_order(self, order: Dict[str, Any] | Order, account_num: int | str) -> Dict[str, Any]:
def preview_order(self, order: Dict[str, Any] | Order) -> Dict[str, Any]:
"""Coming Soon as per official document"""
account_hash = self._get_account_hash(account_num)
account_hash = self._get_account_hash()
order_dict = self._convert_order(order)
order_json = to_json_str(order_dict)
return request(f'{self.base_trader_url}/accounts/{account_hash}/previewOrder', method='POST', headers=self.auth2, data=order_json).json()
Expand All @@ -128,16 +135,16 @@ def _convert_order(self, order: Dict[str, Any] | Order) -> Dict[str, Any]:

raise ValueError("Order must be a dictionary or Order object.")

def get_transactions(self, account_num: str, start_time: datetime=None, end_time: datetime=None, symbol: str=None, types: TransactionType=TransactionType.TRADE) -> List[Transaction]:
account_hash = self._get_account_hash(account_num)
def get_transactions(self, start_time: datetime=None, end_time: datetime=None, symbol: str=None, types: TransactionType=TransactionType.TRADE) -> List[Transaction]:
account_hash = self._get_account_hash()
now = datetime.now()
start = time_to_str(start_time or now - timedelta(days=30))
end = time_to_str(end_time or now)
params = {'startDate': start, 'endDate': end, 'symbol': symbol, 'types': types}
transactions = request(f'{self.base_trader_url}/accounts/{account_hash}/transactions', headers=self.auth, params=format_params(params)).json()
return [Transaction.from_dict(transaction) for transaction in transactions]

def get_transaction(self, account_num: int | str, transaction_id: str) -> Transaction:
account_hash = self._get_account_hash(account_num)
def get_transaction(self, transaction_id: str) -> Transaction:
account_hash = self._get_account_hash()
transaction = request(f'{self.base_trader_url}/accounts/{account_hash}/transactions/{transaction_id}', headers=self.auth).json()
return Transaction.from_dict(transaction)
49 changes: 33 additions & 16 deletions pyschwab/trading_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
from datetime import datetime
from typing import Any, Dict, List, Optional

from .types import AssetType, ComplexOrderStrategyType, MarketSession, OptionActionType, \
OptionAssetType, OrderDuration, OrderInstruction, OrderStatus, OrderStrategyType, OrderType, \
PositionEffect, QuantityType, RequestedDestination
from .utils import camel_to_snake, dataclass_to_dict, to_time


Expand Down Expand Up @@ -36,7 +39,7 @@ def from_dict(cls, data: Dict[str, Any]) -> 'SecuritiesAccount':

@dataclass
class Deliverable:
asset_type: str
asset_type: OptionAssetType
status: str
symbol: str
instrument_id: int
Expand All @@ -48,6 +51,7 @@ def from_dict(cls, data: Dict[str, Any]) -> 'Deliverable':
if data is None:
return None
converted_data = {camel_to_snake(key): value for key, value in data.items()}
converted_data['asset_type'] = OptionAssetType.from_str(converted_data['asset_type'])
return cls(**converted_data)


Expand All @@ -74,14 +78,14 @@ class Instrument:
Represents the financial instrument in a trading position.
Attributes:
asset_type (str): The type of the asset, e.g., 'EQUITY'.
asset_type (AssetType): The type of the asset, e.g., AssetType.EQUITY.
symbol (str): The trading symbol for the instrument.
cusip (str): The CUSIP identifier for the instrument.
net_change (float): The net change in the instrument's value since the last close.
instrument_id (int): The id of the instrument
"""
asset_type: str
symbol: str
asset_type: AssetType = AssetType.EQUITY
cusip: str = None
net_change: float = 0.0
instrument_id: int = 0
Expand All @@ -92,7 +96,7 @@ class Instrument:
expiration_date: datetime = None
option_deliverables: List[Deliverable] = None
option_premium_multiplier: float = 0.0
put_call: str = None
put_call: OptionActionType = OptionActionType.CALL
strike_price: float = None
underlying_symbol: str = None
underlying_cusip: str = None
Expand All @@ -109,6 +113,8 @@ def from_dict(cls, data: Dict[str, Any]) -> 'Instrument':
deliverables = converted_data.get('option_deliverables', None)
if deliverables:
converted_data['option_deliverables'] = [OptionDeliverable.from_dict(deliverable) for deliverable in converted_data['option_deliverables']]
converted_data['asset_type'] = AssetType.from_str(converted_data['asset_type'])
converted_data['put_call'] = OptionActionType.from_str(converted_data.get('put_call', None))
return cls(**converted_data)


Expand Down Expand Up @@ -302,39 +308,44 @@ def from_dict(cls, data: Dict[str, Any]) -> 'OrderActivity':
@dataclass
class OrderLeg:
instrument: Instrument
instruction: str
instruction: OrderInstruction
quantity: int
position_effect: str = 'OPENING'
order_leg_type: str = 'EQUITY'
leg_id: int = 0
quantity_type: str = None
position_effect: PositionEffect = PositionEffect.OPENING
order_leg_type: AssetType = AssetType.EQUITY
quantity_type: QuantityType = QuantityType.ALL_SHARES
leg_id: int = None
div_cap_gains: str = None
to_symbol: str = None

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> 'OrderLeg':
converted_data = {camel_to_snake(key): value for key, value in data.items()}
converted_data['instrument'] = Instrument.from_dict(converted_data['instrument'])
converted_data['instruction'] = OrderInstruction.from_str(converted_data['instruction'])
converted_data['position_effect'] = PositionEffect.from_str(converted_data.get('position_effect', None))
converted_data['order_leg_type'] = AssetType.from_str(converted_data.get('order_leg_type', None))
converted_data['quantity_type'] = QuantityType.from_str(converted_data.get('quantity_type', None))
return cls(**converted_data)


@dataclass
class Order:
order_type: str
session: str
price: float
duration: str
entered_time: datetime
close_time: datetime
release_time: datetime
order_type: OrderType = OrderType.LIMIT
duration: OrderDuration = OrderDuration.DAY
session: MarketSession = MarketSession.NORMAL
order_strategy_type: OrderStrategyType = OrderStrategyType.SINGLE
order_id: int = 0
account_number: int = 0
status: str = 'AWAITING_PARENT_ORDER'
status: OrderStatus = OrderStatus.AWAITING_PARENT_ORDER
quantity: int = 0
filled_quantity: int = 0
remaining_quantity: int = 0
complex_order_strategy_type: str = "NONE"
requested_destination: str = "AUTO"
complex_order_strategy_type: ComplexOrderStrategyType = ComplexOrderStrategyType.NONE
requested_destination: RequestedDestination = RequestedDestination.AUTO
destination_link_name: str = None
order_leg_collection: List[OrderLeg] = None
order_activity_collection: List[OrderActivity] = None
Expand All @@ -348,7 +359,6 @@ class Order:
tax_lot_method: str = None
activation_price: float = 0.0
special_instruction: str = None
order_strategy_type: str = None
cancelable: bool = False
editable: bool = False
cancel_time: datetime = None
Expand All @@ -362,6 +372,12 @@ def from_dict(cls, data: Dict[str, Any]) -> 'Order':
converted_data[key] = to_time(converted_data.get(key, None))
converted_data['order_leg_collection'] = [OrderLeg.from_dict(leg) for leg in converted_data.get('order_leg_collection', [])]
converted_data['order_activity_collection'] = [OrderActivity.from_dict(activity) for activity in converted_data.get('order_activity_collection', [])]
converted_data['order_type'] = OrderType.from_str(converted_data['order_type'])
converted_data['duration'] = OrderDuration.from_str(converted_data['duration'])
converted_data['session'] = MarketSession.from_str(converted_data['session'])
converted_data['order_strategy_type'] = OrderStrategyType.from_str(converted_data['order_strategy_type'])
converted_data['status'] = OrderStatus.from_str(converted_data.get('status', None))
converted_data['complex_order_strategy_type'] = ComplexOrderStrategyType.from_str(converted_data.get('complex_order_strategy_type', None))
return cls(**converted_data)

def to_dict(self, clean_keys: bool=False) -> Dict[str, Any]:
Expand All @@ -374,6 +390,7 @@ def to_dict(self, clean_keys: bool=False) -> Dict[str, Any]:
'destinationLinkName', 'requestedDestination', # destination
'status', 'statusDescription', # status
'tag', # tag
'orderActivityCollection',
]:
del order_dict[key]
return order_dict
Expand Down
Loading

0 comments on commit f3a5a52

Please sign in to comment.