From 7928e9b8669144c3aced7e0463c863e7f357e33b Mon Sep 17 00:00:00 2001 From: Graeme Holliday Date: Fri, 17 Jan 2025 17:52:55 -0500 Subject: [PATCH] bump version, annotate tests, fix #194 --- docs/conf.py | 2 +- docs/index.rst | 2 +- pyproject.toml | 2 +- tastytrade/__init__.py | 2 +- tastytrade/streamer.py | 2 + tests/conftest.py | 9 ++- tests/test_account.py | 114 ++++++++++++++++++++++---------------- tests/test_backtest.py | 3 +- tests/test_instruments.py | 69 +++++++++++------------ tests/test_metrics.py | 17 +++--- tests/test_search.py | 9 +-- tests/test_session.py | 14 ++--- tests/test_streamer.py | 10 ++-- tests/test_watchlists.py | 39 ++++++------- 14 files changed, 160 insertions(+), 134 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index d609ecc..8ce9bf5 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -13,7 +13,7 @@ project = "tastytrade" copyright = "2024, Graeme Holliday" author = "Graeme Holliday" -release = "9.7" +release = "9.8" # -- General configuration --------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration diff --git a/docs/index.rst b/docs/index.rst index 64b3c69..293bb63 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -22,7 +22,7 @@ A simple, reverse-engineered, sync/async SDK for Tastytrade built on their (now .. tip:: Want to see the SDK in action? Check out `tastytrade-cli `_, a CLI for Tastytrade that showcases many of the SDK's features. -.. tip:: +.. note:: Do you use TradeStation? We're building a `brand-new SDK `_ for TS users, with many of the same features! .. toctree:: diff --git a/pyproject.toml b/pyproject.toml index 2d166ee..a5ccd16 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "tastytrade" -version = "9.7" +version = "9.8" description = "An unofficial, sync/async SDK for Tastytrade!" readme = "README.md" requires-python = ">=3.9" diff --git a/tastytrade/__init__.py b/tastytrade/__init__.py index 223122e..37b55f4 100644 --- a/tastytrade/__init__.py +++ b/tastytrade/__init__.py @@ -4,7 +4,7 @@ BACKTEST_URL = "https://backtester.vast.tastyworks.com" CERT_URL = "https://api.cert.tastyworks.com" VAST_URL = "https://vast.tastyworks.com" -VERSION = "9.7" +VERSION = "9.8" logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) diff --git a/tastytrade/streamer.py b/tastytrade/streamer.py index df048fc..1960e69 100644 --- a/tastytrade/streamer.py +++ b/tastytrade/streamer.py @@ -242,6 +242,7 @@ async def close(self) -> None: self._reconnect_task.cancel() tasks.append(self._reconnect_task) await asyncio.gather(*tasks) + await self._websocket.wait_closed() # type: ignore async def _connect(self) -> None: """ @@ -443,6 +444,7 @@ async def close(self) -> None: self._reconnect_task.cancel() tasks.append(self._reconnect_task) await asyncio.gather(*tasks) + await self._websocket.wait_closed() async def _connect(self) -> None: """ diff --git a/tests/conftest.py b/tests/conftest.py index cb6db43..65c8b3d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,4 +1,5 @@ import os +from typing import AsyncGenerator from pytest import fixture @@ -7,12 +8,12 @@ # Run all tests with asyncio only @fixture(scope="session") -def aiolib(): +def aiolib() -> str: return "asyncio" @fixture(scope="session") -def credentials(): +def credentials() -> tuple[str, str]: username = os.getenv("TT_USERNAME") password = os.getenv("TT_PASSWORD") assert username is not None @@ -21,7 +22,9 @@ def credentials(): @fixture(scope="session") -async def session(credentials, aiolib): +async def session( + credentials: tuple[str, str], aiolib: str +) -> AsyncGenerator[Session, None]: session = Session(*credentials) yield session session.destroy() diff --git a/tests/test_account.py b/tests/test_account.py index a1312ff..cc04c58 100644 --- a/tests/test_account.py +++ b/tests/test_account.py @@ -5,7 +5,7 @@ from pytest import fixture -from tastytrade import Account +from tastytrade import Account, Session from tastytrade.instruments import Equity from tastytrade.order import ( NewComplexOrder, @@ -13,149 +13,154 @@ OrderAction, OrderTimeInForce, OrderType, + PlacedOrder, ) @fixture(scope="module") -def account_number(): +def account_number() -> str: account_number = os.getenv("TT_ACCOUNT") assert account_number is not None return account_number @fixture(scope="module") -async def account(session, account_number, aiolib): +async def account(session: Session, account_number: str, aiolib: str) -> Account: return Account.get_account(session, account_number) -def test_get_account(account): +def test_get_account(account: Account): pass -def test_get_accounts(session): +def test_get_accounts(session: Session): assert Account.get_accounts(session) != [] -def test_get_trading_status(session, account): +def test_get_trading_status(session: Session, account: Account): account.get_trading_status(session) -def test_get_balances(session, account): +def test_get_balances(session: Session, account: Account): account.get_balances(session) -def test_get_balance_snapshots(session, account): +def test_get_balance_snapshots(session: Session, account: Account): account.get_balance_snapshots(session) -def test_get_positions(session, account): +def test_get_positions(session: Session, account: Account): account.get_positions(session) -def test_get_history(session, account): +def test_get_history(session: Session, account: Account): account.get_history(session, page_offset=0) -def test_get_total_fees(session, account): +def test_get_total_fees(session: Session, account: Account): account.get_total_fees(session) -def test_get_position_limit(session, account): +def test_get_position_limit(session: Session, account: Account): account.get_position_limit(session) -def test_get_margin_requirements(session, account): +def test_get_margin_requirements(session: Session, account: Account): account.get_margin_requirements(session) -def test_get_net_liquidating_value_history(session, account): +def test_get_net_liquidating_value_history(session: Session, account: Account): account.get_net_liquidating_value_history(session, time_back="1y") -def test_get_effective_margin_requirements(session, account): +def test_get_effective_margin_requirements(session: Session, account: Account): account.get_effective_margin_requirements(session, "SPY") -def test_get_order_history(session, account): +def test_get_order_history(session: Session, account: Account): account.get_order_history(session, page_offset=0) -def test_get_complex_order_history(session, account): +def test_get_complex_order_history(session: Session, account: Account): account.get_complex_order_history(session, page_offset=0) -def test_get_live_orders(session, account): +def test_get_live_orders(session: Session, account: Account): account.get_live_orders(session) -async def test_get_account_async(session, account_number): +async def test_get_account_async(session: Session, account_number: str): await Account.a_get_account(session, account_number) -async def test_get_accounts_async(session): +async def test_get_accounts_async(session: Session): accounts = await Account.a_get_accounts(session) assert accounts != [] -async def test_get_trading_status_async(session, account): +async def test_get_trading_status_async(session: Session, account: Account): await account.a_get_trading_status(session) -async def test_get_balances_async(session, account): +async def test_get_balances_async(session: Session, account: Account): await account.a_get_balances(session) -async def test_get_balance_snapshots_async(session, account): +async def test_get_balance_snapshots_async(session: Session, account: Account): await account.a_get_balance_snapshots(session) -async def test_get_positions_async(session, account): +async def test_get_positions_async(session: Session, account: Account): await account.a_get_positions(session) -async def test_get_history_async(session, account): +async def test_get_history_async(session: Session, account: Account): await account.a_get_history(session, page_offset=0) -async def test_get_total_fees_async(session, account): +async def test_get_total_fees_async(session: Session, account: Account): await account.a_get_total_fees(session) -async def test_get_position_limit_async(session, account): +async def test_get_position_limit_async(session: Session, account: Account): await account.a_get_position_limit(session) -async def test_get_margin_requirements_async(session, account): +async def test_get_margin_requirements_async(session: Session, account: Account): await account.a_get_margin_requirements(session) -async def test_get_net_liquidating_value_history_async(session, account): +async def test_get_net_liquidating_value_history_async( + session: Session, account: Account +): await account.a_get_net_liquidating_value_history(session, time_back="1y") -async def test_get_effective_margin_requirements_async(session, account): +async def test_get_effective_margin_requirements_async( + session: Session, account: Account +): await account.a_get_effective_margin_requirements(session, "SPY") -async def test_get_order_history_async(session, account): +async def test_get_order_history_async(session: Session, account: Account): await account.a_get_order_history(session, page_offset=0) -async def test_get_complex_order_history_async(session, account): +async def test_get_complex_order_history_async(session: Session, account: Account): await account.a_get_complex_order_history(session, page_offset=0) -async def test_get_live_orders_async(session, account): +async def test_get_live_orders_async(session: Session, account: Account): await account.a_get_live_orders(session) -def test_get_order_chains(session, account): +def test_get_order_chains(session: Session, account: Account): start_time = datetime(2024, 1, 1, 0, 0, 0) end_time = datetime.now() account.get_order_chains(session, "F", start_time=start_time, end_time=end_time) -async def test_get_order_chains_async(session, account): +async def test_get_order_chains_async(session: Session, account: Account): start_time = datetime(2024, 1, 1, 0, 0, 0) end_time = datetime.now() await account.a_get_order_chains( @@ -164,7 +169,7 @@ async def test_get_order_chains_async(session, account): @fixture(scope="module") -def new_order(session): +def new_order(session: Session) -> NewOrder: symbol = Equity.get_equity(session, "F") leg = symbol.build_leg(Decimal(1), OrderAction.BUY_TO_OPEN) return NewOrder( @@ -176,20 +181,24 @@ def new_order(session): @fixture(scope="module") -def placed_order(session, account, new_order): +def placed_order( + session: Session, account: Account, new_order: NewOrder +) -> PlacedOrder: return account.place_order(session, new_order, dry_run=False).order -def test_place_order(placed_order): +def test_place_order(placed_order: PlacedOrder): pass -def test_get_order(session, account, placed_order): +def test_get_order(session: Session, account: Account, placed_order: PlacedOrder): sleep(3) assert account.get_order(session, placed_order.id).id == placed_order.id -def test_replace_and_delete_order(session, account, new_order, placed_order): +def test_replace_and_delete_order( + session: Session, account: Account, new_order: NewOrder, placed_order: PlacedOrder +): modified_order = new_order.model_copy() modified_order.price = Decimal("-2.01") replaced = account.replace_order(session, placed_order.id, modified_order) @@ -197,7 +206,7 @@ def test_replace_and_delete_order(session, account, new_order, placed_order): account.delete_order(session, replaced.id) -def test_place_oco_order(session, account): +def test_place_oco_order(session: Session, account: Account): # account must have a share of F for this to work symbol = Equity.get_equity(session, "F") closing = symbol.build_leg(Decimal(1), OrderAction.SELL_TO_CLOSE) @@ -224,7 +233,7 @@ def test_place_oco_order(session, account): account.delete_complex_order(session, resp2.complex_order.id) -def test_place_otoco_order(session, account): +def test_place_otoco_order(session: Session, account: Account): symbol = Equity.get_equity(session, "AAPL") opening = symbol.build_leg(Decimal(1), OrderAction.BUY_TO_OPEN) closing = symbol.build_leg(Decimal(1), OrderAction.SELL_TO_CLOSE) @@ -255,29 +264,36 @@ def test_place_otoco_order(session, account): account.delete_complex_order(session, resp.complex_order.id) -def test_get_live_complex_orders(session, account): +def test_get_live_complex_orders(session: Session, account: Account): orders = account.get_live_complex_orders(session) assert orders != [] @fixture(scope="module") -async def placed_order_async(session, account, new_order): +async def placed_order_async( + session: Session, account: Account, new_order: NewOrder +) -> PlacedOrder: res = await account.a_place_order(session, new_order, dry_run=False) return res.order -async def test_place_order_async(placed_order_async): +async def test_place_order_async(placed_order_async: PlacedOrder): pass -async def test_get_order_async(session, account, placed_order_async): +async def test_get_order_async( + session: Session, account: Account, placed_order_async: PlacedOrder +): sleep(3) placed = await account.a_get_order(session, placed_order_async.id) assert placed.id == placed_order_async.id async def test_replace_and_delete_order_async( - session, account, new_order, placed_order_async + session: Session, + account: Account, + new_order: NewOrder, + placed_order_async: PlacedOrder, ): modified_order = new_order.model_copy() modified_order.price = Decimal("-2.01") @@ -288,7 +304,7 @@ async def test_replace_and_delete_order_async( await account.a_delete_order(session, replaced.id) -async def test_place_complex_order_async(session, account): +async def test_place_complex_order_async(session: Session, account: Account): sleep(3) symbol = Equity.get_equity(session, "AAPL") opening = symbol.build_leg(Decimal(1), OrderAction.BUY_TO_OPEN) @@ -320,6 +336,6 @@ async def test_place_complex_order_async(session, account): await account.a_delete_complex_order(session, resp.complex_order.id) -async def test_get_live_complex_orders_async(session, account): +async def test_get_live_complex_orders_async(session: Session, account: Account): orders = await account.a_get_live_complex_orders(session) assert orders != [] diff --git a/tests/test_backtest.py b/tests/test_backtest.py index 01ce4be..2e6cbd0 100644 --- a/tests/test_backtest.py +++ b/tests/test_backtest.py @@ -1,5 +1,6 @@ from datetime import timedelta +from tastytrade import Session from tastytrade.backtest import ( Backtest, BacktestEntry, @@ -10,7 +11,7 @@ from tastytrade.utils import today_in_new_york -async def test_backtest_simple(session): +async def test_backtest_simple(session: Session): backtest_session = BacktestSession(session) backtest = Backtest( symbol="SPY", diff --git a/tests/test_instruments.py b/tests/test_instruments.py index 9ccf9e4..f71230f 100644 --- a/tests/test_instruments.py +++ b/tests/test_instruments.py @@ -1,3 +1,4 @@ +from tastytrade import Session from tastytrade.instruments import ( Cryptocurrency, Equity, @@ -18,131 +19,131 @@ ) -async def test_get_cryptocurrency_async(session): +async def test_get_cryptocurrency_async(session: Session): await Cryptocurrency.a_get_cryptocurrency(session, "ETH/USD") -def test_get_cryptocurrency(session): +def test_get_cryptocurrency(session: Session): Cryptocurrency.get_cryptocurrency(session, "ETH/USD") -async def test_get_cryptocurrencies_async(session): +async def test_get_cryptocurrencies_async(session: Session): await Cryptocurrency.a_get_cryptocurrencies(session) -def test_get_cryptocurrencies(session): +def test_get_cryptocurrencies(session: Session): Cryptocurrency.get_cryptocurrencies(session) -async def test_get_active_equities_async(session): +async def test_get_active_equities_async(session: Session): await Equity.a_get_active_equities(session, page_offset=0) -def test_get_active_equities(session): +def test_get_active_equities(session: Session): Equity.get_active_equities(session, page_offset=0) -async def test_get_equities_async(session): +async def test_get_equities_async(session: Session): await Equity.a_get_equities(session, ["AAPL", "SPY"]) -def test_get_equities(session): +def test_get_equities(session: Session): Equity.get_equities(session, ["AAPL", "SPY"]) -async def test_get_equity_async(session): +async def test_get_equity_async(session: Session): await Equity.a_get_equity(session, "AAPL") -def test_get_equity(session): +def test_get_equity(session: Session): Equity.get_equity(session, "AAPL") -async def test_get_futures_async(session): +async def test_get_futures_async(session: Session): futures = await Future.a_get_futures(session, product_codes=["ES"]) assert futures != [] await Future.a_get_future(session, futures[0].symbol) -def test_get_futures(session): +def test_get_futures(session: Session): futures = Future.get_futures(session, product_codes=["ES"]) assert futures != [] Future.get_future(session, futures[0].symbol) -async def test_get_future_product_async(session): +async def test_get_future_product_async(session: Session): await FutureProduct.a_get_future_product(session, "ZN") -def test_get_future_product(session): +def test_get_future_product(session: Session): FutureProduct.get_future_product(session, "ZN") -async def test_get_future_option_product_async(session): +async def test_get_future_option_product_async(session: Session): await FutureOptionProduct.a_get_future_option_product(session, "LO") -def test_get_future_option_product(session): +def test_get_future_option_product(session: Session): FutureOptionProduct.get_future_option_product(session, "LO") -async def test_get_future_option_products_async(session): +async def test_get_future_option_products_async(session: Session): await FutureOptionProduct.a_get_future_option_products(session) -def test_get_future_option_products(session): +def test_get_future_option_products(session: Session): FutureOptionProduct.get_future_option_products(session) -async def test_get_future_products_async(session): +async def test_get_future_products_async(session: Session): await FutureProduct.a_get_future_products(session) -def test_get_future_products(session): +def test_get_future_products(session: Session): FutureProduct.get_future_products(session) -async def test_get_nested_option_chain_async(session): +async def test_get_nested_option_chain_async(session: Session): await NestedOptionChain.a_get_chain(session, "SPY") -def test_get_nested_option_chain(session): +def test_get_nested_option_chain(session: Session): NestedOptionChain.get_chain(session, "SPY") -async def test_get_nested_future_option_chain_async(session): +async def test_get_nested_future_option_chain_async(session: Session): await NestedFutureOptionChain.a_get_chain(session, "ES") -def test_get_nested_future_option_chain(session): +def test_get_nested_future_option_chain(session: Session): NestedFutureOptionChain.get_chain(session, "ES") -async def test_get_warrants_async(session): +async def test_get_warrants_async(session: Session): await Warrant.a_get_warrants(session) -def test_get_warrants(session): +def test_get_warrants(session: Session): Warrant.get_warrants(session) -async def test_get_warrant_async(session): +async def test_get_warrant_async(session: Session): await Warrant.a_get_warrant(session, "NKLAW") -def test_get_warrant(session): +def test_get_warrant(session: Session): Warrant.get_warrant(session, "NKLAW") -async def test_get_quantity_decimal_precisions_async(session): +async def test_get_quantity_decimal_precisions_async(session: Session): await a_get_quantity_decimal_precisions(session) -def test_get_quantity_decimal_precisions(session): +def test_get_quantity_decimal_precisions(session: Session): get_quantity_decimal_precisions(session) -async def test_get_option_chain_async(session): +async def test_get_option_chain_async(session: Session): chain = await a_get_option_chain(session, "SPY") assert chain != {} for options in chain.values(): @@ -150,7 +151,7 @@ async def test_get_option_chain_async(session): break -def test_get_option_chain(session): +def test_get_option_chain(session: Session): chain = get_option_chain(session, "SPY") assert chain != {} for options in chain.values(): @@ -158,7 +159,7 @@ def test_get_option_chain(session): break -async def test_get_future_option_chain_async(session): +async def test_get_future_option_chain_async(session: Session): chain = await a_get_future_option_chain(session, "ES") assert chain != {} for options in chain.values(): @@ -168,7 +169,7 @@ async def test_get_future_option_chain_async(session): break -def test_get_future_option_chain(session): +def test_get_future_option_chain(session: Session): chain = get_future_option_chain(session, "ES") assert chain != {} for options in chain.values(): diff --git a/tests/test_metrics.py b/tests/test_metrics.py index 010b76e..a59d8c2 100644 --- a/tests/test_metrics.py +++ b/tests/test_metrics.py @@ -1,5 +1,6 @@ from datetime import date +from tastytrade import Session from tastytrade.metrics import ( a_get_dividends, a_get_earnings, @@ -12,33 +13,33 @@ ) -async def test_get_dividends_async(session): +async def test_get_dividends_async(session: Session): await a_get_dividends(session, "SPY") -async def test_get_earnings_async(session): +async def test_get_earnings_async(session: Session): await a_get_earnings(session, "AAPL", date.today()) -async def test_get_market_metrics_async(session): +async def test_get_market_metrics_async(session: Session): await a_get_market_metrics(session, ["SPY", "AAPL"]) -async def test_get_risk_free_rate_async(session): +async def test_get_risk_free_rate_async(session: Session): await a_get_risk_free_rate(session) -def test_get_dividends(session): +def test_get_dividends(session: Session): get_dividends(session, "SPY") -def test_get_earnings(session): +def test_get_earnings(session: Session): get_earnings(session, "AAPL", date.today()) -def test_get_market_metrics(session): +def test_get_market_metrics(session: Session): get_market_metrics(session, ["SPY", "AAPL"]) -def test_get_risk_free_rate(session): +def test_get_risk_free_rate(session: Session): get_risk_free_rate(session) diff --git a/tests/test_search.py b/tests/test_search.py index b470e72..54e0eaf 100644 --- a/tests/test_search.py +++ b/tests/test_search.py @@ -1,23 +1,24 @@ +from tastytrade import Session from tastytrade.search import a_symbol_search, symbol_search -async def test_symbol_search_valid_async(session): +async def test_symbol_search_valid_async(session: Session): results = await a_symbol_search(session, "AAP") symbols = [s.symbol for s in results] assert "AAPL" in symbols -async def test_symbol_search_invalid_async(session): +async def test_symbol_search_invalid_async(session: Session): results = await a_symbol_search(session, "ASDFGJKL") assert results == [] -def test_symbol_search_valid(session): +def test_symbol_search_valid(session: Session): results = symbol_search(session, "AAP") symbols = [s.symbol for s in results] assert "AAPL" in symbols -def test_symbol_search_invalid(session): +def test_symbol_search_invalid(session: Session): results = symbol_search(session, "ASDFGJKL") assert results == [] diff --git a/tests/test_session.py b/tests/test_session.py index 071b550..ce5ab76 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -1,33 +1,33 @@ from tastytrade import Session -def test_get_customer(session): +def test_get_customer(session: Session): session.get_customer() -async def test_get_customer_async(session): +async def test_get_customer_async(session: Session): await session.a_get_customer() -def test_get_2fa_info(session): +def test_get_2fa_info(session: Session): session.get_2fa_info() -async def test_get_2fa_info_async(session): +async def test_get_2fa_info_async(session: Session): await session.a_get_2fa_info() -def test_destroy(credentials): +def test_destroy(credentials: tuple[str, str]): session = Session(*credentials) session.destroy() -async def test_destroy_async(credentials): +async def test_destroy_async(credentials: tuple[str, str]): session = Session(*credentials) await session.a_destroy() -def test_serialize_deserialize(session): +def test_serialize_deserialize(session: Session): data = session.serialize() obj = Session.deserialize(data) assert set(obj.__dict__.keys()) == set(session.__dict__.keys()) diff --git a/tests/test_streamer.py b/tests/test_streamer.py index aa8a248..4a2fea5 100644 --- a/tests/test_streamer.py +++ b/tests/test_streamer.py @@ -1,11 +1,11 @@ import asyncio from datetime import datetime, timedelta -from tastytrade import Account, AlertStreamer, DXLinkStreamer +from tastytrade import Account, AlertStreamer, DXLinkStreamer, Session from tastytrade.dxfeed import Candle, Quote, Trade -async def test_account_streamer(session): +async def test_account_streamer(session: Session): async with AlertStreamer(session) as streamer: await streamer.subscribe_public_watchlists() await streamer.subscribe_quote_alerts() @@ -14,7 +14,7 @@ async def test_account_streamer(session): await streamer.subscribe_accounts(accounts) -async def test_dxlink_streamer(session): +async def test_dxlink_streamer(session: Session): async with DXLinkStreamer(session) as streamer: subs = ["SPY", "AAPL"] await streamer.subscribe(Quote, subs) @@ -37,7 +37,7 @@ async def reconnect_alerts(streamer: AlertStreamer, ref: dict[str, bool]): ref["test"] = True -async def test_account_streamer_reconnect(session): +async def test_account_streamer_reconnect(session: Session): ref = {} streamer = await AlertStreamer( session, reconnect_args=(ref,), reconnect_fn=reconnect_alerts @@ -56,7 +56,7 @@ async def reconnect_trades(streamer: DXLinkStreamer): await streamer.subscribe(Trade, ["SPX"]) -async def test_dxlink_streamer_reconnect(session): +async def test_dxlink_streamer_reconnect(session: Session): streamer = await DXLinkStreamer(session, reconnect_fn=reconnect_trades) await streamer.subscribe(Quote, ["SPY"]) _ = await streamer.get_event(Quote) diff --git a/tests/test_watchlists.py b/tests/test_watchlists.py index 27de564..f5dc2a5 100644 --- a/tests/test_watchlists.py +++ b/tests/test_watchlists.py @@ -2,95 +2,96 @@ from pytest import fixture +from tastytrade import Session from tastytrade.instruments import InstrumentType from tastytrade.watchlists import PairsWatchlist, Watchlist WATCHLIST_NAME = "TestWatchlist" -def test_get_pairs_watchlists(session): +def test_get_pairs_watchlists(session: Session): PairsWatchlist.get_pairs_watchlists(session) -def test_get_pairs_watchlist(session): +def test_get_pairs_watchlist(session: Session): PairsWatchlist.get_pairs_watchlist(session, "Stocks") -async def test_get_pairs_watchlists_async(session): +async def test_get_pairs_watchlists_async(session: Session): await PairsWatchlist.a_get_pairs_watchlists(session) -async def test_get_pairs_watchlist_async(session): +async def test_get_pairs_watchlist_async(session: Session): await PairsWatchlist.a_get_pairs_watchlist(session, "Stocks") -def test_get_public_watchlists(session): +def test_get_public_watchlists(session: Session): Watchlist.get_public_watchlists(session) -def test_get_public_watchlist(session): +def test_get_public_watchlist(session: Session): Watchlist.get_public_watchlist(session, "Crypto") -def test_get_private_watchlists(session): +def test_get_private_watchlists(session: Session): Watchlist.get_private_watchlists(session) -async def test_get_public_watchlists_async(session): +async def test_get_public_watchlists_async(session: Session): await Watchlist.a_get_public_watchlists(session) -async def test_get_public_watchlist_async(session): +async def test_get_public_watchlist_async(session: Session): await Watchlist.a_get_public_watchlist(session, "Crypto") -async def test_get_private_watchlists_async(session): +async def test_get_private_watchlists_async(session: Session): await Watchlist.a_get_private_watchlists(session) @fixture(scope="module") -def private_wl(): +def private_wl() -> Watchlist: wl = Watchlist(name=WATCHLIST_NAME) wl.add_symbol("MSFT", InstrumentType.EQUITY) wl.add_symbol("AAPL", InstrumentType.EQUITY) return wl -def test_upload_private_watchlist(session, private_wl): +def test_upload_private_watchlist(session: Session, private_wl: Watchlist): private_wl.upload_private_watchlist(session) -def test_get_private_watchlist(session): +def test_get_private_watchlist(session: Session): sleep(1) Watchlist.get_private_watchlist(session, WATCHLIST_NAME) -def test_update_private_watchlist(session, private_wl): +def test_update_private_watchlist(session: Session, private_wl: Watchlist): private_wl.remove_symbol("AAPL", InstrumentType.EQUITY) sleep(1) private_wl.update_private_watchlist(session) -def test_remove_private_watchlist(session): +def test_remove_private_watchlist(session: Session): sleep(1) Watchlist.remove_private_watchlist(session, WATCHLIST_NAME) -async def test_upload_private_watchlist_async(session, private_wl): +async def test_upload_private_watchlist_async(session: Session, private_wl: Watchlist): await private_wl.a_upload_private_watchlist(session) -async def test_get_private_watchlist_async(session): +async def test_get_private_watchlist_async(session: Session): sleep(1) await Watchlist.a_get_private_watchlist(session, WATCHLIST_NAME) -async def test_update_private_watchlist_async(session, private_wl): +async def test_update_private_watchlist_async(session: Session, private_wl: Watchlist): private_wl.remove_symbol("MSFT", InstrumentType.EQUITY) sleep(1) await private_wl.a_update_private_watchlist(session) -async def test_remove_private_watchlist_async(session): +async def test_remove_private_watchlist_async(session: Session): sleep(1) await Watchlist.a_remove_private_watchlist(session, WATCHLIST_NAME)