From dca14861de2ecb60fa7c73579590577f075a8705 Mon Sep 17 00:00:00 2001 From: Kaan Koc Date: Thu, 17 Oct 2024 05:24:56 +0000 Subject: [PATCH] Added and fixed Create auction API endpoint. Added middleware, session-based authentication and logout endpoint --- ..._added_basemodel_for_vehiclecreate_and_.py | 30 +++++ app/main.py | 3 +- app/middleware.py | 109 +++++++++++++++ app/models.py | 26 +++- app/routers/auctions.py | 127 ++++++++++++++++++ app/routers/auth.py | 53 +++++--- app/security.py | 52 ++++--- 7 files changed, 358 insertions(+), 42 deletions(-) create mode 100644 alembic/versions/e8f2d8b9dc30_added_basemodel_for_vehiclecreate_and_.py create mode 100644 app/middleware.py create mode 100644 app/routers/auctions.py diff --git a/alembic/versions/e8f2d8b9dc30_added_basemodel_for_vehiclecreate_and_.py b/alembic/versions/e8f2d8b9dc30_added_basemodel_for_vehiclecreate_and_.py new file mode 100644 index 0000000..4022d65 --- /dev/null +++ b/alembic/versions/e8f2d8b9dc30_added_basemodel_for_vehiclecreate_and_.py @@ -0,0 +1,30 @@ +"""Added BaseModel for VehicleCreate and AuctionCreate + +Revision ID: e8f2d8b9dc30 +Revises: 1a457651bb36 +Create Date: 2024-10-16 19:09:45.413335 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = 'e8f2d8b9dc30' +down_revision: Union[str, None] = '1a457651bb36' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + pass + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + pass + # ### end Alembic commands ### diff --git a/app/main.py b/app/main.py index a229f52..83c93b3 100644 --- a/app/main.py +++ b/app/main.py @@ -1,7 +1,7 @@ from fastapi import FastAPI from sqlalchemy.ext.asyncio import AsyncEngine from contextlib import asynccontextmanager -from app.routers import auth # Assuming you have a router for auth logic +from app.routers import auth, auctions # Assuming you have a router for auth logic from app.database import engine from app.models import Base @@ -23,6 +23,7 @@ app = FastAPI(lifespan=lifespan) # Register your API routes app.include_router(auth.router) +app.include_router(auctions.router) @app.get("/") async def root(): diff --git a/app/middleware.py b/app/middleware.py new file mode 100644 index 0000000..81c3498 --- /dev/null +++ b/app/middleware.py @@ -0,0 +1,109 @@ +from datetime import datetime, timedelta, timezone +from fastapi import Request, HTTPException, Depends, Response +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.future import select +from jose import jwt, JWTError +from .models import Session as SessionModel, User +from .security import create_access_token, verify_access_token, SECRET_KEY, ALGORITHM, ACCESS_TOKEN_EXPIRE_MINUTES +from .database import get_db +import secrets + +async def get_current_user_by_session_token( + request: Request, + db: AsyncSession = Depends(get_db), + response: Response = None +): + session_token = request.cookies.get("session_token") + if not session_token: + # Clear both session and access tokens + if response: + response.delete_cookie("session_token") + response.delete_cookie("access_token") + raise HTTPException(status_code=401, detail="Session token missing") + + # Check if session exists in the database and is valid + query = select(SessionModel).filter(SessionModel.sessionToken == session_token) + result = await db.execute(query) + session = result.scalars().first() + + if not session or session.expires < datetime.now(timezone.utc)(): + # Session invalid or expired, clear cookies + if response: + response.delete_cookie("session_token") + response.delete_cookie("access_token") + raise HTTPException(status_code=401, detail="Session expired or invalid") + + # Optionally refresh session if about to expire + if session.expires - timedelta(minutes=5) < datetime.now(timezone.utc)(): + new_session_token = secrets.token_hex(32) + session.sessionToken = new_session_token + session.expires = datetime.now(timezone.utc) + timedelta(hours=12) + await db.commit() + + # Set new session token cookie if response is passed + if response: + response.set_cookie(key="session_token", value=new_session_token, httponly=True, max_age=12*60*60) + + # Fetch the user associated with the session + user_query = select(User).filter(User.id == session.userId) + result = await db.execute(user_query) + user = result.scalars().first() + + if not user: + # User not found, clear tokens + if response: + response.delete_cookie("session_token") + response.delete_cookie("access_token") + raise HTTPException(status_code=401, detail="User not found") + + return user + +# Middleware to refresh the JWT token if close to expiry +async def token_refresh_middleware( + request: Request, + db: AsyncSession = Depends(get_db), + response: Response = None +): + access_token = request.cookies.get("access_token") + + if access_token: + try: + user_id = verify_access_token(access_token) + # Retrieve token expiration details + payload = jwt.decode(access_token, SECRET_KEY, algorithms=[ALGORITHM]) + token_expiration = datetime.fromtimestamp(payload["exp"], tz=timezone.utc) + + if token_expiration - timedelta(minutes=5) < datetime.now(timezone.utc): + # Refresh token + new_access_token = create_access_token({"user_id": user_id}, timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)) + if response: + response.set_cookie(key="access_token", value=new_access_token, httponly=True, max_age=ACCESS_TOKEN_EXPIRE_MINUTES * 60) + + # Validate the session token and return user + return await get_current_user_by_session_token(request, db, response) + + except JWTError: + if response: + response.delete_cookie("access_token") + response.delete_cookie("session_token") + raise HTTPException(status_code=401, detail="Token validation failed") + + # If no JWT token, fallback to session-based authentication + return await get_current_user_by_session_token(request, db, response) + +# Logout: Clear both session and JWT tokens +async def logout(request: Request, db: AsyncSession = Depends(get_db), response: Response = None): + session_token = request.cookies.get("session_token") + if session_token: + query = select(SessionModel).filter(SessionModel.sessionToken == session_token) + result = await db.execute(query) + session = result.scalars().first() + + if session: + await db.delete(session) + await db.commit() + + response.delete_cookie("session_token") + response.delete_cookie("access_token") + + return {"message": "Logged out successfully"} diff --git a/app/models.py b/app/models.py index eb6f8db..66bd7ea 100644 --- a/app/models.py +++ b/app/models.py @@ -3,7 +3,7 @@ from sqlalchemy.orm import relationship from sqlalchemy.ext.declarative import as_declarative, declared_attr from sqlalchemy.sql import func from pydantic import BaseModel -from typing import Optional +from typing import Optional, List @as_declarative() class Base: @@ -167,4 +167,26 @@ class UserCreate(BaseModel): class UserLogin(BaseModel): email: str - password: str \ No newline at end of file + password: str + +class VehicleCreate(BaseModel): + brand: str + model: str + variant: Optional[str] + year: int + kilometers: int + condition: str + location: str + latitude: Optional[float] + longitude: Optional[float] + gasType: str + images: str + description: str + service: str + inspectedAt: Optional[str] # ISO format for datetime + equipment_ids: List[int] # List of equipment IDs + +class AuctionCreate(BaseModel): + askingPrice: float + description: Optional[str] + vehicle: VehicleCreate \ No newline at end of file diff --git a/app/routers/auctions.py b/app/routers/auctions.py new file mode 100644 index 0000000..847184d --- /dev/null +++ b/app/routers/auctions.py @@ -0,0 +1,127 @@ +from fastapi import APIRouter, Depends, HTTPException, status, Request +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.future import select +from pydantic import BaseModel +from typing import List, Optional +from ..models import Auction, Vehicle, VehicleEquipment, Equipment, User +from ..database import get_db +from ..security import verify_access_token # Ensure this is imported correctly +from fastapi.logger import logger + + +router = APIRouter() + +# Define Pydantic models for data validation + +class VehicleCreate(BaseModel): + brand: str + model: str + variant: Optional[str] + year: int + kilometers: int + condition: str + location: str + latitude: Optional[float] + longitude: Optional[float] + gasType: str + images: str + description: str + service: str + inspectedAt: Optional[str] # ISO format for datetime + equipment_ids: List[int] # List of equipment IDs + + +class AuctionCreate(BaseModel): + askingPrice: float + description: Optional[str] + vehicle: VehicleCreate + +async def get_current_user_id(request: Request, db: AsyncSession = Depends(get_db)): + user_id = verify_access_token(request) + + # Fetch user from database to check their role + result = await db.execute(select(User).filter(User.id == user_id)) + user = result.scalars().first() + #print(f"\n user " + str(user.role)) + + if user is None: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="User not found." + ) + + if not user.role.PRIVATE: # Only allow private users to create auctions + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Only private users can create an auction." + ) + + logger.info(f"\nCurrent user ID: {user.id}\n Current user role: {user.role}\n") + #logger.debug(f"\nCurrent user ID: {user.id}\n Current user role: {user.role}\n") + return user.id + + +# API route to create an auction +@router.post("/api/v1/test") +async def testFuncForDB(request: Request,user_id: int = Depends(get_current_user_id), db: AsyncSession = Depends(get_db)): + print("HIIIIIIIIIIIIIIIIIIT") + result = await db.execute(select(User).filter(User.id == user_id)) + user = result.scalars().first() + if user: + email = user.email + else: + email = "User not found" + + return {"message": "Test function for DB", "email": email} + + +# API route to create an auction +@router.post("/api/v1/auction") +async def create_auction(auction_data: AuctionCreate, db: AsyncSession = Depends(get_db), user_id: int = Depends(get_current_user_id)): + # Create Vehicle + vehicle_data = auction_data.vehicle + vehicle = Vehicle( + brand=vehicle_data.brand, + model=vehicle_data.model, + variant=vehicle_data.variant, + year=vehicle_data.year, + kilometers=vehicle_data.kilometers, + condition=vehicle_data.condition, + location=vehicle_data.location, + latitude=vehicle_data.latitude, + longitude=vehicle_data.longitude, + gasType=vehicle_data.gasType, + images=vehicle_data.images, + description=vehicle_data.description, + service=vehicle_data.service, + inspectedAt=vehicle_data.inspectedAt, + ) + + # Add vehicle to the database + db.add(vehicle) + await db.commit() + await db.refresh(vehicle) + + # Add vehicle equipment + for equipment_id in vehicle_data.equipment_ids: + result = await db.execute(select(Equipment).filter(Equipment.id == equipment_id)) + equipment = result.scalars().first() + if not equipment: + raise HTTPException(status_code=404, detail=f"Equipment with ID {equipment_id} not found") + vehicle_equipment = VehicleEquipment(vehicle_id=vehicle.id, equipment_id=equipment.id) + db.add(vehicle_equipment) + + # Create Auction + auction = Auction( + vehicleId=vehicle.id, + userId=user_id, # This comes from the authenticated user + askingPrice=auction_data.askingPrice, + description=auction_data.description, + ) + + # Add auction to the database + db.add(auction) + await db.commit() + await db.refresh(auction) + + return {"message": "Auction created successfully", "auction": auction, "vehicle": vehicle} \ No newline at end of file diff --git a/app/routers/auth.py b/app/routers/auth.py index 9e663f6..90fe294 100644 --- a/app/routers/auth.py +++ b/app/routers/auth.py @@ -1,11 +1,14 @@ -from fastapi import APIRouter, Depends, HTTPException +from fastapi import APIRouter, Depends, HTTPException, Response, Request +from datetime import datetime, timedelta, timezone +from ..middleware import token_refresh_middleware, logout from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select from ..database import get_db -from ..models import User, UserCreate, UserLogin -from ..security import create_access_token, verify_access_token +from ..models import User, UserCreate, UserLogin, Session as SessionModel +from ..security import create_access_token, verify_password from fastapi.security import OAuth2PasswordBearer import bcrypt +import secrets oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") @@ -53,30 +56,40 @@ async def register(user_data: UserCreate, db: AsyncSession = Depends(get_db)): # User login @router.post("/api/v1/login") -async def login(login_data: UserLogin, db: AsyncSession = Depends(get_db)): +async def login(login_data: UserLogin, db: AsyncSession = Depends(get_db), response: Response = None): async with db.begin(): - # Check if email is an email query = select(User).filter(User.email == login_data.email) - result = await db.execute(query) user = result.scalars().first() - if not user or not bcrypt.checkpw(login_data.password.encode('utf-8'), user.password.encode('utf-8')): + if not user or not verify_password(login_data.password, user.password): raise HTTPException(status_code=400, detail="Invalid credentials") + # Create session token + session_token = secrets.token_hex(32) + session_expiry = datetime.now(timezone.utc) + timedelta(hours=12) + + new_session = SessionModel(sessionToken=session_token, userId=user.id, expires=session_expiry) + db.add(new_session) + + # Create JWT access token access_token = create_access_token(data={"user_id": user.id}) - return {"access_token": access_token, "token_type": "bearer"} + + await db.commit() + + # Set session and access tokens after transaction is committed + response.set_cookie(key="session_token", value=session_token, httponly=True, max_age=12*60*60) + response.set_cookie(key="access_token", value=access_token, httponly=True, max_age=60*60) + + return {"message": "Login successful"} + +# Logout user +@router.post("/api/v1/logout") +async def logout_user(request: Request, response: Response, db: AsyncSession = Depends(get_db)): + return await logout(request, db, response) + -# Protected route example +# Protected route example using middleware @router.get("/api/v1/protected") -async def protected_route(token: str = Depends(oauth2_scheme), db: AsyncSession = Depends(get_db)): - user_id = verify_access_token(token) - async with db.begin(): - query = select(User).filter(User.id == user_id) - result = await db.execute(query) - user = result.scalars().first() - - if not user: - raise HTTPException(status_code=401, detail="User not found") - - return {"message": f"Hello, {user.name}"} +async def protected_route(user: User = Depends(token_refresh_middleware)): + return {"message": f"Hello, {user.name}"} \ No newline at end of file diff --git a/app/security.py b/app/security.py index e2fd628..6e0d812 100644 --- a/app/security.py +++ b/app/security.py @@ -1,13 +1,19 @@ import bcrypt from jose import jwt, JWTError -from datetime import datetime, timedelta -from fastapi import HTTPException, status +from datetime import datetime, timedelta, timezone +from fastapi import HTTPException, status, Request import os +from pydantic import BaseModel +from typing import Optional # Secret and algorithm for JWT SECRET_KEY = os.getenv('SECRET_KEY', 'your_jwt_secret_key') # Ensure this is set in your environment ALGORITHM = "HS256" -ACCESS_TOKEN_EXPIRE_MINUTES = 30 +ACCESS_TOKEN_EXPIRE_MINUTES = 60 * 24 * 14 + + +class TokenData(BaseModel): + user_id: Optional[int] = None # Hash password using bcrypt directly def get_password_hash(password: str) -> str: @@ -26,29 +32,37 @@ def create_access_token(data: dict, expires_delta: timedelta = None): """Creates a JWT token with expiration time.""" to_encode = data.copy() if expires_delta: - expire = datetime.utcnow() + expires_delta + expire = datetime.now(timezone.utc) + expires_delta else: - expire = datetime.utcnow() + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) + expire = datetime.now(timezone.utc) + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) to_encode.update({"exp": expire}) encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) return encoded_jwt -# Verify JWT token -def verify_access_token(token: str): - """Verifies the JWT token and returns the user_id if valid.""" +def verify_access_token(request: Request) -> int: + credentials_exception = HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Could not validate credentials", + headers={"WWW-Authenticate": "Bearer"}, + ) + + # First, check Authorization header (for cases where the JWT is passed in headers) + auth_header = request.headers.get("Authorization") + token = None + if auth_header and auth_header.startswith("Bearer "): + token = auth_header.split(" ")[1] + + # If no Authorization header, fallback to cookies + if not token: + token = request.cookies.get("access_token") + if not token: + raise credentials_exception + try: payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) - user_id: str = payload.get("user_id") + user_id: int = payload.get("user_id") if user_id is None: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Invalid token", - headers={"WWW-Authenticate": "Bearer"}, - ) + raise credentials_exception return user_id except JWTError: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Invalid token", - headers={"WWW-Authenticate": "Bearer"}, - ) \ No newline at end of file + raise credentials_exception \ No newline at end of file