from datetime import datetime, timedelta, timezone from fastapi import Request, HTTPException, Depends, Response from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select from jose import jwt, JWTError from .models import Session as SessionModel, User from .security import create_access_token, verify_access_token, SECRET_KEY, ALGORITHM, ACCESS_TOKEN_EXPIRE_MINUTES from .database import get_db import secrets async def get_current_user_by_session_token( request: Request, db: AsyncSession = Depends(get_db), response: Response = None ): session_token = request.cookies.get("session_token") if not session_token: # Clear both session and access tokens if response: response.delete_cookie("session_token") response.delete_cookie("access_token") raise HTTPException(status_code=401, detail="Session token missing") # Check if session exists in the database and is valid query = select(SessionModel).filter(SessionModel.sessionToken == session_token) result = await db.execute(query) session = result.scalars().first() if not session or session.expires < datetime.now(timezone.utc)(): # Session invalid or expired, clear cookies if response: response.delete_cookie("session_token") response.delete_cookie("access_token") raise HTTPException(status_code=401, detail="Session expired or invalid") # Optionally refresh session if about to expire if session.expires - timedelta(minutes=5) < datetime.now(timezone.utc)(): new_session_token = secrets.token_hex(32) session.sessionToken = new_session_token session.expires = datetime.now(timezone.utc) + timedelta(hours=12) await db.commit() # Set new session token cookie if response is passed if response: response.set_cookie(key="session_token", value=new_session_token, httponly=True, max_age=12*60*60) # Fetch the user associated with the session user_query = select(User).filter(User.id == session.userId) result = await db.execute(user_query) user = result.scalars().first() if not user: # User not found, clear tokens if response: response.delete_cookie("session_token") response.delete_cookie("access_token") raise HTTPException(status_code=401, detail="User not found") return user # Middleware to refresh the JWT token if close to expiry async def token_refresh_middleware( request: Request, db: AsyncSession = Depends(get_db), response: Response = None ): access_token = request.cookies.get("access_token") if access_token: try: user_id = verify_access_token(access_token) # Retrieve token expiration details payload = jwt.decode(access_token, SECRET_KEY, algorithms=[ALGORITHM]) token_expiration = datetime.fromtimestamp(payload["exp"], tz=timezone.utc) if token_expiration - timedelta(minutes=5) < datetime.now(timezone.utc): # Refresh token new_access_token = create_access_token({"user_id": user_id}, timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)) if response: response.set_cookie(key="access_token", value=new_access_token, httponly=True, max_age=ACCESS_TOKEN_EXPIRE_MINUTES * 60) # Validate the session token and return user return await get_current_user_by_session_token(request, db, response) except JWTError: if response: response.delete_cookie("access_token") response.delete_cookie("session_token") raise HTTPException(status_code=401, detail="Token validation failed") # If no JWT token, fallback to session-based authentication return await get_current_user_by_session_token(request, db, response) # Logout: Clear both session and JWT tokens async def logout(request: Request, db: AsyncSession = Depends(get_db), response: Response = None): session_token = request.cookies.get("session_token") if session_token: query = select(SessionModel).filter(SessionModel.sessionToken == session_token) result = await db.execute(query) session = result.scalars().first() if session: await db.delete(session) await db.commit() response.delete_cookie("session_token") response.delete_cookie("access_token") return {"message": "Logged out successfully"}