Skip to content

Commit

Permalink
Merge pull request #56 from tristanpenman/cache-session
Browse files Browse the repository at this point in the history
Allow session to be cached to disk
  • Loading branch information
itsjafer committed Jun 27, 2024
2 parents 4ec4d91 + 3929555 commit 104129e
Show file tree
Hide file tree
Showing 3 changed files with 133 additions and 38 deletions.
9 changes: 6 additions & 3 deletions example/example.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@
totp_secret = os.getenv("SCHWAB_TOTP")

# Initialize our schwab instance
api = Schwab()
api = Schwab(
# Optional session cache - uncomment to enable:
# session_cache="session.json"
)

# Login using playwright
print("Logging into Schwab")
Expand Down Expand Up @@ -48,9 +51,9 @@
print("Placing a dry run trade for PFE stock")
# Place a dry run trade for each account
messages, success = api.trade_v2(
ticker="PFE",
ticker="PFE",
side="Buy", #or Sell
qty=1,
qty=1,
account_id=next(iter(account_info)), # Replace with your account number
dry_run=True # If dry_run=True, we won't place the order, we'll just verify it.
)
Expand Down
146 changes: 121 additions & 25 deletions schwab_api/authentication.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import hashlib
import json
import requests
import pyotp
import re
Expand All @@ -14,15 +16,27 @@
VIEWPORT = { 'width': 1920, 'height': 1080 }

class SessionManager:
def __init__(self) -> None:
"""
This class is using asynchonous playwright mode.
def __init__(self, debug = False) -> None:
"""
self.headers = None
This class is using asynchronous playwright mode.
:type debug: boolean
:param debug: Enable debug logging
"""
self.headers = {}
self.session = requests.Session()
self.playwright = None
self.browser = None
self.page = None
self.debug = debug

# cached credentials
self.username = None
self.username_hash = None
self.password = None
self.password_hash = None
self.totp_secret = None
self.totp_secret_hash = None

def check_auth(self):
r = self.session.get(urls.account_info_v2())
Expand All @@ -32,30 +46,81 @@ def check_auth(self):

def get_session(self):
return self.session

def login(self, username, password, totp_secret=None):
""" This function will log the user into schwab using asynchronous Playwright and saving
the authentication cookies in the session header.

def login(self, username, password, totp_secret, lazy=False):
"""
Logs the user into the Schwab API using asynchronous Playwright, saving
the authentication cookies in the session header.
:type username: str
:param username: The username for the schwab account.
:type password: str
:param password: The password for the schwab account/
:type totp_secret: Optional[str]
:param totp_secret: The TOTP secret used to complete multi-factor authentication
through Symantec VIP. If this isn't given, sign in will use SMS.
:type totp_secret: str
:param totp_secret: The TOTP secret used to complete multi-factor authentication
:type lazy: boolean
:param lazy: Store credentials but don't login until necessary
:rtype: boolean
:returns: True if login was successful and no further action is needed or False
if login requires additional steps (i.e. SMS - no longer supported)
"""
result = asyncio.run(self._async_login(username, password, totp_secret))
return result

