62 lines
2.2 KiB
Python
62 lines
2.2 KiB
Python
import inspect
|
|
from functools import wraps
|
|
from typing import Callable, Any, Tuple, Optional
|
|
from fastapi import Request
|
|
from fastapi.responses import RedirectResponse
|
|
from passlib.context import CryptContext
|
|
|
|
# Accept Argon2 (preferred) and legacy bcrypt; new hashes will be Argon2.
|
|
pwd_context = CryptContext(
|
|
schemes=["argon2", "bcrypt"],
|
|
deprecated="auto",
|
|
)
|
|
|
|
def hash_password(p: str) -> str:
|
|
return pwd_context.hash(p)
|
|
|
|
def verify_password(p: str, hashed: str) -> bool:
|
|
return pwd_context.verify(p, hashed)
|
|
|
|
def verify_and_update_password(p: str, hashed: str) -> Tuple[bool, Optional[str]]:
|
|
"""
|
|
Returns (verified, new_hash). If verified is True and new_hash is not None,
|
|
caller should persist the new_hash (Argon2) to upgrade legacy bcrypt.
|
|
"""
|
|
try:
|
|
return pwd_context.verify_and_update(p, hashed)
|
|
except Exception:
|
|
return False, None
|
|
|
|
def _extract_request(args, kwargs) -> Optional[Request]:
|
|
req: Optional[Request] = kwargs.get("request")
|
|
if isinstance(req, Request):
|
|
return req
|
|
for a in args:
|
|
if isinstance(a, Request):
|
|
return a
|
|
return None
|
|
|
|
def login_required(endpoint: Callable[..., Any]):
|
|
"""
|
|
Decorator that supports both sync and async FastAPI endpoints.
|
|
Redirects to /login when no session is present.
|
|
"""
|
|
if inspect.iscoroutinefunction(endpoint):
|
|
@wraps(endpoint)
|
|
async def async_wrapper(*args, **kwargs):
|
|
request = _extract_request(args, kwargs)
|
|
if not request or not request.session.get("user_id"):
|
|
return RedirectResponse(url="/login", status_code=303)
|
|
return await endpoint(*args, **kwargs)
|
|
return async_wrapper
|
|
else:
|
|
@wraps(endpoint)
|
|
def sync_wrapper(*args, **kwargs):
|
|
request = _extract_request(args, kwargs)
|
|
if not request or not request.session.get("user_id"):
|
|
return RedirectResponse(url="/login", status_code=303)
|
|
return endpoint(*args, **kwargs)
|
|
return sync_wrapper
|
|
|
|
def get_current_user(request: Request):
|
|
return {"id": request.session.get("user_id"), "username": request.session.get("username")} |