110 lines
4.5 KiB
Python
110 lines
4.5 KiB
Python
from datetime import datetime, timedelta, timezone
|
|
from fastapi import Request, HTTPException, Depends, Response
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from sqlalchemy.future import select
|
|
from jose import jwt, JWTError
|
|
from .models import Session as SessionModel, User
|
|
from .security import create_access_token, verify_access_token, SECRET_KEY, ALGORITHM, ACCESS_TOKEN_EXPIRE_MINUTES
|
|
from .database import get_db
|
|
import secrets
|
|
|
|
async def get_current_user_by_session_token(
|
|
request: Request,
|
|
db: AsyncSession = Depends(get_db),
|
|
response: Response = None
|
|
):
|
|
session_token = request.cookies.get("session_token")
|
|
if not session_token:
|
|
# Clear both session and access tokens
|
|
if response:
|
|
response.delete_cookie("session_token")
|
|
response.delete_cookie("access_token")
|
|
raise HTTPException(status_code=401, detail="Session token missing")
|
|
|
|
# Check if session exists in the database and is valid
|
|
query = select(SessionModel).filter(SessionModel.sessionToken == session_token)
|
|
result = await db.execute(query)
|
|
session = result.scalars().first()
|
|
|
|
if not session or session.expires < datetime.now(timezone.utc)():
|
|
# Session invalid or expired, clear cookies
|
|
if response:
|
|
response.delete_cookie("session_token")
|
|
response.delete_cookie("access_token")
|
|
raise HTTPException(status_code=401, detail="Session expired or invalid")
|
|
|
|
# Optionally refresh session if about to expire
|
|
if session.expires - timedelta(minutes=5) < datetime.now(timezone.utc)():
|
|
new_session_token = secrets.token_hex(32)
|
|
session.sessionToken = new_session_token
|
|
session.expires = datetime.now(timezone.utc) + timedelta(hours=12)
|
|
await db.commit()
|
|
|
|
# Set new session token cookie if response is passed
|
|
if response:
|
|
response.set_cookie(key="session_token", value=new_session_token, httponly=True, max_age=12*60*60)
|
|
|
|
# Fetch the user associated with the session
|
|
user_query = select(User).filter(User.id == session.userId)
|
|
result = await db.execute(user_query)
|
|
user = result.scalars().first()
|
|
|
|
if not user:
|
|
# User not found, clear tokens
|
|
if response:
|
|
response.delete_cookie("session_token")
|
|
response.delete_cookie("access_token")
|
|
raise HTTPException(status_code=401, detail="User not found")
|
|
|
|
return user
|
|
|
|
# Middleware to refresh the JWT token if close to expiry
|
|
async def token_refresh_middleware(
|
|
request: Request,
|
|
db: AsyncSession = Depends(get_db),
|
|
response: Response = None
|
|
):
|
|
access_token = request.cookies.get("access_token")
|
|
|
|
if access_token:
|
|
try:
|
|
user_id = verify_access_token(access_token)
|
|
# Retrieve token expiration details
|
|
payload = jwt.decode(access_token, SECRET_KEY, algorithms=[ALGORITHM])
|
|
token_expiration = datetime.fromtimestamp(payload["exp"], tz=timezone.utc)
|
|
|
|
if token_expiration - timedelta(minutes=5) < datetime.now(timezone.utc):
|
|
# Refresh token
|
|
new_access_token = create_access_token({"user_id": user_id}, timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES))
|
|
if response:
|
|
response.set_cookie(key="access_token", value=new_access_token, httponly=True, max_age=ACCESS_TOKEN_EXPIRE_MINUTES * 60)
|
|
|
|
# Validate the session token and return user
|
|
return await get_current_user_by_session_token(request, db, response)
|
|
|
|
except JWTError:
|
|
if response:
|
|
response.delete_cookie("access_token")
|
|
response.delete_cookie("session_token")
|
|
raise HTTPException(status_code=401, detail="Token validation failed")
|
|
|
|
# If no JWT token, fallback to session-based authentication
|
|
return await get_current_user_by_session_token(request, db, response)
|
|
|
|
# Logout: Clear both session and JWT tokens
|
|
async def logout(request: Request, db: AsyncSession = Depends(get_db), response: Response = None):
|
|
session_token = request.cookies.get("session_token")
|
|
if session_token:
|
|
query = select(SessionModel).filter(SessionModel.sessionToken == session_token)
|
|
result = await db.execute(query)
|
|
session = result.scalars().first()
|
|
|
|
if session:
|
|
await db.delete(session)
|
|
await db.commit()
|
|
|
|
response.delete_cookie("session_token")
|
|
response.delete_cookie("access_token")
|
|
|
|
return {"message": "Logged out successfully"}
|