168 lines
5.6 KiB
Python
168 lines
5.6 KiB
Python
from fastapi import APIRouter, Depends, Query, WebSocket, WebSocketDisconnect
|
|
from typing import Optional
|
|
from auth.models import TokenPayload
|
|
from auth.dependencies import require_permission
|
|
from mqtt.models import (
|
|
MqttCommandRequest, CommandSendResponse, MqttStatusResponse,
|
|
DeviceMqttStatus, LogListResponse, HeartbeatListResponse,
|
|
CommandListResponse, HeartbeatEntry,
|
|
)
|
|
from mqtt.client import mqtt_manager
|
|
from mqtt import database as db
|
|
from datetime import datetime, timezone
|
|
|
|
router = APIRouter(prefix="/api/mqtt", tags=["mqtt"])
|
|
|
|
|
|
@router.get("/status", response_model=MqttStatusResponse)
|
|
async def get_all_device_status(
|
|
_user: TokenPayload = Depends(require_permission("mqtt", "view")),
|
|
):
|
|
heartbeats = await db.get_latest_heartbeats()
|
|
now = datetime.now(timezone.utc)
|
|
devices = []
|
|
for hb in heartbeats:
|
|
received_str = hb["received_at"]
|
|
try:
|
|
received = datetime.fromisoformat(received_str)
|
|
if received.tzinfo is None:
|
|
received = received.replace(tzinfo=timezone.utc)
|
|
seconds_ago = int((now - received).total_seconds())
|
|
except (ValueError, TypeError):
|
|
seconds_ago = 9999
|
|
|
|
devices.append(DeviceMqttStatus(
|
|
device_serial=hb["device_serial"],
|
|
online=seconds_ago < 90,
|
|
last_heartbeat=HeartbeatEntry(**hb),
|
|
seconds_since_heartbeat=seconds_ago,
|
|
))
|
|
return MqttStatusResponse(
|
|
devices=devices,
|
|
broker_connected=mqtt_manager.connected,
|
|
)
|
|
|
|
|
|
@router.post("/command/{device_serial}", response_model=CommandSendResponse)
|
|
async def send_command(
|
|
device_serial: str,
|
|
body: MqttCommandRequest,
|
|
_user: TokenPayload = Depends(require_permission("mqtt", "view")),
|
|
):
|
|
command_id = await db.insert_command(
|
|
device_serial=device_serial,
|
|
command_name=body.cmd,
|
|
command_payload={"cmd": body.cmd, "contents": body.contents},
|
|
)
|
|
|
|
success = mqtt_manager.publish_command(
|
|
device_serial=device_serial,
|
|
cmd=body.cmd,
|
|
contents=body.contents,
|
|
)
|
|
|
|
if not success:
|
|
await db.update_command_response(
|
|
command_id, "error",
|
|
{"error": "MQTT broker not connected"},
|
|
)
|
|
return CommandSendResponse(
|
|
success=False, command_id=command_id,
|
|
message="MQTT broker not connected",
|
|
)
|
|
|
|
return CommandSendResponse(
|
|
success=True, command_id=command_id,
|
|
message=f"Command '{body.cmd}' sent to {device_serial}",
|
|
)
|
|
|
|
|
|
@router.get("/logs/{device_serial}", response_model=LogListResponse)
|
|
async def get_device_logs(
|
|
device_serial: str,
|
|
level: Optional[str] = Query(None, description="Filter: INFO, WARN, ERROR"),
|
|
search: Optional[str] = Query(None),
|
|
limit: int = Query(100, ge=1, le=1000),
|
|
offset: int = Query(0, ge=0),
|
|
_user: TokenPayload = Depends(require_permission("mqtt", "view")),
|
|
):
|
|
logs, total = await db.get_logs(
|
|
device_serial, level=level, search=search,
|
|
limit=limit, offset=offset,
|
|
)
|
|
return LogListResponse(logs=logs, total=total)
|
|
|
|
|
|
@router.get("/heartbeats/{device_serial}", response_model=HeartbeatListResponse)
|
|
async def get_device_heartbeats(
|
|
device_serial: str,
|
|
limit: int = Query(100, ge=1, le=1000),
|
|
offset: int = Query(0, ge=0),
|
|
_user: TokenPayload = Depends(require_permission("mqtt", "view")),
|
|
):
|
|
heartbeats, total = await db.get_heartbeats(
|
|
device_serial, limit=limit, offset=offset,
|
|
)
|
|
return HeartbeatListResponse(heartbeats=heartbeats, total=total)
|
|
|
|
|
|
@router.get("/commands/{device_serial}", response_model=CommandListResponse)
|
|
async def get_device_commands(
|
|
device_serial: str,
|
|
limit: int = Query(100, ge=1, le=1000),
|
|
offset: int = Query(0, ge=0),
|
|
_user: TokenPayload = Depends(require_permission("mqtt", "view")),
|
|
):
|
|
commands, total = await db.get_commands(
|
|
device_serial, limit=limit, offset=offset,
|
|
)
|
|
return CommandListResponse(commands=commands, total=total)
|
|
|
|
|
|
@router.websocket("/ws")
|
|
async def mqtt_websocket(websocket: WebSocket):
|
|
"""Live MQTT data stream. Auth via query param: ?token=JWT"""
|
|
token = websocket.query_params.get("token")
|
|
if not token:
|
|
await websocket.close(code=4001, reason="Missing token")
|
|
return
|
|
|
|
try:
|
|
from auth.utils import decode_access_token
|
|
from shared.firebase import get_db
|
|
payload = decode_access_token(token)
|
|
role = payload.get("role", "")
|
|
|
|
# sysadmin and admin always have MQTT access
|
|
if role not in ("sysadmin", "admin"):
|
|
# Check MQTT permission for editor/user
|
|
user_sub = payload.get("sub", "")
|
|
db_inst = get_db()
|
|
if db_inst:
|
|
doc = db_inst.collection("admin_users").document(user_sub).get()
|
|
if doc.exists:
|
|
perms = doc.to_dict().get("permissions", {})
|
|
if not perms.get("mqtt", False):
|
|
await websocket.close(code=4003, reason="MQTT access denied")
|
|
return
|
|
else:
|
|
await websocket.close(code=4003, reason="User not found")
|
|
return
|
|
else:
|
|
await websocket.close(code=4003, reason="Service unavailable")
|
|
return
|
|
except Exception:
|
|
await websocket.close(code=4001, reason="Invalid token")
|
|
return
|
|
|
|
await websocket.accept()
|
|
mqtt_manager.add_ws_subscriber(websocket)
|
|
|
|
try:
|
|
while True:
|
|
await websocket.receive_text()
|
|
except WebSocketDisconnect:
|
|
pass
|
|
finally:
|
|
mqtt_manager.remove_ws_subscriber(websocket)
|