async def _async_login(self, username, password, totp_secret=None):
""" This function runs in async mode to perform login.
Use with login function. See login function for details.
# update credentials
self.username = username or ""
self.password = password or ""
self.totp_secret = totp_secret or ""

# calculate hashes
username_hash = hashlib.md5(self.username.encode('utf-8')).hexdigest()
password_hash = hashlib.md5(self.password.encode('utf-8')).hexdigest()
totp_secret_hash = hashlib.md5(self.totp_secret.encode('utf-8')).hexdigest()

# attempt to load cached session
if self._load_session_cache():
# check hashed credentials
if self.username_hash == username_hash and self.password_hash == password_hash and self.totp_secret_hash == totp_secret_hash:
if self.debug:
print('DEBUG: hashed credentials okay')
try:
if self.update_token():
return True
except:
if self.debug:
print('DEBUG: update token failed, falling back to login')

# update hashed credentials
self.username_hash = username_hash
self.password_hash = password_hash
self.totp_secret_hash = totp_secret_hash

if lazy:
return True
else:
# attempt to login
return asyncio.run(self._async_login())

def update_token(self, token_type='api', login=True):
r = self.session.get(f"https://client.schwab.com/api/auth/authorize/scope/{token_type}")
if not r.ok:
if login:
if self.debug:
print("DEBUG: session invalid; logging in again")
result = asyncio.run(self._async_login())
return result
else:
raise ValueError(f"Error updating Bearer token: {r.reason}")

token = json.loads(r.text)['token']
self.headers['authorization'] = f"Bearer {token}"
self._save_session_cache()
return True

async def _async_login(self):
"""
Helper function to perform asynchronous login using Playwright
"""
self.playwright = await async_playwright().start()
if self.browserType == "firefox":
Expand All @@ -71,23 +136,23 @@ async def _async_login(self, username, password, totp_secret=None):
viewport=VIEWPORT
)
await stealth_async(self.page)

await self.page.goto("https://www.schwab.com/")

await self.page.goto("https://www.schwab.com/")
await self.page.route(re.compile(r".*balancespositions*"), self._asyncCaptureAuthToken)

login_frame = "schwablmslogin"
await self.page.wait_for_selector("#" + login_frame)

await self.page.frame(name=login_frame).select_option("select#landingPageOptions", index=3)

# enter username
await self.page.frame(name=login_frame).click("[placeholder=\"Login ID\"]")
await self.page.frame(name=login_frame).fill("[placeholder=\"Login ID\"]", username)
await self.page.frame(name=login_frame).fill("[placeholder=\"Login ID\"]", self.username)

if totp_secret is not None:
totp = pyotp.TOTP(totp_secret)
password += str(totp.now())
# append otp to passsword
totp = pyotp.TOTP(self.totp_secret)
password = self.password + str(totp.now())

# enter password
await self.page.frame(name=login_frame).press("[placeholder=\"Login ID\"]", "Tab")
await self.page.frame(name=login_frame).fill("[placeholder=\"Password\"]", password)

Expand All @@ -108,7 +173,38 @@ async def _async_save_and_close_session(self):
await self.page.close()
await self.browser.close()
await self.playwright.stop()

self._save_session_cache()

async def _asyncCaptureAuthToken(self, route):
self.headers = await route.request.all_headers()
await route.continue_()

def _load_session_cache(self):
if self.session_cache:
try:
with open(self.session_cache) as f:
data = f.read()
session = json.loads(data)
self.session.cookies = cookiejar_from_dict(session['cookies'])
self.headers = session['headers']
self.username_hash = session['username_hash']
self.password_hash = session['password_hash']
self.totp_secret_hash = session['totp_secret_hash']
return True
except:
# swallow exceptions
pass

return False

def _save_session_cache(self):
if self.session_cache:
with open(self.session_cache, 'w') as f:
session = {
'cookies': self.session.cookies.get_dict(),
'headers': self.headers,
'username_hash': self.username_hash,
'password_hash': self.password_hash,
'totp_secret_hash': self.totp_secret_hash
}
json.dump(session, f)
16 changes: 6 additions & 10 deletions schwab_api/schwab.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,21 @@
from .authentication import SessionManager

class Schwab(SessionManager):
def __init__(self, **kwargs):
def __init__(self, session_cache=None, **kwargs):
"""
The Schwab class. Used to interact with schwab.
The Schwab class. Used to interact with the Schwab API.
:type session_cache: str
:param session_cache: Path to an optional session file, used to save/restore credentials
"""
self.headless = kwargs.get("headless", True)
self.browserType = kwargs.get("browserType", "firefox")
self.session_cache = session_cache
super(Schwab, self).__init__()

def get_account_info(self):
"""
Returns a dictionary of Account objects where the key is the account number
Returns a dictionary of Account objects where the key is the account number
"""

account_info = dict()
Expand Down Expand Up @@ -789,10 +792,3 @@ def get_options_chains_v2(self, ticker, greeks = False):

response = json.loads(r.text)
return response

def update_token(self, token_type='api'):
r = self.session.get(f"https://client.schwab.com/api/auth/authorize/scope/{token_type}")
if not r.ok:
raise ValueError(f"Error updating Bearer token: {r.reason}")
token = json.loads(r.text)['token']
self.headers['authorization'] = f"Bearer {token}"

0 comments on commit 104129e

Please sign in to comment.