Added and fixed Create auction API endpoint. Added middleware, session-based authentication and logout endpoint

This commit is contained in:
2024-10-17 05:24:56 +00:00
parent fbbdad51c2
commit dca14861de
7 changed files with 358 additions and 42 deletions

View File

@ -0,0 +1,30 @@
"""Added BaseModel for VehicleCreate and AuctionCreate
Revision ID: e8f2d8b9dc30
Revises: 1a457651bb36
Create Date: 2024-10-16 19:09:45.413335
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = 'e8f2d8b9dc30'
down_revision: Union[str, None] = '1a457651bb36'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
pass
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
pass
# ### end Alembic commands ###

View File

@ -1,7 +1,7 @@
from fastapi import FastAPI
from sqlalchemy.ext.asyncio import AsyncEngine
from contextlib import asynccontextmanager
from app.routers import auth # Assuming you have a router for auth logic
from app.routers import auth, auctions # Assuming you have a router for auth logic
from app.database import engine
from app.models import Base
@ -23,6 +23,7 @@ app = FastAPI(lifespan=lifespan)
# Register your API routes
app.include_router(auth.router)
app.include_router(auctions.router)
@app.get("/")
async def root():

109
app/middleware.py Normal file
View File

@ -0,0 +1,109 @@
from datetime import datetime, timedelta, timezone
from fastapi import Request, HTTPException, Depends, Response
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from jose import jwt, JWTError
from .models import Session as SessionModel, User
from .security import create_access_token, verify_access_token, SECRET_KEY, ALGORITHM, ACCESS_TOKEN_EXPIRE_MINUTES
from .database import get_db
import secrets
async def get_current_user_by_session_token(
request: Request,
db: AsyncSession = Depends(get_db),
response: Response = None
):
session_token = request.cookies.get("session_token")
if not session_token:
# Clear both session and access tokens
if response:
response.delete_cookie("session_token")
response.delete_cookie("access_token")
raise HTTPException(status_code=401, detail="Session token missing")
# Check if session exists in the database and is valid
query = select(SessionModel).filter(SessionModel.sessionToken == session_token)
result = await db.execute(query)
session = result.scalars().first()
if not session or session.expires < datetime.now(timezone.utc)():
# Session invalid or expired, clear cookies
if response:
response.delete_cookie("session_token")
response.delete_cookie("access_token")
raise HTTPException(status_code=401, detail="Session expired or invalid")
# Optionally refresh session if about to expire
if session.expires - timedelta(minutes=5) < datetime.now(timezone.utc)():
new_session_token = secrets.token_hex(32)
session.sessionToken = new_session_token
session.expires = datetime.now(timezone.utc) + timedelta(hours=12)
await db.commit()
# Set new session token cookie if response is passed
if response:
response.set_cookie(key="session_token", value=new_session_token, httponly=True, max_age=12*60*60)
# Fetch the user associated with the session
user_query = select(User).filter(User.id == session.userId)
result = await db.execute(user_query)
user = result.scalars().first()
if not user:
# User not found, clear tokens
if response:
response.delete_cookie("session_token")
response.delete_cookie("access_token")
raise HTTPException(status_code=401, detail="User not found")
return user
# Middleware to refresh the JWT token if close to expiry
async def token_refresh_middleware(
request: Request,
db: AsyncSession = Depends(get_db),
response: Response = None
):
access_token = request.cookies.get("access_token")
if access_token:
try:
user_id = verify_access_token(access_token)
# Retrieve token expiration details
payload = jwt.decode(access_token, SECRET_KEY, algorithms=[ALGORITHM])
token_expiration = datetime.fromtimestamp(payload["exp"], tz=timezone.utc)
if token_expiration - timedelta(minutes=5) < datetime.now(timezone.utc):
# Refresh token
new_access_token = create_access_token({"user_id": user_id}, timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES))
if response:
response.set_cookie(key="access_token", value=new_access_token, httponly=True, max_age=ACCESS_TOKEN_EXPIRE_MINUTES * 60)
# Validate the session token and return user
return await get_current_user_by_session_token(request, db, response)
except JWTError:
if response:
response.delete_cookie("access_token")
response.delete_cookie("session_token")
raise HTTPException(status_code=401, detail="Token validation failed")
# If no JWT token, fallback to session-based authentication
return await get_current_user_by_session_token(request, db, response)
# Logout: Clear both session and JWT tokens
async def logout(request: Request, db: AsyncSession = Depends(get_db), response: Response = None):
session_token = request.cookies.get("session_token")
if session_token:
query = select(SessionModel).filter(SessionModel.sessionToken == session_token)
result = await db.execute(query)
session = result.scalars().first()
if session:
await db.delete(session)
await db.commit()
response.delete_cookie("session_token")
response.delete_cookie("access_token")
return {"message": "Logged out successfully"}

View File

@ -3,7 +3,7 @@ from sqlalchemy.orm import relationship
from sqlalchemy.ext.declarative import as_declarative, declared_attr
from sqlalchemy.sql import func
from pydantic import BaseModel
from typing import Optional
from typing import Optional, List
@as_declarative()
class Base:
@ -167,4 +167,26 @@ class UserCreate(BaseModel):
class UserLogin(BaseModel):
email: str
password: str
password: str
class VehicleCreate(BaseModel):
brand: str
model: str
variant: Optional[str]
year: int
kilometers: int
condition: str
location: str
latitude: Optional[float]
longitude: Optional[float]
gasType: str
images: str
description: str
service: str
inspectedAt: Optional[str] # ISO format for datetime
equipment_ids: List[int] # List of equipment IDs
class AuctionCreate(BaseModel):
askingPrice: float
description: Optional[str]
vehicle: VehicleCreate

127
app/routers/auctions.py Normal file
View File

@ -0,0 +1,127 @@
from fastapi import APIRouter, Depends, HTTPException, status, Request
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from pydantic import BaseModel
from typing import List, Optional
from ..models import Auction, Vehicle, VehicleEquipment, Equipment, User
from ..database import get_db
from ..security import verify_access_token # Ensure this is imported correctly
from fastapi.logger import logger
router = APIRouter()
# Define Pydantic models for data validation
class VehicleCreate(BaseModel):
brand: str
model: str
variant: Optional[str]
year: int
kilometers: int
condition: str
location: str
latitude: Optional[float]
longitude: Optional[float]
gasType: str
images: str
description: str
service: str
inspectedAt: Optional[str] # ISO format for datetime
equipment_ids: List[int] # List of equipment IDs
class AuctionCreate(BaseModel):
askingPrice: float
description: Optional[str]
vehicle: VehicleCreate
async def get_current_user_id(request: Request, db: AsyncSession = Depends(get_db)):
user_id = verify_access_token(request)
# Fetch user from database to check their role
result = await db.execute(select(User).filter(User.id == user_id))
user = result.scalars().first()
#print(f"\n user " + str(user.role))
if user is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="User not found."
)
if not user.role.PRIVATE: # Only allow private users to create auctions
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Only private users can create an auction."
)
logger.info(f"\nCurrent user ID: {user.id}\n Current user role: {user.role}\n")
#logger.debug(f"\nCurrent user ID: {user.id}\n Current user role: {user.role}\n")
return user.id
# API route to create an auction
@router.post("/api/v1/test")
async def testFuncForDB(request: Request,user_id: int = Depends(get_current_user_id), db: AsyncSession = Depends(get_db)):
print("HIIIIIIIIIIIIIIIIIIT")
result = await db.execute(select(User).filter(User.id == user_id))
user = result.scalars().first()
if user:
email = user.email
else:
email = "User not found"
return {"message": "Test function for DB", "email": email}
# API route to create an auction
@router.post("/api/v1/auction")
async def create_auction(auction_data: AuctionCreate, db: AsyncSession = Depends(get_db), user_id: int = Depends(get_current_user_id)):
# Create Vehicle
vehicle_data = auction_data.vehicle
vehicle = Vehicle(
brand=vehicle_data.brand,
model=vehicle_data.model,
variant=vehicle_data.variant,
year=vehicle_data.year,
kilometers=vehicle_data.kilometers,
condition=vehicle_data.condition,
location=vehicle_data.location,
latitude=vehicle_data.latitude,
longitude=vehicle_data.longitude,
gasType=vehicle_data.gasType,
images=vehicle_data.images,
description=vehicle_data.description,
service=vehicle_data.service,
inspectedAt=vehicle_data.inspectedAt,
)
# Add vehicle to the database
db.add(vehicle)
await db.commit()
await db.refresh(vehicle)
# Add vehicle equipment
for equipment_id in vehicle_data.equipment_ids:
result = await db.execute(select(Equipment).filter(Equipment.id == equipment_id))
equipment = result.scalars().first()
if not equipment:
raise HTTPException(status_code=404, detail=f"Equipment with ID {equipment_id} not found")
vehicle_equipment = VehicleEquipment(vehicle_id=vehicle.id, equipment_id=equipment.id)
db.add(vehicle_equipment)
# Create Auction
auction = Auction(
vehicleId=vehicle.id,
userId=user_id, # This comes from the authenticated user
askingPrice=auction_data.askingPrice,
description=auction_data.description,
)
# Add auction to the database
db.add(auction)
await db.commit()
await db.refresh(auction)
return {"message": "Auction created successfully", "auction": auction, "vehicle": vehicle}

View File

@ -1,11 +1,14 @@
from fastapi import APIRouter, Depends, HTTPException
from fastapi import APIRouter, Depends, HTTPException, Response, Request
from datetime import datetime, timedelta, timezone
from ..middleware import token_refresh_middleware, logout
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from ..database import get_db
from ..models import User, UserCreate, UserLogin
from ..security import create_access_token, verify_access_token
from ..models import User, UserCreate, UserLogin, Session as SessionModel
from ..security import create_access_token, verify_password
from fastapi.security import OAuth2PasswordBearer
import bcrypt
import secrets
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
@ -53,30 +56,40 @@ async def register(user_data: UserCreate, db: AsyncSession = Depends(get_db)):
# User login
@router.post("/api/v1/login")
async def login(login_data: UserLogin, db: AsyncSession = Depends(get_db)):
async def login(login_data: UserLogin, db: AsyncSession = Depends(get_db), response: Response = None):
async with db.begin():
# Check if email is an email
query = select(User).filter(User.email == login_data.email)
result = await db.execute(query)
user = result.scalars().first()
if not user or not bcrypt.checkpw(login_data.password.encode('utf-8'), user.password.encode('utf-8')):
if not user or not verify_password(login_data.password, user.password):
raise HTTPException(status_code=400, detail="Invalid credentials")
# Create session token
session_token = secrets.token_hex(32)
session_expiry = datetime.now(timezone.utc) + timedelta(hours=12)
new_session = SessionModel(sessionToken=session_token, userId=user.id, expires=session_expiry)
db.add(new_session)
# Create JWT access token
access_token = create_access_token(data={"user_id": user.id})
return {"access_token": access_token, "token_type": "bearer"}
await db.commit()
# Set session and access tokens after transaction is committed
response.set_cookie(key="session_token", value=session_token, httponly=True, max_age=12*60*60)
response.set_cookie(key="access_token", value=access_token, httponly=True, max_age=60*60)
return {"message": "Login successful"}
# Logout user
@router.post("/api/v1/logout")
async def logout_user(request: Request, response: Response, db: AsyncSession = Depends(get_db)):
return await logout(request, db, response)
# Protected route example
# Protected route example using middleware
@router.get("/api/v1/protected")
async def protected_route(token: str = Depends(oauth2_scheme), db: AsyncSession = Depends(get_db)):
user_id = verify_access_token(token)
async with db.begin():
query = select(User).filter(User.id == user_id)
result = await db.execute(query)
user = result.scalars().first()
if not user:
raise HTTPException(status_code=401, detail="User not found")
return {"message": f"Hello, {user.name}"}
async def protected_route(user: User = Depends(token_refresh_middleware)):
return {"message": f"Hello, {user.name}"}

View File

@ -1,13 +1,19 @@
import bcrypt
from jose import jwt, JWTError
from datetime import datetime, timedelta
from fastapi import HTTPException, status
from datetime import datetime, timedelta, timezone
from fastapi import HTTPException, status, Request
import os
from pydantic import BaseModel
from typing import Optional
# Secret and algorithm for JWT
SECRET_KEY = os.getenv('SECRET_KEY', 'your_jwt_secret_key') # Ensure this is set in your environment
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 30
ACCESS_TOKEN_EXPIRE_MINUTES = 60 * 24 * 14
class TokenData(BaseModel):
user_id: Optional[int] = None
# Hash password using bcrypt directly
def get_password_hash(password: str) -> str:
@ -26,29 +32,37 @@ def create_access_token(data: dict, expires_delta: timedelta = None):
"""Creates a JWT token with expiration time."""
to_encode = data.copy()
if expires_delta:
expire = datetime.utcnow() + expires_delta
expire = datetime.now(timezone.utc) + expires_delta
else:
expire = datetime.utcnow() + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
expire = datetime.now(timezone.utc) + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
to_encode.update({"exp": expire})
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
return encoded_jwt
# Verify JWT token
def verify_access_token(token: str):
"""Verifies the JWT token and returns the user_id if valid."""
def verify_access_token(request: Request) -> int:
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
# First, check Authorization header (for cases where the JWT is passed in headers)
auth_header = request.headers.get("Authorization")
token = None
if auth_header and auth_header.startswith("Bearer "):
token = auth_header.split(" ")[1]
# If no Authorization header, fallback to cookies
if not token:
token = request.cookies.get("access_token")
if not token:
raise credentials_exception
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
user_id: str = payload.get("user_id")
user_id: int = payload.get("user_id")
if user_id is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid token",
headers={"WWW-Authenticate": "Bearer"},
)
raise credentials_exception
return user_id
except JWTError:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid token",
headers={"WWW-Authenticate": "Bearer"},
)
raise credentials_exception