223 lines
7.4 KiB
Python
223 lines
7.4 KiB
Python
import aiosqlite
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
from datetime import datetime, timedelta, timezone
|
|
from config import settings
|
|
|
|
logger = logging.getLogger("mqtt.database")
|
|
|
|
_db: aiosqlite.Connection | None = None
|
|
|
|
SCHEMA_STATEMENTS = [
|
|
"""CREATE TABLE IF NOT EXISTS device_logs (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
device_serial TEXT NOT NULL,
|
|
level TEXT NOT NULL,
|
|
message TEXT NOT NULL,
|
|
device_timestamp INTEGER,
|
|
received_at TEXT NOT NULL DEFAULT (datetime('now'))
|
|
)""",
|
|
"""CREATE TABLE IF NOT EXISTS heartbeats (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
device_serial TEXT NOT NULL,
|
|
device_id TEXT,
|
|
firmware_version TEXT,
|
|
ip_address TEXT,
|
|
gateway TEXT,
|
|
uptime_ms INTEGER,
|
|
uptime_display TEXT,
|
|
received_at TEXT NOT NULL DEFAULT (datetime('now'))
|
|
)""",
|
|
"""CREATE TABLE IF NOT EXISTS commands (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
device_serial TEXT NOT NULL,
|
|
command_name TEXT NOT NULL,
|
|
command_payload TEXT,
|
|
status TEXT NOT NULL DEFAULT 'pending',
|
|
response_payload TEXT,
|
|
sent_at TEXT NOT NULL DEFAULT (datetime('now')),
|
|
responded_at TEXT
|
|
)""",
|
|
"CREATE INDEX IF NOT EXISTS idx_logs_serial_time ON device_logs(device_serial, received_at)",
|
|
"CREATE INDEX IF NOT EXISTS idx_logs_level ON device_logs(level)",
|
|
"CREATE INDEX IF NOT EXISTS idx_heartbeats_serial_time ON heartbeats(device_serial, received_at)",
|
|
"CREATE INDEX IF NOT EXISTS idx_commands_serial_time ON commands(device_serial, sent_at)",
|
|
"CREATE INDEX IF NOT EXISTS idx_commands_status ON commands(status)",
|
|
]
|
|
|
|
|
|
async def init_db():
|
|
global _db
|
|
_db = await aiosqlite.connect(settings.sqlite_db_path)
|
|
_db.row_factory = aiosqlite.Row
|
|
for stmt in SCHEMA_STATEMENTS:
|
|
await _db.execute(stmt)
|
|
await _db.commit()
|
|
logger.info(f"SQLite database initialized at {settings.sqlite_db_path}")
|
|
|
|
|
|
async def close_db():
|
|
global _db
|
|
if _db:
|
|
await _db.close()
|
|
_db = None
|
|
|
|
|
|
async def get_db() -> aiosqlite.Connection:
|
|
if _db is None:
|
|
await init_db()
|
|
return _db
|
|
|
|
|
|
# --- Insert Operations ---
|
|
|
|
async def insert_log(device_serial: str, level: str, message: str,
|
|
device_timestamp: int | None = None):
|
|
db = await get_db()
|
|
cursor = await db.execute(
|
|
"INSERT INTO device_logs (device_serial, level, message, device_timestamp) VALUES (?, ?, ?, ?)",
|
|
(device_serial, level, message, device_timestamp)
|
|
)
|
|
await db.commit()
|
|
return cursor.lastrowid
|
|
|
|
|
|
async def insert_heartbeat(device_serial: str, device_id: str,
|
|
firmware_version: str, ip_address: str,
|
|
gateway: str, uptime_ms: int, uptime_display: str):
|
|
db = await get_db()
|
|
cursor = await db.execute(
|
|
"""INSERT INTO heartbeats
|
|
(device_serial, device_id, firmware_version, ip_address, gateway, uptime_ms, uptime_display)
|
|
VALUES (?, ?, ?, ?, ?, ?, ?)""",
|
|
(device_serial, device_id, firmware_version, ip_address, gateway, uptime_ms, uptime_display)
|
|
)
|
|
await db.commit()
|
|
return cursor.lastrowid
|
|
|
|
|
|
async def insert_command(device_serial: str, command_name: str,
|
|
command_payload: dict) -> int:
|
|
db = await get_db()
|
|
cursor = await db.execute(
|
|
"INSERT INTO commands (device_serial, command_name, command_payload) VALUES (?, ?, ?)",
|
|
(device_serial, command_name, json.dumps(command_payload))
|
|
)
|
|
await db.commit()
|
|
return cursor.lastrowid
|
|
|
|
|
|
async def update_command_response(command_id: int, status: str,
|
|
response_payload: dict | None = None):
|
|
db = await get_db()
|
|
await db.execute(
|
|
"""UPDATE commands SET status = ?, response_payload = ?,
|
|
responded_at = datetime('now') WHERE id = ?""",
|
|
(status, json.dumps(response_payload) if response_payload else None, command_id)
|
|
)
|
|
await db.commit()
|
|
|
|
|
|
# --- Query Operations ---
|
|
|
|
async def get_logs(device_serial: str, level: str | None = None,
|
|
search: str | None = None,
|
|
limit: int = 100, offset: int = 0) -> tuple[list, int]:
|
|
db = await get_db()
|
|
where_clauses = ["device_serial = ?"]
|
|
params: list = [device_serial]
|
|
|
|
if level:
|
|
where_clauses.append("level = ?")
|
|
params.append(level)
|
|
if search:
|
|
where_clauses.append("message LIKE ?")
|
|
params.append(f"%{search}%")
|
|
|
|
where = " AND ".join(where_clauses)
|
|
|
|
count_row = await db.execute_fetchall(
|
|
f"SELECT COUNT(*) as cnt FROM device_logs WHERE {where}", params
|
|
)
|
|
total = count_row[0][0]
|
|
|
|
rows = await db.execute_fetchall(
|
|
f"SELECT * FROM device_logs WHERE {where} ORDER BY received_at DESC LIMIT ? OFFSET ?",
|
|
params + [limit, offset]
|
|
)
|
|
return [dict(r) for r in rows], total
|
|
|
|
|
|
async def get_heartbeats(device_serial: str, limit: int = 100,
|
|
offset: int = 0) -> tuple[list, int]:
|
|
db = await get_db()
|
|
count_row = await db.execute_fetchall(
|
|
"SELECT COUNT(*) FROM heartbeats WHERE device_serial = ?", (device_serial,)
|
|
)
|
|
total = count_row[0][0]
|
|
rows = await db.execute_fetchall(
|
|
"SELECT * FROM heartbeats WHERE device_serial = ? ORDER BY received_at DESC LIMIT ? OFFSET ?",
|
|
(device_serial, limit, offset)
|
|
)
|
|
return [dict(r) for r in rows], total
|
|
|
|
|
|
async def get_commands(device_serial: str, limit: int = 100,
|
|
offset: int = 0) -> tuple[list, int]:
|
|
db = await get_db()
|
|
count_row = await db.execute_fetchall(
|
|
"SELECT COUNT(*) FROM commands WHERE device_serial = ?", (device_serial,)
|
|
)
|
|
total = count_row[0][0]
|
|
rows = await db.execute_fetchall(
|
|
"SELECT * FROM commands WHERE device_serial = ? ORDER BY sent_at DESC LIMIT ? OFFSET ?",
|
|
(device_serial, limit, offset)
|
|
)
|
|
return [dict(r) for r in rows], total
|
|
|
|
|
|
async def get_latest_heartbeats() -> list:
|
|
db = await get_db()
|
|
rows = await db.execute_fetchall("""
|
|
SELECT h.* FROM heartbeats h
|
|
INNER JOIN (
|
|
SELECT device_serial, MAX(received_at) as max_time
|
|
FROM heartbeats GROUP BY device_serial
|
|
) latest ON h.device_serial = latest.device_serial
|
|
AND h.received_at = latest.max_time
|
|
""")
|
|
return [dict(r) for r in rows]
|
|
|
|
|
|
async def get_pending_command(device_serial: str) -> dict | None:
|
|
db = await get_db()
|
|
rows = await db.execute_fetchall(
|
|
"""SELECT * FROM commands WHERE device_serial = ? AND status = 'pending'
|
|
ORDER BY sent_at DESC LIMIT 1""",
|
|
(device_serial,)
|
|
)
|
|
return dict(rows[0]) if rows else None
|
|
|
|
|
|
# --- Cleanup ---
|
|
|
|
async def purge_old_data(retention_days: int | None = None):
|
|
days = retention_days or settings.mqtt_data_retention_days
|
|
cutoff = (datetime.now(timezone.utc) - timedelta(days=days)).isoformat()
|
|
db = await get_db()
|
|
await db.execute("DELETE FROM device_logs WHERE received_at < ?", (cutoff,))
|
|
await db.execute("DELETE FROM heartbeats WHERE received_at < ?", (cutoff,))
|
|
await db.execute("DELETE FROM commands WHERE sent_at < ?", (cutoff,))
|
|
await db.commit()
|
|
logger.info(f"Purged MQTT data older than {days} days")
|
|
|
|
|
|
async def purge_loop():
|
|
while True:
|
|
await asyncio.sleep(86400)
|
|
try:
|
|
await purge_old_data()
|
|
except Exception as e:
|
|
logger.error(f"Purge failed: {e}")
|