Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion backend/app/api/v1/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Versioned API router."""
from fastapi import APIRouter
from .arxiv import router as arxiv_router
from app.api.v1.endpoints import health, papers, users, auth, academic, conversations
from app.api.v1.endpoints import health, papers, users, auth, academic, conversations, library

api_router = APIRouter()
api_router.include_router(health.router, prefix="/health", tags=["health"])
Expand All @@ -11,3 +11,4 @@
api_router.include_router(academic.router, prefix="/academic", tags=["academic"])
api_router.include_router(arxiv_router,prefix="/arxiv",tags=["arxiv"])
api_router.include_router(conversations.router, prefix="/conversations", tags=["conversations"])
api_router.include_router(library.router, prefix="/library", tags=["library"])
10 changes: 8 additions & 2 deletions backend/app/api/v1/api.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""API v1 router aggregation."""
from fastapi import APIRouter

from app.api.v1.endpoints import academic, auth, conversations, papers, users
from app.api.v1.endpoints import academic, auth, conversations, library, papers, users

# 创建主路由器
api_router = APIRouter()
Expand All @@ -25,6 +25,12 @@
tags=["论文检索"]
)

api_router.include_router(
library.router,
prefix="/library",
tags=["我的文库"]
)

api_router.include_router(
academic.router,
prefix="/academic",
Expand All @@ -35,4 +41,4 @@
conversations.router,
prefix="/conversations",
tags=["对话历史"]
)
)
210 changes: 206 additions & 4 deletions backend/app/api/v1/endpoints/auth.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,33 @@
"""认证授权相关API接口"""
from fastapi import APIRouter, Depends, HTTPException, status
from __future__ import annotations

import json
import secrets
from typing import cast
from urllib.parse import urlencode

import httpx
from fastapi import APIRouter, Depends, HTTPException, Request, status
from fastapi.responses import RedirectResponse
from fastapi.security import OAuth2PasswordRequestForm
from sqlalchemy.ext.asyncio import AsyncSession
from typing import cast

from app.core.auth import create_access_token
from app.core.security import verify_password
from app.core.config import get_settings
from app.core.security import hash_password, verify_password
from app.db import UserRepository
from app.db.session import get_db
from app.schemas.auth import Token # type: ignore[import-not-found]

router = APIRouter()
settings = get_settings()

GITHUB_AUTHORIZE_URL = "https://github.com/login/oauth/authorize"
GITHUB_TOKEN_URL = "https://github.com/login/oauth/access_token"
GITHUB_USER_API = "https://api.github.com/user"
GITHUB_EMAILS_API = "https://api.github.com/user/emails"
GITHUB_STATE_COOKIE = "github_oauth_state"
GITHUB_STATE_TTL = 600


@router.post(
Expand Down Expand Up @@ -99,4 +116,189 @@ async def login_for_access_token(
)

token = create_access_token(subject=cast(str, getattr(user, "email", "")))
return Token(access_token=token)
return Token(access_token=token)


@router.get("/github/login")
async def github_login(request: Request, next: str | None = None):
client_id = settings.github_client_id
client_secret = settings.github_client_secret
if not client_id or not client_secret:
raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="GitHub OAuth 未配置")

nonce = secrets.token_urlsafe(32)
cookie_payload = json.dumps(
{
"nonce": nonce,
"next": _sanitize_next_path(next),
}
)

callback_url = str(request.url_for("github_callback"))
params = {
"client_id": client_id,
"redirect_uri": callback_url,
"scope": "read:user user:email",
"state": nonce,
"allow_signup": "true",
}
authorize_url = f"{GITHUB_AUTHORIZE_URL}?{urlencode(params)}"
response = RedirectResponse(authorize_url, status_code=status.HTTP_302_FOUND)
response.set_cookie(
GITHUB_STATE_COOKIE,
cookie_payload,
max_age=GITHUB_STATE_TTL,
httponly=True,
secure=_is_cookie_secure(),
samesite="lax",
)
return response


@router.get("/github/callback", name="github_callback")
async def github_callback(
request: Request,
code: str | None = None,
state: str | None = None,
db: AsyncSession = Depends(get_db),
):
cookie_payload = request.cookies.get(GITHUB_STATE_COOKIE)
if not cookie_payload:
return _oauth_error_redirect("Missing OAuth state cookie.")

try:
payload = json.loads(cookie_payload)
except json.JSONDecodeError:
return _oauth_error_redirect("Invalid OAuth state payload.")

