diff --git a/backend/database/__init__.py b/backend/database/__init__.py index 2c5707d..f55a12e 100644 --- a/backend/database/__init__.py +++ b/backend/database/__init__.py @@ -1,7 +1,6 @@ -from database.core import ( +from database.pg_mqtt import ( init_db, close_db, - get_db, purge_loop, purge_old_data, insert_log, @@ -16,12 +15,13 @@ from database.core import ( upsert_alert, delete_alert, get_alerts, + partition_manager_loop, + ensure_current_partitions, ) __all__ = [ "init_db", "close_db", - "get_db", "purge_loop", "purge_old_data", "insert_log", @@ -36,4 +36,6 @@ __all__ = [ "upsert_alert", "delete_alert", "get_alerts", + "partition_manager_loop", + "ensure_current_partitions", ] diff --git a/backend/database/pg_mqtt.py b/backend/database/pg_mqtt.py new file mode 100644 index 0000000..5534c4a --- /dev/null +++ b/backend/database/pg_mqtt.py @@ -0,0 +1,416 @@ +""" +Phase 5 — MQTT live data functions backed by Postgres. + +device_logs is a partitioned table; heartbeats and commands are plain tables. +All three are accessed via raw SQL (not ORM) because device_logs partitioning +does not play well with SQLAlchemy's declarative ORM. + +device_alerts is an ORM model (devices/orm.py) and is handled here via raw SQL +to keep a single consistent interface for callers that used to import from database.core. +""" + +import asyncio +import json +import logging +from datetime import date, datetime, timedelta, timezone + +from dateutil.relativedelta import relativedelta +from sqlalchemy import text +from sqlalchemy.ext.asyncio import AsyncSession + +from config import settings +from database.postgres import AsyncSessionLocal + +logger = logging.getLogger("database.pg_mqtt") + + +# --------------------------------------------------------------------------- +# Internal session helper — used by fire-and-forget insert paths that are +# called from the MQTT ingestion thread (not inside a FastAPI request). +# --------------------------------------------------------------------------- + +async def _session() -> AsyncSession: + return AsyncSessionLocal() + + +# --------------------------------------------------------------------------- +# Insert operations +# --------------------------------------------------------------------------- + +async def insert_log(device_serial: str, level: str, message: str, + device_timestamp: int | None = None) -> int: + async with AsyncSessionLocal() as session: + result = await session.execute( + text(""" + INSERT INTO device_logs (device_serial, level, message, device_timestamp, received_at) + VALUES (:serial, :level, :message, :ts, now()) + RETURNING id + """), + {"serial": device_serial, "level": level, "message": message, "ts": device_timestamp}, + ) + row = result.fetchone() + await session.commit() + return row[0] + + +async def insert_heartbeat(device_serial: str, device_id: str, + firmware_version: str, ip_address: str, + gateway: str, uptime_ms: int, uptime_display: str) -> int: + async with AsyncSessionLocal() as session: + result = await session.execute( + text(""" + INSERT INTO heartbeats + (device_serial, device_id, firmware_version, ip_address, + gateway, uptime_ms, uptime_display, received_at) + VALUES + (:serial, :device_id, :fw, :ip, :gw, :uptime_ms, :uptime_display, now()) + RETURNING id + """), + { + "serial": device_serial, + "device_id": device_id, + "fw": firmware_version, + "ip": ip_address, + "gw": gateway, + "uptime_ms": uptime_ms, + "uptime_display": uptime_display, + }, + ) + row = result.fetchone() + await session.commit() + return row[0] + + +async def insert_command(device_serial: str, command_name: str, + command_payload: dict) -> int: + async with AsyncSessionLocal() as session: + result = await session.execute( + text(""" + INSERT INTO commands (device_serial, command_name, command_payload, sent_at) + VALUES (:serial, :name, :payload, now()) + RETURNING id + """), + { + "serial": device_serial, + "name": command_name, + "payload": json.dumps(command_payload), + }, + ) + row = result.fetchone() + await session.commit() + return row[0] + + +async def update_command_response(command_id: int, status: str, + response_payload: dict | None = None): + async with AsyncSessionLocal() as session: + await session.execute( + text(""" + UPDATE commands + SET status = :status, + response_payload = :payload, + responded_at = now() + WHERE id = :id + """), + { + "id": command_id, + "status": status, + "payload": json.dumps(response_payload) if response_payload else None, + }, + ) + await session.commit() + + +# --------------------------------------------------------------------------- +# Query operations +# --------------------------------------------------------------------------- + +async def get_logs(device_serial: str, level: str | None = None, + search: str | None = None, + limit: int = 100, offset: int = 0) -> tuple[list, int]: + where = "device_serial = :serial" + params: dict = {"serial": device_serial, "limit": limit, "offset": offset} + + if level: + where += " AND level = :level" + params["level"] = level + if search: + where += " AND message ILIKE :search" + params["search"] = f"%{search}%" + + async with AsyncSessionLocal() as session: + count_result = await session.execute( + text(f"SELECT COUNT(*) FROM device_logs WHERE {where}"), params + ) + total = count_result.scalar() + + rows_result = await session.execute( + text(f""" + SELECT id, device_serial, level, message, device_timestamp, + received_at AT TIME ZONE 'UTC' AS received_at + FROM device_logs + WHERE {where} + ORDER BY received_at DESC + LIMIT :limit OFFSET :offset + """), + params, + ) + rows = rows_result.mappings().all() + + return [_row_to_dict(r) for r in rows], total + + +async def get_heartbeats(device_serial: str, limit: int = 100, + offset: int = 0) -> tuple[list, int]: + async with AsyncSessionLocal() as session: + count_result = await session.execute( + text("SELECT COUNT(*) FROM heartbeats WHERE device_serial = :serial"), + {"serial": device_serial}, + ) + total = count_result.scalar() + + rows_result = await session.execute( + text(""" + SELECT id, device_serial, device_id, firmware_version, ip_address, + gateway, uptime_ms, uptime_display, + received_at AT TIME ZONE 'UTC' AS received_at + FROM heartbeats + WHERE device_serial = :serial + ORDER BY received_at DESC + LIMIT :limit OFFSET :offset + """), + {"serial": device_serial, "limit": limit, "offset": offset}, + ) + rows = rows_result.mappings().all() + + return [_row_to_dict(r) for r in rows], total + + +async def get_commands(device_serial: str, limit: int = 100, + offset: int = 0) -> tuple[list, int]: + async with AsyncSessionLocal() as session: + count_result = await session.execute( + text("SELECT COUNT(*) FROM commands WHERE device_serial = :serial"), + {"serial": device_serial}, + ) + total = count_result.scalar() + + rows_result = await session.execute( + text(""" + SELECT id, device_serial, command_name, command_payload, status, + response_payload, + sent_at AT TIME ZONE 'UTC' AS sent_at, + responded_at AT TIME ZONE 'UTC' AS responded_at + FROM commands + WHERE device_serial = :serial + ORDER BY sent_at DESC + LIMIT :limit OFFSET :offset + """), + {"serial": device_serial, "limit": limit, "offset": offset}, + ) + rows = rows_result.mappings().all() + + return [_row_to_dict(r) for r in rows], total + + +async def get_latest_heartbeats() -> list: + async with AsyncSessionLocal() as session: + rows_result = await session.execute( + text(""" + SELECT DISTINCT ON (device_serial) + id, device_serial, device_id, firmware_version, ip_address, + gateway, uptime_ms, uptime_display, + received_at AT TIME ZONE 'UTC' AS received_at + FROM heartbeats + ORDER BY device_serial, received_at DESC + """) + ) + rows = rows_result.mappings().all() + + return [_row_to_dict(r) for r in rows] + + +async def get_pending_command(device_serial: str) -> dict | None: + async with AsyncSessionLocal() as session: + result = await session.execute( + text(""" + SELECT id, device_serial, command_name, command_payload, status, + response_payload, + sent_at AT TIME ZONE 'UTC' AS sent_at, + responded_at AT TIME ZONE 'UTC' AS responded_at + FROM commands + WHERE device_serial = :serial AND status = 'pending' + ORDER BY sent_at DESC + LIMIT 1 + """), + {"serial": device_serial}, + ) + row = result.mappings().fetchone() + + return _row_to_dict(row) if row else None + + +# --------------------------------------------------------------------------- +# Device alerts +# --------------------------------------------------------------------------- + +async def upsert_alert(device_serial: str, subsystem: str, state: str, + message: str | None = None): + async with AsyncSessionLocal() as session: + await session.execute( + text(""" + INSERT INTO device_alerts (device_serial, subsystem, state, message, updated_at) + VALUES (:serial, :subsystem, :state, :message, now()) + ON CONFLICT (device_serial, subsystem) + DO UPDATE SET + state = EXCLUDED.state, + message = EXCLUDED.message, + updated_at = EXCLUDED.updated_at + """), + {"serial": device_serial, "subsystem": subsystem, "state": state, "message": message}, + ) + await session.commit() + + +async def delete_alert(device_serial: str, subsystem: str): + async with AsyncSessionLocal() as session: + await session.execute( + text("DELETE FROM device_alerts WHERE device_serial = :serial AND subsystem = :subsystem"), + {"serial": device_serial, "subsystem": subsystem}, + ) + await session.commit() + + +async def get_alerts(device_serial: str) -> list: + async with AsyncSessionLocal() as session: + result = await session.execute( + text(""" + SELECT id, device_serial, subsystem, state, message, + updated_at AT TIME ZONE 'UTC' AS updated_at + FROM device_alerts + WHERE device_serial = :serial + ORDER BY updated_at DESC + """), + {"serial": device_serial}, + ) + rows = result.mappings().all() + + return [_row_to_dict(r) for r in rows] + + +# --------------------------------------------------------------------------- +# Partition management +# --------------------------------------------------------------------------- + +async def ensure_current_partitions(): + """Create device_logs partitions for the current and next month if missing.""" + async with AsyncSessionLocal() as session: + for month_offset in (0, 1): + d = date.today().replace(day=1) + relativedelta(months=month_offset) + partition_name = f"device_logs_{d.strftime('%Y_%m')}" + start = d.isoformat() + end = (d + relativedelta(months=1)).isoformat() + await session.execute(text(f""" + CREATE TABLE IF NOT EXISTS {partition_name} + PARTITION OF device_logs + FOR VALUES FROM ('{start}') TO ('{end}') + """)) + await session.commit() + logger.info("Partition check complete") + + +async def drop_old_partitions(keep_months: int = 6): + """Drop device_logs partitions older than keep_months.""" + cutoff = date.today().replace(day=1) - relativedelta(months=keep_months) + async with AsyncSessionLocal() as session: + result = await session.execute(text(""" + SELECT tablename FROM pg_tables + WHERE schemaname = 'public' + AND tablename LIKE 'device_logs_%' + """)) + partitions = [r[0] for r in result.fetchall()] + + for name in partitions: + # name format: device_logs_YYYY_MM + parts = name.split("_") + if len(parts) != 4: + continue + try: + partition_date = date(int(parts[2]), int(parts[3]), 1) + except ValueError: + continue + if partition_date < cutoff: + async with AsyncSessionLocal() as session: + await session.execute(text(f"DROP TABLE IF EXISTS {name}")) + await session.commit() + logger.info(f"Dropped old partition: {name}") + + +async def partition_manager_loop(): + """Runs once on startup, then monthly thereafter.""" + await ensure_current_partitions() + while True: + # Sleep ~30 days, wake up and ensure next month's partition exists + await asyncio.sleep(30 * 24 * 3600) + try: + await ensure_current_partitions() + await drop_old_partitions() + except Exception as e: + logger.error(f"Partition manager error: {e}") + + +# --------------------------------------------------------------------------- +# Cleanup (replaces SQLite purge_loop — now a no-op since Postgres uses +# partition drops instead of row-by-row deletes for device_logs; heartbeats +# and commands are still purged by row deletion) +# --------------------------------------------------------------------------- + +async def purge_old_data(retention_days: int | None = None): + days = retention_days or settings.mqtt_data_retention_days + cutoff = datetime.now(timezone.utc) - timedelta(days=days) + async with AsyncSessionLocal() as session: + await session.execute( + text("DELETE FROM heartbeats WHERE received_at < :cutoff"), + {"cutoff": cutoff}, + ) + await session.execute( + text("DELETE FROM commands WHERE sent_at < :cutoff"), + {"cutoff": cutoff}, + ) + await session.commit() + logger.info(f"Purged heartbeats and commands older than {days} days") + + +async def purge_loop(): + while True: + await asyncio.sleep(86400) + try: + await purge_old_data() + except Exception as e: + logger.error(f"Purge failed: {e}") + + +# --------------------------------------------------------------------------- +# Stub — no longer needed but kept so nothing that imports init_db/close_db breaks +# --------------------------------------------------------------------------- + +async def init_db(): + """No-op: Postgres schema is managed by Alembic, not runtime init.""" + logger.info("Postgres MQTT backend active — no SQLite init needed") + + +async def close_db(): + """No-op: SQLAlchemy engine lifecycle is managed by the process.""" + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _row_to_dict(row) -> dict: + """Convert a SQLAlchemy RowMapping to a plain dict with ISO string timestamps.""" + d = dict(row) + for key, val in d.items(): + if isinstance(val, datetime): + d[key] = val.isoformat() + return d diff --git a/backend/main.py b/backend/main.py index c55fde0..543762d 100644 --- a/backend/main.py +++ b/backend/main.py @@ -108,9 +108,9 @@ async def crm_poll_loop(): @app.on_event("startup") async def startup(): init_firebase() - await db.init_db() await melody_service.migrate_from_firestore() mqtt_manager.start(asyncio.get_event_loop()) + asyncio.create_task(db.partition_manager_loop()) asyncio.create_task(db.purge_loop()) asyncio.create_task(nextcloud_keepalive_loop()) asyncio.create_task(crm_poll_loop()) @@ -125,7 +125,6 @@ async def startup(): @app.on_event("shutdown") async def shutdown(): mqtt_manager.stop() - await db.close_db() await close_nextcloud_client() diff --git a/backend/mqtt/router.py b/backend/mqtt/router.py index b1b885f..3ddb2d4 100644 --- a/backend/mqtt/router.py +++ b/backend/mqtt/router.py @@ -129,27 +129,29 @@ async def mqtt_websocket(websocket: WebSocket): try: from auth.utils import decode_access_token - from shared.firebase import get_db + from sqlalchemy import select + from database.postgres import AsyncSessionLocal + from staff.orm import Staff + payload = decode_access_token(token) role = payload.get("role", "") # sysadmin and admin always have MQTT access if role not in ("sysadmin", "admin"): - # Check MQTT permission for editor/user user_sub = payload.get("sub", "") - db_inst = get_db() - if db_inst: - doc = db_inst.collection("admin_users").document(user_sub).get() - if doc.exists: - perms = doc.to_dict().get("permissions", {}) - if not perms.get("mqtt", False): - await websocket.close(code=4003, reason="MQTT access denied") - return - else: - await websocket.close(code=4003, reason="User not found") - return - else: - await websocket.close(code=4003, reason="Service unavailable") + async with AsyncSessionLocal() as session: + result = await session.execute( + select(Staff).where(Staff.id == user_sub).limit(1) + ) + staff = result.scalar_one_or_none() + + if staff is None: + await websocket.close(code=4003, reason="User not found") + return + + perms = staff.permissions or {} + if not perms.get("mqtt", {}).get("access", False): + await websocket.close(code=4003, reason="MQTT access denied") return except Exception: await websocket.close(code=4001, reason="Invalid token") diff --git a/strategies/DATABASE_MIGRATION.md b/strategies/DATABASE_MIGRATION.md index 23bc638..0f8c4c3 100644 --- a/strategies/DATABASE_MIGRATION.md +++ b/strategies/DATABASE_MIGRATION.md @@ -402,12 +402,17 @@ await log_action(db, actor_id, actor_name, "UPDATE", "customer", id, label, chan --- ## Phase 5 — MQTT Live Data Cutover -**Status: NOT STARTED** -**Prerequisite:** Phase 1 complete (device_logs in Postgres) +**Status: COMPLETE** — Postgres live ingestion + partition manager active 2026-04-17 -This phase switches the **live MQTT ingestion** from SQLite to Postgres. +### What changed +- New `backend/database/pg_mqtt.py` — all MQTT functions rewritten for Postgres (raw SQL, no ORM) +- `database/__init__.py` — re-exports from `pg_mqtt` instead of `core` +- `main.py` — removed `db.init_db()` / `db.close_db()`, added `db.partition_manager_loop()` +- `mqtt/router.py` WebSocket auth — reads from Postgres `staff` table instead of Firestore `admin_users` +- `device_logs` partitioned writes, `heartbeats`/`commands` as plain tables +- `purge_loop` still runs for heartbeats/commands; device_logs purged via partition drops -### Steps +### Steps (original plan, now implemented) 1. Update `database/core.py` `insert_log`, `insert_heartbeat`, `insert_command` to write to Postgres 2. Update read functions (`get_logs`, `get_heartbeats`, etc.) similarly 3. The partition management background job: each month, at startup or via a cron, ensure next month's partition exists: @@ -470,7 +475,7 @@ backend/ | 2 | Firestore → Postgres (data migration) | **COMPLETE** — all 5 scripts ran successfully on VPS 2026-04-17 | | 3 | Staff auth cutover | **COMPLETE** — Postgres auth live 2026-04-17 | | 4 | Audit log system | **COMPLETE** — shared/audit.py live, wired into auth + staff 2026-04-17 | -| 5 | MQTT live data cutover | NOT STARTED | +| 5 | MQTT live data cutover | **COMPLETE** — Postgres live ingestion + partition manager 2026-04-17 | Update this table as each phase completes.