from fastapi import Depends from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials from jose import JWTError from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from auth.utils import decode_access_token from auth.models import TokenPayload, Role from database.postgres import get_pg_session from staff.orm import Staff from shared.exceptions import AuthenticationError, AuthorizationError security = HTTPBearer() async def get_current_user( credentials: HTTPAuthorizationCredentials = Depends(security), ) -> TokenPayload: try: payload = decode_access_token(credentials.credentials) token_data = TokenPayload( sub=payload["sub"], email=payload["email"], role=payload["role"], name=payload["name"], ) except (JWTError, KeyError): raise AuthenticationError() return token_data def require_roles(*allowed_roles: Role): async def role_checker( current_user: TokenPayload = Depends(get_current_user), ) -> TokenPayload: if current_user.role == Role.sysadmin: return current_user if current_user.role not in [r.value for r in allowed_roles]: raise AuthorizationError() return current_user return role_checker async def _get_user_permissions(user: TokenPayload, db: AsyncSession) -> dict | None: """Fetch permissions from Postgres for the given user.""" if user.role in (Role.sysadmin, Role.admin): return None # Full access result = await db.execute(select(Staff).where(Staff.id == user.sub).limit(1)) staff = result.scalar_one_or_none() if staff is None: raise AuthorizationError() return staff.permissions def require_permission(section: str, action: str): """Check granular permission for a section and action. section: 'melodies', 'devices', 'app_users', 'equipment', 'mqtt' action: 'view', 'add', 'edit', 'delete' (or ignored for mqtt) """ async def permission_checker( current_user: TokenPayload = Depends(get_current_user), db: AsyncSession = Depends(get_pg_session), ) -> TokenPayload: if current_user.role in (Role.sysadmin, Role.admin): return current_user permissions = await _get_user_permissions(current_user, db) if not permissions: raise AuthorizationError() if section == "mqtt": if not permissions.get("mqtt", {}).get("access", False): raise AuthorizationError() return current_user section_perms = permissions.get(section) if not section_perms: raise AuthorizationError() if isinstance(section_perms, dict): if not section_perms.get(action, False): raise AuthorizationError() else: raise AuthorizationError() return current_user return permission_checker # Pre-built convenience dependencies require_sysadmin = require_roles(Role.sysadmin) require_admin_or_above = require_roles(Role.sysadmin, Role.admin) require_staff_management = require_roles(Role.sysadmin, Role.admin) require_any_authenticated = require_roles( Role.sysadmin, Role.admin, Role.editor, Role.user, )