diff --git a/pyschwab/trading.py b/pyschwab/trading.py index 1c16f98..33d3f66 100644 --- a/pyschwab/trading.py +++ b/pyschwab/trading.py @@ -5,7 +5,7 @@ from .utils import format_params, request, time_to_str, to_json_str from .trading_models import Order, SecuritiesAccount, TradingData, Transaction, UserPreference -from .types import OrderStatus, TransactionType +from .types import AssetType, MarketSession, OrderDuration, OrderInstruction, OrderStatus, OrderStrategyType, OrderType, TransactionType """ @@ -24,31 +24,42 @@ def __init__(self, access_token: str, trading_config: Dict[str, Any]): response = request(f'{self.base_trader_url}/accounts/accountNumbers', headers=self.auth).json() self.accounts_hash = {account.get('accountNumber'): account.get('hashValue') for account in response} self.accounts: Dict[str, SecuritiesAccount] = {} + self.current_account_num = None def get_accounts_hash(self) -> Dict[str, str]: return self.accounts_hash - def get_account_hash(self, account_number: str) -> str: - return self.accounts_hash.get(account_number, None) + def set_current_account_number(self, account_number: str) -> None: + self.current_account_num = account_number + + def get_current_account_number(self) -> str: + return self.current_account_num + + def get_account_hash(self, account_number: str=None) -> str: + return self.accounts_hash.get(account_number or self.current_account_num, None) def get_accounts(self) -> Dict[str, SecuritiesAccount]: return self.accounts - def get_account(self, account_number: str) -> SecuritiesAccount: - return self.accounts.get(account_number, None) + def get_account(self, account_number: str=None) -> SecuritiesAccount: + return self.accounts.get(account_number or self.current_account_num, None) + + def _get_account_hash(self, account_num: int | str=None) -> str: + account_num = account_num or self.current_account_num + if not account_num: + raise ValueError("Account number not set") - def _get_account_hash(self, account_num: int | str) -> str: account_hash = self.accounts_hash.get(str(account_num), None) if account_hash is None: raise ValueError(f"Account number {account_num} not found.") return account_hash - + def get_user_preference(self): response = request(f'{self.base_trader_url}/userPreference', headers=self.auth).json() return UserPreference.from_dict(response) - def fetch_trading_data(self, account_num: str, include_pos: bool=True) -> TradingData: - account_hash = self._get_account_hash(account_num) + def fetch_trading_data(self, include_pos: bool=True) -> TradingData: + account_hash = self._get_account_hash() params = {'fields': ['positions'] if include_pos else []} resp = request(f'{self.base_trader_url}/accounts/{account_hash}', headers=self.auth, params=params).json() trading_data = TradingData.from_dict(resp) @@ -56,7 +67,6 @@ def fetch_trading_data(self, account_num: str, include_pos: bool=True) -> Tradin self.accounts[account.account_number] = account return trading_data - def fetch_all_trading_data(self, include_pos: bool=True) -> Dict[str, TradingData]: params = {'fields': ['positions'] if include_pos else []} trading_data_map = {} @@ -76,8 +86,8 @@ def get_all_orders(self, start_time: datetime=None, end_time: datetime=None, sta orders = request(f'{self.base_trader_url}/orders', headers=self.auth, params=format_params(params)).json() return [Order.from_dict(order) for order in orders] - def get_orders(self, account_num: str, start_time: datetime=None, end_time: datetime=None, status: OrderStatus=None, max_results: int=100) -> List[Order]: - account_hash = self._get_account_hash(account_num) + def get_orders(self, start_time: datetime=None, end_time: datetime=None, status: OrderStatus=None, max_results: int=100) -> List[Order]: + account_hash = self._get_account_hash() now = datetime.now() start = time_to_str(start_time or now - timedelta(days=30)) end = time_to_str(end_time or now) @@ -85,18 +95,18 @@ def get_orders(self, account_num: str, start_time: datetime=None, end_time: date orders = request(f'{self.base_trader_url}/accounts/{account_hash}/orders', headers=self.auth, params=format_params(params)).json() return [Order.from_dict(order) for order in orders] - def get_order(self, account_num: int | str, order_id: str) -> Order: + def get_order(self, order_id: str, account_num: int | str=None) -> Order: account_hash = self._get_account_hash(account_num) order = request(f'{self.base_trader_url}/accounts/{account_hash}/orders/{order_id}', headers=self.auth).json() return Order.from_dict(order) - def place_order(self, order: Dict[str, Any] | Order, account_num: int | str) -> None: - account_hash = self._get_account_hash(account_num) + def place_order(self, order: Dict[str, Any] | Order) -> None: + account_hash = self._get_account_hash() order_dict = self._convert_order(order) request(f'{self.base_trader_url}/accounts/{account_hash}/orders', method='POST', headers=self.auth, json=order_dict) - def cancel_order(self, order_id: int | str, account_num: int | str) -> None: - account_hash = self._get_account_hash(account_num) + def cancel_order(self, order_id: int | str) -> None: + account_hash = self._get_account_hash() request(f'{self.base_trader_url}/accounts/{account_hash}/orders/{order_id}', method='DELETE', headers=self.auth) def replace_order(self, order: Dict[str, Any] | Order) -> None: @@ -104,17 +114,14 @@ def replace_order(self, order: Dict[str, Any] | Order) -> None: order_id = order_dict.get('orderId', None) if not order_id: raise ValueError("Order ID not found in order data.") - account_num = order_dict.get('accountNumber', None) - if not account_num: - raise ValueError("Account number not found in order data.") - account_hash = self._get_account_hash(account_num) + account_hash = self._get_account_hash() order_json = to_json_str(order_dict) request(f'{self.base_trader_url}/accounts/{account_hash}/orders/{order_id}', method='PUT', headers=self.auth2, data=order_json) - def preview_order(self, order: Dict[str, Any] | Order, account_num: int | str) -> Dict[str, Any]: + def preview_order(self, order: Dict[str, Any] | Order) -> Dict[str, Any]: """Coming Soon as per official document""" - account_hash = self._get_account_hash(account_num) + account_hash = self._get_account_hash() order_dict = self._convert_order(order) order_json = to_json_str(order_dict) return request(f'{self.base_trader_url}/accounts/{account_hash}/previewOrder', method='POST', headers=self.auth2, data=order_json).json() @@ -128,8 +135,8 @@ def _convert_order(self, order: Dict[str, Any] | Order) -> Dict[str, Any]: raise ValueError("Order must be a dictionary or Order object.") - def get_transactions(self, account_num: str, start_time: datetime=None, end_time: datetime=None, symbol: str=None, types: TransactionType=TransactionType.TRADE) -> List[Transaction]: - account_hash = self._get_account_hash(account_num) + def get_transactions(self, start_time: datetime=None, end_time: datetime=None, symbol: str=None, types: TransactionType=TransactionType.TRADE) -> List[Transaction]: + account_hash = self._get_account_hash() now = datetime.now() start = time_to_str(start_time or now - timedelta(days=30)) end = time_to_str(end_time or now) @@ -137,7 +144,7 @@ def get_transactions(self, account_num: str, start_time: datetime=None, end_time transactions = request(f'{self.base_trader_url}/accounts/{account_hash}/transactions', headers=self.auth, params=format_params(params)).json() return [Transaction.from_dict(transaction) for transaction in transactions] - def get_transaction(self, account_num: int | str, transaction_id: str) -> Transaction: - account_hash = self._get_account_hash(account_num) + def get_transaction(self, transaction_id: str) -> Transaction: + account_hash = self._get_account_hash() transaction = request(f'{self.base_trader_url}/accounts/{account_hash}/transactions/{transaction_id}', headers=self.auth).json() return Transaction.from_dict(transaction) diff --git a/pyschwab/trading_models.py b/pyschwab/trading_models.py index 3fa71c1..282ef04 100644 --- a/pyschwab/trading_models.py +++ b/pyschwab/trading_models.py @@ -2,6 +2,9 @@ from datetime import datetime from typing import Any, Dict, List, Optional +from .types import AssetType, ComplexOrderStrategyType, MarketSession, OptionActionType, \ + OptionAssetType, OrderDuration, OrderInstruction, OrderStatus, OrderStrategyType, OrderType, \ + PositionEffect, QuantityType, RequestedDestination from .utils import camel_to_snake, dataclass_to_dict, to_time @@ -36,7 +39,7 @@ def from_dict(cls, data: Dict[str, Any]) -> 'SecuritiesAccount': @dataclass class Deliverable: - asset_type: str + asset_type: OptionAssetType status: str symbol: str instrument_id: int @@ -48,6 +51,7 @@ def from_dict(cls, data: Dict[str, Any]) -> 'Deliverable': if data is None: return None converted_data = {camel_to_snake(key): value for key, value in data.items()} + converted_data['asset_type'] = OptionAssetType.from_str(converted_data['asset_type']) return cls(**converted_data) @@ -74,14 +78,14 @@ class Instrument: Represents the financial instrument in a trading position. Attributes: - asset_type (str): The type of the asset, e.g., 'EQUITY'. + asset_type (AssetType): The type of the asset, e.g., AssetType.EQUITY. symbol (str): The trading symbol for the instrument. cusip (str): The CUSIP identifier for the instrument. net_change (float): The net change in the instrument's value since the last close. instrument_id (int): The id of the instrument """ - asset_type: str symbol: str + asset_type: AssetType = AssetType.EQUITY cusip: str = None net_change: float = 0.0 instrument_id: int = 0 @@ -92,7 +96,7 @@ class Instrument: expiration_date: datetime = None option_deliverables: List[Deliverable] = None option_premium_multiplier: float = 0.0 - put_call: str = None + put_call: OptionActionType = OptionActionType.CALL strike_price: float = None underlying_symbol: str = None underlying_cusip: str = None @@ -109,6 +113,8 @@ def from_dict(cls, data: Dict[str, Any]) -> 'Instrument': deliverables = converted_data.get('option_deliverables', None) if deliverables: converted_data['option_deliverables'] = [OptionDeliverable.from_dict(deliverable) for deliverable in converted_data['option_deliverables']] + converted_data['asset_type'] = AssetType.from_str(converted_data['asset_type']) + converted_data['put_call'] = OptionActionType.from_str(converted_data.get('put_call', None)) return cls(**converted_data) @@ -302,12 +308,12 @@ def from_dict(cls, data: Dict[str, Any]) -> 'OrderActivity': @dataclass class OrderLeg: instrument: Instrument - instruction: str + instruction: OrderInstruction quantity: int - position_effect: str = 'OPENING' - order_leg_type: str = 'EQUITY' - leg_id: int = 0 - quantity_type: str = None + position_effect: PositionEffect = PositionEffect.OPENING + order_leg_type: AssetType = AssetType.EQUITY + quantity_type: QuantityType = QuantityType.ALL_SHARES + leg_id: int = None div_cap_gains: str = None to_symbol: str = None @@ -315,26 +321,31 @@ class OrderLeg: def from_dict(cls, data: Dict[str, Any]) -> 'OrderLeg': converted_data = {camel_to_snake(key): value for key, value in data.items()} converted_data['instrument'] = Instrument.from_dict(converted_data['instrument']) + converted_data['instruction'] = OrderInstruction.from_str(converted_data['instruction']) + converted_data['position_effect'] = PositionEffect.from_str(converted_data.get('position_effect', None)) + converted_data['order_leg_type'] = AssetType.from_str(converted_data.get('order_leg_type', None)) + converted_data['quantity_type'] = QuantityType.from_str(converted_data.get('quantity_type', None)) return cls(**converted_data) @dataclass class Order: - order_type: str - session: str price: float - duration: str entered_time: datetime close_time: datetime release_time: datetime + order_type: OrderType = OrderType.LIMIT + duration: OrderDuration = OrderDuration.DAY + session: MarketSession = MarketSession.NORMAL + order_strategy_type: OrderStrategyType = OrderStrategyType.SINGLE order_id: int = 0 account_number: int = 0 - status: str = 'AWAITING_PARENT_ORDER' + status: OrderStatus = OrderStatus.AWAITING_PARENT_ORDER quantity: int = 0 filled_quantity: int = 0 remaining_quantity: int = 0 - complex_order_strategy_type: str = "NONE" - requested_destination: str = "AUTO" + complex_order_strategy_type: ComplexOrderStrategyType = ComplexOrderStrategyType.NONE + requested_destination: RequestedDestination = RequestedDestination.AUTO destination_link_name: str = None order_leg_collection: List[OrderLeg] = None order_activity_collection: List[OrderActivity] = None @@ -348,7 +359,6 @@ class Order: tax_lot_method: str = None activation_price: float = 0.0 special_instruction: str = None - order_strategy_type: str = None cancelable: bool = False editable: bool = False cancel_time: datetime = None @@ -362,6 +372,12 @@ def from_dict(cls, data: Dict[str, Any]) -> 'Order': converted_data[key] = to_time(converted_data.get(key, None)) converted_data['order_leg_collection'] = [OrderLeg.from_dict(leg) for leg in converted_data.get('order_leg_collection', [])] converted_data['order_activity_collection'] = [OrderActivity.from_dict(activity) for activity in converted_data.get('order_activity_collection', [])] + converted_data['order_type'] = OrderType.from_str(converted_data['order_type']) + converted_data['duration'] = OrderDuration.from_str(converted_data['duration']) + converted_data['session'] = MarketSession.from_str(converted_data['session']) + converted_data['order_strategy_type'] = OrderStrategyType.from_str(converted_data['order_strategy_type']) + converted_data['status'] = OrderStatus.from_str(converted_data.get('status', None)) + converted_data['complex_order_strategy_type'] = ComplexOrderStrategyType.from_str(converted_data.get('complex_order_strategy_type', None)) return cls(**converted_data) def to_dict(self, clean_keys: bool=False) -> Dict[str, Any]: @@ -374,6 +390,7 @@ def to_dict(self, clean_keys: bool=False) -> Dict[str, Any]: 'destinationLinkName', 'requestedDestination', # destination 'status', 'statusDescription', # status 'tag', # tag + 'orderActivityCollection', ]: del order_dict[key] return order_dict diff --git a/pyschwab/types.py b/pyschwab/types.py index 78622a8..8fd1a54 100644 --- a/pyschwab/types.py +++ b/pyschwab/types.py @@ -1,5 +1,5 @@ from enum import Enum, auto -from typing import Optional +from typing import Optional, Type, TypeVar from pydantic import BaseModel, Field, field_validator, model_validator @@ -96,11 +96,23 @@ def validate_frequency(cls, v, info): return v +T = TypeVar('T', bound='AutoName') + + class AutoName(Enum): @staticmethod def _generate_next_value_(name, start, count, last_values): return name + @classmethod + def from_str(cls: Type[T], s: str | T) -> T: + if s is None or isinstance(s, cls): + return s + try: + return cls[s] + except KeyError: + raise ValueError(f"{s} is not a valid {cls.__name__}") + class AutoNameLower(Enum): @staticmethod @@ -108,6 +120,134 @@ def _generate_next_value_(name, start, count, last_values): return name.lower().replace('_', '-') +class MarketSession(AutoName): + NORMAL = auto() + AM = auto() + PM = auto() + SEAMLESS = auto() + + +class OrderType(AutoName): + MARKET = auto() + LIMIT = auto() + STOP = auto() + STOP_LIMIT = auto() + TRAILING_STOP = auto() + CABINET = auto() + NON_MARKETABLE = auto() + MARKET_ON_CLOSE = auto() + EXERCISE = auto() + TRAILING_STOP_LIMIT = auto() + NET_DEBIT = auto() + NET_CREDIT = auto() + NET_ZERO = auto() + LIMIT_ON_CLOSE = auto() + UNKNOWN = auto() + + +class OrderInstruction(AutoName): + BUY = auto() + SELL = auto() + BUY_TO_COVER = auto() + SELL_SHORT = auto() + BUY_TO_OPEN = auto() + BUY_TO_CLOSE = auto() + SELL_TO_OPEN = auto() + SELL_TO_CLOSE = auto() + EXCHANGE = auto() + SELL_SHORT_EXEMPT = auto() + + +class OrderDuration(AutoName): + DAY = auto() + GOOD_TILL_CANCEL = auto() + FILL_OR_KILL = auto() + IMMEDIATE_OR_CANCEL = auto() + END_OF_WEEK = auto() + END_OF_MONTH = auto() + NEXT_END_OF_MONTH = auto() + UNKNOWN = auto + + +class OrderStrategyType(AutoName): + SINGLE = auto() + CANCEL = auto() + RECALL = auto() + PAIR = auto() + FLATTEN = auto() + TWO_DAY_SWAP = auto() + BLAST_ALL = auto() + OCO = auto() + TRIGGER = auto() + + +class ComplexOrderStrategyType(AutoName): + NONE = auto() + COVERED = auto() + VERTICAL = auto() + BACK_RATIO = auto() + CALENDAR = auto() + DIAGONAL = auto() + STRADDLE = auto() + STRANGLE = auto() + COLLAR_SYNTHETIC = auto() + BUTTERFLY = auto() + CONDOR = auto() + IRON_CONDOR = auto() + VERTICAL_ROLL = auto() + COLLAR_WITH_STOCK = auto() + DOUBLE_DIAGONAL = auto() + UNBALANCED_BUTTERFLY = auto() + UNBALANCED_CONDOR = auto() + UNBALANCED_IRON_CONDOR = auto() + UNBALANCED_VERTICAL_ROLL = auto() + MUTUAL_FUND_SWAP = auto() + CUSTOM = auto() + + +class PositionEffect(AutoName): + OPENING = auto() + CLOSING = auto() + AUTOMATIC = auto() + + +class QuantityType(AutoName): + ALL_SHARES = auto() + DOLLARS = auto() + SHARES = auto() + + +class AssetType(AutoName): + EQUITY = auto() + OPTION = auto() + INDEX = auto() + MUTUAL_FUND = auto() + CASH_EQUIVALENT = auto() + FIXED_INCOME = auto() + CURRENCY = auto() + COLLECTIVE_INVESTMENT = auto() + + +class OptionAssetType(AutoName): + EQUITY = AssetType.EQUITY.value + OPTION = AssetType.OPTION.value + INDEX = AssetType.INDEX.value + MUTUAL_FUND = AssetType.MUTUAL_FUND.value + CASH_EQUIVALENT = AssetType.CASH_EQUIVALENT.value + FIXED_INCOME = AssetType.FIXED_INCOME.value + CURRENCY = AssetType.CURRENCY.value + COLLECTIVE_INVESTMENT = AssetType.COLLECTIVE_INVESTMENT.value + FUTURE = auto() + FOREX = auto() + PRODUCT = auto() + + +class OptionActionType(AutoName): + PUT = auto() + CALL = auto() + UNKNOWN = auto() + + class OrderStatus(AutoName): AWAITING_PARENT_ORDER = auto() AWAITING_CONDITION = auto() @@ -132,6 +272,21 @@ class OrderStatus(AutoName): UNKNOWN = auto() +class RequestedDestination(AutoName): + INET = auto() + ECN_ARCA = auto() + CBOE = auto() + AMEX = auto() + PHLX = auto() + ISE = auto() + BOX = auto() + NYSE = auto() + NASDAQ = auto() + BATS = auto() + C2 = auto() + AUTO = auto() + + class TransactionType(AutoName): TRADE = auto() RECEIVE_AND_DELIVER = auto() diff --git a/pyschwab/utils.py b/pyschwab/utils.py index b18be58..495944b 100644 --- a/pyschwab/utils.py +++ b/pyschwab/utils.py @@ -37,6 +37,8 @@ def dataclass_to_dict(instance): """ if instance is None: return instance + if isinstance(instance, Enum): + return instance.value if isinstance(instance, list): return [dataclass_to_dict(item) for item in instance] if not hasattr(instance, '__dataclass_fields__'): diff --git a/tests/integration/test_api.py b/tests/integration/test_api.py index 0c8a10d..8bb3d9f 100644 --- a/tests/integration/test_api.py +++ b/tests/integration/test_api.py @@ -102,7 +102,7 @@ def check_orders(trading_api: TradingApi, orders: List[Order]): assert leg.time is not None, "Expected time to be fetched" assert isinstance(leg.time, datetime), "Expected execution leg time to be a datetime" - detailed_order = trading_api.get_order(order.account_number, order.order_id) + detailed_order = trading_api.get_order(order.order_id, order.account_number) assert detailed_order == order, "Expected detailed order to match order" @@ -140,7 +140,8 @@ def test_authentication_and_trading_data(app_config, logging_config): account_count = 0 for account_num in accounts_hash: assert len(account_num) > 0, "Expected a non-empty account number to be fetched" - trading_data = trading_api.fetch_trading_data(account_num) + trading_api.set_current_account_number(account_num) + trading_data = trading_api.fetch_trading_data() check_trading_data(trading_data) account_count += 1 assert len(trading_api.get_accounts()) == account_count, f"Expected {account_count} account(s)" @@ -152,7 +153,8 @@ def test_authentication_and_trading_data(app_config, logging_config): check_trading_data(trading_data) for account_num in accounts_hash: - transactions = trading_api.get_transactions(account_num) + trading_api.set_current_account_number(account_num) + transactions = trading_api.get_transactions() for transaction in transactions: assert transaction is not None, "Expected transaction to be fetched" assert transaction.activity_id is not None, "Expected activity id to be fetched" @@ -172,10 +174,10 @@ def test_authentication_and_trading_data(app_config, logging_config): assert transfer_item.cost is not None, "Expected cost to be fetched" assert transfer_item.price is not None, "Expected price to be fetched" - transaction_detail = trading_api.get_transaction(account_num, transaction.activity_id) + transaction_detail = trading_api.get_transaction(transaction.activity_id) assert transaction_detail == transaction, "Expected transaction detail to match transaction" - orders = trading_api.get_orders(account_num) + orders = trading_api.get_orders() check_orders(trading_api, orders) orders = trading_api.get_all_orders(status=OrderStatus.FILLED) @@ -184,25 +186,27 @@ def test_authentication_and_trading_data(app_config, logging_config): if not test_account_number or not test_order_type: # no order placement, change, cancellation, or preview return + trading_api.set_current_account_number(test_account_number) if test_order_type == 'place_dict': print("Testing place order by order dict") - trading_api.place_order(order_dict, test_account_number) + trading_api.place_order(order_dict) elif test_order_type == 'place_obj': print("Testing place order by order obj") order = Order.from_dict(order_dict) - trading_api.place_order(order, test_account_number) + trading_api.place_order(order) else: for order in orders: leg = order.order_leg_collection[0] if leg.instrument.symbol == 'TSLA': if test_order_type == 'cancel': print("Testing cancel order") - trading_api.cancel_order(test_order_id, test_account_number) + trading_api.cancel_order(test_order_id) elif test_order_type == 'replace': print("Testing replace order") order.order_id = test_order_id order.price = 102 - order.quantity = 2 + # order.quantity = 2 + order.order_leg_collection[0].quantity = 2 trading_api.replace_order(order) break