Files
scrapAPI/app/routers/auth.py

83 lines
3.1 KiB
Python

from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from ..database import get_db
from ..models import User, UserCreate, UserLogin
from ..security import create_access_token, verify_access_token
from fastapi.security import OAuth2PasswordBearer
import bcrypt
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
router = APIRouter()
# Register a new user
@router.post("/api/v1/register")
async def register(user_data: UserCreate, db: AsyncSession = Depends(get_db)):
async with db.begin():
# Check if email already exists
query = select(User).filter((User.email == user_data.email))
result = await db.execute(query)
user = result.scalars().first()
if user:
raise HTTPException(status_code=400, detail="Email already exists")
# If user is registering as business, require company, privatePhone, and cvr
if user_data.role == 'BUSINESS':
if not user_data.company or not user_data.privatePhone or not user_data.cvr:
raise HTTPException(status_code=400, detail="Company, Private Phone, and CVR are required for business users.")
# Hash the password using bcrypt
hashed_password = bcrypt.hashpw(user_data.password.encode('utf-8'), bcrypt.gensalt()).decode('utf-8')
# Create a new user with the provided information
new_user = User(
email=user_data.email,
password=hashed_password,
name=user_data.name,
role=user_data.role,
phone=user_data.phone,
address=user_data.address,
postcode=user_data.postcode,
city=user_data.city,
company=user_data.company,
privatePhone=user_data.privatePhone,
cvr=user_data.cvr
)
db.add(new_user)
await db.commit()
return {"message": "User created successfully"}
# User login
@router.post("/api/v1/login")
async def login(login_data: UserLogin, db: AsyncSession = Depends(get_db)):
async with db.begin():
# Check if email is an email
query = select(User).filter(User.email == login_data.email)
result = await db.execute(query)
user = result.scalars().first()
if not user or not bcrypt.checkpw(login_data.password.encode('utf-8'), user.password.encode('utf-8')):
raise HTTPException(status_code=400, detail="Invalid credentials")
access_token = create_access_token(data={"user_id": user.id})
return {"access_token": access_token, "token_type": "bearer"}
# Protected route example
@router.get("/api/v1/protected")
async def protected_route(token: str = Depends(oauth2_scheme), db: AsyncSession = Depends(get_db)):
user_id = verify_access_token(token)
async with db.begin():
query = select(User).filter(User.id == user_id)
result = await db.execute(query)
user = result.scalars().first()
if not user:
raise HTTPException(status_code=401, detail="User not found")
return {"message": f"Hello, {user.name}"}