95 lines
3.7 KiB
Python
95 lines
3.7 KiB
Python
from fastapi import APIRouter, Depends, HTTPException, Response, Request
|
|
from datetime import datetime, timedelta, timezone
|
|
from ..middleware import token_refresh_middleware, logout
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from sqlalchemy.future import select
|
|
from ..database import get_db
|
|
from ..models import User, UserCreate, UserLogin, Session as SessionModel
|
|
from ..security import create_access_token, verify_password
|
|
from fastapi.security import OAuth2PasswordBearer
|
|
import bcrypt
|
|
import secrets
|
|
|
|
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
|
|
|
router = APIRouter()
|
|
|
|
# Register a new user
|
|
@router.post("/api/v1/register")
|
|
async def register(user_data: UserCreate, db: AsyncSession = Depends(get_db)):
|
|
async with db.begin():
|
|
# Check if email already exists
|
|
query = select(User).filter((User.email == user_data.email))
|
|
result = await db.execute(query)
|
|
user = result.scalars().first()
|
|
|
|
if user:
|
|
raise HTTPException(status_code=400, detail="Email already exists")
|
|
|
|
# If user is registering as business, require company, privatePhone, and cvr
|
|
if user_data.role == 'BUSINESS':
|
|
if not user_data.company or not user_data.privatePhone or not user_data.cvr:
|
|
raise HTTPException(status_code=400, detail="Company, Private Phone, and CVR are required for business users.")
|
|
|
|
# Hash the password using bcrypt
|
|
hashed_password = bcrypt.hashpw(user_data.password.encode('utf-8'), bcrypt.gensalt()).decode('utf-8')
|
|
|
|
# Create a new user with the provided information
|
|
new_user = User(
|
|
email=user_data.email,
|
|
password=hashed_password,
|
|
name=user_data.name,
|
|
role=user_data.role,
|
|
phone=user_data.phone,
|
|
address=user_data.address,
|
|
postcode=user_data.postcode,
|
|
city=user_data.city,
|
|
company=user_data.company,
|
|
privatePhone=user_data.privatePhone,
|
|
cvr=user_data.cvr
|
|
)
|
|
|
|
db.add(new_user)
|
|
await db.commit()
|
|
|
|
return {"message": "User created successfully"}
|
|
|
|
# User login
|
|
@router.post("/api/v1/login")
|
|
async def login(login_data: UserLogin, db: AsyncSession = Depends(get_db), response: Response = None):
|
|
async with db.begin():
|
|
query = select(User).filter(User.email == login_data.email)
|
|
result = await db.execute(query)
|
|
user = result.scalars().first()
|
|
|
|
if not user or not verify_password(login_data.password, user.password):
|
|
raise HTTPException(status_code=400, detail="Invalid credentials")
|
|
|
|
# Create session token
|
|
session_token = secrets.token_hex(32)
|
|
session_expiry = datetime.now(timezone.utc) + timedelta(hours=12)
|
|
|
|
new_session = SessionModel(sessionToken=session_token, userId=user.id, expires=session_expiry)
|
|
db.add(new_session)
|
|
|
|
# Create JWT access token
|
|
access_token = create_access_token(data={"user_id": user.id})
|
|
|
|
await db.commit()
|
|
|
|
# Set session and access tokens after transaction is committed
|
|
response.set_cookie(key="session_token", value=session_token, httponly=True, max_age=12*60*60)
|
|
response.set_cookie(key="access_token", value=access_token, httponly=True, max_age=60*60)
|
|
|
|
return {"message": "Login successful"}
|
|
|
|
# Logout user
|
|
@router.post("/api/v1/logout")
|
|
async def logout_user(request: Request, response: Response, db: AsyncSession = Depends(get_db)):
|
|
return await logout(request, db, response)
|
|
|
|
|
|
# Protected route example using middleware
|
|
@router.get("/api/v1/protected")
|
|
async def protected_route(user: User = Depends(token_refresh_middleware)):
|
|
return {"message": f"Hello, {user.name}"} |