Files
scrapAPI/app/middleware.py

110 lines
4.5 KiB
Python

from datetime import datetime, timedelta, timezone
from fastapi import Request, HTTPException, Depends, Response
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from jose import jwt, JWTError
from .models import Session as SessionModel, User
from .security import create_access_token, verify_access_token, SECRET_KEY, ALGORITHM, ACCESS_TOKEN_EXPIRE_MINUTES
from .database import get_db
import secrets
async def get_current_user_by_session_token(
request: Request,
db: AsyncSession = Depends(get_db),
response: Response = None
):
session_token = request.cookies.get("session_token")
if not session_token:
# Clear both session and access tokens
if response:
response.delete_cookie("session_token")
response.delete_cookie("access_token")
raise HTTPException(status_code=401, detail="Session token missing")
# Check if session exists in the database and is valid
query = select(SessionModel).filter(SessionModel.sessionToken == session_token)
result = await db.execute(query)
session = result.scalars().first()
if not session or session.expires < datetime.now(timezone.utc)():
# Session invalid or expired, clear cookies
if response:
response.delete_cookie("session_token")
response.delete_cookie("access_token")
raise HTTPException(status_code=401, detail="Session expired or invalid")
# Optionally refresh session if about to expire
if session.expires - timedelta(minutes=5) < datetime.now(timezone.utc)():
new_session_token = secrets.token_hex(32)
session.sessionToken = new_session_token
session.expires = datetime.now(timezone.utc) + timedelta(hours=12)
await db.commit()
# Set new session token cookie if response is passed
if response:
response.set_cookie(key="session_token", value=new_session_token, httponly=True, max_age=12*60*60)
# Fetch the user associated with the session
user_query = select(User).filter(User.id == session.userId)
result = await db.execute(user_query)
user = result.scalars().first()
if not user:
# User not found, clear tokens
if response:
response.delete_cookie("session_token")
response.delete_cookie("access_token")
raise HTTPException(status_code=401, detail="User not found")
return user
# Middleware to refresh the JWT token if close to expiry
async def token_refresh_middleware(
request: Request,
db: AsyncSession = Depends(get_db),
response: Response = None
):
access_token = request.cookies.get("access_token")
if access_token:
try:
user_id = verify_access_token(access_token)
# Retrieve token expiration details
payload = jwt.decode(access_token, SECRET_KEY, algorithms=[ALGORITHM])
token_expiration = datetime.fromtimestamp(payload["exp"], tz=timezone.utc)
if token_expiration - timedelta(minutes=5) < datetime.now(timezone.utc):
# Refresh token
new_access_token = create_access_token({"user_id": user_id}, timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES))
if response:
response.set_cookie(key="access_token", value=new_access_token, httponly=True, max_age=ACCESS_TOKEN_EXPIRE_MINUTES * 60)
# Validate the session token and return user
return await get_current_user_by_session_token(request, db, response)
except JWTError:
if response:
response.delete_cookie("access_token")
response.delete_cookie("session_token")
raise HTTPException(status_code=401, detail="Token validation failed")
# If no JWT token, fallback to session-based authentication
return await get_current_user_by_session_token(request, db, response)
# Logout: Clear both session and JWT tokens
async def logout(request: Request, db: AsyncSession = Depends(get_db), response: Response = None):
session_token = request.cookies.get("session_token")
if session_token:
query = select(SessionModel).filter(SessionModel.sessionToken == session_token)
result = await db.execute(query)
session = result.scalars().first()
if session:
await db.delete(session)
await db.commit()
response.delete_cookie("session_token")
response.delete_cookie("access_token")
return {"message": "Logged out successfully"}