diff --git a/jupyter_server/_tz.py b/jupyter_server/_tz.py index df123ffe07..7027d80124 100644 --- a/jupyter_server/_tz.py +++ b/jupyter_server/_tz.py @@ -13,7 +13,7 @@ ZERO = timedelta(0) -class tzUTC(tzinfo): # noqa +class tzUTC(tzinfo): # noqa: N801 """tzinfo object for UTC (zero offset)""" def utcoffset(self, d: datetime | None) -> timedelta: @@ -30,7 +30,7 @@ def utcnow() -> datetime: return datetime.now(timezone.utc) -def utcfromtimestamp(timestamp): +def utcfromtimestamp(timestamp: float) -> datetime: return datetime.fromtimestamp(timestamp, timezone.utc) diff --git a/jupyter_server/auth/decorator.py b/jupyter_server/auth/decorator.py index a5d6c0543f..fd38cda1e7 100644 --- a/jupyter_server/auth/decorator.py +++ b/jupyter_server/auth/decorator.py @@ -73,6 +73,6 @@ def inner(self, *args, **kwargs): method = action action = None # no-arguments `@authorized` decorator called - return wrapper(method) + return cast(FuncT, wrapper(method)) return cast(FuncT, wrapper) diff --git a/jupyter_server/auth/identity.py b/jupyter_server/auth/identity.py index 2440710186..1374bc7430 100644 --- a/jupyter_server/auth/identity.py +++ b/jupyter_server/auth/identity.py @@ -13,10 +13,10 @@ import os import re import sys +import typing as t import uuid from dataclasses import asdict, dataclass from http.cookies import Morsel -from typing import TYPE_CHECKING, Any, Awaitable from tornado import escape, httputil, web from traitlets import Bool, Dict, Type, Unicode, default @@ -27,11 +27,6 @@ from .security import passwd_check, set_password from .utils import get_anonymous_username -# circular imports for type checking -if TYPE_CHECKING: - from jupyter_server.base.handlers import AuthenticatedHandler, JupyterHandler - from jupyter_server.serverapp import ServerApp - _non_alphanum = re.compile(r"[^A-Za-z0-9]") @@ -82,7 +77,7 @@ def fill_defaults(self): self.display_name = self.name -def _backward_compat_user(got_user: Any) -> User: +def _backward_compat_user(got_user: t.Any) -> User: """Backward-compatibility for LoginHandler.get_user Prior to 2.0, LoginHandler.get_user could return anything truthy. @@ -128,7 +123,7 @@ class IdentityProvider(LoggingConfigurable): .. versionadded:: 2.0 """ - cookie_name: str | Unicode = Unicode( + cookie_name: str | Unicode[str, str | bytes] = Unicode( "", config=True, help=_i18n("Name of the cookie to set for persisting login. Default: username-${Host}."), @@ -142,7 +137,7 @@ class IdentityProvider(LoggingConfigurable): ), ) - secure_cookie: bool | Bool = Bool( + secure_cookie: bool | Bool[bool | None, bool | int | None] = Bool( None, allow_none=True, config=True, @@ -160,7 +155,7 @@ class IdentityProvider(LoggingConfigurable): ), ) - token: str | Unicode = Unicode( + token: str | Unicode[str, str | bytes] = Unicode( "", help=_i18n( """Token used for authenticating first-time connections to the server. @@ -211,9 +206,9 @@ def _token_default(self): self.token_generated = True return binascii.hexlify(os.urandom(24)).decode("ascii") - need_token: bool | Bool = Bool(True) + need_token: bool | Bool[bool, t.Union[bool, int]] = Bool(True) - def get_user(self, handler: JupyterHandler) -> User | None | Awaitable[User | None]: + def get_user(self, handler: web.RequestHandler) -> User | None | t.Awaitable[User | None]: """Get the authenticated user for a request Must return a :class:`jupyter_server.auth.User`, @@ -228,17 +223,17 @@ def get_user(self, handler: JupyterHandler) -> User | None | Awaitable[User | No # not sure how to have optional-async type signature # on base class with `async def` without splitting it into two methods - async def _get_user(self, handler: JupyterHandler) -> User | None: + async def _get_user(self, handler: web.RequestHandler) -> User | None: """Get the user.""" if getattr(handler, "_jupyter_current_user", None): # already authenticated - return handler._jupyter_current_user - _token_user: User | None | Awaitable[User | None] = self.get_user_token(handler) - if isinstance(_token_user, Awaitable): + return t.cast(User, handler._jupyter_current_user) # type:ignore[attr-defined] + _token_user: User | None | t.Awaitable[User | None] = self.get_user_token(handler) + if isinstance(_token_user, t.Awaitable): _token_user = await _token_user token_user: User | None = _token_user # need second variable name to collapse type _cookie_user = self.get_user_cookie(handler) - if isinstance(_cookie_user, Awaitable): + if isinstance(_cookie_user, t.Awaitable): _cookie_user = await _cookie_user cookie_user: User | None = _cookie_user # prefer token to cookie if both given, @@ -273,12 +268,12 @@ async def _get_user(self, handler: JupyterHandler) -> User | None: return user - def identity_model(self, user: User) -> dict: + def identity_model(self, user: User) -> dict[str, t.Any]: """Return a User as an Identity model""" # TODO: validate? return asdict(user) - def get_handlers(self) -> list: + def get_handlers(self) -> list[tuple[str, object]]: """Return list of additional handlers for this identity provider For example, an OAuth callback handler. @@ -321,7 +316,7 @@ def user_from_cookie(self, cookie_value: str) -> User | None: user["color"], ) - def get_cookie_name(self, handler: AuthenticatedHandler) -> str: + def get_cookie_name(self, handler: web.RequestHandler) -> str: """Return the login cookie name Uses IdentityProvider.cookie_name, if defined. @@ -333,7 +328,7 @@ def get_cookie_name(self, handler: AuthenticatedHandler) -> str: else: return _non_alphanum.sub("-", f"username-{handler.request.host}") - def set_login_cookie(self, handler: AuthenticatedHandler, user: User) -> None: + def set_login_cookie(self, handler: web.RequestHandler, user: User) -> None: """Call this on handlers to set the login cookie for success""" cookie_options = {} cookie_options.update(self.cookie_options) @@ -345,12 +340,12 @@ def set_login_cookie(self, handler: AuthenticatedHandler, user: User) -> None: secure_cookie = handler.request.protocol == "https" if secure_cookie: cookie_options.setdefault("secure", True) - cookie_options.setdefault("path", handler.base_url) + cookie_options.setdefault("path", handler.base_url) # type:ignore[attr-defined] cookie_name = self.get_cookie_name(handler) handler.set_secure_cookie(cookie_name, self.user_to_cookie(user), **cookie_options) def _force_clear_cookie( - self, handler: AuthenticatedHandler, name: str, path: str = "/", domain: str | None = None + self, handler: web.RequestHandler, name: str, path: str = "/", domain: str | None = None ) -> None: """Deletes the cookie with the given name. @@ -368,7 +363,7 @@ def _force_clear_cookie( name = escape.native_str(name) expires = datetime.datetime.now(tz=datetime.timezone.utc) - datetime.timedelta(days=365) - morsel: Morsel = Morsel() + morsel: Morsel[t.Any] = Morsel() morsel.set(name, "", '""') morsel["expires"] = httputil.format_timestamp(expires) morsel["path"] = path @@ -376,11 +371,11 @@ def _force_clear_cookie( morsel["domain"] = domain handler.add_header("Set-Cookie", morsel.OutputString()) - def clear_login_cookie(self, handler: AuthenticatedHandler) -> None: + def clear_login_cookie(self, handler: web.RequestHandler) -> None: """Clear the login cookie, effectively logging out the session.""" cookie_options = {} cookie_options.update(self.cookie_options) - path = cookie_options.setdefault("path", handler.base_url) + path = cookie_options.setdefault("path", handler.base_url) # type:ignore[attr-defined] cookie_name = self.get_cookie_name(handler) handler.clear_cookie(cookie_name, path=path) if path and path != "/": @@ -390,7 +385,9 @@ def clear_login_cookie(self, handler: AuthenticatedHandler) -> None: # two cookies with the same name. See the method above. self._force_clear_cookie(handler, cookie_name) - def get_user_cookie(self, handler: JupyterHandler) -> User | None | Awaitable[User | None]: + def get_user_cookie( + self, handler: web.RequestHandler + ) -> User | None | t.Awaitable[User | None]: """Get user from a cookie Calls user_from_cookie to deserialize cookie value @@ -413,7 +410,7 @@ def get_user_cookie(self, handler: JupyterHandler) -> User | None | Awaitable[Us auth_header_pat = re.compile(r"(token|bearer)\s+(.+)", re.IGNORECASE) - def get_token(self, handler: JupyterHandler) -> str | None: + def get_token(self, handler: web.RequestHandler) -> str | None: """Get the user token from a request Default: @@ -429,14 +426,14 @@ def get_token(self, handler: JupyterHandler) -> str | None: user_token = m.group(2) return user_token - async def get_user_token(self, handler: JupyterHandler) -> User | None: + async def get_user_token(self, handler: web.RequestHandler) -> User | None: """Identify the user based on a token in the URL or Authorization header Returns: - uuid if authenticated - None if not """ - token = handler.token + token = t.cast("str | None", handler.token) # type:ignore[attr-defined] if not token: return None # check login token from URL argument or Authorization header @@ -455,7 +452,7 @@ async def get_user_token(self, handler: JupyterHandler) -> User | None: # which is stored in a cookie. # still check the cookie for the user id _user = self.get_user_cookie(handler) - if isinstance(_user, Awaitable): + if isinstance(_user, t.Awaitable): _user = await _user user: User | None = _user if user is None: @@ -464,7 +461,7 @@ async def get_user_token(self, handler: JupyterHandler) -> User | None: else: return None - def generate_anonymous_user(self, handler: JupyterHandler) -> User: + def generate_anonymous_user(self, handler: web.RequestHandler) -> User: """Generate a random anonymous user. For use when a single shared token is used, @@ -475,10 +472,10 @@ def generate_anonymous_user(self, handler: JupyterHandler) -> User: name = display_name = f"Anonymous {moon}" initials = f"A{moon[0]}" color = None - handler.log.debug(f"Generating new user for token-authenticated request: {user_id}") + handler.log.debug(f"Generating new user for token-authenticated request: {user_id}") # type:ignore[attr-defined] return User(user_id, name, display_name, initials, None, color) - def should_check_origin(self, handler: AuthenticatedHandler) -> bool: + def should_check_origin(self, handler: web.RequestHandler) -> bool: """Should the Handler check for CORS origin validation? Origin check should be skipped for token-authenticated requests. @@ -489,7 +486,7 @@ def should_check_origin(self, handler: AuthenticatedHandler) -> bool: """ return not self.is_token_authenticated(handler) - def is_token_authenticated(self, handler: AuthenticatedHandler) -> bool: + def is_token_authenticated(self, handler: web.RequestHandler) -> bool: """Returns True if handler has been token authenticated. Otherwise, False. Login with a token is used to signal certain things, such as: @@ -504,8 +501,8 @@ def is_token_authenticated(self, handler: AuthenticatedHandler) -> bool: def validate_security( self, - app: ServerApp, - ssl_options: dict | None = None, + app: t.Any, + ssl_options: dict[str, t.Any] | None = None, ) -> None: """Check the application's security. @@ -526,7 +523,7 @@ def validate_security( " Anyone who can connect to this server will be able to run code." ) - def process_login_form(self, handler: JupyterHandler) -> User | None: + def process_login_form(self, handler: web.RequestHandler) -> User | None: """Process login form data Return authenticated User if successful, None if not. @@ -538,7 +535,7 @@ def process_login_form(self, handler: JupyterHandler) -> User | None: return self.generate_anonymous_user(handler) if self.token and self.token == typed_password: - return self.user_for_token(typed_password) # type:ignore[attr-defined] + return t.cast(User, self.user_for_token(typed_password)) # type:ignore[attr-defined] return user @@ -633,7 +630,7 @@ def passwd_check(self, password): """Check password against our stored hashed password""" return passwd_check(self.hashed_password, password) - def process_login_form(self, handler: JupyterHandler) -> User | None: + def process_login_form(self, handler: web.RequestHandler) -> User | None: """Process login form data Return authenticated User if successful, None if not. @@ -659,8 +656,8 @@ def process_login_form(self, handler: JupyterHandler) -> User | None: def validate_security( self, - app: ServerApp, - ssl_options: dict | None = None, + app: t.Any, + ssl_options: dict[str, t.Any] | None = None, ) -> None: """Handle security validation.""" super().validate_security(app, ssl_options) @@ -700,7 +697,7 @@ def _default_login_handler_class(self): def auth_enabled(self): return self.login_available - def get_user(self, handler: JupyterHandler) -> User | None: + def get_user(self, handler: web.RequestHandler) -> User | None: """Get the user.""" user = self.login_handler_class.get_user(handler) # type:ignore[attr-defined] if user is None: @@ -708,23 +705,25 @@ def get_user(self, handler: JupyterHandler) -> User | None: return _backward_compat_user(user) @property - def login_available(self): - return self.login_handler_class.get_login_available( # type:ignore[attr-defined] - self.settings + def login_available(self) -> bool: + return bool( + self.login_handler_class.get_login_available( # type:ignore[attr-defined] + self.settings + ) ) - def should_check_origin(self, handler: AuthenticatedHandler) -> bool: + def should_check_origin(self, handler: web.RequestHandler) -> bool: """Whether we should check origin.""" - return self.login_handler_class.should_check_origin(handler) # type:ignore[attr-defined] + return bool(self.login_handler_class.should_check_origin(handler)) # type:ignore[attr-defined] - def is_token_authenticated(self, handler: AuthenticatedHandler) -> bool: + def is_token_authenticated(self, handler: web.RequestHandler) -> bool: """Whether we are token authenticated.""" - return self.login_handler_class.is_token_authenticated(handler) # type:ignore[attr-defined] + return bool(self.login_handler_class.is_token_authenticated(handler)) # type:ignore[attr-defined] def validate_security( self, - app: ServerApp, - ssl_options: dict | None = None, + app: t.Any, + ssl_options: dict[str, t.Any] | None = None, ) -> None: """Validate security.""" if self.password_required and (not self.hashed_password): @@ -734,6 +733,6 @@ def validate_security( self.log.critical(_i18n("Hint: run the following command to set a password")) self.log.critical(_i18n("\t$ python -m jupyter_server.auth password")) sys.exit(1) - return self.login_handler_class.validate_security( # type:ignore[attr-defined] + self.login_handler_class.validate_security( # type:ignore[attr-defined] app, ssl_options ) diff --git a/jupyter_server/base/handlers.py b/jupyter_server/base/handlers.py index c4d080cf18..a45443619d 100644 --- a/jupyter_server/base/handlers.py +++ b/jupyter_server/base/handlers.py @@ -73,7 +73,7 @@ def json_sys_info(): def log() -> Logger: """Get the application log.""" if Application.initialized(): - return Application.instance().log + return cast(Logger, Application.instance().log) else: return app_log @@ -83,7 +83,7 @@ class AuthenticatedHandler(web.RequestHandler): @property def base_url(self) -> str: - return self.settings.get("base_url", "/") + return cast(str, self.settings.get("base_url", "/")) @property def content_security_policy(self) -> str: @@ -93,7 +93,7 @@ def content_security_policy(self) -> str: """ if "Content-Security-Policy" in self.settings.get("headers", {}): # user-specified, don't override - return self.settings["headers"]["Content-Security-Policy"] + return cast(str, self.settings["headers"]["Content-Security-Policy"]) return "; ".join( [ @@ -171,7 +171,7 @@ def get_current_user(self) -> str: DeprecationWarning, stacklevel=2, ) - return self._jupyter_current_user + return cast(str, self._jupyter_current_user) # haven't called get_user in prepare, raise raise RuntimeError(msg) @@ -195,7 +195,7 @@ def token_authenticated(self) -> bool: def logged_in(self) -> bool: """Is a user currently logged in?""" user = self.current_user - return user and user != "anonymous" + return bool(user and user != "anonymous") @property def login_handler(self) -> Any: @@ -222,7 +222,7 @@ def login_available(self) -> bool: whether the user is already logged in or not. """ - return self.identity_provider.login_available + return cast(bool, self.identity_provider.login_available) @property def authorizer(self) -> Authorizer: @@ -266,7 +266,7 @@ def identity_provider(self) -> IdentityProvider: self.settings["identity_provider"] = IdentityProvider( config=self.settings.get("config", None) ) - return self.settings["identity_provider"] + return cast("IdentityProvider", self.settings["identity_provider"]) class JupyterHandler(AuthenticatedHandler): @@ -277,7 +277,7 @@ class JupyterHandler(AuthenticatedHandler): @property def config(self) -> dict[str, Any] | None: - return self.settings.get("config", None) + return cast("dict[str, Any] | None", self.settings.get("config", None)) @property def log(self) -> Logger: @@ -287,11 +287,11 @@ def log(self) -> Logger: @property def jinja_template_vars(self) -> dict[str, Any]: """User-supplied values to supply to jinja templates.""" - return self.settings.get("jinja_template_vars", {}) + return cast("dict[str, Any]", self.settings.get("jinja_template_vars", {})) @property def serverapp(self) -> ServerApp | None: - return self.settings["serverapp"] + return cast("ServerApp | None", self.settings["serverapp"]) # --------------------------------------------------------------- # URLs @@ -300,26 +300,26 @@ def serverapp(self) -> ServerApp | None: @property def version_hash(self) -> str: """The version hash to use for cache hints for static files""" - return self.settings.get("version_hash", "") + return cast(str, self.settings.get("version_hash", "")) @property def mathjax_url(self) -> str: - url = self.settings.get("mathjax_url", "") + url = cast(str, self.settings.get("mathjax_url", "")) if not url or url_is_absolute(url): return url return url_path_join(self.base_url, url) @property def mathjax_config(self) -> str: - return self.settings.get("mathjax_config", "TeX-AMS-MML_HTMLorMML-full,Safe") + return cast(str, self.settings.get("mathjax_config", "TeX-AMS-MML_HTMLorMML-full,Safe")) @property def default_url(self) -> str: - return self.settings.get("default_url", "") + return cast(str, self.settings.get("default_url", "")) @property def ws_url(self) -> str: - return self.settings.get("websocket_url", "") + return cast(str, self.settings.get("websocket_url", "")) @property def contents_js_source(self) -> str: @@ -327,7 +327,7 @@ def contents_js_source(self) -> str: "Using contents: %s", self.settings.get("contents_js_source", "services/contents"), ) - return self.settings.get("contents_js_source", "services/contents") + return cast(str, self.settings.get("contents_js_source", "services/contents")) # --------------------------------------------------------------- # Manager objects @@ -335,31 +335,31 @@ def contents_js_source(self) -> str: @property def kernel_manager(self) -> AsyncMappingKernelManager: - return self.settings["kernel_manager"] + return cast("AsyncMappingKernelManager", self.settings["kernel_manager"]) @property def contents_manager(self) -> ContentsManager: - return self.settings["contents_manager"] + return cast("ContentsManager", self.settings["contents_manager"]) @property def session_manager(self) -> SessionManager: - return self.settings["session_manager"] + return cast("SessionManager", self.settings["session_manager"]) @property def terminal_manager(self) -> TerminalManager: - return self.settings["terminal_manager"] + return cast("TerminalManager", self.settings["terminal_manager"]) @property def kernel_spec_manager(self) -> KernelSpecManager: - return self.settings["kernel_spec_manager"] + return cast("KernelSpecManager", self.settings["kernel_spec_manager"]) @property def config_manager(self) -> ConfigManager: - return self.settings["config_manager"] + return cast("ConfigManager", self.settings["config_manager"]) @property def event_logger(self) -> EventLogger: - return self.settings["event_logger"] + return cast("EventLogger", self.settings["event_logger"]) # --------------------------------------------------------------- # CORS @@ -368,17 +368,17 @@ def event_logger(self) -> EventLogger: @property def allow_origin(self) -> str: """Normal Access-Control-Allow-Origin""" - return self.settings.get("allow_origin", "") + return cast(str, self.settings.get("allow_origin", "")) @property - def allow_origin_pat(self) -> str: + def allow_origin_pat(self) -> str | None: """Regular expression version of allow_origin""" - return self.settings.get("allow_origin_pat", None) + return cast("str | None", self.settings.get("allow_origin_pat", None)) @property def allow_credentials(self) -> bool: """Whether to set Access-Control-Allow-Credentials""" - return self.settings.get("allow_credentials", False) + return cast(bool, self.settings.get("allow_credentials", False)) def set_default_headers(self) -> None: """Add CORS headers, if defined""" @@ -462,7 +462,7 @@ def check_origin(self, origin_to_satisfy_tornado: str = "") -> bool: # Check CORS headers if self.allow_origin: - allow = self.allow_origin == origin + allow = bool(self.allow_origin == origin) elif self.allow_origin_pat: allow = bool(re.match(self.allow_origin_pat, origin)) else: @@ -682,7 +682,7 @@ def get_json_body(self) -> dict[str, Any] | None: self.log.debug("Bad JSON: %r", body) self.log.error("Couldn't parse JSON", exc_info=True) raise web.HTTPError(400, "Invalid JSON in body of request") from e - return model + return cast("dict[str, Any]", model) def write_error(self, status_code: int, **kwargs: Any) -> None: """render custom error pages""" @@ -736,7 +736,7 @@ def write_error(self, status_code: int, **kwargs: Any) -> None: """APIHandler errors are JSON, not human pages""" self.set_header("Content-Type", "application/json") message = responses.get(status_code, "Unknown HTTP Error") - reply: dict = { + reply: dict[str, Any] = { "message": message, } exc_info = kwargs.get("exc_info") @@ -962,7 +962,7 @@ class FileFindHandler(JupyterHandler, web.StaticFileHandler): """ # cache search results, don't search for files more than once - _static_paths: dict[str, Any] = {} + _static_paths: dict[str, str] = {} root: tuple[str] # type:ignore[assignment] def set_headers(self) -> None: @@ -1102,8 +1102,8 @@ async def redirect_to_files(self: Any, path: str) -> None: self.log.debug("Redirecting %s to %s", self.request.path, url) self.redirect(url) - def get(self, path: str = "") -> Awaitable: - return self.redirect_to_files(self, path) + async def get(self, path: str = "") -> None: + return await self.redirect_to_files(self, path) class RedirectWithParams(web.RequestHandler): diff --git a/jupyter_server/config_manager.py b/jupyter_server/config_manager.py index 76268d8a23..87480d7609 100644 --- a/jupyter_server/config_manager.py +++ b/jupyter_server/config_manager.py @@ -76,7 +76,7 @@ def directory(self, section_name: str) -> str: """Returns the directory name for the section name: {config_dir}/{section_name}.d""" return os.path.join(self.config_dir, section_name + ".d") - def get(self, section_name: str, include_root: bool = True) -> t.Any: + def get(self, section_name: str, include_root: bool = True) -> dict[str, t.Any]: """Retrieve the config data for the specified section. Returns the data as a dictionary, or an empty dictionary if the file @@ -99,7 +99,7 @@ def get(self, section_name: str, include_root: bool = True) -> t.Any: section_name, "\n\t".join(paths), ) - data: dict = {} + data: dict[str, t.Any] = {} for path in paths: if os.path.isfile(path): with open(path, encoding="utf-8") as f: @@ -123,7 +123,7 @@ def set(self, section_name: str, data: t.Any) -> None: with open(filename, "w", encoding="utf-8") as f: f.write(json_content) - def update(self, section_name: str, new_data: t.Any) -> None: + def update(self, section_name: str, new_data: t.Any) -> dict[str, t.Any]: """Modify the config section by recursively updating it with new_data. Returns the modified config data as a dictionary. diff --git a/jupyter_server/extension/application.py b/jupyter_server/extension/application.py index f0e47f9dd7..0bd4e8b018 100644 --- a/jupyter_server/extension/application.py +++ b/jupyter_server/extension/application.py @@ -144,7 +144,7 @@ class method. This method can be set as a entry_point in # A useful class property that subclasses can override to # configure the underlying Jupyter Server when this extension # is launched directly (using its `launch_instance` method). - serverapp_config: dict = {} + serverapp_config: dict[str, t.Any] = {} # Some subclasses will likely override this trait to flip # the default value to False if they don't offer a browser @@ -174,7 +174,7 @@ def config_file_paths(self): # file, jupyter_{name}_config. # This should also match the jupyter subcommand used to launch # this extension from the CLI, e.g. `jupyter {name}`. - name: str | Unicode = "ExtensionApp" # type:ignore[assignment] + name: str | Unicode[str, str] = "ExtensionApp" # type:ignore[assignment] @classmethod def get_extension_package(cls): @@ -336,7 +336,7 @@ def _prepare_handlers(self): handler = handler_items[1] # Get handler kwargs, if given - kwargs: dict = {} + kwargs: dict[str, t.Any] = {} if issubclass(handler, ExtensionHandlerMixin): kwargs["name"] = self.name diff --git a/jupyter_server/extension/handler.py b/jupyter_server/extension/handler.py index 3018aae1c2..55f5aff2c3 100644 --- a/jupyter_server/extension/handler.py +++ b/jupyter_server/extension/handler.py @@ -1,15 +1,14 @@ """An extension handler.""" from __future__ import annotations -from typing import TYPE_CHECKING, Any +from logging import Logger +from typing import TYPE_CHECKING, Any, cast from jinja2.exceptions import TemplateNotFound from jupyter_server.base.handlers import FileFindHandler if TYPE_CHECKING: - from logging import Logger - from traitlets.config import Config from jupyter_server.extension.application import ExtensionApp @@ -25,9 +24,9 @@ def get_template(self, name: str) -> str: """Return the jinja template object for a given name""" try: env = f"{self.name}_jinja2_env" # type:ignore[attr-defined] - return self.settings[env].get_template(name) # type:ignore[attr-defined] + return cast(str, self.settings[env].get_template(name)) # type:ignore[attr-defined] except TemplateNotFound: - return super().get_template(name) # type:ignore[misc] + return cast(str, super().get_template(name)) # type:ignore[misc] class ExtensionHandlerMixin: @@ -41,6 +40,8 @@ class ExtensionHandlerMixin: other extensions. """ + settings: dict[str, Any] + def initialize(self, name: str, *args: Any, **kwargs: Any) -> None: self.name = name try: @@ -50,34 +51,34 @@ def initialize(self, name: str, *args: Any, **kwargs: Any) -> None: @property def extensionapp(self) -> ExtensionApp: - return self.settings[self.name] # type:ignore[attr-defined] + return cast("ExtensionApp", self.settings[self.name]) @property def serverapp(self) -> ServerApp: key = "serverapp" - return self.settings[key] # type:ignore[attr-defined] + return cast("ServerApp", self.settings[key]) @property def log(self) -> Logger: if not hasattr(self, "name"): - return super().log # type:ignore[misc] + return cast(Logger, super().log) # type:ignore[misc] # Attempt to pull the ExtensionApp's log, otherwise fall back to ServerApp. try: - return self.extensionapp.log + return cast(Logger, self.extensionapp.log) except AttributeError: - return self.serverapp.log + return cast(Logger, self.serverapp.log) @property def config(self) -> Config: - return self.settings[f"{self.name}_config"] # type:ignore[attr-defined] + return cast("Config", self.settings[f"{self.name}_config"]) @property def server_config(self) -> Config: - return self.settings["config"] # type:ignore[attr-defined] + return cast("Config", self.settings["config"]) @property def base_url(self) -> str: - return self.settings.get("base_url", "/") # type:ignore[attr-defined] + return cast(str, self.settings.get("base_url", "/")) @property def static_url_prefix(self) -> str: @@ -85,7 +86,7 @@ def static_url_prefix(self) -> str: @property def static_path(self) -> str: - return self.settings[f"{self.name}_static_paths"] # type:ignore[attr-defined] + return cast(str, self.settings[f"{self.name}_static_paths"]) def static_url(self, path: str, include_host: bool | None = None, **kwargs: Any) -> str: """Returns a static URL for the given relative static file path. @@ -108,7 +109,7 @@ def static_url(self, path: str, include_host: bool | None = None, **kwargs: Any) try: self.require_setting(key, "static_url") # type:ignore[attr-defined] except Exception as e: - if key in self.settings: # type:ignore[attr-defined] + if key in self.settings: msg = ( "This extension doesn't have any static paths listed. Check that the " "extension's `static_paths` trait is set." @@ -117,17 +118,14 @@ def static_url(self, path: str, include_host: bool | None = None, **kwargs: Any) else: raise e - get_url = self.settings.get( # type:ignore[attr-defined] - "static_handler_class", FileFindHandler - ).make_static_url + get_url = self.settings.get("static_handler_class", FileFindHandler).make_static_url if include_host is None: include_host = getattr(self, "include_host", False) - if include_host: # noqa + base = "" + if include_host: base = self.request.protocol + "://" + self.request.host # type:ignore[attr-defined] - else: - base = "" # Hijack settings dict to send extension templates to extension # static directory. @@ -136,4 +134,4 @@ def static_url(self, path: str, include_host: bool | None = None, **kwargs: Any) "static_url_prefix": self.static_url_prefix, } - return base + get_url(settings, path, **kwargs) + return base + cast(str, get_url(settings, path, **kwargs)) diff --git a/jupyter_server/extension/serverextension.py b/jupyter_server/extension/serverextension.py index 2d4359bd06..19f3a30709 100644 --- a/jupyter_server/extension/serverextension.py +++ b/jupyter_server/extension/serverextension.py @@ -381,7 +381,7 @@ class ServerExtensionApp(BaseExtensionApp): description: str = "Work with Jupyter server extensions" examples = _examples - subcommands: dict = { + subcommands: dict[str, t.Any] = { "enable": (EnableServerExtensionApp, "Enable a server extension"), "disable": (DisableServerExtensionApp, "Disable a server extension"), "list": (ListServerExtensionsApp, "List server extensions"), diff --git a/jupyter_server/gateway/connections.py b/jupyter_server/gateway/connections.py index 401fe86a21..9926644859 100644 --- a/jupyter_server/gateway/connections.py +++ b/jupyter_server/gateway/connections.py @@ -1,6 +1,7 @@ """Gateway connection classes.""" # Copyright (c) Jupyter Development Team. # Distributed under the terms of the Modified BSD License. +from __future__ import annotations import asyncio import logging @@ -41,11 +42,11 @@ async def connect(self): "channels", ) self.log.info(f"Connecting to {ws_url}") - kwargs: dict = {} + kwargs: dict[str, Any] = {} kwargs = GatewayClient.instance().load_connection_args(**kwargs) request = HTTPRequest(ws_url, **kwargs) - self.ws_future = cast(Future, tornado_websocket.websocket_connect(request)) + self.ws_future = cast("Future[Any]", tornado_websocket.websocket_connect(request)) self.ws_future.add_done_callback(self._connection_done) loop = IOLoop.current() diff --git a/jupyter_server/gateway/gateway_client.py b/jupyter_server/gateway/gateway_client.py index 395906177c..fb0562032b 100644 --- a/jupyter_server/gateway/gateway_client.py +++ b/jupyter_server/gateway/gateway_client.py @@ -1,6 +1,8 @@ """A kernel gateway client.""" # Copyright (c) Jupyter Development Team. # Distributed under the terms of the Modified BSD License. +from __future__ import annotations + import asyncio import json import logging @@ -538,7 +540,7 @@ def gateway_enabled(self): # Ensure KERNEL_LAUNCH_TIMEOUT has a default value. KERNEL_LAUNCH_TIMEOUT = int(os.environ.get("KERNEL_LAUNCH_TIMEOUT", 40)) - _connection_args: dict # initialized on first use + _connection_args: dict[str, ty.Any] # initialized on first use gateway_token_renewer: GatewayTokenRenewerBase @@ -549,7 +551,7 @@ def __init__(self, **kwargs): self.gateway_token_renewer = self.gateway_token_renewer_class(parent=self, log=self.log) # type:ignore[abstract] # store of cookies with store time - self._cookies: ty.Dict[str, ty.Tuple[Morsel, datetime]] = {} + self._cookies: dict[str, tuple[Morsel[ty.Any], datetime]] = {} def init_connection_args(self): """Initialize arguments used on every request. Since these are primarily static values, @@ -630,7 +632,7 @@ def load_connection_args(self, **kwargs): return kwargs - def update_cookies(self, cookie: SimpleCookie) -> None: + def update_cookies(self, cookie: SimpleCookie[ty.Any]) -> None: """Update cookies from existing requests for load balancers""" if not self.accept_cookies: return @@ -661,7 +663,7 @@ def _clear_expired_cookies(self) -> None: for key in expired_keys: self._cookies.pop(key) - def _update_cookie_header(self, connection_args: dict) -> None: + def _update_cookie_header(self, connection_args: dict[str, ty.Any]) -> None: """Update a cookie header.""" self._clear_expired_cookies() @@ -698,9 +700,9 @@ class RetryableHTTPClient: MAX_RETRIES_CAP = 10 # The upper limit to max_retries value. max_retries: int = int(os.getenv("JUPYTER_GATEWAY_MAX_REQUEST_RETRIES", MAX_RETRIES_DEFAULT)) max_retries = max(0, min(max_retries, MAX_RETRIES_CAP)) # Enforce boundaries - retried_methods: ty.Set[str] = {"GET", "DELETE"} - retried_errors: ty.Set[int] = {502, 503, 504, 599} - retried_exceptions: ty.Set[type] = {ConnectionError} + retried_methods: set[str] = {"GET", "DELETE"} + retried_errors: set[int] = {502, 503, 504, 599} + retried_exceptions: set[type] = {ConnectionError} backoff_factor: float = 0.1 def __init__(self): @@ -820,7 +822,7 @@ async def gateway_request(endpoint: str, **kwargs: ty.Any) -> HTTPResponse: # Update cookies on GatewayClient from server if configured. cookie_values = response.headers.get("Set-Cookie") if cookie_values: - cookie: SimpleCookie = SimpleCookie() + cookie: SimpleCookie[ty.Any] = SimpleCookie() cookie.load(cookie_values) GatewayClient.instance().update_cookies(cookie) return response diff --git a/jupyter_server/gateway/handlers.py b/jupyter_server/gateway/handlers.py index 952253ad8e..2dbbf3edfc 100644 --- a/jupyter_server/gateway/handlers.py +++ b/jupyter_server/gateway/handlers.py @@ -1,13 +1,15 @@ """Gateway API handlers.""" # Copyright (c) Jupyter Development Team. # Distributed under the terms of the Modified BSD License. +from __future__ import annotations + import asyncio import logging import mimetypes import os import random import warnings -from typing import Optional, cast +from typing import Any, Optional, cast from jupyter_client.session import Session from tornado import web @@ -159,7 +161,7 @@ def __init__(self, **kwargs): super().__init__() self.kernel_id = None self.ws = None - self.ws_future: Future = Future() + self.ws_future: Future[Any] = Future() self.disconnected = False self.retry = 0 @@ -178,11 +180,11 @@ async def _connect(self, kernel_id, message_callback): "channels", ) self.log.info(f"Connecting to {ws_url}") - kwargs: dict = {} + kwargs: dict[str, Any] = {} kwargs = client.load_connection_args(**kwargs) request = HTTPRequest(ws_url, **kwargs) - self.ws_future = cast(Future, websocket_connect(request)) + self.ws_future = cast("Future[Any]", websocket_connect(request)) self.ws_future.add_done_callback(self._connection_done) loop = IOLoop.current() diff --git a/jupyter_server/gateway/managers.py b/jupyter_server/gateway/managers.py index c77e0edb14..21aaefb86d 100644 --- a/jupyter_server/gateway/managers.py +++ b/jupyter_server/gateway/managers.py @@ -1,6 +1,8 @@ """Kernel gateway managers.""" # Copyright (c) Jupyter Development Team. # Distributed under the terms of the Modified BSD License. +from __future__ import annotations + import asyncio import datetime import json @@ -9,7 +11,7 @@ from queue import Empty, Queue from threading import Thread from time import monotonic -from typing import Any, Dict, Optional +from typing import Any, Optional, cast import websocket from jupyter_client.asynchronous.client import AsyncKernelClient @@ -36,7 +38,7 @@ class GatewayMappingKernelManager(AsyncMappingKernelManager): """Kernel manager that supports remote kernels hosted by Jupyter Kernel or Enterprise Gateway.""" # We'll maintain our own set of kernel ids - _kernels: Dict[str, "GatewayKernelManager"] = {} # type:ignore[assignment] + _kernels: dict[str, GatewayKernelManager] = {} # type:ignore[assignment] @default("kernel_manager_class") def _default_kernel_manager_class(self): @@ -408,7 +410,7 @@ def has_kernel(self): def client(self, **kwargs): """Create a client configured to connect to our kernel""" - kw: dict = {} + kw: dict[str, Any] = {} kw.update(self.get_connection_info(session=True)) kw.update( { @@ -589,7 +591,7 @@ def cleanup_resources(self, restart=False): KernelManagerABC.register(GatewayKernelManager) -class ChannelQueue(Queue): +class ChannelQueue(Queue): # type:ignore[type-arg] """A queue for a named channel.""" channel_name: Optional[str] = None @@ -623,7 +625,7 @@ async def _async_get(self, timeout=None): raise await asyncio.sleep(0) - async def get_msg(self, *args: Any, **kwargs: Any) -> dict: + async def get_msg(self, *args: Any, **kwargs: Any) -> dict[str, Any]: """Get a message from the queue.""" timeout = kwargs.get("timeout", 1) msg = await self._async_get(timeout=timeout) @@ -633,9 +635,9 @@ async def get_msg(self, *args: Any, **kwargs: Any) -> dict: ) ) self.task_done() - return msg + return cast("dict[str, Any]", msg) - def send(self, msg: dict) -> None: + def send(self, msg: dict[str, Any]) -> None: """Send a message to the queue.""" message = json.dumps(msg, default=ChannelQueue.serialize_datetime).replace(" str: @validate("ip") def _validate_ip(self, proposal: t.Any) -> str: - value = proposal["value"] + value = t.cast(str, proposal["value"]) if value == "*": value = "" return value @@ -987,7 +989,7 @@ def _port_retries_default(self) -> int: ) @validate("sock_mode") - def _validate_sock_mode(self, proposal: t.Any) -> int: + def _validate_sock_mode(self, proposal: t.Any) -> t.Any: value = proposal["value"] try: converted_value = int(value.encode(), 8) @@ -1403,7 +1405,7 @@ def _deprecated_cookie_config(self, change: t.Any) -> None: @validate("base_url") def _update_base_url(self, proposal: t.Any) -> str: - value = proposal["value"] + value = t.cast(str, proposal["value"]) if not value.startswith("/"): value = "/" + value if not value.endswith("/"): @@ -1420,14 +1422,14 @@ def _update_base_url(self, proposal: t.Any) -> str: ) @property - def static_file_path(self) -> t.List[str]: + def static_file_path(self) -> list[str]: """return extra paths + the default location""" return [*self.extra_static_paths, DEFAULT_STATIC_FILES_PATH] static_custom_path = List(Unicode(), help=_i18n("""Path to search for custom.js, css""")) @default("static_custom_path") - def _default_static_custom_path(self) -> t.List[str]: + def _default_static_custom_path(self) -> list[str]: return [os.path.join(d, "custom") for d in (self.config_dir, DEFAULT_STATIC_FILES_PATH)] extra_template_paths = List( @@ -1441,7 +1443,7 @@ def _default_static_custom_path(self) -> t.List[str]: ) @property - def template_file_path(self) -> t.List[str]: + def template_file_path(self) -> list[str]: """return extra paths + the default locations""" return self.extra_template_paths + DEFAULT_TEMPLATE_PATH_LIST @@ -1483,7 +1485,7 @@ def template_file_path(self) -> t.List[str]: ) @default("kernel_manager_class") - def _default_kernel_manager_class(self) -> t.Union[str, t.Type[AsyncMappingKernelManager]]: + def _default_kernel_manager_class(self) -> t.Union[str, type[AsyncMappingKernelManager]]: if self.gateway_config.gateway_enabled: return "jupyter_server.gateway.managers.GatewayMappingKernelManager" return AsyncMappingKernelManager @@ -1494,7 +1496,7 @@ def _default_kernel_manager_class(self) -> t.Union[str, t.Type[AsyncMappingKerne ) @default("session_manager_class") - def _default_session_manager_class(self) -> t.Union[str, t.Type[SessionManager]]: + def _default_session_manager_class(self) -> t.Union[str, type[SessionManager]]: if self.gateway_config.gateway_enabled: return "jupyter_server.gateway.managers.GatewaySessionManager" return SessionManager @@ -1508,7 +1510,7 @@ def _default_session_manager_class(self) -> t.Union[str, t.Type[SessionManager]] @default("kernel_websocket_connection_class") def _default_kernel_websocket_connection_class( self, - ) -> t.Union[str, t.Type[ZMQChannelsWebsocketConnection]]: + ) -> t.Union[str, type[ZMQChannelsWebsocketConnection]]: if self.gateway_config.gateway_enabled: return "jupyter_server.gateway.connections.GatewayWebSocketConnection" return ZMQChannelsWebsocketConnection @@ -1533,7 +1535,7 @@ def _default_kernel_websocket_connection_class( ) @default("kernel_spec_manager_class") - def _default_kernel_spec_manager_class(self) -> t.Union[str, t.Type[KernelSpecManager]]: + def _default_kernel_spec_manager_class(self) -> t.Union[str, type[KernelSpecManager]]: if self.gateway_config.gateway_enabled: return "jupyter_server.gateway.managers.GatewayKernelSpecManager" return KernelSpecManager @@ -1856,7 +1858,7 @@ def starter_app(self) -> t.Any: """Get the Extension that started this server.""" return self._starter_app - def parse_command_line(self, argv: t.Optional[t.List[str]] = None) -> None: + def parse_command_line(self, argv: t.Optional[list[str]] = None) -> None: """Parse the command line options.""" super().parse_command_line(argv) @@ -2148,7 +2150,9 @@ def init_resources(self) -> None: ) resource.setrlimit(resource.RLIMIT_NOFILE, (soft, hard)) - def _get_urlparts(self, path: t.Optional[str] = None, include_token: bool = False) -> t.Any: + def _get_urlparts( + self, path: t.Optional[str] = None, include_token: bool = False + ) -> urllib.parse.ParseResult: """Constructs a urllib named tuple, ParseResult, with default values set by server config. The returned tuple can be manipulated using the `_replace` method. @@ -2545,7 +2549,7 @@ def _init_asyncio_patch() -> None: @catch_config_error def initialize( self, - argv: t.Optional[t.List[str]] = None, + argv: t.Optional[list[str]] = None, find_extensions: bool = True, new_httpserver: bool = True, starter_extension: t.Any = None, @@ -2636,7 +2640,7 @@ async def cleanup_extensions(self) -> None: def running_server_info(self, kernel_count: bool = True) -> str: """Return the current working directory and the server url information""" - info = self.contents_manager.info_string() + "\n" + info = t.cast(str, self.contents_manager.info_string()) + "\n" if kernel_count: n_kernels = len(self.kernel_manager.list_kernel_ids()) kernel_msg = trans.ngettext("%d active kernel", "%d active kernels", n_kernels) @@ -2651,7 +2655,7 @@ def running_server_info(self, kernel_count: bool = True) -> str: ) return info - def server_info(self) -> t.Dict[str, t.Any]: + def server_info(self) -> dict[str, t.Any]: """Return a JSONable dict of information about this server.""" return { "url": self.connection_url, @@ -2784,7 +2788,7 @@ def remove_browser_open_file(self) -> None: if e.errno != errno.ENOENT: raise - def _prepare_browser_open(self) -> t.Tuple[str, t.Optional[str]]: + def _prepare_browser_open(self) -> tuple[str, t.Optional[str]]: """Prepare to open the browser.""" if not self.use_redirect_file: uri = self.default_url[len(self.base_url) :] diff --git a/jupyter_server/services/api/handlers.py b/jupyter_server/services/api/handlers.py index 9583732289..efb361186c 100644 --- a/jupyter_server/services/api/handlers.py +++ b/jupyter_server/services/api/handlers.py @@ -3,7 +3,7 @@ # Distributed under the terms of the Modified BSD License. import json import os -from typing import Dict, List +from typing import Any, Dict, List from jupyter_core.utils import ensure_async from tornado import web @@ -97,7 +97,7 @@ def get(self): if self.authorizer.is_authorized(self, user=user, resource=resource, action=action): allowed.append(action) - identity: Dict = self.identity_provider.identity_model(user) + identity: Dict[str, Any] = self.identity_provider.identity_model(user) model = { "identity": identity, "permissions": permissions, diff --git a/jupyter_server/services/config/manager.py b/jupyter_server/services/config/manager.py index bc42deb645..720c8e7bd7 100644 --- a/jupyter_server/services/config/manager.py +++ b/jupyter_server/services/config/manager.py @@ -3,6 +3,7 @@ # Copyright (c) Jupyter Development Team. # Distributed under the terms of the Modified BSD License. import os.path +import typing as t from jupyter_core.paths import jupyter_config_dir, jupyter_config_path from traitlets import Instance, List, Unicode, default, observe @@ -22,7 +23,7 @@ class ConfigManager(LoggingConfigurable): def get(self, section_name): """Get the config from all config sections.""" - config: dict = {} + config: t.Dict[str, t.Any] = {} # step through back to front, to ensure front of the list is top priority for p in self.read_config_path[::-1]: cm = BaseJSONConfigManager(config_dir=p) diff --git a/jupyter_server/services/contents/filemanager.py b/jupyter_server/services/contents/filemanager.py index fe12fb1b7a..64b5fc122a 100644 --- a/jupyter_server/services/contents/filemanager.py +++ b/jupyter_server/services/contents/filemanager.py @@ -1,6 +1,8 @@ """A contents manager that uses the local file system for storage.""" # Copyright (c) Jupyter Development Team. # Distributed under the terms of the Modified BSD License. +from __future__ import annotations + import errno import math import mimetypes @@ -10,6 +12,7 @@ import stat import subprocess import sys +import typing as t import warnings from datetime import datetime from pathlib import Path @@ -375,7 +378,7 @@ def _notebook_model(self, path, content=True): os_path = self._get_os_path(path) if content: - validation_error: dict = {} + validation_error: dict[str, t.Any] = {} nb = self._read_notebook( os_path, as_version=4, capture_validation_error=validation_error ) @@ -465,7 +468,7 @@ def save(self, model, path=""): self.log.debug("Saving %s", os_path) - validation_error: dict = {} + validation_error: dict[str, t.Any] = {} try: if model["type"] == "notebook": nb = nbformat.from_dict(model["content"]) @@ -824,7 +827,7 @@ async def _notebook_model(self, path, content=True): os_path = self._get_os_path(path) if content: - validation_error: dict = {} + validation_error: dict[str, t.Any] = {} nb = await self._read_notebook( os_path, as_version=4, capture_validation_error=validation_error ) @@ -906,7 +909,7 @@ async def save(self, model, path=""): os_path = self._get_os_path(path) self.log.debug("Saving %s", os_path) - validation_error: dict = {} + validation_error: dict[str, t.Any] = {} try: if model["type"] == "notebook": nb = nbformat.from_dict(model["content"]) @@ -1094,7 +1097,7 @@ async def copy(self, from_path, to_path=None): async def _copy_dir( self, from_path: str, to_path_original: str, to_name: str, to_path: str - ) -> dict: + ) -> dict[str, t.Any]: """ handles copying directories returns the model for the copied directory @@ -1111,7 +1114,7 @@ async def _copy_dir( f"Can't copy '{from_path}' into read-only Folder '{to_path}'", ) from err - return model + return model # type:ignore[no-any-return] async def check_folder_size(self, path: str) -> None: """ diff --git a/jupyter_server/services/events/handlers.py b/jupyter_server/services/events/handlers.py index 0ba11c0985..5c52e75ad8 100644 --- a/jupyter_server/services/events/handlers.py +++ b/jupyter_server/services/events/handlers.py @@ -2,6 +2,8 @@ .. versionadded:: 2.0 """ +from __future__ import annotations + import json from datetime import datetime from typing import Any, Dict, Optional, cast @@ -48,7 +50,7 @@ async def get(self, *args, **kwargs): await res async def event_listener( - self, logger: jupyter_events.logger.EventLogger, schema_id: str, data: dict + self, logger: jupyter_events.logger.EventLogger, schema_id: str, data: dict[str, Any] ) -> None: """Write an event message.""" capsule = dict(schema_id=schema_id, **data) @@ -65,7 +67,7 @@ def on_close(self): self.event_logger.remove_listener(listener=self.event_listener) -def validate_model(data: Dict[str, Any]) -> None: +def validate_model(data: dict[str, Any]) -> None: """Validates for required fields in the JSON request body""" required_keys = {"schema_id", "version", "data"} for key in required_keys: @@ -73,7 +75,7 @@ def validate_model(data: Dict[str, Any]) -> None: raise web.HTTPError(400, f"Missing `{key}` in the JSON request body.") -def get_timestamp(data: Dict[str, Any]) -> Optional[datetime]: +def get_timestamp(data: dict[str, Any]) -> Optional[datetime]: """Parses timestamp from the JSON request body""" try: if "timestamp" in data: diff --git a/jupyter_server/services/kernels/connection/abc.py b/jupyter_server/services/kernels/connection/abc.py index 4bdf6e3edc..bc98233a23 100644 --- a/jupyter_server/services/kernels/connection/abc.py +++ b/jupyter_server/services/kernels/connection/abc.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Any +from typing import Any, List class KernelWebsocketConnectionABC(ABC): @@ -28,6 +28,6 @@ def handle_incoming_message(self, incoming_msg: str) -> None: ... @abstractmethod - def handle_outgoing_message(self, stream: str, outgoing_msg: list) -> None: + def handle_outgoing_message(self, stream: str, outgoing_msg: List[Any]) -> None: """Broker outgoing ZMQ messages to the kernel websocket.""" ... diff --git a/jupyter_server/services/kernels/connection/base.py b/jupyter_server/services/kernels/connection/base.py index 0f731f354f..1f6b2fdcf4 100644 --- a/jupyter_server/services/kernels/connection/base.py +++ b/jupyter_server/services/kernels/connection/base.py @@ -1,6 +1,7 @@ """Kernel connection helpers.""" import json import struct +from typing import Any, List from jupyter_client.session import Session from tornado.websocket import WebSocketHandler @@ -87,7 +88,7 @@ def serialize_msg_to_ws_v1(msg_or_list, channel, pack=None): else: msg_list = msg_or_list channel = channel.encode("utf-8") - offsets: list = [] + offsets: List[Any] = [] offsets.append(8 * (1 + 1 + len(msg_list) + 1)) offsets.append(len(channel) + offsets[-1]) for msg in msg_list: @@ -171,7 +172,7 @@ def handle_incoming_message(self, incoming_msg: str) -> None: """Handle an incoming message.""" raise NotImplementedError() - def handle_outgoing_message(self, stream: str, outgoing_msg: list) -> None: + def handle_outgoing_message(self, stream: str, outgoing_msg: List[Any]) -> None: """Handle an outgoing message.""" raise NotImplementedError() diff --git a/jupyter_server/services/kernels/connection/channels.py b/jupyter_server/services/kernels/connection/channels.py index d8a84db47f..c103fe456d 100644 --- a/jupyter_server/services/kernels/connection/channels.py +++ b/jupyter_server/services/kernels/connection/channels.py @@ -99,8 +99,8 @@ def write_message(self): _open_sessions: dict[str, KernelWebsocketHandler] = {} _open_sockets: t.MutableSet[ZMQChannelsWebsocketConnection] = weakref.WeakSet() - _kernel_info_future: Future - _close_future: Future + _kernel_info_future: Future[t.Any] + _close_future: Future[t.Any] channels = Dict({}) kernel_info_channel = Any(allow_none=True) @@ -170,7 +170,7 @@ def nudge(self): # noqa # establishing its zmq subscriptions before processing the next request. if getattr(self.kernel_manager, "execution_state", None) == "busy": self.log.debug("Nudge: not nudging busy kernel %s", self.kernel_id) - f: Future = Future() + f: Future[t.Any] = Future() f.set_result(None) return _ensure_future(f) # Use a transient shell channel to prevent leaking @@ -182,8 +182,8 @@ def nudge(self): # noqa # The IOPub used by the client, whose subscriptions we are verifying. iopub_channel = self.channels["iopub"] - info_future: Future = Future() - iopub_future: Future = Future() + info_future: Future[t.Any] = Future() + iopub_future: Future[t.Any] = Future() both_done = gen.multi([info_future, iopub_future]) def finish(_=None): @@ -486,7 +486,7 @@ def handle_incoming_message(self, incoming_msg: str) -> None: else: self.session.send(stream, msg) - def handle_outgoing_message(self, stream: str, outgoing_msg: list) -> None: + def handle_outgoing_message(self, stream: str, outgoing_msg: list[t.Any]) -> None: """Handle the outgoing messages from ZMQ sockets to Websocket.""" msg_list = outgoing_msg _, fed_msg_list = self.session.feed_identities(msg_list) diff --git a/jupyter_server/services/kernels/kernelmanager.py b/jupyter_server/services/kernels/kernelmanager.py index fb3608d0b6..d0ed803b74 100644 --- a/jupyter_server/services/kernels/kernelmanager.py +++ b/jupyter_server/services/kernels/kernelmanager.py @@ -440,7 +440,7 @@ async def _async_restart_kernel(self, kernel_id, now=False): kernel = self.get_kernel(kernel_id) # return a Future that will resolve when the kernel has successfully restarted channel = kernel.connect_shell() - future: Future = Future() + future: Future[Any] = Future() def finish(): """Common cleanup when restart finishes/fails for any reason.""" @@ -710,7 +710,7 @@ def __init__(self, **kwargs): self.last_kernel_activity = utcnow() -def emit_kernel_action_event(success_msg: str = "") -> t.Callable: +def emit_kernel_action_event(success_msg: str = "") -> t.Callable[..., t.Any]: """Decorate kernel action methods to begin emitting jupyter kernel action events. diff --git a/jupyter_server/services/sessions/sessionmanager.py b/jupyter_server/services/sessions/sessionmanager.py index b20f4c98de..6e4ebd3cac 100644 --- a/jupyter_server/services/sessions/sessionmanager.py +++ b/jupyter_server/services/sessions/sessionmanager.py @@ -5,7 +5,7 @@ import os import pathlib import uuid -from typing import Any, Dict, List, NewType, Optional, Union +from typing import Any, Dict, List, NewType, Optional, Union, cast KernelName = NewType("KernelName", str) ModelName = NewType("ModelName", str) @@ -293,7 +293,7 @@ async def create_session( session_id, path=path, name=name, type=type, kernel_id=kernel_id ) self._pending_sessions.remove(record) - return result + return cast(Dict[str, Any], result) def get_kernel_env( self, path: Optional[str], name: Optional[ModelName] = None @@ -347,7 +347,7 @@ async def start_kernel_for_session( kernel_name=kernel_name, env=kernel_env, ) - return kernel_id + return cast(str, kernel_id) async def save_session(self, session_id, path=None, name=None, type=None, kernel_id=None): """Saves the items for the session with the given session_id diff --git a/jupyter_server/traittypes.py b/jupyter_server/traittypes.py index bd6f28a36b..cfa3a8720e 100644 --- a/jupyter_server/traittypes.py +++ b/jupyter_server/traittypes.py @@ -6,7 +6,7 @@ from traitlets.utils.descriptions import describe -class TypeFromClasses(ClassBasedTraitType): +class TypeFromClasses(ClassBasedTraitType): # type:ignore[type-arg] """A trait whose value must be a subclass of a class in a specified list of classes.""" default_value: Any @@ -125,7 +125,7 @@ def default_value_repr(self): return repr(f"{value.__module__}.{value.__name__}") -class InstanceFromClasses(ClassBasedTraitType): +class InstanceFromClasses(ClassBasedTraitType): # type:ignore[type-arg] """A trait whose value must be an instance of a class in a specified list of classes. The value can also be an instance of a subclass of the specified classes. Subclasses can declare default classes by overriding the klass attribute diff --git a/jupyter_server/utils.py b/jupyter_server/utils.py index 5801eb5f18..1ff4979cc3 100644 --- a/jupyter_server/utils.py +++ b/jupyter_server/utils.py @@ -157,7 +157,7 @@ def check_version(v: str, check: str) -> bool: Users on dev branches are responsible for keeping their own packages up to date. """ try: - return Version(v) >= Version(check) + return bool(Version(v) >= Version(check)) except TypeError: return True diff --git a/pyproject.toml b/pyproject.toml index 523e2fbab9..1d82870cb2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,7 +40,7 @@ dependencies = [ "tornado>=6.2.0", "traitlets>=5.6.0", "websocket-client", - "jupyter_events>=0.6.0", + "jupyter_events>=0.9.0", "overrides" ] @@ -206,6 +206,8 @@ ignore = [ "PLR0912", # RUF012 Mutable class attributes should be annotated with `typing.ClassVar` "RUF012", + # Use `X | Y` for type annotations + "UP007", ] unfixable = [ # Don't touch print statements @@ -300,7 +302,7 @@ strict = true pretty = true show_error_codes = true warn_unreachable = true -disable_error_code = ["no-untyped-def", "no-untyped-call", "type-arg", "no-any-return"] +disable_error_code = ["no-untyped-def", "no-untyped-call"] enable_error_code = ["ignore-without-code", "redundant-expr", "truthy-bool"] [tool.interrogate]