Skip to content

Commit

Permalink
Add more typings (#1356)
Browse files Browse the repository at this point in the history
  • Loading branch information
blink1073 authored Nov 15, 2023
1 parent a20fe64 commit 9f8ff28
Show file tree
Hide file tree
Showing 26 changed files with 212 additions and 192 deletions.
4 changes: 2 additions & 2 deletions jupyter_server/_tz.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
ZERO = timedelta(0)


class tzUTC(tzinfo): # noqa
class tzUTC(tzinfo): # noqa: N801
"""tzinfo object for UTC (zero offset)"""

def utcoffset(self, d: datetime | None) -> timedelta:
Expand All @@ -30,7 +30,7 @@ def utcnow() -> datetime:
return datetime.now(timezone.utc)


def utcfromtimestamp(timestamp):
def utcfromtimestamp(timestamp: float) -> datetime:
return datetime.fromtimestamp(timestamp, timezone.utc)


Expand Down
2 changes: 1 addition & 1 deletion jupyter_server/auth/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,6 @@ def inner(self, *args, **kwargs):
method = action
action = None
# no-arguments `@authorized` decorator called
return wrapper(method)
return cast(FuncT, wrapper(method))

return cast(FuncT, wrapper)
105 changes: 52 additions & 53 deletions jupyter_server/auth/identity.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@
import os
import re
import sys
import typing as t
import uuid
from dataclasses import asdict, dataclass
from http.cookies import Morsel
from typing import TYPE_CHECKING, Any, Awaitable

from tornado import escape, httputil, web
from traitlets import Bool, Dict, Type, Unicode, default
Expand All @@ -27,11 +27,6 @@
from .security import passwd_check, set_password
from .utils import get_anonymous_username

# circular imports for type checking
if TYPE_CHECKING:
from jupyter_server.base.handlers import AuthenticatedHandler, JupyterHandler
from jupyter_server.serverapp import ServerApp

_non_alphanum = re.compile(r"[^A-Za-z0-9]")


Expand Down Expand Up @@ -82,7 +77,7 @@ def fill_defaults(self):
self.display_name = self.name


def _backward_compat_user(got_user: Any) -> User:
def _backward_compat_user(got_user: t.Any) -> User:
"""Backward-compatibility for LoginHandler.get_user
Prior to 2.0, LoginHandler.get_user could return anything truthy.
Expand Down Expand Up @@ -128,7 +123,7 @@ class IdentityProvider(LoggingConfigurable):
.. versionadded:: 2.0
"""

cookie_name: str | Unicode = Unicode(
cookie_name: str | Unicode[str, str | bytes] = Unicode(
"",
config=True,
help=_i18n("Name of the cookie to set for persisting login. Default: username-${Host}."),
Expand All @@ -142,7 +137,7 @@ class IdentityProvider(LoggingConfigurable):
),
)

secure_cookie: bool | Bool = Bool(
secure_cookie: bool | Bool[bool | None, bool | int | None] = Bool(
None,
allow_none=True,
config=True,
Expand All @@ -160,7 +155,7 @@ class IdentityProvider(LoggingConfigurable):
),
)

token: str | Unicode = Unicode(
token: str | Unicode[str, str | bytes] = Unicode(
"<generated>",
help=_i18n(
"""Token used for authenticating first-time connections to the server.
Expand Down Expand Up @@ -211,9 +206,9 @@ def _token_default(self):
self.token_generated = True
return binascii.hexlify(os.urandom(24)).decode("ascii")

need_token: bool | Bool = Bool(True)
need_token: bool | Bool[bool, t.Union[bool, int]] = Bool(True)

def get_user(self, handler: JupyterHandler) -> User | None | Awaitable[User | None]:
def get_user(self, handler: web.RequestHandler) -> User | None | t.Awaitable[User | None]:
"""Get the authenticated user for a request
Must return a :class:`jupyter_server.auth.User`,
Expand All @@ -228,17 +223,17 @@ def get_user(self, handler: JupyterHandler) -> User | None | Awaitable[User | No
# not sure how to have optional-async type signature
# on base class with `async def` without splitting it into two methods

async def _get_user(self, handler: JupyterHandler) -> User | None:
async def _get_user(self, handler: web.RequestHandler) -> User | None:
"""Get the user."""
if getattr(handler, "_jupyter_current_user", None):
# already authenticated
return handler._jupyter_current_user
_token_user: User | None | Awaitable[User | None] = self.get_user_token(handler)
if isinstance(_token_user, Awaitable):
return t.cast(User, handler._jupyter_current_user) # type:ignore[attr-defined]
_token_user: User | None | t.Awaitable[User | None] = self.get_user_token(handler)
if isinstance(_token_user, t.Awaitable):
_token_user = await _token_user
token_user: User | None = _token_user # need second variable name to collapse type
_cookie_user = self.get_user_cookie(handler)
if isinstance(_cookie_user, Awaitable):
if isinstance(_cookie_user, t.Awaitable):
_cookie_user = await _cookie_user
cookie_user: User | None = _cookie_user
# prefer token to cookie if both given,
Expand Down Expand Up @@ -273,12 +268,12 @@ async def _get_user(self, handler: JupyterHandler) -> User | None:

return user

def identity_model(self, user: User) -> dict:
def identity_model(self, user: User) -> dict[str, t.Any]:
"""Return a User as an Identity model"""
# TODO: validate?
return asdict(user)

def get_handlers(self) -> list:
def get_handlers(self) -> list[tuple[str, object]]:
"""Return list of additional handlers for this identity provider
For example, an OAuth callback handler.
Expand Down Expand Up @@ -321,7 +316,7 @@ def user_from_cookie(self, cookie_value: str) -> User | None:
user["color"],
)

def get_cookie_name(self, handler: AuthenticatedHandler) -> str:
def get_cookie_name(self, handler: web.RequestHandler) -> str:
"""Return the login cookie name
Uses IdentityProvider.cookie_name, if defined.
Expand All @@ -333,7 +328,7 @@ def get_cookie_name(self, handler: AuthenticatedHandler) -> str:
else:
return _non_alphanum.sub("-", f"username-{handler.request.host}")

def set_login_cookie(self, handler: AuthenticatedHandler, user: User) -> None:
def set_login_cookie(self, handler: web.RequestHandler, user: User) -> None:
"""Call this on handlers to set the login cookie for success"""
cookie_options = {}
cookie_options.update(self.cookie_options)
Expand All @@ -345,12 +340,12 @@ def set_login_cookie(self, handler: AuthenticatedHandler, user: User) -> None:
secure_cookie = handler.request.protocol == "https"
if secure_cookie:
cookie_options.setdefault("secure", True)
cookie_options.setdefault("path", handler.base_url)
cookie_options.setdefault("path", handler.base_url) # type:ignore[attr-defined]
cookie_name = self.get_cookie_name(handler)
handler.set_secure_cookie(cookie_name, self.user_to_cookie(user), **cookie_options)

def _force_clear_cookie(
self, handler: AuthenticatedHandler, name: str, path: str = "/", domain: str | None = None
self, handler: web.RequestHandler, name: str, path: str = "/", domain: str | None = None
) -> None:
"""Deletes the cookie with the given name.
Expand All @@ -368,19 +363,19 @@ def _force_clear_cookie(
name = escape.native_str(name)
expires = datetime.datetime.now(tz=datetime.timezone.utc) - datetime.timedelta(days=365)

morsel: Morsel = Morsel()
morsel: Morsel[t.Any] = Morsel()
morsel.set(name, "", '""')
morsel["expires"] = httputil.format_timestamp(expires)
morsel["path"] = path
if domain:
morsel["domain"] = domain
handler.add_header("Set-Cookie", morsel.OutputString())

def clear_login_cookie(self, handler: AuthenticatedHandler) -> None:
def clear_login_cookie(self, handler: web.RequestHandler) -> None:
"""Clear the login cookie, effectively logging out the session."""
cookie_options = {}
cookie_options.update(self.cookie_options)
path = cookie_options.setdefault("path", handler.base_url)
path = cookie_options.setdefault("path", handler.base_url) # type:ignore[attr-defined]
cookie_name = self.get_cookie_name(handler)
handler.clear_cookie(cookie_name, path=path)
if path and path != "/":
Expand All @@ -390,7 +385,9 @@ def clear_login_cookie(self, handler: AuthenticatedHandler) -> None:
# two cookies with the same name. See the method above.
self._force_clear_cookie(handler, cookie_name)

def get_user_cookie(self, handler: JupyterHandler) -> User | None | Awaitable[User | None]:
def get_user_cookie(
self, handler: web.RequestHandler
) -> User | None | t.Awaitable[User | None]:
"""Get user from a cookie
Calls user_from_cookie to deserialize cookie value
Expand All @@ -413,7 +410,7 @@ def get_user_cookie(self, handler: JupyterHandler) -> User | None | Awaitable[Us

auth_header_pat = re.compile(r"(token|bearer)\s+(.+)", re.IGNORECASE)

def get_token(self, handler: JupyterHandler) -> str | None:
def get_token(self, handler: web.RequestHandler) -> str | None:
"""Get the user token from a request
Default:
Expand All @@ -429,14 +426,14 @@ def get_token(self, handler: JupyterHandler) -> str | None:
user_token = m.group(2)
return user_token

async def get_user_token(self, handler: JupyterHandler) -> User | None:
async def get_user_token(self, handler: web.RequestHandler) -> User | None:
"""Identify the user based on a token in the URL or Authorization header
Returns:
- uuid if authenticated
- None if not
"""
token = handler.token
token = t.cast("str | None", handler.token) # type:ignore[attr-defined]
if not token:
return None
# check login token from URL argument or Authorization header
Expand All @@ -455,7 +452,7 @@ async def get_user_token(self, handler: JupyterHandler) -> User | None:
# which is stored in a cookie.
# still check the cookie for the user id
_user = self.get_user_cookie(handler)
if isinstance(_user, Awaitable):
if isinstance(_user, t.Awaitable):
_user = await _user
user: User | None = _user
if user is None:
Expand All @@ -464,7 +461,7 @@ async def get_user_token(self, handler: JupyterHandler) -> User | None:
else:
return None

def generate_anonymous_user(self, handler: JupyterHandler) -> User:
def generate_anonymous_user(self, handler: web.RequestHandler) -> User:
"""Generate a random anonymous user.
For use when a single shared token is used,
Expand All @@ -475,10 +472,10 @@ def generate_anonymous_user(self, handler: JupyterHandler) -> User:
name = display_name = f"Anonymous {moon}"
initials = f"A{moon[0]}"
color = None
handler.log.debug(f"Generating new user for token-authenticated request: {user_id}")
handler.log.debug(f"Generating new user for token-authenticated request: {user_id}") # type:ignore[attr-defined]
return User(user_id, name, display_name, initials, None, color)

def should_check_origin(self, handler: AuthenticatedHandler) -> bool:
def should_check_origin(self, handler: web.RequestHandler) -> bool:
"""Should the Handler check for CORS origin validation?
Origin check should be skipped for token-authenticated requests.
Expand All @@ -489,7 +486,7 @@ def should_check_origin(self, handler: AuthenticatedHandler) -> bool:
"""
return not self.is_token_authenticated(handler)

def is_token_authenticated(self, handler: AuthenticatedHandler) -> bool:
def is_token_authenticated(self, handler: web.RequestHandler) -> bool:
"""Returns True if handler has been token authenticated. Otherwise, False.
Login with a token is used to signal certain things, such as:
Expand All @@ -504,8 +501,8 @@ def is_token_authenticated(self, handler: AuthenticatedHandler) -> bool:

def validate_security(
self,
app: ServerApp,
ssl_options: dict | None = None,
app: t.Any,
ssl_options: dict[str, t.Any] | None = None,
) -> None:
"""Check the application's security.
Expand All @@ -526,7 +523,7 @@ def validate_security(
" Anyone who can connect to this server will be able to run code."
)

def process_login_form(self, handler: JupyterHandler) -> User | None:
def process_login_form(self, handler: web.RequestHandler) -> User | None:
"""Process login form data
Return authenticated User if successful, None if not.
Expand All @@ -538,7 +535,7 @@ def process_login_form(self, handler: JupyterHandler) -> User | None:
return self.generate_anonymous_user(handler)

if self.token and self.token == typed_password:
return self.user_for_token(typed_password) # type:ignore[attr-defined]
return t.cast(User, self.user_for_token(typed_password)) # type:ignore[attr-defined]

return user

Expand Down Expand Up @@ -633,7 +630,7 @@ def passwd_check(self, password):
"""Check password against our stored hashed password"""
return passwd_check(self.hashed_password, password)

def process_login_form(self, handler: JupyterHandler) -> User | None:
def process_login_form(self, handler: web.RequestHandler) -> User | None:
"""Process login form data
Return authenticated User if successful, None if not.
Expand All @@ -659,8 +656,8 @@ def process_login_form(self, handler: JupyterHandler) -> User | None:

def validate_security(
self,
app: ServerApp,
ssl_options: dict | None = None,
app: t.Any,
ssl_options: dict[str, t.Any] | None = None,
) -> None:
"""Handle security validation."""
super().validate_security(app, ssl_options)
Expand Down Expand Up @@ -700,31 +697,33 @@ def _default_login_handler_class(self):
def auth_enabled(self):
return self.login_available

def get_user(self, handler: JupyterHandler) -> User | None:
def get_user(self, handler: web.RequestHandler) -> User | None:
"""Get the user."""
user = self.login_handler_class.get_user(handler) # type:ignore[attr-defined]
if user is None:
return None
return _backward_compat_user(user)

@property
def login_available(self):
return self.login_handler_class.get_login_available( # type:ignore[attr-defined]
self.settings
def login_available(self) -> bool:
return bool(
self.login_handler_class.get_login_available( # type:ignore[attr-defined]
self.settings
)
)

def should_check_origin(self, handler: AuthenticatedHandler) -> bool:
def should_check_origin(self, handler: web.RequestHandler) -> bool:
"""Whether we should check origin."""
return self.login_handler_class.should_check_origin(handler) # type:ignore[attr-defined]
return bool(self.login_handler_class.should_check_origin(handler)) # type:ignore[attr-defined]

def is_token_authenticated(self, handler: AuthenticatedHandler) -> bool:
def is_token_authenticated(self, handler: web.RequestHandler) -> bool:
"""Whether we are token authenticated."""
return self.login_handler_class.is_token_authenticated(handler) # type:ignore[attr-defined]
return bool(self.login_handler_class.is_token_authenticated(handler)) # type:ignore[attr-defined]

def validate_security(
self,
app: ServerApp,
ssl_options: dict | None = None,
app: t.Any,
ssl_options: dict[str, t.Any] | None = None,
) -> None:
"""Validate security."""
if self.password_required and (not self.hashed_password):
Expand All @@ -734,6 +733,6 @@ def validate_security(
self.log.critical(_i18n("Hint: run the following command to set a password"))
self.log.critical(_i18n("\t$ python -m jupyter_server.auth password"))
sys.exit(1)
return self.login_handler_class.validate_security( # type:ignore[attr-defined]
self.login_handler_class.validate_security( # type:ignore[attr-defined]
app, ssl_options
)
Loading

0 comments on commit 9f8ff28

Please sign in to comment.