Phase 5 of Migration
This commit is contained in:
@@ -1,7 +1,6 @@
|
||||
from database.core import (
|
||||
from database.pg_mqtt import (
|
||||
init_db,
|
||||
close_db,
|
||||
get_db,
|
||||
purge_loop,
|
||||
purge_old_data,
|
||||
insert_log,
|
||||
@@ -16,12 +15,13 @@ from database.core import (
|
||||
upsert_alert,
|
||||
delete_alert,
|
||||
get_alerts,
|
||||
partition_manager_loop,
|
||||
ensure_current_partitions,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"init_db",
|
||||
"close_db",
|
||||
"get_db",
|
||||
"purge_loop",
|
||||
"purge_old_data",
|
||||
"insert_log",
|
||||
@@ -36,4 +36,6 @@ __all__ = [
|
||||
"upsert_alert",
|
||||
"delete_alert",
|
||||
"get_alerts",
|
||||
"partition_manager_loop",
|
||||
"ensure_current_partitions",
|
||||
]
|
||||
|
||||
416
backend/database/pg_mqtt.py
Normal file
416
backend/database/pg_mqtt.py
Normal file
@@ -0,0 +1,416 @@
|
||||
"""
|
||||
Phase 5 — MQTT live data functions backed by Postgres.
|
||||
|
||||
device_logs is a partitioned table; heartbeats and commands are plain tables.
|
||||
All three are accessed via raw SQL (not ORM) because device_logs partitioning
|
||||
does not play well with SQLAlchemy's declarative ORM.
|
||||
|
||||
device_alerts is an ORM model (devices/orm.py) and is handled here via raw SQL
|
||||
to keep a single consistent interface for callers that used to import from database.core.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from datetime import date, datetime, timedelta, timezone
|
||||
|
||||
from dateutil.relativedelta import relativedelta
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from config import settings
|
||||
from database.postgres import AsyncSessionLocal
|
||||
|
||||
logger = logging.getLogger("database.pg_mqtt")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Internal session helper — used by fire-and-forget insert paths that are
|
||||
# called from the MQTT ingestion thread (not inside a FastAPI request).
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def _session() -> AsyncSession:
|
||||
return AsyncSessionLocal()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Insert operations
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def insert_log(device_serial: str, level: str, message: str,
|
||||
device_timestamp: int | None = None) -> int:
|
||||
async with AsyncSessionLocal() as session:
|
||||
result = await session.execute(
|
||||
text("""
|
||||
INSERT INTO device_logs (device_serial, level, message, device_timestamp, received_at)
|
||||
VALUES (:serial, :level, :message, :ts, now())
|
||||
RETURNING id
|
||||
"""),
|
||||
{"serial": device_serial, "level": level, "message": message, "ts": device_timestamp},
|
||||
)
|
||||
row = result.fetchone()
|
||||
await session.commit()
|
||||
return row[0]
|
||||
|
||||
|
||||
async def insert_heartbeat(device_serial: str, device_id: str,
|
||||
firmware_version: str, ip_address: str,
|
||||
gateway: str, uptime_ms: int, uptime_display: str) -> int:
|
||||
async with AsyncSessionLocal() as session:
|
||||
result = await session.execute(
|
||||
text("""
|
||||
INSERT INTO heartbeats
|
||||
(device_serial, device_id, firmware_version, ip_address,
|
||||
gateway, uptime_ms, uptime_display, received_at)
|
||||
VALUES
|
||||
(:serial, :device_id, :fw, :ip, :gw, :uptime_ms, :uptime_display, now())
|
||||
RETURNING id
|
||||
"""),
|
||||
{
|
||||
"serial": device_serial,
|
||||
"device_id": device_id,
|
||||
"fw": firmware_version,
|
||||
"ip": ip_address,
|
||||
"gw": gateway,
|
||||
"uptime_ms": uptime_ms,
|
||||
"uptime_display": uptime_display,
|
||||
},
|
||||
)
|
||||
row = result.fetchone()
|
||||
await session.commit()
|
||||
return row[0]
|
||||
|
||||
|
||||
async def insert_command(device_serial: str, command_name: str,
|
||||
command_payload: dict) -> int:
|
||||
async with AsyncSessionLocal() as session:
|
||||
result = await session.execute(
|
||||
text("""
|
||||
INSERT INTO commands (device_serial, command_name, command_payload, sent_at)
|
||||
VALUES (:serial, :name, :payload, now())
|
||||
RETURNING id
|
||||
"""),
|
||||
{
|
||||
"serial": device_serial,
|
||||
"name": command_name,
|
||||
"payload": json.dumps(command_payload),
|
||||
},
|
||||
)
|
||||
row = result.fetchone()
|
||||
await session.commit()
|
||||
return row[0]
|
||||
|
||||
|
||||
async def update_command_response(command_id: int, status: str,
|
||||
response_payload: dict | None = None):
|
||||
async with AsyncSessionLocal() as session:
|
||||
await session.execute(
|
||||
text("""
|
||||
UPDATE commands
|
||||
SET status = :status,
|
||||
response_payload = :payload,
|
||||
responded_at = now()
|
||||
WHERE id = :id
|
||||
"""),
|
||||
{
|
||||
"id": command_id,
|
||||
"status": status,
|
||||
"payload": json.dumps(response_payload) if response_payload else None,
|
||||
},
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Query operations
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def get_logs(device_serial: str, level: str | None = None,
|
||||
search: str | None = None,
|
||||
limit: int = 100, offset: int = 0) -> tuple[list, int]:
|
||||
where = "device_serial = :serial"
|
||||
params: dict = {"serial": device_serial, "limit": limit, "offset": offset}
|
||||
|
||||
if level:
|
||||
where += " AND level = :level"
|
||||
params["level"] = level
|
||||
if search:
|
||||
where += " AND message ILIKE :search"
|
||||
params["search"] = f"%{search}%"
|
||||
|
||||
async with AsyncSessionLocal() as session:
|
||||
count_result = await session.execute(
|
||||
text(f"SELECT COUNT(*) FROM device_logs WHERE {where}"), params
|
||||
)
|
||||
total = count_result.scalar()
|
||||
|
||||
rows_result = await session.execute(
|
||||
text(f"""
|
||||
SELECT id, device_serial, level, message, device_timestamp,
|
||||
received_at AT TIME ZONE 'UTC' AS received_at
|
||||
FROM device_logs
|
||||
WHERE {where}
|
||||
ORDER BY received_at DESC
|
||||
LIMIT :limit OFFSET :offset
|
||||
"""),
|
||||
params,
|
||||
)
|
||||
rows = rows_result.mappings().all()
|
||||
|
||||
return [_row_to_dict(r) for r in rows], total
|
||||
|
||||
|
||||
async def get_heartbeats(device_serial: str, limit: int = 100,
|
||||
offset: int = 0) -> tuple[list, int]:
|
||||
async with AsyncSessionLocal() as session:
|
||||
count_result = await session.execute(
|
||||
text("SELECT COUNT(*) FROM heartbeats WHERE device_serial = :serial"),
|
||||
{"serial": device_serial},
|
||||
)
|
||||
total = count_result.scalar()
|
||||
|
||||
rows_result = await session.execute(
|
||||
text("""
|
||||
SELECT id, device_serial, device_id, firmware_version, ip_address,
|
||||
gateway, uptime_ms, uptime_display,
|
||||
received_at AT TIME ZONE 'UTC' AS received_at
|
||||
FROM heartbeats
|
||||
WHERE device_serial = :serial
|
||||
ORDER BY received_at DESC
|
||||
LIMIT :limit OFFSET :offset
|
||||
"""),
|
||||
{"serial": device_serial, "limit": limit, "offset": offset},
|
||||
)
|
||||
rows = rows_result.mappings().all()
|
||||
|
||||
return [_row_to_dict(r) for r in rows], total
|
||||
|
||||
|
||||
async def get_commands(device_serial: str, limit: int = 100,
|
||||
offset: int = 0) -> tuple[list, int]:
|
||||
async with AsyncSessionLocal() as session:
|
||||
count_result = await session.execute(
|
||||
text("SELECT COUNT(*) FROM commands WHERE device_serial = :serial"),
|
||||
{"serial": device_serial},
|
||||
)
|
||||
total = count_result.scalar()
|
||||
|
||||
rows_result = await session.execute(
|
||||
text("""
|
||||
SELECT id, device_serial, command_name, command_payload, status,
|
||||
response_payload,
|
||||
sent_at AT TIME ZONE 'UTC' AS sent_at,
|
||||
responded_at AT TIME ZONE 'UTC' AS responded_at
|
||||
FROM commands
|
||||
WHERE device_serial = :serial
|
||||
ORDER BY sent_at DESC
|
||||
LIMIT :limit OFFSET :offset
|
||||
"""),
|
||||
{"serial": device_serial, "limit": limit, "offset": offset},
|
||||
)
|
||||
rows = rows_result.mappings().all()
|
||||
|
||||
return [_row_to_dict(r) for r in rows], total
|
||||
|
||||
|
||||
async def get_latest_heartbeats() -> list:
|
||||
async with AsyncSessionLocal() as session:
|
||||
rows_result = await session.execute(
|
||||
text("""
|
||||
SELECT DISTINCT ON (device_serial)
|
||||
id, device_serial, device_id, firmware_version, ip_address,
|
||||
gateway, uptime_ms, uptime_display,
|
||||
received_at AT TIME ZONE 'UTC' AS received_at
|
||||
FROM heartbeats
|
||||
ORDER BY device_serial, received_at DESC
|
||||
""")
|
||||
)
|
||||
rows = rows_result.mappings().all()
|
||||
|
||||
return [_row_to_dict(r) for r in rows]
|
||||
|
||||
|
||||
async def get_pending_command(device_serial: str) -> dict | None:
|
||||
async with AsyncSessionLocal() as session:
|
||||
result = await session.execute(
|
||||
text("""
|
||||
SELECT id, device_serial, command_name, command_payload, status,
|
||||
response_payload,
|
||||
sent_at AT TIME ZONE 'UTC' AS sent_at,
|
||||
responded_at AT TIME ZONE 'UTC' AS responded_at
|
||||
FROM commands
|
||||
WHERE device_serial = :serial AND status = 'pending'
|
||||
ORDER BY sent_at DESC
|
||||
LIMIT 1
|
||||
"""),
|
||||
{"serial": device_serial},
|
||||
)
|
||||
row = result.mappings().fetchone()
|
||||
|
||||
return _row_to_dict(row) if row else None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Device alerts
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def upsert_alert(device_serial: str, subsystem: str, state: str,
|
||||
message: str | None = None):
|
||||
async with AsyncSessionLocal() as session:
|
||||
await session.execute(
|
||||
text("""
|
||||
INSERT INTO device_alerts (device_serial, subsystem, state, message, updated_at)
|
||||
VALUES (:serial, :subsystem, :state, :message, now())
|
||||
ON CONFLICT (device_serial, subsystem)
|
||||
DO UPDATE SET
|
||||
state = EXCLUDED.state,
|
||||
message = EXCLUDED.message,
|
||||
updated_at = EXCLUDED.updated_at
|
||||
"""),
|
||||
{"serial": device_serial, "subsystem": subsystem, "state": state, "message": message},
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
|
||||
async def delete_alert(device_serial: str, subsystem: str):
|
||||
async with AsyncSessionLocal() as session:
|
||||
await session.execute(
|
||||
text("DELETE FROM device_alerts WHERE device_serial = :serial AND subsystem = :subsystem"),
|
||||
{"serial": device_serial, "subsystem": subsystem},
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
|
||||
async def get_alerts(device_serial: str) -> list:
|
||||
async with AsyncSessionLocal() as session:
|
||||
result = await session.execute(
|
||||
text("""
|
||||
SELECT id, device_serial, subsystem, state, message,
|
||||
updated_at AT TIME ZONE 'UTC' AS updated_at
|
||||
FROM device_alerts
|
||||
WHERE device_serial = :serial
|
||||
ORDER BY updated_at DESC
|
||||
"""),
|
||||
{"serial": device_serial},
|
||||
)
|
||||
rows = result.mappings().all()
|
||||
|
||||
return [_row_to_dict(r) for r in rows]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Partition management
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def ensure_current_partitions():
|
||||
"""Create device_logs partitions for the current and next month if missing."""
|
||||
async with AsyncSessionLocal() as session:
|
||||
for month_offset in (0, 1):
|
||||
d = date.today().replace(day=1) + relativedelta(months=month_offset)
|
||||
partition_name = f"device_logs_{d.strftime('%Y_%m')}"
|
||||
start = d.isoformat()
|
||||
end = (d + relativedelta(months=1)).isoformat()
|
||||
await session.execute(text(f"""
|
||||
CREATE TABLE IF NOT EXISTS {partition_name}
|
||||
PARTITION OF device_logs
|
||||
FOR VALUES FROM ('{start}') TO ('{end}')
|
||||
"""))
|
||||
await session.commit()
|
||||
logger.info("Partition check complete")
|
||||
|
||||
|
||||
async def drop_old_partitions(keep_months: int = 6):
|
||||
"""Drop device_logs partitions older than keep_months."""
|
||||
cutoff = date.today().replace(day=1) - relativedelta(months=keep_months)
|
||||
async with AsyncSessionLocal() as session:
|
||||
result = await session.execute(text("""
|
||||
SELECT tablename FROM pg_tables
|
||||
WHERE schemaname = 'public'
|
||||
AND tablename LIKE 'device_logs_%'
|
||||
"""))
|
||||
partitions = [r[0] for r in result.fetchall()]
|
||||
|
||||
for name in partitions:
|
||||
# name format: device_logs_YYYY_MM
|
||||
parts = name.split("_")
|
||||
if len(parts) != 4:
|
||||
continue
|
||||
try:
|
||||
partition_date = date(int(parts[2]), int(parts[3]), 1)
|
||||
except ValueError:
|
||||
continue
|
||||
if partition_date < cutoff:
|
||||
async with AsyncSessionLocal() as session:
|
||||
await session.execute(text(f"DROP TABLE IF EXISTS {name}"))
|
||||
await session.commit()
|
||||
logger.info(f"Dropped old partition: {name}")
|
||||
|
||||
|
||||
async def partition_manager_loop():
|
||||
"""Runs once on startup, then monthly thereafter."""
|
||||
await ensure_current_partitions()
|
||||
while True:
|
||||
# Sleep ~30 days, wake up and ensure next month's partition exists
|
||||
await asyncio.sleep(30 * 24 * 3600)
|
||||
try:
|
||||
await ensure_current_partitions()
|
||||
await drop_old_partitions()
|
||||
except Exception as e:
|
||||
logger.error(f"Partition manager error: {e}")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Cleanup (replaces SQLite purge_loop — now a no-op since Postgres uses
|
||||
# partition drops instead of row-by-row deletes for device_logs; heartbeats
|
||||
# and commands are still purged by row deletion)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def purge_old_data(retention_days: int | None = None):
|
||||
days = retention_days or settings.mqtt_data_retention_days
|
||||
cutoff = datetime.now(timezone.utc) - timedelta(days=days)
|
||||
async with AsyncSessionLocal() as session:
|
||||
await session.execute(
|
||||
text("DELETE FROM heartbeats WHERE received_at < :cutoff"),
|
||||
{"cutoff": cutoff},
|
||||
)
|
||||
await session.execute(
|
||||
text("DELETE FROM commands WHERE sent_at < :cutoff"),
|
||||
{"cutoff": cutoff},
|
||||
)
|
||||
await session.commit()
|
||||
logger.info(f"Purged heartbeats and commands older than {days} days")
|
||||
|
||||
|
||||
async def purge_loop():
|
||||
while True:
|
||||
await asyncio.sleep(86400)
|
||||
try:
|
||||
await purge_old_data()
|
||||
except Exception as e:
|
||||
logger.error(f"Purge failed: {e}")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Stub — no longer needed but kept so nothing that imports init_db/close_db breaks
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def init_db():
|
||||
"""No-op: Postgres schema is managed by Alembic, not runtime init."""
|
||||
logger.info("Postgres MQTT backend active — no SQLite init needed")
|
||||
|
||||
|
||||
async def close_db():
|
||||
"""No-op: SQLAlchemy engine lifecycle is managed by the process."""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _row_to_dict(row) -> dict:
|
||||
"""Convert a SQLAlchemy RowMapping to a plain dict with ISO string timestamps."""
|
||||
d = dict(row)
|
||||
for key, val in d.items():
|
||||
if isinstance(val, datetime):
|
||||
d[key] = val.isoformat()
|
||||
return d
|
||||
@@ -108,9 +108,9 @@ async def crm_poll_loop():
|
||||
@app.on_event("startup")
|
||||
async def startup():
|
||||
init_firebase()
|
||||
await db.init_db()
|
||||
await melody_service.migrate_from_firestore()
|
||||
mqtt_manager.start(asyncio.get_event_loop())
|
||||
asyncio.create_task(db.partition_manager_loop())
|
||||
asyncio.create_task(db.purge_loop())
|
||||
asyncio.create_task(nextcloud_keepalive_loop())
|
||||
asyncio.create_task(crm_poll_loop())
|
||||
@@ -125,7 +125,6 @@ async def startup():
|
||||
@app.on_event("shutdown")
|
||||
async def shutdown():
|
||||
mqtt_manager.stop()
|
||||
await db.close_db()
|
||||
await close_nextcloud_client()
|
||||
|
||||
|
||||
|
||||
@@ -129,27 +129,29 @@ async def mqtt_websocket(websocket: WebSocket):
|
||||
|
||||
try:
|
||||
from auth.utils import decode_access_token
|
||||
from shared.firebase import get_db
|
||||
from sqlalchemy import select
|
||||
from database.postgres import AsyncSessionLocal
|
||||
from staff.orm import Staff
|
||||
|
||||
payload = decode_access_token(token)
|
||||
role = payload.get("role", "")
|
||||
|
||||
# sysadmin and admin always have MQTT access
|
||||
if role not in ("sysadmin", "admin"):
|
||||
# Check MQTT permission for editor/user
|
||||
user_sub = payload.get("sub", "")
|
||||
db_inst = get_db()
|
||||
if db_inst:
|
||||
doc = db_inst.collection("admin_users").document(user_sub).get()
|
||||
if doc.exists:
|
||||
perms = doc.to_dict().get("permissions", {})
|
||||
if not perms.get("mqtt", False):
|
||||
await websocket.close(code=4003, reason="MQTT access denied")
|
||||
return
|
||||
else:
|
||||
await websocket.close(code=4003, reason="User not found")
|
||||
return
|
||||
else:
|
||||
await websocket.close(code=4003, reason="Service unavailable")
|
||||
async with AsyncSessionLocal() as session:
|
||||
result = await session.execute(
|
||||
select(Staff).where(Staff.id == user_sub).limit(1)
|
||||
)
|
||||
staff = result.scalar_one_or_none()
|
||||
|
||||
if staff is None:
|
||||
await websocket.close(code=4003, reason="User not found")
|
||||
return
|
||||
|
||||
perms = staff.permissions or {}
|
||||
if not perms.get("mqtt", {}).get("access", False):
|
||||
await websocket.close(code=4003, reason="MQTT access denied")
|
||||
return
|
||||
except Exception:
|
||||
await websocket.close(code=4001, reason="Invalid token")
|
||||
|
||||
Reference in New Issue
Block a user