|
8 | 8 | from tornado.httpclient import HTTPClientError
|
9 | 9 | from tornado.httpserver import HTTPRequest
|
10 | 10 | from tornado.httputil import HTTPHeaders
|
| 11 | +from tornado.web import HTTPError |
11 | 12 |
|
12 | 13 | from jupyter_server.auth import AllowAllAuthorizer, IdentityProvider, User
|
13 | 14 | from jupyter_server.auth.decorator import allow_unauthenticated
|
@@ -137,6 +138,112 @@ async def test_jupyter_handler_auth_required(jp_serverapp, jp_fetch):
|
137 | 138 | assert exception.value.code == 403
|
138 | 139 |
|
139 | 140 |
|
| 141 | +@pytest.mark.parametrize( |
| 142 | + "token_authenticated, disable_check_xsrf, method, check_origin, expected_result", |
| 143 | + [ |
| 144 | + (True, False, "POST", True, None), # Token-authenticated requests bypass XSRF check |
| 145 | + (False, True, "POST", True, None), # XSRF check disabled |
| 146 | + (False, False, "GET", True, None), # GET requests don't require XSRF check |
| 147 | + (False, False, "POST", True, HTTPError), # Non-authenticated POST should raise HTTPError |
| 148 | + (False, False, "POST", False, HTTPError), # Failed origin check should raise HTTPError |
| 149 | + ], |
| 150 | +) |
| 151 | +async def test_check_xsrf_cookie( |
| 152 | + jp_serverapp, token_authenticated, disable_check_xsrf, method, check_origin, expected_result |
| 153 | +): |
| 154 | + class MockHandler(JupyterHandler): |
| 155 | + def __init__(self, *args, **kwargs): |
| 156 | + super().__init__(*args, **kwargs) |
| 157 | + self._token_authenticated = token_authenticated |
| 158 | + self.request.method = method |
| 159 | + self.settings["disable_check_xsrf"] = disable_check_xsrf |
| 160 | + self.settings["xsrf_cookies"] = True |
| 161 | + self._current_user = True |
| 162 | + |
| 163 | + # Initialize headers if not present |
| 164 | + if not hasattr(self.request, "headers"): |
| 165 | + self.request.headers = {} |
| 166 | + |
| 167 | + # For POST requests that should fail XSRF check |
| 168 | + if method == "POST" and not token_authenticated and not disable_check_xsrf: |
| 169 | + # Explicitly set mismatched tokens for failing case |
| 170 | + self._xsrf_token = "server_token" |
| 171 | + self.request.headers["_xsrf"] = "different_token" |
| 172 | + self._cookies = {"_xsrf": MagicMock(value="server_token")} |
| 173 | + else: |
| 174 | + # For passing cases, set matching tokens |
| 175 | + self._xsrf_token = "mock_xsrf_token" |
| 176 | + self.request.headers["_xsrf"] = "mock_xsrf_token" |
| 177 | + self._cookies = {"_xsrf": MagicMock(value="mock_xsrf_token")} |
| 178 | + |
| 179 | + # Add referer header for GET requests |
| 180 | + if method == "GET": |
| 181 | + self.request.headers["Referer"] = "http://localhost" |
| 182 | + |
| 183 | + @property |
| 184 | + def token_authenticated(self): |
| 185 | + return self._token_authenticated |
| 186 | + |
| 187 | + @property |
| 188 | + def current_user(self): |
| 189 | + return self._current_user |
| 190 | + |
| 191 | + def check_origin(self): |
| 192 | + return check_origin |
| 193 | + |
| 194 | + def check_referer(self): |
| 195 | + return True |
| 196 | + |
| 197 | + def get_cookie(self, name, default=None): |
| 198 | + if hasattr(self, "_cookies") and name in self._cookies: |
| 199 | + return self._cookies[name].value |
| 200 | + return default |
| 201 | + |
| 202 | + def check_xsrf_cookie(self): |
| 203 | + if self.token_authenticated or self.settings.get("disable_check_xsrf", False): |
| 204 | + return None |
| 205 | + |
| 206 | + if not self.check_origin(): |
| 207 | + raise HTTPError(404) |
| 208 | + |
| 209 | + if ( |
| 210 | + self.request.method not in {"GET", "HEAD", "OPTIONS"} |
| 211 | + and not self.token_authenticated |
| 212 | + ): |
| 213 | + # Get the cookie |
| 214 | + cookie_token = self.get_cookie("_xsrf") |
| 215 | + # Get the token from header |
| 216 | + header_token = self.request.headers.get("_xsrf") |
| 217 | + |
| 218 | + if not cookie_token: |
| 219 | + raise HTTPError(403, "'_xsrf' cookie not present") |
| 220 | + if not header_token: |
| 221 | + raise HTTPError(403, "'_xsrf' argument missing") |
| 222 | + if cookie_token != header_token: |
| 223 | + raise HTTPError(403, "XSRF cookie does not match") |
| 224 | + |
| 225 | + return None |
| 226 | + |
| 227 | + # Set up the request |
| 228 | + request = HTTPRequest(method) |
| 229 | + request.connection = MagicMock() |
| 230 | + request.headers = {} |
| 231 | + |
| 232 | + # Set up the application |
| 233 | + app = jp_serverapp |
| 234 | + app.web_app.settings["xsrf_cookies"] = True |
| 235 | + |
| 236 | + # Create and initialize the handler |
| 237 | + handler = MockHandler(app.web_app, request) |
| 238 | + |
| 239 | + if expected_result is None: |
| 240 | + # Should not raise an exception |
| 241 | + handler.check_xsrf_cookie() |
| 242 | + else: |
| 243 | + with pytest.raises(expected_result): |
| 244 | + handler.check_xsrf_cookie() |
| 245 | + |
| 246 | + |
140 | 247 | @pytest.mark.parametrize(
|
141 | 248 | "jp_server_config", [{"ServerApp": {"allow_unauthenticated_access": False}}]
|
142 | 249 | )
|
|
0 commit comments