expected_state = payload.get("nonce")
next_path = _sanitize_next_path(payload.get("next"))
if not code or not state or expected_state != state:
return _oauth_error_redirect("OAuth state mismatch.", next_path=next_path)

try:
callback_url = str(request.url_for("github_callback"))
token_data = await _exchange_github_code_for_token(code, callback_url)
access_token = token_data.get("access_token")
if not access_token:
raise RuntimeError("GitHub did not return an access token.")

github_user, primary_email = await _fetch_github_profile(access_token)
if not primary_email:
raise RuntimeError("未能获取 GitHub 邮箱,请在 GitHub 账户开启公开邮箱或授权 email scope。")

github_user_id = str(github_user.get("id"))
repo = UserRepository(db)
user = await repo.get_by_oauth_account("github", github_user_id)

if user is None:
user = await repo.get_by_email(primary_email)
if user:
user = await repo.update(
user,
{
"oauth_provider": "github",
"oauth_account_id": github_user_id,
"avatar_url": user.avatar_url or github_user.get("avatar_url"),
},
)
else:
hashed_password = hash_password(secrets.token_urlsafe(32))
user = await repo.create(
email=primary_email,
hashed_password=hashed_password,
full_name=github_user.get("name") or github_user.get("login"),
avatar_url=github_user.get("avatar_url"),
oauth_provider="github",
oauth_account_id=github_user_id,
)
else:
updates: dict[str, str | None] = {}
if not user.avatar_url and github_user.get("avatar_url"):
updates["avatar_url"] = github_user.get("avatar_url")
updates["oauth_provider"] = "github"
updates["oauth_account_id"] = github_user_id
user = await repo.update(user, updates)

await db.commit()
await db.refresh(user)

token = create_access_token(subject=cast(str, getattr(user, "email", "")))
redirect_target = _build_frontend_redirect(token=token, next_path=next_path)
response = RedirectResponse(redirect_target, status_code=status.HTTP_302_FOUND)
response.delete_cookie(GITHUB_STATE_COOKIE)
return response
except Exception as exc: # pragma: no cover - defensive
return _oauth_error_redirect(str(exc), next_path=next_path)


def _sanitize_next_path(value: str | None) -> str:
if not value or not value.startswith("/"):
return "/"
return value


def _is_cookie_secure() -> bool:
return settings.environment not in {"local", "development"}


async def _exchange_github_code_for_token(code: str, redirect_uri: str) -> dict[str, str]:
client_id = settings.github_client_id
client_secret = settings.github_client_secret
if not client_id or not client_secret:
raise RuntimeError("GitHub OAuth 未配置。")

data = {
"client_id": client_id,
"client_secret": client_secret,
"code": code,
"redirect_uri": redirect_uri,
}

async with httpx.AsyncClient(timeout=15.0, headers={"Accept": "application/json"}) as client:
response = await client.post(GITHUB_TOKEN_URL, data=data)
response.raise_for_status()
return response.json()


async def _fetch_github_profile(access_token: str) -> tuple[dict[str, str], str | None]:
headers = {
"Accept": "application/json",
"Authorization": f"Bearer {access_token}",
}
async with httpx.AsyncClient(timeout=15.0) as client:
user_resp = await client.get(GITHUB_USER_API, headers=headers)
user_resp.raise_for_status()
user_data = user_resp.json()
email = user_data.get("email")
if not email:
emails_resp = await client.get(GITHUB_EMAILS_API, headers=headers)
emails_resp.raise_for_status()
emails = emails_resp.json()
email = next(
(
item.get("email")
for item in emails
if item.get("primary") and item.get("verified")
),
None,
) or next((item.get("email") for item in emails if item.get("verified")), None)
return user_data, email


def _build_frontend_redirect(*, token: str | None, next_path: str) -> str:
params = {"next": _sanitize_next_path(next_path)}
if token:
params["token"] = token
params["provider"] = "github"
return f"{settings.frontend_oauth_redirect_url}?{urlencode(params)}"


def _oauth_error_redirect(message: str, *, next_path: str = "/") -> RedirectResponse:
params = {
"error": message,
"next": _sanitize_next_path(next_path),
}
response = RedirectResponse(f"{settings.frontend_oauth_redirect_url}?{urlencode(params)}", status_code=status.HTTP_302_FOUND)
response.delete_cookie(GITHUB_STATE_COOKIE)
return response
Loading
Loading